bunch of training and label latency fixes, also trained 70% accurate model.
This commit is contained in:
@@ -39,7 +39,7 @@ import matplotlib.pyplot as plt
|
||||
# =============================================================================
|
||||
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
|
||||
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
|
||||
SERIAL_BAUD = 115200 # Typical baud rate for ESP32
|
||||
SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
|
||||
|
||||
# Windowing configuration
|
||||
WINDOW_SIZE_MS = 150 # Window size in milliseconds
|
||||
@@ -55,6 +55,27 @@ DATA_DIR = Path("collected_data") # Directory to store session files
|
||||
MODEL_DIR = Path("models") # Directory to store trained models
|
||||
USER_ID = "user_001" # Current user ID (change per user)
|
||||
|
||||
# =============================================================================
|
||||
# LABEL ALIGNMENT CONFIGURATION
|
||||
# =============================================================================
|
||||
# Human reaction time causes EMG activity to lag behind label prompts.
|
||||
# We detect when EMG actually rises and shift labels to match.
|
||||
|
||||
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
|
||||
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
|
||||
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
|
||||
|
||||
# =============================================================================
|
||||
# TRANSITION WINDOW FILTERING
|
||||
# =============================================================================
|
||||
# Windows near gesture transitions contain ambiguous data (reaction time at start,
|
||||
# muscle relaxation at end). Discard these during training for cleaner labels.
|
||||
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
|
||||
|
||||
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
|
||||
TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts
|
||||
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
|
||||
|
||||
# =============================================================================
|
||||
# DATA STRUCTURES
|
||||
# =============================================================================
|
||||
@@ -187,30 +208,27 @@ class EMGParser:
|
||||
"""
|
||||
Parse a line from ESP32 into an EMGSample.
|
||||
|
||||
Expected format: "timestamp_ms,ch0,ch1,ch2,ch3\n"
|
||||
Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
|
||||
Python assigns timestamp on receipt for label alignment.
|
||||
Returns None if parsing fails.
|
||||
"""
|
||||
try:
|
||||
# Strip whitespace and split
|
||||
parts = line.strip().split(',')
|
||||
|
||||
# Validate we have correct number of fields
|
||||
expected_fields = 1 + self.num_channels # timestamp + channels
|
||||
if len(parts) != expected_fields:
|
||||
# Validate we have correct number of fields (channels only)
|
||||
if len(parts) != self.num_channels:
|
||||
self.parse_errors += 1
|
||||
return None
|
||||
|
||||
# Parse ESP32 timestamp
|
||||
esp_timestamp_ms = int(parts[0])
|
||||
|
||||
# Parse channel values
|
||||
channels = [float(parts[i + 1]) for i in range(self.num_channels)]
|
||||
channels = [float(parts[i]) for i in range(self.num_channels)]
|
||||
|
||||
# Create sample with Python-side timestamp
|
||||
# Create sample with Python-side timestamp (aligned with label clock)
|
||||
sample = EMGSample(
|
||||
timestamp=time.perf_counter(), # High-resolution monotonic clock
|
||||
channels=channels,
|
||||
esp_timestamp_ms=esp_timestamp_ms
|
||||
esp_timestamp_ms=None # No longer using ESP32 timestamp
|
||||
)
|
||||
|
||||
self.samples_parsed += 1
|
||||
@@ -491,6 +509,179 @@ class GestureAwareEMGStream(SimulatedEMGStream):
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LABEL ALIGNMENT (Simple Onset Detection)
|
||||
# =============================================================================
|
||||
from scipy.signal import butter, sosfiltfilt
|
||||
|
||||
|
||||
def align_labels_with_onset(
|
||||
labels: list[str],
|
||||
window_start_times: np.ndarray,
|
||||
raw_timestamps: np.ndarray,
|
||||
raw_channels: np.ndarray,
|
||||
sampling_rate: int,
|
||||
threshold_factor: float = 2.0,
|
||||
search_ms: float = 800
|
||||
) -> list[str]:
|
||||
"""
|
||||
Align labels to EMG onset by detecting when signal rises above baseline.
|
||||
|
||||
Simple algorithm:
|
||||
1. High-pass filter to remove DC offset
|
||||
2. Compute RMS envelope across channels
|
||||
3. At each label transition, find where envelope exceeds baseline + threshold
|
||||
4. Move label boundary to that point
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
return labels.copy()
|
||||
|
||||
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
|
||||
nyquist = sampling_rate / 2
|
||||
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
|
||||
|
||||
# Filter and compute envelope (RMS across channels)
|
||||
filtered = np.zeros_like(raw_channels)
|
||||
for ch in range(raw_channels.shape[1]):
|
||||
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
|
||||
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
|
||||
|
||||
# Smooth envelope
|
||||
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
|
||||
envelope = sosfiltfilt(sos_lp, envelope)
|
||||
|
||||
# Find transitions and detect onsets
|
||||
search_samples = int(search_ms / 1000 * sampling_rate)
|
||||
baseline_samples = int(200 / 1000 * sampling_rate)
|
||||
|
||||
boundaries = [] # (time, new_label)
|
||||
|
||||
for i in range(1, len(labels)):
|
||||
if labels[i] != labels[i - 1]:
|
||||
prompt_time = window_start_times[i]
|
||||
|
||||
# Find index in raw signal closest to prompt time
|
||||
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
|
||||
|
||||
# Get baseline (before transition)
|
||||
base_start = max(0, prompt_idx - baseline_samples)
|
||||
baseline = envelope[base_start:prompt_idx]
|
||||
if len(baseline) == 0:
|
||||
boundaries.append((prompt_time + 0.3, labels[i]))
|
||||
continue
|
||||
|
||||
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
|
||||
|
||||
# Search forward for onset
|
||||
search_end = min(len(envelope), prompt_idx + search_samples)
|
||||
onset_idx = None
|
||||
|
||||
for j in range(prompt_idx, search_end):
|
||||
if envelope[j] > threshold:
|
||||
onset_idx = j
|
||||
break
|
||||
|
||||
if onset_idx is not None:
|
||||
onset_time = raw_timestamps[onset_idx]
|
||||
else:
|
||||
onset_time = prompt_time + 0.3 # fallback
|
||||
|
||||
boundaries.append((onset_time, labels[i]))
|
||||
|
||||
# Assign labels based on detected boundaries
|
||||
aligned = []
|
||||
boundary_idx = 0
|
||||
current_label = labels[0]
|
||||
|
||||
for t in window_start_times:
|
||||
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
|
||||
current_label = boundaries[boundary_idx][1]
|
||||
boundary_idx += 1
|
||||
aligned.append(current_label)
|
||||
|
||||
return aligned
|
||||
|
||||
|
||||
def filter_transition_windows(
|
||||
X: np.ndarray,
|
||||
y: np.ndarray,
|
||||
labels: list[str],
|
||||
start_times: np.ndarray,
|
||||
end_times: np.ndarray,
|
||||
transition_start_ms: float = TRANSITION_START_MS,
|
||||
transition_end_ms: float = TRANSITION_END_MS
|
||||
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||
"""
|
||||
Filter out windows that fall within transition zones at gesture boundaries.
|
||||
|
||||
This removes ambiguous data where:
|
||||
- User is still reacting to prompt (start of gesture)
|
||||
- User is anticipating next gesture (end of gesture)
|
||||
|
||||
Args:
|
||||
X: EMG data array (n_windows, samples, channels)
|
||||
y: Label indices (n_windows,)
|
||||
labels: String labels (n_windows,)
|
||||
start_times: Window start times in seconds (n_windows,)
|
||||
end_times: Window end times in seconds (n_windows,)
|
||||
transition_start_ms: Discard windows within this time after gesture start
|
||||
transition_end_ms: Discard windows within this time before gesture end
|
||||
|
||||
Returns:
|
||||
Filtered (X, y, labels) with transition windows removed
|
||||
"""
|
||||
if len(X) == 0:
|
||||
return X, y, labels
|
||||
|
||||
transition_start_sec = transition_start_ms / 1000.0
|
||||
transition_end_sec = transition_end_ms / 1000.0
|
||||
|
||||
# Find gesture boundaries (where label changes)
|
||||
# Each boundary is the START of a new gesture segment
|
||||
boundaries = [0] # First window starts a segment
|
||||
for i in range(1, len(labels)):
|
||||
if labels[i] != labels[i-1]:
|
||||
boundaries.append(i)
|
||||
boundaries.append(len(labels)) # End marker
|
||||
|
||||
# For each segment, find start_time and end_time of the gesture
|
||||
# Then mark windows that are within transition zones
|
||||
keep_mask = np.ones(len(X), dtype=bool)
|
||||
|
||||
for seg_idx in range(len(boundaries) - 1):
|
||||
seg_start_idx = boundaries[seg_idx]
|
||||
seg_end_idx = boundaries[seg_idx + 1]
|
||||
|
||||
# Get the time boundaries of this gesture segment
|
||||
gesture_start_time = start_times[seg_start_idx]
|
||||
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
|
||||
|
||||
# Mark windows in transition zones
|
||||
for i in range(seg_start_idx, seg_end_idx):
|
||||
window_start = start_times[i]
|
||||
window_end = end_times[i]
|
||||
|
||||
# Check if window is too close to gesture START (reaction time zone)
|
||||
if window_start < gesture_start_time + transition_start_sec:
|
||||
keep_mask[i] = False
|
||||
|
||||
# Check if window is too close to gesture END (anticipation zone)
|
||||
if window_end > gesture_end_time - transition_end_sec:
|
||||
keep_mask[i] = False
|
||||
|
||||
# Apply filter
|
||||
X_filtered = X[keep_mask]
|
||||
y_filtered = y[keep_mask]
|
||||
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
|
||||
|
||||
n_removed = len(X) - len(X_filtered)
|
||||
if n_removed > 0:
|
||||
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
|
||||
print(f"[Filter] Kept {len(X_filtered)} windows for training")
|
||||
|
||||
return X_filtered, y_filtered, labels_filtered
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SESSION STORAGE (Save/Load labeled data to HDF5)
|
||||
# =============================================================================
|
||||
@@ -527,16 +718,24 @@ class SessionStorage:
|
||||
windows: list[EMGWindow],
|
||||
labels: list[str],
|
||||
metadata: SessionMetadata,
|
||||
raw_samples: Optional[list[EMGSample]] = None
|
||||
raw_samples: Optional[list[EMGSample]] = None,
|
||||
session_start_time: Optional[float] = None,
|
||||
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
|
||||
) -> Path:
|
||||
"""
|
||||
Save a collection session to HDF5.
|
||||
Save a collection session to HDF5 with optional label alignment.
|
||||
|
||||
When raw_samples and session_start_time are provided and enable_alignment
|
||||
is True, automatically detects EMG onset and corrects labels for human
|
||||
reaction time delay.
|
||||
|
||||
Args:
|
||||
windows: List of EMGWindow objects (no label info)
|
||||
labels: List of gesture labels, parallel to windows
|
||||
metadata: Session metadata
|
||||
raw_samples: Optional raw samples for debugging
|
||||
raw_samples: Raw samples (required for alignment)
|
||||
session_start_time: When session started (required for alignment)
|
||||
enable_alignment: Whether to perform automatic label alignment
|
||||
"""
|
||||
filepath = self.get_session_filepath(metadata.session_id)
|
||||
|
||||
@@ -549,6 +748,35 @@ class SessionStorage:
|
||||
window_samples = len(windows[0].samples)
|
||||
num_channels = len(windows[0].samples[0].channels)
|
||||
|
||||
# Prepare timing arrays
|
||||
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
|
||||
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
|
||||
|
||||
# Label alignment using onset detection
|
||||
aligned_labels = labels
|
||||
original_labels = labels
|
||||
|
||||
if enable_alignment and raw_samples and len(raw_samples) > 0:
|
||||
print("[Storage] Aligning labels to EMG onset...")
|
||||
|
||||
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
||||
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
|
||||
|
||||
aligned_labels = align_labels_with_onset(
|
||||
labels=labels,
|
||||
window_start_times=start_times,
|
||||
raw_timestamps=raw_timestamps,
|
||||
raw_channels=raw_channels,
|
||||
sampling_rate=metadata.sampling_rate,
|
||||
threshold_factor=ONSET_THRESHOLD,
|
||||
search_ms=ONSET_SEARCH_MS
|
||||
)
|
||||
|
||||
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
|
||||
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
|
||||
elif enable_alignment:
|
||||
print("[Storage] Warning: No raw samples, skipping alignment")
|
||||
|
||||
with h5py.File(filepath, 'w') as f:
|
||||
# Metadata as attributes
|
||||
f.attrs['user_id'] = metadata.user_id
|
||||
@@ -568,19 +796,24 @@ class SessionStorage:
|
||||
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
|
||||
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
|
||||
|
||||
# Labels stored separately from window data
|
||||
max_label_len = max(len(l) for l in labels)
|
||||
# Store ALIGNED labels as primary (what training will use)
|
||||
max_label_len = max(len(l) for l in aligned_labels)
|
||||
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
|
||||
windows_grp.create_dataset('labels', data=labels, dtype=dt)
|
||||
windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
|
||||
|
||||
# Also store original labels for reference/debugging
|
||||
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
|
||||
|
||||
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
|
||||
windows_grp.create_dataset('window_ids', data=window_ids)
|
||||
|
||||
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
|
||||
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
|
||||
windows_grp.create_dataset('start_times', data=start_times)
|
||||
windows_grp.create_dataset('end_times', data=end_times)
|
||||
|
||||
# Store alignment metadata
|
||||
f.attrs['alignment_enabled'] = enable_alignment
|
||||
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
|
||||
|
||||
if raw_samples:
|
||||
raw_grp = f.create_group('raw_samples')
|
||||
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
||||
@@ -655,13 +888,21 @@ class SessionStorage:
|
||||
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
|
||||
return windows, labels_out, metadata
|
||||
|
||||
def load_for_training(self, session_id: str) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||
"""Load a single session in ML-ready format: X, y, label_names."""
|
||||
def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||
"""
|
||||
Load a single session in ML-ready format: X, y, label_names.
|
||||
|
||||
Args:
|
||||
session_id: The session to load
|
||||
filter_transitions: If True, remove windows in transition zones (default from config)
|
||||
"""
|
||||
filepath = self.get_session_filepath(session_id)
|
||||
|
||||
with h5py.File(filepath, 'r') as f:
|
||||
X = f['windows/emg_data'][:]
|
||||
labels_raw = f['windows/labels'][:]
|
||||
start_times = f['windows/start_times'][:]
|
||||
end_times = f['windows/end_times'][:]
|
||||
|
||||
labels = []
|
||||
for l in labels_raw:
|
||||
@@ -670,18 +911,33 @@ class SessionStorage:
|
||||
else:
|
||||
labels.append(l)
|
||||
|
||||
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
|
||||
|
||||
# Apply transition filtering if enabled
|
||||
if filter_transitions:
|
||||
label_names_pre = sorted(set(labels))
|
||||
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
|
||||
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
|
||||
|
||||
X, y_pre, labels = filter_transition_windows(
|
||||
X, y_pre, labels, start_times, end_times
|
||||
)
|
||||
|
||||
label_names = sorted(set(labels))
|
||||
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
|
||||
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
|
||||
|
||||
print(f"[Storage] Loaded for training: X{X.shape}, y{y.shape}")
|
||||
print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
|
||||
print(f"[Storage] Labels: {label_names}")
|
||||
return X, y, label_names
|
||||
|
||||
def load_all_for_training(self) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
|
||||
def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
|
||||
"""
|
||||
Load ALL sessions combined into a single training dataset.
|
||||
|
||||
Args:
|
||||
filter_transitions: If True, remove windows in transition zones (default from config)
|
||||
|
||||
Returns:
|
||||
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
|
||||
y: Combined labels as integers (n_total_windows,)
|
||||
@@ -697,11 +953,15 @@ class SessionStorage:
|
||||
raise ValueError("No sessions found to load!")
|
||||
|
||||
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
|
||||
if filter_transitions:
|
||||
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
|
||||
|
||||
all_X = []
|
||||
all_labels = []
|
||||
loaded_sessions = []
|
||||
reference_shape = None
|
||||
total_removed = 0
|
||||
total_original = 0
|
||||
|
||||
for session_id in sessions:
|
||||
filepath = self.get_session_filepath(session_id)
|
||||
@@ -709,6 +969,8 @@ class SessionStorage:
|
||||
with h5py.File(filepath, 'r') as f:
|
||||
X = f['windows/emg_data'][:]
|
||||
labels_raw = f['windows/labels'][:]
|
||||
start_times = f['windows/start_times'][:]
|
||||
end_times = f['windows/end_times'][:]
|
||||
|
||||
# Validate shape compatibility
|
||||
if reference_shape is None:
|
||||
@@ -725,10 +987,26 @@ class SessionStorage:
|
||||
else:
|
||||
labels.append(l)
|
||||
|
||||
original_count = len(X)
|
||||
total_original += original_count
|
||||
|
||||
# Apply transition filtering per session (each has its own gesture boundaries)
|
||||
if filter_transitions:
|
||||
# Need temporary y for filtering function
|
||||
temp_label_names = sorted(set(labels))
|
||||
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
|
||||
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
|
||||
|
||||
X, temp_y, labels = filter_transition_windows(
|
||||
X, temp_y, labels, start_times, end_times
|
||||
)
|
||||
total_removed += original_count - len(X)
|
||||
|
||||
all_X.append(X)
|
||||
all_labels.extend(labels)
|
||||
loaded_sessions.append(session_id)
|
||||
print(f"[Storage] - {session_id}: {X.shape[0]} windows")
|
||||
print(f"[Storage] - {session_id}: {len(X)} windows" +
|
||||
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
|
||||
|
||||
if not all_X:
|
||||
raise ValueError("No compatible sessions found!")
|
||||
@@ -742,6 +1020,8 @@ class SessionStorage:
|
||||
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
|
||||
|
||||
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
|
||||
if filter_transitions and total_removed > 0:
|
||||
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
|
||||
print(f"[Storage] Labels: {label_names}")
|
||||
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user