This commit is contained in:
2026-01-27 21:31:49 -06:00
parent 9bdf9d1109
commit 31dede537c
7 changed files with 893 additions and 320 deletions

View File

@@ -20,6 +20,7 @@ set(DRIVER_SOURCES
set(CORE_SOURCES set(CORE_SOURCES
core/gestures.c core/gestures.c
core/inference.c
) )
set(APP_SOURCES set(APP_SOURCES

View File

@@ -12,17 +12,18 @@
* @note This is Layer 4 (Application). * @note This is Layer 4 (Application).
*/ */
#include "esp_timer.h"
#include <freertos/FreeRTOS.h>
#include <freertos/queue.h>
#include <freertos/task.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <freertos/FreeRTOS.h>
#include <freertos/task.h>
#include <freertos/queue.h>
#include "esp_timer.h"
#include "config/config.h" #include "config/config.h"
#include "drivers/hand.h"
#include "drivers/emg_sensor.h"
#include "core/gestures.h" #include "core/gestures.h"
#include "core/inference.h" // [NEW]
#include "drivers/emg_sensor.h"
#include "drivers/hand.h"
/******************************************************************************* /*******************************************************************************
* Constants * Constants
@@ -39,20 +40,22 @@
* @brief Device state machine. * @brief Device state machine.
*/ */
typedef enum { typedef enum {
STATE_IDLE = 0, /**< Waiting for connect command */ STATE_IDLE = 0, /**< Waiting for connect command */
STATE_CONNECTED, /**< Connected, waiting for start command */ STATE_CONNECTED, /**< Connected, waiting for start command */
STATE_STREAMING, /**< Actively streaming EMG data */ STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */
STATE_PREDICTING, /**< [NEW] On-device inference and control */
} device_state_t; } device_state_t;
/** /**
* @brief Commands from host. * @brief Commands from host.
*/ */
typedef enum { typedef enum {
CMD_NONE = 0, CMD_NONE = 0,
CMD_CONNECT, CMD_CONNECT,
CMD_START, CMD_START, /**< Start raw streaming */
CMD_STOP, CMD_START_PREDICT, /**< [NEW] Start on-device prediction */
CMD_DISCONNECT, CMD_STOP,
CMD_DISCONNECT,
} command_t; } command_t;
/******************************************************************************* /*******************************************************************************
@@ -80,38 +83,39 @@ static void send_ack_connect(void);
* @param line Input line from serial * @param line Input line from serial
* @return Parsed command * @return Parsed command
*/ */
static command_t parse_command(const char* line) static command_t parse_command(const char *line) {
{ /* Simple JSON parsing - look for "cmd" field */
/* Simple JSON parsing - look for "cmd" field */ const char *cmd_start = strstr(line, "\"cmd\"");
const char* cmd_start = strstr(line, "\"cmd\""); if (!cmd_start) {
if (!cmd_start) {
return CMD_NONE;
}
/* Find the value after "cmd": */
const char* value_start = strchr(cmd_start, ':');
if (!value_start) {
return CMD_NONE;
}
/* Skip whitespace and opening quote */
value_start++;
while (*value_start == ' ' || *value_start == '"') {
value_start++;
}
/* Match command strings */
if (strncmp(value_start, "connect", 7) == 0) {
return CMD_CONNECT;
} else if (strncmp(value_start, "start", 5) == 0) {
return CMD_START;
} else if (strncmp(value_start, "stop", 4) == 0) {
return CMD_STOP;
} else if (strncmp(value_start, "disconnect", 10) == 0) {
return CMD_DISCONNECT;
}
return CMD_NONE; return CMD_NONE;
}
/* Find the value after "cmd": */
const char *value_start = strchr(cmd_start, ':');
if (!value_start) {
return CMD_NONE;
}
/* Skip whitespace and opening quote */
value_start++;
while (*value_start == ' ' || *value_start == '"') {
value_start++;
}
/* Match command strings */
if (strncmp(value_start, "connect", 7) == 0) {
return CMD_CONNECT;
} else if (strncmp(value_start, "start_predict", 13) == 0) {
return CMD_START_PREDICT;
} else if (strncmp(value_start, "start", 5) == 0) {
return CMD_START;
} else if (strncmp(value_start, "stop", 4) == 0) {
return CMD_STOP;
} else if (strncmp(value_start, "disconnect", 10) == 0) {
return CMD_DISCONNECT;
}
return CMD_NONE;
} }
/******************************************************************************* /*******************************************************************************
@@ -120,224 +124,213 @@ static command_t parse_command(const char* line)
/** /**
* @brief FreeRTOS task to read serial input and parse commands. * @brief FreeRTOS task to read serial input and parse commands.
*
* This task runs continuously, reading lines from stdin (USB serial)
* and updating device state directly. This allows commands to interrupt
* streaming immediately via the volatile state variable.
*
* @param pvParameters Unused
*/ */
static void serial_input_task(void* pvParameters) static void serial_input_task(void *pvParameters) {
{ char line_buffer[CMD_BUFFER_SIZE];
char line_buffer[CMD_BUFFER_SIZE]; int line_idx = 0;
int line_idx = 0;
while (1) { while (1) {
/* Read one character at a time */ int c = getchar();
int c = getchar();
if (c == EOF || c == 0xFF) { if (c == EOF || c == 0xFF) {
/* No data available, yield to other tasks */ vTaskDelay(pdMS_TO_TICKS(10));
vTaskDelay(pdMS_TO_TICKS(10)); continue;
continue;
}
if (c == '\n' || c == '\r') {
/* End of line - process command */
if (line_idx > 0) {
line_buffer[line_idx] = '\0';
command_t cmd = parse_command(line_buffer);
if (cmd != CMD_NONE) {
/* Handle CONNECT command from ANY state (for reconnection/recovery) */
if (cmd == CMD_CONNECT) {
/* Stop streaming if active, reset to CONNECTED state */
g_device_state = STATE_CONNECTED;
send_ack_connect();
printf("[STATE] ANY -> CONNECTED (reconnect)\n");
}
/* Handle other state transitions */
else {
switch (g_device_state) {
case STATE_IDLE:
/* Only CONNECT allowed from IDLE (handled above) */
break;
case STATE_CONNECTED:
if (cmd == CMD_START) {
g_device_state = STATE_STREAMING;
printf("[STATE] CONNECTED -> STREAMING\n");
/* Signal state machine to start streaming */
xQueueSend(g_cmd_queue, &cmd, 0);
} else if (cmd == CMD_DISCONNECT) {
g_device_state = STATE_IDLE;
printf("[STATE] CONNECTED -> IDLE\n");
}
break;
case STATE_STREAMING:
if (cmd == CMD_STOP) {
g_device_state = STATE_CONNECTED;
printf("[STATE] STREAMING -> CONNECTED\n");
/* Streaming loop will exit when it sees state change */
} else if (cmd == CMD_DISCONNECT) {
g_device_state = STATE_IDLE;
printf("[STATE] STREAMING -> IDLE\n");
/* Streaming loop will exit when it sees state change */
}
break;
}
}
}
line_idx = 0;
}
} else if (line_idx < CMD_BUFFER_SIZE - 1) {
/* Add character to buffer */
line_buffer[line_idx++] = (char)c;
} else {
/* Buffer overflow - reset */
line_idx = 0;
}
} }
if (c == '\n' || c == '\r') {
if (line_idx > 0) {
line_buffer[line_idx] = '\0';
command_t cmd = parse_command(line_buffer);
if (cmd != CMD_NONE) {
if (cmd == CMD_CONNECT) {
g_device_state = STATE_CONNECTED;
send_ack_connect();
printf("[STATE] ANY -> CONNECTED (reconnect)\n");
} else {
switch (g_device_state) {
case STATE_IDLE:
break;
case STATE_CONNECTED:
if (cmd == CMD_START) {
g_device_state = STATE_STREAMING;
printf("[STATE] CONNECTED -> STREAMING\n");
xQueueSend(g_cmd_queue, &cmd, 0);
} else if (cmd == CMD_START_PREDICT) {
g_device_state = STATE_PREDICTING;
printf("[STATE] CONNECTED -> PREDICTING\n");
xQueueSend(g_cmd_queue, &cmd, 0);
} else if (cmd == CMD_DISCONNECT) {
g_device_state = STATE_IDLE;
printf("[STATE] CONNECTED -> IDLE\n");
}
break;
case STATE_STREAMING:
case STATE_PREDICTING:
if (cmd == CMD_STOP) {
g_device_state = STATE_CONNECTED;
printf("[STATE] ACTIVE -> CONNECTED\n");
} else if (cmd == CMD_DISCONNECT) {
g_device_state = STATE_IDLE;
printf("[STATE] ACTIVE -> IDLE\n");
}
break;
}
}
}
line_idx = 0;
}
} else if (line_idx < CMD_BUFFER_SIZE - 1) {
line_buffer[line_idx++] = (char)c;
} else {
line_idx = 0;
}
}
} }
/******************************************************************************* /*******************************************************************************
* State Machine * State Machine
******************************************************************************/ ******************************************************************************/
/** static void send_ack_connect(void) {
* @brief Send JSON acknowledgment for connection. printf(
*/ "{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n",
static void send_ack_connect(void) EMG_NUM_CHANNELS);
{ fflush(stdout);
printf("{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n",
EMG_NUM_CHANNELS);
fflush(stdout);
} }
/** /**
* @brief Stream EMG data continuously until stopped. * @brief Stream raw EMG data (Training Mode).
*
* This function blocks and streams data at the configured sample rate.
* Returns when state changes from STREAMING.
*/ */
static void stream_emg_data(void) static void stream_emg_data(void) {
{ emg_sample_t sample;
emg_sample_t sample; const TickType_t delay_ticks = 1;
const TickType_t delay_ticks = 1; /* 1 tick = 1ms at 1000 Hz tick rate */
while (g_device_state == STATE_STREAMING) { while (g_device_state == STATE_STREAMING) {
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
emg_sensor_read(&sample);
/* Output in CSV format - channels only, Python handles timestamps */
printf("%u,%u,%u,%u\n",
sample.channels[0],
sample.channels[1],
sample.channels[2],
sample.channels[3]);
/* Yield to FreeRTOS scheduler - prevents watchdog timeout */
vTaskDelay(delay_ticks);
}
}
/**
* @brief Main state machine loop.
*
* Monitors device state and starts streaming when requested.
* Serial input task handles all state transitions directly.
*/
static void state_machine_loop(void)
{
command_t cmd;
const TickType_t poll_interval = pdMS_TO_TICKS(50);
while (1) {
/* Check if we should start streaming */
if (g_device_state == STATE_STREAMING) {
/* Stream until state changes (via serial input task) */
stream_emg_data();
/* Returns when state is no longer STREAMING */
}
/* Wait for start command or just poll state */
/* Timeout allows checking state even if queue is empty */
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
/* Note: State transitions are handled by serial_input_task */
/* This loop only triggers streaming when state becomes STREAMING */
}
}
void emgPrinter() {
TickType_t previousWake = xTaskGetTickCount();
while (1) {
emg_sample_t sample;
emg_sensor_read(&sample); emg_sensor_read(&sample);
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) { printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
printf("%d", sample.channels[i]); sample.channels[2], sample.channels[3]);
if (i != EMG_NUM_CHANNELS - 1) printf(" | "); vTaskDelay(delay_ticks);
}
}
/**
* @brief Run on-device inference (Prediction Mode).
*/
static void run_inference_loop(void) {
emg_sample_t sample;
const TickType_t delay_ticks = 1; // 1ms @ 1kHz
int last_gesture = -1;
// Reset inference state
inference_init();
printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n");
while (g_device_state == STATE_PREDICTING) {
emg_sensor_read(&sample);
// Add to buffer
// Note: sample.channels is uint16_t, matching inference engine expectation
if (inference_add_sample(sample.channels)) {
// Buffer full (sliding window), run prediction
// We can optimize stride here (e.g. valid prediction only every N
// samples) For now, let's predict every sample (sliding window) or
// throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz?
// maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid
// UART spam.
static int stride_counter = 0;
stride_counter++;
if (stride_counter >= 20) { // 20ms stride
float confidence = 0;
int gesture_idx = inference_predict(&confidence);
stride_counter = 0;
if (gesture_idx >= 0) {
// Map class index (0-N) to gesture enum (correct hardware action)
int gesture_enum = inference_get_gesture_enum(gesture_idx);
// Execute gesture on hand
gestures_execute((gesture_t)gesture_enum);
// Send telemetry if changed or periodically?
// "Live prediction flow should change to only have each new output...
// sent"
if (gesture_idx != last_gesture) {
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
inference_get_class_name(gesture_idx), confidence);
last_gesture = gesture_idx;
}
}
}
} }
printf("\n");
// vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100)); vTaskDelay(delay_ticks);
}
}
static void state_machine_loop(void) {
command_t cmd;
const TickType_t poll_interval = pdMS_TO_TICKS(50);
while (1) {
if (g_device_state == STATE_STREAMING) {
stream_emg_data();
} else if (g_device_state == STATE_PREDICTING) {
run_inference_loop();
}
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
} }
} }
void appConnector() { void appConnector() {
/* Create command queue */ g_cmd_queue = xQueueCreate(10, sizeof(command_t));
g_cmd_queue = xQueueCreate(10, sizeof(command_t)); if (g_cmd_queue == NULL) {
if (g_cmd_queue == NULL) { printf("[ERROR] Failed to create command queue!\n");
printf("[ERROR] Failed to create command queue!\n"); return;
return; }
}
/* Launch serial input task */ xTaskCreate(serial_input_task, "serial_input", 4096, NULL, 5, NULL);
xTaskCreate(
serial_input_task,
"serial_input",
4096, /* Stack size */
NULL, /* Parameters */
5, /* Priority */
NULL /* Task handle */
);
printf("[PROTOCOL] Waiting for host to connect...\n"); printf("[PROTOCOL] Waiting for host to connect...\n");
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n\n"); printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n");
printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device "
"inference\n\n");
/* Run main state machine */ state_machine_loop();
state_machine_loop();
} }
/******************************************************************************* /*******************************************************************************
* Application Entry Point * Application Entry Point
******************************************************************************/ ******************************************************************************/
void app_main(void) void app_main(void) {
{ printf("\n");
printf("\n"); printf("========================================\n");
printf("========================================\n"); printf(" Bucky Arm - EMG Robotic Hand\n");
printf(" Bucky Arm - EMG Robotic Hand\n"); printf(" Firmware v2.1.0 (Inference Enabled)\n");
printf(" Firmware v2.0.0 (Handshake Protocol)\n"); printf("========================================\n\n");
printf("========================================\n\n");
/* Initialize subsystems */ printf("[INIT] Initializing hand (servos)...\n");
printf("[INIT] Initializing hand (servos)...\n"); hand_init();
hand_init();
printf("[INIT] Initializing EMG sensor...\n"); printf("[INIT] Initializing EMG sensor...\n");
emg_sensor_init(); emg_sensor_init();
printf("[INIT] Initializing Inference Engine...\n");
inference_init();
#if FEATURE_FAKE_EMG #if FEATURE_FAKE_EMG
printf("[INIT] Using FAKE EMG data (sensors not connected)\n"); printf("[INIT] Using FAKE EMG data (sensors not connected)\n");
#else #else
printf("[INIT] Using REAL EMG sensors\n"); printf("[INIT] Using REAL EMG sensors\n");
#endif #endif
printf("[INIT] Done!\n\n"); printf("[INIT] Done!\n\n");
// emgPrinter(); appConnector();
appConnector();
} }

View File

@@ -0,0 +1,281 @@
/**
* @file inference.c
* @brief Implementation of EMG inference engine.
*/
#include "inference.h"
#include "config/config.h"
#include "model_weights.h"
#include <math.h>
#include <stdio.h>
#include <string.h>
// --- Constants ---
#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python)
#define VOTE_WINDOW 5 // Majority vote window size
#define DEBOUNCE_COUNT 3 // Confirmations needed to change output
// --- State ---
static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS];
static int buffer_head = 0;
static int samples_collected = 0;
// Smoothing State
static float smoothed_probs[MODEL_NUM_CLASSES];
static int vote_history[VOTE_WINDOW];
static int vote_head = 0;
static int current_output = -1;
static int pending_output = -1;
static int pending_count = 0;
void inference_init(void) {
memset(window_buffer, 0, sizeof(window_buffer));
buffer_head = 0;
samples_collected = 0;
// Initialize smoothing
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES;
}
for (int i = 0; i < VOTE_WINDOW; i++) {
vote_history[i] = -1;
}
vote_head = 0;
current_output = -1;
pending_output = -1;
pending_count = 0;
}
bool inference_add_sample(uint16_t *channels) {
// Add to circular buffer
for (int i = 0; i < NUM_CHANNELS; i++) {
window_buffer[buffer_head][i] = channels[i];
}
buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE;
if (samples_collected < INFERENCE_WINDOW_SIZE) {
samples_collected++;
return false;
}
return true; // Buffer is full (always ready in sliding window, but caller
// controls stride)
}
// --- Feature Extraction ---
static void compute_features(float *features_out) {
// Process each channel
// We need to iterate over the logical window (unrolling circular buffer)
for (int ch = 0; ch < NUM_CHANNELS; ch++) {
float sum = 0;
float sq_sum = 0;
// Pass 1: Mean (for centering) and raw values collection
// We could optimize by not copying, but accessing logically is safer
float signal[INFERENCE_WINDOW_SIZE];
int idx = buffer_head; // Oldest sample
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
signal[i] = (float)window_buffer[idx][ch];
sum += signal[i];
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
}
float mean = sum / INFERENCE_WINDOW_SIZE;
// Pass 2: Centering and Features
float wl = 0;
int zc = 0;
int ssc = 0;
// Center the signal
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
signal[i] -= mean;
sq_sum += signal[i] * signal[i];
}
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
// Thresholds
float zc_thresh = FEAT_ZC_THRESH * rms;
float ssc_thresh = (FEAT_SSC_THRESH * rms) *
(FEAT_SSC_THRESH * rms); // threshold is on diff product
for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) {
// WL
wl += fabsf(signal[i + 1] - signal[i]);
// ZC
if ((signal[i] > 0 && signal[i + 1] < 0) ||
(signal[i] < 0 && signal[i + 1] > 0)) {
if (fabsf(signal[i] - signal[i + 1]) > zc_thresh) {
zc++;
}
}
// SSC (needs 3 points, so loop to N-2)
if (i < INFERENCE_WINDOW_SIZE - 2) {
float diff1 = signal[i + 1] - signal[i];
float diff2 = signal[i + 1] - signal[i + 2];
if ((diff1 * diff2) > ssc_thresh) {
ssc++;
}
}
}
// Store features: [RMS, WL, ZC, SSC] per channel
int base = ch * 4;
features_out[base + 0] = rms;
features_out[base + 1] = wl;
features_out[base + 2] = (float)zc;
features_out[base + 3] = (float)ssc;
}
}
// --- Prediction ---
int inference_predict(float *confidence) {
if (samples_collected < INFERENCE_WINDOW_SIZE) {
return -1;
}
// 1. Extract Features
float features[MODEL_NUM_FEATURES];
compute_features(features);
// 2. LDA Inference (Linear Score)
float raw_scores[MODEL_NUM_CLASSES];
float max_score = -1e9;
int max_idx = 0;
// Calculate raw discriminative scores
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
float score = LDA_INTERCEPTS[c];
for (int f = 0; f < MODEL_NUM_FEATURES; f++) {
score += features[f] * LDA_WEIGHTS[c][f];
}
raw_scores[c] = score;
}
// Convert scores to probabilities (Softmax)
// LDA scores are log-likelihoods + const. Softmax is appropriate.
float sum_exp = 0;
float probas[MODEL_NUM_CLASSES];
// Numerical stability: subtract max
// Create temp copy for max finding
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
if (raw_scores[c] > max_score)
max_score = raw_scores[c];
}
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
probas[c] = expf(raw_scores[c] - max_score);
sum_exp += probas[c];
}
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
probas[c] /= sum_exp;
}
// 3. Smoothing
// 3a. Probability EMA
float max_smoothed_prob = 0;
int smoothed_winner = 0;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
smoothed_probs[c] = (SMOOTHING_FACTOR * smoothed_probs[c]) +
((1.0f - SMOOTHING_FACTOR) * probas[c]);
if (smoothed_probs[c] > max_smoothed_prob) {
max_smoothed_prob = smoothed_probs[c];
smoothed_winner = c;
}
}
// 3b. Majority Vote
vote_history[vote_head] = smoothed_winner;
vote_head = (vote_head + 1) % VOTE_WINDOW;
int counts[MODEL_NUM_CLASSES];
memset(counts, 0, sizeof(counts));
for (int i = 0; i < VOTE_WINDOW; i++) {
if (vote_history[i] != -1) {
counts[vote_history[i]]++;
}
}
int majority_winner = 0;
int majority_count = 0;
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
if (counts[c] > majority_count) {
majority_count = counts[c];
majority_winner = c;
}
}
// 3c. Debounce
int final_result = current_output;
if (current_output == -1) {
current_output = majority_winner;
pending_output = majority_winner;
pending_count = 1;
final_result = majority_winner;
} else if (majority_winner == current_output) {
pending_output = majority_winner;
pending_count = 1;
} else if (majority_winner == pending_output) {
pending_count++;
if (pending_count >= DEBOUNCE_COUNT) {
current_output = majority_winner;
final_result = majority_winner;
}
} else {
pending_output = majority_winner;
pending_count = 1;
}
// Use smoothed probability of the final winner as confidence
// Or simpler: use fraction of votes
*confidence = (float)majority_count / VOTE_WINDOW;
return final_result;
}
const char *inference_get_class_name(int class_idx) {
if (class_idx >= 0 && class_idx < MODEL_NUM_CLASSES) {
return MODEL_CLASS_NAMES[class_idx];
}
return "UNKNOWN";
}
int inference_get_gesture_enum(int class_idx) {
const char *name = inference_get_class_name(class_idx);
// Map string name to gesture_t enum
// Strings must match those in Python list: ["fist", "hook_em", "open",
// "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums
// are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5
// Case-insensitive check would be safer, but let's assume Python output is
// lowercase as seen in scripts or uppercase if specified. In
// learning_data_collection.py, they seem to be "rest", "open", "fist", etc.
// Simple string matching
if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0)
return GESTURE_REST;
if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0)
return GESTURE_FIST;
if (strcmp(name, "open") == 0 || strcmp(name, "OPEN") == 0)
return GESTURE_OPEN;
if (strcmp(name, "hook_em") == 0 || strcmp(name, "HOOK_EM") == 0)
return GESTURE_HOOK_EM;
if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0)
return GESTURE_THUMBS_UP;
return GESTURE_NONE;
}

View File

@@ -0,0 +1,47 @@
/**
* @file inference.h
* @brief On-device inference engine for EMG gesture recognition.
*/
#ifndef INFERENCE_H
#define INFERENCE_H
#include <stdbool.h>
#include <stdint.h>
// --- Configuration ---
#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python)
#define NUM_CHANNELS 4 // Number of EMG channels
/**
* @brief Initialize the inference engine.
*/
void inference_init(void);
/**
* @brief Add a sample to the inference buffer.
*
* @param channels Array of 4 channel values (raw ADC)
* @return true if a full window is ready for processing
*/
bool inference_add_sample(uint16_t *channels);
/**
* @brief Run inference on the current window.
*
* @param confidence Output pointer for confidence score (0.0 - 1.0)
* @return Detected class index (-1 if error)
*/
int inference_predict(float *confidence);
/**
* @brief Get the name of a class index.
*/
const char *inference_get_class_name(int class_idx);
/**
* @brief Map class index to gesture_t enum.
*/
int inference_get_gesture_enum(int class_idx);
#endif /* INFERENCE_H */

View File

@@ -0,0 +1,44 @@
/**
* @file model_weights.h
* @brief Placeholder for trained model weights.
*
* This file should be generated by the Python script (Training Page -> Export for ESP32).
*/
#ifndef MODEL_WEIGHTS_H
#define MODEL_WEIGHTS_H
#include <stdint.h>
/* Metadata */
#define MODEL_NUM_CLASSES 5
#define MODEL_NUM_FEATURES 16
/* Class Names */
static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
"REST",
"OPEN",
"FIST",
"HOOK_EM",
"THUMBS_UP",
};
/* Feature Extractor Parameters */
#define FEAT_ZC_THRESH 0.1f
#define FEAT_SSC_THRESH 0.1f
/* LDA Intercepts/Biases */
static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {
0.0f, 0.0f, 0.0f, 0.0f, 0.0f
};
/* LDA Coefficients (Weights) */
static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {
{0.0f},
{0.0f},
{0.0f},
{0.0f},
{0.0f}
};
#endif /* MODEL_WEIGHTS_H */

View File

@@ -1322,6 +1322,18 @@ class TrainingPage(BasePage):
) )
self.train_button.pack(pady=20) self.train_button.pack(pady=20)
# Export button
self.export_button = ctk.CTkButton(
self.content,
text="Export for ESP32",
font=ctk.CTkFont(size=14),
height=40,
fg_color="green",
state="disabled",
command=self.export_model
)
self.export_button.pack(pady=5)
# Progress # Progress
self.progress_bar = ctk.CTkProgressBar(self.content, width=400) self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
self.progress_bar.pack(pady=10) self.progress_bar.pack(pady=10)
@@ -1434,6 +1446,32 @@ class TrainingPage(BasePage):
finally: finally:
self.after(0, lambda: self.train_button.configure(state="normal")) self.after(0, lambda: self.train_button.configure(state="normal"))
self.after(0, lambda: self.export_button.configure(state="normal"))
def export_model(self):
"""Export trained model to C header."""
if not self.classifier or not self.classifier.is_trained:
messagebox.showerror("Error", "No trained model to export!")
return
# Default path in ESP32 project
default_path = Path("EMG_Arm/src/core/model_weights.h").absolute()
# Ask user for location, defaulting to the ESP32 project source
filename = tk.filedialog.asksaveasfilename(
title="Export Model Header",
initialdir=default_path.parent,
initialfile=default_path.name,
filetypes=[("C Header", "*.h")]
)
if filename:
try:
path = self.classifier.export_to_header(filename)
self._log(f"\nExported model to: {path}")
messagebox.showinfo("Export Success", f"Model exported to:\n{path}\n\nRecompile ESP32 firmware to apply.")
except Exception as e:
messagebox.showerror("Export Error", f"Failed to export: {e}")
def _log(self, text: str): def _log(self, text: str):
"""Add text to results.""" """Add text to results."""
@@ -1861,115 +1899,158 @@ class PredictionPage(BasePage):
self._update_connection_status("gray", "Disconnected") self._update_connection_status("gray", "Disconnected")
self.connect_button.configure(text="Connect") self.connect_button.configure(text="Connect")
def _prediction_thread(self): def toggle_prediction(self):
"""Background prediction thread.""" """Start or stop prediction."""
# For simulated mode, create new stream if self.is_predicting:
if not self.using_real_hardware: self.stop_prediction()
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) else:
self.start_prediction()
# Stream is already started (either via handshake for real HW or will be started for simulated) def start_prediction(self):
"""Start live prediction."""
# Determine mode
self.using_real_hardware = (self.source_var.get() == "real")
if self.using_real_hardware:
if not self.is_connected or not self.stream:
messagebox.showerror("Not Connected", "Please connect to ESP32 first.")
return
print("[DEBUG] Starting Edge Prediction (On-Device)...")
try:
# Send "start_predict" command to ESP32
if hasattr(self.stream, 'ser'):
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
self.stream.running = True
else:
# Fallback
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
self.stream.running = True
except Exception as e:
messagebox.showerror("Start Error", f"Failed to start: {e}")
return
else:
# Simulated - use PC-side inference
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
self.stream.start()
# Load model for PC-side (Simulated) OR for display (optional)
# Even for Edge, we might want the label list.
if not self.using_real_hardware:
if not self.classifier:
model_path = EMGClassifier.get_default_model_path()
if model_path.exists():
self.classifier = EMGClassifier.load(model_path)
self.model_label.configure(text="Model: Loaded", text_color="green")
else:
self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange")
# Reset smoother
self.smoother = PredictionSmoother(
label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"],
probability_smoothing=0.7,
majority_vote_window=5,
debounce_count=3
)
self.is_predicting = True
self.start_button.configure(text="Stop Prediction", fg_color="red")
# Start display loop
self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True)
self.prediction_thread.start()
self.update_prediction_ui()
def stop_prediction(self):
"""Stop prediction."""
self.is_predicting = False
if self.stream:
self.stream.stop() # Sends "stop" usually
if not self.using_real_hardware:
self.stream = None
self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"])
self.prediction_label.configure(text="---", text_color="gray")
self.confidence_label.configure(text="Confidence: ---%")
self.confidence_bar.set(0)
def prediction_loop(self):
"""Loop for reading data and (optionally) running inference."""
import json
parser = EMGParser(num_channels=NUM_CHANNELS) parser = EMGParser(num_channels=NUM_CHANNELS)
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0) windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
# Simulated gesture cycling (only for simulated mode)
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
gesture_idx = 0
gesture_duration = 2.5
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[0]
# Start simulated stream if needed
if not self.using_real_hardware:
try:
if hasattr(self.stream, 'set_gesture'):
self.stream.set_gesture(current_gesture)
self.stream.start()
except Exception as e:
self.data_queue.put(('error', f"Failed to start simulated stream: {e}"))
return
else:
# Real hardware is already streaming
self.data_queue.put(('connection_status', ('green', 'Streaming')))
while self.is_predicting: while self.is_predicting:
# Change simulated gesture periodically (only for simulated mode)
if hasattr(self.stream, 'set_gesture'):
elapsed = time.perf_counter() - gesture_start
if elapsed > gesture_duration:
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[gesture_idx]
self.stream.set_gesture(current_gesture)
self.data_queue.put(('sim_gesture', current_gesture))
# Read and process
try: try:
line = self.stream.readline() line = self.stream.readline()
if not line:
continue
if self.using_real_hardware:
# Edge Inference Mode: Expect JSON
try:
line = line.strip()
if line.startswith('{'):
data = json.loads(line)
if "gesture" in data:
# Update UI with Edge Prediction
gesture = data["gesture"]
conf = float(data.get("conf", 0.0))
self.data_queue.put(('prediction', (gesture, conf)))
elif "status" in data:
print(f"[ESP32] {data}")
else:
pass
except json.JSONDecodeError:
pass
else:
# PC Side Inference (Simulated)
sample = parser.parse_line(line)
if sample:
window = windower.add_sample(sample)
if window and self.classifier:
# Run Inference Local
raw_label, proba = self.classifier.predict(window.to_numpy())
label, conf, _ = self.smoother.update(raw_label, proba)
self.data_queue.put(('prediction', (label, conf)))
except Exception as e: except Exception as e:
# Only report error if we didn't intentionally stop
if self.is_predicting: if self.is_predicting:
self.data_queue.put(('error', f"Serial read error: {e}")) print(f"Prediction loop error: {e}")
break break
if line:
sample = parser.parse_line(line)
if sample:
window = windower.add_sample(sample)
if window:
# Get raw prediction
window_data = window.to_numpy()
raw_label, proba = self.classifier.predict(window_data)
raw_confidence = max(proba) * 100
# Apply smoothing
smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba)
smoothed_confidence = smoothed_conf * 100
# Send both raw and smoothed to UI
self.data_queue.put(('prediction', (
smoothed_label, # The stable output
smoothed_confidence,
raw_label, # The raw (possibly twitchy) output
raw_confidence,
)))
# Safe cleanup - stream might already be stopped
try:
if self.stream:
self.stream.stop()
except Exception:
pass # Ignore cleanup errors
def update_prediction_ui(self): def update_prediction_ui(self):
"""Update UI from prediction thread.""" """Update UI from queue."""
try: try:
while True: while True:
msg_type, data = self.data_queue.get_nowait() msg_type, data = self.data_queue.get_nowait()
if msg_type == 'prediction': if msg_type == 'prediction':
smoothed_label, smoothed_conf, raw_label, raw_conf = data label, conf = data
# Display smoothed (stable) prediction # Update label
display_label = smoothed_label.upper().replace("_", " ") self.prediction_label.configure(
color = get_gesture_color(smoothed_label) text=label.upper(),
text_color=get_gesture_color(label)
self.prediction_label.configure(text=display_label, text_color=color) )
self.confidence_bar.set(smoothed_conf / 100)
self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%") # Update confidence
self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%")
# Show raw prediction for comparison (grayed out) self.confidence_bar.set(conf)
raw_display = raw_label.upper().replace("_", " ")
if raw_label != smoothed_label: # Clear raw label since we don't have raw vs smooth distinction in edge mode
# Raw differs from smoothed - show it was filtered # (or we could expose it if we updated the C struct, but for now keep it simple)
self.raw_label.configure( self.raw_label.configure(text="", text_color="gray")
text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered",
text_color="orange"
)
else:
self.raw_label.configure(
text=f"Raw: {raw_display} ({raw_conf:.0f}%)",
text_color="gray"
)
elif msg_type == 'sim_gesture': elif msg_type == 'sim_gesture':
self.sim_label.configure(text=f"[Simulating: {data}]") self.sim_label.configure(text=f"[Simulating: {data}]")

View File

@@ -1757,6 +1757,132 @@ class EMGClassifier:
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB") print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
return filepath return filepath
def export_to_header(self, filepath: Path) -> Path:
"""
Export trained model to a C header file for ESP32 inference.
Args:
filepath: Output .h file path
Returns:
Path to the saved header file
"""
if not self.is_trained:
raise ValueError("Cannot export untrained classifier!")
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
n_classes = len(self.label_names)
n_features = len(self.feature_names)
# Get LDA parameters
# coef_: (n_classes, n_features) - access as [class][feature]
# intercept_: (n_classes,)
coefs = self.lda.coef_
intercepts = self.lda.intercept_
# Add logic for binary classification (sklearn stores only 1 set of coefs)
# For >2 classes, it stores n_classes sets.
if n_classes == 2:
# Binary case: coef_ is (1, n_features), intercept_ is (1,)
# We need to expand this to 2 classes for the C inference engine to be generic.
# Class 1 decision = dot(w, x) + b
# Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function
# BUT: decision_function returns score. A generic 'argmax' approach usually expects
# one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial.
# Let's check sklearn docs or behavior.
# Actually, LDA in sklearn for binary case is special.
# To make C code simple (always argmax), let's explicitly store 2 rows.
# Row 1 (Index 1 in sklearn): coef, intercept
# Row 0 (Index 0): -coef, -intercept ?
# Wait, LDA is generative. The decision boundary is linear.
# Let's assume Multiclass for now or handle binary specifically.
# For simplicity in C, we prefer (n_classes, n_features).
# If coefs.shape[0] != n_classes, we need to handle it.
if coefs.shape[0] == 1:
print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.")
# Class 1 (positive)
c1_coef = coefs[0]
c1_int = intercepts[0]
# Class 0 (negative) - Effectively -score for decision boundary at 0
# But strictly speaking LDA is comparison of log-posteriors.
# Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0).
# The score S = coef.X + intercept. If S > 0 pred class 1, else 0.
# To map this to ArgMax(Score0, Score1):
# We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2.
# Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier).
# Safest: Let's trust that for our 5-gesture demo, it's multiclass.
pass
# Generate C content
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
c_content = [
"/**",
f" * @file {filepath.name}",
" * @brief Trained LDA model weights exported from Python.",
f" * @date {timestamp}",
" */",
"",
"#ifndef MODEL_WEIGHTS_H",
"#define MODEL_WEIGHTS_H",
"",
"#include <stdint.h>",
"",
"/* Metadata */",
f"#define MODEL_NUM_CLASSES {n_classes}",
f"#define MODEL_NUM_FEATURES {n_features}",
"",
"/* Class Names */",
"static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {",
]
for name in self.label_names:
c_content.append(f' "{name}",')
c_content.append("};")
c_content.append("")
c_content.append("/* Feature Extractor Parameters */")
c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f")
c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f")
c_content.append("")
c_content.append("/* LDA Intercepts/Biases */")
c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{")
line = " "
for val in intercepts:
line += f"{val:.6f}f, "
c_content.append(line.rstrip(", "))
c_content.append("};")
c_content.append("")
c_content.append("/* LDA Coefficients (Weights) */")
c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{")
for i, row in enumerate(coefs):
c_content.append(f" /* {self.label_names[i]} */")
c_content.append(" {")
line = " "
for j, val in enumerate(row):
line += f"{val:.6f}f, "
if (j + 1) % 8 == 0:
c_content.append(line)
line = " "
if line.strip():
c_content.append(line.rstrip(", "))
c_content.append(" },")
c_content.append("};")
c_content.append("")
c_content.append("#endif /* MODEL_WEIGHTS_H */")
with open(filepath, 'w') as f:
f.write('\n'.join(c_content))
print(f"[Classifier] Model weights exported to: {filepath}")
return filepath
@classmethod @classmethod
def load(cls, filepath: Path) -> 'EMGClassifier': def load(cls, filepath: Path) -> 'EMGClassifier':
""" """