bunch of training and label latency fixes, also trained 70% accurate model.

This commit is contained in:
Surya Balaji
2026-01-27 20:12:13 -06:00
parent f656f466e7
commit 9bdf9d1109
41 changed files with 563 additions and 124 deletions

View File

@@ -12,6 +12,6 @@ board_upload.maximum_size = 33554432
; and I can give you the content for it. ; and I can give you the content for it.
board_build.partitions = partitions.csv board_build.partitions = partitions.csv
monitor_speed = 115200 monitor_speed = 921600
monitor_dtr = 1 monitor_dtr = 1
monitor_rts = 1 monitor_rts = 1

View File

@@ -232,9 +232,8 @@ static void stream_emg_data(void)
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */ /* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
emg_sensor_read(&sample); emg_sensor_read(&sample);
/* Output in CSV format matching Python expectation */ /* Output in CSV format - channels only, Python handles timestamps */
printf("%lu,%u,%u,%u,%u\n", printf("%u,%u,%u,%u\n",
(unsigned long)sample.timestamp_ms,
sample.channels[0], sample.channels[0],
sample.channels[1], sample.channels[1],
sample.channels[2], sample.channels[2],
@@ -283,7 +282,7 @@ void emgPrinter() {
if (i != EMG_NUM_CHANNELS - 1) printf(" | "); if (i != EMG_NUM_CHANNELS - 1) printf(" | ");
} }
printf("\n"); printf("\n");
vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100)); // vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
} }
} }
@@ -339,6 +338,6 @@ void app_main(void)
printf("[INIT] Done!\n\n"); printf("[INIT] Done!\n\n");
emgPrinter(); // emgPrinter();
// appConnector(); appConnector();
} }

View File

@@ -63,7 +63,7 @@
* EMG Configuration * EMG Configuration
******************************************************************************/ ******************************************************************************/
#define EMG_NUM_CHANNELS 1 /**< Number of EMG sensor channels */ #define EMG_NUM_CHANNELS 4 /**< Number of EMG sensor channels */
#define EMG_SAMPLE_RATE_HZ 1000 /**< Samples per second per channel */ #define EMG_SAMPLE_RATE_HZ 1000 /**< Samples per second per channel */
/******************************************************************************* /*******************************************************************************

View File

@@ -19,7 +19,12 @@
adc_oneshot_unit_handle_t adc1_handle; adc_oneshot_unit_handle_t adc1_handle;
adc_cali_handle_t cali_handle = NULL; adc_cali_handle_t cali_handle = NULL;
const uint8_t emg_channels[EMG_NUM_CHANNELS] = {ADC_CHANNEL_1}; const uint8_t emg_channels[EMG_NUM_CHANNELS] = {
ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0
ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1
ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2
ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3
};
/******************************************************************************* /*******************************************************************************
* Public Functions * Public Functions
@@ -94,7 +99,6 @@ void emg_sensor_read(emg_sample_t *sample)
ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv)); ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
sample->channels[i] = (uint16_t) voltage_mv; sample->channels[i] = (uint16_t) voltage_mv;
} }
printf("\n");
#endif #endif
} }

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -301,6 +301,7 @@ class CollectionPage(BasePage):
self.scheduler = None self.scheduler = None
self.collected_windows = [] self.collected_windows = []
self.collected_labels = [] self.collected_labels = []
self.collected_raw_samples = [] # For label alignment
self.sample_buffer = [] self.sample_buffer = []
self.collection_thread = None self.collection_thread = None
self.data_queue = queue.Queue() self.data_queue = queue.Queue()
@@ -511,7 +512,7 @@ class CollectionPage(BasePage):
ax.tick_params(colors='white') ax.tick_params(colors='white')
ax.set_ylabel(f'Ch{i}', color='white', fontsize=10) ax.set_ylabel(f'Ch{i}', color='white', fontsize=10)
ax.set_xlim(0, 500) ax.set_xlim(0, 500)
ax.set_ylim(0, 1024) ax.set_ylim(0, 3300) # ESP32 outputs millivolts (0-3100 mV)
ax.grid(True, alpha=0.3) ax.grid(True, alpha=0.3)
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_color('white') spine.set_color('white')
@@ -648,6 +649,7 @@ class CollectionPage(BasePage):
# Reset state # Reset state
self.collected_windows = [] self.collected_windows = []
self.collected_labels = [] self.collected_labels = []
self.collected_raw_samples = [] # Store raw samples for label alignment
self.sample_buffer = [] self.sample_buffer = []
print("[DEBUG] Reset collection state") print("[DEBUG] Reset collection state")
@@ -800,6 +802,9 @@ class CollectionPage(BasePage):
timeout_warning_sent = False timeout_warning_sent = False
sample = self.parser.parse_line(line) sample = self.parser.parse_line(line)
if sample: if sample:
# Store raw sample for label alignment
self.collected_raw_samples.append(sample)
# Batch samples for plotting (don't send every single one) # Batch samples for plotting (don't send every single one)
sample_batch.append(sample.channels) sample_batch.append(sample.channels)
@@ -932,9 +937,25 @@ class CollectionPage(BasePage):
notes="" notes=""
) )
filepath = storage.save_session(self.collected_windows, self.collected_labels, metadata) # Get session start time for label alignment
session_start_time = None
if self.scheduler and self.scheduler.session_start_time:
session_start_time = self.scheduler.session_start_time
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}") filepath = storage.save_session(
windows=self.collected_windows,
labels=self.collected_labels,
metadata=metadata,
raw_samples=self.collected_raw_samples if self.collected_raw_samples else None,
session_start_time=session_start_time
)
# Check if alignment was performed
alignment_msg = ""
if session_start_time and self.collected_raw_samples:
alignment_msg = "\n\nLabel alignment: enabled"
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}{alignment_msg}")
# Update sidebar # Update sidebar
app = self.winfo_toplevel() app = self.winfo_toplevel()
@@ -944,6 +965,7 @@ class CollectionPage(BasePage):
# Reset for next collection # Reset for next collection
self.collected_windows = [] self.collected_windows = []
self.collected_labels = [] self.collected_labels = []
self.collected_raw_samples = []
self.save_button.configure(state="disabled") self.save_button.configure(state="disabled")
self.status_label.configure(text="Ready to collect") self.status_label.configure(text="Ready to collect")
self.window_count_label.configure(text="Windows: 0") self.window_count_label.configure(text="Windows: 0")

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -39,7 +39,7 @@ import matplotlib.pyplot as plt
# ============================================================================= # =============================================================================
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors) NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
SERIAL_BAUD = 115200 # Typical baud rate for ESP32 SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
# Windowing configuration # Windowing configuration
WINDOW_SIZE_MS = 150 # Window size in milliseconds WINDOW_SIZE_MS = 150 # Window size in milliseconds
@@ -55,6 +55,27 @@ DATA_DIR = Path("collected_data") # Directory to store session files
MODEL_DIR = Path("models") # Directory to store trained models MODEL_DIR = Path("models") # Directory to store trained models
USER_ID = "user_001" # Current user ID (change per user) USER_ID = "user_001" # Current user ID (change per user)
# =============================================================================
# LABEL ALIGNMENT CONFIGURATION
# =============================================================================
# Human reaction time causes EMG activity to lag behind label prompts.
# We detect when EMG actually rises and shift labels to match.
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
# =============================================================================
# TRANSITION WINDOW FILTERING
# =============================================================================
# Windows near gesture transitions contain ambiguous data (reaction time at start,
# muscle relaxation at end). Discard these during training for cleaner labels.
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
# ============================================================================= # =============================================================================
# DATA STRUCTURES # DATA STRUCTURES
# ============================================================================= # =============================================================================
@@ -187,30 +208,27 @@ class EMGParser:
""" """
Parse a line from ESP32 into an EMGSample. Parse a line from ESP32 into an EMGSample.
Expected format: "timestamp_ms,ch0,ch1,ch2,ch3\n" Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
Python assigns timestamp on receipt for label alignment.
Returns None if parsing fails. Returns None if parsing fails.
""" """
try: try:
# Strip whitespace and split # Strip whitespace and split
parts = line.strip().split(',') parts = line.strip().split(',')
# Validate we have correct number of fields # Validate we have correct number of fields (channels only)
expected_fields = 1 + self.num_channels # timestamp + channels if len(parts) != self.num_channels:
if len(parts) != expected_fields:
self.parse_errors += 1 self.parse_errors += 1
return None return None
# Parse ESP32 timestamp
esp_timestamp_ms = int(parts[0])
# Parse channel values # Parse channel values
channels = [float(parts[i + 1]) for i in range(self.num_channels)] channels = [float(parts[i]) for i in range(self.num_channels)]
# Create sample with Python-side timestamp # Create sample with Python-side timestamp (aligned with label clock)
sample = EMGSample( sample = EMGSample(
timestamp=time.perf_counter(), # High-resolution monotonic clock timestamp=time.perf_counter(), # High-resolution monotonic clock
channels=channels, channels=channels,
esp_timestamp_ms=esp_timestamp_ms esp_timestamp_ms=None # No longer using ESP32 timestamp
) )
self.samples_parsed += 1 self.samples_parsed += 1
@@ -491,6 +509,179 @@ class GestureAwareEMGStream(SimulatedEMGStream):
time.sleep(interval) time.sleep(interval)
# =============================================================================
# LABEL ALIGNMENT (Simple Onset Detection)
# =============================================================================
from scipy.signal import butter, sosfiltfilt
def align_labels_with_onset(
labels: list[str],
window_start_times: np.ndarray,
raw_timestamps: np.ndarray,
raw_channels: np.ndarray,
sampling_rate: int,
threshold_factor: float = 2.0,
search_ms: float = 800
) -> list[str]:
"""
Align labels to EMG onset by detecting when signal rises above baseline.
Simple algorithm:
1. High-pass filter to remove DC offset
2. Compute RMS envelope across channels
3. At each label transition, find where envelope exceeds baseline + threshold
4. Move label boundary to that point
"""
if len(labels) == 0:
return labels.copy()
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
nyquist = sampling_rate / 2
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
# Filter and compute envelope (RMS across channels)
filtered = np.zeros_like(raw_channels)
for ch in range(raw_channels.shape[1]):
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
# Smooth envelope
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
envelope = sosfiltfilt(sos_lp, envelope)
# Find transitions and detect onsets
search_samples = int(search_ms / 1000 * sampling_rate)
baseline_samples = int(200 / 1000 * sampling_rate)
boundaries = [] # (time, new_label)
for i in range(1, len(labels)):
if labels[i] != labels[i - 1]:
prompt_time = window_start_times[i]
# Find index in raw signal closest to prompt time
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
# Get baseline (before transition)
base_start = max(0, prompt_idx - baseline_samples)
baseline = envelope[base_start:prompt_idx]
if len(baseline) == 0:
boundaries.append((prompt_time + 0.3, labels[i]))
continue
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
# Search forward for onset
search_end = min(len(envelope), prompt_idx + search_samples)
onset_idx = None
for j in range(prompt_idx, search_end):
if envelope[j] > threshold:
onset_idx = j
break
if onset_idx is not None:
onset_time = raw_timestamps[onset_idx]
else:
onset_time = prompt_time + 0.3 # fallback
boundaries.append((onset_time, labels[i]))
# Assign labels based on detected boundaries
aligned = []
boundary_idx = 0
current_label = labels[0]
for t in window_start_times:
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
current_label = boundaries[boundary_idx][1]
boundary_idx += 1
aligned.append(current_label)
return aligned
def filter_transition_windows(
X: np.ndarray,
y: np.ndarray,
labels: list[str],
start_times: np.ndarray,
end_times: np.ndarray,
transition_start_ms: float = TRANSITION_START_MS,
transition_end_ms: float = TRANSITION_END_MS
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Filter out windows that fall within transition zones at gesture boundaries.
This removes ambiguous data where:
- User is still reacting to prompt (start of gesture)
- User is anticipating next gesture (end of gesture)
Args:
X: EMG data array (n_windows, samples, channels)
y: Label indices (n_windows,)
labels: String labels (n_windows,)
start_times: Window start times in seconds (n_windows,)
end_times: Window end times in seconds (n_windows,)
transition_start_ms: Discard windows within this time after gesture start
transition_end_ms: Discard windows within this time before gesture end
Returns:
Filtered (X, y, labels) with transition windows removed
"""
if len(X) == 0:
return X, y, labels
transition_start_sec = transition_start_ms / 1000.0
transition_end_sec = transition_end_ms / 1000.0
# Find gesture boundaries (where label changes)
# Each boundary is the START of a new gesture segment
boundaries = [0] # First window starts a segment
for i in range(1, len(labels)):
if labels[i] != labels[i-1]:
boundaries.append(i)
boundaries.append(len(labels)) # End marker
# For each segment, find start_time and end_time of the gesture
# Then mark windows that are within transition zones
keep_mask = np.ones(len(X), dtype=bool)
for seg_idx in range(len(boundaries) - 1):
seg_start_idx = boundaries[seg_idx]
seg_end_idx = boundaries[seg_idx + 1]
# Get the time boundaries of this gesture segment
gesture_start_time = start_times[seg_start_idx]
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
# Mark windows in transition zones
for i in range(seg_start_idx, seg_end_idx):
window_start = start_times[i]
window_end = end_times[i]
# Check if window is too close to gesture START (reaction time zone)
if window_start < gesture_start_time + transition_start_sec:
keep_mask[i] = False
# Check if window is too close to gesture END (anticipation zone)
if window_end > gesture_end_time - transition_end_sec:
keep_mask[i] = False
# Apply filter
X_filtered = X[keep_mask]
y_filtered = y[keep_mask]
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
n_removed = len(X) - len(X_filtered)
if n_removed > 0:
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
print(f"[Filter] Kept {len(X_filtered)} windows for training")
return X_filtered, y_filtered, labels_filtered
# ============================================================================= # =============================================================================
# SESSION STORAGE (Save/Load labeled data to HDF5) # SESSION STORAGE (Save/Load labeled data to HDF5)
# ============================================================================= # =============================================================================
@@ -527,16 +718,24 @@ class SessionStorage:
windows: list[EMGWindow], windows: list[EMGWindow],
labels: list[str], labels: list[str],
metadata: SessionMetadata, metadata: SessionMetadata,
raw_samples: Optional[list[EMGSample]] = None raw_samples: Optional[list[EMGSample]] = None,
session_start_time: Optional[float] = None,
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
) -> Path: ) -> Path:
""" """
Save a collection session to HDF5. Save a collection session to HDF5 with optional label alignment.
When raw_samples and session_start_time are provided and enable_alignment
is True, automatically detects EMG onset and corrects labels for human
reaction time delay.
Args: Args:
windows: List of EMGWindow objects (no label info) windows: List of EMGWindow objects (no label info)
labels: List of gesture labels, parallel to windows labels: List of gesture labels, parallel to windows
metadata: Session metadata metadata: Session metadata
raw_samples: Optional raw samples for debugging raw_samples: Raw samples (required for alignment)
session_start_time: When session started (required for alignment)
enable_alignment: Whether to perform automatic label alignment
""" """
filepath = self.get_session_filepath(metadata.session_id) filepath = self.get_session_filepath(metadata.session_id)
@@ -549,6 +748,35 @@ class SessionStorage:
window_samples = len(windows[0].samples) window_samples = len(windows[0].samples)
num_channels = len(windows[0].samples[0].channels) num_channels = len(windows[0].samples[0].channels)
# Prepare timing arrays
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
# Label alignment using onset detection
aligned_labels = labels
original_labels = labels
if enable_alignment and raw_samples and len(raw_samples) > 0:
print("[Storage] Aligning labels to EMG onset...")
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
aligned_labels = align_labels_with_onset(
labels=labels,
window_start_times=start_times,
raw_timestamps=raw_timestamps,
raw_channels=raw_channels,
sampling_rate=metadata.sampling_rate,
threshold_factor=ONSET_THRESHOLD,
search_ms=ONSET_SEARCH_MS
)
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
elif enable_alignment:
print("[Storage] Warning: No raw samples, skipping alignment")
with h5py.File(filepath, 'w') as f: with h5py.File(filepath, 'w') as f:
# Metadata as attributes # Metadata as attributes
f.attrs['user_id'] = metadata.user_id f.attrs['user_id'] = metadata.user_id
@@ -568,19 +796,24 @@ class SessionStorage:
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32) emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4) windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
# Labels stored separately from window data # Store ALIGNED labels as primary (what training will use)
max_label_len = max(len(l) for l in labels) max_label_len = max(len(l) for l in aligned_labels)
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1) dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
windows_grp.create_dataset('labels', data=labels, dtype=dt) windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
# Also store original labels for reference/debugging
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
window_ids = np.array([w.window_id for w in windows], dtype=np.int32) window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
windows_grp.create_dataset('window_ids', data=window_ids) windows_grp.create_dataset('window_ids', data=window_ids)
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
windows_grp.create_dataset('start_times', data=start_times) windows_grp.create_dataset('start_times', data=start_times)
windows_grp.create_dataset('end_times', data=end_times) windows_grp.create_dataset('end_times', data=end_times)
# Store alignment metadata
f.attrs['alignment_enabled'] = enable_alignment
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
if raw_samples: if raw_samples:
raw_grp = f.create_group('raw_samples') raw_grp = f.create_group('raw_samples')
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64) timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
@@ -655,13 +888,21 @@ class SessionStorage:
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types") print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
return windows, labels_out, metadata return windows, labels_out, metadata
def load_for_training(self, session_id: str) -> tuple[np.ndarray, np.ndarray, list[str]]: def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Load a single session in ML-ready format: X, y, label_names.""" """
Load a single session in ML-ready format: X, y, label_names.
Args:
session_id: The session to load
filter_transitions: If True, remove windows in transition zones (default from config)
"""
filepath = self.get_session_filepath(session_id) filepath = self.get_session_filepath(session_id)
with h5py.File(filepath, 'r') as f: with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:] X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:] labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
labels = [] labels = []
for l in labels_raw: for l in labels_raw:
@@ -670,18 +911,33 @@ class SessionStorage:
else: else:
labels.append(l) labels.append(l)
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
# Apply transition filtering if enabled
if filter_transitions:
label_names_pre = sorted(set(labels))
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
X, y_pre, labels = filter_transition_windows(
X, y_pre, labels, start_times, end_times
)
label_names = sorted(set(labels)) label_names = sorted(set(labels))
label_to_idx = {name: idx for idx, name in enumerate(label_names)} label_to_idx = {name: idx for idx, name in enumerate(label_names)}
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32) y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
print(f"[Storage] Loaded for training: X{X.shape}, y{y.shape}") print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
print(f"[Storage] Labels: {label_names}") print(f"[Storage] Labels: {label_names}")
return X, y, label_names return X, y, label_names
def load_all_for_training(self) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]: def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
""" """
Load ALL sessions combined into a single training dataset. Load ALL sessions combined into a single training dataset.
Args:
filter_transitions: If True, remove windows in transition zones (default from config)
Returns: Returns:
X: Combined EMG windows from all sessions (n_total_windows, samples, channels) X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
y: Combined labels as integers (n_total_windows,) y: Combined labels as integers (n_total_windows,)
@@ -697,11 +953,15 @@ class SessionStorage:
raise ValueError("No sessions found to load!") raise ValueError("No sessions found to load!")
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...") print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
if filter_transitions:
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
all_X = [] all_X = []
all_labels = [] all_labels = []
loaded_sessions = [] loaded_sessions = []
reference_shape = None reference_shape = None
total_removed = 0
total_original = 0
for session_id in sessions: for session_id in sessions:
filepath = self.get_session_filepath(session_id) filepath = self.get_session_filepath(session_id)
@@ -709,6 +969,8 @@ class SessionStorage:
with h5py.File(filepath, 'r') as f: with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:] X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:] labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
# Validate shape compatibility # Validate shape compatibility
if reference_shape is None: if reference_shape is None:
@@ -725,10 +987,26 @@ class SessionStorage:
else: else:
labels.append(l) labels.append(l)
original_count = len(X)
total_original += original_count
# Apply transition filtering per session (each has its own gesture boundaries)
if filter_transitions:
# Need temporary y for filtering function
temp_label_names = sorted(set(labels))
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
X, temp_y, labels = filter_transition_windows(
X, temp_y, labels, start_times, end_times
)
total_removed += original_count - len(X)
all_X.append(X) all_X.append(X)
all_labels.extend(labels) all_labels.extend(labels)
loaded_sessions.append(session_id) loaded_sessions.append(session_id)
print(f"[Storage] - {session_id}: {X.shape[0]} windows") print(f"[Storage] - {session_id}: {len(X)} windows" +
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
if not all_X: if not all_X:
raise ValueError("No compatible sessions found!") raise ValueError("No compatible sessions found!")
@@ -742,6 +1020,8 @@ class SessionStorage:
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32) y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}") print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
if filter_transitions and total_removed > 0:
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
print(f"[Storage] Labels: {label_names}") print(f"[Storage] Labels: {label_names}")
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}") print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")

View File

@@ -1,41 +1,152 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import h5py import h5py
import pandas as pd
from scipy.signal import butter, sosfiltfilt from scipy.signal import butter, sosfiltfilt
# ============================================================================= # =============================================================================
# CONFIGURABLE PARAMETERS # CONFIGURABLE PARAMETERS
# ============================================================================= # =============================================================================
ZC_THRESHOLD_PERCENT = 0.6 # Zero Crossing threshold as fraction of RMS ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS
SSC_THRESHOLD_PERCENT = 0.3 # Slope Sign Change threshold as fraction of RMS SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS
file = h5py.File("assets/discrete_gestures_user_000_dataset_000.hdf5", "r") # =============================================================================
# LOAD DATA FROM GUI's HDF5 FORMAT
# =============================================================================
# Update this path to your collected session file
HDF5_PATH = "collected_data\latency_fix_106_20260127_200043.hdf5"
file = h5py.File(HDF5_PATH, "r")
# Print HDF5 structure for debugging
print("HDF5 Structure:")
def print_tree(name, obj): def print_tree(name, obj):
print(name) print(f" {name}")
file.visititems(print_tree) file.visititems(print_tree)
print()
print(list(file.keys())) # Load metadata
fs = file.attrs['sampling_rate'] # Sampling rate in Hz
n_channels = file.attrs['num_channels']
window_size_ms = file.attrs['window_size_ms']
print(f"Sampling rate: {fs} Hz")
print(f"Channels: {n_channels}")
print(f"Window size: {window_size_ms} ms")
data = file["data"] # Load windowed EMG data: shape (n_windows, samples_per_window, channels)
print(type(data)) emg_windows = file['windows/emg_data'][:]
print(data) labels = file['windows/labels'][:]
print(data.dtype) start_times = file['windows/start_times'][:]
end_times = file['windows/end_times'][:]
raw = data[:] print(f"Windows shape: {emg_windows.shape}")
print(raw.shape) print(f"Labels: {len(labels)} (unique: {np.unique(labels)})")
print(raw.dtype)
emg = raw['emg'] # Flatten windows to continuous signal for filtering analysis
time = raw['time'] # Shape: (n_windows * samples_per_window, channels)
print(emg.shape) n_windows, samples_per_window, _ = emg_windows.shape
emg = emg_windows.reshape(-1, n_channels)
print(f"Flattened EMG shape: {emg.shape}")
dt = np.diff(time) # Reconstruct time vector (assumes continuous recording)
fs = 1.0 / np.median(dt) total_samples = emg.shape[0]
print("fs =", fs) time = np.arange(total_samples) / fs
print("dt min/median/max =", dt.min(), np.median(dt), dt.max()) print(f"Time range: {time[0]:.2f}s to {time[-1]:.2f}s")
# =============================================================================
# PLOT RAW EMG DATA (before filtering)
# =============================================================================
print("\nPlotting raw EMG data...")
# Color map for channels
channel_colors = ['#00ff88', '#ff6b6b', '#4ecdc4', '#ffe66d']
# Map window indices to actual time in the flattened data
# Window i starts at sample (i * samples_per_window), which is time (i * samples_per_window / fs)
window_time_in_data = np.arange(len(labels)) * samples_per_window / fs
# Compute global Y range across all channels for consistent comparison
emg_global_min = emg.min()
emg_global_max = emg.max()
emg_margin = (emg_global_max - emg_global_min) * 0.05 # 5% margin
emg_ylim = (emg_global_min - emg_margin, emg_global_max + emg_margin)
print(f"Global EMG range: {emg_global_min:.1f} to {emg_global_max:.1f} mV")
# Full session plot with shared Y-axis
fig_raw, axes_raw = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels), sharex=True, sharey=True)
if n_channels == 1:
axes_raw = [axes_raw]
for ch in range(n_channels):
ax = axes_raw[ch]
# Plot raw EMG signal
ax.plot(time, emg[:, ch], linewidth=0.3, color=channel_colors[ch % len(channel_colors)], alpha=0.8)
ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xlim(time[0], time[-1])
ax.set_ylim(emg_ylim) # Same Y range for all channels
# Add gesture markers AFTER plotting
y_min, y_max = emg_ylim
y_text = y_min + (y_max - y_min) * 0.85 # Position text at 85% height
for ch in range(n_channels):
ax = axes_raw[ch]
# Add gesture transition markers using window-aligned times
prev_label = None
for i, lbl in enumerate(labels):
lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl
t = window_time_in_data[i] # Time in the flattened data
if lbl_str != prev_label:
if 'open' in lbl_str:
color = 'cyan'
elif 'fist' in lbl_str:
color = 'blue'
elif 'hook' in lbl_str:
color = 'orange'
elif 'thumb' in lbl_str:
color = 'green'
elif 'rest' in lbl_str:
color = 'gray'
else:
color = 'red'
ax.axvline(t, color=color, linestyle='--', alpha=0.7, linewidth=1)
# Only add text label on first channel to avoid clutter
if ch == 0 and lbl_str != 'rest':
ax.text(t + 0.2, y_text, lbl_str, fontsize=8,
color=color, rotation=0, ha='left', va='top',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7))
prev_label = lbl_str
axes_raw[-1].set_xlabel('Time (s)', fontsize=11)
fig_raw.suptitle('Raw EMG Signal (All Channels)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# Zoomed view of first 2 seconds to see waveform detail
zoom_duration = 2.0
zoom_samples = int(zoom_duration * fs)
if total_samples > zoom_samples:
fig_zoom, axes_zoom = plt.subplots(n_channels, 1, figsize=(14, 2 * n_channels), sharex=True, sharey=True)
if n_channels == 1:
axes_zoom = [axes_zoom]
for ch in range(n_channels):
ax = axes_zoom[ch]
ax.plot(time[:zoom_samples], emg[:zoom_samples, ch],
linewidth=0.5, color=channel_colors[ch % len(channel_colors)])
ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10)
ax.set_ylim(emg_ylim) # Same Y range as full plot
ax.grid(True, alpha=0.3)
axes_zoom[-1].set_xlabel('Time (s)', fontsize=11)
fig_zoom.suptitle(f'Raw EMG Signal (First {zoom_duration}s - Zoomed)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
# 1) pick one channel # 1) pick one channel
emg_ch = emg[:, 0].astype(np.float32) emg_ch = emg[:, 0].astype(np.float32)
@@ -122,10 +233,9 @@ def compute_all_features_windowed(x, window_len, threshold_zc, threshold_ssc):
# ============================================================================= # =============================================================================
# COMPUTE FEATURES FOR ALL 16 CHANNELS # COMPUTE FEATURES FOR ALL CHANNELS
# ============================================================================= # =============================================================================
n_channels = 16
all_features = {} # Non-overlapping windows - for ML all_features = {} # Non-overlapping windows - for ML
for ch in range(n_channels): for ch in range(n_channels):
@@ -143,89 +253,106 @@ for ch in range(n_channels):
all_features[ch] = {'rms': rms_i, 'wl': wl_i, 'zc': zc_i, 'ssc': ssc_i} all_features[ch] = {'rms': rms_i, 'wl': wl_i, 'zc': zc_i, 'ssc': ssc_i}
# Time vector for windowed features # Time vector for windowed features
n_windows = len(all_features[0]['rms']) n_feat_windows = len(all_features[0]['rms'])
window_centers = np.arange(n_windows) * window_samples + window_samples // 2 window_centers = np.arange(n_feat_windows) * window_samples + window_samples // 2
time_windows = time_ch[window_centers] time_windows = time_ch[window_centers]
print(f"\nComputed features for {n_channels} channels") print(f"\nComputed features for {n_channels} channels")
print(f"Windows per channel (non-overlapping): {n_windows}") print(f"Windows per channel (non-overlapping): {n_feat_windows}")
print(f"Window size: {window_samples} samples ({window_ms} ms)") print(f"Window size: {window_samples} samples ({window_ms} ms)")
# 5) Load gesture labels (prompts) # 5) Use embedded gesture labels from HDF5
prompts = pd.read_hdf("assets/discrete_gestures_user_000_dataset_000.hdf5", key="prompts") # Labels are per-window, aligned with emg_windows
print("\nUnique gestures:", prompts['name'].unique()) unique_labels = np.unique(labels)
print(f"\nUnique gestures in session: {unique_labels}")
t_abs = time_ch # absolute timestamps # Map labels to time in the flattened data (same as raw plot)
# Each original window i maps to time (i * samples_per_window / fs) in the flattened data
# Find first occurrence of each gesture type # But we dropped `drop` samples, so adjust accordingly
index_gestures = prompts[prompts['name'].str.contains('index')] label_times_in_data = np.arange(len(labels)) * samples_per_window / fs
middle_gestures = prompts[prompts['name'].str.contains('middle')] print(f"Data duration: {time[-1]:.2f}s ({len(labels)} windows)")
thumb_gestures = prompts[prompts['name'].str.contains('thumb')]
# Define plot configurations: 1) Index+Middle combined, 2) Thumb separate
plot_configs = [
{'name': 'Index & Middle Finger', 'start_time': index_gestures['time'].iloc[0], 'filter': 'index|middle'},
{'name': 'Thumb', 'start_time': thumb_gestures['time'].iloc[0], 'filter': 'thumb'},
]
# Color function for markers # Color function for markers
def get_gesture_color(name): def get_gesture_color(name):
if 'index' in name: """Assign colors to gesture types."""
return 'green' name_str = name.decode('utf-8') if isinstance(name, bytes) else name
elif 'middle' in name: if 'open' in name_str:
return 'cyan'
elif 'fist' in name_str:
return 'blue' return 'blue'
elif 'thumb' in name: elif 'hook' in name_str:
return 'orange' return 'orange'
elif 'thumb' in name_str:
return 'green'
elif 'rest' in name_str:
return 'gray' return 'gray'
return 'red'
# ============================================================================= # =============================================================================
# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL 16 CHANNELS # 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL CHANNELS
# ============================================================================= # =============================================================================
feature_names = ['rms', 'wl', 'zc', 'ssc'] feature_names = ['rms', 'wl', 'zc', 'ssc']
feature_titles = ['RMS Envelope', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)'] feature_titles = ['RMS Envelope', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)']
feature_colors = ['red', 'blue', 'green', 'purple'] feature_colors = ['red', 'blue', 'green', 'purple']
feature_ylabels = ['Amplitude', 'WL (a.u.)', 'Count', 'Count'] feature_ylabels = ['Amplitude (mV)', 'WL (a.u.)', 'Count', 'Count']
for config in plot_configs: # Determine subplot grid based on channel count
t_start = config['start_time'] - 0.5 if n_channels <= 4:
t_end = t_start + 10.0 n_rows, n_cols = 2, 2
elif n_channels <= 9:
n_rows, n_cols = 3, 3
else:
n_rows, n_cols = 4, 4
# Mask for windowed time vector # Plot entire session for each feature
mask_win = (time_windows >= t_start) & (time_windows <= t_end)
t_win_rel = time_windows[mask_win] - t_start
# Get gestures in window
gesture_mask = (prompts['time'] >= t_start) & (prompts['time'] <= t_end) & \
(prompts['name'].str.contains(config['filter']))
gestures_in_window = prompts[gesture_mask]
# Create one figure per feature
for feat_idx, feat_name in enumerate(feature_names): for feat_idx, feat_name in enumerate(feature_names):
fig, axes = plt.subplots(4, 4, figsize=(10, 8), sharex=True, sharey=True) fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8), sharex=True, sharey=True)
axes = axes.flatten() axes = axes.flatten()
# Compute global Y range for this feature across all channels
all_feat_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
feat_min, feat_max = all_feat_vals.min(), all_feat_vals.max()
feat_margin = (feat_max - feat_min) * 0.05 if feat_max > feat_min else 1
feat_ylim = (feat_min - feat_margin, feat_max + feat_margin)
for ch in range(n_channels): for ch in range(n_channels):
ax = axes[ch] ax = axes[ch]
# Plot windowed feature # Plot windowed feature over time
feat_data = all_features[ch][feat_name][mask_win] feat_data = all_features[ch][feat_name]
ax.plot(t_win_rel, feat_data, linewidth=1, color=feature_colors[feat_idx]) ax.plot(time_windows, feat_data, linewidth=0.8, color=feature_colors[feat_idx])
ax.set_title(f"Ch {ch}", fontsize=9) ax.set_title(f"Channel {ch}", fontsize=10)
ax.set_ylim(feat_ylim) # Same Y range for all channels
ax.grid(True, alpha=0.3)
# Add gesture markers # Add gesture transition markers from labels
for _, row in gestures_in_window.iterrows(): prev_label = None
t_g = row['time'] - t_start for i, lbl in enumerate(labels):
color = get_gesture_color(row['name']) lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl
ax.axvline(t_g, color=color, linestyle='--', alpha=0.5, linewidth=0.5) t = label_times_in_data[i] # Time in flattened data
if lbl_str != prev_label and lbl_str != 'rest':
# Only show markers within the time_windows range
if t <= time_windows[-1]:
color = get_gesture_color(lbl)
ax.axvline(t, color=color, linestyle='--', alpha=0.6, linewidth=1)
prev_label = lbl_str
# Set subtitle based on gesture type # Hide unused subplots
if 'Index' in config['name']: for ch in range(n_channels, len(axes)):
subtitle = "(Green=index, Blue=middle)" axes[ch].set_visible(False)
else:
subtitle = "(Orange=thumb)"
fig.suptitle(f"{feature_titles[feat_idx]} - {config['name']} Gestures\n{subtitle}", fontsize=12) # Legend for gesture colors
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='cyan', linestyle='--', label='Open'),
Line2D([0], [0], color='blue', linestyle='--', label='Fist'),
Line2D([0], [0], color='orange', linestyle='--', label='Hook Em'),
Line2D([0], [0], color='green', linestyle='--', label='Thumbs Up'),
]
fig.legend(handles=legend_elements, loc='upper right', fontsize=9)
fig.suptitle(f"{feature_titles[feat_idx]} - All Channels", fontsize=14)
fig.supxlabel("Time (s)") fig.supxlabel("Time (s)")
fig.supylabel(feature_ylabels[feat_idx]) fig.supylabel(feature_ylabels[feat_idx])
plt.tight_layout() plt.tight_layout()
@@ -235,9 +362,13 @@ for config in plot_configs:
# SUMMARY: Feature statistics across all channels # SUMMARY: Feature statistics across all channels
# ============================================================================= # =============================================================================
print("\n" + "=" * 60) print("\n" + "=" * 60)
print("FEATURE SUMMARY (all channels, all windows)") print(f"FEATURE SUMMARY ({n_channels} channels, {n_feat_windows} windows each)")
print("=" * 60) print("=" * 60)
for feat_name in ['rms', 'wl', 'zc', 'ssc']: for feat_name in ['rms', 'wl', 'zc', 'ssc']:
all_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)]) all_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
print(f"{feat_name.upper():4s} | min: {all_vals.min():10.4f} | max: {all_vals.max():10.4f} | " print(f"{feat_name.upper():4s} | min: {all_vals.min():10.4f} | max: {all_vals.max():10.4f} | "
f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}") f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}")
# Close the HDF5 file
file.close()
print("\n[Done] HDF5 file closed.")

Binary file not shown.

View File

@@ -99,13 +99,13 @@ class RealSerialStream:
@note Requires pyserial: pip install pyserial @note Requires pyserial: pip install pyserial
""" """
def __init__(self, port: str = None, baud_rate: int = 115200, timeout: float = 1.0): def __init__(self, port: str = None, baud_rate: int = 921600, timeout: float = 0.05):
""" """
@brief Initialize the serial stream. @brief Initialize the serial stream.
@param port Serial port name (e.g., 'COM3' on Windows, '/dev/ttyUSB0' on Linux). @param port Serial port name (e.g., 'COM3' on Windows, '/dev/ttyUSB0' on Linux).
If None, will attempt to auto-detect the ESP32. If None, will attempt to auto-detect the ESP32.
@param baud_rate Communication speed in bits per second. Default 115200 matches ESP32. @param baud_rate Communication speed in bits per second. Default 921600 for high-throughput streaming.
@param timeout Read timeout in seconds for readline(). @param timeout Read timeout in seconds for readline().
""" """
self.port = port self.port = port
@@ -233,6 +233,9 @@ class RealSerialStream:
"Must call connect() first." "Must call connect() first."
) )
# Flush any stale data before starting fresh stream
self.serial.reset_input_buffer()
# Send start command # Send start command
start_cmd = {"cmd": "start"} start_cmd = {"cmd": "start"}
self._send_json(start_cmd) self._send_json(start_cmd)