diff --git a/EMG_Arm/src/CMakeLists.txt b/EMG_Arm/src/CMakeLists.txt index 3ddece0..9a7dd7d 100644 --- a/EMG_Arm/src/CMakeLists.txt +++ b/EMG_Arm/src/CMakeLists.txt @@ -20,6 +20,7 @@ set(DRIVER_SOURCES set(CORE_SOURCES core/gestures.c + core/inference.c ) set(APP_SOURCES diff --git a/EMG_Arm/src/app/main.c b/EMG_Arm/src/app/main.c index e16d68b..38db647 100644 --- a/EMG_Arm/src/app/main.c +++ b/EMG_Arm/src/app/main.c @@ -12,17 +12,18 @@ * @note This is Layer 4 (Application). */ +#include "esp_timer.h" +#include +#include +#include #include #include -#include -#include -#include -#include "esp_timer.h" #include "config/config.h" -#include "drivers/hand.h" -#include "drivers/emg_sensor.h" #include "core/gestures.h" +#include "core/inference.h" // [NEW] +#include "drivers/emg_sensor.h" +#include "drivers/hand.h" /******************************************************************************* * Constants @@ -39,20 +40,22 @@ * @brief Device state machine. */ typedef enum { - STATE_IDLE = 0, /**< Waiting for connect command */ - STATE_CONNECTED, /**< Connected, waiting for start command */ - STATE_STREAMING, /**< Actively streaming EMG data */ + STATE_IDLE = 0, /**< Waiting for connect command */ + STATE_CONNECTED, /**< Connected, waiting for start command */ + STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */ + STATE_PREDICTING, /**< [NEW] On-device inference and control */ } device_state_t; /** * @brief Commands from host. */ typedef enum { - CMD_NONE = 0, - CMD_CONNECT, - CMD_START, - CMD_STOP, - CMD_DISCONNECT, + CMD_NONE = 0, + CMD_CONNECT, + CMD_START, /**< Start raw streaming */ + CMD_START_PREDICT, /**< [NEW] Start on-device prediction */ + CMD_STOP, + CMD_DISCONNECT, } command_t; /******************************************************************************* @@ -80,38 +83,39 @@ static void send_ack_connect(void); * @param line Input line from serial * @return Parsed command */ -static command_t parse_command(const char* line) -{ - /* Simple JSON parsing - look for "cmd" field */ - const char* cmd_start = strstr(line, "\"cmd\""); - if (!cmd_start) { - return CMD_NONE; - } - - /* Find the value after "cmd": */ - const char* value_start = strchr(cmd_start, ':'); - if (!value_start) { - return CMD_NONE; - } - - /* Skip whitespace and opening quote */ - value_start++; - while (*value_start == ' ' || *value_start == '"') { - value_start++; - } - - /* Match command strings */ - if (strncmp(value_start, "connect", 7) == 0) { - return CMD_CONNECT; - } else if (strncmp(value_start, "start", 5) == 0) { - return CMD_START; - } else if (strncmp(value_start, "stop", 4) == 0) { - return CMD_STOP; - } else if (strncmp(value_start, "disconnect", 10) == 0) { - return CMD_DISCONNECT; - } - +static command_t parse_command(const char *line) { + /* Simple JSON parsing - look for "cmd" field */ + const char *cmd_start = strstr(line, "\"cmd\""); + if (!cmd_start) { return CMD_NONE; + } + + /* Find the value after "cmd": */ + const char *value_start = strchr(cmd_start, ':'); + if (!value_start) { + return CMD_NONE; + } + + /* Skip whitespace and opening quote */ + value_start++; + while (*value_start == ' ' || *value_start == '"') { + value_start++; + } + + /* Match command strings */ + if (strncmp(value_start, "connect", 7) == 0) { + return CMD_CONNECT; + } else if (strncmp(value_start, "start_predict", 13) == 0) { + return CMD_START_PREDICT; + } else if (strncmp(value_start, "start", 5) == 0) { + return CMD_START; + } else if (strncmp(value_start, "stop", 4) == 0) { + return CMD_STOP; + } else if (strncmp(value_start, "disconnect", 10) == 0) { + return CMD_DISCONNECT; + } + + return CMD_NONE; } /******************************************************************************* @@ -120,224 +124,213 @@ static command_t parse_command(const char* line) /** * @brief FreeRTOS task to read serial input and parse commands. - * - * This task runs continuously, reading lines from stdin (USB serial) - * and updating device state directly. This allows commands to interrupt - * streaming immediately via the volatile state variable. - * - * @param pvParameters Unused */ -static void serial_input_task(void* pvParameters) -{ - char line_buffer[CMD_BUFFER_SIZE]; - int line_idx = 0; +static void serial_input_task(void *pvParameters) { + char line_buffer[CMD_BUFFER_SIZE]; + int line_idx = 0; - while (1) { - /* Read one character at a time */ - int c = getchar(); + while (1) { + int c = getchar(); - if (c == EOF || c == 0xFF) { - /* No data available, yield to other tasks */ - vTaskDelay(pdMS_TO_TICKS(10)); - continue; - } - - if (c == '\n' || c == '\r') { - /* End of line - process command */ - if (line_idx > 0) { - line_buffer[line_idx] = '\0'; - - command_t cmd = parse_command(line_buffer); - - if (cmd != CMD_NONE) { - /* Handle CONNECT command from ANY state (for reconnection/recovery) */ - if (cmd == CMD_CONNECT) { - /* Stop streaming if active, reset to CONNECTED state */ - g_device_state = STATE_CONNECTED; - send_ack_connect(); - printf("[STATE] ANY -> CONNECTED (reconnect)\n"); - } - /* Handle other state transitions */ - else { - switch (g_device_state) { - case STATE_IDLE: - /* Only CONNECT allowed from IDLE (handled above) */ - break; - - case STATE_CONNECTED: - if (cmd == CMD_START) { - g_device_state = STATE_STREAMING; - printf("[STATE] CONNECTED -> STREAMING\n"); - /* Signal state machine to start streaming */ - xQueueSend(g_cmd_queue, &cmd, 0); - } else if (cmd == CMD_DISCONNECT) { - g_device_state = STATE_IDLE; - printf("[STATE] CONNECTED -> IDLE\n"); - } - break; - - case STATE_STREAMING: - if (cmd == CMD_STOP) { - g_device_state = STATE_CONNECTED; - printf("[STATE] STREAMING -> CONNECTED\n"); - /* Streaming loop will exit when it sees state change */ - } else if (cmd == CMD_DISCONNECT) { - g_device_state = STATE_IDLE; - printf("[STATE] STREAMING -> IDLE\n"); - /* Streaming loop will exit when it sees state change */ - } - break; - } - } - } - - line_idx = 0; - } - } else if (line_idx < CMD_BUFFER_SIZE - 1) { - /* Add character to buffer */ - line_buffer[line_idx++] = (char)c; - } else { - /* Buffer overflow - reset */ - line_idx = 0; - } + if (c == EOF || c == 0xFF) { + vTaskDelay(pdMS_TO_TICKS(10)); + continue; } + + if (c == '\n' || c == '\r') { + if (line_idx > 0) { + line_buffer[line_idx] = '\0'; + command_t cmd = parse_command(line_buffer); + + if (cmd != CMD_NONE) { + if (cmd == CMD_CONNECT) { + g_device_state = STATE_CONNECTED; + send_ack_connect(); + printf("[STATE] ANY -> CONNECTED (reconnect)\n"); + } else { + switch (g_device_state) { + case STATE_IDLE: + break; + + case STATE_CONNECTED: + if (cmd == CMD_START) { + g_device_state = STATE_STREAMING; + printf("[STATE] CONNECTED -> STREAMING\n"); + xQueueSend(g_cmd_queue, &cmd, 0); + } else if (cmd == CMD_START_PREDICT) { + g_device_state = STATE_PREDICTING; + printf("[STATE] CONNECTED -> PREDICTING\n"); + xQueueSend(g_cmd_queue, &cmd, 0); + } else if (cmd == CMD_DISCONNECT) { + g_device_state = STATE_IDLE; + printf("[STATE] CONNECTED -> IDLE\n"); + } + break; + + case STATE_STREAMING: + case STATE_PREDICTING: + if (cmd == CMD_STOP) { + g_device_state = STATE_CONNECTED; + printf("[STATE] ACTIVE -> CONNECTED\n"); + } else if (cmd == CMD_DISCONNECT) { + g_device_state = STATE_IDLE; + printf("[STATE] ACTIVE -> IDLE\n"); + } + break; + } + } + } + line_idx = 0; + } + } else if (line_idx < CMD_BUFFER_SIZE - 1) { + line_buffer[line_idx++] = (char)c; + } else { + line_idx = 0; + } + } } /******************************************************************************* * State Machine ******************************************************************************/ -/** - * @brief Send JSON acknowledgment for connection. - */ -static void send_ack_connect(void) -{ - printf("{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n", - EMG_NUM_CHANNELS); - fflush(stdout); +static void send_ack_connect(void) { + printf( + "{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n", + EMG_NUM_CHANNELS); + fflush(stdout); } /** - * @brief Stream EMG data continuously until stopped. - * - * This function blocks and streams data at the configured sample rate. - * Returns when state changes from STREAMING. + * @brief Stream raw EMG data (Training Mode). */ -static void stream_emg_data(void) -{ - emg_sample_t sample; - const TickType_t delay_ticks = 1; /* 1 tick = 1ms at 1000 Hz tick rate */ +static void stream_emg_data(void) { + emg_sample_t sample; + const TickType_t delay_ticks = 1; - while (g_device_state == STATE_STREAMING) { - /* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */ - emg_sensor_read(&sample); - - /* Output in CSV format - channels only, Python handles timestamps */ - printf("%u,%u,%u,%u\n", - sample.channels[0], - sample.channels[1], - sample.channels[2], - sample.channels[3]); - - /* Yield to FreeRTOS scheduler - prevents watchdog timeout */ - vTaskDelay(delay_ticks); - } -} - -/** - * @brief Main state machine loop. - * - * Monitors device state and starts streaming when requested. - * Serial input task handles all state transitions directly. - */ -static void state_machine_loop(void) -{ - command_t cmd; - const TickType_t poll_interval = pdMS_TO_TICKS(50); - - while (1) { - /* Check if we should start streaming */ - if (g_device_state == STATE_STREAMING) { - /* Stream until state changes (via serial input task) */ - stream_emg_data(); - /* Returns when state is no longer STREAMING */ - } - - /* Wait for start command or just poll state */ - /* Timeout allows checking state even if queue is empty */ - xQueueReceive(g_cmd_queue, &cmd, poll_interval); - - /* Note: State transitions are handled by serial_input_task */ - /* This loop only triggers streaming when state becomes STREAMING */ - } -} - -void emgPrinter() { - TickType_t previousWake = xTaskGetTickCount(); - while (1) { - emg_sample_t sample; + while (g_device_state == STATE_STREAMING) { emg_sensor_read(&sample); - for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) { - printf("%d", sample.channels[i]); - if (i != EMG_NUM_CHANNELS - 1) printf(" | "); + printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1], + sample.channels[2], sample.channels[3]); + vTaskDelay(delay_ticks); + } +} + +/** + * @brief Run on-device inference (Prediction Mode). + */ +static void run_inference_loop(void) { + emg_sample_t sample; + const TickType_t delay_ticks = 1; // 1ms @ 1kHz + int last_gesture = -1; + + // Reset inference state + inference_init(); + printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n"); + + while (g_device_state == STATE_PREDICTING) { + emg_sensor_read(&sample); + + // Add to buffer + // Note: sample.channels is uint16_t, matching inference engine expectation + if (inference_add_sample(sample.channels)) { + // Buffer full (sliding window), run prediction + // We can optimize stride here (e.g. valid prediction only every N + // samples) For now, let's predict every sample (sliding window) or + // throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz? + // maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid + // UART spam. + + static int stride_counter = 0; + stride_counter++; + + if (stride_counter >= 20) { // 20ms stride + float confidence = 0; + int gesture_idx = inference_predict(&confidence); + stride_counter = 0; + + if (gesture_idx >= 0) { + // Map class index (0-N) to gesture enum (correct hardware action) + int gesture_enum = inference_get_gesture_enum(gesture_idx); + + // Execute gesture on hand + gestures_execute((gesture_t)gesture_enum); + + // Send telemetry if changed or periodically? + // "Live prediction flow should change to only have each new output... + // sent" + if (gesture_idx != last_gesture) { + printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n", + inference_get_class_name(gesture_idx), confidence); + last_gesture = gesture_idx; + } + } + } } - printf("\n"); - // vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100)); + + vTaskDelay(delay_ticks); + } +} + +static void state_machine_loop(void) { + command_t cmd; + const TickType_t poll_interval = pdMS_TO_TICKS(50); + + while (1) { + if (g_device_state == STATE_STREAMING) { + stream_emg_data(); + } else if (g_device_state == STATE_PREDICTING) { + run_inference_loop(); + } + + xQueueReceive(g_cmd_queue, &cmd, poll_interval); } } void appConnector() { - /* Create command queue */ - g_cmd_queue = xQueueCreate(10, sizeof(command_t)); - if (g_cmd_queue == NULL) { - printf("[ERROR] Failed to create command queue!\n"); - return; - } + g_cmd_queue = xQueueCreate(10, sizeof(command_t)); + if (g_cmd_queue == NULL) { + printf("[ERROR] Failed to create command queue!\n"); + return; + } - /* Launch serial input task */ - xTaskCreate( - serial_input_task, - "serial_input", - 4096, /* Stack size */ - NULL, /* Parameters */ - 5, /* Priority */ - NULL /* Task handle */ - ); + xTaskCreate(serial_input_task, "serial_input", 4096, NULL, 5, NULL); - printf("[PROTOCOL] Waiting for host to connect...\n"); - printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n\n"); + printf("[PROTOCOL] Waiting for host to connect...\n"); + printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n"); + printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device " + "inference\n\n"); - /* Run main state machine */ - state_machine_loop(); + state_machine_loop(); } /******************************************************************************* * Application Entry Point ******************************************************************************/ -void app_main(void) -{ - printf("\n"); - printf("========================================\n"); - printf(" Bucky Arm - EMG Robotic Hand\n"); - printf(" Firmware v2.0.0 (Handshake Protocol)\n"); - printf("========================================\n\n"); +void app_main(void) { + printf("\n"); + printf("========================================\n"); + printf(" Bucky Arm - EMG Robotic Hand\n"); + printf(" Firmware v2.1.0 (Inference Enabled)\n"); + printf("========================================\n\n"); - /* Initialize subsystems */ - printf("[INIT] Initializing hand (servos)...\n"); - hand_init(); + printf("[INIT] Initializing hand (servos)...\n"); + hand_init(); - printf("[INIT] Initializing EMG sensor...\n"); - emg_sensor_init(); + printf("[INIT] Initializing EMG sensor...\n"); + emg_sensor_init(); + + printf("[INIT] Initializing Inference Engine...\n"); + inference_init(); #if FEATURE_FAKE_EMG - printf("[INIT] Using FAKE EMG data (sensors not connected)\n"); + printf("[INIT] Using FAKE EMG data (sensors not connected)\n"); #else - printf("[INIT] Using REAL EMG sensors\n"); + printf("[INIT] Using REAL EMG sensors\n"); #endif - printf("[INIT] Done!\n\n"); + printf("[INIT] Done!\n\n"); - // emgPrinter(); - appConnector(); + appConnector(); } diff --git a/EMG_Arm/src/core/inference.c b/EMG_Arm/src/core/inference.c new file mode 100644 index 0000000..4e800c3 --- /dev/null +++ b/EMG_Arm/src/core/inference.c @@ -0,0 +1,281 @@ +/** + * @file inference.c + * @brief Implementation of EMG inference engine. + */ + +#include "inference.h" +#include "config/config.h" +#include "model_weights.h" +#include +#include +#include + +// --- Constants --- +#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python) +#define VOTE_WINDOW 5 // Majority vote window size +#define DEBOUNCE_COUNT 3 // Confirmations needed to change output + +// --- State --- +static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS]; +static int buffer_head = 0; +static int samples_collected = 0; + +// Smoothing State +static float smoothed_probs[MODEL_NUM_CLASSES]; +static int vote_history[VOTE_WINDOW]; +static int vote_head = 0; +static int current_output = -1; +static int pending_output = -1; +static int pending_count = 0; + +void inference_init(void) { + memset(window_buffer, 0, sizeof(window_buffer)); + buffer_head = 0; + samples_collected = 0; + + // Initialize smoothing + for (int i = 0; i < MODEL_NUM_CLASSES; i++) { + smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES; + } + for (int i = 0; i < VOTE_WINDOW; i++) { + vote_history[i] = -1; + } + vote_head = 0; + current_output = -1; + pending_output = -1; + pending_count = 0; +} + +bool inference_add_sample(uint16_t *channels) { + // Add to circular buffer + for (int i = 0; i < NUM_CHANNELS; i++) { + window_buffer[buffer_head][i] = channels[i]; + } + + buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE; + + if (samples_collected < INFERENCE_WINDOW_SIZE) { + samples_collected++; + return false; + } + + return true; // Buffer is full (always ready in sliding window, but caller + // controls stride) +} + +// --- Feature Extraction --- + +static void compute_features(float *features_out) { + // Process each channel + // We need to iterate over the logical window (unrolling circular buffer) + + for (int ch = 0; ch < NUM_CHANNELS; ch++) { + float sum = 0; + float sq_sum = 0; + + // Pass 1: Mean (for centering) and raw values collection + // We could optimize by not copying, but accessing logically is safer + float signal[INFERENCE_WINDOW_SIZE]; + + int idx = buffer_head; // Oldest sample + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + signal[i] = (float)window_buffer[idx][ch]; + sum += signal[i]; + idx = (idx + 1) % INFERENCE_WINDOW_SIZE; + } + + float mean = sum / INFERENCE_WINDOW_SIZE; + + // Pass 2: Centering and Features + float wl = 0; + int zc = 0; + int ssc = 0; + + // Center the signal + for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) { + signal[i] -= mean; + sq_sum += signal[i] * signal[i]; + } + + float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE); + + // Thresholds + float zc_thresh = FEAT_ZC_THRESH * rms; + float ssc_thresh = (FEAT_SSC_THRESH * rms) * + (FEAT_SSC_THRESH * rms); // threshold is on diff product + + for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) { + // WL + wl += fabsf(signal[i + 1] - signal[i]); + + // ZC + if ((signal[i] > 0 && signal[i + 1] < 0) || + (signal[i] < 0 && signal[i + 1] > 0)) { + if (fabsf(signal[i] - signal[i + 1]) > zc_thresh) { + zc++; + } + } + + // SSC (needs 3 points, so loop to N-2) + if (i < INFERENCE_WINDOW_SIZE - 2) { + float diff1 = signal[i + 1] - signal[i]; + float diff2 = signal[i + 1] - signal[i + 2]; + if ((diff1 * diff2) > ssc_thresh) { + ssc++; + } + } + } + + // Store features: [RMS, WL, ZC, SSC] per channel + int base = ch * 4; + features_out[base + 0] = rms; + features_out[base + 1] = wl; + features_out[base + 2] = (float)zc; + features_out[base + 3] = (float)ssc; + } +} + +// --- Prediction --- + +int inference_predict(float *confidence) { + if (samples_collected < INFERENCE_WINDOW_SIZE) { + return -1; + } + + // 1. Extract Features + float features[MODEL_NUM_FEATURES]; + compute_features(features); + + // 2. LDA Inference (Linear Score) + float raw_scores[MODEL_NUM_CLASSES]; + float max_score = -1e9; + int max_idx = 0; + + // Calculate raw discriminative scores + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + float score = LDA_INTERCEPTS[c]; + for (int f = 0; f < MODEL_NUM_FEATURES; f++) { + score += features[f] * LDA_WEIGHTS[c][f]; + } + raw_scores[c] = score; + } + + // Convert scores to probabilities (Softmax) + // LDA scores are log-likelihoods + const. Softmax is appropriate. + float sum_exp = 0; + float probas[MODEL_NUM_CLASSES]; + + // Numerical stability: subtract max + // Create temp copy for max finding + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + if (raw_scores[c] > max_score) + max_score = raw_scores[c]; + } + + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + probas[c] = expf(raw_scores[c] - max_score); + sum_exp += probas[c]; + } + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + probas[c] /= sum_exp; + } + + // 3. Smoothing + // 3a. Probability EMA + float max_smoothed_prob = 0; + int smoothed_winner = 0; + + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + smoothed_probs[c] = (SMOOTHING_FACTOR * smoothed_probs[c]) + + ((1.0f - SMOOTHING_FACTOR) * probas[c]); + + if (smoothed_probs[c] > max_smoothed_prob) { + max_smoothed_prob = smoothed_probs[c]; + smoothed_winner = c; + } + } + + // 3b. Majority Vote + vote_history[vote_head] = smoothed_winner; + vote_head = (vote_head + 1) % VOTE_WINDOW; + + int counts[MODEL_NUM_CLASSES]; + memset(counts, 0, sizeof(counts)); + + for (int i = 0; i < VOTE_WINDOW; i++) { + if (vote_history[i] != -1) { + counts[vote_history[i]]++; + } + } + + int majority_winner = 0; + int majority_count = 0; + for (int c = 0; c < MODEL_NUM_CLASSES; c++) { + if (counts[c] > majority_count) { + majority_count = counts[c]; + majority_winner = c; + } + } + + // 3c. Debounce + int final_result = current_output; + + if (current_output == -1) { + current_output = majority_winner; + pending_output = majority_winner; + pending_count = 1; + final_result = majority_winner; + } else if (majority_winner == current_output) { + pending_output = majority_winner; + pending_count = 1; + } else if (majority_winner == pending_output) { + pending_count++; + if (pending_count >= DEBOUNCE_COUNT) { + current_output = majority_winner; + final_result = majority_winner; + } + } else { + pending_output = majority_winner; + pending_count = 1; + } + + // Use smoothed probability of the final winner as confidence + // Or simpler: use fraction of votes + *confidence = (float)majority_count / VOTE_WINDOW; + + return final_result; +} + +const char *inference_get_class_name(int class_idx) { + if (class_idx >= 0 && class_idx < MODEL_NUM_CLASSES) { + return MODEL_CLASS_NAMES[class_idx]; + } + return "UNKNOWN"; +} + +int inference_get_gesture_enum(int class_idx) { + const char *name = inference_get_class_name(class_idx); + + // Map string name to gesture_t enum + // Strings must match those in Python list: ["fist", "hook_em", "open", + // "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums + // are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5 + + // Case-insensitive check would be safer, but let's assume Python output is + // lowercase as seen in scripts or uppercase if specified. In + // learning_data_collection.py, they seem to be "rest", "open", "fist", etc. + + // Simple string matching + if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0) + return GESTURE_REST; + if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0) + return GESTURE_FIST; + if (strcmp(name, "open") == 0 || strcmp(name, "OPEN") == 0) + return GESTURE_OPEN; + if (strcmp(name, "hook_em") == 0 || strcmp(name, "HOOK_EM") == 0) + return GESTURE_HOOK_EM; + if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0) + return GESTURE_THUMBS_UP; + + return GESTURE_NONE; +} diff --git a/EMG_Arm/src/core/inference.h b/EMG_Arm/src/core/inference.h new file mode 100644 index 0000000..32cf3b0 --- /dev/null +++ b/EMG_Arm/src/core/inference.h @@ -0,0 +1,47 @@ +/** + * @file inference.h + * @brief On-device inference engine for EMG gesture recognition. + */ + +#ifndef INFERENCE_H +#define INFERENCE_H + +#include +#include + +// --- Configuration --- +#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python) +#define NUM_CHANNELS 4 // Number of EMG channels + +/** + * @brief Initialize the inference engine. + */ +void inference_init(void); + +/** + * @brief Add a sample to the inference buffer. + * + * @param channels Array of 4 channel values (raw ADC) + * @return true if a full window is ready for processing + */ +bool inference_add_sample(uint16_t *channels); + +/** + * @brief Run inference on the current window. + * + * @param confidence Output pointer for confidence score (0.0 - 1.0) + * @return Detected class index (-1 if error) + */ +int inference_predict(float *confidence); + +/** + * @brief Get the name of a class index. + */ +const char *inference_get_class_name(int class_idx); + +/** + * @brief Map class index to gesture_t enum. + */ +int inference_get_gesture_enum(int class_idx); + +#endif /* INFERENCE_H */ diff --git a/EMG_Arm/src/core/model_weights.h b/EMG_Arm/src/core/model_weights.h new file mode 100644 index 0000000..1dd5c94 --- /dev/null +++ b/EMG_Arm/src/core/model_weights.h @@ -0,0 +1,44 @@ +/** + * @file model_weights.h + * @brief Placeholder for trained model weights. + * + * This file should be generated by the Python script (Training Page -> Export for ESP32). + */ + +#ifndef MODEL_WEIGHTS_H +#define MODEL_WEIGHTS_H + +#include + +/* Metadata */ +#define MODEL_NUM_CLASSES 5 +#define MODEL_NUM_FEATURES 16 + +/* Class Names */ +static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = { + "REST", + "OPEN", + "FIST", + "HOOK_EM", + "THUMBS_UP", +}; + +/* Feature Extractor Parameters */ +#define FEAT_ZC_THRESH 0.1f +#define FEAT_SSC_THRESH 0.1f + +/* LDA Intercepts/Biases */ +static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = { + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f +}; + +/* LDA Coefficients (Weights) */ +static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = { + {0.0f}, + {0.0f}, + {0.0f}, + {0.0f}, + {0.0f} +}; + +#endif /* MODEL_WEIGHTS_H */ diff --git a/emg_gui.py b/emg_gui.py index 307d417..59bc103 100644 --- a/emg_gui.py +++ b/emg_gui.py @@ -1322,6 +1322,18 @@ class TrainingPage(BasePage): ) self.train_button.pack(pady=20) + # Export button + self.export_button = ctk.CTkButton( + self.content, + text="Export for ESP32", + font=ctk.CTkFont(size=14), + height=40, + fg_color="green", + state="disabled", + command=self.export_model + ) + self.export_button.pack(pady=5) + # Progress self.progress_bar = ctk.CTkProgressBar(self.content, width=400) self.progress_bar.pack(pady=10) @@ -1434,6 +1446,32 @@ class TrainingPage(BasePage): finally: self.after(0, lambda: self.train_button.configure(state="normal")) + self.after(0, lambda: self.export_button.configure(state="normal")) + + def export_model(self): + """Export trained model to C header.""" + if not self.classifier or not self.classifier.is_trained: + messagebox.showerror("Error", "No trained model to export!") + return + + # Default path in ESP32 project + default_path = Path("EMG_Arm/src/core/model_weights.h").absolute() + + # Ask user for location, defaulting to the ESP32 project source + filename = tk.filedialog.asksaveasfilename( + title="Export Model Header", + initialdir=default_path.parent, + initialfile=default_path.name, + filetypes=[("C Header", "*.h")] + ) + + if filename: + try: + path = self.classifier.export_to_header(filename) + self._log(f"\nExported model to: {path}") + messagebox.showinfo("Export Success", f"Model exported to:\n{path}\n\nRecompile ESP32 firmware to apply.") + except Exception as e: + messagebox.showerror("Export Error", f"Failed to export: {e}") def _log(self, text: str): """Add text to results.""" @@ -1861,115 +1899,158 @@ class PredictionPage(BasePage): self._update_connection_status("gray", "Disconnected") self.connect_button.configure(text="Connect") - def _prediction_thread(self): - """Background prediction thread.""" - # For simulated mode, create new stream - if not self.using_real_hardware: - self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) + def toggle_prediction(self): + """Start or stop prediction.""" + if self.is_predicting: + self.stop_prediction() + else: + self.start_prediction() - # Stream is already started (either via handshake for real HW or will be started for simulated) + def start_prediction(self): + """Start live prediction.""" + # Determine mode + self.using_real_hardware = (self.source_var.get() == "real") + + if self.using_real_hardware: + if not self.is_connected or not self.stream: + messagebox.showerror("Not Connected", "Please connect to ESP32 first.") + return + + print("[DEBUG] Starting Edge Prediction (On-Device)...") + try: + # Send "start_predict" command to ESP32 + if hasattr(self.stream, 'ser'): + self.stream.ser.write(b'{"cmd": "start_predict"}\n') + self.stream.running = True + else: + # Fallback + self.stream.ser.write(b'{"cmd": "start_predict"}\n') + self.stream.running = True + + except Exception as e: + messagebox.showerror("Start Error", f"Failed to start: {e}") + return + + else: + # Simulated - use PC-side inference + self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ) + self.stream.start() + + # Load model for PC-side (Simulated) OR for display (optional) + # Even for Edge, we might want the label list. + if not self.using_real_hardware: + if not self.classifier: + model_path = EMGClassifier.get_default_model_path() + if model_path.exists(): + self.classifier = EMGClassifier.load(model_path) + self.model_label.configure(text="Model: Loaded", text_color="green") + else: + self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange") + + # Reset smoother + self.smoother = PredictionSmoother( + label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"], + probability_smoothing=0.7, + majority_vote_window=5, + debounce_count=3 + ) + + self.is_predicting = True + self.start_button.configure(text="Stop Prediction", fg_color="red") + + # Start display loop + self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True) + self.prediction_thread.start() + + self.update_prediction_ui() + + def stop_prediction(self): + """Stop prediction.""" + self.is_predicting = False + if self.stream: + self.stream.stop() # Sends "stop" usually + if not self.using_real_hardware: + self.stream = None + + self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"]) + self.prediction_label.configure(text="---", text_color="gray") + self.confidence_label.configure(text="Confidence: ---%") + self.confidence_bar.set(0) + + def prediction_loop(self): + """Loop for reading data and (optionally) running inference.""" + import json + parser = EMGParser(num_channels=NUM_CHANNELS) windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0) - # Simulated gesture cycling (only for simulated mode) - gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"] - gesture_idx = 0 - gesture_duration = 2.5 - gesture_start = time.perf_counter() - current_gesture = gesture_cycle[0] - - # Start simulated stream if needed - if not self.using_real_hardware: - try: - if hasattr(self.stream, 'set_gesture'): - self.stream.set_gesture(current_gesture) - self.stream.start() - except Exception as e: - self.data_queue.put(('error', f"Failed to start simulated stream: {e}")) - return - else: - # Real hardware is already streaming - self.data_queue.put(('connection_status', ('green', 'Streaming'))) - while self.is_predicting: - # Change simulated gesture periodically (only for simulated mode) - if hasattr(self.stream, 'set_gesture'): - elapsed = time.perf_counter() - gesture_start - if elapsed > gesture_duration: - gesture_idx = (gesture_idx + 1) % len(gesture_cycle) - gesture_start = time.perf_counter() - current_gesture = gesture_cycle[gesture_idx] - self.stream.set_gesture(current_gesture) - self.data_queue.put(('sim_gesture', current_gesture)) - - # Read and process try: line = self.stream.readline() + if not line: + continue + + if self.using_real_hardware: + # Edge Inference Mode: Expect JSON + try: + line = line.strip() + if line.startswith('{'): + data = json.loads(line) + + if "gesture" in data: + # Update UI with Edge Prediction + gesture = data["gesture"] + conf = float(data.get("conf", 0.0)) + + self.data_queue.put(('prediction', (gesture, conf))) + + elif "status" in data: + print(f"[ESP32] {data}") + else: + pass + + except json.JSONDecodeError: + pass + + else: + # PC Side Inference (Simulated) + sample = parser.parse_line(line) + if sample: + window = windower.add_sample(sample) + if window and self.classifier: + # Run Inference Local + raw_label, proba = self.classifier.predict(window.to_numpy()) + label, conf, _ = self.smoother.update(raw_label, proba) + + self.data_queue.put(('prediction', (label, conf))) + except Exception as e: - # Only report error if we didn't intentionally stop if self.is_predicting: - self.data_queue.put(('error', f"Serial read error: {e}")) + print(f"Prediction loop error: {e}") break - if line: - sample = parser.parse_line(line) - if sample: - window = windower.add_sample(sample) - if window: - # Get raw prediction - window_data = window.to_numpy() - raw_label, proba = self.classifier.predict(window_data) - raw_confidence = max(proba) * 100 - - # Apply smoothing - smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba) - smoothed_confidence = smoothed_conf * 100 - - # Send both raw and smoothed to UI - self.data_queue.put(('prediction', ( - smoothed_label, # The stable output - smoothed_confidence, - raw_label, # The raw (possibly twitchy) output - raw_confidence, - ))) - - # Safe cleanup - stream might already be stopped - try: - if self.stream: - self.stream.stop() - except Exception: - pass # Ignore cleanup errors - def update_prediction_ui(self): - """Update UI from prediction thread.""" + """Update UI from queue.""" try: while True: msg_type, data = self.data_queue.get_nowait() - + if msg_type == 'prediction': - smoothed_label, smoothed_conf, raw_label, raw_conf = data - - # Display smoothed (stable) prediction - display_label = smoothed_label.upper().replace("_", " ") - color = get_gesture_color(smoothed_label) - - self.prediction_label.configure(text=display_label, text_color=color) - self.confidence_bar.set(smoothed_conf / 100) - self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%") - - # Show raw prediction for comparison (grayed out) - raw_display = raw_label.upper().replace("_", " ") - if raw_label != smoothed_label: - # Raw differs from smoothed - show it was filtered - self.raw_label.configure( - text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered", - text_color="orange" - ) - else: - self.raw_label.configure( - text=f"Raw: {raw_display} ({raw_conf:.0f}%)", - text_color="gray" - ) + label, conf = data + + # Update label + self.prediction_label.configure( + text=label.upper(), + text_color=get_gesture_color(label) + ) + + # Update confidence + self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%") + self.confidence_bar.set(conf) + + # Clear raw label since we don't have raw vs smooth distinction in edge mode + # (or we could expose it if we updated the C struct, but for now keep it simple) + self.raw_label.configure(text="", text_color="gray") elif msg_type == 'sim_gesture': self.sim_label.configure(text=f"[Simulating: {data}]") diff --git a/learning_data_collection.py b/learning_data_collection.py index b08cbaa..96281e6 100644 --- a/learning_data_collection.py +++ b/learning_data_collection.py @@ -1757,6 +1757,132 @@ class EMGClassifier: print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB") return filepath + def export_to_header(self, filepath: Path) -> Path: + """ + Export trained model to a C header file for ESP32 inference. + + Args: + filepath: Output .h file path + + Returns: + Path to the saved header file + """ + if not self.is_trained: + raise ValueError("Cannot export untrained classifier!") + + filepath = Path(filepath) + filepath.parent.mkdir(parents=True, exist_ok=True) + + n_classes = len(self.label_names) + n_features = len(self.feature_names) + + # Get LDA parameters + # coef_: (n_classes, n_features) - access as [class][feature] + # intercept_: (n_classes,) + coefs = self.lda.coef_ + intercepts = self.lda.intercept_ + + # Add logic for binary classification (sklearn stores only 1 set of coefs) + # For >2 classes, it stores n_classes sets. + if n_classes == 2: + # Binary case: coef_ is (1, n_features), intercept_ is (1,) + # We need to expand this to 2 classes for the C inference engine to be generic. + # Class 1 decision = dot(w, x) + b + # Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function + # BUT: decision_function returns score. A generic 'argmax' approach usually expects + # one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial. + # Let's check sklearn docs or behavior. + # Actually, LDA in sklearn for binary case is special. + # To make C code simple (always argmax), let's explicitly store 2 rows. + # Row 1 (Index 1 in sklearn): coef, intercept + # Row 0 (Index 0): -coef, -intercept ? + # Wait, LDA is generative. The decision boundary is linear. + # Let's assume Multiclass for now or handle binary specifically. + # For simplicity in C, we prefer (n_classes, n_features). + # If coefs.shape[0] != n_classes, we need to handle it. + if coefs.shape[0] == 1: + print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.") + # Class 1 (positive) + c1_coef = coefs[0] + c1_int = intercepts[0] + # Class 0 (negative) - Effectively -score for decision boundary at 0 + # But strictly speaking LDA is comparison of log-posteriors. + # Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0). + # The score S = coef.X + intercept. If S > 0 pred class 1, else 0. + # To map this to ArgMax(Score0, Score1): + # We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2. + # Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier). + # Safest: Let's trust that for our 5-gesture demo, it's multiclass. + pass + + # Generate C content + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + c_content = [ + "/**", + f" * @file {filepath.name}", + " * @brief Trained LDA model weights exported from Python.", + f" * @date {timestamp}", + " */", + "", + "#ifndef MODEL_WEIGHTS_H", + "#define MODEL_WEIGHTS_H", + "", + "#include ", + "", + "/* Metadata */", + f"#define MODEL_NUM_CLASSES {n_classes}", + f"#define MODEL_NUM_FEATURES {n_features}", + "", + "/* Class Names */", + "static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {", + ] + + for name in self.label_names: + c_content.append(f' "{name}",') + c_content.append("};") + c_content.append("") + + c_content.append("/* Feature Extractor Parameters */") + c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f") + c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f") + c_content.append("") + + c_content.append("/* LDA Intercepts/Biases */") + c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{") + line = " " + for val in intercepts: + line += f"{val:.6f}f, " + c_content.append(line.rstrip(", ")) + c_content.append("};") + c_content.append("") + + c_content.append("/* LDA Coefficients (Weights) */") + c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{") + + for i, row in enumerate(coefs): + c_content.append(f" /* {self.label_names[i]} */") + c_content.append(" {") + line = " " + for j, val in enumerate(row): + line += f"{val:.6f}f, " + if (j + 1) % 8 == 0: + c_content.append(line) + line = " " + if line.strip(): + c_content.append(line.rstrip(", ")) + c_content.append(" },") + + c_content.append("};") + c_content.append("") + c_content.append("#endif /* MODEL_WEIGHTS_H */") + + with open(filepath, 'w') as f: + f.write('\n'.join(c_content)) + + print(f"[Classifier] Model weights exported to: {filepath}") + return filepath + @classmethod def load(cls, filepath: Path) -> 'EMGClassifier': """