diff --git a/EMG_Arm/platformio.ini b/EMG_Arm/platformio.ini index ec08a09..b9df812 100644 --- a/EMG_Arm/platformio.ini +++ b/EMG_Arm/platformio.ini @@ -12,6 +12,6 @@ board_upload.maximum_size = 33554432 ; and I can give you the content for it. board_build.partitions = partitions.csv -monitor_speed = 115200 +monitor_speed = 921600 monitor_dtr = 1 monitor_rts = 1 \ No newline at end of file diff --git a/EMG_Arm/src/app/main.c b/EMG_Arm/src/app/main.c index a8bde15..e16d68b 100644 --- a/EMG_Arm/src/app/main.c +++ b/EMG_Arm/src/app/main.c @@ -232,9 +232,8 @@ static void stream_emg_data(void) /* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */ emg_sensor_read(&sample); - /* Output in CSV format matching Python expectation */ - printf("%lu,%u,%u,%u,%u\n", - (unsigned long)sample.timestamp_ms, + /* Output in CSV format - channels only, Python handles timestamps */ + printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1], sample.channels[2], @@ -283,7 +282,7 @@ void emgPrinter() { if (i != EMG_NUM_CHANNELS - 1) printf(" | "); } printf("\n"); - vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100)); + // vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100)); } } @@ -339,6 +338,6 @@ void app_main(void) printf("[INIT] Done!\n\n"); - emgPrinter(); - // appConnector(); + // emgPrinter(); + appConnector(); } diff --git a/EMG_Arm/src/config/config.h b/EMG_Arm/src/config/config.h index 521e711..7fb8cd4 100644 --- a/EMG_Arm/src/config/config.h +++ b/EMG_Arm/src/config/config.h @@ -63,7 +63,7 @@ * EMG Configuration ******************************************************************************/ -#define EMG_NUM_CHANNELS 1 /**< Number of EMG sensor channels */ +#define EMG_NUM_CHANNELS 4 /**< Number of EMG sensor channels */ #define EMG_SAMPLE_RATE_HZ 1000 /**< Samples per second per channel */ /******************************************************************************* diff --git a/EMG_Arm/src/drivers/emg_sensor.c b/EMG_Arm/src/drivers/emg_sensor.c index 9cb6e86..fb06fc9 100644 --- a/EMG_Arm/src/drivers/emg_sensor.c +++ b/EMG_Arm/src/drivers/emg_sensor.c @@ -19,7 +19,12 @@ adc_oneshot_unit_handle_t adc1_handle; adc_cali_handle_t cali_handle = NULL; -const uint8_t emg_channels[EMG_NUM_CHANNELS] = {ADC_CHANNEL_1}; +const uint8_t emg_channels[EMG_NUM_CHANNELS] = { + ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0 + ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1 + ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2 + ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3 +}; /******************************************************************************* * Public Functions @@ -94,7 +99,6 @@ void emg_sensor_read(emg_sample_t *sample) ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv)); sample->channels[i] = (uint16_t) voltage_mv; } - printf("\n"); #endif } diff --git a/collected_data/esp32_test_001_20260119_214904.hdf5 b/collected_data/esp32_test_001_20260119_214904.hdf5 deleted file mode 100644 index f5c325e..0000000 Binary files a/collected_data/esp32_test_001_20260119_214904.hdf5 and /dev/null differ diff --git a/collected_data/latency_fix_100_20260127_184804.hdf5 b/collected_data/latency_fix_100_20260127_184804.hdf5 new file mode 100644 index 0000000..5fadabb Binary files /dev/null and b/collected_data/latency_fix_100_20260127_184804.hdf5 differ diff --git a/collected_data/latency_fix_101_20260127_185344.hdf5 b/collected_data/latency_fix_101_20260127_185344.hdf5 new file mode 100644 index 0000000..f069f3d Binary files /dev/null and b/collected_data/latency_fix_101_20260127_185344.hdf5 differ diff --git a/collected_data/latency_fix_102_20260127_190022.hdf5 b/collected_data/latency_fix_102_20260127_190022.hdf5 new file mode 100644 index 0000000..88f0bc5 Binary files /dev/null and b/collected_data/latency_fix_102_20260127_190022.hdf5 differ diff --git a/collected_data/latency_fix_103_20260127_191249.hdf5 b/collected_data/latency_fix_103_20260127_191249.hdf5 new file mode 100644 index 0000000..47de9ab Binary files /dev/null and b/collected_data/latency_fix_103_20260127_191249.hdf5 differ diff --git a/collected_data/latency_fix_104_20260127_195150.hdf5 b/collected_data/latency_fix_104_20260127_195150.hdf5 new file mode 100644 index 0000000..3e4f967 Binary files /dev/null and b/collected_data/latency_fix_104_20260127_195150.hdf5 differ diff --git a/collected_data/latency_fix_105_20260127_195503.hdf5 b/collected_data/latency_fix_105_20260127_195503.hdf5 new file mode 100644 index 0000000..e21253e Binary files /dev/null and b/collected_data/latency_fix_105_20260127_195503.hdf5 differ diff --git a/collected_data/new_placements_000_20260127_174231.hdf5 b/collected_data/new_placements_000_20260127_174231.hdf5 new file mode 100644 index 0000000..33208bb Binary files /dev/null and b/collected_data/new_placements_000_20260127_174231.hdf5 differ diff --git a/collected_data/user_001_20260108_170626.hdf5 b/collected_data/user_001_20260108_170626.hdf5 deleted file mode 100644 index 6e86c47..0000000 Binary files a/collected_data/user_001_20260108_170626.hdf5 and /dev/null differ diff --git a/collected_data/user_001_20260108_171851.hdf5 b/collected_data/user_001_20260108_171851.hdf5 deleted file mode 100644 index 4232de0..0000000 Binary files a/collected_data/user_001_20260108_171851.hdf5 and /dev/null differ diff --git a/collected_data/user_001_20260108_172535.hdf5 b/collected_data/user_001_20260108_172535.hdf5 deleted file mode 100644 index 91c9175..0000000 Binary files a/collected_data/user_001_20260108_172535.hdf5 and /dev/null differ diff --git a/collected_data/user_001_20260108_174542.hdf5 b/collected_data/user_001_20260108_174542.hdf5 deleted file mode 100644 index d4534b7..0000000 Binary files a/collected_data/user_001_20260108_174542.hdf5 and /dev/null differ diff --git a/collected_data/user_001_20260108_174934.hdf5 b/collected_data/user_001_20260108_174934.hdf5 deleted file mode 100644 index f42702e..0000000 Binary files a/collected_data/user_001_20260108_174934.hdf5 and /dev/null differ diff --git a/collected_data/user_001_20260109_215307.hdf5 b/collected_data/user_001_20260109_215307.hdf5 deleted file mode 100644 index 01140da..0000000 Binary files a/collected_data/user_001_20260109_215307.hdf5 and /dev/null differ diff --git a/collected_data/user_001_20260119_214559.hdf5 b/collected_data/user_001_20260119_214559.hdf5 deleted file mode 100644 index c138791..0000000 Binary files a/collected_data/user_001_20260119_214559.hdf5 and /dev/null differ diff --git a/collected_data/user_002_20260108_175610.hdf5 b/collected_data/user_002_20260108_175610.hdf5 deleted file mode 100644 index 57f71ad..0000000 Binary files a/collected_data/user_002_20260108_175610.hdf5 and /dev/null differ diff --git a/collected_data/user_002_20260108_220204.hdf5 b/collected_data/user_002_20260108_220204.hdf5 deleted file mode 100644 index 8526768..0000000 Binary files a/collected_data/user_002_20260108_220204.hdf5 and /dev/null differ diff --git a/collected_data/user_003_20260109_154733.hdf5 b/collected_data/user_003_20260109_154733.hdf5 deleted file mode 100644 index fac45a6..0000000 Binary files a/collected_data/user_003_20260109_154733.hdf5 and /dev/null differ diff --git a/collected_data/user_003_20260109_215459.hdf5 b/collected_data/user_003_20260109_215459.hdf5 deleted file mode 100644 index bab3588..0000000 Binary files a/collected_data/user_003_20260109_215459.hdf5 and /dev/null differ diff --git a/collected_data/user_004_20260110_154828.hdf5 b/collected_data/user_004_20260110_154828.hdf5 deleted file mode 100644 index 65e8caf..0000000 Binary files a/collected_data/user_004_20260110_154828.hdf5 and /dev/null differ diff --git a/emg_gui.py b/emg_gui.py index 32f357c..307d417 100644 --- a/emg_gui.py +++ b/emg_gui.py @@ -301,6 +301,7 @@ class CollectionPage(BasePage): self.scheduler = None self.collected_windows = [] self.collected_labels = [] + self.collected_raw_samples = [] # For label alignment self.sample_buffer = [] self.collection_thread = None self.data_queue = queue.Queue() @@ -511,7 +512,7 @@ class CollectionPage(BasePage): ax.tick_params(colors='white') ax.set_ylabel(f'Ch{i}', color='white', fontsize=10) ax.set_xlim(0, 500) - ax.set_ylim(0, 1024) + ax.set_ylim(0, 3300) # ESP32 outputs millivolts (0-3100 mV) ax.grid(True, alpha=0.3) for spine in ax.spines.values(): spine.set_color('white') @@ -648,6 +649,7 @@ class CollectionPage(BasePage): # Reset state self.collected_windows = [] self.collected_labels = [] + self.collected_raw_samples = [] # Store raw samples for label alignment self.sample_buffer = [] print("[DEBUG] Reset collection state") @@ -800,6 +802,9 @@ class CollectionPage(BasePage): timeout_warning_sent = False sample = self.parser.parse_line(line) if sample: + # Store raw sample for label alignment + self.collected_raw_samples.append(sample) + # Batch samples for plotting (don't send every single one) sample_batch.append(sample.channels) @@ -932,9 +937,25 @@ class CollectionPage(BasePage): notes="" ) - filepath = storage.save_session(self.collected_windows, self.collected_labels, metadata) + # Get session start time for label alignment + session_start_time = None + if self.scheduler and self.scheduler.session_start_time: + session_start_time = self.scheduler.session_start_time - messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}") + filepath = storage.save_session( + windows=self.collected_windows, + labels=self.collected_labels, + metadata=metadata, + raw_samples=self.collected_raw_samples if self.collected_raw_samples else None, + session_start_time=session_start_time + ) + + # Check if alignment was performed + alignment_msg = "" + if session_start_time and self.collected_raw_samples: + alignment_msg = "\n\nLabel alignment: enabled" + + messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}{alignment_msg}") # Update sidebar app = self.winfo_toplevel() @@ -944,6 +965,7 @@ class CollectionPage(BasePage): # Reset for next collection self.collected_windows = [] self.collected_labels = [] + self.collected_raw_samples = [] self.save_button.configure(state="disabled") self.status_label.configure(text="Ready to collect") self.window_count_label.configure(text="Windows: 0") diff --git a/extra_data/emg_real_001_20260125_132316.hdf5 b/extra_data/emg_real_001_20260125_132316.hdf5 new file mode 100644 index 0000000..41f5359 Binary files /dev/null and b/extra_data/emg_real_001_20260125_132316.hdf5 differ diff --git a/extra_data/emg_real_001_20260125_133141.hdf5 b/extra_data/emg_real_001_20260125_133141.hdf5 new file mode 100644 index 0000000..2e797d4 Binary files /dev/null and b/extra_data/emg_real_001_20260125_133141.hdf5 differ diff --git a/extra_data/emg_real_001_20260125_133545.hdf5 b/extra_data/emg_real_001_20260125_133545.hdf5 new file mode 100644 index 0000000..7691af2 Binary files /dev/null and b/extra_data/emg_real_001_20260125_133545.hdf5 differ diff --git a/extra_data/emg_real_001_20260125_144247.hdf5 b/extra_data/emg_real_001_20260125_144247.hdf5 new file mode 100644 index 0000000..dca7d77 Binary files /dev/null and b/extra_data/emg_real_001_20260125_144247.hdf5 differ diff --git a/extra_data/label_test_002_20260125_183955.hdf5 b/extra_data/label_test_002_20260125_183955.hdf5 new file mode 100644 index 0000000..1ace467 Binary files /dev/null and b/extra_data/label_test_002_20260125_183955.hdf5 differ diff --git a/extra_data/label_test_003_20260125_185807.hdf5 b/extra_data/label_test_003_20260125_185807.hdf5 new file mode 100644 index 0000000..68e5126 Binary files /dev/null and b/extra_data/label_test_003_20260125_185807.hdf5 differ diff --git a/extra_data/latency_debug_20260125_181559.hdf5 b/extra_data/latency_debug_20260125_181559.hdf5 new file mode 100644 index 0000000..c65cf6d Binary files /dev/null and b/extra_data/latency_debug_20260125_181559.hdf5 differ diff --git a/extra_data/latency_fix_000_20260126_230849.hdf5 b/extra_data/latency_fix_000_20260126_230849.hdf5 new file mode 100644 index 0000000..bc8cfe1 Binary files /dev/null and b/extra_data/latency_fix_000_20260126_230849.hdf5 differ diff --git a/extra_data/latency_fix_002_20260126_233357.hdf5 b/extra_data/latency_fix_002_20260126_233357.hdf5 new file mode 100644 index 0000000..2dc53d7 Binary files /dev/null and b/extra_data/latency_fix_002_20260126_233357.hdf5 differ diff --git a/extra_data/latency_fix_003_20260126_234717.hdf5 b/extra_data/latency_fix_003_20260126_234717.hdf5 new file mode 100644 index 0000000..1ff6677 Binary files /dev/null and b/extra_data/latency_fix_003_20260126_234717.hdf5 differ diff --git a/extra_data/latency_fix_004_20260127_002041.hdf5 b/extra_data/latency_fix_004_20260127_002041.hdf5 new file mode 100644 index 0000000..e9a7bac Binary files /dev/null and b/extra_data/latency_fix_004_20260127_002041.hdf5 differ diff --git a/extra_data/latency_fix_005_20260127_002618.hdf5 b/extra_data/latency_fix_005_20260127_002618.hdf5 new file mode 100644 index 0000000..c34736e Binary files /dev/null and b/extra_data/latency_fix_005_20260127_002618.hdf5 differ diff --git a/learning_data_collection.py b/learning_data_collection.py index e26b73c..b08cbaa 100644 --- a/learning_data_collection.py +++ b/learning_data_collection.py @@ -39,7 +39,7 @@ import matplotlib.pyplot as plt # ============================================================================= NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors) SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ -SERIAL_BAUD = 115200 # Typical baud rate for ESP32 +SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog # Windowing configuration WINDOW_SIZE_MS = 150 # Window size in milliseconds @@ -55,6 +55,27 @@ DATA_DIR = Path("collected_data") # Directory to store session files MODEL_DIR = Path("models") # Directory to store trained models USER_ID = "user_001" # Current user ID (change per user) +# ============================================================================= +# LABEL ALIGNMENT CONFIGURATION +# ============================================================================= +# Human reaction time causes EMG activity to lag behind label prompts. +# We detect when EMG actually rises and shift labels to match. + +ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment +ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std +ONSET_SEARCH_MS = 2000 # Search window after prompt (ms) + +# ============================================================================= +# TRANSITION WINDOW FILTERING +# ============================================================================= +# Windows near gesture transitions contain ambiguous data (reaction time at start, +# muscle relaxation at end). Discard these during training for cleaner labels. +# This is standard practice in EMG research (see Frontiers Neurorobotics 2023). + +DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training +TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts +TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends + # ============================================================================= # DATA STRUCTURES # ============================================================================= @@ -187,30 +208,27 @@ class EMGParser: """ Parse a line from ESP32 into an EMGSample. - Expected format: "timestamp_ms,ch0,ch1,ch2,ch3\n" + Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp) + Python assigns timestamp on receipt for label alignment. Returns None if parsing fails. """ try: # Strip whitespace and split parts = line.strip().split(',') - # Validate we have correct number of fields - expected_fields = 1 + self.num_channels # timestamp + channels - if len(parts) != expected_fields: + # Validate we have correct number of fields (channels only) + if len(parts) != self.num_channels: self.parse_errors += 1 return None - # Parse ESP32 timestamp - esp_timestamp_ms = int(parts[0]) - # Parse channel values - channels = [float(parts[i + 1]) for i in range(self.num_channels)] + channels = [float(parts[i]) for i in range(self.num_channels)] - # Create sample with Python-side timestamp + # Create sample with Python-side timestamp (aligned with label clock) sample = EMGSample( timestamp=time.perf_counter(), # High-resolution monotonic clock channels=channels, - esp_timestamp_ms=esp_timestamp_ms + esp_timestamp_ms=None # No longer using ESP32 timestamp ) self.samples_parsed += 1 @@ -491,6 +509,179 @@ class GestureAwareEMGStream(SimulatedEMGStream): time.sleep(interval) +# ============================================================================= +# LABEL ALIGNMENT (Simple Onset Detection) +# ============================================================================= +from scipy.signal import butter, sosfiltfilt + + +def align_labels_with_onset( + labels: list[str], + window_start_times: np.ndarray, + raw_timestamps: np.ndarray, + raw_channels: np.ndarray, + sampling_rate: int, + threshold_factor: float = 2.0, + search_ms: float = 800 +) -> list[str]: + """ + Align labels to EMG onset by detecting when signal rises above baseline. + + Simple algorithm: + 1. High-pass filter to remove DC offset + 2. Compute RMS envelope across channels + 3. At each label transition, find where envelope exceeds baseline + threshold + 4. Move label boundary to that point + """ + if len(labels) == 0: + return labels.copy() + + # High-pass filter to remove DC (raw ADC has ~2340mV offset) + nyquist = sampling_rate / 2 + sos = butter(2, 20.0 / nyquist, btype='high', output='sos') + + # Filter and compute envelope (RMS across channels) + filtered = np.zeros_like(raw_channels) + for ch in range(raw_channels.shape[1]): + filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch]) + envelope = np.sqrt(np.mean(filtered ** 2, axis=1)) + + # Smooth envelope + sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos') + envelope = sosfiltfilt(sos_lp, envelope) + + # Find transitions and detect onsets + search_samples = int(search_ms / 1000 * sampling_rate) + baseline_samples = int(200 / 1000 * sampling_rate) + + boundaries = [] # (time, new_label) + + for i in range(1, len(labels)): + if labels[i] != labels[i - 1]: + prompt_time = window_start_times[i] + + # Find index in raw signal closest to prompt time + prompt_idx = np.searchsorted(raw_timestamps, prompt_time) + + # Get baseline (before transition) + base_start = max(0, prompt_idx - baseline_samples) + baseline = envelope[base_start:prompt_idx] + if len(baseline) == 0: + boundaries.append((prompt_time + 0.3, labels[i])) + continue + + threshold = np.mean(baseline) + threshold_factor * np.std(baseline) + + # Search forward for onset + search_end = min(len(envelope), prompt_idx + search_samples) + onset_idx = None + + for j in range(prompt_idx, search_end): + if envelope[j] > threshold: + onset_idx = j + break + + if onset_idx is not None: + onset_time = raw_timestamps[onset_idx] + else: + onset_time = prompt_time + 0.3 # fallback + + boundaries.append((onset_time, labels[i])) + + # Assign labels based on detected boundaries + aligned = [] + boundary_idx = 0 + current_label = labels[0] + + for t in window_start_times: + while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]: + current_label = boundaries[boundary_idx][1] + boundary_idx += 1 + aligned.append(current_label) + + return aligned + + +def filter_transition_windows( + X: np.ndarray, + y: np.ndarray, + labels: list[str], + start_times: np.ndarray, + end_times: np.ndarray, + transition_start_ms: float = TRANSITION_START_MS, + transition_end_ms: float = TRANSITION_END_MS +) -> tuple[np.ndarray, np.ndarray, list[str]]: + """ + Filter out windows that fall within transition zones at gesture boundaries. + + This removes ambiguous data where: + - User is still reacting to prompt (start of gesture) + - User is anticipating next gesture (end of gesture) + + Args: + X: EMG data array (n_windows, samples, channels) + y: Label indices (n_windows,) + labels: String labels (n_windows,) + start_times: Window start times in seconds (n_windows,) + end_times: Window end times in seconds (n_windows,) + transition_start_ms: Discard windows within this time after gesture start + transition_end_ms: Discard windows within this time before gesture end + + Returns: + Filtered (X, y, labels) with transition windows removed + """ + if len(X) == 0: + return X, y, labels + + transition_start_sec = transition_start_ms / 1000.0 + transition_end_sec = transition_end_ms / 1000.0 + + # Find gesture boundaries (where label changes) + # Each boundary is the START of a new gesture segment + boundaries = [0] # First window starts a segment + for i in range(1, len(labels)): + if labels[i] != labels[i-1]: + boundaries.append(i) + boundaries.append(len(labels)) # End marker + + # For each segment, find start_time and end_time of the gesture + # Then mark windows that are within transition zones + keep_mask = np.ones(len(X), dtype=bool) + + for seg_idx in range(len(boundaries) - 1): + seg_start_idx = boundaries[seg_idx] + seg_end_idx = boundaries[seg_idx + 1] + + # Get the time boundaries of this gesture segment + gesture_start_time = start_times[seg_start_idx] + gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time + + # Mark windows in transition zones + for i in range(seg_start_idx, seg_end_idx): + window_start = start_times[i] + window_end = end_times[i] + + # Check if window is too close to gesture START (reaction time zone) + if window_start < gesture_start_time + transition_start_sec: + keep_mask[i] = False + + # Check if window is too close to gesture END (anticipation zone) + if window_end > gesture_end_time - transition_end_sec: + keep_mask[i] = False + + # Apply filter + X_filtered = X[keep_mask] + y_filtered = y[keep_mask] + labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep] + + n_removed = len(X) - len(X_filtered) + if n_removed > 0: + print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)") + print(f"[Filter] Kept {len(X_filtered)} windows for training") + + return X_filtered, y_filtered, labels_filtered + + # ============================================================================= # SESSION STORAGE (Save/Load labeled data to HDF5) # ============================================================================= @@ -527,16 +718,24 @@ class SessionStorage: windows: list[EMGWindow], labels: list[str], metadata: SessionMetadata, - raw_samples: Optional[list[EMGSample]] = None + raw_samples: Optional[list[EMGSample]] = None, + session_start_time: Optional[float] = None, + enable_alignment: bool = ENABLE_LABEL_ALIGNMENT ) -> Path: """ - Save a collection session to HDF5. + Save a collection session to HDF5 with optional label alignment. + + When raw_samples and session_start_time are provided and enable_alignment + is True, automatically detects EMG onset and corrects labels for human + reaction time delay. Args: windows: List of EMGWindow objects (no label info) labels: List of gesture labels, parallel to windows metadata: Session metadata - raw_samples: Optional raw samples for debugging + raw_samples: Raw samples (required for alignment) + session_start_time: When session started (required for alignment) + enable_alignment: Whether to perform automatic label alignment """ filepath = self.get_session_filepath(metadata.session_id) @@ -549,6 +748,35 @@ class SessionStorage: window_samples = len(windows[0].samples) num_channels = len(windows[0].samples[0].channels) + # Prepare timing arrays + start_times = np.array([w.start_time for w in windows], dtype=np.float64) + end_times = np.array([w.end_time for w in windows], dtype=np.float64) + + # Label alignment using onset detection + aligned_labels = labels + original_labels = labels + + if enable_alignment and raw_samples and len(raw_samples) > 0: + print("[Storage] Aligning labels to EMG onset...") + + raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64) + raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32) + + aligned_labels = align_labels_with_onset( + labels=labels, + window_start_times=start_times, + raw_timestamps=raw_timestamps, + raw_channels=raw_channels, + sampling_rate=metadata.sampling_rate, + threshold_factor=ONSET_THRESHOLD, + search_ms=ONSET_SEARCH_MS + ) + + changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b) + print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted") + elif enable_alignment: + print("[Storage] Warning: No raw samples, skipping alignment") + with h5py.File(filepath, 'w') as f: # Metadata as attributes f.attrs['user_id'] = metadata.user_id @@ -568,19 +796,24 @@ class SessionStorage: emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32) windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4) - # Labels stored separately from window data - max_label_len = max(len(l) for l in labels) + # Store ALIGNED labels as primary (what training will use) + max_label_len = max(len(l) for l in aligned_labels) dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1) - windows_grp.create_dataset('labels', data=labels, dtype=dt) + windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt) + + # Also store original labels for reference/debugging + windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt) window_ids = np.array([w.window_id for w in windows], dtype=np.int32) windows_grp.create_dataset('window_ids', data=window_ids) - start_times = np.array([w.start_time for w in windows], dtype=np.float64) - end_times = np.array([w.end_time for w in windows], dtype=np.float64) windows_grp.create_dataset('start_times', data=start_times) windows_grp.create_dataset('end_times', data=end_times) + # Store alignment metadata + f.attrs['alignment_enabled'] = enable_alignment + f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none' + if raw_samples: raw_grp = f.create_group('raw_samples') timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64) @@ -655,13 +888,21 @@ class SessionStorage: print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types") return windows, labels_out, metadata - def load_for_training(self, session_id: str) -> tuple[np.ndarray, np.ndarray, list[str]]: - """Load a single session in ML-ready format: X, y, label_names.""" + def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]: + """ + Load a single session in ML-ready format: X, y, label_names. + + Args: + session_id: The session to load + filter_transitions: If True, remove windows in transition zones (default from config) + """ filepath = self.get_session_filepath(session_id) with h5py.File(filepath, 'r') as f: X = f['windows/emg_data'][:] labels_raw = f['windows/labels'][:] + start_times = f['windows/start_times'][:] + end_times = f['windows/end_times'][:] labels = [] for l in labels_raw: @@ -670,18 +911,33 @@ class SessionStorage: else: labels.append(l) + print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)") + + # Apply transition filtering if enabled + if filter_transitions: + label_names_pre = sorted(set(labels)) + label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)} + y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32) + + X, y_pre, labels = filter_transition_windows( + X, y_pre, labels, start_times, end_times + ) + label_names = sorted(set(labels)) label_to_idx = {name: idx for idx, name in enumerate(label_names)} y = np.array([label_to_idx[l] for l in labels], dtype=np.int32) - print(f"[Storage] Loaded for training: X{X.shape}, y{y.shape}") + print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}") print(f"[Storage] Labels: {label_names}") return X, y, label_names - def load_all_for_training(self) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]: + def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]: """ Load ALL sessions combined into a single training dataset. + Args: + filter_transitions: If True, remove windows in transition zones (default from config) + Returns: X: Combined EMG windows from all sessions (n_total_windows, samples, channels) y: Combined labels as integers (n_total_windows,) @@ -697,11 +953,15 @@ class SessionStorage: raise ValueError("No sessions found to load!") print(f"[Storage] Loading {len(sessions)} session(s) for combined training...") + if filter_transitions: + print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms") all_X = [] all_labels = [] loaded_sessions = [] reference_shape = None + total_removed = 0 + total_original = 0 for session_id in sessions: filepath = self.get_session_filepath(session_id) @@ -709,6 +969,8 @@ class SessionStorage: with h5py.File(filepath, 'r') as f: X = f['windows/emg_data'][:] labels_raw = f['windows/labels'][:] + start_times = f['windows/start_times'][:] + end_times = f['windows/end_times'][:] # Validate shape compatibility if reference_shape is None: @@ -725,10 +987,26 @@ class SessionStorage: else: labels.append(l) + original_count = len(X) + total_original += original_count + + # Apply transition filtering per session (each has its own gesture boundaries) + if filter_transitions: + # Need temporary y for filtering function + temp_label_names = sorted(set(labels)) + temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)} + temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32) + + X, temp_y, labels = filter_transition_windows( + X, temp_y, labels, start_times, end_times + ) + total_removed += original_count - len(X) + all_X.append(X) all_labels.extend(labels) loaded_sessions.append(session_id) - print(f"[Storage] - {session_id}: {X.shape[0]} windows") + print(f"[Storage] - {session_id}: {len(X)} windows" + + (f" (was {original_count})" if filter_transitions and len(X) != original_count else "")) if not all_X: raise ValueError("No compatible sessions found!") @@ -742,6 +1020,8 @@ class SessionStorage: y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32) print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}") + if filter_transitions and total_removed > 0: + print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)") print(f"[Storage] Labels: {label_names}") print(f"[Storage] Sessions loaded: {len(loaded_sessions)}") diff --git a/learning_emg_filtering.py b/learning_emg_filtering.py index 88b3c7b..b257ef7 100644 --- a/learning_emg_filtering.py +++ b/learning_emg_filtering.py @@ -1,41 +1,152 @@ import numpy as np import matplotlib.pyplot as plt import h5py -import pandas as pd from scipy.signal import butter, sosfiltfilt # ============================================================================= # CONFIGURABLE PARAMETERS # ============================================================================= -ZC_THRESHOLD_PERCENT = 0.6 # Zero Crossing threshold as fraction of RMS -SSC_THRESHOLD_PERCENT = 0.3 # Slope Sign Change threshold as fraction of RMS +ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS +SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS -file = h5py.File("assets/discrete_gestures_user_000_dataset_000.hdf5", "r") +# ============================================================================= +# LOAD DATA FROM GUI's HDF5 FORMAT +# ============================================================================= +# Update this path to your collected session file +HDF5_PATH = "collected_data\latency_fix_106_20260127_200043.hdf5" +file = h5py.File(HDF5_PATH, "r") + +# Print HDF5 structure for debugging +print("HDF5 Structure:") def print_tree(name, obj): - print(name) - + print(f" {name}") file.visititems(print_tree) +print() -print(list(file.keys())) +# Load metadata +fs = file.attrs['sampling_rate'] # Sampling rate in Hz +n_channels = file.attrs['num_channels'] +window_size_ms = file.attrs['window_size_ms'] +print(f"Sampling rate: {fs} Hz") +print(f"Channels: {n_channels}") +print(f"Window size: {window_size_ms} ms") -data = file["data"] -print(type(data)) -print(data) -print(data.dtype) +# Load windowed EMG data: shape (n_windows, samples_per_window, channels) +emg_windows = file['windows/emg_data'][:] +labels = file['windows/labels'][:] +start_times = file['windows/start_times'][:] +end_times = file['windows/end_times'][:] -raw = data[:] -print(raw.shape) -print(raw.dtype) +print(f"Windows shape: {emg_windows.shape}") +print(f"Labels: {len(labels)} (unique: {np.unique(labels)})") -emg = raw['emg'] -time = raw['time'] -print(emg.shape) +# Flatten windows to continuous signal for filtering analysis +# Shape: (n_windows * samples_per_window, channels) +n_windows, samples_per_window, _ = emg_windows.shape +emg = emg_windows.reshape(-1, n_channels) +print(f"Flattened EMG shape: {emg.shape}") -dt = np.diff(time) -fs = 1.0 / np.median(dt) -print("fs =", fs) -print("dt min/median/max =", dt.min(), np.median(dt), dt.max()) +# Reconstruct time vector (assumes continuous recording) +total_samples = emg.shape[0] +time = np.arange(total_samples) / fs +print(f"Time range: {time[0]:.2f}s to {time[-1]:.2f}s") + +# ============================================================================= +# PLOT RAW EMG DATA (before filtering) +# ============================================================================= +print("\nPlotting raw EMG data...") + +# Color map for channels +channel_colors = ['#00ff88', '#ff6b6b', '#4ecdc4', '#ffe66d'] + +# Map window indices to actual time in the flattened data +# Window i starts at sample (i * samples_per_window), which is time (i * samples_per_window / fs) +window_time_in_data = np.arange(len(labels)) * samples_per_window / fs + +# Compute global Y range across all channels for consistent comparison +emg_global_min = emg.min() +emg_global_max = emg.max() +emg_margin = (emg_global_max - emg_global_min) * 0.05 # 5% margin +emg_ylim = (emg_global_min - emg_margin, emg_global_max + emg_margin) +print(f"Global EMG range: {emg_global_min:.1f} to {emg_global_max:.1f} mV") + +# Full session plot with shared Y-axis +fig_raw, axes_raw = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels), sharex=True, sharey=True) +if n_channels == 1: + axes_raw = [axes_raw] + +for ch in range(n_channels): + ax = axes_raw[ch] + + # Plot raw EMG signal + ax.plot(time, emg[:, ch], linewidth=0.3, color=channel_colors[ch % len(channel_colors)], alpha=0.8) + ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10) + ax.grid(True, alpha=0.3) + ax.set_xlim(time[0], time[-1]) + ax.set_ylim(emg_ylim) # Same Y range for all channels + +# Add gesture markers AFTER plotting +y_min, y_max = emg_ylim +y_text = y_min + (y_max - y_min) * 0.85 # Position text at 85% height +for ch in range(n_channels): + ax = axes_raw[ch] + + # Add gesture transition markers using window-aligned times + prev_label = None + for i, lbl in enumerate(labels): + lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl + t = window_time_in_data[i] # Time in the flattened data + + if lbl_str != prev_label: + if 'open' in lbl_str: + color = 'cyan' + elif 'fist' in lbl_str: + color = 'blue' + elif 'hook' in lbl_str: + color = 'orange' + elif 'thumb' in lbl_str: + color = 'green' + elif 'rest' in lbl_str: + color = 'gray' + else: + color = 'red' + + ax.axvline(t, color=color, linestyle='--', alpha=0.7, linewidth=1) + + # Only add text label on first channel to avoid clutter + if ch == 0 and lbl_str != 'rest': + ax.text(t + 0.2, y_text, lbl_str, fontsize=8, + color=color, rotation=0, ha='left', va='top', + bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7)) + prev_label = lbl_str + +axes_raw[-1].set_xlabel('Time (s)', fontsize=11) +fig_raw.suptitle('Raw EMG Signal (All Channels)', fontsize=14, fontweight='bold') +plt.tight_layout() +plt.show() + +# Zoomed view of first 2 seconds to see waveform detail +zoom_duration = 2.0 +zoom_samples = int(zoom_duration * fs) + +if total_samples > zoom_samples: + fig_zoom, axes_zoom = plt.subplots(n_channels, 1, figsize=(14, 2 * n_channels), sharex=True, sharey=True) + if n_channels == 1: + axes_zoom = [axes_zoom] + + for ch in range(n_channels): + ax = axes_zoom[ch] + ax.plot(time[:zoom_samples], emg[:zoom_samples, ch], + linewidth=0.5, color=channel_colors[ch % len(channel_colors)]) + ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10) + ax.set_ylim(emg_ylim) # Same Y range as full plot + ax.grid(True, alpha=0.3) + + axes_zoom[-1].set_xlabel('Time (s)', fontsize=11) + fig_zoom.suptitle(f'Raw EMG Signal (First {zoom_duration}s - Zoomed)', fontsize=14, fontweight='bold') + plt.tight_layout() + plt.show() # 1) pick one channel emg_ch = emg[:, 0].astype(np.float32) @@ -122,10 +233,9 @@ def compute_all_features_windowed(x, window_len, threshold_zc, threshold_ssc): # ============================================================================= -# COMPUTE FEATURES FOR ALL 16 CHANNELS +# COMPUTE FEATURES FOR ALL CHANNELS # ============================================================================= -n_channels = 16 all_features = {} # Non-overlapping windows - for ML for ch in range(n_channels): @@ -143,101 +253,122 @@ for ch in range(n_channels): all_features[ch] = {'rms': rms_i, 'wl': wl_i, 'zc': zc_i, 'ssc': ssc_i} # Time vector for windowed features -n_windows = len(all_features[0]['rms']) -window_centers = np.arange(n_windows) * window_samples + window_samples // 2 +n_feat_windows = len(all_features[0]['rms']) +window_centers = np.arange(n_feat_windows) * window_samples + window_samples // 2 time_windows = time_ch[window_centers] print(f"\nComputed features for {n_channels} channels") -print(f"Windows per channel (non-overlapping): {n_windows}") +print(f"Windows per channel (non-overlapping): {n_feat_windows}") print(f"Window size: {window_samples} samples ({window_ms} ms)") -# 5) Load gesture labels (prompts) -prompts = pd.read_hdf("assets/discrete_gestures_user_000_dataset_000.hdf5", key="prompts") -print("\nUnique gestures:", prompts['name'].unique()) +# 5) Use embedded gesture labels from HDF5 +# Labels are per-window, aligned with emg_windows +unique_labels = np.unique(labels) +print(f"\nUnique gestures in session: {unique_labels}") -t_abs = time_ch # absolute timestamps - -# Find first occurrence of each gesture type -index_gestures = prompts[prompts['name'].str.contains('index')] -middle_gestures = prompts[prompts['name'].str.contains('middle')] -thumb_gestures = prompts[prompts['name'].str.contains('thumb')] - -# Define plot configurations: 1) Index+Middle combined, 2) Thumb separate -plot_configs = [ - {'name': 'Index & Middle Finger', 'start_time': index_gestures['time'].iloc[0], 'filter': 'index|middle'}, - {'name': 'Thumb', 'start_time': thumb_gestures['time'].iloc[0], 'filter': 'thumb'}, -] +# Map labels to time in the flattened data (same as raw plot) +# Each original window i maps to time (i * samples_per_window / fs) in the flattened data +# But we dropped `drop` samples, so adjust accordingly +label_times_in_data = np.arange(len(labels)) * samples_per_window / fs +print(f"Data duration: {time[-1]:.2f}s ({len(labels)} windows)") # Color function for markers def get_gesture_color(name): - if 'index' in name: - return 'green' - elif 'middle' in name: + """Assign colors to gesture types.""" + name_str = name.decode('utf-8') if isinstance(name, bytes) else name + if 'open' in name_str: + return 'cyan' + elif 'fist' in name_str: return 'blue' - elif 'thumb' in name: + elif 'hook' in name_str: return 'orange' - return 'gray' + elif 'thumb' in name_str: + return 'green' + elif 'rest' in name_str: + return 'gray' + return 'red' # ============================================================================= -# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL 16 CHANNELS +# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL CHANNELS # ============================================================================= feature_names = ['rms', 'wl', 'zc', 'ssc'] feature_titles = ['RMS Envelope', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)'] feature_colors = ['red', 'blue', 'green', 'purple'] -feature_ylabels = ['Amplitude', 'WL (a.u.)', 'Count', 'Count'] +feature_ylabels = ['Amplitude (mV)', 'WL (a.u.)', 'Count', 'Count'] -for config in plot_configs: - t_start = config['start_time'] - 0.5 - t_end = t_start + 10.0 +# Determine subplot grid based on channel count +if n_channels <= 4: + n_rows, n_cols = 2, 2 +elif n_channels <= 9: + n_rows, n_cols = 3, 3 +else: + n_rows, n_cols = 4, 4 - # Mask for windowed time vector - mask_win = (time_windows >= t_start) & (time_windows <= t_end) - t_win_rel = time_windows[mask_win] - t_start +# Plot entire session for each feature +for feat_idx, feat_name in enumerate(feature_names): + fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8), sharex=True, sharey=True) + axes = axes.flatten() - # Get gestures in window - gesture_mask = (prompts['time'] >= t_start) & (prompts['time'] <= t_end) & \ - (prompts['name'].str.contains(config['filter'])) - gestures_in_window = prompts[gesture_mask] + # Compute global Y range for this feature across all channels + all_feat_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)]) + feat_min, feat_max = all_feat_vals.min(), all_feat_vals.max() + feat_margin = (feat_max - feat_min) * 0.05 if feat_max > feat_min else 1 + feat_ylim = (feat_min - feat_margin, feat_max + feat_margin) - # Create one figure per feature - for feat_idx, feat_name in enumerate(feature_names): - fig, axes = plt.subplots(4, 4, figsize=(10, 8), sharex=True, sharey=True) - axes = axes.flatten() + for ch in range(n_channels): + ax = axes[ch] - for ch in range(n_channels): - ax = axes[ch] + # Plot windowed feature over time + feat_data = all_features[ch][feat_name] + ax.plot(time_windows, feat_data, linewidth=0.8, color=feature_colors[feat_idx]) + ax.set_title(f"Channel {ch}", fontsize=10) + ax.set_ylim(feat_ylim) # Same Y range for all channels + ax.grid(True, alpha=0.3) - # Plot windowed feature - feat_data = all_features[ch][feat_name][mask_win] - ax.plot(t_win_rel, feat_data, linewidth=1, color=feature_colors[feat_idx]) - ax.set_title(f"Ch {ch}", fontsize=9) + # Add gesture transition markers from labels + prev_label = None + for i, lbl in enumerate(labels): + lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl + t = label_times_in_data[i] # Time in flattened data + if lbl_str != prev_label and lbl_str != 'rest': + # Only show markers within the time_windows range + if t <= time_windows[-1]: + color = get_gesture_color(lbl) + ax.axvline(t, color=color, linestyle='--', alpha=0.6, linewidth=1) + prev_label = lbl_str - # Add gesture markers - for _, row in gestures_in_window.iterrows(): - t_g = row['time'] - t_start - color = get_gesture_color(row['name']) - ax.axvline(t_g, color=color, linestyle='--', alpha=0.5, linewidth=0.5) + # Hide unused subplots + for ch in range(n_channels, len(axes)): + axes[ch].set_visible(False) - # Set subtitle based on gesture type - if 'Index' in config['name']: - subtitle = "(Green=index, Blue=middle)" - else: - subtitle = "(Orange=thumb)" + # Legend for gesture colors + from matplotlib.lines import Line2D + legend_elements = [ + Line2D([0], [0], color='cyan', linestyle='--', label='Open'), + Line2D([0], [0], color='blue', linestyle='--', label='Fist'), + Line2D([0], [0], color='orange', linestyle='--', label='Hook Em'), + Line2D([0], [0], color='green', linestyle='--', label='Thumbs Up'), + ] + fig.legend(handles=legend_elements, loc='upper right', fontsize=9) - fig.suptitle(f"{feature_titles[feat_idx]} - {config['name']} Gestures\n{subtitle}", fontsize=12) - fig.supxlabel("Time (s)") - fig.supylabel(feature_ylabels[feat_idx]) - plt.tight_layout() - plt.show() + fig.suptitle(f"{feature_titles[feat_idx]} - All Channels", fontsize=14) + fig.supxlabel("Time (s)") + fig.supylabel(feature_ylabels[feat_idx]) + plt.tight_layout() + plt.show() # ============================================================================= # SUMMARY: Feature statistics across all channels # ============================================================================= print("\n" + "=" * 60) -print("FEATURE SUMMARY (all channels, all windows)") +print(f"FEATURE SUMMARY ({n_channels} channels, {n_feat_windows} windows each)") print("=" * 60) for feat_name in ['rms', 'wl', 'zc', 'ssc']: all_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)]) print(f"{feat_name.upper():4s} | min: {all_vals.min():10.4f} | max: {all_vals.max():10.4f} | " - f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}") \ No newline at end of file + f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}") + +# Close the HDF5 file +file.close() +print("\n[Done] HDF5 file closed.") \ No newline at end of file diff --git a/models/emg_lda_classifier.joblib b/models/emg_lda_classifier.joblib index 76556d0..5f2f6fc 100644 Binary files a/models/emg_lda_classifier.joblib and b/models/emg_lda_classifier.joblib differ diff --git a/serial_stream.py b/serial_stream.py index be007b0..d3b32a7 100644 --- a/serial_stream.py +++ b/serial_stream.py @@ -99,13 +99,13 @@ class RealSerialStream: @note Requires pyserial: pip install pyserial """ - def __init__(self, port: str = None, baud_rate: int = 115200, timeout: float = 1.0): + def __init__(self, port: str = None, baud_rate: int = 921600, timeout: float = 0.05): """ @brief Initialize the serial stream. @param port Serial port name (e.g., 'COM3' on Windows, '/dev/ttyUSB0' on Linux). If None, will attempt to auto-detect the ESP32. - @param baud_rate Communication speed in bits per second. Default 115200 matches ESP32. + @param baud_rate Communication speed in bits per second. Default 921600 for high-throughput streaming. @param timeout Read timeout in seconds for readline(). """ self.port = port @@ -233,6 +233,9 @@ class RealSerialStream: "Must call connect() first." ) + # Flush any stale data before starting fresh stream + self.serial.reset_input_buffer() + # Send start command start_cmd = {"cmd": "start"} self._send_json(start_cmd)