Files
EMG_Arm/learning_data_collection.py

2641 lines
98 KiB
Python
Raw Normal View History

"""
EMG Data Collection Pipeline
============================
A complete pipeline for collecting, labeling, and classifying EMG signals.
OPTIONS:
1. Collect Data - Run a labeled collection session with timed prompts
2. Inspect Data - Load saved sessions, view raw EMG and features
3. Train Classifier - Train LDA on collected data with cross-validation
q. Quit
FEATURES:
- Simulated EMG stream (swap for serial.Serial with real hardware)
- Timed prompt system for consistent data collection
- Automatic labeling based on prompt timing
- HDF5 storage with metadata
- Time-domain feature extraction (RMS, WL, ZC, SSC)
- LDA classifier with evaluation metrics
"""
import time
import threading
import queue
import numpy as np
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
from datetime import datetime
import json
import h5py
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_score, train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import joblib # For model persistence
import matplotlib.pyplot as plt
# =============================================================================
# CONFIGURATION
# =============================================================================
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
2026-01-19 22:24:04 -06:00
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
# Windowing configuration
WINDOW_SIZE_MS = 150 # Window size in milliseconds
WINDOW_OVERLAP = 0.0 # Overlap ratio (0.0 = no overlap, 0.5 = 50% overlap)
# Labeling configuration
GESTURE_HOLD_SEC = 3.0 # How long to hold each gesture
REST_BETWEEN_SEC = 2.0 # Rest period between gestures
REPS_PER_GESTURE = 3 # Repetitions per gesture in a session
# Storage configuration
DATA_DIR = Path("collected_data") # Directory to store session files
MODEL_DIR = Path("models") # Directory to store trained models
USER_ID = "user_001" # Current user ID (change per user)
# =============================================================================
# LABEL ALIGNMENT CONFIGURATION
# =============================================================================
# Human reaction time causes EMG activity to lag behind label prompts.
# We detect when EMG actually rises and shift labels to match.
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
# =============================================================================
# TRANSITION WINDOW FILTERING
# =============================================================================
# Windows near gesture transitions contain ambiguous data (reaction time at start,
# muscle relaxation at end). Discard these during training for cleaner labels.
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
# =============================================================================
# DATA STRUCTURES
# =============================================================================
@dataclass
class EMGSample:
"""Single sample from all channels at one point in time."""
timestamp: float # Python-side timestamp (seconds, monotonic)
channels: list[float] # Raw ADC values per channel
esp_timestamp_ms: Optional[int] = None # Optional: timestamp from ESP32
@dataclass
class EMGWindow:
"""
A window of samples - this is what we'll feed to ML models.
NOTE: This class intentionally contains NO label information.
Labels are stored separately to enforce training/inference separation.
This ensures inference code cannot accidentally access ground truth.
"""
window_id: int
start_time: float
end_time: float
samples: list[EMGSample]
def to_numpy(self) -> np.ndarray:
"""Convert to numpy array of shape (n_samples, n_channels)."""
return np.array([s.channels for s in self.samples])
def get_channel(self, ch: int) -> np.ndarray:
"""Get single channel as 1D array."""
return np.array([s.channels[ch] for s in self.samples])
# =============================================================================
# SIMULATED EMG STREAM (Mimics what ESP32 would send over serial)
# =============================================================================
class SimulatedEMGStream:
"""
Simulates ESP32 sending EMG data over serial.
In reality, you'd replace this with:
import serial
ser = serial.Serial('COM3', 115200)
line = ser.readline()
The ESP32 would send lines like:
"1234,512,489,501,523\n" (timestamp_ms, ch0, ch1, ch2, ch3)
"""
def __init__(self, num_channels: int = 4, sample_rate: int = 1000):
self.num_channels = num_channels
self.sample_rate = sample_rate
self.running = False
self.output_queue = queue.Queue()
self._thread = None
self._esp_time_ms = 0
def start(self):
"""Start the simulated data stream."""
self.running = True
self._thread = threading.Thread(target=self._generate_data, daemon=True)
self._thread.start()
print(f"[SIM] Started simulated EMG stream: {self.num_channels} channels @ {self.sample_rate}Hz")
def stop(self):
"""Stop the simulated data stream."""
self.running = False
if self._thread:
self._thread.join(timeout=1.0)
print("[SIM] Stopped simulated EMG stream")
def readline(self) -> Optional[str]:
"""
Mimics serial.readline() - blocks until data available.
Returns line in format: "timestamp_ms,ch0,ch1,ch2,ch3"
"""
try:
return self.output_queue.get(timeout=1.0)
except queue.Empty:
return None
def _generate_data(self):
"""Background thread that generates fake EMG data."""
interval = 1.0 / self.sample_rate
while self.running:
# Simulate EMG signal: baseline noise + occasional bursts
channels = []
for ch in range(self.num_channels):
# Base noise (typical ADC noise around 512 center for 10-bit ADC)
base = 512 + np.random.randn() * 10
# Occasionally add muscle activation burst (simulates gesture)
if np.random.random() < 0.01: # 1% chance each sample
base += np.random.randn() * 100
channels.append(int(np.clip(base, 0, 1023)))
# Format like ESP32 would send
line = f"{self._esp_time_ms},{','.join(map(str, channels))}\n"
self.output_queue.put(line)
self._esp_time_ms += 1 # Increment ESP32 timestamp
time.sleep(interval)
# =============================================================================
# DATA PARSER (Converts serial lines to EMGSample objects)
# =============================================================================
class EMGParser:
"""
Parses incoming serial data into structured EMGSample objects.
LESSON: Always validate incoming data. Serial lines can be:
- Corrupted (partial lines, garbage bytes)
- Missing (dropped packets)
- Out of order (buffer issues)
"""
def __init__(self, num_channels: int):
self.num_channels = num_channels
self.parse_errors = 0
self.samples_parsed = 0
def parse_line(self, line: str) -> Optional[EMGSample]:
"""
Parse a line from ESP32 into an EMGSample.
Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
Python assigns timestamp on receipt for label alignment.
Returns None if parsing fails.
"""
try:
# Strip whitespace and split
parts = line.strip().split(',')
# Validate we have correct number of fields (channels only)
if len(parts) != self.num_channels:
self.parse_errors += 1
return None
# Parse channel values
channels = [float(parts[i]) for i in range(self.num_channels)]
# Create sample with Python-side timestamp (aligned with label clock)
sample = EMGSample(
timestamp=time.perf_counter(), # High-resolution monotonic clock
channels=channels,
esp_timestamp_ms=None # No longer using ESP32 timestamp
)
self.samples_parsed += 1
return sample
except (ValueError, IndexError) as e:
self.parse_errors += 1
return None
# =============================================================================
# WINDOWING (Groups samples into fixed-size windows)
# =============================================================================
class Windower:
"""
Groups incoming samples into fixed-size windows.
LESSON: ML models need fixed-size inputs. We can't feed them a continuous
stream - we need to chunk it into windows of consistent size.
Window size tradeoffs:
- Too small (50ms): Not enough data, noisy features
- Too large (500ms): Slow response, gesture transitions blurred
- Sweet spot: 150-250ms for EMG gesture recognition
"""
def __init__(self, window_size_ms: int, sample_rate: int, overlap: float = 0.0):
self.window_size_ms = window_size_ms
self.sample_rate = sample_rate
self.overlap = overlap
# Calculate window size in samples
self.window_size_samples = int(window_size_ms / 1000 * sample_rate)
self.step_size_samples = int(self.window_size_samples * (1 - overlap))
# Buffer for incoming samples
self.buffer: list[EMGSample] = []
self.window_count = 0
print(f"[Windower] Window: {window_size_ms}ms = {self.window_size_samples} samples")
print(f"[Windower] Step: {self.step_size_samples} samples (overlap={overlap*100:.0f}%)")
def add_sample(self, sample: EMGSample) -> Optional[EMGWindow]:
"""
Add a sample to the buffer. Returns a window if we have enough samples.
Returns None if buffer isn't full yet.
"""
self.buffer.append(sample)
# Check if we have enough samples for a window
if len(self.buffer) >= self.window_size_samples:
# Extract window
window_samples = self.buffer[:self.window_size_samples]
window = EMGWindow(
window_id=self.window_count,
start_time=window_samples[0].timestamp,
end_time=window_samples[-1].timestamp,
samples=window_samples.copy()
)
self.window_count += 1
# Slide buffer by step size
self.buffer = self.buffer[self.step_size_samples:]
return window
return None
def flush(self) -> Optional[EMGWindow]:
"""Flush remaining samples as a partial window (if any)."""
if len(self.buffer) > 0:
window = EMGWindow(
window_id=self.window_count,
start_time=self.buffer[0].timestamp,
end_time=self.buffer[-1].timestamp,
samples=self.buffer.copy()
)
self.buffer = []
return window
return None
# =============================================================================
# PROMPT SYSTEM (Timed prompts for labeling)
# =============================================================================
@dataclass
class GesturePrompt:
"""Defines a single gesture prompt in the collection sequence."""
gesture_name: str # e.g., "index_flex", "rest", "fist"
duration_sec: float # How long to hold this gesture
start_time: float = 0.0 # Filled in by scheduler when session starts
@dataclass
class PromptSchedule:
"""A complete sequence of prompts for a collection session."""
prompts: list[GesturePrompt]
total_duration: float = 0.0
def __post_init__(self):
"""Calculate start times and total duration."""
current_time = 0.0
for prompt in self.prompts:
prompt.start_time = current_time
current_time += prompt.duration_sec
self.total_duration = current_time
class PromptScheduler:
"""
Manages timed prompts during data collection.
LESSON: Timed prompts give you consistent, repeatable data collection.
The user knows exactly when to perform each gesture, and you know
exactly when each gesture should be happening for labeling.
"""
def __init__(self, gestures: list[str], hold_sec: float, rest_sec: float, reps: int):
"""
Build a prompt schedule.
Args:
gestures: List of gesture names (e.g., ["index_flex", "fist"])
hold_sec: How long to hold each gesture
rest_sec: Rest period between gestures
reps: Number of repetitions per gesture
"""
self.gestures = gestures
self.hold_sec = hold_sec
self.rest_sec = rest_sec
self.reps = reps
# Build the schedule
self.schedule = self._build_schedule()
self.session_start_time: Optional[float] = None
def _build_schedule(self) -> PromptSchedule:
"""Create the sequence of prompts."""
prompts = []
# Initial rest period
prompts.append(GesturePrompt("rest", self.rest_sec))
# For each repetition
for rep in range(self.reps):
# Cycle through all gestures
for gesture in self.gestures:
prompts.append(GesturePrompt(gesture, self.hold_sec))
prompts.append(GesturePrompt("rest", self.rest_sec))
return PromptSchedule(prompts)
def start_session(self):
"""Mark the start of a collection session."""
self.session_start_time = time.perf_counter()
print(f"\n[Scheduler] Session started. Duration: {self.schedule.total_duration:.1f}s")
print(f"[Scheduler] {len(self.schedule.prompts)} prompts scheduled")
def get_current_prompt(self) -> Optional[GesturePrompt]:
"""Get the prompt that should be active right now."""
if self.session_start_time is None:
return None
elapsed = time.perf_counter() - self.session_start_time
# Find which prompt is active
for prompt in self.schedule.prompts:
prompt_end = prompt.start_time + prompt.duration_sec
if prompt.start_time <= elapsed < prompt_end:
return prompt
return None # Session complete
def get_elapsed_time(self) -> float:
"""Get seconds elapsed since session start."""
if self.session_start_time is None:
return 0.0
return time.perf_counter() - self.session_start_time
def is_session_complete(self) -> bool:
"""Check if we've passed the end of the schedule."""
return self.get_elapsed_time() >= self.schedule.total_duration
def get_label_for_time(self, timestamp: float) -> str:
"""
Get the gesture label for a specific timestamp.
This is used to label windows after collection.
"""
if self.session_start_time is None:
return "unlabeled"
elapsed = timestamp - self.session_start_time
for prompt in self.schedule.prompts:
prompt_end = prompt.start_time + prompt.duration_sec
if prompt.start_time <= elapsed < prompt_end:
return prompt.gesture_name
return "unlabeled"
def print_schedule(self):
"""Print the full prompt schedule."""
print("\n" + "-" * 40)
print("PROMPT SCHEDULE")
print("-" * 40)
for i, p in enumerate(self.schedule.prompts):
print(f" {i+1:2d}. [{p.start_time:5.1f}s - {p.start_time + p.duration_sec:5.1f}s] {p.gesture_name}")
print(f"\n Total duration: {self.schedule.total_duration:.1f}s")
# =============================================================================
# SIMULATED EMG STREAM (Gesture-aware signal generation)
# =============================================================================
class GestureAwareEMGStream(SimulatedEMGStream):
"""
Enhanced simulation that generates different EMG patterns based on
which gesture is currently being prompted.
This makes the simulated data more realistic for testing your pipeline.
Each gesture activates different "muscles" (channels) with different intensities.
"""
# Define which channels activate for each gesture (0-1 intensity per channel)
GESTURE_PATTERNS = {
"rest": [0.0, 0.0, 0.0, 0.0],
2026-01-19 22:24:04 -06:00
"open": [0.3, 0.3, 0.3, 0.3], # Moderate all channels (extension)
"fist": [0.7, 0.7, 0.6, 0.6], # All channels active (flexion)
"hook_em": [0.8, 0.2, 0.7, 0.1], # Index + pinky extended (ch0 + ch2)
"thumbs_up": [0.1, 0.1, 0.2, 0.8], # Thumb dominant (ch3)
}
def __init__(self, num_channels: int = 4, sample_rate: int = 1000):
super().__init__(num_channels, sample_rate)
self.current_gesture = "rest"
self._gesture_lock = threading.Lock()
def set_gesture(self, gesture: str):
"""Set the current gesture being performed."""
with self._gesture_lock:
self.current_gesture = gesture
def _generate_data(self):
"""Generate EMG data based on current gesture."""
interval = 1.0 / self.sample_rate
while self.running:
with self._gesture_lock:
gesture = self.current_gesture
# Get activation pattern for current gesture
pattern = self.GESTURE_PATTERNS.get(gesture, [0.0] * self.num_channels)
channels = []
for ch in range(self.num_channels):
# Base signal around 512 (10-bit ADC center)
base = 512
# Add noise (always present)
noise = np.random.randn() * 10
# Add muscle activation based on pattern
activation = pattern[ch] * np.random.randn() * 150 # Scaled EMG burst
# Combine and clip to ADC range
value = int(np.clip(base + noise + activation, 0, 1023))
channels.append(value)
# Format like ESP32 would send
line = f"{self._esp_time_ms},{','.join(map(str, channels))}\n"
self.output_queue.put(line)
self._esp_time_ms += 1
time.sleep(interval)
# =============================================================================
# LABEL ALIGNMENT (Simple Onset Detection)
# =============================================================================
from scipy.signal import butter, sosfiltfilt
def align_labels_with_onset(
labels: list[str],
window_start_times: np.ndarray,
raw_timestamps: np.ndarray,
raw_channels: np.ndarray,
sampling_rate: int,
threshold_factor: float = 2.0,
search_ms: float = 800
) -> list[str]:
"""
Align labels to EMG onset by detecting when signal rises above baseline.
Simple algorithm:
1. High-pass filter to remove DC offset
2. Compute RMS envelope across channels
3. At each label transition, find where envelope exceeds baseline + threshold
4. Move label boundary to that point
"""
if len(labels) == 0:
return labels.copy()
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
nyquist = sampling_rate / 2
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
# Filter and compute envelope (RMS across channels)
filtered = np.zeros_like(raw_channels)
for ch in range(raw_channels.shape[1]):
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
# Smooth envelope
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
envelope = sosfiltfilt(sos_lp, envelope)
# Find transitions and detect onsets
search_samples = int(search_ms / 1000 * sampling_rate)
baseline_samples = int(200 / 1000 * sampling_rate)
boundaries = [] # (time, new_label)
for i in range(1, len(labels)):
if labels[i] != labels[i - 1]:
prompt_time = window_start_times[i]
# Find index in raw signal closest to prompt time
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
# Get baseline (before transition)
base_start = max(0, prompt_idx - baseline_samples)
baseline = envelope[base_start:prompt_idx]
if len(baseline) == 0:
boundaries.append((prompt_time + 0.3, labels[i]))
continue
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
# Search forward for onset
search_end = min(len(envelope), prompt_idx + search_samples)
onset_idx = None
for j in range(prompt_idx, search_end):
if envelope[j] > threshold:
onset_idx = j
break
if onset_idx is not None:
onset_time = raw_timestamps[onset_idx]
else:
onset_time = prompt_time + 0.3 # fallback
boundaries.append((onset_time, labels[i]))
# Assign labels based on detected boundaries
aligned = []
boundary_idx = 0
current_label = labels[0]
for t in window_start_times:
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
current_label = boundaries[boundary_idx][1]
boundary_idx += 1
aligned.append(current_label)
return aligned
def filter_transition_windows(
X: np.ndarray,
y: np.ndarray,
labels: list[str],
start_times: np.ndarray,
end_times: np.ndarray,
transition_start_ms: float = TRANSITION_START_MS,
transition_end_ms: float = TRANSITION_END_MS
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Filter out windows that fall within transition zones at gesture boundaries.
This removes ambiguous data where:
- User is still reacting to prompt (start of gesture)
- User is anticipating next gesture (end of gesture)
Args:
X: EMG data array (n_windows, samples, channels)
y: Label indices (n_windows,)
labels: String labels (n_windows,)
start_times: Window start times in seconds (n_windows,)
end_times: Window end times in seconds (n_windows,)
transition_start_ms: Discard windows within this time after gesture start
transition_end_ms: Discard windows within this time before gesture end
Returns:
Filtered (X, y, labels) with transition windows removed
"""
if len(X) == 0:
return X, y, labels
transition_start_sec = transition_start_ms / 1000.0
transition_end_sec = transition_end_ms / 1000.0
# Find gesture boundaries (where label changes)
# Each boundary is the START of a new gesture segment
boundaries = [0] # First window starts a segment
for i in range(1, len(labels)):
if labels[i] != labels[i-1]:
boundaries.append(i)
boundaries.append(len(labels)) # End marker
# For each segment, find start_time and end_time of the gesture
# Then mark windows that are within transition zones
keep_mask = np.ones(len(X), dtype=bool)
for seg_idx in range(len(boundaries) - 1):
seg_start_idx = boundaries[seg_idx]
seg_end_idx = boundaries[seg_idx + 1]
# Get the time boundaries of this gesture segment
gesture_start_time = start_times[seg_start_idx]
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
# Mark windows in transition zones
for i in range(seg_start_idx, seg_end_idx):
window_start = start_times[i]
window_end = end_times[i]
# Check if window is too close to gesture START (reaction time zone)
if window_start < gesture_start_time + transition_start_sec:
keep_mask[i] = False
# Check if window is too close to gesture END (anticipation zone)
if window_end > gesture_end_time - transition_end_sec:
keep_mask[i] = False
# Apply filter
X_filtered = X[keep_mask]
y_filtered = y[keep_mask]
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
n_removed = len(X) - len(X_filtered)
if n_removed > 0:
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
print(f"[Filter] Kept {len(X_filtered)} windows for training")
return X_filtered, y_filtered, labels_filtered
# =============================================================================
# SESSION STORAGE (Save/Load labeled data to HDF5)
# =============================================================================
@dataclass
class SessionMetadata:
"""Metadata for a collection session."""
user_id: str
session_id: str
timestamp: str
sampling_rate: int
window_size_ms: int
num_channels: int
gestures: list[str]
notes: str = ""
class SessionStorage:
"""Handles saving and loading EMG collection sessions to HDF5 files."""
def __init__(self, data_dir: Path = DATA_DIR):
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
def generate_session_id(self, user_id: str) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return f"{user_id}_{timestamp}"
def get_session_filepath(self, session_id: str) -> Path:
return self.data_dir / f"{session_id}.hdf5"
def save_session(
self,
windows: list[EMGWindow],
labels: list[str],
metadata: SessionMetadata,
raw_samples: Optional[list[EMGSample]] = None,
session_start_time: Optional[float] = None,
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
) -> Path:
"""
Save a collection session to HDF5 with optional label alignment.
When raw_samples and session_start_time are provided and enable_alignment
is True, automatically detects EMG onset and corrects labels for human
reaction time delay.
Args:
windows: List of EMGWindow objects (no label info)
labels: List of gesture labels, parallel to windows
metadata: Session metadata
raw_samples: Raw samples (required for alignment)
session_start_time: When session started (required for alignment)
enable_alignment: Whether to perform automatic label alignment
"""
filepath = self.get_session_filepath(metadata.session_id)
if not windows:
raise ValueError("No windows to save!")
if len(windows) != len(labels):
raise ValueError(f"Windows ({len(windows)}) and labels ({len(labels)}) must have same length!")
window_samples = len(windows[0].samples)
num_channels = len(windows[0].samples[0].channels)
# Prepare timing arrays
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
# Label alignment using onset detection
aligned_labels = labels
original_labels = labels
if enable_alignment and raw_samples and len(raw_samples) > 0:
print("[Storage] Aligning labels to EMG onset...")
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
aligned_labels = align_labels_with_onset(
labels=labels,
window_start_times=start_times,
raw_timestamps=raw_timestamps,
raw_channels=raw_channels,
sampling_rate=metadata.sampling_rate,
threshold_factor=ONSET_THRESHOLD,
search_ms=ONSET_SEARCH_MS
)
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
elif enable_alignment:
print("[Storage] Warning: No raw samples, skipping alignment")
with h5py.File(filepath, 'w') as f:
# Metadata as attributes
f.attrs['user_id'] = metadata.user_id
f.attrs['session_id'] = metadata.session_id
f.attrs['timestamp'] = metadata.timestamp
f.attrs['sampling_rate'] = metadata.sampling_rate
f.attrs['window_size_ms'] = metadata.window_size_ms
f.attrs['num_channels'] = metadata.num_channels
f.attrs['gestures'] = json.dumps(metadata.gestures)
f.attrs['notes'] = metadata.notes
f.attrs['num_windows'] = len(windows)
f.attrs['window_samples'] = window_samples
# Windows group
windows_grp = f.create_group('windows')
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
# Store ALIGNED labels as primary (what training will use)
max_label_len = max(len(l) for l in aligned_labels)
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
# Also store original labels for reference/debugging
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
windows_grp.create_dataset('window_ids', data=window_ids)
windows_grp.create_dataset('start_times', data=start_times)
windows_grp.create_dataset('end_times', data=end_times)
# Store alignment metadata
f.attrs['alignment_enabled'] = enable_alignment
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
if raw_samples:
raw_grp = f.create_group('raw_samples')
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
raw_grp.create_dataset('timestamps', data=timestamps, compression='gzip')
raw_grp.create_dataset('channels', data=channels, compression='gzip')
print(f"[Storage] Saved session to: {filepath}")
print(f"[Storage] File size: {filepath.stat().st_size / 1024:.1f} KB")
return filepath
def load_session(self, session_id: str) -> tuple[list[EMGWindow], list[str], SessionMetadata]:
"""
Load a collection session from HDF5.
Returns:
windows: List of EMGWindow objects (no label info)
labels: List of gesture labels, parallel to windows
metadata: Session metadata
"""
filepath = self.get_session_filepath(session_id)
if not filepath.exists():
raise FileNotFoundError(f"Session not found: {filepath}")
windows = []
labels_out = []
with h5py.File(filepath, 'r') as f:
metadata = SessionMetadata(
user_id=f.attrs['user_id'],
session_id=f.attrs['session_id'],
timestamp=f.attrs['timestamp'],
sampling_rate=int(f.attrs['sampling_rate']),
window_size_ms=int(f.attrs['window_size_ms']),
num_channels=int(f.attrs['num_channels']),
gestures=json.loads(f.attrs['gestures']),
notes=f.attrs.get('notes', '')
)
windows_grp = f['windows']
emg_data = windows_grp['emg_data'][:]
labels_raw = windows_grp['labels'][:]
window_ids = windows_grp['window_ids'][:]
start_times = windows_grp['start_times'][:]
end_times = windows_grp['end_times'][:]
for i in range(len(emg_data)):
samples = []
window_data = emg_data[i]
for j in range(len(window_data)):
sample = EMGSample(
timestamp=start_times[i] + j * (1.0 / metadata.sampling_rate),
channels=window_data[j].tolist()
)
samples.append(sample)
# Decode label
label = labels_raw[i]
if isinstance(label, bytes):
label = label.decode('utf-8')
labels_out.append(label)
# Window contains NO label - labels stored separately
window = EMGWindow(
window_id=int(window_ids[i]),
start_time=float(start_times[i]),
end_time=float(end_times[i]),
samples=samples
)
windows.append(window)
print(f"[Storage] Loaded session: {session_id}")
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
return windows, labels_out, metadata
def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""
Load a single session in ML-ready format: X, y, label_names.
Args:
session_id: The session to load
filter_transitions: If True, remove windows in transition zones (default from config)
"""
filepath = self.get_session_filepath(session_id)
with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
labels = []
for l in labels_raw:
if isinstance(l, bytes):
labels.append(l.decode('utf-8'))
else:
labels.append(l)
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
# Apply transition filtering if enabled
if filter_transitions:
label_names_pre = sorted(set(labels))
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
X, y_pre, labels = filter_transition_windows(
X, y_pre, labels, start_times, end_times
)
label_names = sorted(set(labels))
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
print(f"[Storage] Labels: {label_names}")
return X, y, label_names
def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
"""
Load ALL sessions combined into a single training dataset.
Args:
filter_transitions: If True, remove windows in transition zones (default from config)
Returns:
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
y: Combined labels as integers (n_total_windows,)
label_names: Sorted list of unique gesture labels across all sessions
session_ids: List of session IDs that were loaded
Raises:
ValueError: If no sessions found or sessions have incompatible shapes
"""
sessions = self.list_sessions()
if not sessions:
raise ValueError("No sessions found to load!")
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
if filter_transitions:
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
all_X = []
all_labels = []
loaded_sessions = []
reference_shape = None
total_removed = 0
total_original = 0
for session_id in sessions:
filepath = self.get_session_filepath(session_id)
with h5py.File(filepath, 'r') as f:
X = f['windows/emg_data'][:]
labels_raw = f['windows/labels'][:]
start_times = f['windows/start_times'][:]
end_times = f['windows/end_times'][:]
# Validate shape compatibility
if reference_shape is None:
reference_shape = X.shape[1:] # (samples_per_window, channels)
elif X.shape[1:] != reference_shape:
print(f"[Storage] WARNING: Skipping {session_id} - incompatible shape {X.shape[1:]} vs {reference_shape}")
continue
# Decode labels
labels = []
for l in labels_raw:
if isinstance(l, bytes):
labels.append(l.decode('utf-8'))
else:
labels.append(l)
original_count = len(X)
total_original += original_count
# Apply transition filtering per session (each has its own gesture boundaries)
if filter_transitions:
# Need temporary y for filtering function
temp_label_names = sorted(set(labels))
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
X, temp_y, labels = filter_transition_windows(
X, temp_y, labels, start_times, end_times
)
total_removed += original_count - len(X)
all_X.append(X)
all_labels.extend(labels)
loaded_sessions.append(session_id)
print(f"[Storage] - {session_id}: {len(X)} windows" +
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
if not all_X:
raise ValueError("No compatible sessions found!")
# Combine all data
X_combined = np.concatenate(all_X, axis=0)
# Create unified label mapping across all sessions
label_names = sorted(set(all_labels))
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
if filter_transitions and total_removed > 0:
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
print(f"[Storage] Labels: {label_names}")
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")
return X_combined, y_combined, label_names, loaded_sessions
def list_sessions(self) -> list[str]:
"""List all available session IDs."""
return sorted([f.stem for f in self.data_dir.glob("*.hdf5")])
def get_session_info(self, session_id: str) -> dict:
"""Get quick info about a session without loading all data."""
filepath = self.get_session_filepath(session_id)
with h5py.File(filepath, 'r') as f:
return {
'user_id': f.attrs['user_id'],
'timestamp': f.attrs['timestamp'],
'num_windows': f.attrs['num_windows'],
'gestures': json.loads(f.attrs['gestures']),
'sampling_rate': f.attrs['sampling_rate'],
'window_size_ms': f.attrs['window_size_ms'],
}
# =============================================================================
# COLLECTION LOOP (Core collection pattern)
# =============================================================================
def run_collection_demo(duration_seconds: float = 5.0):
"""
Demonstrates the core data collection loop.
This is the pattern you'll use with real hardware - only the
data source changes (SimulatedEMGStream -> serial.Serial).
"""
print("\n" + "=" * 60)
print("EMG DATA COLLECTION DEMO")
print("=" * 60)
# Initialize components
stream = SimulatedEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
parser = EMGParser(num_channels=NUM_CHANNELS)
# Storage for collected samples
collected_samples: list[EMGSample] = []
# Start the stream
stream.start()
start_time = time.perf_counter()
print(f"\nCollecting for {duration_seconds} seconds...")
print("(In real use, this would read from serial port)\n")
try:
while (time.perf_counter() - start_time) < duration_seconds:
# Read line from stream (blocks briefly if no data)
line = stream.readline()
if line:
# Parse into structured sample
sample = parser.parse_line(line)
if sample:
collected_samples.append(sample)
# Print progress every 500 samples
if len(collected_samples) % 500 == 0:
elapsed = time.perf_counter() - start_time
rate = len(collected_samples) / elapsed
print(f" Collected {len(collected_samples)} samples "
f"({rate:.1f} samples/sec)")
except KeyboardInterrupt:
print("\n[Interrupted by user]")
finally:
stream.stop()
# Report results
print("\n" + "-" * 40)
print("COLLECTION RESULTS")
print("-" * 40)
print(f"Total samples collected: {len(collected_samples)}")
print(f"Parse errors: {parser.parse_errors}")
print(f"Actual duration: {time.perf_counter() - start_time:.2f}s")
if collected_samples:
actual_rate = len(collected_samples) / (time.perf_counter() - start_time)
print(f"Effective sample rate: {actual_rate:.1f} Hz")
# Show sample data structure
print("\nExample sample:")
s = collected_samples[0]
print(f" timestamp: {s.timestamp:.6f}")
print(f" channels: {s.channels}")
print(f" esp_timestamp_ms: {s.esp_timestamp_ms}")
# Quick statistics
data = np.array([s.channels for s in collected_samples])
print(f"\nChannel statistics (mean/std):")
for ch in range(NUM_CHANNELS):
print(f" Ch{ch}: {data[:, ch].mean():.1f} / {data[:, ch].std():.1f}")
return collected_samples
# =============================================================================
# COLLECTION SESSION
# =============================================================================
def run_labeled_collection_demo():
"""
Run a labeled EMG collection session:
1. Prompt scheduler guides the user through gestures
2. EMG stream generates/collects signals
3. Windower groups samples into fixed-size windows
4. Labels are assigned based on which prompt was active
5. Session is saved to HDF5 with user ID
"""
print("\n" + "=" * 60)
print("LABELED EMG COLLECTION")
print("=" * 60)
# Get user ID
user_id = input("\nEnter your user ID (e.g., user_001): ").strip()
if not user_id:
user_id = USER_ID # Fall back to default
print(f" Using default: {user_id}")
else:
print(f" User ID: {user_id}")
2026-01-19 22:24:04 -06:00
# Define gestures to collect (names match ESP32 gesture definitions)
gestures = ["open", "fist", "hook_em", "thumbs_up"]
# Create the prompt scheduler
scheduler = PromptScheduler(
gestures=gestures,
hold_sec=GESTURE_HOLD_SEC,
rest_sec=REST_BETWEEN_SEC,
reps=REPS_PER_GESTURE
)
scheduler.print_schedule()
# Create components
stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
parser = EMGParser(num_channels=NUM_CHANNELS)
windower = Windower(
window_size_ms=WINDOW_SIZE_MS,
sample_rate=SAMPLING_RATE_HZ,
overlap=WINDOW_OVERLAP
)
# Storage for windows and labels (kept separate to enforce training/inference separation)
collected_windows: list[EMGWindow] = []
collected_labels: list[str] = []
last_prompt_name = None
# Start collection
input("\nPress ENTER to start collection session...")
stream.start()
scheduler.start_session()
print("\n" + "-" * 40)
print("COLLECTING... Watch the prompts!")
print("-" * 40)
try:
while not scheduler.is_session_complete():
# Get current prompt
prompt = scheduler.get_current_prompt()
# Display prompt changes
if prompt and prompt.gesture_name != last_prompt_name:
elapsed = scheduler.get_elapsed_time()
if prompt.gesture_name == "rest":
print(f"\n [{elapsed:5.1f}s] >>> REST <<<")
else:
print(f"\n [{elapsed:5.1f}s] >>> {prompt.gesture_name.upper()} <<<")
last_prompt_name = prompt.gesture_name
# Update simulated stream to generate appropriate signal
stream.set_gesture(prompt.gesture_name)
# Read and parse data
line = stream.readline()
if line:
sample = parser.parse_line(line)
if sample:
# Try to form a window
window = windower.add_sample(sample)
if window:
# Store window and label separately (training/inference separation)
label = scheduler.get_label_for_time(window.start_time)
collected_windows.append(window)
collected_labels.append(label)
except KeyboardInterrupt:
print("\n[Interrupted by user]")
finally:
stream.stop()
# Report results
print("\n" + "=" * 60)
print("COLLECTION COMPLETE")
print("=" * 60)
print(f"Total windows collected: {len(collected_windows)}")
print(f"Parse errors: {parser.parse_errors}")
# Count labels (from separate labels list, not from windows)
label_counts = {}
for label in collected_labels:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"\nWindows per label:")
for label, count in sorted(label_counts.items()):
print(f" {label}: {count}")
# Show example windows
if collected_windows:
print(f"\nExample windows:")
for i, w in enumerate(collected_windows[:3]):
data = w.to_numpy()
print(f" Window {w.window_id}: label='{collected_labels[i]}', "
f"samples={len(w.samples)}, "
f"ch0_mean={data[:, 0].mean():.1f}")
# Show one window from each gesture type
print(f"\nSignal comparison (channel 0 std dev by gesture):")
for label in sorted(label_counts.keys()):
# Get indices where label matches
indices = [i for i, l in enumerate(collected_labels) if l == label]
if indices:
all_ch0 = np.concatenate([collected_windows[i].get_channel(0) for i in indices])
print(f" {label}: std={all_ch0.std():.1f}")
# --- Save the session ---
if collected_windows:
save_choice = input("\nSave this session? (y/n): ").strip().lower()
if save_choice == 'y':
storage = SessionStorage()
session_id = storage.generate_session_id(user_id)
metadata = SessionMetadata(
user_id=user_id,
session_id=session_id,
timestamp=datetime.now().isoformat(),
sampling_rate=SAMPLING_RATE_HZ,
window_size_ms=WINDOW_SIZE_MS,
num_channels=NUM_CHANNELS,
gestures=gestures,
notes=""
)
# Pass windows and labels separately (enforces separation)
filepath = storage.save_session(collected_windows, collected_labels, metadata)
print(f"\nSession saved! ID: {session_id}")
return collected_windows, collected_labels
# =============================================================================
# INSPECT SESSIONS (Load and view saved sessions)
# =============================================================================
def run_storage_demo():
"""Demonstrates loading and inspecting saved sessions."""
print("\n" + "=" * 60)
print("INSPECT SAVED SESSIONS")
print("=" * 60)
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
print("\nNo saved sessions found!")
print(f"Run option 2 first to collect and save a session.")
print(f"Sessions are stored in: {storage.data_dir.absolute()}")
return None
print(f"\nFound {len(sessions)} saved session(s):")
print("-" * 40)
for i, session_id in enumerate(sessions):
info = storage.get_session_info(session_id)
print(f"\n [{i+1}] {session_id}")
print(f" User: {info['user_id']}")
print(f" Time: {info['timestamp']}")
print(f" Windows: {info['num_windows']}")
print(f" Gestures: {info['gestures']}")
print("\n" + "-" * 40)
choice = input("Enter session number to load (or 'q' to quit): ").strip()
if choice.lower() == 'q':
return None
try:
idx = int(choice) - 1
if idx < 0 or idx >= len(sessions):
print("Invalid selection!")
return None
session_id = sessions[idx]
except ValueError:
print("Invalid input!")
return None
print(f"\n{'=' * 60}")
print(f"LOADING SESSION: {session_id}")
print("=" * 60)
# Labels returned separately from windows (enforces training/inference separation)
windows, labels, metadata = storage.load_session(session_id)
print(f"\nMetadata:")
print(f" User: {metadata.user_id}")
print(f" Timestamp: {metadata.timestamp}")
print(f" Sampling rate: {metadata.sampling_rate} Hz")
print(f" Window size: {metadata.window_size_ms} ms")
print(f" Channels: {metadata.num_channels}")
print(f" Gestures: {metadata.gestures}")
# Count from separate labels list
label_counts = {}
for label in labels:
label_counts[label] = label_counts.get(label, 0) + 1
print(f"\nLabel distribution:")
for label, count in sorted(label_counts.items()):
print(f" {label}: {count} windows")
print(f"\n{'-' * 40}")
print("LOADING FOR MACHINE LEARNING")
print("-" * 40)
X, y, label_names = storage.load_for_training(session_id)
print(f"\nData shapes:")
print(f" X (features): {X.shape}")
print(f" - {X.shape[0]} windows")
print(f" - {X.shape[1]} samples per window")
print(f" - {X.shape[2]} channels")
print(f" y (labels): {y.shape}")
print(f" Label mapping: {dict(enumerate(label_names))}")
# --- Feature Extraction & Visualization ---
print(f"\n{'-' * 40}")
print("EXTRACTING FEATURES FOR VISUALIZATION")
print("-" * 40)
# Note: Per-window centering is done inside EMGFeatureExtractor.
# This is the correct approach for real-time inference (causal, no future data).
# Global centering across all windows would leak information and not work in real-time.
extractor = EMGFeatureExtractor()
n_windows = X.shape[0]
n_channels = X.shape[2]
# Extract features per channel: shape (n_windows, n_channels, 4)
# Features order: [rms, wl, zc, ssc]
features_by_channel = np.zeros((n_windows, n_channels, 4))
for i in range(n_windows):
for ch in range(n_channels):
ch_features = extractor.extract_features_single_channel(X[i, :, ch])
features_by_channel[i, ch, 0] = ch_features['rms']
features_by_channel[i, ch, 1] = ch_features['wl']
features_by_channel[i, ch, 2] = ch_features['zc']
features_by_channel[i, ch, 3] = ch_features['ssc']
print(f" Extracted features for {n_windows} windows, {n_channels} channels")
print(f" (Per-window centering applied inside feature extractor)")
# Create time axis (window indices as proxy for time)
time_axis = np.arange(n_windows)
# Find gesture transition points (where label changes)
transitions = []
current_label = y[0]
for i in range(1, len(y)):
if y[i] != current_label:
transitions.append((i, label_names[y[i]]))
current_label = y[i]
# Define colors for gesture markers
def get_gesture_color(name):
if 'index' in name.lower():
return 'green'
elif 'fist' in name.lower():
return 'blue'
elif 'rest' in name.lower():
return 'gray'
elif 'thumb' in name.lower():
return 'orange'
elif 'middle' in name.lower():
return 'purple'
return 'red'
feature_titles = ['RMS', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)']
feature_colors = ['red', 'blue', 'green', 'purple']
feature_ylabels = ['Amplitude', 'WL (a.u.)', 'Count', 'Count']
# --- Figure 1: Raw EMG Signal ---
fig_raw, axes_raw = plt.subplots(2, 2, figsize=(14, 8), sharex=True)
axes_raw = axes_raw.flatten()
# Concatenate all windows to show continuous signal
samples_per_window = X.shape[1]
total_samples = n_windows * samples_per_window
raw_time = np.arange(total_samples)
for ch in range(n_channels):
ax = axes_raw[ch]
# Flatten all windows into continuous signal for this channel
# Center per-channel for visualization only (subtract channel mean)
raw_signal = X[:, :, ch].flatten()
raw_signal_centered = raw_signal - raw_signal.mean()
ax.plot(raw_time, raw_signal_centered, linewidth=0.5, color='black')
ax.axhline(0, color='gray', linestyle='-', linewidth=0.5, alpha=0.5)
ax.set_title(f"Channel {ch}", fontsize=11)
ax.set_ylabel("Amplitude (centered)")
ax.grid(True, alpha=0.3)
# Add vertical lines at gesture transitions (scaled to sample index)
for trans_idx, trans_label in transitions:
sample_idx = trans_idx * samples_per_window
color = get_gesture_color(trans_label)
ax.axvline(sample_idx, color=color, linestyle='--', alpha=0.6, linewidth=1)
# Add legend for gesture colors
legend_elements = []
for label_name in label_names:
color = get_gesture_color(label_name)
legend_elements.append(plt.Line2D([0], [0], color=color, linestyle='--', label=label_name))
axes_raw[0].legend(handles=legend_elements, loc='upper right', fontsize=8)
fig_raw.suptitle("Raw EMG Signal (Centered for Display) - All Channels", fontsize=14, fontweight='bold')
fig_raw.supxlabel("Sample Index")
plt.tight_layout()
# --- Figures 2-5: Feature plots (one per feature type) ---
for feat_idx, feat_title in enumerate(feature_titles):
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
axes = axes.flatten()
for ch in range(n_channels):
ax = axes[ch]
# Plot feature as line graph
feat_data = features_by_channel[:, ch, feat_idx]
ax.plot(time_axis, feat_data, linewidth=1, color=feature_colors[feat_idx])
ax.set_title(f"Channel {ch}", fontsize=11)
ax.grid(True, alpha=0.3)
# Add vertical lines at gesture transitions
for trans_idx, trans_label in transitions:
color = get_gesture_color(trans_label)
ax.axvline(trans_idx, color=color, linestyle='--', alpha=0.6, linewidth=1)
# Add legend for gesture colors
legend_elements = []
for label_name in label_names:
color = get_gesture_color(label_name)
legend_elements.append(plt.Line2D([0], [0], color=color, linestyle='--', label=label_name))
axes[0].legend(handles=legend_elements, loc='upper right', fontsize=8)
fig.suptitle(f"{feat_title} - All Channels", fontsize=14, fontweight='bold')
fig.supxlabel("Window Index (Time)")
fig.supylabel(feature_ylabels[feat_idx])
plt.tight_layout()
plt.show()
print(f"\n Displayed 5 figures: Raw EMG + 4 features (close windows to continue)")
return X, y, label_names
# =============================================================================
# FEATURE EXTRACTION (Time-domain features for EMG)
# =============================================================================
class EMGFeatureExtractor:
"""
Extracts time-domain features from EMG windows.
Features per channel:
- RMS (Root Mean Square): Signal power/amplitude
- WL (Waveform Length): Signal complexity
- ZC (Zero Crossings): Frequency content indicator
- SSC (Slope Sign Changes): Frequency content indicator
These 4 features × N channels = 4N features per window.
For 4 channels: 16 features total.
IMPORTANT: Per-window centering (DC offset removal) is applied before
feature extraction. This is critical because:
- EMG sensors have DC offset (e.g., ~512 for 10-bit ADC)
- Zero crossings require signal centered around 0
- Per-window centering is causal (works in real-time inference)
- Global centering would leak information across windows
LESSON: These features are:
- Fast to compute (good for real-time)
- Work well with LDA
- Proven effective for EMG gesture recognition
"""
def __init__(self, zc_threshold_percent: float = 0.1, ssc_threshold_percent: float = 0.1):
"""
Args:
zc_threshold_percent: ZC threshold as fraction of signal RMS
ssc_threshold_percent: SSC threshold as fraction of signal RMS squared
"""
self.zc_threshold_percent = zc_threshold_percent
self.ssc_threshold_percent = ssc_threshold_percent
def extract_features_single_channel(self, signal: np.ndarray) -> dict:
"""
Extract all features from a single channel signal.
Per-window centering is applied first to remove DC offset.
This is the standard approach for EMG feature extraction.
"""
# Per-window centering: remove DC offset (critical for ZC/SSC)
# This only uses data from the current window (causal for real-time)
signal = signal - np.mean(signal)
# RMS - Root Mean Square (now measures AC power, not DC offset)
rms = np.sqrt(np.mean(signal ** 2))
# WL - Waveform Length
wl = np.sum(np.abs(np.diff(signal)))
# Dynamic thresholds based on signal RMS
zc_thresh = self.zc_threshold_percent * rms
ssc_thresh = (self.ssc_threshold_percent * rms) ** 2
# ZC - Zero Crossings (with threshold to reject noise)
# Now meaningful because signal is centered around 0
sign_changes = signal[:-1] * signal[1:] < 0
amplitude_diff = np.abs(np.diff(signal)) > zc_thresh
zc = np.sum(sign_changes & amplitude_diff)
# SSC - Slope Sign Changes
diff_left = signal[1:-1] - signal[:-2]
diff_right = signal[1:-1] - signal[2:]
ssc = np.sum((diff_left * diff_right) > ssc_thresh)
return {'rms': rms, 'wl': wl, 'zc': zc, 'ssc': ssc}
def extract_features_window(self, window: np.ndarray) -> np.ndarray:
"""
Extract features from a window of shape (samples, channels).
Returns flat array: [ch0_rms, ch0_wl, ch0_zc, ch0_ssc, ch1_rms, ...]
"""
n_channels = window.shape[1]
features = []
for ch in range(n_channels):
ch_features = self.extract_features_single_channel(window[:, ch])
features.extend([ch_features['rms'], ch_features['wl'],
ch_features['zc'], ch_features['ssc']])
return np.array(features)
def extract_features_batch(self, X: np.ndarray) -> np.ndarray:
"""
Extract features from batch of windows.
Args:
X: Shape (n_windows, n_samples, n_channels)
Returns:
Features array of shape (n_windows, n_features)
where n_features = n_channels * 4
"""
n_windows = X.shape[0]
n_channels = X.shape[2]
n_features = n_channels * 4 # 4 features per channel
features = np.zeros((n_windows, n_features))
for i in range(n_windows):
features[i] = self.extract_features_window(X[i])
return features
def get_feature_names(self, n_channels: int) -> list[str]:
"""Get human-readable feature names."""
names = []
for ch in range(n_channels):
names.extend([f'ch{ch}_rms', f'ch{ch}_wl', f'ch{ch}_zc', f'ch{ch}_ssc'])
return names
# =============================================================================
# LDA CLASSIFIER
# =============================================================================
class EMGClassifier:
"""
LDA-based EMG gesture classifier.
LESSON: Why LDA for EMG?
- Fast training and inference (good for embedded)
- Works well with small datasets
- Interpretable (can visualize decision boundaries)
- Proven effective for EMG in literature
"""
def __init__(self):
self.feature_extractor = EMGFeatureExtractor()
self.lda = LinearDiscriminantAnalysis()
self.label_names: list[str] = []
self.is_trained = False
self.feature_names: list[str] = []
def train(self, X: np.ndarray, y: np.ndarray, label_names: list[str]):
"""
Train the classifier.
Args:
X: Raw EMG windows (n_windows, n_samples, n_channels)
y: Integer labels (n_windows,)
label_names: List of label strings
"""
print("\n[Classifier] Extracting features...")
X_features = self.feature_extractor.extract_features_batch(X)
self.feature_names = self.feature_extractor.get_feature_names(X.shape[2])
print(f"[Classifier] Feature matrix shape: {X_features.shape}")
print(f"[Classifier] Features per window: {len(self.feature_names)}")
print("\n[Classifier] Training LDA...")
self.lda.fit(X_features, y)
self.label_names = label_names
self.is_trained = True
# Training accuracy
train_acc = self.lda.score(X_features, y)
print(f"[Classifier] Training accuracy: {train_acc*100:.1f}%")
return X_features
def evaluate(self, X: np.ndarray, y: np.ndarray) -> dict:
"""Evaluate classifier on test data."""
if not self.is_trained:
raise ValueError("Classifier not trained!")
X_features = self.feature_extractor.extract_features_batch(X)
y_pred = self.lda.predict(X_features)
accuracy = np.mean(y_pred == y)
return {
'accuracy': accuracy,
'y_pred': y_pred,
'y_true': y
}
def cross_validate(self, X: np.ndarray, y: np.ndarray, cv: int = 5) -> np.ndarray:
"""Perform k-fold cross-validation."""
print(f"\n[Classifier] Running {cv}-fold cross-validation...")
X_features = self.feature_extractor.extract_features_batch(X)
scores = cross_val_score(self.lda, X_features, y, cv=cv)
return scores
def predict(self, window: np.ndarray) -> tuple[str, np.ndarray]:
"""
Predict gesture for a single window.
Args:
window: Shape (n_samples, n_channels)
Returns:
(predicted_label, probabilities)
"""
if not self.is_trained:
raise ValueError("Classifier not trained!")
features = self.feature_extractor.extract_features_window(window)
pred_idx = self.lda.predict([features])[0]
proba = self.lda.predict_proba([features])[0]
return self.label_names[pred_idx], proba
def get_feature_importance(self) -> dict:
"""Get feature importance based on LDA coefficients."""
if not self.is_trained:
return {}
# For multi-class, average absolute coefficients across classes
coef = np.abs(self.lda.coef_).mean(axis=0)
importance = dict(zip(self.feature_names, coef))
return dict(sorted(importance.items(), key=lambda x: x[1], reverse=True))
def save(self, filepath: Path) -> Path:
"""
Save the trained classifier to disk.
Saves:
- LDA model parameters
- Feature extractor settings
- Label names
- Feature names
Args:
filepath: Path to save the model (e.g., 'models/emg_classifier.joblib')
Returns:
Path to the saved model file
"""
if not self.is_trained:
raise ValueError("Cannot save untrained classifier!")
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
model_data = {
'lda': self.lda,
'label_names': self.label_names,
'feature_names': self.feature_names,
'feature_extractor_params': {
'zc_threshold_percent': self.feature_extractor.zc_threshold_percent,
'ssc_threshold_percent': self.feature_extractor.ssc_threshold_percent,
},
'version': '1.0', # For future compatibility
}
joblib.dump(model_data, filepath)
print(f"[Classifier] Model saved to: {filepath}")
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
return filepath
2026-01-27 21:31:49 -06:00
def export_to_header(self, filepath: Path) -> Path:
"""
Export trained model to a C header file for ESP32 inference.
Args:
filepath: Output .h file path
Returns:
Path to the saved header file
"""
if not self.is_trained:
raise ValueError("Cannot export untrained classifier!")
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
n_classes = len(self.label_names)
n_features = len(self.feature_names)
# Get LDA parameters
# coef_: (n_classes, n_features) - access as [class][feature]
# intercept_: (n_classes,)
coefs = self.lda.coef_
intercepts = self.lda.intercept_
# Add logic for binary classification (sklearn stores only 1 set of coefs)
# For >2 classes, it stores n_classes sets.
if n_classes == 2:
# Binary case: coef_ is (1, n_features), intercept_ is (1,)
# We need to expand this to 2 classes for the C inference engine to be generic.
# Class 1 decision = dot(w, x) + b
# Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function
# BUT: decision_function returns score. A generic 'argmax' approach usually expects
# one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial.
# Let's check sklearn docs or behavior.
# Actually, LDA in sklearn for binary case is special.
# To make C code simple (always argmax), let's explicitly store 2 rows.
# Row 1 (Index 1 in sklearn): coef, intercept
# Row 0 (Index 0): -coef, -intercept ?
# Wait, LDA is generative. The decision boundary is linear.
# Let's assume Multiclass for now or handle binary specifically.
# For simplicity in C, we prefer (n_classes, n_features).
# If coefs.shape[0] != n_classes, we need to handle it.
if coefs.shape[0] == 1:
print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.")
# Class 1 (positive)
c1_coef = coefs[0]
c1_int = intercepts[0]
# Class 0 (negative) - Effectively -score for decision boundary at 0
# But strictly speaking LDA is comparison of log-posteriors.
# Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0).
# The score S = coef.X + intercept. If S > 0 pred class 1, else 0.
# To map this to ArgMax(Score0, Score1):
# We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2.
# Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier).
# Safest: Let's trust that for our 5-gesture demo, it's multiclass.
pass
# Generate C content
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
c_content = [
"/**",
f" * @file {filepath.name}",
" * @brief Trained LDA model weights exported from Python.",
f" * @date {timestamp}",
" */",
"",
"#ifndef MODEL_WEIGHTS_H",
"#define MODEL_WEIGHTS_H",
"",
"#include <stdint.h>",
"",
"/* Metadata */",
f"#define MODEL_NUM_CLASSES {n_classes}",
f"#define MODEL_NUM_FEATURES {n_features}",
"",
"/* Class Names */",
"static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {",
]
for name in self.label_names:
c_content.append(f' "{name}",')
c_content.append("};")
c_content.append("")
c_content.append("/* Feature Extractor Parameters */")
c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f")
c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f")
c_content.append("")
c_content.append("/* LDA Intercepts/Biases */")
c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{")
line = " "
for val in intercepts:
line += f"{val:.6f}f, "
c_content.append(line.rstrip(", "))
c_content.append("};")
c_content.append("")
c_content.append("/* LDA Coefficients (Weights) */")
c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{")
for i, row in enumerate(coefs):
c_content.append(f" /* {self.label_names[i]} */")
c_content.append(" {")
line = " "
for j, val in enumerate(row):
line += f"{val:.6f}f, "
if (j + 1) % 8 == 0:
c_content.append(line)
line = " "
if line.strip():
c_content.append(line.rstrip(", "))
c_content.append(" },")
c_content.append("};")
c_content.append("")
c_content.append("#endif /* MODEL_WEIGHTS_H */")
with open(filepath, 'w') as f:
f.write('\n'.join(c_content))
print(f"[Classifier] Model weights exported to: {filepath}")
return filepath
@classmethod
def load(cls, filepath: Path) -> 'EMGClassifier':
"""
Load a trained classifier from disk.
Args:
filepath: Path to the saved model file
Returns:
Loaded EMGClassifier instance ready for prediction
"""
filepath = Path(filepath)
if not filepath.exists():
raise FileNotFoundError(f"Model file not found: {filepath}")
model_data = joblib.load(filepath)
# Create new instance and restore state
classifier = cls()
classifier.lda = model_data['lda']
classifier.label_names = model_data['label_names']
classifier.feature_names = model_data['feature_names']
classifier.is_trained = True
# Restore feature extractor params
params = model_data.get('feature_extractor_params', {})
classifier.feature_extractor = EMGFeatureExtractor(
zc_threshold_percent=params.get('zc_threshold_percent', 0.1),
ssc_threshold_percent=params.get('ssc_threshold_percent', 0.1),
)
print(f"[Classifier] Model loaded from: {filepath}")
print(f"[Classifier] Labels: {classifier.label_names}")
return classifier
@staticmethod
def get_default_model_path() -> Path:
"""Get the default path for saving/loading models."""
return MODEL_DIR / "emg_lda_classifier.joblib"
@staticmethod
def list_saved_models() -> list[Path]:
"""List all saved model files."""
MODEL_DIR.mkdir(parents=True, exist_ok=True)
return sorted(MODEL_DIR.glob("*.joblib"))
# =============================================================================
# PREDICTION SMOOTHING (Temporal smoothing, majority vote, debouncing)
# =============================================================================
class PredictionSmoother:
"""
Smooths predictions to prevent twitchy/unstable output.
Combines three techniques:
1. Probability Smoothing: Exponential moving average on raw probabilities
2. Majority Vote: Output most common prediction from last N predictions
3. Debouncing: Only change output after N consecutive same predictions
This prevents the robotic hand from twitching when there's an occasional
misclassification in a stream of correct predictions.
Example:
Raw predictions: FIST, FIST, OPEN, FIST, FIST, FIST
Without smoothing: Hand twitches open briefly
With smoothing: Hand stays as FIST (OPEN was filtered out)
"""
def __init__(
self,
label_names: list[str],
probability_smoothing: float = 0.7,
majority_vote_window: int = 5,
debounce_count: int = 3,
):
"""
Args:
label_names: List of gesture labels (must match classifier output)
probability_smoothing: EMA factor (0-1). Higher = more smoothing.
0 = no smoothing, 0.9 = very smooth
majority_vote_window: Number of past predictions to consider for voting
debounce_count: Number of consecutive same predictions needed to change output
"""
self.label_names = label_names
self.n_classes = len(label_names)
# Probability smoothing (Exponential Moving Average)
self.prob_smoothing = probability_smoothing
self.smoothed_proba = np.ones(self.n_classes) / self.n_classes # Start uniform
# Majority vote
self.vote_window = majority_vote_window
self.prediction_history: list[str] = []
# Debouncing
self.debounce_count = debounce_count
self.current_output = None
self.pending_output = None
self.pending_count = 0
# Stats
self.total_predictions = 0
self.output_changes = 0
def update(self, predicted_label: str, probabilities: np.ndarray) -> tuple[str, float, dict]:
"""
Process a new prediction and return smoothed output.
Args:
predicted_label: Raw prediction from classifier
probabilities: Raw probability array from classifier
Returns:
(smoothed_label, confidence, debug_info)
- smoothed_label: The stable output label after all smoothing
- confidence: Confidence in the smoothed output (0-1)
- debug_info: Dict with intermediate values for debugging/display
"""
self.total_predictions += 1
# --- 1. Probability Smoothing (EMA) ---
# Blend new probabilities with historical smoothed probabilities
self.smoothed_proba = (
self.prob_smoothing * self.smoothed_proba +
(1 - self.prob_smoothing) * probabilities
)
# Get prediction from smoothed probabilities
prob_smoothed_idx = np.argmax(self.smoothed_proba)
prob_smoothed_label = self.label_names[prob_smoothed_idx]
prob_smoothed_confidence = self.smoothed_proba[prob_smoothed_idx]
# --- 2. Majority Vote ---
# Add to history and keep window size
self.prediction_history.append(prob_smoothed_label)
if len(self.prediction_history) > self.vote_window:
self.prediction_history.pop(0)
# Count votes
vote_counts = {}
for pred in self.prediction_history:
vote_counts[pred] = vote_counts.get(pred, 0) + 1
# Get majority winner
majority_label = max(vote_counts, key=vote_counts.get)
majority_count = vote_counts[majority_label]
majority_confidence = majority_count / len(self.prediction_history)
# --- 3. Debouncing ---
# Only change output after consistent predictions
if self.current_output is None:
# First prediction
self.current_output = majority_label
self.pending_output = majority_label
self.pending_count = 1
elif majority_label == self.current_output:
# Same as current output, reset pending
self.pending_output = majority_label
self.pending_count = 1
elif majority_label == self.pending_output:
# Same as pending, increment count
self.pending_count += 1
if self.pending_count >= self.debounce_count:
# Enough consecutive predictions, change output
self.current_output = majority_label
self.output_changes += 1
else:
# New prediction, start new pending
self.pending_output = majority_label
self.pending_count = 1
# Final output
final_label = self.current_output
final_confidence = majority_confidence
# Debug info
debug_info = {
'raw_label': predicted_label,
'raw_confidence': float(np.max(probabilities)),
'prob_smoothed_label': prob_smoothed_label,
'prob_smoothed_confidence': float(prob_smoothed_confidence),
'majority_label': majority_label,
'majority_confidence': float(majority_confidence),
'vote_counts': vote_counts,
'pending_output': self.pending_output,
'pending_count': self.pending_count,
'debounce_threshold': self.debounce_count,
}
return final_label, final_confidence, debug_info
def reset(self):
"""Reset all state (call when starting a new prediction session)."""
self.smoothed_proba = np.ones(self.n_classes) / self.n_classes
self.prediction_history = []
self.current_output = None
self.pending_output = None
self.pending_count = 0
self.total_predictions = 0
self.output_changes = 0
def get_stats(self) -> dict:
"""Get statistics about smoothing effectiveness."""
return {
'total_predictions': self.total_predictions,
'output_changes': self.output_changes,
'stability_ratio': 1 - (self.output_changes / max(1, self.total_predictions)),
}
# =============================================================================
# TRAINING (Train LDA classifier)
# =============================================================================
def run_training_demo():
"""
Train an LDA classifier on ALL collected sessions combined.
Shows:
1. Loading all session data combined
2. Feature extraction
3. Training LDA
4. Cross-validation evaluation
5. Feature importance analysis
The model learns from all accumulated data, making it more robust
as you collect more sessions over time.
"""
print("\n" + "=" * 60)
print("TRAIN LDA CLASSIFIER (ALL SESSIONS)")
print("=" * 60)
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
print("\nNo saved sessions found!")
print("Run option 1 first to collect and save training data.")
return None
# Show available sessions
print(f"\nFound {len(sessions)} saved session(s):")
print("-" * 40)
total_windows = 0
for session_id in sessions:
info = storage.get_session_info(session_id)
print(f" - {session_id}: {info['num_windows']} windows, gestures: {info['gestures']}")
total_windows += info['num_windows']
print(f"\nTotal windows across all sessions: {total_windows}")
print("-" * 40)
confirm = input("\nTrain on ALL sessions combined? (y/n): ").strip().lower()
if confirm != 'y':
print("Training cancelled.")
return None
# Load ALL data combined
print(f"\n{'=' * 60}")
print("TRAINING ON ALL SESSIONS COMBINED")
print("=" * 60)
X, y, label_names, loaded_sessions = storage.load_all_for_training()
print(f"\nDataset:")
print(f" Windows: {X.shape[0]}")
print(f" Samples per window: {X.shape[1]}")
print(f" Channels: {X.shape[2]}")
print(f" Classes: {label_names}")
# Count per class
print(f"\nSamples per class:")
for i, name in enumerate(label_names):
count = np.sum(y == i)
print(f" {name}: {count}")
# Create and train classifier
classifier = EMGClassifier()
X_features = classifier.train(X, y, label_names)
# Cross-validation
cv_scores = classifier.cross_validate(X, y, cv=5)
print(f"\nCross-validation scores: {cv_scores}")
print(f"Mean CV accuracy: {cv_scores.mean()*100:.1f}% (+/- {cv_scores.std()*100:.1f}%)")
# Feature importance
print(f"\n{'-' * 40}")
print("FEATURE IMPORTANCE (top 8)")
print("-" * 40)
importance = classifier.get_feature_importance()
for i, (name, score) in enumerate(list(importance.items())[:8]):
bar = "" * int(score * 20)
print(f" {name:12s}: {bar} ({score:.3f})")
# Train/test split evaluation
print(f"\n{'-' * 40}")
print("TRAIN/TEST SPLIT EVALUATION")
print("-" * 40)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Train on train set
test_classifier = EMGClassifier()
test_classifier.train(X_train, y_train, label_names)
# Evaluate on test set
result = test_classifier.evaluate(X_test, y_test)
print(f"\nTest accuracy: {result['accuracy']*100:.1f}%")
print(f"\nClassification Report:")
print(classification_report(result['y_true'], result['y_pred'],
target_names=label_names))
print(f"Confusion Matrix:")
cm = confusion_matrix(result['y_true'], result['y_pred'])
print(f" {'':12s} ", end="")
for name in label_names:
print(f"{name[:8]:>8s} ", end="")
print()
for i, row in enumerate(cm):
print(f" {label_names[i]:12s} ", end="")
for val in row:
print(f"{val:8d} ", end="")
print()
# --- Save the model ---
print(f"\n{'-' * 40}")
print("SAVE MODEL")
print("-" * 40)
default_path = EMGClassifier.get_default_model_path()
print(f"Default save path: {default_path}")
save_choice = input("\nSave this model? (y/n): ").strip().lower()
if save_choice == 'y':
classifier.save(default_path)
print(f"\nModel saved! You can now use 'Live prediction' without retraining.")
return classifier
# =============================================================================
# LIVE PREDICTION (Real-time gesture classification)
# =============================================================================
def run_prediction_demo():
"""
Live prediction demo - classifies gestures in real-time.
Shows:
1. Load saved model OR train fresh on all sessions
2. Stream simulated EMG data
3. Classify each window as it comes in
4. Display predictions with confidence
"""
print("\n" + "=" * 60)
print("LIVE PREDICTION DEMO")
print("=" * 60)
# Check for saved model
saved_models = EMGClassifier.list_saved_models()
default_model = EMGClassifier.get_default_model_path()
classifier = None
if default_model.exists():
print(f"\nSaved model found: {default_model}")
print(f" File size: {default_model.stat().st_size / 1024:.1f} KB")
load_choice = input("\nLoad saved model? (y=load, n=retrain): ").strip().lower()
if load_choice == 'y':
classifier = EMGClassifier.load(default_model)
if classifier is None:
# Need to train a new model
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
print("\nNo saved sessions found! Collect data first (Option 1).")
return None
# Show available sessions
print(f"\nNo saved model (or retraining requested).")
print(f"Will train on ALL {len(sessions)} session(s):")
print("-" * 40)
total_windows = 0
for session_id in sessions:
info = storage.get_session_info(session_id)
print(f" - {session_id}: {info['num_windows']} windows")
total_windows += info['num_windows']
print(f"\nTotal training windows: {total_windows}")
print("-" * 40)
confirm = input("\nTrain and start prediction? (y/n): ").strip().lower()
if confirm != 'y':
print("Prediction cancelled.")
return None
# Load ALL sessions and train model
print(f"\n[Training model on all sessions...]")
X, y, label_names, loaded_sessions = storage.load_all_for_training()
classifier = EMGClassifier()
classifier.train(X, y, label_names)
# Start live prediction
print("\n" + "=" * 60)
print("STARTING LIVE PREDICTION (WITH SMOOTHING)")
print("=" * 60)
max_predictions = 50 # Stop after this many predictions
print(f"Running {max_predictions} predictions with smoothing enabled...\n")
print(" Smoothing: Probability EMA (0.7) + Majority Vote (5) + Debounce (3)\n")
stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
parser = EMGParser(num_channels=NUM_CHANNELS)
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
# Create prediction smoother
smoother = PredictionSmoother(
label_names=classifier.label_names,
probability_smoothing=0.7, # Higher = more smoothing
majority_vote_window=5, # Past predictions to consider
debounce_count=3, # Consecutive predictions needed to change
)
2026-01-19 22:24:04 -06:00
# Cycle through gestures for demo (names match ESP32 gesture definitions)
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
gesture_idx = 0
gesture_duration = 2.5 # seconds per gesture
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[0]
stream.set_gesture(current_gesture)
print(f" [Simulating: {current_gesture.upper()}]")
stream.start()
prediction_count = 0
try:
while prediction_count < max_predictions:
# Change simulated gesture periodically
elapsed = time.perf_counter() - gesture_start
if elapsed > gesture_duration:
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[gesture_idx]
stream.set_gesture(current_gesture)
print(f"\n [Simulating: {current_gesture.upper()}]")
# Read and process data
line = stream.readline()
if line:
sample = parser.parse_line(line)
if sample:
window = windower.add_sample(sample)
if window:
# Classify the window (raw prediction)
window_data = window.to_numpy()
raw_label, proba = classifier.predict(window_data)
# Apply smoothing
smoothed_label, smoothed_conf, debug = smoother.update(raw_label, proba)
prediction_count += 1
# Display both raw and smoothed predictions
raw_conf = max(proba) * 100
smoothed_conf_pct = smoothed_conf * 100
# Visual bar for smoothed confidence
bar_len = round(smoothed_conf_pct / 5)
bar = "" * bar_len + "" * (20 - bar_len)
# Show raw vs smoothed (smoothed is the stable output)
raw_marker = " " if raw_label == smoothed_label else "!!"
print(f" #{prediction_count:3d}{bar}{smoothed_label:12s} ({smoothed_conf_pct:5.1f}%) {raw_marker} raw:{raw_label[:8]:8s}")
except KeyboardInterrupt:
print("\n\n[Stopped by user]")
finally:
stream.stop()
# Show smoothing stats
stats = smoother.get_stats()
print(f"\n" + "-" * 40)
print(f"SMOOTHING STATISTICS")
print("-" * 40)
print(f" Total predictions: {stats['total_predictions']}")
print(f" Output changes: {stats['output_changes']}")
print(f" Stability ratio: {stats['stability_ratio']*100:.1f}%")
print(f"\n (Without smoothing, output would change with every raw prediction)")
return classifier
# =============================================================================
# LDA VISUALIZATION (Decision boundaries and feature space)
# =============================================================================
def run_visualization_demo():
"""
Visualize the LDA model trained on ALL sessions with plots:
1. 2D feature space scatter plot (LDA reduced)
2. Decision boundaries
3. Class distributions
4. Confusion matrix heatmap
Uses all accumulated session data for a complete picture of the model.
"""
print("\n" + "=" * 60)
print("LDA VISUALIZATION (ALL SESSIONS)")
print("=" * 60)
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
print("\nNo saved sessions found! Collect data first (Option 1).")
return None
# Show available sessions
print(f"\nFound {len(sessions)} saved session(s):")
print("-" * 40)
total_windows = 0
for session_id in sessions:
info = storage.get_session_info(session_id)
print(f" - {session_id}: {info['num_windows']} windows")
total_windows += info['num_windows']
print(f"\nTotal windows: {total_windows}")
print("-" * 40)
confirm = input("\nVisualize model trained on ALL sessions? (y/n): ").strip().lower()
if confirm != 'y':
print("Visualization cancelled.")
return None
# Load ALL data combined
X, y, label_names, loaded_sessions = storage.load_all_for_training()
# Extract features
extractor = EMGFeatureExtractor()
X_features = extractor.extract_features_batch(X)
# Train LDA
print("\n[Training LDA for visualization...]")
lda = LinearDiscriminantAnalysis()
lda.fit(X_features, y)
# Transform to LDA space (reduces to n_classes - 1 dimensions)
X_lda = lda.transform(X_features)
n_classes = len(label_names)
print(f" LDA dimensions: {X_lda.shape[1]}")
# Color scheme
colors = plt.cm.viridis(np.linspace(0.1, 0.9, n_classes))
# --- Figure 1: LDA Feature Space (2D projection) ---
fig1, ax1 = plt.subplots(figsize=(10, 8))
for i, label in enumerate(label_names):
mask = y == i
ax1.scatter(X_lda[mask, 0],
X_lda[mask, 1] if X_lda.shape[1] > 1 else np.zeros(mask.sum()),
c=[colors[i]], label=label, s=100, alpha=0.7, edgecolors='white', linewidth=1)
# Add class means
for i, label in enumerate(label_names):
mask = y == i
mean_x = X_lda[mask, 0].mean()
mean_y = X_lda[mask, 1].mean() if X_lda.shape[1] > 1 else 0
ax1.scatter(mean_x, mean_y, c=[colors[i]], s=400, marker='X', edgecolors='black', linewidth=2)
ax1.annotate(label.upper(), (mean_x, mean_y), fontsize=12, fontweight='bold',
ha='center', va='bottom', xytext=(0, 15), textcoords='offset points')
ax1.set_xlabel("LDA Component 1", fontsize=12)
ax1.set_ylabel("LDA Component 2", fontsize=12)
ax1.set_title("LDA Feature Space - Gesture Clusters", fontsize=14, fontweight='bold')
ax1.legend(loc='upper right', fontsize=10)
ax1.grid(True, alpha=0.3)
# --- Figure 2: Decision Boundary Heatmap ---
if X_lda.shape[1] >= 1:
fig2, ax2 = plt.subplots(figsize=(10, 8))
# Create mesh grid
x_min, x_max = X_lda[:, 0].min() - 1, X_lda[:, 0].max() + 1
if X_lda.shape[1] > 1:
y_min, y_max = X_lda[:, 1].min() - 1, X_lda[:, 1].max() + 1
else:
y_min, y_max = -2, 2
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
np.linspace(y_min, y_max, 200))
# For prediction, we need to go back to original feature space
# Use simplified approach: train new LDA on LDA-transformed features
if X_lda.shape[1] > 1:
lda_2d = LinearDiscriminantAnalysis()
lda_2d.fit(X_lda[:, :2], y)
Z = lda_2d.predict(np.c_[xx.ravel(), yy.ravel()])
else:
# 1D case - simple threshold
Z = lda.predict(X_features[0:1]) # dummy
Z = np.zeros(xx.ravel().shape)
for i, x_val in enumerate(xx.ravel()):
if X_lda.shape[1] == 1:
# Find closest class mean
distances = [abs(x_val - X_lda[y == c, 0].mean()) for c in range(n_classes)]
Z[i] = np.argmin(distances)
Z = Z.reshape(xx.shape)
# Plot decision regions
ax2.contourf(xx, yy, Z, alpha=0.3, levels=np.arange(-0.5, n_classes, 1),
colors=[colors[i] for i in range(n_classes)])
ax2.contour(xx, yy, Z, colors='black', linewidths=0.5, alpha=0.5)
# Plot data points
for i, label in enumerate(label_names):
mask = y == i
ax2.scatter(X_lda[mask, 0],
X_lda[mask, 1] if X_lda.shape[1] > 1 else np.zeros(mask.sum()),
c=[colors[i]], label=label, s=80, alpha=0.9, edgecolors='black', linewidth=0.5)
ax2.set_xlabel("LDA Component 1", fontsize=12)
ax2.set_ylabel("LDA Component 2", fontsize=12)
ax2.set_title("LDA Decision Boundaries", fontsize=14, fontweight='bold')
ax2.legend(loc='upper right', fontsize=10)
# --- Figure 3: Feature Importance Radar Chart ---
fig3, ax3 = plt.subplots(figsize=(10, 8), subplot_kw=dict(projection='polar'))
feature_names = extractor.get_feature_names(X.shape[2])
coef = np.abs(lda.coef_).mean(axis=0)
coef_normalized = coef / coef.max() # Normalize to 0-1
# Number of features
n_features = len(feature_names)
angles = np.linspace(0, 2 * np.pi, n_features, endpoint=False).tolist()
# Complete the loop
coef_normalized = np.concatenate([coef_normalized, [coef_normalized[0]]])
angles += angles[:1]
ax3.plot(angles, coef_normalized, 'o-', linewidth=2, color='#2E86AB', markersize=8)
ax3.fill(angles, coef_normalized, alpha=0.25, color='#2E86AB')
ax3.set_xticks(angles[:-1])
ax3.set_xticklabels(feature_names, fontsize=9)
ax3.set_ylim(0, 1.1)
ax3.set_title("Feature Importance (Radar)", fontsize=14, fontweight='bold', pad=20)
# --- Figure 4: Class Distribution Histograms ---
fig4, axes4 = plt.subplots(1, n_classes, figsize=(4*n_classes, 4))
if n_classes == 1:
axes4 = [axes4]
for i, (ax, label) in enumerate(zip(axes4, label_names)):
mask = y == i
ax.hist(X_lda[mask, 0], bins=15, color=colors[i], alpha=0.7, edgecolor='black')
ax.axvline(X_lda[mask, 0].mean(), color='black', linestyle='--', linewidth=2, label='Mean')
ax.set_xlabel("LDA Component 1")
ax.set_ylabel("Count")
ax.set_title(f"{label.upper()}", fontsize=12, fontweight='bold')
ax.legend()
fig4.suptitle("Class Distributions on LDA Axis", fontsize=14, fontweight='bold')
plt.tight_layout()
# --- Figure 5: Confusion Matrix Heatmap ---
from sklearn.model_selection import cross_val_predict
fig5, ax5 = plt.subplots(figsize=(8, 6))
y_pred = cross_val_predict(lda, X_features, y, cv=5)
cm = confusion_matrix(y, y_pred)
im = ax5.imshow(cm, interpolation='nearest', cmap='Blues')
ax5.figure.colorbar(im, ax=ax5)
ax5.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=label_names,
yticklabels=label_names,
xlabel='Predicted',
ylabel='Actual',
title='Confusion Matrix Heatmap')
# Add text annotations
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax5.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black",
fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()
print("\n Displayed 5 visualization figures")
return lda
# =============================================================================
# ENTRY POINT
# =============================================================================
if __name__ == "__main__":
print(__doc__)
while True:
print("\n" + "=" * 60)
print("EMG DATA COLLECTION PIPELINE")
print("=" * 60)
print("\nOptions:")
print(" 1. Collect data (labeled session)")
print(" 2. Inspect saved sessions (view features)")
print(" 3. Train LDA classifier")
print(" 4. Live prediction demo")
print(" 5. Visualize LDA model")
print(" q. Quit")
choice = input("\nEnter choice: ").strip().lower()
if choice == 'q':
print("\nGoodbye!")
break
elif choice == "1":
windows, labels = run_labeled_collection_demo()
elif choice == "2":
result = run_storage_demo()
elif choice == "3":
classifier = run_training_demo()
elif choice == "4":
classifier = run_prediction_demo()
elif choice == "5":
lda = run_visualization_demo()
else:
print("\nInvalid choice. Please enter 1-5 or q.")