Multi-model prediction on laptop | need to move to on-board prediction

This commit is contained in:
Surya Balaji
2026-03-10 11:39:02 -05:00
parent 90217a1365
commit 349bcffc71
64 changed files with 8461 additions and 1127 deletions

2736
BUCKY_ARM_MASTER_PLAN.md Normal file

File diff suppressed because it is too large Load Diff

10
EMG_Arm/dependencies.lock Normal file
View File

@@ -0,0 +1,10 @@
dependencies:
idf:
source:
type: idf
version: 5.5.1
direct_dependencies:
- idf
manifest_hash: 26c322f28d1cb305b28c1bcb6df69caf3919f0c18286c9ac6394e338c217fba8
target: esp32s3
version: 2.0.0

10
EMG_Arm/idf_component.yml Normal file
View File

@@ -0,0 +1,10 @@
# ESP-IDF Component Manager dependencies
# These are ESP-IDF components (NOT PlatformIO libraries).
# Run `pio run --target menuconfig` after modifying to refresh the component index.
dependencies:
# esp-dsp: required when MODEL_EXPAND_FEATURES=1 (Change 1 — 69-feature FFT)
espressif/esp-dsp: ">=2.0.0"
# TFLite Micro: required when MODEL_USE_MLP=1 (Change E — int8 MLP)
# tensorflow/tflite-micro: ">=2.0.0"

View File

@@ -14,4 +14,7 @@ board_build.partitions = partitions.csv
monitor_speed = 921600 monitor_speed = 921600
monitor_dtr = 1 monitor_dtr = 1
monitor_rts = 1 monitor_rts = 1
; ── esp-dsp: required for MODEL_EXPAND_FEATURES=1 (FFT-based features) ───────
; Cloned locally: components/esp-dsp

View File

@@ -594,6 +594,14 @@ CONFIG_PARTITION_TABLE_OFFSET=0x8000
CONFIG_PARTITION_TABLE_MD5=y CONFIG_PARTITION_TABLE_MD5=y
# end of Partition Table # end of Partition Table
#
# ESP-NN
#
# CONFIG_NN_ANSI_C is not set
CONFIG_NN_OPTIMIZED=y
CONFIG_NN_OPTIMIZATIONS=1
# end of ESP-NN
# #
# Compiler options # Compiler options
# #
@@ -2237,6 +2245,23 @@ CONFIG_WIFI_PROV_AUTOSTOP_TIMEOUT=30
CONFIG_WIFI_PROV_STA_ALL_CHANNEL_SCAN=y CONFIG_WIFI_PROV_STA_ALL_CHANNEL_SCAN=y
# CONFIG_WIFI_PROV_STA_FAST_SCAN is not set # CONFIG_WIFI_PROV_STA_FAST_SCAN is not set
# end of Wi-Fi Provisioning Manager # end of Wi-Fi Provisioning Manager
#
# DSP Library
#
CONFIG_DSP_OPTIMIZATIONS_SUPPORTED=y
# CONFIG_DSP_ANSI is not set
CONFIG_DSP_OPTIMIZED=y
CONFIG_DSP_OPTIMIZATION=1
# CONFIG_DSP_MAX_FFT_SIZE_512 is not set
# CONFIG_DSP_MAX_FFT_SIZE_1024 is not set
# CONFIG_DSP_MAX_FFT_SIZE_2048 is not set
CONFIG_DSP_MAX_FFT_SIZE_4096=y
# CONFIG_DSP_MAX_FFT_SIZE_8192 is not set
# CONFIG_DSP_MAX_FFT_SIZE_16384 is not set
# CONFIG_DSP_MAX_FFT_SIZE_32768 is not set
CONFIG_DSP_MAX_FFT_SIZE=4096
# end of DSP Library
# end of Component config # end of Component config
# CONFIG_IDF_EXPERIMENTAL_FEATURES is not set # CONFIG_IDF_EXPERIMENTAL_FEATURES is not set

View File

@@ -19,8 +19,13 @@ set(DRIVER_SOURCES
) )
set(CORE_SOURCES set(CORE_SOURCES
core/bicep.c
core/calibration.c
core/gestures.c core/gestures.c
core/inference.c core/inference.c
core/inference_ensemble.c
core/inference_mlp.cc
core/emg_model_data.cc
) )
set(APP_SOURCES set(APP_SOURCES
@@ -36,5 +41,5 @@ idf_component_register(
${APP_SOURCES} ${APP_SOURCES}
INCLUDE_DIRS INCLUDE_DIRS
. .
REQUIRES esp_adc REQUIRES esp_adc nvs_flash esp-dsp esp-tflite-micro
) )

View File

@@ -20,8 +20,13 @@
#include <string.h> #include <string.h>
#include "config/config.h" #include "config/config.h"
#include "core/bicep.h"
#include "core/calibration.h"
#include "core/gestures.h" #include "core/gestures.h"
#include "core/inference.h" // [NEW] #include "core/inference.h"
#include "core/inference_ensemble.h"
#include "core/inference_mlp.h"
#include "core/model_weights.h"
#include "drivers/emg_sensor.h" #include "drivers/emg_sensor.h"
#include "drivers/hand.h" #include "drivers/hand.h"
@@ -40,10 +45,12 @@
* @brief Device state machine. * @brief Device state machine.
*/ */
typedef enum { typedef enum {
STATE_IDLE = 0, /**< Waiting for connect command */ STATE_IDLE = 0, /**< Waiting for connect command */
STATE_CONNECTED, /**< Connected, waiting for start command */ STATE_CONNECTED, /**< Connected, waiting for start command */
STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */ STATE_STREAMING, /**< Streaming raw EMG CSV to laptop (data collection) */
STATE_PREDICTING, /**< [NEW] On-device inference and control */ STATE_PREDICTING, /**< On-device inference + arm control */
STATE_LAPTOP_PREDICT, /**< Streaming CSV to laptop; laptop infers + sends gesture cmds back */
STATE_CALIBRATING, /**< Collecting rest data for calibration */
} device_state_t; } device_state_t;
/** /**
@@ -52,8 +59,10 @@ typedef enum {
typedef enum { typedef enum {
CMD_NONE = 0, CMD_NONE = 0,
CMD_CONNECT, CMD_CONNECT,
CMD_START, /**< Start raw streaming */ CMD_START, /**< Start raw ADC streaming to laptop */
CMD_START_PREDICT, /**< [NEW] Start on-device prediction */ CMD_START_PREDICT, /**< Start on-device inference + arm control */
CMD_START_LAPTOP_PREDICT, /**< Start laptop-mediated inference (stream + receive cmds) */
CMD_CALIBRATE, /**< Run rest calibration (z-score + bicep threshold) */
CMD_STOP, CMD_STOP,
CMD_DISCONNECT, CMD_DISCONNECT,
} command_t; } command_t;
@@ -65,11 +74,17 @@ typedef enum {
static volatile device_state_t g_device_state = STATE_IDLE; static volatile device_state_t g_device_state = STATE_IDLE;
static QueueHandle_t g_cmd_queue = NULL; static QueueHandle_t g_cmd_queue = NULL;
// Latest gesture command received from laptop during STATE_LAPTOP_PREDICT.
// Written by serial_input_task; read+cleared by run_laptop_predict_loop.
// gesture_t is a 32-bit int on LX7 — reads/writes are atomic; volatile is sufficient.
static volatile gesture_t g_laptop_gesture = GESTURE_NONE;
/******************************************************************************* /*******************************************************************************
* Forward Declarations * Forward Declarations
******************************************************************************/ ******************************************************************************/
static void send_ack_connect(void); static void send_ack_connect(void);
static gesture_t parse_laptop_gesture(const char *line);
/******************************************************************************* /*******************************************************************************
* Command Parsing * Command Parsing
@@ -102,13 +117,17 @@ static command_t parse_command(const char *line) {
value_start++; value_start++;
} }
/* Match command strings */ /* Match command strings — ordered longest-prefix-first to avoid false matches */
if (strncmp(value_start, "connect", 7) == 0) { if (strncmp(value_start, "connect", 7) == 0) {
return CMD_CONNECT; return CMD_CONNECT;
} else if (strncmp(value_start, "start_laptop_predict", 20) == 0) {
return CMD_START_LAPTOP_PREDICT;
} else if (strncmp(value_start, "start_predict", 13) == 0) { } else if (strncmp(value_start, "start_predict", 13) == 0) {
return CMD_START_PREDICT; return CMD_START_PREDICT;
} else if (strncmp(value_start, "start", 5) == 0) { } else if (strncmp(value_start, "start", 5) == 0) {
return CMD_START; return CMD_START;
} else if (strncmp(value_start, "calibrate", 9) == 0) {
return CMD_CALIBRATE;
} else if (strncmp(value_start, "stop", 4) == 0) { } else if (strncmp(value_start, "stop", 4) == 0) {
return CMD_STOP; return CMD_STOP;
} else if (strncmp(value_start, "disconnect", 10) == 0) { } else if (strncmp(value_start, "disconnect", 10) == 0) {
@@ -118,6 +137,38 @@ static command_t parse_command(const char *line) {
return CMD_NONE; return CMD_NONE;
} }
/*******************************************************************************
* Laptop Gesture Parser
******************************************************************************/
/**
* @brief Parse a gesture command sent by live_predict.py.
*
* Expected format: {"gesture":"fist"}
* Returns GESTURE_NONE if the line is not a valid gesture command.
*/
static gesture_t parse_laptop_gesture(const char *line) {
const char *g = strstr(line, "\"gesture\"");
if (!g) return GESTURE_NONE;
const char *v = strchr(g, ':');
if (!v) return GESTURE_NONE;
v++;
while (*v == ' ' || *v == '"') v++;
// Extract gesture name up to the closing quote
char name[32] = {0};
int ni = 0;
while (*v && *v != '"' && ni < (int)(sizeof(name) - 1)) {
name[ni++] = *v++;
}
// Delegate to the inference module's name→enum mapping
int result = inference_get_gesture_by_name(name);
return (gesture_t)result;
}
/******************************************************************************* /*******************************************************************************
* Serial Input Task * Serial Input Task
******************************************************************************/ ******************************************************************************/
@@ -142,6 +193,15 @@ static void serial_input_task(void *pvParameters) {
line_buffer[line_idx] = '\0'; line_buffer[line_idx] = '\0';
command_t cmd = parse_command(line_buffer); command_t cmd = parse_command(line_buffer);
// When laptop-predict is active, try to parse gesture commands first.
// These are separate from FSM commands and must not block the FSM parser.
if (g_device_state == STATE_LAPTOP_PREDICT) {
gesture_t g = parse_laptop_gesture(line_buffer);
if (g != GESTURE_NONE) {
g_laptop_gesture = g;
}
}
if (cmd != CMD_NONE) { if (cmd != CMD_NONE) {
if (cmd == CMD_CONNECT) { if (cmd == CMD_CONNECT) {
g_device_state = STATE_CONNECTED; g_device_state = STATE_CONNECTED;
@@ -161,6 +221,15 @@ static void serial_input_task(void *pvParameters) {
g_device_state = STATE_PREDICTING; g_device_state = STATE_PREDICTING;
printf("[STATE] CONNECTED -> PREDICTING\n"); printf("[STATE] CONNECTED -> PREDICTING\n");
xQueueSend(g_cmd_queue, &cmd, 0); xQueueSend(g_cmd_queue, &cmd, 0);
} else if (cmd == CMD_START_LAPTOP_PREDICT) {
g_device_state = STATE_LAPTOP_PREDICT;
g_laptop_gesture = GESTURE_NONE;
printf("[STATE] CONNECTED -> LAPTOP_PREDICT\n");
xQueueSend(g_cmd_queue, &cmd, 0);
} else if (cmd == CMD_CALIBRATE) {
g_device_state = STATE_CALIBRATING;
printf("[STATE] CONNECTED -> CALIBRATING\n");
xQueueSend(g_cmd_queue, &cmd, 0);
} else if (cmd == CMD_DISCONNECT) { } else if (cmd == CMD_DISCONNECT) {
g_device_state = STATE_IDLE; g_device_state = STATE_IDLE;
printf("[STATE] CONNECTED -> IDLE\n"); printf("[STATE] CONNECTED -> IDLE\n");
@@ -169,6 +238,8 @@ static void serial_input_task(void *pvParameters) {
case STATE_STREAMING: case STATE_STREAMING:
case STATE_PREDICTING: case STATE_PREDICTING:
case STATE_LAPTOP_PREDICT:
case STATE_CALIBRATING:
if (cmd == CMD_STOP) { if (cmd == CMD_STOP) {
g_device_state = STATE_CONNECTED; g_device_state = STATE_CONNECTED;
printf("[STATE] ACTIVE -> CONNECTED\n"); printf("[STATE] ACTIVE -> CONNECTED\n");
@@ -206,59 +277,198 @@ static void send_ack_connect(void) {
*/ */
static void stream_emg_data(void) { static void stream_emg_data(void) {
emg_sample_t sample; emg_sample_t sample;
const TickType_t delay_ticks = 1;
while (g_device_state == STATE_STREAMING) { while (g_device_state == STATE_STREAMING) {
emg_sensor_read(&sample); emg_sensor_read(&sample); // blocks ~1 ms (queue-paced by DMA task)
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1], printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
sample.channels[2], sample.channels[3]); sample.channels[2], sample.channels[3]);
vTaskDelay(delay_ticks);
} }
} }
/*******************************************************************************
* Multi-Model Voting Post-Processing
*
* Shared EMA smoothing, majority vote, and debounce applied to the averaged
* probability output from all enabled models (single LDA, ensemble, MLP).
******************************************************************************/
#define VOTE_EMA_ALPHA 0.70f
#define VOTE_CONF_THRESHOLD 0.40f
#define VOTE_WINDOW_SIZE 5
#define VOTE_DEBOUNCE_COUNT 3
static float vote_smoothed[MODEL_NUM_CLASSES];
static int vote_history[VOTE_WINDOW_SIZE];
static int vote_head = 0;
static int vote_current_output = -1;
static int vote_pending_output = -1;
static int vote_pending_count = 0;
static void vote_init(void) {
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
vote_smoothed[c] = 1.0f / MODEL_NUM_CLASSES;
for (int i = 0; i < VOTE_WINDOW_SIZE; i++)
vote_history[i] = -1;
vote_head = 0;
vote_current_output = -1;
vote_pending_output = -1;
vote_pending_count = 0;
}
/**
* @brief Apply EMA + majority vote + debounce to averaged probabilities.
*
* @param avg_proba Averaged probability vector from all models.
* @param confidence Output: smoothed confidence of the final winner.
* @return Final gesture class index (-1 if uncertain).
*/
static int vote_postprocess(const float *avg_proba, float *confidence) {
/* EMA smoothing */
float max_smooth = 0.0f;
int winner = 0;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
vote_smoothed[c] = VOTE_EMA_ALPHA * vote_smoothed[c] +
(1.0f - VOTE_EMA_ALPHA) * avg_proba[c];
if (vote_smoothed[c] > max_smooth) {
max_smooth = vote_smoothed[c];
winner = c;
}
}
/* Confidence rejection */
if (max_smooth < VOTE_CONF_THRESHOLD) {
*confidence = max_smooth;
return vote_current_output;
}
*confidence = max_smooth;
/* Majority vote */
vote_history[vote_head] = winner;
vote_head = (vote_head + 1) % VOTE_WINDOW_SIZE;
int counts[MODEL_NUM_CLASSES];
memset(counts, 0, sizeof(counts));
for (int i = 0; i < VOTE_WINDOW_SIZE; i++) {
if (vote_history[i] >= 0)
counts[vote_history[i]]++;
}
int majority = 0, majority_cnt = 0;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
if (counts[c] > majority_cnt) {
majority_cnt = counts[c];
majority = c;
}
}
/* Debounce */
int final_out = (vote_current_output >= 0) ? vote_current_output : majority;
if (vote_current_output < 0) {
vote_current_output = majority;
final_out = majority;
} else if (majority == vote_current_output) {
vote_pending_output = majority;
vote_pending_count = 1;
} else if (majority == vote_pending_output) {
if (++vote_pending_count >= VOTE_DEBOUNCE_COUNT) {
vote_current_output = majority;
final_out = majority;
}
} else {
vote_pending_output = majority;
vote_pending_count = 1;
}
return final_out;
}
/** /**
* @brief Run on-device inference (Prediction Mode). * @brief Run on-device inference (Prediction Mode).
*
* Uses multi-model voting: all enabled models (single LDA, ensemble, MLP)
* produce raw probability vectors which are averaged, then passed through
* shared EMA smoothing, majority vote, and debounce.
*/ */
static void run_inference_loop(void) { static void run_inference_loop(void) {
emg_sample_t sample; emg_sample_t sample;
const TickType_t delay_ticks = 1; // 1ms @ 1kHz
int last_gesture = -1; int last_gesture = -1;
int stride_counter = 0;
// Reset inference state
inference_init(); inference_init();
printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n"); vote_init();
int n_models = 1; /* single LDA always enabled */
#if MODEL_USE_ENSEMBLE
inference_ensemble_init();
n_models++;
#endif
#if MODEL_USE_MLP
inference_mlp_init();
n_models++;
#endif
printf("{\"status\":\"info\",\"msg\":\"Multi-model inference (%d models)\"}\n",
n_models);
while (g_device_state == STATE_PREDICTING) { while (g_device_state == STATE_PREDICTING) {
emg_sensor_read(&sample); emg_sensor_read(&sample);
// Add to buffer
// Note: sample.channels is uint16_t, matching inference engine expectation
if (inference_add_sample(sample.channels)) { if (inference_add_sample(sample.channels)) {
// Buffer full (sliding window), run prediction
// We can optimize stride here (e.g. valid prediction only every N
// samples) For now, let's predict every sample (sliding window) or
// throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz?
// maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid
// UART spam.
static int stride_counter = 0;
stride_counter++; stride_counter++;
if (stride_counter >= 20) { // 20ms stride if (stride_counter >= INFERENCE_HOP_SIZE) {
float confidence = 0;
int gesture_idx = inference_predict(&confidence);
stride_counter = 0; stride_counter = 0;
if (gesture_idx >= 0) { /* 1. Extract features once */
// Map class index (0-N) to gesture enum (correct hardware action) float features[MODEL_NUM_FEATURES];
int gesture_enum = inference_get_gesture_enum(gesture_idx); inference_extract_features(features);
// Execute gesture on hand /* 2. Collect raw probabilities from each model */
float avg_proba[MODEL_NUM_CLASSES];
memset(avg_proba, 0, sizeof(avg_proba));
/* Model A: Single LDA (always active) */
float proba_lda[MODEL_NUM_CLASSES];
inference_predict_raw(features, proba_lda);
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] += proba_lda[c];
#if MODEL_USE_ENSEMBLE
/* Model B: 3-specialist + meta-LDA ensemble */
float proba_ens[MODEL_NUM_CLASSES];
inference_ensemble_predict_raw(features, proba_ens);
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] += proba_ens[c];
#endif
#if MODEL_USE_MLP
/* Model C: int8 MLP — returns class index + confidence.
* Spread confidence as soft one-hot for averaging. */
float mlp_conf = 0.0f;
int mlp_class = inference_mlp_predict(features, MODEL_NUM_FEATURES,
&mlp_conf);
float remainder = (1.0f - mlp_conf) / (MODEL_NUM_CLASSES - 1);
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] += (c == mlp_class) ? mlp_conf : remainder;
#endif
/* 3. Average across models */
float inv_n = 1.0f / n_models;
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] *= inv_n;
/* 4. Shared post-processing */
float confidence = 0.0f;
int gesture_idx = vote_postprocess(avg_proba, &confidence);
if (gesture_idx >= 0) {
int gesture_enum = inference_get_gesture_enum(gesture_idx);
gestures_execute((gesture_t)gesture_enum); gestures_execute((gesture_t)gesture_enum);
// Send telemetry if changed or periodically? bicep_state_t bicep = bicep_detect();
// "Live prediction flow should change to only have each new output... (void)bicep;
// sent"
if (gesture_idx != last_gesture) { if (gesture_idx != last_gesture) {
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n", printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
inference_get_class_name(gesture_idx), confidence); inference_get_class_name(gesture_idx), confidence);
@@ -267,11 +477,192 @@ static void run_inference_loop(void) {
} }
} }
} }
vTaskDelay(delay_ticks);
} }
} }
/**
* @brief Laptop-mediated prediction loop (STATE_LAPTOP_PREDICT).
*
* Streams raw ADC CSV to laptop at 1kHz (same format as STATE_STREAMING).
* The laptop runs live_predict.py, which classifies each window and sends
* {"gesture":"<name>"} back over UART. serial_input_task() intercepts those
* lines and writes to g_laptop_gesture; this loop executes whatever is set.
*/
static void run_laptop_predict_loop(void) {
emg_sample_t sample;
gesture_t last_gesture = GESTURE_NONE;
printf("{\"status\":\"info\",\"msg\":\"Laptop-predict mode started\"}\n");
while (g_device_state == STATE_LAPTOP_PREDICT) {
emg_sensor_read(&sample);
// Stream raw CSV to laptop (live_predict.py reads this)
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
sample.channels[2], sample.channels[3]);
// Execute any gesture command that arrived from the laptop
gesture_t g = g_laptop_gesture;
if (g != GESTURE_NONE) {
g_laptop_gesture = GESTURE_NONE; // clear before executing
gestures_execute(g);
if (g != last_gesture) {
// Echo the executed gesture back for laptop-side logging
printf("{\"executed\":\"%s\"}\n", gestures_get_name(g));
last_gesture = g;
}
}
}
}
/**
* @brief Fully autonomous inference loop (MAIN_MODE == EMG_STANDALONE).
*
* No laptop required. Runs forever until power-off.
* Assumes inference_init() and sensor init have already been called by app_main().
*/
static void run_standalone_loop(void) {
emg_sample_t sample;
int stride_counter = 0;
int last_gesture = -1;
inference_init();
vote_init();
int n_models = 1;
#if MODEL_USE_ENSEMBLE
inference_ensemble_init();
n_models++;
#endif
#if MODEL_USE_MLP
inference_mlp_init();
n_models++;
#endif
printf("[STANDALONE] Running autonomous EMG control (%d models).\n", n_models);
while (1) {
emg_sensor_read(&sample);
if (inference_add_sample(sample.channels)) {
stride_counter++;
if (stride_counter >= INFERENCE_HOP_SIZE) {
stride_counter = 0;
float features[MODEL_NUM_FEATURES];
inference_extract_features(features);
float avg_proba[MODEL_NUM_CLASSES];
memset(avg_proba, 0, sizeof(avg_proba));
float proba_lda[MODEL_NUM_CLASSES];
inference_predict_raw(features, proba_lda);
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] += proba_lda[c];
#if MODEL_USE_ENSEMBLE
float proba_ens[MODEL_NUM_CLASSES];
inference_ensemble_predict_raw(features, proba_ens);
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] += proba_ens[c];
#endif
#if MODEL_USE_MLP
float mlp_conf = 0.0f;
int mlp_class = inference_mlp_predict(features, MODEL_NUM_FEATURES,
&mlp_conf);
float remainder = (1.0f - mlp_conf) / (MODEL_NUM_CLASSES - 1);
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] += (c == mlp_class) ? mlp_conf : remainder;
#endif
float inv_n = 1.0f / n_models;
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
avg_proba[c] *= inv_n;
float confidence = 0.0f;
int gesture_idx = vote_postprocess(avg_proba, &confidence);
if (gesture_idx >= 0) {
int gesture_enum = inference_get_gesture_enum(gesture_idx);
gestures_execute((gesture_t)gesture_enum);
bicep_state_t bicep = bicep_detect();
(void)bicep;
if (gesture_idx != last_gesture) {
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
inference_get_class_name(gesture_idx), confidence);
last_gesture = gesture_idx;
}
}
}
}
}
}
/**
* @brief Run rest calibration (STATE_CALIBRATING).
*
* Collects 3 seconds of rest EMG data, extracts features from each window,
* then updates:
* - NVS z-score calibration (Change D) via calibration_update()
* - Bicep detection threshold via bicep_calibrate_from_buffer()
*
* The user must keep their arm relaxed during this period.
* Send {"cmd": "calibrate"} from the host to trigger.
*/
static void run_calibration(void) {
#define CALIB_DURATION_SAMPLES 3000 /* 3 seconds at 1 kHz */
#define CALIB_MAX_WINDOWS \
((CALIB_DURATION_SAMPLES - INFERENCE_WINDOW_SIZE) / INFERENCE_HOP_SIZE + 1)
printf("{\"status\":\"calibrating\",\"duration_ms\":3000}\n");
fflush(stdout);
inference_init(); /* reset buffer + filter state */
float *feat_matrix = (float *)malloc(
(size_t)CALIB_MAX_WINDOWS * MODEL_NUM_FEATURES * sizeof(float));
if (!feat_matrix) {
printf("{\"status\":\"error\",\"msg\":\"Calibration malloc failed\"}\n");
g_device_state = STATE_CONNECTED;
return;
}
emg_sample_t sample;
int window_count = 0;
int stride_counter = 0;
for (int s = 0; s < CALIB_DURATION_SAMPLES; s++) {
emg_sensor_read(&sample);
if (inference_add_sample(sample.channels)) {
stride_counter++;
if (stride_counter >= INFERENCE_HOP_SIZE) {
stride_counter = 0;
if (window_count < CALIB_MAX_WINDOWS) {
inference_extract_features(
feat_matrix + window_count * MODEL_NUM_FEATURES);
window_count++;
}
}
}
}
if (window_count >= 10) {
calibration_update(feat_matrix, window_count, MODEL_NUM_FEATURES);
bicep_calibrate_from_buffer(INFERENCE_WINDOW_SIZE);
printf("{\"status\":\"calibrated\",\"windows\":%d}\n", window_count);
} else {
printf("{\"status\":\"error\",\"msg\":\"Not enough calibration data\"}\n");
}
free(feat_matrix);
g_device_state = STATE_CONNECTED;
#undef CALIB_DURATION_SAMPLES
#undef CALIB_MAX_WINDOWS
}
static void state_machine_loop(void) { static void state_machine_loop(void) {
command_t cmd; command_t cmd;
const TickType_t poll_interval = pdMS_TO_TICKS(50); const TickType_t poll_interval = pdMS_TO_TICKS(50);
@@ -281,6 +672,10 @@ static void state_machine_loop(void) {
stream_emg_data(); stream_emg_data();
} else if (g_device_state == STATE_PREDICTING) { } else if (g_device_state == STATE_PREDICTING) {
run_inference_loop(); run_inference_loop();
} else if (g_device_state == STATE_LAPTOP_PREDICT) {
run_laptop_predict_loop();
} else if (g_device_state == STATE_CALIBRATING) {
run_calibration();
} }
xQueueReceive(g_cmd_queue, &cmd, poll_interval); xQueueReceive(g_cmd_queue, &cmd, poll_interval);
@@ -299,7 +694,9 @@ void appConnector() {
printf("[PROTOCOL] Waiting for host to connect...\n"); printf("[PROTOCOL] Waiting for host to connect...\n");
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n"); printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n");
printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device " printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device "
"inference\n\n"); "inference\n");
printf("[PROTOCOL] Send: {\"cmd\": \"calibrate\"} for rest "
"calibration (3s)\n\n");
state_machine_loop(); state_machine_loop();
} }
@@ -324,13 +721,108 @@ void app_main(void) {
printf("[INIT] Initializing Inference Engine...\n"); printf("[INIT] Initializing Inference Engine...\n");
inference_init(); inference_init();
#if FEATURE_FAKE_EMG printf("[INIT] Loading NVS calibration...\n");
printf("[INIT] Using FAKE EMG data (sensors not connected)\n"); calibration_init(); // Change D: no-op on first boot; loads if previously saved
#else
printf("[INIT] Using REAL EMG sensors\n");
#endif
// Bicep: load persisted threshold from NVS (if previously calibrated)
{
float bicep_thresh = 0.0f;
if (bicep_load_threshold(&bicep_thresh)) {
printf("[INIT] Bicep threshold loaded: %.1f\n", bicep_thresh);
} else {
printf("[INIT] No bicep calibration — run 'calibrate' command\n");
}
}
printf("[INIT] Using REAL EMG sensors\n");
printf("[INIT] Done!\n\n"); printf("[INIT] Done!\n\n");
appConnector(); switch (MAIN_MODE) {
} case EMG_STANDALONE:
// Fully autonomous: no laptop needed after this point.
// Boots directly into the inference + arm control loop.
run_standalone_loop(); // never returns
break;
case SERVO_CALIBRATOR:
while (1) {
int angle;
printf("Enter servo angle (0-180): ");
fflush(stdout);
// Read a line manually, yielding while waiting for UART input
char buf[16];
int idx = 0;
while (idx < (int)sizeof(buf) - 1) {
int ch = getchar();
if (ch == EOF) {
vTaskDelay(pdMS_TO_TICKS(10));
continue;
}
if (ch == '\n' || ch == '\r')
break;
buf[idx++] = (char)ch;
}
buf[idx] = '\0';
if (idx == 0)
continue;
if (sscanf(buf, "%d", &angle) == 1) {
if (angle >= 0 && angle <= 180) {
hand_set_finger_angle(FINGER_THUMB, angle);
vTaskDelay(pdMS_TO_TICKS(1000));
} else {
printf("Invalid angle. Must be between 0 and 180.\n");
}
} else {
printf("Invalid input.\n");
}
}
break;
case GESTURE_TESTER:
while (1) {
fflush(stdout);
int ch = getchar();
if (ch == EOF) {
vTaskDelay(pdMS_TO_TICKS(10));
continue;
}
if (ch == '\n' || ch == '\r') {
continue;
}
gesture_t gesture = GESTURE_NONE;
switch (ch) {
case 'r': gesture = GESTURE_REST; break;
case 'f': gesture = GESTURE_FIST; break;
case 'o': gesture = GESTURE_OPEN; break;
case 'h': gesture = GESTURE_HOOK_EM; break;
case 't': gesture = GESTURE_THUMBS_UP; break;
default:
printf("Invalid gesture: %c\n", ch);
continue;
}
printf("Executing gesture: %s\n", gestures_get_name(gesture));
gestures_execute(gesture);
vTaskDelay(pdMS_TO_TICKS(500));
}
break;
case EMG_MAIN:
appConnector();
break;
default:
printf("[ERROR] Unknown MAIN_MODE\n");
break;
}
}

View File

@@ -13,19 +13,12 @@
#include "driver/ledc.h" #include "driver/ledc.h"
/******************************************************************************* /*******************************************************************************
* Feature Flags * Main Modes
*
* Compile-time switches to enable/disable features.
* Set to 1 to enable, 0 to disable.
******************************************************************************/ ******************************************************************************/
/** enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER, EMG_STANDALONE};
* @brief Use fake EMG data (random values) instead of real ADC reads.
* #define MAIN_MODE EMG_MAIN
* Set to 1 while waiting for EMG sensors to arrive.
* Set to 0 when ready to use real sensors.
*/
#define FEATURE_FAKE_EMG 0
/******************************************************************************* /*******************************************************************************
* GPIO Pin Definitions - Servos * GPIO Pin Definitions - Servos
@@ -95,15 +88,4 @@ typedef enum {
GESTURE_COUNT GESTURE_COUNT
} gesture_t; } gesture_t;
/**
* @brief System operating modes.
*/
typedef enum {
MODE_IDLE = 0, /**< Waiting for commands */
MODE_DATA_STREAM, /**< Streaming EMG data to laptop */
MODE_COMMAND, /**< Executing gesture commands from laptop */
MODE_DEMO, /**< Running demo sequence */
MODE_COUNT
} system_mode_t;
#endif /* CONFIG_H */ #endif /* CONFIG_H */

142
EMG_Arm/src/core/bicep.c Normal file
View File

@@ -0,0 +1,142 @@
/**
* @file bicep.c
* @brief Bicep channel subsystem — binary flex/unflex detector (Phase 1).
*/
#include "bicep.h"
#include "inference.h" /* inference_get_bicep_rms() */
#include "nvs_flash.h"
#include "nvs.h"
#include <math.h>
#include <stdio.h>
#include <string.h>
/* Tuning constants */
#define BICEP_WINDOW_SAMPLES 50 /**< 50 ms window at 1 kHz */
#define BICEP_FLEX_MULTIPLIER 2.5f /**< threshold = rest_rms × 2.5 */
#define BICEP_HYSTERESIS 1.3f /**< scale factor to enter flex (prevents toggling) */
/* NVS storage */
#define BICEP_NVS_NAMESPACE "bicep_calib"
#define BICEP_NVS_KEY_THRESH "threshold"
#define BICEP_NVS_KEY_VALID "calib_ok"
/* Module state */
static float s_threshold_mv = 0.0f;
static bicep_state_t s_state = BICEP_STATE_REST;
/*******************************************************************************
* Public API
******************************************************************************/
float bicep_calibrate(const uint16_t *ch3_samples, int n_samples) {
if (n_samples <= 0) return 0.0f;
float rms_sq = 0.0f;
for (int i = 0; i < n_samples; i++) {
float v = (float)ch3_samples[i];
rms_sq += v * v;
}
float rest_rms = sqrtf(rms_sq / n_samples);
s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER;
s_state = BICEP_STATE_REST;
printf("[Bicep] Calibrated: rest_rms=%.1f mV, threshold=%.1f mV\n",
rest_rms, s_threshold_mv);
bicep_save_threshold(s_threshold_mv);
return s_threshold_mv;
}
bicep_state_t bicep_detect(void) {
if (s_threshold_mv <= 0.0f) {
return BICEP_STATE_REST; /* Not calibrated */
}
float rms = inference_get_bicep_rms(BICEP_WINDOW_SAMPLES);
/* Hysteretic threshold: need FLEX_MULTIPLIER × threshold to enter flex,
* drop below threshold to return to rest. */
if (s_state == BICEP_STATE_REST) {
if (rms > s_threshold_mv * BICEP_HYSTERESIS) {
s_state = BICEP_STATE_FLEX;
}
} else { /* BICEP_STATE_FLEX */
if (rms < s_threshold_mv) {
s_state = BICEP_STATE_REST;
}
}
return s_state;
}
bool bicep_save_threshold(float threshold_mv) {
nvs_handle_t h;
if (nvs_open(BICEP_NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) {
printf("[Bicep] Failed to open NVS for write\n");
return false;
}
esp_err_t err = ESP_OK;
err |= nvs_set_blob(h, BICEP_NVS_KEY_THRESH, &threshold_mv, sizeof(threshold_mv));
err |= nvs_set_u8 (h, BICEP_NVS_KEY_VALID, 1u);
err |= nvs_commit(h);
nvs_close(h);
if (err != ESP_OK) {
printf("[Bicep] NVS write failed (err=0x%x)\n", err);
return false;
}
printf("[Bicep] Threshold %.1f mV saved to NVS\n", threshold_mv);
return true;
}
bool bicep_load_threshold(float *threshold_mv_out) {
nvs_handle_t h;
if (nvs_open(BICEP_NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) {
return false;
}
uint8_t valid = 0;
float thresh = 0.0f;
size_t sz = sizeof(thresh);
bool ok = (nvs_get_u8 (h, BICEP_NVS_KEY_VALID, &valid) == ESP_OK) &&
(valid == 1) &&
(nvs_get_blob(h, BICEP_NVS_KEY_THRESH, &thresh, &sz) == ESP_OK) &&
(thresh > 0.0f);
nvs_close(h);
if (ok) {
s_threshold_mv = thresh;
if (threshold_mv_out) *threshold_mv_out = thresh;
printf("[Bicep] Loaded threshold: %.1f mV\n", thresh);
}
return ok;
}
float bicep_calibrate_from_buffer(int n_samples) {
float rest_rms = inference_get_bicep_rms(n_samples);
if (rest_rms < 1e-6f) {
printf("[Bicep] WARNING: rest RMS ≈ 0 — buffer may not be filled yet\n");
return 0.0f;
}
s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER;
s_state = BICEP_STATE_REST;
printf("[Bicep] Calibrated (filtered): rest_rms=%.2f, threshold=%.2f\n",
rest_rms, s_threshold_mv);
bicep_save_threshold(s_threshold_mv);
return s_threshold_mv;
}
void bicep_set_threshold(float threshold_mv) {
s_threshold_mv = threshold_mv;
s_state = BICEP_STATE_REST;
}
float bicep_get_threshold(void) {
return s_threshold_mv;
}

97
EMG_Arm/src/core/bicep.h Normal file
View File

@@ -0,0 +1,97 @@
/**
* @file bicep.h
* @brief Bicep channel (ch3) subsystem — Phase 1: binary flex/unflex detector.
*
* Implements a simple RMS threshold detector with hysteresis for bicep activation.
* ch3 data flows through the same IIR bandpass filter and circular buffer as the
* hand gesture channels (via inference_get_bicep_rms()), so no separate ADC read
* is required.
*
* Usage:
* 1. On startup: bicep_load_threshold(&thresh) — restore persisted threshold
* 2. After 3 s of relaxed rest:
* bicep_calibrate(raw_ch3_samples, n_samples) — sets + saves threshold
* 3. Every 25 ms hop:
* bicep_state_t state = bicep_detect();
*/
#ifndef BICEP_H
#define BICEP_H
#include <stdint.h>
#include <stdbool.h>
/**
* @brief Bicep activation state.
*/
typedef enum {
BICEP_STATE_REST = 0,
BICEP_STATE_FLEX = 1,
} bicep_state_t;
/**
* @brief Calibrate bicep threshold from REST data.
*
* Computes rest-RMS over the provided samples, then sets the internal
* detection threshold to rest_rms × BICEP_FLEX_MULTIPLIER.
*
* @param ch3_samples Raw ADC / mV values from the bicep channel.
* @param n_samples Number of samples provided.
* @return Computed threshold in the same units as ch3_samples.
*/
float bicep_calibrate(const uint16_t *ch3_samples, int n_samples);
/**
* @brief Detect current bicep state from the latest window.
*
* Uses inference_get_bicep_rms(BICEP_WINDOW_SAMPLES) internally, so
* inference_add_sample() must have been called to fill the buffer first.
*
* @return BICEP_STATE_FLEX or BICEP_STATE_REST.
*/
bicep_state_t bicep_detect(void);
/**
* @brief Persist the current threshold to NVS.
*
* @param threshold_mv Threshold value to save (in mV / same units as bicep RMS).
* @return true on success.
*/
bool bicep_save_threshold(float threshold_mv);
/**
* @brief Load the persisted threshold from NVS.
*
* @param threshold_mv_out Output pointer; untouched on failure.
* @return true if a valid threshold was loaded.
*/
bool bicep_load_threshold(float *threshold_mv_out);
/**
* @brief Set the detection threshold directly (without NVS save).
*/
void bicep_set_threshold(float threshold_mv);
/**
* @brief Return the current threshold (0 if not calibrated).
*/
float bicep_get_threshold(void);
/**
* @brief Calibrate bicep threshold from the filtered inference buffer.
*
* Uses inference_get_bicep_rms() to read from the bandpass-filtered
* circular buffer — the same data source that bicep_detect() uses.
* The old bicep_calibrate() accepts raw uint16_t ADC values, which are
* in a different domain (includes DC offset) and produce unusable thresholds.
*
* Call this after the inference buffer has been filled with ≥ n_samples
* of rest data via inference_add_sample().
*
* @param n_samples Number of recent buffer samples to use for RMS.
* Clamped to INFERENCE_WINDOW_SIZE internally.
* @return Computed threshold (same units as bicep_detect sees).
*/
float bicep_calibrate_from_buffer(int n_samples);
#endif /* BICEP_H */

View File

@@ -0,0 +1,138 @@
/**
* @file calibration.c
* @brief NVS-backed z-score feature calibration (Change D).
*/
#include "calibration.h"
#include "nvs_flash.h"
#include "nvs.h"
#include <math.h>
#include <string.h>
#include <stdio.h>
#define NVS_NAMESPACE "emg_calib"
#define NVS_KEY_MEAN "feat_mean"
#define NVS_KEY_STD "feat_std"
#define NVS_KEY_NFEAT "n_feat"
#define NVS_KEY_VALID "calib_ok"
static float s_mean[CALIB_MAX_FEATURES];
static float s_std[CALIB_MAX_FEATURES];
static int s_n_feat = 0;
static bool s_valid = false;
bool calibration_init(void) {
/* Standard NVS flash initialisation boilerplate */
esp_err_t err = nvs_flash_init();
if (err == ESP_ERR_NVS_NO_FREE_PAGES || err == ESP_ERR_NVS_NEW_VERSION_FOUND) {
nvs_flash_erase();
nvs_flash_init();
}
nvs_handle_t h;
if (nvs_open(NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) {
printf("[Calib] No NVS partition found — identity transform active\n");
return false;
}
uint8_t valid = 0;
int32_t n_feat = 0;
size_t mean_sz = sizeof(s_mean);
size_t std_sz = sizeof(s_std);
bool ok = (nvs_get_u8 (h, NVS_KEY_VALID, &valid) == ESP_OK) &&
(valid == 1) &&
(nvs_get_i32 (h, NVS_KEY_NFEAT, &n_feat) == ESP_OK) &&
(n_feat > 0 && n_feat <= CALIB_MAX_FEATURES) &&
(nvs_get_blob(h, NVS_KEY_MEAN, s_mean, &mean_sz) == ESP_OK) &&
(nvs_get_blob(h, NVS_KEY_STD, s_std, &std_sz) == ESP_OK);
nvs_close(h);
if (ok) {
s_n_feat = (int)n_feat;
s_valid = true;
printf("[Calib] Loaded from NVS (%d features)\n", s_n_feat);
} else {
printf("[Calib] No valid calibration in NVS — identity transform active\n");
}
return ok;
}
void calibration_apply(float *feat) {
if (!s_valid) return;
for (int i = 0; i < s_n_feat; i++) {
feat[i] = (feat[i] - s_mean[i]) / s_std[i];
}
}
bool calibration_update(const float *X_flat, int n_windows, int n_feat) {
if (n_windows < 10 || n_feat <= 0 || n_feat > CALIB_MAX_FEATURES) {
printf("[Calib] calibration_update: invalid args (%d windows, %d features)\n",
n_windows, n_feat);
return false;
}
s_n_feat = n_feat;
/* Compute per-feature mean */
memset(s_mean, 0, sizeof(s_mean));
for (int w = 0; w < n_windows; w++) {
for (int f = 0; f < n_feat; f++) {
s_mean[f] += X_flat[w * n_feat + f];
}
}
for (int f = 0; f < n_feat; f++) {
s_mean[f] /= n_windows;
}
/* Compute per-feature std (with epsilon floor) */
memset(s_std, 0, sizeof(s_std));
for (int w = 0; w < n_windows; w++) {
for (int f = 0; f < n_feat; f++) {
float d = X_flat[w * n_feat + f] - s_mean[f];
s_std[f] += d * d;
}
}
for (int f = 0; f < n_feat; f++) {
float var = s_std[f] / n_windows;
s_std[f] = (var > 1e-12f) ? sqrtf(var) : 1e-6f;
}
/* Persist to NVS */
nvs_handle_t h;
if (nvs_open(NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) {
printf("[Calib] calibration_update: failed to open NVS\n");
return false;
}
esp_err_t err = ESP_OK;
err |= nvs_set_blob(h, NVS_KEY_MEAN, s_mean, sizeof(s_mean));
err |= nvs_set_blob(h, NVS_KEY_STD, s_std, sizeof(s_std));
err |= nvs_set_i32 (h, NVS_KEY_NFEAT, (int32_t)n_feat);
err |= nvs_set_u8 (h, NVS_KEY_VALID, 1u);
err |= nvs_commit(h);
nvs_close(h);
if (err != ESP_OK) {
printf("[Calib] calibration_update: NVS write failed (err=0x%x)\n", err);
return false;
}
s_valid = true;
printf("[Calib] Updated: %d REST windows, %d features saved to NVS\n",
n_windows, n_feat);
return true;
}
void calibration_reset(void) {
s_valid = false;
s_n_feat = 0;
memset(s_mean, 0, sizeof(s_mean));
memset(s_std, 0, sizeof(s_std));
}
bool calibration_is_valid(void) {
return s_valid;
}

View File

@@ -0,0 +1,68 @@
/**
* @file calibration.h
* @brief NVS-backed z-score feature calibration for on-device EMG inference.
*
* Change D: Stores per-feature mean and std computed from a short REST session
* in ESP32 non-volatile storage (NVS). At inference time, calibration_apply()
* z-scores each feature vector before the LDA classifier sees it. This removes
* day-to-day electrode placement drift without retraining the model.
*
* Typical workflow:
* 1. calibration_init() — called once at startup; loads from NVS
* 2. calibration_update() — called after collecting ~3s of REST windows
* 3. calibration_apply(feat) — called every inference hop in inference.c
*/
#ifndef CALIBRATION_H
#define CALIBRATION_H
#include <stdbool.h>
#include <stdint.h>
/* Maximum supported feature vector length.
* 96 > 69 (expanded) > 12 (legacy) — gives headroom for future expansion. */
#define CALIB_MAX_FEATURES 96
/**
* @brief Initialise NVS and load stored calibration statistics.
*
* Must be called once before any inference starts (e.g., in app_main).
* @return true Calibration data found and loaded.
* @return false No stored data; calibration_apply() will be a no-op.
*/
bool calibration_init(void);
/**
* @brief Apply stored z-score calibration to a feature vector in-place.
*
* x_i_out = (x_i - mean_i) / std_i
*
* No-op if calibration is not valid (calibration_is_valid() == false).
* @param feat Feature vector of length n_feat (set during calibration_update).
*/
void calibration_apply(float *feat);
/**
* @brief Compute and persist calibration statistics from REST EMG windows.
*
* Computes per-feature mean and std over n_windows windows, stores to NVS.
* After this call, calibration_apply() uses the new statistics.
*
* @param X_flat Flattened feature array [n_windows × n_feat], row-major.
* @param n_windows Number of windows (minimum 10).
* @param n_feat Feature vector length (≤ CALIB_MAX_FEATURES).
* @return true on success, false if inputs are invalid or NVS write fails.
*/
bool calibration_update(const float *X_flat, int n_windows, int n_feat);
/**
* @brief Clear calibration state (in-memory only; does not erase NVS).
*/
void calibration_reset(void);
/**
* @brief Check whether valid calibration statistics are loaded.
*/
bool calibration_is_valid(void);
#endif /* CALIBRATION_H */

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,13 @@
// Auto-generated by train_mlp_tflite.py — do not edit
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
extern const unsigned char g_model[];
extern const int g_model_len;
#ifdef __cplusplus
}
#endif

View File

@@ -10,6 +10,7 @@
#include <freertos/FreeRTOS.h> #include <freertos/FreeRTOS.h>
#include <freertos/task.h> #include <freertos/task.h>
/******************************************************************************* /*******************************************************************************
* Private Data * Private Data
******************************************************************************/ ******************************************************************************/
@@ -51,6 +52,19 @@ void gestures_execute(gesture_t gesture)
} }
} }
gesture_t parse_gesture(const char *s)
{
if (strcmp(s, "rest") == 0) return GESTURE_REST;
if (strcmp(s, "fist") == 0) return GESTURE_FIST;
if (strcmp(s, "open") == 0) return GESTURE_OPEN;
if (strcmp(s, "hook-em") == 0 ||
strcmp(s, "hookem") == 0) return GESTURE_HOOK_EM;
if (strcmp(s, "thumbs-up") == 0 ||
strcmp(s, "thumbsup") == 0) return GESTURE_THUMBS_UP;
return GESTURE_NONE;
}
const char* gestures_get_name(gesture_t gesture) const char* gestures_get_name(gesture_t gesture)
{ {
if (gesture >= GESTURE_COUNT) { if (gesture >= GESTURE_COUNT) {
@@ -65,38 +79,49 @@ const char* gestures_get_name(gesture_t gesture)
void gesture_open(void) void gesture_open(void)
{ {
hand_unflex_all(); hand_set_finger_angle(FINGER_THUMB, minAngles[FINGER_THUMB]);
hand_set_finger_angle(FINGER_INDEX, minAngles[FINGER_INDEX]);
hand_set_finger_angle(FINGER_MIDDLE, minAngles[FINGER_MIDDLE]);
hand_set_finger_angle(FINGER_RING, minAngles[FINGER_RING]);
hand_set_finger_angle(FINGER_PINKY, minAngles[FINGER_PINKY]);
} }
void gesture_fist(void) void gesture_fist(void)
{ {
hand_flex_all(); hand_set_finger_angle(FINGER_INDEX, maxAngles[FINGER_INDEX]);
hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]);
hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]);
hand_set_finger_angle(FINGER_PINKY, maxAngles[FINGER_PINKY]);
hand_set_finger_angle(FINGER_THUMB, maxAngles[FINGER_THUMB]);
} }
void gesture_hook_em(void) void gesture_hook_em(void)
{ {
/* Index and pinky extended, others flexed */ /* Index and pinky extended, others flexed */
hand_flex_finger(FINGER_THUMB); hand_set_finger_angle(FINGER_THUMB, maxAngles[FINGER_THUMB]);
hand_unflex_finger(FINGER_INDEX); hand_set_finger_angle(FINGER_INDEX, minAngles[FINGER_INDEX]);
hand_flex_finger(FINGER_MIDDLE); hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]);
hand_flex_finger(FINGER_RING); hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]);
hand_unflex_finger(FINGER_PINKY); hand_set_finger_angle(FINGER_PINKY, minAngles[FINGER_PINKY]);
} }
void gesture_thumbs_up(void) void gesture_thumbs_up(void)
{ {
/* Thumb extended, others flexed */ /* Thumb extended, others flexed */
hand_unflex_finger(FINGER_THUMB); hand_set_finger_angle(FINGER_THUMB, minAngles[FINGER_THUMB]);
hand_flex_finger(FINGER_INDEX); hand_set_finger_angle(FINGER_INDEX, maxAngles[FINGER_INDEX]);
hand_flex_finger(FINGER_MIDDLE); hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]);
hand_flex_finger(FINGER_RING); hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]);
hand_flex_finger(FINGER_PINKY); hand_set_finger_angle(FINGER_PINKY, maxAngles[FINGER_PINKY]);
} }
void gesture_rest(void) void gesture_rest(void)
{ {
/* Rest is same as open - neutral position */ hand_set_finger_angle(FINGER_THUMB, (maxAngles[FINGER_THUMB] + minAngles[FINGER_THUMB])/2);
gesture_open(); hand_set_finger_angle(FINGER_INDEX, (maxAngles[FINGER_INDEX] + minAngles[FINGER_INDEX])/2);
hand_set_finger_angle(FINGER_MIDDLE, (maxAngles[FINGER_MIDDLE] + minAngles[FINGER_MIDDLE])/2);
hand_set_finger_angle(FINGER_RING, (maxAngles[FINGER_RING] + minAngles[FINGER_RING])/2);
hand_set_finger_angle(FINGER_PINKY, (maxAngles[FINGER_PINKY] + minAngles[FINGER_PINKY])/2);
} }
/******************************************************************************* /*******************************************************************************

View File

@@ -13,6 +13,7 @@
#define GESTURES_H #define GESTURES_H
#include <stdint.h> #include <stdint.h>
#include <string.h>
#include "config/config.h" #include "config/config.h"
/******************************************************************************* /*******************************************************************************
@@ -26,6 +27,8 @@
*/ */
void gestures_execute(gesture_t gesture); void gestures_execute(gesture_t gesture);
gesture_t parse_gesture(const char *s);
/** /**
* @brief Get the name of a gesture as a string. * @brief Get the name of a gesture as a string.
* *

View File

@@ -4,19 +4,52 @@
*/ */
#include "inference.h" #include "inference.h"
#include "calibration.h"
#include "config/config.h" #include "config/config.h"
#include "model_weights.h" #include "model_weights.h"
#include <math.h> #include <math.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#if MODEL_EXPAND_FEATURES
#include "dsps_fft2r.h" /* esp-dsp: complex 2-radix FFT */
#define FFT_N 256 /* Must match Python fft_n=256 */
#endif
// --- Constants --- // --- Constants ---
#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python) #define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python)
#define VOTE_WINDOW 5 // Majority vote window size #define VOTE_WINDOW 5 // Majority vote window size
#define DEBOUNCE_COUNT 3 // Confirmations needed to change output #define DEBOUNCE_COUNT 3 // Confirmations needed to change output
// Change C: confidence rejection threshold.
// When the peak smoothed probability stays below this, hold the last confirmed
// output rather than outputting an uncertain prediction. Prevents false arm
// actuations during gesture transitions, rest-to-gesture onset, and electrode lift.
// Meta paper uses 0.35; 0.40 adds a prosthetic safety margin.
// Tune down to 0.35 if real gestures are being incorrectly rejected.
#define CONFIDENCE_THRESHOLD 0.40f
// --- Change B: IIR Bandpass Filter (20450 Hz, 2nd-order Butterworth @ 1 kHz) ---
// Two cascaded biquad sections, Direct Form II Transposed.
// Computed via scipy.signal.butter(2, [20,450], btype='bandpass', fs=1000, output='sos').
// b coefficients [b0, b1, b2] per section:
#define IIR_NUM_SECTIONS 2
static const float IIR_B[IIR_NUM_SECTIONS][3] = {
{ 0.7320224766f, 1.4640449531f, 0.7320224766f }, /* section 0 */
{ 1.0000000000f, -2.0000000000f, 1.0000000000f }, /* section 1 */
};
// Feedback coefficients [a1, a2] per section (a0 = 1, implicit):
static const float IIR_A[IIR_NUM_SECTIONS][2] = {
{ 1.5597081442f, 0.6416146818f }, /* section 0 */
{ -1.8224796027f, 0.8372542588f }, /* section 1 */
};
// --- State --- // --- State ---
static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS]; static float window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS];
static float biquad_w[IIR_NUM_SECTIONS][NUM_CHANNELS][2]; /* biquad delay states */
#if MODEL_EXPAND_FEATURES
static bool s_fft_ready = false;
#endif
static int buffer_head = 0; static int buffer_head = 0;
static int samples_collected = 0; static int samples_collected = 0;
@@ -30,9 +63,17 @@ static int pending_count = 0;
void inference_init(void) { void inference_init(void) {
memset(window_buffer, 0, sizeof(window_buffer)); memset(window_buffer, 0, sizeof(window_buffer));
memset(biquad_w, 0, sizeof(biquad_w));
buffer_head = 0; buffer_head = 0;
samples_collected = 0; samples_collected = 0;
#if MODEL_EXPAND_FEATURES
if (!s_fft_ready) {
dsps_fft2r_init_fc32(NULL, FFT_N);
s_fft_ready = true;
}
#endif
// Initialize smoothing // Initialize smoothing
for (int i = 0; i < MODEL_NUM_CLASSES; i++) { for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES; smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES;
@@ -47,9 +88,17 @@ void inference_init(void) {
} }
bool inference_add_sample(uint16_t *channels) { bool inference_add_sample(uint16_t *channels) {
// Add to circular buffer // Convert to float, apply per-channel biquad bandpass, then store in circular buffer.
for (int i = 0; i < NUM_CHANNELS; i++) { for (int i = 0; i < NUM_CHANNELS; i++) {
window_buffer[buffer_head][i] = channels[i]; float x = (float)channels[i];
// Cascade IIR_NUM_SECTIONS biquad sections (Direct Form II Transposed)
for (int s = 0; s < IIR_NUM_SECTIONS; s++) {
float y = IIR_B[s][0] * x + biquad_w[s][i][0];
biquad_w[s][i][0] = IIR_B[s][1] * x - IIR_A[s][0] * y + biquad_w[s][i][1];
biquad_w[s][i][1] = IIR_B[s][2] * x - IIR_A[s][1] * y;
x = y;
}
window_buffer[buffer_head][i] = x;
} }
buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE; buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE;
@@ -65,11 +114,265 @@ bool inference_add_sample(uint16_t *channels) {
// --- Feature Extraction --- // --- Feature Extraction ---
static void compute_features(float *features_out) { /* ── helpers used by compute_features_expanded ──────────────────────────── */
// Process each channel
// We need to iterate over the logical window (unrolling circular buffer)
for (int ch = 0; ch < NUM_CHANNELS; ch++) { #if MODEL_EXPAND_FEATURES
/** Solve 4×4 linear system A·x = b via Gaussian elimination with partial pivoting.
* Returns false and leaves x zeroed if the matrix is singular. */
static bool solve_4x4(float A[4][4], const float b[4], float x[4]) {
float M[4][5];
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) M[i][j] = A[i][j];
M[i][4] = b[i];
}
for (int col = 0; col < 4; col++) {
int pivot = col;
float maxv = fabsf(M[col][col]);
for (int row = col + 1; row < 4; row++) {
if (fabsf(M[row][col]) > maxv) { maxv = fabsf(M[row][col]); pivot = row; }
}
if (maxv < 1e-8f) { x[0]=x[1]=x[2]=x[3]=0.f; return false; }
if (pivot != col) {
float tmp[5];
memcpy(tmp, M[col], 5*sizeof(float));
memcpy(M[col], M[pivot], 5*sizeof(float));
memcpy(M[pivot], tmp, 5*sizeof(float));
}
for (int row = col + 1; row < 4; row++) {
float f = M[row][col] / M[col][col];
for (int j = col; j < 5; j++) M[row][j] -= f * M[col][j];
}
}
for (int row = 3; row >= 0; row--) {
x[row] = M[row][4];
for (int j = row + 1; j < 4; j++) x[row] -= M[row][j] * x[j];
x[row] /= M[row][row];
}
return true;
}
/**
* @brief Full 69-feature extraction (Change 1).
*
* Per channel (20): RMS, WL, ZC, SSC, MAV, VAR, IEMG, WAMP,
* AR1-AR4, MNF, MDF, PKF, MNP, BP0-BP3
* Cross-channel (9 for 3 channels): corr, log-RMS-ratio, cov for each pair
*
* Feature layout must exactly match EMGFeatureExtractor._EXPANDED_KEYS +
* cross-channel order in the Python class.
*/
static void compute_features_expanded(float *features_out) {
memset(features_out, 0, MODEL_NUM_FEATURES * sizeof(float));
/* Persistent buffers for centered signals (3 ch × 150 samples) */
static float ch_signals[HAND_NUM_CHANNELS][INFERENCE_WINDOW_SIZE];
static float s_fft_buf[FFT_N * 2]; /* Complex interleaved [re,im,...] */
float ch_rms[HAND_NUM_CHANNELS];
float norm_sq = 0.0f;
/* ──────────────────────────────────────────────────────────────────────
* Pass 1: per-channel TD + AR + spectral features
* ────────────────────────────────────────────────────────────────────── */
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
/* Read channel from circular buffer (oldest-first) */
float *sig = ch_signals[ch];
float sum = 0.0f;
int idx = buffer_head;
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
sig[i] = window_buffer[idx][ch];
sum += sig[i];
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
}
/* DC removal */
float mean = sum / INFERENCE_WINDOW_SIZE;
float sq_sum = 0.0f;
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
sig[i] -= mean;
sq_sum += sig[i] * sig[i];
}
/* Change 4 — Reinhard tone-mapping: 64·x / (32 + |x|) */
#if MODEL_USE_REINHARD
sq_sum = 0.0f;
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
float x = sig[i];
sig[i] = 64.0f * x / (32.0f + fabsf(x));
sq_sum += sig[i] * sig[i];
}
#endif
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
ch_rms[ch] = rms;
norm_sq += rms * rms;
float zc_thresh = FEAT_ZC_THRESH * rms;
float ssc_thresh = (FEAT_SSC_THRESH * rms) * (FEAT_SSC_THRESH * rms);
/* TD features */
float wl = 0.0f, mav = 0.0f, iemg = 0.0f;
int zc = 0, ssc = 0, wamp = 0;
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
float a = fabsf(sig[i]);
mav += a;
iemg += a;
}
mav /= INFERENCE_WINDOW_SIZE;
float var_val = sq_sum / INFERENCE_WINDOW_SIZE; /* variance (mean already 0) */
for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) {
float diff = sig[i+1] - sig[i];
float adiff = fabsf(diff);
wl += adiff;
if (adiff > zc_thresh) wamp++;
if ((sig[i] > 0.0f && sig[i+1] < 0.0f) ||
(sig[i] < 0.0f && sig[i+1] > 0.0f)) {
if (adiff > zc_thresh) zc++;
}
if (i < INFERENCE_WINDOW_SIZE - 2) {
float d1 = sig[i+1] - sig[i];
float d2 = sig[i+1] - sig[i+2];
if ((d1 * d2) > ssc_thresh) ssc++;
}
}
/* AR(4) via Yule-Walker */
float r[5] = {0};
for (int lag = 0; lag < 5; lag++) {
for (int i = 0; i < INFERENCE_WINDOW_SIZE - lag; i++)
r[lag] += sig[i] * sig[i + lag];
r[lag] /= INFERENCE_WINDOW_SIZE;
}
float T[4][4], b_ar[4], ar[4] = {0};
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) T[i][j] = r[abs(i - j)];
b_ar[i] = r[i + 1];
}
solve_4x4(T, b_ar, ar);
/* Spectral features via FFT (zero-pad to FFT_N) */
memset(s_fft_buf, 0, sizeof(s_fft_buf));
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
s_fft_buf[2*i] = sig[i];
s_fft_buf[2*i + 1] = 0.0f;
}
dsps_fft2r_fc32(s_fft_buf, FFT_N);
dsps_bit_rev_fc32(s_fft_buf, FFT_N);
float total_power = 1e-10f;
const float freq_step = (float)EMG_SAMPLE_RATE_HZ / FFT_N;
float pwr[FFT_N / 2];
for (int k = 0; k < FFT_N / 2; k++) {
float re = s_fft_buf[2*k], im = s_fft_buf[2*k + 1];
pwr[k] = re*re + im*im;
total_power += pwr[k];
}
float mnf = 0.0f;
for (int k = 0; k < FFT_N / 2; k++) mnf += k * freq_step * pwr[k];
mnf /= total_power;
float cumsum = 0.0f, half = total_power / 2.0f;
int mdf_k = FFT_N / 2 - 1;
for (int k = 0; k < FFT_N / 2; k++) {
cumsum += pwr[k];
if (cumsum >= half) { mdf_k = k; break; }
}
float mdf = mdf_k * freq_step;
int pkf_k = 0;
for (int k = 1; k < FFT_N / 2; k++) {
if (pwr[k] > pwr[pkf_k]) pkf_k = k;
}
float pkf = pkf_k * freq_step;
float mnp = total_power / (FFT_N / 2);
float bp0 = 0, bp1 = 0, bp2 = 0, bp3 = 0;
for (int k = 0; k < FFT_N / 2; k++) {
float f = k * freq_step;
if (f >= 20.f && f < 80.f ) bp0 += pwr[k];
else if (f >= 80.f && f < 150.f) bp1 += pwr[k];
else if (f >= 150.f && f < 250.f) bp2 += pwr[k];
else if (f >= 250.f && f < 450.f) bp3 += pwr[k];
}
bp0 /= total_power; bp1 /= total_power;
bp2 /= total_power; bp3 /= total_power;
/* Store 20 features for this channel */
int base = ch * 20;
features_out[base + 0] = rms;
features_out[base + 1] = wl;
features_out[base + 2] = (float)zc;
features_out[base + 3] = (float)ssc;
features_out[base + 4] = mav;
features_out[base + 5] = var_val;
features_out[base + 6] = iemg;
features_out[base + 7] = (float)wamp;
features_out[base + 8] = ar[0];
features_out[base + 9] = ar[1];
features_out[base + 10] = ar[2];
features_out[base + 11] = ar[3];
features_out[base + 12] = mnf;
features_out[base + 13] = mdf;
features_out[base + 14] = pkf;
features_out[base + 15] = mnp;
features_out[base + 16] = bp0;
features_out[base + 17] = bp1;
features_out[base + 18] = bp2;
features_out[base + 19] = bp3;
}
/* ──────────────────────────────────────────────────────────────────────
* Amplitude normalization (matches Python normalize=True behaviour)
* ────────────────────────────────────────────────────────────────────── */
float norm_factor = sqrtf(norm_sq);
if (norm_factor < 1e-6f) norm_factor = 1e-6f;
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
int base = ch * 20;
features_out[base + 0] /= norm_factor; /* rms */
features_out[base + 1] /= norm_factor; /* wl */
features_out[base + 4] /= norm_factor; /* mav */
features_out[base + 6] /= norm_factor; /* iemg */
}
/* ──────────────────────────────────────────────────────────────────────
* Cross-channel features (3 pairs × 3 features = 9)
* Order: (0,1), (0,2), (1,2) → corr, lrms, cov
* ────────────────────────────────────────────────────────────────────── */
int xc_base = HAND_NUM_CHANNELS * 20; /* = 60 */
int xc_idx = 0;
for (int i = 0; i < HAND_NUM_CHANNELS; i++) {
for (int j = i + 1; j < HAND_NUM_CHANNELS; j++) {
float ri = ch_rms[i] + 1e-10f;
float rj = ch_rms[j] + 1e-10f;
float dot_ij = 0.0f;
for (int k = 0; k < INFERENCE_WINDOW_SIZE; k++)
dot_ij += ch_signals[i][k] * ch_signals[j][k];
float n_inv = 1.0f / INFERENCE_WINDOW_SIZE;
float corr = dot_ij * n_inv / (ri * rj);
if (corr > 1.0f) corr = 1.0f;
if (corr < -1.0f) corr = -1.0f;
float lrms = logf(ri / rj);
float cov = dot_ij * n_inv / (norm_factor * norm_factor);
features_out[xc_base + xc_idx * 3 + 0] = corr;
features_out[xc_base + xc_idx * 3 + 1] = lrms;
features_out[xc_base + xc_idx * 3 + 2] = cov;
xc_idx++;
}
}
}
#endif /* MODEL_EXPAND_FEATURES */
static void compute_features(float *features_out) {
// Process forearm channels only (ch0-ch2) for hand gesture classification.
// The bicep channel (ch3) is excluded — it will be processed independently.
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
float sum = 0; float sum = 0;
float sq_sum = 0; float sq_sum = 0;
@@ -79,7 +382,7 @@ static void compute_features(float *features_out) {
int idx = buffer_head; // Oldest sample int idx = buffer_head; // Oldest sample
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
signal[i] = (float)window_buffer[idx][ch]; signal[i] = window_buffer[idx][ch];
sum += signal[i]; sum += signal[i];
idx = (idx + 1) % INFERENCE_WINDOW_SIZE; idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
} }
@@ -96,6 +399,15 @@ static void compute_features(float *features_out) {
signal[i] -= mean; signal[i] -= mean;
sq_sum += signal[i] * signal[i]; sq_sum += signal[i] * signal[i];
} }
/* Change 4 — Reinhard tone-mapping: 64·x / (32 + |x|) */
#if MODEL_USE_REINHARD
sq_sum = 0.0f;
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
float x = signal[i];
signal[i] = 64.0f * x / (32.0f + fabsf(x));
sq_sum += signal[i] * signal[i];
}
#endif
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE); float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
@@ -133,6 +445,60 @@ static void compute_features(float *features_out) {
features_out[base + 2] = (float)zc; features_out[base + 2] = (float)zc;
features_out[base + 3] = (float)ssc; features_out[base + 3] = (float)ssc;
} }
#if MODEL_NORMALIZE_FEATURES
// Normalize amplitude-dependent features (RMS, WL) by total RMS across
// channels. This makes the model robust to impedance shifts between sessions
// while preserving relative channel activation patterns.
float total_rms_sq = 0;
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
float ch_rms = features_out[ch * 4 + 0];
total_rms_sq += ch_rms * ch_rms;
}
float norm_factor = sqrtf(total_rms_sq);
if (norm_factor < 1e-6f) norm_factor = 1e-6f;
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
features_out[ch * 4 + 0] /= norm_factor; // RMS
features_out[ch * 4 + 1] /= norm_factor; // WL
}
#endif
}
// --- Feature extraction (public wrapper used by inference_ensemble.c) ---
void inference_extract_features(float *features_out) {
#if MODEL_EXPAND_FEATURES
compute_features_expanded(features_out);
#else
compute_features(features_out);
#endif
calibration_apply(features_out);
}
// --- Raw LDA probability (for multi-model voting) ---
void inference_predict_raw(const float *features, float *proba_out) {
float raw_scores[MODEL_NUM_CLASSES];
float max_score = -1e9f;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
float score = LDA_INTERCEPTS[c];
for (int f = 0; f < MODEL_NUM_FEATURES; f++) {
score += features[f] * LDA_WEIGHTS[c][f];
}
raw_scores[c] = score;
if (score > max_score) max_score = score;
}
float sum_exp = 0.0f;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
proba_out[c] = expf(raw_scores[c] - max_score);
sum_exp += proba_out[c];
}
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
proba_out[c] /= sum_exp;
}
} }
// --- Prediction --- // --- Prediction ---
@@ -144,7 +510,14 @@ int inference_predict(float *confidence) {
// 1. Extract Features // 1. Extract Features
float features[MODEL_NUM_FEATURES]; float features[MODEL_NUM_FEATURES];
#if MODEL_EXPAND_FEATURES
compute_features_expanded(features);
#else
compute_features(features); compute_features(features);
#endif
// 1b. Change D: z-score normalise using NVS-stored session calibration
calibration_apply(features);
// 2. LDA Inference (Linear Score) // 2. LDA Inference (Linear Score)
float raw_scores[MODEL_NUM_CLASSES]; float raw_scores[MODEL_NUM_CLASSES];
@@ -195,6 +568,15 @@ int inference_predict(float *confidence) {
} }
} }
// Change C: confidence rejection.
// If the strongest smoothed probability is still too low, the classifier is
// uncertain — return the last confirmed output without updating vote/debounce state.
// This prevents low-confidence transitions from actuating the arm spuriously.
if (max_smoothed_prob < CONFIDENCE_THRESHOLD) {
*confidence = max_smoothed_prob;
return current_output; // -1 (GESTURE_NONE) until first confident prediction
}
// 3b. Majority Vote // 3b. Majority Vote
vote_history[vote_head] = smoothed_winner; vote_history[vote_head] = smoothed_winner;
vote_head = (vote_head + 1) % VOTE_WINDOW; vote_head = (vote_head + 1) % VOTE_WINDOW;
@@ -255,17 +637,12 @@ const char *inference_get_class_name(int class_idx) {
int inference_get_gesture_enum(int class_idx) { int inference_get_gesture_enum(int class_idx) {
const char *name = inference_get_class_name(class_idx); const char *name = inference_get_class_name(class_idx);
return inference_get_gesture_by_name(name);
}
// Map string name to gesture_t enum int inference_get_gesture_by_name(const char *name) {
// Strings must match those in Python list: ["fist", "hook_em", "open", // Accepts both lowercase (Python output) and uppercase (C enum name style).
// "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums // Add a new case here whenever a gesture is added to gesture_t in config.h.
// are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5
// Case-insensitive check would be safer, but let's assume Python output is
// lowercase as seen in scripts or uppercase if specified. In
// learning_data_collection.py, they seem to be "rest", "open", "fist", etc.
// Simple string matching
if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0) if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0)
return GESTURE_REST; return GESTURE_REST;
if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0) if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0)
@@ -276,6 +653,21 @@ int inference_get_gesture_enum(int class_idx) {
return GESTURE_HOOK_EM; return GESTURE_HOOK_EM;
if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0) if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0)
return GESTURE_THUMBS_UP; return GESTURE_THUMBS_UP;
return GESTURE_NONE; return GESTURE_NONE;
} }
float inference_get_bicep_rms(int n_samples) {
if (samples_collected < INFERENCE_WINDOW_SIZE) return 0.0f;
if (n_samples > INFERENCE_WINDOW_SIZE) n_samples = INFERENCE_WINDOW_SIZE;
float sum_sq = 0.0f;
// Walk backwards from buffer_head (oldest = buffer_head, newest = buffer_head - 1)
int start = (buffer_head - n_samples + INFERENCE_WINDOW_SIZE) % INFERENCE_WINDOW_SIZE;
int idx = start;
for (int i = 0; i < n_samples; i++) {
float v = window_buffer[idx][3]; // channel 3 = bicep
sum_sq += v * v;
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
}
return sqrtf(sum_sq / n_samples);
}

View File

@@ -9,9 +9,11 @@
#include <stdbool.h> #include <stdbool.h>
#include <stdint.h> #include <stdint.h>
// --- Configuration --- // --- Configuration (must match Python WINDOW_SIZE_MS / HOP_SIZE_MS) ---
#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python) #define INFERENCE_WINDOW_SIZE 150 // Window size in samples (150ms at 1kHz)
#define NUM_CHANNELS 4 // Number of EMG channels #define INFERENCE_HOP_SIZE 25 // Hop/stride in samples (25ms at 1kHz)
#define NUM_CHANNELS 4 // Total EMG channels (buffer stores all)
#define HAND_NUM_CHANNELS 3 // Forearm channels for hand classifier (ch0-ch2)
/** /**
* @brief Initialize the inference engine. * @brief Initialize the inference engine.
@@ -44,4 +46,53 @@ const char *inference_get_class_name(int class_idx);
*/ */
int inference_get_gesture_enum(int class_idx); int inference_get_gesture_enum(int class_idx);
/**
* @brief Map a gesture name string directly to gesture_t enum value.
*
* Used by the laptop-predict path to convert the name sent by live_predict.py
* into a gesture_t without needing a class index.
*
* @param name Lowercase gesture name, e.g. "fist", "rest", "open"
* @return gesture_t value, or GESTURE_NONE if unrecognised
*/
int inference_get_gesture_by_name(const char *name);
/**
* @brief Compute LDA softmax probabilities without smoothing/voting/debounce.
*
* Used by the multi-model voting path in main.c. The caller is responsible
* for post-processing (EMA, majority vote, debounce).
*
* @param features Calibrated feature vector (MODEL_NUM_FEATURES floats).
* @param proba_out Output probability array (MODEL_NUM_CLASSES floats).
*/
void inference_predict_raw(const float *features, float *proba_out);
/**
* @brief Extract and calibrate features from the current window.
*
* Dispatches to compute_features() or compute_features_expanded() depending
* on MODEL_EXPAND_FEATURES, then applies calibration_apply(). The resulting
* float array is identical to what inference_predict() uses internally.
*
* Called by inference_ensemble.c so that the ensemble path does not duplicate
* the feature-extraction logic.
*
* @param features_out Caller-allocated array of MODEL_NUM_FEATURES floats.
*/
void inference_extract_features(float *features_out);
/**
* @brief Compute RMS of the last n_samples from channel 3 (bicep) in the
* inference circular buffer.
*
* Used by the bicep subsystem to obtain current activation level without
* exposing the internal window_buffer.
*
* @param n_samples Number of samples to include (clamped to INFERENCE_WINDOW_SIZE).
* @return RMS value in the same units as the filtered buffer (mV after Change B).
* Returns 0 if the buffer is not yet filled.
*/
float inference_get_bicep_rms(int n_samples);
#endif /* INFERENCE_H */ #endif /* INFERENCE_H */

View File

@@ -0,0 +1,262 @@
/**
* @file inference_ensemble.c
* @brief 3-specialist-LDA + meta-LDA ensemble inference pipeline (Change F).
*
* Guarded by MODEL_USE_ENSEMBLE in model_weights.h.
* When 0, provides empty stubs so the file compiles unconditionally.
*/
#include "inference_ensemble.h"
#include "inference.h"
#include "model_weights.h"
#if MODEL_USE_ENSEMBLE
#include "inference_mlp.h"
#include "model_weights_ensemble.h"
#include "calibration.h"
#include <math.h>
#include <string.h>
#include <stdio.h>
#define ENSEMBLE_EMA_ALPHA 0.70f
#define ENSEMBLE_CONF_THRESHOLD 0.50f /**< below this: escalate to MLP fallback */
#define REJECT_THRESHOLD 0.40f /**< below this even after MLP: hold output */
#define REST_ACTIVITY_THRESHOLD 0.05f /**< total RMS gate — skip inference during rest */
/* Class index for "rest" — looked up once in init so we don't return
* the gesture_t enum value (GESTURE_REST = 1) which would be misinterpreted
* as class index 1 ("hook_em"). See Bug 2 in the code review. */
static int s_rest_class_idx = 0;
/* EMA probability state */
static float s_smoothed[MODEL_NUM_CLASSES];
/* Majority vote ring buffer (window = 5) */
static int s_vote_history[5];
static int s_vote_head = 0;
/* Debounce state */
static int s_current_output = -1;
static int s_pending_output = -1;
static int s_pending_count = 0;
/* ── Generic LDA softmax ──────────────────────────────────────────────────── */
/**
* Compute softmax class probabilities from a flat feature vector.
*
* @param feat Feature vector (contiguous, length n_feat).
* @param n_feat Number of features.
* @param weights_flat Row-major weight matrix, shape [n_classes][n_feat].
* @param intercepts Intercept vector, length n_classes.
* @param n_classes Number of output classes.
* @param proba_out Output probabilities, length n_classes (caller-allocated).
*/
static void lda_softmax(const float *feat, int n_feat,
const float *weights_flat, const float *intercepts,
int n_classes, float *proba_out) {
float raw[MODEL_NUM_CLASSES];
float max_raw = -1e9f;
float sum_exp = 0.0f;
for (int c = 0; c < n_classes; c++) {
raw[c] = intercepts[c];
const float *w = weights_flat + c * n_feat;
for (int f = 0; f < n_feat; f++) {
raw[c] += feat[f] * w[f];
}
if (raw[c] > max_raw) max_raw = raw[c];
}
for (int c = 0; c < n_classes; c++) {
proba_out[c] = expf(raw[c] - max_raw);
sum_exp += proba_out[c];
}
for (int c = 0; c < n_classes; c++) {
proba_out[c] /= sum_exp;
}
}
/* ── Public API ───────────────────────────────────────────────────────────── */
void inference_ensemble_init(void) {
/* Find the class index for "rest" once, so we can return it correctly
* from the activity gate without confusing class indices with gesture_t. */
s_rest_class_idx = 0;
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
if (strcmp(MODEL_CLASS_NAMES[i], "rest") == 0) {
s_rest_class_idx = i;
break;
}
}
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
s_smoothed[c] = 1.0f / MODEL_NUM_CLASSES;
}
for (int i = 0; i < 5; i++) {
s_vote_history[i] = -1;
}
s_vote_head = 0;
s_current_output = -1;
s_pending_output = -1;
s_pending_count = 0;
}
void inference_ensemble_predict_raw(const float *features, float *proba_out) {
/* Gather TD features (non-contiguous: 12 per channel × 3 channels) */
float td_buf[TD_NUM_FEATURES];
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
memcpy(td_buf + ch * 12,
features + ch * ENSEMBLE_PER_CH_FEATURES,
12 * sizeof(float));
}
/* Gather FD features (non-contiguous: 8 per channel × 3 channels) */
float fd_buf[FD_NUM_FEATURES];
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
memcpy(fd_buf + ch * 8,
features + ch * ENSEMBLE_PER_CH_FEATURES + 12,
8 * sizeof(float));
}
/* CC features are already contiguous at the end (indices 60-68) */
const float *cc_buf = features + CC_FEAT_OFFSET;
/* Specialist LDA predictions */
float prob_td[MODEL_NUM_CLASSES];
float prob_fd[MODEL_NUM_CLASSES];
float prob_cc[MODEL_NUM_CLASSES];
lda_softmax(td_buf, TD_NUM_FEATURES,
(const float *)LDA_TD_WEIGHTS, LDA_TD_INTERCEPTS,
MODEL_NUM_CLASSES, prob_td);
lda_softmax(fd_buf, FD_NUM_FEATURES,
(const float *)LDA_FD_WEIGHTS, LDA_FD_INTERCEPTS,
MODEL_NUM_CLASSES, prob_fd);
lda_softmax(cc_buf, CC_NUM_FEATURES,
(const float *)LDA_CC_WEIGHTS, LDA_CC_INTERCEPTS,
MODEL_NUM_CLASSES, prob_cc);
/* Meta-LDA stacker */
float meta_in[META_NUM_INPUTS];
memcpy(meta_in, prob_td, MODEL_NUM_CLASSES * sizeof(float));
memcpy(meta_in + MODEL_NUM_CLASSES, prob_fd, MODEL_NUM_CLASSES * sizeof(float));
memcpy(meta_in + 2*MODEL_NUM_CLASSES, prob_cc, MODEL_NUM_CLASSES * sizeof(float));
lda_softmax(meta_in, META_NUM_INPUTS,
(const float *)META_LDA_WEIGHTS, META_LDA_INTERCEPTS,
MODEL_NUM_CLASSES, proba_out);
}
int inference_ensemble_predict(float *confidence) {
/* 1. Extract + calibrate features (shared with single-model path) */
float features[MODEL_NUM_FEATURES];
inference_extract_features(features); /* includes calibration_apply() */
/* 2. Activity gate — skip inference during obvious REST */
float total_rms_sq = 0.0f;
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
/* RMS is the first feature per channel (index 0 in each 20-feature block) */
float r = features[ch * ENSEMBLE_PER_CH_FEATURES];
total_rms_sq += r * r;
}
if (sqrtf(total_rms_sq) < REST_ACTIVITY_THRESHOLD) {
*confidence = 1.0f;
return s_rest_class_idx;
}
/* 3. Run ensemble pipeline (raw probabilities) */
float meta_probs[MODEL_NUM_CLASSES];
inference_ensemble_predict_raw(features, meta_probs);
/* 7. EMA smoothing on meta output */
float max_smooth = 0.0f;
int winner = 0;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
s_smoothed[c] = ENSEMBLE_EMA_ALPHA * s_smoothed[c] +
(1.0f - ENSEMBLE_EMA_ALPHA) * meta_probs[c];
if (s_smoothed[c] > max_smooth) {
max_smooth = s_smoothed[c];
winner = c;
}
}
/* 8. Confidence cascade: escalate to MLP if meta-LDA is uncertain */
if (max_smooth < ENSEMBLE_CONF_THRESHOLD) {
float mlp_conf = 0.0f;
int mlp_winner = inference_mlp_predict(features, MODEL_NUM_FEATURES, &mlp_conf);
if (mlp_conf > max_smooth) {
winner = mlp_winner;
max_smooth = mlp_conf;
}
}
/* 9. Reject if still uncertain — hold current output */
if (max_smooth < REJECT_THRESHOLD) {
*confidence = max_smooth;
return s_current_output >= 0 ? s_current_output : s_rest_class_idx;
}
*confidence = max_smooth;
/* 10. Majority vote (window = 5) */
s_vote_history[s_vote_head] = winner;
s_vote_head = (s_vote_head + 1) % 5;
int counts[MODEL_NUM_CLASSES];
memset(counts, 0, sizeof(counts));
for (int i = 0; i < 5; i++) {
if (s_vote_history[i] >= 0) {
counts[s_vote_history[i]]++;
}
}
int majority = 0, majority_cnt = 0;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
if (counts[c] > majority_cnt) {
majority_cnt = counts[c];
majority = c;
}
}
/* 11. Debounce — need 3 consecutive predictions to change output */
int final_out = (s_current_output >= 0) ? s_current_output : majority;
if (s_current_output < 0) {
/* First prediction ever */
s_current_output = majority;
final_out = majority;
} else if (majority == s_current_output) {
/* Staying in current gesture — reset pending */
s_pending_output = majority;
s_pending_count = 1;
} else if (majority == s_pending_output) {
/* Same new gesture again */
if (++s_pending_count >= 3) {
s_current_output = majority;
final_out = majority;
}
} else {
/* New gesture candidate — start counting */
s_pending_output = majority;
s_pending_count = 1;
}
return final_out;
}
#else /* MODEL_USE_ENSEMBLE == 0 — compile-time stubs */
void inference_ensemble_init(void) {}
void inference_ensemble_predict_raw(const float *features, float *proba_out) {
(void)features;
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
proba_out[c] = 1.0f / MODEL_NUM_CLASSES;
}
int inference_ensemble_predict(float *confidence) {
if (confidence) *confidence = 0.0f;
return 0;
}
#endif /* MODEL_USE_ENSEMBLE */

View File

@@ -0,0 +1,44 @@
/**
* @file inference_ensemble.h
* @brief 3-specialist-LDA + meta-LDA ensemble inference pipeline (Change F).
*
* Requires:
* - Change 1 expanded features (MODEL_EXPAND_FEATURES 1)
* - Change 7 training (train_ensemble.py) to generate model_weights_ensemble.h
* - Change E MLP (MODEL_USE_MLP 1) for confidence-cascade fallback
*
* Enable by setting MODEL_USE_ENSEMBLE 1 in model_weights.h and calling
* inference_ensemble_init() from app_main() instead of (or alongside) inference_init().
*/
#pragma once
#include <stdbool.h>
/**
* @brief Initialise ensemble state (EMA, vote history, debounce).
* Must be called before inference_ensemble_predict().
*/
void inference_ensemble_init(void);
/**
* @brief Compute ensemble probabilities without smoothing/voting/debounce.
*
* Runs the 3 specialist LDAs + meta-LDA stacker and writes the raw meta-LDA
* probabilities to proba_out. Used by the multi-model voting path in main.c.
*
* @param features Calibrated feature vector (MODEL_NUM_FEATURES floats).
* @param proba_out Output probability array (MODEL_NUM_CLASSES floats).
*/
void inference_ensemble_predict_raw(const float *features, float *proba_out);
/**
* @brief Run one inference hop through the full ensemble pipeline.
*
* Internally calls inference_extract_features() to pull the latest window,
* routes through the three specialist LDAs, the meta-LDA stacker, EMA
* smoothing, majority vote, and debounce.
*
* @param confidence Output: winning class smoothed probability [0,1].
* @return Gesture enum value (same as inference_predict() return).
*/
int inference_ensemble_predict(float *confidence);

View File

@@ -0,0 +1,75 @@
/**
* @file inference_mlp.cc
* @brief int8 MLP inference via TFLite Micro (Change E).
*
* Compiled as C++ because TFLite Micro headers require C++.
* Guarded by MODEL_USE_MLP — provides compile-time stubs when 0.
*/
#include "inference_mlp.h"
#include "model_weights.h"
#if MODEL_USE_MLP
#include "emg_model_data.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
#include "tensorflow/lite/schema/schema_generated.h"
static uint8_t tensor_arena[48 * 1024]; /* 48 KB — tune down if memory is tight */
static tflite::MicroInterpreter *s_interpreter = nullptr;
static TfLiteTensor *s_input = nullptr;
static TfLiteTensor *s_output = nullptr;
void inference_mlp_init(void) {
const tflite::Model *model = tflite::GetModel(g_model);
static tflite::MicroMutableOpResolver<4> resolver;
resolver.AddFullyConnected();
resolver.AddRelu();
resolver.AddSoftmax();
resolver.AddDequantize();
static tflite::MicroInterpreter interp(model, resolver,
tensor_arena, sizeof(tensor_arena));
s_interpreter = &interp;
s_interpreter->AllocateTensors();
s_input = s_interpreter->input(0);
s_output = s_interpreter->output(0);
}
int inference_mlp_predict(const float *features, int n_feat, float *conf_out) {
/* Quantise input: int8 = round(float / scale) + zero_point */
const float iscale = s_input->params.scale;
const int izp = s_input->params.zero_point;
for (int i = 0; i < n_feat; i++) {
int q = static_cast<int>(roundf(features[i] / iscale)) + izp;
if (q < -128) q = -128;
if (q > 127) q = 127;
s_input->data.int8[i] = static_cast<int8_t>(q);
}
s_interpreter->Invoke();
/* Dequantise output and find winning class */
const float oscale = s_output->params.scale;
const int ozp = s_output->params.zero_point;
float max_p = -1e9f;
int max_c = 0;
const int n_cls = s_output->dims->data[1];
for (int c = 0; c < n_cls; c++) {
float p = (s_output->data.int8[c] - ozp) * oscale;
if (p > max_p) { max_p = p; max_c = c; }
}
*conf_out = max_p;
return max_c;
}
#else /* MODEL_USE_MLP == 0 — compile-time stubs */
void inference_mlp_init(void) {}
int inference_mlp_predict(const float * /*features*/, int /*n_feat*/, float *conf_out) {
if (conf_out) *conf_out = 0.0f;
return 0;
}
#endif /* MODEL_USE_MLP */

View File

@@ -0,0 +1,40 @@
/**
* @file inference_mlp.h
* @brief int8 MLP inference via TFLite Micro (Change E).
*
* Enable by:
* 1. Setting MODEL_USE_MLP 1 in model_weights.h.
* 2. Running train_mlp_tflite.py to generate emg_model_data.cc.
* 3. Adding TFLite Micro to platformio.ini lib_deps.
*
* When MODEL_USE_MLP 0, both functions are empty stubs that compile without
* TFLite Micro headers.
*/
#pragma once
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
/**
* @brief Initialise TFLite Micro interpreter and allocate tensors.
* Must be called before inference_mlp_predict().
* No-op when MODEL_USE_MLP 0.
*/
void inference_mlp_init(void);
/**
* @brief Run one forward pass through the int8 MLP.
*
* @param features Float32 feature vector (same order as Python extractor).
* @param n_feat Number of features (must match model input size).
* @param conf_out Output: winning class softmax probability [0,1].
* @return Winning class index. Returns 0 when MODEL_USE_MLP 0.
*/
int inference_mlp_predict(const float *features, int n_feat, float *conf_out);
#ifdef __cplusplus
}
#endif

View File

@@ -1,7 +1,7 @@
/** /**
* @file model_weights.h * @file model_weights.h
* @brief Trained LDA model weights exported from Python. * @brief Trained LDA model weights exported from Python.
* @date 2026-01-27 21:35:17 * @date 2026-03-08 20:27:52
*/ */
#ifndef MODEL_WEIGHTS_H #ifndef MODEL_WEIGHTS_H
@@ -11,7 +11,14 @@
/* Metadata */ /* Metadata */
#define MODEL_NUM_CLASSES 5 #define MODEL_NUM_CLASSES 5
#define MODEL_NUM_FEATURES 16 #define MODEL_NUM_FEATURES 69
#define MODEL_NORMALIZE_FEATURES 1
/* Compile-time feature flags */
#define MODEL_EXPAND_FEATURES 1
#define MODEL_USE_REINHARD 1
#define MODEL_USE_MLP 1
#define MODEL_USE_ENSEMBLE 1
/* Class Names */ /* Class Names */
static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = { static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
@@ -28,35 +35,70 @@ static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
/* LDA Intercepts/Biases */ /* LDA Intercepts/Biases */
static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = { static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {
-14.097581f, -2.018629f, -4.478267f, 1.460458f, -5.562349f -2.361446f, -2.300285f, -3.189113f, -1.744772f, -2.137618f
}; };
/* LDA Coefficients (Weights) */ /* LDA Coefficients (Weights) */
static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = { static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {
/* fist */ /* fist */
{ {
0.070110f, -0.002554f, 0.043924f, 0.020555f, -0.660305f, 0.010691f, -0.074429f, -0.037253f, -1.005012f, 1.102594f, 0.220576f, -0.171785f, 0.160110f, 1.460799f, 0.160097f, 0.023789f,
0.057908f, -0.002655f, 0.042119f, -0.052956f, 0.063822f, 0.006184f, -0.025462f, 0.040815f, 0.316695f, 0.551076f, 0.193143f, -0.030601f, -1.068705f, -0.414611f, -0.040128f, 1.215391f,
0.945437f, 0.545068f, 0.736472f, 0.328106f, 1.212364f, -0.769889f, 0.428595f, -0.033948f,
0.181605f, -3.236872f, 0.181581f, -0.061825f, 1.166229f, 0.338513f, 0.207578f, 0.041526f,
0.793306f, 0.256971f, -0.062490f, 2.940555f, 0.251149f, 0.096629f, 0.124341f, -0.125052f,
0.362799f, 0.430005f, 0.106206f, -0.201260f, 0.529469f, -1.578447f, 0.529456f, -0.040727f,
1.636684f, 0.129321f, 0.011299f, 0.086495f, 1.510614f, -0.578640f, -0.025043f, 0.275174f,
0.300145f, 0.366742f, -0.104204f, -0.125212f, 0.819188f, 2.292691f, -0.957544f, -0.321092f,
-0.173973f, 0.306428f, -1.015226f, 0.770284f, 0.797008f
}, },
/* hook_em */ /* hook_em */
{ {
-0.002511f, 0.001034f, 0.027889f, 0.026006f, 0.183681f, -0.000773f, 0.016791f, -0.027926f, 1.694859f, -0.653387f, 0.521604f, 0.251859f, -0.706711f, 1.182055f, -0.706678f, 0.008488f,
-0.023321f, 0.000770f, 0.059023f, -0.056021f, 0.237063f, -0.007423f, 0.082101f, -0.021472f, 2.450049f, 0.068430f, -0.002051f, 0.214216f, 2.641921f, -0.677963f, -0.020216f, -0.084562f,
0.314933f, 0.166270f, -0.174888f, 0.023982f, -5.326996f, -0.143884f, -0.165025f, 0.014950f,
0.167851f, 3.526903f, 0.167886f, 0.139922f, 0.702011f, 0.197141f, 0.087553f, 0.093573f,
1.414652f, 0.237107f, -0.044205f, -3.707572f, 0.146742f, 0.025417f, -0.085462f, -0.227836f,
1.745303f, 0.408549f, -0.224446f, 0.081996f, 0.432159f, 0.151780f, 0.432186f, 0.009932f,
-1.447145f, -0.136321f, -0.071784f, -0.118749f, -1.872055f, 0.490270f, 0.017250f, -0.667072f,
-0.381349f, -0.429662f, -0.084705f, -0.123111f, -0.213680f, -3.070284f, 0.267574f, 0.343828f,
1.024997f, -0.347210f, 1.222141f, 4.897431f, -1.097906f
}, },
/* open */ /* open */
{ {
-0.006170f, 0.000208f, -0.041151f, 0.013271f, 0.054508f, -0.002356f, 0.000170f, 0.012941f, 2.756912f, 0.029446f, -0.408740f, -0.020121f, 0.375563f, -3.479846f, 0.375601f, -0.022014f,
-0.106180f, 0.003538f, -0.013656f, -0.017712f, 0.131131f, -0.002623f, -0.007022f, 0.024497f, -1.357556f, -0.445301f, -0.119886f, -0.020492f, -0.494526f, 0.247305f, 0.028053f, -0.106328f,
-0.585811f, -0.405593f, -0.486449f, -0.319108f, -3.159308f, 0.871359f, 0.611153f, -0.191602f,
0.059082f, 4.326131f, 0.059172f, -0.006684f, -0.502170f, -0.110067f, -0.057217f, -0.224138f,
-1.214966f, -1.056423f, -0.019260f, -0.878641f, 0.167700f, 0.576919f, -0.143247f, -0.914332f,
-2.970328f, -0.403648f, -0.014340f, 0.048198f, 0.332494f, 0.479158f, 0.332573f, -0.002956f,
-0.351570f, 0.181890f, 0.109983f, -0.013305f, 0.133854f, -0.180451f, -0.046355f, -0.355403f,
-0.123445f, 0.146366f, 0.317895f, 0.342519f, -0.399778f, -6.284204f, 0.422619f, -0.109523f,
-0.996114f, 0.059564f, 0.043680f, -3.046475f, -0.178861f
}, },
/* rest */ /* rest */
{ {
-0.011094f, 0.000160f, -0.012547f, -0.011058f, 0.130577f, -0.001942f, 0.020823f, -0.001961f, -1.609762f, -0.516942f, -0.285630f, -0.108276f, 0.196206f, 0.488562f, 0.196196f, 0.026133f,
0.018021f, -0.000404f, -0.065598f, 0.039676f, 0.018679f, -0.001522f, 0.023302f, -0.008474f, -1.189232f, -0.034391f, -0.062033f, -0.130823f, -1.033252f, 0.697769f, 0.032136f, -0.605629f,
-0.007999f, -0.097476f, 0.094768f, 0.333748f, 1.317443f, -0.880336f, -0.213594f, -0.073489f,
-0.207093f, -4.451737f, -0.207141f, 0.000240f, -1.645825f, -0.245072f, -0.026848f, 0.029140f,
-0.364751f, 0.395671f, 0.066683f, 3.611272f, -0.209066f, -0.214834f, 0.027566f, 0.514329f,
1.066669f, -0.684801f, -0.070219f, -0.046770f, -0.494221f, 0.722863f, -0.494243f, 0.008539f,
-0.124061f, -0.075996f, -0.084995f, -0.012796f, 0.291955f, 0.121964f, 0.042458f, -0.627127f,
0.001798f, -0.199790f, -0.153351f, 0.027358f, -0.117621f, 1.819176f, 0.092071f, -0.017248f,
0.251629f, 0.021265f, -0.106236f, 0.495258f, 0.069582f
}, },
/* thumbs_up */ /* thumbs_up */
{ {
-0.016738f, 0.000488f, 0.024199f, -0.024643f, -0.044912f, 0.000153f, -0.011080f, 0.043487f, -0.792044f, 0.374311f, 0.170137f, 0.127379f, -0.184879f, 0.153073f, -0.184920f, -0.053070f,
0.051828f, -0.001670f, 0.109633f, 0.004154f, -0.460694f, 0.008616f, -0.104097f, -0.020886f, 0.682094f, -0.099986f, 0.036795f, 0.061637f, 0.725941f, -0.348358f, -0.022970f, -0.004588f,
-0.632688f, -0.124638f, -0.221206f, -0.580598f, 5.030162f, 1.484454f, -0.536655f, 0.339228f,
-0.058585f, 2.787274f, -0.058607f, -0.068406f, 1.425498f, -0.005883f, -0.188273f, 0.049091f,
-0.309501f, -0.063155f, 0.013505f, -4.453858f, -0.215617f, -0.354558f, 0.060673f, 0.425839f,
-0.794436f, 0.734764f, 0.245384f, 0.149158f, -0.463152f, -0.279041f, -0.463210f, 0.019611f,
0.350645f, -0.055642f, 0.087931f, 0.064769f, -0.304954f, 0.079455f, -0.015292f, 1.791572f,
0.196916f, 0.237732f, 0.116177f, -0.153250f, 0.000654f, 4.127202f, 0.103829f, 0.125913f,
-0.223239f, -0.063374f, -0.047960f, -3.238916f, 0.344427f
}, },
}; };

View File

@@ -0,0 +1,65 @@
// Auto-generated by train_ensemble.py — do not edit
#pragma once
// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from
// model_weights.h to avoid redefinition conflicts.
#include "model_weights.h"
#define ENSEMBLE_PER_CH_FEATURES 20
#define TD_FEAT_OFFSET 0
#define TD_NUM_FEATURES 36
#define FD_FEAT_OFFSET 12
#define FD_NUM_FEATURES 24
#define CC_FEAT_OFFSET 60
#define CC_NUM_FEATURES 9
#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)
// Feature index arrays for gather operations (TD and FD are non-contiguous)
// TD indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51]
// FD indices: [12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39, 52, 53, 54, 55, 56, 57, 58, 59]
// CC indices: [60, 61, 62, 63, 64, 65, 66, 67, 68]
const float LDA_TD_WEIGHTS[5][36] = {
{-4.98005951f, 3.15698227f, -0.01656110f, 0.16589549f, 1.89759250f, 4.02171634f, 1.89759264f, -0.02318365f, 0.68313384f, 1.08146976f, 0.24790972f, 0.02111336f, 1.40219525f, -0.34718600f, 0.03292437f, 0.03224236f, 0.07318621f, -1.47505917f, 0.07318595f, -0.15298836f, 0.68723991f, 0.31478366f, 0.26694627f, 0.03383652f, 1.16109002f, -0.17694826f, -0.02001825f, 0.04811623f, 0.11145582f, 0.98011591f, 0.11145584f, 0.02078491f, 0.16441538f, -0.37776186f, -0.14062534f, -0.20666287f}, // fist
{7.06735173f, -0.07584252f, 0.28878350f, 0.25548171f, -2.72040928f, -1.14291179f, -2.72040944f, -0.05457462f, 0.67944659f, -1.72471154f, -0.56168071f, 0.08847797f, -0.36636016f, -0.92733589f, 0.07186714f, -0.15407635f, 0.09285758f, -0.99197678f, 0.09285799f, 0.24903098f, -0.58741170f, -1.69916941f, -0.79791300f, -0.51257426f, -4.23844985f, -1.30822564f, 0.08744698f, -0.10944734f, 1.12572192f, 0.67459317f, 1.12572190f, 0.12677750f, -0.06012924f, 0.74169191f, 0.20452265f, 0.02420439f}, // hook_em
{-5.68183942f, -2.70544312f, -0.02268383f, -0.07857558f, 0.94259344f, -1.41885234f, 0.94259353f, 0.12889098f, -0.30338581f, 0.37920265f, 0.06556953f, 0.00592915f, -6.03034018f, 1.95694619f, 0.22083802f, 0.33970496f, 0.85733980f, 4.45235717f, 0.85733951f, 0.16720219f, 0.59151687f, 1.12441071f, 0.07520788f, -0.44451340f, -0.73662069f, 1.71355275f, -0.19822542f, 0.32955834f, 0.05661870f, -1.29613804f, 0.05661887f, -0.21196686f, -0.61895423f, -1.71535920f, -0.73493998f, -0.21009359f}, // open
{0.20260846f, -0.23909181f, -0.28033711f, -0.22450103f, 0.13810806f, -0.88146743f, 0.13810809f, 0.01613534f, -0.34304575f, 1.18327519f, 0.36712706f, -0.09916758f, 1.33818062f, -0.49450303f, 0.03771161f, -0.21199101f, -0.57758751f, -0.06415963f, -0.57758750f, -0.11776329f, -0.49294313f, 0.47593171f, 0.42732005f, 0.38381135f, 0.93751846f, -0.40085712f, -0.01569340f, -0.21022401f, -0.70330953f, 0.29469121f, -0.70330957f, -0.03953274f, -0.19099800f, 1.11867928f, 0.51910780f, 0.26138187f}, // rest
{3.56990173f, 0.11258470f, 0.22746547f, 0.04228335f, -0.43704307f, 0.04449194f, -0.43704318f, -0.08323209f, -0.45652127f, -1.76552235f, -0.38107651f, 0.05260073f, 2.92141069f, 0.06553870f, -0.39303365f, 0.12369768f, -0.07810754f, -2.03448136f, -0.07810740f, -0.06529502f, 0.10545265f, -0.60931490f, -0.28035188f, 0.28069772f, 2.16483999f, 0.36093444f, 0.16473959f, 0.07185201f, -0.08960941f, -0.79422744f, -0.08960951f, 0.13979645f, 0.85101150f, -0.45589633f, -0.17267249f, -0.03933928f}, // thumbs_up
};
const float LDA_TD_INTERCEPTS[5] = {
-5.59817408f, -4.91484804f, -9.33780358f, -3.85428822f, -3.12479496f
};
const float LDA_FD_WEIGHTS[5][24] = {
{0.19477058f, 0.33798486f, 0.24311717f, 3.58490717f, 0.37449334f, 0.01733587f, 0.99771453f, 0.35846897f, -0.07609940f, -0.08551280f, -0.04212976f, -1.64981087f, -0.28267724f, -0.45724067f, 0.09742739f, -0.45412877f, -0.67354231f, 0.07763610f, -0.03042757f, 1.23057256f, 0.47423480f, 0.96870723f, 0.09008234f, 0.61253227f}, // fist
{-1.33051949f, -1.06668043f, -0.13054796f, 1.49029766f, 0.34559104f, 2.58938944f, 2.10036791f, 0.00627139f, 0.67799858f, -0.05916871f, -0.02308780f, -0.51389981f, 0.49698104f, 0.52080687f, -0.53642075f, -0.44702479f, 0.45831951f, -0.05527688f, -0.12666255f, -2.29960756f, 0.15527345f, -0.60190848f, -0.50386274f, -0.20863809f}, // hook_em
{0.86684408f, -0.18732078f, 0.04828207f, -4.18979500f, -0.25709076f, -1.34726223f, -1.93439169f, -0.36975670f, -0.57908022f, 0.40713764f, 0.14203746f, 3.81070011f, 0.15930714f, 0.81375493f, 0.17586089f, 1.21709829f, 0.30729211f, -0.06304829f, 0.23529519f, 2.10099132f, -1.15363931f, -0.25856679f, 0.65430329f, -0.77828374f}, // open
{0.10167142f, 0.92873579f, -0.03451580f, -0.35104059f, 0.33898659f, -0.99047210f, -1.06006192f, 0.02907091f, -0.22314490f, 0.33956274f, -0.07111302f, 0.31329279f, -0.56910921f, -0.79014171f, -0.38388114f, 0.01131879f, 0.87387134f, -0.12380692f, 0.05214671f, -0.42582976f, -0.17031139f, -1.06661596f, -0.91600447f, -0.26435070f}, // rest
{0.04198304f, -0.65837713f, -0.10663077f, -0.12701858f, -1.01290184f, 0.50100717f, 0.72188030f, -0.03150884f, 0.38356910f, -0.84413736f, 0.03745147f, -2.29871124f, 0.58565208f, 0.43331566f, 0.88777658f, -0.38229432f, -1.55494079f, 0.24865587f, -0.17540676f, -0.43032659f, 0.84763042f, 1.67359692f, 1.26211059f, 0.83625332f}, // thumbs_up
};
const float LDA_FD_INTERCEPTS[5] = {
-5.22055861f, -3.76092770f, -8.23556500f, -3.59434123f, -2.96625341f
};
const float LDA_CC_WEIGHTS[5][9] = {
{-0.94928369f, 2.44942094f, 0.26441444f, -0.80355120f, -1.19821936f, 1.39871694f, -2.06386346f, 0.59615281f, 1.57632161f}, // fist
{-2.17638786f, 1.40504639f, 1.97661862f, 0.99132032f, 1.46231208f, -0.77606652f, 2.43631056f, 1.58734751f, -1.75342004f}, // hook_em
{-0.15938422f, -2.49617253f, -0.43491962f, 0.32764670f, -3.51157144f, -0.74698502f, 2.59991544f, -0.64230610f, -3.23294039f}, // open
{2.78062805f, -1.24792206f, -2.03404274f, -0.13924844f, 2.03596007f, -0.25590903f, -0.40921478f, -0.84411543f, 0.27031906f}, // rest
{-1.42317081f, 0.84654172f, 1.66121800f, -0.27035074f, -0.03025088f, 0.56047099f, -2.30886825f, -0.06801108f, 3.01175668f}, // thumbs_up
};
const float LDA_CC_INTERCEPTS[5] = {
-2.77069779f, -4.07631951f, -6.23624226f, -1.94638886f, -2.30392172f
};
const float META_LDA_WEIGHTS[5][15] = {
{13.83327682f, -4.99575141f, -8.19270319f, -2.93944549f, 0.75007499f, 3.36110029f, 0.53962472f, -4.82143659f, -1.61056244f, 1.48044754f, 1.56410113f, -0.98537108f, 0.20587945f, 0.39305391f, -1.42106273f}, // fist
{-4.64514194f, 13.09553405f, -4.41874813f, -1.25841220f, -2.98757893f, 0.16270053f, 3.15820296f, -2.24183188f, -1.32266753f, 0.20766953f, -0.65827396f, 1.59311387f, -1.22803182f, 0.04336258f, -0.24531550f}, // hook_em
{-5.79187671f, -2.27375570f, 23.21490262f, -0.09093684f, -4.81706827f, -1.92030254f, -1.47383708f, 9.25922135f, 0.57114390f, -2.76993062f, -1.31290246f, -1.07932832f, 5.46202897f, -0.72440402f, -0.19414883f}, // open
{-2.87608878f, -1.27035778f, -3.11538939f, 4.40275586f, -1.44999243f, -0.49668812f, -1.24040434f, 0.65617176f, 1.79490014f, -1.52075301f, 0.08593208f, 0.38743884f, -1.35293868f, 0.30551755f, -0.06438460f}, // rest
{1.52549481f, -3.33085871f, -6.18851152f, -3.12231842f, 9.54599696f, -0.69931080f, -0.02514346f, -3.63196650f, -0.69837189f, 3.71692816f, 0.29062506f, -0.11110049f, -2.35881642f, -0.20019868f, 1.96143145f}, // thumbs_up
};
const float META_LDA_INTERCEPTS[5] = {
-8.46795016f, -8.64988671f, -21.16256224f, -3.22663037f, -6.35767829f
};

View File

@@ -1,109 +1,180 @@
/** /**
* @file emg_sensor.c * @file emg_sensor.c
* @brief EMG sensor driver implementation. * @brief EMG sensor driver — DMA-backed continuous ADC acquisition.
* *
* Provides EMG readings - fake data for now, real ADC when sensors arrive. * Change A: adc_continuous (DMA) replaces adc_oneshot polling.
* A background FreeRTOS task on Core 0 reads DMA frames, assembles
* complete 4-channel sample sets, and pushes them to a FreeRTOS queue.
* emg_sensor_read() blocks on that queue; at 1 kHz per channel this
* returns within ~1 ms, providing exact timing without vTaskDelay(1).
*/ */
#include "emg_sensor.h" #include "emg_sensor.h"
#include "esp_timer.h" #include "esp_timer.h"
#include <stdlib.h>
#include <stdio.h>
#include "freertos/FreeRTOS.h" #include "freertos/FreeRTOS.h"
#include "freertos/task.h" #include "freertos/task.h"
#include "esp_adc/adc_oneshot.h" #include "freertos/queue.h"
#include "esp_adc/adc_continuous.h"
#include "esp_adc/adc_cali.h" #include "esp_adc/adc_cali.h"
#include "esp_adc/adc_cali_scheme.h" #include "esp_adc/adc_cali_scheme.h"
#include "esp_err.h" #include "esp_err.h"
#include <assert.h>
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
adc_oneshot_unit_handle_t adc1_handle; // --- ADC DMA constants ---
adc_cali_handle_t cali_handle = NULL; // Total sample rate: 4 channels × 1 kHz = 4 kHz
#define ADC_TOTAL_SAMPLE_RATE_HZ (EMG_NUM_CHANNELS * EMG_SAMPLE_RATE_HZ)
// DMA frame: 64 conversions × 4 bytes = 256 bytes, arrives ~every 16 ms
#define ADC_CONV_FRAME_SIZE (64u * sizeof(adc_digi_output_data_t))
// Internal DMA pool: 2× conv_frame_size
#define ADC_POOL_SIZE (2u * ADC_CONV_FRAME_SIZE)
// Sample queue: buffers up to ~32 ms of assembled 4-channel sets
#define SAMPLE_QUEUE_DEPTH 32
const uint8_t emg_channels[EMG_NUM_CHANNELS] = { // --- Static handles ---
ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0 static adc_continuous_handle_t s_adc_handle = NULL;
ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1 static adc_cali_handle_t s_cali_handle = NULL;
ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2 static QueueHandle_t s_sample_queue = NULL;
ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3
// Channel mapping (ADC1, GPIO 2/3/9/10)
static const adc_channel_t s_channels[EMG_NUM_CHANNELS] = {
ADC_CHANNEL_1, // GPIO 2 — FCR/Belly (ch0)
ADC_CHANNEL_2, // GPIO 3 — Extensors (ch1)
ADC_CHANNEL_8, // GPIO 9 — FCU/Outer Flexors (ch2)
ADC_CHANNEL_9, // GPIO 10 — Bicep (ch3)
}; };
/*******************************************************************************
* ADC Sampling Task (Core 0)
******************************************************************************/
/**
* @brief DMA sampling task.
*
* Reads adc_continuous DMA frames, assembles complete 4-channel sample sets
* (one value per channel), applies curve-fitting calibration, and pushes
* emg_sample_t structs to s_sample_queue. Runs continuously on Core 0.
*/
static void adc_sampling_task(void *arg) {
uint8_t *buf = (uint8_t *)malloc(ADC_CONV_FRAME_SIZE);
if (!buf) {
printf("[EMG] ERROR: DMA read buffer alloc failed\n");
vTaskDelete(NULL);
return;
}
// Per-channel raw accumulator; holds the latest sample for each channel
int raw_set[EMG_NUM_CHANNELS];
bool got[EMG_NUM_CHANNELS];
memset(raw_set, 0, sizeof(raw_set));
memset(got, 0, sizeof(got));
while (1) {
uint32_t out_len = 0;
esp_err_t err = adc_continuous_read(
s_adc_handle, buf, (uint32_t)ADC_CONV_FRAME_SIZE,
&out_len, pdMS_TO_TICKS(100)
);
if (err != ESP_OK || out_len == 0) {
continue;
}
uint32_t n = out_len / sizeof(adc_digi_output_data_t);
adc_digi_output_data_t *p = (adc_digi_output_data_t *)buf;
for (uint32_t i = 0; i < n; i++) {
int ch = (int)p[i].type2.channel;
int raw = (int)p[i].type2.data;
// Map ADC channel index to sensor channel
for (int c = 0; c < EMG_NUM_CHANNELS; c++) {
if ((int)s_channels[c] == ch) {
raw_set[c] = raw;
got[c] = true;
break;
}
}
// Emit a complete sample set once all channels have been updated
bool all = true;
for (int c = 0; c < EMG_NUM_CHANNELS; c++) {
if (!got[c]) { all = false; break; }
}
if (!all) continue;
emg_sample_t s;
s.timestamp_ms = emg_sensor_get_timestamp_ms();
for (int c = 0; c < EMG_NUM_CHANNELS; c++) {
int mv = 0;
adc_cali_raw_to_voltage(s_cali_handle, raw_set[c], &mv);
s.channels[c] = (uint16_t)mv;
}
// Non-blocking send: drop if queue is full (prefer fresh data)
xQueueSend(s_sample_queue, &s, 0);
// Reset accumulator for the next complete set
memset(got, 0, sizeof(got));
}
}
}
/******************************************************************************* /*******************************************************************************
* Public Functions * Public Functions
******************************************************************************/ ******************************************************************************/
void emg_sensor_init(void) void emg_sensor_init(void) {
{ // 1. Curve-fitting calibration (same scheme as before)
#if FEATURE_FAKE_EMG adc_cali_curve_fitting_config_t cali_cfg = {
/* Seed random number generator for fake data */ .unit_id = ADC_UNIT_1,
srand((unsigned int)esp_timer_get_time()); .atten = ADC_ATTEN_DB_12,
#else
// 1. --- ADC Unit Setup ---
adc_oneshot_unit_init_cfg_t init_config1 = {
.unit_id = ADC_UNIT_1,
.ulp_mode = ADC_ULP_MODE_DISABLE,
};
ESP_ERROR_CHECK(adc_oneshot_new_unit(&init_config1, &adc1_handle));
// 2. --- ADC Channel Setup (GPIO 1?) ---
// Ensure the channel matches your GPIO in pinmap. For ADC1, GPIO1 is usually not CH0.
// Check your datasheet! (e.g., on S3, GPIO 1 is ADC1_CH0)
adc_oneshot_chan_cfg_t config = {
.bitwidth = ADC_BITWIDTH_DEFAULT, // 12-bit for S3
.atten = ADC_ATTEN_DB_12, // Allows up to ~3.1V
};
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++)
ESP_ERROR_CHECK(adc_oneshot_config_channel(adc1_handle, emg_channels[i], &config));
// 3. --- Calibration Setup (CORRECTED for S3) ---
// ESP32-S3 uses Curve Fitting, not Line Fitting
adc_cali_curve_fitting_config_t cali_config = {
.unit_id = ADC_UNIT_1,
.atten = ADC_ATTEN_DB_12,
.bitwidth = ADC_BITWIDTH_DEFAULT, .bitwidth = ADC_BITWIDTH_DEFAULT,
}; };
ESP_ERROR_CHECK(adc_cali_create_scheme_curve_fitting(&cali_config, &cali_handle)); ESP_ERROR_CHECK(adc_cali_create_scheme_curve_fitting(&cali_cfg, &s_cali_handle));
// while (1) {
// int raw_val, voltage_mv;
// // Read Raw
// ESP_ERROR_CHECK(adc_oneshot_read(adc1_handle, ADC_CHANNEL_1, &raw_val));
// // Convert to mV using calibration
// ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
// printf("Raw: %d | Voltage: %d mV\n", raw_val, voltage_mv); // 2. Create continuous ADC handle
// vTaskDelay(pdMS_TO_TICKS(500)); adc_continuous_handle_cfg_t adc_cfg = {
// } .max_store_buf_size = ADC_POOL_SIZE,
#endif .conv_frame_size = ADC_CONV_FRAME_SIZE,
} };
ESP_ERROR_CHECK(adc_continuous_new_handle(&adc_cfg, &s_adc_handle));
void emg_sensor_read(emg_sample_t *sample) // 3. Configure scan pattern (4 channels, 4 kHz total)
{ adc_digi_pattern_config_t patterns[EMG_NUM_CHANNELS];
sample->timestamp_ms = emg_sensor_get_timestamp_ms();
#if FEATURE_FAKE_EMG
/*
* Generate fake EMG data:
* - Base value around 1650 (middle of 3.3V millivolt range)
* - Random noise of +/- 50
* - Mimics real EMG baseline noise
*/
for (int i = 0; i < EMG_NUM_CHANNELS; i++) { for (int i = 0; i < EMG_NUM_CHANNELS; i++) {
int noise = (rand() % 101) - 50; /* -50 to +50 */ patterns[i].atten = ADC_ATTEN_DB_12;
sample->channels[i] = (uint16_t)(1650 + noise); patterns[i].channel = (uint8_t)s_channels[i];
} patterns[i].unit = ADC_UNIT_1;
#else patterns[i].bit_width = ADC_BITWIDTH_12;
int raw_val, voltage_mv;
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) {
ESP_ERROR_CHECK(adc_oneshot_read(adc1_handle, emg_channels[i], &raw_val));
ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
sample->channels[i] = (uint16_t) voltage_mv;
} }
adc_continuous_config_t cont_cfg = {
.pattern_num = EMG_NUM_CHANNELS,
.adc_pattern = patterns,
.sample_freq_hz = ADC_TOTAL_SAMPLE_RATE_HZ,
.conv_mode = ADC_CONV_SINGLE_UNIT_1,
.format = ADC_DIGI_OUTPUT_FORMAT_TYPE2,
};
ESP_ERROR_CHECK(adc_continuous_config(s_adc_handle, &cont_cfg));
#endif // 4. Start DMA acquisition
ESP_ERROR_CHECK(adc_continuous_start(s_adc_handle));
// 5. Create sample queue (assembled complete sets land here)
s_sample_queue = xQueueCreate(SAMPLE_QUEUE_DEPTH, sizeof(emg_sample_t));
assert(s_sample_queue != NULL);
// 6. Launch sampling task pinned to Core 0
xTaskCreatePinnedToCore(adc_sampling_task, "adc_sample", 4096, NULL, 6, NULL, 0);
} }
uint32_t emg_sensor_get_timestamp_ms(void) void emg_sensor_read(emg_sample_t *sample) {
{ // Block until a complete 4-channel sample set arrives from the DMA task.
// At 1 kHz per channel this typically returns within ~1 ms.
xQueueReceive(s_sample_queue, sample, portMAX_DELAY);
}
uint32_t emg_sensor_get_timestamp_ms(void) {
return (uint32_t)(esp_timer_get_time() / 1000); return (uint32_t)(esp_timer_get_time() / 1000);
} }

View File

@@ -2,9 +2,8 @@
* @file emg_sensor.h * @file emg_sensor.h
* @brief EMG sensor driver for reading muscle signals. * @brief EMG sensor driver for reading muscle signals.
* *
* This module provides EMG data acquisition. Currently generates fake * This module provides EMG data acquisition from ADC channels connected
* data for testing (FEATURE_FAKE_EMG=1). When sensors arrive, the * to MyoWare sensors. Outputs calibrated millivolt values (0-3300 mV).
* implementation switches to real ADC reads without changing the interface.
* *
* @note This is Layer 2 (Driver). * @note This is Layer 2 (Driver).
*/ */
@@ -34,8 +33,7 @@ typedef struct {
/** /**
* @brief Initialize the EMG sensor system. * @brief Initialize the EMG sensor system.
* *
* If FEATURE_FAKE_EMG is enabled, just seeds the random generator. * Configures ADC channels and calibration for real sensor reading.
* Otherwise, configures ADC channels for real sensor reading.
*/ */
void emg_sensor_init(void); void emg_sensor_init(void);

View File

@@ -8,6 +8,9 @@
#include "hand.h" #include "hand.h"
#include "hal/servo_hal.h" #include "hal/servo_hal.h"
float maxAngles[] = {155, 155, 180, 165, 150};
float minAngles[] = {65, 45, 45, 30, 25};
/******************************************************************************* /*******************************************************************************
* Public Functions * Public Functions
******************************************************************************/ ******************************************************************************/

View File

@@ -13,6 +13,8 @@
#include "config/config.h" #include "config/config.h"
extern float maxAngles[];
extern float minAngles[];
/******************************************************************************* /*******************************************************************************
* Public Functions * Public Functions
******************************************************************************/ ******************************************************************************/

View File

@@ -55,7 +55,7 @@ void servo_hal_init(void)
.timer_sel = SERVO_PWM_TIMER, .timer_sel = SERVO_PWM_TIMER,
.intr_type = LEDC_INTR_DISABLE, .intr_type = LEDC_INTR_DISABLE,
.gpio_num = servo_pins[i], .gpio_num = servo_pins[i],
.duty = SERVO_DUTY_MIN, /* Start extended (open) */ .duty = servo_hal_degrees_to_duty(90), /* Start extended (open) */
.hpoint = 0 .hpoint = 0
}; };
ESP_ERROR_CHECK(ledc_channel_config(&channel_config)); ESP_ERROR_CHECK(ledc_channel_config(&channel_config));

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

2185
emg_gui.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,42 @@
"""
EMG Signal Analysis Script (STANDALONE - NOT PRODUCTION CODE)
==============================================================
!! WARNING !!
This script is for OFFLINE ANALYSIS AND VISUALIZATION ONLY.
It is NOT part of the training or inference pipeline.
DO NOT use this script's outputs for:
- Model training
- Feature extraction thresholds
- Production inference
The thresholds defined here (ZC_THRESHOLD_PERCENT, SSC_THRESHOLD_PERCENT)
are DIFFERENT from production values in:
- learning_data_collection.py (production: 0.1, 0.1)
- model_weights.h (production: 0.1, 0.1)
This script uses higher thresholds (0.7, 0.6) for visualization clarity,
which would produce INCORRECT features if used for training/inference.
To analyze collected sessions:
1. Update HDF5_PATH to your session file
2. Run: python learning_emg_filtering.py
3. View the generated plots
"""
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import h5py import h5py
from scipy.signal import butter, sosfiltfilt from scipy.signal import butter, sosfiltfilt
# ============================================================================= # =============================================================================
# CONFIGURABLE PARAMETERS # CONFIGURABLE PARAMETERS (FOR VISUALIZATION ONLY - NOT PRODUCTION VALUES!)
# ============================================================================= # =============================================================================
ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS # These thresholds are intentionally different from production (0.1, 0.1)
SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS # to provide cleaner visualizations. DO NOT use for training/inference.
ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold (VISUALIZATION ONLY)
SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold (VISUALIZATION ONLY)
# ============================================================================= # =============================================================================
# LOAD DATA FROM GUI's HDF5 FORMAT # LOAD DATA FROM GUI's HDF5 FORMAT

325
live_predict.py Normal file
View File

@@ -0,0 +1,325 @@
"""
live_predict.py — Laptop-side live EMG inference for Bucky Arm.
Use this script when the ESP32 is in EMG_MAIN mode and you want the laptop to
run the classifier (instead of the on-device model). Useful for:
- Comparing laptop accuracy vs. on-device accuracy before flashing a new model
- Debugging the feature pipeline without reflashing firmware
- Running an updated model that hasn't been exported to C yet
Workflow:
1. ESP32 must be in EMG_MAIN mode (MAIN_MODE = EMG_MAIN in config.h)
2. This script handshakes → requests STATE_LAPTOP_PREDICT
3. ESP32 streams raw ADC CSV at 1 kHz
4. Script collects 3s of REST for session normalization, then classifies
5. Every 25ms (one hop), the predicted gesture is sent back: {"gesture":"fist"}
6. ESP32 executes the received gesture command on the arm
Usage:
python live_predict.py --port COM3
python live_predict.py --port COM3 --model models/my_model.joblib
python live_predict.py --port COM3 --confidence 0.45
"""
import argparse
import sys
import time
from pathlib import Path
import numpy as np
import serial
# Import from the main training pipeline
sys.path.insert(0, str(Path(__file__).parent))
from learning_data_collection import (
EMGClassifier,
NUM_CHANNELS,
SAMPLING_RATE_HZ,
WINDOW_SIZE_MS,
HOP_SIZE_MS,
HAND_CHANNELS,
)
# Derived constants
WINDOW_SIZE = int(WINDOW_SIZE_MS * SAMPLING_RATE_HZ / 1000) # 150 samples
HOP_SIZE = int(HOP_SIZE_MS * SAMPLING_RATE_HZ / 1000) # 25 samples
BAUD_RATE = 921600
CALIB_SECS = 3.0 # seconds of REST to collect at startup for normalization
# ──────────────────────────────────────────────────────────────────────────────
def parse_args():
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--port", required=True,
help="Serial port (e.g. COM3 on Windows, /dev/ttyUSB0 on Linux)")
p.add_argument("--model", default=None,
help="Path to .joblib model file. Defaults to the most recently trained model.")
p.add_argument("--confidence", type=float, default=0.40,
help="Reject predictions below this confidence (default: 0.40, same as firmware)")
p.add_argument("--no-calib", action="store_true",
help="Skip REST calibration (use raw features — only for quick testing)")
return p.parse_args()
def load_model(model_path: str | None) -> EMGClassifier:
if model_path:
path = Path(model_path)
else:
path = EMGClassifier.get_latest_model_path()
if path is None:
print("[ERROR] No trained model found in models/. Run training first.")
sys.exit(1)
print(f"[Model] Auto-selected latest model: {path.name}")
classifier = EMGClassifier.load(path)
if not classifier.is_trained:
print("[ERROR] Loaded model is not trained.")
sys.exit(1)
return classifier
def handshake(ser: serial.Serial) -> bool:
"""Send connect command and wait for ack_connect from ESP32."""
print("[Handshake] Sending connect command...")
ser.write(b'{"cmd":"connect"}\n')
deadline = time.time() + 5.0
while time.time() < deadline:
raw = ser.readline()
line = raw.decode("utf-8", errors="ignore").strip()
if "ack_connect" in line:
print(f"[Handshake] Connected — {line}")
return True
if line:
print(f"[Handshake] (ignored) {line}")
print("[ERROR] No ack_connect within 5s. Is the ESP32 powered and in EMG_MAIN mode?")
return False
def collect_calibration_windows(
ser: serial.Serial,
n_windows: int,
extractor,
) -> tuple[np.ndarray, np.ndarray]:
"""
Collect n_windows of REST EMG, extract features, and compute
per-feature mean and std for session normalization.
Returns (mean, std) arrays of shape (n_features,).
"""
print(f"[Calib] Hold arm relaxed at rest for {CALIB_SECS:.0f}s...")
raw_buf = np.zeros((WINDOW_SIZE, NUM_CHANNELS), dtype=np.float32)
samples = 0
feat_list = []
while len(feat_list) < n_windows:
raw = ser.readline()
line = raw.decode("utf-8", errors="ignore").strip()
if not line or line.startswith("{"):
continue
try:
vals = [float(v) for v in line.split(",")]
except ValueError:
continue
if len(vals) != NUM_CHANNELS:
continue
# Slide window
raw_buf = np.roll(raw_buf, -1, axis=0)
raw_buf[-1] = vals
samples += 1
if samples >= WINDOW_SIZE and (samples % HOP_SIZE) == 0:
feat = extractor.extract_features_window(raw_buf)
feat_list.append(feat)
done = len(feat_list)
if done % 10 == 0:
pct = int(100 * done / n_windows)
print(f" {pct}% ({done}/{n_windows} windows)", end="\r", flush=True)
print(f"\n[Calib] Collected {len(feat_list)} windows.")
feats = np.array(feat_list, dtype=np.float32)
mean = feats.mean(axis=0)
std = np.where(feats.std(axis=0) > 1e-6, feats.std(axis=0), 1e-6).astype(np.float32)
print("[Calib] Session normalization computed.")
return mean, std
def load_ensemble():
"""Load ensemble sklearn models if available."""
path = Path(__file__).parent / 'models' / 'emg_ensemble.joblib'
if not path.exists():
return None
try:
import joblib
ens = joblib.load(path)
print(f"[Model] Loaded ensemble (4 LDAs)")
return ens
except Exception as e:
print(f"[Model] Ensemble load failed: {e}")
return None
def load_mlp():
"""Load MLP numpy weights if available."""
path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz'
if not path.exists():
return None
try:
mlp = dict(np.load(path, allow_pickle=True))
print(f"[Model] Loaded MLP weights (numpy)")
return mlp
except Exception as e:
print(f"[Model] MLP load failed: {e}")
return None
def run_ensemble(ens, features):
"""Run ensemble: 3 specialist LDAs → meta-LDA → probabilities."""
p_td = ens['lda_td'].predict_proba([features[ens['td_idx']]])[0]
p_fd = ens['lda_fd'].predict_proba([features[ens['fd_idx']]])[0]
p_cc = ens['lda_cc'].predict_proba([features[ens['cc_idx']]])[0]
x_meta = np.concatenate([p_td, p_fd, p_cc])
return ens['meta_lda'].predict_proba([x_meta])[0]
def run_mlp(mlp, features):
"""Run MLP forward pass: Dense(32,relu) → Dense(16,relu) → Dense(5,softmax)."""
x = features.astype(np.float32)
x = np.maximum(0, x @ mlp['w0'] + mlp['b0'])
x = np.maximum(0, x @ mlp['w1'] + mlp['b1'])
logits = x @ mlp['w2'] + mlp['b2']
e = np.exp(logits - logits.max())
return e / e.sum()
def main():
args = parse_args()
# ── Load classifier ──────────────────────────────────────────────────────
classifier = load_model(args.model)
extractor = classifier.feature_extractor
ensemble = load_ensemble()
mlp = load_mlp()
model_names = ["LDA"]
if ensemble:
model_names.append("Ensemble")
if mlp:
model_names.append("MLP")
print(f"[Model] Active: {' + '.join(model_names)} ({len(model_names)} models)")
# ── Open serial ──────────────────────────────────────────────────────────
try:
ser = serial.Serial(args.port, BAUD_RATE, timeout=1.0)
except serial.SerialException as e:
print(f"[ERROR] Could not open {args.port}: {e}")
sys.exit(1)
time.sleep(0.5)
ser.reset_input_buffer()
# ── Handshake ────────────────────────────────────────────────────────────
if not handshake(ser):
ser.close()
sys.exit(1)
# ── Request laptop-predict mode ──────────────────────────────────────────
ser.write(b'{"cmd":"start_laptop_predict"}\n')
print("[Control] ESP32 entering STATE_LAPTOP_PREDICT — streaming ADC...")
# ── Calibration ──────────────────────────────────────────────────────────
calib_mean = None
calib_std = None
if not args.no_calib:
n_calib = max(20, int(CALIB_SECS * 1000 / HOP_SIZE_MS))
calib_mean, calib_std = collect_calibration_windows(ser, n_calib, extractor)
else:
print("[Calib] Skipped (--no-calib). Accuracy may be reduced.")
# ── Live prediction loop ─────────────────────────────────────────────────
print(f"\n[Predict] Running. Confidence threshold: {args.confidence:.2f}")
print("[Predict] Press Ctrl+C to stop.\n")
raw_buf = np.zeros((WINDOW_SIZE, NUM_CHANNELS), dtype=np.float32)
samples = 0
last_gesture = None
n_inferences = 0
n_rejected = 0
try:
while True:
raw = ser.readline()
line = raw.decode("utf-8", errors="ignore").strip()
# Skip JSON telemetry lines from ESP32
if not line or line.startswith("{"):
continue
# Parse CSV sample
try:
vals = [float(v) for v in line.split(",")]
except ValueError:
continue
if len(vals) != NUM_CHANNELS:
continue
# Slide window
raw_buf = np.roll(raw_buf, -1, axis=0)
raw_buf[-1] = vals
samples += 1
# Classify every HOP_SIZE samples
if samples >= WINDOW_SIZE and (samples % HOP_SIZE) == 0:
feat = extractor.extract_features_window(raw_buf).astype(np.float32)
# Apply session normalization
if calib_mean is not None:
feat = (feat - calib_mean) / calib_std
# Run all available models and average probabilities
probas = [classifier.model.predict_proba([feat])[0]]
if ensemble:
try:
probas.append(run_ensemble(ensemble, feat))
except Exception:
pass
if mlp:
try:
probas.append(run_mlp(mlp, feat))
except Exception:
pass
proba = np.mean(probas, axis=0)
class_idx = int(np.argmax(proba))
confidence = float(proba[class_idx])
gesture = classifier.label_names[class_idx]
n_inferences += 1
# Reject below threshold
if confidence < args.confidence:
n_rejected += 1
continue
# Send gesture command to ESP32
cmd = f'{{"gesture":"{gesture}"}}\n'
ser.write(cmd.encode("utf-8"))
# Local logging (only on change)
if gesture != last_gesture:
reject_rate = 100 * n_rejected / n_inferences if n_inferences else 0
print(f"{gesture:<12} conf={confidence:.2f} "
f"reject_rate={reject_rate:.0f}%")
last_gesture = gesture
n_rejected = 0
n_inferences = 0
except KeyboardInterrupt:
print("\n\n[Stop] Sending stop command to ESP32...")
ser.write(b'{"cmd":"stop"}\n')
time.sleep(0.2)
ser.close()
print("[Stop] Done.")
if __name__ == "__main__":
main()

BIN
models/emg_ensemble.joblib Normal file

Binary file not shown.

Binary file not shown.

BIN
models/emg_mlp_weights.npz Normal file

Binary file not shown.

Binary file not shown.

200
train_ensemble.py Normal file
View File

@@ -0,0 +1,200 @@
"""
Train the full 3-specialist-LDA + meta-LDA ensemble.
Requires Change 1 (expanded features) to be implemented first.
Exports model_weights_ensemble.h for firmware Change F.
Architecture:
LDA_TD (36 time-domain feat) ─┐
LDA_FD (24 freq-domain feat) ├─ 15 probs ─► Meta-LDA ─► final class
LDA_CC (9 cross-ch feat) ─┘
Change 7 — priority Tier 3.
"""
import numpy as np
from pathlib import Path
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_predict, GroupKFold, cross_val_score
import sys
sys.path.insert(0, str(Path(__file__).parent))
from learning_data_collection import (
SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
)
# --- Load and extract features -----------------------------------------------
storage = SessionStorage()
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
X = extractor.extract_features_batch(X_raw).astype(np.float64)
# Per-session class-balanced normalization
# Must match EMGClassifier._apply_session_normalization():
# mean = average of per-class means (not overall mean), std = overall std.
# StandardScaler uses overall mean, which biases toward the majority class.
for sid in np.unique(session_indices):
mask = session_indices == sid
X_sess = X[mask]
y_sess = y[mask]
# Class-balanced mean: average of per-class centroids
class_means = []
for cls in np.unique(y_sess):
class_means.append(X_sess[y_sess == cls].mean(axis=0))
balanced_mean = np.mean(class_means, axis=0)
# Overall std (same as StandardScaler)
std = X_sess.std(axis=0)
std[std < 1e-12] = 1.0 # avoid division by zero
X[mask] = (X_sess - balanced_mean) / std
feat_names = extractor.get_feature_names(n_channels=len(HAND_CHANNELS))
n_cls = len(np.unique(y))
# --- Feature subset indices ---------------------------------------------------
# Per-channel layout (20 features/channel): indices 0-11 TD, 12-19 FD
# Cross-channel features start at index 60 (3 channels × 20 features each)
TD_FEAT = ['rms', 'wl', 'zc', 'ssc', 'mav', 'var', 'iemg', 'wamp', 'ar1', 'ar2', 'ar3', 'ar4']
FD_FEAT = ['mnf', 'mdf', 'pkf', 'mnp', 'bp0', 'bp1', 'bp2', 'bp3']
td_idx = [i for i, n in enumerate(feat_names)
if any(n.endswith(f'_{f}') for f in TD_FEAT) and n.startswith('ch')]
fd_idx = [i for i, n in enumerate(feat_names)
if any(n.endswith(f'_{f}') for f in FD_FEAT) and n.startswith('ch')]
cc_idx = [i for i, n in enumerate(feat_names) if n.startswith('cc_')]
print(f"Feature subsets — TD: {len(td_idx)}, FD: {len(fd_idx)}, CC: {len(cc_idx)}")
assert len(td_idx) == 36, f"Expected 36 TD features, got {len(td_idx)}"
assert len(fd_idx) == 24, f"Expected 24 FD features, got {len(fd_idx)}"
assert len(cc_idx) == 9, f"Expected 9 CC features, got {len(cc_idx)}"
X_td = X[:, td_idx]
X_fd = X[:, fd_idx]
X_cc = X[:, cc_idx]
# --- Train specialist LDAs with out-of-fold stacking -------------------------
gkf = GroupKFold(n_splits=min(5, len(np.unique(trial_ids))))
print("Training specialist LDAs (out-of-fold for stacking)...")
lda_td = LinearDiscriminantAnalysis()
lda_fd = LinearDiscriminantAnalysis()
lda_cc = LinearDiscriminantAnalysis()
oof_td = cross_val_predict(lda_td, X_td, y, cv=gkf, groups=trial_ids, method='predict_proba')
oof_fd = cross_val_predict(lda_fd, X_fd, y, cv=gkf, groups=trial_ids, method='predict_proba')
oof_cc = cross_val_predict(lda_cc, X_cc, y, cv=gkf, groups=trial_ids, method='predict_proba')
# Specialist CV accuracy (for diagnostics)
for name, mdl, Xs in [('LDA_TD', lda_td, X_td),
('LDA_FD', lda_fd, X_fd),
('LDA_CC', lda_cc, X_cc)]:
sc = cross_val_score(mdl, Xs, y, cv=gkf, groups=trial_ids)
print(f" {name}: {sc.mean()*100:.1f}% ± {sc.std()*100:.1f}%")
# --- Train meta-LDA on out-of-fold outputs ------------------------------------
X_meta = np.hstack([oof_td, oof_fd, oof_cc]) # (n_samples, 3*n_cls = 15)
meta_lda = LinearDiscriminantAnalysis()
meta_sc = cross_val_score(meta_lda, X_meta, y, cv=gkf, groups=trial_ids)
print(f" Meta-LDA: {meta_sc.mean()*100:.1f}% ± {meta_sc.std()*100:.1f}%")
# Fit all models on full dataset for deployment
lda_td.fit(X_td, y)
lda_fd.fit(X_fd, y)
lda_cc.fit(X_cc, y)
meta_lda.fit(X_meta, y)
# --- Export all weights to C header ------------------------------------------
def lda_to_c_arrays(lda, name, feat_dim, n_cls, label_names, class_order):
"""Generate C array strings for LDA weights and intercepts.
NOTE: sklearn LDA.coef_ for multi-class has shape (n_classes-1, n_features)
when using SVD solver. If so, we use decision_function and re-derive weights.
"""
coef = lda.coef_
intercept = lda.intercept_
if coef.shape[0] != n_cls:
# SVD solver returns (n_cls-1, n_feat); sklearn handles this internally
# via scalings_. We refit with 'lsqr' solver to get full coef matrix.
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA2
lda2 = LDA2(solver='lsqr')
# We can't refit here (no data) so just warn and pad with zeros
print(f" WARNING: {name} coef_ shape {coef.shape} != ({n_cls}, {feat_dim}). "
f"Padding with zeros. Refit with solver='lsqr' for full matrix.")
padded = np.zeros((n_cls, feat_dim))
padded[:coef.shape[0]] = coef
coef = padded
padded_i = np.zeros(n_cls)
padded_i[:intercept.shape[0]] = intercept
intercept = padded_i
lines = []
lines.append(f"const float {name}_WEIGHTS[{n_cls}][{feat_dim}] = {{")
for c in class_order:
row = ', '.join(f'{v:.8f}f' for v in coef[c])
lines.append(f" {{{row}}}, // {label_names[c]}")
lines.append("};")
lines.append(f"const float {name}_INTERCEPTS[{n_cls}] = {{")
intercept_str = ', '.join(f'{intercept[c]:.8f}f' for c in class_order)
lines.append(f" {intercept_str}")
lines.append("};")
return '\n'.join(lines)
class_order = list(range(n_cls))
out_path = Path(__file__).parent / 'EMG_Arm/src/core/model_weights_ensemble.h'
out_path.parent.mkdir(parents=True, exist_ok=True)
td_offset = min(td_idx) if td_idx else 0
fd_offset = min(fd_idx) if fd_idx else 0
cc_offset = min(cc_idx) if cc_idx else 0
with open(out_path, 'w') as f:
f.write("// Auto-generated by train_ensemble.py — do not edit\n")
f.write("#pragma once\n\n")
f.write("// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from\n")
f.write("// model_weights.h to avoid redefinition conflicts.\n")
f.write('#include "model_weights.h"\n\n')
f.write(f"#define ENSEMBLE_PER_CH_FEATURES 20\n\n")
f.write(f"#define TD_FEAT_OFFSET {td_offset}\n")
f.write(f"#define TD_NUM_FEATURES {len(td_idx)}\n")
f.write(f"#define FD_FEAT_OFFSET {fd_offset}\n")
f.write(f"#define FD_NUM_FEATURES {len(fd_idx)}\n")
f.write(f"#define CC_FEAT_OFFSET {cc_offset}\n")
f.write(f"#define CC_NUM_FEATURES {len(cc_idx)}\n")
f.write(f"#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)\n\n")
f.write("// Feature index arrays for gather operations (TD and FD are non-contiguous)\n")
f.write(f"// TD indices: {td_idx}\n")
f.write(f"// FD indices: {fd_idx}\n")
f.write(f"// CC indices: {cc_idx}\n\n")
f.write(lda_to_c_arrays(lda_td, 'LDA_TD', len(td_idx), n_cls, label_names, class_order))
f.write('\n\n')
f.write(lda_to_c_arrays(lda_fd, 'LDA_FD', len(fd_idx), n_cls, label_names, class_order))
f.write('\n\n')
f.write(lda_to_c_arrays(lda_cc, 'LDA_CC', len(cc_idx), n_cls, label_names, class_order))
f.write('\n\n')
f.write(lda_to_c_arrays(meta_lda, 'META_LDA', 3 * n_cls, n_cls, label_names, class_order))
f.write('\n')
print(f"Exported ensemble weights to {out_path}")
print(f"Total weight storage: "
f"{(len(td_idx) + len(fd_idx) + len(cc_idx) + 3*n_cls) * n_cls * 4} bytes float32")
# --- Also save sklearn models for laptop-side inference ----------------------
import joblib
ensemble_bundle = {
'lda_td': lda_td,
'lda_fd': lda_fd,
'lda_cc': lda_cc,
'meta_lda': meta_lda,
'td_idx': td_idx,
'fd_idx': fd_idx,
'cc_idx': cc_idx,
'label_names': label_names,
}
ensemble_joblib = Path(__file__).parent / 'models' / 'emg_ensemble.joblib'
ensemble_joblib.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(ensemble_bundle, ensemble_joblib)
print(f"Saved laptop ensemble model to {ensemble_joblib}")

106
train_mlp_tflite.py Normal file
View File

@@ -0,0 +1,106 @@
"""
Train int8 MLP for ESP32-S3 deployment via TFLite Micro.
Run AFTER Change 0 (label shift) + Change 1 (expanded features).
Change E — priority Tier 3.
Outputs: EMG_Arm/src/core/emg_model_data.cc
"""
import numpy as np
from pathlib import Path
import sys
sys.path.insert(0, str(Path(__file__).parent))
from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
try:
import tensorflow as tf
except ImportError:
print("ERROR: TensorFlow not installed. Run: pip install tensorflow")
sys.exit(1)
# --- Load and extract features -----------------------------------------------
storage = SessionStorage()
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
X = extractor.extract_features_batch(X_raw).astype(np.float32)
# Per-session class-balanced normalization (must match EMGClassifier + train_ensemble.py)
for sid in np.unique(session_indices):
mask = session_indices == sid
X_sess = X[mask]
y_sess = y[mask]
class_means = [X_sess[y_sess == cls].mean(axis=0) for cls in np.unique(y_sess)]
balanced_mean = np.mean(class_means, axis=0)
std = X_sess.std(axis=0)
std[std < 1e-12] = 1.0
X[mask] = (X_sess - balanced_mean) / std
n_feat = X.shape[1]
n_cls = len(np.unique(y))
print(f"Dataset: {len(X)} samples, {n_feat} features, {n_cls} classes")
print(f"Classes: {label_names}")
# --- Build and train MLP -----------------------------------------------------
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(n_feat,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(n_cls, activation='softmax'),
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1)
# --- Convert to int8 TFLite --------------------------------------------------
def representative_dataset():
for i in range(0, len(X), 10):
yield [X[i:i+1]]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
# --- Write emg_model_data.cc -------------------------------------------------
out_cc = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.cc'
out_h = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.h'
with open(out_cc, 'w') as f:
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
f.write('#include "emg_model_data.h"\n')
f.write(f'const int g_model_len = {len(tflite_model)};\n')
f.write('alignas(8) const unsigned char g_model[] = {\n ')
f.write(', '.join(f'0x{b:02x}' for b in tflite_model))
f.write('\n};\n')
with open(out_h, 'w') as f:
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
f.write('#pragma once\n\n')
f.write('#ifdef __cplusplus\nextern "C" {\n#endif\n\n')
f.write('extern const unsigned char g_model[];\n')
f.write('extern const int g_model_len;\n\n')
f.write('#ifdef __cplusplus\n}\n#endif\n')
print(f"Wrote {out_cc} ({len(tflite_model)} bytes)")
print(f"Wrote {out_h}")
# --- Save MLP weights as numpy for laptop-side inference (no TF needed) ------
layer_weights = [layer.get_weights() for layer in model.layers if layer.get_weights()]
mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz'
mlp_path.parent.mkdir(parents=True, exist_ok=True)
np.savez(mlp_path,
w0=layer_weights[0][0], b0=layer_weights[0][1], # Dense(32, relu)
w1=layer_weights[1][0], b1=layer_weights[1][1], # Dense(16, relu)
w2=layer_weights[2][0], b2=layer_weights[2][1], # Dense(5, softmax)
label_names=np.array(label_names))
print(f"Saved laptop MLP weights to {mlp_path}")
print(f"\nNext steps:")
print(f" 1. Set MODEL_USE_MLP 1 in EMG_Arm/src/core/model_weights.h")
print(f" 2. Run: pio run -t upload")