bunch of training and label latency fixes, also trained 70% accurate model.

This commit is contained in:
Surya Balaji
2026-01-27 20:12:13 -06:00
parent f656f466e7
commit 9bdf9d1109
41 changed files with 563 additions and 124 deletions

View File

@@ -12,6 +12,6 @@ board_upload.maximum_size = 33554432
; and I can give you the content for it.
board_build.partitions = partitions.csv
monitor_speed = 115200
monitor_speed = 921600
monitor_dtr = 1
monitor_rts = 1

View File

@@ -232,9 +232,8 @@ static void stream_emg_data(void)
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
emg_sensor_read(&sample);
/* Output in CSV format matching Python expectation */
printf("%lu,%u,%u,%u,%u\n",
(unsigned long)sample.timestamp_ms,
/* Output in CSV format - channels only, Python handles timestamps */
printf("%u,%u,%u,%u\n",
sample.channels[0],
sample.channels[1],
sample.channels[2],
@@ -283,7 +282,7 @@ void emgPrinter() {
if (i != EMG_NUM_CHANNELS - 1) printf(" | ");
}
printf("\n");
vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
// vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
}
}
@@ -339,6 +338,6 @@ void app_main(void)
printf("[INIT] Done!\n\n");
emgPrinter();
// appConnector();
// emgPrinter();
appConnector();
}

View File

@@ -63,7 +63,7 @@
* EMG Configuration
******************************************************************************/
#define EMG_NUM_CHANNELS 1 /**< Number of EMG sensor channels */
#define EMG_NUM_CHANNELS 4 /**< Number of EMG sensor channels */
#define EMG_SAMPLE_RATE_HZ 1000 /**< Samples per second per channel */
/*******************************************************************************

View File

@@ -19,7 +19,12 @@
adc_oneshot_unit_handle_t adc1_handle;
adc_cali_handle_t cali_handle = NULL;
const uint8_t emg_channels[EMG_NUM_CHANNELS] = {ADC_CHANNEL_1};
const uint8_t emg_channels[EMG_NUM_CHANNELS] = {
ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0
ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1
ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2
ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3
};
/*******************************************************************************
* Public Functions
@@ -94,7 +99,6 @@ void emg_sensor_read(emg_sample_t *sample)
ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
sample->channels[i] = (uint16_t) voltage_mv;
}
printf("\n");
#endif
}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -301,6 +301,7 @@ class CollectionPage(BasePage):
self.scheduler = None
self.collected_windows = []
self.collected_labels = []
self.collected_raw_samples = [] # For label alignment
self.sample_buffer = []
self.collection_thread = None
self.data_queue = queue.Queue()
@@ -511,7 +512,7 @@ class CollectionPage(BasePage):
ax.tick_params(colors='white')
ax.set_ylabel(f'Ch{i}', color='white', fontsize=10)
ax.set_xlim(0, 500)
ax.set_ylim(0, 1024)
ax.set_ylim(0, 3300) # ESP32 outputs millivolts (0-3100 mV)
ax.grid(True, alpha=0.3)
for spine in ax.spines.values():
spine.set_color('white')
@@ -648,6 +649,7 @@ class CollectionPage(BasePage):
# Reset state
self.collected_windows = []
self.collected_labels = []
self.collected_raw_samples = [] # Store raw samples for label alignment
self.sample_buffer = []
print("[DEBUG] Reset collection state")
@@ -800,6 +802,9 @@ class CollectionPage(BasePage):
timeout_warning_sent = False
sample = self.parser.parse_line(line)
if sample:
# Store raw sample for label alignment
self.collected_raw_samples.append(sample)
# Batch samples for plotting (don't send every single one)
sample_batch.append(sample.channels)
@@ -932,9 +937,25 @@ class CollectionPage(BasePage):
notes=""
)
filepath = storage.save_session(self.collected_windows, self.collected_labels, metadata)
# Get session start time for label alignment
session_start_time = None
if self.scheduler and self.scheduler.session_start_time:
session_start_time = self.scheduler.session_start_time
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}")
filepath = storage.save_session(
windows=self.collected_windows,
labels=self.collected_labels,
metadata=metadata,
raw_samples=self.collected_raw_samples if self.collected_raw_samples else None,
session_start_time=session_start_time
)
# Check if alignment was performed
alignment_msg = ""
if session_start_time and self.collected_raw_samples:
alignment_msg = "\n\nLabel alignment: enabled"
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}{alignment_msg}")
# Update sidebar
app = self.winfo_toplevel()
@@ -944,6 +965,7 @@ class CollectionPage(BasePage):
# Reset for next collection
self.collected_windows = []
self.collected_labels = []
self.collected_raw_samples = []
self.save_button.configure(state="disabled")
self.status_label.configure(text="Ready to collect")
self.window_count_label.configure(text="Windows: 0")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -39,7 +39,7 @@ import matplotlib.pyplot as plt
# =============================================================================
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
SERIAL_BAUD = 115200 # Typical baud rate for ESP32
SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
# Windowing configuration
WINDOW_SIZE_MS = 150 # Window size in milliseconds
@@ -55,6 +55,27 @@ DATA_DIR = Path("collected_data") # Directory to store session files
MODEL_DIR = Path("models") # Directory to store trained models
USER_ID = "user_001" # Current user ID (change per user)
# =============================================================================
# LABEL ALIGNMENT CONFIGURATION
# =============================================================================
# Human reaction time causes EMG activity to lag behind label prompts.
# We detect when EMG actually rises and shift labels to match.
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
# =============================================================================
# TRANSITION WINDOW FILTERING
# =============================================================================
# Windows near gesture transitions contain ambiguous data (reaction time at start,
# muscle relaxation at end). Discard these during training for cleaner labels.
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
# =============================================================================
# DATA STRUCTURES
# =============================================================================
@@ -187,30 +208,27 @@ class EMGParser:
"""
Parse a line from ESP32 into an EMGSample.
Expected format: "timestamp_ms,ch0,ch1,ch2,ch3\n"
Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
Python assigns timestamp on receipt for label alignment.
Returns None if parsing fails.
"""
try:
# Strip whitespace and split
parts = line.strip().split(',')
# Validate we have correct number of fields
expected_fields = 1 + self.num_channels # timestamp + channels
if len(parts) != expected_fields:
# Validate we have correct number of fields (channels only)
if len(parts) != self.num_channels:
self.parse_errors += 1
return None
# Parse ESP32 timestamp
esp_timestamp_ms = int(parts[0])
# Parse channel values
channels = [float(parts[i + 1]) for i in range(self.num_channels)]
channels = [float(parts[i]) for i in range(self.num_channels)]
# Create sample with Python-side timestamp
# Create sample with Python-side timestamp (aligned with label clock)
sample = EMGSample(
timestamp=time.perf_counter(), # High-resolution monotonic clock
channels=channels,
esp_timestamp_ms=esp_timestamp_ms
esp_timestamp_ms=None # No longer using ESP32 timestamp
)
self.samples_parsed += 1
@@ -491,6 +509,179 @@ class GestureAwareEMGStream(SimulatedEMGStream):
time.sleep(interval)
# =============================================================================
# LABEL ALIGNMENT (Simple Onset Detection)
# =============================================================================
from scipy.signal import butter, sosfiltfilt
def align_labels_with_onset(
labels: list[str],
window_start_times: np.ndarray,
raw_timestamps: np.ndarray,
raw_channels: np.ndarray,
sampling_rate: int,
threshold_factor: float = 2.0,
search_ms: float = 800
) -> list[str]:
"""
Align labels to EMG onset by detecting when signal rises above baseline.
Simple algorithm:
1. High-pass filter to remove DC offset
2. Compute RMS envelope across channels
3. At each label transition, find where envelope exceeds baseline + threshold
4. Move label boundary to that point
"""
if len(labels) == 0:
return labels.copy()
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
nyquist = sampling_rate / 2
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
# Filter and compute envelope (RMS across channels)
filtered = np.zeros_like(raw_channels)
for ch in range(raw_channels.shape[1]):
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
# Smooth envelope
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
envelope = sosfiltfilt(sos_lp, envelope)
# Find transitions and detect onsets
search_samples = int(search_ms / 1000 * sampling_rate)
baseline_samples = int(200 / 1000 * sampling_rate)
boundaries = [] # (time, new_label)
for i in range(1, len(labels)):
if labels[i] != labels[i - 1]:
prompt_time = window_start_times[i]
# Find index in raw signal closest to prompt time
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
# Get baseline (before transition)
base_start = max(0, prompt_idx - baseline_samples)
baseline = envelope[base_start:prompt_idx]
if len(baseline) == 0:
boundaries.append((prompt_time + 0.3, labels[i]))
continue
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
# Search forward for onset
search_end = min(len(envelope), prompt_idx + search_samples)
onset_idx = None
for j in range(prompt_idx, search_end):
if envelope[j] > threshold:
onset_idx = j
break
if onset_idx is not None:
onset_time = raw_timestamps[onset_idx]
else:
onset_time = prompt_time + 0.3 # fallback
boundaries.append((onset_time, labels[i]))
# Assign labels based on detected boundaries
aligned = []
boundary_idx = 0
current_label = labels[0]
for t in window_start_times:
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
current_label = boundaries[boundary_idx][1]
boundary_idx += 1
aligned.append(current_label)
return aligned
def filter_transition_windows(
X: np.ndarray,
y: np.ndarray,
labels: list[str],
start_times: np.ndarray,
end_times: np.ndarray,
transition_start_ms: float = TRANSITION_START_MS,
transition_end_ms: float = TRANSITION_END_MS
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Filter out windows that fall within transition zones at gesture boundaries.
This removes ambiguous data where:
- User is still reacting to prompt (start of gesture)
- User is anticipating next gesture (end of gesture)
Args:
X: EMG data array (n_windows, samples, channels)
y: Label indices (n_windows,)
labels: String labels (n_windows,)
start_times: Window start times in seconds (n_windows,)
end_times: Window end times in seconds (n_windows,)
transition_start_ms: Discard windows within this time after gesture start
transition_end_ms: Discard windows within this time before gesture end
Returns:
Filtered (X, y, labels) with transition windows removed
"""
if len(X) == 0:
return X, y, labels
transition_start_sec = transition_start_ms / 1000.0
transition_end_sec = transition_end_ms / 1000.0
# Find gesture boundaries (where label changes)
# Each boundary is the START of a new gesture segment
boundaries = [0] # First window starts a segment
for i in range(1, len(labels)):
if labels[i] != labels[i-1]:
boundaries.append(i)
boundaries.append(len(labels)) # End marker
# For each segment, find start_time and end_time of the gesture
# Then mark windows that are within transition zones
keep_mask = np.ones(len(X), dtype=bool)
for seg_idx in range(len(boundaries) - 1):
seg_start_idx = boundaries[seg_idx]
seg_end_idx = boundaries[seg_idx + 1]
# Get the time boundaries of this gesture segment
gesture_start_time = start_times[seg_start_idx]
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
# Mark windows in transition zones
for i in range(seg_start_idx, seg_end_idx):
window_start = start_times[i]
window_end = end_times[i]
# Check if window is too close to gesture START (reaction time zone)
if window_start < gesture_start_time + transition_start_sec:
keep_mask[i] = False
# Check if window is too close to gesture END (anticipation zone)
if window_end > gesture_end_time - transition_end_sec:
keep_mask[i] = False
# Apply filter
X_filtered = X[keep_mask]
y_filtered = y[keep_mask]
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
n_removed = len(X) - len(X_filtered)
if n_removed > 0:
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
print(f"[Filter] Kept {len(X_filtered)} windows for training")
return X_filtered, y_filtered, labels_filtered
# =============================================================================
# SESSION STORAGE (Save/Load labeled data to HDF5)
# =============================================================================
@@ -527,16 +718,24 @@ class SessionStorage:
windows: list[EMGWindow],
labels: list[str],
metadata: SessionMetadata,
raw_samples: Optional[list[EMGSample]] = None
raw_samples: Optional[list[EMGSample]] = None,
session_start_time: Optional[float] = None,
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
) -> Path:
"""
Save a collection session to HDF5.
Save a collection session to HDF5 with optional label alignment.
When raw_samples and session_start_time are provided and enable_alignment
is True, automatically detects EMG onset and corrects labels for human
reaction time delay.
Args:
windows: List of EMGWindow objects (no label info)
labels: List of gesture labels, parallel to windows
metadata: Session metadata
raw_samples: Optional raw samples for debugging
raw_samples: Raw samples (required for alignment)
session_start_time: When session started (required for alignment)
enable_alignment: Whether to perform automatic label alignment
"""
filepath = self.get_session_filepath(metadata.session_id)
@@ -549,6 +748,35 @@ class SessionStorage:
window_samples = len(windows[0].samples)
num_channels = len(windows[0].samples[0].channels)
# Prepare timing arrays
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
# Label alignment using onset detection
aligned_labels = labels
original_labels = labels
if enable_alignment and raw_samples and len(raw_samples) > 0:
print("[Storage] Aligning labels to EMG onset...")
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
aligned_labels = align_labels_with_onset(
labels=labels,
window_start_times=start_times,
raw_timestamps=raw_timestamps,
raw_channels=raw_channels,
sampling_rate=metadata.sampling_rate,
threshold_factor=ONSET_THRESHOLD,
search_ms=ONSET_SEARCH_MS
)
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
elif enable_alignment:
print("[Storage] Warning: No raw samples, skipping alignment")
with h5py.File(filepath, 'w') as f:
# Metadata as attributes
f.attrs['user_id'] = metadata.user_id
@@ -568,19 +796,24 @@ class SessionStorage:
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
# Labels stored separately from window data
max_label_len = max(len(l) for l in labels)
# Store ALIGNED labels as primary (what training will use)
max_label_len = max(len(l) for l in aligned_labels)
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
windows_grp.create_dataset('labels', data=labels, dtype=dt)
windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
# Also store original labels for reference/debugging
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
windows_grp.create_dataset('window_ids', data=window_ids)
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
windows_grp.create_dataset('start_times', data=start_times)
windows_grp.create_dataset('end_times', data=end_times)
# Store alignment metadata
f.attrs['alignment_enabled'] = enable_alignment
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
if raw_samples:
raw_grp = f.create_group('raw_samples')
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
@@ -655,13 +888,21 @@ class SessionStorage:
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
return windows, labels_out, metadata
def load_for_training(self, session_id: str) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Load a single session in ML-ready format: X, y, label_names."""
def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Load a single session in ML-ready format: X, y, label_names.
Args:
session_id: The session to load
filter_transitions: If True, remove windows in transition zones (default from config)
"""
filepath = self.get_session_filepath(session_id)
with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
labels = []
for l in labels_raw:
@@ -670,18 +911,33 @@ class SessionStorage:
else:
labels.append(l)
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
# Apply transition filtering if enabled
if filter_transitions:
label_names_pre = sorted(set(labels))
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
X, y_pre, labels = filter_transition_windows(
X, y_pre, labels, start_times, end_times
)
label_names = sorted(set(labels))
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
print(f"[Storage] Loaded for training: X{X.shape}, y{y.shape}")
print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
print(f"[Storage] Labels: {label_names}")
return X, y, label_names
def load_all_for_training(self) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
"""
Load ALL sessions combined into a single training dataset.
Args:
filter_transitions: If True, remove windows in transition zones (default from config)
Returns:
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
y: Combined labels as integers (n_total_windows,)
@@ -697,11 +953,15 @@ class SessionStorage:
raise ValueError("No sessions found to load!")
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
if filter_transitions:
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
all_X = []
all_labels = []
loaded_sessions = []
reference_shape = None
total_removed = 0
total_original = 0
for session_id in sessions:
filepath = self.get_session_filepath(session_id)
@@ -709,6 +969,8 @@ class SessionStorage:
with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
# Validate shape compatibility
if reference_shape is None:
@@ -725,10 +987,26 @@ class SessionStorage:
else:
labels.append(l)
original_count = len(X)
total_original += original_count
# Apply transition filtering per session (each has its own gesture boundaries)
if filter_transitions:
# Need temporary y for filtering function
temp_label_names = sorted(set(labels))
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
X, temp_y, labels = filter_transition_windows(
X, temp_y, labels, start_times, end_times
)
total_removed += original_count - len(X)
all_X.append(X)
all_labels.extend(labels)
loaded_sessions.append(session_id)
print(f"[Storage] - {session_id}: {X.shape[0]} windows")
print(f"[Storage] - {session_id}: {len(X)} windows" +
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
if not all_X:
raise ValueError("No compatible sessions found!")
@@ -742,6 +1020,8 @@ class SessionStorage:
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
if filter_transitions and total_removed > 0:
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
print(f"[Storage] Labels: {label_names}")
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")

View File

@@ -1,41 +1,152 @@
import numpy as np
import matplotlib.pyplot as plt
import h5py
import pandas as pd
from scipy.signal import butter, sosfiltfilt
# =============================================================================
# CONFIGURABLE PARAMETERS
# =============================================================================
ZC_THRESHOLD_PERCENT = 0.6 # Zero Crossing threshold as fraction of RMS
SSC_THRESHOLD_PERCENT = 0.3 # Slope Sign Change threshold as fraction of RMS
ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS
SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS
file = h5py.File("assets/discrete_gestures_user_000_dataset_000.hdf5", "r")
# =============================================================================
# LOAD DATA FROM GUI's HDF5 FORMAT
# =============================================================================
# Update this path to your collected session file
HDF5_PATH = "collected_data\latency_fix_106_20260127_200043.hdf5"
file = h5py.File(HDF5_PATH, "r")
# Print HDF5 structure for debugging
print("HDF5 Structure:")
def print_tree(name, obj):
print(name)
print(f" {name}")
file.visititems(print_tree)
print()
print(list(file.keys()))
# Load metadata
fs = file.attrs['sampling_rate'] # Sampling rate in Hz
n_channels = file.attrs['num_channels']
window_size_ms = file.attrs['window_size_ms']
print(f"Sampling rate: {fs} Hz")
print(f"Channels: {n_channels}")
print(f"Window size: {window_size_ms} ms")
data = file["data"]
print(type(data))
print(data)
print(data.dtype)
# Load windowed EMG data: shape (n_windows, samples_per_window, channels)
emg_windows = file['windows/emg_data'][:]
labels = file['windows/labels'][:]
start_times = file['windows/start_times'][:]
end_times = file['windows/end_times'][:]
raw = data[:]
print(raw.shape)
print(raw.dtype)
print(f"Windows shape: {emg_windows.shape}")
print(f"Labels: {len(labels)} (unique: {np.unique(labels)})")
emg = raw['emg']
time = raw['time']
print(emg.shape)
# Flatten windows to continuous signal for filtering analysis
# Shape: (n_windows * samples_per_window, channels)
n_windows, samples_per_window, _ = emg_windows.shape
emg = emg_windows.reshape(-1, n_channels)
print(f"Flattened EMG shape: {emg.shape}")
dt = np.diff(time)
fs = 1.0 / np.median(dt)
print("fs =", fs)
print("dt min/median/max =", dt.min(), np.median(dt), dt.max())
# Reconstruct time vector (assumes continuous recording)
total_samples = emg.shape[0]
time = np.arange(total_samples) / fs
print(f"Time range: {time[0]:.2f}s to {time[-1]:.2f}s")
# =============================================================================
# PLOT RAW EMG DATA (before filtering)
# =============================================================================
print("\nPlotting raw EMG data...")
# Color map for channels
channel_colors = ['#00ff88', '#ff6b6b', '#4ecdc4', '#ffe66d']
# Map window indices to actual time in the flattened data
# Window i starts at sample (i * samples_per_window), which is time (i * samples_per_window / fs)
window_time_in_data = np.arange(len(labels)) * samples_per_window / fs
# Compute global Y range across all channels for consistent comparison
emg_global_min = emg.min()
emg_global_max = emg.max()
emg_margin = (emg_global_max - emg_global_min) * 0.05 # 5% margin
emg_ylim = (emg_global_min - emg_margin, emg_global_max + emg_margin)
print(f"Global EMG range: {emg_global_min:.1f} to {emg_global_max:.1f} mV")
# Full session plot with shared Y-axis
fig_raw, axes_raw = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels), sharex=True, sharey=True)
if n_channels == 1:
axes_raw = [axes_raw]
for ch in range(n_channels):
ax = axes_raw[ch]
# Plot raw EMG signal
ax.plot(time, emg[:, ch], linewidth=0.3, color=channel_colors[ch % len(channel_colors)], alpha=0.8)
ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(time[0], time[-1])
ax.set_ylim(emg_ylim) # Same Y range for all channels
# Add gesture markers AFTER plotting
y_min, y_max = emg_ylim
y_text = y_min + (y_max - y_min) * 0.85 # Position text at 85% height
for ch in range(n_channels):
ax = axes_raw[ch]
# Add gesture transition markers using window-aligned times
prev_label = None
for i, lbl in enumerate(labels):
lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl
t = window_time_in_data[i] # Time in the flattened data
if lbl_str != prev_label:
if 'open' in lbl_str:
color = 'cyan'
elif 'fist' in lbl_str:
color = 'blue'
elif 'hook' in lbl_str:
color = 'orange'
elif 'thumb' in lbl_str:
color = 'green'
elif 'rest' in lbl_str:
color = 'gray'
else:
color = 'red'
ax.axvline(t, color=color, linestyle='--', alpha=0.7, linewidth=1)
# Only add text label on first channel to avoid clutter
if ch == 0 and lbl_str != 'rest':
ax.text(t + 0.2, y_text, lbl_str, fontsize=8,
color=color, rotation=0, ha='left', va='top',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7))
prev_label = lbl_str
axes_raw[-1].set_xlabel('Time (s)', fontsize=11)
fig_raw.suptitle('Raw EMG Signal (All Channels)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# Zoomed view of first 2 seconds to see waveform detail
zoom_duration = 2.0
zoom_samples = int(zoom_duration * fs)
if total_samples > zoom_samples:
fig_zoom, axes_zoom = plt.subplots(n_channels, 1, figsize=(14, 2 * n_channels), sharex=True, sharey=True)
if n_channels == 1:
axes_zoom = [axes_zoom]
for ch in range(n_channels):
ax = axes_zoom[ch]
ax.plot(time[:zoom_samples], emg[:zoom_samples, ch],
linewidth=0.5, color=channel_colors[ch % len(channel_colors)])
ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10)
ax.set_ylim(emg_ylim) # Same Y range as full plot
ax.grid(True, alpha=0.3)
axes_zoom[-1].set_xlabel('Time (s)', fontsize=11)
fig_zoom.suptitle(f'Raw EMG Signal (First {zoom_duration}s - Zoomed)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# 1) pick one channel
emg_ch = emg[:, 0].astype(np.float32)
@@ -122,10 +233,9 @@ def compute_all_features_windowed(x, window_len, threshold_zc, threshold_ssc):
# =============================================================================
# COMPUTE FEATURES FOR ALL 16 CHANNELS
# COMPUTE FEATURES FOR ALL CHANNELS
# =============================================================================
n_channels = 16
all_features = {} # Non-overlapping windows - for ML
for ch in range(n_channels):
@@ -143,101 +253,122 @@ for ch in range(n_channels):
all_features[ch] = {'rms': rms_i, 'wl': wl_i, 'zc': zc_i, 'ssc': ssc_i}
# Time vector for windowed features
n_windows = len(all_features[0]['rms'])
window_centers = np.arange(n_windows) * window_samples + window_samples // 2
n_feat_windows = len(all_features[0]['rms'])
window_centers = np.arange(n_feat_windows) * window_samples + window_samples // 2
time_windows = time_ch[window_centers]
print(f"\nComputed features for {n_channels} channels")
print(f"Windows per channel (non-overlapping): {n_windows}")
print(f"Windows per channel (non-overlapping): {n_feat_windows}")
print(f"Window size: {window_samples} samples ({window_ms} ms)")
# 5) Load gesture labels (prompts)
prompts = pd.read_hdf("assets/discrete_gestures_user_000_dataset_000.hdf5", key="prompts")
print("\nUnique gestures:", prompts['name'].unique())
# 5) Use embedded gesture labels from HDF5
# Labels are per-window, aligned with emg_windows
unique_labels = np.unique(labels)
print(f"\nUnique gestures in session: {unique_labels}")
t_abs = time_ch # absolute timestamps
# Find first occurrence of each gesture type
index_gestures = prompts[prompts['name'].str.contains('index')]
middle_gestures = prompts[prompts['name'].str.contains('middle')]
thumb_gestures = prompts[prompts['name'].str.contains('thumb')]
# Define plot configurations: 1) Index+Middle combined, 2) Thumb separate
plot_configs = [
{'name': 'Index & Middle Finger', 'start_time': index_gestures['time'].iloc[0], 'filter': 'index|middle'},
{'name': 'Thumb', 'start_time': thumb_gestures['time'].iloc[0], 'filter': 'thumb'},
]
# Map labels to time in the flattened data (same as raw plot)
# Each original window i maps to time (i * samples_per_window / fs) in the flattened data
# But we dropped `drop` samples, so adjust accordingly
label_times_in_data = np.arange(len(labels)) * samples_per_window / fs
print(f"Data duration: {time[-1]:.2f}s ({len(labels)} windows)")
# Color function for markers
def get_gesture_color(name):
if 'index' in name:
return 'green'
elif 'middle' in name:
"""Assign colors to gesture types."""
name_str = name.decode('utf-8') if isinstance(name, bytes) else name
if 'open' in name_str:
return 'cyan'
elif 'fist' in name_str:
return 'blue'
elif 'thumb' in name:
elif 'hook' in name_str:
return 'orange'
return 'gray'
elif 'thumb' in name_str:
return 'green'
elif 'rest' in name_str:
return 'gray'
return 'red'
# =============================================================================
# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL 16 CHANNELS
# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL CHANNELS
# =============================================================================
feature_names = ['rms', 'wl', 'zc', 'ssc']
feature_titles = ['RMS Envelope', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)']
feature_colors = ['red', 'blue', 'green', 'purple']
feature_ylabels = ['Amplitude', 'WL (a.u.)', 'Count', 'Count']
feature_ylabels = ['Amplitude (mV)', 'WL (a.u.)', 'Count', 'Count']
for config in plot_configs:
t_start = config['start_time'] - 0.5
t_end = t_start + 10.0
# Determine subplot grid based on channel count
if n_channels <= 4:
n_rows, n_cols = 2, 2
elif n_channels <= 9:
n_rows, n_cols = 3, 3
else:
n_rows, n_cols = 4, 4
# Mask for windowed time vector
mask_win = (time_windows >= t_start) & (time_windows <= t_end)
t_win_rel = time_windows[mask_win] - t_start
# Plot entire session for each feature
for feat_idx, feat_name in enumerate(feature_names):
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8), sharex=True, sharey=True)
axes = axes.flatten()
# Get gestures in window
gesture_mask = (prompts['time'] >= t_start) & (prompts['time'] <= t_end) & \
(prompts['name'].str.contains(config['filter']))
gestures_in_window = prompts[gesture_mask]
# Compute global Y range for this feature across all channels
all_feat_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
feat_min, feat_max = all_feat_vals.min(), all_feat_vals.max()
feat_margin = (feat_max - feat_min) * 0.05 if feat_max > feat_min else 1
feat_ylim = (feat_min - feat_margin, feat_max + feat_margin)
# Create one figure per feature
for feat_idx, feat_name in enumerate(feature_names):
fig, axes = plt.subplots(4, 4, figsize=(10, 8), sharex=True, sharey=True)
axes = axes.flatten()
for ch in range(n_channels):
ax = axes[ch]
for ch in range(n_channels):
ax = axes[ch]
# Plot windowed feature over time
feat_data = all_features[ch][feat_name]
ax.plot(time_windows, feat_data, linewidth=0.8, color=feature_colors[feat_idx])
ax.set_title(f"Channel {ch}", fontsize=10)
ax.set_ylim(feat_ylim) # Same Y range for all channels
ax.grid(True, alpha=0.3)
# Plot windowed feature
feat_data = all_features[ch][feat_name][mask_win]
ax.plot(t_win_rel, feat_data, linewidth=1, color=feature_colors[feat_idx])
ax.set_title(f"Ch {ch}", fontsize=9)
# Add gesture transition markers from labels
prev_label = None
for i, lbl in enumerate(labels):
lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl
t = label_times_in_data[i] # Time in flattened data
if lbl_str != prev_label and lbl_str != 'rest':
# Only show markers within the time_windows range
if t <= time_windows[-1]:
color = get_gesture_color(lbl)
ax.axvline(t, color=color, linestyle='--', alpha=0.6, linewidth=1)
prev_label = lbl_str
# Add gesture markers
for _, row in gestures_in_window.iterrows():
t_g = row['time'] - t_start
color = get_gesture_color(row['name'])
ax.axvline(t_g, color=color, linestyle='--', alpha=0.5, linewidth=0.5)
# Hide unused subplots
for ch in range(n_channels, len(axes)):
axes[ch].set_visible(False)
# Set subtitle based on gesture type
if 'Index' in config['name']:
subtitle = "(Green=index, Blue=middle)"
else:
subtitle = "(Orange=thumb)"
# Legend for gesture colors
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='cyan', linestyle='--', label='Open'),
Line2D([0], [0], color='blue', linestyle='--', label='Fist'),
Line2D([0], [0], color='orange', linestyle='--', label='Hook Em'),
Line2D([0], [0], color='green', linestyle='--', label='Thumbs Up'),
]
fig.legend(handles=legend_elements, loc='upper right', fontsize=9)
fig.suptitle(f"{feature_titles[feat_idx]} - {config['name']} Gestures\n{subtitle}", fontsize=12)
fig.supxlabel("Time (s)")
fig.supylabel(feature_ylabels[feat_idx])
plt.tight_layout()
plt.show()
fig.suptitle(f"{feature_titles[feat_idx]} - All Channels", fontsize=14)
fig.supxlabel("Time (s)")
fig.supylabel(feature_ylabels[feat_idx])
plt.tight_layout()
plt.show()
# =============================================================================
# SUMMARY: Feature statistics across all channels
# =============================================================================
print("\n" + "=" * 60)
print("FEATURE SUMMARY (all channels, all windows)")
print(f"FEATURE SUMMARY ({n_channels} channels, {n_feat_windows} windows each)")
print("=" * 60)
for feat_name in ['rms', 'wl', 'zc', 'ssc']:
all_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
print(f"{feat_name.upper():4s} | min: {all_vals.min():10.4f} | max: {all_vals.max():10.4f} | "
f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}")
# Close the HDF5 file
file.close()
print("\n[Done] HDF5 file closed.")

Binary file not shown.

View File

@@ -99,13 +99,13 @@ class RealSerialStream:
@note Requires pyserial: pip install pyserial
"""
def __init__(self, port: str = None, baud_rate: int = 115200, timeout: float = 1.0):
def __init__(self, port: str = None, baud_rate: int = 921600, timeout: float = 0.05):
"""
@brief Initialize the serial stream.
@param port Serial port name (e.g., 'COM3' on Windows, '/dev/ttyUSB0' on Linux).
If None, will attempt to auto-detect the ESP32.
@param baud_rate Communication speed in bits per second. Default 115200 matches ESP32.
@param baud_rate Communication speed in bits per second. Default 921600 for high-throughput streaming.
@param timeout Read timeout in seconds for readline().
"""
self.port = port
@@ -233,6 +233,9 @@ class RealSerialStream:
"Must call connect() first."
)
# Flush any stale data before starting fresh stream
self.serial.reset_input_buffer()
# Send start command
start_cmd = {"cmd": "start"}
self._send_json(start_cmd)