bunch of training and label latency fixes, also trained 70% accurate model.

This commit is contained in:
Surya Balaji
2026-01-27 20:12:13 -06:00
parent f656f466e7
commit 9bdf9d1109
41 changed files with 563 additions and 124 deletions

View File

@@ -39,7 +39,7 @@ import matplotlib.pyplot as plt
# =============================================================================
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
SERIAL_BAUD = 115200 # Typical baud rate for ESP32
SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
# Windowing configuration
WINDOW_SIZE_MS = 150 # Window size in milliseconds
@@ -55,6 +55,27 @@ DATA_DIR = Path("collected_data") # Directory to store session files
MODEL_DIR = Path("models") # Directory to store trained models
USER_ID = "user_001" # Current user ID (change per user)
# =============================================================================
# LABEL ALIGNMENT CONFIGURATION
# =============================================================================
# Human reaction time causes EMG activity to lag behind label prompts.
# We detect when EMG actually rises and shift labels to match.
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
# =============================================================================
# TRANSITION WINDOW FILTERING
# =============================================================================
# Windows near gesture transitions contain ambiguous data (reaction time at start,
# muscle relaxation at end). Discard these during training for cleaner labels.
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
# =============================================================================
# DATA STRUCTURES
# =============================================================================
@@ -187,30 +208,27 @@ class EMGParser:
"""
Parse a line from ESP32 into an EMGSample.
Expected format: "timestamp_ms,ch0,ch1,ch2,ch3\n"
Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
Python assigns timestamp on receipt for label alignment.
Returns None if parsing fails.
"""
try:
# Strip whitespace and split
parts = line.strip().split(',')
# Validate we have correct number of fields
expected_fields = 1 + self.num_channels # timestamp + channels
if len(parts) != expected_fields:
# Validate we have correct number of fields (channels only)
if len(parts) != self.num_channels:
self.parse_errors += 1
return None
# Parse ESP32 timestamp
esp_timestamp_ms = int(parts[0])
# Parse channel values
channels = [float(parts[i + 1]) for i in range(self.num_channels)]
channels = [float(parts[i]) for i in range(self.num_channels)]
# Create sample with Python-side timestamp
# Create sample with Python-side timestamp (aligned with label clock)
sample = EMGSample(
timestamp=time.perf_counter(), # High-resolution monotonic clock
channels=channels,
esp_timestamp_ms=esp_timestamp_ms
esp_timestamp_ms=None # No longer using ESP32 timestamp
)
self.samples_parsed += 1
@@ -491,6 +509,179 @@ class GestureAwareEMGStream(SimulatedEMGStream):
time.sleep(interval)
# =============================================================================
# LABEL ALIGNMENT (Simple Onset Detection)
# =============================================================================
from scipy.signal import butter, sosfiltfilt
def align_labels_with_onset(
labels: list[str],
window_start_times: np.ndarray,
raw_timestamps: np.ndarray,
raw_channels: np.ndarray,
sampling_rate: int,
threshold_factor: float = 2.0,
search_ms: float = 800
) -> list[str]:
"""
Align labels to EMG onset by detecting when signal rises above baseline.
Simple algorithm:
1. High-pass filter to remove DC offset
2. Compute RMS envelope across channels
3. At each label transition, find where envelope exceeds baseline + threshold
4. Move label boundary to that point
"""
if len(labels) == 0:
return labels.copy()
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
nyquist = sampling_rate / 2
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
# Filter and compute envelope (RMS across channels)
filtered = np.zeros_like(raw_channels)
for ch in range(raw_channels.shape[1]):
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
# Smooth envelope
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
envelope = sosfiltfilt(sos_lp, envelope)
# Find transitions and detect onsets
search_samples = int(search_ms / 1000 * sampling_rate)
baseline_samples = int(200 / 1000 * sampling_rate)
boundaries = [] # (time, new_label)
for i in range(1, len(labels)):
if labels[i] != labels[i - 1]:
prompt_time = window_start_times[i]
# Find index in raw signal closest to prompt time
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
# Get baseline (before transition)
base_start = max(0, prompt_idx - baseline_samples)
baseline = envelope[base_start:prompt_idx]
if len(baseline) == 0:
boundaries.append((prompt_time + 0.3, labels[i]))
continue
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
# Search forward for onset
search_end = min(len(envelope), prompt_idx + search_samples)
onset_idx = None
for j in range(prompt_idx, search_end):
if envelope[j] > threshold:
onset_idx = j
break
if onset_idx is not None:
onset_time = raw_timestamps[onset_idx]
else:
onset_time = prompt_time + 0.3 # fallback
boundaries.append((onset_time, labels[i]))
# Assign labels based on detected boundaries
aligned = []
boundary_idx = 0
current_label = labels[0]
for t in window_start_times:
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
current_label = boundaries[boundary_idx][1]
boundary_idx += 1
aligned.append(current_label)
return aligned
def filter_transition_windows(
X: np.ndarray,
y: np.ndarray,
labels: list[str],
start_times: np.ndarray,
end_times: np.ndarray,
transition_start_ms: float = TRANSITION_START_MS,
transition_end_ms: float = TRANSITION_END_MS
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Filter out windows that fall within transition zones at gesture boundaries.
This removes ambiguous data where:
- User is still reacting to prompt (start of gesture)
- User is anticipating next gesture (end of gesture)
Args:
X: EMG data array (n_windows, samples, channels)
y: Label indices (n_windows,)
labels: String labels (n_windows,)
start_times: Window start times in seconds (n_windows,)
end_times: Window end times in seconds (n_windows,)
transition_start_ms: Discard windows within this time after gesture start
transition_end_ms: Discard windows within this time before gesture end
Returns:
Filtered (X, y, labels) with transition windows removed
"""
if len(X) == 0:
return X, y, labels
transition_start_sec = transition_start_ms / 1000.0
transition_end_sec = transition_end_ms / 1000.0
# Find gesture boundaries (where label changes)
# Each boundary is the START of a new gesture segment
boundaries = [0] # First window starts a segment
for i in range(1, len(labels)):
if labels[i] != labels[i-1]:
boundaries.append(i)
boundaries.append(len(labels)) # End marker
# For each segment, find start_time and end_time of the gesture
# Then mark windows that are within transition zones
keep_mask = np.ones(len(X), dtype=bool)
for seg_idx in range(len(boundaries) - 1):
seg_start_idx = boundaries[seg_idx]
seg_end_idx = boundaries[seg_idx + 1]
# Get the time boundaries of this gesture segment
gesture_start_time = start_times[seg_start_idx]
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
# Mark windows in transition zones
for i in range(seg_start_idx, seg_end_idx):
window_start = start_times[i]
window_end = end_times[i]
# Check if window is too close to gesture START (reaction time zone)
if window_start < gesture_start_time + transition_start_sec:
keep_mask[i] = False
# Check if window is too close to gesture END (anticipation zone)
if window_end > gesture_end_time - transition_end_sec:
keep_mask[i] = False
# Apply filter
X_filtered = X[keep_mask]
y_filtered = y[keep_mask]
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
n_removed = len(X) - len(X_filtered)
if n_removed > 0:
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
print(f"[Filter] Kept {len(X_filtered)} windows for training")
return X_filtered, y_filtered, labels_filtered
# =============================================================================
# SESSION STORAGE (Save/Load labeled data to HDF5)
# =============================================================================
@@ -527,16 +718,24 @@ class SessionStorage:
windows: list[EMGWindow],
labels: list[str],
metadata: SessionMetadata,
raw_samples: Optional[list[EMGSample]] = None
raw_samples: Optional[list[EMGSample]] = None,
session_start_time: Optional[float] = None,
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
) -> Path:
"""
Save a collection session to HDF5.
Save a collection session to HDF5 with optional label alignment.
When raw_samples and session_start_time are provided and enable_alignment
is True, automatically detects EMG onset and corrects labels for human
reaction time delay.
Args:
windows: List of EMGWindow objects (no label info)
labels: List of gesture labels, parallel to windows
metadata: Session metadata
raw_samples: Optional raw samples for debugging
raw_samples: Raw samples (required for alignment)
session_start_time: When session started (required for alignment)
enable_alignment: Whether to perform automatic label alignment
"""
filepath = self.get_session_filepath(metadata.session_id)
@@ -549,6 +748,35 @@ class SessionStorage:
window_samples = len(windows[0].samples)
num_channels = len(windows[0].samples[0].channels)
# Prepare timing arrays
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
# Label alignment using onset detection
aligned_labels = labels
original_labels = labels
if enable_alignment and raw_samples and len(raw_samples) > 0:
print("[Storage] Aligning labels to EMG onset...")
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
aligned_labels = align_labels_with_onset(
labels=labels,
window_start_times=start_times,
raw_timestamps=raw_timestamps,
raw_channels=raw_channels,
sampling_rate=metadata.sampling_rate,
threshold_factor=ONSET_THRESHOLD,
search_ms=ONSET_SEARCH_MS
)
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
elif enable_alignment:
print("[Storage] Warning: No raw samples, skipping alignment")
with h5py.File(filepath, 'w') as f:
# Metadata as attributes
f.attrs['user_id'] = metadata.user_id
@@ -568,19 +796,24 @@ class SessionStorage:
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
# Labels stored separately from window data
max_label_len = max(len(l) for l in labels)
# Store ALIGNED labels as primary (what training will use)
max_label_len = max(len(l) for l in aligned_labels)
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
windows_grp.create_dataset('labels', data=labels, dtype=dt)
windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
# Also store original labels for reference/debugging
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
windows_grp.create_dataset('window_ids', data=window_ids)
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
windows_grp.create_dataset('start_times', data=start_times)
windows_grp.create_dataset('end_times', data=end_times)
# Store alignment metadata
f.attrs['alignment_enabled'] = enable_alignment
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
if raw_samples:
raw_grp = f.create_group('raw_samples')
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
@@ -655,13 +888,21 @@ class SessionStorage:
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
return windows, labels_out, metadata
def load_for_training(self, session_id: str) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Load a single session in ML-ready format: X, y, label_names."""
def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Load a single session in ML-ready format: X, y, label_names.
Args:
session_id: The session to load
filter_transitions: If True, remove windows in transition zones (default from config)
"""
filepath = self.get_session_filepath(session_id)
with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
labels = []
for l in labels_raw:
@@ -670,18 +911,33 @@ class SessionStorage:
else:
labels.append(l)
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
# Apply transition filtering if enabled
if filter_transitions:
label_names_pre = sorted(set(labels))
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
X, y_pre, labels = filter_transition_windows(
X, y_pre, labels, start_times, end_times
)
label_names = sorted(set(labels))
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
print(f"[Storage] Loaded for training: X{X.shape}, y{y.shape}")
print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
print(f"[Storage] Labels: {label_names}")
return X, y, label_names
def load_all_for_training(self) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
"""
Load ALL sessions combined into a single training dataset.
Args:
filter_transitions: If True, remove windows in transition zones (default from config)
Returns:
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
y: Combined labels as integers (n_total_windows,)
@@ -697,11 +953,15 @@ class SessionStorage:
raise ValueError("No sessions found to load!")
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
if filter_transitions:
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
all_X = []
all_labels = []
loaded_sessions = []
reference_shape = None
total_removed = 0
total_original = 0
for session_id in sessions:
filepath = self.get_session_filepath(session_id)
@@ -709,6 +969,8 @@ class SessionStorage:
with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
# Validate shape compatibility
if reference_shape is None:
@@ -725,10 +987,26 @@ class SessionStorage:
else:
labels.append(l)
original_count = len(X)
total_original += original_count
# Apply transition filtering per session (each has its own gesture boundaries)
if filter_transitions:
# Need temporary y for filtering function
temp_label_names = sorted(set(labels))
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
X, temp_y, labels = filter_transition_windows(
X, temp_y, labels, start_times, end_times
)
total_removed += original_count - len(X)
all_X.append(X)
all_labels.extend(labels)
loaded_sessions.append(session_id)
print(f"[Storage] - {session_id}: {X.shape[0]} windows")
print(f"[Storage] - {session_id}: {len(X)} windows" +
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
if not all_X:
raise ValueError("No compatible sessions found!")
@@ -742,6 +1020,8 @@ class SessionStorage:
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
if filter_transitions and total_removed > 0:
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
print(f"[Storage] Labels: {label_names}")
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")