2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
EMG Data Collection Pipeline
|
|
|
|
|
|
============================
|
|
|
|
|
|
A complete pipeline for collecting, labeling, and classifying EMG signals.
|
|
|
|
|
|
|
|
|
|
|
|
OPTIONS:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
1. Collect Data - Run a labeled collection session with timed prompts (requires ESP32)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
2. Inspect Data - Load saved sessions, view raw EMG and features
|
|
|
|
|
|
3. Train Classifier - Train LDA on collected data with cross-validation
|
2026-03-10 11:39:02 -05:00
|
|
|
|
4. Live Prediction - Real-time gesture classification (requires ESP32)
|
|
|
|
|
|
5. Visualize LDA - Decision boundaries and feature space plots
|
|
|
|
|
|
6. Benchmark - Compare LDA/QDA/SVM/MLP classifiers
|
2026-01-17 23:31:15 -06:00
|
|
|
|
q. Quit
|
|
|
|
|
|
|
|
|
|
|
|
FEATURES:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
- Real-time EMG acquisition via ESP32 serial interface
|
2026-01-17 23:31:15 -06:00
|
|
|
|
- Timed prompt system for consistent data collection
|
2026-03-10 11:39:02 -05:00
|
|
|
|
- Automatic labeling based on prompt timing with onset detection
|
2026-01-17 23:31:15 -06:00
|
|
|
|
- HDF5 storage with metadata
|
|
|
|
|
|
- Time-domain feature extraction (RMS, WL, ZC, SSC)
|
|
|
|
|
|
- LDA classifier with evaluation metrics
|
2026-03-10 11:39:02 -05:00
|
|
|
|
- Prediction smoothing (EMA + majority vote + debounce)
|
|
|
|
|
|
|
|
|
|
|
|
HARDWARE REQUIRED:
|
|
|
|
|
|
- ESP32 with EMG sensors connected and firmware flashed
|
|
|
|
|
|
- USB serial connection (921600 baud)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import time
|
|
|
|
|
|
import threading
|
|
|
|
|
|
import queue
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
import json
|
|
|
|
|
|
import h5py
|
2026-03-10 11:39:02 -05:00
|
|
|
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
|
|
|
|
|
|
from sklearn.model_selection import cross_val_score, train_test_split, cross_val_predict, GroupShuffleSplit, GroupKFold
|
2026-01-17 23:31:15 -06:00
|
|
|
|
from sklearn.metrics import classification_report, confusion_matrix
|
|
|
|
|
|
import joblib # For model persistence
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
2026-03-10 11:39:02 -05:00
|
|
|
|
from scipy.signal import butter, sosfiltfilt, sosfilt, sosfilt_zi # For label alignment + bandpass
|
|
|
|
|
|
from serial_stream import RealSerialStream # ESP32 serial communication
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# CONFIGURATION
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
|
2026-01-19 22:24:04 -06:00
|
|
|
|
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
|
2026-01-27 20:12:13 -06:00
|
|
|
|
SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Windowing configuration (must match ESP32 inference timing)
|
|
|
|
|
|
WINDOW_SIZE_MS = 150 # Window size in milliseconds (150 samples at 1kHz)
|
|
|
|
|
|
HOP_SIZE_MS = 25 # Hop/stride in milliseconds (25 samples at 1kHz)
|
|
|
|
|
|
MAJORITY_WINDOW = 10
|
|
|
|
|
|
|
|
|
|
|
|
# Hand classifier channel selection
|
|
|
|
|
|
# The hand gesture classifier uses only forearm channels (ch0-ch2).
|
|
|
|
|
|
# The bicep channel (ch3) is excluded to prevent bicep activity from
|
|
|
|
|
|
# corrupting hand gesture classification. Ch3 is reserved for independent
|
|
|
|
|
|
# bicep envelope processing (see Phase 5).
|
|
|
|
|
|
HAND_CHANNELS = [0, 1, 2] # Forearm channels only (excludes bicep ch3)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Labeling configuration
|
|
|
|
|
|
GESTURE_HOLD_SEC = 3.0 # How long to hold each gesture
|
|
|
|
|
|
REST_BETWEEN_SEC = 2.0 # Rest period between gestures
|
|
|
|
|
|
REPS_PER_GESTURE = 3 # Repetitions per gesture in a session
|
2026-03-10 11:39:02 -05:00
|
|
|
|
LABEL_SHIFT_MS = 150 # Shift label lookup forward by this many ms to account
|
|
|
|
|
|
# for human reaction time. A 150ms window labelled at its
|
|
|
|
|
|
# start_time can straddle a prompt transition; using
|
|
|
|
|
|
# start_time + shift assigns the label based on what the
|
|
|
|
|
|
# user is actually doing at the window's centre.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Storage configuration
|
|
|
|
|
|
DATA_DIR = Path("collected_data") # Directory to store session files
|
|
|
|
|
|
MODEL_DIR = Path("models") # Directory to store trained models
|
|
|
|
|
|
USER_ID = "user_001" # Current user ID (change per user)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# LABEL ALIGNMENT CONFIGURATION
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# Human reaction time causes EMG activity to lag behind label prompts.
|
|
|
|
|
|
# We detect when EMG actually rises and shift labels to match.
|
|
|
|
|
|
|
|
|
|
|
|
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
|
|
|
|
|
|
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
|
|
|
|
|
|
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Change 0: after onset detection shifts the label start backward, additionally
|
|
|
|
|
|
# relabel the first LABEL_FORWARD_SHIFT_MS of each gesture run as "rest" to skip
|
|
|
|
|
|
# the EMG transient at gesture onset. Paired with reducing TRANSITION_START_MS.
|
|
|
|
|
|
LABEL_FORWARD_SHIFT_MS = 100 # ms of each gesture onset to relabel as rest
|
2026-01-27 20:12:13 -06:00
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# TRANSITION WINDOW FILTERING
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# Windows near gesture transitions contain ambiguous data (reaction time at start,
|
|
|
|
|
|
# muscle relaxation at end). Discard these during training for cleaner labels.
|
|
|
|
|
|
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
|
|
|
|
|
|
|
|
|
|
|
|
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
|
2026-03-10 11:39:02 -05:00
|
|
|
|
TRANSITION_START_MS = 200 # Discard windows within this time AFTER gesture starts
|
2026-01-27 20:12:13 -06:00
|
|
|
|
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# DATA STRUCTURES
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class EMGSample:
|
|
|
|
|
|
"""Single sample from all channels at one point in time."""
|
|
|
|
|
|
timestamp: float # Python-side timestamp (seconds, monotonic)
|
|
|
|
|
|
channels: list[float] # Raw ADC values per channel
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# DEPRECATED: esp_timestamp_ms is no longer used. Python-side timestamps are used
|
|
|
|
|
|
# for label alignment. Kept for backward compatibility with old serialized data.
|
|
|
|
|
|
esp_timestamp_ms: Optional[int] = None
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class EMGWindow:
|
|
|
|
|
|
"""
|
|
|
|
|
|
A window of samples - this is what we'll feed to ML models.
|
|
|
|
|
|
|
|
|
|
|
|
NOTE: This class intentionally contains NO label information.
|
|
|
|
|
|
Labels are stored separately to enforce training/inference separation.
|
|
|
|
|
|
This ensures inference code cannot accidentally access ground truth.
|
|
|
|
|
|
"""
|
|
|
|
|
|
window_id: int
|
|
|
|
|
|
start_time: float
|
|
|
|
|
|
end_time: float
|
|
|
|
|
|
samples: list[EMGSample]
|
|
|
|
|
|
|
|
|
|
|
|
def to_numpy(self) -> np.ndarray:
|
|
|
|
|
|
"""Convert to numpy array of shape (n_samples, n_channels)."""
|
|
|
|
|
|
return np.array([s.channels for s in self.samples])
|
|
|
|
|
|
|
|
|
|
|
|
def get_channel(self, ch: int) -> np.ndarray:
|
|
|
|
|
|
"""Get single channel as 1D array."""
|
|
|
|
|
|
return np.array([s.channels[ch] for s in self.samples])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# DATA PARSER (Converts serial lines to EMGSample objects)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class EMGParser:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Parses incoming serial data into structured EMGSample objects.
|
|
|
|
|
|
|
|
|
|
|
|
LESSON: Always validate incoming data. Serial lines can be:
|
|
|
|
|
|
- Corrupted (partial lines, garbage bytes)
|
|
|
|
|
|
- Missing (dropped packets)
|
|
|
|
|
|
- Out of order (buffer issues)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, num_channels: int):
|
|
|
|
|
|
self.num_channels = num_channels
|
|
|
|
|
|
self.parse_errors = 0
|
|
|
|
|
|
self.samples_parsed = 0
|
|
|
|
|
|
|
|
|
|
|
|
def parse_line(self, line: str) -> Optional[EMGSample]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Parse a line from ESP32 into an EMGSample.
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
|
|
|
|
|
|
Python assigns timestamp on receipt for label alignment.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
Returns None if parsing fails.
|
|
|
|
|
|
"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# Strip whitespace and split
|
|
|
|
|
|
parts = line.strip().split(',')
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# Validate we have correct number of fields (channels only)
|
|
|
|
|
|
if len(parts) != self.num_channels:
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.parse_errors += 1
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Parse channel values
|
2026-01-27 20:12:13 -06:00
|
|
|
|
channels = [float(parts[i]) for i in range(self.num_channels)]
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# Create sample with Python-side timestamp (aligned with label clock)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
sample = EMGSample(
|
|
|
|
|
|
timestamp=time.perf_counter(), # High-resolution monotonic clock
|
|
|
|
|
|
channels=channels,
|
2026-03-10 11:39:02 -05:00
|
|
|
|
esp_timestamp_ms=None # Deprecated field, kept for compatibility
|
2026-01-17 23:31:15 -06:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self.samples_parsed += 1
|
|
|
|
|
|
return sample
|
|
|
|
|
|
|
|
|
|
|
|
except (ValueError, IndexError) as e:
|
|
|
|
|
|
self.parse_errors += 1
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# WINDOWING (Groups samples into fixed-size windows)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class Windower:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Groups incoming samples into fixed-size windows.
|
|
|
|
|
|
|
|
|
|
|
|
LESSON: ML models need fixed-size inputs. We can't feed them a continuous
|
|
|
|
|
|
stream - we need to chunk it into windows of consistent size.
|
|
|
|
|
|
|
|
|
|
|
|
Window size tradeoffs:
|
|
|
|
|
|
- Too small (50ms): Not enough data, noisy features
|
|
|
|
|
|
- Too large (500ms): Slow response, gesture transitions blurred
|
|
|
|
|
|
- Sweet spot: 150-250ms for EMG gesture recognition
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def __init__(self, window_size_ms: int, sample_rate: int, hop_size_ms: int = 25):
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.window_size_ms = window_size_ms
|
|
|
|
|
|
self.sample_rate = sample_rate
|
2026-03-10 11:39:02 -05:00
|
|
|
|
self.hop_size_ms = hop_size_ms
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Calculate window and step size in samples (hop-based, not overlap-based)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.window_size_samples = int(window_size_ms / 1000 * sample_rate)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
self.step_size_samples = int(hop_size_ms / 1000 * sample_rate)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Buffer for incoming samples
|
|
|
|
|
|
self.buffer: list[EMGSample] = []
|
|
|
|
|
|
self.window_count = 0
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Verification: Print first 10 window start indices and timestamps
|
|
|
|
|
|
self._verification_printed = False
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"[Windower] Window: {window_size_ms}ms = {self.window_size_samples} samples")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print(f"[Windower] Hop: {hop_size_ms}ms = {self.step_size_samples} samples")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
def add_sample(self, sample: EMGSample) -> Optional[EMGWindow]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Add a sample to the buffer. Returns a window if we have enough samples.
|
|
|
|
|
|
|
|
|
|
|
|
Returns None if buffer isn't full yet.
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
|
|
|
|
|
Window timing (at 1kHz):
|
|
|
|
|
|
- Window 0: samples 0-149, start index 0, time 0.000s
|
|
|
|
|
|
- Window 1: samples 25-174, start index 25, time 0.025s
|
|
|
|
|
|
- Window 2: samples 50-199, start index 50, time 0.050s
|
|
|
|
|
|
- ...
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
self.buffer.append(sample)
|
|
|
|
|
|
|
|
|
|
|
|
# Check if we have enough samples for a window
|
|
|
|
|
|
if len(self.buffer) >= self.window_size_samples:
|
|
|
|
|
|
# Extract window
|
|
|
|
|
|
window_samples = self.buffer[:self.window_size_samples]
|
|
|
|
|
|
window = EMGWindow(
|
|
|
|
|
|
window_id=self.window_count,
|
|
|
|
|
|
start_time=window_samples[0].timestamp,
|
|
|
|
|
|
end_time=window_samples[-1].timestamp,
|
|
|
|
|
|
samples=window_samples.copy()
|
|
|
|
|
|
)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
|
|
|
|
|
# Verification: Print first 10 window start indices and timestamps
|
|
|
|
|
|
if not self._verification_printed and self.window_count < 10:
|
|
|
|
|
|
start_idx = self.window_count * self.step_size_samples
|
|
|
|
|
|
start_time_sec = start_idx / self.sample_rate
|
|
|
|
|
|
print(f"[Windower] Window {self.window_count}: start_idx={start_idx}, time={start_time_sec:.3f}s")
|
|
|
|
|
|
if self.window_count == 9:
|
|
|
|
|
|
self._verification_printed = True
|
|
|
|
|
|
print(f"[Windower] Verified: 150-sample windows, {self.step_size_samples}-sample hop")
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.window_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
# Slide buffer by step size
|
|
|
|
|
|
self.buffer = self.buffer[self.step_size_samples:]
|
|
|
|
|
|
|
|
|
|
|
|
return window
|
|
|
|
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def flush(self) -> Optional[EMGWindow]:
|
|
|
|
|
|
"""Flush remaining samples as a partial window (if any)."""
|
|
|
|
|
|
if len(self.buffer) > 0:
|
|
|
|
|
|
window = EMGWindow(
|
|
|
|
|
|
window_id=self.window_count,
|
|
|
|
|
|
start_time=self.buffer[0].timestamp,
|
|
|
|
|
|
end_time=self.buffer[-1].timestamp,
|
|
|
|
|
|
samples=self.buffer.copy()
|
|
|
|
|
|
)
|
|
|
|
|
|
self.buffer = []
|
|
|
|
|
|
return window
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# PROMPT SYSTEM (Timed prompts for labeling)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class GesturePrompt:
|
|
|
|
|
|
"""Defines a single gesture prompt in the collection sequence."""
|
|
|
|
|
|
gesture_name: str # e.g., "index_flex", "rest", "fist"
|
|
|
|
|
|
duration_sec: float # How long to hold this gesture
|
|
|
|
|
|
start_time: float = 0.0 # Filled in by scheduler when session starts
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_id: int = -1 # Unique ID for this trial (gesture repetition)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class PromptSchedule:
|
|
|
|
|
|
"""A complete sequence of prompts for a collection session."""
|
|
|
|
|
|
prompts: list[GesturePrompt]
|
|
|
|
|
|
total_duration: float = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
|
|
"""Calculate start times and total duration."""
|
|
|
|
|
|
current_time = 0.0
|
|
|
|
|
|
for prompt in self.prompts:
|
|
|
|
|
|
prompt.start_time = current_time
|
|
|
|
|
|
current_time += prompt.duration_sec
|
|
|
|
|
|
self.total_duration = current_time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptScheduler:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Manages timed prompts during data collection.
|
|
|
|
|
|
|
|
|
|
|
|
LESSON: Timed prompts give you consistent, repeatable data collection.
|
|
|
|
|
|
The user knows exactly when to perform each gesture, and you know
|
|
|
|
|
|
exactly when each gesture should be happening for labeling.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, gestures: list[str], hold_sec: float, rest_sec: float, reps: int):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Build a prompt schedule.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
gestures: List of gesture names (e.g., ["index_flex", "fist"])
|
|
|
|
|
|
hold_sec: How long to hold each gesture
|
|
|
|
|
|
rest_sec: Rest period between gestures
|
|
|
|
|
|
reps: Number of repetitions per gesture
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.gestures = gestures
|
|
|
|
|
|
self.hold_sec = hold_sec
|
|
|
|
|
|
self.rest_sec = rest_sec
|
|
|
|
|
|
self.reps = reps
|
|
|
|
|
|
|
|
|
|
|
|
# Build the schedule
|
|
|
|
|
|
self.schedule = self._build_schedule()
|
|
|
|
|
|
self.session_start_time: Optional[float] = None
|
|
|
|
|
|
|
|
|
|
|
|
def _build_schedule(self) -> PromptSchedule:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
"""Create the sequence of prompts with unique trial_ids."""
|
2026-01-17 23:31:15 -06:00
|
|
|
|
prompts = []
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_counter = 0
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Initial rest period (trial_id = 0)
|
|
|
|
|
|
prompts.append(GesturePrompt("rest", self.rest_sec, trial_id=trial_counter))
|
|
|
|
|
|
trial_counter += 1
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# For each repetition
|
|
|
|
|
|
for rep in range(self.reps):
|
|
|
|
|
|
# Cycle through all gestures
|
|
|
|
|
|
for gesture in self.gestures:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Gesture trial
|
|
|
|
|
|
prompts.append(GesturePrompt(gesture, self.hold_sec, trial_id=trial_counter))
|
|
|
|
|
|
trial_counter += 1
|
|
|
|
|
|
# Rest trial (each rest is its own trial to avoid leakage)
|
|
|
|
|
|
prompts.append(GesturePrompt("rest", self.rest_sec, trial_id=trial_counter))
|
|
|
|
|
|
trial_counter += 1
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
return PromptSchedule(prompts)
|
|
|
|
|
|
|
|
|
|
|
|
def start_session(self):
|
|
|
|
|
|
"""Mark the start of a collection session."""
|
|
|
|
|
|
self.session_start_time = time.perf_counter()
|
|
|
|
|
|
print(f"\n[Scheduler] Session started. Duration: {self.schedule.total_duration:.1f}s")
|
|
|
|
|
|
print(f"[Scheduler] {len(self.schedule.prompts)} prompts scheduled")
|
|
|
|
|
|
|
|
|
|
|
|
def get_current_prompt(self) -> Optional[GesturePrompt]:
|
|
|
|
|
|
"""Get the prompt that should be active right now."""
|
|
|
|
|
|
if self.session_start_time is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
elapsed = time.perf_counter() - self.session_start_time
|
|
|
|
|
|
|
|
|
|
|
|
# Find which prompt is active
|
|
|
|
|
|
for prompt in self.schedule.prompts:
|
|
|
|
|
|
prompt_end = prompt.start_time + prompt.duration_sec
|
|
|
|
|
|
if prompt.start_time <= elapsed < prompt_end:
|
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
|
|
return None # Session complete
|
|
|
|
|
|
|
|
|
|
|
|
def get_elapsed_time(self) -> float:
|
|
|
|
|
|
"""Get seconds elapsed since session start."""
|
|
|
|
|
|
if self.session_start_time is None:
|
|
|
|
|
|
return 0.0
|
|
|
|
|
|
return time.perf_counter() - self.session_start_time
|
|
|
|
|
|
|
|
|
|
|
|
def is_session_complete(self) -> bool:
|
|
|
|
|
|
"""Check if we've passed the end of the schedule."""
|
|
|
|
|
|
return self.get_elapsed_time() >= self.schedule.total_duration
|
|
|
|
|
|
|
|
|
|
|
|
def get_label_for_time(self, timestamp: float) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Get the gesture label for a specific timestamp.
|
|
|
|
|
|
|
|
|
|
|
|
This is used to label windows after collection.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.session_start_time is None:
|
|
|
|
|
|
return "unlabeled"
|
|
|
|
|
|
|
|
|
|
|
|
elapsed = timestamp - self.session_start_time
|
|
|
|
|
|
|
|
|
|
|
|
for prompt in self.schedule.prompts:
|
|
|
|
|
|
prompt_end = prompt.start_time + prompt.duration_sec
|
|
|
|
|
|
if prompt.start_time <= elapsed < prompt_end:
|
|
|
|
|
|
return prompt.gesture_name
|
|
|
|
|
|
|
|
|
|
|
|
return "unlabeled"
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def get_trial_id_for_time(self, timestamp: float) -> int:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Get the trial_id for a specific timestamp.
|
|
|
|
|
|
|
|
|
|
|
|
Each gesture repetition has a unique trial_id. Windows from the same
|
|
|
|
|
|
trial MUST stay together during train/test splitting to prevent leakage.
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.session_start_time is None:
|
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
|
|
elapsed = timestamp - self.session_start_time
|
|
|
|
|
|
|
|
|
|
|
|
for prompt in self.schedule.prompts:
|
|
|
|
|
|
prompt_end = prompt.start_time + prompt.duration_sec
|
|
|
|
|
|
if prompt.start_time <= elapsed < prompt_end:
|
|
|
|
|
|
return prompt.trial_id
|
|
|
|
|
|
|
|
|
|
|
|
return -1
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
def print_schedule(self):
|
|
|
|
|
|
"""Print the full prompt schedule."""
|
|
|
|
|
|
print("\n" + "-" * 40)
|
|
|
|
|
|
print("PROMPT SCHEDULE")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
for i, p in enumerate(self.schedule.prompts):
|
|
|
|
|
|
print(f" {i+1:2d}. [{p.start_time:5.1f}s - {p.start_time + p.duration_sec:5.1f}s] {p.gesture_name}")
|
|
|
|
|
|
print(f"\n Total duration: {self.schedule.total_duration:.1f}s")
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# LABEL ALIGNMENT (Simple Onset Detection)
|
|
|
|
|
|
# =============================================================================
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# NOTE: butter and sosfiltfilt imported at top of file
|
2026-01-27 20:12:13 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def align_labels_with_onset(
|
|
|
|
|
|
labels: list[str],
|
|
|
|
|
|
window_start_times: np.ndarray,
|
|
|
|
|
|
raw_timestamps: np.ndarray,
|
|
|
|
|
|
raw_channels: np.ndarray,
|
|
|
|
|
|
sampling_rate: int,
|
|
|
|
|
|
threshold_factor: float = 2.0,
|
|
|
|
|
|
search_ms: float = 800
|
|
|
|
|
|
) -> list[str]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Align labels to EMG onset by detecting when signal rises above baseline.
|
|
|
|
|
|
|
|
|
|
|
|
Simple algorithm:
|
|
|
|
|
|
1. High-pass filter to remove DC offset
|
|
|
|
|
|
2. Compute RMS envelope across channels
|
|
|
|
|
|
3. At each label transition, find where envelope exceeds baseline + threshold
|
|
|
|
|
|
4. Move label boundary to that point
|
|
|
|
|
|
"""
|
|
|
|
|
|
if len(labels) == 0:
|
|
|
|
|
|
return labels.copy()
|
|
|
|
|
|
|
|
|
|
|
|
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
|
|
|
|
|
|
nyquist = sampling_rate / 2
|
|
|
|
|
|
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
|
|
|
|
|
|
|
|
|
|
|
|
# Filter and compute envelope (RMS across channels)
|
|
|
|
|
|
filtered = np.zeros_like(raw_channels)
|
|
|
|
|
|
for ch in range(raw_channels.shape[1]):
|
|
|
|
|
|
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
|
|
|
|
|
|
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
|
|
|
|
|
|
|
|
|
|
|
|
# Smooth envelope
|
|
|
|
|
|
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
|
|
|
|
|
|
envelope = sosfiltfilt(sos_lp, envelope)
|
|
|
|
|
|
|
|
|
|
|
|
# Find transitions and detect onsets
|
|
|
|
|
|
search_samples = int(search_ms / 1000 * sampling_rate)
|
|
|
|
|
|
baseline_samples = int(200 / 1000 * sampling_rate)
|
|
|
|
|
|
|
|
|
|
|
|
boundaries = [] # (time, new_label)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, len(labels)):
|
|
|
|
|
|
if labels[i] != labels[i - 1]:
|
|
|
|
|
|
prompt_time = window_start_times[i]
|
|
|
|
|
|
|
|
|
|
|
|
# Find index in raw signal closest to prompt time
|
|
|
|
|
|
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
|
|
|
|
|
|
|
|
|
|
|
|
# Get baseline (before transition)
|
|
|
|
|
|
base_start = max(0, prompt_idx - baseline_samples)
|
|
|
|
|
|
baseline = envelope[base_start:prompt_idx]
|
|
|
|
|
|
if len(baseline) == 0:
|
|
|
|
|
|
boundaries.append((prompt_time + 0.3, labels[i]))
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
|
|
|
|
|
|
|
|
|
|
|
|
# Search forward for onset
|
|
|
|
|
|
search_end = min(len(envelope), prompt_idx + search_samples)
|
|
|
|
|
|
onset_idx = None
|
|
|
|
|
|
|
|
|
|
|
|
for j in range(prompt_idx, search_end):
|
|
|
|
|
|
if envelope[j] > threshold:
|
|
|
|
|
|
onset_idx = j
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
if onset_idx is not None:
|
|
|
|
|
|
onset_time = raw_timestamps[onset_idx]
|
|
|
|
|
|
else:
|
|
|
|
|
|
onset_time = prompt_time + 0.3 # fallback
|
|
|
|
|
|
|
|
|
|
|
|
boundaries.append((onset_time, labels[i]))
|
|
|
|
|
|
|
|
|
|
|
|
# Assign labels based on detected boundaries
|
|
|
|
|
|
aligned = []
|
|
|
|
|
|
boundary_idx = 0
|
|
|
|
|
|
current_label = labels[0]
|
|
|
|
|
|
|
|
|
|
|
|
for t in window_start_times:
|
|
|
|
|
|
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
|
|
|
|
|
|
current_label = boundaries[boundary_idx][1]
|
|
|
|
|
|
boundary_idx += 1
|
|
|
|
|
|
aligned.append(current_label)
|
|
|
|
|
|
|
|
|
|
|
|
return aligned
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def filter_transition_windows(
|
|
|
|
|
|
X: np.ndarray,
|
|
|
|
|
|
y: np.ndarray,
|
|
|
|
|
|
labels: list[str],
|
|
|
|
|
|
start_times: np.ndarray,
|
|
|
|
|
|
end_times: np.ndarray,
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids: Optional[np.ndarray] = None,
|
2026-01-27 20:12:13 -06:00
|
|
|
|
transition_start_ms: float = TRANSITION_START_MS,
|
|
|
|
|
|
transition_end_ms: float = TRANSITION_END_MS
|
2026-03-10 11:39:02 -05:00
|
|
|
|
) -> tuple[np.ndarray, np.ndarray, list[str], Optional[np.ndarray]]:
|
2026-01-27 20:12:13 -06:00
|
|
|
|
"""
|
|
|
|
|
|
Filter out windows that fall within transition zones at gesture boundaries.
|
|
|
|
|
|
|
|
|
|
|
|
This removes ambiguous data where:
|
|
|
|
|
|
- User is still reacting to prompt (start of gesture)
|
|
|
|
|
|
- User is anticipating next gesture (end of gesture)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
X: EMG data array (n_windows, samples, channels)
|
|
|
|
|
|
y: Label indices (n_windows,)
|
|
|
|
|
|
labels: String labels (n_windows,)
|
|
|
|
|
|
start_times: Window start times in seconds (n_windows,)
|
|
|
|
|
|
end_times: Window end times in seconds (n_windows,)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids: Trial IDs for train/test splitting (n_windows,) - optional
|
2026-01-27 20:12:13 -06:00
|
|
|
|
transition_start_ms: Discard windows within this time after gesture start
|
|
|
|
|
|
transition_end_ms: Discard windows within this time before gesture end
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Filtered (X, y, labels, trial_ids) with transition windows removed
|
2026-01-27 20:12:13 -06:00
|
|
|
|
"""
|
|
|
|
|
|
if len(X) == 0:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
return X, y, labels, trial_ids
|
2026-01-27 20:12:13 -06:00
|
|
|
|
|
|
|
|
|
|
transition_start_sec = transition_start_ms / 1000.0
|
|
|
|
|
|
transition_end_sec = transition_end_ms / 1000.0
|
|
|
|
|
|
|
|
|
|
|
|
# Find gesture boundaries (where label changes)
|
|
|
|
|
|
# Each boundary is the START of a new gesture segment
|
|
|
|
|
|
boundaries = [0] # First window starts a segment
|
|
|
|
|
|
for i in range(1, len(labels)):
|
|
|
|
|
|
if labels[i] != labels[i-1]:
|
|
|
|
|
|
boundaries.append(i)
|
|
|
|
|
|
boundaries.append(len(labels)) # End marker
|
|
|
|
|
|
|
|
|
|
|
|
# For each segment, find start_time and end_time of the gesture
|
|
|
|
|
|
# Then mark windows that are within transition zones
|
|
|
|
|
|
keep_mask = np.ones(len(X), dtype=bool)
|
|
|
|
|
|
|
|
|
|
|
|
for seg_idx in range(len(boundaries) - 1):
|
|
|
|
|
|
seg_start_idx = boundaries[seg_idx]
|
|
|
|
|
|
seg_end_idx = boundaries[seg_idx + 1]
|
|
|
|
|
|
|
|
|
|
|
|
# Get the time boundaries of this gesture segment
|
|
|
|
|
|
gesture_start_time = start_times[seg_start_idx]
|
|
|
|
|
|
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
|
|
|
|
|
|
|
|
|
|
|
|
# Mark windows in transition zones
|
|
|
|
|
|
for i in range(seg_start_idx, seg_end_idx):
|
|
|
|
|
|
window_start = start_times[i]
|
|
|
|
|
|
window_end = end_times[i]
|
|
|
|
|
|
|
|
|
|
|
|
# Check if window is too close to gesture START (reaction time zone)
|
|
|
|
|
|
if window_start < gesture_start_time + transition_start_sec:
|
|
|
|
|
|
keep_mask[i] = False
|
|
|
|
|
|
|
|
|
|
|
|
# Check if window is too close to gesture END (anticipation zone)
|
|
|
|
|
|
if window_end > gesture_end_time - transition_end_sec:
|
|
|
|
|
|
keep_mask[i] = False
|
|
|
|
|
|
|
|
|
|
|
|
# Apply filter
|
|
|
|
|
|
X_filtered = X[keep_mask]
|
|
|
|
|
|
y_filtered = y[keep_mask]
|
|
|
|
|
|
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids_filtered = trial_ids[keep_mask] if trial_ids is not None else None
|
2026-01-27 20:12:13 -06:00
|
|
|
|
|
|
|
|
|
|
n_removed = len(X) - len(X_filtered)
|
|
|
|
|
|
if n_removed > 0:
|
|
|
|
|
|
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
|
|
|
|
|
|
print(f"[Filter] Kept {len(X_filtered)} windows for training")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
return X_filtered, y_filtered, labels_filtered, trial_ids_filtered
|
2026-01-27 20:12:13 -06:00
|
|
|
|
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# SESSION STORAGE (Save/Load labeled data to HDF5)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class SessionMetadata:
|
|
|
|
|
|
"""Metadata for a collection session."""
|
|
|
|
|
|
user_id: str
|
|
|
|
|
|
session_id: str
|
|
|
|
|
|
timestamp: str
|
|
|
|
|
|
sampling_rate: int
|
|
|
|
|
|
window_size_ms: int
|
|
|
|
|
|
num_channels: int
|
|
|
|
|
|
gestures: list[str]
|
|
|
|
|
|
notes: str = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SessionStorage:
|
|
|
|
|
|
"""Handles saving and loading EMG collection sessions to HDF5 files."""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, data_dir: Path = DATA_DIR):
|
|
|
|
|
|
self.data_dir = Path(data_dir)
|
|
|
|
|
|
self.data_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_session_id(self, user_id: str) -> str:
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
|
|
return f"{user_id}_{timestamp}"
|
|
|
|
|
|
|
|
|
|
|
|
def get_session_filepath(self, session_id: str) -> Path:
|
|
|
|
|
|
return self.data_dir / f"{session_id}.hdf5"
|
|
|
|
|
|
|
|
|
|
|
|
def save_session(
|
|
|
|
|
|
self,
|
|
|
|
|
|
windows: list[EMGWindow],
|
|
|
|
|
|
labels: list[str],
|
|
|
|
|
|
metadata: SessionMetadata,
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids: Optional[list[int]] = None,
|
2026-01-27 20:12:13 -06:00
|
|
|
|
raw_samples: Optional[list[EMGSample]] = None,
|
|
|
|
|
|
session_start_time: Optional[float] = None,
|
|
|
|
|
|
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
|
2026-01-17 23:31:15 -06:00
|
|
|
|
) -> Path:
|
|
|
|
|
|
"""
|
2026-01-27 20:12:13 -06:00
|
|
|
|
Save a collection session to HDF5 with optional label alignment.
|
|
|
|
|
|
|
|
|
|
|
|
When raw_samples and session_start_time are provided and enable_alignment
|
|
|
|
|
|
is True, automatically detects EMG onset and corrects labels for human
|
|
|
|
|
|
reaction time delay.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
windows: List of EMGWindow objects (no label info)
|
|
|
|
|
|
labels: List of gesture labels, parallel to windows
|
|
|
|
|
|
metadata: Session metadata
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids: List of trial IDs, parallel to windows (for proper train/test splitting)
|
2026-01-27 20:12:13 -06:00
|
|
|
|
raw_samples: Raw samples (required for alignment)
|
|
|
|
|
|
session_start_time: When session started (required for alignment)
|
|
|
|
|
|
enable_alignment: Whether to perform automatic label alignment
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
filepath = self.get_session_filepath(metadata.session_id)
|
|
|
|
|
|
|
|
|
|
|
|
if not windows:
|
|
|
|
|
|
raise ValueError("No windows to save!")
|
|
|
|
|
|
|
|
|
|
|
|
if len(windows) != len(labels):
|
|
|
|
|
|
raise ValueError(f"Windows ({len(windows)}) and labels ({len(labels)}) must have same length!")
|
|
|
|
|
|
|
|
|
|
|
|
window_samples = len(windows[0].samples)
|
|
|
|
|
|
num_channels = len(windows[0].samples[0].channels)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# Prepare timing arrays
|
|
|
|
|
|
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
|
|
|
|
|
|
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
|
|
|
|
|
|
|
|
|
|
|
|
# Label alignment using onset detection
|
|
|
|
|
|
aligned_labels = labels
|
|
|
|
|
|
original_labels = labels
|
|
|
|
|
|
|
|
|
|
|
|
if enable_alignment and raw_samples and len(raw_samples) > 0:
|
|
|
|
|
|
print("[Storage] Aligning labels to EMG onset...")
|
|
|
|
|
|
|
|
|
|
|
|
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
|
|
|
|
|
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
aligned_labels = align_labels_with_onset(
|
|
|
|
|
|
labels=labels,
|
|
|
|
|
|
window_start_times=start_times,
|
|
|
|
|
|
raw_timestamps=raw_timestamps,
|
|
|
|
|
|
raw_channels=raw_channels,
|
|
|
|
|
|
sampling_rate=metadata.sampling_rate,
|
|
|
|
|
|
threshold_factor=ONSET_THRESHOLD,
|
|
|
|
|
|
search_ms=ONSET_SEARCH_MS
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
|
|
|
|
|
|
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
|
|
|
|
|
# Change 0: relabel the first LABEL_FORWARD_SHIFT_MS of each gesture
|
|
|
|
|
|
# run as 'rest' to remove the EMG onset transient from training data.
|
|
|
|
|
|
if LABEL_FORWARD_SHIFT_MS > 0:
|
|
|
|
|
|
shift_n = max(1, round(LABEL_FORWARD_SHIFT_MS / HOP_SIZE_MS))
|
|
|
|
|
|
shifted = list(aligned_labels)
|
|
|
|
|
|
for i in range(len(aligned_labels)):
|
|
|
|
|
|
if aligned_labels[i] == 'rest':
|
|
|
|
|
|
continue
|
|
|
|
|
|
# Count consecutive same-label windows immediately before this one
|
|
|
|
|
|
prior_same = 0
|
|
|
|
|
|
j = i - 1
|
|
|
|
|
|
while j >= 0 and aligned_labels[j] == aligned_labels[i]:
|
|
|
|
|
|
prior_same += 1
|
|
|
|
|
|
j -= 1
|
|
|
|
|
|
if prior_same < shift_n:
|
|
|
|
|
|
shifted[i] = 'rest'
|
|
|
|
|
|
n_shifted = sum(1 for a, b in zip(aligned_labels, shifted) if a != b)
|
|
|
|
|
|
aligned_labels = shifted
|
|
|
|
|
|
print(f"[Storage] Forward shift ({LABEL_FORWARD_SHIFT_MS}ms, "
|
|
|
|
|
|
f"{shift_n} windows): {n_shifted} relabeled as rest")
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
elif enable_alignment:
|
|
|
|
|
|
print("[Storage] Warning: No raw samples, skipping alignment")
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
with h5py.File(filepath, 'w') as f:
|
|
|
|
|
|
# Metadata as attributes
|
|
|
|
|
|
f.attrs['user_id'] = metadata.user_id
|
|
|
|
|
|
f.attrs['session_id'] = metadata.session_id
|
|
|
|
|
|
f.attrs['timestamp'] = metadata.timestamp
|
|
|
|
|
|
f.attrs['sampling_rate'] = metadata.sampling_rate
|
|
|
|
|
|
f.attrs['window_size_ms'] = metadata.window_size_ms
|
|
|
|
|
|
f.attrs['num_channels'] = metadata.num_channels
|
|
|
|
|
|
f.attrs['gestures'] = json.dumps(metadata.gestures)
|
|
|
|
|
|
f.attrs['notes'] = metadata.notes
|
|
|
|
|
|
f.attrs['num_windows'] = len(windows)
|
|
|
|
|
|
f.attrs['window_samples'] = window_samples
|
|
|
|
|
|
|
|
|
|
|
|
# Windows group
|
|
|
|
|
|
windows_grp = f.create_group('windows')
|
|
|
|
|
|
|
|
|
|
|
|
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
|
|
|
|
|
|
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# Store ALIGNED labels as primary (what training will use)
|
|
|
|
|
|
max_label_len = max(len(l) for l in aligned_labels)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
|
2026-01-27 20:12:13 -06:00
|
|
|
|
windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
|
|
|
|
|
|
|
|
|
|
|
|
# Also store original labels for reference/debugging
|
|
|
|
|
|
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
|
|
|
|
|
|
windows_grp.create_dataset('window_ids', data=window_ids)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Store trial_ids for proper train/test splitting (no trial leakage)
|
|
|
|
|
|
if trial_ids is not None:
|
|
|
|
|
|
trial_ids_arr = np.array(trial_ids, dtype=np.int32)
|
|
|
|
|
|
windows_grp.create_dataset('trial_ids', data=trial_ids_arr)
|
|
|
|
|
|
f.attrs['has_trial_ids'] = True
|
|
|
|
|
|
else:
|
|
|
|
|
|
f.attrs['has_trial_ids'] = False
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
windows_grp.create_dataset('start_times', data=start_times)
|
|
|
|
|
|
windows_grp.create_dataset('end_times', data=end_times)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
# Store alignment metadata
|
|
|
|
|
|
f.attrs['alignment_enabled'] = enable_alignment
|
|
|
|
|
|
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
if raw_samples:
|
|
|
|
|
|
raw_grp = f.create_group('raw_samples')
|
|
|
|
|
|
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
|
|
|
|
|
channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
|
|
|
|
|
|
raw_grp.create_dataset('timestamps', data=timestamps, compression='gzip')
|
|
|
|
|
|
raw_grp.create_dataset('channels', data=channels, compression='gzip')
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[Storage] Saved session to: {filepath}")
|
|
|
|
|
|
print(f"[Storage] File size: {filepath.stat().st_size / 1024:.1f} KB")
|
|
|
|
|
|
return filepath
|
|
|
|
|
|
|
|
|
|
|
|
def load_session(self, session_id: str) -> tuple[list[EMGWindow], list[str], SessionMetadata]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Load a collection session from HDF5.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
windows: List of EMGWindow objects (no label info)
|
|
|
|
|
|
labels: List of gesture labels, parallel to windows
|
|
|
|
|
|
metadata: Session metadata
|
|
|
|
|
|
"""
|
|
|
|
|
|
filepath = self.get_session_filepath(session_id)
|
|
|
|
|
|
if not filepath.exists():
|
|
|
|
|
|
raise FileNotFoundError(f"Session not found: {filepath}")
|
|
|
|
|
|
|
|
|
|
|
|
windows = []
|
|
|
|
|
|
labels_out = []
|
|
|
|
|
|
with h5py.File(filepath, 'r') as f:
|
|
|
|
|
|
metadata = SessionMetadata(
|
|
|
|
|
|
user_id=f.attrs['user_id'],
|
|
|
|
|
|
session_id=f.attrs['session_id'],
|
|
|
|
|
|
timestamp=f.attrs['timestamp'],
|
|
|
|
|
|
sampling_rate=int(f.attrs['sampling_rate']),
|
|
|
|
|
|
window_size_ms=int(f.attrs['window_size_ms']),
|
|
|
|
|
|
num_channels=int(f.attrs['num_channels']),
|
|
|
|
|
|
gestures=json.loads(f.attrs['gestures']),
|
|
|
|
|
|
notes=f.attrs.get('notes', '')
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
windows_grp = f['windows']
|
|
|
|
|
|
emg_data = windows_grp['emg_data'][:]
|
|
|
|
|
|
labels_raw = windows_grp['labels'][:]
|
|
|
|
|
|
window_ids = windows_grp['window_ids'][:]
|
|
|
|
|
|
start_times = windows_grp['start_times'][:]
|
|
|
|
|
|
end_times = windows_grp['end_times'][:]
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(emg_data)):
|
|
|
|
|
|
samples = []
|
|
|
|
|
|
window_data = emg_data[i]
|
|
|
|
|
|
for j in range(len(window_data)):
|
|
|
|
|
|
sample = EMGSample(
|
|
|
|
|
|
timestamp=start_times[i] + j * (1.0 / metadata.sampling_rate),
|
|
|
|
|
|
channels=window_data[j].tolist()
|
|
|
|
|
|
)
|
|
|
|
|
|
samples.append(sample)
|
|
|
|
|
|
|
|
|
|
|
|
# Decode label
|
|
|
|
|
|
label = labels_raw[i]
|
|
|
|
|
|
if isinstance(label, bytes):
|
|
|
|
|
|
label = label.decode('utf-8')
|
|
|
|
|
|
labels_out.append(label)
|
|
|
|
|
|
|
|
|
|
|
|
# Window contains NO label - labels stored separately
|
|
|
|
|
|
window = EMGWindow(
|
|
|
|
|
|
window_id=int(window_ids[i]),
|
|
|
|
|
|
start_time=float(start_times[i]),
|
|
|
|
|
|
end_time=float(end_times[i]),
|
|
|
|
|
|
samples=samples
|
|
|
|
|
|
)
|
|
|
|
|
|
windows.append(window)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[Storage] Loaded session: {session_id}")
|
|
|
|
|
|
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
|
|
|
|
|
|
return windows, labels_out, metadata
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Load a single session in ML-ready format: X, y, label_names.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
session_id: The session to load
|
|
|
|
|
|
filter_transitions: If True, remove windows in transition zones (default from config)
|
|
|
|
|
|
"""
|
2026-01-17 23:31:15 -06:00
|
|
|
|
filepath = self.get_session_filepath(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
with h5py.File(filepath, 'r') as f:
|
|
|
|
|
|
X = f['windows/emg_data'][:]
|
|
|
|
|
|
labels_raw = f['windows/labels'][:]
|
2026-01-27 20:12:13 -06:00
|
|
|
|
start_times = f['windows/start_times'][:]
|
|
|
|
|
|
end_times = f['windows/end_times'][:]
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
labels = []
|
|
|
|
|
|
for l in labels_raw:
|
|
|
|
|
|
if isinstance(l, bytes):
|
|
|
|
|
|
labels.append(l.decode('utf-8'))
|
|
|
|
|
|
else:
|
|
|
|
|
|
labels.append(l)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
|
|
|
|
|
|
|
|
|
|
|
|
# Apply transition filtering if enabled
|
|
|
|
|
|
if filter_transitions:
|
|
|
|
|
|
label_names_pre = sorted(set(labels))
|
|
|
|
|
|
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
|
|
|
|
|
|
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X, y_pre, labels, _ = filter_transition_windows(
|
2026-01-27 20:12:13 -06:00
|
|
|
|
X, y_pre, labels, start_times, end_times
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
label_names = sorted(set(labels))
|
|
|
|
|
|
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
|
|
|
|
|
|
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"[Storage] Labels: {label_names}")
|
|
|
|
|
|
return X, y, label_names
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, list[str], list[str]]:
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
Load ALL sessions combined into a single training dataset.
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
Args:
|
|
|
|
|
|
filter_transitions: If True, remove windows in transition zones (default from config)
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
Returns:
|
|
|
|
|
|
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
|
|
|
|
|
|
y: Combined labels as integers (n_total_windows,)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids: Combined trial IDs for proper train/test splitting (n_total_windows,)
|
|
|
|
|
|
session_indices: Per-window session index (0..n_sessions-1) for session normalization
|
2026-01-17 23:31:15 -06:00
|
|
|
|
label_names: Sorted list of unique gesture labels across all sessions
|
|
|
|
|
|
session_ids: List of session IDs that were loaded
|
|
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
|
ValueError: If no sessions found or sessions have incompatible shapes
|
|
|
|
|
|
"""
|
|
|
|
|
|
sessions = self.list_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
if not sessions:
|
|
|
|
|
|
raise ValueError("No sessions found to load!")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
|
2026-01-27 20:12:13 -06:00
|
|
|
|
if filter_transitions:
|
|
|
|
|
|
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
all_X = []
|
|
|
|
|
|
all_labels = []
|
2026-03-10 11:39:02 -05:00
|
|
|
|
all_trial_ids = [] # Track trial_ids for proper train/test splitting
|
|
|
|
|
|
all_session_indices = [] # Per-window session index for session normalization
|
2026-01-17 23:31:15 -06:00
|
|
|
|
loaded_sessions = []
|
|
|
|
|
|
reference_shape = None
|
2026-01-27 20:12:13 -06:00
|
|
|
|
total_removed = 0
|
|
|
|
|
|
total_original = 0
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_id_offset = 0 # Offset trial_ids across sessions to ensure global uniqueness
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
for session_id in sessions:
|
|
|
|
|
|
filepath = self.get_session_filepath(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
with h5py.File(filepath, 'r') as f:
|
|
|
|
|
|
X = f['windows/emg_data'][:]
|
|
|
|
|
|
labels_raw = f['windows/labels'][:]
|
2026-01-27 20:12:13 -06:00
|
|
|
|
start_times = f['windows/start_times'][:]
|
|
|
|
|
|
end_times = f['windows/end_times'][:]
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Load trial_ids if available (new files), otherwise generate from index
|
|
|
|
|
|
if 'windows/trial_ids' in f:
|
|
|
|
|
|
trial_ids = f['windows/trial_ids'][:] + trial_id_offset
|
|
|
|
|
|
else:
|
|
|
|
|
|
# Legacy file without trial_ids: assign unique trial_id per window
|
|
|
|
|
|
# This is conservative - treats each window as separate trial
|
|
|
|
|
|
print(f"[Storage] WARNING: {session_id} missing trial_ids, generating from indices")
|
|
|
|
|
|
trial_ids = np.arange(len(X), dtype=np.int32) + trial_id_offset
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# Validate shape compatibility
|
|
|
|
|
|
if reference_shape is None:
|
|
|
|
|
|
reference_shape = X.shape[1:] # (samples_per_window, channels)
|
|
|
|
|
|
elif X.shape[1:] != reference_shape:
|
|
|
|
|
|
print(f"[Storage] WARNING: Skipping {session_id} - incompatible shape {X.shape[1:]} vs {reference_shape}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# Decode labels
|
|
|
|
|
|
labels = []
|
|
|
|
|
|
for l in labels_raw:
|
|
|
|
|
|
if isinstance(l, bytes):
|
|
|
|
|
|
labels.append(l.decode('utf-8'))
|
|
|
|
|
|
else:
|
|
|
|
|
|
labels.append(l)
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
original_count = len(X)
|
|
|
|
|
|
total_original += original_count
|
|
|
|
|
|
|
|
|
|
|
|
# Apply transition filtering per session (each has its own gesture boundaries)
|
|
|
|
|
|
if filter_transitions:
|
|
|
|
|
|
# Need temporary y for filtering function
|
|
|
|
|
|
temp_label_names = sorted(set(labels))
|
|
|
|
|
|
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
|
|
|
|
|
|
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X, temp_y, labels, trial_ids = filter_transition_windows(
|
|
|
|
|
|
X, temp_y, labels, start_times, end_times, trial_ids=trial_ids
|
2026-01-27 20:12:13 -06:00
|
|
|
|
)
|
|
|
|
|
|
total_removed += original_count - len(X)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
current_session_idx = len(all_X) # 0-based index before appending
|
2026-01-17 23:31:15 -06:00
|
|
|
|
all_X.append(X)
|
|
|
|
|
|
all_labels.extend(labels)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
all_trial_ids.extend(trial_ids.tolist())
|
|
|
|
|
|
all_session_indices.extend([current_session_idx] * len(X))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
loaded_sessions.append(session_id)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
|
|
|
|
|
# Update trial_id offset for next session (ensure global uniqueness)
|
|
|
|
|
|
if len(trial_ids) > 0:
|
|
|
|
|
|
trial_id_offset = max(trial_ids) + 1
|
|
|
|
|
|
|
2026-01-27 20:12:13 -06:00
|
|
|
|
print(f"[Storage] - {session_id}: {len(X)} windows" +
|
|
|
|
|
|
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
if not all_X:
|
|
|
|
|
|
raise ValueError("No compatible sessions found!")
|
|
|
|
|
|
|
|
|
|
|
|
# Combine all data
|
|
|
|
|
|
X_combined = np.concatenate(all_X, axis=0)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
trial_ids_combined = np.array(all_trial_ids, dtype=np.int32)
|
|
|
|
|
|
session_indices_combined = np.array(all_session_indices, dtype=np.int32)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Create unified label mapping across all sessions
|
|
|
|
|
|
label_names = sorted(set(all_labels))
|
|
|
|
|
|
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
|
|
|
|
|
|
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
n_unique_trials = len(np.unique(trial_ids_combined))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print(f"[Storage] Unique trials: {n_unique_trials} (for proper train/test splitting)")
|
2026-01-27 20:12:13 -06:00
|
|
|
|
if filter_transitions and total_removed > 0:
|
|
|
|
|
|
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"[Storage] Labels: {label_names}")
|
|
|
|
|
|
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
return X_combined, y_combined, trial_ids_combined, session_indices_combined, label_names, loaded_sessions
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
def list_sessions(self) -> list[str]:
|
|
|
|
|
|
"""List all available session IDs."""
|
|
|
|
|
|
return sorted([f.stem for f in self.data_dir.glob("*.hdf5")])
|
|
|
|
|
|
|
|
|
|
|
|
def get_session_info(self, session_id: str) -> dict:
|
|
|
|
|
|
"""Get quick info about a session without loading all data."""
|
|
|
|
|
|
filepath = self.get_session_filepath(session_id)
|
|
|
|
|
|
with h5py.File(filepath, 'r') as f:
|
|
|
|
|
|
return {
|
|
|
|
|
|
'user_id': f.attrs['user_id'],
|
|
|
|
|
|
'timestamp': f.attrs['timestamp'],
|
|
|
|
|
|
'num_windows': f.attrs['num_windows'],
|
|
|
|
|
|
'gestures': json.loads(f.attrs['gestures']),
|
|
|
|
|
|
'sampling_rate': f.attrs['sampling_rate'],
|
|
|
|
|
|
'window_size_ms': f.attrs['window_size_ms'],
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# COLLECTION SESSION (Requires ESP32 hardware)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def run_labeled_collection_demo():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Run a labeled EMG collection session:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
1. Connect to ESP32 via serial
|
|
|
|
|
|
2. Prompt scheduler guides the user through gestures
|
|
|
|
|
|
3. EMG stream collects real signals
|
|
|
|
|
|
4. Windower groups samples into fixed-size windows
|
|
|
|
|
|
5. Labels are assigned based on which prompt was active
|
|
|
|
|
|
6. Session is saved to HDF5 with user ID
|
|
|
|
|
|
|
|
|
|
|
|
REQUIRES: ESP32 hardware connected via USB.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
print("\n" + "=" * 60)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print("LABELED EMG COLLECTION (ESP32 Required)")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
# Get user ID
|
|
|
|
|
|
user_id = input("\nEnter your user ID (e.g., user_001): ").strip()
|
|
|
|
|
|
if not user_id:
|
|
|
|
|
|
user_id = USER_ID # Fall back to default
|
|
|
|
|
|
print(f" Using default: {user_id}")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f" User ID: {user_id}")
|
|
|
|
|
|
|
2026-01-19 22:24:04 -06:00
|
|
|
|
# Define gestures to collect (names match ESP32 gesture definitions)
|
|
|
|
|
|
gestures = ["open", "fist", "hook_em", "thumbs_up"]
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Create the prompt scheduler
|
|
|
|
|
|
scheduler = PromptScheduler(
|
|
|
|
|
|
gestures=gestures,
|
|
|
|
|
|
hold_sec=GESTURE_HOLD_SEC,
|
|
|
|
|
|
rest_sec=REST_BETWEEN_SEC,
|
|
|
|
|
|
reps=REPS_PER_GESTURE
|
|
|
|
|
|
)
|
|
|
|
|
|
scheduler.print_schedule()
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Connect to ESP32
|
|
|
|
|
|
print("\n[Connecting to ESP32...]")
|
|
|
|
|
|
stream = RealSerialStream()
|
|
|
|
|
|
try:
|
|
|
|
|
|
stream.connect(timeout=5.0)
|
|
|
|
|
|
print(f" Connected: {stream.device_info}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f" ERROR: Failed to connect to ESP32: {e}")
|
|
|
|
|
|
print(" Make sure the ESP32 is connected and firmware is flashed.")
|
|
|
|
|
|
return [], []
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
parser = EMGParser(num_channels=NUM_CHANNELS)
|
|
|
|
|
|
windower = Windower(
|
|
|
|
|
|
window_size_ms=WINDOW_SIZE_MS,
|
|
|
|
|
|
sample_rate=SAMPLING_RATE_HZ,
|
2026-03-10 11:39:02 -05:00
|
|
|
|
hop_size_ms=HOP_SIZE_MS
|
2026-01-17 23:31:15 -06:00
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Storage for windows, labels, and trial_ids (kept separate to enforce training/inference separation)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
collected_windows: list[EMGWindow] = []
|
|
|
|
|
|
collected_labels: list[str] = []
|
2026-03-10 11:39:02 -05:00
|
|
|
|
collected_trial_ids: list[int] = [] # Track trial_id for proper train/test splitting
|
2026-01-17 23:31:15 -06:00
|
|
|
|
last_prompt_name = None
|
|
|
|
|
|
|
|
|
|
|
|
# Start collection
|
|
|
|
|
|
input("\nPress ENTER to start collection session...")
|
|
|
|
|
|
stream.start()
|
|
|
|
|
|
scheduler.start_session()
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 40)
|
|
|
|
|
|
print("COLLECTING... Watch the prompts!")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
while not scheduler.is_session_complete():
|
|
|
|
|
|
# Get current prompt
|
|
|
|
|
|
prompt = scheduler.get_current_prompt()
|
|
|
|
|
|
|
|
|
|
|
|
# Display prompt changes
|
|
|
|
|
|
if prompt and prompt.gesture_name != last_prompt_name:
|
|
|
|
|
|
elapsed = scheduler.get_elapsed_time()
|
|
|
|
|
|
if prompt.gesture_name == "rest":
|
|
|
|
|
|
print(f"\n [{elapsed:5.1f}s] >>> REST <<<")
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"\n [{elapsed:5.1f}s] >>> {prompt.gesture_name.upper()} <<<")
|
|
|
|
|
|
last_prompt_name = prompt.gesture_name
|
|
|
|
|
|
|
|
|
|
|
|
# Read and parse data
|
|
|
|
|
|
line = stream.readline()
|
|
|
|
|
|
if line:
|
|
|
|
|
|
sample = parser.parse_line(line)
|
|
|
|
|
|
if sample:
|
|
|
|
|
|
# Try to form a window
|
|
|
|
|
|
window = windower.add_sample(sample)
|
|
|
|
|
|
if window:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Store window, label, and trial_id separately (training/inference separation)
|
|
|
|
|
|
# Shift label lookup forward to align with actual muscle activation
|
|
|
|
|
|
label_time = window.start_time + LABEL_SHIFT_MS / 1000.0
|
|
|
|
|
|
label = scheduler.get_label_for_time(label_time)
|
|
|
|
|
|
trial_id = scheduler.get_trial_id_for_time(label_time)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
collected_windows.append(window)
|
|
|
|
|
|
collected_labels.append(label)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
collected_trial_ids.append(trial_id)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
print("\n[Interrupted by user]")
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
stream.stop()
|
2026-03-10 11:39:02 -05:00
|
|
|
|
stream.disconnect()
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Report results
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("COLLECTION COMPLETE")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
print(f"Total windows collected: {len(collected_windows)}")
|
|
|
|
|
|
print(f"Parse errors: {parser.parse_errors}")
|
|
|
|
|
|
|
|
|
|
|
|
# Count labels (from separate labels list, not from windows)
|
|
|
|
|
|
label_counts = {}
|
|
|
|
|
|
for label in collected_labels:
|
|
|
|
|
|
label_counts[label] = label_counts.get(label, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nWindows per label:")
|
|
|
|
|
|
for label, count in sorted(label_counts.items()):
|
|
|
|
|
|
print(f" {label}: {count}")
|
|
|
|
|
|
|
|
|
|
|
|
# Show example windows
|
|
|
|
|
|
if collected_windows:
|
|
|
|
|
|
print(f"\nExample windows:")
|
|
|
|
|
|
for i, w in enumerate(collected_windows[:3]):
|
|
|
|
|
|
data = w.to_numpy()
|
|
|
|
|
|
print(f" Window {w.window_id}: label='{collected_labels[i]}', "
|
|
|
|
|
|
f"samples={len(w.samples)}, "
|
|
|
|
|
|
f"ch0_mean={data[:, 0].mean():.1f}")
|
|
|
|
|
|
|
|
|
|
|
|
# Show one window from each gesture type
|
|
|
|
|
|
print(f"\nSignal comparison (channel 0 std dev by gesture):")
|
|
|
|
|
|
for label in sorted(label_counts.keys()):
|
|
|
|
|
|
# Get indices where label matches
|
|
|
|
|
|
indices = [i for i, l in enumerate(collected_labels) if l == label]
|
|
|
|
|
|
if indices:
|
|
|
|
|
|
all_ch0 = np.concatenate([collected_windows[i].get_channel(0) for i in indices])
|
|
|
|
|
|
print(f" {label}: std={all_ch0.std():.1f}")
|
|
|
|
|
|
|
|
|
|
|
|
# --- Save the session ---
|
|
|
|
|
|
if collected_windows:
|
|
|
|
|
|
save_choice = input("\nSave this session? (y/n): ").strip().lower()
|
|
|
|
|
|
if save_choice == 'y':
|
|
|
|
|
|
storage = SessionStorage()
|
|
|
|
|
|
session_id = storage.generate_session_id(user_id)
|
|
|
|
|
|
|
|
|
|
|
|
metadata = SessionMetadata(
|
|
|
|
|
|
user_id=user_id,
|
|
|
|
|
|
session_id=session_id,
|
|
|
|
|
|
timestamp=datetime.now().isoformat(),
|
|
|
|
|
|
sampling_rate=SAMPLING_RATE_HZ,
|
|
|
|
|
|
window_size_ms=WINDOW_SIZE_MS,
|
|
|
|
|
|
num_channels=NUM_CHANNELS,
|
|
|
|
|
|
gestures=gestures,
|
|
|
|
|
|
notes=""
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Pass windows, labels, and trial_ids separately (enforces separation)
|
|
|
|
|
|
filepath = storage.save_session(
|
|
|
|
|
|
collected_windows, collected_labels, metadata,
|
|
|
|
|
|
trial_ids=collected_trial_ids
|
|
|
|
|
|
)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"\nSession saved! ID: {session_id}")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
return collected_windows, collected_labels, collected_trial_ids
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# INSPECT SESSIONS (Load and view saved sessions)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def run_storage_demo():
|
|
|
|
|
|
"""Demonstrates loading and inspecting saved sessions."""
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("INSPECT SAVED SESSIONS")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
storage = SessionStorage()
|
|
|
|
|
|
sessions = storage.list_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
if not sessions:
|
|
|
|
|
|
print("\nNo saved sessions found!")
|
|
|
|
|
|
print(f"Run option 2 first to collect and save a session.")
|
|
|
|
|
|
print(f"Sessions are stored in: {storage.data_dir.absolute()}")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nFound {len(sessions)} saved session(s):")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
for i, session_id in enumerate(sessions):
|
|
|
|
|
|
info = storage.get_session_info(session_id)
|
|
|
|
|
|
print(f"\n [{i+1}] {session_id}")
|
|
|
|
|
|
print(f" User: {info['user_id']}")
|
|
|
|
|
|
print(f" Time: {info['timestamp']}")
|
|
|
|
|
|
print(f" Windows: {info['num_windows']}")
|
|
|
|
|
|
print(f" Gestures: {info['gestures']}")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "-" * 40)
|
|
|
|
|
|
choice = input("Enter session number to load (or 'q' to quit): ").strip()
|
|
|
|
|
|
|
|
|
|
|
|
if choice.lower() == 'q':
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
idx = int(choice) - 1
|
|
|
|
|
|
if idx < 0 or idx >= len(sessions):
|
|
|
|
|
|
print("Invalid selection!")
|
|
|
|
|
|
return None
|
|
|
|
|
|
session_id = sessions[idx]
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
print("Invalid input!")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 60}")
|
|
|
|
|
|
print(f"LOADING SESSION: {session_id}")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
# Labels returned separately from windows (enforces training/inference separation)
|
|
|
|
|
|
windows, labels, metadata = storage.load_session(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nMetadata:")
|
|
|
|
|
|
print(f" User: {metadata.user_id}")
|
|
|
|
|
|
print(f" Timestamp: {metadata.timestamp}")
|
|
|
|
|
|
print(f" Sampling rate: {metadata.sampling_rate} Hz")
|
|
|
|
|
|
print(f" Window size: {metadata.window_size_ms} ms")
|
|
|
|
|
|
print(f" Channels: {metadata.num_channels}")
|
|
|
|
|
|
print(f" Gestures: {metadata.gestures}")
|
|
|
|
|
|
|
|
|
|
|
|
# Count from separate labels list
|
|
|
|
|
|
label_counts = {}
|
|
|
|
|
|
for label in labels:
|
|
|
|
|
|
label_counts[label] = label_counts.get(label, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nLabel distribution:")
|
|
|
|
|
|
for label, count in sorted(label_counts.items()):
|
|
|
|
|
|
print(f" {label}: {count} windows")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'-' * 40}")
|
|
|
|
|
|
print("LOADING FOR MACHINE LEARNING")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
X, y, label_names = storage.load_for_training(session_id)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nData shapes:")
|
|
|
|
|
|
print(f" X (features): {X.shape}")
|
|
|
|
|
|
print(f" - {X.shape[0]} windows")
|
|
|
|
|
|
print(f" - {X.shape[1]} samples per window")
|
|
|
|
|
|
print(f" - {X.shape[2]} channels")
|
|
|
|
|
|
print(f" y (labels): {y.shape}")
|
|
|
|
|
|
print(f" Label mapping: {dict(enumerate(label_names))}")
|
|
|
|
|
|
|
|
|
|
|
|
# --- Feature Extraction & Visualization ---
|
|
|
|
|
|
print(f"\n{'-' * 40}")
|
|
|
|
|
|
print("EXTRACTING FEATURES FOR VISUALIZATION")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
# Note: Per-window centering is done inside EMGFeatureExtractor.
|
|
|
|
|
|
# This is the correct approach for real-time inference (causal, no future data).
|
|
|
|
|
|
# Global centering across all windows would leak information and not work in real-time.
|
|
|
|
|
|
|
|
|
|
|
|
extractor = EMGFeatureExtractor()
|
|
|
|
|
|
n_windows = X.shape[0]
|
|
|
|
|
|
n_channels = X.shape[2]
|
|
|
|
|
|
|
|
|
|
|
|
# Extract features per channel: shape (n_windows, n_channels, 4)
|
|
|
|
|
|
# Features order: [rms, wl, zc, ssc]
|
|
|
|
|
|
features_by_channel = np.zeros((n_windows, n_channels, 4))
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(n_windows):
|
|
|
|
|
|
for ch in range(n_channels):
|
|
|
|
|
|
ch_features = extractor.extract_features_single_channel(X[i, :, ch])
|
|
|
|
|
|
features_by_channel[i, ch, 0] = ch_features['rms']
|
|
|
|
|
|
features_by_channel[i, ch, 1] = ch_features['wl']
|
|
|
|
|
|
features_by_channel[i, ch, 2] = ch_features['zc']
|
|
|
|
|
|
features_by_channel[i, ch, 3] = ch_features['ssc']
|
|
|
|
|
|
|
|
|
|
|
|
print(f" Extracted features for {n_windows} windows, {n_channels} channels")
|
|
|
|
|
|
print(f" (Per-window centering applied inside feature extractor)")
|
|
|
|
|
|
|
|
|
|
|
|
# Create time axis (window indices as proxy for time)
|
|
|
|
|
|
time_axis = np.arange(n_windows)
|
|
|
|
|
|
|
|
|
|
|
|
# Find gesture transition points (where label changes)
|
|
|
|
|
|
transitions = []
|
|
|
|
|
|
current_label = y[0]
|
|
|
|
|
|
for i in range(1, len(y)):
|
|
|
|
|
|
if y[i] != current_label:
|
|
|
|
|
|
transitions.append((i, label_names[y[i]]))
|
|
|
|
|
|
current_label = y[i]
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Define colors for gesture markers (matches GUI color scheme)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
def get_gesture_color(name):
|
2026-03-10 11:39:02 -05:00
|
|
|
|
name_lower = name.lower()
|
|
|
|
|
|
if 'rest' in name_lower:
|
2026-01-17 23:31:15 -06:00
|
|
|
|
return 'gray'
|
2026-03-10 11:39:02 -05:00
|
|
|
|
elif 'open' in name_lower:
|
|
|
|
|
|
return 'cyan'
|
|
|
|
|
|
elif 'fist' in name_lower:
|
|
|
|
|
|
return 'blue'
|
|
|
|
|
|
elif 'hook' in name_lower:
|
2026-01-17 23:31:15 -06:00
|
|
|
|
return 'orange'
|
2026-03-10 11:39:02 -05:00
|
|
|
|
elif 'thumb' in name_lower:
|
|
|
|
|
|
return 'green'
|
2026-01-17 23:31:15 -06:00
|
|
|
|
return 'red'
|
|
|
|
|
|
|
|
|
|
|
|
feature_titles = ['RMS', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)']
|
|
|
|
|
|
feature_colors = ['red', 'blue', 'green', 'purple']
|
|
|
|
|
|
feature_ylabels = ['Amplitude', 'WL (a.u.)', 'Count', 'Count']
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figure 1: Raw EMG Signal ---
|
|
|
|
|
|
fig_raw, axes_raw = plt.subplots(2, 2, figsize=(14, 8), sharex=True)
|
|
|
|
|
|
axes_raw = axes_raw.flatten()
|
|
|
|
|
|
|
|
|
|
|
|
# Concatenate all windows to show continuous signal
|
|
|
|
|
|
samples_per_window = X.shape[1]
|
|
|
|
|
|
total_samples = n_windows * samples_per_window
|
|
|
|
|
|
raw_time = np.arange(total_samples)
|
|
|
|
|
|
|
|
|
|
|
|
for ch in range(n_channels):
|
|
|
|
|
|
ax = axes_raw[ch]
|
|
|
|
|
|
|
|
|
|
|
|
# Flatten all windows into continuous signal for this channel
|
|
|
|
|
|
# Center per-channel for visualization only (subtract channel mean)
|
|
|
|
|
|
raw_signal = X[:, :, ch].flatten()
|
|
|
|
|
|
raw_signal_centered = raw_signal - raw_signal.mean()
|
|
|
|
|
|
ax.plot(raw_time, raw_signal_centered, linewidth=0.5, color='black')
|
|
|
|
|
|
ax.axhline(0, color='gray', linestyle='-', linewidth=0.5, alpha=0.5)
|
|
|
|
|
|
ax.set_title(f"Channel {ch}", fontsize=11)
|
|
|
|
|
|
ax.set_ylabel("Amplitude (centered)")
|
|
|
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# Add vertical lines at gesture transitions (scaled to sample index)
|
|
|
|
|
|
for trans_idx, trans_label in transitions:
|
|
|
|
|
|
sample_idx = trans_idx * samples_per_window
|
|
|
|
|
|
color = get_gesture_color(trans_label)
|
|
|
|
|
|
ax.axvline(sample_idx, color=color, linestyle='--', alpha=0.6, linewidth=1)
|
|
|
|
|
|
|
|
|
|
|
|
# Add legend for gesture colors
|
|
|
|
|
|
legend_elements = []
|
|
|
|
|
|
for label_name in label_names:
|
|
|
|
|
|
color = get_gesture_color(label_name)
|
|
|
|
|
|
legend_elements.append(plt.Line2D([0], [0], color=color, linestyle='--', label=label_name))
|
|
|
|
|
|
axes_raw[0].legend(handles=legend_elements, loc='upper right', fontsize=8)
|
|
|
|
|
|
|
|
|
|
|
|
fig_raw.suptitle("Raw EMG Signal (Centered for Display) - All Channels", fontsize=14, fontweight='bold')
|
|
|
|
|
|
fig_raw.supxlabel("Sample Index")
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figures 2-5: Feature plots (one per feature type) ---
|
|
|
|
|
|
for feat_idx, feat_title in enumerate(feature_titles):
|
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex=True)
|
|
|
|
|
|
axes = axes.flatten()
|
|
|
|
|
|
|
|
|
|
|
|
for ch in range(n_channels):
|
|
|
|
|
|
ax = axes[ch]
|
|
|
|
|
|
|
|
|
|
|
|
# Plot feature as line graph
|
|
|
|
|
|
feat_data = features_by_channel[:, ch, feat_idx]
|
|
|
|
|
|
ax.plot(time_axis, feat_data, linewidth=1, color=feature_colors[feat_idx])
|
|
|
|
|
|
ax.set_title(f"Channel {ch}", fontsize=11)
|
|
|
|
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# Add vertical lines at gesture transitions
|
|
|
|
|
|
for trans_idx, trans_label in transitions:
|
|
|
|
|
|
color = get_gesture_color(trans_label)
|
|
|
|
|
|
ax.axvline(trans_idx, color=color, linestyle='--', alpha=0.6, linewidth=1)
|
|
|
|
|
|
|
|
|
|
|
|
# Add legend for gesture colors
|
|
|
|
|
|
legend_elements = []
|
|
|
|
|
|
for label_name in label_names:
|
|
|
|
|
|
color = get_gesture_color(label_name)
|
|
|
|
|
|
legend_elements.append(plt.Line2D([0], [0], color=color, linestyle='--', label=label_name))
|
|
|
|
|
|
axes[0].legend(handles=legend_elements, loc='upper right', fontsize=8)
|
|
|
|
|
|
|
|
|
|
|
|
fig.suptitle(f"{feat_title} - All Channels", fontsize=14, fontweight='bold')
|
|
|
|
|
|
fig.supxlabel("Window Index (Time)")
|
|
|
|
|
|
fig.supylabel(feature_ylabels[feat_idx])
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
print(f"\n Displayed 5 figures: Raw EMG + 4 features (close windows to continue)")
|
|
|
|
|
|
|
|
|
|
|
|
return X, y, label_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# FEATURE EXTRACTION (Time-domain features for EMG)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class EMGFeatureExtractor:
|
|
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Extracts time-domain and frequency-domain features from EMG windows.
|
|
|
|
|
|
|
|
|
|
|
|
Change 1 — expanded feature set (expanded=True, default):
|
|
|
|
|
|
Per channel (20 features):
|
|
|
|
|
|
TD-4 (legacy): RMS, WL, ZC, SSC
|
|
|
|
|
|
TD extended: MAV, VAR, IEMG, WAMP
|
|
|
|
|
|
AR model: AR1, AR2, AR3, AR4 (4th-order via Yule-Walker)
|
|
|
|
|
|
Frequency: MNF, MDF, PKF, MNP (mean/median/peak freq, mean power)
|
|
|
|
|
|
Band power: BP0(20-80Hz), BP1(80-150Hz), BP2(150-250Hz), BP3(250-450Hz)
|
|
|
|
|
|
Cross-channel (cross_channel=True, default):
|
|
|
|
|
|
For each channel pair (i,j): Pearson correlation, log-RMS ratio, covariance
|
|
|
|
|
|
For 3 hand channels: 3 pairs × 3 = 9 cross-channel features
|
|
|
|
|
|
Total for HAND_CHANNELS=[0,1,2]: 20×3 + 9 = 69 features
|
|
|
|
|
|
|
|
|
|
|
|
Legacy mode (expanded=False): 4 features per channel only (RMS, WL, ZC, SSC).
|
|
|
|
|
|
Old pickled models automatically use legacy mode via __setstate__.
|
|
|
|
|
|
|
|
|
|
|
|
IMPORTANT: Per-window DC removal (mean subtraction) is applied before all
|
|
|
|
|
|
features. This is causal (uses only data within the current window).
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Feature key ordering — determines output vector layout
|
|
|
|
|
|
_LEGACY_KEYS = ['rms', 'wl', 'zc', 'ssc']
|
|
|
|
|
|
_EXPANDED_KEYS = [
|
|
|
|
|
|
'rms', 'wl', 'zc', 'ssc', # TD-4
|
|
|
|
|
|
'mav', 'var', 'iemg', 'wamp', # TD extended
|
|
|
|
|
|
'ar1', 'ar2', 'ar3', 'ar4', # AR(4) model
|
|
|
|
|
|
'mnf', 'mdf', 'pkf', 'mnp', # Frequency descriptors
|
|
|
|
|
|
'bp0', 'bp1', 'bp2', 'bp3', # Band powers
|
|
|
|
|
|
]
|
|
|
|
|
|
# Keys that are amplitude-dependent and should be divided by norm_factor
|
|
|
|
|
|
_NORM_KEYS = {'rms', 'wl', 'mav', 'iemg'}
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
|
zc_threshold_percent: float = 0.1,
|
|
|
|
|
|
ssc_threshold_percent: float = 0.1,
|
|
|
|
|
|
channels: Optional[list[int]] = None,
|
|
|
|
|
|
normalize: bool = True,
|
|
|
|
|
|
expanded: bool = True,
|
|
|
|
|
|
cross_channel: bool = True,
|
|
|
|
|
|
fft_n: int = 256,
|
|
|
|
|
|
fs: float = float(SAMPLING_RATE_HZ),
|
|
|
|
|
|
reinhard: bool = False,
|
|
|
|
|
|
bandpass: bool = True):
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
Args:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
zc_threshold_percent: ZC/WAMP threshold as fraction of RMS.
|
|
|
|
|
|
ssc_threshold_percent: SSC threshold as fraction of RMS squared.
|
|
|
|
|
|
channels: Channel indices to extract features from; None = all.
|
|
|
|
|
|
normalize: Divide amplitude-dependent features by total RMS across
|
|
|
|
|
|
channels (makes model robust to impedance shifts).
|
|
|
|
|
|
expanded: Use full 20-feature/channel set (Change 1). False = legacy
|
|
|
|
|
|
4-feature/channel set for backward compatibility.
|
|
|
|
|
|
cross_channel: Append pairwise cross-channel features (correlation,
|
|
|
|
|
|
log-RMS ratio, covariance). Only when expanded=True.
|
|
|
|
|
|
fft_n: FFT size for frequency features (zero-pads window if needed).
|
|
|
|
|
|
fs: Sampling frequency in Hz (used for frequency axis).
|
|
|
|
|
|
reinhard: Change 4 — apply Reinhard tone-mapping (64·x/(32+|x|))
|
|
|
|
|
|
before feature extraction. Must match MODEL_USE_REINHARD in
|
|
|
|
|
|
firmware model_weights.h. Default False.
|
|
|
|
|
|
bandpass: Apply 20-450 Hz bandpass filter before feature extraction.
|
|
|
|
|
|
Must be True to match firmware IIR bandpass. Default True.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
self.zc_threshold_percent = zc_threshold_percent
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.ssc_threshold_percent = ssc_threshold_percent
|
2026-03-10 11:39:02 -05:00
|
|
|
|
self.channels = channels
|
|
|
|
|
|
self.normalize = normalize
|
|
|
|
|
|
self.expanded = expanded
|
|
|
|
|
|
self.cross_channel = cross_channel
|
|
|
|
|
|
self.fft_n = fft_n
|
|
|
|
|
|
self.fs = fs
|
|
|
|
|
|
self.reinhard = reinhard
|
|
|
|
|
|
self.bandpass = bandpass
|
|
|
|
|
|
|
|
|
|
|
|
# Pre-compute bandpass SOS coefficients (2nd-order Butterworth, 20-450 Hz)
|
|
|
|
|
|
# Matches firmware IIR biquad bandpass in inference.c
|
|
|
|
|
|
if self.bandpass:
|
|
|
|
|
|
nyq = self.fs / 2.0
|
|
|
|
|
|
self._bp_sos = butter(2, [20.0 / nyq, 450.0 / nyq], btype='band', output='sos')
|
|
|
|
|
|
else:
|
|
|
|
|
|
self._bp_sos = None
|
|
|
|
|
|
|
|
|
|
|
|
def __setstate__(self, state: dict):
|
|
|
|
|
|
"""Restore pickle and add defaults for attributes added in Change 1+."""
|
|
|
|
|
|
self.__dict__.update(state)
|
|
|
|
|
|
if 'expanded' not in state: self.expanded = False
|
|
|
|
|
|
if 'cross_channel' not in state: self.cross_channel = False
|
|
|
|
|
|
if 'fft_n' not in state: self.fft_n = 256
|
|
|
|
|
|
if 'fs' not in state: self.fs = float(SAMPLING_RATE_HZ)
|
|
|
|
|
|
if 'reinhard' not in state: self.reinhard = False
|
|
|
|
|
|
if 'bandpass' not in state: self.bandpass = False
|
|
|
|
|
|
# Reconstruct SOS coefficients for bandpass filter
|
|
|
|
|
|
if self.bandpass:
|
|
|
|
|
|
nyq = self.fs / 2.0
|
|
|
|
|
|
self._bp_sos = butter(2, [20.0 / nyq, 450.0 / nyq], btype='band', output='sos')
|
|
|
|
|
|
else:
|
|
|
|
|
|
self._bp_sos = None
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Private helpers
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _ar_coefficients(signal: np.ndarray, order: int = 4) -> np.ndarray:
|
|
|
|
|
|
"""4th-order AR coefficients via Yule-Walker (autocorrelation method)."""
|
|
|
|
|
|
n = len(signal)
|
|
|
|
|
|
r = np.array([float(np.dot(signal[:n - k], signal[k:])) / n
|
|
|
|
|
|
for k in range(order + 1)])
|
|
|
|
|
|
T = np.array([[r[abs(i - j)] for j in range(order)] for i in range(order)])
|
|
|
|
|
|
try:
|
|
|
|
|
|
return np.linalg.solve(T, r[1:order + 1])
|
|
|
|
|
|
except np.linalg.LinAlgError:
|
|
|
|
|
|
return np.zeros(order)
|
|
|
|
|
|
|
|
|
|
|
|
def _spectral_features(self, signal: np.ndarray) -> tuple:
|
|
|
|
|
|
"""MNF, MDF, PKF, MNP, BP0-BP3 via rfft (zero-padded to fft_n)."""
|
|
|
|
|
|
spec = np.abs(np.fft.rfft(signal, n=self.fft_n)) ** 2
|
|
|
|
|
|
freqs = np.fft.rfftfreq(self.fft_n, d=1.0 / self.fs)
|
|
|
|
|
|
total = float(np.sum(spec)) + 1e-10
|
|
|
|
|
|
|
|
|
|
|
|
mnf = float(np.dot(freqs, spec) / total)
|
|
|
|
|
|
|
|
|
|
|
|
cumsum = np.cumsum(spec)
|
|
|
|
|
|
mid_idx = int(np.searchsorted(cumsum, total / 2.0))
|
|
|
|
|
|
mdf = float(freqs[min(mid_idx, len(freqs) - 1)])
|
|
|
|
|
|
|
|
|
|
|
|
pkf = float(freqs[int(np.argmax(spec))])
|
|
|
|
|
|
mnp = float(total / len(spec))
|
|
|
|
|
|
|
|
|
|
|
|
def _bp(f_lo: float, f_hi: float) -> float:
|
|
|
|
|
|
mask = (freqs >= f_lo) & (freqs < f_hi)
|
|
|
|
|
|
return float(np.sum(spec[mask]) / total)
|
|
|
|
|
|
|
|
|
|
|
|
return mnf, mdf, pkf, mnp, _bp(20, 80), _bp(80, 150), _bp(150, 250), _bp(250, 450)
|
|
|
|
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------
|
|
|
|
|
|
# Public API
|
|
|
|
|
|
# ------------------------------------------------------------------
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
def extract_features_single_channel(self, signal: np.ndarray) -> dict:
|
|
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Extract features from a single, already-selected channel.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Returns a dict with 4 keys (legacy) or 20 keys (expanded).
|
|
|
|
|
|
Bandpass filter (if enabled) + per-window DC removal are applied first.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Bandpass filter to match firmware IIR (20-450 Hz, 2nd-order Butterworth).
|
|
|
|
|
|
# Uses sosfilt (causal) with sosfilt_zi to initialise the filter state
|
|
|
|
|
|
# at the signal's DC level, avoiding large startup transients on the
|
|
|
|
|
|
# short 150-sample windows.
|
|
|
|
|
|
if self.bandpass and self._bp_sos is not None:
|
|
|
|
|
|
zi = sosfilt_zi(self._bp_sos) * signal[0]
|
|
|
|
|
|
signal, _ = sosfilt(self._bp_sos, signal, zi=zi)
|
|
|
|
|
|
|
|
|
|
|
|
signal = signal - np.mean(signal) # DC removal
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Change 4 — Reinhard tone-mapping (compresses large spikes)
|
|
|
|
|
|
if self.reinhard:
|
|
|
|
|
|
signal = 64.0 * signal / (32.0 + np.abs(signal))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
rms = float(np.sqrt(np.mean(signal ** 2)))
|
|
|
|
|
|
wl = float(np.sum(np.abs(np.diff(signal))))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
zc_thresh = self.zc_threshold_percent * rms
|
2026-01-17 23:31:15 -06:00
|
|
|
|
ssc_thresh = (self.ssc_threshold_percent * rms) ** 2
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
diffs = np.diff(signal)
|
|
|
|
|
|
sign_chg = signal[:-1] * signal[1:] < 0
|
|
|
|
|
|
zc = int(np.sum(sign_chg & (np.abs(diffs) > zc_thresh)))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
dl = signal[1:-1] - signal[:-2]
|
|
|
|
|
|
dr = signal[1:-1] - signal[2:]
|
|
|
|
|
|
ssc = int(np.sum((dl * dr) > ssc_thresh))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
feats: dict = {'rms': rms, 'wl': wl, 'zc': float(zc), 'ssc': float(ssc)}
|
|
|
|
|
|
|
|
|
|
|
|
if self.expanded:
|
|
|
|
|
|
mav = float(np.mean(np.abs(signal)))
|
|
|
|
|
|
var = float(np.var(signal))
|
|
|
|
|
|
iemg = float(np.sum(np.abs(signal)))
|
|
|
|
|
|
wamp = int(np.sum(np.abs(diffs) > zc_thresh))
|
|
|
|
|
|
|
|
|
|
|
|
ar = self._ar_coefficients(signal, order=4)
|
|
|
|
|
|
mnf, mdf, pkf, mnp, bp0, bp1, bp2, bp3 = self._spectral_features(signal)
|
|
|
|
|
|
|
|
|
|
|
|
feats.update({
|
|
|
|
|
|
'mav': mav, 'var': var, 'iemg': iemg, 'wamp': float(wamp),
|
|
|
|
|
|
'ar1': float(ar[0]), 'ar2': float(ar[1]),
|
|
|
|
|
|
'ar3': float(ar[2]), 'ar4': float(ar[3]),
|
|
|
|
|
|
'mnf': mnf, 'mdf': mdf, 'pkf': pkf, 'mnp': mnp,
|
|
|
|
|
|
'bp0': bp0, 'bp1': bp1, 'bp2': bp2, 'bp3': bp3,
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
return feats
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
def extract_features_window(self, window: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Extract features from a window of shape (samples, channels).
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Returns a flat float32 array ordered as:
|
|
|
|
|
|
[ch_i feats..., ch_j feats..., ..., cross-channel feats...]
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
channel_indices = self.channels if self.channels is not None \
|
|
|
|
|
|
else list(range(window.shape[1]))
|
|
|
|
|
|
|
|
|
|
|
|
all_ch_feats = [self.extract_features_single_channel(window[:, ch])
|
|
|
|
|
|
for ch in channel_indices]
|
|
|
|
|
|
|
|
|
|
|
|
norm_factor = 1.0
|
|
|
|
|
|
if self.normalize:
|
|
|
|
|
|
total_rms = float(np.sqrt(sum(f['rms'] ** 2 for f in all_ch_feats)))
|
|
|
|
|
|
norm_factor = max(total_rms, 1e-6)
|
|
|
|
|
|
|
|
|
|
|
|
feat_keys = self._EXPANDED_KEYS if self.expanded else self._LEGACY_KEYS
|
|
|
|
|
|
|
|
|
|
|
|
features: list[float] = []
|
|
|
|
|
|
for ch_feats in all_ch_feats:
|
|
|
|
|
|
for key in feat_keys:
|
|
|
|
|
|
val = ch_feats[key]
|
|
|
|
|
|
if self.normalize and key in self._NORM_KEYS:
|
|
|
|
|
|
val = val / norm_factor
|
|
|
|
|
|
features.append(val)
|
|
|
|
|
|
|
|
|
|
|
|
# Cross-channel features (expanded mode, ≥2 channels)
|
|
|
|
|
|
# Bug 6 fix: firmware computes cross-channel features from
|
|
|
|
|
|
# Reinhard-mapped signals when MODEL_USE_REINHARD=1. Apply the
|
|
|
|
|
|
# same tone-mapping here so correlation/covariance match.
|
|
|
|
|
|
if self.expanded and self.cross_channel and len(channel_indices) >= 2:
|
|
|
|
|
|
centered = []
|
|
|
|
|
|
for ch in channel_indices:
|
|
|
|
|
|
sig = window[:, ch] - np.mean(window[:, ch])
|
|
|
|
|
|
# Apply bandpass if enabled (matches firmware pipeline order)
|
|
|
|
|
|
if self.bandpass and self._bp_sos is not None:
|
|
|
|
|
|
zi = sosfilt_zi(self._bp_sos) * sig[0]
|
|
|
|
|
|
sig, _ = sosfilt(self._bp_sos, sig, zi=zi)
|
|
|
|
|
|
if self.reinhard:
|
|
|
|
|
|
sig = 64.0 * sig / (32.0 + np.abs(sig))
|
|
|
|
|
|
centered.append(sig)
|
|
|
|
|
|
rms_vals = [f['rms'] + 1e-10 for f in all_ch_feats]
|
|
|
|
|
|
n = window.shape[0]
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(len(channel_indices)):
|
|
|
|
|
|
for j in range(i + 1, len(channel_indices)):
|
|
|
|
|
|
si, sj = centered[i], centered[j]
|
|
|
|
|
|
ri, rj = rms_vals[i], rms_vals[j]
|
|
|
|
|
|
|
|
|
|
|
|
corr = float(np.clip(np.dot(si, sj) / (n * ri * rj), -1.0, 1.0))
|
|
|
|
|
|
lrms = float(np.log(ri / rj))
|
|
|
|
|
|
cov = float(np.dot(si, sj) / n)
|
|
|
|
|
|
if self.normalize:
|
|
|
|
|
|
cov /= (norm_factor ** 2)
|
|
|
|
|
|
|
|
|
|
|
|
features.extend([corr, lrms, cov])
|
|
|
|
|
|
|
|
|
|
|
|
return np.array(features, dtype=np.float32)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def extract_features_batch(self, X: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Extract features from a batch of windows.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Args:
|
|
|
|
|
|
X: (n_windows, n_samples, n_channels)
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(n_windows, n_features) float32 array.
|
|
|
|
|
|
"""
|
|
|
|
|
|
# Vectorised bandpass: apply sosfiltfilt on all windows at once along
|
|
|
|
|
|
# the samples axis. This is ~100x faster than per-window sosfilt calls
|
|
|
|
|
|
# (scipy's C loop vs Python loop). We disable per-window bandpass in
|
|
|
|
|
|
# extract_features_single_channel during batch extraction.
|
|
|
|
|
|
if self.bandpass and self._bp_sos is not None:
|
|
|
|
|
|
X = sosfiltfilt(self._bp_sos, X, axis=1).astype(np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
n_windows = X.shape[0]
|
|
|
|
|
|
n_ch_total = X.shape[2]
|
|
|
|
|
|
n_features = self._n_features(n_ch_total)
|
|
|
|
|
|
features = np.zeros((n_windows, n_features), dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
# Temporarily disable per-window bandpass (already applied above)
|
|
|
|
|
|
saved_bp = self.bandpass
|
|
|
|
|
|
self.bandpass = False
|
|
|
|
|
|
try:
|
|
|
|
|
|
for i in range(n_windows):
|
|
|
|
|
|
features[i] = self.extract_features_window(X[i])
|
|
|
|
|
|
finally:
|
|
|
|
|
|
self.bandpass = saved_bp
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
return features
|
|
|
|
|
|
|
|
|
|
|
|
def _n_features(self, n_total_channels: int) -> int:
|
|
|
|
|
|
"""Total feature vector length for the current configuration."""
|
|
|
|
|
|
n_ch = len(self.channels) if self.channels is not None else n_total_channels
|
|
|
|
|
|
per_ch = len(self._EXPANDED_KEYS if self.expanded else self._LEGACY_KEYS)
|
|
|
|
|
|
n = n_ch * per_ch
|
|
|
|
|
|
if self.expanded and self.cross_channel and n_ch >= 2:
|
|
|
|
|
|
n += 3 * (n_ch * (n_ch - 1) // 2) # 3 features × C(n_ch,2) pairs
|
|
|
|
|
|
return n
|
|
|
|
|
|
|
|
|
|
|
|
def get_feature_names(self, n_channels: int = 0) -> list[str]:
|
|
|
|
|
|
"""Human-readable feature names matching the extract_features_window layout."""
|
|
|
|
|
|
channel_indices = self.channels if self.channels is not None \
|
|
|
|
|
|
else list(range(n_channels))
|
|
|
|
|
|
|
|
|
|
|
|
feat_keys = self._EXPANDED_KEYS if self.expanded else self._LEGACY_KEYS
|
|
|
|
|
|
|
|
|
|
|
|
names: list[str] = []
|
|
|
|
|
|
for ch in channel_indices:
|
|
|
|
|
|
for key in feat_keys:
|
|
|
|
|
|
names.append(f'ch{ch}_{key}')
|
|
|
|
|
|
|
|
|
|
|
|
if self.expanded and self.cross_channel and len(channel_indices) >= 2:
|
|
|
|
|
|
for i in range(len(channel_indices)):
|
|
|
|
|
|
for j in range(i + 1, len(channel_indices)):
|
|
|
|
|
|
ci, cj = channel_indices[i], channel_indices[j]
|
|
|
|
|
|
names.extend([
|
|
|
|
|
|
f'cc_{ci}{cj}_corr',
|
|
|
|
|
|
f'cc_{ci}{cj}_lrms',
|
|
|
|
|
|
f'cc_{ci}{cj}_cov',
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
return names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# Change 6 — MPF FEATURE EXTRACTOR (Python training only)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class MPFFeatureExtractor:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Simplified 3-channel MPF: CSD upper triangle per 6 frequency bands = 36 features.
|
|
|
|
|
|
Python training only. Omits matrix logarithm (not needed for 3 channels).
|
|
|
|
|
|
Source: Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w
|
|
|
|
|
|
ESP32 approximation: use bp0-bp3 from EMGFeatureExtractor (Change 1).
|
|
|
|
|
|
"""
|
|
|
|
|
|
BANDS = [(0, 62), (62, 125), (125, 187), (187, 250), (250, 375), (375, 500)]
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels=None, log_diagonal=True):
|
|
|
|
|
|
self.channels = channels or HAND_CHANNELS
|
|
|
|
|
|
self.log_diag = log_diagonal
|
|
|
|
|
|
self.n_ch = len(self.channels)
|
|
|
|
|
|
self._r, self._c = np.triu_indices(self.n_ch)
|
|
|
|
|
|
self.n_features = len(self.BANDS) * len(self._r)
|
|
|
|
|
|
|
|
|
|
|
|
def extract_window(self, window):
|
|
|
|
|
|
sig = window[:, self.channels].astype(np.float64)
|
|
|
|
|
|
N = len(sig)
|
|
|
|
|
|
freqs = np.fft.rfftfreq(N, d=1.0 / SAMPLING_RATE_HZ)
|
|
|
|
|
|
Xf = np.fft.rfft(sig, axis=0)
|
|
|
|
|
|
feats = []
|
|
|
|
|
|
for lo, hi in self.BANDS:
|
|
|
|
|
|
mask = (freqs >= lo) & (freqs < hi)
|
|
|
|
|
|
if not mask.any():
|
|
|
|
|
|
feats.extend([0.0] * len(self._r))
|
|
|
|
|
|
continue
|
|
|
|
|
|
CSD = (Xf[mask].conj().T @ Xf[mask]).real / N
|
|
|
|
|
|
if self.log_diag:
|
|
|
|
|
|
for k in range(self.n_ch):
|
|
|
|
|
|
CSD[k, k] = np.log(max(CSD[k, k], 1e-10))
|
|
|
|
|
|
feats.extend(CSD[self._r, self._c].tolist())
|
|
|
|
|
|
return np.array(feats, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
def extract_batch(self, X):
|
|
|
|
|
|
out = np.zeros((len(X), self.n_features), dtype=np.float32)
|
|
|
|
|
|
for i in range(len(X)):
|
|
|
|
|
|
out[i] = self.extract_window(X[i])
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# CALIBRATION TRANSFORM (Per-session feature-space alignment)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class CalibrationTransform:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Corrects for electrode placement drift between sessions via Session Z-Score Normalization.
|
|
|
|
|
|
|
|
|
|
|
|
Training: each training session's features are independently z-scored
|
|
|
|
|
|
(subtract session mean, divide by session std) before LDA fitting.
|
|
|
|
|
|
This removes placement-dependent amplitude shifts, so the model learns
|
|
|
|
|
|
in a placement-invariant normalized feature space.
|
|
|
|
|
|
|
|
|
|
|
|
Calibration: collect a short clip of each gesture → compute global
|
|
|
|
|
|
mean (mu_calib) and std (sigma_calib) of those features → apply the
|
|
|
|
|
|
same z-score to every live window:
|
|
|
|
|
|
|
|
|
|
|
|
x_normalized = (x_live - mu_calib) / sigma_calib
|
|
|
|
|
|
|
|
|
|
|
|
This projects live features into the same normalized space that training
|
|
|
|
|
|
used, regardless of how electrode placement changed.
|
|
|
|
|
|
|
|
|
|
|
|
Workflow:
|
|
|
|
|
|
1. fit_from_training() — called automatically in EMGClassifier.train().
|
|
|
|
|
|
Stores per-class training centroids (in normalized
|
|
|
|
|
|
space) for diagnostics.
|
|
|
|
|
|
2. fit_from_calibration() — called at session start after collecting
|
|
|
|
|
|
a short clip of each gesture.
|
|
|
|
|
|
Computes mu_calib / sigma_calib.
|
|
|
|
|
|
3. apply() — called on every live feature vector.
|
|
|
|
|
|
Returns (features - mu_calib) / sigma_calib.
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.has_training_stats: bool = False
|
|
|
|
|
|
self.is_fitted: bool = False
|
|
|
|
|
|
self.class_means_train: dict = {} # {label: ndarray} from training (normalized space)
|
|
|
|
|
|
self.class_means_calib: dict = {} # {label: ndarray} from calibration (raw space)
|
|
|
|
|
|
# Stats for the z-score transform
|
|
|
|
|
|
self.mu_calib: Optional[np.ndarray] = None # Class-balanced mean of calibration features (raw space)
|
|
|
|
|
|
self.sigma_calib: Optional[np.ndarray] = None # Global std of calibration features (raw space)
|
|
|
|
|
|
self.sigma_train: Optional[np.ndarray] = None # Mean per-session sigma from training (preferred scale ref)
|
|
|
|
|
|
# Energy gate for rest detection (bypasses LDA when signal is quiet)
|
|
|
|
|
|
self.rest_energy_threshold: Optional[float] = None
|
|
|
|
|
|
|
|
|
|
|
|
def fit_from_training(self, X_features: np.ndarray, y: np.ndarray, label_names: list):
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Store per-class training centroids. Called automatically in EMGClassifier.train().
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
Args:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X_features: (n_windows, n_features) extracted training features
|
|
|
|
|
|
y: (n_windows,) integer label indices
|
|
|
|
|
|
label_names: label string list matching y indices
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.has_training_stats = True
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
self.class_means_train = {}
|
|
|
|
|
|
for i, name in enumerate(label_names):
|
|
|
|
|
|
mask = y == i
|
|
|
|
|
|
if mask.any():
|
|
|
|
|
|
self.class_means_train[name] = np.mean(X_features[mask], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
def fit_from_calibration(self, calib_features: np.ndarray, calib_labels: list):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Compute z-score normalization params from calibration-session data.
|
|
|
|
|
|
|
|
|
|
|
|
mu_calib = class-balanced mean (average of per-class centroids)
|
|
|
|
|
|
sigma_calib = overall std of all calibration feature windows
|
|
|
|
|
|
|
|
|
|
|
|
Using the class-balanced mean prevents near-zero-amplitude classes (rest)
|
|
|
|
|
|
from landing at the wrong normalized position when training sessions had
|
|
|
|
|
|
unequal numbers of windows per class.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
calib_features: (n_windows, n_features) from calibration clips
|
|
|
|
|
|
calib_labels: gesture label per window
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.has_training_stats:
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
"Training stats not available. Load a model that was trained "
|
|
|
|
|
|
"after calibration support was added (retrain if needed)."
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Per-class calibration centroids (raw space)
|
|
|
|
|
|
self.class_means_calib = {}
|
|
|
|
|
|
label_arr = np.array(calib_labels)
|
|
|
|
|
|
for label in set(calib_labels):
|
|
|
|
|
|
mask = label_arr == label
|
|
|
|
|
|
if mask.any():
|
|
|
|
|
|
self.class_means_calib[label] = np.mean(calib_features[mask], axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
# Class-balanced mean: average of per-class centroids (not overall mean).
|
|
|
|
|
|
# Prevents class-imbalanced calibration clips from biasing the normalization
|
|
|
|
|
|
# origin (especially important for rest, which has near-zero amplitude).
|
|
|
|
|
|
self.mu_calib = np.mean(list(self.class_means_calib.values()), axis=0)
|
|
|
|
|
|
self.sigma_calib = np.std(calib_features, axis=0) + 1e-8
|
|
|
|
|
|
|
|
|
|
|
|
# rest_energy_threshold is set externally from raw window RMS values
|
|
|
|
|
|
# (cannot be computed here — extracted features are amplitude-normalized).
|
|
|
|
|
|
self.rest_energy_threshold = None
|
|
|
|
|
|
|
|
|
|
|
|
self.is_fitted = True
|
|
|
|
|
|
|
|
|
|
|
|
# Decide which sigma to use for scaling:
|
|
|
|
|
|
# sigma_train (preferred) — mean per-session sigma from training.
|
|
|
|
|
|
# Ensures the classifier sees calibration features at the SAME scale
|
|
|
|
|
|
# as training features, which is critical for QDA whose per-class
|
|
|
|
|
|
# covariance ellipsoids are fixed in normalized training space.
|
|
|
|
|
|
# sigma_calib (fallback) — std of this calibration session.
|
|
|
|
|
|
# Used only if the model was trained without session normalization.
|
|
|
|
|
|
sigma_used = self.sigma_train if self.sigma_train is not None else self.sigma_calib
|
|
|
|
|
|
sigma_source = "sigma_train" if self.sigma_train is not None else "sigma_calib (fallback)"
|
|
|
|
|
|
print(f"[Calibration] Z-score fit: {len(calib_features)} windows, "
|
|
|
|
|
|
f"{len(self.class_means_calib)} classes [scale ref: {sigma_source}]")
|
|
|
|
|
|
# Per-class residual in normalized space (lower = better alignment)
|
|
|
|
|
|
common = set(self.class_means_calib) & set(self.class_means_train)
|
|
|
|
|
|
for c in sorted(common):
|
|
|
|
|
|
norm_calib = (self.class_means_calib[c] - self.mu_calib) / self.sigma_calib
|
|
|
|
|
|
residual = np.linalg.norm(self.class_means_train[c] - norm_calib)
|
|
|
|
|
|
print(f"[Calibration] {c}: normalized residual = {residual:.4f}")
|
|
|
|
|
|
|
|
|
|
|
|
def apply(self, features: np.ndarray) -> np.ndarray:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Z-score normalize features using calibration session statistics.
|
|
|
|
|
|
|
|
|
|
|
|
Uses sigma_train (mean per-session sigma from training) for scaling when
|
|
|
|
|
|
available — this keeps calibration features at the same scale as training
|
|
|
|
|
|
features, which is critical for QDA. Falls back to sigma_calib for old
|
|
|
|
|
|
models trained without session normalization.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
features: shape (n_features,) or (n_windows, n_features)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
Returns:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
(features - mu_calib) / sigma, same shape as input.
|
|
|
|
|
|
Pass-through if not fitted.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
if not self.is_fitted:
|
|
|
|
|
|
return features
|
|
|
|
|
|
sigma = self.sigma_train if self.sigma_train is not None else self.sigma_calib
|
|
|
|
|
|
return (features - self.mu_calib) / sigma
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def reset(self):
|
|
|
|
|
|
"""Remove per-session calibration (keeps training centroids intact)."""
|
|
|
|
|
|
self.mu_calib = None
|
|
|
|
|
|
self.sigma_calib = None
|
|
|
|
|
|
self.rest_energy_threshold = None
|
|
|
|
|
|
self.is_fitted = False
|
|
|
|
|
|
self.class_means_calib = {}
|
|
|
|
|
|
# sigma_train is permanent (set at train time, not session-specific)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# DATA AUGMENTATION (Change 3)
|
|
|
|
|
|
# =============================================================================
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def augment_emg_batch(
|
|
|
|
|
|
X: np.ndarray,
|
|
|
|
|
|
y: np.ndarray,
|
|
|
|
|
|
multiplier: int = 3,
|
|
|
|
|
|
seed: int = 42,
|
|
|
|
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Augment raw EMG windows for training robustness.
|
|
|
|
|
|
|
|
|
|
|
|
Must be called on raw windows (n_windows, n_samples, n_channels),
|
|
|
|
|
|
not on pre-computed features. Each copy independently applies:
|
|
|
|
|
|
- Amplitude scaling ×[0.80, 1.20]
|
|
|
|
|
|
- Gaussian noise 5 % of per-window RMS
|
|
|
|
|
|
- DC offset jitter ±20 counts
|
|
|
|
|
|
- Time-shift (roll) ±5 samples
|
|
|
|
|
|
|
|
|
|
|
|
Source: Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w
|
|
|
|
|
|
"""
|
|
|
|
|
|
rng = np.random.default_rng(seed)
|
|
|
|
|
|
aug_X, aug_y = [X], [y]
|
|
|
|
|
|
for _ in range(multiplier - 1):
|
|
|
|
|
|
Xc = X.copy().astype(np.float32)
|
|
|
|
|
|
Xc *= rng.uniform(0.80, 1.20, (len(X), 1, 1)).astype(np.float32)
|
|
|
|
|
|
rms = np.sqrt(np.mean(Xc ** 2, axis=(1, 2), keepdims=True)) + 1e-8
|
|
|
|
|
|
Xc += rng.standard_normal(Xc.shape).astype(np.float32) * (0.05 * rms)
|
|
|
|
|
|
Xc += rng.uniform(-20., 20., (len(X), 1, X.shape[2])).astype(np.float32)
|
|
|
|
|
|
shifts = rng.integers(-5, 6, size=len(X))
|
|
|
|
|
|
for i in range(len(Xc)):
|
|
|
|
|
|
if shifts[i]:
|
|
|
|
|
|
Xc[i] = np.roll(Xc[i], shifts[i], axis=0)
|
|
|
|
|
|
aug_X.append(Xc)
|
|
|
|
|
|
aug_y.append(y)
|
|
|
|
|
|
return np.concatenate(aug_X), np.concatenate(aug_y)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# LDA CLASSIFIER
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class EMGClassifier:
|
|
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
EMG gesture classifier supporting LDA and QDA.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Model types:
|
|
|
|
|
|
- LDA: Linear Discriminant Analysis — fast, exportable to ESP32 C header
|
|
|
|
|
|
- QDA: Quadratic Discriminant Analysis — more flexible boundaries, laptop-only
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def __init__(self, model_type: str = "lda", reg_param: float = 0.1):
|
|
|
|
|
|
self.model_type = model_type.lower()
|
|
|
|
|
|
self.reg_param = reg_param # only used by QDA
|
|
|
|
|
|
self.feature_extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, reinhard=True)
|
|
|
|
|
|
if self.model_type == "qda":
|
|
|
|
|
|
self.model = QuadraticDiscriminantAnalysis(reg_param=reg_param)
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.model = LinearDiscriminantAnalysis()
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.label_names: list[str] = []
|
|
|
|
|
|
self.is_trained = False
|
|
|
|
|
|
self.feature_names: list[str] = []
|
2026-03-10 11:39:02 -05:00
|
|
|
|
self.calibration_transform = CalibrationTransform()
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def train(self, X: np.ndarray, y: np.ndarray, label_names: list[str],
|
|
|
|
|
|
session_indices: Optional[np.ndarray] = None):
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
Train the classifier.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
X: Raw EMG windows (n_windows, n_samples, n_channels)
|
|
|
|
|
|
y: Integer labels (n_windows,)
|
|
|
|
|
|
label_names: List of label strings
|
2026-03-10 11:39:02 -05:00
|
|
|
|
session_indices: Optional per-window integer session ID (0..n_sessions-1).
|
|
|
|
|
|
When provided, each session's features are independently
|
|
|
|
|
|
z-scored before fitting, creating a placement-invariant model.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Change 3: data augmentation on raw windows before feature extraction
|
|
|
|
|
|
if getattr(self, 'use_augmentation', True):
|
|
|
|
|
|
X_aug, y_aug = augment_emg_batch(X, y, multiplier=3)
|
|
|
|
|
|
print(f"[Classifier] Augmentation: {len(X)} -> {len(X_aug)} windows")
|
|
|
|
|
|
# Replicate session_indices to match the augmented size
|
|
|
|
|
|
if session_indices is not None:
|
|
|
|
|
|
session_indices = np.tile(session_indices, 3)
|
|
|
|
|
|
else:
|
|
|
|
|
|
X_aug, y_aug = X, y
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print("\n[Classifier] Extracting features...")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X_features = self.feature_extractor.extract_features_batch(X_aug)
|
|
|
|
|
|
self.feature_names = self.feature_extractor.get_feature_names(X_aug.shape[2])
|
|
|
|
|
|
|
|
|
|
|
|
# Change 6: optionally stack MPF features
|
|
|
|
|
|
if getattr(self, 'use_mpf', False):
|
|
|
|
|
|
mpf = MPFFeatureExtractor(channels=HAND_CHANNELS)
|
|
|
|
|
|
X_features = np.hstack([X_features, mpf.extract_batch(X_aug)])
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
print(f"[Classifier] Feature matrix shape: {X_features.shape}")
|
|
|
|
|
|
print(f"[Classifier] Features per window: {len(self.feature_names)}")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
if session_indices is not None:
|
|
|
|
|
|
n_sessions = len(np.unique(session_indices))
|
|
|
|
|
|
print(f"\n[Classifier] Applying per-session z-score normalization ({n_sessions} sessions, class-balanced mu)...")
|
|
|
|
|
|
X_features = self._apply_session_normalization(X_features, session_indices, y=y_aug)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n[Classifier] Training {self.model_type.upper()}...")
|
|
|
|
|
|
self.model.fit(X_features, y_aug)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
self.label_names = label_names
|
|
|
|
|
|
self.is_trained = True
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Store training distribution (in normalized space) for calibration diagnostics
|
|
|
|
|
|
self.calibration_transform.fit_from_training(X_features, y_aug, label_names)
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# Training accuracy
|
2026-03-10 11:39:02 -05:00
|
|
|
|
train_acc = self.model.score(X_features, y_aug)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"[Classifier] Training accuracy: {train_acc*100:.1f}%")
|
|
|
|
|
|
|
|
|
|
|
|
return X_features
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def _apply_session_normalization(self, X_features: np.ndarray, session_indices: np.ndarray,
|
|
|
|
|
|
y: Optional[np.ndarray] = None) -> np.ndarray:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Z-score each session's features independently using a class-balanced mean.
|
|
|
|
|
|
|
|
|
|
|
|
For each session:
|
|
|
|
|
|
- mu = mean of per-class centroids (class-balanced, not weighted by window count)
|
|
|
|
|
|
- sigma = overall std of all windows in the session
|
|
|
|
|
|
|
|
|
|
|
|
Using the class-balanced mean prevents sessions with more rest windows (or any
|
|
|
|
|
|
imbalanced class) from skewing the normalization origin toward that class.
|
|
|
|
|
|
"""
|
|
|
|
|
|
X_norm = X_features.copy()
|
|
|
|
|
|
session_sigmas = []
|
|
|
|
|
|
for sid in np.unique(session_indices):
|
|
|
|
|
|
mask = session_indices == sid
|
|
|
|
|
|
X_sess = X_features[mask]
|
|
|
|
|
|
if y is not None:
|
|
|
|
|
|
# Class-balanced mean: average of per-class centroids
|
|
|
|
|
|
y_sess = y[mask]
|
|
|
|
|
|
class_means = [X_sess[y_sess == cls].mean(axis=0)
|
|
|
|
|
|
for cls in np.unique(y_sess)]
|
|
|
|
|
|
mu = np.mean(class_means, axis=0)
|
|
|
|
|
|
else:
|
|
|
|
|
|
mu = X_sess.mean(axis=0)
|
|
|
|
|
|
sigma = X_sess.std(axis=0) + 1e-8
|
|
|
|
|
|
session_sigmas.append(sigma)
|
|
|
|
|
|
X_norm[mask] = (X_sess - mu) / sigma
|
|
|
|
|
|
# Store mean per-session sigma so calibration can use the same scale reference
|
|
|
|
|
|
self.calibration_transform.sigma_train = np.mean(session_sigmas, axis=0)
|
|
|
|
|
|
return X_norm
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
def evaluate(self, X: np.ndarray, y: np.ndarray) -> dict:
|
|
|
|
|
|
"""Evaluate classifier on test data."""
|
|
|
|
|
|
if not self.is_trained:
|
|
|
|
|
|
raise ValueError("Classifier not trained!")
|
|
|
|
|
|
|
|
|
|
|
|
X_features = self.feature_extractor.extract_features_batch(X)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
y_pred = self.model.predict(X_features)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
accuracy = np.mean(y_pred == y)
|
|
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
'accuracy': accuracy,
|
|
|
|
|
|
'y_pred': y_pred,
|
|
|
|
|
|
'y_true': y
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
def cross_validate(self, X: np.ndarray, y: np.ndarray, trial_ids: Optional[np.ndarray] = None,
|
|
|
|
|
|
cv: int = 5, session_indices: Optional[np.ndarray] = None) -> np.ndarray:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Perform k-fold cross-validation with trial-level splitting.
|
|
|
|
|
|
|
|
|
|
|
|
When trial_ids are provided, uses GroupKFold to ensure windows from the
|
|
|
|
|
|
same trial never appear in both train and test folds (prevents leakage).
|
|
|
|
|
|
|
|
|
|
|
|
When session_indices are provided, applies the same per-session z-score
|
|
|
|
|
|
normalization used during training before running CV.
|
|
|
|
|
|
"""
|
2026-01-17 23:31:15 -06:00
|
|
|
|
X_features = self.feature_extractor.extract_features_batch(X)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
|
|
|
|
|
if session_indices is not None:
|
|
|
|
|
|
X_features = self._apply_session_normalization(X_features, session_indices, y=y)
|
|
|
|
|
|
|
|
|
|
|
|
if trial_ids is not None:
|
|
|
|
|
|
print(f"\n[Classifier] Running {cv}-fold cross-validation (TRIAL-LEVEL, no leakage)...")
|
|
|
|
|
|
group_kfold = GroupKFold(n_splits=cv)
|
|
|
|
|
|
scores = cross_val_score(self.model, X_features, y, cv=group_kfold, groups=trial_ids)
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"\n[Classifier] Running {cv}-fold cross-validation (window-level, legacy)...")
|
|
|
|
|
|
scores = cross_val_score(self.model, X_features, y, cv=cv)
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
return scores
|
|
|
|
|
|
|
|
|
|
|
|
def predict(self, window: np.ndarray) -> tuple[str, np.ndarray]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Predict gesture for a single window.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
window: Shape (n_samples, n_channels)
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(predicted_label, probabilities)
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.is_trained:
|
|
|
|
|
|
raise ValueError("Classifier not trained!")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
if not hasattr(self, '_predict_count'):
|
|
|
|
|
|
self._predict_count = 0
|
|
|
|
|
|
self._predict_count += 1
|
|
|
|
|
|
_debug = (self._predict_count <= 30)
|
|
|
|
|
|
|
|
|
|
|
|
features_raw = self.feature_extractor.extract_features_window(window)
|
|
|
|
|
|
|
|
|
|
|
|
# Energy gate: if raw signal is quiet enough to be rest, skip LDA entirely.
|
|
|
|
|
|
# Uses raw window RMS (pre-feature-extraction) so amplitude normalization
|
|
|
|
|
|
# inside the feature extractor doesn't mask the energy difference.
|
|
|
|
|
|
ct = self.calibration_transform
|
|
|
|
|
|
if (ct.is_fitted and ct.rest_energy_threshold is not None
|
|
|
|
|
|
and "rest" in self.label_names):
|
|
|
|
|
|
w_ac = window - window.mean(axis=0) # remove per-window DC offset (matches feature extractor)
|
|
|
|
|
|
raw_rms = float(np.sqrt(np.mean(w_ac ** 2)))
|
|
|
|
|
|
if _debug:
|
|
|
|
|
|
print(f"[predict #{self._predict_count}] rms={raw_rms:.1f} gate={ct.rest_energy_threshold:.1f} "
|
|
|
|
|
|
f"{'GATED->rest' if raw_rms < ct.rest_energy_threshold else 'pass->QDA/LDA'}")
|
|
|
|
|
|
if raw_rms < ct.rest_energy_threshold:
|
|
|
|
|
|
rest_idx = self.label_names.index("rest")
|
|
|
|
|
|
proba = np.zeros(len(self.label_names))
|
|
|
|
|
|
proba[rest_idx] = 1.0
|
|
|
|
|
|
return "rest", proba
|
|
|
|
|
|
elif _debug:
|
|
|
|
|
|
print(f"[predict #{self._predict_count}] gate inactive (is_fitted={ct.is_fitted}, "
|
|
|
|
|
|
f"threshold={ct.rest_energy_threshold})")
|
|
|
|
|
|
|
|
|
|
|
|
features = ct.apply(features_raw)
|
|
|
|
|
|
pred_idx = self.model.predict([features])[0]
|
|
|
|
|
|
proba = self.model.predict_proba([features])[0]
|
|
|
|
|
|
if _debug:
|
|
|
|
|
|
top = sorted(zip(self.label_names, proba), key=lambda x: -x[1])[:3]
|
|
|
|
|
|
print(f"[predict #{self._predict_count}] {self.model_type.upper()} -> {self.label_names[pred_idx]}"
|
|
|
|
|
|
f" proba: {', '.join(f'{n}={p:.2f}' for n,p in top)}")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
return self.label_names[pred_idx], proba
|
|
|
|
|
|
|
|
|
|
|
|
def get_feature_importance(self) -> dict:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
"""Get feature importance based on LDA coefficients (LDA only)."""
|
2026-01-17 23:31:15 -06:00
|
|
|
|
if not self.is_trained:
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
if not hasattr(self.model, 'coef_'):
|
|
|
|
|
|
return {}
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# For multi-class, average absolute coefficients across classes
|
2026-03-10 11:39:02 -05:00
|
|
|
|
coef = np.abs(self.model.coef_).mean(axis=0)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
importance = dict(zip(self.feature_names, coef))
|
|
|
|
|
|
return dict(sorted(importance.items(), key=lambda x: x[1], reverse=True))
|
|
|
|
|
|
|
|
|
|
|
|
def save(self, filepath: Path) -> Path:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Save the trained classifier to disk.
|
|
|
|
|
|
|
|
|
|
|
|
Saves:
|
|
|
|
|
|
- LDA model parameters
|
|
|
|
|
|
- Feature extractor settings
|
|
|
|
|
|
- Label names
|
|
|
|
|
|
- Feature names
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
filepath: Path to save the model (e.g., 'models/emg_classifier.joblib')
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Path to the saved model file
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.is_trained:
|
|
|
|
|
|
raise ValueError("Cannot save untrained classifier!")
|
|
|
|
|
|
|
|
|
|
|
|
filepath = Path(filepath)
|
|
|
|
|
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
model_data = {
|
2026-03-10 11:39:02 -05:00
|
|
|
|
'model': self.model,
|
|
|
|
|
|
'model_type': self.model_type,
|
2026-01-17 23:31:15 -06:00
|
|
|
|
'label_names': self.label_names,
|
|
|
|
|
|
'feature_names': self.feature_names,
|
|
|
|
|
|
'feature_extractor_params': {
|
|
|
|
|
|
'zc_threshold_percent': self.feature_extractor.zc_threshold_percent,
|
|
|
|
|
|
'ssc_threshold_percent': self.feature_extractor.ssc_threshold_percent,
|
2026-03-10 11:39:02 -05:00
|
|
|
|
'channels': self.feature_extractor.channels,
|
|
|
|
|
|
'normalize': self.feature_extractor.normalize,
|
|
|
|
|
|
'expanded': self.feature_extractor.expanded,
|
|
|
|
|
|
'cross_channel': self.feature_extractor.cross_channel,
|
|
|
|
|
|
'bandpass': self.feature_extractor.bandpass,
|
|
|
|
|
|
'reinhard': self.feature_extractor.reinhard,
|
|
|
|
|
|
'fft_n': self.feature_extractor.fft_n,
|
|
|
|
|
|
'fs': self.feature_extractor.fs,
|
2026-01-17 23:31:15 -06:00
|
|
|
|
},
|
2026-03-10 11:39:02 -05:00
|
|
|
|
'version': '1.3',
|
|
|
|
|
|
'reg_param': self.reg_param,
|
|
|
|
|
|
'session_normalized': True,
|
|
|
|
|
|
# Calibration transform training stats (used by CalibrationPage)
|
|
|
|
|
|
'calib_class_means_train': self.calibration_transform.class_means_train,
|
|
|
|
|
|
'calib_sigma_train': self.calibration_transform.sigma_train,
|
2026-01-17 23:31:15 -06:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
joblib.dump(model_data, filepath)
|
|
|
|
|
|
print(f"[Classifier] Model saved to: {filepath}")
|
|
|
|
|
|
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
|
|
|
|
|
|
return filepath
|
2026-01-27 21:31:49 -06:00
|
|
|
|
|
|
|
|
|
|
def export_to_header(self, filepath: Path) -> Path:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Export trained model to a C header file for ESP32 inference.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
filepath: Output .h file path
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Path to the saved header file
|
|
|
|
|
|
"""
|
|
|
|
|
|
if not self.is_trained:
|
|
|
|
|
|
raise ValueError("Cannot export untrained classifier!")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
if self.model_type != "lda":
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
f"Cannot export {self.model_type.upper()} to C header. "
|
|
|
|
|
|
"Only LDA models can be exported (QDA lacks coef_/intercept_)."
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-01-27 21:31:49 -06:00
|
|
|
|
filepath = Path(filepath)
|
|
|
|
|
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
n_classes = len(self.label_names)
|
|
|
|
|
|
n_features = len(self.feature_names)
|
|
|
|
|
|
|
|
|
|
|
|
# Get LDA parameters
|
|
|
|
|
|
# coef_: (n_classes, n_features) - access as [class][feature]
|
|
|
|
|
|
# intercept_: (n_classes,)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
coefs = self.model.coef_
|
|
|
|
|
|
intercepts = self.model.intercept_
|
2026-01-27 21:31:49 -06:00
|
|
|
|
|
|
|
|
|
|
# Add logic for binary classification (sklearn stores only 1 set of coefs)
|
|
|
|
|
|
# For >2 classes, it stores n_classes sets.
|
|
|
|
|
|
if n_classes == 2:
|
|
|
|
|
|
# Binary case: coef_ is (1, n_features), intercept_ is (1,)
|
|
|
|
|
|
# We need to expand this to 2 classes for the C inference engine to be generic.
|
|
|
|
|
|
# Class 1 decision = dot(w, x) + b
|
|
|
|
|
|
# Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function
|
|
|
|
|
|
# BUT: decision_function returns score. A generic 'argmax' approach usually expects
|
|
|
|
|
|
# one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial.
|
|
|
|
|
|
# Let's check sklearn docs or behavior.
|
|
|
|
|
|
# Actually, LDA in sklearn for binary case is special.
|
|
|
|
|
|
# To make C code simple (always argmax), let's explicitly store 2 rows.
|
|
|
|
|
|
# Row 1 (Index 1 in sklearn): coef, intercept
|
|
|
|
|
|
# Row 0 (Index 0): -coef, -intercept ?
|
|
|
|
|
|
# Wait, LDA is generative. The decision boundary is linear.
|
|
|
|
|
|
# Let's assume Multiclass for now or handle binary specifically.
|
|
|
|
|
|
# For simplicity in C, we prefer (n_classes, n_features).
|
|
|
|
|
|
# If coefs.shape[0] != n_classes, we need to handle it.
|
|
|
|
|
|
if coefs.shape[0] == 1:
|
|
|
|
|
|
print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.")
|
|
|
|
|
|
# Class 1 (positive)
|
|
|
|
|
|
c1_coef = coefs[0]
|
|
|
|
|
|
c1_int = intercepts[0]
|
|
|
|
|
|
# Class 0 (negative) - Effectively -score for decision boundary at 0
|
|
|
|
|
|
# But strictly speaking LDA is comparison of log-posteriors.
|
|
|
|
|
|
# Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0).
|
|
|
|
|
|
# The score S = coef.X + intercept. If S > 0 pred class 1, else 0.
|
|
|
|
|
|
# To map this to ArgMax(Score0, Score1):
|
|
|
|
|
|
# We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2.
|
|
|
|
|
|
# Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier).
|
|
|
|
|
|
# Safest: Let's trust that for our 5-gesture demo, it's multiclass.
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Bug 7 fix: preserve compile-time flags that are independent of
|
|
|
|
|
|
# the feature pipeline (MLP, ensemble). Pipeline-dependent flags
|
|
|
|
|
|
# (EXPAND_FEATURES, REINHARD) are set from the extractor config so
|
|
|
|
|
|
# they always match the exported weights.
|
|
|
|
|
|
preserved_flags = {}
|
|
|
|
|
|
_PRESERVED_FLAG_NAMES = ['MODEL_USE_MLP', 'MODEL_USE_ENSEMBLE']
|
|
|
|
|
|
if filepath.exists():
|
|
|
|
|
|
import re
|
|
|
|
|
|
existing = filepath.read_text()
|
|
|
|
|
|
for flag in _PRESERVED_FLAG_NAMES:
|
|
|
|
|
|
m = re.search(rf'#define\s+{flag}\s+(\d+)', existing)
|
|
|
|
|
|
if m:
|
|
|
|
|
|
preserved_flags[flag] = int(m.group(1))
|
|
|
|
|
|
|
|
|
|
|
|
# Auto-set pipeline flags from training config (prevents mismatch)
|
|
|
|
|
|
preserved_flags['MODEL_EXPAND_FEATURES'] = 1 if self.feature_extractor.expanded else 0
|
|
|
|
|
|
preserved_flags['MODEL_USE_REINHARD'] = 1 if self.feature_extractor.reinhard else 0
|
|
|
|
|
|
|
2026-01-27 21:31:49 -06:00
|
|
|
|
# Generate C content
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
2026-01-27 21:31:49 -06:00
|
|
|
|
c_content = [
|
|
|
|
|
|
"/**",
|
|
|
|
|
|
f" * @file {filepath.name}",
|
|
|
|
|
|
" * @brief Trained LDA model weights exported from Python.",
|
|
|
|
|
|
f" * @date {timestamp}",
|
|
|
|
|
|
" */",
|
|
|
|
|
|
"",
|
|
|
|
|
|
"#ifndef MODEL_WEIGHTS_H",
|
|
|
|
|
|
"#define MODEL_WEIGHTS_H",
|
|
|
|
|
|
"",
|
|
|
|
|
|
"#include <stdint.h>",
|
|
|
|
|
|
"",
|
|
|
|
|
|
"/* Metadata */",
|
|
|
|
|
|
f"#define MODEL_NUM_CLASSES {n_classes}",
|
|
|
|
|
|
f"#define MODEL_NUM_FEATURES {n_features}",
|
2026-03-10 11:39:02 -05:00
|
|
|
|
f"#define MODEL_NORMALIZE_FEATURES {1 if self.feature_extractor.normalize else 0}",
|
2026-01-27 21:31:49 -06:00
|
|
|
|
"",
|
|
|
|
|
|
]
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Write compile-time flags (pipeline flags auto-set, architecture flags preserved)
|
|
|
|
|
|
_ALL_FLAGS = [
|
|
|
|
|
|
'MODEL_EXPAND_FEATURES', 'MODEL_USE_REINHARD',
|
|
|
|
|
|
'MODEL_USE_MLP', 'MODEL_USE_ENSEMBLE',
|
|
|
|
|
|
]
|
|
|
|
|
|
c_content.append("/* Compile-time feature flags */")
|
|
|
|
|
|
for flag in _ALL_FLAGS:
|
|
|
|
|
|
val = preserved_flags.get(flag, 0)
|
|
|
|
|
|
c_content.append(f"#define {flag} {val}")
|
|
|
|
|
|
c_content.append("")
|
|
|
|
|
|
|
|
|
|
|
|
c_content.append("/* Class Names */")
|
|
|
|
|
|
c_content.append("static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {")
|
|
|
|
|
|
|
2026-01-27 21:31:49 -06:00
|
|
|
|
for name in self.label_names:
|
|
|
|
|
|
c_content.append(f' "{name}",')
|
|
|
|
|
|
c_content.append("};")
|
|
|
|
|
|
c_content.append("")
|
|
|
|
|
|
|
|
|
|
|
|
c_content.append("/* Feature Extractor Parameters */")
|
|
|
|
|
|
c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f")
|
|
|
|
|
|
c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f")
|
|
|
|
|
|
c_content.append("")
|
|
|
|
|
|
|
|
|
|
|
|
c_content.append("/* LDA Intercepts/Biases */")
|
|
|
|
|
|
c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{")
|
|
|
|
|
|
line = " "
|
|
|
|
|
|
for val in intercepts:
|
|
|
|
|
|
line += f"{val:.6f}f, "
|
|
|
|
|
|
c_content.append(line.rstrip(", "))
|
|
|
|
|
|
c_content.append("};")
|
|
|
|
|
|
c_content.append("")
|
|
|
|
|
|
|
|
|
|
|
|
c_content.append("/* LDA Coefficients (Weights) */")
|
|
|
|
|
|
c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{")
|
|
|
|
|
|
|
|
|
|
|
|
for i, row in enumerate(coefs):
|
|
|
|
|
|
c_content.append(f" /* {self.label_names[i]} */")
|
|
|
|
|
|
c_content.append(" {")
|
|
|
|
|
|
line = " "
|
|
|
|
|
|
for j, val in enumerate(row):
|
|
|
|
|
|
line += f"{val:.6f}f, "
|
|
|
|
|
|
if (j + 1) % 8 == 0:
|
|
|
|
|
|
c_content.append(line)
|
|
|
|
|
|
line = " "
|
|
|
|
|
|
if line.strip():
|
|
|
|
|
|
c_content.append(line.rstrip(", "))
|
|
|
|
|
|
c_content.append(" },")
|
|
|
|
|
|
|
|
|
|
|
|
c_content.append("};")
|
|
|
|
|
|
c_content.append("")
|
|
|
|
|
|
c_content.append("#endif /* MODEL_WEIGHTS_H */")
|
|
|
|
|
|
|
|
|
|
|
|
with open(filepath, 'w') as f:
|
|
|
|
|
|
f.write('\n'.join(c_content))
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[Classifier] Model weights exported to: {filepath}")
|
|
|
|
|
|
return filepath
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
def load(cls, filepath: Path) -> 'EMGClassifier':
|
|
|
|
|
|
"""
|
|
|
|
|
|
Load a trained classifier from disk.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
filepath: Path to the saved model file
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
Loaded EMGClassifier instance ready for prediction
|
|
|
|
|
|
"""
|
|
|
|
|
|
filepath = Path(filepath)
|
|
|
|
|
|
if not filepath.exists():
|
|
|
|
|
|
raise FileNotFoundError(f"Model file not found: {filepath}")
|
|
|
|
|
|
|
|
|
|
|
|
model_data = joblib.load(filepath)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Determine model type (backward compat: old files have 'lda' key, no 'model_type')
|
|
|
|
|
|
model_type = model_data.get('model_type', 'lda')
|
|
|
|
|
|
reg_param = model_data.get('reg_param', 0.1)
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# Create new instance and restore state
|
2026-03-10 11:39:02 -05:00
|
|
|
|
classifier = cls(model_type=model_type, reg_param=reg_param)
|
|
|
|
|
|
classifier.model = model_data.get('model', model_data.get('lda'))
|
2026-01-17 23:31:15 -06:00
|
|
|
|
classifier.label_names = model_data['label_names']
|
|
|
|
|
|
classifier.is_trained = True
|
|
|
|
|
|
|
|
|
|
|
|
# Restore feature extractor params
|
|
|
|
|
|
params = model_data.get('feature_extractor_params', {})
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Infer expanded/cross_channel from feature count for old models
|
|
|
|
|
|
# that don't store these params: 12 features = legacy (4×3),
|
|
|
|
|
|
# 69 features = expanded (20×3 + 9 cross-channel)
|
|
|
|
|
|
saved_feat_names = model_data.get('feature_names', [])
|
|
|
|
|
|
n_feat = len(saved_feat_names) if saved_feat_names else 69
|
|
|
|
|
|
default_expanded = n_feat > 12
|
|
|
|
|
|
default_cc = n_feat > 60 # cross-channel adds 9 features (60→69)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
classifier.feature_extractor = EMGFeatureExtractor(
|
|
|
|
|
|
zc_threshold_percent=params.get('zc_threshold_percent', 0.1),
|
|
|
|
|
|
ssc_threshold_percent=params.get('ssc_threshold_percent', 0.1),
|
2026-03-10 11:39:02 -05:00
|
|
|
|
channels=params.get('channels', HAND_CHANNELS),
|
|
|
|
|
|
normalize=params.get('normalize', False),
|
|
|
|
|
|
expanded=params.get('expanded', default_expanded),
|
|
|
|
|
|
cross_channel=params.get('cross_channel', default_cc),
|
|
|
|
|
|
bandpass=params.get('bandpass', False), # False for old models
|
|
|
|
|
|
reinhard=params.get('reinhard', False),
|
|
|
|
|
|
fft_n=params.get('fft_n', 256),
|
|
|
|
|
|
fs=params.get('fs', float(SAMPLING_RATE_HZ)),
|
2026-01-17 23:31:15 -06:00
|
|
|
|
)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Regenerate feature names from extractor if not in saved data
|
|
|
|
|
|
if saved_feat_names:
|
|
|
|
|
|
classifier.feature_names = saved_feat_names
|
|
|
|
|
|
else:
|
|
|
|
|
|
channels = params.get('channels', HAND_CHANNELS)
|
|
|
|
|
|
classifier.feature_names = classifier.feature_extractor.get_feature_names(len(channels))
|
|
|
|
|
|
|
|
|
|
|
|
# Restore calibration transform training stats (saved from v1.2+ models)
|
|
|
|
|
|
classifier.calibration_transform = CalibrationTransform()
|
|
|
|
|
|
class_means_train = model_data.get('calib_class_means_train', {})
|
|
|
|
|
|
sigma_train = model_data.get('calib_sigma_train')
|
|
|
|
|
|
session_normalized = model_data.get('session_normalized', False)
|
|
|
|
|
|
classifier.session_normalized = session_normalized
|
|
|
|
|
|
if class_means_train:
|
|
|
|
|
|
classifier.calibration_transform.class_means_train = class_means_train
|
|
|
|
|
|
classifier.calibration_transform.has_training_stats = True
|
|
|
|
|
|
if sigma_train is not None:
|
|
|
|
|
|
classifier.calibration_transform.sigma_train = sigma_train
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
print(f"[Classifier] Model loaded from: {filepath}")
|
|
|
|
|
|
print(f"[Classifier] Labels: {classifier.label_names}")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
calib_ready = classifier.calibration_transform.has_training_stats
|
|
|
|
|
|
print(f"[Classifier] Calibration support: {'yes' if calib_ready else 'no (retrain to enable)'}")
|
|
|
|
|
|
print(f"[Classifier] Session-normalized: {session_normalized}")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
return classifier
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_default_model_path() -> Path:
|
|
|
|
|
|
"""Get the default path for saving/loading models."""
|
|
|
|
|
|
return MODEL_DIR / "emg_lda_classifier.joblib"
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def get_latest_model_path() -> Path | None:
|
|
|
|
|
|
"""Get the most recently modified model file, or None if no models exist."""
|
|
|
|
|
|
models = EMGClassifier.list_saved_models()
|
|
|
|
|
|
if not models:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return max(models, key=lambda p: p.stat().st_mtime)
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def list_saved_models() -> list[Path]:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
"""List all saved classifier model files (excludes ensemble/auxiliary files)."""
|
2026-01-17 23:31:15 -06:00
|
|
|
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
return sorted(
|
|
|
|
|
|
p for p in MODEL_DIR.glob("*.joblib")
|
|
|
|
|
|
if "ensemble" not in p.stem
|
|
|
|
|
|
)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# PREDICTION SMOOTHING (Temporal smoothing, majority vote, debouncing)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
class PredictionSmoother:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Smooths predictions to prevent twitchy/unstable output.
|
|
|
|
|
|
|
|
|
|
|
|
Combines three techniques:
|
|
|
|
|
|
1. Probability Smoothing: Exponential moving average on raw probabilities
|
|
|
|
|
|
2. Majority Vote: Output most common prediction from last N predictions
|
|
|
|
|
|
3. Debouncing: Only change output after N consecutive same predictions
|
|
|
|
|
|
|
|
|
|
|
|
This prevents the robotic hand from twitching when there's an occasional
|
|
|
|
|
|
misclassification in a stream of correct predictions.
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
Raw predictions: FIST, FIST, OPEN, FIST, FIST, FIST
|
|
|
|
|
|
Without smoothing: Hand twitches open briefly
|
|
|
|
|
|
With smoothing: Hand stays as FIST (OPEN was filtered out)
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
label_names: list[str],
|
|
|
|
|
|
probability_smoothing: float = 0.7,
|
|
|
|
|
|
majority_vote_window: int = 5,
|
|
|
|
|
|
debounce_count: int = 3,
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Args:
|
|
|
|
|
|
label_names: List of gesture labels (must match classifier output)
|
|
|
|
|
|
probability_smoothing: EMA factor (0-1). Higher = more smoothing.
|
|
|
|
|
|
0 = no smoothing, 0.9 = very smooth
|
|
|
|
|
|
majority_vote_window: Number of past predictions to consider for voting
|
|
|
|
|
|
debounce_count: Number of consecutive same predictions needed to change output
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.label_names = label_names
|
|
|
|
|
|
self.n_classes = len(label_names)
|
|
|
|
|
|
|
|
|
|
|
|
# Probability smoothing (Exponential Moving Average)
|
|
|
|
|
|
self.prob_smoothing = probability_smoothing
|
|
|
|
|
|
self.smoothed_proba = np.ones(self.n_classes) / self.n_classes # Start uniform
|
|
|
|
|
|
|
|
|
|
|
|
# Majority vote
|
|
|
|
|
|
self.vote_window = majority_vote_window
|
|
|
|
|
|
self.prediction_history: list[str] = []
|
|
|
|
|
|
|
|
|
|
|
|
# Debouncing
|
|
|
|
|
|
self.debounce_count = debounce_count
|
|
|
|
|
|
self.current_output = None
|
|
|
|
|
|
self.pending_output = None
|
|
|
|
|
|
self.pending_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
# Stats
|
|
|
|
|
|
self.total_predictions = 0
|
|
|
|
|
|
self.output_changes = 0
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, predicted_label: str, probabilities: np.ndarray) -> tuple[str, float, dict]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
Process a new prediction and return smoothed output.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
predicted_label: Raw prediction from classifier
|
|
|
|
|
|
probabilities: Raw probability array from classifier
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
(smoothed_label, confidence, debug_info)
|
|
|
|
|
|
- smoothed_label: The stable output label after all smoothing
|
|
|
|
|
|
- confidence: Confidence in the smoothed output (0-1)
|
|
|
|
|
|
- debug_info: Dict with intermediate values for debugging/display
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.total_predictions += 1
|
|
|
|
|
|
|
|
|
|
|
|
# --- 1. Probability Smoothing (EMA) ---
|
|
|
|
|
|
# Blend new probabilities with historical smoothed probabilities
|
|
|
|
|
|
self.smoothed_proba = (
|
|
|
|
|
|
self.prob_smoothing * self.smoothed_proba +
|
|
|
|
|
|
(1 - self.prob_smoothing) * probabilities
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Get prediction from smoothed probabilities
|
|
|
|
|
|
prob_smoothed_idx = np.argmax(self.smoothed_proba)
|
|
|
|
|
|
prob_smoothed_label = self.label_names[prob_smoothed_idx]
|
|
|
|
|
|
prob_smoothed_confidence = self.smoothed_proba[prob_smoothed_idx]
|
|
|
|
|
|
|
|
|
|
|
|
# --- 2. Majority Vote ---
|
|
|
|
|
|
# Add to history and keep window size
|
|
|
|
|
|
self.prediction_history.append(prob_smoothed_label)
|
|
|
|
|
|
if len(self.prediction_history) > self.vote_window:
|
|
|
|
|
|
self.prediction_history.pop(0)
|
|
|
|
|
|
|
|
|
|
|
|
# Count votes
|
|
|
|
|
|
vote_counts = {}
|
|
|
|
|
|
for pred in self.prediction_history:
|
|
|
|
|
|
vote_counts[pred] = vote_counts.get(pred, 0) + 1
|
|
|
|
|
|
|
|
|
|
|
|
# Get majority winner
|
|
|
|
|
|
majority_label = max(vote_counts, key=vote_counts.get)
|
|
|
|
|
|
majority_count = vote_counts[majority_label]
|
|
|
|
|
|
majority_confidence = majority_count / len(self.prediction_history)
|
|
|
|
|
|
|
|
|
|
|
|
# --- 3. Debouncing ---
|
|
|
|
|
|
# Only change output after consistent predictions
|
|
|
|
|
|
if self.current_output is None:
|
|
|
|
|
|
# First prediction
|
|
|
|
|
|
self.current_output = majority_label
|
|
|
|
|
|
self.pending_output = majority_label
|
|
|
|
|
|
self.pending_count = 1
|
|
|
|
|
|
elif majority_label == self.current_output:
|
|
|
|
|
|
# Same as current output, reset pending
|
|
|
|
|
|
self.pending_output = majority_label
|
|
|
|
|
|
self.pending_count = 1
|
|
|
|
|
|
elif majority_label == self.pending_output:
|
|
|
|
|
|
# Same as pending, increment count
|
|
|
|
|
|
self.pending_count += 1
|
|
|
|
|
|
if self.pending_count >= self.debounce_count:
|
|
|
|
|
|
# Enough consecutive predictions, change output
|
|
|
|
|
|
self.current_output = majority_label
|
|
|
|
|
|
self.output_changes += 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
# New prediction, start new pending
|
|
|
|
|
|
self.pending_output = majority_label
|
|
|
|
|
|
self.pending_count = 1
|
|
|
|
|
|
|
|
|
|
|
|
# Final output
|
|
|
|
|
|
final_label = self.current_output
|
|
|
|
|
|
final_confidence = majority_confidence
|
|
|
|
|
|
|
|
|
|
|
|
# Debug info
|
|
|
|
|
|
debug_info = {
|
|
|
|
|
|
'raw_label': predicted_label,
|
|
|
|
|
|
'raw_confidence': float(np.max(probabilities)),
|
|
|
|
|
|
'prob_smoothed_label': prob_smoothed_label,
|
|
|
|
|
|
'prob_smoothed_confidence': float(prob_smoothed_confidence),
|
|
|
|
|
|
'majority_label': majority_label,
|
|
|
|
|
|
'majority_confidence': float(majority_confidence),
|
|
|
|
|
|
'vote_counts': vote_counts,
|
|
|
|
|
|
'pending_output': self.pending_output,
|
|
|
|
|
|
'pending_count': self.pending_count,
|
|
|
|
|
|
'debounce_threshold': self.debounce_count,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return final_label, final_confidence, debug_info
|
|
|
|
|
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
|
|
"""Reset all state (call when starting a new prediction session)."""
|
|
|
|
|
|
self.smoothed_proba = np.ones(self.n_classes) / self.n_classes
|
|
|
|
|
|
self.prediction_history = []
|
|
|
|
|
|
self.current_output = None
|
|
|
|
|
|
self.pending_output = None
|
|
|
|
|
|
self.pending_count = 0
|
|
|
|
|
|
self.total_predictions = 0
|
|
|
|
|
|
self.output_changes = 0
|
|
|
|
|
|
|
|
|
|
|
|
def get_stats(self) -> dict:
|
|
|
|
|
|
"""Get statistics about smoothing effectiveness."""
|
|
|
|
|
|
return {
|
|
|
|
|
|
'total_predictions': self.total_predictions,
|
|
|
|
|
|
'output_changes': self.output_changes,
|
|
|
|
|
|
'stability_ratio': 1 - (self.output_changes / max(1, self.total_predictions)),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# TRAINING (Train LDA classifier)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def run_training_demo():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Train an LDA classifier on ALL collected sessions combined.
|
|
|
|
|
|
|
|
|
|
|
|
Shows:
|
|
|
|
|
|
1. Loading all session data combined
|
|
|
|
|
|
2. Feature extraction
|
|
|
|
|
|
3. Training LDA
|
|
|
|
|
|
4. Cross-validation evaluation
|
|
|
|
|
|
5. Feature importance analysis
|
|
|
|
|
|
|
|
|
|
|
|
The model learns from all accumulated data, making it more robust
|
|
|
|
|
|
as you collect more sessions over time.
|
|
|
|
|
|
"""
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("TRAIN LDA CLASSIFIER (ALL SESSIONS)")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
storage = SessionStorage()
|
|
|
|
|
|
sessions = storage.list_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
if not sessions:
|
|
|
|
|
|
print("\nNo saved sessions found!")
|
|
|
|
|
|
print("Run option 1 first to collect and save training data.")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Show available sessions
|
|
|
|
|
|
print(f"\nFound {len(sessions)} saved session(s):")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
total_windows = 0
|
|
|
|
|
|
for session_id in sessions:
|
|
|
|
|
|
info = storage.get_session_info(session_id)
|
|
|
|
|
|
print(f" - {session_id}: {info['num_windows']} windows, gestures: {info['gestures']}")
|
|
|
|
|
|
total_windows += info['num_windows']
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nTotal windows across all sessions: {total_windows}")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
confirm = input("\nTrain on ALL sessions combined? (y/n): ").strip().lower()
|
|
|
|
|
|
if confirm != 'y':
|
|
|
|
|
|
print("Training cancelled.")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Load ALL data combined
|
|
|
|
|
|
print(f"\n{'=' * 60}")
|
|
|
|
|
|
print("TRAINING ON ALL SESSIONS COMBINED")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training()
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
print(f"\nDataset:")
|
|
|
|
|
|
print(f" Windows: {X.shape[0]}")
|
|
|
|
|
|
print(f" Samples per window: {X.shape[1]}")
|
|
|
|
|
|
print(f" Channels: {X.shape[2]}")
|
|
|
|
|
|
print(f" Classes: {label_names}")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print(f" Unique trials: {len(np.unique(trial_ids))}")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Count per class
|
|
|
|
|
|
print(f"\nSamples per class:")
|
|
|
|
|
|
for i, name in enumerate(label_names):
|
|
|
|
|
|
count = np.sum(y == i)
|
|
|
|
|
|
print(f" {name}: {count}")
|
|
|
|
|
|
|
|
|
|
|
|
# Create and train classifier
|
|
|
|
|
|
classifier = EMGClassifier()
|
|
|
|
|
|
X_features = classifier.train(X, y, label_names)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Cross-validation (trial-level to prevent leakage)
|
|
|
|
|
|
cv_scores = classifier.cross_validate(X, y, trial_ids=trial_ids, cv=5)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"\nCross-validation scores: {cv_scores}")
|
|
|
|
|
|
print(f"Mean CV accuracy: {cv_scores.mean()*100:.1f}% (+/- {cv_scores.std()*100:.1f}%)")
|
|
|
|
|
|
|
|
|
|
|
|
# Feature importance
|
|
|
|
|
|
print(f"\n{'-' * 40}")
|
|
|
|
|
|
print("FEATURE IMPORTANCE (top 8)")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
importance = classifier.get_feature_importance()
|
|
|
|
|
|
for i, (name, score) in enumerate(list(importance.items())[:8]):
|
|
|
|
|
|
bar = "█" * int(score * 20)
|
|
|
|
|
|
print(f" {name:12s}: {bar} ({score:.3f})")
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Train/test split evaluation (TRIAL-LEVEL to prevent leakage)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(f"\n{'-' * 40}")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print("TRAIN/TEST SPLIT EVALUATION (TRIAL-LEVEL)")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Use GroupShuffleSplit to split by trial, not by window
|
|
|
|
|
|
gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
|
|
|
|
|
|
train_idx, test_idx = next(gss.split(X, y, groups=trial_ids))
|
|
|
|
|
|
|
|
|
|
|
|
X_train, X_test = X[train_idx], X[test_idx]
|
|
|
|
|
|
y_train, y_test = y[train_idx], y[test_idx]
|
|
|
|
|
|
train_trial_ids = trial_ids[train_idx]
|
|
|
|
|
|
test_trial_ids = trial_ids[test_idx]
|
|
|
|
|
|
|
|
|
|
|
|
# VERIFICATION: Ensure no trial leakage
|
|
|
|
|
|
train_trials_set = set(train_trial_ids)
|
|
|
|
|
|
test_trials_set = set(test_trial_ids)
|
|
|
|
|
|
overlap = train_trials_set & test_trials_set
|
|
|
|
|
|
assert len(overlap) == 0, f"Trial leakage detected! Overlapping trials: {overlap}"
|
|
|
|
|
|
print(f" Train: {len(X_train)} windows from {len(train_trials_set)} trials")
|
|
|
|
|
|
print(f" Test: {len(X_test)} windows from {len(test_trials_set)} trials")
|
|
|
|
|
|
print(f" Trial overlap: {len(overlap)} (VERIFIED: no leakage)")
|
|
|
|
|
|
|
|
|
|
|
|
# Log per-class distribution
|
|
|
|
|
|
print(f"\n Per-class window counts:")
|
|
|
|
|
|
for i, name in enumerate(label_names):
|
|
|
|
|
|
train_count = np.sum(y_train == i)
|
|
|
|
|
|
test_count = np.sum(y_test == i)
|
|
|
|
|
|
print(f" {name:12s}: train={train_count:4d}, test={test_count:4d}")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Train on train set
|
|
|
|
|
|
test_classifier = EMGClassifier()
|
|
|
|
|
|
test_classifier.train(X_train, y_train, label_names)
|
|
|
|
|
|
|
|
|
|
|
|
# Evaluate on test set
|
|
|
|
|
|
result = test_classifier.evaluate(X_test, y_test)
|
|
|
|
|
|
print(f"\nTest accuracy: {result['accuracy']*100:.1f}%")
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nClassification Report:")
|
|
|
|
|
|
print(classification_report(result['y_true'], result['y_pred'],
|
|
|
|
|
|
target_names=label_names))
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Confusion Matrix:")
|
|
|
|
|
|
cm = confusion_matrix(result['y_true'], result['y_pred'])
|
|
|
|
|
|
print(f" {'':12s} ", end="")
|
|
|
|
|
|
for name in label_names:
|
|
|
|
|
|
print(f"{name[:8]:>8s} ", end="")
|
|
|
|
|
|
print()
|
|
|
|
|
|
for i, row in enumerate(cm):
|
|
|
|
|
|
print(f" {label_names[i]:12s} ", end="")
|
|
|
|
|
|
for val in row:
|
|
|
|
|
|
print(f"{val:8d} ", end="")
|
|
|
|
|
|
print()
|
|
|
|
|
|
|
|
|
|
|
|
# --- Save the model ---
|
|
|
|
|
|
print(f"\n{'-' * 40}")
|
|
|
|
|
|
print("SAVE MODEL")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
default_path = EMGClassifier.get_default_model_path()
|
|
|
|
|
|
print(f"Default save path: {default_path}")
|
|
|
|
|
|
|
|
|
|
|
|
save_choice = input("\nSave this model? (y/n): ").strip().lower()
|
|
|
|
|
|
if save_choice == 'y':
|
|
|
|
|
|
classifier.save(default_path)
|
|
|
|
|
|
print(f"\nModel saved! You can now use 'Live prediction' without retraining.")
|
|
|
|
|
|
|
|
|
|
|
|
return classifier
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# Change 5 — CLASSIFIER BENCHMARK
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def run_classifier_benchmark():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Cross-validate LDA, QDA, SVM-RBF, and MLP on the collected dataset.
|
|
|
|
|
|
|
|
|
|
|
|
Purpose: tells you whether accuracy plateau is a features problem
|
|
|
|
|
|
(all classifiers similar → add features) or a model complexity problem
|
|
|
|
|
|
(SVM/MLP >> LDA → implement Change E / ensemble).
|
|
|
|
|
|
"""
|
|
|
|
|
|
from sklearn.svm import SVC
|
|
|
|
|
|
from sklearn.neural_network import MLPClassifier
|
|
|
|
|
|
from sklearn.pipeline import Pipeline
|
|
|
|
|
|
from sklearn.preprocessing import StandardScaler
|
|
|
|
|
|
from sklearn.model_selection import cross_val_score, GroupKFold
|
|
|
|
|
|
from sklearn.discriminant_analysis import (LinearDiscriminantAnalysis,
|
|
|
|
|
|
QuadraticDiscriminantAnalysis)
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("CLASSIFIER BENCHMARK (Cross-validation)")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
storage = SessionStorage()
|
|
|
|
|
|
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
|
|
|
|
|
|
|
|
|
|
|
|
if len(np.unique(y)) < 2:
|
|
|
|
|
|
print("Need at least 2 gesture classes. Collect more data first.")
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True)
|
|
|
|
|
|
X = extractor.extract_features_batch(X_raw)
|
|
|
|
|
|
X = EMGClassifier()._apply_session_normalization(X, session_indices, y=y)
|
|
|
|
|
|
|
|
|
|
|
|
clfs = {
|
|
|
|
|
|
'LDA (ESP32 model)': LinearDiscriminantAnalysis(),
|
|
|
|
|
|
'QDA': QuadraticDiscriminantAnalysis(reg_param=0.1),
|
|
|
|
|
|
'SVM-RBF': Pipeline([('s', StandardScaler()),
|
|
|
|
|
|
('m', SVC(kernel='rbf', C=10))]),
|
|
|
|
|
|
'MLP-128-64': Pipeline([('s', StandardScaler()),
|
|
|
|
|
|
('m', MLPClassifier(hidden_layer_sizes=(128, 64),
|
|
|
|
|
|
max_iter=1000,
|
|
|
|
|
|
early_stopping=True))]),
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
n_splits = min(5, len(np.unique(trial_ids)))
|
|
|
|
|
|
gkf = GroupKFold(n_splits=n_splits)
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'Classifier':<22} {'Mean CV':>8} {'Std':>6}")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
for name, clf in clfs.items():
|
|
|
|
|
|
sc = cross_val_score(clf, X, y, cv=gkf, groups=trial_ids, scoring='accuracy')
|
|
|
|
|
|
print(f" {name:<20} {sc.mean()*100:>7.1f}% ±{sc.std()*100:.1f}%")
|
|
|
|
|
|
|
|
|
|
|
|
print()
|
|
|
|
|
|
print(" → If LDA ≈ SVM: features are the bottleneck (add Change 1 features)")
|
|
|
|
|
|
print(" → If SVM >> LDA: model complexity bottleneck (implement Change F ensemble)")
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# LIVE PREDICTION (Real-time gesture classification)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def run_prediction_demo():
|
|
|
|
|
|
"""
|
2026-03-10 11:39:02 -05:00
|
|
|
|
Live prediction demo - classifies gestures in real-time from ESP32.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
Shows:
|
|
|
|
|
|
1. Load saved model OR train fresh on all sessions
|
2026-03-10 11:39:02 -05:00
|
|
|
|
2. Connect to ESP32 and stream real EMG data
|
2026-01-17 23:31:15 -06:00
|
|
|
|
3. Classify each window as it comes in
|
|
|
|
|
|
4. Display predictions with confidence
|
2026-03-10 11:39:02 -05:00
|
|
|
|
|
|
|
|
|
|
REQUIRES: ESP32 hardware connected via USB.
|
2026-01-17 23:31:15 -06:00
|
|
|
|
"""
|
|
|
|
|
|
print("\n" + "=" * 60)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print("LIVE PREDICTION DEMO (ESP32 Required)")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
# Check for saved model
|
|
|
|
|
|
saved_models = EMGClassifier.list_saved_models()
|
|
|
|
|
|
default_model = EMGClassifier.get_default_model_path()
|
|
|
|
|
|
|
|
|
|
|
|
classifier = None
|
|
|
|
|
|
|
|
|
|
|
|
if default_model.exists():
|
|
|
|
|
|
print(f"\nSaved model found: {default_model}")
|
|
|
|
|
|
print(f" File size: {default_model.stat().st_size / 1024:.1f} KB")
|
|
|
|
|
|
|
|
|
|
|
|
load_choice = input("\nLoad saved model? (y=load, n=retrain): ").strip().lower()
|
|
|
|
|
|
if load_choice == 'y':
|
|
|
|
|
|
classifier = EMGClassifier.load(default_model)
|
|
|
|
|
|
|
|
|
|
|
|
if classifier is None:
|
|
|
|
|
|
# Need to train a new model
|
|
|
|
|
|
storage = SessionStorage()
|
|
|
|
|
|
sessions = storage.list_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
if not sessions:
|
|
|
|
|
|
print("\nNo saved sessions found! Collect data first (Option 1).")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Show available sessions
|
|
|
|
|
|
print(f"\nNo saved model (or retraining requested).")
|
|
|
|
|
|
print(f"Will train on ALL {len(sessions)} session(s):")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
total_windows = 0
|
|
|
|
|
|
for session_id in sessions:
|
|
|
|
|
|
info = storage.get_session_info(session_id)
|
|
|
|
|
|
print(f" - {session_id}: {info['num_windows']} windows")
|
|
|
|
|
|
total_windows += info['num_windows']
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nTotal training windows: {total_windows}")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
confirm = input("\nTrain and start prediction? (y/n): ").strip().lower()
|
|
|
|
|
|
if confirm != 'y':
|
|
|
|
|
|
print("Prediction cancelled.")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Load ALL sessions and train model
|
|
|
|
|
|
print(f"\n[Training model on all sessions...]")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training()
|
|
|
|
|
|
print(f"[Unique trials: {len(np.unique(trial_ids))}]")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
classifier = EMGClassifier()
|
|
|
|
|
|
classifier.train(X, y, label_names)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Connect to ESP32
|
|
|
|
|
|
print("\n[Connecting to ESP32...]")
|
|
|
|
|
|
stream = RealSerialStream()
|
|
|
|
|
|
try:
|
|
|
|
|
|
stream.connect(timeout=5.0)
|
|
|
|
|
|
print(f" Connected: {stream.device_info}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f" ERROR: Failed to connect to ESP32: {e}")
|
|
|
|
|
|
print(" Make sure the ESP32 is connected and firmware is flashed.")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# Start live prediction
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("STARTING LIVE PREDICTION (WITH SMOOTHING)")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print("Press Ctrl+C to stop.\n")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(" Smoothing: Probability EMA (0.7) + Majority Vote (5) + Debounce (3)\n")
|
|
|
|
|
|
|
|
|
|
|
|
parser = EMGParser(num_channels=NUM_CHANNELS)
|
2026-03-10 11:39:02 -05:00
|
|
|
|
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, hop_size_ms=HOP_SIZE_MS)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Create prediction smoother
|
|
|
|
|
|
smoother = PredictionSmoother(
|
|
|
|
|
|
label_names=classifier.label_names,
|
|
|
|
|
|
probability_smoothing=0.7, # Higher = more smoothing
|
|
|
|
|
|
majority_vote_window=5, # Past predictions to consider
|
|
|
|
|
|
debounce_count=3, # Consecutive predictions needed to change
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
stream.start()
|
|
|
|
|
|
prediction_count = 0
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
while True:
|
2026-01-17 23:31:15 -06:00
|
|
|
|
# Read and process data
|
|
|
|
|
|
line = stream.readline()
|
|
|
|
|
|
if line:
|
|
|
|
|
|
sample = parser.parse_line(line)
|
|
|
|
|
|
if sample:
|
|
|
|
|
|
window = windower.add_sample(sample)
|
|
|
|
|
|
if window:
|
|
|
|
|
|
# Classify the window (raw prediction)
|
|
|
|
|
|
window_data = window.to_numpy()
|
|
|
|
|
|
raw_label, proba = classifier.predict(window_data)
|
|
|
|
|
|
|
|
|
|
|
|
# Apply smoothing
|
|
|
|
|
|
smoothed_label, smoothed_conf, debug = smoother.update(raw_label, proba)
|
|
|
|
|
|
|
|
|
|
|
|
prediction_count += 1
|
|
|
|
|
|
|
|
|
|
|
|
# Display both raw and smoothed predictions
|
|
|
|
|
|
raw_conf = max(proba) * 100
|
|
|
|
|
|
smoothed_conf_pct = smoothed_conf * 100
|
|
|
|
|
|
|
|
|
|
|
|
# Visual bar for smoothed confidence
|
|
|
|
|
|
bar_len = round(smoothed_conf_pct / 5)
|
|
|
|
|
|
bar = "█" * bar_len + "░" * (20 - bar_len)
|
|
|
|
|
|
|
|
|
|
|
|
# Show raw vs smoothed (smoothed is the stable output)
|
|
|
|
|
|
raw_marker = " " if raw_label == smoothed_label else "!!"
|
|
|
|
|
|
print(f" #{prediction_count:3d} │ {bar} │ {smoothed_label:12s} ({smoothed_conf_pct:5.1f}%) {raw_marker} raw:{raw_label[:8]:8s}")
|
|
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
print("\n\n[Stopped by user]")
|
|
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
|
stream.stop()
|
2026-03-10 11:39:02 -05:00
|
|
|
|
stream.disconnect()
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
|
|
|
|
|
# Show smoothing stats
|
|
|
|
|
|
stats = smoother.get_stats()
|
|
|
|
|
|
print(f"\n" + "-" * 40)
|
|
|
|
|
|
print(f"SMOOTHING STATISTICS")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
print(f" Total predictions: {stats['total_predictions']}")
|
|
|
|
|
|
print(f" Output changes: {stats['output_changes']}")
|
|
|
|
|
|
print(f" Stability ratio: {stats['stability_ratio']*100:.1f}%")
|
|
|
|
|
|
|
|
|
|
|
|
return classifier
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# LDA VISUALIZATION (Decision boundaries and feature space)
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
def run_visualization_demo():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Visualize the LDA model trained on ALL sessions with plots:
|
|
|
|
|
|
1. 2D feature space scatter plot (LDA reduced)
|
|
|
|
|
|
2. Decision boundaries
|
|
|
|
|
|
3. Class distributions
|
|
|
|
|
|
4. Confusion matrix heatmap
|
|
|
|
|
|
|
|
|
|
|
|
Uses all accumulated session data for a complete picture of the model.
|
|
|
|
|
|
"""
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("LDA VISUALIZATION (ALL SESSIONS)")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
|
|
|
|
|
|
storage = SessionStorage()
|
|
|
|
|
|
sessions = storage.list_sessions()
|
|
|
|
|
|
|
|
|
|
|
|
if not sessions:
|
|
|
|
|
|
print("\nNo saved sessions found! Collect data first (Option 1).")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Show available sessions
|
|
|
|
|
|
print(f"\nFound {len(sessions)} saved session(s):")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
total_windows = 0
|
|
|
|
|
|
for session_id in sessions:
|
|
|
|
|
|
info = storage.get_session_info(session_id)
|
|
|
|
|
|
print(f" - {session_id}: {info['num_windows']} windows")
|
|
|
|
|
|
total_windows += info['num_windows']
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\nTotal windows: {total_windows}")
|
|
|
|
|
|
print("-" * 40)
|
|
|
|
|
|
|
|
|
|
|
|
confirm = input("\nVisualize model trained on ALL sessions? (y/n): ").strip().lower()
|
|
|
|
|
|
if confirm != 'y':
|
|
|
|
|
|
print("Visualization cancelled.")
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
# Load ALL data combined
|
2026-03-10 11:39:02 -05:00
|
|
|
|
X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training()
|
|
|
|
|
|
print(f"[Unique trials: {len(np.unique(trial_ids))}]")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
# Extract features (forearm channels only, matching hand classifier)
|
|
|
|
|
|
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS)
|
2026-01-17 23:31:15 -06:00
|
|
|
|
X_features = extractor.extract_features_batch(X)
|
|
|
|
|
|
|
|
|
|
|
|
# Train LDA
|
|
|
|
|
|
print("\n[Training LDA for visualization...]")
|
|
|
|
|
|
lda = LinearDiscriminantAnalysis()
|
|
|
|
|
|
lda.fit(X_features, y)
|
|
|
|
|
|
|
|
|
|
|
|
# Transform to LDA space (reduces to n_classes - 1 dimensions)
|
|
|
|
|
|
X_lda = lda.transform(X_features)
|
|
|
|
|
|
|
|
|
|
|
|
n_classes = len(label_names)
|
|
|
|
|
|
print(f" LDA dimensions: {X_lda.shape[1]}")
|
|
|
|
|
|
|
|
|
|
|
|
# Color scheme
|
|
|
|
|
|
colors = plt.cm.viridis(np.linspace(0.1, 0.9, n_classes))
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figure 1: LDA Feature Space (2D projection) ---
|
|
|
|
|
|
fig1, ax1 = plt.subplots(figsize=(10, 8))
|
|
|
|
|
|
|
|
|
|
|
|
for i, label in enumerate(label_names):
|
|
|
|
|
|
mask = y == i
|
|
|
|
|
|
ax1.scatter(X_lda[mask, 0],
|
|
|
|
|
|
X_lda[mask, 1] if X_lda.shape[1] > 1 else np.zeros(mask.sum()),
|
|
|
|
|
|
c=[colors[i]], label=label, s=100, alpha=0.7, edgecolors='white', linewidth=1)
|
|
|
|
|
|
|
|
|
|
|
|
# Add class means
|
|
|
|
|
|
for i, label in enumerate(label_names):
|
|
|
|
|
|
mask = y == i
|
|
|
|
|
|
mean_x = X_lda[mask, 0].mean()
|
|
|
|
|
|
mean_y = X_lda[mask, 1].mean() if X_lda.shape[1] > 1 else 0
|
|
|
|
|
|
ax1.scatter(mean_x, mean_y, c=[colors[i]], s=400, marker='X', edgecolors='black', linewidth=2)
|
|
|
|
|
|
ax1.annotate(label.upper(), (mean_x, mean_y), fontsize=12, fontweight='bold',
|
|
|
|
|
|
ha='center', va='bottom', xytext=(0, 15), textcoords='offset points')
|
|
|
|
|
|
|
|
|
|
|
|
ax1.set_xlabel("LDA Component 1", fontsize=12)
|
|
|
|
|
|
ax1.set_ylabel("LDA Component 2", fontsize=12)
|
|
|
|
|
|
ax1.set_title("LDA Feature Space - Gesture Clusters", fontsize=14, fontweight='bold')
|
|
|
|
|
|
ax1.legend(loc='upper right', fontsize=10)
|
|
|
|
|
|
ax1.grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figure 2: Decision Boundary Heatmap ---
|
|
|
|
|
|
if X_lda.shape[1] >= 1:
|
|
|
|
|
|
fig2, ax2 = plt.subplots(figsize=(10, 8))
|
|
|
|
|
|
|
|
|
|
|
|
# Create mesh grid
|
|
|
|
|
|
x_min, x_max = X_lda[:, 0].min() - 1, X_lda[:, 0].max() + 1
|
|
|
|
|
|
if X_lda.shape[1] > 1:
|
|
|
|
|
|
y_min, y_max = X_lda[:, 1].min() - 1, X_lda[:, 1].max() + 1
|
|
|
|
|
|
else:
|
|
|
|
|
|
y_min, y_max = -2, 2
|
|
|
|
|
|
|
|
|
|
|
|
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
|
|
|
|
|
|
np.linspace(y_min, y_max, 200))
|
|
|
|
|
|
|
|
|
|
|
|
# For prediction, we need to go back to original feature space
|
|
|
|
|
|
# Use simplified approach: train new LDA on LDA-transformed features
|
|
|
|
|
|
if X_lda.shape[1] > 1:
|
|
|
|
|
|
lda_2d = LinearDiscriminantAnalysis()
|
|
|
|
|
|
lda_2d.fit(X_lda[:, :2], y)
|
|
|
|
|
|
Z = lda_2d.predict(np.c_[xx.ravel(), yy.ravel()])
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 1D case - simple threshold
|
|
|
|
|
|
Z = lda.predict(X_features[0:1]) # dummy
|
|
|
|
|
|
Z = np.zeros(xx.ravel().shape)
|
|
|
|
|
|
for i, x_val in enumerate(xx.ravel()):
|
|
|
|
|
|
if X_lda.shape[1] == 1:
|
|
|
|
|
|
# Find closest class mean
|
|
|
|
|
|
distances = [abs(x_val - X_lda[y == c, 0].mean()) for c in range(n_classes)]
|
|
|
|
|
|
Z[i] = np.argmin(distances)
|
|
|
|
|
|
|
|
|
|
|
|
Z = Z.reshape(xx.shape)
|
|
|
|
|
|
|
|
|
|
|
|
# Plot decision regions
|
|
|
|
|
|
ax2.contourf(xx, yy, Z, alpha=0.3, levels=np.arange(-0.5, n_classes, 1),
|
|
|
|
|
|
colors=[colors[i] for i in range(n_classes)])
|
|
|
|
|
|
ax2.contour(xx, yy, Z, colors='black', linewidths=0.5, alpha=0.5)
|
|
|
|
|
|
|
|
|
|
|
|
# Plot data points
|
|
|
|
|
|
for i, label in enumerate(label_names):
|
|
|
|
|
|
mask = y == i
|
|
|
|
|
|
ax2.scatter(X_lda[mask, 0],
|
|
|
|
|
|
X_lda[mask, 1] if X_lda.shape[1] > 1 else np.zeros(mask.sum()),
|
|
|
|
|
|
c=[colors[i]], label=label, s=80, alpha=0.9, edgecolors='black', linewidth=0.5)
|
|
|
|
|
|
|
|
|
|
|
|
ax2.set_xlabel("LDA Component 1", fontsize=12)
|
|
|
|
|
|
ax2.set_ylabel("LDA Component 2", fontsize=12)
|
|
|
|
|
|
ax2.set_title("LDA Decision Boundaries", fontsize=14, fontweight='bold')
|
|
|
|
|
|
ax2.legend(loc='upper right', fontsize=10)
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figure 3: Feature Importance Radar Chart ---
|
|
|
|
|
|
fig3, ax3 = plt.subplots(figsize=(10, 8), subplot_kw=dict(projection='polar'))
|
|
|
|
|
|
|
|
|
|
|
|
feature_names = extractor.get_feature_names(X.shape[2])
|
|
|
|
|
|
coef = np.abs(lda.coef_).mean(axis=0)
|
|
|
|
|
|
coef_normalized = coef / coef.max() # Normalize to 0-1
|
|
|
|
|
|
|
|
|
|
|
|
# Number of features
|
|
|
|
|
|
n_features = len(feature_names)
|
|
|
|
|
|
angles = np.linspace(0, 2 * np.pi, n_features, endpoint=False).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
# Complete the loop
|
|
|
|
|
|
coef_normalized = np.concatenate([coef_normalized, [coef_normalized[0]]])
|
|
|
|
|
|
angles += angles[:1]
|
|
|
|
|
|
|
|
|
|
|
|
ax3.plot(angles, coef_normalized, 'o-', linewidth=2, color='#2E86AB', markersize=8)
|
|
|
|
|
|
ax3.fill(angles, coef_normalized, alpha=0.25, color='#2E86AB')
|
|
|
|
|
|
ax3.set_xticks(angles[:-1])
|
|
|
|
|
|
ax3.set_xticklabels(feature_names, fontsize=9)
|
|
|
|
|
|
ax3.set_ylim(0, 1.1)
|
|
|
|
|
|
ax3.set_title("Feature Importance (Radar)", fontsize=14, fontweight='bold', pad=20)
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figure 4: Class Distribution Histograms ---
|
|
|
|
|
|
fig4, axes4 = plt.subplots(1, n_classes, figsize=(4*n_classes, 4))
|
|
|
|
|
|
if n_classes == 1:
|
|
|
|
|
|
axes4 = [axes4]
|
|
|
|
|
|
|
|
|
|
|
|
for i, (ax, label) in enumerate(zip(axes4, label_names)):
|
|
|
|
|
|
mask = y == i
|
|
|
|
|
|
ax.hist(X_lda[mask, 0], bins=15, color=colors[i], alpha=0.7, edgecolor='black')
|
|
|
|
|
|
ax.axvline(X_lda[mask, 0].mean(), color='black', linestyle='--', linewidth=2, label='Mean')
|
|
|
|
|
|
ax.set_xlabel("LDA Component 1")
|
|
|
|
|
|
ax.set_ylabel("Count")
|
|
|
|
|
|
ax.set_title(f"{label.upper()}", fontsize=12, fontweight='bold')
|
|
|
|
|
|
ax.legend()
|
|
|
|
|
|
|
|
|
|
|
|
fig4.suptitle("Class Distributions on LDA Axis", fontsize=14, fontweight='bold')
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
|
|
|
|
|
|
# --- Figure 5: Confusion Matrix Heatmap ---
|
|
|
|
|
|
fig5, ax5 = plt.subplots(figsize=(8, 6))
|
|
|
|
|
|
|
|
|
|
|
|
y_pred = cross_val_predict(lda, X_features, y, cv=5)
|
|
|
|
|
|
cm = confusion_matrix(y, y_pred)
|
|
|
|
|
|
|
|
|
|
|
|
im = ax5.imshow(cm, interpolation='nearest', cmap='Blues')
|
|
|
|
|
|
ax5.figure.colorbar(im, ax=ax5)
|
|
|
|
|
|
|
|
|
|
|
|
ax5.set(xticks=np.arange(cm.shape[1]),
|
|
|
|
|
|
yticks=np.arange(cm.shape[0]),
|
|
|
|
|
|
xticklabels=label_names,
|
|
|
|
|
|
yticklabels=label_names,
|
|
|
|
|
|
xlabel='Predicted',
|
|
|
|
|
|
ylabel='Actual',
|
|
|
|
|
|
title='Confusion Matrix Heatmap')
|
|
|
|
|
|
|
|
|
|
|
|
# Add text annotations
|
|
|
|
|
|
thresh = cm.max() / 2.
|
|
|
|
|
|
for i in range(cm.shape[0]):
|
|
|
|
|
|
for j in range(cm.shape[1]):
|
|
|
|
|
|
ax5.text(j, i, format(cm[i, j], 'd'),
|
|
|
|
|
|
ha="center", va="center",
|
|
|
|
|
|
color="white" if cm[i, j] > thresh else "black",
|
|
|
|
|
|
fontsize=14, fontweight='bold')
|
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout()
|
|
|
|
|
|
plt.show()
|
|
|
|
|
|
|
|
|
|
|
|
print("\n Displayed 5 visualization figures")
|
|
|
|
|
|
return lda
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
# ENTRY POINT
|
|
|
|
|
|
# =============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
print(__doc__)
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
print("EMG DATA COLLECTION PIPELINE")
|
|
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
print("\nOptions:")
|
|
|
|
|
|
print(" 1. Collect data (labeled session)")
|
|
|
|
|
|
print(" 2. Inspect saved sessions (view features)")
|
|
|
|
|
|
print(" 3. Train LDA classifier")
|
|
|
|
|
|
print(" 4. Live prediction demo")
|
|
|
|
|
|
print(" 5. Visualize LDA model")
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print(" 6. Classifier benchmark (LDA vs QDA vs SVM vs MLP)")
|
2026-01-17 23:31:15 -06:00
|
|
|
|
print(" q. Quit")
|
|
|
|
|
|
|
|
|
|
|
|
choice = input("\nEnter choice: ").strip().lower()
|
|
|
|
|
|
|
|
|
|
|
|
if choice == 'q':
|
|
|
|
|
|
print("\nGoodbye!")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == "1":
|
|
|
|
|
|
windows, labels = run_labeled_collection_demo()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == "2":
|
|
|
|
|
|
result = run_storage_demo()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == "3":
|
|
|
|
|
|
classifier = run_training_demo()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == "4":
|
|
|
|
|
|
classifier = run_prediction_demo()
|
|
|
|
|
|
|
|
|
|
|
|
elif choice == "5":
|
|
|
|
|
|
lda = run_visualization_demo()
|
|
|
|
|
|
|
2026-03-10 11:39:02 -05:00
|
|
|
|
elif choice == "6":
|
|
|
|
|
|
run_classifier_benchmark()
|
|
|
|
|
|
|
2026-01-17 23:31:15 -06:00
|
|
|
|
else:
|
2026-03-10 11:39:02 -05:00
|
|
|
|
print("\nInvalid choice. Please enter 1-6 or q.")
|