Multi-model prediction on laptop | need to move to on-board prediction
This commit is contained in:
2736
BUCKY_ARM_MASTER_PLAN.md
Normal file
2736
BUCKY_ARM_MASTER_PLAN.md
Normal file
File diff suppressed because it is too large
Load Diff
10
EMG_Arm/dependencies.lock
Normal file
10
EMG_Arm/dependencies.lock
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
dependencies:
|
||||||
|
idf:
|
||||||
|
source:
|
||||||
|
type: idf
|
||||||
|
version: 5.5.1
|
||||||
|
direct_dependencies:
|
||||||
|
- idf
|
||||||
|
manifest_hash: 26c322f28d1cb305b28c1bcb6df69caf3919f0c18286c9ac6394e338c217fba8
|
||||||
|
target: esp32s3
|
||||||
|
version: 2.0.0
|
||||||
10
EMG_Arm/idf_component.yml
Normal file
10
EMG_Arm/idf_component.yml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# ESP-IDF Component Manager dependencies
|
||||||
|
# These are ESP-IDF components (NOT PlatformIO libraries).
|
||||||
|
# Run `pio run --target menuconfig` after modifying to refresh the component index.
|
||||||
|
|
||||||
|
dependencies:
|
||||||
|
# esp-dsp: required when MODEL_EXPAND_FEATURES=1 (Change 1 — 69-feature FFT)
|
||||||
|
espressif/esp-dsp: ">=2.0.0"
|
||||||
|
|
||||||
|
# TFLite Micro: required when MODEL_USE_MLP=1 (Change E — int8 MLP)
|
||||||
|
# tensorflow/tflite-micro: ">=2.0.0"
|
||||||
@@ -14,4 +14,7 @@ board_build.partitions = partitions.csv
|
|||||||
|
|
||||||
monitor_speed = 921600
|
monitor_speed = 921600
|
||||||
monitor_dtr = 1
|
monitor_dtr = 1
|
||||||
monitor_rts = 1
|
monitor_rts = 1
|
||||||
|
|
||||||
|
; ── esp-dsp: required for MODEL_EXPAND_FEATURES=1 (FFT-based features) ───────
|
||||||
|
; Cloned locally: components/esp-dsp
|
||||||
@@ -594,6 +594,14 @@ CONFIG_PARTITION_TABLE_OFFSET=0x8000
|
|||||||
CONFIG_PARTITION_TABLE_MD5=y
|
CONFIG_PARTITION_TABLE_MD5=y
|
||||||
# end of Partition Table
|
# end of Partition Table
|
||||||
|
|
||||||
|
#
|
||||||
|
# ESP-NN
|
||||||
|
#
|
||||||
|
# CONFIG_NN_ANSI_C is not set
|
||||||
|
CONFIG_NN_OPTIMIZED=y
|
||||||
|
CONFIG_NN_OPTIMIZATIONS=1
|
||||||
|
# end of ESP-NN
|
||||||
|
|
||||||
#
|
#
|
||||||
# Compiler options
|
# Compiler options
|
||||||
#
|
#
|
||||||
@@ -2237,6 +2245,23 @@ CONFIG_WIFI_PROV_AUTOSTOP_TIMEOUT=30
|
|||||||
CONFIG_WIFI_PROV_STA_ALL_CHANNEL_SCAN=y
|
CONFIG_WIFI_PROV_STA_ALL_CHANNEL_SCAN=y
|
||||||
# CONFIG_WIFI_PROV_STA_FAST_SCAN is not set
|
# CONFIG_WIFI_PROV_STA_FAST_SCAN is not set
|
||||||
# end of Wi-Fi Provisioning Manager
|
# end of Wi-Fi Provisioning Manager
|
||||||
|
|
||||||
|
#
|
||||||
|
# DSP Library
|
||||||
|
#
|
||||||
|
CONFIG_DSP_OPTIMIZATIONS_SUPPORTED=y
|
||||||
|
# CONFIG_DSP_ANSI is not set
|
||||||
|
CONFIG_DSP_OPTIMIZED=y
|
||||||
|
CONFIG_DSP_OPTIMIZATION=1
|
||||||
|
# CONFIG_DSP_MAX_FFT_SIZE_512 is not set
|
||||||
|
# CONFIG_DSP_MAX_FFT_SIZE_1024 is not set
|
||||||
|
# CONFIG_DSP_MAX_FFT_SIZE_2048 is not set
|
||||||
|
CONFIG_DSP_MAX_FFT_SIZE_4096=y
|
||||||
|
# CONFIG_DSP_MAX_FFT_SIZE_8192 is not set
|
||||||
|
# CONFIG_DSP_MAX_FFT_SIZE_16384 is not set
|
||||||
|
# CONFIG_DSP_MAX_FFT_SIZE_32768 is not set
|
||||||
|
CONFIG_DSP_MAX_FFT_SIZE=4096
|
||||||
|
# end of DSP Library
|
||||||
# end of Component config
|
# end of Component config
|
||||||
|
|
||||||
# CONFIG_IDF_EXPERIMENTAL_FEATURES is not set
|
# CONFIG_IDF_EXPERIMENTAL_FEATURES is not set
|
||||||
|
|||||||
@@ -19,8 +19,13 @@ set(DRIVER_SOURCES
|
|||||||
)
|
)
|
||||||
|
|
||||||
set(CORE_SOURCES
|
set(CORE_SOURCES
|
||||||
|
core/bicep.c
|
||||||
|
core/calibration.c
|
||||||
core/gestures.c
|
core/gestures.c
|
||||||
core/inference.c
|
core/inference.c
|
||||||
|
core/inference_ensemble.c
|
||||||
|
core/inference_mlp.cc
|
||||||
|
core/emg_model_data.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
set(APP_SOURCES
|
set(APP_SOURCES
|
||||||
@@ -36,5 +41,5 @@ idf_component_register(
|
|||||||
${APP_SOURCES}
|
${APP_SOURCES}
|
||||||
INCLUDE_DIRS
|
INCLUDE_DIRS
|
||||||
.
|
.
|
||||||
REQUIRES esp_adc
|
REQUIRES esp_adc nvs_flash esp-dsp esp-tflite-micro
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,8 +20,13 @@
|
|||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include "config/config.h"
|
#include "config/config.h"
|
||||||
|
#include "core/bicep.h"
|
||||||
|
#include "core/calibration.h"
|
||||||
#include "core/gestures.h"
|
#include "core/gestures.h"
|
||||||
#include "core/inference.h" // [NEW]
|
#include "core/inference.h"
|
||||||
|
#include "core/inference_ensemble.h"
|
||||||
|
#include "core/inference_mlp.h"
|
||||||
|
#include "core/model_weights.h"
|
||||||
#include "drivers/emg_sensor.h"
|
#include "drivers/emg_sensor.h"
|
||||||
#include "drivers/hand.h"
|
#include "drivers/hand.h"
|
||||||
|
|
||||||
@@ -40,10 +45,12 @@
|
|||||||
* @brief Device state machine.
|
* @brief Device state machine.
|
||||||
*/
|
*/
|
||||||
typedef enum {
|
typedef enum {
|
||||||
STATE_IDLE = 0, /**< Waiting for connect command */
|
STATE_IDLE = 0, /**< Waiting for connect command */
|
||||||
STATE_CONNECTED, /**< Connected, waiting for start command */
|
STATE_CONNECTED, /**< Connected, waiting for start command */
|
||||||
STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */
|
STATE_STREAMING, /**< Streaming raw EMG CSV to laptop (data collection) */
|
||||||
STATE_PREDICTING, /**< [NEW] On-device inference and control */
|
STATE_PREDICTING, /**< On-device inference + arm control */
|
||||||
|
STATE_LAPTOP_PREDICT, /**< Streaming CSV to laptop; laptop infers + sends gesture cmds back */
|
||||||
|
STATE_CALIBRATING, /**< Collecting rest data for calibration */
|
||||||
} device_state_t;
|
} device_state_t;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -52,8 +59,10 @@ typedef enum {
|
|||||||
typedef enum {
|
typedef enum {
|
||||||
CMD_NONE = 0,
|
CMD_NONE = 0,
|
||||||
CMD_CONNECT,
|
CMD_CONNECT,
|
||||||
CMD_START, /**< Start raw streaming */
|
CMD_START, /**< Start raw ADC streaming to laptop */
|
||||||
CMD_START_PREDICT, /**< [NEW] Start on-device prediction */
|
CMD_START_PREDICT, /**< Start on-device inference + arm control */
|
||||||
|
CMD_START_LAPTOP_PREDICT, /**< Start laptop-mediated inference (stream + receive cmds) */
|
||||||
|
CMD_CALIBRATE, /**< Run rest calibration (z-score + bicep threshold) */
|
||||||
CMD_STOP,
|
CMD_STOP,
|
||||||
CMD_DISCONNECT,
|
CMD_DISCONNECT,
|
||||||
} command_t;
|
} command_t;
|
||||||
@@ -65,11 +74,17 @@ typedef enum {
|
|||||||
static volatile device_state_t g_device_state = STATE_IDLE;
|
static volatile device_state_t g_device_state = STATE_IDLE;
|
||||||
static QueueHandle_t g_cmd_queue = NULL;
|
static QueueHandle_t g_cmd_queue = NULL;
|
||||||
|
|
||||||
|
// Latest gesture command received from laptop during STATE_LAPTOP_PREDICT.
|
||||||
|
// Written by serial_input_task; read+cleared by run_laptop_predict_loop.
|
||||||
|
// gesture_t is a 32-bit int on LX7 — reads/writes are atomic; volatile is sufficient.
|
||||||
|
static volatile gesture_t g_laptop_gesture = GESTURE_NONE;
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Forward Declarations
|
* Forward Declarations
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
static void send_ack_connect(void);
|
static void send_ack_connect(void);
|
||||||
|
static gesture_t parse_laptop_gesture(const char *line);
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Command Parsing
|
* Command Parsing
|
||||||
@@ -102,13 +117,17 @@ static command_t parse_command(const char *line) {
|
|||||||
value_start++;
|
value_start++;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Match command strings */
|
/* Match command strings — ordered longest-prefix-first to avoid false matches */
|
||||||
if (strncmp(value_start, "connect", 7) == 0) {
|
if (strncmp(value_start, "connect", 7) == 0) {
|
||||||
return CMD_CONNECT;
|
return CMD_CONNECT;
|
||||||
|
} else if (strncmp(value_start, "start_laptop_predict", 20) == 0) {
|
||||||
|
return CMD_START_LAPTOP_PREDICT;
|
||||||
} else if (strncmp(value_start, "start_predict", 13) == 0) {
|
} else if (strncmp(value_start, "start_predict", 13) == 0) {
|
||||||
return CMD_START_PREDICT;
|
return CMD_START_PREDICT;
|
||||||
} else if (strncmp(value_start, "start", 5) == 0) {
|
} else if (strncmp(value_start, "start", 5) == 0) {
|
||||||
return CMD_START;
|
return CMD_START;
|
||||||
|
} else if (strncmp(value_start, "calibrate", 9) == 0) {
|
||||||
|
return CMD_CALIBRATE;
|
||||||
} else if (strncmp(value_start, "stop", 4) == 0) {
|
} else if (strncmp(value_start, "stop", 4) == 0) {
|
||||||
return CMD_STOP;
|
return CMD_STOP;
|
||||||
} else if (strncmp(value_start, "disconnect", 10) == 0) {
|
} else if (strncmp(value_start, "disconnect", 10) == 0) {
|
||||||
@@ -118,6 +137,38 @@ static command_t parse_command(const char *line) {
|
|||||||
return CMD_NONE;
|
return CMD_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Laptop Gesture Parser
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Parse a gesture command sent by live_predict.py.
|
||||||
|
*
|
||||||
|
* Expected format: {"gesture":"fist"}
|
||||||
|
* Returns GESTURE_NONE if the line is not a valid gesture command.
|
||||||
|
*/
|
||||||
|
static gesture_t parse_laptop_gesture(const char *line) {
|
||||||
|
const char *g = strstr(line, "\"gesture\"");
|
||||||
|
if (!g) return GESTURE_NONE;
|
||||||
|
|
||||||
|
const char *v = strchr(g, ':');
|
||||||
|
if (!v) return GESTURE_NONE;
|
||||||
|
|
||||||
|
v++;
|
||||||
|
while (*v == ' ' || *v == '"') v++;
|
||||||
|
|
||||||
|
// Extract gesture name up to the closing quote
|
||||||
|
char name[32] = {0};
|
||||||
|
int ni = 0;
|
||||||
|
while (*v && *v != '"' && ni < (int)(sizeof(name) - 1)) {
|
||||||
|
name[ni++] = *v++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delegate to the inference module's name→enum mapping
|
||||||
|
int result = inference_get_gesture_by_name(name);
|
||||||
|
return (gesture_t)result;
|
||||||
|
}
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Serial Input Task
|
* Serial Input Task
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
@@ -142,6 +193,15 @@ static void serial_input_task(void *pvParameters) {
|
|||||||
line_buffer[line_idx] = '\0';
|
line_buffer[line_idx] = '\0';
|
||||||
command_t cmd = parse_command(line_buffer);
|
command_t cmd = parse_command(line_buffer);
|
||||||
|
|
||||||
|
// When laptop-predict is active, try to parse gesture commands first.
|
||||||
|
// These are separate from FSM commands and must not block the FSM parser.
|
||||||
|
if (g_device_state == STATE_LAPTOP_PREDICT) {
|
||||||
|
gesture_t g = parse_laptop_gesture(line_buffer);
|
||||||
|
if (g != GESTURE_NONE) {
|
||||||
|
g_laptop_gesture = g;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (cmd != CMD_NONE) {
|
if (cmd != CMD_NONE) {
|
||||||
if (cmd == CMD_CONNECT) {
|
if (cmd == CMD_CONNECT) {
|
||||||
g_device_state = STATE_CONNECTED;
|
g_device_state = STATE_CONNECTED;
|
||||||
@@ -161,6 +221,15 @@ static void serial_input_task(void *pvParameters) {
|
|||||||
g_device_state = STATE_PREDICTING;
|
g_device_state = STATE_PREDICTING;
|
||||||
printf("[STATE] CONNECTED -> PREDICTING\n");
|
printf("[STATE] CONNECTED -> PREDICTING\n");
|
||||||
xQueueSend(g_cmd_queue, &cmd, 0);
|
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||||
|
} else if (cmd == CMD_START_LAPTOP_PREDICT) {
|
||||||
|
g_device_state = STATE_LAPTOP_PREDICT;
|
||||||
|
g_laptop_gesture = GESTURE_NONE;
|
||||||
|
printf("[STATE] CONNECTED -> LAPTOP_PREDICT\n");
|
||||||
|
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||||
|
} else if (cmd == CMD_CALIBRATE) {
|
||||||
|
g_device_state = STATE_CALIBRATING;
|
||||||
|
printf("[STATE] CONNECTED -> CALIBRATING\n");
|
||||||
|
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||||
} else if (cmd == CMD_DISCONNECT) {
|
} else if (cmd == CMD_DISCONNECT) {
|
||||||
g_device_state = STATE_IDLE;
|
g_device_state = STATE_IDLE;
|
||||||
printf("[STATE] CONNECTED -> IDLE\n");
|
printf("[STATE] CONNECTED -> IDLE\n");
|
||||||
@@ -169,6 +238,8 @@ static void serial_input_task(void *pvParameters) {
|
|||||||
|
|
||||||
case STATE_STREAMING:
|
case STATE_STREAMING:
|
||||||
case STATE_PREDICTING:
|
case STATE_PREDICTING:
|
||||||
|
case STATE_LAPTOP_PREDICT:
|
||||||
|
case STATE_CALIBRATING:
|
||||||
if (cmd == CMD_STOP) {
|
if (cmd == CMD_STOP) {
|
||||||
g_device_state = STATE_CONNECTED;
|
g_device_state = STATE_CONNECTED;
|
||||||
printf("[STATE] ACTIVE -> CONNECTED\n");
|
printf("[STATE] ACTIVE -> CONNECTED\n");
|
||||||
@@ -206,59 +277,198 @@ static void send_ack_connect(void) {
|
|||||||
*/
|
*/
|
||||||
static void stream_emg_data(void) {
|
static void stream_emg_data(void) {
|
||||||
emg_sample_t sample;
|
emg_sample_t sample;
|
||||||
const TickType_t delay_ticks = 1;
|
|
||||||
|
|
||||||
while (g_device_state == STATE_STREAMING) {
|
while (g_device_state == STATE_STREAMING) {
|
||||||
emg_sensor_read(&sample);
|
emg_sensor_read(&sample); // blocks ~1 ms (queue-paced by DMA task)
|
||||||
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
|
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
|
||||||
sample.channels[2], sample.channels[3]);
|
sample.channels[2], sample.channels[3]);
|
||||||
vTaskDelay(delay_ticks);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Multi-Model Voting Post-Processing
|
||||||
|
*
|
||||||
|
* Shared EMA smoothing, majority vote, and debounce applied to the averaged
|
||||||
|
* probability output from all enabled models (single LDA, ensemble, MLP).
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
#define VOTE_EMA_ALPHA 0.70f
|
||||||
|
#define VOTE_CONF_THRESHOLD 0.40f
|
||||||
|
#define VOTE_WINDOW_SIZE 5
|
||||||
|
#define VOTE_DEBOUNCE_COUNT 3
|
||||||
|
|
||||||
|
static float vote_smoothed[MODEL_NUM_CLASSES];
|
||||||
|
static int vote_history[VOTE_WINDOW_SIZE];
|
||||||
|
static int vote_head = 0;
|
||||||
|
static int vote_current_output = -1;
|
||||||
|
static int vote_pending_output = -1;
|
||||||
|
static int vote_pending_count = 0;
|
||||||
|
|
||||||
|
static void vote_init(void) {
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
vote_smoothed[c] = 1.0f / MODEL_NUM_CLASSES;
|
||||||
|
for (int i = 0; i < VOTE_WINDOW_SIZE; i++)
|
||||||
|
vote_history[i] = -1;
|
||||||
|
vote_head = 0;
|
||||||
|
vote_current_output = -1;
|
||||||
|
vote_pending_output = -1;
|
||||||
|
vote_pending_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Apply EMA + majority vote + debounce to averaged probabilities.
|
||||||
|
*
|
||||||
|
* @param avg_proba Averaged probability vector from all models.
|
||||||
|
* @param confidence Output: smoothed confidence of the final winner.
|
||||||
|
* @return Final gesture class index (-1 if uncertain).
|
||||||
|
*/
|
||||||
|
static int vote_postprocess(const float *avg_proba, float *confidence) {
|
||||||
|
/* EMA smoothing */
|
||||||
|
float max_smooth = 0.0f;
|
||||||
|
int winner = 0;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
vote_smoothed[c] = VOTE_EMA_ALPHA * vote_smoothed[c] +
|
||||||
|
(1.0f - VOTE_EMA_ALPHA) * avg_proba[c];
|
||||||
|
if (vote_smoothed[c] > max_smooth) {
|
||||||
|
max_smooth = vote_smoothed[c];
|
||||||
|
winner = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Confidence rejection */
|
||||||
|
if (max_smooth < VOTE_CONF_THRESHOLD) {
|
||||||
|
*confidence = max_smooth;
|
||||||
|
return vote_current_output;
|
||||||
|
}
|
||||||
|
|
||||||
|
*confidence = max_smooth;
|
||||||
|
|
||||||
|
/* Majority vote */
|
||||||
|
vote_history[vote_head] = winner;
|
||||||
|
vote_head = (vote_head + 1) % VOTE_WINDOW_SIZE;
|
||||||
|
|
||||||
|
int counts[MODEL_NUM_CLASSES];
|
||||||
|
memset(counts, 0, sizeof(counts));
|
||||||
|
for (int i = 0; i < VOTE_WINDOW_SIZE; i++) {
|
||||||
|
if (vote_history[i] >= 0)
|
||||||
|
counts[vote_history[i]]++;
|
||||||
|
}
|
||||||
|
int majority = 0, majority_cnt = 0;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
if (counts[c] > majority_cnt) {
|
||||||
|
majority_cnt = counts[c];
|
||||||
|
majority = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Debounce */
|
||||||
|
int final_out = (vote_current_output >= 0) ? vote_current_output : majority;
|
||||||
|
|
||||||
|
if (vote_current_output < 0) {
|
||||||
|
vote_current_output = majority;
|
||||||
|
final_out = majority;
|
||||||
|
} else if (majority == vote_current_output) {
|
||||||
|
vote_pending_output = majority;
|
||||||
|
vote_pending_count = 1;
|
||||||
|
} else if (majority == vote_pending_output) {
|
||||||
|
if (++vote_pending_count >= VOTE_DEBOUNCE_COUNT) {
|
||||||
|
vote_current_output = majority;
|
||||||
|
final_out = majority;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
vote_pending_output = majority;
|
||||||
|
vote_pending_count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return final_out;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Run on-device inference (Prediction Mode).
|
* @brief Run on-device inference (Prediction Mode).
|
||||||
|
*
|
||||||
|
* Uses multi-model voting: all enabled models (single LDA, ensemble, MLP)
|
||||||
|
* produce raw probability vectors which are averaged, then passed through
|
||||||
|
* shared EMA smoothing, majority vote, and debounce.
|
||||||
*/
|
*/
|
||||||
static void run_inference_loop(void) {
|
static void run_inference_loop(void) {
|
||||||
emg_sample_t sample;
|
emg_sample_t sample;
|
||||||
const TickType_t delay_ticks = 1; // 1ms @ 1kHz
|
|
||||||
int last_gesture = -1;
|
int last_gesture = -1;
|
||||||
|
int stride_counter = 0;
|
||||||
|
|
||||||
// Reset inference state
|
|
||||||
inference_init();
|
inference_init();
|
||||||
printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n");
|
vote_init();
|
||||||
|
|
||||||
|
int n_models = 1; /* single LDA always enabled */
|
||||||
|
#if MODEL_USE_ENSEMBLE
|
||||||
|
inference_ensemble_init();
|
||||||
|
n_models++;
|
||||||
|
#endif
|
||||||
|
#if MODEL_USE_MLP
|
||||||
|
inference_mlp_init();
|
||||||
|
n_models++;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
printf("{\"status\":\"info\",\"msg\":\"Multi-model inference (%d models)\"}\n",
|
||||||
|
n_models);
|
||||||
|
|
||||||
while (g_device_state == STATE_PREDICTING) {
|
while (g_device_state == STATE_PREDICTING) {
|
||||||
emg_sensor_read(&sample);
|
emg_sensor_read(&sample);
|
||||||
|
|
||||||
// Add to buffer
|
|
||||||
// Note: sample.channels is uint16_t, matching inference engine expectation
|
|
||||||
if (inference_add_sample(sample.channels)) {
|
if (inference_add_sample(sample.channels)) {
|
||||||
// Buffer full (sliding window), run prediction
|
|
||||||
// We can optimize stride here (e.g. valid prediction only every N
|
|
||||||
// samples) For now, let's predict every sample (sliding window) or
|
|
||||||
// throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz?
|
|
||||||
// maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid
|
|
||||||
// UART spam.
|
|
||||||
|
|
||||||
static int stride_counter = 0;
|
|
||||||
stride_counter++;
|
stride_counter++;
|
||||||
|
|
||||||
if (stride_counter >= 20) { // 20ms stride
|
if (stride_counter >= INFERENCE_HOP_SIZE) {
|
||||||
float confidence = 0;
|
|
||||||
int gesture_idx = inference_predict(&confidence);
|
|
||||||
stride_counter = 0;
|
stride_counter = 0;
|
||||||
|
|
||||||
if (gesture_idx >= 0) {
|
/* 1. Extract features once */
|
||||||
// Map class index (0-N) to gesture enum (correct hardware action)
|
float features[MODEL_NUM_FEATURES];
|
||||||
int gesture_enum = inference_get_gesture_enum(gesture_idx);
|
inference_extract_features(features);
|
||||||
|
|
||||||
// Execute gesture on hand
|
/* 2. Collect raw probabilities from each model */
|
||||||
|
float avg_proba[MODEL_NUM_CLASSES];
|
||||||
|
memset(avg_proba, 0, sizeof(avg_proba));
|
||||||
|
|
||||||
|
/* Model A: Single LDA (always active) */
|
||||||
|
float proba_lda[MODEL_NUM_CLASSES];
|
||||||
|
inference_predict_raw(features, proba_lda);
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] += proba_lda[c];
|
||||||
|
|
||||||
|
#if MODEL_USE_ENSEMBLE
|
||||||
|
/* Model B: 3-specialist + meta-LDA ensemble */
|
||||||
|
float proba_ens[MODEL_NUM_CLASSES];
|
||||||
|
inference_ensemble_predict_raw(features, proba_ens);
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] += proba_ens[c];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if MODEL_USE_MLP
|
||||||
|
/* Model C: int8 MLP — returns class index + confidence.
|
||||||
|
* Spread confidence as soft one-hot for averaging. */
|
||||||
|
float mlp_conf = 0.0f;
|
||||||
|
int mlp_class = inference_mlp_predict(features, MODEL_NUM_FEATURES,
|
||||||
|
&mlp_conf);
|
||||||
|
float remainder = (1.0f - mlp_conf) / (MODEL_NUM_CLASSES - 1);
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] += (c == mlp_class) ? mlp_conf : remainder;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/* 3. Average across models */
|
||||||
|
float inv_n = 1.0f / n_models;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] *= inv_n;
|
||||||
|
|
||||||
|
/* 4. Shared post-processing */
|
||||||
|
float confidence = 0.0f;
|
||||||
|
int gesture_idx = vote_postprocess(avg_proba, &confidence);
|
||||||
|
|
||||||
|
if (gesture_idx >= 0) {
|
||||||
|
int gesture_enum = inference_get_gesture_enum(gesture_idx);
|
||||||
gestures_execute((gesture_t)gesture_enum);
|
gestures_execute((gesture_t)gesture_enum);
|
||||||
|
|
||||||
// Send telemetry if changed or periodically?
|
bicep_state_t bicep = bicep_detect();
|
||||||
// "Live prediction flow should change to only have each new output...
|
(void)bicep;
|
||||||
// sent"
|
|
||||||
if (gesture_idx != last_gesture) {
|
if (gesture_idx != last_gesture) {
|
||||||
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
|
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
|
||||||
inference_get_class_name(gesture_idx), confidence);
|
inference_get_class_name(gesture_idx), confidence);
|
||||||
@@ -267,11 +477,192 @@ static void run_inference_loop(void) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
vTaskDelay(delay_ticks);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Laptop-mediated prediction loop (STATE_LAPTOP_PREDICT).
|
||||||
|
*
|
||||||
|
* Streams raw ADC CSV to laptop at 1kHz (same format as STATE_STREAMING).
|
||||||
|
* The laptop runs live_predict.py, which classifies each window and sends
|
||||||
|
* {"gesture":"<name>"} back over UART. serial_input_task() intercepts those
|
||||||
|
* lines and writes to g_laptop_gesture; this loop executes whatever is set.
|
||||||
|
*/
|
||||||
|
static void run_laptop_predict_loop(void) {
|
||||||
|
emg_sample_t sample;
|
||||||
|
gesture_t last_gesture = GESTURE_NONE;
|
||||||
|
|
||||||
|
printf("{\"status\":\"info\",\"msg\":\"Laptop-predict mode started\"}\n");
|
||||||
|
|
||||||
|
while (g_device_state == STATE_LAPTOP_PREDICT) {
|
||||||
|
emg_sensor_read(&sample);
|
||||||
|
// Stream raw CSV to laptop (live_predict.py reads this)
|
||||||
|
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
|
||||||
|
sample.channels[2], sample.channels[3]);
|
||||||
|
|
||||||
|
// Execute any gesture command that arrived from the laptop
|
||||||
|
gesture_t g = g_laptop_gesture;
|
||||||
|
if (g != GESTURE_NONE) {
|
||||||
|
g_laptop_gesture = GESTURE_NONE; // clear before executing
|
||||||
|
gestures_execute(g);
|
||||||
|
if (g != last_gesture) {
|
||||||
|
// Echo the executed gesture back for laptop-side logging
|
||||||
|
printf("{\"executed\":\"%s\"}\n", gestures_get_name(g));
|
||||||
|
last_gesture = g;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Fully autonomous inference loop (MAIN_MODE == EMG_STANDALONE).
|
||||||
|
*
|
||||||
|
* No laptop required. Runs forever until power-off.
|
||||||
|
* Assumes inference_init() and sensor init have already been called by app_main().
|
||||||
|
*/
|
||||||
|
static void run_standalone_loop(void) {
|
||||||
|
emg_sample_t sample;
|
||||||
|
int stride_counter = 0;
|
||||||
|
int last_gesture = -1;
|
||||||
|
|
||||||
|
inference_init();
|
||||||
|
vote_init();
|
||||||
|
|
||||||
|
int n_models = 1;
|
||||||
|
#if MODEL_USE_ENSEMBLE
|
||||||
|
inference_ensemble_init();
|
||||||
|
n_models++;
|
||||||
|
#endif
|
||||||
|
#if MODEL_USE_MLP
|
||||||
|
inference_mlp_init();
|
||||||
|
n_models++;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
printf("[STANDALONE] Running autonomous EMG control (%d models).\n", n_models);
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
emg_sensor_read(&sample);
|
||||||
|
|
||||||
|
if (inference_add_sample(sample.channels)) {
|
||||||
|
stride_counter++;
|
||||||
|
if (stride_counter >= INFERENCE_HOP_SIZE) {
|
||||||
|
stride_counter = 0;
|
||||||
|
|
||||||
|
float features[MODEL_NUM_FEATURES];
|
||||||
|
inference_extract_features(features);
|
||||||
|
|
||||||
|
float avg_proba[MODEL_NUM_CLASSES];
|
||||||
|
memset(avg_proba, 0, sizeof(avg_proba));
|
||||||
|
|
||||||
|
float proba_lda[MODEL_NUM_CLASSES];
|
||||||
|
inference_predict_raw(features, proba_lda);
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] += proba_lda[c];
|
||||||
|
|
||||||
|
#if MODEL_USE_ENSEMBLE
|
||||||
|
float proba_ens[MODEL_NUM_CLASSES];
|
||||||
|
inference_ensemble_predict_raw(features, proba_ens);
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] += proba_ens[c];
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if MODEL_USE_MLP
|
||||||
|
float mlp_conf = 0.0f;
|
||||||
|
int mlp_class = inference_mlp_predict(features, MODEL_NUM_FEATURES,
|
||||||
|
&mlp_conf);
|
||||||
|
float remainder = (1.0f - mlp_conf) / (MODEL_NUM_CLASSES - 1);
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] += (c == mlp_class) ? mlp_conf : remainder;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float inv_n = 1.0f / n_models;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
avg_proba[c] *= inv_n;
|
||||||
|
|
||||||
|
float confidence = 0.0f;
|
||||||
|
int gesture_idx = vote_postprocess(avg_proba, &confidence);
|
||||||
|
|
||||||
|
if (gesture_idx >= 0) {
|
||||||
|
int gesture_enum = inference_get_gesture_enum(gesture_idx);
|
||||||
|
gestures_execute((gesture_t)gesture_enum);
|
||||||
|
|
||||||
|
bicep_state_t bicep = bicep_detect();
|
||||||
|
(void)bicep;
|
||||||
|
|
||||||
|
if (gesture_idx != last_gesture) {
|
||||||
|
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
|
||||||
|
inference_get_class_name(gesture_idx), confidence);
|
||||||
|
last_gesture = gesture_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run rest calibration (STATE_CALIBRATING).
|
||||||
|
*
|
||||||
|
* Collects 3 seconds of rest EMG data, extracts features from each window,
|
||||||
|
* then updates:
|
||||||
|
* - NVS z-score calibration (Change D) via calibration_update()
|
||||||
|
* - Bicep detection threshold via bicep_calibrate_from_buffer()
|
||||||
|
*
|
||||||
|
* The user must keep their arm relaxed during this period.
|
||||||
|
* Send {"cmd": "calibrate"} from the host to trigger.
|
||||||
|
*/
|
||||||
|
static void run_calibration(void) {
|
||||||
|
#define CALIB_DURATION_SAMPLES 3000 /* 3 seconds at 1 kHz */
|
||||||
|
#define CALIB_MAX_WINDOWS \
|
||||||
|
((CALIB_DURATION_SAMPLES - INFERENCE_WINDOW_SIZE) / INFERENCE_HOP_SIZE + 1)
|
||||||
|
|
||||||
|
printf("{\"status\":\"calibrating\",\"duration_ms\":3000}\n");
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
inference_init(); /* reset buffer + filter state */
|
||||||
|
|
||||||
|
float *feat_matrix = (float *)malloc(
|
||||||
|
(size_t)CALIB_MAX_WINDOWS * MODEL_NUM_FEATURES * sizeof(float));
|
||||||
|
if (!feat_matrix) {
|
||||||
|
printf("{\"status\":\"error\",\"msg\":\"Calibration malloc failed\"}\n");
|
||||||
|
g_device_state = STATE_CONNECTED;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
emg_sample_t sample;
|
||||||
|
int window_count = 0;
|
||||||
|
int stride_counter = 0;
|
||||||
|
|
||||||
|
for (int s = 0; s < CALIB_DURATION_SAMPLES; s++) {
|
||||||
|
emg_sensor_read(&sample);
|
||||||
|
if (inference_add_sample(sample.channels)) {
|
||||||
|
stride_counter++;
|
||||||
|
if (stride_counter >= INFERENCE_HOP_SIZE) {
|
||||||
|
stride_counter = 0;
|
||||||
|
if (window_count < CALIB_MAX_WINDOWS) {
|
||||||
|
inference_extract_features(
|
||||||
|
feat_matrix + window_count * MODEL_NUM_FEATURES);
|
||||||
|
window_count++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (window_count >= 10) {
|
||||||
|
calibration_update(feat_matrix, window_count, MODEL_NUM_FEATURES);
|
||||||
|
bicep_calibrate_from_buffer(INFERENCE_WINDOW_SIZE);
|
||||||
|
printf("{\"status\":\"calibrated\",\"windows\":%d}\n", window_count);
|
||||||
|
} else {
|
||||||
|
printf("{\"status\":\"error\",\"msg\":\"Not enough calibration data\"}\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
free(feat_matrix);
|
||||||
|
g_device_state = STATE_CONNECTED;
|
||||||
|
|
||||||
|
#undef CALIB_DURATION_SAMPLES
|
||||||
|
#undef CALIB_MAX_WINDOWS
|
||||||
|
}
|
||||||
|
|
||||||
static void state_machine_loop(void) {
|
static void state_machine_loop(void) {
|
||||||
command_t cmd;
|
command_t cmd;
|
||||||
const TickType_t poll_interval = pdMS_TO_TICKS(50);
|
const TickType_t poll_interval = pdMS_TO_TICKS(50);
|
||||||
@@ -281,6 +672,10 @@ static void state_machine_loop(void) {
|
|||||||
stream_emg_data();
|
stream_emg_data();
|
||||||
} else if (g_device_state == STATE_PREDICTING) {
|
} else if (g_device_state == STATE_PREDICTING) {
|
||||||
run_inference_loop();
|
run_inference_loop();
|
||||||
|
} else if (g_device_state == STATE_LAPTOP_PREDICT) {
|
||||||
|
run_laptop_predict_loop();
|
||||||
|
} else if (g_device_state == STATE_CALIBRATING) {
|
||||||
|
run_calibration();
|
||||||
}
|
}
|
||||||
|
|
||||||
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
|
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
|
||||||
@@ -299,7 +694,9 @@ void appConnector() {
|
|||||||
printf("[PROTOCOL] Waiting for host to connect...\n");
|
printf("[PROTOCOL] Waiting for host to connect...\n");
|
||||||
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n");
|
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n");
|
||||||
printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device "
|
printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device "
|
||||||
"inference\n\n");
|
"inference\n");
|
||||||
|
printf("[PROTOCOL] Send: {\"cmd\": \"calibrate\"} for rest "
|
||||||
|
"calibration (3s)\n\n");
|
||||||
|
|
||||||
state_machine_loop();
|
state_machine_loop();
|
||||||
}
|
}
|
||||||
@@ -324,13 +721,108 @@ void app_main(void) {
|
|||||||
printf("[INIT] Initializing Inference Engine...\n");
|
printf("[INIT] Initializing Inference Engine...\n");
|
||||||
inference_init();
|
inference_init();
|
||||||
|
|
||||||
#if FEATURE_FAKE_EMG
|
printf("[INIT] Loading NVS calibration...\n");
|
||||||
printf("[INIT] Using FAKE EMG data (sensors not connected)\n");
|
calibration_init(); // Change D: no-op on first boot; loads if previously saved
|
||||||
#else
|
|
||||||
printf("[INIT] Using REAL EMG sensors\n");
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
// Bicep: load persisted threshold from NVS (if previously calibrated)
|
||||||
|
{
|
||||||
|
float bicep_thresh = 0.0f;
|
||||||
|
if (bicep_load_threshold(&bicep_thresh)) {
|
||||||
|
printf("[INIT] Bicep threshold loaded: %.1f\n", bicep_thresh);
|
||||||
|
} else {
|
||||||
|
printf("[INIT] No bicep calibration — run 'calibrate' command\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("[INIT] Using REAL EMG sensors\n");
|
||||||
printf("[INIT] Done!\n\n");
|
printf("[INIT] Done!\n\n");
|
||||||
|
|
||||||
appConnector();
|
switch (MAIN_MODE) {
|
||||||
}
|
case EMG_STANDALONE:
|
||||||
|
// Fully autonomous: no laptop needed after this point.
|
||||||
|
// Boots directly into the inference + arm control loop.
|
||||||
|
run_standalone_loop(); // never returns
|
||||||
|
break;
|
||||||
|
|
||||||
|
case SERVO_CALIBRATOR:
|
||||||
|
while (1) {
|
||||||
|
int angle;
|
||||||
|
printf("Enter servo angle (0-180): ");
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
// Read a line manually, yielding while waiting for UART input
|
||||||
|
char buf[16];
|
||||||
|
int idx = 0;
|
||||||
|
while (idx < (int)sizeof(buf) - 1) {
|
||||||
|
int ch = getchar();
|
||||||
|
if (ch == EOF) {
|
||||||
|
vTaskDelay(pdMS_TO_TICKS(10));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (ch == '\n' || ch == '\r')
|
||||||
|
break;
|
||||||
|
buf[idx++] = (char)ch;
|
||||||
|
}
|
||||||
|
buf[idx] = '\0';
|
||||||
|
|
||||||
|
if (idx == 0)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
if (sscanf(buf, "%d", &angle) == 1) {
|
||||||
|
if (angle >= 0 && angle <= 180) {
|
||||||
|
hand_set_finger_angle(FINGER_THUMB, angle);
|
||||||
|
vTaskDelay(pdMS_TO_TICKS(1000));
|
||||||
|
} else {
|
||||||
|
printf("Invalid angle. Must be between 0 and 180.\n");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
printf("Invalid input.\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case GESTURE_TESTER:
|
||||||
|
while (1) {
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
int ch = getchar();
|
||||||
|
|
||||||
|
if (ch == EOF) {
|
||||||
|
vTaskDelay(pdMS_TO_TICKS(10));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ch == '\n' || ch == '\r') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
gesture_t gesture = GESTURE_NONE;
|
||||||
|
|
||||||
|
switch (ch) {
|
||||||
|
case 'r': gesture = GESTURE_REST; break;
|
||||||
|
case 'f': gesture = GESTURE_FIST; break;
|
||||||
|
case 'o': gesture = GESTURE_OPEN; break;
|
||||||
|
case 'h': gesture = GESTURE_HOOK_EM; break;
|
||||||
|
case 't': gesture = GESTURE_THUMBS_UP; break;
|
||||||
|
default:
|
||||||
|
printf("Invalid gesture: %c\n", ch);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("Executing gesture: %s\n", gestures_get_name(gesture));
|
||||||
|
gestures_execute(gesture);
|
||||||
|
|
||||||
|
vTaskDelay(pdMS_TO_TICKS(500));
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
|
||||||
|
case EMG_MAIN:
|
||||||
|
appConnector();
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
printf("[ERROR] Unknown MAIN_MODE\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,19 +13,12 @@
|
|||||||
#include "driver/ledc.h"
|
#include "driver/ledc.h"
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Feature Flags
|
* Main Modes
|
||||||
*
|
|
||||||
* Compile-time switches to enable/disable features.
|
|
||||||
* Set to 1 to enable, 0 to disable.
|
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
/**
|
enum {EMG_MAIN, SERVO_CALIBRATOR, GESTURE_TESTER, EMG_STANDALONE};
|
||||||
* @brief Use fake EMG data (random values) instead of real ADC reads.
|
|
||||||
*
|
#define MAIN_MODE EMG_MAIN
|
||||||
* Set to 1 while waiting for EMG sensors to arrive.
|
|
||||||
* Set to 0 when ready to use real sensors.
|
|
||||||
*/
|
|
||||||
#define FEATURE_FAKE_EMG 0
|
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* GPIO Pin Definitions - Servos
|
* GPIO Pin Definitions - Servos
|
||||||
@@ -95,15 +88,4 @@ typedef enum {
|
|||||||
GESTURE_COUNT
|
GESTURE_COUNT
|
||||||
} gesture_t;
|
} gesture_t;
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief System operating modes.
|
|
||||||
*/
|
|
||||||
typedef enum {
|
|
||||||
MODE_IDLE = 0, /**< Waiting for commands */
|
|
||||||
MODE_DATA_STREAM, /**< Streaming EMG data to laptop */
|
|
||||||
MODE_COMMAND, /**< Executing gesture commands from laptop */
|
|
||||||
MODE_DEMO, /**< Running demo sequence */
|
|
||||||
MODE_COUNT
|
|
||||||
} system_mode_t;
|
|
||||||
|
|
||||||
#endif /* CONFIG_H */
|
#endif /* CONFIG_H */
|
||||||
|
|||||||
142
EMG_Arm/src/core/bicep.c
Normal file
142
EMG_Arm/src/core/bicep.c
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
/**
|
||||||
|
* @file bicep.c
|
||||||
|
* @brief Bicep channel subsystem — binary flex/unflex detector (Phase 1).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "bicep.h"
|
||||||
|
#include "inference.h" /* inference_get_bicep_rms() */
|
||||||
|
#include "nvs_flash.h"
|
||||||
|
#include "nvs.h"
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
/* Tuning constants */
|
||||||
|
#define BICEP_WINDOW_SAMPLES 50 /**< 50 ms window at 1 kHz */
|
||||||
|
#define BICEP_FLEX_MULTIPLIER 2.5f /**< threshold = rest_rms × 2.5 */
|
||||||
|
#define BICEP_HYSTERESIS 1.3f /**< scale factor to enter flex (prevents toggling) */
|
||||||
|
|
||||||
|
/* NVS storage */
|
||||||
|
#define BICEP_NVS_NAMESPACE "bicep_calib"
|
||||||
|
#define BICEP_NVS_KEY_THRESH "threshold"
|
||||||
|
#define BICEP_NVS_KEY_VALID "calib_ok"
|
||||||
|
|
||||||
|
/* Module state */
|
||||||
|
static float s_threshold_mv = 0.0f;
|
||||||
|
static bicep_state_t s_state = BICEP_STATE_REST;
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Public API
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
float bicep_calibrate(const uint16_t *ch3_samples, int n_samples) {
|
||||||
|
if (n_samples <= 0) return 0.0f;
|
||||||
|
|
||||||
|
float rms_sq = 0.0f;
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
float v = (float)ch3_samples[i];
|
||||||
|
rms_sq += v * v;
|
||||||
|
}
|
||||||
|
float rest_rms = sqrtf(rms_sq / n_samples);
|
||||||
|
s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER;
|
||||||
|
s_state = BICEP_STATE_REST;
|
||||||
|
|
||||||
|
printf("[Bicep] Calibrated: rest_rms=%.1f mV, threshold=%.1f mV\n",
|
||||||
|
rest_rms, s_threshold_mv);
|
||||||
|
|
||||||
|
bicep_save_threshold(s_threshold_mv);
|
||||||
|
return s_threshold_mv;
|
||||||
|
}
|
||||||
|
|
||||||
|
bicep_state_t bicep_detect(void) {
|
||||||
|
if (s_threshold_mv <= 0.0f) {
|
||||||
|
return BICEP_STATE_REST; /* Not calibrated */
|
||||||
|
}
|
||||||
|
|
||||||
|
float rms = inference_get_bicep_rms(BICEP_WINDOW_SAMPLES);
|
||||||
|
|
||||||
|
/* Hysteretic threshold: need FLEX_MULTIPLIER × threshold to enter flex,
|
||||||
|
* drop below threshold to return to rest. */
|
||||||
|
if (s_state == BICEP_STATE_REST) {
|
||||||
|
if (rms > s_threshold_mv * BICEP_HYSTERESIS) {
|
||||||
|
s_state = BICEP_STATE_FLEX;
|
||||||
|
}
|
||||||
|
} else { /* BICEP_STATE_FLEX */
|
||||||
|
if (rms < s_threshold_mv) {
|
||||||
|
s_state = BICEP_STATE_REST;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return s_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bicep_save_threshold(float threshold_mv) {
|
||||||
|
nvs_handle_t h;
|
||||||
|
if (nvs_open(BICEP_NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) {
|
||||||
|
printf("[Bicep] Failed to open NVS for write\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
esp_err_t err = ESP_OK;
|
||||||
|
err |= nvs_set_blob(h, BICEP_NVS_KEY_THRESH, &threshold_mv, sizeof(threshold_mv));
|
||||||
|
err |= nvs_set_u8 (h, BICEP_NVS_KEY_VALID, 1u);
|
||||||
|
err |= nvs_commit(h);
|
||||||
|
nvs_close(h);
|
||||||
|
|
||||||
|
if (err != ESP_OK) {
|
||||||
|
printf("[Bicep] NVS write failed (err=0x%x)\n", err);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
printf("[Bicep] Threshold %.1f mV saved to NVS\n", threshold_mv);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool bicep_load_threshold(float *threshold_mv_out) {
|
||||||
|
nvs_handle_t h;
|
||||||
|
if (nvs_open(BICEP_NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t valid = 0;
|
||||||
|
float thresh = 0.0f;
|
||||||
|
size_t sz = sizeof(thresh);
|
||||||
|
|
||||||
|
bool ok = (nvs_get_u8 (h, BICEP_NVS_KEY_VALID, &valid) == ESP_OK) &&
|
||||||
|
(valid == 1) &&
|
||||||
|
(nvs_get_blob(h, BICEP_NVS_KEY_THRESH, &thresh, &sz) == ESP_OK) &&
|
||||||
|
(thresh > 0.0f);
|
||||||
|
nvs_close(h);
|
||||||
|
|
||||||
|
if (ok) {
|
||||||
|
s_threshold_mv = thresh;
|
||||||
|
if (threshold_mv_out) *threshold_mv_out = thresh;
|
||||||
|
printf("[Bicep] Loaded threshold: %.1f mV\n", thresh);
|
||||||
|
}
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
float bicep_calibrate_from_buffer(int n_samples) {
|
||||||
|
float rest_rms = inference_get_bicep_rms(n_samples);
|
||||||
|
if (rest_rms < 1e-6f) {
|
||||||
|
printf("[Bicep] WARNING: rest RMS ≈ 0 — buffer may not be filled yet\n");
|
||||||
|
return 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
s_threshold_mv = rest_rms * BICEP_FLEX_MULTIPLIER;
|
||||||
|
s_state = BICEP_STATE_REST;
|
||||||
|
|
||||||
|
printf("[Bicep] Calibrated (filtered): rest_rms=%.2f, threshold=%.2f\n",
|
||||||
|
rest_rms, s_threshold_mv);
|
||||||
|
|
||||||
|
bicep_save_threshold(s_threshold_mv);
|
||||||
|
return s_threshold_mv;
|
||||||
|
}
|
||||||
|
|
||||||
|
void bicep_set_threshold(float threshold_mv) {
|
||||||
|
s_threshold_mv = threshold_mv;
|
||||||
|
s_state = BICEP_STATE_REST;
|
||||||
|
}
|
||||||
|
|
||||||
|
float bicep_get_threshold(void) {
|
||||||
|
return s_threshold_mv;
|
||||||
|
}
|
||||||
97
EMG_Arm/src/core/bicep.h
Normal file
97
EMG_Arm/src/core/bicep.h
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
/**
|
||||||
|
* @file bicep.h
|
||||||
|
* @brief Bicep channel (ch3) subsystem — Phase 1: binary flex/unflex detector.
|
||||||
|
*
|
||||||
|
* Implements a simple RMS threshold detector with hysteresis for bicep activation.
|
||||||
|
* ch3 data flows through the same IIR bandpass filter and circular buffer as the
|
||||||
|
* hand gesture channels (via inference_get_bicep_rms()), so no separate ADC read
|
||||||
|
* is required.
|
||||||
|
*
|
||||||
|
* Usage:
|
||||||
|
* 1. On startup: bicep_load_threshold(&thresh) — restore persisted threshold
|
||||||
|
* 2. After 3 s of relaxed rest:
|
||||||
|
* bicep_calibrate(raw_ch3_samples, n_samples) — sets + saves threshold
|
||||||
|
* 3. Every 25 ms hop:
|
||||||
|
* bicep_state_t state = bicep_detect();
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef BICEP_H
|
||||||
|
#define BICEP_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Bicep activation state.
|
||||||
|
*/
|
||||||
|
typedef enum {
|
||||||
|
BICEP_STATE_REST = 0,
|
||||||
|
BICEP_STATE_FLEX = 1,
|
||||||
|
} bicep_state_t;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calibrate bicep threshold from REST data.
|
||||||
|
*
|
||||||
|
* Computes rest-RMS over the provided samples, then sets the internal
|
||||||
|
* detection threshold to rest_rms × BICEP_FLEX_MULTIPLIER.
|
||||||
|
*
|
||||||
|
* @param ch3_samples Raw ADC / mV values from the bicep channel.
|
||||||
|
* @param n_samples Number of samples provided.
|
||||||
|
* @return Computed threshold in the same units as ch3_samples.
|
||||||
|
*/
|
||||||
|
float bicep_calibrate(const uint16_t *ch3_samples, int n_samples);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Detect current bicep state from the latest window.
|
||||||
|
*
|
||||||
|
* Uses inference_get_bicep_rms(BICEP_WINDOW_SAMPLES) internally, so
|
||||||
|
* inference_add_sample() must have been called to fill the buffer first.
|
||||||
|
*
|
||||||
|
* @return BICEP_STATE_FLEX or BICEP_STATE_REST.
|
||||||
|
*/
|
||||||
|
bicep_state_t bicep_detect(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Persist the current threshold to NVS.
|
||||||
|
*
|
||||||
|
* @param threshold_mv Threshold value to save (in mV / same units as bicep RMS).
|
||||||
|
* @return true on success.
|
||||||
|
*/
|
||||||
|
bool bicep_save_threshold(float threshold_mv);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Load the persisted threshold from NVS.
|
||||||
|
*
|
||||||
|
* @param threshold_mv_out Output pointer; untouched on failure.
|
||||||
|
* @return true if a valid threshold was loaded.
|
||||||
|
*/
|
||||||
|
bool bicep_load_threshold(float *threshold_mv_out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Set the detection threshold directly (without NVS save).
|
||||||
|
*/
|
||||||
|
void bicep_set_threshold(float threshold_mv);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return the current threshold (0 if not calibrated).
|
||||||
|
*/
|
||||||
|
float bicep_get_threshold(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Calibrate bicep threshold from the filtered inference buffer.
|
||||||
|
*
|
||||||
|
* Uses inference_get_bicep_rms() to read from the bandpass-filtered
|
||||||
|
* circular buffer — the same data source that bicep_detect() uses.
|
||||||
|
* The old bicep_calibrate() accepts raw uint16_t ADC values, which are
|
||||||
|
* in a different domain (includes DC offset) and produce unusable thresholds.
|
||||||
|
*
|
||||||
|
* Call this after the inference buffer has been filled with ≥ n_samples
|
||||||
|
* of rest data via inference_add_sample().
|
||||||
|
*
|
||||||
|
* @param n_samples Number of recent buffer samples to use for RMS.
|
||||||
|
* Clamped to INFERENCE_WINDOW_SIZE internally.
|
||||||
|
* @return Computed threshold (same units as bicep_detect sees).
|
||||||
|
*/
|
||||||
|
float bicep_calibrate_from_buffer(int n_samples);
|
||||||
|
|
||||||
|
#endif /* BICEP_H */
|
||||||
138
EMG_Arm/src/core/calibration.c
Normal file
138
EMG_Arm/src/core/calibration.c
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
/**
|
||||||
|
* @file calibration.c
|
||||||
|
* @brief NVS-backed z-score feature calibration (Change D).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "calibration.h"
|
||||||
|
#include "nvs_flash.h"
|
||||||
|
#include "nvs.h"
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#define NVS_NAMESPACE "emg_calib"
|
||||||
|
#define NVS_KEY_MEAN "feat_mean"
|
||||||
|
#define NVS_KEY_STD "feat_std"
|
||||||
|
#define NVS_KEY_NFEAT "n_feat"
|
||||||
|
#define NVS_KEY_VALID "calib_ok"
|
||||||
|
|
||||||
|
static float s_mean[CALIB_MAX_FEATURES];
|
||||||
|
static float s_std[CALIB_MAX_FEATURES];
|
||||||
|
static int s_n_feat = 0;
|
||||||
|
static bool s_valid = false;
|
||||||
|
|
||||||
|
bool calibration_init(void) {
|
||||||
|
/* Standard NVS flash initialisation boilerplate */
|
||||||
|
esp_err_t err = nvs_flash_init();
|
||||||
|
if (err == ESP_ERR_NVS_NO_FREE_PAGES || err == ESP_ERR_NVS_NEW_VERSION_FOUND) {
|
||||||
|
nvs_flash_erase();
|
||||||
|
nvs_flash_init();
|
||||||
|
}
|
||||||
|
|
||||||
|
nvs_handle_t h;
|
||||||
|
if (nvs_open(NVS_NAMESPACE, NVS_READONLY, &h) != ESP_OK) {
|
||||||
|
printf("[Calib] No NVS partition found — identity transform active\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint8_t valid = 0;
|
||||||
|
int32_t n_feat = 0;
|
||||||
|
size_t mean_sz = sizeof(s_mean);
|
||||||
|
size_t std_sz = sizeof(s_std);
|
||||||
|
|
||||||
|
bool ok = (nvs_get_u8 (h, NVS_KEY_VALID, &valid) == ESP_OK) &&
|
||||||
|
(valid == 1) &&
|
||||||
|
(nvs_get_i32 (h, NVS_KEY_NFEAT, &n_feat) == ESP_OK) &&
|
||||||
|
(n_feat > 0 && n_feat <= CALIB_MAX_FEATURES) &&
|
||||||
|
(nvs_get_blob(h, NVS_KEY_MEAN, s_mean, &mean_sz) == ESP_OK) &&
|
||||||
|
(nvs_get_blob(h, NVS_KEY_STD, s_std, &std_sz) == ESP_OK);
|
||||||
|
|
||||||
|
nvs_close(h);
|
||||||
|
|
||||||
|
if (ok) {
|
||||||
|
s_n_feat = (int)n_feat;
|
||||||
|
s_valid = true;
|
||||||
|
printf("[Calib] Loaded from NVS (%d features)\n", s_n_feat);
|
||||||
|
} else {
|
||||||
|
printf("[Calib] No valid calibration in NVS — identity transform active\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
void calibration_apply(float *feat) {
|
||||||
|
if (!s_valid) return;
|
||||||
|
for (int i = 0; i < s_n_feat; i++) {
|
||||||
|
feat[i] = (feat[i] - s_mean[i]) / s_std[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool calibration_update(const float *X_flat, int n_windows, int n_feat) {
|
||||||
|
if (n_windows < 10 || n_feat <= 0 || n_feat > CALIB_MAX_FEATURES) {
|
||||||
|
printf("[Calib] calibration_update: invalid args (%d windows, %d features)\n",
|
||||||
|
n_windows, n_feat);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
s_n_feat = n_feat;
|
||||||
|
|
||||||
|
/* Compute per-feature mean */
|
||||||
|
memset(s_mean, 0, sizeof(s_mean));
|
||||||
|
for (int w = 0; w < n_windows; w++) {
|
||||||
|
for (int f = 0; f < n_feat; f++) {
|
||||||
|
s_mean[f] += X_flat[w * n_feat + f];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int f = 0; f < n_feat; f++) {
|
||||||
|
s_mean[f] /= n_windows;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Compute per-feature std (with epsilon floor) */
|
||||||
|
memset(s_std, 0, sizeof(s_std));
|
||||||
|
for (int w = 0; w < n_windows; w++) {
|
||||||
|
for (int f = 0; f < n_feat; f++) {
|
||||||
|
float d = X_flat[w * n_feat + f] - s_mean[f];
|
||||||
|
s_std[f] += d * d;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int f = 0; f < n_feat; f++) {
|
||||||
|
float var = s_std[f] / n_windows;
|
||||||
|
s_std[f] = (var > 1e-12f) ? sqrtf(var) : 1e-6f;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Persist to NVS */
|
||||||
|
nvs_handle_t h;
|
||||||
|
if (nvs_open(NVS_NAMESPACE, NVS_READWRITE, &h) != ESP_OK) {
|
||||||
|
printf("[Calib] calibration_update: failed to open NVS\n");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
esp_err_t err = ESP_OK;
|
||||||
|
err |= nvs_set_blob(h, NVS_KEY_MEAN, s_mean, sizeof(s_mean));
|
||||||
|
err |= nvs_set_blob(h, NVS_KEY_STD, s_std, sizeof(s_std));
|
||||||
|
err |= nvs_set_i32 (h, NVS_KEY_NFEAT, (int32_t)n_feat);
|
||||||
|
err |= nvs_set_u8 (h, NVS_KEY_VALID, 1u);
|
||||||
|
err |= nvs_commit(h);
|
||||||
|
nvs_close(h);
|
||||||
|
|
||||||
|
if (err != ESP_OK) {
|
||||||
|
printf("[Calib] calibration_update: NVS write failed (err=0x%x)\n", err);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
s_valid = true;
|
||||||
|
printf("[Calib] Updated: %d REST windows, %d features saved to NVS\n",
|
||||||
|
n_windows, n_feat);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void calibration_reset(void) {
|
||||||
|
s_valid = false;
|
||||||
|
s_n_feat = 0;
|
||||||
|
memset(s_mean, 0, sizeof(s_mean));
|
||||||
|
memset(s_std, 0, sizeof(s_std));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool calibration_is_valid(void) {
|
||||||
|
return s_valid;
|
||||||
|
}
|
||||||
68
EMG_Arm/src/core/calibration.h
Normal file
68
EMG_Arm/src/core/calibration.h
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
/**
|
||||||
|
* @file calibration.h
|
||||||
|
* @brief NVS-backed z-score feature calibration for on-device EMG inference.
|
||||||
|
*
|
||||||
|
* Change D: Stores per-feature mean and std computed from a short REST session
|
||||||
|
* in ESP32 non-volatile storage (NVS). At inference time, calibration_apply()
|
||||||
|
* z-scores each feature vector before the LDA classifier sees it. This removes
|
||||||
|
* day-to-day electrode placement drift without retraining the model.
|
||||||
|
*
|
||||||
|
* Typical workflow:
|
||||||
|
* 1. calibration_init() — called once at startup; loads from NVS
|
||||||
|
* 2. calibration_update() — called after collecting ~3s of REST windows
|
||||||
|
* 3. calibration_apply(feat) — called every inference hop in inference.c
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef CALIBRATION_H
|
||||||
|
#define CALIBRATION_H
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
/* Maximum supported feature vector length.
|
||||||
|
* 96 > 69 (expanded) > 12 (legacy) — gives headroom for future expansion. */
|
||||||
|
#define CALIB_MAX_FEATURES 96
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initialise NVS and load stored calibration statistics.
|
||||||
|
*
|
||||||
|
* Must be called once before any inference starts (e.g., in app_main).
|
||||||
|
* @return true Calibration data found and loaded.
|
||||||
|
* @return false No stored data; calibration_apply() will be a no-op.
|
||||||
|
*/
|
||||||
|
bool calibration_init(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Apply stored z-score calibration to a feature vector in-place.
|
||||||
|
*
|
||||||
|
* x_i_out = (x_i - mean_i) / std_i
|
||||||
|
*
|
||||||
|
* No-op if calibration is not valid (calibration_is_valid() == false).
|
||||||
|
* @param feat Feature vector of length n_feat (set during calibration_update).
|
||||||
|
*/
|
||||||
|
void calibration_apply(float *feat);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute and persist calibration statistics from REST EMG windows.
|
||||||
|
*
|
||||||
|
* Computes per-feature mean and std over n_windows windows, stores to NVS.
|
||||||
|
* After this call, calibration_apply() uses the new statistics.
|
||||||
|
*
|
||||||
|
* @param X_flat Flattened feature array [n_windows × n_feat], row-major.
|
||||||
|
* @param n_windows Number of windows (minimum 10).
|
||||||
|
* @param n_feat Feature vector length (≤ CALIB_MAX_FEATURES).
|
||||||
|
* @return true on success, false if inputs are invalid or NVS write fails.
|
||||||
|
*/
|
||||||
|
bool calibration_update(const float *X_flat, int n_windows, int n_feat);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Clear calibration state (in-memory only; does not erase NVS).
|
||||||
|
*/
|
||||||
|
void calibration_reset(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Check whether valid calibration statistics are loaded.
|
||||||
|
*/
|
||||||
|
bool calibration_is_valid(void);
|
||||||
|
|
||||||
|
#endif /* CALIBRATION_H */
|
||||||
6
EMG_Arm/src/core/emg_model_data.cc
Normal file
6
EMG_Arm/src/core/emg_model_data.cc
Normal file
File diff suppressed because one or more lines are too long
13
EMG_Arm/src/core/emg_model_data.h
Normal file
13
EMG_Arm/src/core/emg_model_data.h
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// Auto-generated by train_mlp_tflite.py — do not edit
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
extern const unsigned char g_model[];
|
||||||
|
extern const int g_model_len;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -10,6 +10,7 @@
|
|||||||
#include <freertos/FreeRTOS.h>
|
#include <freertos/FreeRTOS.h>
|
||||||
#include <freertos/task.h>
|
#include <freertos/task.h>
|
||||||
|
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Private Data
|
* Private Data
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
@@ -51,6 +52,19 @@ void gestures_execute(gesture_t gesture)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
gesture_t parse_gesture(const char *s)
|
||||||
|
{
|
||||||
|
if (strcmp(s, "rest") == 0) return GESTURE_REST;
|
||||||
|
if (strcmp(s, "fist") == 0) return GESTURE_FIST;
|
||||||
|
if (strcmp(s, "open") == 0) return GESTURE_OPEN;
|
||||||
|
if (strcmp(s, "hook-em") == 0 ||
|
||||||
|
strcmp(s, "hookem") == 0) return GESTURE_HOOK_EM;
|
||||||
|
if (strcmp(s, "thumbs-up") == 0 ||
|
||||||
|
strcmp(s, "thumbsup") == 0) return GESTURE_THUMBS_UP;
|
||||||
|
|
||||||
|
return GESTURE_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
const char* gestures_get_name(gesture_t gesture)
|
const char* gestures_get_name(gesture_t gesture)
|
||||||
{
|
{
|
||||||
if (gesture >= GESTURE_COUNT) {
|
if (gesture >= GESTURE_COUNT) {
|
||||||
@@ -65,38 +79,49 @@ const char* gestures_get_name(gesture_t gesture)
|
|||||||
|
|
||||||
void gesture_open(void)
|
void gesture_open(void)
|
||||||
{
|
{
|
||||||
hand_unflex_all();
|
hand_set_finger_angle(FINGER_THUMB, minAngles[FINGER_THUMB]);
|
||||||
|
hand_set_finger_angle(FINGER_INDEX, minAngles[FINGER_INDEX]);
|
||||||
|
hand_set_finger_angle(FINGER_MIDDLE, minAngles[FINGER_MIDDLE]);
|
||||||
|
hand_set_finger_angle(FINGER_RING, minAngles[FINGER_RING]);
|
||||||
|
hand_set_finger_angle(FINGER_PINKY, minAngles[FINGER_PINKY]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gesture_fist(void)
|
void gesture_fist(void)
|
||||||
{
|
{
|
||||||
hand_flex_all();
|
hand_set_finger_angle(FINGER_INDEX, maxAngles[FINGER_INDEX]);
|
||||||
|
hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]);
|
||||||
|
hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]);
|
||||||
|
hand_set_finger_angle(FINGER_PINKY, maxAngles[FINGER_PINKY]);
|
||||||
|
hand_set_finger_angle(FINGER_THUMB, maxAngles[FINGER_THUMB]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gesture_hook_em(void)
|
void gesture_hook_em(void)
|
||||||
{
|
{
|
||||||
/* Index and pinky extended, others flexed */
|
/* Index and pinky extended, others flexed */
|
||||||
hand_flex_finger(FINGER_THUMB);
|
hand_set_finger_angle(FINGER_THUMB, maxAngles[FINGER_THUMB]);
|
||||||
hand_unflex_finger(FINGER_INDEX);
|
hand_set_finger_angle(FINGER_INDEX, minAngles[FINGER_INDEX]);
|
||||||
hand_flex_finger(FINGER_MIDDLE);
|
hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]);
|
||||||
hand_flex_finger(FINGER_RING);
|
hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]);
|
||||||
hand_unflex_finger(FINGER_PINKY);
|
hand_set_finger_angle(FINGER_PINKY, minAngles[FINGER_PINKY]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gesture_thumbs_up(void)
|
void gesture_thumbs_up(void)
|
||||||
{
|
{
|
||||||
/* Thumb extended, others flexed */
|
/* Thumb extended, others flexed */
|
||||||
hand_unflex_finger(FINGER_THUMB);
|
hand_set_finger_angle(FINGER_THUMB, minAngles[FINGER_THUMB]);
|
||||||
hand_flex_finger(FINGER_INDEX);
|
hand_set_finger_angle(FINGER_INDEX, maxAngles[FINGER_INDEX]);
|
||||||
hand_flex_finger(FINGER_MIDDLE);
|
hand_set_finger_angle(FINGER_MIDDLE, maxAngles[FINGER_MIDDLE]);
|
||||||
hand_flex_finger(FINGER_RING);
|
hand_set_finger_angle(FINGER_RING, maxAngles[FINGER_RING]);
|
||||||
hand_flex_finger(FINGER_PINKY);
|
hand_set_finger_angle(FINGER_PINKY, maxAngles[FINGER_PINKY]);
|
||||||
}
|
}
|
||||||
|
|
||||||
void gesture_rest(void)
|
void gesture_rest(void)
|
||||||
{
|
{
|
||||||
/* Rest is same as open - neutral position */
|
hand_set_finger_angle(FINGER_THUMB, (maxAngles[FINGER_THUMB] + minAngles[FINGER_THUMB])/2);
|
||||||
gesture_open();
|
hand_set_finger_angle(FINGER_INDEX, (maxAngles[FINGER_INDEX] + minAngles[FINGER_INDEX])/2);
|
||||||
|
hand_set_finger_angle(FINGER_MIDDLE, (maxAngles[FINGER_MIDDLE] + minAngles[FINGER_MIDDLE])/2);
|
||||||
|
hand_set_finger_angle(FINGER_RING, (maxAngles[FINGER_RING] + minAngles[FINGER_RING])/2);
|
||||||
|
hand_set_finger_angle(FINGER_PINKY, (maxAngles[FINGER_PINKY] + minAngles[FINGER_PINKY])/2);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
#define GESTURES_H
|
#define GESTURES_H
|
||||||
|
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include <string.h>
|
||||||
#include "config/config.h"
|
#include "config/config.h"
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
@@ -26,6 +27,8 @@
|
|||||||
*/
|
*/
|
||||||
void gestures_execute(gesture_t gesture);
|
void gestures_execute(gesture_t gesture);
|
||||||
|
|
||||||
|
gesture_t parse_gesture(const char *s);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Get the name of a gesture as a string.
|
* @brief Get the name of a gesture as a string.
|
||||||
*
|
*
|
||||||
|
|||||||
@@ -4,19 +4,52 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "inference.h"
|
#include "inference.h"
|
||||||
|
#include "calibration.h"
|
||||||
#include "config/config.h"
|
#include "config/config.h"
|
||||||
#include "model_weights.h"
|
#include "model_weights.h"
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
|
#if MODEL_EXPAND_FEATURES
|
||||||
|
#include "dsps_fft2r.h" /* esp-dsp: complex 2-radix FFT */
|
||||||
|
#define FFT_N 256 /* Must match Python fft_n=256 */
|
||||||
|
#endif
|
||||||
|
|
||||||
// --- Constants ---
|
// --- Constants ---
|
||||||
#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python)
|
#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python)
|
||||||
#define VOTE_WINDOW 5 // Majority vote window size
|
#define VOTE_WINDOW 5 // Majority vote window size
|
||||||
#define DEBOUNCE_COUNT 3 // Confirmations needed to change output
|
#define DEBOUNCE_COUNT 3 // Confirmations needed to change output
|
||||||
|
// Change C: confidence rejection threshold.
|
||||||
|
// When the peak smoothed probability stays below this, hold the last confirmed
|
||||||
|
// output rather than outputting an uncertain prediction. Prevents false arm
|
||||||
|
// actuations during gesture transitions, rest-to-gesture onset, and electrode lift.
|
||||||
|
// Meta paper uses 0.35; 0.40 adds a prosthetic safety margin.
|
||||||
|
// Tune down to 0.35 if real gestures are being incorrectly rejected.
|
||||||
|
#define CONFIDENCE_THRESHOLD 0.40f
|
||||||
|
|
||||||
|
// --- Change B: IIR Bandpass Filter (20–450 Hz, 2nd-order Butterworth @ 1 kHz) ---
|
||||||
|
// Two cascaded biquad sections, Direct Form II Transposed.
|
||||||
|
// Computed via scipy.signal.butter(2, [20,450], btype='bandpass', fs=1000, output='sos').
|
||||||
|
// b coefficients [b0, b1, b2] per section:
|
||||||
|
#define IIR_NUM_SECTIONS 2
|
||||||
|
static const float IIR_B[IIR_NUM_SECTIONS][3] = {
|
||||||
|
{ 0.7320224766f, 1.4640449531f, 0.7320224766f }, /* section 0 */
|
||||||
|
{ 1.0000000000f, -2.0000000000f, 1.0000000000f }, /* section 1 */
|
||||||
|
};
|
||||||
|
// Feedback coefficients [a1, a2] per section (a0 = 1, implicit):
|
||||||
|
static const float IIR_A[IIR_NUM_SECTIONS][2] = {
|
||||||
|
{ 1.5597081442f, 0.6416146818f }, /* section 0 */
|
||||||
|
{ -1.8224796027f, 0.8372542588f }, /* section 1 */
|
||||||
|
};
|
||||||
|
|
||||||
// --- State ---
|
// --- State ---
|
||||||
static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS];
|
static float window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS];
|
||||||
|
static float biquad_w[IIR_NUM_SECTIONS][NUM_CHANNELS][2]; /* biquad delay states */
|
||||||
|
|
||||||
|
#if MODEL_EXPAND_FEATURES
|
||||||
|
static bool s_fft_ready = false;
|
||||||
|
#endif
|
||||||
static int buffer_head = 0;
|
static int buffer_head = 0;
|
||||||
static int samples_collected = 0;
|
static int samples_collected = 0;
|
||||||
|
|
||||||
@@ -30,9 +63,17 @@ static int pending_count = 0;
|
|||||||
|
|
||||||
void inference_init(void) {
|
void inference_init(void) {
|
||||||
memset(window_buffer, 0, sizeof(window_buffer));
|
memset(window_buffer, 0, sizeof(window_buffer));
|
||||||
|
memset(biquad_w, 0, sizeof(biquad_w));
|
||||||
buffer_head = 0;
|
buffer_head = 0;
|
||||||
samples_collected = 0;
|
samples_collected = 0;
|
||||||
|
|
||||||
|
#if MODEL_EXPAND_FEATURES
|
||||||
|
if (!s_fft_ready) {
|
||||||
|
dsps_fft2r_init_fc32(NULL, FFT_N);
|
||||||
|
s_fft_ready = true;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// Initialize smoothing
|
// Initialize smoothing
|
||||||
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
|
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
|
||||||
smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES;
|
smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES;
|
||||||
@@ -47,9 +88,17 @@ void inference_init(void) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool inference_add_sample(uint16_t *channels) {
|
bool inference_add_sample(uint16_t *channels) {
|
||||||
// Add to circular buffer
|
// Convert to float, apply per-channel biquad bandpass, then store in circular buffer.
|
||||||
for (int i = 0; i < NUM_CHANNELS; i++) {
|
for (int i = 0; i < NUM_CHANNELS; i++) {
|
||||||
window_buffer[buffer_head][i] = channels[i];
|
float x = (float)channels[i];
|
||||||
|
// Cascade IIR_NUM_SECTIONS biquad sections (Direct Form II Transposed)
|
||||||
|
for (int s = 0; s < IIR_NUM_SECTIONS; s++) {
|
||||||
|
float y = IIR_B[s][0] * x + biquad_w[s][i][0];
|
||||||
|
biquad_w[s][i][0] = IIR_B[s][1] * x - IIR_A[s][0] * y + biquad_w[s][i][1];
|
||||||
|
biquad_w[s][i][1] = IIR_B[s][2] * x - IIR_A[s][1] * y;
|
||||||
|
x = y;
|
||||||
|
}
|
||||||
|
window_buffer[buffer_head][i] = x;
|
||||||
}
|
}
|
||||||
|
|
||||||
buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE;
|
buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE;
|
||||||
@@ -65,11 +114,265 @@ bool inference_add_sample(uint16_t *channels) {
|
|||||||
|
|
||||||
// --- Feature Extraction ---
|
// --- Feature Extraction ---
|
||||||
|
|
||||||
static void compute_features(float *features_out) {
|
/* ── helpers used by compute_features_expanded ──────────────────────────── */
|
||||||
// Process each channel
|
|
||||||
// We need to iterate over the logical window (unrolling circular buffer)
|
|
||||||
|
|
||||||
for (int ch = 0; ch < NUM_CHANNELS; ch++) {
|
#if MODEL_EXPAND_FEATURES
|
||||||
|
|
||||||
|
/** Solve 4×4 linear system A·x = b via Gaussian elimination with partial pivoting.
|
||||||
|
* Returns false and leaves x zeroed if the matrix is singular. */
|
||||||
|
static bool solve_4x4(float A[4][4], const float b[4], float x[4]) {
|
||||||
|
float M[4][5];
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
for (int j = 0; j < 4; j++) M[i][j] = A[i][j];
|
||||||
|
M[i][4] = b[i];
|
||||||
|
}
|
||||||
|
for (int col = 0; col < 4; col++) {
|
||||||
|
int pivot = col;
|
||||||
|
float maxv = fabsf(M[col][col]);
|
||||||
|
for (int row = col + 1; row < 4; row++) {
|
||||||
|
if (fabsf(M[row][col]) > maxv) { maxv = fabsf(M[row][col]); pivot = row; }
|
||||||
|
}
|
||||||
|
if (maxv < 1e-8f) { x[0]=x[1]=x[2]=x[3]=0.f; return false; }
|
||||||
|
if (pivot != col) {
|
||||||
|
float tmp[5];
|
||||||
|
memcpy(tmp, M[col], 5*sizeof(float));
|
||||||
|
memcpy(M[col], M[pivot], 5*sizeof(float));
|
||||||
|
memcpy(M[pivot], tmp, 5*sizeof(float));
|
||||||
|
}
|
||||||
|
for (int row = col + 1; row < 4; row++) {
|
||||||
|
float f = M[row][col] / M[col][col];
|
||||||
|
for (int j = col; j < 5; j++) M[row][j] -= f * M[col][j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int row = 3; row >= 0; row--) {
|
||||||
|
x[row] = M[row][4];
|
||||||
|
for (int j = row + 1; j < 4; j++) x[row] -= M[row][j] * x[j];
|
||||||
|
x[row] /= M[row][row];
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Full 69-feature extraction (Change 1).
|
||||||
|
*
|
||||||
|
* Per channel (20): RMS, WL, ZC, SSC, MAV, VAR, IEMG, WAMP,
|
||||||
|
* AR1-AR4, MNF, MDF, PKF, MNP, BP0-BP3
|
||||||
|
* Cross-channel (9 for 3 channels): corr, log-RMS-ratio, cov for each pair
|
||||||
|
*
|
||||||
|
* Feature layout must exactly match EMGFeatureExtractor._EXPANDED_KEYS +
|
||||||
|
* cross-channel order in the Python class.
|
||||||
|
*/
|
||||||
|
static void compute_features_expanded(float *features_out) {
|
||||||
|
memset(features_out, 0, MODEL_NUM_FEATURES * sizeof(float));
|
||||||
|
|
||||||
|
/* Persistent buffers for centered signals (3 ch × 150 samples) */
|
||||||
|
static float ch_signals[HAND_NUM_CHANNELS][INFERENCE_WINDOW_SIZE];
|
||||||
|
static float s_fft_buf[FFT_N * 2]; /* Complex interleaved [re,im,...] */
|
||||||
|
|
||||||
|
float ch_rms[HAND_NUM_CHANNELS];
|
||||||
|
float norm_sq = 0.0f;
|
||||||
|
|
||||||
|
/* ──────────────────────────────────────────────────────────────────────
|
||||||
|
* Pass 1: per-channel TD + AR + spectral features
|
||||||
|
* ────────────────────────────────────────────────────────────────────── */
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
/* Read channel from circular buffer (oldest-first) */
|
||||||
|
float *sig = ch_signals[ch];
|
||||||
|
float sum = 0.0f;
|
||||||
|
int idx = buffer_head;
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
sig[i] = window_buffer[idx][ch];
|
||||||
|
sum += sig[i];
|
||||||
|
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
|
||||||
|
}
|
||||||
|
/* DC removal */
|
||||||
|
float mean = sum / INFERENCE_WINDOW_SIZE;
|
||||||
|
float sq_sum = 0.0f;
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
sig[i] -= mean;
|
||||||
|
sq_sum += sig[i] * sig[i];
|
||||||
|
}
|
||||||
|
/* Change 4 — Reinhard tone-mapping: 64·x / (32 + |x|) */
|
||||||
|
#if MODEL_USE_REINHARD
|
||||||
|
sq_sum = 0.0f;
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
float x = sig[i];
|
||||||
|
sig[i] = 64.0f * x / (32.0f + fabsf(x));
|
||||||
|
sq_sum += sig[i] * sig[i];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
|
||||||
|
ch_rms[ch] = rms;
|
||||||
|
norm_sq += rms * rms;
|
||||||
|
|
||||||
|
float zc_thresh = FEAT_ZC_THRESH * rms;
|
||||||
|
float ssc_thresh = (FEAT_SSC_THRESH * rms) * (FEAT_SSC_THRESH * rms);
|
||||||
|
|
||||||
|
/* TD features */
|
||||||
|
float wl = 0.0f, mav = 0.0f, iemg = 0.0f;
|
||||||
|
int zc = 0, ssc = 0, wamp = 0;
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
float a = fabsf(sig[i]);
|
||||||
|
mav += a;
|
||||||
|
iemg += a;
|
||||||
|
}
|
||||||
|
mav /= INFERENCE_WINDOW_SIZE;
|
||||||
|
float var_val = sq_sum / INFERENCE_WINDOW_SIZE; /* variance (mean already 0) */
|
||||||
|
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) {
|
||||||
|
float diff = sig[i+1] - sig[i];
|
||||||
|
float adiff = fabsf(diff);
|
||||||
|
wl += adiff;
|
||||||
|
if (adiff > zc_thresh) wamp++;
|
||||||
|
if ((sig[i] > 0.0f && sig[i+1] < 0.0f) ||
|
||||||
|
(sig[i] < 0.0f && sig[i+1] > 0.0f)) {
|
||||||
|
if (adiff > zc_thresh) zc++;
|
||||||
|
}
|
||||||
|
if (i < INFERENCE_WINDOW_SIZE - 2) {
|
||||||
|
float d1 = sig[i+1] - sig[i];
|
||||||
|
float d2 = sig[i+1] - sig[i+2];
|
||||||
|
if ((d1 * d2) > ssc_thresh) ssc++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* AR(4) via Yule-Walker */
|
||||||
|
float r[5] = {0};
|
||||||
|
for (int lag = 0; lag < 5; lag++) {
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE - lag; i++)
|
||||||
|
r[lag] += sig[i] * sig[i + lag];
|
||||||
|
r[lag] /= INFERENCE_WINDOW_SIZE;
|
||||||
|
}
|
||||||
|
float T[4][4], b_ar[4], ar[4] = {0};
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
for (int j = 0; j < 4; j++) T[i][j] = r[abs(i - j)];
|
||||||
|
b_ar[i] = r[i + 1];
|
||||||
|
}
|
||||||
|
solve_4x4(T, b_ar, ar);
|
||||||
|
|
||||||
|
/* Spectral features via FFT (zero-pad to FFT_N) */
|
||||||
|
memset(s_fft_buf, 0, sizeof(s_fft_buf));
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
s_fft_buf[2*i] = sig[i];
|
||||||
|
s_fft_buf[2*i + 1] = 0.0f;
|
||||||
|
}
|
||||||
|
dsps_fft2r_fc32(s_fft_buf, FFT_N);
|
||||||
|
dsps_bit_rev_fc32(s_fft_buf, FFT_N);
|
||||||
|
|
||||||
|
float total_power = 1e-10f;
|
||||||
|
const float freq_step = (float)EMG_SAMPLE_RATE_HZ / FFT_N;
|
||||||
|
float pwr[FFT_N / 2];
|
||||||
|
for (int k = 0; k < FFT_N / 2; k++) {
|
||||||
|
float re = s_fft_buf[2*k], im = s_fft_buf[2*k + 1];
|
||||||
|
pwr[k] = re*re + im*im;
|
||||||
|
total_power += pwr[k];
|
||||||
|
}
|
||||||
|
|
||||||
|
float mnf = 0.0f;
|
||||||
|
for (int k = 0; k < FFT_N / 2; k++) mnf += k * freq_step * pwr[k];
|
||||||
|
mnf /= total_power;
|
||||||
|
|
||||||
|
float cumsum = 0.0f, half = total_power / 2.0f;
|
||||||
|
int mdf_k = FFT_N / 2 - 1;
|
||||||
|
for (int k = 0; k < FFT_N / 2; k++) {
|
||||||
|
cumsum += pwr[k];
|
||||||
|
if (cumsum >= half) { mdf_k = k; break; }
|
||||||
|
}
|
||||||
|
float mdf = mdf_k * freq_step;
|
||||||
|
|
||||||
|
int pkf_k = 0;
|
||||||
|
for (int k = 1; k < FFT_N / 2; k++) {
|
||||||
|
if (pwr[k] > pwr[pkf_k]) pkf_k = k;
|
||||||
|
}
|
||||||
|
float pkf = pkf_k * freq_step;
|
||||||
|
float mnp = total_power / (FFT_N / 2);
|
||||||
|
|
||||||
|
float bp0 = 0, bp1 = 0, bp2 = 0, bp3 = 0;
|
||||||
|
for (int k = 0; k < FFT_N / 2; k++) {
|
||||||
|
float f = k * freq_step;
|
||||||
|
if (f >= 20.f && f < 80.f ) bp0 += pwr[k];
|
||||||
|
else if (f >= 80.f && f < 150.f) bp1 += pwr[k];
|
||||||
|
else if (f >= 150.f && f < 250.f) bp2 += pwr[k];
|
||||||
|
else if (f >= 250.f && f < 450.f) bp3 += pwr[k];
|
||||||
|
}
|
||||||
|
bp0 /= total_power; bp1 /= total_power;
|
||||||
|
bp2 /= total_power; bp3 /= total_power;
|
||||||
|
|
||||||
|
/* Store 20 features for this channel */
|
||||||
|
int base = ch * 20;
|
||||||
|
features_out[base + 0] = rms;
|
||||||
|
features_out[base + 1] = wl;
|
||||||
|
features_out[base + 2] = (float)zc;
|
||||||
|
features_out[base + 3] = (float)ssc;
|
||||||
|
features_out[base + 4] = mav;
|
||||||
|
features_out[base + 5] = var_val;
|
||||||
|
features_out[base + 6] = iemg;
|
||||||
|
features_out[base + 7] = (float)wamp;
|
||||||
|
features_out[base + 8] = ar[0];
|
||||||
|
features_out[base + 9] = ar[1];
|
||||||
|
features_out[base + 10] = ar[2];
|
||||||
|
features_out[base + 11] = ar[3];
|
||||||
|
features_out[base + 12] = mnf;
|
||||||
|
features_out[base + 13] = mdf;
|
||||||
|
features_out[base + 14] = pkf;
|
||||||
|
features_out[base + 15] = mnp;
|
||||||
|
features_out[base + 16] = bp0;
|
||||||
|
features_out[base + 17] = bp1;
|
||||||
|
features_out[base + 18] = bp2;
|
||||||
|
features_out[base + 19] = bp3;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ──────────────────────────────────────────────────────────────────────
|
||||||
|
* Amplitude normalization (matches Python normalize=True behaviour)
|
||||||
|
* ────────────────────────────────────────────────────────────────────── */
|
||||||
|
float norm_factor = sqrtf(norm_sq);
|
||||||
|
if (norm_factor < 1e-6f) norm_factor = 1e-6f;
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
int base = ch * 20;
|
||||||
|
features_out[base + 0] /= norm_factor; /* rms */
|
||||||
|
features_out[base + 1] /= norm_factor; /* wl */
|
||||||
|
features_out[base + 4] /= norm_factor; /* mav */
|
||||||
|
features_out[base + 6] /= norm_factor; /* iemg */
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ──────────────────────────────────────────────────────────────────────
|
||||||
|
* Cross-channel features (3 pairs × 3 features = 9)
|
||||||
|
* Order: (0,1), (0,2), (1,2) → corr, lrms, cov
|
||||||
|
* ────────────────────────────────────────────────────────────────────── */
|
||||||
|
int xc_base = HAND_NUM_CHANNELS * 20; /* = 60 */
|
||||||
|
int xc_idx = 0;
|
||||||
|
for (int i = 0; i < HAND_NUM_CHANNELS; i++) {
|
||||||
|
for (int j = i + 1; j < HAND_NUM_CHANNELS; j++) {
|
||||||
|
float ri = ch_rms[i] + 1e-10f;
|
||||||
|
float rj = ch_rms[j] + 1e-10f;
|
||||||
|
|
||||||
|
float dot_ij = 0.0f;
|
||||||
|
for (int k = 0; k < INFERENCE_WINDOW_SIZE; k++)
|
||||||
|
dot_ij += ch_signals[i][k] * ch_signals[j][k];
|
||||||
|
float n_inv = 1.0f / INFERENCE_WINDOW_SIZE;
|
||||||
|
|
||||||
|
float corr = dot_ij * n_inv / (ri * rj);
|
||||||
|
if (corr > 1.0f) corr = 1.0f;
|
||||||
|
if (corr < -1.0f) corr = -1.0f;
|
||||||
|
|
||||||
|
float lrms = logf(ri / rj);
|
||||||
|
float cov = dot_ij * n_inv / (norm_factor * norm_factor);
|
||||||
|
|
||||||
|
features_out[xc_base + xc_idx * 3 + 0] = corr;
|
||||||
|
features_out[xc_base + xc_idx * 3 + 1] = lrms;
|
||||||
|
features_out[xc_base + xc_idx * 3 + 2] = cov;
|
||||||
|
xc_idx++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* MODEL_EXPAND_FEATURES */
|
||||||
|
|
||||||
|
static void compute_features(float *features_out) {
|
||||||
|
// Process forearm channels only (ch0-ch2) for hand gesture classification.
|
||||||
|
// The bicep channel (ch3) is excluded — it will be processed independently.
|
||||||
|
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
float sq_sum = 0;
|
float sq_sum = 0;
|
||||||
|
|
||||||
@@ -79,7 +382,7 @@ static void compute_features(float *features_out) {
|
|||||||
|
|
||||||
int idx = buffer_head; // Oldest sample
|
int idx = buffer_head; // Oldest sample
|
||||||
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
signal[i] = (float)window_buffer[idx][ch];
|
signal[i] = window_buffer[idx][ch];
|
||||||
sum += signal[i];
|
sum += signal[i];
|
||||||
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
|
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
|
||||||
}
|
}
|
||||||
@@ -96,6 +399,15 @@ static void compute_features(float *features_out) {
|
|||||||
signal[i] -= mean;
|
signal[i] -= mean;
|
||||||
sq_sum += signal[i] * signal[i];
|
sq_sum += signal[i] * signal[i];
|
||||||
}
|
}
|
||||||
|
/* Change 4 — Reinhard tone-mapping: 64·x / (32 + |x|) */
|
||||||
|
#if MODEL_USE_REINHARD
|
||||||
|
sq_sum = 0.0f;
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
float x = signal[i];
|
||||||
|
signal[i] = 64.0f * x / (32.0f + fabsf(x));
|
||||||
|
sq_sum += signal[i] * signal[i];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
|
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
|
||||||
|
|
||||||
@@ -133,6 +445,60 @@ static void compute_features(float *features_out) {
|
|||||||
features_out[base + 2] = (float)zc;
|
features_out[base + 2] = (float)zc;
|
||||||
features_out[base + 3] = (float)ssc;
|
features_out[base + 3] = (float)ssc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if MODEL_NORMALIZE_FEATURES
|
||||||
|
// Normalize amplitude-dependent features (RMS, WL) by total RMS across
|
||||||
|
// channels. This makes the model robust to impedance shifts between sessions
|
||||||
|
// while preserving relative channel activation patterns.
|
||||||
|
float total_rms_sq = 0;
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
float ch_rms = features_out[ch * 4 + 0];
|
||||||
|
total_rms_sq += ch_rms * ch_rms;
|
||||||
|
}
|
||||||
|
float norm_factor = sqrtf(total_rms_sq);
|
||||||
|
if (norm_factor < 1e-6f) norm_factor = 1e-6f;
|
||||||
|
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
features_out[ch * 4 + 0] /= norm_factor; // RMS
|
||||||
|
features_out[ch * 4 + 1] /= norm_factor; // WL
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Feature extraction (public wrapper used by inference_ensemble.c) ---
|
||||||
|
|
||||||
|
void inference_extract_features(float *features_out) {
|
||||||
|
#if MODEL_EXPAND_FEATURES
|
||||||
|
compute_features_expanded(features_out);
|
||||||
|
#else
|
||||||
|
compute_features(features_out);
|
||||||
|
#endif
|
||||||
|
calibration_apply(features_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Raw LDA probability (for multi-model voting) ---
|
||||||
|
|
||||||
|
void inference_predict_raw(const float *features, float *proba_out) {
|
||||||
|
float raw_scores[MODEL_NUM_CLASSES];
|
||||||
|
float max_score = -1e9f;
|
||||||
|
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
float score = LDA_INTERCEPTS[c];
|
||||||
|
for (int f = 0; f < MODEL_NUM_FEATURES; f++) {
|
||||||
|
score += features[f] * LDA_WEIGHTS[c][f];
|
||||||
|
}
|
||||||
|
raw_scores[c] = score;
|
||||||
|
if (score > max_score) max_score = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sum_exp = 0.0f;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
proba_out[c] = expf(raw_scores[c] - max_score);
|
||||||
|
sum_exp += proba_out[c];
|
||||||
|
}
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
proba_out[c] /= sum_exp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Prediction ---
|
// --- Prediction ---
|
||||||
@@ -144,7 +510,14 @@ int inference_predict(float *confidence) {
|
|||||||
|
|
||||||
// 1. Extract Features
|
// 1. Extract Features
|
||||||
float features[MODEL_NUM_FEATURES];
|
float features[MODEL_NUM_FEATURES];
|
||||||
|
#if MODEL_EXPAND_FEATURES
|
||||||
|
compute_features_expanded(features);
|
||||||
|
#else
|
||||||
compute_features(features);
|
compute_features(features);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// 1b. Change D: z-score normalise using NVS-stored session calibration
|
||||||
|
calibration_apply(features);
|
||||||
|
|
||||||
// 2. LDA Inference (Linear Score)
|
// 2. LDA Inference (Linear Score)
|
||||||
float raw_scores[MODEL_NUM_CLASSES];
|
float raw_scores[MODEL_NUM_CLASSES];
|
||||||
@@ -195,6 +568,15 @@ int inference_predict(float *confidence) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Change C: confidence rejection.
|
||||||
|
// If the strongest smoothed probability is still too low, the classifier is
|
||||||
|
// uncertain — return the last confirmed output without updating vote/debounce state.
|
||||||
|
// This prevents low-confidence transitions from actuating the arm spuriously.
|
||||||
|
if (max_smoothed_prob < CONFIDENCE_THRESHOLD) {
|
||||||
|
*confidence = max_smoothed_prob;
|
||||||
|
return current_output; // -1 (GESTURE_NONE) until first confident prediction
|
||||||
|
}
|
||||||
|
|
||||||
// 3b. Majority Vote
|
// 3b. Majority Vote
|
||||||
vote_history[vote_head] = smoothed_winner;
|
vote_history[vote_head] = smoothed_winner;
|
||||||
vote_head = (vote_head + 1) % VOTE_WINDOW;
|
vote_head = (vote_head + 1) % VOTE_WINDOW;
|
||||||
@@ -255,17 +637,12 @@ const char *inference_get_class_name(int class_idx) {
|
|||||||
|
|
||||||
int inference_get_gesture_enum(int class_idx) {
|
int inference_get_gesture_enum(int class_idx) {
|
||||||
const char *name = inference_get_class_name(class_idx);
|
const char *name = inference_get_class_name(class_idx);
|
||||||
|
return inference_get_gesture_by_name(name);
|
||||||
|
}
|
||||||
|
|
||||||
// Map string name to gesture_t enum
|
int inference_get_gesture_by_name(const char *name) {
|
||||||
// Strings must match those in Python list: ["fist", "hook_em", "open",
|
// Accepts both lowercase (Python output) and uppercase (C enum name style).
|
||||||
// "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums
|
// Add a new case here whenever a gesture is added to gesture_t in config.h.
|
||||||
// are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5
|
|
||||||
|
|
||||||
// Case-insensitive check would be safer, but let's assume Python output is
|
|
||||||
// lowercase as seen in scripts or uppercase if specified. In
|
|
||||||
// learning_data_collection.py, they seem to be "rest", "open", "fist", etc.
|
|
||||||
|
|
||||||
// Simple string matching
|
|
||||||
if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0)
|
if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0)
|
||||||
return GESTURE_REST;
|
return GESTURE_REST;
|
||||||
if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0)
|
if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0)
|
||||||
@@ -276,6 +653,21 @@ int inference_get_gesture_enum(int class_idx) {
|
|||||||
return GESTURE_HOOK_EM;
|
return GESTURE_HOOK_EM;
|
||||||
if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0)
|
if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0)
|
||||||
return GESTURE_THUMBS_UP;
|
return GESTURE_THUMBS_UP;
|
||||||
|
|
||||||
return GESTURE_NONE;
|
return GESTURE_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float inference_get_bicep_rms(int n_samples) {
|
||||||
|
if (samples_collected < INFERENCE_WINDOW_SIZE) return 0.0f;
|
||||||
|
if (n_samples > INFERENCE_WINDOW_SIZE) n_samples = INFERENCE_WINDOW_SIZE;
|
||||||
|
|
||||||
|
float sum_sq = 0.0f;
|
||||||
|
// Walk backwards from buffer_head (oldest = buffer_head, newest = buffer_head - 1)
|
||||||
|
int start = (buffer_head - n_samples + INFERENCE_WINDOW_SIZE) % INFERENCE_WINDOW_SIZE;
|
||||||
|
int idx = start;
|
||||||
|
for (int i = 0; i < n_samples; i++) {
|
||||||
|
float v = window_buffer[idx][3]; // channel 3 = bicep
|
||||||
|
sum_sq += v * v;
|
||||||
|
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
|
||||||
|
}
|
||||||
|
return sqrtf(sum_sq / n_samples);
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,9 +9,11 @@
|
|||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
// --- Configuration ---
|
// --- Configuration (must match Python WINDOW_SIZE_MS / HOP_SIZE_MS) ---
|
||||||
#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python)
|
#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (150ms at 1kHz)
|
||||||
#define NUM_CHANNELS 4 // Number of EMG channels
|
#define INFERENCE_HOP_SIZE 25 // Hop/stride in samples (25ms at 1kHz)
|
||||||
|
#define NUM_CHANNELS 4 // Total EMG channels (buffer stores all)
|
||||||
|
#define HAND_NUM_CHANNELS 3 // Forearm channels for hand classifier (ch0-ch2)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Initialize the inference engine.
|
* @brief Initialize the inference engine.
|
||||||
@@ -44,4 +46,53 @@ const char *inference_get_class_name(int class_idx);
|
|||||||
*/
|
*/
|
||||||
int inference_get_gesture_enum(int class_idx);
|
int inference_get_gesture_enum(int class_idx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Map a gesture name string directly to gesture_t enum value.
|
||||||
|
*
|
||||||
|
* Used by the laptop-predict path to convert the name sent by live_predict.py
|
||||||
|
* into a gesture_t without needing a class index.
|
||||||
|
*
|
||||||
|
* @param name Lowercase gesture name, e.g. "fist", "rest", "open"
|
||||||
|
* @return gesture_t value, or GESTURE_NONE if unrecognised
|
||||||
|
*/
|
||||||
|
int inference_get_gesture_by_name(const char *name);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute LDA softmax probabilities without smoothing/voting/debounce.
|
||||||
|
*
|
||||||
|
* Used by the multi-model voting path in main.c. The caller is responsible
|
||||||
|
* for post-processing (EMA, majority vote, debounce).
|
||||||
|
*
|
||||||
|
* @param features Calibrated feature vector (MODEL_NUM_FEATURES floats).
|
||||||
|
* @param proba_out Output probability array (MODEL_NUM_CLASSES floats).
|
||||||
|
*/
|
||||||
|
void inference_predict_raw(const float *features, float *proba_out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Extract and calibrate features from the current window.
|
||||||
|
*
|
||||||
|
* Dispatches to compute_features() or compute_features_expanded() depending
|
||||||
|
* on MODEL_EXPAND_FEATURES, then applies calibration_apply(). The resulting
|
||||||
|
* float array is identical to what inference_predict() uses internally.
|
||||||
|
*
|
||||||
|
* Called by inference_ensemble.c so that the ensemble path does not duplicate
|
||||||
|
* the feature-extraction logic.
|
||||||
|
*
|
||||||
|
* @param features_out Caller-allocated array of MODEL_NUM_FEATURES floats.
|
||||||
|
*/
|
||||||
|
void inference_extract_features(float *features_out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute RMS of the last n_samples from channel 3 (bicep) in the
|
||||||
|
* inference circular buffer.
|
||||||
|
*
|
||||||
|
* Used by the bicep subsystem to obtain current activation level without
|
||||||
|
* exposing the internal window_buffer.
|
||||||
|
*
|
||||||
|
* @param n_samples Number of samples to include (clamped to INFERENCE_WINDOW_SIZE).
|
||||||
|
* @return RMS value in the same units as the filtered buffer (mV after Change B).
|
||||||
|
* Returns 0 if the buffer is not yet filled.
|
||||||
|
*/
|
||||||
|
float inference_get_bicep_rms(int n_samples);
|
||||||
|
|
||||||
#endif /* INFERENCE_H */
|
#endif /* INFERENCE_H */
|
||||||
|
|||||||
262
EMG_Arm/src/core/inference_ensemble.c
Normal file
262
EMG_Arm/src/core/inference_ensemble.c
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
/**
|
||||||
|
* @file inference_ensemble.c
|
||||||
|
* @brief 3-specialist-LDA + meta-LDA ensemble inference pipeline (Change F).
|
||||||
|
*
|
||||||
|
* Guarded by MODEL_USE_ENSEMBLE in model_weights.h.
|
||||||
|
* When 0, provides empty stubs so the file compiles unconditionally.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "inference_ensemble.h"
|
||||||
|
#include "inference.h"
|
||||||
|
#include "model_weights.h"
|
||||||
|
|
||||||
|
#if MODEL_USE_ENSEMBLE
|
||||||
|
|
||||||
|
#include "inference_mlp.h"
|
||||||
|
#include "model_weights_ensemble.h"
|
||||||
|
#include "calibration.h"
|
||||||
|
#include <math.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#define ENSEMBLE_EMA_ALPHA 0.70f
|
||||||
|
#define ENSEMBLE_CONF_THRESHOLD 0.50f /**< below this: escalate to MLP fallback */
|
||||||
|
#define REJECT_THRESHOLD 0.40f /**< below this even after MLP: hold output */
|
||||||
|
#define REST_ACTIVITY_THRESHOLD 0.05f /**< total RMS gate — skip inference during rest */
|
||||||
|
|
||||||
|
/* Class index for "rest" — looked up once in init so we don't return
|
||||||
|
* the gesture_t enum value (GESTURE_REST = 1) which would be misinterpreted
|
||||||
|
* as class index 1 ("hook_em"). See Bug 2 in the code review. */
|
||||||
|
static int s_rest_class_idx = 0;
|
||||||
|
|
||||||
|
/* EMA probability state */
|
||||||
|
static float s_smoothed[MODEL_NUM_CLASSES];
|
||||||
|
|
||||||
|
/* Majority vote ring buffer (window = 5) */
|
||||||
|
static int s_vote_history[5];
|
||||||
|
static int s_vote_head = 0;
|
||||||
|
|
||||||
|
/* Debounce state */
|
||||||
|
static int s_current_output = -1;
|
||||||
|
static int s_pending_output = -1;
|
||||||
|
static int s_pending_count = 0;
|
||||||
|
|
||||||
|
/* ── Generic LDA softmax ──────────────────────────────────────────────────── */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute softmax class probabilities from a flat feature vector.
|
||||||
|
*
|
||||||
|
* @param feat Feature vector (contiguous, length n_feat).
|
||||||
|
* @param n_feat Number of features.
|
||||||
|
* @param weights_flat Row-major weight matrix, shape [n_classes][n_feat].
|
||||||
|
* @param intercepts Intercept vector, length n_classes.
|
||||||
|
* @param n_classes Number of output classes.
|
||||||
|
* @param proba_out Output probabilities, length n_classes (caller-allocated).
|
||||||
|
*/
|
||||||
|
static void lda_softmax(const float *feat, int n_feat,
|
||||||
|
const float *weights_flat, const float *intercepts,
|
||||||
|
int n_classes, float *proba_out) {
|
||||||
|
float raw[MODEL_NUM_CLASSES];
|
||||||
|
float max_raw = -1e9f;
|
||||||
|
float sum_exp = 0.0f;
|
||||||
|
|
||||||
|
for (int c = 0; c < n_classes; c++) {
|
||||||
|
raw[c] = intercepts[c];
|
||||||
|
const float *w = weights_flat + c * n_feat;
|
||||||
|
for (int f = 0; f < n_feat; f++) {
|
||||||
|
raw[c] += feat[f] * w[f];
|
||||||
|
}
|
||||||
|
if (raw[c] > max_raw) max_raw = raw[c];
|
||||||
|
}
|
||||||
|
for (int c = 0; c < n_classes; c++) {
|
||||||
|
proba_out[c] = expf(raw[c] - max_raw);
|
||||||
|
sum_exp += proba_out[c];
|
||||||
|
}
|
||||||
|
for (int c = 0; c < n_classes; c++) {
|
||||||
|
proba_out[c] /= sum_exp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ── Public API ───────────────────────────────────────────────────────────── */
|
||||||
|
|
||||||
|
void inference_ensemble_init(void) {
|
||||||
|
/* Find the class index for "rest" once, so we can return it correctly
|
||||||
|
* from the activity gate without confusing class indices with gesture_t. */
|
||||||
|
s_rest_class_idx = 0;
|
||||||
|
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
|
||||||
|
if (strcmp(MODEL_CLASS_NAMES[i], "rest") == 0) {
|
||||||
|
s_rest_class_idx = i;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
s_smoothed[c] = 1.0f / MODEL_NUM_CLASSES;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
s_vote_history[i] = -1;
|
||||||
|
}
|
||||||
|
s_vote_head = 0;
|
||||||
|
s_current_output = -1;
|
||||||
|
s_pending_output = -1;
|
||||||
|
s_pending_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
void inference_ensemble_predict_raw(const float *features, float *proba_out) {
|
||||||
|
/* Gather TD features (non-contiguous: 12 per channel × 3 channels) */
|
||||||
|
float td_buf[TD_NUM_FEATURES];
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
memcpy(td_buf + ch * 12,
|
||||||
|
features + ch * ENSEMBLE_PER_CH_FEATURES,
|
||||||
|
12 * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Gather FD features (non-contiguous: 8 per channel × 3 channels) */
|
||||||
|
float fd_buf[FD_NUM_FEATURES];
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
memcpy(fd_buf + ch * 8,
|
||||||
|
features + ch * ENSEMBLE_PER_CH_FEATURES + 12,
|
||||||
|
8 * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
/* CC features are already contiguous at the end (indices 60-68) */
|
||||||
|
const float *cc_buf = features + CC_FEAT_OFFSET;
|
||||||
|
|
||||||
|
/* Specialist LDA predictions */
|
||||||
|
float prob_td[MODEL_NUM_CLASSES];
|
||||||
|
float prob_fd[MODEL_NUM_CLASSES];
|
||||||
|
float prob_cc[MODEL_NUM_CLASSES];
|
||||||
|
|
||||||
|
lda_softmax(td_buf, TD_NUM_FEATURES,
|
||||||
|
(const float *)LDA_TD_WEIGHTS, LDA_TD_INTERCEPTS,
|
||||||
|
MODEL_NUM_CLASSES, prob_td);
|
||||||
|
lda_softmax(fd_buf, FD_NUM_FEATURES,
|
||||||
|
(const float *)LDA_FD_WEIGHTS, LDA_FD_INTERCEPTS,
|
||||||
|
MODEL_NUM_CLASSES, prob_fd);
|
||||||
|
lda_softmax(cc_buf, CC_NUM_FEATURES,
|
||||||
|
(const float *)LDA_CC_WEIGHTS, LDA_CC_INTERCEPTS,
|
||||||
|
MODEL_NUM_CLASSES, prob_cc);
|
||||||
|
|
||||||
|
/* Meta-LDA stacker */
|
||||||
|
float meta_in[META_NUM_INPUTS];
|
||||||
|
memcpy(meta_in, prob_td, MODEL_NUM_CLASSES * sizeof(float));
|
||||||
|
memcpy(meta_in + MODEL_NUM_CLASSES, prob_fd, MODEL_NUM_CLASSES * sizeof(float));
|
||||||
|
memcpy(meta_in + 2*MODEL_NUM_CLASSES, prob_cc, MODEL_NUM_CLASSES * sizeof(float));
|
||||||
|
|
||||||
|
lda_softmax(meta_in, META_NUM_INPUTS,
|
||||||
|
(const float *)META_LDA_WEIGHTS, META_LDA_INTERCEPTS,
|
||||||
|
MODEL_NUM_CLASSES, proba_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
int inference_ensemble_predict(float *confidence) {
|
||||||
|
/* 1. Extract + calibrate features (shared with single-model path) */
|
||||||
|
float features[MODEL_NUM_FEATURES];
|
||||||
|
inference_extract_features(features); /* includes calibration_apply() */
|
||||||
|
|
||||||
|
/* 2. Activity gate — skip inference during obvious REST */
|
||||||
|
float total_rms_sq = 0.0f;
|
||||||
|
for (int ch = 0; ch < HAND_NUM_CHANNELS; ch++) {
|
||||||
|
/* RMS is the first feature per channel (index 0 in each 20-feature block) */
|
||||||
|
float r = features[ch * ENSEMBLE_PER_CH_FEATURES];
|
||||||
|
total_rms_sq += r * r;
|
||||||
|
}
|
||||||
|
if (sqrtf(total_rms_sq) < REST_ACTIVITY_THRESHOLD) {
|
||||||
|
*confidence = 1.0f;
|
||||||
|
return s_rest_class_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 3. Run ensemble pipeline (raw probabilities) */
|
||||||
|
float meta_probs[MODEL_NUM_CLASSES];
|
||||||
|
inference_ensemble_predict_raw(features, meta_probs);
|
||||||
|
|
||||||
|
/* 7. EMA smoothing on meta output */
|
||||||
|
float max_smooth = 0.0f;
|
||||||
|
int winner = 0;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
s_smoothed[c] = ENSEMBLE_EMA_ALPHA * s_smoothed[c] +
|
||||||
|
(1.0f - ENSEMBLE_EMA_ALPHA) * meta_probs[c];
|
||||||
|
if (s_smoothed[c] > max_smooth) {
|
||||||
|
max_smooth = s_smoothed[c];
|
||||||
|
winner = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 8. Confidence cascade: escalate to MLP if meta-LDA is uncertain */
|
||||||
|
if (max_smooth < ENSEMBLE_CONF_THRESHOLD) {
|
||||||
|
float mlp_conf = 0.0f;
|
||||||
|
int mlp_winner = inference_mlp_predict(features, MODEL_NUM_FEATURES, &mlp_conf);
|
||||||
|
if (mlp_conf > max_smooth) {
|
||||||
|
winner = mlp_winner;
|
||||||
|
max_smooth = mlp_conf;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 9. Reject if still uncertain — hold current output */
|
||||||
|
if (max_smooth < REJECT_THRESHOLD) {
|
||||||
|
*confidence = max_smooth;
|
||||||
|
return s_current_output >= 0 ? s_current_output : s_rest_class_idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
*confidence = max_smooth;
|
||||||
|
|
||||||
|
/* 10. Majority vote (window = 5) */
|
||||||
|
s_vote_history[s_vote_head] = winner;
|
||||||
|
s_vote_head = (s_vote_head + 1) % 5;
|
||||||
|
|
||||||
|
int counts[MODEL_NUM_CLASSES];
|
||||||
|
memset(counts, 0, sizeof(counts));
|
||||||
|
for (int i = 0; i < 5; i++) {
|
||||||
|
if (s_vote_history[i] >= 0) {
|
||||||
|
counts[s_vote_history[i]]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int majority = 0, majority_cnt = 0;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
if (counts[c] > majority_cnt) {
|
||||||
|
majority_cnt = counts[c];
|
||||||
|
majority = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 11. Debounce — need 3 consecutive predictions to change output */
|
||||||
|
int final_out = (s_current_output >= 0) ? s_current_output : majority;
|
||||||
|
|
||||||
|
if (s_current_output < 0) {
|
||||||
|
/* First prediction ever */
|
||||||
|
s_current_output = majority;
|
||||||
|
final_out = majority;
|
||||||
|
} else if (majority == s_current_output) {
|
||||||
|
/* Staying in current gesture — reset pending */
|
||||||
|
s_pending_output = majority;
|
||||||
|
s_pending_count = 1;
|
||||||
|
} else if (majority == s_pending_output) {
|
||||||
|
/* Same new gesture again */
|
||||||
|
if (++s_pending_count >= 3) {
|
||||||
|
s_current_output = majority;
|
||||||
|
final_out = majority;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
/* New gesture candidate — start counting */
|
||||||
|
s_pending_output = majority;
|
||||||
|
s_pending_count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return final_out;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else /* MODEL_USE_ENSEMBLE == 0 — compile-time stubs */
|
||||||
|
|
||||||
|
void inference_ensemble_init(void) {}
|
||||||
|
|
||||||
|
void inference_ensemble_predict_raw(const float *features, float *proba_out) {
|
||||||
|
(void)features;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++)
|
||||||
|
proba_out[c] = 1.0f / MODEL_NUM_CLASSES;
|
||||||
|
}
|
||||||
|
|
||||||
|
int inference_ensemble_predict(float *confidence) {
|
||||||
|
if (confidence) *confidence = 0.0f;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* MODEL_USE_ENSEMBLE */
|
||||||
44
EMG_Arm/src/core/inference_ensemble.h
Normal file
44
EMG_Arm/src/core/inference_ensemble.h
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
/**
|
||||||
|
* @file inference_ensemble.h
|
||||||
|
* @brief 3-specialist-LDA + meta-LDA ensemble inference pipeline (Change F).
|
||||||
|
*
|
||||||
|
* Requires:
|
||||||
|
* - Change 1 expanded features (MODEL_EXPAND_FEATURES 1)
|
||||||
|
* - Change 7 training (train_ensemble.py) to generate model_weights_ensemble.h
|
||||||
|
* - Change E MLP (MODEL_USE_MLP 1) for confidence-cascade fallback
|
||||||
|
*
|
||||||
|
* Enable by setting MODEL_USE_ENSEMBLE 1 in model_weights.h and calling
|
||||||
|
* inference_ensemble_init() from app_main() instead of (or alongside) inference_init().
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initialise ensemble state (EMA, vote history, debounce).
|
||||||
|
* Must be called before inference_ensemble_predict().
|
||||||
|
*/
|
||||||
|
void inference_ensemble_init(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Compute ensemble probabilities without smoothing/voting/debounce.
|
||||||
|
*
|
||||||
|
* Runs the 3 specialist LDAs + meta-LDA stacker and writes the raw meta-LDA
|
||||||
|
* probabilities to proba_out. Used by the multi-model voting path in main.c.
|
||||||
|
*
|
||||||
|
* @param features Calibrated feature vector (MODEL_NUM_FEATURES floats).
|
||||||
|
* @param proba_out Output probability array (MODEL_NUM_CLASSES floats).
|
||||||
|
*/
|
||||||
|
void inference_ensemble_predict_raw(const float *features, float *proba_out);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run one inference hop through the full ensemble pipeline.
|
||||||
|
*
|
||||||
|
* Internally calls inference_extract_features() to pull the latest window,
|
||||||
|
* routes through the three specialist LDAs, the meta-LDA stacker, EMA
|
||||||
|
* smoothing, majority vote, and debounce.
|
||||||
|
*
|
||||||
|
* @param confidence Output: winning class smoothed probability [0,1].
|
||||||
|
* @return Gesture enum value (same as inference_predict() return).
|
||||||
|
*/
|
||||||
|
int inference_ensemble_predict(float *confidence);
|
||||||
75
EMG_Arm/src/core/inference_mlp.cc
Normal file
75
EMG_Arm/src/core/inference_mlp.cc
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
/**
|
||||||
|
* @file inference_mlp.cc
|
||||||
|
* @brief int8 MLP inference via TFLite Micro (Change E).
|
||||||
|
*
|
||||||
|
* Compiled as C++ because TFLite Micro headers require C++.
|
||||||
|
* Guarded by MODEL_USE_MLP — provides compile-time stubs when 0.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "inference_mlp.h"
|
||||||
|
#include "model_weights.h"
|
||||||
|
|
||||||
|
#if MODEL_USE_MLP
|
||||||
|
|
||||||
|
#include "emg_model_data.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_interpreter.h"
|
||||||
|
#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
|
||||||
|
#include "tensorflow/lite/schema/schema_generated.h"
|
||||||
|
|
||||||
|
static uint8_t tensor_arena[48 * 1024]; /* 48 KB — tune down if memory is tight */
|
||||||
|
static tflite::MicroInterpreter *s_interpreter = nullptr;
|
||||||
|
static TfLiteTensor *s_input = nullptr;
|
||||||
|
static TfLiteTensor *s_output = nullptr;
|
||||||
|
|
||||||
|
void inference_mlp_init(void) {
|
||||||
|
const tflite::Model *model = tflite::GetModel(g_model);
|
||||||
|
static tflite::MicroMutableOpResolver<4> resolver;
|
||||||
|
resolver.AddFullyConnected();
|
||||||
|
resolver.AddRelu();
|
||||||
|
resolver.AddSoftmax();
|
||||||
|
resolver.AddDequantize();
|
||||||
|
static tflite::MicroInterpreter interp(model, resolver,
|
||||||
|
tensor_arena, sizeof(tensor_arena));
|
||||||
|
s_interpreter = &interp;
|
||||||
|
s_interpreter->AllocateTensors();
|
||||||
|
s_input = s_interpreter->input(0);
|
||||||
|
s_output = s_interpreter->output(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
int inference_mlp_predict(const float *features, int n_feat, float *conf_out) {
|
||||||
|
/* Quantise input: int8 = round(float / scale) + zero_point */
|
||||||
|
const float iscale = s_input->params.scale;
|
||||||
|
const int izp = s_input->params.zero_point;
|
||||||
|
for (int i = 0; i < n_feat; i++) {
|
||||||
|
int q = static_cast<int>(roundf(features[i] / iscale)) + izp;
|
||||||
|
if (q < -128) q = -128;
|
||||||
|
if (q > 127) q = 127;
|
||||||
|
s_input->data.int8[i] = static_cast<int8_t>(q);
|
||||||
|
}
|
||||||
|
|
||||||
|
s_interpreter->Invoke();
|
||||||
|
|
||||||
|
/* Dequantise output and find winning class */
|
||||||
|
const float oscale = s_output->params.scale;
|
||||||
|
const int ozp = s_output->params.zero_point;
|
||||||
|
float max_p = -1e9f;
|
||||||
|
int max_c = 0;
|
||||||
|
const int n_cls = s_output->dims->data[1];
|
||||||
|
for (int c = 0; c < n_cls; c++) {
|
||||||
|
float p = (s_output->data.int8[c] - ozp) * oscale;
|
||||||
|
if (p > max_p) { max_p = p; max_c = c; }
|
||||||
|
}
|
||||||
|
*conf_out = max_p;
|
||||||
|
return max_c;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else /* MODEL_USE_MLP == 0 — compile-time stubs */
|
||||||
|
|
||||||
|
void inference_mlp_init(void) {}
|
||||||
|
|
||||||
|
int inference_mlp_predict(const float * /*features*/, int /*n_feat*/, float *conf_out) {
|
||||||
|
if (conf_out) *conf_out = 0.0f;
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif /* MODEL_USE_MLP */
|
||||||
40
EMG_Arm/src/core/inference_mlp.h
Normal file
40
EMG_Arm/src/core/inference_mlp.h
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* @file inference_mlp.h
|
||||||
|
* @brief int8 MLP inference via TFLite Micro (Change E).
|
||||||
|
*
|
||||||
|
* Enable by:
|
||||||
|
* 1. Setting MODEL_USE_MLP 1 in model_weights.h.
|
||||||
|
* 2. Running train_mlp_tflite.py to generate emg_model_data.cc.
|
||||||
|
* 3. Adding TFLite Micro to platformio.ini lib_deps.
|
||||||
|
*
|
||||||
|
* When MODEL_USE_MLP 0, both functions are empty stubs that compile without
|
||||||
|
* TFLite Micro headers.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <stdbool.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initialise TFLite Micro interpreter and allocate tensors.
|
||||||
|
* Must be called before inference_mlp_predict().
|
||||||
|
* No-op when MODEL_USE_MLP 0.
|
||||||
|
*/
|
||||||
|
void inference_mlp_init(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run one forward pass through the int8 MLP.
|
||||||
|
*
|
||||||
|
* @param features Float32 feature vector (same order as Python extractor).
|
||||||
|
* @param n_feat Number of features (must match model input size).
|
||||||
|
* @param conf_out Output: winning class softmax probability [0,1].
|
||||||
|
* @return Winning class index. Returns 0 when MODEL_USE_MLP 0.
|
||||||
|
*/
|
||||||
|
int inference_mlp_predict(const float *features, int n_feat, float *conf_out);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* @file model_weights.h
|
* @file model_weights.h
|
||||||
* @brief Trained LDA model weights exported from Python.
|
* @brief Trained LDA model weights exported from Python.
|
||||||
* @date 2026-01-27 21:35:17
|
* @date 2026-03-08 20:27:52
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MODEL_WEIGHTS_H
|
#ifndef MODEL_WEIGHTS_H
|
||||||
@@ -11,7 +11,14 @@
|
|||||||
|
|
||||||
/* Metadata */
|
/* Metadata */
|
||||||
#define MODEL_NUM_CLASSES 5
|
#define MODEL_NUM_CLASSES 5
|
||||||
#define MODEL_NUM_FEATURES 16
|
#define MODEL_NUM_FEATURES 69
|
||||||
|
#define MODEL_NORMALIZE_FEATURES 1
|
||||||
|
|
||||||
|
/* Compile-time feature flags */
|
||||||
|
#define MODEL_EXPAND_FEATURES 1
|
||||||
|
#define MODEL_USE_REINHARD 1
|
||||||
|
#define MODEL_USE_MLP 1
|
||||||
|
#define MODEL_USE_ENSEMBLE 1
|
||||||
|
|
||||||
/* Class Names */
|
/* Class Names */
|
||||||
static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
|
static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
|
||||||
@@ -28,35 +35,70 @@ static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
|
|||||||
|
|
||||||
/* LDA Intercepts/Biases */
|
/* LDA Intercepts/Biases */
|
||||||
static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {
|
static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {
|
||||||
-14.097581f, -2.018629f, -4.478267f, 1.460458f, -5.562349f
|
-2.361446f, -2.300285f, -3.189113f, -1.744772f, -2.137618f
|
||||||
};
|
};
|
||||||
|
|
||||||
/* LDA Coefficients (Weights) */
|
/* LDA Coefficients (Weights) */
|
||||||
static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {
|
static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {
|
||||||
/* fist */
|
/* fist */
|
||||||
{
|
{
|
||||||
0.070110f, -0.002554f, 0.043924f, 0.020555f, -0.660305f, 0.010691f, -0.074429f, -0.037253f,
|
-1.005012f, 1.102594f, 0.220576f, -0.171785f, 0.160110f, 1.460799f, 0.160097f, 0.023789f,
|
||||||
0.057908f, -0.002655f, 0.042119f, -0.052956f, 0.063822f, 0.006184f, -0.025462f, 0.040815f,
|
0.316695f, 0.551076f, 0.193143f, -0.030601f, -1.068705f, -0.414611f, -0.040128f, 1.215391f,
|
||||||
|
0.945437f, 0.545068f, 0.736472f, 0.328106f, 1.212364f, -0.769889f, 0.428595f, -0.033948f,
|
||||||
|
0.181605f, -3.236872f, 0.181581f, -0.061825f, 1.166229f, 0.338513f, 0.207578f, 0.041526f,
|
||||||
|
0.793306f, 0.256971f, -0.062490f, 2.940555f, 0.251149f, 0.096629f, 0.124341f, -0.125052f,
|
||||||
|
0.362799f, 0.430005f, 0.106206f, -0.201260f, 0.529469f, -1.578447f, 0.529456f, -0.040727f,
|
||||||
|
1.636684f, 0.129321f, 0.011299f, 0.086495f, 1.510614f, -0.578640f, -0.025043f, 0.275174f,
|
||||||
|
0.300145f, 0.366742f, -0.104204f, -0.125212f, 0.819188f, 2.292691f, -0.957544f, -0.321092f,
|
||||||
|
-0.173973f, 0.306428f, -1.015226f, 0.770284f, 0.797008f
|
||||||
},
|
},
|
||||||
/* hook_em */
|
/* hook_em */
|
||||||
{
|
{
|
||||||
-0.002511f, 0.001034f, 0.027889f, 0.026006f, 0.183681f, -0.000773f, 0.016791f, -0.027926f,
|
1.694859f, -0.653387f, 0.521604f, 0.251859f, -0.706711f, 1.182055f, -0.706678f, 0.008488f,
|
||||||
-0.023321f, 0.000770f, 0.059023f, -0.056021f, 0.237063f, -0.007423f, 0.082101f, -0.021472f,
|
2.450049f, 0.068430f, -0.002051f, 0.214216f, 2.641921f, -0.677963f, -0.020216f, -0.084562f,
|
||||||
|
0.314933f, 0.166270f, -0.174888f, 0.023982f, -5.326996f, -0.143884f, -0.165025f, 0.014950f,
|
||||||
|
0.167851f, 3.526903f, 0.167886f, 0.139922f, 0.702011f, 0.197141f, 0.087553f, 0.093573f,
|
||||||
|
1.414652f, 0.237107f, -0.044205f, -3.707572f, 0.146742f, 0.025417f, -0.085462f, -0.227836f,
|
||||||
|
1.745303f, 0.408549f, -0.224446f, 0.081996f, 0.432159f, 0.151780f, 0.432186f, 0.009932f,
|
||||||
|
-1.447145f, -0.136321f, -0.071784f, -0.118749f, -1.872055f, 0.490270f, 0.017250f, -0.667072f,
|
||||||
|
-0.381349f, -0.429662f, -0.084705f, -0.123111f, -0.213680f, -3.070284f, 0.267574f, 0.343828f,
|
||||||
|
1.024997f, -0.347210f, 1.222141f, 4.897431f, -1.097906f
|
||||||
},
|
},
|
||||||
/* open */
|
/* open */
|
||||||
{
|
{
|
||||||
-0.006170f, 0.000208f, -0.041151f, 0.013271f, 0.054508f, -0.002356f, 0.000170f, 0.012941f,
|
2.756912f, 0.029446f, -0.408740f, -0.020121f, 0.375563f, -3.479846f, 0.375601f, -0.022014f,
|
||||||
-0.106180f, 0.003538f, -0.013656f, -0.017712f, 0.131131f, -0.002623f, -0.007022f, 0.024497f,
|
-1.357556f, -0.445301f, -0.119886f, -0.020492f, -0.494526f, 0.247305f, 0.028053f, -0.106328f,
|
||||||
|
-0.585811f, -0.405593f, -0.486449f, -0.319108f, -3.159308f, 0.871359f, 0.611153f, -0.191602f,
|
||||||
|
0.059082f, 4.326131f, 0.059172f, -0.006684f, -0.502170f, -0.110067f, -0.057217f, -0.224138f,
|
||||||
|
-1.214966f, -1.056423f, -0.019260f, -0.878641f, 0.167700f, 0.576919f, -0.143247f, -0.914332f,
|
||||||
|
-2.970328f, -0.403648f, -0.014340f, 0.048198f, 0.332494f, 0.479158f, 0.332573f, -0.002956f,
|
||||||
|
-0.351570f, 0.181890f, 0.109983f, -0.013305f, 0.133854f, -0.180451f, -0.046355f, -0.355403f,
|
||||||
|
-0.123445f, 0.146366f, 0.317895f, 0.342519f, -0.399778f, -6.284204f, 0.422619f, -0.109523f,
|
||||||
|
-0.996114f, 0.059564f, 0.043680f, -3.046475f, -0.178861f
|
||||||
},
|
},
|
||||||
/* rest */
|
/* rest */
|
||||||
{
|
{
|
||||||
-0.011094f, 0.000160f, -0.012547f, -0.011058f, 0.130577f, -0.001942f, 0.020823f, -0.001961f,
|
-1.609762f, -0.516942f, -0.285630f, -0.108276f, 0.196206f, 0.488562f, 0.196196f, 0.026133f,
|
||||||
0.018021f, -0.000404f, -0.065598f, 0.039676f, 0.018679f, -0.001522f, 0.023302f, -0.008474f,
|
-1.189232f, -0.034391f, -0.062033f, -0.130823f, -1.033252f, 0.697769f, 0.032136f, -0.605629f,
|
||||||
|
-0.007999f, -0.097476f, 0.094768f, 0.333748f, 1.317443f, -0.880336f, -0.213594f, -0.073489f,
|
||||||
|
-0.207093f, -4.451737f, -0.207141f, 0.000240f, -1.645825f, -0.245072f, -0.026848f, 0.029140f,
|
||||||
|
-0.364751f, 0.395671f, 0.066683f, 3.611272f, -0.209066f, -0.214834f, 0.027566f, 0.514329f,
|
||||||
|
1.066669f, -0.684801f, -0.070219f, -0.046770f, -0.494221f, 0.722863f, -0.494243f, 0.008539f,
|
||||||
|
-0.124061f, -0.075996f, -0.084995f, -0.012796f, 0.291955f, 0.121964f, 0.042458f, -0.627127f,
|
||||||
|
0.001798f, -0.199790f, -0.153351f, 0.027358f, -0.117621f, 1.819176f, 0.092071f, -0.017248f,
|
||||||
|
0.251629f, 0.021265f, -0.106236f, 0.495258f, 0.069582f
|
||||||
},
|
},
|
||||||
/* thumbs_up */
|
/* thumbs_up */
|
||||||
{
|
{
|
||||||
-0.016738f, 0.000488f, 0.024199f, -0.024643f, -0.044912f, 0.000153f, -0.011080f, 0.043487f,
|
-0.792044f, 0.374311f, 0.170137f, 0.127379f, -0.184879f, 0.153073f, -0.184920f, -0.053070f,
|
||||||
0.051828f, -0.001670f, 0.109633f, 0.004154f, -0.460694f, 0.008616f, -0.104097f, -0.020886f,
|
0.682094f, -0.099986f, 0.036795f, 0.061637f, 0.725941f, -0.348358f, -0.022970f, -0.004588f,
|
||||||
|
-0.632688f, -0.124638f, -0.221206f, -0.580598f, 5.030162f, 1.484454f, -0.536655f, 0.339228f,
|
||||||
|
-0.058585f, 2.787274f, -0.058607f, -0.068406f, 1.425498f, -0.005883f, -0.188273f, 0.049091f,
|
||||||
|
-0.309501f, -0.063155f, 0.013505f, -4.453858f, -0.215617f, -0.354558f, 0.060673f, 0.425839f,
|
||||||
|
-0.794436f, 0.734764f, 0.245384f, 0.149158f, -0.463152f, -0.279041f, -0.463210f, 0.019611f,
|
||||||
|
0.350645f, -0.055642f, 0.087931f, 0.064769f, -0.304954f, 0.079455f, -0.015292f, 1.791572f,
|
||||||
|
0.196916f, 0.237732f, 0.116177f, -0.153250f, 0.000654f, 4.127202f, 0.103829f, 0.125913f,
|
||||||
|
-0.223239f, -0.063374f, -0.047960f, -3.238916f, 0.344427f
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
65
EMG_Arm/src/core/model_weights_ensemble.h
Normal file
65
EMG_Arm/src/core/model_weights_ensemble.h
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
// Auto-generated by train_ensemble.py — do not edit
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from
|
||||||
|
// model_weights.h to avoid redefinition conflicts.
|
||||||
|
#include "model_weights.h"
|
||||||
|
|
||||||
|
#define ENSEMBLE_PER_CH_FEATURES 20
|
||||||
|
|
||||||
|
#define TD_FEAT_OFFSET 0
|
||||||
|
#define TD_NUM_FEATURES 36
|
||||||
|
#define FD_FEAT_OFFSET 12
|
||||||
|
#define FD_NUM_FEATURES 24
|
||||||
|
#define CC_FEAT_OFFSET 60
|
||||||
|
#define CC_NUM_FEATURES 9
|
||||||
|
#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)
|
||||||
|
|
||||||
|
// Feature index arrays for gather operations (TD and FD are non-contiguous)
|
||||||
|
// TD indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51]
|
||||||
|
// FD indices: [12, 13, 14, 15, 16, 17, 18, 19, 32, 33, 34, 35, 36, 37, 38, 39, 52, 53, 54, 55, 56, 57, 58, 59]
|
||||||
|
// CC indices: [60, 61, 62, 63, 64, 65, 66, 67, 68]
|
||||||
|
|
||||||
|
const float LDA_TD_WEIGHTS[5][36] = {
|
||||||
|
{-4.98005951f, 3.15698227f, -0.01656110f, 0.16589549f, 1.89759250f, 4.02171634f, 1.89759264f, -0.02318365f, 0.68313384f, 1.08146976f, 0.24790972f, 0.02111336f, 1.40219525f, -0.34718600f, 0.03292437f, 0.03224236f, 0.07318621f, -1.47505917f, 0.07318595f, -0.15298836f, 0.68723991f, 0.31478366f, 0.26694627f, 0.03383652f, 1.16109002f, -0.17694826f, -0.02001825f, 0.04811623f, 0.11145582f, 0.98011591f, 0.11145584f, 0.02078491f, 0.16441538f, -0.37776186f, -0.14062534f, -0.20666287f}, // fist
|
||||||
|
{7.06735173f, -0.07584252f, 0.28878350f, 0.25548171f, -2.72040928f, -1.14291179f, -2.72040944f, -0.05457462f, 0.67944659f, -1.72471154f, -0.56168071f, 0.08847797f, -0.36636016f, -0.92733589f, 0.07186714f, -0.15407635f, 0.09285758f, -0.99197678f, 0.09285799f, 0.24903098f, -0.58741170f, -1.69916941f, -0.79791300f, -0.51257426f, -4.23844985f, -1.30822564f, 0.08744698f, -0.10944734f, 1.12572192f, 0.67459317f, 1.12572190f, 0.12677750f, -0.06012924f, 0.74169191f, 0.20452265f, 0.02420439f}, // hook_em
|
||||||
|
{-5.68183942f, -2.70544312f, -0.02268383f, -0.07857558f, 0.94259344f, -1.41885234f, 0.94259353f, 0.12889098f, -0.30338581f, 0.37920265f, 0.06556953f, 0.00592915f, -6.03034018f, 1.95694619f, 0.22083802f, 0.33970496f, 0.85733980f, 4.45235717f, 0.85733951f, 0.16720219f, 0.59151687f, 1.12441071f, 0.07520788f, -0.44451340f, -0.73662069f, 1.71355275f, -0.19822542f, 0.32955834f, 0.05661870f, -1.29613804f, 0.05661887f, -0.21196686f, -0.61895423f, -1.71535920f, -0.73493998f, -0.21009359f}, // open
|
||||||
|
{0.20260846f, -0.23909181f, -0.28033711f, -0.22450103f, 0.13810806f, -0.88146743f, 0.13810809f, 0.01613534f, -0.34304575f, 1.18327519f, 0.36712706f, -0.09916758f, 1.33818062f, -0.49450303f, 0.03771161f, -0.21199101f, -0.57758751f, -0.06415963f, -0.57758750f, -0.11776329f, -0.49294313f, 0.47593171f, 0.42732005f, 0.38381135f, 0.93751846f, -0.40085712f, -0.01569340f, -0.21022401f, -0.70330953f, 0.29469121f, -0.70330957f, -0.03953274f, -0.19099800f, 1.11867928f, 0.51910780f, 0.26138187f}, // rest
|
||||||
|
{3.56990173f, 0.11258470f, 0.22746547f, 0.04228335f, -0.43704307f, 0.04449194f, -0.43704318f, -0.08323209f, -0.45652127f, -1.76552235f, -0.38107651f, 0.05260073f, 2.92141069f, 0.06553870f, -0.39303365f, 0.12369768f, -0.07810754f, -2.03448136f, -0.07810740f, -0.06529502f, 0.10545265f, -0.60931490f, -0.28035188f, 0.28069772f, 2.16483999f, 0.36093444f, 0.16473959f, 0.07185201f, -0.08960941f, -0.79422744f, -0.08960951f, 0.13979645f, 0.85101150f, -0.45589633f, -0.17267249f, -0.03933928f}, // thumbs_up
|
||||||
|
};
|
||||||
|
const float LDA_TD_INTERCEPTS[5] = {
|
||||||
|
-5.59817408f, -4.91484804f, -9.33780358f, -3.85428822f, -3.12479496f
|
||||||
|
};
|
||||||
|
|
||||||
|
const float LDA_FD_WEIGHTS[5][24] = {
|
||||||
|
{0.19477058f, 0.33798486f, 0.24311717f, 3.58490717f, 0.37449334f, 0.01733587f, 0.99771453f, 0.35846897f, -0.07609940f, -0.08551280f, -0.04212976f, -1.64981087f, -0.28267724f, -0.45724067f, 0.09742739f, -0.45412877f, -0.67354231f, 0.07763610f, -0.03042757f, 1.23057256f, 0.47423480f, 0.96870723f, 0.09008234f, 0.61253227f}, // fist
|
||||||
|
{-1.33051949f, -1.06668043f, -0.13054796f, 1.49029766f, 0.34559104f, 2.58938944f, 2.10036791f, 0.00627139f, 0.67799858f, -0.05916871f, -0.02308780f, -0.51389981f, 0.49698104f, 0.52080687f, -0.53642075f, -0.44702479f, 0.45831951f, -0.05527688f, -0.12666255f, -2.29960756f, 0.15527345f, -0.60190848f, -0.50386274f, -0.20863809f}, // hook_em
|
||||||
|
{0.86684408f, -0.18732078f, 0.04828207f, -4.18979500f, -0.25709076f, -1.34726223f, -1.93439169f, -0.36975670f, -0.57908022f, 0.40713764f, 0.14203746f, 3.81070011f, 0.15930714f, 0.81375493f, 0.17586089f, 1.21709829f, 0.30729211f, -0.06304829f, 0.23529519f, 2.10099132f, -1.15363931f, -0.25856679f, 0.65430329f, -0.77828374f}, // open
|
||||||
|
{0.10167142f, 0.92873579f, -0.03451580f, -0.35104059f, 0.33898659f, -0.99047210f, -1.06006192f, 0.02907091f, -0.22314490f, 0.33956274f, -0.07111302f, 0.31329279f, -0.56910921f, -0.79014171f, -0.38388114f, 0.01131879f, 0.87387134f, -0.12380692f, 0.05214671f, -0.42582976f, -0.17031139f, -1.06661596f, -0.91600447f, -0.26435070f}, // rest
|
||||||
|
{0.04198304f, -0.65837713f, -0.10663077f, -0.12701858f, -1.01290184f, 0.50100717f, 0.72188030f, -0.03150884f, 0.38356910f, -0.84413736f, 0.03745147f, -2.29871124f, 0.58565208f, 0.43331566f, 0.88777658f, -0.38229432f, -1.55494079f, 0.24865587f, -0.17540676f, -0.43032659f, 0.84763042f, 1.67359692f, 1.26211059f, 0.83625332f}, // thumbs_up
|
||||||
|
};
|
||||||
|
const float LDA_FD_INTERCEPTS[5] = {
|
||||||
|
-5.22055861f, -3.76092770f, -8.23556500f, -3.59434123f, -2.96625341f
|
||||||
|
};
|
||||||
|
|
||||||
|
const float LDA_CC_WEIGHTS[5][9] = {
|
||||||
|
{-0.94928369f, 2.44942094f, 0.26441444f, -0.80355120f, -1.19821936f, 1.39871694f, -2.06386346f, 0.59615281f, 1.57632161f}, // fist
|
||||||
|
{-2.17638786f, 1.40504639f, 1.97661862f, 0.99132032f, 1.46231208f, -0.77606652f, 2.43631056f, 1.58734751f, -1.75342004f}, // hook_em
|
||||||
|
{-0.15938422f, -2.49617253f, -0.43491962f, 0.32764670f, -3.51157144f, -0.74698502f, 2.59991544f, -0.64230610f, -3.23294039f}, // open
|
||||||
|
{2.78062805f, -1.24792206f, -2.03404274f, -0.13924844f, 2.03596007f, -0.25590903f, -0.40921478f, -0.84411543f, 0.27031906f}, // rest
|
||||||
|
{-1.42317081f, 0.84654172f, 1.66121800f, -0.27035074f, -0.03025088f, 0.56047099f, -2.30886825f, -0.06801108f, 3.01175668f}, // thumbs_up
|
||||||
|
};
|
||||||
|
const float LDA_CC_INTERCEPTS[5] = {
|
||||||
|
-2.77069779f, -4.07631951f, -6.23624226f, -1.94638886f, -2.30392172f
|
||||||
|
};
|
||||||
|
|
||||||
|
const float META_LDA_WEIGHTS[5][15] = {
|
||||||
|
{13.83327682f, -4.99575141f, -8.19270319f, -2.93944549f, 0.75007499f, 3.36110029f, 0.53962472f, -4.82143659f, -1.61056244f, 1.48044754f, 1.56410113f, -0.98537108f, 0.20587945f, 0.39305391f, -1.42106273f}, // fist
|
||||||
|
{-4.64514194f, 13.09553405f, -4.41874813f, -1.25841220f, -2.98757893f, 0.16270053f, 3.15820296f, -2.24183188f, -1.32266753f, 0.20766953f, -0.65827396f, 1.59311387f, -1.22803182f, 0.04336258f, -0.24531550f}, // hook_em
|
||||||
|
{-5.79187671f, -2.27375570f, 23.21490262f, -0.09093684f, -4.81706827f, -1.92030254f, -1.47383708f, 9.25922135f, 0.57114390f, -2.76993062f, -1.31290246f, -1.07932832f, 5.46202897f, -0.72440402f, -0.19414883f}, // open
|
||||||
|
{-2.87608878f, -1.27035778f, -3.11538939f, 4.40275586f, -1.44999243f, -0.49668812f, -1.24040434f, 0.65617176f, 1.79490014f, -1.52075301f, 0.08593208f, 0.38743884f, -1.35293868f, 0.30551755f, -0.06438460f}, // rest
|
||||||
|
{1.52549481f, -3.33085871f, -6.18851152f, -3.12231842f, 9.54599696f, -0.69931080f, -0.02514346f, -3.63196650f, -0.69837189f, 3.71692816f, 0.29062506f, -0.11110049f, -2.35881642f, -0.20019868f, 1.96143145f}, // thumbs_up
|
||||||
|
};
|
||||||
|
const float META_LDA_INTERCEPTS[5] = {
|
||||||
|
-8.46795016f, -8.64988671f, -21.16256224f, -3.22663037f, -6.35767829f
|
||||||
|
};
|
||||||
@@ -1,109 +1,180 @@
|
|||||||
/**
|
/**
|
||||||
* @file emg_sensor.c
|
* @file emg_sensor.c
|
||||||
* @brief EMG sensor driver implementation.
|
* @brief EMG sensor driver — DMA-backed continuous ADC acquisition.
|
||||||
*
|
*
|
||||||
* Provides EMG readings - fake data for now, real ADC when sensors arrive.
|
* Change A: adc_continuous (DMA) replaces adc_oneshot polling.
|
||||||
|
* A background FreeRTOS task on Core 0 reads DMA frames, assembles
|
||||||
|
* complete 4-channel sample sets, and pushes them to a FreeRTOS queue.
|
||||||
|
* emg_sensor_read() blocks on that queue; at 1 kHz per channel this
|
||||||
|
* returns within ~1 ms, providing exact timing without vTaskDelay(1).
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "emg_sensor.h"
|
#include "emg_sensor.h"
|
||||||
#include "esp_timer.h"
|
#include "esp_timer.h"
|
||||||
#include <stdlib.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include "freertos/FreeRTOS.h"
|
#include "freertos/FreeRTOS.h"
|
||||||
#include "freertos/task.h"
|
#include "freertos/task.h"
|
||||||
#include "esp_adc/adc_oneshot.h"
|
#include "freertos/queue.h"
|
||||||
|
#include "esp_adc/adc_continuous.h"
|
||||||
#include "esp_adc/adc_cali.h"
|
#include "esp_adc/adc_cali.h"
|
||||||
#include "esp_adc/adc_cali_scheme.h"
|
#include "esp_adc/adc_cali_scheme.h"
|
||||||
#include "esp_err.h"
|
#include "esp_err.h"
|
||||||
|
#include <assert.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
#include <stdlib.h>
|
||||||
|
|
||||||
adc_oneshot_unit_handle_t adc1_handle;
|
// --- ADC DMA constants ---
|
||||||
adc_cali_handle_t cali_handle = NULL;
|
// Total sample rate: 4 channels × 1 kHz = 4 kHz
|
||||||
|
#define ADC_TOTAL_SAMPLE_RATE_HZ (EMG_NUM_CHANNELS * EMG_SAMPLE_RATE_HZ)
|
||||||
|
// DMA frame: 64 conversions × 4 bytes = 256 bytes, arrives ~every 16 ms
|
||||||
|
#define ADC_CONV_FRAME_SIZE (64u * sizeof(adc_digi_output_data_t))
|
||||||
|
// Internal DMA pool: 2× conv_frame_size
|
||||||
|
#define ADC_POOL_SIZE (2u * ADC_CONV_FRAME_SIZE)
|
||||||
|
// Sample queue: buffers up to ~32 ms of assembled 4-channel sets
|
||||||
|
#define SAMPLE_QUEUE_DEPTH 32
|
||||||
|
|
||||||
const uint8_t emg_channels[EMG_NUM_CHANNELS] = {
|
// --- Static handles ---
|
||||||
ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0
|
static adc_continuous_handle_t s_adc_handle = NULL;
|
||||||
ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1
|
static adc_cali_handle_t s_cali_handle = NULL;
|
||||||
ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2
|
static QueueHandle_t s_sample_queue = NULL;
|
||||||
ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3
|
|
||||||
|
// Channel mapping (ADC1, GPIO 2/3/9/10)
|
||||||
|
static const adc_channel_t s_channels[EMG_NUM_CHANNELS] = {
|
||||||
|
ADC_CHANNEL_1, // GPIO 2 — FCR/Belly (ch0)
|
||||||
|
ADC_CHANNEL_2, // GPIO 3 — Extensors (ch1)
|
||||||
|
ADC_CHANNEL_8, // GPIO 9 — FCU/Outer Flexors (ch2)
|
||||||
|
ADC_CHANNEL_9, // GPIO 10 — Bicep (ch3)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* ADC Sampling Task (Core 0)
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief DMA sampling task.
|
||||||
|
*
|
||||||
|
* Reads adc_continuous DMA frames, assembles complete 4-channel sample sets
|
||||||
|
* (one value per channel), applies curve-fitting calibration, and pushes
|
||||||
|
* emg_sample_t structs to s_sample_queue. Runs continuously on Core 0.
|
||||||
|
*/
|
||||||
|
static void adc_sampling_task(void *arg) {
|
||||||
|
uint8_t *buf = (uint8_t *)malloc(ADC_CONV_FRAME_SIZE);
|
||||||
|
if (!buf) {
|
||||||
|
printf("[EMG] ERROR: DMA read buffer alloc failed\n");
|
||||||
|
vTaskDelete(NULL);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Per-channel raw accumulator; holds the latest sample for each channel
|
||||||
|
int raw_set[EMG_NUM_CHANNELS];
|
||||||
|
bool got[EMG_NUM_CHANNELS];
|
||||||
|
memset(raw_set, 0, sizeof(raw_set));
|
||||||
|
memset(got, 0, sizeof(got));
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
uint32_t out_len = 0;
|
||||||
|
esp_err_t err = adc_continuous_read(
|
||||||
|
s_adc_handle, buf, (uint32_t)ADC_CONV_FRAME_SIZE,
|
||||||
|
&out_len, pdMS_TO_TICKS(100)
|
||||||
|
);
|
||||||
|
if (err != ESP_OK || out_len == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t n = out_len / sizeof(adc_digi_output_data_t);
|
||||||
|
adc_digi_output_data_t *p = (adc_digi_output_data_t *)buf;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < n; i++) {
|
||||||
|
int ch = (int)p[i].type2.channel;
|
||||||
|
int raw = (int)p[i].type2.data;
|
||||||
|
|
||||||
|
// Map ADC channel index to sensor channel
|
||||||
|
for (int c = 0; c < EMG_NUM_CHANNELS; c++) {
|
||||||
|
if ((int)s_channels[c] == ch) {
|
||||||
|
raw_set[c] = raw;
|
||||||
|
got[c] = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit a complete sample set once all channels have been updated
|
||||||
|
bool all = true;
|
||||||
|
for (int c = 0; c < EMG_NUM_CHANNELS; c++) {
|
||||||
|
if (!got[c]) { all = false; break; }
|
||||||
|
}
|
||||||
|
if (!all) continue;
|
||||||
|
|
||||||
|
emg_sample_t s;
|
||||||
|
s.timestamp_ms = emg_sensor_get_timestamp_ms();
|
||||||
|
for (int c = 0; c < EMG_NUM_CHANNELS; c++) {
|
||||||
|
int mv = 0;
|
||||||
|
adc_cali_raw_to_voltage(s_cali_handle, raw_set[c], &mv);
|
||||||
|
s.channels[c] = (uint16_t)mv;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-blocking send: drop if queue is full (prefer fresh data)
|
||||||
|
xQueueSend(s_sample_queue, &s, 0);
|
||||||
|
|
||||||
|
// Reset accumulator for the next complete set
|
||||||
|
memset(got, 0, sizeof(got));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Public Functions
|
* Public Functions
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
void emg_sensor_init(void)
|
void emg_sensor_init(void) {
|
||||||
{
|
// 1. Curve-fitting calibration (same scheme as before)
|
||||||
#if FEATURE_FAKE_EMG
|
adc_cali_curve_fitting_config_t cali_cfg = {
|
||||||
/* Seed random number generator for fake data */
|
.unit_id = ADC_UNIT_1,
|
||||||
srand((unsigned int)esp_timer_get_time());
|
.atten = ADC_ATTEN_DB_12,
|
||||||
#else
|
|
||||||
// 1. --- ADC Unit Setup ---
|
|
||||||
adc_oneshot_unit_init_cfg_t init_config1 = {
|
|
||||||
.unit_id = ADC_UNIT_1,
|
|
||||||
.ulp_mode = ADC_ULP_MODE_DISABLE,
|
|
||||||
};
|
|
||||||
ESP_ERROR_CHECK(adc_oneshot_new_unit(&init_config1, &adc1_handle));
|
|
||||||
|
|
||||||
// 2. --- ADC Channel Setup (GPIO 1?) ---
|
|
||||||
// Ensure the channel matches your GPIO in pinmap. For ADC1, GPIO1 is usually not CH0.
|
|
||||||
// Check your datasheet! (e.g., on S3, GPIO 1 is ADC1_CH0)
|
|
||||||
adc_oneshot_chan_cfg_t config = {
|
|
||||||
.bitwidth = ADC_BITWIDTH_DEFAULT, // 12-bit for S3
|
|
||||||
.atten = ADC_ATTEN_DB_12, // Allows up to ~3.1V
|
|
||||||
};
|
|
||||||
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++)
|
|
||||||
ESP_ERROR_CHECK(adc_oneshot_config_channel(adc1_handle, emg_channels[i], &config));
|
|
||||||
|
|
||||||
// 3. --- Calibration Setup (CORRECTED for S3) ---
|
|
||||||
// ESP32-S3 uses Curve Fitting, not Line Fitting
|
|
||||||
adc_cali_curve_fitting_config_t cali_config = {
|
|
||||||
.unit_id = ADC_UNIT_1,
|
|
||||||
.atten = ADC_ATTEN_DB_12,
|
|
||||||
.bitwidth = ADC_BITWIDTH_DEFAULT,
|
.bitwidth = ADC_BITWIDTH_DEFAULT,
|
||||||
};
|
};
|
||||||
ESP_ERROR_CHECK(adc_cali_create_scheme_curve_fitting(&cali_config, &cali_handle));
|
ESP_ERROR_CHECK(adc_cali_create_scheme_curve_fitting(&cali_cfg, &s_cali_handle));
|
||||||
|
|
||||||
// while (1) {
|
|
||||||
// int raw_val, voltage_mv;
|
|
||||||
|
|
||||||
// // Read Raw
|
|
||||||
// ESP_ERROR_CHECK(adc_oneshot_read(adc1_handle, ADC_CHANNEL_1, &raw_val));
|
|
||||||
|
|
||||||
// // Convert to mV using calibration
|
|
||||||
// ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
|
|
||||||
|
|
||||||
// printf("Raw: %d | Voltage: %d mV\n", raw_val, voltage_mv);
|
// 2. Create continuous ADC handle
|
||||||
// vTaskDelay(pdMS_TO_TICKS(500));
|
adc_continuous_handle_cfg_t adc_cfg = {
|
||||||
// }
|
.max_store_buf_size = ADC_POOL_SIZE,
|
||||||
#endif
|
.conv_frame_size = ADC_CONV_FRAME_SIZE,
|
||||||
}
|
};
|
||||||
|
ESP_ERROR_CHECK(adc_continuous_new_handle(&adc_cfg, &s_adc_handle));
|
||||||
|
|
||||||
void emg_sensor_read(emg_sample_t *sample)
|
// 3. Configure scan pattern (4 channels, 4 kHz total)
|
||||||
{
|
adc_digi_pattern_config_t patterns[EMG_NUM_CHANNELS];
|
||||||
sample->timestamp_ms = emg_sensor_get_timestamp_ms();
|
|
||||||
|
|
||||||
#if FEATURE_FAKE_EMG
|
|
||||||
/*
|
|
||||||
* Generate fake EMG data:
|
|
||||||
* - Base value around 1650 (middle of 3.3V millivolt range)
|
|
||||||
* - Random noise of +/- 50
|
|
||||||
* - Mimics real EMG baseline noise
|
|
||||||
*/
|
|
||||||
for (int i = 0; i < EMG_NUM_CHANNELS; i++) {
|
for (int i = 0; i < EMG_NUM_CHANNELS; i++) {
|
||||||
int noise = (rand() % 101) - 50; /* -50 to +50 */
|
patterns[i].atten = ADC_ATTEN_DB_12;
|
||||||
sample->channels[i] = (uint16_t)(1650 + noise);
|
patterns[i].channel = (uint8_t)s_channels[i];
|
||||||
}
|
patterns[i].unit = ADC_UNIT_1;
|
||||||
#else
|
patterns[i].bit_width = ADC_BITWIDTH_12;
|
||||||
int raw_val, voltage_mv;
|
|
||||||
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) {
|
|
||||||
ESP_ERROR_CHECK(adc_oneshot_read(adc1_handle, emg_channels[i], &raw_val));
|
|
||||||
ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
|
|
||||||
sample->channels[i] = (uint16_t) voltage_mv;
|
|
||||||
}
|
}
|
||||||
|
adc_continuous_config_t cont_cfg = {
|
||||||
|
.pattern_num = EMG_NUM_CHANNELS,
|
||||||
|
.adc_pattern = patterns,
|
||||||
|
.sample_freq_hz = ADC_TOTAL_SAMPLE_RATE_HZ,
|
||||||
|
.conv_mode = ADC_CONV_SINGLE_UNIT_1,
|
||||||
|
.format = ADC_DIGI_OUTPUT_FORMAT_TYPE2,
|
||||||
|
};
|
||||||
|
ESP_ERROR_CHECK(adc_continuous_config(s_adc_handle, &cont_cfg));
|
||||||
|
|
||||||
#endif
|
// 4. Start DMA acquisition
|
||||||
|
ESP_ERROR_CHECK(adc_continuous_start(s_adc_handle));
|
||||||
|
|
||||||
|
// 5. Create sample queue (assembled complete sets land here)
|
||||||
|
s_sample_queue = xQueueCreate(SAMPLE_QUEUE_DEPTH, sizeof(emg_sample_t));
|
||||||
|
assert(s_sample_queue != NULL);
|
||||||
|
|
||||||
|
// 6. Launch sampling task pinned to Core 0
|
||||||
|
xTaskCreatePinnedToCore(adc_sampling_task, "adc_sample", 4096, NULL, 6, NULL, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t emg_sensor_get_timestamp_ms(void)
|
void emg_sensor_read(emg_sample_t *sample) {
|
||||||
{
|
// Block until a complete 4-channel sample set arrives from the DMA task.
|
||||||
|
// At 1 kHz per channel this typically returns within ~1 ms.
|
||||||
|
xQueueReceive(s_sample_queue, sample, portMAX_DELAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t emg_sensor_get_timestamp_ms(void) {
|
||||||
return (uint32_t)(esp_timer_get_time() / 1000);
|
return (uint32_t)(esp_timer_get_time() / 1000);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,8 @@
|
|||||||
* @file emg_sensor.h
|
* @file emg_sensor.h
|
||||||
* @brief EMG sensor driver for reading muscle signals.
|
* @brief EMG sensor driver for reading muscle signals.
|
||||||
*
|
*
|
||||||
* This module provides EMG data acquisition. Currently generates fake
|
* This module provides EMG data acquisition from ADC channels connected
|
||||||
* data for testing (FEATURE_FAKE_EMG=1). When sensors arrive, the
|
* to MyoWare sensors. Outputs calibrated millivolt values (0-3300 mV).
|
||||||
* implementation switches to real ADC reads without changing the interface.
|
|
||||||
*
|
*
|
||||||
* @note This is Layer 2 (Driver).
|
* @note This is Layer 2 (Driver).
|
||||||
*/
|
*/
|
||||||
@@ -34,8 +33,7 @@ typedef struct {
|
|||||||
/**
|
/**
|
||||||
* @brief Initialize the EMG sensor system.
|
* @brief Initialize the EMG sensor system.
|
||||||
*
|
*
|
||||||
* If FEATURE_FAKE_EMG is enabled, just seeds the random generator.
|
* Configures ADC channels and calibration for real sensor reading.
|
||||||
* Otherwise, configures ADC channels for real sensor reading.
|
|
||||||
*/
|
*/
|
||||||
void emg_sensor_init(void);
|
void emg_sensor_init(void);
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,9 @@
|
|||||||
#include "hand.h"
|
#include "hand.h"
|
||||||
#include "hal/servo_hal.h"
|
#include "hal/servo_hal.h"
|
||||||
|
|
||||||
|
float maxAngles[] = {155, 155, 180, 165, 150};
|
||||||
|
float minAngles[] = {65, 45, 45, 30, 25};
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Public Functions
|
* Public Functions
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|||||||
@@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
#include "config/config.h"
|
#include "config/config.h"
|
||||||
|
|
||||||
|
extern float maxAngles[];
|
||||||
|
extern float minAngles[];
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Public Functions
|
* Public Functions
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ void servo_hal_init(void)
|
|||||||
.timer_sel = SERVO_PWM_TIMER,
|
.timer_sel = SERVO_PWM_TIMER,
|
||||||
.intr_type = LEDC_INTR_DISABLE,
|
.intr_type = LEDC_INTR_DISABLE,
|
||||||
.gpio_num = servo_pins[i],
|
.gpio_num = servo_pins[i],
|
||||||
.duty = SERVO_DUTY_MIN, /* Start extended (open) */
|
.duty = servo_hal_degrees_to_duty(90), /* Start extended (open) */
|
||||||
.hpoint = 0
|
.hpoint = 0
|
||||||
};
|
};
|
||||||
ESP_ERROR_CHECK(ledc_channel_config(&channel_config));
|
ESP_ERROR_CHECK(ledc_channel_config(&channel_config));
|
||||||
|
|||||||
BIN
collected_data/new_system_000_20260308_180810.hdf5
Normal file
BIN
collected_data/new_system_000_20260308_180810.hdf5
Normal file
Binary file not shown.
BIN
collected_data/new_system_001_20260308_185810.hdf5
Normal file
BIN
collected_data/new_system_001_20260308_185810.hdf5
Normal file
Binary file not shown.
BIN
collected_data/new_system_002_20260308_191206.hdf5
Normal file
BIN
collected_data/new_system_002_20260308_191206.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated001_20260214_195555.hdf5
Normal file
BIN
collected_data/updated001_20260214_195555.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated002_20260214_195732.hdf5
Normal file
BIN
collected_data/updated002_20260214_195732.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated003_20260214_200039.hdf5
Normal file
BIN
collected_data/updated003_20260214_200039.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated004_20260214_200216.hdf5
Normal file
BIN
collected_data/updated004_20260214_200216.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated005_20260214_202724.hdf5
Normal file
BIN
collected_data/updated005_20260214_202724.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated006_20260214_202910.hdf5
Normal file
BIN
collected_data/updated006_20260214_202910.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated007_20260214_203049.hdf5
Normal file
BIN
collected_data/updated007_20260214_203049.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated008_20260214_203228.hdf5
Normal file
BIN
collected_data/updated008_20260214_203228.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated009_20260214_203612.hdf5
Normal file
BIN
collected_data/updated009_20260214_203612.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated010_20260214_204204.hdf5
Normal file
BIN
collected_data/updated010_20260214_204204.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated011_20260214_212146.hdf5
Normal file
BIN
collected_data/updated011_20260214_212146.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated012_20260214_212732.hdf5
Normal file
BIN
collected_data/updated012_20260214_212732.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated013_20260214_212957.hdf5
Normal file
BIN
collected_data/updated013_20260214_212957.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated014_20260214_213133.hdf5
Normal file
BIN
collected_data/updated014_20260214_213133.hdf5
Normal file
Binary file not shown.
BIN
collected_data/updated015_20260214_213536.hdf5
Normal file
BIN
collected_data/updated015_20260214_213536.hdf5
Normal file
Binary file not shown.
2185
emg_gui.py
2185
emg_gui.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,42 @@
|
|||||||
|
"""
|
||||||
|
EMG Signal Analysis Script (STANDALONE - NOT PRODUCTION CODE)
|
||||||
|
==============================================================
|
||||||
|
|
||||||
|
!! WARNING !!
|
||||||
|
This script is for OFFLINE ANALYSIS AND VISUALIZATION ONLY.
|
||||||
|
It is NOT part of the training or inference pipeline.
|
||||||
|
|
||||||
|
DO NOT use this script's outputs for:
|
||||||
|
- Model training
|
||||||
|
- Feature extraction thresholds
|
||||||
|
- Production inference
|
||||||
|
|
||||||
|
The thresholds defined here (ZC_THRESHOLD_PERCENT, SSC_THRESHOLD_PERCENT)
|
||||||
|
are DIFFERENT from production values in:
|
||||||
|
- learning_data_collection.py (production: 0.1, 0.1)
|
||||||
|
- model_weights.h (production: 0.1, 0.1)
|
||||||
|
|
||||||
|
This script uses higher thresholds (0.7, 0.6) for visualization clarity,
|
||||||
|
which would produce INCORRECT features if used for training/inference.
|
||||||
|
|
||||||
|
To analyze collected sessions:
|
||||||
|
1. Update HDF5_PATH to your session file
|
||||||
|
2. Run: python learning_emg_filtering.py
|
||||||
|
3. View the generated plots
|
||||||
|
"""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import h5py
|
import h5py
|
||||||
from scipy.signal import butter, sosfiltfilt
|
from scipy.signal import butter, sosfiltfilt
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# CONFIGURABLE PARAMETERS
|
# CONFIGURABLE PARAMETERS (FOR VISUALIZATION ONLY - NOT PRODUCTION VALUES!)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS
|
# These thresholds are intentionally different from production (0.1, 0.1)
|
||||||
SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS
|
# to provide cleaner visualizations. DO NOT use for training/inference.
|
||||||
|
ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold (VISUALIZATION ONLY)
|
||||||
|
SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold (VISUALIZATION ONLY)
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# LOAD DATA FROM GUI's HDF5 FORMAT
|
# LOAD DATA FROM GUI's HDF5 FORMAT
|
||||||
|
|||||||
325
live_predict.py
Normal file
325
live_predict.py
Normal file
@@ -0,0 +1,325 @@
|
|||||||
|
"""
|
||||||
|
live_predict.py — Laptop-side live EMG inference for Bucky Arm.
|
||||||
|
|
||||||
|
Use this script when the ESP32 is in EMG_MAIN mode and you want the laptop to
|
||||||
|
run the classifier (instead of the on-device model). Useful for:
|
||||||
|
- Comparing laptop accuracy vs. on-device accuracy before flashing a new model
|
||||||
|
- Debugging the feature pipeline without reflashing firmware
|
||||||
|
- Running an updated model that hasn't been exported to C yet
|
||||||
|
|
||||||
|
Workflow:
|
||||||
|
1. ESP32 must be in EMG_MAIN mode (MAIN_MODE = EMG_MAIN in config.h)
|
||||||
|
2. This script handshakes → requests STATE_LAPTOP_PREDICT
|
||||||
|
3. ESP32 streams raw ADC CSV at 1 kHz
|
||||||
|
4. Script collects 3s of REST for session normalization, then classifies
|
||||||
|
5. Every 25ms (one hop), the predicted gesture is sent back: {"gesture":"fist"}
|
||||||
|
6. ESP32 executes the received gesture command on the arm
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python live_predict.py --port COM3
|
||||||
|
python live_predict.py --port COM3 --model models/my_model.joblib
|
||||||
|
python live_predict.py --port COM3 --confidence 0.45
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import serial
|
||||||
|
|
||||||
|
# Import from the main training pipeline
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from learning_data_collection import (
|
||||||
|
EMGClassifier,
|
||||||
|
NUM_CHANNELS,
|
||||||
|
SAMPLING_RATE_HZ,
|
||||||
|
WINDOW_SIZE_MS,
|
||||||
|
HOP_SIZE_MS,
|
||||||
|
HAND_CHANNELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Derived constants
|
||||||
|
WINDOW_SIZE = int(WINDOW_SIZE_MS * SAMPLING_RATE_HZ / 1000) # 150 samples
|
||||||
|
HOP_SIZE = int(HOP_SIZE_MS * SAMPLING_RATE_HZ / 1000) # 25 samples
|
||||||
|
BAUD_RATE = 921600
|
||||||
|
CALIB_SECS = 3.0 # seconds of REST to collect at startup for normalization
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||||
|
p.add_argument("--port", required=True,
|
||||||
|
help="Serial port (e.g. COM3 on Windows, /dev/ttyUSB0 on Linux)")
|
||||||
|
p.add_argument("--model", default=None,
|
||||||
|
help="Path to .joblib model file. Defaults to the most recently trained model.")
|
||||||
|
p.add_argument("--confidence", type=float, default=0.40,
|
||||||
|
help="Reject predictions below this confidence (default: 0.40, same as firmware)")
|
||||||
|
p.add_argument("--no-calib", action="store_true",
|
||||||
|
help="Skip REST calibration (use raw features — only for quick testing)")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: str | None) -> EMGClassifier:
|
||||||
|
if model_path:
|
||||||
|
path = Path(model_path)
|
||||||
|
else:
|
||||||
|
path = EMGClassifier.get_latest_model_path()
|
||||||
|
if path is None:
|
||||||
|
print("[ERROR] No trained model found in models/. Run training first.")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"[Model] Auto-selected latest model: {path.name}")
|
||||||
|
|
||||||
|
classifier = EMGClassifier.load(path)
|
||||||
|
if not classifier.is_trained:
|
||||||
|
print("[ERROR] Loaded model is not trained.")
|
||||||
|
sys.exit(1)
|
||||||
|
return classifier
|
||||||
|
|
||||||
|
|
||||||
|
def handshake(ser: serial.Serial) -> bool:
|
||||||
|
"""Send connect command and wait for ack_connect from ESP32."""
|
||||||
|
print("[Handshake] Sending connect command...")
|
||||||
|
ser.write(b'{"cmd":"connect"}\n')
|
||||||
|
deadline = time.time() + 5.0
|
||||||
|
while time.time() < deadline:
|
||||||
|
raw = ser.readline()
|
||||||
|
line = raw.decode("utf-8", errors="ignore").strip()
|
||||||
|
if "ack_connect" in line:
|
||||||
|
print(f"[Handshake] Connected — {line}")
|
||||||
|
return True
|
||||||
|
if line:
|
||||||
|
print(f"[Handshake] (ignored) {line}")
|
||||||
|
print("[ERROR] No ack_connect within 5s. Is the ESP32 powered and in EMG_MAIN mode?")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def collect_calibration_windows(
|
||||||
|
ser: serial.Serial,
|
||||||
|
n_windows: int,
|
||||||
|
extractor,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Collect n_windows of REST EMG, extract features, and compute
|
||||||
|
per-feature mean and std for session normalization.
|
||||||
|
|
||||||
|
Returns (mean, std) arrays of shape (n_features,).
|
||||||
|
"""
|
||||||
|
print(f"[Calib] Hold arm relaxed at rest for {CALIB_SECS:.0f}s...")
|
||||||
|
|
||||||
|
raw_buf = np.zeros((WINDOW_SIZE, NUM_CHANNELS), dtype=np.float32)
|
||||||
|
samples = 0
|
||||||
|
feat_list = []
|
||||||
|
|
||||||
|
while len(feat_list) < n_windows:
|
||||||
|
raw = ser.readline()
|
||||||
|
line = raw.decode("utf-8", errors="ignore").strip()
|
||||||
|
if not line or line.startswith("{"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
vals = [float(v) for v in line.split(",")]
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if len(vals) != NUM_CHANNELS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Slide window
|
||||||
|
raw_buf = np.roll(raw_buf, -1, axis=0)
|
||||||
|
raw_buf[-1] = vals
|
||||||
|
samples += 1
|
||||||
|
|
||||||
|
if samples >= WINDOW_SIZE and (samples % HOP_SIZE) == 0:
|
||||||
|
feat = extractor.extract_features_window(raw_buf)
|
||||||
|
feat_list.append(feat)
|
||||||
|
|
||||||
|
done = len(feat_list)
|
||||||
|
if done % 10 == 0:
|
||||||
|
pct = int(100 * done / n_windows)
|
||||||
|
print(f" {pct}% ({done}/{n_windows} windows)", end="\r", flush=True)
|
||||||
|
|
||||||
|
print(f"\n[Calib] Collected {len(feat_list)} windows.")
|
||||||
|
feats = np.array(feat_list, dtype=np.float32)
|
||||||
|
mean = feats.mean(axis=0)
|
||||||
|
std = np.where(feats.std(axis=0) > 1e-6, feats.std(axis=0), 1e-6).astype(np.float32)
|
||||||
|
print("[Calib] Session normalization computed.")
|
||||||
|
return mean, std
|
||||||
|
|
||||||
|
|
||||||
|
def load_ensemble():
|
||||||
|
"""Load ensemble sklearn models if available."""
|
||||||
|
path = Path(__file__).parent / 'models' / 'emg_ensemble.joblib'
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
import joblib
|
||||||
|
ens = joblib.load(path)
|
||||||
|
print(f"[Model] Loaded ensemble (4 LDAs)")
|
||||||
|
return ens
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Model] Ensemble load failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def load_mlp():
|
||||||
|
"""Load MLP numpy weights if available."""
|
||||||
|
path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz'
|
||||||
|
if not path.exists():
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
mlp = dict(np.load(path, allow_pickle=True))
|
||||||
|
print(f"[Model] Loaded MLP weights (numpy)")
|
||||||
|
return mlp
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[Model] MLP load failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_ensemble(ens, features):
|
||||||
|
"""Run ensemble: 3 specialist LDAs → meta-LDA → probabilities."""
|
||||||
|
p_td = ens['lda_td'].predict_proba([features[ens['td_idx']]])[0]
|
||||||
|
p_fd = ens['lda_fd'].predict_proba([features[ens['fd_idx']]])[0]
|
||||||
|
p_cc = ens['lda_cc'].predict_proba([features[ens['cc_idx']]])[0]
|
||||||
|
x_meta = np.concatenate([p_td, p_fd, p_cc])
|
||||||
|
return ens['meta_lda'].predict_proba([x_meta])[0]
|
||||||
|
|
||||||
|
|
||||||
|
def run_mlp(mlp, features):
|
||||||
|
"""Run MLP forward pass: Dense(32,relu) → Dense(16,relu) → Dense(5,softmax)."""
|
||||||
|
x = features.astype(np.float32)
|
||||||
|
x = np.maximum(0, x @ mlp['w0'] + mlp['b0'])
|
||||||
|
x = np.maximum(0, x @ mlp['w1'] + mlp['b1'])
|
||||||
|
logits = x @ mlp['w2'] + mlp['b2']
|
||||||
|
e = np.exp(logits - logits.max())
|
||||||
|
return e / e.sum()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# ── Load classifier ──────────────────────────────────────────────────────
|
||||||
|
classifier = load_model(args.model)
|
||||||
|
extractor = classifier.feature_extractor
|
||||||
|
ensemble = load_ensemble()
|
||||||
|
mlp = load_mlp()
|
||||||
|
model_names = ["LDA"]
|
||||||
|
if ensemble:
|
||||||
|
model_names.append("Ensemble")
|
||||||
|
if mlp:
|
||||||
|
model_names.append("MLP")
|
||||||
|
print(f"[Model] Active: {' + '.join(model_names)} ({len(model_names)} models)")
|
||||||
|
|
||||||
|
# ── Open serial ──────────────────────────────────────────────────────────
|
||||||
|
try:
|
||||||
|
ser = serial.Serial(args.port, BAUD_RATE, timeout=1.0)
|
||||||
|
except serial.SerialException as e:
|
||||||
|
print(f"[ERROR] Could not open {args.port}: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
ser.reset_input_buffer()
|
||||||
|
|
||||||
|
# ── Handshake ────────────────────────────────────────────────────────────
|
||||||
|
if not handshake(ser):
|
||||||
|
ser.close()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# ── Request laptop-predict mode ──────────────────────────────────────────
|
||||||
|
ser.write(b'{"cmd":"start_laptop_predict"}\n')
|
||||||
|
print("[Control] ESP32 entering STATE_LAPTOP_PREDICT — streaming ADC...")
|
||||||
|
|
||||||
|
# ── Calibration ──────────────────────────────────────────────────────────
|
||||||
|
calib_mean = None
|
||||||
|
calib_std = None
|
||||||
|
if not args.no_calib:
|
||||||
|
n_calib = max(20, int(CALIB_SECS * 1000 / HOP_SIZE_MS))
|
||||||
|
calib_mean, calib_std = collect_calibration_windows(ser, n_calib, extractor)
|
||||||
|
else:
|
||||||
|
print("[Calib] Skipped (--no-calib). Accuracy may be reduced.")
|
||||||
|
|
||||||
|
# ── Live prediction loop ─────────────────────────────────────────────────
|
||||||
|
print(f"\n[Predict] Running. Confidence threshold: {args.confidence:.2f}")
|
||||||
|
print("[Predict] Press Ctrl+C to stop.\n")
|
||||||
|
|
||||||
|
raw_buf = np.zeros((WINDOW_SIZE, NUM_CHANNELS), dtype=np.float32)
|
||||||
|
samples = 0
|
||||||
|
last_gesture = None
|
||||||
|
n_inferences = 0
|
||||||
|
n_rejected = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
raw = ser.readline()
|
||||||
|
line = raw.decode("utf-8", errors="ignore").strip()
|
||||||
|
|
||||||
|
# Skip JSON telemetry lines from ESP32
|
||||||
|
if not line or line.startswith("{"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Parse CSV sample
|
||||||
|
try:
|
||||||
|
vals = [float(v) for v in line.split(",")]
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
if len(vals) != NUM_CHANNELS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Slide window
|
||||||
|
raw_buf = np.roll(raw_buf, -1, axis=0)
|
||||||
|
raw_buf[-1] = vals
|
||||||
|
samples += 1
|
||||||
|
|
||||||
|
# Classify every HOP_SIZE samples
|
||||||
|
if samples >= WINDOW_SIZE and (samples % HOP_SIZE) == 0:
|
||||||
|
feat = extractor.extract_features_window(raw_buf).astype(np.float32)
|
||||||
|
|
||||||
|
# Apply session normalization
|
||||||
|
if calib_mean is not None:
|
||||||
|
feat = (feat - calib_mean) / calib_std
|
||||||
|
|
||||||
|
# Run all available models and average probabilities
|
||||||
|
probas = [classifier.model.predict_proba([feat])[0]]
|
||||||
|
if ensemble:
|
||||||
|
try:
|
||||||
|
probas.append(run_ensemble(ensemble, feat))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if mlp:
|
||||||
|
try:
|
||||||
|
probas.append(run_mlp(mlp, feat))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
proba = np.mean(probas, axis=0)
|
||||||
|
class_idx = int(np.argmax(proba))
|
||||||
|
confidence = float(proba[class_idx])
|
||||||
|
gesture = classifier.label_names[class_idx]
|
||||||
|
n_inferences += 1
|
||||||
|
|
||||||
|
# Reject below threshold
|
||||||
|
if confidence < args.confidence:
|
||||||
|
n_rejected += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Send gesture command to ESP32
|
||||||
|
cmd = f'{{"gesture":"{gesture}"}}\n'
|
||||||
|
ser.write(cmd.encode("utf-8"))
|
||||||
|
|
||||||
|
# Local logging (only on change)
|
||||||
|
if gesture != last_gesture:
|
||||||
|
reject_rate = 100 * n_rejected / n_inferences if n_inferences else 0
|
||||||
|
print(f" → {gesture:<12} conf={confidence:.2f} "
|
||||||
|
f"reject_rate={reject_rate:.0f}%")
|
||||||
|
last_gesture = gesture
|
||||||
|
n_rejected = 0
|
||||||
|
n_inferences = 0
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n[Stop] Sending stop command to ESP32...")
|
||||||
|
ser.write(b'{"cmd":"stop"}\n')
|
||||||
|
time.sleep(0.2)
|
||||||
|
ser.close()
|
||||||
|
print("[Stop] Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
models/emg_ensemble.joblib
Normal file
BIN
models/emg_ensemble.joblib
Normal file
Binary file not shown.
Binary file not shown.
BIN
models/emg_mlp_weights.npz
Normal file
BIN
models/emg_mlp_weights.npz
Normal file
Binary file not shown.
BIN
models/emg_qda_classifier.joblib
Normal file
BIN
models/emg_qda_classifier.joblib
Normal file
Binary file not shown.
200
train_ensemble.py
Normal file
200
train_ensemble.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
"""
|
||||||
|
Train the full 3-specialist-LDA + meta-LDA ensemble.
|
||||||
|
Requires Change 1 (expanded features) to be implemented first.
|
||||||
|
Exports model_weights_ensemble.h for firmware Change F.
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
LDA_TD (36 time-domain feat) ─┐
|
||||||
|
LDA_FD (24 freq-domain feat) ├─ 15 probs ─► Meta-LDA ─► final class
|
||||||
|
LDA_CC (9 cross-ch feat) ─┘
|
||||||
|
|
||||||
|
Change 7 — priority Tier 3.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
||||||
|
from sklearn.model_selection import cross_val_predict, GroupKFold, cross_val_score
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from learning_data_collection import (
|
||||||
|
SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Load and extract features -----------------------------------------------
|
||||||
|
storage = SessionStorage()
|
||||||
|
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
|
||||||
|
|
||||||
|
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
|
||||||
|
X = extractor.extract_features_batch(X_raw).astype(np.float64)
|
||||||
|
|
||||||
|
# Per-session class-balanced normalization
|
||||||
|
# Must match EMGClassifier._apply_session_normalization():
|
||||||
|
# mean = average of per-class means (not overall mean), std = overall std.
|
||||||
|
# StandardScaler uses overall mean, which biases toward the majority class.
|
||||||
|
for sid in np.unique(session_indices):
|
||||||
|
mask = session_indices == sid
|
||||||
|
X_sess = X[mask]
|
||||||
|
y_sess = y[mask]
|
||||||
|
|
||||||
|
# Class-balanced mean: average of per-class centroids
|
||||||
|
class_means = []
|
||||||
|
for cls in np.unique(y_sess):
|
||||||
|
class_means.append(X_sess[y_sess == cls].mean(axis=0))
|
||||||
|
balanced_mean = np.mean(class_means, axis=0)
|
||||||
|
|
||||||
|
# Overall std (same as StandardScaler)
|
||||||
|
std = X_sess.std(axis=0)
|
||||||
|
std[std < 1e-12] = 1.0 # avoid division by zero
|
||||||
|
|
||||||
|
X[mask] = (X_sess - balanced_mean) / std
|
||||||
|
|
||||||
|
feat_names = extractor.get_feature_names(n_channels=len(HAND_CHANNELS))
|
||||||
|
n_cls = len(np.unique(y))
|
||||||
|
|
||||||
|
# --- Feature subset indices ---------------------------------------------------
|
||||||
|
# Per-channel layout (20 features/channel): indices 0-11 TD, 12-19 FD
|
||||||
|
# Cross-channel features start at index 60 (3 channels × 20 features each)
|
||||||
|
TD_FEAT = ['rms', 'wl', 'zc', 'ssc', 'mav', 'var', 'iemg', 'wamp', 'ar1', 'ar2', 'ar3', 'ar4']
|
||||||
|
FD_FEAT = ['mnf', 'mdf', 'pkf', 'mnp', 'bp0', 'bp1', 'bp2', 'bp3']
|
||||||
|
|
||||||
|
td_idx = [i for i, n in enumerate(feat_names)
|
||||||
|
if any(n.endswith(f'_{f}') for f in TD_FEAT) and n.startswith('ch')]
|
||||||
|
fd_idx = [i for i, n in enumerate(feat_names)
|
||||||
|
if any(n.endswith(f'_{f}') for f in FD_FEAT) and n.startswith('ch')]
|
||||||
|
cc_idx = [i for i, n in enumerate(feat_names) if n.startswith('cc_')]
|
||||||
|
|
||||||
|
print(f"Feature subsets — TD: {len(td_idx)}, FD: {len(fd_idx)}, CC: {len(cc_idx)}")
|
||||||
|
assert len(td_idx) == 36, f"Expected 36 TD features, got {len(td_idx)}"
|
||||||
|
assert len(fd_idx) == 24, f"Expected 24 FD features, got {len(fd_idx)}"
|
||||||
|
assert len(cc_idx) == 9, f"Expected 9 CC features, got {len(cc_idx)}"
|
||||||
|
|
||||||
|
X_td = X[:, td_idx]
|
||||||
|
X_fd = X[:, fd_idx]
|
||||||
|
X_cc = X[:, cc_idx]
|
||||||
|
|
||||||
|
# --- Train specialist LDAs with out-of-fold stacking -------------------------
|
||||||
|
gkf = GroupKFold(n_splits=min(5, len(np.unique(trial_ids))))
|
||||||
|
|
||||||
|
print("Training specialist LDAs (out-of-fold for stacking)...")
|
||||||
|
lda_td = LinearDiscriminantAnalysis()
|
||||||
|
lda_fd = LinearDiscriminantAnalysis()
|
||||||
|
lda_cc = LinearDiscriminantAnalysis()
|
||||||
|
|
||||||
|
oof_td = cross_val_predict(lda_td, X_td, y, cv=gkf, groups=trial_ids, method='predict_proba')
|
||||||
|
oof_fd = cross_val_predict(lda_fd, X_fd, y, cv=gkf, groups=trial_ids, method='predict_proba')
|
||||||
|
oof_cc = cross_val_predict(lda_cc, X_cc, y, cv=gkf, groups=trial_ids, method='predict_proba')
|
||||||
|
|
||||||
|
# Specialist CV accuracy (for diagnostics)
|
||||||
|
for name, mdl, Xs in [('LDA_TD', lda_td, X_td),
|
||||||
|
('LDA_FD', lda_fd, X_fd),
|
||||||
|
('LDA_CC', lda_cc, X_cc)]:
|
||||||
|
sc = cross_val_score(mdl, Xs, y, cv=gkf, groups=trial_ids)
|
||||||
|
print(f" {name}: {sc.mean()*100:.1f}% ± {sc.std()*100:.1f}%")
|
||||||
|
|
||||||
|
# --- Train meta-LDA on out-of-fold outputs ------------------------------------
|
||||||
|
X_meta = np.hstack([oof_td, oof_fd, oof_cc]) # (n_samples, 3*n_cls = 15)
|
||||||
|
meta_lda = LinearDiscriminantAnalysis()
|
||||||
|
meta_sc = cross_val_score(meta_lda, X_meta, y, cv=gkf, groups=trial_ids)
|
||||||
|
print(f" Meta-LDA: {meta_sc.mean()*100:.1f}% ± {meta_sc.std()*100:.1f}%")
|
||||||
|
|
||||||
|
# Fit all models on full dataset for deployment
|
||||||
|
lda_td.fit(X_td, y)
|
||||||
|
lda_fd.fit(X_fd, y)
|
||||||
|
lda_cc.fit(X_cc, y)
|
||||||
|
meta_lda.fit(X_meta, y)
|
||||||
|
|
||||||
|
# --- Export all weights to C header ------------------------------------------
|
||||||
|
def lda_to_c_arrays(lda, name, feat_dim, n_cls, label_names, class_order):
|
||||||
|
"""Generate C array strings for LDA weights and intercepts.
|
||||||
|
|
||||||
|
NOTE: sklearn LDA.coef_ for multi-class has shape (n_classes-1, n_features)
|
||||||
|
when using SVD solver. If so, we use decision_function and re-derive weights.
|
||||||
|
"""
|
||||||
|
coef = lda.coef_
|
||||||
|
intercept = lda.intercept_
|
||||||
|
|
||||||
|
if coef.shape[0] != n_cls:
|
||||||
|
# SVD solver returns (n_cls-1, n_feat); sklearn handles this internally
|
||||||
|
# via scalings_. We refit with 'lsqr' solver to get full coef matrix.
|
||||||
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA2
|
||||||
|
lda2 = LDA2(solver='lsqr')
|
||||||
|
# We can't refit here (no data) so just warn and pad with zeros
|
||||||
|
print(f" WARNING: {name} coef_ shape {coef.shape} != ({n_cls}, {feat_dim}). "
|
||||||
|
f"Padding with zeros. Refit with solver='lsqr' for full matrix.")
|
||||||
|
padded = np.zeros((n_cls, feat_dim))
|
||||||
|
padded[:coef.shape[0]] = coef
|
||||||
|
coef = padded
|
||||||
|
padded_i = np.zeros(n_cls)
|
||||||
|
padded_i[:intercept.shape[0]] = intercept
|
||||||
|
intercept = padded_i
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
lines.append(f"const float {name}_WEIGHTS[{n_cls}][{feat_dim}] = {{")
|
||||||
|
for c in class_order:
|
||||||
|
row = ', '.join(f'{v:.8f}f' for v in coef[c])
|
||||||
|
lines.append(f" {{{row}}}, // {label_names[c]}")
|
||||||
|
lines.append("};")
|
||||||
|
lines.append(f"const float {name}_INTERCEPTS[{n_cls}] = {{")
|
||||||
|
intercept_str = ', '.join(f'{intercept[c]:.8f}f' for c in class_order)
|
||||||
|
lines.append(f" {intercept_str}")
|
||||||
|
lines.append("};")
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class_order = list(range(n_cls))
|
||||||
|
out_path = Path(__file__).parent / 'EMG_Arm/src/core/model_weights_ensemble.h'
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
td_offset = min(td_idx) if td_idx else 0
|
||||||
|
fd_offset = min(fd_idx) if fd_idx else 0
|
||||||
|
cc_offset = min(cc_idx) if cc_idx else 0
|
||||||
|
|
||||||
|
with open(out_path, 'w') as f:
|
||||||
|
f.write("// Auto-generated by train_ensemble.py — do not edit\n")
|
||||||
|
f.write("#pragma once\n\n")
|
||||||
|
f.write("// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from\n")
|
||||||
|
f.write("// model_weights.h to avoid redefinition conflicts.\n")
|
||||||
|
f.write('#include "model_weights.h"\n\n')
|
||||||
|
f.write(f"#define ENSEMBLE_PER_CH_FEATURES 20\n\n")
|
||||||
|
f.write(f"#define TD_FEAT_OFFSET {td_offset}\n")
|
||||||
|
f.write(f"#define TD_NUM_FEATURES {len(td_idx)}\n")
|
||||||
|
f.write(f"#define FD_FEAT_OFFSET {fd_offset}\n")
|
||||||
|
f.write(f"#define FD_NUM_FEATURES {len(fd_idx)}\n")
|
||||||
|
f.write(f"#define CC_FEAT_OFFSET {cc_offset}\n")
|
||||||
|
f.write(f"#define CC_NUM_FEATURES {len(cc_idx)}\n")
|
||||||
|
f.write(f"#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)\n\n")
|
||||||
|
|
||||||
|
f.write("// Feature index arrays for gather operations (TD and FD are non-contiguous)\n")
|
||||||
|
f.write(f"// TD indices: {td_idx}\n")
|
||||||
|
f.write(f"// FD indices: {fd_idx}\n")
|
||||||
|
f.write(f"// CC indices: {cc_idx}\n\n")
|
||||||
|
|
||||||
|
f.write(lda_to_c_arrays(lda_td, 'LDA_TD', len(td_idx), n_cls, label_names, class_order))
|
||||||
|
f.write('\n\n')
|
||||||
|
f.write(lda_to_c_arrays(lda_fd, 'LDA_FD', len(fd_idx), n_cls, label_names, class_order))
|
||||||
|
f.write('\n\n')
|
||||||
|
f.write(lda_to_c_arrays(lda_cc, 'LDA_CC', len(cc_idx), n_cls, label_names, class_order))
|
||||||
|
f.write('\n\n')
|
||||||
|
f.write(lda_to_c_arrays(meta_lda, 'META_LDA', 3 * n_cls, n_cls, label_names, class_order))
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
print(f"Exported ensemble weights to {out_path}")
|
||||||
|
print(f"Total weight storage: "
|
||||||
|
f"{(len(td_idx) + len(fd_idx) + len(cc_idx) + 3*n_cls) * n_cls * 4} bytes float32")
|
||||||
|
|
||||||
|
# --- Also save sklearn models for laptop-side inference ----------------------
|
||||||
|
import joblib
|
||||||
|
ensemble_bundle = {
|
||||||
|
'lda_td': lda_td,
|
||||||
|
'lda_fd': lda_fd,
|
||||||
|
'lda_cc': lda_cc,
|
||||||
|
'meta_lda': meta_lda,
|
||||||
|
'td_idx': td_idx,
|
||||||
|
'fd_idx': fd_idx,
|
||||||
|
'cc_idx': cc_idx,
|
||||||
|
'label_names': label_names,
|
||||||
|
}
|
||||||
|
ensemble_joblib = Path(__file__).parent / 'models' / 'emg_ensemble.joblib'
|
||||||
|
ensemble_joblib.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
joblib.dump(ensemble_bundle, ensemble_joblib)
|
||||||
|
print(f"Saved laptop ensemble model to {ensemble_joblib}")
|
||||||
106
train_mlp_tflite.py
Normal file
106
train_mlp_tflite.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""
|
||||||
|
Train int8 MLP for ESP32-S3 deployment via TFLite Micro.
|
||||||
|
Run AFTER Change 0 (label shift) + Change 1 (expanded features).
|
||||||
|
|
||||||
|
Change E — priority Tier 3.
|
||||||
|
Outputs: EMG_Arm/src/core/emg_model_data.cc
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
|
||||||
|
|
||||||
|
try:
|
||||||
|
import tensorflow as tf
|
||||||
|
except ImportError:
|
||||||
|
print("ERROR: TensorFlow not installed. Run: pip install tensorflow")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# --- Load and extract features -----------------------------------------------
|
||||||
|
storage = SessionStorage()
|
||||||
|
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
|
||||||
|
|
||||||
|
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
|
||||||
|
X = extractor.extract_features_batch(X_raw).astype(np.float32)
|
||||||
|
|
||||||
|
# Per-session class-balanced normalization (must match EMGClassifier + train_ensemble.py)
|
||||||
|
for sid in np.unique(session_indices):
|
||||||
|
mask = session_indices == sid
|
||||||
|
X_sess = X[mask]
|
||||||
|
y_sess = y[mask]
|
||||||
|
class_means = [X_sess[y_sess == cls].mean(axis=0) for cls in np.unique(y_sess)]
|
||||||
|
balanced_mean = np.mean(class_means, axis=0)
|
||||||
|
std = X_sess.std(axis=0)
|
||||||
|
std[std < 1e-12] = 1.0
|
||||||
|
X[mask] = (X_sess - balanced_mean) / std
|
||||||
|
|
||||||
|
n_feat = X.shape[1]
|
||||||
|
n_cls = len(np.unique(y))
|
||||||
|
print(f"Dataset: {len(X)} samples, {n_feat} features, {n_cls} classes")
|
||||||
|
print(f"Classes: {label_names}")
|
||||||
|
|
||||||
|
# --- Build and train MLP -----------------------------------------------------
|
||||||
|
model = tf.keras.Sequential([
|
||||||
|
tf.keras.layers.Input(shape=(n_feat,)),
|
||||||
|
tf.keras.layers.Dense(32, activation='relu'),
|
||||||
|
tf.keras.layers.Dropout(0.2),
|
||||||
|
tf.keras.layers.Dense(16, activation='relu'),
|
||||||
|
tf.keras.layers.Dense(n_cls, activation='softmax'),
|
||||||
|
])
|
||||||
|
model.compile(optimizer='adam',
|
||||||
|
loss='sparse_categorical_crossentropy',
|
||||||
|
metrics=['accuracy'])
|
||||||
|
model.summary()
|
||||||
|
model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1)
|
||||||
|
|
||||||
|
# --- Convert to int8 TFLite --------------------------------------------------
|
||||||
|
def representative_dataset():
|
||||||
|
for i in range(0, len(X), 10):
|
||||||
|
yield [X[i:i+1]]
|
||||||
|
|
||||||
|
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||||
|
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||||
|
converter.representative_dataset = representative_dataset
|
||||||
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||||
|
converter.inference_input_type = tf.int8
|
||||||
|
converter.inference_output_type = tf.int8
|
||||||
|
tflite_model = converter.convert()
|
||||||
|
|
||||||
|
# --- Write emg_model_data.cc -------------------------------------------------
|
||||||
|
out_cc = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.cc'
|
||||||
|
out_h = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.h'
|
||||||
|
|
||||||
|
with open(out_cc, 'w') as f:
|
||||||
|
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
|
||||||
|
f.write('#include "emg_model_data.h"\n')
|
||||||
|
f.write(f'const int g_model_len = {len(tflite_model)};\n')
|
||||||
|
f.write('alignas(8) const unsigned char g_model[] = {\n ')
|
||||||
|
f.write(', '.join(f'0x{b:02x}' for b in tflite_model))
|
||||||
|
f.write('\n};\n')
|
||||||
|
|
||||||
|
with open(out_h, 'w') as f:
|
||||||
|
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
|
||||||
|
f.write('#pragma once\n\n')
|
||||||
|
f.write('#ifdef __cplusplus\nextern "C" {\n#endif\n\n')
|
||||||
|
f.write('extern const unsigned char g_model[];\n')
|
||||||
|
f.write('extern const int g_model_len;\n\n')
|
||||||
|
f.write('#ifdef __cplusplus\n}\n#endif\n')
|
||||||
|
|
||||||
|
print(f"Wrote {out_cc} ({len(tflite_model)} bytes)")
|
||||||
|
print(f"Wrote {out_h}")
|
||||||
|
|
||||||
|
# --- Save MLP weights as numpy for laptop-side inference (no TF needed) ------
|
||||||
|
layer_weights = [layer.get_weights() for layer in model.layers if layer.get_weights()]
|
||||||
|
mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz'
|
||||||
|
mlp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
np.savez(mlp_path,
|
||||||
|
w0=layer_weights[0][0], b0=layer_weights[0][1], # Dense(32, relu)
|
||||||
|
w1=layer_weights[1][0], b1=layer_weights[1][1], # Dense(16, relu)
|
||||||
|
w2=layer_weights[2][0], b2=layer_weights[2][1], # Dense(5, softmax)
|
||||||
|
label_names=np.array(label_names))
|
||||||
|
print(f"Saved laptop MLP weights to {mlp_path}")
|
||||||
|
|
||||||
|
print(f"\nNext steps:")
|
||||||
|
print(f" 1. Set MODEL_USE_MLP 1 in EMG_Arm/src/core/model_weights.h")
|
||||||
|
print(f" 2. Run: pio run -t upload")
|
||||||
Reference in New Issue
Block a user