diff --git a/BUCKY_ARM_MASTER_PLAN.md b/BUCKY_ARM_MASTER_PLAN.md new file mode 100644 index 0000000..f208b65 --- /dev/null +++ b/BUCKY_ARM_MASTER_PLAN.md @@ -0,0 +1,2736 @@ +# Bucky Arm — EMG Gesture Control: Master Implementation Reference +> Version: 2026-03-01 | Target: ESP32-S3 N32R16V (Xtensa LX7 @ 240 MHz, 512 KB SRAM, 16 MB OPI PSRAM) +> Supersedes: META_EMG_RESEARCH_NOTES.md + BUCKY_ARM_IMPROVEMENT_PLAN.md +> Source paper: doi:10.1038/s41586-025-09255-w (PDF: C:/VSCode/Marvel_Projects/s41586-025-09255-w.pdf) + +--- + +## TABLE OF CONTENTS + +- [PART 0 — SYSTEM ARCHITECTURE & RESPONSIBILITY ASSIGNMENT](#part-0--system-architecture--responsibility-assignment) + - [0.1 Who Does What](#01-who-does-what) + - [0.2 Operating Modes](#02-operating-modes) + - [0.3 FSM Reference (EMG_MAIN mode)](#03-fsm-reference-emg_main-mode) + - [0.4 EMG_STANDALONE Boot Sequence](#04-emg_standalone-boot-sequence) + - [0.5 New Firmware Changes for Architecture](#05-new-firmware-changes-for-architecture) + - [0.6 New Python Script: live_predict.py](#06-new-python-script-live_predictpy) + - [0.7 Firmware Cleanup: system_mode_t Removal](#07-firmware-cleanup-system_mode_t-removal) +- [PART I — SYSTEM FOUNDATIONS](#part-i--system-foundations) + - [1. Hardware Specification](#1-hardware-specification) + - [2. Current System Snapshot](#2-current-system-snapshot) + - [2.1 Confirmed Firmware Architecture](#21--confirmed-firmware-architecture-from-codebase-exploration) + - [2.2 Bicep Channel Subsystem](#22--bicep-channel-subsystem-ch3--adc_channel_9--gpio-10) + - [3. What Meta Built — Filtered for ESP32](#3-what-meta-built--filtered-for-esp32) + - [4. Current Code State + Known Bugs](#4-current-code-state--known-bugs) +- [PART II — TARGET ARCHITECTURE](#part-ii--target-architecture) + - [5. Full Recommended Multi-Model Stack](#5-full-recommended-multi-model-stack) + - [6. Compute Budget for Full Stack](#6-compute-budget-for-full-stack) + - [7. Why This Architecture Works for 3-Channel EMG](#7-why-this-architecture-works-for-3-channel-emg) +- [PART III — GESTURE EXTENSIBILITY](#part-iii--gesture-extensibility) + - [8. What Changes When Adding or Removing a Gesture](#8-what-changes-when-adding-or-removing-a-gesture) + - [9. Practical Limits of 3-Channel EMG](#9-practical-limits-of-3-channel-emg) + - [10. Specific Gesture Considerations](#10-specific-gesture-considerations) +- [PART IV — CHANGE REFERENCE](#part-iv--change-reference) + - [11. Change Classification Matrix](#11-change-classification-matrix) +- [PART V — FIRMWARE CHANGES](#part-v--firmware-changes) + - [Change A — DMA-Driven ADC Sampling](#change-a--dma-driven-adc-sampling) + - [Change B — IIR Biquad Bandpass Filter](#change-b--iir-biquad-bandpass-filter) + - [Change C — Confidence Rejection](#change-c--confidence-rejection) + - [Change D — On-Device NVS Calibration](#change-d--on-device-nvs-calibration) + - [Change E — int8 MLP via TFLM](#change-e--int8-mlp-via-tflm) + - [Change F — Ensemble Inference Pipeline](#change-f--ensemble-inference-pipeline) +- [PART VI — PYTHON/TRAINING CHANGES](#part-vi--pythontraining-changes) + - [Change 0 — Forward Label Shift](#change-0--forward-label-shift) + - [Change 1 — Expanded Feature Set](#change-1--expanded-feature-set) + - [Change 2 — Electrode Repositioning](#change-2--electrode-repositioning) + - [Change 3 — Data Augmentation](#change-3--data-augmentation) + - [Change 4 — Reinhard Compression](#change-4--reinhard-compression) + - [Change 5 — Classifier Benchmark](#change-5--classifier-benchmark) + - [Change 6 — Simplified MPF Features](#change-6--simplified-mpf-features) + - [Change 7 — Ensemble Training](#change-7--ensemble-training) +- [PART VII — FEATURE SELECTION FOR ESP32 PORTING](#part-vii--feature-selection-for-esp32-porting) +- [PART VIII — MEASUREMENT AND VALIDATION](#part-viii--measurement-and-validation) +- [PART IX — EXPORT WORKFLOW](#part-ix--export-workflow) +- [PART X — REFERENCES](#part-x--references) + +--- + +# PART 0 — SYSTEM ARCHITECTURE & RESPONSIBILITY ASSIGNMENT + +> This section is the authoritative reference for what runs where. All implementation +> decisions in later parts should be consistent with this partition. + +## 0.1 Who Does What + +| Responsibility | Laptop (Python) | ESP32 | +|----------------|-----------------|-------| +| EMG sensor reading | — | ✓ `emg_sensor_read()` always | +| Raw data streaming (for collection) | Receives CSV, saves to HDF5 | Streams CSV over UART | +| Model training | ✓ `learning_data_collection.py` | — | +| Model export | ✓ `export_to_header()` → `model_weights.h` | Compiled into firmware | +| On-device inference | — | ✓ `inference_predict()` | +| Laptop-side live inference | ✓ `live_predict.py` (new script) | Streams ADC + executes received cmd | +| Arm actuation | — (sends gesture string back to ESP32) | ✓ `gestures_execute()` | +| Autonomous operation (no laptop) | Not needed | ✓ `EMG_STANDALONE` mode | +| Bicep flex detection | — | ✓ `bicep_detect()` (new, Section 2.2) | +| NVS calibration | — | ✓ `calibration.c` (Change D) | + +**Key rule**: The laptop is never required for real-time arm control in production. +The laptop's role is: collect data → train model → export → flash firmware → done. +After that, the ESP32 operates completely independently. + +--- + +## 0.2 Operating Modes + +Controlled by `#define MAIN_MODE` in `config/config.h`. +The enum currently reads `enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER}`. +A new value `EMG_STANDALONE` must be added. + +| `MAIN_MODE` | When to use | Laptop required? | Entry point | +|-------------|-------------|-----------------|-------------| +| `EMG_MAIN` | Development sessions, data collection, monitored operation | Yes — UART handshake to start any mode | `appConnector()` in `main.c` | +| `EMG_STANDALONE` | **Fully autonomous deployment** — no laptop | **No** — boots directly into predict+control | `run_standalone_loop()` (new function in `main.c`) | +| `SERVO_CALIBRATOR` | Hardware setup, testing servo range of motion | Yes (serial input) | Inline in `app_main()` | +| `GESTURE_TESTER` | Testing gesture→servo mapping via keyboard | Yes (serial input) | Inline in `app_main()` | + +**How to switch mode**: change `#define MAIN_MODE` in `config.h` and reflash. + +**To add `EMG_STANDALONE` to `config.h`** (1-line change): +```c +// config.h line 19 — current: +enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER}; + +// Update to: +enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER, EMG_STANDALONE}; +``` + +--- + +## 0.3 FSM Reference (EMG_MAIN mode) + +The `device_state_t` enum in `main.c` and the `command_t` enum control all transitions. +Currently: `{STATE_IDLE, STATE_CONNECTED, STATE_STREAMING, STATE_PREDICTING}`. +A new state `STATE_LAPTOP_PREDICT` must be added (see Section 0.5). + +``` +STATE_IDLE + └─ {"cmd":"connect"} ──────────────────────────► STATE_CONNECTED + │ + {"cmd":"start"} ──────────┤ + │ STATE_STREAMING + │ ESP32 sends raw ADC CSV at 1kHz + │ Laptop: saves to HDF5 (data collection) + │ Laptop: trains model → exports model_weights.h + │ ◄──── {"cmd":"stop"} ────────────────────┘ + │ + {"cmd":"start_predict"} ─────────┤ + │ STATE_PREDICTING + │ ESP32: inference_predict() on-device + │ ESP32: gestures_execute() + │ Laptop: optional UART monitor only + │ ◄──── {"cmd":"stop"} ────────────────────┘ + │ + {"cmd":"start_laptop_predict"} ───────┘ + STATE_LAPTOP_PREDICT [NEW] + ESP32: streams raw ADC CSV (same as STREAMING) + Laptop: runs live_predict.py inference + Laptop: sends {"gesture":"fist"} back + ESP32: executes received gesture command + ◄──── {"cmd":"stop"} ────────────────────┘ + +All active states: + {"cmd":"stop"} → STATE_CONNECTED + {"cmd":"disconnect"} → STATE_IDLE + {"cmd":"connect"} → STATE_CONNECTED (from any state — reconnect) +``` + +**Convenience table of commands and their effects:** + +| JSON command | Valid from state | Result | +|---|---|---| +| `{"cmd":"connect"}` | Any | → `STATE_CONNECTED` | +| `{"cmd":"start"}` | `STATE_CONNECTED` | → `STATE_STREAMING` | +| `{"cmd":"start_predict"}` | `STATE_CONNECTED` | → `STATE_PREDICTING` | +| `{"cmd":"start_laptop_predict"}` | `STATE_CONNECTED` | → `STATE_LAPTOP_PREDICT` (new) | +| `{"cmd":"stop"}` | `STREAMING/PREDICTING/LAPTOP_PREDICT` | → `STATE_CONNECTED` | +| `{"cmd":"disconnect"}` | Any active state | → `STATE_IDLE` | + +--- + +## 0.4 EMG_STANDALONE Boot Sequence + +No UART handshake. No laptop required. Powers on → predicts → controls arm. + +``` +app_main() switch MAIN_MODE == EMG_STANDALONE: + │ + ├── hand_init() // servos + ├── emg_sensor_init() // ADC setup + ├── inference_init() // clear window buffer, reset smoothing state + ├── calibration_init() // load NVS z-score params (Change D) + │ └── if not found in NVS: + │ collect 120 REST windows (~3s at 25ms hop) + │ call calibration_update() to compute and store stats + ├── bicep_load_threshold() // load NVS bicep threshold (Section 2.2) + │ └── if not found: + │ collect 3s of still bicep data + │ call bicep_calibrate() and bicep_save_threshold() + │ + └── run_standalone_loop() ← NEW function (added to main.c) + while (1): + emg_sensor_read(&sample) + inference_add_sample(sample.channels) + if stride_counter++ >= INFERENCE_HOP_SIZE: + stride_counter = 0 + gesture_t g = inference_get_gesture_enum(inference_predict(&conf)) + gestures_execute(g) + bicep_state_t b = bicep_detect() + // (future: bicep_actuate(b)) + vTaskDelay(1) +``` + +`run_standalone_loop()` is structurally identical to `run_inference_loop()` in `EMG_MAIN`, +minus all UART state-change checking and telemetry prints. It runs forever until power-off. + +**Where to add**: New function `run_standalone_loop()` in `app/main.c`, plus a new case +in the `app_main()` switch block: +```c +case EMG_STANDALONE: + run_standalone_loop(); + break; +``` + +--- + +## 0.5 New Firmware Changes for Architecture + +These changes are needed to implement the architecture above. They are **structural** +(not accuracy improvements) and should be done before any other changes. + +### S1 — Add `EMG_STANDALONE` to `config.h` + +**File**: `EMG_Arm/src/config/config.h`, line 19 +```c +// Change: +enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER}; +// To: +enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER, EMG_STANDALONE}; +``` + +### S2 — Add `STATE_LAPTOP_PREDICT` to FSM (`main.c`) + +**File**: `EMG_Arm/src/app/main.c` + +```c +// In device_state_t enum — add new state: +typedef enum { + STATE_IDLE = 0, + STATE_CONNECTED, + STATE_STREAMING, + STATE_PREDICTING, + STATE_LAPTOP_PREDICT, // ← ADD: streams ADC to laptop, executes laptop's gesture commands +} device_state_t; + +// In command_t enum — add new command: +typedef enum { + CMD_NONE = 0, + CMD_CONNECT, + CMD_START, + CMD_START_PREDICT, + CMD_START_LAPTOP_PREDICT, // ← ADD + CMD_STOP, + CMD_DISCONNECT, +} command_t; +``` + +**In `parse_command()`** — add detection (place BEFORE the `"start"` check to avoid prefix collision): +```c +} else if (strncmp(value_start, "start_laptop_predict", 20) == 0) { + return CMD_START_LAPTOP_PREDICT; +} else if (strncmp(value_start, "start_predict", 13) == 0) { + return CMD_START_PREDICT; +} else if (strncmp(value_start, "start", 5) == 0) { + return CMD_START; +``` + +**In `serial_input_task()` FSM switch** — add to `STATE_CONNECTED` block: +```c +} else if (cmd == CMD_START_LAPTOP_PREDICT) { + g_device_state = STATE_LAPTOP_PREDICT; + printf("[STATE] CONNECTED -> LAPTOP_PREDICT\n"); + xQueueSend(g_cmd_queue, &cmd, 0); +} +``` + +**Add to the active-state check** in `serial_input_task()`: +```c +case STATE_STREAMING: +case STATE_PREDICTING: +case STATE_LAPTOP_PREDICT: // ← ADD to the case list + if (cmd == CMD_STOP) { ... } +``` + +**New function `run_laptop_predict_loop()`** (add alongside `stream_emg_data()` and `run_inference_loop()`): +```c +/** + * @brief Laptop-mediated prediction loop (STATE_LAPTOP_PREDICT). + * + * Streams raw ADC CSV to laptop for inference. + * Simultaneously reads gesture commands sent back by laptop. + * Executes received gesture immediately. + * + * Laptop sends: {"gesture":"fist"}\n OR {"gesture":"rest"}\n etc. + * ESP32 parses the "gesture" field and calls inference_get_gesture_enum() + gestures_execute(). + */ +static void run_laptop_predict_loop(void) { + emg_sample_t sample; + char cmd_buf[64]; + int cmd_idx = 0; + + printf("{\"status\":\"info\",\"msg\":\"Laptop-predict mode started\"}\n"); + + while (g_device_state == STATE_LAPTOP_PREDICT) { + // 1. Send raw ADC sample (same format as STATE_STREAMING) + emg_sensor_read(&sample); + printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1], + sample.channels[2], sample.channels[3]); + + // 2. Non-blocking read of any incoming gesture command from laptop + // (serial_input_task already handles FSM commands; this handles gesture commands) + // Note: getchar() is non-blocking when there is no data (returns EOF). + // Gesture messages from laptop look like: {"gesture":"fist"}\n + int c = getchar(); + if (c != EOF && c != 0xFF) { + if (c == '\n' || c == '\r') { + if (cmd_idx > 0) { + cmd_buf[cmd_idx] = '\0'; + // Parse {"gesture":""} — look for "gesture" field + const char *g = strstr(cmd_buf, "\"gesture\""); + if (g) { + const char *v = strchr(g, ':'); + if (v) { + v++; + while (*v == ' ' || *v == '"') v++; + // Extract gesture name up to closing quote + char name[32] = {0}; + int ni = 0; + while (*v && *v != '"' && ni < 31) name[ni++] = *v++; + name[ni] = '\0'; + // Map name to enum and execute (reuse inference mapping) + gesture_t gesture = (gesture_t)inference_get_gesture_enum_by_name(name); + if (gesture != GESTURE_NONE) { + gestures_execute(gesture); + } + } + } + cmd_idx = 0; + } + } else if (cmd_idx < (int)sizeof(cmd_buf) - 1) { + cmd_buf[cmd_idx++] = (char)c; + } else { + cmd_idx = 0; + } + } + + vTaskDelay(1); + } +} +``` + +**Note**: `inference_get_gesture_enum_by_name(const char *name)` is just the existing +`inference_get_gesture_enum(int class_idx)` refactored to accept a string directly +(bypassing the class_idx lookup). Alternatively, keep the existing function and add a +simple wrapper — the string matching logic already exists in `inference.c`: +```c +// Simpler: reuse the existing strcmp chain in inference_get_gesture_enum() +// by passing the name through a helper that returns the gesture_t directly. +// Add to inference.c / inference.h: +gesture_t inference_get_gesture_by_name(const char *name); +// (same strcmp logic as inference_get_gesture_enum, but returns gesture_t directly) +``` + +**In `state_machine_loop()`** — add the new state: +```c +static void state_machine_loop(void) { + command_t cmd; + const TickType_t poll_interval = pdMS_TO_TICKS(50); + while (1) { + if (g_device_state == STATE_STREAMING) stream_emg_data(); + else if (g_device_state == STATE_PREDICTING) run_inference_loop(); + else if (g_device_state == STATE_LAPTOP_PREDICT) run_laptop_predict_loop(); // ← ADD + xQueueReceive(g_cmd_queue, &cmd, poll_interval); + } +} +``` + +**In `app_main()` switch** — add the standalone case: +```c +case EMG_STANDALONE: + run_standalone_loop(); // new function — see Section 0.4 + break; +``` + +--- + +## 0.6 New Python Script: `live_predict.py` + +**Location**: `C:/VSCode/Marvel_Projects/Bucky_Arm/live_predict.py` (new file) +**Purpose**: Laptop-side live inference. Reads raw ADC stream from ESP32, runs the Python +classifier, sends gesture commands back to ESP32 for arm control. +**When to use**: `EMG_MAIN` + `STATE_LAPTOP_PREDICT` — useful for debugging and comparing +laptop accuracy vs on-device accuracy before flashing a new model. + +```python +""" +live_predict.py — Laptop-side live EMG inference for Bucky Arm. + +Connects to ESP32, requests STATE_LAPTOP_PREDICT, reads raw ADC CSV, +runs the trained Python classifier, sends gesture commands back to ESP32. + +Usage: + python live_predict.py --port COM3 --model path/to/saved_model/ +""" +import argparse +import time +import numpy as np +import serial +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from learning_data_collection import ( + EMGClassifier, EMGFeatureExtractor, SessionStorage, HAND_CHANNELS, + WINDOW_SIZE_SAMPLES, HOP_SIZE_SAMPLES, NUM_CHANNELS, +) + +BAUD_RATE = 921600 +CALIB_SEC = 3.0 # seconds of REST to collect for normalization at startup +CALIB_LABEL = "rest" # label used during calibration window + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--port", required=True, help="Serial port, e.g. COM3 or /dev/ttyUSB0") + p.add_argument("--model", required=True, help="Path to saved EMGClassifier model directory") + return p.parse_args() + +def handshake(ser): + """Send connect command, wait for ack.""" + ser.write(b'{"cmd":"connect"}\n') + deadline = time.time() + 5.0 + while time.time() < deadline: + line = ser.readline().decode("utf-8", errors="ignore").strip() + if "ack_connect" in line: + print(f"[Handshake] Connected: {line}") + return True + raise RuntimeError("No ack_connect received within 5s") + +def collect_calibration_windows(ser, n_windows, window_size, hop_size, n_channels): + """Collect n_windows worth of REST data for normalization calibration.""" + print(f"[Calib] Collecting {n_windows} REST windows — hold arm still...") + raw_buffer = np.zeros((window_size, n_channels), dtype=np.float32) + windows = [] + sample_count = 0 + while len(windows) < n_windows: + line = ser.readline().decode("utf-8", errors="ignore").strip() + try: + vals = [float(v) for v in line.split(",")] + if len(vals) != n_channels: + continue + except ValueError: + continue + raw_buffer = np.roll(raw_buffer, -1, axis=0) + raw_buffer[-1] = vals + sample_count += 1 + if sample_count >= window_size and sample_count % hop_size == 0: + windows.append(raw_buffer.copy()) + print(f"[Calib] Collected {len(windows)} windows. Computing normalization stats...") + return np.array(windows) # (n_windows, window_size, n_channels) + +def main(): + args = parse_args() + + # Load trained classifier + print(f"[Init] Loading classifier from {args.model}...") + classifier = EMGClassifier() + classifier.load(Path(args.model)) + extractor = classifier.feature_extractor + + ser = serial.Serial(args.port, BAUD_RATE, timeout=1.0) + time.sleep(0.5) + ser.reset_input_buffer() + + handshake(ser) + + # Request laptop-predict mode + ser.write(b'{"cmd":"start_laptop_predict"}\n') + print("[Control] Entered STATE_LAPTOP_PREDICT") + + # Calibration: collect 3s of REST for session normalization + n_calib_windows = max(10, int(CALIB_SEC * 1000 / (HOP_SIZE_SAMPLES))) + calib_raw = collect_calibration_windows( + ser, n_calib_windows, WINDOW_SIZE_SAMPLES, HOP_SIZE_SAMPLES, NUM_CHANNELS + ) + calib_features = extractor.extract_features_batch(calib_raw) + calib_mean = calib_features.mean(axis=0) + calib_std = np.where(calib_features.std(axis=0) > 1e-6, + calib_features.std(axis=0), 1e-6) + print("[Calib] Done. Starting live prediction...") + + # Live prediction loop + raw_buffer = np.zeros((WINDOW_SIZE_SAMPLES, NUM_CHANNELS), dtype=np.float32) + sample_count = 0 + last_gesture = None + + try: + while True: + line = ser.readline().decode("utf-8", errors="ignore").strip() + + # Skip JSON telemetry lines from ESP32 + if line.startswith("{"): + continue + + try: + vals = [float(v) for v in line.split(",")] + if len(vals) != NUM_CHANNELS: + continue + except ValueError: + continue + + # Slide window + raw_buffer = np.roll(raw_buffer, -1, axis=0) + raw_buffer[-1] = vals + sample_count += 1 + + if sample_count >= WINDOW_SIZE_SAMPLES and sample_count % HOP_SIZE_SAMPLES == 0: + # Extract features and normalize with session stats + feat = extractor.extract_features_window(raw_buffer) + feat = (feat - calib_mean) / calib_std + + proba = classifier.model.predict_proba([feat])[0] + class_idx = int(np.argmax(proba)) + gesture_name = classifier.label_names[class_idx] + confidence = float(proba[class_idx]) + + # Send gesture command to ESP32 + cmd = f'{{"gesture":"{gesture_name}"}}\n' + ser.write(cmd.encode("utf-8")) + + if gesture_name != last_gesture: + print(f"[Predict] {gesture_name:12s} conf={confidence:.2f}") + last_gesture = gesture_name + + except KeyboardInterrupt: + print("\n[Stop] Sending stop command...") + ser.write(b'{"cmd":"stop"}\n') + ser.close() + +if __name__ == "__main__": + main() +``` + +**Dependencies** (add to a `requirements.txt` in `Bucky_Arm/` if not already there): +``` +pyserial +numpy +scikit-learn +``` + +--- + +## 0.7 Firmware Cleanup: `system_mode_t` Removal + +`config.h` lines 94–100 define a `system_mode_t` typedef that is **not referenced anywhere** +in the firmware. It predates the current `device_state_t` FSM in `main.c` and conflicts +conceptually with it. Remove before starting implementation work. + +**File**: `EMG_Arm/src/config/config.h` +**Remove** (lines 93–100): +```c +/** + * @brief System operating modes. + */ +typedef enum { + MODE_IDLE = 0, /**< Waiting for commands */ + MODE_DATA_STREAM, /**< Streaming EMG data to laptop */ + MODE_COMMAND, /**< Executing gesture commands from laptop */ + MODE_DEMO, /**< Running demo sequence */ + MODE_COUNT +} system_mode_t; +``` +No other file references `system_mode_t` — the deletion is safe and requires no other changes. + +--- + +# PART I — SYSTEM FOUNDATIONS + +## 1. Hardware Specification + +### ESP32-S3 N32R16V — Confirmed Hardware + +| Resource | Spec | Implication | +|----------|------|-------------| +| CPU | Dual-core Xtensa LX7 @ 240 MHz | Pin inference to Core 1, sampling to Core 0 | +| SIMD | PIE 128-bit vector extension | esp-dsp exploits this for FFT, biquad, dot-product | +| Internal SRAM | ~512 KB | All hot-path buffers, model weights, inference state | +| OPI PSRAM | 16 MB (~80 MB/s) | ADC ring buffer, raw window storage — not hot path | +| Flash | 32 MB | Code + read-only model flatbuffers (TFLM path) | +| ADC | 2× SAR ADC, 12-bit, continuous DMA mode | Change A: use `adc_continuous` driver | + +**Memory rules**: +- Tag inference code: `IRAM_ATTR` — prevents cache miss stalls +- Tag large ring buffers: `EXT_RAM_BSS_ATTR` — pushes to PSRAM automatically +- Never run hot-path loops from PSRAM (latency varies; ~10× slower than SRAM) + +### Espressif Acceleration Libraries + +| Library | Accelerates | Key Functions | +|---------|-------------|---------------| +| **esp-dsp** | IIR biquad, FFT (up to 4096-pt), vector dot-product, matrix ops — PIE SIMD | `dsps_biquad_f32`, `dsps_fft2r_fc32`, `dsps_dotprod_f32` | +| **esp-nn** | int8 FC, depthwise/pointwise Conv, activations — SIMD optimized | Used internally by esp-dl | +| **esp-dl** | High-level int8 inference: MLP, Conv1D, LSTM; activation buffer management | Small MLP / tiny CNN deployment | +| **TFLite Micro** | Standard int8 flatbuffer inference, tensor arena (static alloc) | Keras → TFLite → int8 workflow | + +### Real-Time Budget (1000 Hz, 25ms hop) + +| Stage | Cost | Notes | +|-------|------|-------| +| ADC DMA sampling | ~0 µs | Hardware; CPU-free | +| IIR biquad (3 ch, 2 stages) | <100 µs | `dsps_biquad_f32` | +| Feature extraction (69 feat) | ~1,200 µs | FFT-based features dominate | +| 3 specialist LDAs | ~150 µs | `dsps_dotprod_f32` per class | +| Meta-LDA (15 inputs) | ~10 µs | 75 MACs total | +| int8 MLP fallback [69→32→16→5] | ~250 µs | esp-nn FC kernels | +| Post-processing | <50 µs | EMA, vote, debounce | +| **Total (full ensemble)** | **~1,760 µs** | **14× margin within 25ms** | + +### Hard No-Gos + +| Technique | Why | +|-----------|-----| +| Full MPF with matrix logarithm | Eigendecomposition per window; fragile float32; no SIMD path | +| Conv1D(16→512) + 3×LSTM(512) | ~4 MB weights; LSTM sequential dependency — impossible | +| Any transformer / attention | O(n²); no int8 transformer kernels for MCU | +| On-device gradient updates | Inference only — no training infrastructure | +| Heap allocations on hot path | FreeRTOS heap fragmentation kills determinism | + +--- + +## 2. Current System Snapshot + +| Aspect | Current State | +|--------|--------------| +| Channels | 4 total; ch0–ch2 forearm (FCR, FCU, extensor), ch3 bicep (excluded from hand classifier) | +| Sampling | 1000 Hz, timer/polling (jitter — fix with Change A) | +| Window | 150 samples (150ms), 25-sample hop (25ms) | +| Features | 12: RMS, WL, ZC, SSC × 3 channels | +| Classifier | Single LDA, float32 weights in C header | +| Label alignment | RMS onset detection — missing +100ms forward shift (Change 0) | +| Normalization | Per-session z-score in Python; no on-device equivalent (Change D) | +| Smoothing | EMA (α=0.7) + majority vote (5) + debounce (3 counts) | +| Confidence rejection | None — always outputs a class (Change C) | +| Signal filtering | Analogue only via MyoWare (Change B adds software IIR) | +| Gestures | 5: fist, hook\_em, open, rest, thumbs\_up | +| Training data | 15 HDF5 sessions, 1 user | + +--- + +## 2.1 — Confirmed Firmware Architecture (From Codebase Exploration) + +> Confirmed by direct codebase inspection 2026-02-24. All file paths relative to +> `C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/` + +### ADC Pin Mapping (`drivers/emg_sensor.c`) + +| Channel | ADC Channel | GPIO | Muscle Location | Role in Classifier | +|---------|-------------|------|-----------------|-------------------| +| ch0 | `ADC_CHANNEL_1` | GPIO 2 | Forearm Belly (FCR) | Primary flexion signal | +| ch1 | `ADC_CHANNEL_2` | GPIO 3 | Forearm Extensors | Extension signal | +| ch2 | `ADC_CHANNEL_8` | GPIO 9 | Forearm Contractors (FCU) | Ulnar flexion signal | +| ch3 | `ADC_CHANNEL_9` | GPIO 10 | Bicep | Independent — see Section 2.2 | + +**Current ADC driver**: `adc_oneshot` (polling — **NOT DMA continuous yet**; Change A migrates this) +- Attenuation: `ADC_ATTEN_DB_12` (0–3.9V full-scale range) +- Calibration: `adc_cali_curve_fitting` scheme +- Output: calibrated millivolts as `uint16_t` packed into `emg_sample_t.channels[4]` +- Timing: `vTaskDelay(1)` in `run_inference_loop()` provides the ~1ms sample interval + +### Current Task Structure (`app/main.c`) + +| Task | Priority | Stack | Core Pinning | Role | +|------|----------|-------|--------------|------| +| `app_main` (implicit) | Default | Default | None | Runs inference loop + state machine | +| `serial_input_task` | 5 | 4096 B | **None** | Parses UART JSON commands | + +**No other tasks exist.** Change A will add `adc_sampling_task` pinned to Core 0. +The inference loop runs on `app_main`'s default task — no explicit core affinity. + +### State Machine (`app/main.c`) + +``` +STATE_IDLE ─(BLE/UART connect)─► STATE_CONNECTED + │ + {"cmd": "start_stream"}▼ + STATE_STREAMING (sends raw ADC over UART for Python) + │ + {"cmd": "start_predict"}▼ + STATE_PREDICTING (runs run_inference_loop()) +``` +Communication: UART at 921600 baud, JSON framing. + +### Complete Data Flow (Exact Function Names) + +``` +emg_sensor_read(&sample) + │ drivers/emg_sensor.c + │ adc_oneshot_read() × 4 channels → adc_cali_raw_to_voltage() → uint16_t mV + │ Result: sample.channels[4] = {ch0_mV, ch1_mV, ch2_mV, ch3_mV} + │ + ▼ Called every ~1ms (vTaskDelay(1) in run_inference_loop) +inference_add_sample(sample.channels) + │ core/inference.c + │ Writes to circular window_buffer[150][4] + │ Returns true when buffer is full (after first 150 samples) + │ + ▼ Called every 25 samples (stride_counter % INFERENCE_HOP_SIZE == 0) +inference_predict(&confidence) + │ core/inference.c + │ compute_features() → LDA scores → softmax → EMA → majority vote → debounce + │ Returns: gesture class index (int), fills confidence (float) + │ + ▼ +inference_get_gesture_enum(class_idx) + │ core/inference.c + │ String match on MODEL_CLASS_NAMES[] → gesture_t enum value + │ + ▼ +gestures_execute(gesture) + core/gestures.c + switch(gesture) → servo PWM via LEDC driver + Servo pins: GPIO 1,4,5,6,7 (Thumb, Index, Middle, Ring, Pinky) +``` + +### Current Buffer State + +```c +// core/inference.c line 19: +static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS]; +// ^^^^^^^^ MUST change to float when adding IIR filter (Change B) +// +// uint16_t: 150 × 4 × 2 = 1,200 bytes in internal SRAM +// float: 150 × 4 × 4 = 2,400 bytes in internal SRAM (still trivially small) +// +// Reason for change: IIR filter outputs float; casting back to uint16_t loses +// sub-mV precision and re-introduces the quantization noise we just filtered out. +``` + +### `platformio.ini` Current State (`EMG_Arm/platformio.ini`) + +**Current `lib_deps`**: **None** — completely empty, no external library dependencies. + +Required additions per change tier: + +| Change | Library | `platformio.ini` `lib_deps` entry | +|--------|---------|----------------------------------| +| B (IIR biquad) | esp-dsp | `espressif/esp-dsp @ ^2.0.0` | +| 1 (FFT features) | esp-dsp | (same — add once for both B and 1) | +| E (int8 MLP) | TFLite Micro | `tensorflow/tflite-micro` | +| F (ensemble) | esp-dsp | (same as B) | + +Add to `platformio.ini` under `[env:esp32-s3-devkitc1-n16r16]`: +```ini +lib_deps = + espressif/esp-dsp @ ^2.0.0 + ; tensorflow/tflite-micro ← add this only when implementing Change E +``` + +--- + +## 2.2 — Bicep Channel Subsystem (ch3 / ADC_CHANNEL_9 / GPIO 10) + +### Current Status + +The bicep channel is: +- **Sampled**: `emg_sensor_read()` reads all 4 channels; `sample.channels[3]` holds bicep data +- **Excluded from hand classifier**: `HAND_NUM_CHANNELS = 3`; `compute_features()` explicitly + loops `ch = 0` to `ch < HAND_NUM_CHANNELS` (i.e., ch0, ch1, ch2 only) +- **Not yet independently processed**: the comment in `inference.c` line 68 + (`"ch3 (bicep) is excluded — it will be processed independently"`) is aspirational — + the independent processing is not yet implemented + +### Phase 1 — Binary Flex/Unflex (Current Target) + +Implement a simple RMS threshold detector as a new subsystem: + +**New files:** +``` +EMG_Arm/src/core/bicep.h +EMG_Arm/src/core/bicep.c +``` + +**bicep.h:** +```c +#pragma once +#include +#include + +typedef enum { + BICEP_STATE_REST = 0, + BICEP_STATE_FLEX = 1, +} bicep_state_t; + +// Call once at session start with ~3s of relaxed bicep data. +// Returns the computed threshold (also stored internally). +float bicep_calibrate(const uint16_t *ch3_samples, int n_samples); + +// Call every 25ms (same hop as hand gesture inference). +// Computes RMS on the last BICEP_WINDOW_SAMPLES from the ch3 circular buffer. +bicep_state_t bicep_detect(void); + +// Load/save threshold to NVS (reuse calibration.c infrastructure from Change D) +bool bicep_save_threshold(float threshold_mv); +bool bicep_load_threshold(float *threshold_mv_out); +``` + +**Core logic (`bicep.c`):** +```c +#define BICEP_WINDOW_SAMPLES 50 // 50ms window at 1000Hz +#define BICEP_FLEX_MULTIPLIER 2.5f // threshold = rest_rms × 2.5 +#define BICEP_HYSTERESIS 1.3f // prevents rapid toggling at threshold boundary + +static float s_threshold_mv = 0.0f; +static bicep_state_t s_state = BICEP_STATE_REST; + +float bicep_calibrate(const uint16_t *ch3_samples, int n_samples) { + float rms_sq = 0.0f; + for (int i = 0; i < n_samples; i++) + rms_sq += (float)ch3_samples[i] * ch3_samples[i]; + float rest_rms = sqrtf(rms_sq / n_samples); + s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER; + printf("[Bicep] Calibrated: rest_rms=%.1f mV, threshold=%.1f mV\n", + rest_rms, s_threshold_mv); + return s_threshold_mv; +} + +bicep_state_t bicep_detect(void) { + // Compute RMS on last BICEP_WINDOW_SAMPLES from ch3 circular buffer + // (ch3 values are stored in window_buffer[][3] alongside hand channels) + float rms_sq = 0.0f; + int idx = buffer_head; + for (int i = 0; i < BICEP_WINDOW_SAMPLES; i++) { + float v = (float)window_buffer[idx][3]; // ch3 = bicep + rms_sq += v * v; + idx = (idx + 1) % INFERENCE_WINDOW_SIZE; + } + float rms = sqrtf(rms_sq / BICEP_WINDOW_SAMPLES); + + // Hysteresis: require FLEX_MULTIPLIER to enter flex, 1.0× to exit + if (s_state == BICEP_STATE_REST && rms > s_threshold_mv * BICEP_HYSTERESIS) + s_state = BICEP_STATE_FLEX; + else if (s_state == BICEP_STATE_FLEX && rms < s_threshold_mv) + s_state = BICEP_STATE_REST; + + return s_state; +} +``` + +**Integration in `main.c` `run_inference_loop()`:** +```c +// Call alongside inference_predict() every 25ms: +if (stride_counter % INFERENCE_HOP_SIZE == 0) { + float confidence; + int class_idx = inference_predict(&confidence); + gesture_t gesture = inference_get_gesture_enum(class_idx); + bicep_state_t bicep = bicep_detect(); + + // Combined actuation: hand gesture + bicep state + // Example: bicep flex can enable/disable certain gestures, + // or control a separate elbow/wrist joint. + gestures_execute(gesture); + // bicep_actuate(bicep); ← add when elbow motor is wired +} +``` + +**Calibration trigger (add to serial_input_task command parsing):** +```c +// {"cmd": "calibrate_bicep"} → collect 3s of rest data, call bicep_calibrate() +``` + +### Phase 2 — Continuous Angle/Velocity Prediction (Future) + +When ready to move beyond binary flex/unflex: + +1. **Collect angle-labeled data**: hold arm at 0°, 15°, 30°, 45°, 60°, 75°, 90°; + log RMS at each; collect 5+ reps per angle. +2. **Fit polynomial**: `angle = a0 + a1*rms + a2*rms²` (degree-2 usually sufficient); + use `numpy.polyfit(rms_values, angles, deg=2)`. +3. **Store coefficients in NVS**: 3 floats via `nvs_set_blob()`. +4. **On-device evaluation**: `angle = a0 + rms*(a1 + rms*a2)` — 2 MACs per inference. +5. **Velocity**: `velocity = (angle_now - angle_prev) / HOP_MS` with low-pass smoothing. + +### Including ch3 in Hand Gesture Classifier (for Wrist Rotation) + +If/when wrist rotation or supination gestures are added: +```python +# learning_data_collection.py — change this constant: +HAND_CHANNELS = [0, 1, 2, 3] # was [0, 1, 2]; include bicep for rotation gestures +``` +Feature count becomes: 4 channels × 20 per-ch + 10 cross-ch covariances + 6 correlations = **96 total**. +The bicep subsystem is then retired and ch3 becomes part of the main gesture classifier. + +--- + +## 3. What Meta Built — Filtered for ESP32 + +Meta's Nature 2025 paper (doi:10.1038/s41586-025-09255-w) describes a 16-channel wristband +running Conv1D(16→512)+3×LSTM(512). **That exact model is not portable to ESP32-S3** (~4 MB +weights). What IS transferable: + +| Meta Technique | Transferability | Where Used | +|----------------|-----------------|-----------| +| +100ms forward label shift after onset detection | ✓ Direct copy | Change 0 | +| Frequency features > amplitude features (Extended Data Fig. 6) | ✓ Core insight | Change 1, Change 6 | +| Deliberate electrode repositioning between sessions | ✓ Protocol | Change 2 | +| Window jitter + amplitude augmentation | ✓ Training | Change 3 | +| Reinhard compression `64x/(32+|x|)` | ✓ Optional flag | Change 4 | +| EMA α=0.7, threshold=0.35, debounce=50ms | ✓ Already implemented | Change C | +| Specialist features → meta-learner stacking | ✓ Adapted | Change 7 + F | +| Conv1D+LSTM architecture | ✗ Too large | Not implementable | +| Full MPF with matrix logarithm | ✗ Eigendecomp too costly | Not implementable | + +--- + +## 4. Current Code State + Known Bugs + +**All Python changes**: `C:/VSCode/Marvel_Projects/Bucky_Arm/learning_data_collection.py` +**Firmware**: `C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/inference.c` +**Config**: `C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/config/config.h` +**Weights**: `C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/model_weights.h` + +### Key Symbol Locations + +| Symbol | Line | Notes | +|--------|------|-------| +| Constants block | 49–94 | `NUM_CHANNELS`, `SAMPLING_RATE_HZ`, `WINDOW_SIZE_MS`, etc. | +| `align_labels_with_onset()` | 442 | RMS onset detection | +| `filter_transition_windows()` | 529 | Removes onset/offset ambiguity windows | +| `SessionStorage.save_session()` | 643 | Calls onset alignment, saves HDF5 | +| `SessionStorage.load_all_for_training()` | 871 | Returns 6 values (see bug below) | +| `EMGFeatureExtractor` class | 1404 | Current: RMS, WL, ZC, SSC only | +| `extract_features_single_channel()` | 1448 | Per-channel feature dict | +| `extract_features_window()` | 1482 | Flat array + cross-channel | +| `extract_features_batch()` | 1520 | Batch wrapper | +| `get_feature_names()` | 1545 | String names for features | +| `CalibrationTransform` class | 1562 | z-score at Python-side inference | +| `EMGClassifier` class | 1713 | LDA/QDA wrapper | +| `EMGClassifier.__init__()` | 1722 | Creates `EMGFeatureExtractor` | +| `EMGClassifier.train()` | 1735 | Feature extraction + model fit | +| `EMGClassifier._apply_session_normalization()` | 1774 | Per-session z-score | +| `EMGClassifier.cross_validate()` | 1822 | GroupKFold, trial-level | +| `EMGClassifier.export_to_header()` | 1956 | Writes `model_weights.h` | +| `EMGClassifier.save()` | 1910 | Persists model params | +| `EMGClassifier.load()` | 2089 | Reconstructs from saved params | +| `run_training_demo()` | 2333 | Main training entry point | +| `inference.c` `compute_features()` | 68 | C feature extraction | +| `inference.c` `inference_predict()` | 158 | C LDA + smoothing pipeline | + +### Pending Cleanups (Do Before Any Other Code Changes) + +| Item | File | Action | +|------|------|--------| +| Remove `system_mode_t` | `config/config.h` lines 93–100 | Delete the unused typedef (see Part 0, Section 0.7) | +| Add `EMG_STANDALONE` to enum | `config/config.h` line 19 | Add value to the existing MAIN_MODE enum | +| Add `STATE_LAPTOP_PREDICT` + `CMD_START_LAPTOP_PREDICT` | `app/main.c` | See Part 0, Section 0.5 for exact diffs | +| Add `run_standalone_loop()` | `app/main.c` | New function — see Part 0, Section 0.4 | +| Add `run_laptop_predict_loop()` | `app/main.c` | New function — see Part 0, Section 0.5 | +| Add `inference_get_gesture_by_name()` | `core/inference.c` + `core/inference.h` | Small helper — extracts existing strcmp logic | + +### Known Bug — Line 2382 + +```python +# BUG: load_all_for_training() returns 6 values; this call unpacks only 5. +# session_indices_combined is silently dropped — breaks per-session normalization. +X, y, trial_ids, label_names, loaded_sessions = storage.load_all_for_training() + +# FIX (apply with Change 1): +X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training() +``` + +### Current `model_weights.h` State (as of 2026-02-14 training run) + +| Constant | Value | Note | +|----------|-------|------| +| `MODEL_NUM_CLASSES` | 5 | fist, hook_em, open, rest, thumbs_up | +| `MODEL_NUM_FEATURES` | 12 | RMS, WL, ZC, SSC × 3 forearm channels | +| `MODEL_CLASS_NAMES` | `{"fist","hook_em","open","rest","thumbs_up"}` | Alphabetical order | +| `MODEL_NORMALIZE_FEATURES` | *not defined yet* | Add when enabling cross-ch norm (Change B) | +| `MODEL_USE_REINHARD` | *not defined yet* | Add when enabling Reinhard compression (Change 4) | +| `FEAT_ZC_THRESH` | `0.1f` | Fraction of RMS for zero-crossing threshold | +| `FEAT_SSC_THRESH` | `0.1f` | Fraction of RMS for slope sign change threshold | + +The LDA_WEIGHTS and LDA_INTERCEPTS arrays are current trained values — do not modify manually. +They are regenerated by `EMGClassifier.export_to_header()` after each training run. + +### Current Feature Vector (12 features — firmware contract) + +``` +ch0: [0]=rms [1]=wl [2]=zc [3]=ssc +ch1: [4]=rms [5]=wl [6]=zc [7]=ssc +ch2: [8]=rms [9]=wl [10]=zc [11]=ssc +``` + +### Target Feature Vector (69 features after Change 1) + +``` +Per channel (×3 channels, 20 features each): + [0] rms [1] wl [2] zc [3] ssc [4] mav [5] var + [6] iemg [7] wamp [8] ar1 [9] ar2 [10] ar3 [11] ar4 + [12] mnf [13] mdf [14] pkf [15] mnp [16] bp0 [17] bp1 + [18] bp2 [19] bp3 + +ch0: indices 0–19 +ch1: indices 20–39 +ch2: indices 40–59 + +Cross-channel (9 features): + [60] cov_ch0_ch0 [61] cov_ch0_ch1 [62] cov_ch0_ch2 + [63] cov_ch1_ch1 [64] cov_ch1_ch2 [65] cov_ch2_ch2 + [66] cor_ch0_ch1 [67] cor_ch0_ch2 [68] cor_ch1_ch2 +``` + +### Specialist Feature Subset Indices (for Change F + Change 7) + +``` +TD (time-domain, 36 feat): indices [0–11, 20–31, 40–51] +FD (frequency-domain, 24 feat): indices [12–19, 32–39, 52–59] +CC (cross-channel, 9 feat): indices [60–68] +``` + +--- + +# PART II — TARGET ARCHITECTURE + +## 5. Full Recommended Multi-Model Stack + +``` +ADC (DMA, Change A) + └── IIR Biquad filter per channel (Change B) + └── 150-sample circular window buffer + │ + ▼ [every 25ms] + compute_features() → 69-feature vector + │ + ▼ + calibration_apply() (Change D — NVS z-score) + │ + ├─── Stage 1: Activity Gate ──────────────────────────────────┐ + │ total_rms < REST_THRESHOLD? → return GESTURE_REST │ + │ (skips all inference during obvious idle) │ + │ │ + ▼ (only reached when gesture is active) │ + Stage 2: Parallel Specialist LDAs (Change F) │ + ├── LDA_TD [TD features, 36-dim] → prob_td[5] │ + ├── LDA_FD [FD features, 24-dim] → prob_fd[5] │ + └── LDA_CC [CC features, 9-dim] → prob_cc[5] │ + │ + ▼ │ + Stage 3: Meta-LDA stacker (Change F) │ + input: [prob_td | prob_fd | prob_cc] (15-dim) │ + output: meta_probs[5] │ + │ + ▼ │ + EMA smoothing (α=0.7) on meta_probs │ + │ │ + ├── max smoothed prob ≥ 0.50? ────── Yes ──────────────────┐ │ + │ │ │ + └── No: Stage 4 Confidence Cascade (Change E) │ │ + run int8 MLP on full 69-feat vector │ │ + use higher-confidence winner │ │ + │ │ │ + └────────────────────────────────────────────►│ │ + │ │ + ◄────────────────────────────────────────────────────────── │ │ + │ ◄─┘ + ▼ + Stage 5: Confidence rejection (Change C) + max_prob < 0.40? → return current_output (hold / GESTURE_NONE) + │ + ▼ + Majority vote (window=5) + Debounce (count=3) + │ + ▼ + final gesture → actuation +``` + +### Model Weight Footprint + +| Model | Input Dim | Weights | Memory (float32) | +|-------|-----------|---------|-----------------| +| LDA_TD | 36 | 5×36 = 180 | 720 B | +| LDA_FD | 24 | 5×24 = 120 | 480 B | +| LDA_CC | 9 | 5×9 = 45 | 180 B | +| Meta-LDA | 15 | 5×15 = 75 | 300 B | +| int8 MLP [69→32→16→5] | 69 | ~2,900 | ~2.9 KB int8 | +| **Total** | | | **~4.6 KB** | + +All model weights fit comfortably in internal SRAM. + +--- + +## 6. Compute Budget for Full Stack + +| Stage | Cost | Cumulative | +|-------|------|-----------| +| Feature extraction (69 feat, 128-pt FFT ×3) | 1,200 µs | 1,200 µs | +| NVS calibration apply | 10 µs | 1,210 µs | +| Activity gate (RMS check) | 5 µs | 1,215 µs | +| LDA_TD (36 feat × 5 classes) | 50 µs | 1,265 µs | +| LDA_FD (24 feat × 5 classes) | 35 µs | 1,300 µs | +| LDA_CC (9 feat × 5 classes) | 15 µs | 1,315 µs | +| Meta-LDA (15 feat × 5 classes) | 10 µs | 1,325 µs | +| EMA + confidence check | 10 µs | 1,335 µs | +| int8 MLP (worst case, ~30% of hops) | 250 µs | 1,585 µs | +| Vote + debounce | 20 µs | 1,605 µs | +| **Worst-case total** | **1,760 µs** | **7% of 25ms budget** | + +--- + +## 7. Why This Architecture Works for 3-Channel EMG + +Three channels means limited spatial information. The ensemble compensates by extracting +**maximum diversity from the temporal and spectral dimensions**: + +- **LDA_TD** specializes in muscle activation *intensity and dynamics* (how hard and fast is each muscle firing) +- **LDA_FD** specializes in muscle activation *frequency content* (motor unit recruitment patterns — slow vs. fast twitch fibres fire at different frequencies) +- **LDA_CC** specializes in *inter-muscle coordination* (which muscles co-activate — the spatial "fingerprint" of each gesture) + +These three signal aspects are partially uncorrelated. A gesture that confuses LDA_TD (similar amplitude patterns) may be distinguishable by LDA_FD (different frequency recruitment) or LDA_CC (different co-activation pattern). The meta-LDA learns which specialist to trust for each gesture boundary. + +The int8 MLP fallback handles the residual nonlinear cases: gesture pairs where the decision boundary is curved in feature space, which LDA (linear boundary only) cannot resolve. + +--- + +# PART III — GESTURE EXTENSIBILITY + +## 8. What Changes When Adding or Removing a Gesture + +The system is designed for extensibility. Adding a gesture requires **3 firmware lines and a retrain**. + +### What Changes Automatically (No Manual Code Edits) + +| Component | How it adapts | +|-----------|--------------| +| `MODEL_NUM_CLASSES` in `model_weights.h` | Auto-computed from training data label count | +| LDA weight array dimensions | `[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES]` — regenerated by `export_to_header()` | +| `MODEL_CLASS_NAMES` array | Regenerated by `export_to_header()` | +| All ensemble LDA weight arrays | Regenerated by `export_ensemble_header()` (Change 7) | +| int8 MLP output layer | Retrained with new class count; re-exported to TFLite | +| Meta-LDA input/output dims | `META_NUM_INPUTS = 3 × MODEL_NUM_CLASSES` — auto from Python | + +### What Requires Manual Code Changes + +**Python side** (`learning_data_collection.py`): +```python +# 1. Add gesture name to the gesture list (1 line) +# Find where GESTURES or similar list is defined (near constants block ~line 49) +GESTURES = ['fist', 'hook_em', 'open', 'rest', 'thumbs_up', 'wrist_flex'] # example +``` + +**Firmware — `config.h`** (1 line per gesture): +```c +// Add enum value +typedef enum { + GESTURE_NONE = 0, + GESTURE_REST = 1, + GESTURE_FIST = 2, + GESTURE_OPEN = 3, + GESTURE_HOOK_EM = 4, + GESTURE_THUMBS_UP = 5, + GESTURE_WRIST_FLEX = 6, // ← add this line +} gesture_t; +``` + +**Firmware — `inference.c`** `inference_get_gesture_enum()` (2–3 lines per gesture): +```c +if (strcmp(name, "wrist_flex") == 0 || strcmp(name, "WRIST_FLEX") == 0) + return GESTURE_WRIST_FLEX; +``` + +**Firmware — `gestures.c`** (2 changes — these are easy to miss): +```c +// 1. Add to gesture_names[] static array — index MUST match gesture_t enum value: +static const char *gesture_names[GESTURE_COUNT] = { + "NONE", // GESTURE_NONE = 0 + "REST", // GESTURE_REST = 1 + "FIST", // GESTURE_FIST = 2 + "OPEN", // GESTURE_OPEN = 3 + "HOOK_EM", // GESTURE_HOOK_EM = 4 + "THUMBS_UP", // GESTURE_THUMBS_UP = 5 + "WRIST_FLEX", // GESTURE_WRIST_FLEX = 6 ← add here +}; + +// 2. Add case to gestures_execute() switch statement: +case GESTURE_WRIST_FLEX: + gesture_wrist_flex(); // implement the actuation function + break; +``` + +**Critical**: `GESTURE_COUNT` at the end of the `gesture_t` enum in `config.h` is used as the +array size for `gesture_names[]`. It updates automatically when new enum values are added before +it. Both `gesture_names[GESTURE_COUNT]` and the switch statement must be kept in sync with +`GESTURE_COUNT`. Mismatch causes a bounds-overrun or silent misclassification. + +### Complete Workflow for Adding a Gesture + +``` +1. Python: add gesture string to GESTURES list in learning_data_collection.py (1 line) + +2. Data: collect ≥10 sessions × ≥30 reps of new gesture + (follow Change 2 protocol: vary electrode placement between sessions) + +3. Train: python learning_data_collection.py → option 3 + OR: python train_ensemble.py (after Change 7 is implemented) + +4. Export: export_to_header() OR export_ensemble_header() + → overwrites model_weights.h / model_weights_ensemble.h with new class count + +5. config.h: add enum value before GESTURE_COUNT (1 line): + GESTURE_WRIST_FLEX = 6, // ← insert before GESTURE_COUNT + GESTURE_COUNT // stays last — auto-counts + +6. inference.c: add string mapping in inference_get_gesture_enum() (2 lines) + +7. gestures.c: add name to gesture_names[] array at correct index (1 line) + +8. gestures.c: add case to gestures_execute() switch statement (3 lines) + +9. Implement actuation function for new gesture (servo angles) + +10. Reflash and validate: pio run -t upload +``` + +**Exact files touched per new gesture (summary):** +| File | What to change | +|------|---------------| +| `learning_data_collection.py` | Add string to GESTURES list | +| `config/config.h` | Add enum value before `GESTURE_COUNT` | +| `core/inference.c` | Add `strcmp` case in `inference_get_gesture_enum()` | +| `core/gestures.c` | Add to `gesture_names[]` array + add switch case | +| `core/gestures.c` | Implement `gesture_()` function with servo angles | +| `core/model_weights.h` | Auto-generated — do not edit manually | + +### Removing a Gesture + +Removing is the same process in reverse, with one additional step: filter the HDF5 training +data to exclude sessions that contain the removed gesture's label. The simplest approach is +to pass a label whitelist to `load_all_for_training()`: + +```python +# Proposed addition to load_all_for_training() — add include_labels parameter +X, y, trial_ids, session_indices, label_names, sessions = \ + storage.load_all_for_training(include_labels=['fist', 'open', 'rest', 'thumbs_up']) + # hook_em removed — existing session files are not modified +``` + +--- + +## 9. Practical Limits of 3-Channel EMG + +This is the most important constraint for gesture count: + +| Gesture Count | Expected Accuracy | Notes | +|--------------|-------------------|-------| +| 3–5 gestures | >90% achievable | Current baseline target | +| 6–8 gestures | 80–90% achievable | Requires richer features + ensemble | +| 9–12 gestures | 65–80% achievable | Diminishing returns; some pairs will be confused | +| 13+ gestures | <65% | Surface EMG with 3 channels cannot reliably separate this many | + +**Why 3 channels limits gesture count**: Surface EMG captures the summed electrical activity of +many motor units under each electrode. With only 3 spatial locations, gestures that recruit +overlapping muscle groups (e.g., all finger-flexion gestures recruit FCR) produce similar +signals. The frequency and coordination features from Change 1 help, but there's a hard +information-theoretic limit imposed by channel count. + +**Rule of thumb**: aim for ≤8 gestures with the current 3-channel setup. For more, add the +bicep channel (ch3, currently excluded) to get 4 channels — see Section 10. + +--- + +## 10. Specific Gesture Considerations + +### Wrist Flexion / Extension +- **Feasibility**: High — FCR (ch0) activates strongly for flexion; extensor group (ch2) for extension +- **Differentiation from finger gestures**: frequency content differs (wrist involves slower motor units) +- **Recommendation**: Add these before wrist rotation — more reliable with surface EMG + +### Wrist Rotation (Supination / Pronation) +- **Feasibility**: Medium — the primary supinator is a deep muscle; surface electrodes capture it weakly +- **Key helper**: the bicep activates strongly during supination → **include ch3** (`HAND_CHANNELS = [0, 1, 2, 3]`) +- **Code change for 4 channels**: Python: `HAND_CHANNELS = [0, 1, 2, 3]`; firmware: `HAND_NUM_CHANNELS` auto-updates from the exported header since `MODEL_NUM_FEATURES` is recalculated +- **Caveat**: pronation vs. rest may be harder to distinguish than supination vs. rest + +### Pinch / Precision Grasp +- **Feasibility**: Medium — involves intrinsic hand muscles poorly captured by forearm electrodes +- Likely confused with open hand depending on electrode placement +- Collect with careful placement; validate cross-session accuracy before relying on it + +### Including ch3 (Bicep) for Wrist Gestures + +To include the bicep channel in the hand gesture classifier: +```python +# learning_data_collection.py — change this constant +HAND_CHANNELS = [0, 1, 2, 3] # was [0, 1, 2] — add bicep channel +``` +Feature count: 4 channels × 20 per-channel features + 10 cross-channel covariances + 6 correlations = **96 total features**. +The ensemble architecture handles this automatically — specialist LDA weight dimensions +recalculate at training time. + +--- + +# PART IV — CHANGE REFERENCE + +## 11. Change Classification Matrix + +| Change | Category | Priority | Files | ESP32 Reflash? | Retrain? | Risk | +|--------|----------|----------|-------|----------------|----------|------| +| **C** | Firmware | **Tier 1** | inference.c | ✓ | No | **Very Low** | +| **B** | Firmware | **Tier 1** | inference.c / filter.c | ✓ | No | Low | +| **A** | Firmware | **Tier 1** | adc_sampling.c | ✓ | No | Medium | +| **0** | Python | **Tier 1** | learning_data_collection.py | No | ✓ | Low | +| **1** | Python+C | **Tier 2** | learning_data_collection.py + inference.c | ✓ after | ✓ | Medium | +| **D** | Firmware | **Tier 2** | calibration.c/.h | ✓ | No | Medium | +| **2** | Protocol | **Tier 2** | None | No | ✓ new data | None | +| **3** | Python | **Tier 2** | learning_data_collection.py | No | ✓ | Low | +| **E** | Python+FW | **Tier 3** | train_mlp_tflite.py + firmware | ✓ | ✓ | High | +| **4** | Python+C | **Tier 3** | learning_data_collection.py + inference.c | ✓ if enabled | ✓ | Low | +| **5** | Python | **Tier 3** | learning_data_collection.py | No | No | None | +| **6** | Python | **Tier 3** | learning_data_collection.py | No | ✓ | Low | +| **7** | Python | **Tier 3** | new: train_ensemble.py | No | ✓ | Medium | +| **F** | Firmware | **Tier 3** | new: inference_ensemble.c | ✓ | No (needs 7 first) | Medium | + +**Recommended implementation order**: C → B → A → 0 → 1 → D → 2 → 3 → 5 (benchmark) → 7+F → E + +--- + +# PART V — FIRMWARE CHANGES + +## Change A — DMA-Driven ADC Sampling (Migration from `adc_oneshot` to `adc_continuous`) + +**Priority**: Tier 1 +**Current driver**: `adc_oneshot_read()` polling in `drivers/emg_sensor.c`. Timing is +controlled by `vTaskDelay(1)` in `run_inference_loop()` — subject to FreeRTOS scheduler +jitter of ±0.5–1ms, which corrupts frequency-domain features and ADC burst grouping. +**Why**: `adc_continuous` runs entirely in hardware DMA. Sample-to-sample jitter drops from +±1ms to <10µs. CPU overhead between samples is zero. Required for frequency features (Change 1). +**Effort**: 2–4 hours (replace `emg_sensor_read()` internals; keep public API the same) + +### ESP-IDF ADC Continuous API + +```c +// --- Initialize (call once at startup) --- +adc_continuous_handle_t adc_handle = NULL; +adc_continuous_handle_cfg_t adc_cfg = { + .max_store_buf_size = 4096, // PSRAM ring buffer size (bytes) + .conv_frame_size = 256, // bytes per conversion frame +}; +adc_continuous_new_handle(&adc_cfg, &adc_handle); + +// Actual hardware channel mapping (from emg_sensor.c): +// ch0 = ADC_CHANNEL_1 / GPIO 2 (Forearm Belly / FCR) +// ch1 = ADC_CHANNEL_2 / GPIO 3 (Forearm Extensors) +// ch2 = ADC_CHANNEL_8 / GPIO 9 (Forearm Contractors / FCU) +// ch3 = ADC_CHANNEL_9 / GPIO 10 (Bicep — independent subsystem) +adc_digi_pattern_config_t chan_cfg[4] = { + {.atten = ADC_ATTEN_DB_12, .channel = ADC_CHANNEL_1, .unit = ADC_UNIT_1, .bit_width = ADC_BITWIDTH_12}, + {.atten = ADC_ATTEN_DB_12, .channel = ADC_CHANNEL_2, .unit = ADC_UNIT_1, .bit_width = ADC_BITWIDTH_12}, + {.atten = ADC_ATTEN_DB_12, .channel = ADC_CHANNEL_8, .unit = ADC_UNIT_1, .bit_width = ADC_BITWIDTH_12}, + {.atten = ADC_ATTEN_DB_12, .channel = ADC_CHANNEL_9, .unit = ADC_UNIT_1, .bit_width = ADC_BITWIDTH_12}, +}; +adc_continuous_config_t cont_cfg = { + .sample_freq_hz = 4000, // 4 channels × 1000 Hz = 4000 total samples/sec + .conv_mode = ADC_CONV_SINGLE_UNIT_1, + .format = ADC_DIGI_OUTPUT_FORMAT_TYPE2, + .pattern_num = 4, + .adc_pattern = chan_cfg, +}; +adc_continuous_config(adc_handle, &cont_cfg); + +// --- ISR callback (fires each frame) --- +static SemaphoreHandle_t s_adc_sem; +static bool IRAM_ATTR adc_conv_done_cb( + adc_continuous_handle_t handle, + const adc_continuous_evt_data_t *edata, void *user_data) { + BaseType_t hp_woken = pdFALSE; + xSemaphoreGiveFromISR(s_adc_sem, &hp_woken); + return hp_woken == pdTRUE; +} +adc_continuous_evt_cbs_t cbs = { .on_conv_done = adc_conv_done_cb }; +adc_continuous_register_event_callbacks(adc_handle, &cbs, NULL); +adc_continuous_start(adc_handle); + +// --- ADC calibration (apply per sample) --- +adc_cali_handle_t cali_handle; +adc_cali_curve_fitting_config_t cali_cfg = { + .unit_id = ADC_UNIT_1, + .atten = ADC_ATTEN_DB_12, // matches ADC_ATTEN_DB_12 used in current emg_sensor.c + .bitwidth = ADC_BITWIDTH_12, +}; +adc_cali_create_scheme_curve_fitting(&cali_cfg, &cali_handle); + +// --- Sampling task (pin to Core 0) --- +void adc_sampling_task(void *arg) { + uint8_t result_buf[256]; + uint32_t out_len = 0; + while (1) { + xSemaphoreTake(s_adc_sem, portMAX_DELAY); + adc_continuous_read(adc_handle, result_buf, sizeof(result_buf), &out_len, 0); + // Parse: each entry is adc_digi_output_data_t + // Apply adc_cali_raw_to_voltage() for each sample + // Apply IIR filter (Change B) → post to inference ring buffer + } +} +``` + +**Verify**: log consecutive sample timestamps via `esp_timer_get_time()`; spacing should be 1.0ms ± 0.05ms. + +--- + +## Change B — IIR Biquad Bandpass Filter + +**Priority**: Tier 1 +**Why**: MyoWare analogue filters are not tunable. Software IIR removes powerline interference +(50/60 Hz), sub-20 Hz motion artifact, and >500 Hz noise — all of which inflate ZC, WL, and +other features computed at rest. +**Effort**: 2 hours + +### Step 1 — Compute Coefficients in Python (one-time, offline) + +```python +from scipy.signal import butter +import numpy as np + +fs = 1000.0 +sos = butter(N=2, Wn=[20.0, 500.0], btype='bandpass', fs=fs, output='sos') +# sos[i] = [b0, b1, b2, a0, a1, a2] +# esp-dsp Direct Form II convention: coeffs = [b0, b1, b2, -a1, -a2] +for i, s in enumerate(sos): + b0, b1, b2, a0, a1, a2 = s + print(f"Section {i}: {b0:.8f}f, {b1:.8f}f, {b2:.8f}f, {-a1:.8f}f, {-a2:.8f}f") +# Run this and paste the printed values into the C constants below +``` + +### Step 2 — Add to inference.c (after includes, before `// --- State ---`) + +```c +#include "dsps_biquad.h" + +// 2nd-order Butterworth bandpass 20–500 Hz @ 1000 Hz +// Coefficients: [b0, b1, b2, -a1, -a2] — Direct Form II, esp-dsp sign convention +// Regenerate with: scipy.signal.butter(N=2, Wn=[20,500], btype='bandpass', fs=1000, output='sos') +static const float BIQUAD_HP_COEFFS[5] = { /* paste section 0 output here */ }; +static const float BIQUAD_LP_COEFFS[5] = { /* paste section 1 output here */ }; + +// Filter delay state: 3 channels × 2 stages × 2 delay elements = 12 floats (48 bytes) +static float biquad_hp_w[HAND_NUM_CHANNELS][2]; +static float biquad_lp_w[HAND_NUM_CHANNELS][2]; +``` + +Add to `inference_init()`: +```c + memset(biquad_hp_w, 0, sizeof(biquad_hp_w)); + memset(biquad_lp_w, 0, sizeof(biquad_lp_w)); +``` + +### Step 3 — Apply Per Sample (called before writing to window_buffer) + +```c +// Apply to each channel before posting to the window buffer. +// Must be called IN ORDER for each sample (IIR has memory across calls). +static float IRAM_ATTR apply_bandpass(int ch, float raw) { + float hp_out, lp_out; + dsps_biquad_f32(&raw, &hp_out, 1, (float *)BIQUAD_HP_COEFFS, biquad_hp_w[ch]); + dsps_biquad_f32(&hp_out, &lp_out, 1, (float *)BIQUAD_LP_COEFFS, biquad_lp_w[ch]); + return lp_out; +} +``` + +**Note**: `window_buffer` stores `uint16_t` — change to `float` when adding this filter, so +filtered values are stored directly without lossy integer round-trip. + +**Verify**: log ZC count at rest before and after — filtered ZC should be substantially lower +(less spurious noise crossings). + +--- + +## Change C — Confidence Rejection + +**Priority**: Tier 1 — **implement this first, lowest risk of all changes** +**Why**: Without a rejection threshold, ambiguous EMG (rest-to-gesture transition, +mid-gesture fatigue, electrode lift) always produces a false actuation. +**Effort**: 15 minutes + +### Step 1 — Add Constant (top of inference.c with other constants) + +```c +#define CONFIDENCE_THRESHOLD 0.40f // Reject when max smoothed prob < this. + // Meta paper uses 0.35; 0.40 adds prosthetic safety margin. + // Tune: lower to 0.35 if real gestures are being rejected. +``` + +### Step 2 — Insert After EMA Block in `inference_predict()` (after line 214) + +```c + // Confidence rejection: if the peak smoothed probability is below threshold, + // hold the last confirmed output rather than outputting an uncertain prediction. + // Prevents false actuations during gesture transitions and electrode artifacts. + if (max_smoothed_prob < CONFIDENCE_THRESHOLD) { + *confidence = max_smoothed_prob; + return current_output; // -1 (GESTURE_NONE) until first confident prediction + } +``` + +**Verify**: arm at complete rest → confirm output stays at GESTURE_NONE and confidence logs +below 0.40. Deliberate fist → confidence rises above 0.40 within 1–3 inference cycles. + +--- + +## Change D — On-Device NVS Calibration + +**Priority**: Tier 2 +**Why**: Python `CalibrationTransform` only runs during training. On-device NVS calibration +lets the ESP32 recalibrate z-score normalization at startup (3 seconds of REST) without +retraining — solving placement drift and day-to-day impedance variation. +**Effort**: 3–4 hours + +### New Files + +``` +EMG_Arm/src/core/calibration.h +EMG_Arm/src/core/calibration.c +``` + +### calibration.h + +```c +#pragma once +#include +#include "config/config.h" + +#define CALIB_MAX_FEATURES 96 // supports up to 4-channel expansion + +bool calibration_init(void); // load from NVS at startup +void calibration_apply(float *feat); // z-score in-place; no-op if not calibrated +bool calibration_update(const float X[][CALIB_MAX_FEATURES], int n_windows, int n_feat); +void calibration_reset(void); +bool calibration_is_valid(void); +``` + +### calibration.c + +```c +#include "calibration.h" +#include "nvs_flash.h" +#include "nvs.h" +#include +#include +#include + +#define NVS_NAMESPACE "emg_calib" +#define NVS_KEY_MEAN "feat_mean" +#define NVS_KEY_STD "feat_std" +#define NVS_KEY_NFEAT "n_feat" +#define NVS_KEY_VALID "calib_ok" + +static float s_mean[CALIB_MAX_FEATURES]; +static float s_std[CALIB_MAX_FEATURES]; +static int s_n_feat = 0; +static bool s_valid = false; + +bool calibration_init(void) { + esp_err_t err = nvs_flash_init(); + if (err == ESP_ERR_NVS_NO_FREE_PAGES || err == ESP_ERR_NVS_NEW_VERSION_FOUND) { + nvs_flash_erase(); + nvs_flash_init(); + } + nvs_handle_t h; + if (nvs_open(NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) return false; + + uint8_t valid = 0; + size_t mean_sz = sizeof(s_mean), std_sz = sizeof(s_std); + bool ok = (nvs_get_u8(h, NVS_KEY_VALID, &valid) == ESP_OK) && (valid == 1) && + (nvs_get_i32(h, NVS_KEY_NFEAT, (int32_t*)&s_n_feat) == ESP_OK) && + (nvs_get_blob(h, NVS_KEY_MEAN, s_mean, &mean_sz) == ESP_OK) && + (nvs_get_blob(h, NVS_KEY_STD, s_std, &std_sz) == ESP_OK); + nvs_close(h); + s_valid = ok; + printf("[Calib] %s (%d features)\n", ok ? "Loaded from NVS" : "Not found — identity", s_n_feat); + return ok; +} + +void calibration_apply(float *feat) { + if (!s_valid) return; + for (int i = 0; i < s_n_feat; i++) + feat[i] = (feat[i] - s_mean[i]) / s_std[i]; +} + +bool calibration_update(const float X[][CALIB_MAX_FEATURES], int n_windows, int n_feat) { + if (n_windows < 10 || n_feat > CALIB_MAX_FEATURES) return false; + s_n_feat = n_feat; + memset(s_mean, 0, sizeof(s_mean)); + for (int w = 0; w < n_windows; w++) + for (int f = 0; f < n_feat; f++) + s_mean[f] += X[w][f]; + for (int f = 0; f < n_feat; f++) s_mean[f] /= n_windows; + + memset(s_std, 0, sizeof(s_std)); + for (int w = 0; w < n_windows; w++) + for (int f = 0; f < n_feat; f++) { + float d = X[w][f] - s_mean[f]; + s_std[f] += d * d; + } + for (int f = 0; f < n_feat; f++) { + s_std[f] = sqrtf(s_std[f] / n_windows); + if (s_std[f] < 1e-6f) s_std[f] = 1e-6f; + } + + nvs_handle_t h; + if (nvs_open(NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) return false; + nvs_set_blob(h, NVS_KEY_MEAN, s_mean, sizeof(s_mean)); + nvs_set_blob(h, NVS_KEY_STD, s_std, sizeof(s_std)); + nvs_set_i32(h, NVS_KEY_NFEAT, n_feat); + nvs_set_u8(h, NVS_KEY_VALID, 1); + nvs_commit(h); + nvs_close(h); + s_valid = true; + printf("[Calib] Updated from %d REST windows, %d features\n", n_windows, n_feat); + return true; +} +``` + +### Integration in inference.c + +In `inference_predict()`, after `compute_features(features)`, before LDA: +```c + calibration_apply(features); // z-score using NVS-stored mean/std +``` + +### Startup Flow + +```c +// In main application startup sequence: +calibration_init(); // load from NVS; no-op if not present yet + +// When user triggers recalibration (button press or serial command): +// Collect ~120 REST windows (~3 seconds at 25ms hop) +// Call calibration_update(rest_feature_buffer, 120, MODEL_NUM_FEATURES) +``` + +--- + +## Change E — int8 MLP via TFLite Micro + +**Priority**: Tier 3 — implement after Tier 1+2 changes and benchmark (Change 5) shows LDA plateauing +**Why**: LDA finds only linear decision boundaries. A 2-layer int8 MLP adds nonlinear +boundaries for gesture pairs that overlap in feature space. +**Effort**: 4–6 hours + +### Python Training (new file: `train_mlp_tflite.py`) + +```python +""" +Train int8 MLP for ESP32-S3 deployment via TFLite Micro. +Run AFTER Change 0 (label shift) + Change 1 (expanded features). +""" +import numpy as np +import tensorflow as tf +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS + +storage = SessionStorage() +X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() + +extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True) +X = extractor.extract_features_batch(X_raw).astype(np.float32) + +from sklearn.preprocessing import StandardScaler +scaler = StandardScaler() +X = scaler.fit_transform(X) + +n_feat, n_cls = X.shape[1], len(np.unique(y)) + +model = tf.keras.Sequential([ + tf.keras.layers.Input(shape=(n_feat,)), + tf.keras.layers.Dense(32, activation='relu'), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(16, activation='relu'), + tf.keras.layers.Dense(n_cls, activation='softmax'), +]) +model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) +model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1) + +def representative_dataset(): + for i in range(0, len(X), 10): + yield [X[i:i+1]] + +converter = tf.lite.TFLiteConverter.from_keras_model(model) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +converter.representative_dataset = representative_dataset +converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] +converter.inference_input_type = tf.int8 +converter.inference_output_type = tf.int8 +tflite_model = converter.convert() + +out = Path('EMG_Arm/src/core/emg_model_data.cc') +with open(out, 'w') as f: + f.write('#include "emg_model_data.h"\n') + f.write(f'const int g_model_len = {len(tflite_model)};\n') + f.write('const unsigned char g_model[] = {\n ') + f.write(', '.join(f'0x{b:02x}' for b in tflite_model)) + f.write('\n};\n') +print(f"Wrote {out} ({len(tflite_model)} bytes)") +``` + +### Firmware (inference_mlp.cc) + +```cpp +#include "inference_mlp.h" +#include "emg_model_data.h" +#include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" + +static uint8_t tensor_arena[48 * 1024]; // 48 KB — tune down if memory is tight +static tflite::MicroInterpreter *interpreter = nullptr; +static TfLiteTensor *input = nullptr, *output = nullptr; + +void inference_mlp_init(void) { + const tflite::Model *model = tflite::GetModel(g_model); + static tflite::MicroMutableOpResolver<4> resolver; + resolver.AddFullyConnected(); + resolver.AddRelu(); + resolver.AddSoftmax(); + resolver.AddDequantize(); + static tflite::MicroInterpreter interp(model, resolver, tensor_arena, sizeof(tensor_arena)); + interpreter = &interp; + interpreter->AllocateTensors(); + input = interpreter->input(0); + output = interpreter->output(0); +} + +int inference_mlp_predict(const float *features, int n_feat, float *conf_out) { + float iscale = input->params.scale; + int izp = input->params.zero_point; + for (int i = 0; i < n_feat; i++) { + int q = (int)roundf(features[i] / iscale) + izp; + input->data.int8[i] = (int8_t)(q < -128 ? -128 : q > 127 ? 127 : q); + } + interpreter->Invoke(); + + float oscale = output->params.scale; + int ozp = output->params.zero_point; + float max_p = -1e9f; + int max_c = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + float p = (output->data.int8[c] - ozp) * oscale; + if (p > max_p) { max_p = p; max_c = c; } + } + *conf_out = max_p; + return max_c; +} +``` + +**platformio.ini addition**: +```ini +lib_deps = + tensorflow/tflite-micro +``` + +--- + +## Change F — Ensemble Inference Pipeline + +**Priority**: Tier 3 (requires Change 1 features + Change 7 training + Change E MLP) +**Why**: This is the full recommended architecture from Part II. +**Effort**: 3–4 hours firmware (after Python ensemble is trained and exported) + +### New Files + +``` +EMG_Arm/src/core/inference_ensemble.c +EMG_Arm/src/core/inference_ensemble.h +EMG_Arm/src/core/model_weights_ensemble.h (generated by Change 7 Python script) +``` + +### inference_ensemble.h + +```c +#pragma once +#include + +void inference_ensemble_init(void); +int inference_ensemble_predict(float *confidence); +``` + +### inference_ensemble.c + +```c +#include "inference_ensemble.h" +#include "inference.h" // for compute_features(), calibration_apply() +#include "inference_mlp.h" // for inference_mlp_predict() +#include "model_weights_ensemble.h" +#include "config/config.h" +#include "dsps_dotprod.h" +#include +#include +#include + +#define ENSEMBLE_EMA_ALPHA 0.70f +#define ENSEMBLE_CONF_THRESHOLD 0.50f // below this: escalate to MLP fallback +#define REJECT_THRESHOLD 0.40f // below this even after MLP: hold output +#define REST_ACTIVITY_THRESHOLD 0.05f // total_rms below this → skip inference, return REST + +// EMA state +static float s_smoothed[MODEL_NUM_CLASSES]; +// Vote + debounce (reuse existing pattern from inference.c) +static int s_vote_history[5]; +static int s_vote_head = 0; +static int s_current_output = -1; +static int s_pending_output = -1; +static int s_pending_count = 0; + +// --- Generic LDA softmax predict --- +// weights: [n_classes][n_feat], intercepts: [n_classes] +// proba_out: [n_classes] — caller-provided output +static void lda_softmax(const float *feat, int n_feat, + const float *weights_flat, const float *intercepts, + int n_classes, float *proba_out) { + float raw[MODEL_NUM_CLASSES]; + float max_raw = -1e9f, sum_exp = 0.0f; + + for (int c = 0; c < n_classes; c++) { + raw[c] = intercepts[c]; + // dsps_dotprod_f32 requires 4-byte aligned arrays and length multiple of 4; + // for safety use plain loop — compiler will auto-vectorize with -O2 + const float *w = weights_flat + c * n_feat; + for (int f = 0; f < n_feat; f++) raw[c] += feat[f] * w[f]; + if (raw[c] > max_raw) max_raw = raw[c]; + } + for (int c = 0; c < n_classes; c++) { + proba_out[c] = expf(raw[c] - max_raw); + sum_exp += proba_out[c]; + } + for (int c = 0; c < n_classes; c++) proba_out[c] /= sum_exp; +} + +void inference_ensemble_init(void) { + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + s_smoothed[c] = 1.0f / MODEL_NUM_CLASSES; + for (int i = 0; i < 5; i++) s_vote_history[i] = -1; + s_vote_head = 0; + s_current_output = -1; + s_pending_output = -1; + s_pending_count = 0; +} + +int inference_ensemble_predict(float *confidence) { + // 1. Extract features (shared with single-model path) + float features[MODEL_NUM_FEATURES]; + compute_features(features); + calibration_apply(features); + + // 2. Activity gate — skip inference during obvious REST + float total_rms_sq = 0.0f; + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + float r = features[ch * ENSEMBLE_PER_CH_FEATURES]; // RMS is index 0 per channel + total_rms_sq += r * r; + } + if (sqrtf(total_rms_sq) < REST_ACTIVITY_THRESHOLD) { + *confidence = 1.0f; + return GESTURE_REST; + } + + // 3. Specialist LDAs + float prob_td[MODEL_NUM_CLASSES]; + float prob_fd[MODEL_NUM_CLASSES]; + float prob_cc[MODEL_NUM_CLASSES]; + + lda_softmax(features + TD_FEAT_OFFSET, TD_NUM_FEATURES, + (const float *)LDA_TD_WEIGHTS, LDA_TD_INTERCEPTS, + MODEL_NUM_CLASSES, prob_td); + lda_softmax(features + FD_FEAT_OFFSET, FD_NUM_FEATURES, + (const float *)LDA_FD_WEIGHTS, LDA_FD_INTERCEPTS, + MODEL_NUM_CLASSES, prob_fd); + lda_softmax(features + CC_FEAT_OFFSET, CC_NUM_FEATURES, + (const float *)LDA_CC_WEIGHTS, LDA_CC_INTERCEPTS, + MODEL_NUM_CLASSES, prob_cc); + + // 4. Meta-LDA stacker + float meta_in[META_NUM_INPUTS]; // = 3 * MODEL_NUM_CLASSES + memcpy(meta_in, prob_td, MODEL_NUM_CLASSES * sizeof(float)); + memcpy(meta_in + MODEL_NUM_CLASSES, prob_fd, MODEL_NUM_CLASSES * sizeof(float)); + memcpy(meta_in + 2*MODEL_NUM_CLASSES, prob_cc, MODEL_NUM_CLASSES * sizeof(float)); + + float meta_probs[MODEL_NUM_CLASSES]; + lda_softmax(meta_in, META_NUM_INPUTS, + (const float *)META_LDA_WEIGHTS, META_LDA_INTERCEPTS, + MODEL_NUM_CLASSES, meta_probs); + + // 5. EMA smoothing on meta output + float max_smooth = 0.0f; + int winner = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + s_smoothed[c] = ENSEMBLE_EMA_ALPHA * s_smoothed[c] + + (1.0f - ENSEMBLE_EMA_ALPHA) * meta_probs[c]; + if (s_smoothed[c] > max_smooth) { max_smooth = s_smoothed[c]; winner = c; } + } + + // 6. Confidence cascade: escalate to MLP if meta-LDA is uncertain + if (max_smooth < ENSEMBLE_CONF_THRESHOLD) { + float mlp_conf = 0.0f; + int mlp_winner = inference_mlp_predict(features, MODEL_NUM_FEATURES, &mlp_conf); + if (mlp_conf > max_smooth) { winner = mlp_winner; max_smooth = mlp_conf; } + } + + // 7. Reject if still uncertain + if (max_smooth < REJECT_THRESHOLD) { + *confidence = max_smooth; + return s_current_output; + } + + *confidence = max_smooth; + + // 8. Majority vote (window = 5) + s_vote_history[s_vote_head] = winner; + s_vote_head = (s_vote_head + 1) % 5; + int counts[MODEL_NUM_CLASSES] = {0}; + for (int i = 0; i < 5; i++) + if (s_vote_history[i] >= 0) counts[s_vote_history[i]]++; + int majority = 0, majority_cnt = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + if (counts[c] > majority_cnt) { majority_cnt = counts[c]; majority = c; } + + // 9. Debounce (3 consecutive predictions to change output) + int final = s_current_output; + if (s_current_output == -1) { + s_current_output = majority; final = majority; + } else if (majority == s_current_output) { + s_pending_output = majority; s_pending_count = 1; + } else if (majority == s_pending_output) { + if (++s_pending_count >= 3) { s_current_output = majority; final = majority; } + } else { + s_pending_output = majority; s_pending_count = 1; + } + + return final; +} +``` + +### model_weights_ensemble.h Layout (generated by Change 7) + +```c +// Auto-generated by train_ensemble.py — do not edit manually +#pragma once + +#define MODEL_NUM_CLASSES 5 // auto-computed from training data +#define MODEL_NUM_FEATURES 69 // total feature count (after Change 1) +#define ENSEMBLE_PER_CH_FEATURES 20 // features per channel + +// Specialist feature subset offsets and sizes +#define TD_FEAT_OFFSET 0 +#define TD_NUM_FEATURES 36 // time-domain: indices 0–11, 20–31, 40–51 +#define FD_FEAT_OFFSET 12 // NOTE: FD features are interleaved per-channel +#define FD_NUM_FEATURES 24 // freq-domain: indices 12–19, 32–39, 52–59 +#define CC_FEAT_OFFSET 60 +#define CC_NUM_FEATURES 9 // cross-channel: indices 60–68 + +#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES) // = 15 + +// Specialist LDA weights (flat row-major: [n_classes][n_feat]) +extern const float LDA_TD_WEIGHTS[MODEL_NUM_CLASSES][TD_NUM_FEATURES]; +extern const float LDA_TD_INTERCEPTS[MODEL_NUM_CLASSES]; + +extern const float LDA_FD_WEIGHTS[MODEL_NUM_CLASSES][FD_NUM_FEATURES]; +extern const float LDA_FD_INTERCEPTS[MODEL_NUM_CLASSES]; + +extern const float LDA_CC_WEIGHTS[MODEL_NUM_CLASSES][CC_NUM_FEATURES]; +extern const float LDA_CC_INTERCEPTS[MODEL_NUM_CLASSES]; + +// Meta-LDA weights +extern const float META_LDA_WEIGHTS[MODEL_NUM_CLASSES][META_NUM_INPUTS]; +extern const float META_LDA_INTERCEPTS[MODEL_NUM_CLASSES]; + +// Class names (for inference_get_gesture_enum) +extern const char *MODEL_CLASS_NAMES[MODEL_NUM_CLASSES]; +``` + +**Important note on FD features**: the frequency-domain features are interleaved at indices +[12–19] for ch0, [32–39] for ch1, [52–59] for ch2. The `lda_softmax` call for LDA_FD must +pass a **gathered** (non-contiguous) sub-vector. The cleanest approach is to gather them into +a contiguous buffer before calling lda_softmax: + +```c +// Gather FD features into contiguous buffer before LDA_FD +float fd_buf[FD_NUM_FEATURES]; +for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) + memcpy(fd_buf + ch*8, features + ch*20 + 12, 8 * sizeof(float)); +lda_softmax(fd_buf, FD_NUM_FEATURES, ...); +``` + +Similarly for TD features. This gather costs <5 µs — negligible. + +--- + +# PART VI — PYTHON/TRAINING CHANGES + +## Change 0 — Forward Label Shift + +**Priority**: Tier 1 +**Source**: Meta Nature 2025, Methods: "Discrete-gesture time alignment" +**Why**: +100ms shift after onset detection gives the classifier 100ms of pre-event "building" +signal, dramatically cleaning the decision boundary near gesture onset. +**ESP32 impact**: None. + +### Step 1 — Add Constant After Line 94 + +```python +# After: TRANSITION_END_MS = 150 +LABEL_FORWARD_SHIFT_MS = 100 # shift label boundaries +100ms after onset alignment + # Source: Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w +``` + +### Step 2 — Apply Shift in `SessionStorage.save_session()` (after line ~704) + +Find and insert after: +```python + print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted") +``` + +Insert: +```python + if LABEL_FORWARD_SHIFT_MS > 0: + shift_windows = max(1, round(LABEL_FORWARD_SHIFT_MS / HOP_SIZE_MS)) + shifted = list(aligned_labels) + for i in range(1, len(aligned_labels)): + if aligned_labels[i] != aligned_labels[i - 1]: + for j in range(i, min(i + shift_windows, len(aligned_labels))): + if shifted[j] == aligned_labels[i]: + shifted[j] = aligned_labels[i - 1] + n_shifted = sum(1 for a, b in zip(aligned_labels, shifted) if a != b) + aligned_labels = shifted + print(f"[Storage] Forward label shift (+{LABEL_FORWARD_SHIFT_MS}ms): {n_shifted} windows adjusted") +``` + +### Step 3 — Reduce TRANSITION_START_MS + +```python +TRANSITION_START_MS = 200 # was 300 — reduce because 100ms shift already adds pre-event context +``` + +**Verify**: printout shows `N windows adjusted` where N is 5–20% of total windows per session. + +--- + +## Change 1 — Expanded Feature Set + +**Priority**: Tier 2 +**Why**: 12 → 69 features; adds frequency-domain and cross-channel information that is +structurally more informative than amplitude alone (Meta Extended Data Fig. 6). +**ESP32 impact**: retrain → export new `model_weights.h`; port selected features to C. + +### Sub-change 1A — Expand `extract_features_single_channel()` (line 1448) + +Replace the entire function body: + +```python + def extract_features_single_channel(self, signal: np.ndarray) -> dict: + if getattr(self, 'reinhard', False): + signal = 64.0 * signal / (32.0 + np.abs(signal)) + + signal = signal - np.mean(signal) + N = len(signal) + + # --- Time domain --- + rms = np.sqrt(np.mean(signal ** 2)) + diff = np.diff(signal) + wl = np.sum(np.abs(diff)) + zc_thresh = self.zc_threshold_percent * rms + ssc_thresh = (self.ssc_threshold_percent * rms) ** 2 + sign_ch = signal[:-1] * signal[1:] < 0 + zc = int(np.sum(sign_ch & (np.abs(diff) > zc_thresh))) + d_l = signal[1:-1] - signal[:-2] + d_r = signal[1:-1] - signal[2:] + ssc = int(np.sum((d_l * d_r) > ssc_thresh)) + mav = np.mean(np.abs(signal)) + var = np.mean(signal ** 2) + iemg = np.sum(np.abs(signal)) + wamp = int(np.sum(np.abs(diff) > 0.15 * rms)) + + # AR(4) via Yule-Walker + ar = np.zeros(4) + if rms > 1e-6: + try: + from scipy.linalg import solve_toeplitz + r = np.array([np.dot(signal[i:], signal[:N-i]) / N for i in range(5)]) + if r[0] > 1e-10: + ar = solve_toeplitz(r[:4], -r[1:5]) + except Exception: + pass + + # --- Frequency domain (20–500 Hz) --- + freqs = np.fft.rfftfreq(N, d=1.0 / SAMPLING_RATE_HZ) + psd = np.abs(np.fft.rfft(signal)) ** 2 / N + m = (freqs >= 20) & (freqs <= 500) + f_m, p_m = freqs[m], psd[m] + tp = np.sum(p_m) + 1e-10 + mnf = float(np.sum(f_m * p_m) / tp) + cum = np.cumsum(p_m) + mdf = float(f_m[min(np.searchsorted(cum, tp / 2), len(f_m) - 1)]) + pkf = float(f_m[np.argmax(p_m)]) if len(p_m) > 0 else 0.0 + mnp = float(tp / max(len(p_m), 1)) + + # Bandpower in 4 physiological bands (mirrors firmware esp-dsp FFT bands) + bands = [(20, 80), (80, 150), (150, 300), (300, 500)] + bp = [float(np.sum(psd[(freqs >= lo) & (freqs < hi)])) for lo, hi in bands] + + return { + 'rms': rms, 'wl': wl, 'zc': zc, 'ssc': ssc, + 'mav': mav, 'var': var, 'iemg': iemg, 'wamp': wamp, + 'ar1': float(ar[0]), 'ar2': float(ar[1]), + 'ar3': float(ar[2]), 'ar4': float(ar[3]), + 'mnf': mnf, 'mdf': mdf, 'pkf': pkf, 'mnp': mnp, + 'bp0': bp[0], 'bp1': bp[1], 'bp2': bp[2], 'bp3': bp[3], + } +``` + +### Sub-change 1B — Update `extract_features_window()` Return Block (line 1482) + +Replace the return section: + +```python + FEATURE_ORDER = ['rms', 'wl', 'zc', 'ssc', 'mav', 'var', 'iemg', 'wamp', + 'ar1', 'ar2', 'ar3', 'ar4', 'mnf', 'mdf', 'pkf', 'mnp', + 'bp0', 'bp1', 'bp2', 'bp3'] + NORMALIZE_KEYS = {'rms', 'wl', 'mav', 'iemg'} + + features = [] + for ch_features in all_ch_features: + for key in FEATURE_ORDER: + val = ch_features.get(key, 0.0) + if self.normalize and key in NORMALIZE_KEYS: + val = val / norm_factor + features.append(float(val)) + + if self.cross_channel and window.shape[1] >= 2: + sel = window[:, channel_indices].astype(np.float32) + wc = sel - sel.mean(axis=0) + cov = (wc.T @ wc) / len(wc) + ri, ci = np.triu_indices(len(channel_indices)) + features.extend(cov[ri, ci].tolist()) + stds = np.sqrt(np.diag(cov)) + 1e-10 + cor = cov / np.outer(stds, stds) + ro, co = np.triu_indices(len(channel_indices), k=1) + features.extend(cor[ro, co].tolist()) + + return np.array(features, dtype=np.float32) +``` + +### Sub-change 1C — Update `EMGFeatureExtractor.__init__()` (line 1430) + +```python + def __init__(self, zc_threshold_percent=0.1, ssc_threshold_percent=0.1, + channels=None, normalize=True, cross_channel=True, reinhard=False): + self.zc_threshold_percent = zc_threshold_percent + self.ssc_threshold_percent = ssc_threshold_percent + self.channels = channels + self.normalize = normalize + self.cross_channel = cross_channel + self.reinhard = reinhard +``` + +### Sub-change 1D — Update Feature Count in `extract_features_batch()` (line 1520) + +Replace `n_features = n_channels * 4`: +```python + per_ch = 20 + if self.cross_channel and n_channels >= 2: + n_features = n_channels * per_ch + \ + n_channels*(n_channels+1)//2 + n_channels*(n_channels-1)//2 + else: + n_features = n_channels * per_ch +``` + +### Sub-change 1E — Update `get_feature_names()` (line 1545) + +```python + def get_feature_names(self, n_channels=0): + ch_idx = self.channels if self.channels is not None else list(range(n_channels)) + ORDER = ['rms','wl','zc','ssc','mav','var','iemg','wamp', + 'ar1','ar2','ar3','ar4','mnf','mdf','pkf','mnp','bp0','bp1','bp2','bp3'] + names = [f'ch{ch}_{f}' for ch in ch_idx for f in ORDER] + if self.cross_channel and len(ch_idx) >= 2: + n = len(ch_idx) + names += [f'cov_ch{ch_idx[i]}_ch{ch_idx[j]}' for i in range(n) for j in range(i, n)] + names += [f'cor_ch{ch_idx[i]}_ch{ch_idx[j]}' for i in range(n) for j in range(i+1, n)] + return names +``` + +### Sub-change 1F — Update `EMGClassifier.__init__()` (line 1722) + +```python + self.feature_extractor = EMGFeatureExtractor( + channels=HAND_CHANNELS, cross_channel=True, reinhard=False) +``` + +### Sub-change 1G — Update `save()` (line 1910) and `load()` (line 2089) + +In `save()`, add to `feature_extractor_params` dict: +```python + 'cross_channel': getattr(self.feature_extractor, 'cross_channel', True), + 'reinhard': getattr(self.feature_extractor, 'reinhard', False), +``` + +In `load()`, update `EMGFeatureExtractor(...)` constructor: +```python + classifier.feature_extractor = EMGFeatureExtractor( + zc_threshold_percent = params.get('zc_threshold_percent', 0.1), + ssc_threshold_percent = params.get('ssc_threshold_percent', 0.1), + channels = params.get('channels', HAND_CHANNELS), + normalize = params.get('normalize', False), + cross_channel = params.get('cross_channel', True), + reinhard = params.get('reinhard', False), + ) +``` + +### Also Fix Bug at Line 2382 + +```python +X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training() +``` + +--- + +## Change 2 — Electrode Repositioning Protocol + +**Protocol**: no code changes. +> *"Between sessions within a single day, the participants remove and slightly reposition the +> sEMG wristband to enable generalization across different recording positions."* +> — Meta Nature 2025 Methods + +- Session 1: standard placement +- Session 2: band 1–2 cm up the forearm +- Session 3: band 1–2 cm down the forearm +- Session 4+: slight axial rotation or return to any above position + +The per-session z-score normalization in `_apply_session_normalization()` handles the +resulting amplitude shifts. Perform **fast, natural** gestures — not slow/deliberate. + +--- + +## Change 3 — Data Augmentation + +**Priority**: Tier 2. Apply to **raw windows BEFORE feature extraction**. + +Insert before the `# === LDA CLASSIFIER ===` comment (~line 1709): + +```python +def augment_emg_batch(X, y, multiplier=3, seed=42): + """ + Augment raw EMG windows for training robustness. + Must be called on raw windows (n_windows, n_samples, n_channels), + not on pre-computed features. + Source (window jitter): Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w + """ + rng = np.random.default_rng(seed) + aug_X, aug_y = [X], [y] + for _ in range(multiplier - 1): + Xc = X.copy().astype(np.float32) + Xc *= rng.uniform(0.80, 1.20, (len(X), 1, 1)).astype(np.float32) # amplitude + rms = np.sqrt(np.mean(Xc**2, axis=(1,2), keepdims=True)) + 1e-8 + Xc += rng.standard_normal(Xc.shape).astype(np.float32) * (0.05 * rms) # noise + Xc += rng.uniform(-20., 20., (len(X), 1, X.shape[2])).astype(np.float32) # DC jitter + shifts = rng.integers(-5, 6, size=len(X)) + for i in range(len(Xc)): + if shifts[i]: Xc[i] = np.roll(Xc[i], shifts[i], axis=0) # jitter + aug_X.append(Xc); aug_y.append(y) + return np.concatenate(aug_X), np.concatenate(aug_y) +``` + +In `EMGClassifier.train()`, replace the start of the function's feature extraction block: + +```python + if getattr(self, 'use_augmentation', True): + X_aug, y_aug = augment_emg_batch(X, y, multiplier=3) + print(f"[Classifier] Augmented: {len(X)} → {len(X_aug)} windows") + else: + X_aug, y_aug = X, y + X_features = self.feature_extractor.extract_features_batch(X_aug) + # ... then use y_aug instead of y for model.fit() +``` + +--- + +## Change 4 — Reinhard Compression (Optional) + +**Formula**: `output = 64 × x / (32 + |x|)` +**Enable in Python**: set `reinhard=True` in `EMGFeatureExtractor` constructor (Change 1F). + +**Enable in firmware** (`inference.c` `compute_features()`, after signal copy loop, before mean calc): +```c +#if MODEL_USE_REINHARD + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + float x = signal[i]; + signal[i] = 64.0f * x / (32.0f + fabsf(x)); + } +#endif +``` +Add `#define MODEL_USE_REINHARD 0` to `model_weights.h` (set to `1` when Python uses `reinhard=True`). +**Python and firmware MUST match.** Mismatch silently corrupts all predictions. + +--- + +## Change 5 — Classifier Benchmark + +**Purpose**: tells you whether LDA accuracy plateau is a features problem (all classifiers similar → add features) or a model complexity problem (SVM/MLP >> LDA → implement Change E/F). + +Add after `run_training_demo()`: + +```python +def run_classifier_benchmark(): + from sklearn.svm import SVC + from sklearn.neural_network import MLPClassifier + from sklearn.pipeline import Pipeline + from sklearn.preprocessing import StandardScaler + from sklearn.model_selection import cross_val_score, GroupKFold + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis + + storage = SessionStorage() + X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() + extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True) + X = extractor.extract_features_batch(X_raw) + X = EMGClassifier()._apply_session_normalization(X, session_indices, y=y) + + clfs = { + 'LDA (ESP32 model)': LinearDiscriminantAnalysis(), + 'QDA': QuadraticDiscriminantAnalysis(reg_param=0.1), + 'SVM-RBF': Pipeline([('s', StandardScaler()), ('m', SVC(kernel='rbf', C=10))]), + 'MLP-128-64': Pipeline([('s', StandardScaler()), + ('m', MLPClassifier(hidden_layer_sizes=(128,64), + max_iter=1000, early_stopping=True))]), + } + gkf = GroupKFold(n_splits=5) + print(f"\n{'Classifier':<22} {'Mean CV':>8} {'Std':>6}") + print("-" * 40) + for name, clf in clfs.items(): + sc = cross_val_score(clf, X, y, cv=gkf, groups=trial_ids, scoring='accuracy') + print(f" {name:<20} {sc.mean()*100:>7.1f}% ±{sc.std()*100:.1f}%") + print("\n → If LDA ≈ SVM: features are the bottleneck (add Change 1 features)") + print(" → If SVM >> LDA: model complexity bottleneck (implement Change F ensemble)") +``` + +--- + +## Change 6 — Simplified MPF Features + +**Python training only** — not worth porting to ESP32 directly (use bandpower bp0–bp3 from Change 1 as the firmware-side approximation). + +Add after `EMGFeatureExtractor` class: + +```python +class MPFFeatureExtractor: + """ + Simplified 3-channel MPF: CSD upper triangle per 6 frequency bands = 36 features. + Python training only. Omits matrix logarithm (not needed for 3 channels). + Source: Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w + ESP32 approximation: use bp0–bp3 from EMGFeatureExtractor (Change 1). + """ + BANDS = [(0,62),(62,125),(125,187),(187,250),(250,375),(375,500)] + + def __init__(self, channels=None, log_diagonal=True): + self.channels = channels or HAND_CHANNELS + self.log_diag = log_diagonal + self.n_ch = len(self.channels) + self._r, self._c = np.triu_indices(self.n_ch) + self.n_features = len(self.BANDS) * len(self._r) + + def extract_window(self, window): + sig = window[:, self.channels].astype(np.float64) + N = len(sig) + freqs = np.fft.rfftfreq(N, d=1.0/SAMPLING_RATE_HZ) + Xf = np.fft.rfft(sig, axis=0) + feats = [] + for lo, hi in self.BANDS: + mask = (freqs >= lo) & (freqs < hi) + if not mask.any(): + feats.extend([0.0] * len(self._r)); continue + CSD = (Xf[mask].conj().T @ Xf[mask]).real / N + if self.log_diag: + for k in range(self.n_ch): CSD[k,k] = np.log(max(CSD[k,k], 1e-10)) + feats.extend(CSD[self._r, self._c].tolist()) + return np.array(feats, dtype=np.float32) + + def extract_batch(self, X): + out = np.zeros((len(X), self.n_features), dtype=np.float32) + for i in range(len(X)): out[i] = self.extract_window(X[i]) + return out +``` + +In `EMGClassifier.train()`, after standard feature extraction: +```python + if getattr(self, 'use_mpf', False): + mpf = MPFFeatureExtractor(channels=HAND_CHANNELS) + X_features = np.hstack([X_features, mpf.extract_batch(X_aug)]) +``` + +--- + +## Change 7 — Ensemble Training + +**Priority**: Tier 3 (implements Change F's training side) +**New file**: `C:/VSCode/Marvel_Projects/Bucky_Arm/train_ensemble.py` + +```python +""" +Train the full 3-specialist-LDA + meta-LDA ensemble. +Requires Change 1 (expanded features) to be implemented first. +Exports model_weights_ensemble.h for firmware Change F. + +Architecture: + LDA_TD (36 time-domain feat) ─┐ + LDA_FD (24 freq-domain feat) ├─ 15 probs ─► Meta-LDA ─► final class + LDA_CC (9 cross-ch feat) ─┘ +""" +import numpy as np +from pathlib import Path +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.model_selection import cross_val_predict, GroupKFold, cross_val_score +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from learning_data_collection import ( + SessionStorage, EMGFeatureExtractor, HAND_CHANNELS +) + +# ─── Load and extract features ─────────────────────────────────────────────── +storage = SessionStorage() +X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() + +extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True) +X = extractor.extract_features_batch(X_raw).astype(np.float64) + +# Per-session normalization (same as EMGClassifier._apply_session_normalization) +from sklearn.preprocessing import StandardScaler +for sid in np.unique(session_indices): + mask = session_indices == sid + sc = StandardScaler() + X[mask] = sc.fit_transform(X[mask]) + +feat_names = extractor.get_feature_names(n_channels=len(HAND_CHANNELS)) +n_cls = len(np.unique(y)) + +# ─── Feature subset indices ─────────────────────────────────────────────────── +TD_FEAT = ['rms','wl','zc','ssc','mav','var','iemg','wamp','ar1','ar2','ar3','ar4'] +FD_FEAT = ['mnf','mdf','pkf','mnp','bp0','bp1','bp2','bp3'] + +td_idx = [i for i,n in enumerate(feat_names) if any(n.endswith(f'_{f}') for f in TD_FEAT)] +fd_idx = [i for i,n in enumerate(feat_names) if any(n.endswith(f'_{f}') for f in FD_FEAT)] +cc_idx = [i for i,n in enumerate(feat_names) if n.startswith('cov_') or n.startswith('cor_')] + +print(f"Feature subsets — TD: {len(td_idx)}, FD: {len(fd_idx)}, CC: {len(cc_idx)}") + +X_td = X[:, td_idx] +X_fd = X[:, fd_idx] +X_cc = X[:, cc_idx] + +# ─── Train specialist LDAs with out-of-fold stacking ───────────────────────── +gkf = GroupKFold(n_splits=5) + +print("Training specialist LDAs (out-of-fold for stacking)...") +lda_td = LinearDiscriminantAnalysis() +lda_fd = LinearDiscriminantAnalysis() +lda_cc = LinearDiscriminantAnalysis() + +oof_td = cross_val_predict(lda_td, X_td, y, cv=gkf, groups=trial_ids, method='predict_proba') +oof_fd = cross_val_predict(lda_fd, X_fd, y, cv=gkf, groups=trial_ids, method='predict_proba') +oof_cc = cross_val_predict(lda_cc, X_cc, y, cv=gkf, groups=trial_ids, method='predict_proba') + +# Specialist CV accuracy (for diagnostics) +for name, mdl, Xs in [('LDA_TD', lda_td, X_td), ('LDA_FD', lda_fd, X_fd), ('LDA_CC', lda_cc, X_cc)]: + sc = cross_val_score(mdl, Xs, y, cv=gkf, groups=trial_ids) + print(f" {name}: {sc.mean()*100:.1f}% ± {sc.std()*100:.1f}%") + +# ─── Train meta-LDA on out-of-fold outputs ─────────────────────────────────── +X_meta = np.hstack([oof_td, oof_fd, oof_cc]) # (n_samples, 3*n_cls = 15) +meta_lda = LinearDiscriminantAnalysis() +meta_sc = cross_val_score(meta_lda, X_meta, y, cv=gkf, groups=trial_ids) +print(f" Meta-LDA: {meta_sc.mean()*100:.1f}% ± {meta_sc.std()*100:.1f}%") + +# Fit all models on full dataset for deployment +lda_td.fit(X_td, y); lda_fd.fit(X_fd, y); lda_cc.fit(X_cc, y) +meta_lda.fit(X_meta, y) + +# ─── Export all weights to C header ────────────────────────────────────────── +def lda_to_c_arrays(lda, name, feat_dim, n_cls, label_names, class_order): + """Generate C array strings for LDA weights and intercepts.""" + # Reorder classes to match label_names order + coef = lda.coef_ # shape (n_cls, feat_dim) for LinearDiscriminantAnalysis + intercept = lda.intercept_ + lines = [] + lines.append(f"const float {name}_WEIGHTS[{n_cls}][{feat_dim}] = {{") + for c in class_order: + row = ', '.join(f'{v:.8f}f' for v in coef[c]) + lines.append(f" {{{row}}}, // {label_names[c]}") + lines.append("};") + lines.append(f"const float {name}_INTERCEPTS[{n_cls}] = {{") + intercept_str = ', '.join(f'{intercept[c]:.8f}f' for c in class_order) + lines.append(f" {intercept_str}") + lines.append("};") + return '\n'.join(lines) + +class_order = list(range(n_cls)) +out_path = Path('EMG_Arm/src/core/model_weights_ensemble.h') + +with open(out_path, 'w') as f: + f.write("// Auto-generated by train_ensemble.py — do not edit\n") + f.write("#pragma once\n\n") + f.write(f"#define MODEL_NUM_CLASSES {n_cls}\n") + f.write(f"#define MODEL_NUM_FEATURES {X.shape[1]}\n") + f.write(f"#define ENSEMBLE_PER_CH_FEATURES 20\n\n") + f.write(f"#define TD_FEAT_OFFSET {min(td_idx)}\n") + f.write(f"#define TD_NUM_FEATURES {len(td_idx)}\n") + f.write(f"#define FD_FEAT_OFFSET {min(fd_idx)}\n") + f.write(f"#define FD_NUM_FEATURES {len(fd_idx)}\n") + f.write(f"#define CC_FEAT_OFFSET {min(cc_idx)}\n") + f.write(f"#define CC_NUM_FEATURES {len(cc_idx)}\n") + f.write(f"#define META_NUM_INPUTS ({3} * MODEL_NUM_CLASSES)\n\n") + + f.write(lda_to_c_arrays(lda_td, 'LDA_TD', len(td_idx), n_cls, label_names, class_order)) + f.write('\n\n') + f.write(lda_to_c_arrays(lda_fd, 'LDA_FD', len(fd_idx), n_cls, label_names, class_order)) + f.write('\n\n') + f.write(lda_to_c_arrays(lda_cc, 'LDA_CC', len(cc_idx), n_cls, label_names, class_order)) + f.write('\n\n') + f.write(lda_to_c_arrays(meta_lda, 'META_LDA', 3*n_cls, n_cls, label_names, class_order)) + f.write('\n\n') + + names_str = ', '.join(f'"{label_names[c]}"' for c in class_order) + f.write(f"const char *MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {{{names_str}}};\n") + +print(f"Exported ensemble weights to {out_path}") +print(f"Total weight storage: {(len(td_idx)+len(fd_idx)+len(cc_idx)+3*n_cls)*n_cls*4} bytes float32") +``` + +**Note on LinearDiscriminantAnalysis with multi-class**: scikit-learn's LDA uses a +`(n_classes-1, n_features)` coef matrix for multi-class. Verify `lda.coef_.shape` after +fitting — if it is `(n_cls-1, n_feat)` rather than `(n_cls, n_feat)`, use the +`decision_function()` output structure and adjust the export accordingly. + +--- + +# PART VII — FEATURE SELECTION FOR ESP32 PORTING + +After Change 1 is trained, use this to decide what to port to C firmware. + +### Step 1 — Get Feature Importance + +```python +importance = np.abs(classifier.model.coef_).mean(axis=0) +feat_names = classifier.feature_extractor.get_feature_names(n_channels=len(HAND_CHANNELS)) +ranked = sorted(zip(feat_names, importance), key=lambda x: -x[1]) +print("Top 20 features by LDA discriminative weight:") +for name, score in ranked[:20]: + print(f" {name:<35} {score:.4f}") +``` + +### Step 2 — Port Decision Matrix + +| Feature | C Complexity | Prereq | Port? | +|---------|-------------|--------|-------| +| RMS, WL, ZC, SSC | ✓ Already in C | — | Keep | +| MAV, VAR, IEMG | Very easy (1 loop) | None | ✓ Yes | +| WAMP | Very easy (threshold on diff) | None | ✓ Yes | +| Cross-ch covariance | Easy (3×3 outer product) | None | ✓ Yes | +| Cross-ch correlation | Easy (normalize covariance) | Covariance | ✓ Yes | +| Bandpower bp0–bp3 | Medium (128-pt FFT via esp-dsp) | Add FFT call | ✓ Yes — highest ROI | +| MNF, MDF, PKF, MNP | Easy after FFT | Bandpower FFT | ✓ Free once FFT added | +| AR(4) | Medium (Levinson-Durbin in C) | None | Only if top-8 importance | + +Once `dsps_fft2r_fc32()` is added for bandpower, MNF/MDF/PKF/MNP come free. + +### Step 3 — Adding FFT-Based Features to inference.c + +Add inside `compute_features()` loop, after time-domain features per channel: + +```c +// 128-pt FFT for frequency-domain features per channel +// Zero-pad signal from INFERENCE_WINDOW_SIZE (150) to 128 by truncating +float fft_buf[256] = {0}; // 128 complex floats +for (int i = 0; i < 128 && i < INFERENCE_WINDOW_SIZE; i++) { + fft_buf[2*i] = signal[i]; // real + fft_buf[2*i+1] = 0.0f; // imag +} +dsps_fft2r_fc32(fft_buf, 128); +dsps_bit_rev_fc32(fft_buf, 128); + +// Bandpower: bin k → freq = k * 1000/128 ≈ k * 7.8125 Hz +// Band 0: 20–80 Hz → bins 3–10 +// Band 1: 80–150 Hz → bins 10–19 +// Band 2: 150–300 Hz→ bins 19–38 +// Band 3: 300–500 Hz→ bins 38–64 +int band_bins[5] = {3, 10, 19, 38, 64}; +float bp[4] = {0,0,0,0}; +for (int b = 0; b < 4; b++) + for (int k = band_bins[b]; k < band_bins[b+1]; k++) { + float re = fft_buf[2*k], im = fft_buf[2*k+1]; + bp[b] += re*re + im*im; + } +// Store at correct indices (base = ch * 20) +int base = ch * 20; +features_out[base+16] = bp[0]; features_out[base+17] = bp[1]; +features_out[base+18] = bp[2]; features_out[base+19] = bp[3]; +``` + +--- + +# PART VIII — MEASUREMENT AND VALIDATION + +## Baseline Protocol + +**Run this BEFORE any change and after EACH change.** + +``` +1. python learning_data_collection.py → option 3 (Train Classifier) +2. Record: + - "Mean CV accuracy: XX.X% ± Y.Y%" (cross-validation) + - Confusion matrix (which gesture pairs are most confused) + - Per-gesture accuracy breakdown +3. On-device test: + - Put on sensors, perform 10 reps of each gesture + - Log classification output (UART or Python serial monitor) + - Compute per-gesture accuracy manually +4. Record REST false-trigger rate: hold arm at rest for 30 seconds, + count number of non-REST outputs +``` + +## Results Log + +| Change | CV Acc Before | CV Acc After | Delta | On-Device Acc | False Triggers/30s | Keep? | +|--------|--------------|-------------|-------|---------------|-------------------|-------| +| Baseline | — | — | — | — | — | — | +| Change C (reject) | — | — | — | — | — | — | +| Change B (filter) | — | — | — | — | — | — | +| Change 0 (label shift) | — | — | — | — | — | — | +| Change 1 (features) | — | — | — | — | — | — | +| Change D (NVS calib) | — | — | — | — | — | — | +| Change 3 (augment) | — | — | — | — | — | — | +| Change 5 (benchmark) | — | — | — | — | — | — | +| Change 7+F (ensemble) | — | — | — | — | — | — | +| Change E (MLP) | — | — | — | — | — | — | + +## When to Add More Gestures + +| CV Accuracy | Recommendation | +|-------------|----------------| +| <80% | Do NOT add gestures — fix the existing 5 first | +| 80–90% | Adding 1–2 gestures is reasonable; expect 5–8% drop per new gesture | +| >90% | Good baseline; can add gestures; target staying above 85% | +| >95% | Excellent; can be ambitious with gesture count | + +--- + +# PART IX — EXPORT WORKFLOW + +## Path 1 — LDA / Ensemble (Changes 0–4, 7+F) + +``` +1. Train: python learning_data_collection.py → option 3 (single LDA) + OR: python train_ensemble.py (full ensemble) + +2. Export: + Single LDA: classifier.export_to_header(Path('EMG_Arm/src/core/model_weights.h')) + Ensemble: export_ensemble_header() in train_ensemble.py + → writes model_weights_ensemble.h + +3. Port new features to inference.c (if Change 1 features added): + - Follow feature selection decision matrix (Part VII) + - CRITICAL: C feature index order MUST match Python FEATURE_ORDER exactly + +4. Build + flash: pio run -t upload +``` + +## Path 2 — int8 MLP via TFLM (Change E) + +``` +1. python train_mlp_tflite.py → emg_model_data.cc +2. Add TFLM to platformio.ini lib_deps +3. Replace LDA inference call with inference_mlp_predict() in inference.c + OR use inference_ensemble_predict() which calls MLP as fallback (Change F) +4. pio run -t upload +``` + +## Feature Index Contract (Critical) + +The order of values written to `features_out[]` in `compute_features()` in C **must exactly +match** `FEATURE_ORDER` in `extract_features_window()` in Python, index for index. + +To verify before flashing: print both the C feature names (from `MODEL_FEATURE_NAMES` if +added to header) and Python `extractor.get_feature_names()` and diff them. + +--- + +# PART X — REFERENCES + +**Primary paper**: Kaifosh, P., Reardon, T., et al. "A high-bandwidth neuromotor prosthesis +enabled by implicit information in intrinsic motor neurons." *Nature* (2025). +doi:10.1038/s41586-025-09255-w + +**Meta codebase** (label alignment, CLER metric, model architectures): +`C:/VSCode/Marvel_Projects/Meta_Emg_Stuff/generic-neuromotor-interface/` +- `data.py`: onset detection, `searchsorted` alignment, window jitter +- `cler.py`: threshold=0.35, debounce=50ms, tolerance=±50/250ms +- `networks.py`: model architectures, left_context=20, stride=10 +- `lightning.py`: `targets[..., left_context::stride]` label shift + +**Barachant et al. 2012**: "Multiclass brain–computer interface classification by +Riemannian geometry." — matrix logarithm reference (MPF features). + +**Espressif libraries**: +- esp-dsp: `github.com/espressif/esp-dsp` — biquad, FFT, dot-product +- esp-dl: `github.com/espressif/esp-dl` — quantized MLP/CNN inference +- TFLite Micro: `github.com/tensorflow/tflite-micro` + +**All project files** (existing + planned): + +``` +── Laptop / Python ───────────────────────────────────────────────────────────────────────── +C:/VSCode/Marvel_Projects/Bucky_Arm/learning_data_collection.py ← main: data collection + training +C:/VSCode/Marvel_Projects/Bucky_Arm/live_predict.py ← NEW (Part 0.6): laptop-side live inference +C:/VSCode/Marvel_Projects/Bucky_Arm/train_ensemble.py ← NEW (Change 7): ensemble training +C:/VSCode/Marvel_Projects/Bucky_Arm/train_mlp_tflite.py ← NEW (Change E): int8 MLP export + +── ESP32 Firmware — Existing ─────────────────────────────────────────────────────────────── +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/platformio.ini + └─ ADD lib_deps: espressif/esp-dsp (Changes B,1,F), tensorflow/tflite-micro (Change E) +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/config/config.h + └─ MODIFY: remove system_mode_t; add EMG_STANDALONE to MAIN_MODE enum (Part 0.7, S1) +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/app/main.c + └─ MODIFY: add STATE_LAPTOP_PREDICT, CMD_START_LAPTOP_PREDICT, run_laptop_predict_loop(), + run_standalone_loop() (Part 0.5) +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/drivers/emg_sensor.c + └─ MODIFY (Change A): migrate from adc_oneshot to adc_continuous driver +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/inference.c + └─ MODIFY: add inference_get_gesture_by_name(), IIR filter (B), features (1), confidence rejection (C) +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/inference.h + └─ MODIFY: add inference_get_gesture_by_name() declaration +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/gestures.c + └─ MODIFY: update gesture_names[] and gestures_execute() when adding gestures +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/model_weights.h + └─ AUTO-GENERATED by export_to_header() — do not edit manually + +── ESP32 Firmware — New Files ────────────────────────────────────────────────────────────── +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/bicep.h/.c ← Part 0 / Section 2.2 +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/calibration.h/.c ← Change D (NVS z-score) +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/inference_ensemble.h/.c ← Change F +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/inference_mlp.h/.cc ← Change E +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/model_weights_ensemble.h ← AUTO-GENERATED (Change 7) +C:/VSCode/Marvel_Projects/Bucky_Arm/EMG_Arm/src/core/emg_model_data.h/.cc ← AUTO-GENERATED (Change E) +``` diff --git a/EMG_Arm/dependencies.lock b/EMG_Arm/dependencies.lock new file mode 100644 index 0000000..07788a1 --- /dev/null +++ b/EMG_Arm/dependencies.lock @@ -0,0 +1,10 @@ +dependencies: + idf: + source: + type: idf + version: 5.5.1 +direct_dependencies: +- idf +manifest_hash: 26c322f28d1cb305b28c1bcb6df69caf3919f0c18286c9ac6394e338c217fba8 +target: esp32s3 +version: 2.0.0 diff --git a/EMG_Arm/idf_component.yml b/EMG_Arm/idf_component.yml new file mode 100644 index 0000000..b582716 --- /dev/null +++ b/EMG_Arm/idf_component.yml @@ -0,0 +1,10 @@ +# ESP-IDF Component Manager dependencies +# These are ESP-IDF components (NOT PlatformIO libraries). +# Run `pio run --target menuconfig` after modifying to refresh the component index. + +dependencies: + # esp-dsp: required when MODEL_EXPAND_FEATURES=1 (Change 1 — 69-feature FFT) + espressif/esp-dsp: ">=2.0.0" + + # TFLite Micro: required when MODEL_USE_MLP=1 (Change E — int8 MLP) + # tensorflow/tflite-micro: ">=2.0.0" diff --git a/EMG_Arm/platformio.ini b/EMG_Arm/platformio.ini index b9df812..b3137ef 100644 --- a/EMG_Arm/platformio.ini +++ b/EMG_Arm/platformio.ini @@ -14,4 +14,7 @@ board_build.partitions = partitions.csv monitor_speed = 921600 monitor_dtr = 1 -monitor_rts = 1 \ No newline at end of file +monitor_rts = 1 + +; ── esp-dsp: required for MODEL_EXPAND_FEATURES=1 (FFT-based features) ─────── +; Cloned locally: components/esp-dsp \ No newline at end of file diff --git a/EMG_Arm/sdkconfig.esp32-s3-devkitc1-n16r16 b/EMG_Arm/sdkconfig.esp32-s3-devkitc1-n16r16 index 9d99d73..1229419 100644 --- a/EMG_Arm/sdkconfig.esp32-s3-devkitc1-n16r16 +++ b/EMG_Arm/sdkconfig.esp32-s3-devkitc1-n16r16 @@ -594,6 +594,14 @@ CONFIG_PARTITION_TABLE_OFFSET=0x8000 CONFIG_PARTITION_TABLE_MD5=y # end of Partition Table +# +# ESP-NN +# +# CONFIG_NN_ANSI_C is not set +CONFIG_NN_OPTIMIZED=y +CONFIG_NN_OPTIMIZATIONS=1 +# end of ESP-NN + # # Compiler options # @@ -2237,6 +2245,23 @@ CONFIG_WIFI_PROV_AUTOSTOP_TIMEOUT=30 CONFIG_WIFI_PROV_STA_ALL_CHANNEL_SCAN=y # CONFIG_WIFI_PROV_STA_FAST_SCAN is not set # end of Wi-Fi Provisioning Manager + +# +# DSP Library +# +CONFIG_DSP_OPTIMIZATIONS_SUPPORTED=y +# CONFIG_DSP_ANSI is not set +CONFIG_DSP_OPTIMIZED=y +CONFIG_DSP_OPTIMIZATION=1 +# CONFIG_DSP_MAX_FFT_SIZE_512 is not set +# CONFIG_DSP_MAX_FFT_SIZE_1024 is not set +# CONFIG_DSP_MAX_FFT_SIZE_2048 is not set +CONFIG_DSP_MAX_FFT_SIZE_4096=y +# CONFIG_DSP_MAX_FFT_SIZE_8192 is not set +# CONFIG_DSP_MAX_FFT_SIZE_16384 is not set +# CONFIG_DSP_MAX_FFT_SIZE_32768 is not set +CONFIG_DSP_MAX_FFT_SIZE=4096 +# end of DSP Library # end of Component config # CONFIG_IDF_EXPERIMENTAL_FEATURES is not set diff --git a/EMG_Arm/src/CMakeLists.txt b/EMG_Arm/src/CMakeLists.txt index 9a7dd7d..61daea5 100644 --- a/EMG_Arm/src/CMakeLists.txt +++ b/EMG_Arm/src/CMakeLists.txt @@ -19,8 +19,13 @@ set(DRIVER_SOURCES ) set(CORE_SOURCES + core/bicep.c + core/calibration.c core/gestures.c core/inference.c + core/inference_ensemble.c + core/inference_mlp.cc + core/emg_model_data.cc ) set(APP_SOURCES @@ -36,5 +41,5 @@ idf_component_register( ${APP_SOURCES} INCLUDE_DIRS . - REQUIRES esp_adc + REQUIRES esp_adc nvs_flash esp-dsp esp-tflite-micro ) diff --git a/EMG_Arm/src/app/main.c b/EMG_Arm/src/app/main.c index 38db647..ccc1528 100644 --- a/EMG_Arm/src/app/main.c +++ b/EMG_Arm/src/app/main.c @@ -20,8 +20,13 @@ #include #include "config/config.h" +#include "core/bicep.h" +#include "core/calibration.h" #include "core/gestures.h" -#include "core/inference.h" // [NEW] +#include "core/inference.h" +#include "core/inference_ensemble.h" +#include "core/inference_mlp.h" +#include "core/model_weights.h" #include "drivers/emg_sensor.h" #include "drivers/hand.h" @@ -40,10 +45,12 @@ * @brief Device state machine. */ typedef enum { - STATE_IDLE = 0, /**< Waiting for connect command */ - STATE_CONNECTED, /**< Connected, waiting for start command */ - STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */ - STATE_PREDICTING, /**< [NEW] On-device inference and control */ + STATE_IDLE = 0, /**< Waiting for connect command */ + STATE_CONNECTED, /**< Connected, waiting for start command */ + STATE_STREAMING, /**< Streaming raw EMG CSV to laptop (data collection) */ + STATE_PREDICTING, /**< On-device inference + arm control */ + STATE_LAPTOP_PREDICT, /**< Streaming CSV to laptop; laptop infers + sends gesture cmds back */ + STATE_CALIBRATING, /**< Collecting rest data for calibration */ } device_state_t; /** @@ -52,8 +59,10 @@ typedef enum { typedef enum { CMD_NONE = 0, CMD_CONNECT, - CMD_START, /**< Start raw streaming */ - CMD_START_PREDICT, /**< [NEW] Start on-device prediction */ + CMD_START, /**< Start raw ADC streaming to laptop */ + CMD_START_PREDICT, /**< Start on-device inference + arm control */ + CMD_START_LAPTOP_PREDICT, /**< Start laptop-mediated inference (stream + receive cmds) */ + CMD_CALIBRATE, /**< Run rest calibration (z-score + bicep threshold) */ CMD_STOP, CMD_DISCONNECT, } command_t; @@ -65,11 +74,17 @@ typedef enum { static volatile device_state_t g_device_state = STATE_IDLE; static QueueHandle_t g_cmd_queue = NULL; +// Latest gesture command received from laptop during STATE_LAPTOP_PREDICT. +// Written by serial_input_task; read+cleared by run_laptop_predict_loop. +// gesture_t is a 32-bit int on LX7 — reads/writes are atomic; volatile is sufficient. +static volatile gesture_t g_laptop_gesture = GESTURE_NONE; + /******************************************************************************* * Forward Declarations ******************************************************************************/ static void send_ack_connect(void); +static gesture_t parse_laptop_gesture(const char *line); /******************************************************************************* * Command Parsing @@ -102,13 +117,17 @@ static command_t parse_command(const char *line) { value_start++; } - /* Match command strings */ + /* Match command strings — ordered longest-prefix-first to avoid false matches */ if (strncmp(value_start, "connect", 7) == 0) { return CMD_CONNECT; + } else if (strncmp(value_start, "start_laptop_predict", 20) == 0) { + return CMD_START_LAPTOP_PREDICT; } else if (strncmp(value_start, "start_predict", 13) == 0) { return CMD_START_PREDICT; } else if (strncmp(value_start, "start", 5) == 0) { return CMD_START; + } else if (strncmp(value_start, "calibrate", 9) == 0) { + return CMD_CALIBRATE; } else if (strncmp(value_start, "stop", 4) == 0) { return CMD_STOP; } else if (strncmp(value_start, "disconnect", 10) == 0) { @@ -118,6 +137,38 @@ static command_t parse_command(const char *line) { return CMD_NONE; } +/******************************************************************************* + * Laptop Gesture Parser + ******************************************************************************/ + +/** + * @brief Parse a gesture command sent by live_predict.py. + * + * Expected format: {"gesture":"fist"} + * Returns GESTURE_NONE if the line is not a valid gesture command. + */ +static gesture_t parse_laptop_gesture(const char *line) { + const char *g = strstr(line, "\"gesture\""); + if (!g) return GESTURE_NONE; + + const char *v = strchr(g, ':'); + if (!v) return GESTURE_NONE; + + v++; + while (*v == ' ' || *v == '"') v++; + + // Extract gesture name up to the closing quote + char name[32] = {0}; + int ni = 0; + while (*v && *v != '"' && ni < (int)(sizeof(name) - 1)) { + name[ni++] = *v++; + } + + // Delegate to the inference module's name→enum mapping + int result = inference_get_gesture_by_name(name); + return (gesture_t)result; +} + /******************************************************************************* * Serial Input Task ******************************************************************************/ @@ -142,6 +193,15 @@ static void serial_input_task(void *pvParameters) { line_buffer[line_idx] = '\0'; command_t cmd = parse_command(line_buffer); + // When laptop-predict is active, try to parse gesture commands first. + // These are separate from FSM commands and must not block the FSM parser. + if (g_device_state == STATE_LAPTOP_PREDICT) { + gesture_t g = parse_laptop_gesture(line_buffer); + if (g != GESTURE_NONE) { + g_laptop_gesture = g; + } + } + if (cmd != CMD_NONE) { if (cmd == CMD_CONNECT) { g_device_state = STATE_CONNECTED; @@ -161,6 +221,15 @@ static void serial_input_task(void *pvParameters) { g_device_state = STATE_PREDICTING; printf("[STATE] CONNECTED -> PREDICTING\n"); xQueueSend(g_cmd_queue, &cmd, 0); + } else if (cmd == CMD_START_LAPTOP_PREDICT) { + g_device_state = STATE_LAPTOP_PREDICT; + g_laptop_gesture = GESTURE_NONE; + printf("[STATE] CONNECTED -> LAPTOP_PREDICT\n"); + xQueueSend(g_cmd_queue, &cmd, 0); + } else if (cmd == CMD_CALIBRATE) { + g_device_state = STATE_CALIBRATING; + printf("[STATE] CONNECTED -> CALIBRATING\n"); + xQueueSend(g_cmd_queue, &cmd, 0); } else if (cmd == CMD_DISCONNECT) { g_device_state = STATE_IDLE; printf("[STATE] CONNECTED -> IDLE\n"); @@ -169,6 +238,8 @@ static void serial_input_task(void *pvParameters) { case STATE_STREAMING: case STATE_PREDICTING: + case STATE_LAPTOP_PREDICT: + case STATE_CALIBRATING: if (cmd == CMD_STOP) { g_device_state = STATE_CONNECTED; printf("[STATE] ACTIVE -> CONNECTED\n"); @@ -206,59 +277,198 @@ static void send_ack_connect(void) { */ static void stream_emg_data(void) { emg_sample_t sample; - const TickType_t delay_ticks = 1; while (g_device_state == STATE_STREAMING) { - emg_sensor_read(&sample); + emg_sensor_read(&sample); // blocks ~1 ms (queue-paced by DMA task) printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1], sample.channels[2], sample.channels[3]); - vTaskDelay(delay_ticks); } } +/******************************************************************************* + * Multi-Model Voting Post-Processing + * + * Shared EMA smoothing, majority vote, and debounce applied to the averaged + * probability output from all enabled models (single LDA, ensemble, MLP). + ******************************************************************************/ + +#define VOTE_EMA_ALPHA 0.70f +#define VOTE_CONF_THRESHOLD 0.40f +#define VOTE_WINDOW_SIZE 5 +#define VOTE_DEBOUNCE_COUNT 3 + +static float vote_smoothed[MODEL_NUM_CLASSES]; +static int vote_history[VOTE_WINDOW_SIZE]; +static int vote_head = 0; +static int vote_current_output = -1; +static int vote_pending_output = -1; +static int vote_pending_count = 0; + +static void vote_init(void) { + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + vote_smoothed[c] = 1.0f / MODEL_NUM_CLASSES; + for (int i = 0; i < VOTE_WINDOW_SIZE; i++) + vote_history[i] = -1; + vote_head = 0; + vote_current_output = -1; + vote_pending_output = -1; + vote_pending_count = 0; +} + +/** + * @brief Apply EMA + majority vote + debounce to averaged probabilities. + * + * @param avg_proba Averaged probability vector from all models. + * @param confidence Output: smoothed confidence of the final winner. + * @return Final gesture class index (-1 if uncertain). + */ +static int vote_postprocess(const float *avg_proba, float *confidence) { + /* EMA smoothing */ + float max_smooth = 0.0f; + int winner = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + vote_smoothed[c] = VOTE_EMA_ALPHA * vote_smoothed[c] + + (1.0f - VOTE_EMA_ALPHA) * avg_proba[c]; + if (vote_smoothed[c] > max_smooth) { + max_smooth = vote_smoothed[c]; + winner = c; + } + } + + /* Confidence rejection */ + if (max_smooth < VOTE_CONF_THRESHOLD) { + *confidence = max_smooth; + return vote_current_output; + } + + *confidence = max_smooth; + + /* Majority vote */ + vote_history[vote_head] = winner; + vote_head = (vote_head + 1) % VOTE_WINDOW_SIZE; + + int counts[MODEL_NUM_CLASSES]; + memset(counts, 0, sizeof(counts)); + for (int i = 0; i < VOTE_WINDOW_SIZE; i++) { + if (vote_history[i] >= 0) + counts[vote_history[i]]++; + } + int majority = 0, majority_cnt = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + if (counts[c] > majority_cnt) { + majority_cnt = counts[c]; + majority = c; + } + } + + /* Debounce */ + int final_out = (vote_current_output >= 0) ? vote_current_output : majority; + + if (vote_current_output < 0) { + vote_current_output = majority; + final_out = majority; + } else if (majority == vote_current_output) { + vote_pending_output = majority; + vote_pending_count = 1; + } else if (majority == vote_pending_output) { + if (++vote_pending_count >= VOTE_DEBOUNCE_COUNT) { + vote_current_output = majority; + final_out = majority; + } + } else { + vote_pending_output = majority; + vote_pending_count = 1; + } + + return final_out; +} + /** * @brief Run on-device inference (Prediction Mode). + * + * Uses multi-model voting: all enabled models (single LDA, ensemble, MLP) + * produce raw probability vectors which are averaged, then passed through + * shared EMA smoothing, majority vote, and debounce. */ static void run_inference_loop(void) { emg_sample_t sample; - const TickType_t delay_ticks = 1; // 1ms @ 1kHz int last_gesture = -1; + int stride_counter = 0; - // Reset inference state inference_init(); - printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n"); + vote_init(); + + int n_models = 1; /* single LDA always enabled */ +#if MODEL_USE_ENSEMBLE + inference_ensemble_init(); + n_models++; +#endif +#if MODEL_USE_MLP + inference_mlp_init(); + n_models++; +#endif + + printf("{\"status\":\"info\",\"msg\":\"Multi-model inference (%d models)\"}\n", + n_models); while (g_device_state == STATE_PREDICTING) { emg_sensor_read(&sample); - // Add to buffer - // Note: sample.channels is uint16_t, matching inference engine expectation if (inference_add_sample(sample.channels)) { - // Buffer full (sliding window), run prediction - // We can optimize stride here (e.g. valid prediction only every N - // samples) For now, let's predict every sample (sliding window) or - // throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz? - // maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid - // UART spam. - - static int stride_counter = 0; stride_counter++; - if (stride_counter >= 20) { // 20ms stride - float confidence = 0; - int gesture_idx = inference_predict(&confidence); + if (stride_counter >= INFERENCE_HOP_SIZE) { stride_counter = 0; - if (gesture_idx >= 0) { - // Map class index (0-N) to gesture enum (correct hardware action) - int gesture_enum = inference_get_gesture_enum(gesture_idx); + /* 1. Extract features once */ + float features[MODEL_NUM_FEATURES]; + inference_extract_features(features); - // Execute gesture on hand + /* 2. Collect raw probabilities from each model */ + float avg_proba[MODEL_NUM_CLASSES]; + memset(avg_proba, 0, sizeof(avg_proba)); + + /* Model A: Single LDA (always active) */ + float proba_lda[MODEL_NUM_CLASSES]; + inference_predict_raw(features, proba_lda); + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] += proba_lda[c]; + +#if MODEL_USE_ENSEMBLE + /* Model B: 3-specialist + meta-LDA ensemble */ + float proba_ens[MODEL_NUM_CLASSES]; + inference_ensemble_predict_raw(features, proba_ens); + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] += proba_ens[c]; +#endif + +#if MODEL_USE_MLP + /* Model C: int8 MLP — returns class index + confidence. + * Spread confidence as soft one-hot for averaging. */ + float mlp_conf = 0.0f; + int mlp_class = inference_mlp_predict(features, MODEL_NUM_FEATURES, + &mlp_conf); + float remainder = (1.0f - mlp_conf) / (MODEL_NUM_CLASSES - 1); + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] += (c == mlp_class) ? mlp_conf : remainder; +#endif + + /* 3. Average across models */ + float inv_n = 1.0f / n_models; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] *= inv_n; + + /* 4. Shared post-processing */ + float confidence = 0.0f; + int gesture_idx = vote_postprocess(avg_proba, &confidence); + + if (gesture_idx >= 0) { + int gesture_enum = inference_get_gesture_enum(gesture_idx); gestures_execute((gesture_t)gesture_enum); - // Send telemetry if changed or periodically? - // "Live prediction flow should change to only have each new output... - // sent" + bicep_state_t bicep = bicep_detect(); + (void)bicep; + if (gesture_idx != last_gesture) { printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n", inference_get_class_name(gesture_idx), confidence); @@ -267,11 +477,192 @@ static void run_inference_loop(void) { } } } - - vTaskDelay(delay_ticks); } } +/** + * @brief Laptop-mediated prediction loop (STATE_LAPTOP_PREDICT). + * + * Streams raw ADC CSV to laptop at 1kHz (same format as STATE_STREAMING). + * The laptop runs live_predict.py, which classifies each window and sends + * {"gesture":""} back over UART. serial_input_task() intercepts those + * lines and writes to g_laptop_gesture; this loop executes whatever is set. + */ +static void run_laptop_predict_loop(void) { + emg_sample_t sample; + gesture_t last_gesture = GESTURE_NONE; + + printf("{\"status\":\"info\",\"msg\":\"Laptop-predict mode started\"}\n"); + + while (g_device_state == STATE_LAPTOP_PREDICT) { + emg_sensor_read(&sample); + // Stream raw CSV to laptop (live_predict.py reads this) + printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1], + sample.channels[2], sample.channels[3]); + + // Execute any gesture command that arrived from the laptop + gesture_t g = g_laptop_gesture; + if (g != GESTURE_NONE) { + g_laptop_gesture = GESTURE_NONE; // clear before executing + gestures_execute(g); + if (g != last_gesture) { + // Echo the executed gesture back for laptop-side logging + printf("{\"executed\":\"%s\"}\n", gestures_get_name(g)); + last_gesture = g; + } + } + } +} + +/** + * @brief Fully autonomous inference loop (MAIN_MODE == EMG_STANDALONE). + * + * No laptop required. Runs forever until power-off. + * Assumes inference_init() and sensor init have already been called by app_main(). + */ +static void run_standalone_loop(void) { + emg_sample_t sample; + int stride_counter = 0; + int last_gesture = -1; + + inference_init(); + vote_init(); + + int n_models = 1; +#if MODEL_USE_ENSEMBLE + inference_ensemble_init(); + n_models++; +#endif +#if MODEL_USE_MLP + inference_mlp_init(); + n_models++; +#endif + + printf("[STANDALONE] Running autonomous EMG control (%d models).\n", n_models); + + while (1) { + emg_sensor_read(&sample); + + if (inference_add_sample(sample.channels)) { + stride_counter++; + if (stride_counter >= INFERENCE_HOP_SIZE) { + stride_counter = 0; + + float features[MODEL_NUM_FEATURES]; + inference_extract_features(features); + + float avg_proba[MODEL_NUM_CLASSES]; + memset(avg_proba, 0, sizeof(avg_proba)); + + float proba_lda[MODEL_NUM_CLASSES]; + inference_predict_raw(features, proba_lda); + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] += proba_lda[c]; + +#if MODEL_USE_ENSEMBLE + float proba_ens[MODEL_NUM_CLASSES]; + inference_ensemble_predict_raw(features, proba_ens); + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] += proba_ens[c]; +#endif + +#if MODEL_USE_MLP + float mlp_conf = 0.0f; + int mlp_class = inference_mlp_predict(features, MODEL_NUM_FEATURES, + &mlp_conf); + float remainder = (1.0f - mlp_conf) / (MODEL_NUM_CLASSES - 1); + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] += (c == mlp_class) ? mlp_conf : remainder; +#endif + + float inv_n = 1.0f / n_models; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + avg_proba[c] *= inv_n; + + float confidence = 0.0f; + int gesture_idx = vote_postprocess(avg_proba, &confidence); + + if (gesture_idx >= 0) { + int gesture_enum = inference_get_gesture_enum(gesture_idx); + gestures_execute((gesture_t)gesture_enum); + + bicep_state_t bicep = bicep_detect(); + (void)bicep; + + if (gesture_idx != last_gesture) { + printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n", + inference_get_class_name(gesture_idx), confidence); + last_gesture = gesture_idx; + } + } + } + } + } +} + +/** + * @brief Run rest calibration (STATE_CALIBRATING). + * + * Collects 3 seconds of rest EMG data, extracts features from each window, + * then updates: + * - NVS z-score calibration (Change D) via calibration_update() + * - Bicep detection threshold via bicep_calibrate_from_buffer() + * + * The user must keep their arm relaxed during this period. + * Send {"cmd": "calibrate"} from the host to trigger. + */ +static void run_calibration(void) { + #define CALIB_DURATION_SAMPLES 3000 /* 3 seconds at 1 kHz */ + #define CALIB_MAX_WINDOWS \ + ((CALIB_DURATION_SAMPLES - INFERENCE_WINDOW_SIZE) / INFERENCE_HOP_SIZE + 1) + + printf("{\"status\":\"calibrating\",\"duration_ms\":3000}\n"); + fflush(stdout); + + inference_init(); /* reset buffer + filter state */ + + float *feat_matrix = (float *)malloc( + (size_t)CALIB_MAX_WINDOWS * MODEL_NUM_FEATURES * sizeof(float)); + if (!feat_matrix) { + printf("{\"status\":\"error\",\"msg\":\"Calibration malloc failed\"}\n"); + g_device_state = STATE_CONNECTED; + return; + } + + emg_sample_t sample; + int window_count = 0; + int stride_counter = 0; + + for (int s = 0; s < CALIB_DURATION_SAMPLES; s++) { + emg_sensor_read(&sample); + if (inference_add_sample(sample.channels)) { + stride_counter++; + if (stride_counter >= INFERENCE_HOP_SIZE) { + stride_counter = 0; + if (window_count < CALIB_MAX_WINDOWS) { + inference_extract_features( + feat_matrix + window_count * MODEL_NUM_FEATURES); + window_count++; + } + } + } + } + + if (window_count >= 10) { + calibration_update(feat_matrix, window_count, MODEL_NUM_FEATURES); + bicep_calibrate_from_buffer(INFERENCE_WINDOW_SIZE); + printf("{\"status\":\"calibrated\",\"windows\":%d}\n", window_count); + } else { + printf("{\"status\":\"error\",\"msg\":\"Not enough calibration data\"}\n"); + } + + free(feat_matrix); + g_device_state = STATE_CONNECTED; + + #undef CALIB_DURATION_SAMPLES + #undef CALIB_MAX_WINDOWS +} + static void state_machine_loop(void) { command_t cmd; const TickType_t poll_interval = pdMS_TO_TICKS(50); @@ -281,6 +672,10 @@ static void state_machine_loop(void) { stream_emg_data(); } else if (g_device_state == STATE_PREDICTING) { run_inference_loop(); + } else if (g_device_state == STATE_LAPTOP_PREDICT) { + run_laptop_predict_loop(); + } else if (g_device_state == STATE_CALIBRATING) { + run_calibration(); } xQueueReceive(g_cmd_queue, &cmd, poll_interval); @@ -299,7 +694,9 @@ void appConnector() { printf("[PROTOCOL] Waiting for host to connect...\n"); printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n"); printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device " - "inference\n\n"); + "inference\n"); + printf("[PROTOCOL] Send: {\"cmd\": \"calibrate\"} for rest " + "calibration (3s)\n\n"); state_machine_loop(); } @@ -324,13 +721,108 @@ void app_main(void) { printf("[INIT] Initializing Inference Engine...\n"); inference_init(); -#if FEATURE_FAKE_EMG - printf("[INIT] Using FAKE EMG data (sensors not connected)\n"); -#else - printf("[INIT] Using REAL EMG sensors\n"); -#endif + printf("[INIT] Loading NVS calibration...\n"); + calibration_init(); // Change D: no-op on first boot; loads if previously saved + // Bicep: load persisted threshold from NVS (if previously calibrated) + { + float bicep_thresh = 0.0f; + if (bicep_load_threshold(&bicep_thresh)) { + printf("[INIT] Bicep threshold loaded: %.1f\n", bicep_thresh); + } else { + printf("[INIT] No bicep calibration — run 'calibrate' command\n"); + } + } + + printf("[INIT] Using REAL EMG sensors\n"); printf("[INIT] Done!\n\n"); - appConnector(); -} + switch (MAIN_MODE) { + case EMG_STANDALONE: + // Fully autonomous: no laptop needed after this point. + // Boots directly into the inference + arm control loop. + run_standalone_loop(); // never returns + break; + + case SERVO_CALIBRATOR: + while (1) { + int angle; + printf("Enter servo angle (0-180): "); + fflush(stdout); + + // Read a line manually, yielding while waiting for UART input + char buf[16]; + int idx = 0; + while (idx < (int)sizeof(buf) - 1) { + int ch = getchar(); + if (ch == EOF) { + vTaskDelay(pdMS_TO_TICKS(10)); + continue; + } + if (ch == '\n' || ch == '\r') + break; + buf[idx++] = (char)ch; + } + buf[idx] = '\0'; + + if (idx == 0) + continue; + + if (sscanf(buf, "%d", &angle) == 1) { + if (angle >= 0 && angle <= 180) { + hand_set_finger_angle(FINGER_THUMB, angle); + vTaskDelay(pdMS_TO_TICKS(1000)); + } else { + printf("Invalid angle. Must be between 0 and 180.\n"); + } + } else { + printf("Invalid input.\n"); + } + } + break; + + case GESTURE_TESTER: + while (1) { + fflush(stdout); + + int ch = getchar(); + + if (ch == EOF) { + vTaskDelay(pdMS_TO_TICKS(10)); + continue; + } + + if (ch == '\n' || ch == '\r') { + continue; + } + + gesture_t gesture = GESTURE_NONE; + + switch (ch) { + case 'r': gesture = GESTURE_REST; break; + case 'f': gesture = GESTURE_FIST; break; + case 'o': gesture = GESTURE_OPEN; break; + case 'h': gesture = GESTURE_HOOK_EM; break; + case 't': gesture = GESTURE_THUMBS_UP; break; + default: + printf("Invalid gesture: %c\n", ch); + continue; + } + + printf("Executing gesture: %s\n", gestures_get_name(gesture)); + gestures_execute(gesture); + + vTaskDelay(pdMS_TO_TICKS(500)); + } + + break; + + case EMG_MAIN: + appConnector(); + break; + + default: + printf("[ERROR] Unknown MAIN_MODE\n"); + break; + } +} \ No newline at end of file diff --git a/EMG_Arm/src/config/config.h b/EMG_Arm/src/config/config.h index 7fb8cd4..16057b0 100644 --- a/EMG_Arm/src/config/config.h +++ b/EMG_Arm/src/config/config.h @@ -13,19 +13,12 @@ #include "driver/ledc.h" /******************************************************************************* - * Feature Flags - * - * Compile-time switches to enable/disable features. - * Set to 1 to enable, 0 to disable. + * Main Modes ******************************************************************************/ -/** - * @brief Use fake EMG data (random values) instead of real ADC reads. - * - * Set to 1 while waiting for EMG sensors to arrive. - * Set to 0 when ready to use real sensors. - */ -#define FEATURE_FAKE_EMG 0 +enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER, EMG_STANDALONE}; + +#define MAIN_MODE EMG_MAIN /******************************************************************************* * GPIO Pin Definitions - Servos @@ -95,15 +88,4 @@ typedef enum { GESTURE_COUNT } gesture_t; -/** - * @brief System operating modes. - */ -typedef enum { - MODE_IDLE = 0, /**< Waiting for commands */ - MODE_DATA_STREAM, /**< Streaming EMG data to laptop */ - MODE_COMMAND, /**< Executing gesture commands from laptop */ - MODE_DEMO, /**< Running demo sequence */ - MODE_COUNT -} system_mode_t; - #endif /* CONFIG_H */ diff --git a/EMG_Arm/src/core/bicep.c b/EMG_Arm/src/core/bicep.c new file mode 100644 index 0000000..c6733f0 --- /dev/null +++ b/EMG_Arm/src/core/bicep.c @@ -0,0 +1,142 @@ +/** + * @file bicep.c + * @brief Bicep channel subsystem — binary flex/unflex detector (Phase 1). + */ + +#include "bicep.h" +#include "inference.h" /* inference_get_bicep_rms() */ +#include "nvs_flash.h" +#include "nvs.h" +#include +#include +#include + +/* Tuning constants */ +#define BICEP_WINDOW_SAMPLES 50 /**< 50 ms window at 1 kHz */ +#define BICEP_FLEX_MULTIPLIER 2.5f /**< threshold = rest_rms × 2.5 */ +#define BICEP_HYSTERESIS 1.3f /**< scale factor to enter flex (prevents toggling) */ + +/* NVS storage */ +#define BICEP_NVS_NAMESPACE "bicep_calib" +#define BICEP_NVS_KEY_THRESH "threshold" +#define BICEP_NVS_KEY_VALID "calib_ok" + +/* Module state */ +static float s_threshold_mv = 0.0f; +static bicep_state_t s_state = BICEP_STATE_REST; + +/******************************************************************************* + * Public API + ******************************************************************************/ + +float bicep_calibrate(const uint16_t *ch3_samples, int n_samples) { + if (n_samples <= 0) return 0.0f; + + float rms_sq = 0.0f; + for (int i = 0; i < n_samples; i++) { + float v = (float)ch3_samples[i]; + rms_sq += v * v; + } + float rest_rms = sqrtf(rms_sq / n_samples); + s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER; + s_state = BICEP_STATE_REST; + + printf("[Bicep] Calibrated: rest_rms=%.1f mV, threshold=%.1f mV\n", + rest_rms, s_threshold_mv); + + bicep_save_threshold(s_threshold_mv); + return s_threshold_mv; +} + +bicep_state_t bicep_detect(void) { + if (s_threshold_mv <= 0.0f) { + return BICEP_STATE_REST; /* Not calibrated */ + } + + float rms = inference_get_bicep_rms(BICEP_WINDOW_SAMPLES); + + /* Hysteretic threshold: need FLEX_MULTIPLIER × threshold to enter flex, + * drop below threshold to return to rest. */ + if (s_state == BICEP_STATE_REST) { + if (rms > s_threshold_mv * BICEP_HYSTERESIS) { + s_state = BICEP_STATE_FLEX; + } + } else { /* BICEP_STATE_FLEX */ + if (rms < s_threshold_mv) { + s_state = BICEP_STATE_REST; + } + } + + return s_state; +} + +bool bicep_save_threshold(float threshold_mv) { + nvs_handle_t h; + if (nvs_open(BICEP_NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) { + printf("[Bicep] Failed to open NVS for write\n"); + return false; + } + + esp_err_t err = ESP_OK; + err |= nvs_set_blob(h, BICEP_NVS_KEY_THRESH, &threshold_mv, sizeof(threshold_mv)); + err |= nvs_set_u8 (h, BICEP_NVS_KEY_VALID, 1u); + err |= nvs_commit(h); + nvs_close(h); + + if (err != ESP_OK) { + printf("[Bicep] NVS write failed (err=0x%x)\n", err); + return false; + } + printf("[Bicep] Threshold %.1f mV saved to NVS\n", threshold_mv); + return true; +} + +bool bicep_load_threshold(float *threshold_mv_out) { + nvs_handle_t h; + if (nvs_open(BICEP_NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) { + return false; + } + + uint8_t valid = 0; + float thresh = 0.0f; + size_t sz = sizeof(thresh); + + bool ok = (nvs_get_u8 (h, BICEP_NVS_KEY_VALID, &valid) == ESP_OK) && + (valid == 1) && + (nvs_get_blob(h, BICEP_NVS_KEY_THRESH, &thresh, &sz) == ESP_OK) && + (thresh > 0.0f); + nvs_close(h); + + if (ok) { + s_threshold_mv = thresh; + if (threshold_mv_out) *threshold_mv_out = thresh; + printf("[Bicep] Loaded threshold: %.1f mV\n", thresh); + } + return ok; +} + +float bicep_calibrate_from_buffer(int n_samples) { + float rest_rms = inference_get_bicep_rms(n_samples); + if (rest_rms < 1e-6f) { + printf("[Bicep] WARNING: rest RMS ≈ 0 — buffer may not be filled yet\n"); + return 0.0f; + } + + s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER; + s_state = BICEP_STATE_REST; + + printf("[Bicep] Calibrated (filtered): rest_rms=%.2f, threshold=%.2f\n", + rest_rms, s_threshold_mv); + + bicep_save_threshold(s_threshold_mv); + return s_threshold_mv; +} + +void bicep_set_threshold(float threshold_mv) { + s_threshold_mv = threshold_mv; + s_state = BICEP_STATE_REST; +} + +float bicep_get_threshold(void) { + return s_threshold_mv; +} diff --git a/EMG_Arm/src/core/bicep.h b/EMG_Arm/src/core/bicep.h new file mode 100644 index 0000000..3b6aae4 --- /dev/null +++ b/EMG_Arm/src/core/bicep.h @@ -0,0 +1,97 @@ +/** + * @file bicep.h + * @brief Bicep channel (ch3) subsystem — Phase 1: binary flex/unflex detector. + * + * Implements a simple RMS threshold detector with hysteresis for bicep activation. + * ch3 data flows through the same IIR bandpass filter and circular buffer as the + * hand gesture channels (via inference_get_bicep_rms()), so no separate ADC read + * is required. + * + * Usage: + * 1. On startup: bicep_load_threshold(&thresh) — restore persisted threshold + * 2. After 3 s of relaxed rest: + * bicep_calibrate(raw_ch3_samples, n_samples) — sets + saves threshold + * 3. Every 25 ms hop: + * bicep_state_t state = bicep_detect(); + */ + +#ifndef BICEP_H +#define BICEP_H + +#include +#include + +/** + * @brief Bicep activation state. + */ +typedef enum { + BICEP_STATE_REST = 0, + BICEP_STATE_FLEX = 1, +} bicep_state_t; + +/** + * @brief Calibrate bicep threshold from REST data. + * + * Computes rest-RMS over the provided samples, then sets the internal + * detection threshold to rest_rms × BICEP_FLEX_MULTIPLIER. + * + * @param ch3_samples Raw ADC / mV values from the bicep channel. + * @param n_samples Number of samples provided. + * @return Computed threshold in the same units as ch3_samples. + */ +float bicep_calibrate(const uint16_t *ch3_samples, int n_samples); + +/** + * @brief Detect current bicep state from the latest window. + * + * Uses inference_get_bicep_rms(BICEP_WINDOW_SAMPLES) internally, so + * inference_add_sample() must have been called to fill the buffer first. + * + * @return BICEP_STATE_FLEX or BICEP_STATE_REST. + */ +bicep_state_t bicep_detect(void); + +/** + * @brief Persist the current threshold to NVS. + * + * @param threshold_mv Threshold value to save (in mV / same units as bicep RMS). + * @return true on success. + */ +bool bicep_save_threshold(float threshold_mv); + +/** + * @brief Load the persisted threshold from NVS. + * + * @param threshold_mv_out Output pointer; untouched on failure. + * @return true if a valid threshold was loaded. + */ +bool bicep_load_threshold(float *threshold_mv_out); + +/** + * @brief Set the detection threshold directly (without NVS save). + */ +void bicep_set_threshold(float threshold_mv); + +/** + * @brief Return the current threshold (0 if not calibrated). + */ +float bicep_get_threshold(void); + +/** + * @brief Calibrate bicep threshold from the filtered inference buffer. + * + * Uses inference_get_bicep_rms() to read from the bandpass-filtered + * circular buffer — the same data source that bicep_detect() uses. + * The old bicep_calibrate() accepts raw uint16_t ADC values, which are + * in a different domain (includes DC offset) and produce unusable thresholds. + * + * Call this after the inference buffer has been filled with ≥ n_samples + * of rest data via inference_add_sample(). + * + * @param n_samples Number of recent buffer samples to use for RMS. + * Clamped to INFERENCE_WINDOW_SIZE internally. + * @return Computed threshold (same units as bicep_detect sees). + */ +float bicep_calibrate_from_buffer(int n_samples); + +#endif /* BICEP_H */ diff --git a/EMG_Arm/src/core/calibration.c b/EMG_Arm/src/core/calibration.c new file mode 100644 index 0000000..9744685 --- /dev/null +++ b/EMG_Arm/src/core/calibration.c @@ -0,0 +1,138 @@ +/** + * @file calibration.c + * @brief NVS-backed z-score feature calibration (Change D). + */ + +#include "calibration.h" +#include "nvs_flash.h" +#include "nvs.h" +#include +#include +#include + +#define NVS_NAMESPACE "emg_calib" +#define NVS_KEY_MEAN "feat_mean" +#define NVS_KEY_STD "feat_std" +#define NVS_KEY_NFEAT "n_feat" +#define NVS_KEY_VALID "calib_ok" + +static float s_mean[CALIB_MAX_FEATURES]; +static float s_std[CALIB_MAX_FEATURES]; +static int s_n_feat = 0; +static bool s_valid = false; + +bool calibration_init(void) { + /* Standard NVS flash initialisation boilerplate */ + esp_err_t err = nvs_flash_init(); + if (err == ESP_ERR_NVS_NO_FREE_PAGES || err == ESP_ERR_NVS_NEW_VERSION_FOUND) { + nvs_flash_erase(); + nvs_flash_init(); + } + + nvs_handle_t h; + if (nvs_open(NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) { + printf("[Calib] No NVS partition found — identity transform active\n"); + return false; + } + + uint8_t valid = 0; + int32_t n_feat = 0; + size_t mean_sz = sizeof(s_mean); + size_t std_sz = sizeof(s_std); + + bool ok = (nvs_get_u8 (h, NVS_KEY_VALID, &valid) == ESP_OK) && + (valid == 1) && + (nvs_get_i32 (h, NVS_KEY_NFEAT, &n_feat) == ESP_OK) && + (n_feat > 0 && n_feat <= CALIB_MAX_FEATURES) && + (nvs_get_blob(h, NVS_KEY_MEAN, s_mean, &mean_sz) == ESP_OK) && + (nvs_get_blob(h, NVS_KEY_STD, s_std, &std_sz) == ESP_OK); + + nvs_close(h); + + if (ok) { + s_n_feat = (int)n_feat; + s_valid = true; + printf("[Calib] Loaded from NVS (%d features)\n", s_n_feat); + } else { + printf("[Calib] No valid calibration in NVS — identity transform active\n"); + } + + return ok; +} + +void calibration_apply(float *feat) { + if (!s_valid) return; + for (int i = 0; i < s_n_feat; i++) { + feat[i] = (feat[i] - s_mean[i]) / s_std[i]; + } +} + +bool calibration_update(const float *X_flat, int n_windows, int n_feat) { + if (n_windows < 10 || n_feat <= 0 || n_feat > CALIB_MAX_FEATURES) { + printf("[Calib] calibration_update: invalid args (%d windows, %d features)\n", + n_windows, n_feat); + return false; + } + + s_n_feat = n_feat; + + /* Compute per-feature mean */ + memset(s_mean, 0, sizeof(s_mean)); + for (int w = 0; w < n_windows; w++) { + for (int f = 0; f < n_feat; f++) { + s_mean[f] += X_flat[w * n_feat + f]; + } + } + for (int f = 0; f < n_feat; f++) { + s_mean[f] /= n_windows; + } + + /* Compute per-feature std (with epsilon floor) */ + memset(s_std, 0, sizeof(s_std)); + for (int w = 0; w < n_windows; w++) { + for (int f = 0; f < n_feat; f++) { + float d = X_flat[w * n_feat + f] - s_mean[f]; + s_std[f] += d * d; + } + } + for (int f = 0; f < n_feat; f++) { + float var = s_std[f] / n_windows; + s_std[f] = (var > 1e-12f) ? sqrtf(var) : 1e-6f; + } + + /* Persist to NVS */ + nvs_handle_t h; + if (nvs_open(NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) { + printf("[Calib] calibration_update: failed to open NVS\n"); + return false; + } + + esp_err_t err = ESP_OK; + err |= nvs_set_blob(h, NVS_KEY_MEAN, s_mean, sizeof(s_mean)); + err |= nvs_set_blob(h, NVS_KEY_STD, s_std, sizeof(s_std)); + err |= nvs_set_i32 (h, NVS_KEY_NFEAT, (int32_t)n_feat); + err |= nvs_set_u8 (h, NVS_KEY_VALID, 1u); + err |= nvs_commit(h); + nvs_close(h); + + if (err != ESP_OK) { + printf("[Calib] calibration_update: NVS write failed (err=0x%x)\n", err); + return false; + } + + s_valid = true; + printf("[Calib] Updated: %d REST windows, %d features saved to NVS\n", + n_windows, n_feat); + return true; +} + +void calibration_reset(void) { + s_valid = false; + s_n_feat = 0; + memset(s_mean, 0, sizeof(s_mean)); + memset(s_std, 0, sizeof(s_std)); +} + +bool calibration_is_valid(void) { + return s_valid; +} diff --git a/EMG_Arm/src/core/calibration.h b/EMG_Arm/src/core/calibration.h new file mode 100644 index 0000000..f20e250 --- /dev/null +++ b/EMG_Arm/src/core/calibration.h @@ -0,0 +1,68 @@ +/** + * @file calibration.h + * @brief NVS-backed z-score feature calibration for on-device EMG inference. + * + * Change D: Stores per-feature mean and std computed from a short REST session + * in ESP32 non-volatile storage (NVS). At inference time, calibration_apply() + * z-scores each feature vector before the LDA classifier sees it. This removes + * day-to-day electrode placement drift without retraining the model. + * + * Typical workflow: + * 1. calibration_init() — called once at startup; loads from NVS + * 2. calibration_update() — called after collecting ~3s of REST windows + * 3. calibration_apply(feat) — called every inference hop in inference.c + */ + +#ifndef CALIBRATION_H +#define CALIBRATION_H + +#include +#include + +/* Maximum supported feature vector length. + * 96 > 69 (expanded) > 12 (legacy) — gives headroom for future expansion. */ +#define CALIB_MAX_FEATURES 96 + +/** + * @brief Initialise NVS and load stored calibration statistics. + * + * Must be called once before any inference starts (e.g., in app_main). + * @return true Calibration data found and loaded. + * @return false No stored data; calibration_apply() will be a no-op. + */ +bool calibration_init(void); + +/** + * @brief Apply stored z-score calibration to a feature vector in-place. + * + * x_i_out = (x_i - mean_i) / std_i + * + * No-op if calibration is not valid (calibration_is_valid() == false). + * @param feat Feature vector of length n_feat (set during calibration_update). + */ +void calibration_apply(float *feat); + +/** + * @brief Compute and persist calibration statistics from REST EMG windows. + * + * Computes per-feature mean and std over n_windows windows, stores to NVS. + * After this call, calibration_apply() uses the new statistics. + * + * @param X_flat Flattened feature array [n_windows × n_feat], row-major. + * @param n_windows Number of windows (minimum 10). + * @param n_feat Feature vector length (≤ CALIB_MAX_FEATURES). + * @return true on success, false if inputs are invalid or NVS write fails. + */ +bool calibration_update(const float *X_flat, int n_windows, int n_feat); + +/** + * @brief Clear calibration state (in-memory only; does not erase NVS). + */ +void calibration_reset(void); + +/** + * @brief Check whether valid calibration statistics are loaded. + */ +bool calibration_is_valid(void); + +#endif /* CALIBRATION_H */ diff --git a/EMG_Arm/src/core/emg_model_data.cc b/EMG_Arm/src/core/emg_model_data.cc new file mode 100644 index 0000000..5a0cf04 --- /dev/null +++ b/EMG_Arm/src/core/emg_model_data.cc @@ -0,0 +1,6 @@ +// Auto-generated by train_mlp_tflite.py do not edit +#include "emg_model_data.h" +const int g_model_len = 6784; +alignas(8) const unsigned char g_model[] = { + 0x20, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x00, 0x00, 0x14, 0x00, 0x20, 0x00, 0x1c, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x98, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x00, 0x00, 0xd4, 0x0d, 0x00, 0x00, 0xe4, 0x0d, 0x00, 0x00, 0x10, 0x1a, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x73, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x90, 0xff, 0xff, 0xff, 0x0a, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x5f, 0x30, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x12, 0xf2, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x6b, 0x65, 0x72, 0x61, 0x73, 0x5f, 0x74, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xdc, 0xff, 0xff, 0xff, 0x0d, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x43, 0x4f, 0x4e, 0x56, 0x45, 0x52, 0x53, 0x49, 0x4f, 0x4e, 0x5f, 0x4d, 0x45, 0x54, 0x41, 0x44, 0x41, 0x54, 0x41, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00, 0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x0e, 0x00, 0x00, 0x00, 0xe0, 0x0c, 0x00, 0x00, 0xd8, 0x0c, 0x00, 0x00, 0xb4, 0x0c, 0x00, 0x00, 0x54, 0x0c, 0x00, 0x00, 0x04, 0x0c, 0x00, 0x00, 0xf4, 0x09, 0x00, 0x00, 0x64, 0x09, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0xac, 0x00, 0x00, 0x00, 0xa4, 0x00, 0x00, 0x00, 0x9c, 0x00, 0x00, 0x00, 0x94, 0x00, 0x00, 0x00, 0x74, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xc6, 0xf2, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x60, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x0e, 0x00, 0x08, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xeb, 0x03, 0x00, 0x00, 0x0c, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0xce, 0xae, 0xc0, 0x93, 0x5a, 0x4f, 0x9c, 0xbf, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x32, 0x2e, 0x32, 0x31, 0x2e, 0x30, 0x00, 0x00, 0x32, 0xf3, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x31, 0x2e, 0x31, 0x34, 0x2e, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x1c, 0xf3, 0xff, 0xff, 0x20, 0xf3, 0xff, 0xff, 0x24, 0xf3, 0xff, 0xff, 0x28, 0xf3, 0xff, 0xff, 0x5e, 0xf3, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0xa0, 0x08, 0x00, 0x00, 0xed, 0xe8, 0x10, 0xdd, 0x36, 0xf0, 0x1d, 0xd1, 0xb3, 0xba, 0x34, 0x19, 0x2f, 0xf0, 0xf5, 0xa2, 0x96, 0xb0, 0xd7, 0xf5, 0x34, 0xfe, 0xfb, 0xed, 0x17, 0xd0, 0x38, 0xf4, 0xde, 0xf3, 0x05, 0x20, 0x47, 0x01, 0xee, 0xcc, 0x37, 0xe5, 0xf3, 0x46, 0xfa, 0x13, 0x16, 0xfd, 0xef, 0xeb, 0x31, 0x0c, 0x9e, 0x93, 0xe8, 0x13, 0x26, 0xf2, 0xfb, 0xe5, 0x18, 0x16, 0xc9, 0xe8, 0xd4, 0xde, 0x3a, 0x05, 0x7f, 0xbf, 0xdd, 0xb7, 0x46, 0x04, 0x38, 0xfc, 0xcd, 0xf4, 0x38, 0x12, 0x02, 0xff, 0x03, 0x18, 0x06, 0xfb, 0x20, 0x03, 0x18, 0xf4, 0xe5, 0xf7, 0xf6, 0x7f, 0x94, 0x1b, 0x3c, 0xfc, 0x1e, 0x1a, 0x01, 0x24, 0x41, 0x34, 0x24, 0x49, 0x1d, 0xff, 0x5b, 0xfd, 0x00, 0x19, 0x00, 0xf4, 0x7f, 0xe5, 0x17, 0xc5, 0x1f, 0xc7, 0xf2, 0x29, 0xfa, 0xf0, 0xfe, 0xdd, 0xc9, 0xf8, 0x2d, 0x08, 0x13, 0x1b, 0x36, 0xd4, 0x4d, 0x24, 0x03, 0xa0, 0xee, 0xb4, 0x44, 0x38, 0x08, 0xe5, 0x02, 0xfe, 0x09, 0xe1, 0x3a, 0xfd, 0x03, 0x81, 0xc3, 0xd8, 0xe7, 0xe3, 0x08, 0xdb, 0xe4, 0xdf, 0x06, 0x16, 0xda, 0x5e, 0xfe, 0x12, 0x2a, 0x0c, 0x47, 0xfb, 0xe5, 0x9f, 0xce, 0xf3, 0xc2, 0xe3, 0x0a, 0xde, 0x15, 0x28, 0x01, 0xf0, 0xa5, 0x07, 0x0e, 0x00, 0x40, 0xa3, 0x55, 0xf9, 0x20, 0x41, 0x19, 0xed, 0x43, 0x08, 0xfd, 0xd7, 0x0d, 0x1e, 0x0d, 0xbc, 0xe2, 0xcd, 0xf9, 0xea, 0x22, 0x1d, 0x0b, 0xc1, 0xfd, 0xc0, 0x1c, 0x17, 0xf0, 0xfb, 0x6f, 0xd2, 0xf9, 0xd1, 0x0e, 0x1f, 0x09, 0xdf, 0xf4, 0x00, 0x19, 0xa5, 0xb2, 0x33, 0xed, 0x13, 0x86, 0x13, 0x15, 0x60, 0x8b, 0x6d, 0x22, 0x14, 0x1f, 0x29, 0x40, 0x3d, 0x30, 0x0b, 0x94, 0x32, 0x53, 0xe3, 0x31, 0x46, 0x6e, 0x2b, 0x06, 0xf0, 0x0f, 0xd2, 0xfc, 0xff, 0xba, 0x09, 0x0b, 0xa0, 0x30, 0x05, 0xfd, 0x24, 0x1b, 0x1d, 0x53, 0x26, 0x19, 0x91, 0x3a, 0x9b, 0xef, 0x18, 0x81, 0x36, 0x30, 0xf8, 0x05, 0x0c, 0xfb, 0x42, 0xd8, 0xed, 0xeb, 0xf2, 0x27, 0x53, 0xe3, 0x07, 0xf3, 0xe2, 0x54, 0x29, 0x0f, 0x15, 0x38, 0xdc, 0xf6, 0x08, 0x0a, 0x15, 0xe9, 0x09, 0x93, 0x23, 0x29, 0x2a, 0xc5, 0x13, 0xf8, 0x0e, 0x1b, 0x0e, 0x06, 0xec, 0x66, 0xfd, 0xfb, 0x19, 0xa7, 0xfe, 0xc2, 0xf8, 0x4a, 0x04, 0xe8, 0xe6, 0x3b, 0xf8, 0x09, 0xfd, 0x0d, 0xf1, 0xf4, 0xed, 0xef, 0x14, 0x17, 0xed, 0xf2, 0x24, 0x07, 0x81, 0xec, 0xda, 0x7f, 0xe5, 0x05, 0x91, 0x39, 0xa2, 0xf4, 0xf4, 0x15, 0x18, 0x16, 0xc3, 0xff, 0x05, 0xfb, 0x40, 0x44, 0x3b, 0x29, 0x17, 0xc4, 0xee, 0xe8, 0xb6, 0x3a, 0xcd, 0xef, 0xe6, 0xb6, 0xfb, 0xfd, 0x1f, 0x17, 0xfd, 0x5d, 0xb4, 0xfa, 0x03, 0x1f, 0xde, 0xac, 0xe9, 0x29, 0xcd, 0xf3, 0xcd, 0x01, 0x1b, 0xbf, 0xc7, 0xdc, 0x33, 0xce, 0x13, 0xf9, 0xef, 0x1b, 0xf8, 0x06, 0x04, 0x1b, 0x0a, 0x27, 0x12, 0xe1, 0x18, 0xbb, 0xc9, 0xfd, 0x06, 0x08, 0x35, 0x20, 0x7f, 0x3a, 0xf9, 0x56, 0xed, 0xf1, 0xed, 0x12, 0x0e, 0x09, 0x17, 0x1a, 0xf5, 0xeb, 0x1f, 0x42, 0x1c, 0x07, 0xd6, 0x16, 0xdf, 0x08, 0x1d, 0x06, 0xe8, 0x19, 0x30, 0xf4, 0xf4, 0x04, 0xf4, 0xea, 0xf8, 0xc4, 0xeb, 0x50, 0xe4, 0xe5, 0x31, 0xf2, 0x11, 0xf9, 0x14, 0xec, 0xce, 0xe9, 0xdf, 0xf2, 0xfd, 0xbd, 0x0f, 0x25, 0x26, 0xf6, 0x2e, 0x57, 0x3c, 0xd1, 0x19, 0xdb, 0xea, 0xd0, 0xc5, 0x1e, 0x0b, 0xa5, 0x09, 0x25, 0xcc, 0xbb, 0xb6, 0xf8, 0xc8, 0xfe, 0x16, 0x06, 0x0c, 0xf2, 0xf9, 0xb6, 0xed, 0xfb, 0x20, 0xcf, 0x18, 0xb2, 0x16, 0x09, 0xda, 0x81, 0xd8, 0xed, 0x0f, 0x02, 0xf2, 0x04, 0x14, 0xf5, 0x04, 0xc5, 0x05, 0xfd, 0xfc, 0x1e, 0x18, 0x06, 0x0b, 0x10, 0xab, 0x32, 0xa1, 0xe1, 0xce, 0x04, 0x04, 0xf2, 0xe6, 0x0c, 0xf0, 0x2e, 0x1b, 0x07, 0x0b, 0xf9, 0xff, 0x43, 0xfa, 0xf9, 0xe3, 0x02, 0x25, 0xe6, 0xe0, 0xbd, 0xbd, 0x05, 0x05, 0xfb, 0xe2, 0xea, 0x0a, 0x24, 0xeb, 0xec, 0x06, 0xe8, 0xf5, 0xff, 0xdd, 0x07, 0x1a, 0x1f, 0x12, 0xe2, 0x57, 0x06, 0x0f, 0x1b, 0x44, 0x15, 0x19, 0x38, 0x07, 0xe8, 0xde, 0x1f, 0x1f, 0x00, 0x18, 0x0a, 0xe4, 0xce, 0xfd, 0x02, 0xbe, 0xee, 0xef, 0x14, 0x29, 0xfc, 0x19, 0x06, 0x51, 0x16, 0x06, 0x49, 0xcb, 0x0d, 0x13, 0xf9, 0x23, 0x33, 0xf7, 0x0a, 0xe0, 0x0f, 0xea, 0x7f, 0xfc, 0x32, 0x9e, 0xd0, 0x16, 0x27, 0x15, 0xc4, 0xeb, 0xda, 0xff, 0x09, 0xe9, 0x00, 0x03, 0xef, 0xe0, 0x00, 0x03, 0xc5, 0x4c, 0x55, 0xf8, 0xec, 0xd6, 0xe9, 0xf4, 0xfd, 0xbd, 0x3f, 0xc9, 0xfe, 0x25, 0x25, 0x08, 0xe9, 0x37, 0x15, 0x03, 0x35, 0xf0, 0xf7, 0xe3, 0xf3, 0xec, 0x81, 0x18, 0xe1, 0x35, 0xb0, 0x2b, 0x20, 0xdf, 0x38, 0x26, 0x1e, 0x41, 0xcc, 0x05, 0xd9, 0x17, 0xf2, 0x06, 0xe0, 0xf8, 0xa1, 0x03, 0xf4, 0x4b, 0x0a, 0x22, 0x18, 0xeb, 0xcf, 0xd7, 0xf3, 0xf6, 0x27, 0xff, 0x14, 0x0c, 0x0d, 0x18, 0x07, 0xdb, 0x33, 0xfe, 0xff, 0xfc, 0xdc, 0xec, 0xf0, 0xf2, 0x1b, 0xf9, 0xd4, 0xf0, 0x11, 0xa0, 0x16, 0xfa, 0xf1, 0xe7, 0xd7, 0xe6, 0xf5, 0x10, 0xef, 0x81, 0x1a, 0xf2, 0x17, 0x09, 0x3b, 0xcf, 0xfd, 0x0e, 0xce, 0x2c, 0xcc, 0x0d, 0xdf, 0xdc, 0x10, 0x18, 0x0e, 0x10, 0x05, 0x25, 0x45, 0x3e, 0x18, 0x0e, 0x28, 0x05, 0xed, 0xd0, 0xdd, 0x23, 0x43, 0xa1, 0xf5, 0x1b, 0xcc, 0x0d, 0xf4, 0xe2, 0xfb, 0xc8, 0xfa, 0xef, 0xdd, 0x02, 0xf9, 0xeb, 0x0f, 0x00, 0xde, 0x17, 0x24, 0x24, 0xf8, 0xe1, 0x07, 0xf5, 0x30, 0xd7, 0xea, 0xd5, 0xea, 0x13, 0xe8, 0xe7, 0x04, 0x37, 0xeb, 0x0c, 0xde, 0xf6, 0xf0, 0xd3, 0x0c, 0x96, 0x0e, 0xe8, 0xe2, 0x1a, 0xf0, 0x2b, 0xf8, 0x0e, 0xfb, 0xe0, 0xe9, 0x03, 0x0a, 0xe3, 0xe8, 0xfc, 0xcf, 0xdb, 0x05, 0x04, 0x05, 0xfb, 0x12, 0x5c, 0xeb, 0x33, 0x7f, 0xd5, 0x37, 0x81, 0x00, 0xff, 0xbb, 0xe3, 0xd1, 0x00, 0xf3, 0xf3, 0x04, 0x0d, 0x2f, 0x12, 0xf1, 0xc3, 0x07, 0x06, 0xf7, 0xfd, 0x0a, 0xa8, 0x02, 0xfd, 0xd9, 0x0c, 0xe0, 0xfd, 0x09, 0x1b, 0x09, 0x07, 0x0f, 0xeb, 0xfe, 0x1d, 0xf1, 0x16, 0x34, 0x09, 0x2b, 0xdb, 0xfd, 0xfe, 0xd6, 0xdf, 0xef, 0x06, 0xe3, 0xc6, 0xee, 0xf3, 0x09, 0xfe, 0x00, 0xdf, 0x01, 0xf9, 0xfc, 0xfc, 0xe6, 0x2f, 0x10, 0xf2, 0x0b, 0x09, 0xfd, 0xcf, 0xf0, 0xc2, 0x3f, 0xff, 0x04, 0x35, 0xf5, 0x37, 0x02, 0x10, 0xea, 0xf4, 0x10, 0x14, 0xd8, 0xff, 0xdf, 0x04, 0x19, 0x10, 0x16, 0xc3, 0x2d, 0x0f, 0xe2, 0x0d, 0xf1, 0x39, 0x18, 0x1b, 0x08, 0xfe, 0xe8, 0x1d, 0xfb, 0x18, 0xfe, 0x02, 0xad, 0xd0, 0xc0, 0x13, 0x81, 0x0b, 0xdc, 0x2b, 0xd9, 0x49, 0x05, 0xe4, 0xed, 0x15, 0x1e, 0x4e, 0x0c, 0x0a, 0xd3, 0xe9, 0xe4, 0xd8, 0xea, 0xe2, 0xe3, 0xe9, 0x00, 0x58, 0x02, 0xfe, 0xba, 0xda, 0xb4, 0x68, 0xec, 0x10, 0x46, 0x4a, 0x26, 0x04, 0xe0, 0xd8, 0xfa, 0x1c, 0xcd, 0xe5, 0x08, 0x14, 0x29, 0x38, 0x16, 0x14, 0xbd, 0x27, 0x08, 0x12, 0x7a, 0xcb, 0x4a, 0xdc, 0x14, 0xec, 0x03, 0xfc, 0xaf, 0xf2, 0x06, 0xb7, 0x55, 0x33, 0x33, 0x10, 0xf3, 0x59, 0xfc, 0xed, 0x44, 0x30, 0x59, 0x04, 0xd4, 0xc2, 0xdd, 0xbc, 0x87, 0x19, 0x16, 0x21, 0xd8, 0xdf, 0x46, 0x21, 0xef, 0x45, 0x05, 0x3b, 0xb3, 0xd5, 0xea, 0x7f, 0xe1, 0x25, 0xf6, 0x10, 0x00, 0xc3, 0x27, 0xda, 0x0c, 0xf5, 0xf8, 0xf4, 0xe9, 0x85, 0xe8, 0xeb, 0x11, 0x4a, 0x70, 0x65, 0x23, 0x0d, 0xb2, 0x04, 0xe1, 0x25, 0x01, 0xd8, 0x17, 0xde, 0xdd, 0xe9, 0x10, 0x00, 0xee, 0x0b, 0xa9, 0xe7, 0x15, 0xf6, 0x35, 0xa3, 0x81, 0xff, 0x06, 0x1f, 0xd5, 0x31, 0xfc, 0x5a, 0x4e, 0xde, 0xc9, 0x10, 0x21, 0x0a, 0xa4, 0xcc, 0xbd, 0xff, 0xc9, 0x3e, 0xdd, 0xd7, 0xe7, 0x4e, 0x10, 0x77, 0x56, 0xc9, 0xb2, 0xec, 0xcb, 0xfe, 0x37, 0x29, 0x3a, 0x11, 0xc8, 0x12, 0x09, 0x00, 0x2a, 0x17, 0xf7, 0x1c, 0x14, 0xcc, 0x0f, 0xf0, 0xbf, 0x58, 0xe5, 0x0f, 0x00, 0x17, 0x0f, 0xe8, 0x43, 0x6a, 0x13, 0x07, 0x05, 0x46, 0x10, 0x11, 0xdc, 0x92, 0xc6, 0x1d, 0x06, 0x0a, 0xee, 0x14, 0x3b, 0xf9, 0x36, 0xfc, 0xb1, 0x0f, 0x27, 0x25, 0x1f, 0x03, 0x11, 0x09, 0xe6, 0xdc, 0x00, 0xeb, 0x17, 0x81, 0xf5, 0xe7, 0xa8, 0x20, 0xd6, 0x1e, 0xe9, 0xc2, 0x81, 0xfd, 0xf7, 0xf5, 0xe4, 0x20, 0x08, 0xf6, 0x36, 0xfd, 0xbe, 0xfc, 0x1e, 0xe1, 0xf6, 0xfa, 0xff, 0x9c, 0x33, 0xcf, 0xdb, 0xfc, 0x0e, 0xd9, 0x25, 0xbe, 0xf9, 0xf6, 0x3e, 0x0e, 0xef, 0x0c, 0x2d, 0x1e, 0x52, 0xc3, 0xd6, 0xf2, 0x40, 0xb0, 0xfd, 0xe8, 0xe3, 0xe4, 0xf4, 0xe9, 0xed, 0x00, 0xf3, 0x08, 0xf5, 0x54, 0xf7, 0x0c, 0xef, 0xbf, 0xd8, 0xeb, 0xd4, 0x09, 0xba, 0x08, 0xfa, 0xcb, 0xe9, 0x0d, 0xda, 0xe1, 0x06, 0xed, 0xf4, 0xdc, 0xb7, 0xdb, 0xb3, 0x01, 0xcd, 0xd5, 0xf1, 0x05, 0xda, 0x20, 0xeb, 0xf8, 0x5e, 0x2c, 0x32, 0x11, 0xf2, 0xa6, 0x1b, 0x10, 0xc7, 0xba, 0xe7, 0xfc, 0xe3, 0xec, 0x01, 0x15, 0xeb, 0xf0, 0x08, 0x92, 0xe7, 0x18, 0xf4, 0xfa, 0xa3, 0x10, 0xfd, 0xd6, 0xc8, 0x28, 0xf2, 0xf9, 0x26, 0x4a, 0x16, 0x17, 0x20, 0x14, 0x09, 0x07, 0xe3, 0xb9, 0x97, 0xe7, 0x01, 0xed, 0x03, 0xe5, 0x28, 0x0d, 0xf8, 0x7f, 0xd3, 0xed, 0x02, 0xee, 0xfb, 0xbb, 0xb6, 0xc7, 0x0e, 0xe8, 0x24, 0x09, 0x1b, 0xe6, 0x0f, 0xfe, 0x81, 0xf4, 0x0f, 0xf3, 0xe9, 0x20, 0xa5, 0x0a, 0x3a, 0xaf, 0xd1, 0xbb, 0xf4, 0x14, 0x43, 0x33, 0x34, 0x38, 0x41, 0xf2, 0x02, 0xbc, 0xe6, 0xc0, 0xd3, 0xb3, 0xa5, 0xfd, 0x05, 0xe5, 0x31, 0xf4, 0x16, 0xc3, 0xca, 0xf5, 0xcc, 0x2e, 0xda, 0x03, 0x2b, 0x0b, 0xbb, 0xfa, 0xe9, 0xe2, 0x28, 0x0d, 0x20, 0x27, 0xe1, 0xd7, 0x4e, 0xfa, 0xef, 0x3d, 0xf4, 0x25, 0xb8, 0x5e, 0xd2, 0x07, 0xb4, 0xe4, 0xf9, 0x04, 0xa5, 0x01, 0x07, 0x24, 0x1f, 0x3c, 0x66, 0xe0, 0x08, 0x81, 0x02, 0xdd, 0xc7, 0xf2, 0xb7, 0x01, 0xd0, 0xcc, 0x04, 0xe0, 0x0f, 0xfe, 0x0e, 0xdc, 0xf1, 0x23, 0xe9, 0xd9, 0xe8, 0x9a, 0xec, 0x52, 0x14, 0x1e, 0xfd, 0xe4, 0x18, 0xb7, 0xd9, 0xea, 0x76, 0xee, 0xea, 0x28, 0xd4, 0xed, 0xfc, 0x36, 0xeb, 0x39, 0xf1, 0xbd, 0x0a, 0x14, 0xfa, 0x5b, 0xdd, 0xbc, 0x39, 0x05, 0xe5, 0x14, 0xdd, 0x14, 0x01, 0xbd, 0x04, 0x18, 0x21, 0xe7, 0xd8, 0x0f, 0xdc, 0xe1, 0x24, 0xf2, 0x0b, 0xad, 0xe2, 0x27, 0xec, 0x02, 0xc9, 0x19, 0x12, 0xfa, 0xf4, 0x22, 0x13, 0x14, 0x0c, 0xfc, 0x81, 0xeb, 0xf2, 0xdc, 0x9b, 0xfb, 0x0e, 0xd2, 0x20, 0xef, 0x4c, 0xc8, 0x0c, 0x0f, 0xb8, 0xa6, 0xab, 0xdb, 0xe5, 0x0f, 0x64, 0x03, 0x12, 0x13, 0x57, 0xc1, 0x19, 0x1f, 0x27, 0x09, 0xcf, 0xc8, 0x43, 0x0b, 0xea, 0x51, 0x16, 0x03, 0x41, 0xd8, 0x33, 0x00, 0xdd, 0xc3, 0xfc, 0xf0, 0xe0, 0x03, 0x02, 0xdc, 0xf1, 0xe9, 0xeb, 0xc1, 0xd5, 0x2c, 0xfb, 0x0e, 0x0b, 0xfb, 0x1e, 0xec, 0xca, 0xc6, 0xf1, 0xff, 0xd2, 0x20, 0x01, 0xe8, 0x11, 0x01, 0xf6, 0xe8, 0x0b, 0xfe, 0x05, 0x1a, 0x24, 0xdf, 0x20, 0xfb, 0xea, 0x81, 0xe8, 0x27, 0xd9, 0x0c, 0xf6, 0xe4, 0x52, 0x3b, 0x34, 0x0d, 0xff, 0xf9, 0x07, 0x21, 0xe5, 0x01, 0xe9, 0x1d, 0x1a, 0xf2, 0xa0, 0x0e, 0x00, 0xb2, 0x7f, 0xb1, 0x07, 0x06, 0x2d, 0x22, 0x11, 0x2c, 0x0b, 0x01, 0x5e, 0xff, 0xf4, 0x09, 0xe5, 0x52, 0xcf, 0x11, 0x16, 0xc8, 0x14, 0xf2, 0x03, 0x13, 0x2b, 0x03, 0xf8, 0x57, 0xe9, 0x01, 0x18, 0x00, 0x01, 0x10, 0x0f, 0x40, 0xc3, 0xff, 0x31, 0xec, 0xbc, 0x03, 0xf5, 0x0f, 0xde, 0xee, 0x06, 0xe4, 0x43, 0xf4, 0xc7, 0x03, 0xfa, 0xf8, 0x2c, 0xdc, 0xc8, 0x25, 0xec, 0xf7, 0x2a, 0x2c, 0xa6, 0xd0, 0x33, 0x6b, 0x0f, 0xc1, 0xf9, 0x7f, 0xf0, 0xfc, 0xad, 0xe9, 0x2f, 0x34, 0xfb, 0xb3, 0xff, 0x21, 0x2f, 0x60, 0x53, 0x0f, 0x6b, 0xa8, 0x1c, 0x13, 0x17, 0x15, 0x38, 0xdd, 0xf1, 0x1b, 0x44, 0x0f, 0x5b, 0x10, 0xea, 0x66, 0x2a, 0x4f, 0x11, 0x09, 0xb7, 0x78, 0xe7, 0xf6, 0x13, 0x6a, 0x0b, 0xd6, 0xe8, 0xcc, 0xb5, 0xf5, 0x1c, 0x87, 0xfe, 0x75, 0xed, 0x0f, 0x31, 0x77, 0xda, 0x78, 0x15, 0x06, 0xf2, 0xf7, 0xcc, 0x54, 0x05, 0x46, 0x81, 0x05, 0xf0, 0xd6, 0xcb, 0x03, 0x03, 0x15, 0xaf, 0xd3, 0xe2, 0x01, 0x02, 0x01, 0xda, 0x19, 0x0a, 0x1a, 0x3c, 0x33, 0xe1, 0x11, 0xce, 0xda, 0xf7, 0xcc, 0xf3, 0x27, 0xeb, 0xe9, 0xee, 0x26, 0xee, 0xf8, 0xcf, 0x0c, 0x1f, 0xee, 0x26, 0x19, 0x9d, 0x0a, 0x28, 0xf3, 0x0b, 0x08, 0x00, 0x31, 0x00, 0xea, 0xf3, 0x3f, 0x03, 0x02, 0xfd, 0x0a, 0xdf, 0xeb, 0xdc, 0xed, 0x0b, 0x08, 0xe2, 0x58, 0x20, 0x0e, 0xd9, 0xed, 0x36, 0xc4, 0x0d, 0xff, 0x0d, 0xb4, 0x22, 0xf9, 0x19, 0x26, 0xef, 0xe4, 0x14, 0x04, 0x03, 0xbc, 0xed, 0xe3, 0x1f, 0x05, 0xa4, 0x2f, 0x23, 0xf9, 0x1d, 0x4a, 0x27, 0x0c, 0x40, 0x25, 0xeb, 0xba, 0xe4, 0xe6, 0xef, 0x30, 0x07, 0x00, 0x40, 0xe7, 0xd5, 0x7f, 0xfe, 0xd3, 0xf2, 0xd0, 0xf6, 0xf0, 0x0b, 0xce, 0xd7, 0xf1, 0xb8, 0xed, 0xf4, 0xa4, 0x20, 0x1f, 0x0e, 0x14, 0x19, 0x4f, 0x11, 0xf1, 0xd1, 0x02, 0xd2, 0xf2, 0x23, 0x60, 0xac, 0x14, 0x03, 0x2e, 0x19, 0x0f, 0x12, 0x08, 0x81, 0xc6, 0xf5, 0x12, 0xdb, 0xf9, 0x1a, 0x30, 0x2d, 0xed, 0x11, 0x55, 0xeb, 0xfb, 0xe9, 0x4f, 0xbd, 0x53, 0x00, 0x0c, 0xc9, 0xe9, 0xf5, 0x19, 0x24, 0x04, 0x8f, 0x22, 0xe0, 0x15, 0x21, 0xe4, 0x0c, 0xf6, 0xf9, 0x5d, 0xb7, 0x4b, 0x00, 0x2b, 0x3c, 0x0f, 0xf4, 0x38, 0xce, 0xfa, 0xbd, 0xdd, 0xe2, 0x0e, 0xc7, 0x0d, 0xed, 0xd6, 0xfa, 0x64, 0x1f, 0x23, 0x56, 0x41, 0xec, 0x0f, 0x17, 0x12, 0xe0, 0x05, 0xec, 0xf0, 0x18, 0xca, 0xec, 0xeb, 0xba, 0xfe, 0x09, 0x31, 0xe7, 0x08, 0x13, 0xfa, 0xb3, 0x51, 0x06, 0xfc, 0xec, 0x05, 0xea, 0xd3, 0xd4, 0xc7, 0xd9, 0xe0, 0xf8, 0xf9, 0xff, 0xbe, 0x1e, 0x08, 0x1d, 0xfd, 0x81, 0x41, 0x00, 0x14, 0x32, 0xfa, 0x1c, 0xe5, 0xeb, 0xd8, 0xe7, 0xd6, 0xca, 0xe1, 0xf6, 0x0d, 0x17, 0x19, 0x09, 0x0c, 0x22, 0xc8, 0x26, 0xf1, 0xf3, 0xf0, 0x02, 0x1f, 0xf4, 0xcd, 0x81, 0xfe, 0xe8, 0xf2, 0xf7, 0x1d, 0x04, 0xf1, 0x2c, 0x21, 0xff, 0x0f, 0x04, 0x11, 0x02, 0x02, 0x09, 0x02, 0xe3, 0xe6, 0xec, 0x09, 0xdc, 0x0d, 0xf4, 0x03, 0x04, 0x26, 0x2e, 0x04, 0xea, 0x2b, 0x01, 0xf9, 0xe4, 0xe3, 0x03, 0x04, 0x1e, 0xb3, 0x23, 0xf1, 0x05, 0xf4, 0xa9, 0xfa, 0x0f, 0xf5, 0xe8, 0xfd, 0xf9, 0xe0, 0x10, 0xfc, 0xa7, 0xea, 0xef, 0x03, 0xfe, 0x02, 0xa7, 0xf8, 0x18, 0xd6, 0xca, 0x11, 0x92, 0xe6, 0x00, 0x8e, 0x0c, 0xfd, 0x12, 0xd7, 0x21, 0xfd, 0x16, 0xfe, 0xd8, 0xce, 0x0c, 0xfb, 0xf6, 0x00, 0xec, 0xe2, 0xfd, 0xe8, 0x7f, 0x0f, 0x02, 0xf8, 0xee, 0xfe, 0xfc, 0x06, 0xf4, 0x22, 0x0e, 0xe0, 0xed, 0xee, 0xfa, 0x07, 0xb1, 0xf1, 0x21, 0xf1, 0x00, 0xe0, 0x1e, 0x39, 0xde, 0x2f, 0xec, 0xf6, 0x3c, 0xd5, 0xdd, 0xe8, 0xeb, 0x14, 0xf7, 0x33, 0x10, 0x10, 0x11, 0xf1, 0xe5, 0xf3, 0x2d, 0xe6, 0xa6, 0x0e, 0xd8, 0x20, 0x1e, 0xdc, 0x68, 0x00, 0x0c, 0x0a, 0x32, 0xf6, 0xfd, 0x0f, 0x1e, 0x02, 0xf2, 0xe2, 0xd7, 0x01, 0x52, 0x05, 0x20, 0x24, 0x06, 0xf0, 0x42, 0x12, 0x1b, 0x06, 0x12, 0xdd, 0xed, 0x33, 0xea, 0xf5, 0xfe, 0x00, 0xf0, 0xfe, 0x38, 0xed, 0x01, 0xfb, 0x0c, 0xad, 0x11, 0xf4, 0x25, 0x10, 0x16, 0xfc, 0xfc, 0xf2, 0xa2, 0xbe, 0xd7, 0xfd, 0xf8, 0x04, 0x04, 0xee, 0xf2, 0xfd, 0x0b, 0xe1, 0xe7, 0x0c, 0x03, 0x81, 0xef, 0x0d, 0x0f, 0xe4, 0x0a, 0xfc, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x04, 0xff, 0xff, 0xff, 0xa3, 0xfe, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x9a, 0xff, 0xff, 0xff, 0x4c, 0xfe, 0xff, 0xff, 0x51, 0xfc, 0xff, 0xff, 0x6e, 0xfc, 0xff, 0xff, 0x32, 0xf4, 0xff, 0xff, 0x26, 0xfe, 0xff, 0xff, 0xc5, 0xf5, 0xff, 0xff, 0xf5, 0xfb, 0xff, 0xff, 0xdd, 0xf9, 0xff, 0xff, 0x81, 0xfd, 0xff, 0xff, 0xbc, 0xfd, 0xff, 0xff, 0x76, 0xfc, 0xff, 0xff, 0xd1, 0xee, 0xff, 0xff, 0xa4, 0xf6, 0xff, 0xff, 0x0c, 0xfa, 0xff, 0xff, 0x75, 0xf8, 0xff, 0xff, 0x0d, 0xf9, 0xff, 0xff, 0x5e, 0xf9, 0xff, 0xff, 0xf2, 0xfe, 0xff, 0xff, 0xce, 0xfe, 0xff, 0xff, 0x02, 0xfb, 0xff, 0xff, 0x35, 0xf7, 0xff, 0xff, 0x4e, 0xfc, 0xff, 0xff, 0x53, 0xfd, 0xff, 0xff, 0xa1, 0xfa, 0xff, 0xff, 0x97, 0xfe, 0xff, 0xff, 0x56, 0xf8, 0xff, 0xff, 0x76, 0xfb, 0xff, 0xff, 0x92, 0xfb, 0xff, 0xff, 0x96, 0xfc, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x09, 0xc9, 0xea, 0x14, 0x16, 0x81, 0x18, 0x0c, 0xce, 0xfa, 0x07, 0x0f, 0x15, 0xe2, 0x3d, 0x06, 0x0f, 0x0d, 0x05, 0x0e, 0x05, 0x1a, 0x19, 0xe3, 0xf0, 0xf8, 0x10, 0xf6, 0x29, 0x11, 0x14, 0xd2, 0xf9, 0xf7, 0xfb, 0x00, 0x04, 0xf2, 0x01, 0x81, 0x05, 0x15, 0xb5, 0x01, 0x03, 0xfe, 0xe5, 0x20, 0xec, 0xf8, 0x08, 0xf7, 0xe9, 0xff, 0xfe, 0x14, 0xf4, 0x08, 0xe7, 0xfb, 0x0d, 0xef, 0xfa, 0x32, 0x08, 0x29, 0xf2, 0x42, 0x26, 0x04, 0x3a, 0xf6, 0x02, 0xe0, 0x15, 0xad, 0xec, 0xf1, 0x11, 0x81, 0x01, 0xff, 0xf9, 0x17, 0xec, 0xf3, 0x39, 0x06, 0x09, 0xfa, 0xf8, 0xda, 0x02, 0xfc, 0xf7, 0x05, 0xf4, 0x02, 0x10, 0x09, 0xbd, 0x05, 0xca, 0xb1, 0x07, 0x9f, 0xdf, 0xc3, 0x00, 0x81, 0x0d, 0xb5, 0x09, 0x06, 0xeb, 0xf3, 0xae, 0xb4, 0xf5, 0x00, 0x04, 0xd9, 0x20, 0xee, 0x17, 0x06, 0x02, 0x04, 0x09, 0xec, 0x08, 0x05, 0xf5, 0x00, 0xfc, 0x05, 0x13, 0xfb, 0xe7, 0xff, 0xea, 0xfa, 0xdb, 0xfd, 0xd3, 0xd1, 0xff, 0x03, 0x01, 0x07, 0x11, 0xea, 0x81, 0xfe, 0x05, 0x01, 0x0a, 0xfa, 0xdd, 0x93, 0x06, 0x0f, 0xa6, 0x2b, 0xe8, 0x29, 0x01, 0xf9, 0xfa, 0xf3, 0x16, 0xe8, 0x34, 0x2e, 0x45, 0x06, 0x08, 0xf4, 0xfa, 0x9a, 0x13, 0x25, 0xf8, 0x7f, 0x10, 0xfa, 0x0b, 0xf0, 0x27, 0xf9, 0x06, 0xf7, 0x2a, 0xfd, 0xf0, 0xed, 0x06, 0x28, 0xf8, 0x17, 0x63, 0x31, 0xe6, 0x7f, 0x49, 0x22, 0xfe, 0x26, 0x3d, 0x2d, 0x10, 0x38, 0x6f, 0x28, 0x95, 0x70, 0x13, 0x0c, 0xfe, 0xd9, 0xfe, 0x1f, 0x43, 0x02, 0x20, 0xf1, 0x53, 0x0a, 0xfd, 0xfa, 0xff, 0x16, 0x3e, 0x25, 0x17, 0x81, 0x0c, 0x15, 0x09, 0x6f, 0x03, 0x05, 0x09, 0xc9, 0xfe, 0x04, 0xfd, 0x66, 0xfb, 0x69, 0xed, 0x76, 0xd7, 0x0c, 0x12, 0xf6, 0x0a, 0x0d, 0xed, 0xec, 0x36, 0x2a, 0x64, 0xa8, 0xe9, 0x36, 0xb9, 0x49, 0xab, 0xf3, 0x0c, 0x1e, 0x81, 0x9f, 0x0e, 0xf1, 0xf2, 0xd6, 0x3a, 0xde, 0x11, 0x22, 0x09, 0x1d, 0xff, 0x81, 0xc8, 0xfd, 0x24, 0xde, 0x1c, 0x12, 0xfb, 0x96, 0x01, 0xf3, 0xc2, 0x05, 0x09, 0xe3, 0x47, 0x07, 0xfe, 0xdf, 0xc9, 0xe4, 0xf3, 0xdd, 0xbb, 0xfa, 0x19, 0x29, 0xcc, 0x16, 0x0a, 0x1b, 0x3c, 0xcd, 0xf6, 0x81, 0xf9, 0x5e, 0x09, 0x41, 0xec, 0xeb, 0xe6, 0xe0, 0x12, 0xae, 0xed, 0xf4, 0xde, 0xdf, 0xf9, 0xf5, 0xed, 0xfa, 0xf6, 0xff, 0xcc, 0xf3, 0x20, 0x7f, 0xf9, 0xa5, 0x0d, 0xfc, 0x02, 0x1c, 0x19, 0xf6, 0xeb, 0x01, 0xf9, 0x08, 0xed, 0xf8, 0x09, 0xed, 0x13, 0x81, 0xe9, 0xdc, 0x0b, 0xcf, 0x27, 0xdf, 0x0f, 0x08, 0xa5, 0xe9, 0xfc, 0xef, 0xf9, 0x1f, 0xfc, 0xf7, 0x02, 0xa9, 0x04, 0xf4, 0x18, 0x07, 0x2f, 0xc3, 0x18, 0x04, 0xf1, 0xd6, 0xf3, 0x84, 0x70, 0xe4, 0xec, 0x0f, 0xe0, 0x1f, 0x17, 0x8a, 0xd3, 0xe1, 0xd9, 0xf1, 0xf6, 0x20, 0xf1, 0x81, 0xc9, 0x07, 0x05, 0x06, 0x14, 0xa4, 0x0d, 0xee, 0x05, 0xef, 0x08, 0x1f, 0x02, 0xfb, 0x12, 0xea, 0x01, 0xcd, 0x16, 0xe8, 0xfb, 0xf8, 0x1b, 0xd1, 0xb4, 0xcc, 0x81, 0xad, 0xff, 0xfe, 0x0a, 0x1b, 0x00, 0x02, 0x0c, 0xfc, 0xf5, 0x87, 0xe0, 0x09, 0x02, 0x01, 0x03, 0xf2, 0x00, 0x01, 0x0a, 0xc5, 0x0c, 0x01, 0xeb, 0xfa, 0xe3, 0x10, 0x07, 0x04, 0xbe, 0xa7, 0x00, 0xfa, 0x11, 0x12, 0x14, 0xfd, 0x01, 0xfc, 0xdb, 0xfa, 0x0b, 0x81, 0xbf, 0xff, 0x0f, 0xec, 0xff, 0xf8, 0xe6, 0xe2, 0x26, 0x04, 0xef, 0x2c, 0x07, 0x53, 0xd5, 0xfe, 0x09, 0xcc, 0xfc, 0xff, 0x05, 0xf3, 0xa3, 0x68, 0x29, 0x8e, 0xd3, 0xf7, 0x7f, 0xfe, 0x3e, 0x01, 0xf8, 0xda, 0xa2, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x42, 0xf8, 0xff, 0xff, 0x85, 0x03, 0x00, 0x00, 0xbf, 0x03, 0x00, 0x00, 0xfc, 0x03, 0x00, 0x00, 0x63, 0x04, 0x00, 0x00, 0xdc, 0x08, 0x00, 0x00, 0x3a, 0xff, 0xff, 0xff, 0x50, 0xff, 0xff, 0xff, 0xae, 0x0e, 0x00, 0x00, 0x98, 0x05, 0x00, 0x00, 0x17, 0x08, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0xb1, 0x0c, 0x00, 0x00, 0x07, 0x08, 0x00, 0x00, 0xc3, 0x03, 0x00, 0x00, 0x25, 0x0d, 0x00, 0x00, 0xee, 0xfe, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x4d, 0x1e, 0x25, 0xbe, 0xa2, 0x67, 0x99, 0x81, 0xc2, 0x94, 0x1e, 0xa5, 0x98, 0x35, 0x43, 0xe8, 0xf8, 0x15, 0xd5, 0x13, 0x3f, 0xc5, 0xe8, 0x1b, 0x30, 0xea, 0xbb, 0x4a, 0x24, 0x81, 0x2c, 0xf6, 0xf0, 0x0d, 0x05, 0x3b, 0xfe, 0xfb, 0xff, 0x21, 0x81, 0xfc, 0x13, 0x0b, 0xfb, 0xec, 0xf2, 0xc9, 0x9e, 0xfc, 0x69, 0xa7, 0x16, 0xed, 0x2a, 0x5c, 0xdb, 0x24, 0x12, 0x1f, 0xde, 0xd9, 0x81, 0x2e, 0x37, 0x81, 0xe5, 0xe9, 0xb6, 0xe2, 0xc5, 0xfa, 0x1b, 0x38, 0x4a, 0x8b, 0x63, 0x44, 0x03, 0x04, 0x4a, 0xff, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x6b, 0x01, 0x00, 0x00, 0x2e, 0xff, 0xff, 0xff, 0x73, 0xfe, 0xff, 0xff, 0xd4, 0x01, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00, 0x38, 0xff, 0xff, 0xff, 0x3c, 0xff, 0xff, 0xff, 0x0f, 0x00, 0x00, 0x00, 0x4d, 0x4c, 0x49, 0x52, 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74, 0x65, 0x64, 0x2e, 0x00, 0x01, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x28, 0x01, 0x00, 0x00, 0x2c, 0x01, 0x00, 0x00, 0x30, 0x01, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x6d, 0x61, 0x69, 0x6e, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xd0, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x50, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x1a, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x1c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x96, 0xff, 0xff, 0xff, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xca, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0xba, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x16, 0x00, 0x00, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x0b, 0x00, 0x04, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x18, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x48, 0x0a, 0x00, 0x00, 0x90, 0x09, 0x00, 0x00, 0xd4, 0x08, 0x00, 0x00, 0x98, 0x07, 0x00, 0x00, 0x74, 0x06, 0x00, 0x00, 0x78, 0x04, 0x00, 0x00, 0x94, 0x02, 0x00, 0x00, 0xe0, 0x01, 0x00, 0x00, 0x24, 0x01, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xfa, 0xf5, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x50, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0xe4, 0xf5, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3b, 0x1b, 0x00, 0x00, 0x00, 0x53, 0x74, 0x61, 0x74, 0x65, 0x66, 0x75, 0x6c, 0x50, 0x61, 0x72, 0x74, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x65, 0x64, 0x43, 0x61, 0x6c, 0x6c, 0x5f, 0x31, 0x3a, 0x30, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x72, 0xf6, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x78, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, 0x5c, 0xf6, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x1f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x48, 0x9a, 0x80, 0x3e, 0x3c, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x5f, 0x31, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x5f, 0x31, 0x2f, 0x42, 0x69, 0x61, 0x73, 0x41, 0x64, 0x64, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x12, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x90, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x10, 0x00, 0x00, 0x00, 0xfc, 0xf6, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xa4, 0xcf, 0xab, 0x3d, 0x58, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x5f, 0x32, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x5f, 0x32, 0x2f, 0x52, 0x65, 0x6c, 0x75, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x5f, 0x32, 0x2f, 0x42, 0x69, 0x61, 0x73, 0x41, 0x64, 0x64, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0xca, 0xf7, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x18, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x88, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00, 0xb4, 0xf7, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x80, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0x09, 0xfb, 0x22, 0x3d, 0x52, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x2f, 0x52, 0x65, 0x6c, 0x75, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x2f, 0x42, 0x69, 0x61, 0x73, 0x41, 0x64, 0x64, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x2e, 0xf9, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x14, 0x00, 0x00, 0x00, 0xa8, 0x01, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0xbc, 0x01, 0x00, 0x00, 0x54, 0xf8, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x01, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0xef, 0xf0, 0xd1, 0x3b, 0x65, 0xce, 0xf3, 0x3b, 0x78, 0x8e, 0xf9, 0x3b, 0x13, 0x18, 0xc9, 0x3b, 0x2e, 0x4c, 0x11, 0x3c, 0x7f, 0x70, 0xeb, 0x3b, 0x0d, 0x9c, 0xfd, 0x3b, 0x03, 0x34, 0xfb, 0x3b, 0xd0, 0x75, 0x1d, 0x3c, 0x03, 0x8a, 0xe8, 0x3b, 0xf2, 0xc3, 0x12, 0x3c, 0x5c, 0x10, 0x2d, 0x3c, 0xd6, 0x52, 0x3f, 0x3c, 0x4e, 0xad, 0x07, 0x3c, 0x0d, 0xd3, 0xdf, 0x3b, 0x72, 0x11, 0xb8, 0x3b, 0xfa, 0x5a, 0xfe, 0x3b, 0xde, 0x46, 0x00, 0x3c, 0x18, 0xa1, 0xef, 0x3b, 0x4b, 0x12, 0x00, 0x3c, 0x92, 0xba, 0xe7, 0x3b, 0xf0, 0x0c, 0xfe, 0x3b, 0xfe, 0xad, 0x1e, 0x3c, 0xa3, 0x8d, 0x0e, 0x3c, 0x22, 0x3f, 0xad, 0x3b, 0xb7, 0x78, 0x0f, 0x3c, 0xfb, 0xbf, 0x05, 0x3c, 0x3a, 0xb4, 0xdc, 0x3b, 0xe6, 0xdb, 0x19, 0x3c, 0x29, 0x38, 0x18, 0x3c, 0x82, 0xdb, 0x25, 0x3c, 0x4e, 0x13, 0x20, 0x3c, 0x1b, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x00, 0x02, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00, 0x0e, 0xfb, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x14, 0x00, 0x00, 0x00, 0xa8, 0x01, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xd8, 0x01, 0x00, 0x00, 0x34, 0xfa, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x0c, 0x01, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x95, 0xfc, 0xfe, 0x39, 0x18, 0x0f, 0x14, 0x3a, 0x0e, 0x8d, 0x17, 0x3a, 0xc5, 0x3d, 0xf4, 0x39, 0x16, 0x79, 0x30, 0x3a, 0x5c, 0xfa, 0x0e, 0x3a, 0x29, 0x03, 0x1a, 0x3a, 0x0d, 0x8d, 0x18, 0x3a, 0xca, 0x3e, 0x3f, 0x3a, 0x76, 0x37, 0x0d, 0x3a, 0x7a, 0x41, 0x32, 0x3a, 0x6a, 0x32, 0x52, 0x3a, 0xdc, 0x5f, 0x68, 0x3a, 0xc5, 0xc9, 0x24, 0x3a, 0xa4, 0xec, 0x07, 0x3a, 0xf1, 0x8f, 0xdf, 0x39, 0x1b, 0x77, 0x1a, 0x3a, 0xdc, 0xcc, 0x1b, 0x3a, 0xba, 0x85, 0x11, 0x3a, 0x01, 0x8d, 0x1b, 0x3a, 0x7c, 0xb9, 0x0c, 0x3a, 0xb7, 0x47, 0x1a, 0x3a, 0xf3, 0xb9, 0x40, 0x3a, 0xce, 0x23, 0x2d, 0x3a, 0x39, 0x6b, 0xd2, 0x39, 0x52, 0x41, 0x2e, 0x3a, 0x99, 0x72, 0x22, 0x3a, 0x87, 0x07, 0x06, 0x3a, 0x11, 0xdf, 0x3a, 0x3a, 0x44, 0xe1, 0x38, 0x3a, 0xba, 0x71, 0x49, 0x3a, 0xee, 0x6b, 0x42, 0x3a, 0x36, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x2f, 0x52, 0x65, 0x6c, 0x75, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x2f, 0x42, 0x69, 0x61, 0x73, 0x41, 0x64, 0x64, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x06, 0xfd, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x14, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0xfc, 0x00, 0x00, 0x00, 0x2c, 0xfc, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0xc5, 0x78, 0x3c, 0x55, 0x54, 0xaf, 0x3c, 0x5e, 0x4b, 0x5c, 0x3c, 0x8f, 0x7a, 0xc7, 0x3c, 0x9d, 0xbc, 0xe8, 0x3c, 0xd7, 0x2f, 0x2c, 0x3c, 0xfc, 0x0e, 0xf9, 0x3b, 0x86, 0x54, 0x1e, 0x3c, 0x5e, 0xb0, 0x01, 0x3c, 0xdc, 0xe9, 0x2f, 0x3c, 0xe1, 0xc3, 0x44, 0x3c, 0x35, 0x33, 0xdc, 0x3c, 0x22, 0xe7, 0x23, 0x3c, 0xc3, 0xd0, 0x84, 0x3c, 0xed, 0x12, 0xd0, 0x3c, 0xc6, 0x0e, 0xff, 0x3b, 0x1d, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x5f, 0x32, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x26, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x14, 0x00, 0x00, 0x00, 0xe4, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x18, 0x01, 0x00, 0x00, 0x4c, 0xfd, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x88, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0xa9, 0x60, 0x1e, 0x3a, 0x97, 0x3e, 0x5f, 0x3a, 0xb7, 0x3f, 0x0c, 0x3a, 0x55, 0xfe, 0x7d, 0x3a, 0x94, 0x2b, 0x94, 0x3a, 0x3e, 0x3e, 0xdb, 0x39, 0xb6, 0x8f, 0x9e, 0x39, 0x7e, 0x99, 0xc9, 0x39, 0x90, 0x21, 0xa5, 0x39, 0xfb, 0xfc, 0xdf, 0x39, 0xcf, 0x89, 0xfa, 0x39, 0x55, 0x30, 0x8c, 0x3a, 0xfa, 0xb1, 0xd0, 0x39, 0xb2, 0x1c, 0x29, 0x3a, 0x04, 0x78, 0x84, 0x3a, 0x76, 0x61, 0xa2, 0x39, 0x3a, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x5f, 0x32, 0x2f, 0x52, 0x65, 0x6c, 0x75, 0x3b, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x31, 0x5f, 0x32, 0x2f, 0x42, 0x69, 0x61, 0x73, 0x41, 0x64, 0x64, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x5e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x01, 0x14, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x7c, 0x00, 0x00, 0x00, 0x84, 0xfe, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x49, 0xdc, 0x02, 0x3c, 0x6d, 0x85, 0x66, 0x3c, 0x10, 0x13, 0x92, 0x3c, 0x51, 0xb4, 0xd5, 0x3b, 0x98, 0xdf, 0x40, 0x3c, 0x1d, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x5f, 0x31, 0x2f, 0x4d, 0x61, 0x74, 0x4d, 0x75, 0x6c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x00, 0x1c, 0x00, 0x18, 0x00, 0x17, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x00, 0x16, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x14, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x7c, 0x00, 0x00, 0x00, 0x3c, 0xff, 0xff, 0xff, 0x08, 0x00, 0x00, 0x00, 0x34, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x91, 0xa6, 0x2f, 0x3a, 0x19, 0xb6, 0x9a, 0x3a, 0x6d, 0x12, 0xc4, 0x3a, 0xc8, 0x6c, 0x0f, 0x3a, 0xcb, 0x71, 0x81, 0x3a, 0x1e, 0x00, 0x00, 0x00, 0x73, 0x65, 0x71, 0x75, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x5f, 0x31, 0x2f, 0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x5f, 0x31, 0x2f, 0x42, 0x69, 0x61, 0x73, 0x41, 0x64, 0x64, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x16, 0x00, 0x20, 0x00, 0x1c, 0x00, 0x1b, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x07, 0x00, 0x16, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x18, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00, 0x48, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x60, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x45, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xd5, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, 0xc9, 0x76, 0x9b, 0x3d, 0x1e, 0x00, 0x00, 0x00, 0x73, 0x65, 0x72, 0x76, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x5f, 0x6b, 0x65, 0x72, 0x61, 0x73, 0x5f, 0x74, 0x65, 0x6e, 0x73, 0x6f, 0x72, 0x3a, 0x30, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x45, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xf0, 0xff, 0xff, 0xff, 0x19, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x19, 0x0c, 0x00, 0x10, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09 +}; diff --git a/EMG_Arm/src/core/emg_model_data.h b/EMG_Arm/src/core/emg_model_data.h new file mode 100644 index 0000000..0a12821 --- /dev/null +++ b/EMG_Arm/src/core/emg_model_data.h @@ -0,0 +1,13 @@ +// Auto-generated by train_mlp_tflite.py do not edit +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +extern const unsigned char g_model[]; +extern const int g_model_len; + +#ifdef __cplusplus +} +#endif diff --git a/EMG_Arm/src/core/gestures.c b/EMG_Arm/src/core/gestures.c index 7aa4db6..ab107b4 100644 --- a/EMG_Arm/src/core/gestures.c +++ b/EMG_Arm/src/core/gestures.c @@ -10,6 +10,7 @@ #include #include + /******************************************************************************* * Private Data ******************************************************************************/ @@ -51,6 +52,19 @@ void gestures_execute(gesture_t gesture) } } +gesture_t parse_gesture(const char *s) +{ + if (strcmp(s, "rest") == 0) return GESTURE_REST; + if (strcmp(s, "fist") == 0) return GESTURE_FIST; + if (strcmp(s, "open") == 0) return GESTURE_OPEN; + if (strcmp(s, "hook-em") == 0 || + strcmp(s, "hookem") == 0) return GESTURE_HOOK_EM; + if (strcmp(s, "thumbs-up") == 0 || + strcmp(s, "thumbsup") == 0) return GESTURE_THUMBS_UP; + + return GESTURE_NONE; +} + const char* gestures_get_name(gesture_t gesture) { if (gesture >= GESTURE_COUNT) { @@ -65,38 +79,49 @@ const char* gestures_get_name(gesture_t gesture) void gesture_open(void) { - hand_unflex_all(); + hand_set_finger_angle(FINGER_THUMB, minAngles[FINGER_THUMB]); + hand_set_finger_angle(FINGER_INDEX, minAngles[FINGER_INDEX]); + hand_set_finger_angle(FINGER_MIDDLE, minAngles[FINGER_MIDDLE]); + hand_set_finger_angle(FINGER_RING, minAngles[FINGER_RING]); + hand_set_finger_angle(FINGER_PINKY, minAngles[FINGER_PINKY]); } void gesture_fist(void) { - hand_flex_all(); + hand_set_finger_angle(FINGER_INDEX, maxAngles[FINGER_INDEX]); + hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]); + hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]); + hand_set_finger_angle(FINGER_PINKY, maxAngles[FINGER_PINKY]); + hand_set_finger_angle(FINGER_THUMB, maxAngles[FINGER_THUMB]); } void gesture_hook_em(void) { /* Index and pinky extended, others flexed */ - hand_flex_finger(FINGER_THUMB); - hand_unflex_finger(FINGER_INDEX); - hand_flex_finger(FINGER_MIDDLE); - hand_flex_finger(FINGER_RING); - hand_unflex_finger(FINGER_PINKY); + hand_set_finger_angle(FINGER_THUMB, maxAngles[FINGER_THUMB]); + hand_set_finger_angle(FINGER_INDEX, minAngles[FINGER_INDEX]); + hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]); + hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]); + hand_set_finger_angle(FINGER_PINKY, minAngles[FINGER_PINKY]); } void gesture_thumbs_up(void) { /* Thumb extended, others flexed */ - hand_unflex_finger(FINGER_THUMB); - hand_flex_finger(FINGER_INDEX); - hand_flex_finger(FINGER_MIDDLE); - hand_flex_finger(FINGER_RING); - hand_flex_finger(FINGER_PINKY); + hand_set_finger_angle(FINGER_THUMB, minAngles[FINGER_THUMB]); + hand_set_finger_angle(FINGER_INDEX, maxAngles[FINGER_INDEX]); + hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]); + hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]); + hand_set_finger_angle(FINGER_PINKY, maxAngles[FINGER_PINKY]); } void gesture_rest(void) { - /* Rest is same as open - neutral position */ - gesture_open(); + hand_set_finger_angle(FINGER_THUMB, (maxAngles[FINGER_THUMB] + minAngles[FINGER_THUMB])/2); + hand_set_finger_angle(FINGER_INDEX, (maxAngles[FINGER_INDEX] + minAngles[FINGER_INDEX])/2); + hand_set_finger_angle(FINGER_MIDDLE, (maxAngles[FINGER_MIDDLE] + minAngles[FINGER_MIDDLE])/2); + hand_set_finger_angle(FINGER_RING, (maxAngles[FINGER_RING] + minAngles[FINGER_RING])/2); + hand_set_finger_angle(FINGER_PINKY, (maxAngles[FINGER_PINKY] + minAngles[FINGER_PINKY])/2); } /******************************************************************************* diff --git a/EMG_Arm/src/core/gestures.h b/EMG_Arm/src/core/gestures.h index a77e9ea..2e88538 100644 --- a/EMG_Arm/src/core/gestures.h +++ b/EMG_Arm/src/core/gestures.h @@ -13,6 +13,7 @@ #define GESTURES_H #include +#include #include "config/config.h" /******************************************************************************* @@ -26,6 +27,8 @@ */ void gestures_execute(gesture_t gesture); +gesture_t parse_gesture(const char *s); + /** * @brief Get the name of a gesture as a string. * diff --git a/EMG_Arm/src/core/inference.c b/EMG_Arm/src/core/inference.c index 4e800c3..55bf63c 100644 --- a/EMG_Arm/src/core/inference.c +++ b/EMG_Arm/src/core/inference.c @@ -4,19 +4,52 @@ */ #include "inference.h" +#include "calibration.h" #include "config/config.h" #include "model_weights.h" #include #include #include +#if MODEL_EXPAND_FEATURES +#include "dsps_fft2r.h" /* esp-dsp: complex 2-radix FFT */ +#define FFT_N 256 /* Must match Python fft_n=256 */ +#endif + // --- Constants --- -#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python) -#define VOTE_WINDOW 5 // Majority vote window size -#define DEBOUNCE_COUNT 3 // Confirmations needed to change output +#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python) +#define VOTE_WINDOW 5 // Majority vote window size +#define DEBOUNCE_COUNT 3 // Confirmations needed to change output +// Change C: confidence rejection threshold. +// When the peak smoothed probability stays below this, hold the last confirmed +// output rather than outputting an uncertain prediction. Prevents false arm +// actuations during gesture transitions, rest-to-gesture onset, and electrode lift. +// Meta paper uses 0.35; 0.40 adds a prosthetic safety margin. +// Tune down to 0.35 if real gestures are being incorrectly rejected. +#define CONFIDENCE_THRESHOLD 0.40f + +// --- Change B: IIR Bandpass Filter (20–450 Hz, 2nd-order Butterworth @ 1 kHz) --- +// Two cascaded biquad sections, Direct Form II Transposed. +// Computed via scipy.signal.butter(2, [20,450], btype='bandpass', fs=1000, output='sos'). +// b coefficients [b0, b1, b2] per section: +#define IIR_NUM_SECTIONS 2 +static const float IIR_B[IIR_NUM_SECTIONS][3] = { + { 0.7320224766f, 1.4640449531f, 0.7320224766f }, /* section 0 */ + { 1.0000000000f, -2.0000000000f, 1.0000000000f }, /* section 1 */ +}; +// Feedback coefficients [a1, a2] per section (a0 = 1, implicit): +static const float IIR_A[IIR_NUM_SECTIONS][2] = { + { 1.5597081442f, 0.6416146818f }, /* section 0 */ + { -1.8224796027f, 0.8372542588f }, /* section 1 */ +}; // --- State --- -static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS]; +static float window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS]; +static float biquad_w[IIR_NUM_SECTIONS][NUM_CHANNELS][2]; /* biquad delay states */ + +#if MODEL_EXPAND_FEATURES +static bool s_fft_ready = false; +#endif static int buffer_head = 0; static int samples_collected = 0; @@ -30,9 +63,17 @@ static int pending_count = 0; void inference_init(void) { memset(window_buffer, 0, sizeof(window_buffer)); + memset(biquad_w, 0, sizeof(biquad_w)); buffer_head = 0; samples_collected = 0; +#if MODEL_EXPAND_FEATURES + if (!s_fft_ready) { + dsps_fft2r_init_fc32(NULL, FFT_N); + s_fft_ready = true; + } +#endif + // Initialize smoothing for (int i = 0; i < MODEL_NUM_CLASSES; i++) { smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES; @@ -47,9 +88,17 @@ void inference_init(void) { } bool inference_add_sample(uint16_t *channels) { - // Add to circular buffer + // Convert to float, apply per-channel biquad bandpass, then store in circular buffer. for (int i = 0; i < NUM_CHANNELS; i++) { - window_buffer[buffer_head][i] = channels[i]; + float x = (float)channels[i]; + // Cascade IIR_NUM_SECTIONS biquad sections (Direct Form II Transposed) + for (int s = 0; s < IIR_NUM_SECTIONS; s++) { + float y = IIR_B[s][0] * x + biquad_w[s][i][0]; + biquad_w[s][i][0] = IIR_B[s][1] * x - IIR_A[s][0] * y + biquad_w[s][i][1]; + biquad_w[s][i][1] = IIR_B[s][2] * x - IIR_A[s][1] * y; + x = y; + } + window_buffer[buffer_head][i] = x; } buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE; @@ -65,11 +114,265 @@ bool inference_add_sample(uint16_t *channels) { // --- Feature Extraction --- -static void compute_features(float *features_out) { - // Process each channel - // We need to iterate over the logical window (unrolling circular buffer) +/* ── helpers used by compute_features_expanded ──────────────────────────── */ - for (int ch = 0; ch < NUM_CHANNELS; ch++) { +#if MODEL_EXPAND_FEATURES + +/** Solve 4×4 linear system A·x = b via Gaussian elimination with partial pivoting. + * Returns false and leaves x zeroed if the matrix is singular. */ +static bool solve_4x4(float A[4][4], const float b[4], float x[4]) { + float M[4][5]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) M[i][j] = A[i][j]; + M[i][4] = b[i]; + } + for (int col = 0; col < 4; col++) { + int pivot = col; + float maxv = fabsf(M[col][col]); + for (int row = col + 1; row < 4; row++) { + if (fabsf(M[row][col]) > maxv) { maxv = fabsf(M[row][col]); pivot = row; } + } + if (maxv < 1e-8f) { x[0]=x[1]=x[2]=x[3]=0.f; return false; } + if (pivot != col) { + float tmp[5]; + memcpy(tmp, M[col], 5*sizeof(float)); + memcpy(M[col], M[pivot], 5*sizeof(float)); + memcpy(M[pivot], tmp, 5*sizeof(float)); + } + for (int row = col + 1; row < 4; row++) { + float f = M[row][col] / M[col][col]; + for (int j = col; j < 5; j++) M[row][j] -= f * M[col][j]; + } + } + for (int row = 3; row >= 0; row--) { + x[row] = M[row][4]; + for (int j = row + 1; j < 4; j++) x[row] -= M[row][j] * x[j]; + x[row] /= M[row][row]; + } + return true; +} + +/** + * @brief Full 69-feature extraction (Change 1). + * + * Per channel (20): RMS, WL, ZC, SSC, MAV, VAR, IEMG, WAMP, + * AR1-AR4, MNF, MDF, PKF, MNP, BP0-BP3 + * Cross-channel (9 for 3 channels): corr, log-RMS-ratio, cov for each pair + * + * Feature layout must exactly match EMGFeatureExtractor._EXPANDED_KEYS + + * cross-channel order in the Python class. + */ +static void compute_features_expanded(float *features_out) { + memset(features_out, 0, MODEL_NUM_FEATURES * sizeof(float)); + + /* Persistent buffers for centered signals (3 ch × 150 samples) */ + static float ch_signals[HAND_NUM_CHANNELS][INFERENCE_WINDOW_SIZE]; + static float s_fft_buf[FFT_N * 2]; /* Complex interleaved [re,im,...] */ + + float ch_rms[HAND_NUM_CHANNELS]; + float norm_sq = 0.0f; + + /* ────────────────────────────────────────────────────────────────────── + * Pass 1: per-channel TD + AR + spectral features + * ────────────────────────────────────────────────────────────────────── */ + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + /* Read channel from circular buffer (oldest-first) */ + float *sig = ch_signals[ch]; + float sum = 0.0f; + int idx = buffer_head; + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + sig[i] = window_buffer[idx][ch]; + sum += sig[i]; + idx = (idx + 1) % INFERENCE_WINDOW_SIZE; + } + /* DC removal */ + float mean = sum / INFERENCE_WINDOW_SIZE; + float sq_sum = 0.0f; + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + sig[i] -= mean; + sq_sum += sig[i] * sig[i]; + } + /* Change 4 — Reinhard tone-mapping: 64·x / (32 + |x|) */ +#if MODEL_USE_REINHARD + sq_sum = 0.0f; + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + float x = sig[i]; + sig[i] = 64.0f * x / (32.0f + fabsf(x)); + sq_sum += sig[i] * sig[i]; + } +#endif + + float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE); + ch_rms[ch] = rms; + norm_sq += rms * rms; + + float zc_thresh = FEAT_ZC_THRESH * rms; + float ssc_thresh = (FEAT_SSC_THRESH * rms) * (FEAT_SSC_THRESH * rms); + + /* TD features */ + float wl = 0.0f, mav = 0.0f, iemg = 0.0f; + int zc = 0, ssc = 0, wamp = 0; + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + float a = fabsf(sig[i]); + mav += a; + iemg += a; + } + mav /= INFERENCE_WINDOW_SIZE; + float var_val = sq_sum / INFERENCE_WINDOW_SIZE; /* variance (mean already 0) */ + + for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) { + float diff = sig[i+1] - sig[i]; + float adiff = fabsf(diff); + wl += adiff; + if (adiff > zc_thresh) wamp++; + if ((sig[i] > 0.0f && sig[i+1] < 0.0f) || + (sig[i] < 0.0f && sig[i+1] > 0.0f)) { + if (adiff > zc_thresh) zc++; + } + if (i < INFERENCE_WINDOW_SIZE - 2) { + float d1 = sig[i+1] - sig[i]; + float d2 = sig[i+1] - sig[i+2]; + if ((d1 * d2) > ssc_thresh) ssc++; + } + } + + /* AR(4) via Yule-Walker */ + float r[5] = {0}; + for (int lag = 0; lag < 5; lag++) { + for (int i = 0; i < INFERENCE_WINDOW_SIZE - lag; i++) + r[lag] += sig[i] * sig[i + lag]; + r[lag] /= INFERENCE_WINDOW_SIZE; + } + float T[4][4], b_ar[4], ar[4] = {0}; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) T[i][j] = r[abs(i - j)]; + b_ar[i] = r[i + 1]; + } + solve_4x4(T, b_ar, ar); + + /* Spectral features via FFT (zero-pad to FFT_N) */ + memset(s_fft_buf, 0, sizeof(s_fft_buf)); + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + s_fft_buf[2*i] = sig[i]; + s_fft_buf[2*i + 1] = 0.0f; + } + dsps_fft2r_fc32(s_fft_buf, FFT_N); + dsps_bit_rev_fc32(s_fft_buf, FFT_N); + + float total_power = 1e-10f; + const float freq_step = (float)EMG_SAMPLE_RATE_HZ / FFT_N; + float pwr[FFT_N / 2]; + for (int k = 0; k < FFT_N / 2; k++) { + float re = s_fft_buf[2*k], im = s_fft_buf[2*k + 1]; + pwr[k] = re*re + im*im; + total_power += pwr[k]; + } + + float mnf = 0.0f; + for (int k = 0; k < FFT_N / 2; k++) mnf += k * freq_step * pwr[k]; + mnf /= total_power; + + float cumsum = 0.0f, half = total_power / 2.0f; + int mdf_k = FFT_N / 2 - 1; + for (int k = 0; k < FFT_N / 2; k++) { + cumsum += pwr[k]; + if (cumsum >= half) { mdf_k = k; break; } + } + float mdf = mdf_k * freq_step; + + int pkf_k = 0; + for (int k = 1; k < FFT_N / 2; k++) { + if (pwr[k] > pwr[pkf_k]) pkf_k = k; + } + float pkf = pkf_k * freq_step; + float mnp = total_power / (FFT_N / 2); + + float bp0 = 0, bp1 = 0, bp2 = 0, bp3 = 0; + for (int k = 0; k < FFT_N / 2; k++) { + float f = k * freq_step; + if (f >= 20.f && f < 80.f ) bp0 += pwr[k]; + else if (f >= 80.f && f < 150.f) bp1 += pwr[k]; + else if (f >= 150.f && f < 250.f) bp2 += pwr[k]; + else if (f >= 250.f && f < 450.f) bp3 += pwr[k]; + } + bp0 /= total_power; bp1 /= total_power; + bp2 /= total_power; bp3 /= total_power; + + /* Store 20 features for this channel */ + int base = ch * 20; + features_out[base + 0] = rms; + features_out[base + 1] = wl; + features_out[base + 2] = (float)zc; + features_out[base + 3] = (float)ssc; + features_out[base + 4] = mav; + features_out[base + 5] = var_val; + features_out[base + 6] = iemg; + features_out[base + 7] = (float)wamp; + features_out[base + 8] = ar[0]; + features_out[base + 9] = ar[1]; + features_out[base + 10] = ar[2]; + features_out[base + 11] = ar[3]; + features_out[base + 12] = mnf; + features_out[base + 13] = mdf; + features_out[base + 14] = pkf; + features_out[base + 15] = mnp; + features_out[base + 16] = bp0; + features_out[base + 17] = bp1; + features_out[base + 18] = bp2; + features_out[base + 19] = bp3; + } + + /* ────────────────────────────────────────────────────────────────────── + * Amplitude normalization (matches Python normalize=True behaviour) + * ────────────────────────────────────────────────────────────────────── */ + float norm_factor = sqrtf(norm_sq); + if (norm_factor < 1e-6f) norm_factor = 1e-6f; + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + int base = ch * 20; + features_out[base + 0] /= norm_factor; /* rms */ + features_out[base + 1] /= norm_factor; /* wl */ + features_out[base + 4] /= norm_factor; /* mav */ + features_out[base + 6] /= norm_factor; /* iemg */ + } + + /* ────────────────────────────────────────────────────────────────────── + * Cross-channel features (3 pairs × 3 features = 9) + * Order: (0,1), (0,2), (1,2) → corr, lrms, cov + * ────────────────────────────────────────────────────────────────────── */ + int xc_base = HAND_NUM_CHANNELS * 20; /* = 60 */ + int xc_idx = 0; + for (int i = 0; i < HAND_NUM_CHANNELS; i++) { + for (int j = i + 1; j < HAND_NUM_CHANNELS; j++) { + float ri = ch_rms[i] + 1e-10f; + float rj = ch_rms[j] + 1e-10f; + + float dot_ij = 0.0f; + for (int k = 0; k < INFERENCE_WINDOW_SIZE; k++) + dot_ij += ch_signals[i][k] * ch_signals[j][k]; + float n_inv = 1.0f / INFERENCE_WINDOW_SIZE; + + float corr = dot_ij * n_inv / (ri * rj); + if (corr > 1.0f) corr = 1.0f; + if (corr < -1.0f) corr = -1.0f; + + float lrms = logf(ri / rj); + float cov = dot_ij * n_inv / (norm_factor * norm_factor); + + features_out[xc_base + xc_idx * 3 + 0] = corr; + features_out[xc_base + xc_idx * 3 + 1] = lrms; + features_out[xc_base + xc_idx * 3 + 2] = cov; + xc_idx++; + } + } +} + +#endif /* MODEL_EXPAND_FEATURES */ + +static void compute_features(float *features_out) { + // Process forearm channels only (ch0-ch2) for hand gesture classification. + // The bicep channel (ch3) is excluded — it will be processed independently. + + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { float sum = 0; float sq_sum = 0; @@ -79,7 +382,7 @@ static void compute_features(float *features_out) { int idx = buffer_head; // Oldest sample for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { - signal[i] = (float)window_buffer[idx][ch]; + signal[i] = window_buffer[idx][ch]; sum += signal[i]; idx = (idx + 1) % INFERENCE_WINDOW_SIZE; } @@ -96,6 +399,15 @@ static void compute_features(float *features_out) { signal[i] -= mean; sq_sum += signal[i] * signal[i]; } + /* Change 4 — Reinhard tone-mapping: 64·x / (32 + |x|) */ +#if MODEL_USE_REINHARD + sq_sum = 0.0f; + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + float x = signal[i]; + signal[i] = 64.0f * x / (32.0f + fabsf(x)); + sq_sum += signal[i] * signal[i]; + } +#endif float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE); @@ -133,6 +445,60 @@ static void compute_features(float *features_out) { features_out[base + 2] = (float)zc; features_out[base + 3] = (float)ssc; } + +#if MODEL_NORMALIZE_FEATURES + // Normalize amplitude-dependent features (RMS, WL) by total RMS across + // channels. This makes the model robust to impedance shifts between sessions + // while preserving relative channel activation patterns. + float total_rms_sq = 0; + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + float ch_rms = features_out[ch * 4 + 0]; + total_rms_sq += ch_rms * ch_rms; + } + float norm_factor = sqrtf(total_rms_sq); + if (norm_factor < 1e-6f) norm_factor = 1e-6f; + + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + features_out[ch * 4 + 0] /= norm_factor; // RMS + features_out[ch * 4 + 1] /= norm_factor; // WL + } +#endif +} + +// --- Feature extraction (public wrapper used by inference_ensemble.c) --- + +void inference_extract_features(float *features_out) { +#if MODEL_EXPAND_FEATURES + compute_features_expanded(features_out); +#else + compute_features(features_out); +#endif + calibration_apply(features_out); +} + +// --- Raw LDA probability (for multi-model voting) --- + +void inference_predict_raw(const float *features, float *proba_out) { + float raw_scores[MODEL_NUM_CLASSES]; + float max_score = -1e9f; + + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + float score = LDA_INTERCEPTS[c]; + for (int f = 0; f < MODEL_NUM_FEATURES; f++) { + score += features[f] * LDA_WEIGHTS[c][f]; + } + raw_scores[c] = score; + if (score > max_score) max_score = score; + } + + float sum_exp = 0.0f; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + proba_out[c] = expf(raw_scores[c] - max_score); + sum_exp += proba_out[c]; + } + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + proba_out[c] /= sum_exp; + } } // --- Prediction --- @@ -144,7 +510,14 @@ int inference_predict(float *confidence) { // 1. Extract Features float features[MODEL_NUM_FEATURES]; +#if MODEL_EXPAND_FEATURES + compute_features_expanded(features); +#else compute_features(features); +#endif + + // 1b. Change D: z-score normalise using NVS-stored session calibration + calibration_apply(features); // 2. LDA Inference (Linear Score) float raw_scores[MODEL_NUM_CLASSES]; @@ -195,6 +568,15 @@ int inference_predict(float *confidence) { } } + // Change C: confidence rejection. + // If the strongest smoothed probability is still too low, the classifier is + // uncertain — return the last confirmed output without updating vote/debounce state. + // This prevents low-confidence transitions from actuating the arm spuriously. + if (max_smoothed_prob < CONFIDENCE_THRESHOLD) { + *confidence = max_smoothed_prob; + return current_output; // -1 (GESTURE_NONE) until first confident prediction + } + // 3b. Majority Vote vote_history[vote_head] = smoothed_winner; vote_head = (vote_head + 1) % VOTE_WINDOW; @@ -255,17 +637,12 @@ const char *inference_get_class_name(int class_idx) { int inference_get_gesture_enum(int class_idx) { const char *name = inference_get_class_name(class_idx); + return inference_get_gesture_by_name(name); +} - // Map string name to gesture_t enum - // Strings must match those in Python list: ["fist", "hook_em", "open", - // "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums - // are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5 - - // Case-insensitive check would be safer, but let's assume Python output is - // lowercase as seen in scripts or uppercase if specified. In - // learning_data_collection.py, they seem to be "rest", "open", "fist", etc. - - // Simple string matching +int inference_get_gesture_by_name(const char *name) { + // Accepts both lowercase (Python output) and uppercase (C enum name style). + // Add a new case here whenever a gesture is added to gesture_t in config.h. if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0) return GESTURE_REST; if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0) @@ -276,6 +653,21 @@ int inference_get_gesture_enum(int class_idx) { return GESTURE_HOOK_EM; if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0) return GESTURE_THUMBS_UP; - return GESTURE_NONE; } + +float inference_get_bicep_rms(int n_samples) { + if (samples_collected < INFERENCE_WINDOW_SIZE) return 0.0f; + if (n_samples > INFERENCE_WINDOW_SIZE) n_samples = INFERENCE_WINDOW_SIZE; + + float sum_sq = 0.0f; + // Walk backwards from buffer_head (oldest = buffer_head, newest = buffer_head - 1) + int start = (buffer_head - n_samples + INFERENCE_WINDOW_SIZE) % INFERENCE_WINDOW_SIZE; + int idx = start; + for (int i = 0; i < n_samples; i++) { + float v = window_buffer[idx][3]; // channel 3 = bicep + sum_sq += v * v; + idx = (idx + 1) % INFERENCE_WINDOW_SIZE; + } + return sqrtf(sum_sq / n_samples); +} diff --git a/EMG_Arm/src/core/inference.h b/EMG_Arm/src/core/inference.h index 32cf3b0..6a0fea6 100644 --- a/EMG_Arm/src/core/inference.h +++ b/EMG_Arm/src/core/inference.h @@ -9,9 +9,11 @@ #include #include -// --- Configuration --- -#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python) -#define NUM_CHANNELS 4 // Number of EMG channels +// --- Configuration (must match Python WINDOW_SIZE_MS / HOP_SIZE_MS) --- +#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (150ms at 1kHz) +#define INFERENCE_HOP_SIZE 25 // Hop/stride in samples (25ms at 1kHz) +#define NUM_CHANNELS 4 // Total EMG channels (buffer stores all) +#define HAND_NUM_CHANNELS 3 // Forearm channels for hand classifier (ch0-ch2) /** * @brief Initialize the inference engine. @@ -44,4 +46,53 @@ const char *inference_get_class_name(int class_idx); */ int inference_get_gesture_enum(int class_idx); +/** + * @brief Map a gesture name string directly to gesture_t enum value. + * + * Used by the laptop-predict path to convert the name sent by live_predict.py + * into a gesture_t without needing a class index. + * + * @param name Lowercase gesture name, e.g. "fist", "rest", "open" + * @return gesture_t value, or GESTURE_NONE if unrecognised + */ +int inference_get_gesture_by_name(const char *name); + +/** + * @brief Compute LDA softmax probabilities without smoothing/voting/debounce. + * + * Used by the multi-model voting path in main.c. The caller is responsible + * for post-processing (EMA, majority vote, debounce). + * + * @param features Calibrated feature vector (MODEL_NUM_FEATURES floats). + * @param proba_out Output probability array (MODEL_NUM_CLASSES floats). + */ +void inference_predict_raw(const float *features, float *proba_out); + +/** + * @brief Extract and calibrate features from the current window. + * + * Dispatches to compute_features() or compute_features_expanded() depending + * on MODEL_EXPAND_FEATURES, then applies calibration_apply(). The resulting + * float array is identical to what inference_predict() uses internally. + * + * Called by inference_ensemble.c so that the ensemble path does not duplicate + * the feature-extraction logic. + * + * @param features_out Caller-allocated array of MODEL_NUM_FEATURES floats. + */ +void inference_extract_features(float *features_out); + +/** + * @brief Compute RMS of the last n_samples from channel 3 (bicep) in the + * inference circular buffer. + * + * Used by the bicep subsystem to obtain current activation level without + * exposing the internal window_buffer. + * + * @param n_samples Number of samples to include (clamped to INFERENCE_WINDOW_SIZE). + * @return RMS value in the same units as the filtered buffer (mV after Change B). + * Returns 0 if the buffer is not yet filled. + */ +float inference_get_bicep_rms(int n_samples); + #endif /* INFERENCE_H */ diff --git a/EMG_Arm/src/core/inference_ensemble.c b/EMG_Arm/src/core/inference_ensemble.c new file mode 100644 index 0000000..0e4713e --- /dev/null +++ b/EMG_Arm/src/core/inference_ensemble.c @@ -0,0 +1,262 @@ +/** + * @file inference_ensemble.c + * @brief 3-specialist-LDA + meta-LDA ensemble inference pipeline (Change F). + * + * Guarded by MODEL_USE_ENSEMBLE in model_weights.h. + * When 0, provides empty stubs so the file compiles unconditionally. + */ + +#include "inference_ensemble.h" +#include "inference.h" +#include "model_weights.h" + +#if MODEL_USE_ENSEMBLE + +#include "inference_mlp.h" +#include "model_weights_ensemble.h" +#include "calibration.h" +#include +#include +#include + +#define ENSEMBLE_EMA_ALPHA 0.70f +#define ENSEMBLE_CONF_THRESHOLD 0.50f /**< below this: escalate to MLP fallback */ +#define REJECT_THRESHOLD 0.40f /**< below this even after MLP: hold output */ +#define REST_ACTIVITY_THRESHOLD 0.05f /**< total RMS gate — skip inference during rest */ + +/* Class index for "rest" — looked up once in init so we don't return + * the gesture_t enum value (GESTURE_REST = 1) which would be misinterpreted + * as class index 1 ("hook_em"). See Bug 2 in the code review. */ +static int s_rest_class_idx = 0; + +/* EMA probability state */ +static float s_smoothed[MODEL_NUM_CLASSES]; + +/* Majority vote ring buffer (window = 5) */ +static int s_vote_history[5]; +static int s_vote_head = 0; + +/* Debounce state */ +static int s_current_output = -1; +static int s_pending_output = -1; +static int s_pending_count = 0; + +/* ── Generic LDA softmax ──────────────────────────────────────────────────── */ + +/** + * Compute softmax class probabilities from a flat feature vector. + * + * @param feat Feature vector (contiguous, length n_feat). + * @param n_feat Number of features. + * @param weights_flat Row-major weight matrix, shape [n_classes][n_feat]. + * @param intercepts Intercept vector, length n_classes. + * @param n_classes Number of output classes. + * @param proba_out Output probabilities, length n_classes (caller-allocated). + */ +static void lda_softmax(const float *feat, int n_feat, + const float *weights_flat, const float *intercepts, + int n_classes, float *proba_out) { + float raw[MODEL_NUM_CLASSES]; + float max_raw = -1e9f; + float sum_exp = 0.0f; + + for (int c = 0; c < n_classes; c++) { + raw[c] = intercepts[c]; + const float *w = weights_flat + c * n_feat; + for (int f = 0; f < n_feat; f++) { + raw[c] += feat[f] * w[f]; + } + if (raw[c] > max_raw) max_raw = raw[c]; + } + for (int c = 0; c < n_classes; c++) { + proba_out[c] = expf(raw[c] - max_raw); + sum_exp += proba_out[c]; + } + for (int c = 0; c < n_classes; c++) { + proba_out[c] /= sum_exp; + } +} + +/* ── Public API ───────────────────────────────────────────────────────────── */ + +void inference_ensemble_init(void) { + /* Find the class index for "rest" once, so we can return it correctly + * from the activity gate without confusing class indices with gesture_t. */ + s_rest_class_idx = 0; + for (int i = 0; i < MODEL_NUM_CLASSES; i++) { + if (strcmp(MODEL_CLASS_NAMES[i], "rest") == 0) { + s_rest_class_idx = i; + break; + } + } + + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + s_smoothed[c] = 1.0f / MODEL_NUM_CLASSES; + } + for (int i = 0; i < 5; i++) { + s_vote_history[i] = -1; + } + s_vote_head = 0; + s_current_output = -1; + s_pending_output = -1; + s_pending_count = 0; +} + +void inference_ensemble_predict_raw(const float *features, float *proba_out) { + /* Gather TD features (non-contiguous: 12 per channel × 3 channels) */ + float td_buf[TD_NUM_FEATURES]; + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + memcpy(td_buf + ch * 12, + features + ch * ENSEMBLE_PER_CH_FEATURES, + 12 * sizeof(float)); + } + + /* Gather FD features (non-contiguous: 8 per channel × 3 channels) */ + float fd_buf[FD_NUM_FEATURES]; + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + memcpy(fd_buf + ch * 8, + features + ch * ENSEMBLE_PER_CH_FEATURES + 12, + 8 * sizeof(float)); + } + + /* CC features are already contiguous at the end (indices 60-68) */ + const float *cc_buf = features + CC_FEAT_OFFSET; + + /* Specialist LDA predictions */ + float prob_td[MODEL_NUM_CLASSES]; + float prob_fd[MODEL_NUM_CLASSES]; + float prob_cc[MODEL_NUM_CLASSES]; + + lda_softmax(td_buf, TD_NUM_FEATURES, + (const float *)LDA_TD_WEIGHTS, LDA_TD_INTERCEPTS, + MODEL_NUM_CLASSES, prob_td); + lda_softmax(fd_buf, FD_NUM_FEATURES, + (const float *)LDA_FD_WEIGHTS, LDA_FD_INTERCEPTS, + MODEL_NUM_CLASSES, prob_fd); + lda_softmax(cc_buf, CC_NUM_FEATURES, + (const float *)LDA_CC_WEIGHTS, LDA_CC_INTERCEPTS, + MODEL_NUM_CLASSES, prob_cc); + + /* Meta-LDA stacker */ + float meta_in[META_NUM_INPUTS]; + memcpy(meta_in, prob_td, MODEL_NUM_CLASSES * sizeof(float)); + memcpy(meta_in + MODEL_NUM_CLASSES, prob_fd, MODEL_NUM_CLASSES * sizeof(float)); + memcpy(meta_in + 2*MODEL_NUM_CLASSES, prob_cc, MODEL_NUM_CLASSES * sizeof(float)); + + lda_softmax(meta_in, META_NUM_INPUTS, + (const float *)META_LDA_WEIGHTS, META_LDA_INTERCEPTS, + MODEL_NUM_CLASSES, proba_out); +} + +int inference_ensemble_predict(float *confidence) { + /* 1. Extract + calibrate features (shared with single-model path) */ + float features[MODEL_NUM_FEATURES]; + inference_extract_features(features); /* includes calibration_apply() */ + + /* 2. Activity gate — skip inference during obvious REST */ + float total_rms_sq = 0.0f; + for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) { + /* RMS is the first feature per channel (index 0 in each 20-feature block) */ + float r = features[ch * ENSEMBLE_PER_CH_FEATURES]; + total_rms_sq += r * r; + } + if (sqrtf(total_rms_sq) < REST_ACTIVITY_THRESHOLD) { + *confidence = 1.0f; + return s_rest_class_idx; + } + + /* 3. Run ensemble pipeline (raw probabilities) */ + float meta_probs[MODEL_NUM_CLASSES]; + inference_ensemble_predict_raw(features, meta_probs); + + /* 7. EMA smoothing on meta output */ + float max_smooth = 0.0f; + int winner = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + s_smoothed[c] = ENSEMBLE_EMA_ALPHA * s_smoothed[c] + + (1.0f - ENSEMBLE_EMA_ALPHA) * meta_probs[c]; + if (s_smoothed[c] > max_smooth) { + max_smooth = s_smoothed[c]; + winner = c; + } + } + + /* 8. Confidence cascade: escalate to MLP if meta-LDA is uncertain */ + if (max_smooth < ENSEMBLE_CONF_THRESHOLD) { + float mlp_conf = 0.0f; + int mlp_winner = inference_mlp_predict(features, MODEL_NUM_FEATURES, &mlp_conf); + if (mlp_conf > max_smooth) { + winner = mlp_winner; + max_smooth = mlp_conf; + } + } + + /* 9. Reject if still uncertain — hold current output */ + if (max_smooth < REJECT_THRESHOLD) { + *confidence = max_smooth; + return s_current_output >= 0 ? s_current_output : s_rest_class_idx; + } + + *confidence = max_smooth; + + /* 10. Majority vote (window = 5) */ + s_vote_history[s_vote_head] = winner; + s_vote_head = (s_vote_head + 1) % 5; + + int counts[MODEL_NUM_CLASSES]; + memset(counts, 0, sizeof(counts)); + for (int i = 0; i < 5; i++) { + if (s_vote_history[i] >= 0) { + counts[s_vote_history[i]]++; + } + } + int majority = 0, majority_cnt = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + if (counts[c] > majority_cnt) { + majority_cnt = counts[c]; + majority = c; + } + } + + /* 11. Debounce — need 3 consecutive predictions to change output */ + int final_out = (s_current_output >= 0) ? s_current_output : majority; + + if (s_current_output < 0) { + /* First prediction ever */ + s_current_output = majority; + final_out = majority; + } else if (majority == s_current_output) { + /* Staying in current gesture — reset pending */ + s_pending_output = majority; + s_pending_count = 1; + } else if (majority == s_pending_output) { + /* Same new gesture again */ + if (++s_pending_count >= 3) { + s_current_output = majority; + final_out = majority; + } + } else { + /* New gesture candidate — start counting */ + s_pending_output = majority; + s_pending_count = 1; + } + + return final_out; +} + +#else /* MODEL_USE_ENSEMBLE == 0 — compile-time stubs */ + +void inference_ensemble_init(void) {} + +void inference_ensemble_predict_raw(const float *features, float *proba_out) { + (void)features; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) + proba_out[c] = 1.0f / MODEL_NUM_CLASSES; +} + +int inference_ensemble_predict(float *confidence) { + if (confidence) *confidence = 0.0f; + return 0; +} + +#endif /* MODEL_USE_ENSEMBLE */ diff --git a/EMG_Arm/src/core/inference_ensemble.h b/EMG_Arm/src/core/inference_ensemble.h new file mode 100644 index 0000000..31ae031 --- /dev/null +++ b/EMG_Arm/src/core/inference_ensemble.h @@ -0,0 +1,44 @@ +/** + * @file inference_ensemble.h + * @brief 3-specialist-LDA + meta-LDA ensemble inference pipeline (Change F). + * + * Requires: + * - Change 1 expanded features (MODEL_EXPAND_FEATURES 1) + * - Change 7 training (train_ensemble.py) to generate model_weights_ensemble.h + * - Change E MLP (MODEL_USE_MLP 1) for confidence-cascade fallback + * + * Enable by setting MODEL_USE_ENSEMBLE 1 in model_weights.h and calling + * inference_ensemble_init() from app_main() instead of (or alongside) inference_init(). + */ + +#pragma once +#include + +/** + * @brief Initialise ensemble state (EMA, vote history, debounce). + * Must be called before inference_ensemble_predict(). + */ +void inference_ensemble_init(void); + +/** + * @brief Compute ensemble probabilities without smoothing/voting/debounce. + * + * Runs the 3 specialist LDAs + meta-LDA stacker and writes the raw meta-LDA + * probabilities to proba_out. Used by the multi-model voting path in main.c. + * + * @param features Calibrated feature vector (MODEL_NUM_FEATURES floats). + * @param proba_out Output probability array (MODEL_NUM_CLASSES floats). + */ +void inference_ensemble_predict_raw(const float *features, float *proba_out); + +/** + * @brief Run one inference hop through the full ensemble pipeline. + * + * Internally calls inference_extract_features() to pull the latest window, + * routes through the three specialist LDAs, the meta-LDA stacker, EMA + * smoothing, majority vote, and debounce. + * + * @param confidence Output: winning class smoothed probability [0,1]. + * @return Gesture enum value (same as inference_predict() return). + */ +int inference_ensemble_predict(float *confidence); diff --git a/EMG_Arm/src/core/inference_mlp.cc b/EMG_Arm/src/core/inference_mlp.cc new file mode 100644 index 0000000..8cd86f2 --- /dev/null +++ b/EMG_Arm/src/core/inference_mlp.cc @@ -0,0 +1,75 @@ +/** + * @file inference_mlp.cc + * @brief int8 MLP inference via TFLite Micro (Change E). + * + * Compiled as C++ because TFLite Micro headers require C++. + * Guarded by MODEL_USE_MLP — provides compile-time stubs when 0. + */ + +#include "inference_mlp.h" +#include "model_weights.h" + +#if MODEL_USE_MLP + +#include "emg_model_data.h" +#include "tensorflow/lite/micro/micro_interpreter.h" +#include "tensorflow/lite/micro/micro_mutable_op_resolver.h" +#include "tensorflow/lite/schema/schema_generated.h" + +static uint8_t tensor_arena[48 * 1024]; /* 48 KB — tune down if memory is tight */ +static tflite::MicroInterpreter *s_interpreter = nullptr; +static TfLiteTensor *s_input = nullptr; +static TfLiteTensor *s_output = nullptr; + +void inference_mlp_init(void) { + const tflite::Model *model = tflite::GetModel(g_model); + static tflite::MicroMutableOpResolver<4> resolver; + resolver.AddFullyConnected(); + resolver.AddRelu(); + resolver.AddSoftmax(); + resolver.AddDequantize(); + static tflite::MicroInterpreter interp(model, resolver, + tensor_arena, sizeof(tensor_arena)); + s_interpreter = &interp; + s_interpreter->AllocateTensors(); + s_input = s_interpreter->input(0); + s_output = s_interpreter->output(0); +} + +int inference_mlp_predict(const float *features, int n_feat, float *conf_out) { + /* Quantise input: int8 = round(float / scale) + zero_point */ + const float iscale = s_input->params.scale; + const int izp = s_input->params.zero_point; + for (int i = 0; i < n_feat; i++) { + int q = static_cast(roundf(features[i] / iscale)) + izp; + if (q < -128) q = -128; + if (q > 127) q = 127; + s_input->data.int8[i] = static_cast(q); + } + + s_interpreter->Invoke(); + + /* Dequantise output and find winning class */ + const float oscale = s_output->params.scale; + const int ozp = s_output->params.zero_point; + float max_p = -1e9f; + int max_c = 0; + const int n_cls = s_output->dims->data[1]; + for (int c = 0; c < n_cls; c++) { + float p = (s_output->data.int8[c] - ozp) * oscale; + if (p > max_p) { max_p = p; max_c = c; } + } + *conf_out = max_p; + return max_c; +} + +#else /* MODEL_USE_MLP == 0 — compile-time stubs */ + +void inference_mlp_init(void) {} + +int inference_mlp_predict(const float * /*features*/, int /*n_feat*/, float *conf_out) { + if (conf_out) *conf_out = 0.0f; + return 0; +} + +#endif /* MODEL_USE_MLP */ diff --git a/EMG_Arm/src/core/inference_mlp.h b/EMG_Arm/src/core/inference_mlp.h new file mode 100644 index 0000000..559da6d --- /dev/null +++ b/EMG_Arm/src/core/inference_mlp.h @@ -0,0 +1,40 @@ +/** + * @file inference_mlp.h + * @brief int8 MLP inference via TFLite Micro (Change E). + * + * Enable by: + * 1. Setting MODEL_USE_MLP 1 in model_weights.h. + * 2. Running train_mlp_tflite.py to generate emg_model_data.cc. + * 3. Adding TFLite Micro to platformio.ini lib_deps. + * + * When MODEL_USE_MLP 0, both functions are empty stubs that compile without + * TFLite Micro headers. + */ + +#pragma once +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Initialise TFLite Micro interpreter and allocate tensors. + * Must be called before inference_mlp_predict(). + * No-op when MODEL_USE_MLP 0. + */ +void inference_mlp_init(void); + +/** + * @brief Run one forward pass through the int8 MLP. + * + * @param features Float32 feature vector (same order as Python extractor). + * @param n_feat Number of features (must match model input size). + * @param conf_out Output: winning class softmax probability [0,1]. + * @return Winning class index. Returns 0 when MODEL_USE_MLP 0. + */ +int inference_mlp_predict(const float *features, int n_feat, float *conf_out); + +#ifdef __cplusplus +} +#endif diff --git a/EMG_Arm/src/core/model_weights.h b/EMG_Arm/src/core/model_weights.h index 94a35e2..a632a77 100644 --- a/EMG_Arm/src/core/model_weights.h +++ b/EMG_Arm/src/core/model_weights.h @@ -1,7 +1,7 @@ /** * @file model_weights.h * @brief Trained LDA model weights exported from Python. - * @date 2026-01-27 21:35:17 + * @date 2026-03-08 20:27:52 */ #ifndef MODEL_WEIGHTS_H @@ -11,7 +11,14 @@ /* Metadata */ #define MODEL_NUM_CLASSES 5 -#define MODEL_NUM_FEATURES 16 +#define MODEL_NUM_FEATURES 69 +#define MODEL_NORMALIZE_FEATURES 1 + +/* Compile-time feature flags */ +#define MODEL_EXPAND_FEATURES 1 +#define MODEL_USE_REINHARD 1 +#define MODEL_USE_MLP 1 +#define MODEL_USE_ENSEMBLE 1 /* Class Names */ static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = { @@ -28,35 +35,70 @@ static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = { /* LDA Intercepts/Biases */ static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = { - -14.097581f, -2.018629f, -4.478267f, 1.460458f, -5.562349f + -2.361446f, -2.300285f, -3.189113f, -1.744772f, -2.137618f }; /* LDA Coefficients (Weights) */ static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = { /* fist */ { - 0.070110f, -0.002554f, 0.043924f, 0.020555f, -0.660305f, 0.010691f, -0.074429f, -0.037253f, - 0.057908f, -0.002655f, 0.042119f, -0.052956f, 0.063822f, 0.006184f, -0.025462f, 0.040815f, + -1.005012f, 1.102594f, 0.220576f, -0.171785f, 0.160110f, 1.460799f, 0.160097f, 0.023789f, + 0.316695f, 0.551076f, 0.193143f, -0.030601f, -1.068705f, -0.414611f, -0.040128f, 1.215391f, + 0.945437f, 0.545068f, 0.736472f, 0.328106f, 1.212364f, -0.769889f, 0.428595f, -0.033948f, + 0.181605f, -3.236872f, 0.181581f, -0.061825f, 1.166229f, 0.338513f, 0.207578f, 0.041526f, + 0.793306f, 0.256971f, -0.062490f, 2.940555f, 0.251149f, 0.096629f, 0.124341f, -0.125052f, + 0.362799f, 0.430005f, 0.106206f, -0.201260f, 0.529469f, -1.578447f, 0.529456f, -0.040727f, + 1.636684f, 0.129321f, 0.011299f, 0.086495f, 1.510614f, -0.578640f, -0.025043f, 0.275174f, + 0.300145f, 0.366742f, -0.104204f, -0.125212f, 0.819188f, 2.292691f, -0.957544f, -0.321092f, + -0.173973f, 0.306428f, -1.015226f, 0.770284f, 0.797008f }, /* hook_em */ { - -0.002511f, 0.001034f, 0.027889f, 0.026006f, 0.183681f, -0.000773f, 0.016791f, -0.027926f, - -0.023321f, 0.000770f, 0.059023f, -0.056021f, 0.237063f, -0.007423f, 0.082101f, -0.021472f, + 1.694859f, -0.653387f, 0.521604f, 0.251859f, -0.706711f, 1.182055f, -0.706678f, 0.008488f, + 2.450049f, 0.068430f, -0.002051f, 0.214216f, 2.641921f, -0.677963f, -0.020216f, -0.084562f, + 0.314933f, 0.166270f, -0.174888f, 0.023982f, -5.326996f, -0.143884f, -0.165025f, 0.014950f, + 0.167851f, 3.526903f, 0.167886f, 0.139922f, 0.702011f, 0.197141f, 0.087553f, 0.093573f, + 1.414652f, 0.237107f, -0.044205f, -3.707572f, 0.146742f, 0.025417f, -0.085462f, -0.227836f, + 1.745303f, 0.408549f, -0.224446f, 0.081996f, 0.432159f, 0.151780f, 0.432186f, 0.009932f, + -1.447145f, -0.136321f, -0.071784f, -0.118749f, -1.872055f, 0.490270f, 0.017250f, -0.667072f, + -0.381349f, -0.429662f, -0.084705f, -0.123111f, -0.213680f, -3.070284f, 0.267574f, 0.343828f, + 1.024997f, -0.347210f, 1.222141f, 4.897431f, -1.097906f }, /* open */ { - -0.006170f, 0.000208f, -0.041151f, 0.013271f, 0.054508f, -0.002356f, 0.000170f, 0.012941f, - -0.106180f, 0.003538f, -0.013656f, -0.017712f, 0.131131f, -0.002623f, -0.007022f, 0.024497f, + 2.756912f, 0.029446f, -0.408740f, -0.020121f, 0.375563f, -3.479846f, 0.375601f, -0.022014f, + -1.357556f, -0.445301f, -0.119886f, -0.020492f, -0.494526f, 0.247305f, 0.028053f, -0.106328f, + -0.585811f, -0.405593f, -0.486449f, -0.319108f, -3.159308f, 0.871359f, 0.611153f, -0.191602f, + 0.059082f, 4.326131f, 0.059172f, -0.006684f, -0.502170f, -0.110067f, -0.057217f, -0.224138f, + -1.214966f, -1.056423f, -0.019260f, -0.878641f, 0.167700f, 0.576919f, -0.143247f, -0.914332f, + -2.970328f, -0.403648f, -0.014340f, 0.048198f, 0.332494f, 0.479158f, 0.332573f, -0.002956f, + -0.351570f, 0.181890f, 0.109983f, -0.013305f, 0.133854f, -0.180451f, -0.046355f, -0.355403f, + -0.123445f, 0.146366f, 0.317895f, 0.342519f, -0.399778f, -6.284204f, 0.422619f, -0.109523f, + -0.996114f, 0.059564f, 0.043680f, -3.046475f, -0.178861f }, /* rest */ { - -0.011094f, 0.000160f, -0.012547f, -0.011058f, 0.130577f, -0.001942f, 0.020823f, -0.001961f, - 0.018021f, -0.000404f, -0.065598f, 0.039676f, 0.018679f, -0.001522f, 0.023302f, -0.008474f, + -1.609762f, -0.516942f, -0.285630f, -0.108276f, 0.196206f, 0.488562f, 0.196196f, 0.026133f, + -1.189232f, -0.034391f, -0.062033f, -0.130823f, -1.033252f, 0.697769f, 0.032136f, -0.605629f, + -0.007999f, -0.097476f, 0.094768f, 0.333748f, 1.317443f, -0.880336f, -0.213594f, -0.073489f, + -0.207093f, -4.451737f, -0.207141f, 0.000240f, -1.645825f, -0.245072f, -0.026848f, 0.029140f, + -0.364751f, 0.395671f, 0.066683f, 3.611272f, -0.209066f, -0.214834f, 0.027566f, 0.514329f, + 1.066669f, -0.684801f, -0.070219f, -0.046770f, -0.494221f, 0.722863f, -0.494243f, 0.008539f, + -0.124061f, -0.075996f, -0.084995f, -0.012796f, 0.291955f, 0.121964f, 0.042458f, -0.627127f, + 0.001798f, -0.199790f, -0.153351f, 0.027358f, -0.117621f, 1.819176f, 0.092071f, -0.017248f, + 0.251629f, 0.021265f, -0.106236f, 0.495258f, 0.069582f }, /* thumbs_up */ { - -0.016738f, 0.000488f, 0.024199f, -0.024643f, -0.044912f, 0.000153f, -0.011080f, 0.043487f, - 0.051828f, -0.001670f, 0.109633f, 0.004154f, -0.460694f, 0.008616f, -0.104097f, -0.020886f, + -0.792044f, 0.374311f, 0.170137f, 0.127379f, -0.184879f, 0.153073f, -0.184920f, -0.053070f, + 0.682094f, -0.099986f, 0.036795f, 0.061637f, 0.725941f, -0.348358f, -0.022970f, -0.004588f, + -0.632688f, -0.124638f, -0.221206f, -0.580598f, 5.030162f, 1.484454f, -0.536655f, 0.339228f, + -0.058585f, 2.787274f, -0.058607f, -0.068406f, 1.425498f, -0.005883f, -0.188273f, 0.049091f, + -0.309501f, -0.063155f, 0.013505f, -4.453858f, -0.215617f, -0.354558f, 0.060673f, 0.425839f, + -0.794436f, 0.734764f, 0.245384f, 0.149158f, -0.463152f, -0.279041f, -0.463210f, 0.019611f, + 0.350645f, -0.055642f, 0.087931f, 0.064769f, -0.304954f, 0.079455f, -0.015292f, 1.791572f, + 0.196916f, 0.237732f, 0.116177f, -0.153250f, 0.000654f, 4.127202f, 0.103829f, 0.125913f, + -0.223239f, -0.063374f, -0.047960f, -3.238916f, 0.344427f }, }; diff --git a/EMG_Arm/src/core/model_weights_ensemble.h b/EMG_Arm/src/core/model_weights_ensemble.h new file mode 100644 index 0000000..640d6b5 --- /dev/null +++ b/EMG_Arm/src/core/model_weights_ensemble.h @@ -0,0 +1,65 @@ +// Auto-generated by train_ensemble.py do not edit +#pragma once + +// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from +// model_weights.h to avoid redefinition conflicts. +#include "model_weights.h" + +#define ENSEMBLE_PER_CH_FEATURES 20 + +#define TD_FEAT_OFFSET 0 +#define TD_NUM_FEATURES 36 +#define FD_FEAT_OFFSET 12 +#define FD_NUM_FEATURES 24 +#define CC_FEAT_OFFSET 60 +#define CC_NUM_FEATURES 9 +#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES) + +// Feature index arrays for gather operations (TD and FD are non-contiguous) +// TD indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51] +// FD indices: [12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39, 52, 53, 54, 55, 56, 57, 58, 59] +// CC indices: [60, 61, 62, 63, 64, 65, 66, 67, 68] + +const float LDA_TD_WEIGHTS[5][36] = { + {-4.98005951f, 3.15698227f, -0.01656110f, 0.16589549f, 1.89759250f, 4.02171634f, 1.89759264f, -0.02318365f, 0.68313384f, 1.08146976f, 0.24790972f, 0.02111336f, 1.40219525f, -0.34718600f, 0.03292437f, 0.03224236f, 0.07318621f, -1.47505917f, 0.07318595f, -0.15298836f, 0.68723991f, 0.31478366f, 0.26694627f, 0.03383652f, 1.16109002f, -0.17694826f, -0.02001825f, 0.04811623f, 0.11145582f, 0.98011591f, 0.11145584f, 0.02078491f, 0.16441538f, -0.37776186f, -0.14062534f, -0.20666287f}, // fist + {7.06735173f, -0.07584252f, 0.28878350f, 0.25548171f, -2.72040928f, -1.14291179f, -2.72040944f, -0.05457462f, 0.67944659f, -1.72471154f, -0.56168071f, 0.08847797f, -0.36636016f, -0.92733589f, 0.07186714f, -0.15407635f, 0.09285758f, -0.99197678f, 0.09285799f, 0.24903098f, -0.58741170f, -1.69916941f, -0.79791300f, -0.51257426f, -4.23844985f, -1.30822564f, 0.08744698f, -0.10944734f, 1.12572192f, 0.67459317f, 1.12572190f, 0.12677750f, -0.06012924f, 0.74169191f, 0.20452265f, 0.02420439f}, // hook_em + {-5.68183942f, -2.70544312f, -0.02268383f, -0.07857558f, 0.94259344f, -1.41885234f, 0.94259353f, 0.12889098f, -0.30338581f, 0.37920265f, 0.06556953f, 0.00592915f, -6.03034018f, 1.95694619f, 0.22083802f, 0.33970496f, 0.85733980f, 4.45235717f, 0.85733951f, 0.16720219f, 0.59151687f, 1.12441071f, 0.07520788f, -0.44451340f, -0.73662069f, 1.71355275f, -0.19822542f, 0.32955834f, 0.05661870f, -1.29613804f, 0.05661887f, -0.21196686f, -0.61895423f, -1.71535920f, -0.73493998f, -0.21009359f}, // open + {0.20260846f, -0.23909181f, -0.28033711f, -0.22450103f, 0.13810806f, -0.88146743f, 0.13810809f, 0.01613534f, -0.34304575f, 1.18327519f, 0.36712706f, -0.09916758f, 1.33818062f, -0.49450303f, 0.03771161f, -0.21199101f, -0.57758751f, -0.06415963f, -0.57758750f, -0.11776329f, -0.49294313f, 0.47593171f, 0.42732005f, 0.38381135f, 0.93751846f, -0.40085712f, -0.01569340f, -0.21022401f, -0.70330953f, 0.29469121f, -0.70330957f, -0.03953274f, -0.19099800f, 1.11867928f, 0.51910780f, 0.26138187f}, // rest + {3.56990173f, 0.11258470f, 0.22746547f, 0.04228335f, -0.43704307f, 0.04449194f, -0.43704318f, -0.08323209f, -0.45652127f, -1.76552235f, -0.38107651f, 0.05260073f, 2.92141069f, 0.06553870f, -0.39303365f, 0.12369768f, -0.07810754f, -2.03448136f, -0.07810740f, -0.06529502f, 0.10545265f, -0.60931490f, -0.28035188f, 0.28069772f, 2.16483999f, 0.36093444f, 0.16473959f, 0.07185201f, -0.08960941f, -0.79422744f, -0.08960951f, 0.13979645f, 0.85101150f, -0.45589633f, -0.17267249f, -0.03933928f}, // thumbs_up +}; +const float LDA_TD_INTERCEPTS[5] = { + -5.59817408f, -4.91484804f, -9.33780358f, -3.85428822f, -3.12479496f +}; + +const float LDA_FD_WEIGHTS[5][24] = { + {0.19477058f, 0.33798486f, 0.24311717f, 3.58490717f, 0.37449334f, 0.01733587f, 0.99771453f, 0.35846897f, -0.07609940f, -0.08551280f, -0.04212976f, -1.64981087f, -0.28267724f, -0.45724067f, 0.09742739f, -0.45412877f, -0.67354231f, 0.07763610f, -0.03042757f, 1.23057256f, 0.47423480f, 0.96870723f, 0.09008234f, 0.61253227f}, // fist + {-1.33051949f, -1.06668043f, -0.13054796f, 1.49029766f, 0.34559104f, 2.58938944f, 2.10036791f, 0.00627139f, 0.67799858f, -0.05916871f, -0.02308780f, -0.51389981f, 0.49698104f, 0.52080687f, -0.53642075f, -0.44702479f, 0.45831951f, -0.05527688f, -0.12666255f, -2.29960756f, 0.15527345f, -0.60190848f, -0.50386274f, -0.20863809f}, // hook_em + {0.86684408f, -0.18732078f, 0.04828207f, -4.18979500f, -0.25709076f, -1.34726223f, -1.93439169f, -0.36975670f, -0.57908022f, 0.40713764f, 0.14203746f, 3.81070011f, 0.15930714f, 0.81375493f, 0.17586089f, 1.21709829f, 0.30729211f, -0.06304829f, 0.23529519f, 2.10099132f, -1.15363931f, -0.25856679f, 0.65430329f, -0.77828374f}, // open + {0.10167142f, 0.92873579f, -0.03451580f, -0.35104059f, 0.33898659f, -0.99047210f, -1.06006192f, 0.02907091f, -0.22314490f, 0.33956274f, -0.07111302f, 0.31329279f, -0.56910921f, -0.79014171f, -0.38388114f, 0.01131879f, 0.87387134f, -0.12380692f, 0.05214671f, -0.42582976f, -0.17031139f, -1.06661596f, -0.91600447f, -0.26435070f}, // rest + {0.04198304f, -0.65837713f, -0.10663077f, -0.12701858f, -1.01290184f, 0.50100717f, 0.72188030f, -0.03150884f, 0.38356910f, -0.84413736f, 0.03745147f, -2.29871124f, 0.58565208f, 0.43331566f, 0.88777658f, -0.38229432f, -1.55494079f, 0.24865587f, -0.17540676f, -0.43032659f, 0.84763042f, 1.67359692f, 1.26211059f, 0.83625332f}, // thumbs_up +}; +const float LDA_FD_INTERCEPTS[5] = { + -5.22055861f, -3.76092770f, -8.23556500f, -3.59434123f, -2.96625341f +}; + +const float LDA_CC_WEIGHTS[5][9] = { + {-0.94928369f, 2.44942094f, 0.26441444f, -0.80355120f, -1.19821936f, 1.39871694f, -2.06386346f, 0.59615281f, 1.57632161f}, // fist + {-2.17638786f, 1.40504639f, 1.97661862f, 0.99132032f, 1.46231208f, -0.77606652f, 2.43631056f, 1.58734751f, -1.75342004f}, // hook_em + {-0.15938422f, -2.49617253f, -0.43491962f, 0.32764670f, -3.51157144f, -0.74698502f, 2.59991544f, -0.64230610f, -3.23294039f}, // open + {2.78062805f, -1.24792206f, -2.03404274f, -0.13924844f, 2.03596007f, -0.25590903f, -0.40921478f, -0.84411543f, 0.27031906f}, // rest + {-1.42317081f, 0.84654172f, 1.66121800f, -0.27035074f, -0.03025088f, 0.56047099f, -2.30886825f, -0.06801108f, 3.01175668f}, // thumbs_up +}; +const float LDA_CC_INTERCEPTS[5] = { + -2.77069779f, -4.07631951f, -6.23624226f, -1.94638886f, -2.30392172f +}; + +const float META_LDA_WEIGHTS[5][15] = { + {13.83327682f, -4.99575141f, -8.19270319f, -2.93944549f, 0.75007499f, 3.36110029f, 0.53962472f, -4.82143659f, -1.61056244f, 1.48044754f, 1.56410113f, -0.98537108f, 0.20587945f, 0.39305391f, -1.42106273f}, // fist + {-4.64514194f, 13.09553405f, -4.41874813f, -1.25841220f, -2.98757893f, 0.16270053f, 3.15820296f, -2.24183188f, -1.32266753f, 0.20766953f, -0.65827396f, 1.59311387f, -1.22803182f, 0.04336258f, -0.24531550f}, // hook_em + {-5.79187671f, -2.27375570f, 23.21490262f, -0.09093684f, -4.81706827f, -1.92030254f, -1.47383708f, 9.25922135f, 0.57114390f, -2.76993062f, -1.31290246f, -1.07932832f, 5.46202897f, -0.72440402f, -0.19414883f}, // open + {-2.87608878f, -1.27035778f, -3.11538939f, 4.40275586f, -1.44999243f, -0.49668812f, -1.24040434f, 0.65617176f, 1.79490014f, -1.52075301f, 0.08593208f, 0.38743884f, -1.35293868f, 0.30551755f, -0.06438460f}, // rest + {1.52549481f, -3.33085871f, -6.18851152f, -3.12231842f, 9.54599696f, -0.69931080f, -0.02514346f, -3.63196650f, -0.69837189f, 3.71692816f, 0.29062506f, -0.11110049f, -2.35881642f, -0.20019868f, 1.96143145f}, // thumbs_up +}; +const float META_LDA_INTERCEPTS[5] = { + -8.46795016f, -8.64988671f, -21.16256224f, -3.22663037f, -6.35767829f +}; diff --git a/EMG_Arm/src/drivers/emg_sensor.c b/EMG_Arm/src/drivers/emg_sensor.c index fb06fc9..d7ba970 100644 --- a/EMG_Arm/src/drivers/emg_sensor.c +++ b/EMG_Arm/src/drivers/emg_sensor.c @@ -1,109 +1,180 @@ /** * @file emg_sensor.c - * @brief EMG sensor driver implementation. + * @brief EMG sensor driver — DMA-backed continuous ADC acquisition. * - * Provides EMG readings - fake data for now, real ADC when sensors arrive. + * Change A: adc_continuous (DMA) replaces adc_oneshot polling. + * A background FreeRTOS task on Core 0 reads DMA frames, assembles + * complete 4-channel sample sets, and pushes them to a FreeRTOS queue. + * emg_sensor_read() blocks on that queue; at 1 kHz per channel this + * returns within ~1 ms, providing exact timing without vTaskDelay(1). */ #include "emg_sensor.h" #include "esp_timer.h" -#include -#include #include "freertos/FreeRTOS.h" #include "freertos/task.h" -#include "esp_adc/adc_oneshot.h" +#include "freertos/queue.h" +#include "esp_adc/adc_continuous.h" #include "esp_adc/adc_cali.h" #include "esp_adc/adc_cali_scheme.h" #include "esp_err.h" +#include +#include +#include +#include -adc_oneshot_unit_handle_t adc1_handle; -adc_cali_handle_t cali_handle = NULL; +// --- ADC DMA constants --- +// Total sample rate: 4 channels × 1 kHz = 4 kHz +#define ADC_TOTAL_SAMPLE_RATE_HZ (EMG_NUM_CHANNELS * EMG_SAMPLE_RATE_HZ) +// DMA frame: 64 conversions × 4 bytes = 256 bytes, arrives ~every 16 ms +#define ADC_CONV_FRAME_SIZE (64u * sizeof(adc_digi_output_data_t)) +// Internal DMA pool: 2× conv_frame_size +#define ADC_POOL_SIZE (2u * ADC_CONV_FRAME_SIZE) +// Sample queue: buffers up to ~32 ms of assembled 4-channel sets +#define SAMPLE_QUEUE_DEPTH 32 -const uint8_t emg_channels[EMG_NUM_CHANNELS] = { - ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0 - ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1 - ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2 - ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3 +// --- Static handles --- +static adc_continuous_handle_t s_adc_handle = NULL; +static adc_cali_handle_t s_cali_handle = NULL; +static QueueHandle_t s_sample_queue = NULL; + +// Channel mapping (ADC1, GPIO 2/3/9/10) +static const adc_channel_t s_channels[EMG_NUM_CHANNELS] = { + ADC_CHANNEL_1, // GPIO 2 — FCR/Belly (ch0) + ADC_CHANNEL_2, // GPIO 3 — Extensors (ch1) + ADC_CHANNEL_8, // GPIO 9 — FCU/Outer Flexors (ch2) + ADC_CHANNEL_9, // GPIO 10 — Bicep (ch3) }; +/******************************************************************************* + * ADC Sampling Task (Core 0) + ******************************************************************************/ + +/** + * @brief DMA sampling task. + * + * Reads adc_continuous DMA frames, assembles complete 4-channel sample sets + * (one value per channel), applies curve-fitting calibration, and pushes + * emg_sample_t structs to s_sample_queue. Runs continuously on Core 0. + */ +static void adc_sampling_task(void *arg) { + uint8_t *buf = (uint8_t *)malloc(ADC_CONV_FRAME_SIZE); + if (!buf) { + printf("[EMG] ERROR: DMA read buffer alloc failed\n"); + vTaskDelete(NULL); + return; + } + + // Per-channel raw accumulator; holds the latest sample for each channel + int raw_set[EMG_NUM_CHANNELS]; + bool got[EMG_NUM_CHANNELS]; + memset(raw_set, 0, sizeof(raw_set)); + memset(got, 0, sizeof(got)); + + while (1) { + uint32_t out_len = 0; + esp_err_t err = adc_continuous_read( + s_adc_handle, buf, (uint32_t)ADC_CONV_FRAME_SIZE, + &out_len, pdMS_TO_TICKS(100) + ); + if (err != ESP_OK || out_len == 0) { + continue; + } + + uint32_t n = out_len / sizeof(adc_digi_output_data_t); + adc_digi_output_data_t *p = (adc_digi_output_data_t *)buf; + + for (uint32_t i = 0; i < n; i++) { + int ch = (int)p[i].type2.channel; + int raw = (int)p[i].type2.data; + + // Map ADC channel index to sensor channel + for (int c = 0; c < EMG_NUM_CHANNELS; c++) { + if ((int)s_channels[c] == ch) { + raw_set[c] = raw; + got[c] = true; + break; + } + } + + // Emit a complete sample set once all channels have been updated + bool all = true; + for (int c = 0; c < EMG_NUM_CHANNELS; c++) { + if (!got[c]) { all = false; break; } + } + if (!all) continue; + + emg_sample_t s; + s.timestamp_ms = emg_sensor_get_timestamp_ms(); + for (int c = 0; c < EMG_NUM_CHANNELS; c++) { + int mv = 0; + adc_cali_raw_to_voltage(s_cali_handle, raw_set[c], &mv); + s.channels[c] = (uint16_t)mv; + } + + // Non-blocking send: drop if queue is full (prefer fresh data) + xQueueSend(s_sample_queue, &s, 0); + + // Reset accumulator for the next complete set + memset(got, 0, sizeof(got)); + } + } +} + /******************************************************************************* * Public Functions ******************************************************************************/ -void emg_sensor_init(void) -{ -#if FEATURE_FAKE_EMG - /* Seed random number generator for fake data */ - srand((unsigned int)esp_timer_get_time()); -#else - // 1. --- ADC Unit Setup --- - adc_oneshot_unit_init_cfg_t init_config1 = { - .unit_id = ADC_UNIT_1, - .ulp_mode = ADC_ULP_MODE_DISABLE, - }; - ESP_ERROR_CHECK(adc_oneshot_new_unit(&init_config1, &adc1_handle)); - - // 2. --- ADC Channel Setup (GPIO 1?) --- - // Ensure the channel matches your GPIO in pinmap. For ADC1, GPIO1 is usually not CH0. - // Check your datasheet! (e.g., on S3, GPIO 1 is ADC1_CH0) - adc_oneshot_chan_cfg_t config = { - .bitwidth = ADC_BITWIDTH_DEFAULT, // 12-bit for S3 - .atten = ADC_ATTEN_DB_12, // Allows up to ~3.1V - }; - for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) - ESP_ERROR_CHECK(adc_oneshot_config_channel(adc1_handle, emg_channels[i], &config)); - - // 3. --- Calibration Setup (CORRECTED for S3) --- - // ESP32-S3 uses Curve Fitting, not Line Fitting - adc_cali_curve_fitting_config_t cali_config = { - .unit_id = ADC_UNIT_1, - .atten = ADC_ATTEN_DB_12, +void emg_sensor_init(void) { + // 1. Curve-fitting calibration (same scheme as before) + adc_cali_curve_fitting_config_t cali_cfg = { + .unit_id = ADC_UNIT_1, + .atten = ADC_ATTEN_DB_12, .bitwidth = ADC_BITWIDTH_DEFAULT, }; - ESP_ERROR_CHECK(adc_cali_create_scheme_curve_fitting(&cali_config, &cali_handle)); - - // while (1) { - // int raw_val, voltage_mv; - - // // Read Raw - // ESP_ERROR_CHECK(adc_oneshot_read(adc1_handle, ADC_CHANNEL_1, &raw_val)); - - // // Convert to mV using calibration - // ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv)); + ESP_ERROR_CHECK(adc_cali_create_scheme_curve_fitting(&cali_cfg, &s_cali_handle)); - // printf("Raw: %d | Voltage: %d mV\n", raw_val, voltage_mv); - // vTaskDelay(pdMS_TO_TICKS(500)); - // } -#endif -} + // 2. Create continuous ADC handle + adc_continuous_handle_cfg_t adc_cfg = { + .max_store_buf_size = ADC_POOL_SIZE, + .conv_frame_size = ADC_CONV_FRAME_SIZE, + }; + ESP_ERROR_CHECK(adc_continuous_new_handle(&adc_cfg, &s_adc_handle)); -void emg_sensor_read(emg_sample_t *sample) -{ - sample->timestamp_ms = emg_sensor_get_timestamp_ms(); - -#if FEATURE_FAKE_EMG - /* - * Generate fake EMG data: - * - Base value around 1650 (middle of 3.3V millivolt range) - * - Random noise of +/- 50 - * - Mimics real EMG baseline noise - */ + // 3. Configure scan pattern (4 channels, 4 kHz total) + adc_digi_pattern_config_t patterns[EMG_NUM_CHANNELS]; for (int i = 0; i < EMG_NUM_CHANNELS; i++) { - int noise = (rand() % 101) - 50; /* -50 to +50 */ - sample->channels[i] = (uint16_t)(1650 + noise); - } -#else - int raw_val, voltage_mv; - for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) { - ESP_ERROR_CHECK(adc_oneshot_read(adc1_handle, emg_channels[i], &raw_val)); - ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv)); - sample->channels[i] = (uint16_t) voltage_mv; + patterns[i].atten = ADC_ATTEN_DB_12; + patterns[i].channel = (uint8_t)s_channels[i]; + patterns[i].unit = ADC_UNIT_1; + patterns[i].bit_width = ADC_BITWIDTH_12; } + adc_continuous_config_t cont_cfg = { + .pattern_num = EMG_NUM_CHANNELS, + .adc_pattern = patterns, + .sample_freq_hz = ADC_TOTAL_SAMPLE_RATE_HZ, + .conv_mode = ADC_CONV_SINGLE_UNIT_1, + .format = ADC_DIGI_OUTPUT_FORMAT_TYPE2, + }; + ESP_ERROR_CHECK(adc_continuous_config(s_adc_handle, &cont_cfg)); -#endif + // 4. Start DMA acquisition + ESP_ERROR_CHECK(adc_continuous_start(s_adc_handle)); + + // 5. Create sample queue (assembled complete sets land here) + s_sample_queue = xQueueCreate(SAMPLE_QUEUE_DEPTH, sizeof(emg_sample_t)); + assert(s_sample_queue != NULL); + + // 6. Launch sampling task pinned to Core 0 + xTaskCreatePinnedToCore(adc_sampling_task, "adc_sample", 4096, NULL, 6, NULL, 0); } -uint32_t emg_sensor_get_timestamp_ms(void) -{ +void emg_sensor_read(emg_sample_t *sample) { + // Block until a complete 4-channel sample set arrives from the DMA task. + // At 1 kHz per channel this typically returns within ~1 ms. + xQueueReceive(s_sample_queue, sample, portMAX_DELAY); +} + +uint32_t emg_sensor_get_timestamp_ms(void) { return (uint32_t)(esp_timer_get_time() / 1000); } diff --git a/EMG_Arm/src/drivers/emg_sensor.h b/EMG_Arm/src/drivers/emg_sensor.h index cec5752..5822945 100644 --- a/EMG_Arm/src/drivers/emg_sensor.h +++ b/EMG_Arm/src/drivers/emg_sensor.h @@ -2,9 +2,8 @@ * @file emg_sensor.h * @brief EMG sensor driver for reading muscle signals. * - * This module provides EMG data acquisition. Currently generates fake - * data for testing (FEATURE_FAKE_EMG=1). When sensors arrive, the - * implementation switches to real ADC reads without changing the interface. + * This module provides EMG data acquisition from ADC channels connected + * to MyoWare sensors. Outputs calibrated millivolt values (0-3300 mV). * * @note This is Layer 2 (Driver). */ @@ -34,8 +33,7 @@ typedef struct { /** * @brief Initialize the EMG sensor system. * - * If FEATURE_FAKE_EMG is enabled, just seeds the random generator. - * Otherwise, configures ADC channels for real sensor reading. + * Configures ADC channels and calibration for real sensor reading. */ void emg_sensor_init(void); diff --git a/EMG_Arm/src/drivers/hand.c b/EMG_Arm/src/drivers/hand.c index 9731428..582799f 100644 --- a/EMG_Arm/src/drivers/hand.c +++ b/EMG_Arm/src/drivers/hand.c @@ -8,6 +8,9 @@ #include "hand.h" #include "hal/servo_hal.h" +float maxAngles[] = {155, 155, 180, 165, 150}; +float minAngles[] = {65, 45, 45, 30, 25}; + /******************************************************************************* * Public Functions ******************************************************************************/ diff --git a/EMG_Arm/src/drivers/hand.h b/EMG_Arm/src/drivers/hand.h index 7b2f171..59ae41f 100644 --- a/EMG_Arm/src/drivers/hand.h +++ b/EMG_Arm/src/drivers/hand.h @@ -13,6 +13,8 @@ #include "config/config.h" +extern float maxAngles[]; +extern float minAngles[]; /******************************************************************************* * Public Functions ******************************************************************************/ diff --git a/EMG_Arm/src/hal/servo_hal.c b/EMG_Arm/src/hal/servo_hal.c index 729554e..41ae17d 100644 --- a/EMG_Arm/src/hal/servo_hal.c +++ b/EMG_Arm/src/hal/servo_hal.c @@ -55,7 +55,7 @@ void servo_hal_init(void) .timer_sel = SERVO_PWM_TIMER, .intr_type = LEDC_INTR_DISABLE, .gpio_num = servo_pins[i], - .duty = SERVO_DUTY_MIN, /* Start extended (open) */ + .duty = servo_hal_degrees_to_duty(90), /* Start extended (open) */ .hpoint = 0 }; ESP_ERROR_CHECK(ledc_channel_config(&channel_config)); diff --git a/collected_data/new_system_000_20260308_180810.hdf5 b/collected_data/new_system_000_20260308_180810.hdf5 new file mode 100644 index 0000000..bde701a Binary files /dev/null and b/collected_data/new_system_000_20260308_180810.hdf5 differ diff --git a/collected_data/new_system_001_20260308_185810.hdf5 b/collected_data/new_system_001_20260308_185810.hdf5 new file mode 100644 index 0000000..fd33c8c Binary files /dev/null and b/collected_data/new_system_001_20260308_185810.hdf5 differ diff --git a/collected_data/new_system_002_20260308_191206.hdf5 b/collected_data/new_system_002_20260308_191206.hdf5 new file mode 100644 index 0000000..0bf36a7 Binary files /dev/null and b/collected_data/new_system_002_20260308_191206.hdf5 differ diff --git a/collected_data/updated001_20260214_195555.hdf5 b/collected_data/updated001_20260214_195555.hdf5 new file mode 100644 index 0000000..a9b53f6 Binary files /dev/null and b/collected_data/updated001_20260214_195555.hdf5 differ diff --git a/collected_data/updated002_20260214_195732.hdf5 b/collected_data/updated002_20260214_195732.hdf5 new file mode 100644 index 0000000..800c90e Binary files /dev/null and b/collected_data/updated002_20260214_195732.hdf5 differ diff --git a/collected_data/updated003_20260214_200039.hdf5 b/collected_data/updated003_20260214_200039.hdf5 new file mode 100644 index 0000000..0284a12 Binary files /dev/null and b/collected_data/updated003_20260214_200039.hdf5 differ diff --git a/collected_data/updated004_20260214_200216.hdf5 b/collected_data/updated004_20260214_200216.hdf5 new file mode 100644 index 0000000..3698ed6 Binary files /dev/null and b/collected_data/updated004_20260214_200216.hdf5 differ diff --git a/collected_data/updated005_20260214_202724.hdf5 b/collected_data/updated005_20260214_202724.hdf5 new file mode 100644 index 0000000..7659696 Binary files /dev/null and b/collected_data/updated005_20260214_202724.hdf5 differ diff --git a/collected_data/updated006_20260214_202910.hdf5 b/collected_data/updated006_20260214_202910.hdf5 new file mode 100644 index 0000000..08d4ed6 Binary files /dev/null and b/collected_data/updated006_20260214_202910.hdf5 differ diff --git a/collected_data/updated007_20260214_203049.hdf5 b/collected_data/updated007_20260214_203049.hdf5 new file mode 100644 index 0000000..8045549 Binary files /dev/null and b/collected_data/updated007_20260214_203049.hdf5 differ diff --git a/collected_data/updated008_20260214_203228.hdf5 b/collected_data/updated008_20260214_203228.hdf5 new file mode 100644 index 0000000..88bb4cb Binary files /dev/null and b/collected_data/updated008_20260214_203228.hdf5 differ diff --git a/collected_data/updated009_20260214_203612.hdf5 b/collected_data/updated009_20260214_203612.hdf5 new file mode 100644 index 0000000..b19f5a5 Binary files /dev/null and b/collected_data/updated009_20260214_203612.hdf5 differ diff --git a/collected_data/updated010_20260214_204204.hdf5 b/collected_data/updated010_20260214_204204.hdf5 new file mode 100644 index 0000000..5a85d10 Binary files /dev/null and b/collected_data/updated010_20260214_204204.hdf5 differ diff --git a/collected_data/updated011_20260214_212146.hdf5 b/collected_data/updated011_20260214_212146.hdf5 new file mode 100644 index 0000000..fa99c1e Binary files /dev/null and b/collected_data/updated011_20260214_212146.hdf5 differ diff --git a/collected_data/updated012_20260214_212732.hdf5 b/collected_data/updated012_20260214_212732.hdf5 new file mode 100644 index 0000000..33e8d00 Binary files /dev/null and b/collected_data/updated012_20260214_212732.hdf5 differ diff --git a/collected_data/updated013_20260214_212957.hdf5 b/collected_data/updated013_20260214_212957.hdf5 new file mode 100644 index 0000000..72e7907 Binary files /dev/null and b/collected_data/updated013_20260214_212957.hdf5 differ diff --git a/collected_data/updated014_20260214_213133.hdf5 b/collected_data/updated014_20260214_213133.hdf5 new file mode 100644 index 0000000..db3e366 Binary files /dev/null and b/collected_data/updated014_20260214_213133.hdf5 differ diff --git a/collected_data/updated015_20260214_213536.hdf5 b/collected_data/updated015_20260214_213536.hdf5 new file mode 100644 index 0000000..397be8c Binary files /dev/null and b/collected_data/updated015_20260214_213536.hdf5 differ diff --git a/emg_gui.py b/emg_gui.py index 4849611..7c01f34 100644 --- a/emg_gui.py +++ b/emg_gui.py @@ -23,24 +23,26 @@ from tkinter import messagebox import threading import queue import time +import sys +import subprocess import numpy as np from pathlib import Path from datetime import datetime import matplotlib.pyplot as plt from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure import Figure -from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis # Import from the existing pipeline from learning_data_collection import ( # Configuration - NUM_CHANNELS, SAMPLING_RATE_HZ, WINDOW_SIZE_MS, WINDOW_OVERLAP, - GESTURE_HOLD_SEC, REST_BETWEEN_SEC, REPS_PER_GESTURE, DATA_DIR, USER_ID, + NUM_CHANNELS, SAMPLING_RATE_HZ, WINDOW_SIZE_MS, HOP_SIZE_MS, HAND_CHANNELS, + GESTURE_HOLD_SEC, REST_BETWEEN_SEC, REPS_PER_GESTURE, DATA_DIR, MODEL_DIR, USER_ID, # Classes EMGSample, EMGWindow, EMGParser, Windower, - GestureAwareEMGStream, SimulatedEMGStream, PromptScheduler, SessionStorage, SessionMetadata, - EMGFeatureExtractor, EMGClassifier, PredictionSmoother, + EMGFeatureExtractor, EMGClassifier, PredictionSmoother, CalibrationTransform, + LABEL_SHIFT_MS, ) # Import real serial stream for ESP32 hardware @@ -63,6 +65,10 @@ GESTURE_COLORS = { "thumbs_up": "#28a745", # Green } +CALIB_PREP_SEC = 3 # Seconds of "get ready" countdown before each gesture +CALIB_DURATION_SEC = 5.0 # Seconds to hold each gesture during calibration + + def get_gesture_color(gesture_name: str) -> str: """Get color for a gesture name.""" for key, color in GESTURE_COLORS.items(): @@ -99,11 +105,16 @@ class EMGApp(ctk.CTk): self.page_container.grid_columnconfigure(0, weight=1) self.page_container.grid_rowconfigure(0, weight=1) + # Calibrated classifier shared between CalibrationPage and PredictionPage. + # Set by CalibrationPage._apply_calibration(), read by PredictionPage. + self.calibrated_classifier = None + # Create pages self.pages = {} self.pages["collection"] = CollectionPage(self.page_container) self.pages["inspect"] = InspectPage(self.page_container) self.pages["training"] = TrainingPage(self.page_container) + self.pages["calibration"] = CalibrationPage(self.page_container) self.pages["prediction"] = PredictionPage(self.page_container) self.pages["visualization"] = VisualizationPage(self.page_container) @@ -170,8 +181,9 @@ class Sidebar(ctk.CTkFrame): ("collection", "1. Collect Data"), ("inspect", "2. Inspect Sessions"), ("training", "3. Train Model"), - ("prediction", "4. Live Prediction"), - ("visualization", "5. Visualize LDA"), + ("calibration", "4. Calibrate"), + ("prediction", "5. Live Prediction"), + ("visualization", "6. Visualize LDA"), ] for page_id, label in nav_items: @@ -227,9 +239,9 @@ class Sidebar(ctk.CTkFrame): sessions = storage.list_sessions() self.session_count_label.configure(text=f"Sessions: {len(sessions)}") - model_path = EMGClassifier.get_default_model_path() - if model_path.exists(): - self.model_status_label.configure(text="Model: Saved", text_color="green") + model_path = EMGClassifier.get_latest_model_path() + if model_path: + self.model_status_label.configure(text=f"Model: {model_path.stem}", text_color="green") else: self.model_status_label.configure(text="Model: Not saved", text_color="gray") @@ -294,7 +306,7 @@ class CollectionPage(BasePage): # Collection state (MUST be initialized BEFORE setup_controls) self.is_collecting = False self.is_connected = False - self.using_real_hardware = False + self.using_real_hardware = True # Always use real ESP32 hardware self.stream = None self.parser = None self.windower = None @@ -334,34 +346,14 @@ class CollectionPage(BasePage): self.user_id_entry.pack(fill="x", pady=(5, 0)) self.user_id_entry.insert(0, USER_ID) - # Data Source selection + # ESP32 Connection (hardware required) source_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent") source_frame.pack(fill="x", padx=20, pady=10) - ctk.CTkLabel(source_frame, text="Data Source:", font=ctk.CTkFont(size=14)).pack(anchor="w") + ctk.CTkLabel(source_frame, text="ESP32 Connection:", font=ctk.CTkFont(size=14)).pack(anchor="w") - self.source_var = ctk.StringVar(value="simulated") - - radio_frame = ctk.CTkFrame(source_frame, fg_color="transparent") - radio_frame.pack(fill="x", pady=(5, 0)) - - self.sim_radio = ctk.CTkRadioButton( - radio_frame, text="Simulated", variable=self.source_var, value="simulated", - command=self._on_source_change - ) - self.sim_radio.pack(side="left", padx=(0, 20)) - - self.real_radio = ctk.CTkRadioButton( - radio_frame, text="Real ESP32", variable=self.source_var, value="real", - command=self._on_source_change - ) - self.real_radio.pack(side="left") - - # Port selection (initially hidden, shown when "Real ESP32" selected) - self.port_frame = ctk.CTkFrame(source_frame, fg_color="transparent") - # Don't pack yet - _on_source_change will handle visibility - - port_select_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent") + # Port selection + port_select_frame = ctk.CTkFrame(source_frame, fg_color="transparent") port_select_frame.pack(fill="x", pady=(5, 0)) ctk.CTkLabel(port_select_frame, text="Port:").pack(side="left") @@ -380,14 +372,13 @@ class CollectionPage(BasePage): self.refresh_ports_btn.pack(side="left") # Connection status and button - connect_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent") + connect_frame = ctk.CTkFrame(source_frame, fg_color="transparent") connect_frame.pack(fill="x", pady=(5, 0)) self.connect_button = ctk.CTkButton( connect_frame, text="Connect", width=100, height=28, - command=self._toggle_connection, - state="disabled" # Disabled until "Real ESP32" selected + command=self._toggle_connection ) self.connect_button.pack(side="left", padx=(0, 10)) @@ -397,6 +388,9 @@ class CollectionPage(BasePage): ) self.connection_status.pack(side="left") + # Refresh ports on startup + self._refresh_ports() + # Gesture selection gesture_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent") gesture_frame.pack(fill="x", padx=20, pady=10) @@ -540,8 +534,6 @@ class CollectionPage(BasePage): print(f"[DEBUG] Current state:") print(f" - is_collecting: {self.is_collecting}") print(f" - is_connected: {self.is_connected}") - print(f" - using_real_hardware: {self.using_real_hardware}") - print(f" - source_var: {self.source_var.get()}") print(f" - stream exists: {self.stream is not None}") if self.stream: if hasattr(self.stream, 'state'): @@ -592,51 +584,38 @@ class CollectionPage(BasePage): messagebox.showwarning("No Gestures", "Please select at least one gesture.") return - # Determine data source and create appropriate stream - self.using_real_hardware = (self.source_var.get() == "real") - print(f"[DEBUG] using_real_hardware set to: {self.using_real_hardware}") + # Must be connected to ESP32 + print(f"[DEBUG] Checking connection: is_connected={self.is_connected}, stream exists={self.stream is not None}") + if not self.is_connected or not self.stream: + print("[DEBUG] EXIT: Not connected to device") + messagebox.showerror("Not Connected", "Please connect to the ESP32 first.") + return - if self.using_real_hardware: - print("[DEBUG] Real hardware path") - # Must be connected for real hardware - print(f"[DEBUG] Checking connection: is_connected={self.is_connected}, stream exists={self.stream is not None}") - if not self.is_connected or not self.stream: - print("[DEBUG] EXIT: Not connected to device") - messagebox.showerror("Not Connected", "Please connect to the ESP32 first.") - return - - # Send start command to begin streaming - print("[DEBUG] Calling stream.start()...") - try: - self.stream.start() - print("[DEBUG] stream.start() succeeded") - except Exception as e: - print(f"[DEBUG] stream.start() FAILED: {e}") - # Reset stream state if start failed - if self.stream: - try: - print("[DEBUG] Attempting stream.stop() to reset state...") - self.stream.stop() # Try to return to CONNECTED state - print("[DEBUG] stream.stop() succeeded") - except Exception as e2: - print(f"[DEBUG] stream.stop() FAILED: {e2}") - messagebox.showerror("Start Error", f"Failed to start streaming:\n{e}") - print("[DEBUG] EXIT: Stream start error") - return - else: - print("[DEBUG] Simulated stream path") - # Simulated stream (gesture-aware for realistic testing) - self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) - print("[DEBUG] Created GestureAwareEMGStream") - self.stream.start() # Start the background data generation thread - print("[DEBUG] Started simulated stream") + # Send start command to begin streaming + print("[DEBUG] Calling stream.start()...") + try: + self.stream.start() + print("[DEBUG] stream.start() succeeded") + except Exception as e: + print(f"[DEBUG] stream.start() FAILED: {e}") + # Reset stream state if start failed + if self.stream: + try: + print("[DEBUG] Attempting stream.stop() to reset state...") + self.stream.stop() # Try to return to CONNECTED state + print("[DEBUG] stream.stop() succeeded") + except Exception as e2: + print(f"[DEBUG] stream.stop() FAILED: {e2}") + messagebox.showerror("Start Error", f"Failed to start streaming:\n{e}") + print("[DEBUG] EXIT: Stream start error") + return # Initialize parser and windower self.parser = EMGParser(num_channels=NUM_CHANNELS) self.windower = Windower( window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, - overlap=WINDOW_OVERLAP + hop_size_ms=HOP_SIZE_MS ) self.scheduler = PromptScheduler( @@ -649,6 +628,7 @@ class CollectionPage(BasePage): # Reset state self.collected_windows = [] self.collected_labels = [] + self.collected_trial_ids = [] # Track trial_ids for proper train/test splitting self.collected_raw_samples = [] # Store raw samples for label alignment self.sample_buffer = [] print("[DEBUG] Reset collection state") @@ -663,12 +643,9 @@ class CollectionPage(BasePage): self.status_label.configure(text="Starting...") print("[DEBUG] Updated UI - button now shows 'Stop Collection'") - # Disable source selection and connection during collection - self.sim_radio.configure(state="disabled") - self.real_radio.configure(state="disabled") - if self.using_real_hardware: - self.connect_button.configure(state="disabled") - print("[DEBUG] Disabled source/connection controls") + # Disable connection controls during collection + self.connect_button.configure(state="disabled") + print("[DEBUG] Disabled connection controls") # Start collection thread self.collection_thread = threading.Thread(target=self.collection_loop, daemon=True) @@ -689,17 +666,10 @@ class CollectionPage(BasePage): # Safe cleanup - stream might already be in error state try: if self.stream: - if self.using_real_hardware: - print("[DEBUG] Calling stream.stop() for real hardware") - # Send stop command (returns to CONNECTED state) - self.stream.stop() - print("[DEBUG] stream.stop() completed") - else: - print("[DEBUG] Stopping simulated stream") - # For simulated stream, just stop it - self.stream.stop() - self.stream = None - print("[DEBUG] Simulated stream stopped and cleared") + print("[DEBUG] Calling stream.stop()") + # Send stop command (returns to CONNECTED state) + self.stream.stop() + print("[DEBUG] stream.stop() completed") except Exception as e: print(f"[DEBUG] Exception during stream cleanup: {e}") pass # Ignore cleanup errors @@ -717,15 +687,12 @@ class CollectionPage(BasePage): self.countdown_label.configure(text="") print("[DEBUG] UI reset - button shows 'Start Collection'") - # Re-enable source selection and connection button - self.sim_radio.configure(state="normal") - self.real_radio.configure(state="normal") - if self.using_real_hardware: - self.connect_button.configure(state="normal") - # Still connected, just not streaming - if self.is_connected: - device_name = self.stream.device_info.get('device', 'ESP32') if self.stream and self.stream.device_info else 'ESP32' - self._update_connection_status("green", f"Connected ({device_name})") + # Re-enable connection button + self.connect_button.configure(state="normal") + # Still connected, just not streaming + if self.is_connected: + device_name = self.stream.device_info.get('device', 'ESP32') if self.stream and self.stream.device_info else 'ESP32' + self._update_connection_status("green", f"Connected ({device_name})") if self.collected_windows: self.save_button.configure(state="normal") @@ -735,10 +702,8 @@ class CollectionPage(BasePage): def collection_loop(self): """Background collection loop.""" - # Stream is already started (either via handshake for real HW or created for simulated) - # Just mark as ready - if self.using_real_hardware: - self.data_queue.put(('connection_status', ('green', 'Streaming'))) + # Stream is already started via handshake + self.data_queue.put(('connection_status', ('green', 'Streaming'))) self.scheduler.start_session() @@ -755,10 +720,6 @@ class CollectionPage(BasePage): current_time = time.perf_counter() if prompt: - # Update simulated stream gesture (only for GestureAwareEMGStream) - if hasattr(self.stream, 'set_gesture'): - self.stream.set_gesture(prompt.gesture_name) - # Calculate time remaining in current gesture elapsed_in_session = self.scheduler.get_elapsed_time() elapsed_in_gesture = elapsed_in_session - prompt.start_time @@ -818,13 +779,18 @@ class CollectionPage(BasePage): # Try to form a window window = self.windower.add_sample(sample) if window: - label = self.scheduler.get_label_for_time(window.start_time) + # Shift label lookup forward to align with actual muscle + # activation (accounts for reaction time + window centre) + label_time = window.start_time + LABEL_SHIFT_MS / 1000.0 + label = self.scheduler.get_label_for_time(label_time) + trial_id = self.scheduler.get_trial_id_for_time(label_time) self.collected_windows.append(window) self.collected_labels.append(label) + self.collected_trial_ids.append(trial_id) self.data_queue.put(('window_count', len(self.collected_windows))) else: - # Check for data timeout (only relevant for real hardware) - if self.using_real_hardware and (current_time - last_data_time > 3.0): + # Check for data timeout + if current_time - last_data_time > 3.0: if not timeout_warning_sent: self.data_queue.put(('warning', 'No data received - check ESP32 connection')) self.data_queue.put(('connection_status', ('orange', 'No data'))) @@ -885,8 +851,7 @@ class CollectionPage(BasePage): elif msg_type == 'error': # Show error and stop collection self.status_label.configure(text=f"Error: {data}", text_color="red") - if self.using_real_hardware: - self._update_connection_status("red", "Disconnected") + self._update_connection_status("red", "Disconnected") messagebox.showerror("Collection Error", data) self.stop_collection() return @@ -946,6 +911,7 @@ class CollectionPage(BasePage): windows=self.collected_windows, labels=self.collected_labels, metadata=metadata, + trial_ids=self.collected_trial_ids if self.collected_trial_ids else None, raw_samples=self.collected_raw_samples if self.collected_raw_samples else None, session_start_time=session_start_time ) @@ -972,47 +938,6 @@ class CollectionPage(BasePage): self.progress_bar.set(0) self.prompt_label.configure(text="READY", text_color="gray") - def _on_source_change(self): - """Show/hide port selection based on data source.""" - print("\n" + "="*80) - print("[DEBUG] _on_source_change() called") - print(f"[DEBUG] Before cleanup:") - print(f" - is_connected: {self.is_connected}") - print(f" - is_collecting: {self.is_collecting}") - print(f" - stream exists: {self.stream is not None}") - print(f" - source_var changing to: {self.source_var.get()}") - - # Clean up any existing connection/stream when switching modes - if self.is_connected and self.stream: - print("[DEBUG] Disconnecting existing stream...") - try: - self.stream.disconnect() - print("[DEBUG] Stream disconnected successfully") - except Exception as e: - print(f"[DEBUG] Stream disconnect failed: {e}") - - self.is_connected = False - self.stream = None - print("[DEBUG] Cleared is_connected and stream") - print(f"[DEBUG] NOTE: is_collecting remains: {self.is_collecting}") - - if self.source_var.get() == "real": - print("[DEBUG] Configuring for REAL hardware mode") - self.port_frame.pack(fill="x", pady=(5, 0)) - self._refresh_ports() - self.connect_button.configure(text="Connect", state="normal") - self.start_button.configure(state="disabled") # Must connect first - self._update_connection_status("gray", "Disconnected") - print("[DEBUG] Start button DISABLED (must connect first)") - else: - print("[DEBUG] Configuring for SIMULATED mode") - self.port_frame.pack_forget() - self._update_connection_status("gray", "Not using hardware") - self.connect_button.configure(state="disabled") - self.start_button.configure(state="normal") # Simulated mode doesn't need connect - print("[DEBUG] Start button ENABLED (no connection needed)") - print("="*80 + "\n") - def _refresh_ports(self): """Scan and populate available serial ports.""" ports = serial.tools.list_ports.comports() @@ -1148,59 +1073,106 @@ class CollectionPage(BasePage): # ============================================================================= class InspectPage(BasePage): - """Page for inspecting saved sessions.""" + """Page for inspecting saved sessions with scrollable signal + label view.""" + + # How many samples to show in the visible window at once + VIEW_SAMPLES = 3000 # ~3 seconds at 1 kHz def __init__(self, parent): super().__init__(parent) self.create_header( "Inspect Sessions", - "View saved session data and features" + "Browse session data — scroll through signals with gesture labels" ) # Content self.content = ctk.CTkFrame(self) self.content.grid(row=1, column=0, sticky="nsew") - self.content.grid_columnconfigure(0, weight=1) - self.content.grid_columnconfigure(1, weight=3) + self.content.grid_columnconfigure(0, weight=0, minsize=220) + self.content.grid_columnconfigure(1, weight=1) self.content.grid_rowconfigure(0, weight=1) - # Left panel - Session list - self.list_panel = ctk.CTkFrame(self.content) + # ── Left panel ── Session list + self.list_panel = ctk.CTkFrame(self.content, width=220) self.list_panel.grid(row=0, column=0, sticky="nsew", padx=(0, 10)) + self.list_panel.grid_propagate(False) - ctk.CTkLabel(self.list_panel, text="Sessions", font=ctk.CTkFont(size=16, weight="bold")).pack(pady=10) + ctk.CTkLabel(self.list_panel, text="Sessions", + font=ctk.CTkFont(size=16, weight="bold")).pack(pady=10) self.session_listbox = ctk.CTkScrollableFrame(self.list_panel) - self.session_listbox.pack(fill="both", expand=True, padx=10, pady=10) + self.session_listbox.pack(fill="both", expand=True, padx=10, pady=(0, 5)) - self.refresh_button = ctk.CTkButton(self.list_panel, text="Refresh", command=self.load_sessions) + self.refresh_button = ctk.CTkButton(self.list_panel, text="Refresh", + command=self.load_sessions) self.refresh_button.pack(pady=10) - # Right panel - Details + # ── Right panel ── Details + plot + slider self.details_panel = ctk.CTkFrame(self.content) self.details_panel.grid(row=0, column=1, sticky="nsew", padx=(10, 0)) + self.details_panel.grid_columnconfigure(0, weight=1) + self.details_panel.grid_rowconfigure(1, weight=1) # plot row expands self.details_label = ctk.CTkLabel( self.details_panel, text="Select a session to view details", - font=ctk.CTkFont(size=14) + font=ctk.CTkFont(size=14), + justify="left", anchor="w" ) - self.details_label.pack(pady=20) + self.details_label.grid(row=0, column=0, sticky="ew", padx=20, pady=(10, 0)) - # Plot area + # Plot area (filled on session select, dark bg to avoid white flash) + self.plot_frame = ctk.CTkFrame(self.details_panel, fg_color="#2b2b2b") + self.plot_frame.grid(row=1, column=0, sticky="nsew", padx=10, pady=5) + + # Slider + zoom row + self.controls_frame = ctk.CTkFrame(self.details_panel, fg_color="transparent") + self.controls_frame.grid(row=2, column=0, sticky="ew", padx=20, pady=(0, 10)) + self.controls_frame.grid_columnconfigure(1, weight=1) + + ctk.CTkLabel(self.controls_frame, text="Position:", + font=ctk.CTkFont(size=12)).grid(row=0, column=0, padx=(0, 8)) + + self.pos_slider = ctk.CTkSlider(self.controls_frame, from_=0, to=1, + command=self._on_slider) + self.pos_slider.grid(row=0, column=1, sticky="ew") + self.pos_slider.set(0) + + self.pos_label = ctk.CTkLabel(self.controls_frame, text="0.0 s", + font=ctk.CTkFont(size=12), width=80) + self.pos_label.grid(row=0, column=2, padx=(8, 0)) + + # Zoom buttons + zoom_frame = ctk.CTkFrame(self.controls_frame, fg_color="transparent") + zoom_frame.grid(row=0, column=3, padx=(16, 0)) + ctk.CTkButton(zoom_frame, text="−", width=32, command=self._zoom_out).pack(side="left", padx=2) + ctk.CTkButton(zoom_frame, text="+", width=32, command=self._zoom_in).pack(side="left", padx=2) + + # Matplotlib objects self.fig = None self.canvas = None - + self.axes = [] self.session_buttons = [] + # Loaded session state + self._signal = None # (total_samples, n_channels) continuous signal + self._labels_per_sample = None # label string per sample + self._label_names = [] + self._n_channels = 0 + self._total_samples = 0 + self._view_start = 0 # current scroll position in samples + self._view_len = self.VIEW_SAMPLES + self._slider_debounce_id = None # for debouncing slider updates + + # ── lifecycle ── + def on_show(self): - """Load sessions when page is shown.""" self.load_sessions() + # ── session list ── + def load_sessions(self): - """Load and display available sessions.""" - # Clear existing buttons for btn in self.session_buttons: btn.destroy() self.session_buttons = [] @@ -1209,76 +1181,246 @@ class InspectPage(BasePage): sessions = storage.list_sessions() if not sessions: - label = ctk.CTkLabel(self.session_listbox, text="No sessions found") - label.pack(pady=10) - self.session_buttons.append(label) + lbl = ctk.CTkLabel(self.session_listbox, text="No sessions found") + lbl.pack(pady=10) + self.session_buttons.append(lbl) return for session_id in sessions: info = storage.get_session_info(session_id) - btn_text = f"{session_id}\n{info['num_windows']} windows" + gestures = info['gestures'] + btn_text = f"{session_id}\n{info['num_windows']} win · {len(gestures)} gestures" btn = ctk.CTkButton( self.session_listbox, text=btn_text, - font=ctk.CTkFont(size=12), - height=60, - anchor="w", + font=ctk.CTkFont(size=11), + height=55, anchor="w", command=lambda s=session_id: self.show_session(s) ) - btn.pack(fill="x", pady=5) + btn.pack(fill="x", pady=3) self.session_buttons.append(btn) + # ── load & show session ── + def show_session(self, session_id: str): - """Display session details and plot.""" storage = SessionStorage() try: - X, y, label_names = storage.load_for_training(session_id) + # Load raw windowed data WITHOUT transition filtering so we see + # every window exactly as collected, labels included. + X, y, label_names = storage.load_for_training( + session_id, filter_transitions=False + ) except Exception as e: messagebox.showerror("Error", f"Failed to load session: {e}") return - # Clear previous plot - if self.canvas: - self.canvas.get_tk_widget().destroy() + n_windows, samples_per_window, n_channels = X.shape - # Create info text + # Build a continuous signal by concatenating windows using hop-based + # reconstruction. With 150-sample windows and 25-sample hop, consecutive + # windows overlap by 125 samples. We take only the first `hop` samples + # from each window (except the last, where we take the full window) to + # avoid duplicated overlap regions. + hop = HOP_SIZE_MS # = 25 samples at 1 kHz (hop_size_ms == hop samples) + total_samples = (n_windows - 1) * hop + samples_per_window + + signal = np.zeros((total_samples, n_channels), dtype=np.float32) + labels_per_sample = np.empty(total_samples, dtype=object) + labels_per_sample[:] = "" + + for i in range(n_windows): + start = i * hop + end = start + samples_per_window + signal[start:end] = X[i] + # Label this window's hop region (non-overlapping part) + hop_end = start + hop if i < n_windows - 1 else end + labels_per_sample[start:hop_end] = label_names[y[i]] + + # Fill any remaining gaps from the last window's tail + mask = labels_per_sample == "" + if mask.any(): + labels_per_sample[mask] = label_names[y[-1]] + + # Pre-compute centered signals (global mean removal) for smooth scrolling. + # Using global mean ensures the signal doesn't jump when scrolling. + centered = signal.astype(np.float64) + for ch in range(n_channels): + centered[:, ch] -= centered[:, ch].mean() + + # Store for scrolling + self._signal = signal + self._centered = centered + self._labels_per_sample = labels_per_sample + self._label_names = label_names + self._n_channels = n_channels + self._total_samples = total_samples + self._view_start = 0 + + # Update info text info = storage.get_session_info(session_id) - info_text = f"""Session: {session_id} -User: {info['user_id']} -Time: {info['timestamp']} -Windows: {X.shape[0]} -Samples/window: {X.shape[1]} -Channels: {X.shape[2]} -Gestures: {', '.join(label_names)}""" - + duration_sec = total_samples / SAMPLING_RATE_HZ + label_counts = {ln: int(np.sum(y == i)) for i, ln in enumerate(label_names)} + counts_str = ", ".join(f"{n}: {c}" for n, c in sorted(label_counts.items())) + info_text = ( + f"Session: {session_id} | " + f"{n_windows} windows · {total_samples} samples · " + f"{duration_sec:.1f} s · {n_channels} ch\n" + f"Labels: {counts_str}" + ) self.details_label.configure(text=info_text) - # Create plot - self.fig = Figure(figsize=(10, 6), dpi=100, facecolor='#2b2b2b') + # Configure slider range + max_start = max(0, total_samples - self._view_len) + self.pos_slider.configure(to=max(max_start, 1)) + self.pos_slider.set(0) + self._update_pos_label() - # Plot raw signal for each channel - for ch in range(min(X.shape[2], 4)): - ax = self.fig.add_subplot(2, 2, ch + 1) - ax.set_facecolor('#2b2b2b') - ax.tick_params(colors='white') + # Build full-session plot (scroll with xlim, not rebuild) + self._build_plot() - signal = X[:, :, ch].flatten() - signal_centered = signal - signal.mean() - ax.plot(signal_centered[:2000], color='#00ff88', linewidth=0.5) - ax.set_title(f'Channel {ch}', color='white', fontsize=10) - ax.set_ylabel('Amplitude', color='white', fontsize=8) - ax.grid(True, alpha=0.3) + # ── plotting ── + def _build_plot(self): + """Build plot skeleton once. Line data is updated via set_data() on scroll.""" + # Tear down old canvas + if self.canvas: + self.canvas.get_tk_widget().destroy() + self.canvas = None + if self.fig: + plt.close(self.fig) + self.axes = [] + self._lines = [] + + n_ch = min(self._n_channels, 4) + self.fig = Figure(figsize=(12, max(2.5 * n_ch, 5)), dpi=100, + facecolor='#2b2b2b') + + duration_sec = self._total_samples / SAMPLING_RATE_HZ + + # Pre-build label colour strip as a tiny RGBA image (1 row, ~2k cols). + # This replaces hundreds of axvspan patches with a single imshow per axis, + # cutting per-frame render cost dramatically. + from matplotlib.colors import to_rgba + hop_ds = max(1, self._total_samples // 2000) # downsample to ~2k pixels + n_px = (self._total_samples + hop_ds - 1) // hop_ds + label_img = np.zeros((1, n_px, 4), dtype=np.float32) + for i in range(n_px): + lbl = self._labels_per_sample[min(i * hop_ds, self._total_samples - 1)] + label_img[0, i] = to_rgba(get_gesture_color(lbl), alpha=0.25) + + for ch in range(n_ch): + ax = self.fig.add_subplot(n_ch, 1, ch + 1) + ax.set_facecolor('#1e1e1e') + self.axes.append(ax) + + # Fix y-axis to full signal range so it doesn't jump on scroll + ch_min = float(self._centered[:, ch].min()) + ch_max = float(self._centered[:, ch].max()) + margin = (ch_max - ch_min) * 0.05 + ylo, yhi = ch_min - margin, ch_max + margin + ax.set_ylim(ylo, yhi) + + # Label colour strip as a single imshow (replaces ~100 axvspan patches) + ax.imshow(label_img, aspect='auto', + extent=[0, duration_sec, ylo, yhi], + origin='lower', zorder=1, interpolation='nearest') + + # Create empty line — data filled by _fill_view_data() + line, = ax.plot([], [], color='#00ff88', linewidth=0.6, zorder=3) + self._lines.append(line) + + ax.set_ylabel(f'Ch {ch}', color='white', fontsize=10, labelpad=10) + ax.tick_params(colors='white', labelsize=8) + ax.grid(True, alpha=0.15, color='white') for spine in ax.spines.values(): - spine.set_color('white') + spine.set_color('#555555') - self.fig.tight_layout() + if ch < n_ch - 1: + ax.tick_params(labelbottom=False) - self.canvas = FigureCanvasTkAgg(self.fig, master=self.details_panel) + # X label on bottom axis + if self.axes: + self.axes[-1].set_xlabel('Time (s)', color='white', fontsize=10) + + # Legend at top + if self.axes: + from matplotlib.patches import Patch + patches = [Patch(facecolor=get_gesture_color(n), alpha=0.35, label=n) + for n in self._label_names] + self.axes[0].legend(handles=patches, loc='upper right', fontsize=8, + ncol=len(patches), framealpha=0.5, + facecolor='#333333', edgecolor='#555555', + labelcolor='white') + + self.fig.tight_layout(pad=1.0) + + self.canvas = FigureCanvasTkAgg(self.fig, master=self.plot_frame) + widget = self.canvas.get_tk_widget() + widget.configure(bg='#2b2b2b', highlightthickness=0) + widget.pack(fill="both", expand=True) + + # Fill initial view data and render + self._fill_view_data() self.canvas.draw() - self.canvas.get_tk_widget().pack(fill="both", expand=True, padx=20, pady=20) + + def _fill_view_data(self): + """Update line artists with only the visible window's data (~3k points).""" + s = self._view_start + e = min(s + self._view_len, self._total_samples) + time_slice = np.arange(s, e) / SAMPLING_RATE_HZ + + for ch, line in enumerate(self._lines): + line.set_data(time_slice, self._centered[s:e, ch]) + + t_start = s / SAMPLING_RATE_HZ + t_end = e / SAMPLING_RATE_HZ + for ax in self.axes: + ax.set_xlim(t_start, t_end) + + # ── slider / zoom callbacks ── + + def _on_slider(self, value): + new_start = int(float(value)) + if new_start == self._view_start: + return # No change, skip redraw + self._view_start = new_start + self._update_pos_label() + # Debounce: cancel pending draw and schedule a new one + if self._slider_debounce_id is not None: + self.after_cancel(self._slider_debounce_id) + self._slider_debounce_id = self.after(8, self._scroll_draw) + + def _scroll_draw(self): + """Fast redraw: update line data (~3k points) + xlim, no plot rebuild.""" + self._slider_debounce_id = None + if self.canvas and self._lines: + self._fill_view_data() + self.canvas.draw() + + def _update_pos_label(self): + t = self._view_start / SAMPLING_RATE_HZ + self.pos_label.configure(text=f"{t:.1f} s") + + def _zoom_in(self): + """Show fewer samples (zoom in).""" + self._view_len = max(500, self._view_len // 2) + self._clamp_view() + self._scroll_draw() + + def _zoom_out(self): + """Show more samples (zoom out).""" + self._view_len = min(self._total_samples, self._view_len * 2) + self._clamp_view() + self._scroll_draw() + + def _clamp_view(self): + max_start = max(0, self._total_samples - self._view_len) + self._view_start = min(self._view_start, max_start) + self.pos_slider.configure(to=max(max_start, 1)) + self.pos_slider.set(self._view_start) + self._update_pos_label() # ============================================================================= @@ -1312,6 +1454,71 @@ class TrainingPage(BasePage): ) self.sessions_label.pack(pady=10) + # Model name input + name_frame = ctk.CTkFrame(self.content, fg_color="transparent") + name_frame.pack(fill="x", padx=20, pady=(10, 0)) + + ctk.CTkLabel(name_frame, text="Model name:", font=ctk.CTkFont(size=14)).pack(side="left") + + self.model_name_var = ctk.StringVar(value="emg_lda_classifier") + self.model_name_entry = ctk.CTkEntry( + name_frame, textvariable=self.model_name_var, + width=250, placeholder_text="emg_lda_classifier" + ) + self.model_name_entry.pack(side="left", padx=(10, 5)) + + ctk.CTkLabel( + name_frame, text=".joblib", + font=ctk.CTkFont(size=14), text_color="gray" + ).pack(side="left") + + # Model type selector + type_frame = ctk.CTkFrame(self.content, fg_color="transparent") + type_frame.pack(fill="x", padx=20, pady=(10, 0)) + + ctk.CTkLabel(type_frame, text="Model type:", font=ctk.CTkFont(size=14)).pack(side="left") + + self.model_type_var = ctk.StringVar(value="LDA") + self.model_type_selector = ctk.CTkSegmentedButton( + type_frame, values=["LDA", "QDA"], + variable=self.model_type_var, + ) + self.model_type_selector.pack(side="left", padx=(10, 10)) + + self.model_type_desc = ctk.CTkLabel( + type_frame, + text="Linear — fast, exportable to ESP32", + font=ctk.CTkFont(size=11), text_color="gray" + ) + self.model_type_desc.pack(side="left") + + self.model_type_var.trace_add("write", self._on_model_type_changed) + + # QDA regularisation slider (only active when QDA is selected) + reg_frame = ctk.CTkFrame(self.content, fg_color="transparent") + reg_frame.pack(fill="x", padx=20, pady=(6, 0)) + + ctk.CTkLabel(reg_frame, text="reg_param:", font=ctk.CTkFont(size=14)).pack(side="left") + + self.reg_param_var = ctk.DoubleVar(value=0.1) + self.reg_param_slider = ctk.CTkSlider( + reg_frame, from_=0.0, to=1.0, variable=self.reg_param_var, + width=180, state="disabled", + command=lambda v: self.reg_param_label.configure(text=f"{v:.2f}"), + ) + self.reg_param_slider.pack(side="left", padx=(10, 6)) + + self.reg_param_label = ctk.CTkLabel( + reg_frame, text="0.10", font=ctk.CTkFont(size=13), width=40 + ) + self.reg_param_label.pack(side="left") + + self.reg_param_desc = ctk.CTkLabel( + reg_frame, text="(enable QDA to adjust — 0=flexible, 1=LDA-like)", + font=ctk.CTkFont(size=11), text_color="gray" + ) + self.reg_param_desc.pack(side="left", padx=(8, 0)) + # Train button self.train_button = ctk.CTkButton( self.content, @@ -1334,6 +1541,40 @@ class TrainingPage(BasePage): ) self.export_button.pack(pady=5) + # Advanced training (ensemble + MLP) + adv_frame = ctk.CTkFrame(self.content, fg_color="transparent") + adv_frame.pack(fill="x", padx=20, pady=(15, 0)) + + ctk.CTkLabel( + adv_frame, text="Advanced (ESP32 only):", + font=ctk.CTkFont(size=13, weight="bold") + ).pack(side="left") + + self.train_ensemble_button = ctk.CTkButton( + adv_frame, text="Train Ensemble", + font=ctk.CTkFont(size=13), height=34, + fg_color="#8B5CF6", hover_color="#7C3AED", + state="disabled", + command=self._train_ensemble + ) + self.train_ensemble_button.pack(side="left", padx=(10, 5)) + + self.train_mlp_button = ctk.CTkButton( + adv_frame, text="Train MLP", + font=ctk.CTkFont(size=13), height=34, + fg_color="#8B5CF6", hover_color="#7C3AED", + state="disabled", + command=self._train_mlp + ) + self.train_mlp_button.pack(side="left", padx=5) + + self.adv_desc = ctk.CTkLabel( + adv_frame, + text="(train base LDA first)", + font=ctk.CTkFont(size=11), text_color="gray" + ) + self.adv_desc.pack(side="left", padx=(8, 0)) + # Progress self.progress_bar = ctk.CTkProgressBar(self.content, width=400) self.progress_bar.pack(pady=10) @@ -1378,6 +1619,34 @@ class TrainingPage(BasePage): self.sessions_label.configure(text="\n".join(info_lines)) self.train_button.configure(state="normal") + def _get_model_path(self) -> Path: + """Build model save path from the user-entered name.""" + name = self.model_name_var.get().strip() + if not name: + name = "emg_lda_classifier" + # Sanitize: remove extension if user typed one, strip unsafe chars + name = name.replace(".joblib", "").replace("/", "_").replace("\\", "_") + return MODEL_DIR / f"{name}.joblib" + + def _on_model_type_changed(self, *args): + """Update description, model name, and reg_param slider when model type changes.""" + mt = self.model_type_var.get() + if mt == "QDA": + self.model_type_desc.configure(text="Quadratic — flexible boundaries, laptop-only") + self.reg_param_slider.configure(state="normal") + self.reg_param_desc.configure(text="0=flexible quadratic, 1=LDA-like", text_color="white") + # Auto-suggest a QDA filename if still on the default LDA name + if self.model_name_var.get().strip() in ("", "emg_lda_classifier"): + self.model_name_var.set("emg_qda_classifier") + else: + self.model_type_desc.configure(text="Linear — fast, exportable to ESP32") + self.reg_param_slider.configure(state="disabled") + self.reg_param_desc.configure( + text="(enable QDA to adjust — 0=flexible, 1=LDA-like)", text_color="gray" + ) + if self.model_name_var.get().strip() in ("", "emg_qda_classifier"): + self.model_name_var.set("emg_lda_classifier") + def train_model(self): """Train the model on all sessions.""" self.train_button.configure(state="disabled") @@ -1385,11 +1654,18 @@ class TrainingPage(BasePage): self.progress_bar.set(0) self.status_label.configure(text="Loading data...") + # Capture model path, type, and reg_param on UI thread (StringVar isn't thread-safe) + model_save_path = self._get_model_path() + model_type = self.model_type_var.get().lower() + reg_param = float(self.reg_param_var.get()) + # Run in thread to not block UI - thread = threading.Thread(target=self._train_thread, daemon=True) + thread = threading.Thread( + target=self._train_thread, args=(model_save_path, model_type, reg_param), daemon=True + ) thread.start() - def _train_thread(self): + def _train_thread(self, model_save_path: Path, model_type: str = "lda", reg_param: float = 0.1): """Training thread.""" try: storage = SessionStorage() @@ -1398,25 +1674,29 @@ class TrainingPage(BasePage): self.after(0, lambda: self.status_label.configure(text="Loading all sessions...")) self.after(0, lambda: self.progress_bar.set(0.2)) - X, y, label_names, loaded_sessions = storage.load_all_for_training() + X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training() + n_trials = len(np.unique(trial_ids)) + n_sessions = len(np.unique(session_indices)) self.after(0, lambda: self._log(f"Loaded {X.shape[0]} windows from {len(loaded_sessions)} sessions")) + self.after(0, lambda: self._log(f"Unique trials: {n_trials} (for proper train/test splitting)")) + self.after(0, lambda ns=n_sessions: self._log(f"Session normalization: {ns} sessions will be z-scored independently")) self.after(0, lambda: self._log(f"Labels: {label_names}\n")) # Train - self.after(0, lambda: self.status_label.configure(text="Training classifier...")) + self.after(0, lambda mt=model_type: self.status_label.configure(text=f"Training {mt.upper()} classifier...")) self.after(0, lambda: self.progress_bar.set(0.5)) - self.classifier = EMGClassifier() - self.classifier.train(X, y, label_names) + self.classifier = EMGClassifier(model_type=model_type, reg_param=reg_param) + self.classifier.train(X, y, label_names, session_indices=session_indices) self.after(0, lambda: self._log("Training complete!\n")) - # Cross-validation - self.after(0, lambda: self.status_label.configure(text="Running cross-validation...")) + # Cross-validation (trial-level to prevent leakage) + self.after(0, lambda: self.status_label.configure(text="Running cross-validation (trial-level)...")) self.after(0, lambda: self.progress_bar.set(0.7)) - cv_scores = self.classifier.cross_validate(X, y, cv=5) + cv_scores = self.classifier.cross_validate(X, y, trial_ids=trial_ids, cv=5, session_indices=session_indices) self.after(0, lambda: self._log(f"Cross-validation scores: {cv_scores.round(3)}")) self.after(0, lambda: self._log(f"Mean accuracy: {cv_scores.mean()*100:.1f}% (+/- {cv_scores.std()*100:.1f}%)\n")) @@ -1431,7 +1711,7 @@ class TrainingPage(BasePage): self.after(0, lambda: self.status_label.configure(text="Saving model...")) self.after(0, lambda: self.progress_bar.set(0.9)) - model_path = self.classifier.save(EMGClassifier.get_default_model_path()) + model_path = self.classifier.save(model_save_path) self.after(0, lambda: self._log(f"\nModel saved to: {model_path}")) self.after(0, lambda: self.progress_bar.set(1.0)) @@ -1446,14 +1726,112 @@ class TrainingPage(BasePage): finally: self.after(0, lambda: self.train_button.configure(state="normal")) - self.after(0, lambda: self.export_button.configure(state="normal")) + # Only enable export if an LDA model was trained (QDA can't export to C) + can_export = self.classifier and self.classifier.model_type == "lda" + self.after(0, lambda: self.export_button.configure( + state="normal" if can_export else "disabled" + )) + # Enable advanced training buttons after successful LDA training + if can_export: + self.after(0, lambda: self.train_ensemble_button.configure(state="normal")) + self.after(0, lambda: self.train_mlp_button.configure(state="normal")) + self.after(0, lambda: self.adv_desc.configure( + text="Ensemble: 3-specialist LDA stacker | MLP: int8 neural net" + )) + + def _train_ensemble(self): + """Train the 3-specialist + meta-LDA ensemble (runs train_ensemble.py).""" + self.train_ensemble_button.configure(state="disabled") + self._log("\n--- Training Ensemble ---") + self.status_label.configure(text="Training ensemble (3 specialist LDAs + meta-LDA)...") + self.progress_bar.set(0.3) + + def _run(): + try: + script = str(Path(__file__).parent / "train_ensemble.py") + result = subprocess.run( + [sys.executable, script], + capture_output=True, text=True, timeout=300 + ) + output = result.stdout + result.stderr + self.after(0, lambda: self._log(output)) + if result.returncode == 0: + self.after(0, lambda: self._log("\nEnsemble training complete!")) + self.after(0, lambda: self.status_label.configure(text="Ensemble trained!")) + else: + self.after(0, lambda: self._log(f"\nEnsemble training failed (exit code {result.returncode})")) + self.after(0, lambda: self.status_label.configure(text="Ensemble training failed")) + except Exception as e: + self.after(0, lambda: self._log(f"\nEnsemble error: {e}")) + self.after(0, lambda: self.status_label.configure(text="Ensemble training failed")) + finally: + self.after(0, lambda: self.progress_bar.set(1.0)) + self.after(0, lambda: self.train_ensemble_button.configure(state="normal")) + + threading.Thread(target=_run, daemon=True).start() + + def _train_mlp(self): + """Train the int8 MLP model (runs train_mlp_tflite.py). + + TensorFlow requires Python <=3.12. Try ``py -3.12`` first (Windows + launcher), fall back to the current interpreter. + """ + self.train_mlp_button.configure(state="disabled") + self._log("\n--- Training MLP (TFLite int8) ---") + self.status_label.configure(text="Training MLP neural network...") + self.progress_bar.set(0.3) + + def _run(): + try: + script = str(Path(__file__).parent / "train_mlp_tflite.py") + # TensorFlow needs Python <=3.12; try py launcher first + python_cmd = [sys.executable] + try: + probe = subprocess.run( + ["py", "-3.12", "-c", "import tensorflow"], + capture_output=True, timeout=30, + ) + if probe.returncode == 0: + python_cmd = ["py", "-3.12"] + self.after(0, lambda: self._log("Using Python 3.12 (TensorFlow compatible)")) + except FileNotFoundError: + pass + result = subprocess.run( + python_cmd + [script], + capture_output=True, text=True, timeout=600 + ) + output = result.stdout + result.stderr + self.after(0, lambda: self._log(output)) + if result.returncode == 0: + self.after(0, lambda: self._log("\nMLP training complete!")) + self.after(0, lambda: self.status_label.configure(text="MLP trained!")) + else: + self.after(0, lambda: self._log(f"\nMLP training failed (exit code {result.returncode})")) + self.after(0, lambda: self.status_label.configure(text="MLP training failed")) + except Exception as e: + self.after(0, lambda: self._log(f"\nMLP error: {e}")) + self.after(0, lambda: self.status_label.configure(text="MLP training failed")) + finally: + self.after(0, lambda: self.progress_bar.set(1.0)) + self.after(0, lambda: self.train_mlp_button.configure(state="normal")) + + threading.Thread(target=_run, daemon=True).start() def export_model(self): - """Export trained model to C header.""" + """Export trained model to C header (LDA only).""" if not self.classifier or not self.classifier.is_trained: messagebox.showerror("Error", "No trained model to export!") return + if self.classifier.model_type != "lda": + messagebox.showerror( + "Export Not Supported", + "QDA models cannot be exported to C header.\n\n" + "QDA uses per-class covariance matrices which don't reduce to\n" + "simple weights/biases. Train an LDA model to export for ESP32." + ) + return + # Default path in ESP32 project default_path = Path("EMG_Arm/src/core/model_weights.h").absolute() @@ -1485,6 +1863,704 @@ class TrainingPage(BasePage): app.sidebar.update_status() +# ============================================================================= +# CALIBRATION PAGE +# ============================================================================= + +class CalibrationPage(BasePage): + """ + Session calibration — aligns the current-session EMG feature distribution + to the training distribution so the classifier works reliably across sessions. + + Workflow: + 1. Load a trained model (needs training stats stored during training). + 2. Connect to ESP32. + 3. Click "Start Calibration": hold each gesture for 5 seconds when prompted. + 4. Click "Apply Calibration": stores the fitted transform in the app so + PredictionPage uses it automatically in Laptop inference mode. + """ + + def __init__(self, parent): + super().__init__(parent) + + self.create_header( + "Calibrate", + "Align current session to training data — fixes electrode placement drift" + ) + + # Page state + self.is_calibrating = False + self.is_connected = False + self.classifier = None + self.stream = None + self.calib_thread = None + self._calib_gestures: list[str] = [] # Populated from model labels at start + + # Two-column layout + self.content = ctk.CTkFrame(self) + self.content.grid(row=1, column=0, sticky="nsew") + self.content.grid_columnconfigure(0, weight=1) + self.content.grid_columnconfigure(1, weight=1) + self.content.grid_rowconfigure(0, weight=1) + + self.left_panel = ctk.CTkFrame(self.content) + self.left_panel.grid(row=0, column=0, sticky="nsew", padx=(0, 8)) + + self.right_panel = ctk.CTkFrame(self.content) + self.right_panel.grid(row=0, column=1, sticky="nsew", padx=(8, 0)) + + self._setup_left_panel() + self._setup_right_panel() + + # ------------------------------------------------------------------ + # Left panel — controls + # ------------------------------------------------------------------ + + def _setup_left_panel(self): + p = self.left_panel + + # Model picker + ctk.CTkLabel(p, text="Trained Model:", font=ctk.CTkFont(size=14)).pack( + anchor="w", padx=20, pady=(20, 0) + ) + model_row = ctk.CTkFrame(p, fg_color="transparent") + model_row.pack(fill="x", padx=20, pady=(5, 0)) + + self.model_var = ctk.StringVar(value="No models found") + self.model_dropdown = ctk.CTkOptionMenu(model_row, variable=self.model_var, width=240) + self.model_dropdown.pack(side="left") + + self.refresh_models_btn = ctk.CTkButton( + model_row, text="⟳", width=30, command=self._refresh_models + ) + self.refresh_models_btn.pack(side="left", padx=(5, 0)) + + self.load_model_btn = ctk.CTkButton( + p, text="Load Model", height=34, command=self._load_model + ) + self.load_model_btn.pack(fill="x", padx=20, pady=(8, 0)) + + self.model_status_label = ctk.CTkLabel( + p, text="No model loaded", font=ctk.CTkFont(size=12), text_color="orange" + ) + self.model_status_label.pack(anchor="w", padx=20, pady=(4, 0)) + + # Divider + ctk.CTkFrame(p, height=1, fg_color="gray40").pack(fill="x", padx=20, pady=14) + + # ESP32 connection + ctk.CTkLabel(p, text="ESP32 Connection:", font=ctk.CTkFont(size=14)).pack( + anchor="w", padx=20 + ) + port_row = ctk.CTkFrame(p, fg_color="transparent") + port_row.pack(fill="x", padx=20, pady=(5, 0)) + + ctk.CTkLabel(port_row, text="Port:").pack(side="left") + self.port_var = ctk.StringVar(value="Auto-detect") + self.port_dropdown = ctk.CTkOptionMenu( + port_row, variable=self.port_var, values=["Auto-detect"], width=140 + ) + self.port_dropdown.pack(side="left", padx=(8, 4)) + + self.refresh_ports_btn = ctk.CTkButton( + port_row, text="⟳", width=30, command=self._refresh_ports + ) + self.refresh_ports_btn.pack(side="left") + + conn_row = ctk.CTkFrame(p, fg_color="transparent") + conn_row.pack(fill="x", padx=20, pady=(5, 0)) + + self.connect_btn = ctk.CTkButton( + conn_row, text="Connect", width=100, height=28, command=self._toggle_connection + ) + self.connect_btn.pack(side="left", padx=(0, 10)) + + self.conn_status_label = ctk.CTkLabel( + conn_row, text="● Disconnected", font=ctk.CTkFont(size=11), text_color="gray" + ) + self.conn_status_label.pack(side="left") + + # Divider + ctk.CTkFrame(p, height=1, fg_color="gray40").pack(fill="x", padx=20, pady=14) + + # Action buttons + self.start_btn = ctk.CTkButton( + p, + text="Start Calibration", + font=ctk.CTkFont(size=16, weight="bold"), + height=50, + state="disabled", + command=self._start_calibration, + ) + self.start_btn.pack(fill="x", padx=20, pady=(0, 8)) + + self.apply_btn = ctk.CTkButton( + p, + text="Apply Calibration to Prediction", + font=ctk.CTkFont(size=13), + height=40, + fg_color="#28a745", + hover_color="#1e7e34", + state="disabled", + command=self._apply_calibration, + ) + self.apply_btn.pack(fill="x", padx=20, pady=(0, 8)) + + # Log box + ctk.CTkLabel(p, text="Log:", font=ctk.CTkFont(size=12)).pack( + anchor="w", padx=20, pady=(10, 0) + ) + self.log_box = ctk.CTkTextbox( + p, font=ctk.CTkFont(family="Courier", size=11), height=160 + ) + self.log_box.pack(fill="x", padx=20, pady=(4, 20)) + + self._refresh_models() + self._refresh_ports() + + # ------------------------------------------------------------------ + # Right panel — gesture display and countdown + # ------------------------------------------------------------------ + + def _setup_right_panel(self): + p = self.right_panel + + # Overall progress + ctk.CTkLabel(p, text="Overall progress:", font=ctk.CTkFont(size=13)).pack( + pady=(20, 4) + ) + self.overall_progress = ctk.CTkProgressBar(p, width=320) + self.overall_progress.pack() + self.overall_progress.set(0) + + self.progress_text = ctk.CTkLabel( + p, text="0 / 0 gestures", font=ctk.CTkFont(size=12), text_color="gray" + ) + self.progress_text.pack(pady=(4, 16)) + + # Big gesture name + self.gesture_label = ctk.CTkLabel( + p, + text="---", + font=ctk.CTkFont(size=64, weight="bold"), + text_color="gray", + ) + self.gesture_label.pack(pady=(10, 6)) + + # Instruction text + self.instruction_label = ctk.CTkLabel( + p, + text="Load a model and connect to begin", + font=ctk.CTkFont(size=15), + text_color="gray", + ) + self.instruction_label.pack(pady=4) + + # Countdown (remaining seconds in current gesture) + self.countdown_label = ctk.CTkLabel( + p, + text="", + font=ctk.CTkFont(size=44, weight="bold"), + text_color="#FFD700", + ) + self.countdown_label.pack(pady=8) + + # Per-gesture progress bar + ctk.CTkLabel( + p, text="Current gesture:", font=ctk.CTkFont(size=12), text_color="gray" + ).pack(pady=(8, 2)) + self.gesture_progress = ctk.CTkProgressBar(p, width=320) + self.gesture_progress.pack() + self.gesture_progress.set(0) + + # Applied status + self.calib_applied_label = ctk.CTkLabel( + p, text="", font=ctk.CTkFont(size=13, weight="bold"), text_color="green" + ) + self.calib_applied_label.pack(pady=16) + + # ------------------------------------------------------------------ + # on_show / on_hide + # ------------------------------------------------------------------ + + def on_show(self): + self._refresh_models() + # Reflect whether calibration is already applied + app = self.winfo_toplevel() + if isinstance(app, EMGApp) and app.calibrated_classifier is not None: + self.calib_applied_label.configure( + text="Calibration active — go to Live Prediction to use it" + ) + else: + self.calib_applied_label.configure(text="") + + def on_hide(self): + if self.is_calibrating: + self.is_calibrating = False + if self.stream: + try: + self.stream.stop() + except Exception: + pass + + def stop(self): + self.is_calibrating = False + if self.stream: + try: + self.stream.stop() + except Exception: + pass + + # ------------------------------------------------------------------ + # Model helpers + # ------------------------------------------------------------------ + + def _refresh_models(self): + models = EMGClassifier.list_saved_models() + if models: + names = [p.name for p in models] + self.model_dropdown.configure(values=names) + latest = max(models, key=lambda p: p.stat().st_mtime) + self.model_var.set(latest.name) + else: + self.model_dropdown.configure(values=["No models found"]) + self.model_var.set("No models found") + + def _get_model_path(self): + name = self.model_var.get() + if name == "No models found": + return None + path = MODEL_DIR / name + return path if path.exists() else None + + def _load_model(self): + path = self._get_model_path() + if not path: + messagebox.showerror("No Model", "Select a model from the dropdown first.") + return + try: + self.classifier = EMGClassifier.load(path) + if self.classifier.calibration_transform.has_training_stats: + mt = self.classifier.model_type.upper() + rp = (f", reg_param={self.classifier.reg_param:.2f}" + if self.classifier.model_type == "qda" else "") + sn = getattr(self.classifier, 'session_normalized', False) + sn_str = "" if sn else " [!old — retrain recommended]" + status_color = "green" if sn else "orange" + self.model_status_label.configure( + text=f"Loaded: {path.name} [{mt}{rp}]{sn_str}", + text_color=status_color, + ) + self._log(f"Model loaded: {path.name} [{mt}{rp}]") + self._log(f"Gestures: {self.classifier.label_names}") + if not sn: + self._log("WARNING: This model was trained without session normalization.") + self._log(" Calibration will work but may be less accurate, especially for QDA.") + self._log(" Retrain to get proper calibration support.") + else: + self.model_status_label.configure( + text=f"Loaded (old model — retrain to enable calibration)", + text_color="orange", + ) + self._log("Warning: model has no training stats.") + self._log("Retrain the model to enable calibration support.") + self._update_start_button() + except Exception as e: + messagebox.showerror("Load Error", f"Failed to load model:\n{e}") + + # ------------------------------------------------------------------ + # Connection helpers + # ------------------------------------------------------------------ + + def _refresh_ports(self): + ports = serial.tools.list_ports.comports() + port_names = ["Auto-detect"] + [p.device for p in ports] + self.port_dropdown.configure(values=port_names) + + def _get_port(self): + p = self.port_var.get() + return None if p == "Auto-detect" else p + + def _toggle_connection(self): + if self.is_connected: + self._disconnect() + else: + self._connect() + + def _connect(self): + port = self._get_port() + try: + self.conn_status_label.configure(text="● Connecting...", text_color="orange") + self.connect_btn.configure(state="disabled") + self.update() + self.stream = RealSerialStream(port=port) + device_info = self.stream.connect(timeout=5.0) + self.is_connected = True + self.conn_status_label.configure( + text=f"● Connected ({device_info.get('device', 'ESP32')})", + text_color="green", + ) + self.connect_btn.configure(text="Disconnect", state="normal") + self._log("ESP32 connected") + self._update_start_button() + except TimeoutError: + self.conn_status_label.configure(text="● Timeout", text_color="red") + self.connect_btn.configure(state="normal") + messagebox.showerror("Timeout", "ESP32 did not respond within 5 seconds.") + if self.stream: + try: + self.stream.disconnect() + except Exception: + pass + self.stream = None + except Exception as e: + self.conn_status_label.configure(text="● Failed", text_color="red") + self.connect_btn.configure(state="normal") + messagebox.showerror("Connection Error", str(e)) + if self.stream: + try: + self.stream.disconnect() + except Exception: + pass + self.stream = None + + def _disconnect(self): + try: + if self.stream: + self.stream.disconnect() + time.sleep(0.3) + except Exception: + pass + self.is_connected = False + self.stream = None + self.conn_status_label.configure(text="● Disconnected", text_color="gray") + self.connect_btn.configure(text="Connect", state="normal") + self._update_start_button() + + # ------------------------------------------------------------------ + # UI helpers + # ------------------------------------------------------------------ + + def _update_start_button(self): + can_start = ( + self.classifier is not None + and self.classifier.calibration_transform.has_training_stats + and self.is_connected + and not self.is_calibrating + ) + self.start_btn.configure(state="normal" if can_start else "disabled") + + def _log(self, text: str): + self.log_box.insert("end", text + "\n") + self.log_box.see("end") + + # ------------------------------------------------------------------ + # Calibration logic + # ------------------------------------------------------------------ + + def _start_calibration(self): + if self.is_calibrating: + return + + self.is_calibrating = True + self.apply_btn.configure(state="disabled") + self.start_btn.configure(state="disabled") + self.calib_applied_label.configure(text="") + self.overall_progress.set(0) + self.gesture_progress.set(0) + self._log("\n--- Starting calibration ---") + self._log(f"Each gesture: {int(CALIB_PREP_SEC)}s prep → {int(CALIB_DURATION_SEC)}s hold") + + try: + self.stream.start() + self.stream.running = True + except Exception as e: + messagebox.showerror("Stream Error", f"Could not start EMG stream:\n{e}") + self.is_calibrating = False + self._update_start_button() + return + + # Build gesture order: rest first, then others sorted + labels = self.classifier.label_names + gestures = ["rest"] + sorted(g for g in labels if g != "rest") + self._calib_gestures = gestures + self.progress_text.configure(text=f"0 / {len(gestures)} gestures") + + import threading as _threading + self.calib_thread = _threading.Thread( + target=self._calibration_thread, args=(gestures,), daemon=True + ) + self.calib_thread.start() + + def _calibration_thread(self, gestures: list): + """ + Background thread: walks through each gesture in two phases. + + Phase 1 — Preparation (CALIB_PREP_SEC seconds): + Show the gesture name in yellow with a whole-second countdown so + the user has time to form the gesture before recording begins. + Serial samples are drained but discarded to keep the buffer fresh. + + Phase 2 — Collection (CALIB_DURATION_SEC seconds): + Gesture name switches to its gesture colour. EMG windows are + extracted and stored. A decimal countdown shows remaining time. + + All UI mutations go through self.after() for thread safety. + """ + parser = EMGParser(num_channels=NUM_CHANNELS) + windower = Windower( + window_size_ms=WINDOW_SIZE_MS, + sample_rate=SAMPLING_RATE_HZ, + hop_size_ms=HOP_SIZE_MS, + ) + all_features = [] + all_labels = [] + rms_by_gesture: dict[str, list[float]] = {} # AC-RMS per window, keyed by gesture + n_gestures = len(gestures) + + try: + for g_idx, gesture in enumerate(gestures): + if not self.is_calibrating: + return + + display_name = gesture.upper().replace("_", " ") + gesture_color = get_gesture_color(gesture) + + # ── Phase 1: Preparation countdown ────────────────────────── + # Show gesture in yellow so the user knows what's coming and + # can start forming the gesture before recording begins. + self.after(0, lambda t=display_name: self.gesture_label.configure( + text=t, text_color="#FFD700" + )) + self.after(0, lambda: self.instruction_label.configure( + text="Get ready..." + )) + self.after(0, lambda: self.gesture_progress.set(0)) + self.after(0, lambda: self.countdown_label.configure( + text=str(int(CALIB_PREP_SEC)) + )) + + prep_start = time.perf_counter() + last_ui_time = prep_start + + while self.is_calibrating: + elapsed = time.perf_counter() - prep_start + if elapsed >= CALIB_PREP_SEC: + break + + now = time.perf_counter() + if now - last_ui_time >= 0.05: + remaining = CALIB_PREP_SEC - elapsed + # Show whole-second countdown: 3 → 2 → 1 + tick = max(1, int(np.ceil(remaining))) + self.after(0, lambda s=tick: self.countdown_label.configure( + text=str(s) + )) + last_ui_time = now + + # Drain serial buffer — keeps it fresh for collection + self.stream.readline() + + if not self.is_calibrating: + return + + # Brief "GO!" flash before collection starts + self.after(0, lambda: self.countdown_label.configure(text="GO!")) + time.sleep(0.2) + + # ── Phase 2: Collection ───────────────────────────────────── + # Switch to gesture colour — this signals "recording now" + self.after(0, lambda t=display_name, c=gesture_color: ( + self.gesture_label.configure(text=t, text_color=c) + )) + self.after(0, lambda d=int(CALIB_DURATION_SEC): self.instruction_label.configure( + text=f"Hold this gesture for {d} seconds" + )) + + gesture_start = time.perf_counter() + windows_collected = 0 + last_ui_time = gesture_start + + while self.is_calibrating: + elapsed = time.perf_counter() - gesture_start + if elapsed >= CALIB_DURATION_SEC: + break + + now = time.perf_counter() + if now - last_ui_time >= 0.05: + remaining = CALIB_DURATION_SEC - elapsed + progress = elapsed / CALIB_DURATION_SEC + self.after(0, lambda r=remaining: self.countdown_label.configure( + text=f"{r:.1f}s" + )) + self.after(0, lambda p=progress: self.gesture_progress.set(p)) + last_ui_time = now + + line = self.stream.readline() + if not line: + continue + + sample = parser.parse_line(line) + if sample is None: + continue + + window = windower.add_sample(sample) + if window is not None: + w_np = window.to_numpy() + feat = self.classifier.feature_extractor.extract_features_window(w_np) + all_features.append(feat) + all_labels.append(gesture) + windows_collected += 1 + w_ac = w_np - w_np.mean(axis=0) # remove per-window DC offset + ac_rms = float(np.sqrt(np.mean(w_ac ** 2))) + rms_by_gesture.setdefault(gesture, []).append(ac_rms) + + # Log and advance overall progress bar + overall_prog = (g_idx + 1) / n_gestures + self.after(0, lambda g=gesture, w=windows_collected, p=overall_prog, i=g_idx, n=n_gestures: ( + self._log(f" {g}: {w} windows"), + self.overall_progress.set(p), + self.progress_text.configure(text=f"{i + 1} / {n} gestures"), + )) + + finally: + self.stream.stop() + + if not self.is_calibrating: + # User navigated away — abort + return + + self.is_calibrating = False + + if not all_features: + self.after(0, lambda: messagebox.showerror( + "No Data", "No windows were collected. Check the EMG connection." + )) + self.after(0, self._update_start_button) + return + + # Fit the calibration transform + X_calib = np.array(all_features) + try: + self.classifier.calibration_transform.fit_from_calibration(X_calib, all_labels) + + # Set rest energy gate from raw window RMS (must be done here, not in + # fit_from_calibration, because extracted features are amplitude-normalized). + # + # Scan every candidate threshold and pick the one that minimises: + # rest_miss_rate (rest windows above gate → reach LDA → may jitter) + # gesture_miss_rate (gesture windows below gate → blocked → feel hard) + # Equal weighting by default; prints the full breakdown so you can see + # whether the two distributions actually separate cleanly. + if "rest" in rms_by_gesture: + rest_arr = np.array(rms_by_gesture["rest"]) + active_arr = np.concatenate([ + np.array(v) for g, v in rms_by_gesture.items() if g != "rest" + ]) + + # Print distribution summary for diagnosis + self.after(0, lambda: self._log("\nRMS energy distribution (AC, pre-gate):")) + self.after(0, lambda r=rest_arr: self._log( + f" rest — p50={np.percentile(r,50):.1f} p95={np.percentile(r,95):.1f} max={r.max():.1f}")) + for g, v in rms_by_gesture.items(): + if g == "rest": + continue + va = np.array(v) + self.after(0, lambda g=g, va=va: self._log( + f" {g:<12s}— p5={np.percentile(va,5):.1f} p50={np.percentile(va,50):.1f} min={va.min():.1f}")) + + # Scan candidates from rest min to active max + candidates = np.linspace(rest_arr.min(), active_arr.max(), 1000) + best_t, best_err = float(rest_arr.max()), float("inf") + for t in candidates: + rest_miss = float((rest_arr > t).mean()) # rest slips to LDA + gesture_miss = float((active_arr <= t).mean()) # gesture blocked + err = rest_miss + gesture_miss + if err < best_err: + best_err, best_t = err, float(t) + + rest_miss_at_best = float((rest_arr > best_t).mean()) * 100 + gesture_miss_at_best = float((active_arr <= best_t).mean()) * 100 + + self.classifier.calibration_transform.rest_energy_threshold = best_t + print(f"[Calibration] Optimal rest gate: {best_t:.2f} " + f"(rest_miss={rest_miss_at_best:.1f}%, gesture_miss={gesture_miss_at_best:.1f}%)") + self.after(0, lambda t=best_t, rm=rest_miss_at_best, gm=gesture_miss_at_best: ( + self._log(f"\nOptimal rest gate: {t:.2f}"), + self._log(f" rest above gate (may jitter): {rm:.1f}%"), + self._log(f" gestures below gate (feel hard): {gm:.1f}%"), + )) + + # Warn when rest energy overlaps any gesture — indicates bad electrode contact + if "rest" in rms_by_gesture: + for g, v in rms_by_gesture.items(): + if g != "rest" and np.array(v).min() < rest_arr.max(): + self.after(0, lambda g=g: self._log( + f"\nWARNING: rest energy overlaps {g}. " + f"Electrode placement may be poor — adjust and recalibrate.")) + + self.after(0, self._on_calibration_complete) + except Exception as e: + self.after(0, lambda err=e: messagebox.showerror( + "Calibration Error", f"Failed to fit transform:\n{err}" + )) + + self.after(0, self._update_start_button) + + def _on_calibration_complete(self): + """Called on the main thread when calibration data collection finishes.""" + self.gesture_label.configure(text="DONE!", text_color="#28a745") + self.instruction_label.configure( + text="Calibration collected. Click 'Apply' to activate." + ) + self.countdown_label.configure(text="") + self.gesture_progress.set(1.0) + self.apply_btn.configure(state="normal") + + # Show z-score normalization diagnostics so the user can spot bad calibration + ct = self.classifier.calibration_transform + if ct.mu_calib is not None and ct.sigma_calib is not None: + self._log(f"\nZ-score normalization fitted:") + self._log(f" mu_calib magnitude: {np.linalg.norm(ct.mu_calib):.4f}") + self._log(f" sigma_calib magnitude: {np.linalg.norm(ct.sigma_calib):.4f}") + if ct.rest_energy_threshold is not None: + self._log(f" rest energy gate: {ct.rest_energy_threshold:.4f}") + # Per-class residual in normalized space (lower = better alignment) + common = set(ct.class_means_calib) & set(ct.class_means_train) + if common: + self._log("Per-class alignment (normalized residual — lower is better):") + for cls in sorted(common): + norm_calib = (ct.class_means_calib[cls] - ct.mu_calib) / ct.sigma_calib + residual = np.linalg.norm(ct.class_means_train[cls] - norm_calib) + self._log(f" {cls}: {residual:.3f}") + + self._log("\nDone! Click 'Apply Calibration to Prediction' to use it.") + + def _apply_calibration(self): + if self.classifier is None or not self.classifier.calibration_transform.is_fitted: + messagebox.showerror("Not Ready", "Run calibration first.") + return + + app = self.winfo_toplevel() + if isinstance(app, EMGApp): + app.calibrated_classifier = self.classifier + self.calib_applied_label.configure( + text="Calibration applied! Disconnect, then go to Live Prediction.", + text_color="green", + ) + self._log("Calibration applied to Prediction page.") + messagebox.showinfo( + "Calibration Applied", + "Session calibration is now active.\n\n" + "Next steps:\n" + "1. Click 'Disconnect' on this page\n" + "2. Go to '5. Live Prediction'\n" + "3. Connect to ESP32 there\n" + "4. Choose Laptop inference mode\n" + "5. Start Prediction — the calibrated model will be used automatically.", + ) + + # ============================================================================= # LIVE PREDICTION PAGE # ============================================================================= @@ -1503,11 +2579,11 @@ class PredictionPage(BasePage): # State (MUST be initialized BEFORE creating UI elements) self.is_predicting = False self.is_connected = False - self.using_real_hardware = False self.classifier = None self.smoother = None self.stream = None self.data_queue = queue.Queue() + self.inference_mode = "ESP32" # "ESP32" or "Laptop" # Content self.content = ctk.CTkFrame(self) @@ -1526,33 +2602,57 @@ class PredictionPage(BasePage): ) self.model_label.pack(pady=10) - # Data Source selection + # Model file picker (for Laptop mode) + self.model_picker_frame = ctk.CTkFrame(self.status_frame, fg_color="transparent") + self.model_picker_frame.pack(fill="x", pady=(5, 0)) + + ctk.CTkLabel(self.model_picker_frame, text="Model:", font=ctk.CTkFont(size=14)).pack(side="left") + + self.model_file_var = ctk.StringVar(value="No models found") + self.model_dropdown = ctk.CTkOptionMenu( + self.model_picker_frame, variable=self.model_file_var, + values=["No models found"], width=280, + ) + self.model_dropdown.pack(side="left", padx=(10, 5)) + + self.refresh_models_btn = ctk.CTkButton( + self.model_picker_frame, text="⟳", width=30, + command=self._refresh_model_list + ) + self.refresh_models_btn.pack(side="left") + + # Initially hidden (only shown in Laptop mode) + self.model_picker_frame.pack_forget() + + # Inference mode selector + mode_frame = ctk.CTkFrame(self.status_frame, fg_color="transparent") + mode_frame.pack(fill="x", pady=(10, 0)) + + ctk.CTkLabel(mode_frame, text="Inference:", font=ctk.CTkFont(size=14)).pack(side="left") + + self.mode_var = ctk.StringVar(value="ESP32") + self.mode_selector = ctk.CTkSegmentedButton( + mode_frame, values=["ESP32", "Laptop"], + variable=self.mode_var, + command=self._on_mode_changed + ) + self.mode_selector.pack(side="left", padx=(10, 0)) + + self.mode_desc_label = ctk.CTkLabel( + mode_frame, + text="On-device inference (model baked into firmware)", + font=ctk.CTkFont(size=11), text_color="gray" + ) + self.mode_desc_label.pack(side="left", padx=(10, 0)) + + # ESP32 Connection (hardware required) source_frame = ctk.CTkFrame(self.status_frame, fg_color="transparent") source_frame.pack(fill="x", pady=(10, 0)) - ctk.CTkLabel(source_frame, text="Data Source:", font=ctk.CTkFont(size=14)).pack(anchor="w") + ctk.CTkLabel(source_frame, text="ESP32 Connection:", font=ctk.CTkFont(size=14)).pack(anchor="w") - self.source_var = ctk.StringVar(value="simulated") - - radio_frame = ctk.CTkFrame(source_frame, fg_color="transparent") - radio_frame.pack(fill="x", pady=(5, 0)) - - self.sim_radio = ctk.CTkRadioButton( - radio_frame, text="Simulated", variable=self.source_var, value="simulated", - command=self._on_source_change - ) - self.sim_radio.pack(side="left", padx=(0, 20)) - - self.real_radio = ctk.CTkRadioButton( - radio_frame, text="Real ESP32", variable=self.source_var, value="real", - command=self._on_source_change - ) - self.real_radio.pack(side="left") - - # Port selection (initially hidden) - self.port_frame = ctk.CTkFrame(source_frame, fg_color="transparent") - - port_select_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent") + # Port selection + port_select_frame = ctk.CTkFrame(source_frame, fg_color="transparent") port_select_frame.pack(fill="x", pady=(5, 0)) ctk.CTkLabel(port_select_frame, text="Port:").pack(side="left") @@ -1571,14 +2671,13 @@ class PredictionPage(BasePage): self.refresh_ports_btn.pack(side="left") # Connection status and button - connect_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent") + connect_frame = ctk.CTkFrame(source_frame, fg_color="transparent") connect_frame.pack(fill="x", pady=(5, 0)) self.connect_button = ctk.CTkButton( connect_frame, text="Connect", width=100, height=28, - command=self._toggle_connection, - state="disabled" # Disabled until "Real ESP32" selected + command=self._toggle_connection ) self.connect_button.pack(side="left", padx=(0, 10)) @@ -1648,24 +2747,84 @@ class PredictionPage(BasePage): def on_show(self): """Check model status when shown.""" - self.check_model() + self._refresh_model_list() + # If a calibrated classifier is available, surface it prominently + app = self.winfo_toplevel() + if isinstance(app, EMGApp) and app.calibrated_classifier is not None: + clf = app.calibrated_classifier + self.model_label.configure( + text=( + f"Calibrated model ready ({clf.model_type.upper()}, " + f"{len(clf.label_names)} classes) — will be used in Laptop mode" + ), + text_color="green", + ) + else: + self.check_model() + + def _refresh_model_list(self): + """Scan for saved models and populate the dropdown.""" + models = EMGClassifier.list_saved_models() + if models: + names = [p.name for p in models] + self.model_dropdown.configure(values=names) + # Default to most recent if current selection is invalid + current = self.model_file_var.get() + if current not in names: + latest = max(models, key=lambda p: p.stat().st_mtime) + self.model_file_var.set(latest.name) + else: + self.model_dropdown.configure(values=["No models found"]) + self.model_file_var.set("No models found") + + def _get_selected_model_path(self) -> Path | None: + """Get the full path of the user-selected model file.""" + name = self.model_file_var.get() + if name == "No models found": + return None + path = MODEL_DIR / name + return path if path.exists() else None def check_model(self): - """Check if a saved model exists.""" - model_path = EMGClassifier.get_default_model_path() - - if model_path.exists(): + """Check if a saved model exists (needed for Laptop mode).""" + if self.inference_mode == "Laptop": + # Show model picker in Laptop mode + self.model_picker_frame.pack(fill="x", pady=(5, 0), after=self.model_label) + model_path = self._get_selected_model_path() + if model_path: + self.model_label.configure( + text=f"Selected model: {model_path.name}", + text_color="green" + ) + self.start_button.configure(state="normal") + else: + self.model_label.configure( + text="No saved models. Train a model first!", + text_color="orange" + ) + self.start_button.configure(state="disabled") + else: + # ESP32 mode — hide model picker + self.model_picker_frame.pack_forget() self.model_label.configure( - text=f"Saved model found: {model_path.name}", + text="ESP32 mode: model is baked into firmware", text_color="green" ) self.start_button.configure(state="normal") - else: - self.model_label.configure( - text="No saved model. Train a model first (Option 3).", - text_color="orange" + + def _on_mode_changed(self, mode: str): + """Handle inference mode toggle.""" + self.inference_mode = mode + if mode == "ESP32": + self.mode_desc_label.configure( + text="On-device inference (model baked into firmware)" ) - self.start_button.configure(state="disabled") + else: + self.mode_desc_label.configure( + text="Laptop inference (streams raw EMG, runs Python model)" + ) + self._refresh_model_list() + self.check_model() def toggle_prediction(self): """Start or stop prediction.""" @@ -1683,120 +2842,6 @@ class PredictionPage(BasePage): # Reset flag after brief delay to prevent immediate re-trigger self.after(100, lambda: setattr(self, '_toggling', False)) - def start_prediction(self): - """Start live prediction.""" - # CRITICAL: Drain any stale messages from previous sessions FIRST - # This prevents old 'done' messages from stopping the new session - try: - while True: - self.data_queue.get_nowait() - except queue.Empty: - pass - - # Load model - try: - self.classifier = EMGClassifier.load(EMGClassifier.get_default_model_path()) - except Exception as e: - messagebox.showerror("Error", f"Failed to load model: {e}") - return - - # Determine data source - self.using_real_hardware = (self.source_var.get() == "real") - - # For real hardware, must be connected - if self.using_real_hardware: - if not self.is_connected or not self.stream: - messagebox.showerror("Not Connected", "Please connect to the ESP32 first.") - return - - # Send start command to begin streaming - try: - self.stream.start() - except Exception as e: - messagebox.showerror("Start Error", f"Failed to start streaming:\n{e}") - return - - # Create prediction smoother - self.smoother = PredictionSmoother( - label_names=self.classifier.label_names, - probability_smoothing=0.7, # Higher = more smoothing - majority_vote_window=5, # Past predictions to consider - debounce_count=3, # Consecutive same predictions to change output - ) - - self.is_predicting = True - self.start_button.configure(text="Stop", fg_color="red") - - # Disable source selection and connection during prediction - self.sim_radio.configure(state="disabled") - self.real_radio.configure(state="disabled") - if self.using_real_hardware: - self.connect_button.configure(state="disabled") - - # Start prediction thread - thread = threading.Thread(target=self._prediction_thread, daemon=True) - thread.start() - - # Start UI update - self.update_prediction_ui() - - def stop_prediction(self): - """Stop live prediction.""" - self.is_predicting = False - - # Safe cleanup - stream might already be in error state - try: - if self.stream: - if self.using_real_hardware: - # Send stop command (returns to CONNECTED state) - self.stream.stop() - else: - # For simulated stream, just stop it - self.stream.stop() - self.stream = None - except Exception: - pass # Ignore cleanup errors - - self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"]) - self.prediction_label.configure(text="---", text_color="white") - self.confidence_bar.set(0) - self.confidence_label.configure(text="Confidence: ---%") - self.sim_label.configure(text="") - self.raw_label.configure(text="", text_color="gray") - - # Re-enable source selection and connection button - self.sim_radio.configure(state="normal") - self.real_radio.configure(state="normal") - if self.using_real_hardware: - self.connect_button.configure(state="normal") - # Still connected, just not streaming - if self.is_connected: - device_name = self.stream.device_info.get('device', 'ESP32') if self.stream and self.stream.device_info else 'ESP32' - self._update_connection_status("green", f"Connected ({device_name})") - - def _on_source_change(self): - """Show/hide port selection based on data source.""" - # Clean up any existing connection/stream when switching modes - if self.is_connected and self.stream: - try: - self.stream.disconnect() - except: - pass - - self.is_connected = False - self.stream = None - - if self.source_var.get() == "real": - self.port_frame.pack(fill="x", pady=(5, 0)) - self._refresh_ports() - self.connect_button.configure(text="Connect", state="normal") - self._update_connection_status("gray", "Disconnected") - # Start button will be enabled after connection - else: - self.port_frame.pack_forget() - self._update_connection_status("gray", "Not using hardware") - self.connect_button.configure(state="disabled") - def _refresh_ports(self): """Scan and populate available serial ports.""" ports = serial.tools.list_ports.comports() @@ -1899,93 +2944,156 @@ class PredictionPage(BasePage): self._update_connection_status("gray", "Disconnected") self.connect_button.configure(text="Connect") - def toggle_prediction(self): - """Start or stop prediction.""" - if self.is_predicting: - self.stop_prediction() - else: - self.start_prediction() - def start_prediction(self): - """Start live prediction.""" - # Determine mode - self.using_real_hardware = (self.source_var.get() == "real") - - if self.using_real_hardware: - if not self.is_connected or not self.stream: - messagebox.showerror("Not Connected", "Please connect to ESP32 first.") - return - - print("[DEBUG] Starting Edge Prediction (On-Device)...") - try: - # Use the new interface method to start prediction - if hasattr(self.stream, 'start_predict'): - self.stream.start_predict() - self.stream.running = True - else: - # Fallback for simulated stream or older interface - # Simulated stream doesn't need 'start_predict', just 'start' - if not self.using_real_hardware: - self.stream.start() - else: - raise RuntimeError("Stream object missing start_predict method") - - except Exception as e: - messagebox.showerror("Start Error", f"Failed to start: {e}") - return + """Start live prediction (dispatches based on inference mode).""" + # Must be connected to ESP32 + if not self.is_connected or not self.stream: + messagebox.showerror("Not Connected", "Please connect to ESP32 first.") + return + if self.inference_mode == "ESP32": + self._start_esp32_prediction() else: - # Simulated - use PC-side inference - self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) - self.stream.start() + self._start_laptop_prediction() - # Load model for PC-side (Simulated) OR for display (optional) - # Even for Edge, we might want the label list. - if not self.using_real_hardware: - if not self.classifier: - model_path = EMGClassifier.get_default_model_path() - if model_path.exists(): - self.classifier = EMGClassifier.load(model_path) - self.model_label.configure(text="Model: Loaded", text_color="green") - else: - self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange") - - # Reset smoother - self.smoother = PredictionSmoother( - label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"], - probability_smoothing=0.7, - majority_vote_window=5, - debounce_count=3 - ) + def _start_esp32_prediction(self): + """Start on-device inference (ESP32 runs LDA internally).""" + print("[DEBUG] Starting ESP32 Prediction (On-Device)...") + try: + self.stream.start_predict() + self.stream.running = True + except Exception as e: + messagebox.showerror("Start Error", f"Failed to start ESP32 prediction: {e}") + return self.is_predicting = True self.start_button.configure(text="Stop Prediction", fg_color="red") - - # Start display loop - self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True) + self.connect_button.configure(state="disabled") + self.mode_selector.configure(state="disabled") + self.smoothing_info_label.configure( + text="Smoothing: ESP32 firmware (EMA + Majority + Debounce)" + ) + self.sim_label.configure(text="[ESP32 On-Device Inference]") + self.raw_label.configure(text="") + + self.prediction_thread = threading.Thread( + target=self._esp32_prediction_loop, daemon=True + ) + self.prediction_thread.start() + self.update_prediction_ui() + + def _start_laptop_prediction(self): + """Start laptop-side inference (raw EMG stream + Python multi-model voting).""" + print("[DEBUG] Starting Laptop Prediction...") + + # Prefer calibrated classifier from CalibrationPage if available + app = self.winfo_toplevel() + if isinstance(app, EMGApp) and app.calibrated_classifier is not None: + self.classifier = app.calibrated_classifier + print( + f"[Prediction] Using calibrated {self.classifier.model_type.upper()} " + f"classifier (session-aligned)" + ) + else: + # Fall back to loading the user-selected model from disk + model_path = self._get_selected_model_path() + if not model_path: + messagebox.showerror( + "No Model", + "No saved model found. Train a model first!\n\n" + "Tip: run '4. Calibrate' before predicting for better cross-session accuracy.", + ) + return + print(f"[DEBUG] Loading model: {model_path.name}") + try: + self.classifier = EMGClassifier.load(model_path) + except Exception as e: + messagebox.showerror("Model Error", f"Failed to load model: {e}") + return + + # Load ensemble model if available + self._ensemble = None + ensemble_path = Path(__file__).parent / 'models' / 'emg_ensemble.joblib' + if ensemble_path.exists(): + try: + import joblib + self._ensemble = joblib.load(ensemble_path) + print(f"[Prediction] Loaded ensemble model (4 LDAs)") + except Exception as e: + print(f"[Prediction] Ensemble load failed: {e}") + + # Load MLP weights if available + self._mlp = None + mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz' + if mlp_path.exists(): + try: + self._mlp = dict(np.load(mlp_path, allow_pickle=True)) + print(f"[Prediction] Loaded MLP weights (numpy)") + except Exception as e: + print(f"[Prediction] MLP load failed: {e}") + + # Report active models + model_names = [self.classifier.model_type.upper()] + if self._ensemble: + model_names.append("Ensemble") + if self._mlp: + model_names.append("MLP") + print(f"[Prediction] Active models: {' + '.join(model_names)} ({len(model_names)} total)") + + # Create smoother + self.smoother = PredictionSmoother( + label_names=self.classifier.label_names, + probability_smoothing=0.7, + majority_vote_window=5, + debounce_count=4, + ) + + # Start raw EMG streaming from ESP32 + try: + self.stream.start() + self.stream.running = True + except Exception as e: + messagebox.showerror("Start Error", f"Failed to start raw streaming: {e}") + return + + self.is_predicting = True + self.start_button.configure(text="Stop Prediction", fg_color="red") + self.connect_button.configure(state="disabled") + self.mode_selector.configure(state="disabled") + self.smoothing_info_label.configure( + text="Smoothing: Python (EMA 0.7 + Majority 5 + Debounce 3)" + ) + calib_active = self.classifier.calibration_transform.is_fitted + mode_str = ( + f"[Laptop — {' + '.join(model_names)}" + f"{' + Calibration' if calib_active else ''}]" + ) + self.sim_label.configure(text=mode_str) + + self.prediction_thread = threading.Thread( + target=self._laptop_prediction_loop, daemon=True + ) self.prediction_thread.start() - self.update_prediction_ui() def stop_prediction(self): - """Stop prediction.""" + """Stop prediction (either mode).""" self.is_predicting = False if self.stream: - self.stream.stop() # Sends "stop" usually - if not self.using_real_hardware: - self.stream = None - + self.stream.stop() + self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"]) self.prediction_label.configure(text="---", text_color="gray") self.confidence_label.configure(text="Confidence: ---%") self.confidence_bar.set(0) + self.connect_button.configure(state="normal") + self.mode_selector.configure(state="normal") + self.sim_label.configure(text="") + self.raw_label.configure(text="") - def prediction_loop(self): - """Loop for reading data and (optionally) running inference.""" + def _esp32_prediction_loop(self): + """Read JSON predictions from ESP32 on-device inference.""" import json - - parser = EMGParser(num_channels=NUM_CHANNELS) - windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0) while self.is_predicting: try: @@ -1993,43 +3101,124 @@ class PredictionPage(BasePage): if not line: continue - if self.using_real_hardware: - # Edge Inference Mode: Expect JSON - try: - line = line.strip() - if line.startswith('{'): - data = json.loads(line) - - if "gesture" in data: - # Update UI with Edge Prediction - gesture = data["gesture"] - conf = float(data.get("conf", 0.0)) - - self.data_queue.put(('prediction', (gesture, conf))) - - elif "status" in data: - print(f"[ESP32] {data}") - else: - pass - - except json.JSONDecodeError: - pass - - else: - # PC Side Inference (Simulated) - sample = parser.parse_line(line) - if sample: - window = windower.add_sample(sample) - if window and self.classifier: - # Run Inference Local - raw_label, proba = self.classifier.predict(window.to_numpy()) - label, conf, _ = self.smoother.update(raw_label, proba) - - self.data_queue.put(('prediction', (label, conf))) + try: + line = line.strip() + if line.startswith('{'): + data = json.loads(line) + + if "gesture" in data: + gesture = data["gesture"] + conf = float(data.get("conf", 0.0)) + self.data_queue.put(('prediction', (gesture, conf))) + + elif "status" in data: + print(f"[ESP32] {data}") + + except json.JSONDecodeError: + pass except Exception as e: if self.is_predicting: - print(f"Prediction loop error: {e}") + print(f"ESP32 prediction loop error: {e}") + self.data_queue.put(('error', f"ESP32 error: {e}")) + break + + def _run_ensemble(self, features: np.ndarray) -> np.ndarray: + """Run ensemble prediction: 3 specialist LDAs → meta-LDA → probabilities.""" + ens = self._ensemble + x_td = features[ens['td_idx']] + x_fd = features[ens['fd_idx']] + x_cc = features[ens['cc_idx']] + p_td = ens['lda_td'].predict_proba([x_td])[0] + p_fd = ens['lda_fd'].predict_proba([x_fd])[0] + p_cc = ens['lda_cc'].predict_proba([x_cc])[0] + x_meta = np.concatenate([p_td, p_fd, p_cc]) + return ens['meta_lda'].predict_proba([x_meta])[0] + + def _run_mlp(self, features: np.ndarray) -> np.ndarray: + """Run MLP forward pass: Dense(32,relu) → Dense(16,relu) → Dense(5,softmax).""" + m = self._mlp + x = features.astype(np.float32) + x = np.maximum(0, x @ m['w0'] + m['b0']) # relu + x = np.maximum(0, x @ m['w1'] + m['b1']) # relu + logits = x @ m['w2'] + m['b2'] # softmax + e = np.exp(logits - logits.max()) + return e / e.sum() + + def _laptop_prediction_loop(self): + """Parse raw EMG stream, window, extract features, multi-model vote.""" + parser = EMGParser(num_channels=NUM_CHANNELS) + windower = Windower( + window_size_ms=WINDOW_SIZE_MS, + sample_rate=SAMPLING_RATE_HZ, + hop_size_ms=HOP_SIZE_MS, + ) + + while self.is_predicting: + try: + line = self.stream.readline() + if not line: + continue + + sample = parser.parse_line(line) + if sample is None: + continue + + window = windower.add_sample(sample) + if window is None: + continue + + window_data = window.to_numpy() + + # --- Base LDA prediction (includes energy gate + calibration) --- + raw_label, proba_lda = self.classifier.predict(window_data) + + # If energy gate triggered rest, skip other models + rest_gated = (raw_label == "rest" and proba_lda.max() == 1.0) + + if rest_gated: + avg_proba = proba_lda + else: + # Extract calibrated features for ensemble/MLP + features_raw = self.classifier.feature_extractor.extract_features_window(window_data) + features = self.classifier.calibration_transform.apply(features_raw) + + probas = [proba_lda] + + # --- Ensemble --- + if self._ensemble: + try: + probas.append(self._run_ensemble(features)) + except Exception: + pass + + # --- MLP --- + if self._mlp: + try: + probas.append(self._run_mlp(features)) + except Exception: + pass + + avg_proba = np.mean(probas, axis=0) + + raw_label = self.classifier.label_names[int(np.argmax(avg_proba))] + + # Apply smoothing + smoothed_label, smoothed_conf, _debug = self.smoother.update(raw_label, avg_proba) + + self.data_queue.put(('prediction', (smoothed_label, smoothed_conf))) + + # Show raw vs smoothed mismatch + if raw_label != smoothed_label: + self.data_queue.put(('raw_info', f"raw: {raw_label}")) + else: + self.data_queue.put(('raw_info', "")) + + except Exception as e: + if self.is_predicting: + import traceback + traceback.print_exc() + self.data_queue.put(('error', f"Prediction error: {e}")) break def update_prediction_ui(self): @@ -2040,28 +3229,27 @@ class PredictionPage(BasePage): if msg_type == 'prediction': label, conf = data - + # Update label self.prediction_label.configure( text=label.upper(), text_color=get_gesture_color(label) ) - + # Update confidence self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%") self.confidence_bar.set(conf) - - # Clear raw label since we don't have raw vs smooth distinction in edge mode - # (or we could expose it if we updated the C struct, but for now keep it simple) - self.raw_label.configure(text="", text_color="gray") + + elif msg_type == 'raw_info': + # Show raw vs smoothed mismatch (laptop mode only) + self.raw_label.configure(text=data, text_color="orange" if data else "gray") elif msg_type == 'sim_gesture': self.sim_label.configure(text=f"[Simulating: {data}]") elif msg_type == 'error': # Show error and stop prediction - if self.using_real_hardware: - self._update_connection_status("red", "Disconnected") + self._update_connection_status("red", "Disconnected") messagebox.showerror("Prediction Error", data) self.stop_prediction() return @@ -2155,14 +3343,29 @@ class VisualizationPage(BasePage): """Generate plots in background.""" try: storage = SessionStorage() - X, y, label_names, _ = storage.load_all_for_training() + X, y, _trial_ids, session_indices, label_names, _ = storage.load_all_for_training() self.after(0, lambda: self.status_label.configure(text="Extracting features...")) - # Extract features and train LDA - extractor = EMGFeatureExtractor() + # Extract features matching the training pipeline + extractor = EMGFeatureExtractor( + channels=HAND_CHANNELS, expanded=True, + cross_channel=True, bandpass=True, + ) X_features = extractor.extract_features_batch(X) + # Apply per-session z-score normalization (matches training pipeline) + for sid in np.unique(session_indices): + mask = session_indices == sid + X_sess = X_features[mask] + y_sess = y[mask] + class_means = [X_sess[y_sess == c].mean(axis=0) + for c in np.unique(y_sess)] + balanced_mean = np.mean(class_means, axis=0) + std = X_sess.std(axis=0) + std[std < 1e-12] = 1.0 + X_features[mask] = (X_sess - balanced_mean) / std + lda = LinearDiscriminantAnalysis() lda.fit(X_features, y) X_lda = lda.transform(X_features) diff --git a/collected_data/latency_fix_100_20260127_184804.hdf5 b/extra_data/latency_fix_100_20260127_184804.hdf5 similarity index 100% rename from collected_data/latency_fix_100_20260127_184804.hdf5 rename to extra_data/latency_fix_100_20260127_184804.hdf5 diff --git a/collected_data/latency_fix_101_20260127_185344.hdf5 b/extra_data/latency_fix_101_20260127_185344.hdf5 similarity index 100% rename from collected_data/latency_fix_101_20260127_185344.hdf5 rename to extra_data/latency_fix_101_20260127_185344.hdf5 diff --git a/collected_data/latency_fix_102_20260127_190022.hdf5 b/extra_data/latency_fix_102_20260127_190022.hdf5 similarity index 100% rename from collected_data/latency_fix_102_20260127_190022.hdf5 rename to extra_data/latency_fix_102_20260127_190022.hdf5 diff --git a/collected_data/latency_fix_103_20260127_191249.hdf5 b/extra_data/latency_fix_103_20260127_191249.hdf5 similarity index 100% rename from collected_data/latency_fix_103_20260127_191249.hdf5 rename to extra_data/latency_fix_103_20260127_191249.hdf5 diff --git a/collected_data/latency_fix_104_20260127_195150.hdf5 b/extra_data/latency_fix_104_20260127_195150.hdf5 similarity index 100% rename from collected_data/latency_fix_104_20260127_195150.hdf5 rename to extra_data/latency_fix_104_20260127_195150.hdf5 diff --git a/collected_data/latency_fix_105_20260127_195503.hdf5 b/extra_data/latency_fix_105_20260127_195503.hdf5 similarity index 100% rename from collected_data/latency_fix_105_20260127_195503.hdf5 rename to extra_data/latency_fix_105_20260127_195503.hdf5 diff --git a/collected_data/new_placements_000_20260127_174231.hdf5 b/extra_data/new_placements_000_20260127_174231.hdf5 similarity index 100% rename from collected_data/new_placements_000_20260127_174231.hdf5 rename to extra_data/new_placements_000_20260127_174231.hdf5 diff --git a/learning_data_collection.py b/learning_data_collection.py index 96281e6..3aecf84 100644 --- a/learning_data_collection.py +++ b/learning_data_collection.py @@ -4,18 +4,26 @@ EMG Data Collection Pipeline A complete pipeline for collecting, labeling, and classifying EMG signals. OPTIONS: - 1. Collect Data - Run a labeled collection session with timed prompts + 1. Collect Data - Run a labeled collection session with timed prompts (requires ESP32) 2. Inspect Data - Load saved sessions, view raw EMG and features 3. Train Classifier - Train LDA on collected data with cross-validation + 4. Live Prediction - Real-time gesture classification (requires ESP32) + 5. Visualize LDA - Decision boundaries and feature space plots + 6. Benchmark - Compare LDA/QDA/SVM/MLP classifiers q. Quit FEATURES: - - Simulated EMG stream (swap for serial.Serial with real hardware) + - Real-time EMG acquisition via ESP32 serial interface - Timed prompt system for consistent data collection - - Automatic labeling based on prompt timing + - Automatic labeling based on prompt timing with onset detection - HDF5 storage with metadata - Time-domain feature extraction (RMS, WL, ZC, SSC) - LDA classifier with evaluation metrics + - Prediction smoothing (EMA + majority vote + debounce) + +HARDWARE REQUIRED: + - ESP32 with EMG sensors connected and firmware flashed + - USB serial connection (921600 baud) """ import time @@ -28,11 +36,13 @@ from pathlib import Path from datetime import datetime import json import h5py -from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from sklearn.model_selection import cross_val_score, train_test_split +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis +from sklearn.model_selection import cross_val_score, train_test_split, cross_val_predict, GroupShuffleSplit, GroupKFold from sklearn.metrics import classification_report, confusion_matrix import joblib # For model persistence import matplotlib.pyplot as plt +from scipy.signal import butter, sosfiltfilt, sosfilt, sosfilt_zi # For label alignment + bandpass +from serial_stream import RealSerialStream # ESP32 serial communication # ============================================================================= # CONFIGURATION @@ -41,14 +51,27 @@ NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors) SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog -# Windowing configuration -WINDOW_SIZE_MS = 150 # Window size in milliseconds -WINDOW_OVERLAP = 0.0 # Overlap ratio (0.0 = no overlap, 0.5 = 50% overlap) +# Windowing configuration (must match ESP32 inference timing) +WINDOW_SIZE_MS = 150 # Window size in milliseconds (150 samples at 1kHz) +HOP_SIZE_MS = 25 # Hop/stride in milliseconds (25 samples at 1kHz) +MAJORITY_WINDOW = 10 + +# Hand classifier channel selection +# The hand gesture classifier uses only forearm channels (ch0-ch2). +# The bicep channel (ch3) is excluded to prevent bicep activity from +# corrupting hand gesture classification. Ch3 is reserved for independent +# bicep envelope processing (see Phase 5). +HAND_CHANNELS = [0, 1, 2] # Forearm channels only (excludes bicep ch3) # Labeling configuration GESTURE_HOLD_SEC = 3.0 # How long to hold each gesture REST_BETWEEN_SEC = 2.0 # Rest period between gestures REPS_PER_GESTURE = 3 # Repetitions per gesture in a session +LABEL_SHIFT_MS = 150 # Shift label lookup forward by this many ms to account + # for human reaction time. A 150ms window labelled at its + # start_time can straddle a prompt transition; using + # start_time + shift assigns the label based on what the + # user is actually doing at the window's centre. # Storage configuration DATA_DIR = Path("collected_data") # Directory to store session files @@ -64,6 +87,10 @@ USER_ID = "user_001" # Current user ID (change per user) ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std ONSET_SEARCH_MS = 2000 # Search window after prompt (ms) +# Change 0: after onset detection shifts the label start backward, additionally +# relabel the first LABEL_FORWARD_SHIFT_MS of each gesture run as "rest" to skip +# the EMG transient at gesture onset. Paired with reducing TRANSITION_START_MS. +LABEL_FORWARD_SHIFT_MS = 100 # ms of each gesture onset to relabel as rest # ============================================================================= # TRANSITION WINDOW FILTERING @@ -73,7 +100,7 @@ ONSET_SEARCH_MS = 2000 # Search window after prompt (ms) # This is standard practice in EMG research (see Frontiers Neurorobotics 2023). DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training -TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts +TRANSITION_START_MS = 200 # Discard windows within this time AFTER gesture starts TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends # ============================================================================= @@ -85,7 +112,9 @@ class EMGSample: """Single sample from all channels at one point in time.""" timestamp: float # Python-side timestamp (seconds, monotonic) channels: list[float] # Raw ADC values per channel - esp_timestamp_ms: Optional[int] = None # Optional: timestamp from ESP32 + # DEPRECATED: esp_timestamp_ms is no longer used. Python-side timestamps are used + # for label alignment. Kept for backward compatibility with old serialized data. + esp_timestamp_ms: Optional[int] = None @dataclass @@ -111,80 +140,6 @@ class EMGWindow: return np.array([s.channels[ch] for s in self.samples]) -# ============================================================================= -# SIMULATED EMG STREAM (Mimics what ESP32 would send over serial) -# ============================================================================= - -class SimulatedEMGStream: - """ - Simulates ESP32 sending EMG data over serial. - - In reality, you'd replace this with: - import serial - ser = serial.Serial('COM3', 115200) - line = ser.readline() - - The ESP32 would send lines like: - "1234,512,489,501,523\n" (timestamp_ms, ch0, ch1, ch2, ch3) - """ - - def __init__(self, num_channels: int = 4, sample_rate: int = 1000): - self.num_channels = num_channels - self.sample_rate = sample_rate - self.running = False - self.output_queue = queue.Queue() - self._thread = None - self._esp_time_ms = 0 - - def start(self): - """Start the simulated data stream.""" - self.running = True - self._thread = threading.Thread(target=self._generate_data, daemon=True) - self._thread.start() - print(f"[SIM] Started simulated EMG stream: {self.num_channels} channels @ {self.sample_rate}Hz") - - def stop(self): - """Stop the simulated data stream.""" - self.running = False - if self._thread: - self._thread.join(timeout=1.0) - print("[SIM] Stopped simulated EMG stream") - - def readline(self) -> Optional[str]: - """ - Mimics serial.readline() - blocks until data available. - Returns line in format: "timestamp_ms,ch0,ch1,ch2,ch3" - """ - try: - return self.output_queue.get(timeout=1.0) - except queue.Empty: - return None - - def _generate_data(self): - """Background thread that generates fake EMG data.""" - interval = 1.0 / self.sample_rate - - while self.running: - # Simulate EMG signal: baseline noise + occasional bursts - channels = [] - for ch in range(self.num_channels): - # Base noise (typical ADC noise around 512 center for 10-bit ADC) - base = 512 + np.random.randn() * 10 - - # Occasionally add muscle activation burst (simulates gesture) - if np.random.random() < 0.01: # 1% chance each sample - base += np.random.randn() * 100 - - channels.append(int(np.clip(base, 0, 1023))) - - # Format like ESP32 would send - line = f"{self._esp_time_ms},{','.join(map(str, channels))}\n" - self.output_queue.put(line) - - self._esp_time_ms += 1 # Increment ESP32 timestamp - time.sleep(interval) - - # ============================================================================= # DATA PARSER (Converts serial lines to EMGSample objects) # ============================================================================= @@ -228,7 +183,7 @@ class EMGParser: sample = EMGSample( timestamp=time.perf_counter(), # High-resolution monotonic clock channels=channels, - esp_timestamp_ms=None # No longer using ESP32 timestamp + esp_timestamp_ms=None # Deprecated field, kept for compatibility ) self.samples_parsed += 1 @@ -256,27 +211,36 @@ class Windower: - Sweet spot: 150-250ms for EMG gesture recognition """ - def __init__(self, window_size_ms: int, sample_rate: int, overlap: float = 0.0): + def __init__(self, window_size_ms: int, sample_rate: int, hop_size_ms: int = 25): self.window_size_ms = window_size_ms self.sample_rate = sample_rate - self.overlap = overlap + self.hop_size_ms = hop_size_ms - # Calculate window size in samples + # Calculate window and step size in samples (hop-based, not overlap-based) self.window_size_samples = int(window_size_ms / 1000 * sample_rate) - self.step_size_samples = int(self.window_size_samples * (1 - overlap)) + self.step_size_samples = int(hop_size_ms / 1000 * sample_rate) # Buffer for incoming samples self.buffer: list[EMGSample] = [] self.window_count = 0 + # Verification: Print first 10 window start indices and timestamps + self._verification_printed = False + print(f"[Windower] Window: {window_size_ms}ms = {self.window_size_samples} samples") - print(f"[Windower] Step: {self.step_size_samples} samples (overlap={overlap*100:.0f}%)") + print(f"[Windower] Hop: {hop_size_ms}ms = {self.step_size_samples} samples") def add_sample(self, sample: EMGSample) -> Optional[EMGWindow]: """ Add a sample to the buffer. Returns a window if we have enough samples. Returns None if buffer isn't full yet. + + Window timing (at 1kHz): + - Window 0: samples 0-149, start index 0, time 0.000s + - Window 1: samples 25-174, start index 25, time 0.025s + - Window 2: samples 50-199, start index 50, time 0.050s + - ... """ self.buffer.append(sample) @@ -290,6 +254,16 @@ class Windower: end_time=window_samples[-1].timestamp, samples=window_samples.copy() ) + + # Verification: Print first 10 window start indices and timestamps + if not self._verification_printed and self.window_count < 10: + start_idx = self.window_count * self.step_size_samples + start_time_sec = start_idx / self.sample_rate + print(f"[Windower] Window {self.window_count}: start_idx={start_idx}, time={start_time_sec:.3f}s") + if self.window_count == 9: + self._verification_printed = True + print(f"[Windower] Verified: 150-sample windows, {self.step_size_samples}-sample hop") + self.window_count += 1 # Slide buffer by step size @@ -323,6 +297,7 @@ class GesturePrompt: gesture_name: str # e.g., "index_flex", "rest", "fist" duration_sec: float # How long to hold this gesture start_time: float = 0.0 # Filled in by scheduler when session starts + trial_id: int = -1 # Unique ID for this trial (gesture repetition) @dataclass @@ -369,18 +344,24 @@ class PromptScheduler: self.session_start_time: Optional[float] = None def _build_schedule(self) -> PromptSchedule: - """Create the sequence of prompts.""" + """Create the sequence of prompts with unique trial_ids.""" prompts = [] + trial_counter = 0 - # Initial rest period - prompts.append(GesturePrompt("rest", self.rest_sec)) + # Initial rest period (trial_id = 0) + prompts.append(GesturePrompt("rest", self.rest_sec, trial_id=trial_counter)) + trial_counter += 1 # For each repetition for rep in range(self.reps): # Cycle through all gestures for gesture in self.gestures: - prompts.append(GesturePrompt(gesture, self.hold_sec)) - prompts.append(GesturePrompt("rest", self.rest_sec)) + # Gesture trial + prompts.append(GesturePrompt(gesture, self.hold_sec, trial_id=trial_counter)) + trial_counter += 1 + # Rest trial (each rest is its own trial to avoid leakage) + prompts.append(GesturePrompt("rest", self.rest_sec, trial_id=trial_counter)) + trial_counter += 1 return PromptSchedule(prompts) @@ -433,6 +414,25 @@ class PromptScheduler: return "unlabeled" + def get_trial_id_for_time(self, timestamp: float) -> int: + """ + Get the trial_id for a specific timestamp. + + Each gesture repetition has a unique trial_id. Windows from the same + trial MUST stay together during train/test splitting to prevent leakage. + """ + if self.session_start_time is None: + return -1 + + elapsed = timestamp - self.session_start_time + + for prompt in self.schedule.prompts: + prompt_end = prompt.start_time + prompt.duration_sec + if prompt.start_time <= elapsed < prompt_end: + return prompt.trial_id + + return -1 + def print_schedule(self): """Print the full prompt schedule.""" print("\n" + "-" * 40) @@ -443,76 +443,10 @@ class PromptScheduler: print(f"\n Total duration: {self.schedule.total_duration:.1f}s") -# ============================================================================= -# SIMULATED EMG STREAM (Gesture-aware signal generation) -# ============================================================================= - -class GestureAwareEMGStream(SimulatedEMGStream): - """ - Enhanced simulation that generates different EMG patterns based on - which gesture is currently being prompted. - - This makes the simulated data more realistic for testing your pipeline. - Each gesture activates different "muscles" (channels) with different intensities. - """ - - # Define which channels activate for each gesture (0-1 intensity per channel) - GESTURE_PATTERNS = { - "rest": [0.0, 0.0, 0.0, 0.0], - "open": [0.3, 0.3, 0.3, 0.3], # Moderate all channels (extension) - "fist": [0.7, 0.7, 0.6, 0.6], # All channels active (flexion) - "hook_em": [0.8, 0.2, 0.7, 0.1], # Index + pinky extended (ch0 + ch2) - "thumbs_up": [0.1, 0.1, 0.2, 0.8], # Thumb dominant (ch3) - } - - def __init__(self, num_channels: int = 4, sample_rate: int = 1000): - super().__init__(num_channels, sample_rate) - self.current_gesture = "rest" - self._gesture_lock = threading.Lock() - - def set_gesture(self, gesture: str): - """Set the current gesture being performed.""" - with self._gesture_lock: - self.current_gesture = gesture - - def _generate_data(self): - """Generate EMG data based on current gesture.""" - interval = 1.0 / self.sample_rate - - while self.running: - with self._gesture_lock: - gesture = self.current_gesture - - # Get activation pattern for current gesture - pattern = self.GESTURE_PATTERNS.get(gesture, [0.0] * self.num_channels) - - channels = [] - for ch in range(self.num_channels): - # Base signal around 512 (10-bit ADC center) - base = 512 - - # Add noise (always present) - noise = np.random.randn() * 10 - - # Add muscle activation based on pattern - activation = pattern[ch] * np.random.randn() * 150 # Scaled EMG burst - - # Combine and clip to ADC range - value = int(np.clip(base + noise + activation, 0, 1023)) - channels.append(value) - - # Format like ESP32 would send - line = f"{self._esp_time_ms},{','.join(map(str, channels))}\n" - self.output_queue.put(line) - - self._esp_time_ms += 1 - time.sleep(interval) - - # ============================================================================= # LABEL ALIGNMENT (Simple Onset Detection) # ============================================================================= -from scipy.signal import butter, sosfiltfilt +# NOTE: butter and sosfiltfilt imported at top of file def align_labels_with_onset( @@ -608,9 +542,10 @@ def filter_transition_windows( labels: list[str], start_times: np.ndarray, end_times: np.ndarray, + trial_ids: Optional[np.ndarray] = None, transition_start_ms: float = TRANSITION_START_MS, transition_end_ms: float = TRANSITION_END_MS -) -> tuple[np.ndarray, np.ndarray, list[str]]: +) -> tuple[np.ndarray, np.ndarray, list[str], Optional[np.ndarray]]: """ Filter out windows that fall within transition zones at gesture boundaries. @@ -624,14 +559,15 @@ def filter_transition_windows( labels: String labels (n_windows,) start_times: Window start times in seconds (n_windows,) end_times: Window end times in seconds (n_windows,) + trial_ids: Trial IDs for train/test splitting (n_windows,) - optional transition_start_ms: Discard windows within this time after gesture start transition_end_ms: Discard windows within this time before gesture end Returns: - Filtered (X, y, labels) with transition windows removed + Filtered (X, y, labels, trial_ids) with transition windows removed """ if len(X) == 0: - return X, y, labels + return X, y, labels, trial_ids transition_start_sec = transition_start_ms / 1000.0 transition_end_sec = transition_end_ms / 1000.0 @@ -673,13 +609,14 @@ def filter_transition_windows( X_filtered = X[keep_mask] y_filtered = y[keep_mask] labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep] + trial_ids_filtered = trial_ids[keep_mask] if trial_ids is not None else None n_removed = len(X) - len(X_filtered) if n_removed > 0: print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)") print(f"[Filter] Kept {len(X_filtered)} windows for training") - return X_filtered, y_filtered, labels_filtered + return X_filtered, y_filtered, labels_filtered, trial_ids_filtered # ============================================================================= @@ -718,6 +655,7 @@ class SessionStorage: windows: list[EMGWindow], labels: list[str], metadata: SessionMetadata, + trial_ids: Optional[list[int]] = None, raw_samples: Optional[list[EMGSample]] = None, session_start_time: Optional[float] = None, enable_alignment: bool = ENABLE_LABEL_ALIGNMENT @@ -733,6 +671,7 @@ class SessionStorage: windows: List of EMGWindow objects (no label info) labels: List of gesture labels, parallel to windows metadata: Session metadata + trial_ids: List of trial IDs, parallel to windows (for proper train/test splitting) raw_samples: Raw samples (required for alignment) session_start_time: When session started (required for alignment) enable_alignment: Whether to perform automatic label alignment @@ -774,6 +713,28 @@ class SessionStorage: changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b) print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted") + + # Change 0: relabel the first LABEL_FORWARD_SHIFT_MS of each gesture + # run as 'rest' to remove the EMG onset transient from training data. + if LABEL_FORWARD_SHIFT_MS > 0: + shift_n = max(1, round(LABEL_FORWARD_SHIFT_MS / HOP_SIZE_MS)) + shifted = list(aligned_labels) + for i in range(len(aligned_labels)): + if aligned_labels[i] == 'rest': + continue + # Count consecutive same-label windows immediately before this one + prior_same = 0 + j = i - 1 + while j >= 0 and aligned_labels[j] == aligned_labels[i]: + prior_same += 1 + j -= 1 + if prior_same < shift_n: + shifted[i] = 'rest' + n_shifted = sum(1 for a, b in zip(aligned_labels, shifted) if a != b) + aligned_labels = shifted + print(f"[Storage] Forward shift ({LABEL_FORWARD_SHIFT_MS}ms, " + f"{shift_n} windows): {n_shifted} relabeled as rest") + elif enable_alignment: print("[Storage] Warning: No raw samples, skipping alignment") @@ -807,6 +768,14 @@ class SessionStorage: window_ids = np.array([w.window_id for w in windows], dtype=np.int32) windows_grp.create_dataset('window_ids', data=window_ids) + # Store trial_ids for proper train/test splitting (no trial leakage) + if trial_ids is not None: + trial_ids_arr = np.array(trial_ids, dtype=np.int32) + windows_grp.create_dataset('trial_ids', data=trial_ids_arr) + f.attrs['has_trial_ids'] = True + else: + f.attrs['has_trial_ids'] = False + windows_grp.create_dataset('start_times', data=start_times) windows_grp.create_dataset('end_times', data=end_times) @@ -919,7 +888,7 @@ class SessionStorage: label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)} y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32) - X, y_pre, labels = filter_transition_windows( + X, y_pre, labels, _ = filter_transition_windows( X, y_pre, labels, start_times, end_times ) @@ -931,7 +900,7 @@ class SessionStorage: print(f"[Storage] Labels: {label_names}") return X, y, label_names - def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]: + def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, list[str], list[str]]: """ Load ALL sessions combined into a single training dataset. @@ -941,6 +910,8 @@ class SessionStorage: Returns: X: Combined EMG windows from all sessions (n_total_windows, samples, channels) y: Combined labels as integers (n_total_windows,) + trial_ids: Combined trial IDs for proper train/test splitting (n_total_windows,) + session_indices: Per-window session index (0..n_sessions-1) for session normalization label_names: Sorted list of unique gesture labels across all sessions session_ids: List of session IDs that were loaded @@ -958,10 +929,13 @@ class SessionStorage: all_X = [] all_labels = [] + all_trial_ids = [] # Track trial_ids for proper train/test splitting + all_session_indices = [] # Per-window session index for session normalization loaded_sessions = [] reference_shape = None total_removed = 0 total_original = 0 + trial_id_offset = 0 # Offset trial_ids across sessions to ensure global uniqueness for session_id in sessions: filepath = self.get_session_filepath(session_id) @@ -972,6 +946,15 @@ class SessionStorage: start_times = f['windows/start_times'][:] end_times = f['windows/end_times'][:] + # Load trial_ids if available (new files), otherwise generate from index + if 'windows/trial_ids' in f: + trial_ids = f['windows/trial_ids'][:] + trial_id_offset + else: + # Legacy file without trial_ids: assign unique trial_id per window + # This is conservative - treats each window as separate trial + print(f"[Storage] WARNING: {session_id} missing trial_ids, generating from indices") + trial_ids = np.arange(len(X), dtype=np.int32) + trial_id_offset + # Validate shape compatibility if reference_shape is None: reference_shape = X.shape[1:] # (samples_per_window, channels) @@ -997,14 +980,22 @@ class SessionStorage: temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)} temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32) - X, temp_y, labels = filter_transition_windows( - X, temp_y, labels, start_times, end_times + X, temp_y, labels, trial_ids = filter_transition_windows( + X, temp_y, labels, start_times, end_times, trial_ids=trial_ids ) total_removed += original_count - len(X) + current_session_idx = len(all_X) # 0-based index before appending all_X.append(X) all_labels.extend(labels) + all_trial_ids.extend(trial_ids.tolist()) + all_session_indices.extend([current_session_idx] * len(X)) loaded_sessions.append(session_id) + + # Update trial_id offset for next session (ensure global uniqueness) + if len(trial_ids) > 0: + trial_id_offset = max(trial_ids) + 1 + print(f"[Storage] - {session_id}: {len(X)} windows" + (f" (was {original_count})" if filter_transitions and len(X) != original_count else "")) @@ -1013,19 +1004,23 @@ class SessionStorage: # Combine all data X_combined = np.concatenate(all_X, axis=0) + trial_ids_combined = np.array(all_trial_ids, dtype=np.int32) + session_indices_combined = np.array(all_session_indices, dtype=np.int32) # Create unified label mapping across all sessions label_names = sorted(set(all_labels)) label_to_idx = {name: idx for idx, name in enumerate(label_names)} y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32) + n_unique_trials = len(np.unique(trial_ids_combined)) print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}") + print(f"[Storage] Unique trials: {n_unique_trials} (for proper train/test splitting)") if filter_transitions and total_removed > 0: print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)") print(f"[Storage] Labels: {label_names}") print(f"[Storage] Sessions loaded: {len(loaded_sessions)}") - return X_combined, y_combined, label_names, loaded_sessions + return X_combined, y_combined, trial_ids_combined, session_indices_combined, label_names, loaded_sessions def list_sessions(self) -> list[str]: """List all available session IDs.""" @@ -1046,102 +1041,23 @@ class SessionStorage: # ============================================================================= -# COLLECTION LOOP (Core collection pattern) -# ============================================================================= - -def run_collection_demo(duration_seconds: float = 5.0): - """ - Demonstrates the core data collection loop. - - This is the pattern you'll use with real hardware - only the - data source changes (SimulatedEMGStream -> serial.Serial). - """ - print("\n" + "=" * 60) - print("EMG DATA COLLECTION DEMO") - print("=" * 60) - - # Initialize components - stream = SimulatedEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) - parser = EMGParser(num_channels=NUM_CHANNELS) - - # Storage for collected samples - collected_samples: list[EMGSample] = [] - - # Start the stream - stream.start() - start_time = time.perf_counter() - - print(f"\nCollecting for {duration_seconds} seconds...") - print("(In real use, this would read from serial port)\n") - - try: - while (time.perf_counter() - start_time) < duration_seconds: - # Read line from stream (blocks briefly if no data) - line = stream.readline() - - if line: - # Parse into structured sample - sample = parser.parse_line(line) - - if sample: - collected_samples.append(sample) - - # Print progress every 500 samples - if len(collected_samples) % 500 == 0: - elapsed = time.perf_counter() - start_time - rate = len(collected_samples) / elapsed - print(f" Collected {len(collected_samples)} samples " - f"({rate:.1f} samples/sec)") - - except KeyboardInterrupt: - print("\n[Interrupted by user]") - - finally: - stream.stop() - - # Report results - print("\n" + "-" * 40) - print("COLLECTION RESULTS") - print("-" * 40) - print(f"Total samples collected: {len(collected_samples)}") - print(f"Parse errors: {parser.parse_errors}") - print(f"Actual duration: {time.perf_counter() - start_time:.2f}s") - - if collected_samples: - actual_rate = len(collected_samples) / (time.perf_counter() - start_time) - print(f"Effective sample rate: {actual_rate:.1f} Hz") - - # Show sample data structure - print("\nExample sample:") - s = collected_samples[0] - print(f" timestamp: {s.timestamp:.6f}") - print(f" channels: {s.channels}") - print(f" esp_timestamp_ms: {s.esp_timestamp_ms}") - - # Quick statistics - data = np.array([s.channels for s in collected_samples]) - print(f"\nChannel statistics (mean/std):") - for ch in range(NUM_CHANNELS): - print(f" Ch{ch}: {data[:, ch].mean():.1f} / {data[:, ch].std():.1f}") - - return collected_samples - - -# ============================================================================= -# COLLECTION SESSION +# COLLECTION SESSION (Requires ESP32 hardware) # ============================================================================= def run_labeled_collection_demo(): """ Run a labeled EMG collection session: - 1. Prompt scheduler guides the user through gestures - 2. EMG stream generates/collects signals - 3. Windower groups samples into fixed-size windows - 4. Labels are assigned based on which prompt was active - 5. Session is saved to HDF5 with user ID + 1. Connect to ESP32 via serial + 2. Prompt scheduler guides the user through gestures + 3. EMG stream collects real signals + 4. Windower groups samples into fixed-size windows + 5. Labels are assigned based on which prompt was active + 6. Session is saved to HDF5 with user ID + + REQUIRES: ESP32 hardware connected via USB. """ print("\n" + "=" * 60) - print("LABELED EMG COLLECTION") + print("LABELED EMG COLLECTION (ESP32 Required)") print("=" * 60) # Get user ID @@ -1164,18 +1080,28 @@ def run_labeled_collection_demo(): ) scheduler.print_schedule() - # Create components - stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) + # Connect to ESP32 + print("\n[Connecting to ESP32...]") + stream = RealSerialStream() + try: + stream.connect(timeout=5.0) + print(f" Connected: {stream.device_info}") + except Exception as e: + print(f" ERROR: Failed to connect to ESP32: {e}") + print(" Make sure the ESP32 is connected and firmware is flashed.") + return [], [] + parser = EMGParser(num_channels=NUM_CHANNELS) windower = Windower( window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, - overlap=WINDOW_OVERLAP + hop_size_ms=HOP_SIZE_MS ) - # Storage for windows and labels (kept separate to enforce training/inference separation) + # Storage for windows, labels, and trial_ids (kept separate to enforce training/inference separation) collected_windows: list[EMGWindow] = [] collected_labels: list[str] = [] + collected_trial_ids: list[int] = [] # Track trial_id for proper train/test splitting last_prompt_name = None # Start collection @@ -1201,9 +1127,6 @@ def run_labeled_collection_demo(): print(f"\n [{elapsed:5.1f}s] >>> {prompt.gesture_name.upper()} <<<") last_prompt_name = prompt.gesture_name - # Update simulated stream to generate appropriate signal - stream.set_gesture(prompt.gesture_name) - # Read and parse data line = stream.readline() if line: @@ -1212,16 +1135,21 @@ def run_labeled_collection_demo(): # Try to form a window window = windower.add_sample(sample) if window: - # Store window and label separately (training/inference separation) - label = scheduler.get_label_for_time(window.start_time) + # Store window, label, and trial_id separately (training/inference separation) + # Shift label lookup forward to align with actual muscle activation + label_time = window.start_time + LABEL_SHIFT_MS / 1000.0 + label = scheduler.get_label_for_time(label_time) + trial_id = scheduler.get_trial_id_for_time(label_time) collected_windows.append(window) collected_labels.append(label) + collected_trial_ids.append(trial_id) except KeyboardInterrupt: print("\n[Interrupted by user]") finally: stream.stop() + stream.disconnect() # Report results print("\n" + "=" * 60) @@ -1275,11 +1203,14 @@ def run_labeled_collection_demo(): notes="" ) - # Pass windows and labels separately (enforces separation) - filepath = storage.save_session(collected_windows, collected_labels, metadata) + # Pass windows, labels, and trial_ids separately (enforces separation) + filepath = storage.save_session( + collected_windows, collected_labels, metadata, + trial_ids=collected_trial_ids + ) print(f"\nSession saved! ID: {session_id}") - return collected_windows, collected_labels + return collected_windows, collected_labels, collected_trial_ids # ============================================================================= @@ -1405,18 +1336,19 @@ def run_storage_demo(): transitions.append((i, label_names[y[i]])) current_label = y[i] - # Define colors for gesture markers + # Define colors for gesture markers (matches GUI color scheme) def get_gesture_color(name): - if 'index' in name.lower(): - return 'green' - elif 'fist' in name.lower(): - return 'blue' - elif 'rest' in name.lower(): + name_lower = name.lower() + if 'rest' in name_lower: return 'gray' - elif 'thumb' in name.lower(): + elif 'open' in name_lower: + return 'cyan' + elif 'fist' in name_lower: + return 'blue' + elif 'hook' in name_lower: return 'orange' - elif 'middle' in name.lower(): - return 'purple' + elif 'thumb' in name_lower: + return 'green' return 'red' feature_titles = ['RMS', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)'] @@ -1505,142 +1437,591 @@ def run_storage_demo(): class EMGFeatureExtractor: """ - Extracts time-domain features from EMG windows. + Extracts time-domain and frequency-domain features from EMG windows. - Features per channel: - - RMS (Root Mean Square): Signal power/amplitude - - WL (Waveform Length): Signal complexity - - ZC (Zero Crossings): Frequency content indicator - - SSC (Slope Sign Changes): Frequency content indicator + Change 1 — expanded feature set (expanded=True, default): + Per channel (20 features): + TD-4 (legacy): RMS, WL, ZC, SSC + TD extended: MAV, VAR, IEMG, WAMP + AR model: AR1, AR2, AR3, AR4 (4th-order via Yule-Walker) + Frequency: MNF, MDF, PKF, MNP (mean/median/peak freq, mean power) + Band power: BP0(20-80Hz), BP1(80-150Hz), BP2(150-250Hz), BP3(250-450Hz) + Cross-channel (cross_channel=True, default): + For each channel pair (i,j): Pearson correlation, log-RMS ratio, covariance + For 3 hand channels: 3 pairs × 3 = 9 cross-channel features + Total for HAND_CHANNELS=[0,1,2]: 20×3 + 9 = 69 features - These 4 features × N channels = 4N features per window. - For 4 channels: 16 features total. + Legacy mode (expanded=False): 4 features per channel only (RMS, WL, ZC, SSC). + Old pickled models automatically use legacy mode via __setstate__. - IMPORTANT: Per-window centering (DC offset removal) is applied before - feature extraction. This is critical because: - - EMG sensors have DC offset (e.g., ~512 for 10-bit ADC) - - Zero crossings require signal centered around 0 - - Per-window centering is causal (works in real-time inference) - - Global centering would leak information across windows - - LESSON: These features are: - - Fast to compute (good for real-time) - - Work well with LDA - - Proven effective for EMG gesture recognition + IMPORTANT: Per-window DC removal (mean subtraction) is applied before all + features. This is causal (uses only data within the current window). """ - def __init__(self, zc_threshold_percent: float = 0.1, ssc_threshold_percent: float = 0.1): + # Feature key ordering — determines output vector layout + _LEGACY_KEYS = ['rms', 'wl', 'zc', 'ssc'] + _EXPANDED_KEYS = [ + 'rms', 'wl', 'zc', 'ssc', # TD-4 + 'mav', 'var', 'iemg', 'wamp', # TD extended + 'ar1', 'ar2', 'ar3', 'ar4', # AR(4) model + 'mnf', 'mdf', 'pkf', 'mnp', # Frequency descriptors + 'bp0', 'bp1', 'bp2', 'bp3', # Band powers + ] + # Keys that are amplitude-dependent and should be divided by norm_factor + _NORM_KEYS = {'rms', 'wl', 'mav', 'iemg'} + + def __init__(self, + zc_threshold_percent: float = 0.1, + ssc_threshold_percent: float = 0.1, + channels: Optional[list[int]] = None, + normalize: bool = True, + expanded: bool = True, + cross_channel: bool = True, + fft_n: int = 256, + fs: float = float(SAMPLING_RATE_HZ), + reinhard: bool = False, + bandpass: bool = True): """ Args: - zc_threshold_percent: ZC threshold as fraction of signal RMS - ssc_threshold_percent: SSC threshold as fraction of signal RMS squared + zc_threshold_percent: ZC/WAMP threshold as fraction of RMS. + ssc_threshold_percent: SSC threshold as fraction of RMS squared. + channels: Channel indices to extract features from; None = all. + normalize: Divide amplitude-dependent features by total RMS across + channels (makes model robust to impedance shifts). + expanded: Use full 20-feature/channel set (Change 1). False = legacy + 4-feature/channel set for backward compatibility. + cross_channel: Append pairwise cross-channel features (correlation, + log-RMS ratio, covariance). Only when expanded=True. + fft_n: FFT size for frequency features (zero-pads window if needed). + fs: Sampling frequency in Hz (used for frequency axis). + reinhard: Change 4 — apply Reinhard tone-mapping (64·x/(32+|x|)) + before feature extraction. Must match MODEL_USE_REINHARD in + firmware model_weights.h. Default False. + bandpass: Apply 20-450 Hz bandpass filter before feature extraction. + Must be True to match firmware IIR bandpass. Default True. """ - self.zc_threshold_percent = zc_threshold_percent + self.zc_threshold_percent = zc_threshold_percent self.ssc_threshold_percent = ssc_threshold_percent + self.channels = channels + self.normalize = normalize + self.expanded = expanded + self.cross_channel = cross_channel + self.fft_n = fft_n + self.fs = fs + self.reinhard = reinhard + self.bandpass = bandpass + + # Pre-compute bandpass SOS coefficients (2nd-order Butterworth, 20-450 Hz) + # Matches firmware IIR biquad bandpass in inference.c + if self.bandpass: + nyq = self.fs / 2.0 + self._bp_sos = butter(2, [20.0 / nyq, 450.0 / nyq], btype='band', output='sos') + else: + self._bp_sos = None + + def __setstate__(self, state: dict): + """Restore pickle and add defaults for attributes added in Change 1+.""" + self.__dict__.update(state) + if 'expanded' not in state: self.expanded = False + if 'cross_channel' not in state: self.cross_channel = False + if 'fft_n' not in state: self.fft_n = 256 + if 'fs' not in state: self.fs = float(SAMPLING_RATE_HZ) + if 'reinhard' not in state: self.reinhard = False + if 'bandpass' not in state: self.bandpass = False + # Reconstruct SOS coefficients for bandpass filter + if self.bandpass: + nyq = self.fs / 2.0 + self._bp_sos = butter(2, [20.0 / nyq, 450.0 / nyq], btype='band', output='sos') + else: + self._bp_sos = None + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + @staticmethod + def _ar_coefficients(signal: np.ndarray, order: int = 4) -> np.ndarray: + """4th-order AR coefficients via Yule-Walker (autocorrelation method).""" + n = len(signal) + r = np.array([float(np.dot(signal[:n - k], signal[k:])) / n + for k in range(order + 1)]) + T = np.array([[r[abs(i - j)] for j in range(order)] for i in range(order)]) + try: + return np.linalg.solve(T, r[1:order + 1]) + except np.linalg.LinAlgError: + return np.zeros(order) + + def _spectral_features(self, signal: np.ndarray) -> tuple: + """MNF, MDF, PKF, MNP, BP0-BP3 via rfft (zero-padded to fft_n).""" + spec = np.abs(np.fft.rfft(signal, n=self.fft_n)) ** 2 + freqs = np.fft.rfftfreq(self.fft_n, d=1.0 / self.fs) + total = float(np.sum(spec)) + 1e-10 + + mnf = float(np.dot(freqs, spec) / total) + + cumsum = np.cumsum(spec) + mid_idx = int(np.searchsorted(cumsum, total / 2.0)) + mdf = float(freqs[min(mid_idx, len(freqs) - 1)]) + + pkf = float(freqs[int(np.argmax(spec))]) + mnp = float(total / len(spec)) + + def _bp(f_lo: float, f_hi: float) -> float: + mask = (freqs >= f_lo) & (freqs < f_hi) + return float(np.sum(spec[mask]) / total) + + return mnf, mdf, pkf, mnp, _bp(20, 80), _bp(80, 150), _bp(150, 250), _bp(250, 450) + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ def extract_features_single_channel(self, signal: np.ndarray) -> dict: """ - Extract all features from a single channel signal. + Extract features from a single, already-selected channel. - Per-window centering is applied first to remove DC offset. - This is the standard approach for EMG feature extraction. + Returns a dict with 4 keys (legacy) or 20 keys (expanded). + Bandpass filter (if enabled) + per-window DC removal are applied first. """ - # Per-window centering: remove DC offset (critical for ZC/SSC) - # This only uses data from the current window (causal for real-time) - signal = signal - np.mean(signal) + # Bandpass filter to match firmware IIR (20-450 Hz, 2nd-order Butterworth). + # Uses sosfilt (causal) with sosfilt_zi to initialise the filter state + # at the signal's DC level, avoiding large startup transients on the + # short 150-sample windows. + if self.bandpass and self._bp_sos is not None: + zi = sosfilt_zi(self._bp_sos) * signal[0] + signal, _ = sosfilt(self._bp_sos, signal, zi=zi) - # RMS - Root Mean Square (now measures AC power, not DC offset) - rms = np.sqrt(np.mean(signal ** 2)) + signal = signal - np.mean(signal) # DC removal - # WL - Waveform Length - wl = np.sum(np.abs(np.diff(signal))) + # Change 4 — Reinhard tone-mapping (compresses large spikes) + if self.reinhard: + signal = 64.0 * signal / (32.0 + np.abs(signal)) - # Dynamic thresholds based on signal RMS - zc_thresh = self.zc_threshold_percent * rms + rms = float(np.sqrt(np.mean(signal ** 2))) + wl = float(np.sum(np.abs(np.diff(signal)))) + + zc_thresh = self.zc_threshold_percent * rms ssc_thresh = (self.ssc_threshold_percent * rms) ** 2 - # ZC - Zero Crossings (with threshold to reject noise) - # Now meaningful because signal is centered around 0 - sign_changes = signal[:-1] * signal[1:] < 0 - amplitude_diff = np.abs(np.diff(signal)) > zc_thresh - zc = np.sum(sign_changes & amplitude_diff) + diffs = np.diff(signal) + sign_chg = signal[:-1] * signal[1:] < 0 + zc = int(np.sum(sign_chg & (np.abs(diffs) > zc_thresh))) - # SSC - Slope Sign Changes - diff_left = signal[1:-1] - signal[:-2] - diff_right = signal[1:-1] - signal[2:] - ssc = np.sum((diff_left * diff_right) > ssc_thresh) + dl = signal[1:-1] - signal[:-2] + dr = signal[1:-1] - signal[2:] + ssc = int(np.sum((dl * dr) > ssc_thresh)) - return {'rms': rms, 'wl': wl, 'zc': zc, 'ssc': ssc} + feats: dict = {'rms': rms, 'wl': wl, 'zc': float(zc), 'ssc': float(ssc)} + + if self.expanded: + mav = float(np.mean(np.abs(signal))) + var = float(np.var(signal)) + iemg = float(np.sum(np.abs(signal))) + wamp = int(np.sum(np.abs(diffs) > zc_thresh)) + + ar = self._ar_coefficients(signal, order=4) + mnf, mdf, pkf, mnp, bp0, bp1, bp2, bp3 = self._spectral_features(signal) + + feats.update({ + 'mav': mav, 'var': var, 'iemg': iemg, 'wamp': float(wamp), + 'ar1': float(ar[0]), 'ar2': float(ar[1]), + 'ar3': float(ar[2]), 'ar4': float(ar[3]), + 'mnf': mnf, 'mdf': mdf, 'pkf': pkf, 'mnp': mnp, + 'bp0': bp0, 'bp1': bp1, 'bp2': bp2, 'bp3': bp3, + }) + + return feats def extract_features_window(self, window: np.ndarray) -> np.ndarray: """ Extract features from a window of shape (samples, channels). - Returns flat array: [ch0_rms, ch0_wl, ch0_zc, ch0_ssc, ch1_rms, ...] + Returns a flat float32 array ordered as: + [ch_i feats..., ch_j feats..., ..., cross-channel feats...] """ - n_channels = window.shape[1] - features = [] + channel_indices = self.channels if self.channels is not None \ + else list(range(window.shape[1])) - for ch in range(n_channels): - ch_features = self.extract_features_single_channel(window[:, ch]) - features.extend([ch_features['rms'], ch_features['wl'], - ch_features['zc'], ch_features['ssc']]) + all_ch_feats = [self.extract_features_single_channel(window[:, ch]) + for ch in channel_indices] - return np.array(features) + norm_factor = 1.0 + if self.normalize: + total_rms = float(np.sqrt(sum(f['rms'] ** 2 for f in all_ch_feats))) + norm_factor = max(total_rms, 1e-6) + + feat_keys = self._EXPANDED_KEYS if self.expanded else self._LEGACY_KEYS + + features: list[float] = [] + for ch_feats in all_ch_feats: + for key in feat_keys: + val = ch_feats[key] + if self.normalize and key in self._NORM_KEYS: + val = val / norm_factor + features.append(val) + + # Cross-channel features (expanded mode, ≥2 channels) + # Bug 6 fix: firmware computes cross-channel features from + # Reinhard-mapped signals when MODEL_USE_REINHARD=1. Apply the + # same tone-mapping here so correlation/covariance match. + if self.expanded and self.cross_channel and len(channel_indices) >= 2: + centered = [] + for ch in channel_indices: + sig = window[:, ch] - np.mean(window[:, ch]) + # Apply bandpass if enabled (matches firmware pipeline order) + if self.bandpass and self._bp_sos is not None: + zi = sosfilt_zi(self._bp_sos) * sig[0] + sig, _ = sosfilt(self._bp_sos, sig, zi=zi) + if self.reinhard: + sig = 64.0 * sig / (32.0 + np.abs(sig)) + centered.append(sig) + rms_vals = [f['rms'] + 1e-10 for f in all_ch_feats] + n = window.shape[0] + + for i in range(len(channel_indices)): + for j in range(i + 1, len(channel_indices)): + si, sj = centered[i], centered[j] + ri, rj = rms_vals[i], rms_vals[j] + + corr = float(np.clip(np.dot(si, sj) / (n * ri * rj), -1.0, 1.0)) + lrms = float(np.log(ri / rj)) + cov = float(np.dot(si, sj) / n) + if self.normalize: + cov /= (norm_factor ** 2) + + features.extend([corr, lrms, cov]) + + return np.array(features, dtype=np.float32) def extract_features_batch(self, X: np.ndarray) -> np.ndarray: """ - Extract features from batch of windows. + Extract features from a batch of windows. Args: - X: Shape (n_windows, n_samples, n_channels) - + X: (n_windows, n_samples, n_channels) Returns: - Features array of shape (n_windows, n_features) - where n_features = n_channels * 4 + (n_windows, n_features) float32 array. """ - n_windows = X.shape[0] - n_channels = X.shape[2] - n_features = n_channels * 4 # 4 features per channel + # Vectorised bandpass: apply sosfiltfilt on all windows at once along + # the samples axis. This is ~100x faster than per-window sosfilt calls + # (scipy's C loop vs Python loop). We disable per-window bandpass in + # extract_features_single_channel during batch extraction. + if self.bandpass and self._bp_sos is not None: + X = sosfiltfilt(self._bp_sos, X, axis=1).astype(np.float32) - features = np.zeros((n_windows, n_features)) + n_windows = X.shape[0] + n_ch_total = X.shape[2] + n_features = self._n_features(n_ch_total) + features = np.zeros((n_windows, n_features), dtype=np.float32) - for i in range(n_windows): - features[i] = self.extract_features_window(X[i]) + # Temporarily disable per-window bandpass (already applied above) + saved_bp = self.bandpass + self.bandpass = False + try: + for i in range(n_windows): + features[i] = self.extract_features_window(X[i]) + finally: + self.bandpass = saved_bp return features - def get_feature_names(self, n_channels: int) -> list[str]: - """Get human-readable feature names.""" - names = [] - for ch in range(n_channels): - names.extend([f'ch{ch}_rms', f'ch{ch}_wl', f'ch{ch}_zc', f'ch{ch}_ssc']) + def _n_features(self, n_total_channels: int) -> int: + """Total feature vector length for the current configuration.""" + n_ch = len(self.channels) if self.channels is not None else n_total_channels + per_ch = len(self._EXPANDED_KEYS if self.expanded else self._LEGACY_KEYS) + n = n_ch * per_ch + if self.expanded and self.cross_channel and n_ch >= 2: + n += 3 * (n_ch * (n_ch - 1) // 2) # 3 features × C(n_ch,2) pairs + return n + + def get_feature_names(self, n_channels: int = 0) -> list[str]: + """Human-readable feature names matching the extract_features_window layout.""" + channel_indices = self.channels if self.channels is not None \ + else list(range(n_channels)) + + feat_keys = self._EXPANDED_KEYS if self.expanded else self._LEGACY_KEYS + + names: list[str] = [] + for ch in channel_indices: + for key in feat_keys: + names.append(f'ch{ch}_{key}') + + if self.expanded and self.cross_channel and len(channel_indices) >= 2: + for i in range(len(channel_indices)): + for j in range(i + 1, len(channel_indices)): + ci, cj = channel_indices[i], channel_indices[j] + names.extend([ + f'cc_{ci}{cj}_corr', + f'cc_{ci}{cj}_lrms', + f'cc_{ci}{cj}_cov', + ]) + return names +# ============================================================================= +# Change 6 — MPF FEATURE EXTRACTOR (Python training only) +# ============================================================================= + +class MPFFeatureExtractor: + """ + Simplified 3-channel MPF: CSD upper triangle per 6 frequency bands = 36 features. + Python training only. Omits matrix logarithm (not needed for 3 channels). + Source: Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w + ESP32 approximation: use bp0-bp3 from EMGFeatureExtractor (Change 1). + """ + BANDS = [(0, 62), (62, 125), (125, 187), (187, 250), (250, 375), (375, 500)] + + def __init__(self, channels=None, log_diagonal=True): + self.channels = channels or HAND_CHANNELS + self.log_diag = log_diagonal + self.n_ch = len(self.channels) + self._r, self._c = np.triu_indices(self.n_ch) + self.n_features = len(self.BANDS) * len(self._r) + + def extract_window(self, window): + sig = window[:, self.channels].astype(np.float64) + N = len(sig) + freqs = np.fft.rfftfreq(N, d=1.0 / SAMPLING_RATE_HZ) + Xf = np.fft.rfft(sig, axis=0) + feats = [] + for lo, hi in self.BANDS: + mask = (freqs >= lo) & (freqs < hi) + if not mask.any(): + feats.extend([0.0] * len(self._r)) + continue + CSD = (Xf[mask].conj().T @ Xf[mask]).real / N + if self.log_diag: + for k in range(self.n_ch): + CSD[k, k] = np.log(max(CSD[k, k], 1e-10)) + feats.extend(CSD[self._r, self._c].tolist()) + return np.array(feats, dtype=np.float32) + + def extract_batch(self, X): + out = np.zeros((len(X), self.n_features), dtype=np.float32) + for i in range(len(X)): + out[i] = self.extract_window(X[i]) + return out + + +# ============================================================================= +# CALIBRATION TRANSFORM (Per-session feature-space alignment) +# ============================================================================= + +class CalibrationTransform: + """ + Corrects for electrode placement drift between sessions via Session Z-Score Normalization. + + Training: each training session's features are independently z-scored + (subtract session mean, divide by session std) before LDA fitting. + This removes placement-dependent amplitude shifts, so the model learns + in a placement-invariant normalized feature space. + + Calibration: collect a short clip of each gesture → compute global + mean (mu_calib) and std (sigma_calib) of those features → apply the + same z-score to every live window: + + x_normalized = (x_live - mu_calib) / sigma_calib + + This projects live features into the same normalized space that training + used, regardless of how electrode placement changed. + + Workflow: + 1. fit_from_training() — called automatically in EMGClassifier.train(). + Stores per-class training centroids (in normalized + space) for diagnostics. + 2. fit_from_calibration() — called at session start after collecting + a short clip of each gesture. + Computes mu_calib / sigma_calib. + 3. apply() — called on every live feature vector. + Returns (features - mu_calib) / sigma_calib. + """ + + def __init__(self): + self.has_training_stats: bool = False + self.is_fitted: bool = False + self.class_means_train: dict = {} # {label: ndarray} from training (normalized space) + self.class_means_calib: dict = {} # {label: ndarray} from calibration (raw space) + # Stats for the z-score transform + self.mu_calib: Optional[np.ndarray] = None # Class-balanced mean of calibration features (raw space) + self.sigma_calib: Optional[np.ndarray] = None # Global std of calibration features (raw space) + self.sigma_train: Optional[np.ndarray] = None # Mean per-session sigma from training (preferred scale ref) + # Energy gate for rest detection (bypasses LDA when signal is quiet) + self.rest_energy_threshold: Optional[float] = None + + def fit_from_training(self, X_features: np.ndarray, y: np.ndarray, label_names: list): + """ + Store per-class training centroids. Called automatically in EMGClassifier.train(). + + Args: + X_features: (n_windows, n_features) extracted training features + y: (n_windows,) integer label indices + label_names: label string list matching y indices + """ + self.has_training_stats = True + + self.class_means_train = {} + for i, name in enumerate(label_names): + mask = y == i + if mask.any(): + self.class_means_train[name] = np.mean(X_features[mask], axis=0) + + def fit_from_calibration(self, calib_features: np.ndarray, calib_labels: list): + """ + Compute z-score normalization params from calibration-session data. + + mu_calib = class-balanced mean (average of per-class centroids) + sigma_calib = overall std of all calibration feature windows + + Using the class-balanced mean prevents near-zero-amplitude classes (rest) + from landing at the wrong normalized position when training sessions had + unequal numbers of windows per class. + + Args: + calib_features: (n_windows, n_features) from calibration clips + calib_labels: gesture label per window + """ + if not self.has_training_stats: + raise ValueError( + "Training stats not available. Load a model that was trained " + "after calibration support was added (retrain if needed)." + ) + + # Per-class calibration centroids (raw space) + self.class_means_calib = {} + label_arr = np.array(calib_labels) + for label in set(calib_labels): + mask = label_arr == label + if mask.any(): + self.class_means_calib[label] = np.mean(calib_features[mask], axis=0) + + # Class-balanced mean: average of per-class centroids (not overall mean). + # Prevents class-imbalanced calibration clips from biasing the normalization + # origin (especially important for rest, which has near-zero amplitude). + self.mu_calib = np.mean(list(self.class_means_calib.values()), axis=0) + self.sigma_calib = np.std(calib_features, axis=0) + 1e-8 + + # rest_energy_threshold is set externally from raw window RMS values + # (cannot be computed here — extracted features are amplitude-normalized). + self.rest_energy_threshold = None + + self.is_fitted = True + + # Decide which sigma to use for scaling: + # sigma_train (preferred) — mean per-session sigma from training. + # Ensures the classifier sees calibration features at the SAME scale + # as training features, which is critical for QDA whose per-class + # covariance ellipsoids are fixed in normalized training space. + # sigma_calib (fallback) — std of this calibration session. + # Used only if the model was trained without session normalization. + sigma_used = self.sigma_train if self.sigma_train is not None else self.sigma_calib + sigma_source = "sigma_train" if self.sigma_train is not None else "sigma_calib (fallback)" + print(f"[Calibration] Z-score fit: {len(calib_features)} windows, " + f"{len(self.class_means_calib)} classes [scale ref: {sigma_source}]") + # Per-class residual in normalized space (lower = better alignment) + common = set(self.class_means_calib) & set(self.class_means_train) + for c in sorted(common): + norm_calib = (self.class_means_calib[c] - self.mu_calib) / self.sigma_calib + residual = np.linalg.norm(self.class_means_train[c] - norm_calib) + print(f"[Calibration] {c}: normalized residual = {residual:.4f}") + + def apply(self, features: np.ndarray) -> np.ndarray: + """ + Z-score normalize features using calibration session statistics. + + Uses sigma_train (mean per-session sigma from training) for scaling when + available — this keeps calibration features at the same scale as training + features, which is critical for QDA. Falls back to sigma_calib for old + models trained without session normalization. + + Args: + features: shape (n_features,) or (n_windows, n_features) + Returns: + (features - mu_calib) / sigma, same shape as input. + Pass-through if not fitted. + """ + if not self.is_fitted: + return features + sigma = self.sigma_train if self.sigma_train is not None else self.sigma_calib + return (features - self.mu_calib) / sigma + + def reset(self): + """Remove per-session calibration (keeps training centroids intact).""" + self.mu_calib = None + self.sigma_calib = None + self.rest_energy_threshold = None + self.is_fitted = False + self.class_means_calib = {} + # sigma_train is permanent (set at train time, not session-specific) + + +# ============================================================================= +# DATA AUGMENTATION (Change 3) +# ============================================================================= + +def augment_emg_batch( + X: np.ndarray, + y: np.ndarray, + multiplier: int = 3, + seed: int = 42, +) -> tuple[np.ndarray, np.ndarray]: + """ + Augment raw EMG windows for training robustness. + + Must be called on raw windows (n_windows, n_samples, n_channels), + not on pre-computed features. Each copy independently applies: + - Amplitude scaling ×[0.80, 1.20] + - Gaussian noise 5 % of per-window RMS + - DC offset jitter ±20 counts + - Time-shift (roll) ±5 samples + + Source: Kaifosh et al. Nature 2025. doi:10.1038/s41586-025-09255-w + """ + rng = np.random.default_rng(seed) + aug_X, aug_y = [X], [y] + for _ in range(multiplier - 1): + Xc = X.copy().astype(np.float32) + Xc *= rng.uniform(0.80, 1.20, (len(X), 1, 1)).astype(np.float32) + rms = np.sqrt(np.mean(Xc ** 2, axis=(1, 2), keepdims=True)) + 1e-8 + Xc += rng.standard_normal(Xc.shape).astype(np.float32) * (0.05 * rms) + Xc += rng.uniform(-20., 20., (len(X), 1, X.shape[2])).astype(np.float32) + shifts = rng.integers(-5, 6, size=len(X)) + for i in range(len(Xc)): + if shifts[i]: + Xc[i] = np.roll(Xc[i], shifts[i], axis=0) + aug_X.append(Xc) + aug_y.append(y) + return np.concatenate(aug_X), np.concatenate(aug_y) + + # ============================================================================= # LDA CLASSIFIER # ============================================================================= class EMGClassifier: """ - LDA-based EMG gesture classifier. + EMG gesture classifier supporting LDA and QDA. - LESSON: Why LDA for EMG? - - Fast training and inference (good for embedded) - - Works well with small datasets - - Interpretable (can visualize decision boundaries) - - Proven effective for EMG in literature + Model types: + - LDA: Linear Discriminant Analysis — fast, exportable to ESP32 C header + - QDA: Quadratic Discriminant Analysis — more flexible boundaries, laptop-only """ - def __init__(self): - self.feature_extractor = EMGFeatureExtractor() - self.lda = LinearDiscriminantAnalysis() + def __init__(self, model_type: str = "lda", reg_param: float = 0.1): + self.model_type = model_type.lower() + self.reg_param = reg_param # only used by QDA + self.feature_extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, reinhard=True) + if self.model_type == "qda": + self.model = QuadraticDiscriminantAnalysis(reg_param=reg_param) + else: + self.model = LinearDiscriminantAnalysis() self.label_names: list[str] = [] self.is_trained = False self.feature_names: list[str] = [] + self.calibration_transform = CalibrationTransform() - def train(self, X: np.ndarray, y: np.ndarray, label_names: list[str]): + def train(self, X: np.ndarray, y: np.ndarray, label_names: list[str], + session_indices: Optional[np.ndarray] = None): """ Train the classifier. @@ -1648,32 +2029,90 @@ class EMGClassifier: X: Raw EMG windows (n_windows, n_samples, n_channels) y: Integer labels (n_windows,) label_names: List of label strings + session_indices: Optional per-window integer session ID (0..n_sessions-1). + When provided, each session's features are independently + z-scored before fitting, creating a placement-invariant model. """ + # Change 3: data augmentation on raw windows before feature extraction + if getattr(self, 'use_augmentation', True): + X_aug, y_aug = augment_emg_batch(X, y, multiplier=3) + print(f"[Classifier] Augmentation: {len(X)} -> {len(X_aug)} windows") + # Replicate session_indices to match the augmented size + if session_indices is not None: + session_indices = np.tile(session_indices, 3) + else: + X_aug, y_aug = X, y + print("\n[Classifier] Extracting features...") - X_features = self.feature_extractor.extract_features_batch(X) - self.feature_names = self.feature_extractor.get_feature_names(X.shape[2]) + X_features = self.feature_extractor.extract_features_batch(X_aug) + self.feature_names = self.feature_extractor.get_feature_names(X_aug.shape[2]) + + # Change 6: optionally stack MPF features + if getattr(self, 'use_mpf', False): + mpf = MPFFeatureExtractor(channels=HAND_CHANNELS) + X_features = np.hstack([X_features, mpf.extract_batch(X_aug)]) print(f"[Classifier] Feature matrix shape: {X_features.shape}") print(f"[Classifier] Features per window: {len(self.feature_names)}") - print("\n[Classifier] Training LDA...") - self.lda.fit(X_features, y) + if session_indices is not None: + n_sessions = len(np.unique(session_indices)) + print(f"\n[Classifier] Applying per-session z-score normalization ({n_sessions} sessions, class-balanced mu)...") + X_features = self._apply_session_normalization(X_features, session_indices, y=y_aug) + + print(f"\n[Classifier] Training {self.model_type.upper()}...") + self.model.fit(X_features, y_aug) self.label_names = label_names self.is_trained = True + # Store training distribution (in normalized space) for calibration diagnostics + self.calibration_transform.fit_from_training(X_features, y_aug, label_names) + # Training accuracy - train_acc = self.lda.score(X_features, y) + train_acc = self.model.score(X_features, y_aug) print(f"[Classifier] Training accuracy: {train_acc*100:.1f}%") return X_features + def _apply_session_normalization(self, X_features: np.ndarray, session_indices: np.ndarray, + y: Optional[np.ndarray] = None) -> np.ndarray: + """ + Z-score each session's features independently using a class-balanced mean. + + For each session: + - mu = mean of per-class centroids (class-balanced, not weighted by window count) + - sigma = overall std of all windows in the session + + Using the class-balanced mean prevents sessions with more rest windows (or any + imbalanced class) from skewing the normalization origin toward that class. + """ + X_norm = X_features.copy() + session_sigmas = [] + for sid in np.unique(session_indices): + mask = session_indices == sid + X_sess = X_features[mask] + if y is not None: + # Class-balanced mean: average of per-class centroids + y_sess = y[mask] + class_means = [X_sess[y_sess == cls].mean(axis=0) + for cls in np.unique(y_sess)] + mu = np.mean(class_means, axis=0) + else: + mu = X_sess.mean(axis=0) + sigma = X_sess.std(axis=0) + 1e-8 + session_sigmas.append(sigma) + X_norm[mask] = (X_sess - mu) / sigma + # Store mean per-session sigma so calibration can use the same scale reference + self.calibration_transform.sigma_train = np.mean(session_sigmas, axis=0) + return X_norm + def evaluate(self, X: np.ndarray, y: np.ndarray) -> dict: """Evaluate classifier on test data.""" if not self.is_trained: raise ValueError("Classifier not trained!") X_features = self.feature_extractor.extract_features_batch(X) - y_pred = self.lda.predict(X_features) + y_pred = self.model.predict(X_features) accuracy = np.mean(y_pred == y) @@ -1683,11 +2122,30 @@ class EMGClassifier: 'y_true': y } - def cross_validate(self, X: np.ndarray, y: np.ndarray, cv: int = 5) -> np.ndarray: - """Perform k-fold cross-validation.""" - print(f"\n[Classifier] Running {cv}-fold cross-validation...") + def cross_validate(self, X: np.ndarray, y: np.ndarray, trial_ids: Optional[np.ndarray] = None, + cv: int = 5, session_indices: Optional[np.ndarray] = None) -> np.ndarray: + """ + Perform k-fold cross-validation with trial-level splitting. + + When trial_ids are provided, uses GroupKFold to ensure windows from the + same trial never appear in both train and test folds (prevents leakage). + + When session_indices are provided, applies the same per-session z-score + normalization used during training before running CV. + """ X_features = self.feature_extractor.extract_features_batch(X) - scores = cross_val_score(self.lda, X_features, y, cv=cv) + + if session_indices is not None: + X_features = self._apply_session_normalization(X_features, session_indices, y=y) + + if trial_ids is not None: + print(f"\n[Classifier] Running {cv}-fold cross-validation (TRIAL-LEVEL, no leakage)...") + group_kfold = GroupKFold(n_splits=cv) + scores = cross_val_score(self.model, X_features, y, cv=group_kfold, groups=trial_ids) + else: + print(f"\n[Classifier] Running {cv}-fold cross-validation (window-level, legacy)...") + scores = cross_val_score(self.model, X_features, y, cv=cv) + return scores def predict(self, window: np.ndarray) -> tuple[str, np.ndarray]: @@ -1703,19 +2161,52 @@ class EMGClassifier: if not self.is_trained: raise ValueError("Classifier not trained!") - features = self.feature_extractor.extract_features_window(window) - pred_idx = self.lda.predict([features])[0] - proba = self.lda.predict_proba([features])[0] + if not hasattr(self, '_predict_count'): + self._predict_count = 0 + self._predict_count += 1 + _debug = (self._predict_count <= 30) + features_raw = self.feature_extractor.extract_features_window(window) + + # Energy gate: if raw signal is quiet enough to be rest, skip LDA entirely. + # Uses raw window RMS (pre-feature-extraction) so amplitude normalization + # inside the feature extractor doesn't mask the energy difference. + ct = self.calibration_transform + if (ct.is_fitted and ct.rest_energy_threshold is not None + and "rest" in self.label_names): + w_ac = window - window.mean(axis=0) # remove per-window DC offset (matches feature extractor) + raw_rms = float(np.sqrt(np.mean(w_ac ** 2))) + if _debug: + print(f"[predict #{self._predict_count}] rms={raw_rms:.1f} gate={ct.rest_energy_threshold:.1f} " + f"{'GATED->rest' if raw_rms < ct.rest_energy_threshold else 'pass->QDA/LDA'}") + if raw_rms < ct.rest_energy_threshold: + rest_idx = self.label_names.index("rest") + proba = np.zeros(len(self.label_names)) + proba[rest_idx] = 1.0 + return "rest", proba + elif _debug: + print(f"[predict #{self._predict_count}] gate inactive (is_fitted={ct.is_fitted}, " + f"threshold={ct.rest_energy_threshold})") + + features = ct.apply(features_raw) + pred_idx = self.model.predict([features])[0] + proba = self.model.predict_proba([features])[0] + if _debug: + top = sorted(zip(self.label_names, proba), key=lambda x: -x[1])[:3] + print(f"[predict #{self._predict_count}] {self.model_type.upper()} -> {self.label_names[pred_idx]}" + f" proba: {', '.join(f'{n}={p:.2f}' for n,p in top)}") return self.label_names[pred_idx], proba def get_feature_importance(self) -> dict: - """Get feature importance based on LDA coefficients.""" + """Get feature importance based on LDA coefficients (LDA only).""" if not self.is_trained: return {} + if not hasattr(self.model, 'coef_'): + return {} + # For multi-class, average absolute coefficients across classes - coef = np.abs(self.lda.coef_).mean(axis=0) + coef = np.abs(self.model.coef_).mean(axis=0) importance = dict(zip(self.feature_names, coef)) return dict(sorted(importance.items(), key=lambda x: x[1], reverse=True)) @@ -1742,14 +2233,28 @@ class EMGClassifier: filepath.parent.mkdir(parents=True, exist_ok=True) model_data = { - 'lda': self.lda, + 'model': self.model, + 'model_type': self.model_type, 'label_names': self.label_names, 'feature_names': self.feature_names, 'feature_extractor_params': { 'zc_threshold_percent': self.feature_extractor.zc_threshold_percent, 'ssc_threshold_percent': self.feature_extractor.ssc_threshold_percent, + 'channels': self.feature_extractor.channels, + 'normalize': self.feature_extractor.normalize, + 'expanded': self.feature_extractor.expanded, + 'cross_channel': self.feature_extractor.cross_channel, + 'bandpass': self.feature_extractor.bandpass, + 'reinhard': self.feature_extractor.reinhard, + 'fft_n': self.feature_extractor.fft_n, + 'fs': self.feature_extractor.fs, }, - 'version': '1.0', # For future compatibility + 'version': '1.3', + 'reg_param': self.reg_param, + 'session_normalized': True, + # Calibration transform training stats (used by CalibrationPage) + 'calib_class_means_train': self.calibration_transform.class_means_train, + 'calib_sigma_train': self.calibration_transform.sigma_train, } joblib.dump(model_data, filepath) @@ -1770,6 +2275,12 @@ class EMGClassifier: if not self.is_trained: raise ValueError("Cannot export untrained classifier!") + if self.model_type != "lda": + raise ValueError( + f"Cannot export {self.model_type.upper()} to C header. " + "Only LDA models can be exported (QDA lacks coef_/intercept_)." + ) + filepath = Path(filepath) filepath.parent.mkdir(parents=True, exist_ok=True) @@ -1779,8 +2290,8 @@ class EMGClassifier: # Get LDA parameters # coef_: (n_classes, n_features) - access as [class][feature] # intercept_: (n_classes,) - coefs = self.lda.coef_ - intercepts = self.lda.intercept_ + coefs = self.model.coef_ + intercepts = self.model.intercept_ # Add logic for binary classification (sklearn stores only 1 set of coefs) # For >2 classes, it stores n_classes sets. @@ -1815,9 +2326,27 @@ class EMGClassifier: # Safest: Let's trust that for our 5-gesture demo, it's multiclass. pass + # Bug 7 fix: preserve compile-time flags that are independent of + # the feature pipeline (MLP, ensemble). Pipeline-dependent flags + # (EXPAND_FEATURES, REINHARD) are set from the extractor config so + # they always match the exported weights. + preserved_flags = {} + _PRESERVED_FLAG_NAMES = ['MODEL_USE_MLP', 'MODEL_USE_ENSEMBLE'] + if filepath.exists(): + import re + existing = filepath.read_text() + for flag in _PRESERVED_FLAG_NAMES: + m = re.search(rf'#define\s+{flag}\s+(\d+)', existing) + if m: + preserved_flags[flag] = int(m.group(1)) + + # Auto-set pipeline flags from training config (prevents mismatch) + preserved_flags['MODEL_EXPAND_FEATURES'] = 1 if self.feature_extractor.expanded else 0 + preserved_flags['MODEL_USE_REINHARD'] = 1 if self.feature_extractor.reinhard else 0 + # Generate C content timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") - + c_content = [ "/**", f" * @file {filepath.name}", @@ -1833,11 +2362,24 @@ class EMGClassifier: "/* Metadata */", f"#define MODEL_NUM_CLASSES {n_classes}", f"#define MODEL_NUM_FEATURES {n_features}", + f"#define MODEL_NORMALIZE_FEATURES {1 if self.feature_extractor.normalize else 0}", "", - "/* Class Names */", - "static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {", ] + # Write compile-time flags (pipeline flags auto-set, architecture flags preserved) + _ALL_FLAGS = [ + 'MODEL_EXPAND_FEATURES', 'MODEL_USE_REINHARD', + 'MODEL_USE_MLP', 'MODEL_USE_ENSEMBLE', + ] + c_content.append("/* Compile-time feature flags */") + for flag in _ALL_FLAGS: + val = preserved_flags.get(flag, 0) + c_content.append(f"#define {flag} {val}") + c_content.append("") + + c_content.append("/* Class Names */") + c_content.append("static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {") + for name in self.label_names: c_content.append(f' "{name}",') c_content.append("};") @@ -1900,22 +2442,61 @@ class EMGClassifier: model_data = joblib.load(filepath) + # Determine model type (backward compat: old files have 'lda' key, no 'model_type') + model_type = model_data.get('model_type', 'lda') + reg_param = model_data.get('reg_param', 0.1) + # Create new instance and restore state - classifier = cls() - classifier.lda = model_data['lda'] + classifier = cls(model_type=model_type, reg_param=reg_param) + classifier.model = model_data.get('model', model_data.get('lda')) classifier.label_names = model_data['label_names'] - classifier.feature_names = model_data['feature_names'] classifier.is_trained = True # Restore feature extractor params params = model_data.get('feature_extractor_params', {}) + # Infer expanded/cross_channel from feature count for old models + # that don't store these params: 12 features = legacy (4×3), + # 69 features = expanded (20×3 + 9 cross-channel) + saved_feat_names = model_data.get('feature_names', []) + n_feat = len(saved_feat_names) if saved_feat_names else 69 + default_expanded = n_feat > 12 + default_cc = n_feat > 60 # cross-channel adds 9 features (60→69) classifier.feature_extractor = EMGFeatureExtractor( zc_threshold_percent=params.get('zc_threshold_percent', 0.1), ssc_threshold_percent=params.get('ssc_threshold_percent', 0.1), + channels=params.get('channels', HAND_CHANNELS), + normalize=params.get('normalize', False), + expanded=params.get('expanded', default_expanded), + cross_channel=params.get('cross_channel', default_cc), + bandpass=params.get('bandpass', False), # False for old models + reinhard=params.get('reinhard', False), + fft_n=params.get('fft_n', 256), + fs=params.get('fs', float(SAMPLING_RATE_HZ)), ) + # Regenerate feature names from extractor if not in saved data + if saved_feat_names: + classifier.feature_names = saved_feat_names + else: + channels = params.get('channels', HAND_CHANNELS) + classifier.feature_names = classifier.feature_extractor.get_feature_names(len(channels)) + + # Restore calibration transform training stats (saved from v1.2+ models) + classifier.calibration_transform = CalibrationTransform() + class_means_train = model_data.get('calib_class_means_train', {}) + sigma_train = model_data.get('calib_sigma_train') + session_normalized = model_data.get('session_normalized', False) + classifier.session_normalized = session_normalized + if class_means_train: + classifier.calibration_transform.class_means_train = class_means_train + classifier.calibration_transform.has_training_stats = True + if sigma_train is not None: + classifier.calibration_transform.sigma_train = sigma_train print(f"[Classifier] Model loaded from: {filepath}") print(f"[Classifier] Labels: {classifier.label_names}") + calib_ready = classifier.calibration_transform.has_training_stats + print(f"[Classifier] Calibration support: {'yes' if calib_ready else 'no (retrain to enable)'}") + print(f"[Classifier] Session-normalized: {session_normalized}") return classifier @staticmethod @@ -1923,11 +2504,22 @@ class EMGClassifier: """Get the default path for saving/loading models.""" return MODEL_DIR / "emg_lda_classifier.joblib" + @staticmethod + def get_latest_model_path() -> Path | None: + """Get the most recently modified model file, or None if no models exist.""" + models = EMGClassifier.list_saved_models() + if not models: + return None + return max(models, key=lambda p: p.stat().st_mtime) + @staticmethod def list_saved_models() -> list[Path]: - """List all saved model files.""" + """List all saved classifier model files (excludes ensemble/auxiliary files).""" MODEL_DIR.mkdir(parents=True, exist_ok=True) - return sorted(MODEL_DIR.glob("*.joblib")) + return sorted( + p for p in MODEL_DIR.glob("*.joblib") + if "ensemble" not in p.stem + ) # ============================================================================= @@ -2147,13 +2739,14 @@ def run_training_demo(): print("TRAINING ON ALL SESSIONS COMBINED") print("=" * 60) - X, y, label_names, loaded_sessions = storage.load_all_for_training() + X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training() print(f"\nDataset:") print(f" Windows: {X.shape[0]}") print(f" Samples per window: {X.shape[1]}") print(f" Channels: {X.shape[2]}") print(f" Classes: {label_names}") + print(f" Unique trials: {len(np.unique(trial_ids))}") # Count per class print(f"\nSamples per class:") @@ -2165,8 +2758,8 @@ def run_training_demo(): classifier = EMGClassifier() X_features = classifier.train(X, y, label_names) - # Cross-validation - cv_scores = classifier.cross_validate(X, y, cv=5) + # Cross-validation (trial-level to prevent leakage) + cv_scores = classifier.cross_validate(X, y, trial_ids=trial_ids, cv=5) print(f"\nCross-validation scores: {cv_scores}") print(f"Mean CV accuracy: {cv_scores.mean()*100:.1f}% (+/- {cv_scores.std()*100:.1f}%)") @@ -2179,14 +2772,35 @@ def run_training_demo(): bar = "█" * int(score * 20) print(f" {name:12s}: {bar} ({score:.3f})") - # Train/test split evaluation + # Train/test split evaluation (TRIAL-LEVEL to prevent leakage) print(f"\n{'-' * 40}") - print("TRAIN/TEST SPLIT EVALUATION") + print("TRAIN/TEST SPLIT EVALUATION (TRIAL-LEVEL)") print("-" * 40) - X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=0.2, random_state=42, stratify=y - ) + # Use GroupShuffleSplit to split by trial, not by window + gss = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=42) + train_idx, test_idx = next(gss.split(X, y, groups=trial_ids)) + + X_train, X_test = X[train_idx], X[test_idx] + y_train, y_test = y[train_idx], y[test_idx] + train_trial_ids = trial_ids[train_idx] + test_trial_ids = trial_ids[test_idx] + + # VERIFICATION: Ensure no trial leakage + train_trials_set = set(train_trial_ids) + test_trials_set = set(test_trial_ids) + overlap = train_trials_set & test_trials_set + assert len(overlap) == 0, f"Trial leakage detected! Overlapping trials: {overlap}" + print(f" Train: {len(X_train)} windows from {len(train_trials_set)} trials") + print(f" Test: {len(X_test)} windows from {len(test_trials_set)} trials") + print(f" Trial overlap: {len(overlap)} (VERIFIED: no leakage)") + + # Log per-class distribution + print(f"\n Per-class window counts:") + for i, name in enumerate(label_names): + train_count = np.sum(y_train == i) + test_count = np.sum(y_test == i) + print(f" {name:12s}: train={train_count:4d}, test={test_count:4d}") # Train on train set test_classifier = EMGClassifier() @@ -2228,22 +2842,84 @@ def run_training_demo(): return classifier +# ============================================================================= +# Change 5 — CLASSIFIER BENCHMARK +# ============================================================================= + +def run_classifier_benchmark(): + """ + Cross-validate LDA, QDA, SVM-RBF, and MLP on the collected dataset. + + Purpose: tells you whether accuracy plateau is a features problem + (all classifiers similar → add features) or a model complexity problem + (SVM/MLP >> LDA → implement Change E / ensemble). + """ + from sklearn.svm import SVC + from sklearn.neural_network import MLPClassifier + from sklearn.pipeline import Pipeline + from sklearn.preprocessing import StandardScaler + from sklearn.model_selection import cross_val_score, GroupKFold + from sklearn.discriminant_analysis import (LinearDiscriminantAnalysis, + QuadraticDiscriminantAnalysis) + + print("\n" + "=" * 60) + print("CLASSIFIER BENCHMARK (Cross-validation)") + print("=" * 60) + + storage = SessionStorage() + X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() + + if len(np.unique(y)) < 2: + print("Need at least 2 gesture classes. Collect more data first.") + return + + extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True) + X = extractor.extract_features_batch(X_raw) + X = EMGClassifier()._apply_session_normalization(X, session_indices, y=y) + + clfs = { + 'LDA (ESP32 model)': LinearDiscriminantAnalysis(), + 'QDA': QuadraticDiscriminantAnalysis(reg_param=0.1), + 'SVM-RBF': Pipeline([('s', StandardScaler()), + ('m', SVC(kernel='rbf', C=10))]), + 'MLP-128-64': Pipeline([('s', StandardScaler()), + ('m', MLPClassifier(hidden_layer_sizes=(128, 64), + max_iter=1000, + early_stopping=True))]), + } + + n_splits = min(5, len(np.unique(trial_ids))) + gkf = GroupKFold(n_splits=n_splits) + + print(f"\n{'Classifier':<22} {'Mean CV':>8} {'Std':>6}") + print("-" * 40) + for name, clf in clfs.items(): + sc = cross_val_score(clf, X, y, cv=gkf, groups=trial_ids, scoring='accuracy') + print(f" {name:<20} {sc.mean()*100:>7.1f}% ±{sc.std()*100:.1f}%") + + print() + print(" → If LDA ≈ SVM: features are the bottleneck (add Change 1 features)") + print(" → If SVM >> LDA: model complexity bottleneck (implement Change F ensemble)") + + # ============================================================================= # LIVE PREDICTION (Real-time gesture classification) # ============================================================================= def run_prediction_demo(): """ - Live prediction demo - classifies gestures in real-time. + Live prediction demo - classifies gestures in real-time from ESP32. Shows: 1. Load saved model OR train fresh on all sessions - 2. Stream simulated EMG data + 2. Connect to ESP32 and stream real EMG data 3. Classify each window as it comes in 4. Display predictions with confidence + + REQUIRES: ESP32 hardware connected via USB. """ print("\n" + "=" * 60) - print("LIVE PREDICTION DEMO") + print("LIVE PREDICTION DEMO (ESP32 Required)") print("=" * 60) # Check for saved model @@ -2290,23 +2966,33 @@ def run_prediction_demo(): # Load ALL sessions and train model print(f"\n[Training model on all sessions...]") - X, y, label_names, loaded_sessions = storage.load_all_for_training() + X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training() + print(f"[Unique trials: {len(np.unique(trial_ids))}]") classifier = EMGClassifier() classifier.train(X, y, label_names) + # Connect to ESP32 + print("\n[Connecting to ESP32...]") + stream = RealSerialStream() + try: + stream.connect(timeout=5.0) + print(f" Connected: {stream.device_info}") + except Exception as e: + print(f" ERROR: Failed to connect to ESP32: {e}") + print(" Make sure the ESP32 is connected and firmware is flashed.") + return None + # Start live prediction print("\n" + "=" * 60) print("STARTING LIVE PREDICTION (WITH SMOOTHING)") print("=" * 60) - max_predictions = 50 # Stop after this many predictions - print(f"Running {max_predictions} predictions with smoothing enabled...\n") + print("Press Ctrl+C to stop.\n") print(" Smoothing: Probability EMA (0.7) + Majority Vote (5) + Debounce (3)\n") - stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) parser = EMGParser(num_channels=NUM_CHANNELS) - windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0) + windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, hop_size_ms=HOP_SIZE_MS) # Create prediction smoother smoother = PredictionSmoother( @@ -2316,29 +3002,11 @@ def run_prediction_demo(): debounce_count=3, # Consecutive predictions needed to change ) - # Cycle through gestures for demo (names match ESP32 gesture definitions) - gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"] - gesture_idx = 0 - gesture_duration = 2.5 # seconds per gesture - gesture_start = time.perf_counter() - current_gesture = gesture_cycle[0] - stream.set_gesture(current_gesture) - print(f" [Simulating: {current_gesture.upper()}]") - stream.start() prediction_count = 0 try: - while prediction_count < max_predictions: - # Change simulated gesture periodically - elapsed = time.perf_counter() - gesture_start - if elapsed > gesture_duration: - gesture_idx = (gesture_idx + 1) % len(gesture_cycle) - gesture_start = time.perf_counter() - current_gesture = gesture_cycle[gesture_idx] - stream.set_gesture(current_gesture) - print(f"\n [Simulating: {current_gesture.upper()}]") - + while True: # Read and process data line = stream.readline() if line: @@ -2372,6 +3040,7 @@ def run_prediction_demo(): finally: stream.stop() + stream.disconnect() # Show smoothing stats stats = smoother.get_stats() @@ -2381,7 +3050,6 @@ def run_prediction_demo(): print(f" Total predictions: {stats['total_predictions']}") print(f" Output changes: {stats['output_changes']}") print(f" Stability ratio: {stats['stability_ratio']*100:.1f}%") - print(f"\n (Without smoothing, output would change with every raw prediction)") return classifier @@ -2430,10 +3098,11 @@ def run_visualization_demo(): return None # Load ALL data combined - X, y, label_names, loaded_sessions = storage.load_all_for_training() + X, y, trial_ids, session_indices, label_names, loaded_sessions = storage.load_all_for_training() + print(f"[Unique trials: {len(np.unique(trial_ids))}]") - # Extract features - extractor = EMGFeatureExtractor() + # Extract features (forearm channels only, matching hand classifier) + extractor = EMGFeatureExtractor(channels=HAND_CHANNELS) X_features = extractor.extract_features_batch(X) # Train LDA @@ -2563,8 +3232,6 @@ def run_visualization_demo(): plt.tight_layout() # --- Figure 5: Confusion Matrix Heatmap --- - from sklearn.model_selection import cross_val_predict - fig5, ax5 = plt.subplots(figsize=(8, 6)) y_pred = cross_val_predict(lda, X_features, y, cv=5) @@ -2614,6 +3281,7 @@ if __name__ == "__main__": print(" 3. Train LDA classifier") print(" 4. Live prediction demo") print(" 5. Visualize LDA model") + print(" 6. Classifier benchmark (LDA vs QDA vs SVM vs MLP)") print(" q. Quit") choice = input("\nEnter choice: ").strip().lower() @@ -2637,5 +3305,8 @@ if __name__ == "__main__": elif choice == "5": lda = run_visualization_demo() + elif choice == "6": + run_classifier_benchmark() + else: - print("\nInvalid choice. Please enter 1-5 or q.") \ No newline at end of file + print("\nInvalid choice. Please enter 1-6 or q.") \ No newline at end of file diff --git a/learning_emg_filtering.py b/learning_emg_filtering.py index b257ef7..33ff90d 100644 --- a/learning_emg_filtering.py +++ b/learning_emg_filtering.py @@ -1,13 +1,42 @@ +""" +EMG Signal Analysis Script (STANDALONE - NOT PRODUCTION CODE) +============================================================== + +!! WARNING !! +This script is for OFFLINE ANALYSIS AND VISUALIZATION ONLY. +It is NOT part of the training or inference pipeline. + +DO NOT use this script's outputs for: +- Model training +- Feature extraction thresholds +- Production inference + +The thresholds defined here (ZC_THRESHOLD_PERCENT, SSC_THRESHOLD_PERCENT) +are DIFFERENT from production values in: +- learning_data_collection.py (production: 0.1, 0.1) +- model_weights.h (production: 0.1, 0.1) + +This script uses higher thresholds (0.7, 0.6) for visualization clarity, +which would produce INCORRECT features if used for training/inference. + +To analyze collected sessions: +1. Update HDF5_PATH to your session file +2. Run: python learning_emg_filtering.py +3. View the generated plots +""" + import numpy as np import matplotlib.pyplot as plt import h5py from scipy.signal import butter, sosfiltfilt # ============================================================================= -# CONFIGURABLE PARAMETERS +# CONFIGURABLE PARAMETERS (FOR VISUALIZATION ONLY - NOT PRODUCTION VALUES!) # ============================================================================= -ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS -SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS +# These thresholds are intentionally different from production (0.1, 0.1) +# to provide cleaner visualizations. DO NOT use for training/inference. +ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold (VISUALIZATION ONLY) +SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold (VISUALIZATION ONLY) # ============================================================================= # LOAD DATA FROM GUI's HDF5 FORMAT diff --git a/live_predict.py b/live_predict.py new file mode 100644 index 0000000..bb61d2f --- /dev/null +++ b/live_predict.py @@ -0,0 +1,325 @@ +""" +live_predict.py — Laptop-side live EMG inference for Bucky Arm. + +Use this script when the ESP32 is in EMG_MAIN mode and you want the laptop to +run the classifier (instead of the on-device model). Useful for: + - Comparing laptop accuracy vs. on-device accuracy before flashing a new model + - Debugging the feature pipeline without reflashing firmware + - Running an updated model that hasn't been exported to C yet + +Workflow: + 1. ESP32 must be in EMG_MAIN mode (MAIN_MODE = EMG_MAIN in config.h) + 2. This script handshakes → requests STATE_LAPTOP_PREDICT + 3. ESP32 streams raw ADC CSV at 1 kHz + 4. Script collects 3s of REST for session normalization, then classifies + 5. Every 25ms (one hop), the predicted gesture is sent back: {"gesture":"fist"} + 6. ESP32 executes the received gesture command on the arm + +Usage: + python live_predict.py --port COM3 + python live_predict.py --port COM3 --model models/my_model.joblib + python live_predict.py --port COM3 --confidence 0.45 +""" + +import argparse +import sys +import time +from pathlib import Path + +import numpy as np +import serial + +# Import from the main training pipeline +sys.path.insert(0, str(Path(__file__).parent)) +from learning_data_collection import ( + EMGClassifier, + NUM_CHANNELS, + SAMPLING_RATE_HZ, + WINDOW_SIZE_MS, + HOP_SIZE_MS, + HAND_CHANNELS, +) + +# Derived constants +WINDOW_SIZE = int(WINDOW_SIZE_MS * SAMPLING_RATE_HZ / 1000) # 150 samples +HOP_SIZE = int(HOP_SIZE_MS * SAMPLING_RATE_HZ / 1000) # 25 samples +BAUD_RATE = 921600 +CALIB_SECS = 3.0 # seconds of REST to collect at startup for normalization + +# ────────────────────────────────────────────────────────────────────────────── + +def parse_args(): + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument("--port", required=True, + help="Serial port (e.g. COM3 on Windows, /dev/ttyUSB0 on Linux)") + p.add_argument("--model", default=None, + help="Path to .joblib model file. Defaults to the most recently trained model.") + p.add_argument("--confidence", type=float, default=0.40, + help="Reject predictions below this confidence (default: 0.40, same as firmware)") + p.add_argument("--no-calib", action="store_true", + help="Skip REST calibration (use raw features — only for quick testing)") + return p.parse_args() + + +def load_model(model_path: str | None) -> EMGClassifier: + if model_path: + path = Path(model_path) + else: + path = EMGClassifier.get_latest_model_path() + if path is None: + print("[ERROR] No trained model found in models/. Run training first.") + sys.exit(1) + print(f"[Model] Auto-selected latest model: {path.name}") + + classifier = EMGClassifier.load(path) + if not classifier.is_trained: + print("[ERROR] Loaded model is not trained.") + sys.exit(1) + return classifier + + +def handshake(ser: serial.Serial) -> bool: + """Send connect command and wait for ack_connect from ESP32.""" + print("[Handshake] Sending connect command...") + ser.write(b'{"cmd":"connect"}\n') + deadline = time.time() + 5.0 + while time.time() < deadline: + raw = ser.readline() + line = raw.decode("utf-8", errors="ignore").strip() + if "ack_connect" in line: + print(f"[Handshake] Connected — {line}") + return True + if line: + print(f"[Handshake] (ignored) {line}") + print("[ERROR] No ack_connect within 5s. Is the ESP32 powered and in EMG_MAIN mode?") + return False + + +def collect_calibration_windows( + ser: serial.Serial, + n_windows: int, + extractor, +) -> tuple[np.ndarray, np.ndarray]: + """ + Collect n_windows of REST EMG, extract features, and compute + per-feature mean and std for session normalization. + + Returns (mean, std) arrays of shape (n_features,). + """ + print(f"[Calib] Hold arm relaxed at rest for {CALIB_SECS:.0f}s...") + + raw_buf = np.zeros((WINDOW_SIZE, NUM_CHANNELS), dtype=np.float32) + samples = 0 + feat_list = [] + + while len(feat_list) < n_windows: + raw = ser.readline() + line = raw.decode("utf-8", errors="ignore").strip() + if not line or line.startswith("{"): + continue + try: + vals = [float(v) for v in line.split(",")] + except ValueError: + continue + if len(vals) != NUM_CHANNELS: + continue + + # Slide window + raw_buf = np.roll(raw_buf, -1, axis=0) + raw_buf[-1] = vals + samples += 1 + + if samples >= WINDOW_SIZE and (samples % HOP_SIZE) == 0: + feat = extractor.extract_features_window(raw_buf) + feat_list.append(feat) + + done = len(feat_list) + if done % 10 == 0: + pct = int(100 * done / n_windows) + print(f" {pct}% ({done}/{n_windows} windows)", end="\r", flush=True) + + print(f"\n[Calib] Collected {len(feat_list)} windows.") + feats = np.array(feat_list, dtype=np.float32) + mean = feats.mean(axis=0) + std = np.where(feats.std(axis=0) > 1e-6, feats.std(axis=0), 1e-6).astype(np.float32) + print("[Calib] Session normalization computed.") + return mean, std + + +def load_ensemble(): + """Load ensemble sklearn models if available.""" + path = Path(__file__).parent / 'models' / 'emg_ensemble.joblib' + if not path.exists(): + return None + try: + import joblib + ens = joblib.load(path) + print(f"[Model] Loaded ensemble (4 LDAs)") + return ens + except Exception as e: + print(f"[Model] Ensemble load failed: {e}") + return None + + +def load_mlp(): + """Load MLP numpy weights if available.""" + path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz' + if not path.exists(): + return None + try: + mlp = dict(np.load(path, allow_pickle=True)) + print(f"[Model] Loaded MLP weights (numpy)") + return mlp + except Exception as e: + print(f"[Model] MLP load failed: {e}") + return None + + +def run_ensemble(ens, features): + """Run ensemble: 3 specialist LDAs → meta-LDA → probabilities.""" + p_td = ens['lda_td'].predict_proba([features[ens['td_idx']]])[0] + p_fd = ens['lda_fd'].predict_proba([features[ens['fd_idx']]])[0] + p_cc = ens['lda_cc'].predict_proba([features[ens['cc_idx']]])[0] + x_meta = np.concatenate([p_td, p_fd, p_cc]) + return ens['meta_lda'].predict_proba([x_meta])[0] + + +def run_mlp(mlp, features): + """Run MLP forward pass: Dense(32,relu) → Dense(16,relu) → Dense(5,softmax).""" + x = features.astype(np.float32) + x = np.maximum(0, x @ mlp['w0'] + mlp['b0']) + x = np.maximum(0, x @ mlp['w1'] + mlp['b1']) + logits = x @ mlp['w2'] + mlp['b2'] + e = np.exp(logits - logits.max()) + return e / e.sum() + + +def main(): + args = parse_args() + + # ── Load classifier ────────────────────────────────────────────────────── + classifier = load_model(args.model) + extractor = classifier.feature_extractor + ensemble = load_ensemble() + mlp = load_mlp() + model_names = ["LDA"] + if ensemble: + model_names.append("Ensemble") + if mlp: + model_names.append("MLP") + print(f"[Model] Active: {' + '.join(model_names)} ({len(model_names)} models)") + + # ── Open serial ────────────────────────────────────────────────────────── + try: + ser = serial.Serial(args.port, BAUD_RATE, timeout=1.0) + except serial.SerialException as e: + print(f"[ERROR] Could not open {args.port}: {e}") + sys.exit(1) + + time.sleep(0.5) + ser.reset_input_buffer() + + # ── Handshake ──────────────────────────────────────────────────────────── + if not handshake(ser): + ser.close() + sys.exit(1) + + # ── Request laptop-predict mode ────────────────────────────────────────── + ser.write(b'{"cmd":"start_laptop_predict"}\n') + print("[Control] ESP32 entering STATE_LAPTOP_PREDICT — streaming ADC...") + + # ── Calibration ────────────────────────────────────────────────────────── + calib_mean = None + calib_std = None + if not args.no_calib: + n_calib = max(20, int(CALIB_SECS * 1000 / HOP_SIZE_MS)) + calib_mean, calib_std = collect_calibration_windows(ser, n_calib, extractor) + else: + print("[Calib] Skipped (--no-calib). Accuracy may be reduced.") + + # ── Live prediction loop ───────────────────────────────────────────────── + print(f"\n[Predict] Running. Confidence threshold: {args.confidence:.2f}") + print("[Predict] Press Ctrl+C to stop.\n") + + raw_buf = np.zeros((WINDOW_SIZE, NUM_CHANNELS), dtype=np.float32) + samples = 0 + last_gesture = None + n_inferences = 0 + n_rejected = 0 + + try: + while True: + raw = ser.readline() + line = raw.decode("utf-8", errors="ignore").strip() + + # Skip JSON telemetry lines from ESP32 + if not line or line.startswith("{"): + continue + + # Parse CSV sample + try: + vals = [float(v) for v in line.split(",")] + except ValueError: + continue + if len(vals) != NUM_CHANNELS: + continue + + # Slide window + raw_buf = np.roll(raw_buf, -1, axis=0) + raw_buf[-1] = vals + samples += 1 + + # Classify every HOP_SIZE samples + if samples >= WINDOW_SIZE and (samples % HOP_SIZE) == 0: + feat = extractor.extract_features_window(raw_buf).astype(np.float32) + + # Apply session normalization + if calib_mean is not None: + feat = (feat - calib_mean) / calib_std + + # Run all available models and average probabilities + probas = [classifier.model.predict_proba([feat])[0]] + if ensemble: + try: + probas.append(run_ensemble(ensemble, feat)) + except Exception: + pass + if mlp: + try: + probas.append(run_mlp(mlp, feat)) + except Exception: + pass + proba = np.mean(probas, axis=0) + class_idx = int(np.argmax(proba)) + confidence = float(proba[class_idx]) + gesture = classifier.label_names[class_idx] + n_inferences += 1 + + # Reject below threshold + if confidence < args.confidence: + n_rejected += 1 + continue + + # Send gesture command to ESP32 + cmd = f'{{"gesture":"{gesture}"}}\n' + ser.write(cmd.encode("utf-8")) + + # Local logging (only on change) + if gesture != last_gesture: + reject_rate = 100 * n_rejected / n_inferences if n_inferences else 0 + print(f" → {gesture:<12} conf={confidence:.2f} " + f"reject_rate={reject_rate:.0f}%") + last_gesture = gesture + n_rejected = 0 + n_inferences = 0 + + except KeyboardInterrupt: + print("\n\n[Stop] Sending stop command to ESP32...") + ser.write(b'{"cmd":"stop"}\n') + time.sleep(0.2) + ser.close() + print("[Stop] Done.") + + +if __name__ == "__main__": + main() diff --git a/models/emg_ensemble.joblib b/models/emg_ensemble.joblib new file mode 100644 index 0000000..deb144a Binary files /dev/null and b/models/emg_ensemble.joblib differ diff --git a/models/emg_lda_classifier.joblib b/models/emg_lda_classifier.joblib index 6757e9a..cb0a8d9 100644 Binary files a/models/emg_lda_classifier.joblib and b/models/emg_lda_classifier.joblib differ diff --git a/models/emg_mlp_weights.npz b/models/emg_mlp_weights.npz new file mode 100644 index 0000000..2e5cf81 Binary files /dev/null and b/models/emg_mlp_weights.npz differ diff --git a/models/emg_qda_classifier.joblib b/models/emg_qda_classifier.joblib new file mode 100644 index 0000000..accf7cf Binary files /dev/null and b/models/emg_qda_classifier.joblib differ diff --git a/train_ensemble.py b/train_ensemble.py new file mode 100644 index 0000000..d98ea76 --- /dev/null +++ b/train_ensemble.py @@ -0,0 +1,200 @@ +""" +Train the full 3-specialist-LDA + meta-LDA ensemble. +Requires Change 1 (expanded features) to be implemented first. +Exports model_weights_ensemble.h for firmware Change F. + +Architecture: + LDA_TD (36 time-domain feat) ─┐ + LDA_FD (24 freq-domain feat) ├─ 15 probs ─► Meta-LDA ─► final class + LDA_CC (9 cross-ch feat) ─┘ + +Change 7 — priority Tier 3. +""" +import numpy as np +from pathlib import Path +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.model_selection import cross_val_predict, GroupKFold, cross_val_score +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from learning_data_collection import ( + SessionStorage, EMGFeatureExtractor, HAND_CHANNELS +) + +# --- Load and extract features ----------------------------------------------- +storage = SessionStorage() +X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() + +extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True) +X = extractor.extract_features_batch(X_raw).astype(np.float64) + +# Per-session class-balanced normalization +# Must match EMGClassifier._apply_session_normalization(): +# mean = average of per-class means (not overall mean), std = overall std. +# StandardScaler uses overall mean, which biases toward the majority class. +for sid in np.unique(session_indices): + mask = session_indices == sid + X_sess = X[mask] + y_sess = y[mask] + + # Class-balanced mean: average of per-class centroids + class_means = [] + for cls in np.unique(y_sess): + class_means.append(X_sess[y_sess == cls].mean(axis=0)) + balanced_mean = np.mean(class_means, axis=0) + + # Overall std (same as StandardScaler) + std = X_sess.std(axis=0) + std[std < 1e-12] = 1.0 # avoid division by zero + + X[mask] = (X_sess - balanced_mean) / std + +feat_names = extractor.get_feature_names(n_channels=len(HAND_CHANNELS)) +n_cls = len(np.unique(y)) + +# --- Feature subset indices --------------------------------------------------- +# Per-channel layout (20 features/channel): indices 0-11 TD, 12-19 FD +# Cross-channel features start at index 60 (3 channels × 20 features each) +TD_FEAT = ['rms', 'wl', 'zc', 'ssc', 'mav', 'var', 'iemg', 'wamp', 'ar1', 'ar2', 'ar3', 'ar4'] +FD_FEAT = ['mnf', 'mdf', 'pkf', 'mnp', 'bp0', 'bp1', 'bp2', 'bp3'] + +td_idx = [i for i, n in enumerate(feat_names) + if any(n.endswith(f'_{f}') for f in TD_FEAT) and n.startswith('ch')] +fd_idx = [i for i, n in enumerate(feat_names) + if any(n.endswith(f'_{f}') for f in FD_FEAT) and n.startswith('ch')] +cc_idx = [i for i, n in enumerate(feat_names) if n.startswith('cc_')] + +print(f"Feature subsets — TD: {len(td_idx)}, FD: {len(fd_idx)}, CC: {len(cc_idx)}") +assert len(td_idx) == 36, f"Expected 36 TD features, got {len(td_idx)}" +assert len(fd_idx) == 24, f"Expected 24 FD features, got {len(fd_idx)}" +assert len(cc_idx) == 9, f"Expected 9 CC features, got {len(cc_idx)}" + +X_td = X[:, td_idx] +X_fd = X[:, fd_idx] +X_cc = X[:, cc_idx] + +# --- Train specialist LDAs with out-of-fold stacking ------------------------- +gkf = GroupKFold(n_splits=min(5, len(np.unique(trial_ids)))) + +print("Training specialist LDAs (out-of-fold for stacking)...") +lda_td = LinearDiscriminantAnalysis() +lda_fd = LinearDiscriminantAnalysis() +lda_cc = LinearDiscriminantAnalysis() + +oof_td = cross_val_predict(lda_td, X_td, y, cv=gkf, groups=trial_ids, method='predict_proba') +oof_fd = cross_val_predict(lda_fd, X_fd, y, cv=gkf, groups=trial_ids, method='predict_proba') +oof_cc = cross_val_predict(lda_cc, X_cc, y, cv=gkf, groups=trial_ids, method='predict_proba') + +# Specialist CV accuracy (for diagnostics) +for name, mdl, Xs in [('LDA_TD', lda_td, X_td), + ('LDA_FD', lda_fd, X_fd), + ('LDA_CC', lda_cc, X_cc)]: + sc = cross_val_score(mdl, Xs, y, cv=gkf, groups=trial_ids) + print(f" {name}: {sc.mean()*100:.1f}% ± {sc.std()*100:.1f}%") + +# --- Train meta-LDA on out-of-fold outputs ------------------------------------ +X_meta = np.hstack([oof_td, oof_fd, oof_cc]) # (n_samples, 3*n_cls = 15) +meta_lda = LinearDiscriminantAnalysis() +meta_sc = cross_val_score(meta_lda, X_meta, y, cv=gkf, groups=trial_ids) +print(f" Meta-LDA: {meta_sc.mean()*100:.1f}% ± {meta_sc.std()*100:.1f}%") + +# Fit all models on full dataset for deployment +lda_td.fit(X_td, y) +lda_fd.fit(X_fd, y) +lda_cc.fit(X_cc, y) +meta_lda.fit(X_meta, y) + +# --- Export all weights to C header ------------------------------------------ +def lda_to_c_arrays(lda, name, feat_dim, n_cls, label_names, class_order): + """Generate C array strings for LDA weights and intercepts. + + NOTE: sklearn LDA.coef_ for multi-class has shape (n_classes-1, n_features) + when using SVD solver. If so, we use decision_function and re-derive weights. + """ + coef = lda.coef_ + intercept = lda.intercept_ + + if coef.shape[0] != n_cls: + # SVD solver returns (n_cls-1, n_feat); sklearn handles this internally + # via scalings_. We refit with 'lsqr' solver to get full coef matrix. + from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA2 + lda2 = LDA2(solver='lsqr') + # We can't refit here (no data) so just warn and pad with zeros + print(f" WARNING: {name} coef_ shape {coef.shape} != ({n_cls}, {feat_dim}). " + f"Padding with zeros. Refit with solver='lsqr' for full matrix.") + padded = np.zeros((n_cls, feat_dim)) + padded[:coef.shape[0]] = coef + coef = padded + padded_i = np.zeros(n_cls) + padded_i[:intercept.shape[0]] = intercept + intercept = padded_i + + lines = [] + lines.append(f"const float {name}_WEIGHTS[{n_cls}][{feat_dim}] = {{") + for c in class_order: + row = ', '.join(f'{v:.8f}f' for v in coef[c]) + lines.append(f" {{{row}}}, // {label_names[c]}") + lines.append("};") + lines.append(f"const float {name}_INTERCEPTS[{n_cls}] = {{") + intercept_str = ', '.join(f'{intercept[c]:.8f}f' for c in class_order) + lines.append(f" {intercept_str}") + lines.append("};") + return '\n'.join(lines) + + +class_order = list(range(n_cls)) +out_path = Path(__file__).parent / 'EMG_Arm/src/core/model_weights_ensemble.h' +out_path.parent.mkdir(parents=True, exist_ok=True) + +td_offset = min(td_idx) if td_idx else 0 +fd_offset = min(fd_idx) if fd_idx else 0 +cc_offset = min(cc_idx) if cc_idx else 0 + +with open(out_path, 'w') as f: + f.write("// Auto-generated by train_ensemble.py — do not edit\n") + f.write("#pragma once\n\n") + f.write("// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from\n") + f.write("// model_weights.h to avoid redefinition conflicts.\n") + f.write('#include "model_weights.h"\n\n') + f.write(f"#define ENSEMBLE_PER_CH_FEATURES 20\n\n") + f.write(f"#define TD_FEAT_OFFSET {td_offset}\n") + f.write(f"#define TD_NUM_FEATURES {len(td_idx)}\n") + f.write(f"#define FD_FEAT_OFFSET {fd_offset}\n") + f.write(f"#define FD_NUM_FEATURES {len(fd_idx)}\n") + f.write(f"#define CC_FEAT_OFFSET {cc_offset}\n") + f.write(f"#define CC_NUM_FEATURES {len(cc_idx)}\n") + f.write(f"#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)\n\n") + + f.write("// Feature index arrays for gather operations (TD and FD are non-contiguous)\n") + f.write(f"// TD indices: {td_idx}\n") + f.write(f"// FD indices: {fd_idx}\n") + f.write(f"// CC indices: {cc_idx}\n\n") + + f.write(lda_to_c_arrays(lda_td, 'LDA_TD', len(td_idx), n_cls, label_names, class_order)) + f.write('\n\n') + f.write(lda_to_c_arrays(lda_fd, 'LDA_FD', len(fd_idx), n_cls, label_names, class_order)) + f.write('\n\n') + f.write(lda_to_c_arrays(lda_cc, 'LDA_CC', len(cc_idx), n_cls, label_names, class_order)) + f.write('\n\n') + f.write(lda_to_c_arrays(meta_lda, 'META_LDA', 3 * n_cls, n_cls, label_names, class_order)) + f.write('\n') + +print(f"Exported ensemble weights to {out_path}") +print(f"Total weight storage: " + f"{(len(td_idx) + len(fd_idx) + len(cc_idx) + 3*n_cls) * n_cls * 4} bytes float32") + +# --- Also save sklearn models for laptop-side inference ---------------------- +import joblib +ensemble_bundle = { + 'lda_td': lda_td, + 'lda_fd': lda_fd, + 'lda_cc': lda_cc, + 'meta_lda': meta_lda, + 'td_idx': td_idx, + 'fd_idx': fd_idx, + 'cc_idx': cc_idx, + 'label_names': label_names, +} +ensemble_joblib = Path(__file__).parent / 'models' / 'emg_ensemble.joblib' +ensemble_joblib.parent.mkdir(parents=True, exist_ok=True) +joblib.dump(ensemble_bundle, ensemble_joblib) +print(f"Saved laptop ensemble model to {ensemble_joblib}") diff --git a/train_mlp_tflite.py b/train_mlp_tflite.py new file mode 100644 index 0000000..e774ed4 --- /dev/null +++ b/train_mlp_tflite.py @@ -0,0 +1,106 @@ +""" +Train int8 MLP for ESP32-S3 deployment via TFLite Micro. +Run AFTER Change 0 (label shift) + Change 1 (expanded features). + +Change E — priority Tier 3. +Outputs: EMG_Arm/src/core/emg_model_data.cc +""" +import numpy as np +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS + +try: + import tensorflow as tf +except ImportError: + print("ERROR: TensorFlow not installed. Run: pip install tensorflow") + sys.exit(1) + +# --- Load and extract features ----------------------------------------------- +storage = SessionStorage() +X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() + +extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True) +X = extractor.extract_features_batch(X_raw).astype(np.float32) + +# Per-session class-balanced normalization (must match EMGClassifier + train_ensemble.py) +for sid in np.unique(session_indices): + mask = session_indices == sid + X_sess = X[mask] + y_sess = y[mask] + class_means = [X_sess[y_sess == cls].mean(axis=0) for cls in np.unique(y_sess)] + balanced_mean = np.mean(class_means, axis=0) + std = X_sess.std(axis=0) + std[std < 1e-12] = 1.0 + X[mask] = (X_sess - balanced_mean) / std + +n_feat = X.shape[1] +n_cls = len(np.unique(y)) +print(f"Dataset: {len(X)} samples, {n_feat} features, {n_cls} classes") +print(f"Classes: {label_names}") + +# --- Build and train MLP ----------------------------------------------------- +model = tf.keras.Sequential([ + tf.keras.layers.Input(shape=(n_feat,)), + tf.keras.layers.Dense(32, activation='relu'), + tf.keras.layers.Dropout(0.2), + tf.keras.layers.Dense(16, activation='relu'), + tf.keras.layers.Dense(n_cls, activation='softmax'), +]) +model.compile(optimizer='adam', + loss='sparse_categorical_crossentropy', + metrics=['accuracy']) +model.summary() +model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1) + +# --- Convert to int8 TFLite -------------------------------------------------- +def representative_dataset(): + for i in range(0, len(X), 10): + yield [X[i:i+1]] + +converter = tf.lite.TFLiteConverter.from_keras_model(model) +converter.optimizations = [tf.lite.Optimize.DEFAULT] +converter.representative_dataset = representative_dataset +converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] +converter.inference_input_type = tf.int8 +converter.inference_output_type = tf.int8 +tflite_model = converter.convert() + +# --- Write emg_model_data.cc ------------------------------------------------- +out_cc = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.cc' +out_h = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.h' + +with open(out_cc, 'w') as f: + f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n') + f.write('#include "emg_model_data.h"\n') + f.write(f'const int g_model_len = {len(tflite_model)};\n') + f.write('alignas(8) const unsigned char g_model[] = {\n ') + f.write(', '.join(f'0x{b:02x}' for b in tflite_model)) + f.write('\n};\n') + +with open(out_h, 'w') as f: + f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n') + f.write('#pragma once\n\n') + f.write('#ifdef __cplusplus\nextern "C" {\n#endif\n\n') + f.write('extern const unsigned char g_model[];\n') + f.write('extern const int g_model_len;\n\n') + f.write('#ifdef __cplusplus\n}\n#endif\n') + +print(f"Wrote {out_cc} ({len(tflite_model)} bytes)") +print(f"Wrote {out_h}") + +# --- Save MLP weights as numpy for laptop-side inference (no TF needed) ------ +layer_weights = [layer.get_weights() for layer in model.layers if layer.get_weights()] +mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz' +mlp_path.parent.mkdir(parents=True, exist_ok=True) +np.savez(mlp_path, + w0=layer_weights[0][0], b0=layer_weights[0][1], # Dense(32, relu) + w1=layer_weights[1][0], b1=layer_weights[1][1], # Dense(16, relu) + w2=layer_weights[2][0], b2=layer_weights[2][1], # Dense(5, softmax) + label_names=np.array(label_names)) +print(f"Saved laptop MLP weights to {mlp_path}") + +print(f"\nNext steps:") +print(f" 1. Set MODEL_USE_MLP 1 in EMG_Arm/src/core/model_weights.h") +print(f" 2. Run: pio run -t upload")