overhaul
This commit is contained in:
@@ -20,6 +20,7 @@ set(DRIVER_SOURCES
|
|||||||
|
|
||||||
set(CORE_SOURCES
|
set(CORE_SOURCES
|
||||||
core/gestures.c
|
core/gestures.c
|
||||||
|
core/inference.c
|
||||||
)
|
)
|
||||||
|
|
||||||
set(APP_SOURCES
|
set(APP_SOURCES
|
||||||
|
|||||||
@@ -12,17 +12,18 @@
|
|||||||
* @note This is Layer 4 (Application).
|
* @note This is Layer 4 (Application).
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
#include "esp_timer.h"
|
||||||
|
#include <freertos/FreeRTOS.h>
|
||||||
|
#include <freertos/queue.h>
|
||||||
|
#include <freertos/task.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <freertos/FreeRTOS.h>
|
|
||||||
#include <freertos/task.h>
|
|
||||||
#include <freertos/queue.h>
|
|
||||||
#include "esp_timer.h"
|
|
||||||
|
|
||||||
#include "config/config.h"
|
#include "config/config.h"
|
||||||
#include "drivers/hand.h"
|
|
||||||
#include "drivers/emg_sensor.h"
|
|
||||||
#include "core/gestures.h"
|
#include "core/gestures.h"
|
||||||
|
#include "core/inference.h" // [NEW]
|
||||||
|
#include "drivers/emg_sensor.h"
|
||||||
|
#include "drivers/hand.h"
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Constants
|
* Constants
|
||||||
@@ -39,20 +40,22 @@
|
|||||||
* @brief Device state machine.
|
* @brief Device state machine.
|
||||||
*/
|
*/
|
||||||
typedef enum {
|
typedef enum {
|
||||||
STATE_IDLE = 0, /**< Waiting for connect command */
|
STATE_IDLE = 0, /**< Waiting for connect command */
|
||||||
STATE_CONNECTED, /**< Connected, waiting for start command */
|
STATE_CONNECTED, /**< Connected, waiting for start command */
|
||||||
STATE_STREAMING, /**< Actively streaming EMG data */
|
STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */
|
||||||
|
STATE_PREDICTING, /**< [NEW] On-device inference and control */
|
||||||
} device_state_t;
|
} device_state_t;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Commands from host.
|
* @brief Commands from host.
|
||||||
*/
|
*/
|
||||||
typedef enum {
|
typedef enum {
|
||||||
CMD_NONE = 0,
|
CMD_NONE = 0,
|
||||||
CMD_CONNECT,
|
CMD_CONNECT,
|
||||||
CMD_START,
|
CMD_START, /**< Start raw streaming */
|
||||||
CMD_STOP,
|
CMD_START_PREDICT, /**< [NEW] Start on-device prediction */
|
||||||
CMD_DISCONNECT,
|
CMD_STOP,
|
||||||
|
CMD_DISCONNECT,
|
||||||
} command_t;
|
} command_t;
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
@@ -80,38 +83,39 @@ static void send_ack_connect(void);
|
|||||||
* @param line Input line from serial
|
* @param line Input line from serial
|
||||||
* @return Parsed command
|
* @return Parsed command
|
||||||
*/
|
*/
|
||||||
static command_t parse_command(const char* line)
|
static command_t parse_command(const char *line) {
|
||||||
{
|
/* Simple JSON parsing - look for "cmd" field */
|
||||||
/* Simple JSON parsing - look for "cmd" field */
|
const char *cmd_start = strstr(line, "\"cmd\"");
|
||||||
const char* cmd_start = strstr(line, "\"cmd\"");
|
if (!cmd_start) {
|
||||||
if (!cmd_start) {
|
|
||||||
return CMD_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Find the value after "cmd": */
|
|
||||||
const char* value_start = strchr(cmd_start, ':');
|
|
||||||
if (!value_start) {
|
|
||||||
return CMD_NONE;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Skip whitespace and opening quote */
|
|
||||||
value_start++;
|
|
||||||
while (*value_start == ' ' || *value_start == '"') {
|
|
||||||
value_start++;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Match command strings */
|
|
||||||
if (strncmp(value_start, "connect", 7) == 0) {
|
|
||||||
return CMD_CONNECT;
|
|
||||||
} else if (strncmp(value_start, "start", 5) == 0) {
|
|
||||||
return CMD_START;
|
|
||||||
} else if (strncmp(value_start, "stop", 4) == 0) {
|
|
||||||
return CMD_STOP;
|
|
||||||
} else if (strncmp(value_start, "disconnect", 10) == 0) {
|
|
||||||
return CMD_DISCONNECT;
|
|
||||||
}
|
|
||||||
|
|
||||||
return CMD_NONE;
|
return CMD_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Find the value after "cmd": */
|
||||||
|
const char *value_start = strchr(cmd_start, ':');
|
||||||
|
if (!value_start) {
|
||||||
|
return CMD_NONE;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Skip whitespace and opening quote */
|
||||||
|
value_start++;
|
||||||
|
while (*value_start == ' ' || *value_start == '"') {
|
||||||
|
value_start++;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Match command strings */
|
||||||
|
if (strncmp(value_start, "connect", 7) == 0) {
|
||||||
|
return CMD_CONNECT;
|
||||||
|
} else if (strncmp(value_start, "start_predict", 13) == 0) {
|
||||||
|
return CMD_START_PREDICT;
|
||||||
|
} else if (strncmp(value_start, "start", 5) == 0) {
|
||||||
|
return CMD_START;
|
||||||
|
} else if (strncmp(value_start, "stop", 4) == 0) {
|
||||||
|
return CMD_STOP;
|
||||||
|
} else if (strncmp(value_start, "disconnect", 10) == 0) {
|
||||||
|
return CMD_DISCONNECT;
|
||||||
|
}
|
||||||
|
|
||||||
|
return CMD_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
@@ -120,224 +124,213 @@ static command_t parse_command(const char* line)
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief FreeRTOS task to read serial input and parse commands.
|
* @brief FreeRTOS task to read serial input and parse commands.
|
||||||
*
|
|
||||||
* This task runs continuously, reading lines from stdin (USB serial)
|
|
||||||
* and updating device state directly. This allows commands to interrupt
|
|
||||||
* streaming immediately via the volatile state variable.
|
|
||||||
*
|
|
||||||
* @param pvParameters Unused
|
|
||||||
*/
|
*/
|
||||||
static void serial_input_task(void* pvParameters)
|
static void serial_input_task(void *pvParameters) {
|
||||||
{
|
char line_buffer[CMD_BUFFER_SIZE];
|
||||||
char line_buffer[CMD_BUFFER_SIZE];
|
int line_idx = 0;
|
||||||
int line_idx = 0;
|
|
||||||
|
|
||||||
while (1) {
|
while (1) {
|
||||||
/* Read one character at a time */
|
int c = getchar();
|
||||||
int c = getchar();
|
|
||||||
|
|
||||||
if (c == EOF || c == 0xFF) {
|
if (c == EOF || c == 0xFF) {
|
||||||
/* No data available, yield to other tasks */
|
vTaskDelay(pdMS_TO_TICKS(10));
|
||||||
vTaskDelay(pdMS_TO_TICKS(10));
|
continue;
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (c == '\n' || c == '\r') {
|
|
||||||
/* End of line - process command */
|
|
||||||
if (line_idx > 0) {
|
|
||||||
line_buffer[line_idx] = '\0';
|
|
||||||
|
|
||||||
command_t cmd = parse_command(line_buffer);
|
|
||||||
|
|
||||||
if (cmd != CMD_NONE) {
|
|
||||||
/* Handle CONNECT command from ANY state (for reconnection/recovery) */
|
|
||||||
if (cmd == CMD_CONNECT) {
|
|
||||||
/* Stop streaming if active, reset to CONNECTED state */
|
|
||||||
g_device_state = STATE_CONNECTED;
|
|
||||||
send_ack_connect();
|
|
||||||
printf("[STATE] ANY -> CONNECTED (reconnect)\n");
|
|
||||||
}
|
|
||||||
/* Handle other state transitions */
|
|
||||||
else {
|
|
||||||
switch (g_device_state) {
|
|
||||||
case STATE_IDLE:
|
|
||||||
/* Only CONNECT allowed from IDLE (handled above) */
|
|
||||||
break;
|
|
||||||
|
|
||||||
case STATE_CONNECTED:
|
|
||||||
if (cmd == CMD_START) {
|
|
||||||
g_device_state = STATE_STREAMING;
|
|
||||||
printf("[STATE] CONNECTED -> STREAMING\n");
|
|
||||||
/* Signal state machine to start streaming */
|
|
||||||
xQueueSend(g_cmd_queue, &cmd, 0);
|
|
||||||
} else if (cmd == CMD_DISCONNECT) {
|
|
||||||
g_device_state = STATE_IDLE;
|
|
||||||
printf("[STATE] CONNECTED -> IDLE\n");
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
|
|
||||||
case STATE_STREAMING:
|
|
||||||
if (cmd == CMD_STOP) {
|
|
||||||
g_device_state = STATE_CONNECTED;
|
|
||||||
printf("[STATE] STREAMING -> CONNECTED\n");
|
|
||||||
/* Streaming loop will exit when it sees state change */
|
|
||||||
} else if (cmd == CMD_DISCONNECT) {
|
|
||||||
g_device_state = STATE_IDLE;
|
|
||||||
printf("[STATE] STREAMING -> IDLE\n");
|
|
||||||
/* Streaming loop will exit when it sees state change */
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
line_idx = 0;
|
|
||||||
}
|
|
||||||
} else if (line_idx < CMD_BUFFER_SIZE - 1) {
|
|
||||||
/* Add character to buffer */
|
|
||||||
line_buffer[line_idx++] = (char)c;
|
|
||||||
} else {
|
|
||||||
/* Buffer overflow - reset */
|
|
||||||
line_idx = 0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (c == '\n' || c == '\r') {
|
||||||
|
if (line_idx > 0) {
|
||||||
|
line_buffer[line_idx] = '\0';
|
||||||
|
command_t cmd = parse_command(line_buffer);
|
||||||
|
|
||||||
|
if (cmd != CMD_NONE) {
|
||||||
|
if (cmd == CMD_CONNECT) {
|
||||||
|
g_device_state = STATE_CONNECTED;
|
||||||
|
send_ack_connect();
|
||||||
|
printf("[STATE] ANY -> CONNECTED (reconnect)\n");
|
||||||
|
} else {
|
||||||
|
switch (g_device_state) {
|
||||||
|
case STATE_IDLE:
|
||||||
|
break;
|
||||||
|
|
||||||
|
case STATE_CONNECTED:
|
||||||
|
if (cmd == CMD_START) {
|
||||||
|
g_device_state = STATE_STREAMING;
|
||||||
|
printf("[STATE] CONNECTED -> STREAMING\n");
|
||||||
|
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||||
|
} else if (cmd == CMD_START_PREDICT) {
|
||||||
|
g_device_state = STATE_PREDICTING;
|
||||||
|
printf("[STATE] CONNECTED -> PREDICTING\n");
|
||||||
|
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||||
|
} else if (cmd == CMD_DISCONNECT) {
|
||||||
|
g_device_state = STATE_IDLE;
|
||||||
|
printf("[STATE] CONNECTED -> IDLE\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
|
||||||
|
case STATE_STREAMING:
|
||||||
|
case STATE_PREDICTING:
|
||||||
|
if (cmd == CMD_STOP) {
|
||||||
|
g_device_state = STATE_CONNECTED;
|
||||||
|
printf("[STATE] ACTIVE -> CONNECTED\n");
|
||||||
|
} else if (cmd == CMD_DISCONNECT) {
|
||||||
|
g_device_state = STATE_IDLE;
|
||||||
|
printf("[STATE] ACTIVE -> IDLE\n");
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
line_idx = 0;
|
||||||
|
}
|
||||||
|
} else if (line_idx < CMD_BUFFER_SIZE - 1) {
|
||||||
|
line_buffer[line_idx++] = (char)c;
|
||||||
|
} else {
|
||||||
|
line_idx = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* State Machine
|
* State Machine
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
/**
|
static void send_ack_connect(void) {
|
||||||
* @brief Send JSON acknowledgment for connection.
|
printf(
|
||||||
*/
|
"{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n",
|
||||||
static void send_ack_connect(void)
|
EMG_NUM_CHANNELS);
|
||||||
{
|
fflush(stdout);
|
||||||
printf("{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n",
|
|
||||||
EMG_NUM_CHANNELS);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Stream EMG data continuously until stopped.
|
* @brief Stream raw EMG data (Training Mode).
|
||||||
*
|
|
||||||
* This function blocks and streams data at the configured sample rate.
|
|
||||||
* Returns when state changes from STREAMING.
|
|
||||||
*/
|
*/
|
||||||
static void stream_emg_data(void)
|
static void stream_emg_data(void) {
|
||||||
{
|
emg_sample_t sample;
|
||||||
emg_sample_t sample;
|
const TickType_t delay_ticks = 1;
|
||||||
const TickType_t delay_ticks = 1; /* 1 tick = 1ms at 1000 Hz tick rate */
|
|
||||||
|
|
||||||
while (g_device_state == STATE_STREAMING) {
|
while (g_device_state == STATE_STREAMING) {
|
||||||
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
|
|
||||||
emg_sensor_read(&sample);
|
|
||||||
|
|
||||||
/* Output in CSV format - channels only, Python handles timestamps */
|
|
||||||
printf("%u,%u,%u,%u\n",
|
|
||||||
sample.channels[0],
|
|
||||||
sample.channels[1],
|
|
||||||
sample.channels[2],
|
|
||||||
sample.channels[3]);
|
|
||||||
|
|
||||||
/* Yield to FreeRTOS scheduler - prevents watchdog timeout */
|
|
||||||
vTaskDelay(delay_ticks);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Main state machine loop.
|
|
||||||
*
|
|
||||||
* Monitors device state and starts streaming when requested.
|
|
||||||
* Serial input task handles all state transitions directly.
|
|
||||||
*/
|
|
||||||
static void state_machine_loop(void)
|
|
||||||
{
|
|
||||||
command_t cmd;
|
|
||||||
const TickType_t poll_interval = pdMS_TO_TICKS(50);
|
|
||||||
|
|
||||||
while (1) {
|
|
||||||
/* Check if we should start streaming */
|
|
||||||
if (g_device_state == STATE_STREAMING) {
|
|
||||||
/* Stream until state changes (via serial input task) */
|
|
||||||
stream_emg_data();
|
|
||||||
/* Returns when state is no longer STREAMING */
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Wait for start command or just poll state */
|
|
||||||
/* Timeout allows checking state even if queue is empty */
|
|
||||||
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
|
|
||||||
|
|
||||||
/* Note: State transitions are handled by serial_input_task */
|
|
||||||
/* This loop only triggers streaming when state becomes STREAMING */
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void emgPrinter() {
|
|
||||||
TickType_t previousWake = xTaskGetTickCount();
|
|
||||||
while (1) {
|
|
||||||
emg_sample_t sample;
|
|
||||||
emg_sensor_read(&sample);
|
emg_sensor_read(&sample);
|
||||||
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) {
|
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
|
||||||
printf("%d", sample.channels[i]);
|
sample.channels[2], sample.channels[3]);
|
||||||
if (i != EMG_NUM_CHANNELS - 1) printf(" | ");
|
vTaskDelay(delay_ticks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run on-device inference (Prediction Mode).
|
||||||
|
*/
|
||||||
|
static void run_inference_loop(void) {
|
||||||
|
emg_sample_t sample;
|
||||||
|
const TickType_t delay_ticks = 1; // 1ms @ 1kHz
|
||||||
|
int last_gesture = -1;
|
||||||
|
|
||||||
|
// Reset inference state
|
||||||
|
inference_init();
|
||||||
|
printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n");
|
||||||
|
|
||||||
|
while (g_device_state == STATE_PREDICTING) {
|
||||||
|
emg_sensor_read(&sample);
|
||||||
|
|
||||||
|
// Add to buffer
|
||||||
|
// Note: sample.channels is uint16_t, matching inference engine expectation
|
||||||
|
if (inference_add_sample(sample.channels)) {
|
||||||
|
// Buffer full (sliding window), run prediction
|
||||||
|
// We can optimize stride here (e.g. valid prediction only every N
|
||||||
|
// samples) For now, let's predict every sample (sliding window) or
|
||||||
|
// throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz?
|
||||||
|
// maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid
|
||||||
|
// UART spam.
|
||||||
|
|
||||||
|
static int stride_counter = 0;
|
||||||
|
stride_counter++;
|
||||||
|
|
||||||
|
if (stride_counter >= 20) { // 20ms stride
|
||||||
|
float confidence = 0;
|
||||||
|
int gesture_idx = inference_predict(&confidence);
|
||||||
|
stride_counter = 0;
|
||||||
|
|
||||||
|
if (gesture_idx >= 0) {
|
||||||
|
// Map class index (0-N) to gesture enum (correct hardware action)
|
||||||
|
int gesture_enum = inference_get_gesture_enum(gesture_idx);
|
||||||
|
|
||||||
|
// Execute gesture on hand
|
||||||
|
gestures_execute((gesture_t)gesture_enum);
|
||||||
|
|
||||||
|
// Send telemetry if changed or periodically?
|
||||||
|
// "Live prediction flow should change to only have each new output...
|
||||||
|
// sent"
|
||||||
|
if (gesture_idx != last_gesture) {
|
||||||
|
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
|
||||||
|
inference_get_class_name(gesture_idx), confidence);
|
||||||
|
last_gesture = gesture_idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
printf("\n");
|
|
||||||
// vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
|
vTaskDelay(delay_ticks);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void state_machine_loop(void) {
|
||||||
|
command_t cmd;
|
||||||
|
const TickType_t poll_interval = pdMS_TO_TICKS(50);
|
||||||
|
|
||||||
|
while (1) {
|
||||||
|
if (g_device_state == STATE_STREAMING) {
|
||||||
|
stream_emg_data();
|
||||||
|
} else if (g_device_state == STATE_PREDICTING) {
|
||||||
|
run_inference_loop();
|
||||||
|
}
|
||||||
|
|
||||||
|
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void appConnector() {
|
void appConnector() {
|
||||||
/* Create command queue */
|
g_cmd_queue = xQueueCreate(10, sizeof(command_t));
|
||||||
g_cmd_queue = xQueueCreate(10, sizeof(command_t));
|
if (g_cmd_queue == NULL) {
|
||||||
if (g_cmd_queue == NULL) {
|
printf("[ERROR] Failed to create command queue!\n");
|
||||||
printf("[ERROR] Failed to create command queue!\n");
|
return;
|
||||||
return;
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/* Launch serial input task */
|
xTaskCreate(serial_input_task, "serial_input", 4096, NULL, 5, NULL);
|
||||||
xTaskCreate(
|
|
||||||
serial_input_task,
|
|
||||||
"serial_input",
|
|
||||||
4096, /* Stack size */
|
|
||||||
NULL, /* Parameters */
|
|
||||||
5, /* Priority */
|
|
||||||
NULL /* Task handle */
|
|
||||||
);
|
|
||||||
|
|
||||||
printf("[PROTOCOL] Waiting for host to connect...\n");
|
printf("[PROTOCOL] Waiting for host to connect...\n");
|
||||||
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n\n");
|
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n");
|
||||||
|
printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device "
|
||||||
|
"inference\n\n");
|
||||||
|
|
||||||
/* Run main state machine */
|
state_machine_loop();
|
||||||
state_machine_loop();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Application Entry Point
|
* Application Entry Point
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
void app_main(void)
|
void app_main(void) {
|
||||||
{
|
printf("\n");
|
||||||
printf("\n");
|
printf("========================================\n");
|
||||||
printf("========================================\n");
|
printf(" Bucky Arm - EMG Robotic Hand\n");
|
||||||
printf(" Bucky Arm - EMG Robotic Hand\n");
|
printf(" Firmware v2.1.0 (Inference Enabled)\n");
|
||||||
printf(" Firmware v2.0.0 (Handshake Protocol)\n");
|
printf("========================================\n\n");
|
||||||
printf("========================================\n\n");
|
|
||||||
|
|
||||||
/* Initialize subsystems */
|
printf("[INIT] Initializing hand (servos)...\n");
|
||||||
printf("[INIT] Initializing hand (servos)...\n");
|
hand_init();
|
||||||
hand_init();
|
|
||||||
|
|
||||||
printf("[INIT] Initializing EMG sensor...\n");
|
printf("[INIT] Initializing EMG sensor...\n");
|
||||||
emg_sensor_init();
|
emg_sensor_init();
|
||||||
|
|
||||||
|
printf("[INIT] Initializing Inference Engine...\n");
|
||||||
|
inference_init();
|
||||||
|
|
||||||
#if FEATURE_FAKE_EMG
|
#if FEATURE_FAKE_EMG
|
||||||
printf("[INIT] Using FAKE EMG data (sensors not connected)\n");
|
printf("[INIT] Using FAKE EMG data (sensors not connected)\n");
|
||||||
#else
|
#else
|
||||||
printf("[INIT] Using REAL EMG sensors\n");
|
printf("[INIT] Using REAL EMG sensors\n");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
printf("[INIT] Done!\n\n");
|
printf("[INIT] Done!\n\n");
|
||||||
|
|
||||||
// emgPrinter();
|
appConnector();
|
||||||
appConnector();
|
|
||||||
}
|
}
|
||||||
|
|||||||
281
EMG_Arm/src/core/inference.c
Normal file
281
EMG_Arm/src/core/inference.c
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
/**
|
||||||
|
* @file inference.c
|
||||||
|
* @brief Implementation of EMG inference engine.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "inference.h"
|
||||||
|
#include "config/config.h"
|
||||||
|
#include "model_weights.h"
|
||||||
|
#include <math.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
#include <string.h>
|
||||||
|
|
||||||
|
// --- Constants ---
|
||||||
|
#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python)
|
||||||
|
#define VOTE_WINDOW 5 // Majority vote window size
|
||||||
|
#define DEBOUNCE_COUNT 3 // Confirmations needed to change output
|
||||||
|
|
||||||
|
// --- State ---
|
||||||
|
static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS];
|
||||||
|
static int buffer_head = 0;
|
||||||
|
static int samples_collected = 0;
|
||||||
|
|
||||||
|
// Smoothing State
|
||||||
|
static float smoothed_probs[MODEL_NUM_CLASSES];
|
||||||
|
static int vote_history[VOTE_WINDOW];
|
||||||
|
static int vote_head = 0;
|
||||||
|
static int current_output = -1;
|
||||||
|
static int pending_output = -1;
|
||||||
|
static int pending_count = 0;
|
||||||
|
|
||||||
|
void inference_init(void) {
|
||||||
|
memset(window_buffer, 0, sizeof(window_buffer));
|
||||||
|
buffer_head = 0;
|
||||||
|
samples_collected = 0;
|
||||||
|
|
||||||
|
// Initialize smoothing
|
||||||
|
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
|
||||||
|
smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < VOTE_WINDOW; i++) {
|
||||||
|
vote_history[i] = -1;
|
||||||
|
}
|
||||||
|
vote_head = 0;
|
||||||
|
current_output = -1;
|
||||||
|
pending_output = -1;
|
||||||
|
pending_count = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool inference_add_sample(uint16_t *channels) {
|
||||||
|
// Add to circular buffer
|
||||||
|
for (int i = 0; i < NUM_CHANNELS; i++) {
|
||||||
|
window_buffer[buffer_head][i] = channels[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE;
|
||||||
|
|
||||||
|
if (samples_collected < INFERENCE_WINDOW_SIZE) {
|
||||||
|
samples_collected++;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true; // Buffer is full (always ready in sliding window, but caller
|
||||||
|
// controls stride)
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Feature Extraction ---
|
||||||
|
|
||||||
|
static void compute_features(float *features_out) {
|
||||||
|
// Process each channel
|
||||||
|
// We need to iterate over the logical window (unrolling circular buffer)
|
||||||
|
|
||||||
|
for (int ch = 0; ch < NUM_CHANNELS; ch++) {
|
||||||
|
float sum = 0;
|
||||||
|
float sq_sum = 0;
|
||||||
|
|
||||||
|
// Pass 1: Mean (for centering) and raw values collection
|
||||||
|
// We could optimize by not copying, but accessing logically is safer
|
||||||
|
float signal[INFERENCE_WINDOW_SIZE];
|
||||||
|
|
||||||
|
int idx = buffer_head; // Oldest sample
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
signal[i] = (float)window_buffer[idx][ch];
|
||||||
|
sum += signal[i];
|
||||||
|
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
float mean = sum / INFERENCE_WINDOW_SIZE;
|
||||||
|
|
||||||
|
// Pass 2: Centering and Features
|
||||||
|
float wl = 0;
|
||||||
|
int zc = 0;
|
||||||
|
int ssc = 0;
|
||||||
|
|
||||||
|
// Center the signal
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||||
|
signal[i] -= mean;
|
||||||
|
sq_sum += signal[i] * signal[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
|
||||||
|
|
||||||
|
// Thresholds
|
||||||
|
float zc_thresh = FEAT_ZC_THRESH * rms;
|
||||||
|
float ssc_thresh = (FEAT_SSC_THRESH * rms) *
|
||||||
|
(FEAT_SSC_THRESH * rms); // threshold is on diff product
|
||||||
|
|
||||||
|
for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) {
|
||||||
|
// WL
|
||||||
|
wl += fabsf(signal[i + 1] - signal[i]);
|
||||||
|
|
||||||
|
// ZC
|
||||||
|
if ((signal[i] > 0 && signal[i + 1] < 0) ||
|
||||||
|
(signal[i] < 0 && signal[i + 1] > 0)) {
|
||||||
|
if (fabsf(signal[i] - signal[i + 1]) > zc_thresh) {
|
||||||
|
zc++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSC (needs 3 points, so loop to N-2)
|
||||||
|
if (i < INFERENCE_WINDOW_SIZE - 2) {
|
||||||
|
float diff1 = signal[i + 1] - signal[i];
|
||||||
|
float diff2 = signal[i + 1] - signal[i + 2];
|
||||||
|
if ((diff1 * diff2) > ssc_thresh) {
|
||||||
|
ssc++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store features: [RMS, WL, ZC, SSC] per channel
|
||||||
|
int base = ch * 4;
|
||||||
|
features_out[base + 0] = rms;
|
||||||
|
features_out[base + 1] = wl;
|
||||||
|
features_out[base + 2] = (float)zc;
|
||||||
|
features_out[base + 3] = (float)ssc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Prediction ---
|
||||||
|
|
||||||
|
int inference_predict(float *confidence) {
|
||||||
|
if (samples_collected < INFERENCE_WINDOW_SIZE) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. Extract Features
|
||||||
|
float features[MODEL_NUM_FEATURES];
|
||||||
|
compute_features(features);
|
||||||
|
|
||||||
|
// 2. LDA Inference (Linear Score)
|
||||||
|
float raw_scores[MODEL_NUM_CLASSES];
|
||||||
|
float max_score = -1e9;
|
||||||
|
int max_idx = 0;
|
||||||
|
|
||||||
|
// Calculate raw discriminative scores
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
float score = LDA_INTERCEPTS[c];
|
||||||
|
for (int f = 0; f < MODEL_NUM_FEATURES; f++) {
|
||||||
|
score += features[f] * LDA_WEIGHTS[c][f];
|
||||||
|
}
|
||||||
|
raw_scores[c] = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert scores to probabilities (Softmax)
|
||||||
|
// LDA scores are log-likelihoods + const. Softmax is appropriate.
|
||||||
|
float sum_exp = 0;
|
||||||
|
float probas[MODEL_NUM_CLASSES];
|
||||||
|
|
||||||
|
// Numerical stability: subtract max
|
||||||
|
// Create temp copy for max finding
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
if (raw_scores[c] > max_score)
|
||||||
|
max_score = raw_scores[c];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
probas[c] = expf(raw_scores[c] - max_score);
|
||||||
|
sum_exp += probas[c];
|
||||||
|
}
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
probas[c] /= sum_exp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Smoothing
|
||||||
|
// 3a. Probability EMA
|
||||||
|
float max_smoothed_prob = 0;
|
||||||
|
int smoothed_winner = 0;
|
||||||
|
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
smoothed_probs[c] = (SMOOTHING_FACTOR * smoothed_probs[c]) +
|
||||||
|
((1.0f - SMOOTHING_FACTOR) * probas[c]);
|
||||||
|
|
||||||
|
if (smoothed_probs[c] > max_smoothed_prob) {
|
||||||
|
max_smoothed_prob = smoothed_probs[c];
|
||||||
|
smoothed_winner = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3b. Majority Vote
|
||||||
|
vote_history[vote_head] = smoothed_winner;
|
||||||
|
vote_head = (vote_head + 1) % VOTE_WINDOW;
|
||||||
|
|
||||||
|
int counts[MODEL_NUM_CLASSES];
|
||||||
|
memset(counts, 0, sizeof(counts));
|
||||||
|
|
||||||
|
for (int i = 0; i < VOTE_WINDOW; i++) {
|
||||||
|
if (vote_history[i] != -1) {
|
||||||
|
counts[vote_history[i]]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int majority_winner = 0;
|
||||||
|
int majority_count = 0;
|
||||||
|
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||||
|
if (counts[c] > majority_count) {
|
||||||
|
majority_count = counts[c];
|
||||||
|
majority_winner = c;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3c. Debounce
|
||||||
|
int final_result = current_output;
|
||||||
|
|
||||||
|
if (current_output == -1) {
|
||||||
|
current_output = majority_winner;
|
||||||
|
pending_output = majority_winner;
|
||||||
|
pending_count = 1;
|
||||||
|
final_result = majority_winner;
|
||||||
|
} else if (majority_winner == current_output) {
|
||||||
|
pending_output = majority_winner;
|
||||||
|
pending_count = 1;
|
||||||
|
} else if (majority_winner == pending_output) {
|
||||||
|
pending_count++;
|
||||||
|
if (pending_count >= DEBOUNCE_COUNT) {
|
||||||
|
current_output = majority_winner;
|
||||||
|
final_result = majority_winner;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
pending_output = majority_winner;
|
||||||
|
pending_count = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use smoothed probability of the final winner as confidence
|
||||||
|
// Or simpler: use fraction of votes
|
||||||
|
*confidence = (float)majority_count / VOTE_WINDOW;
|
||||||
|
|
||||||
|
return final_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char *inference_get_class_name(int class_idx) {
|
||||||
|
if (class_idx >= 0 && class_idx < MODEL_NUM_CLASSES) {
|
||||||
|
return MODEL_CLASS_NAMES[class_idx];
|
||||||
|
}
|
||||||
|
return "UNKNOWN";
|
||||||
|
}
|
||||||
|
|
||||||
|
int inference_get_gesture_enum(int class_idx) {
|
||||||
|
const char *name = inference_get_class_name(class_idx);
|
||||||
|
|
||||||
|
// Map string name to gesture_t enum
|
||||||
|
// Strings must match those in Python list: ["fist", "hook_em", "open",
|
||||||
|
// "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums
|
||||||
|
// are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5
|
||||||
|
|
||||||
|
// Case-insensitive check would be safer, but let's assume Python output is
|
||||||
|
// lowercase as seen in scripts or uppercase if specified. In
|
||||||
|
// learning_data_collection.py, they seem to be "rest", "open", "fist", etc.
|
||||||
|
|
||||||
|
// Simple string matching
|
||||||
|
if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0)
|
||||||
|
return GESTURE_REST;
|
||||||
|
if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0)
|
||||||
|
return GESTURE_FIST;
|
||||||
|
if (strcmp(name, "open") == 0 || strcmp(name, "OPEN") == 0)
|
||||||
|
return GESTURE_OPEN;
|
||||||
|
if (strcmp(name, "hook_em") == 0 || strcmp(name, "HOOK_EM") == 0)
|
||||||
|
return GESTURE_HOOK_EM;
|
||||||
|
if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0)
|
||||||
|
return GESTURE_THUMBS_UP;
|
||||||
|
|
||||||
|
return GESTURE_NONE;
|
||||||
|
}
|
||||||
47
EMG_Arm/src/core/inference.h
Normal file
47
EMG_Arm/src/core/inference.h
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
/**
|
||||||
|
* @file inference.h
|
||||||
|
* @brief On-device inference engine for EMG gesture recognition.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef INFERENCE_H
|
||||||
|
#define INFERENCE_H
|
||||||
|
|
||||||
|
#include <stdbool.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
// --- Configuration ---
|
||||||
|
#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python)
|
||||||
|
#define NUM_CHANNELS 4 // Number of EMG channels
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Initialize the inference engine.
|
||||||
|
*/
|
||||||
|
void inference_init(void);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Add a sample to the inference buffer.
|
||||||
|
*
|
||||||
|
* @param channels Array of 4 channel values (raw ADC)
|
||||||
|
* @return true if a full window is ready for processing
|
||||||
|
*/
|
||||||
|
bool inference_add_sample(uint16_t *channels);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run inference on the current window.
|
||||||
|
*
|
||||||
|
* @param confidence Output pointer for confidence score (0.0 - 1.0)
|
||||||
|
* @return Detected class index (-1 if error)
|
||||||
|
*/
|
||||||
|
int inference_predict(float *confidence);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Get the name of a class index.
|
||||||
|
*/
|
||||||
|
const char *inference_get_class_name(int class_idx);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Map class index to gesture_t enum.
|
||||||
|
*/
|
||||||
|
int inference_get_gesture_enum(int class_idx);
|
||||||
|
|
||||||
|
#endif /* INFERENCE_H */
|
||||||
44
EMG_Arm/src/core/model_weights.h
Normal file
44
EMG_Arm/src/core/model_weights.h
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
/**
|
||||||
|
* @file model_weights.h
|
||||||
|
* @brief Placeholder for trained model weights.
|
||||||
|
*
|
||||||
|
* This file should be generated by the Python script (Training Page -> Export for ESP32).
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MODEL_WEIGHTS_H
|
||||||
|
#define MODEL_WEIGHTS_H
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
/* Metadata */
|
||||||
|
#define MODEL_NUM_CLASSES 5
|
||||||
|
#define MODEL_NUM_FEATURES 16
|
||||||
|
|
||||||
|
/* Class Names */
|
||||||
|
static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
|
||||||
|
"REST",
|
||||||
|
"OPEN",
|
||||||
|
"FIST",
|
||||||
|
"HOOK_EM",
|
||||||
|
"THUMBS_UP",
|
||||||
|
};
|
||||||
|
|
||||||
|
/* Feature Extractor Parameters */
|
||||||
|
#define FEAT_ZC_THRESH 0.1f
|
||||||
|
#define FEAT_SSC_THRESH 0.1f
|
||||||
|
|
||||||
|
/* LDA Intercepts/Biases */
|
||||||
|
static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {
|
||||||
|
0.0f, 0.0f, 0.0f, 0.0f, 0.0f
|
||||||
|
};
|
||||||
|
|
||||||
|
/* LDA Coefficients (Weights) */
|
||||||
|
static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {
|
||||||
|
{0.0f},
|
||||||
|
{0.0f},
|
||||||
|
{0.0f},
|
||||||
|
{0.0f},
|
||||||
|
{0.0f}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif /* MODEL_WEIGHTS_H */
|
||||||
267
emg_gui.py
267
emg_gui.py
@@ -1322,6 +1322,18 @@ class TrainingPage(BasePage):
|
|||||||
)
|
)
|
||||||
self.train_button.pack(pady=20)
|
self.train_button.pack(pady=20)
|
||||||
|
|
||||||
|
# Export button
|
||||||
|
self.export_button = ctk.CTkButton(
|
||||||
|
self.content,
|
||||||
|
text="Export for ESP32",
|
||||||
|
font=ctk.CTkFont(size=14),
|
||||||
|
height=40,
|
||||||
|
fg_color="green",
|
||||||
|
state="disabled",
|
||||||
|
command=self.export_model
|
||||||
|
)
|
||||||
|
self.export_button.pack(pady=5)
|
||||||
|
|
||||||
# Progress
|
# Progress
|
||||||
self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
|
self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
|
||||||
self.progress_bar.pack(pady=10)
|
self.progress_bar.pack(pady=10)
|
||||||
@@ -1434,6 +1446,32 @@ class TrainingPage(BasePage):
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.after(0, lambda: self.train_button.configure(state="normal"))
|
self.after(0, lambda: self.train_button.configure(state="normal"))
|
||||||
|
self.after(0, lambda: self.export_button.configure(state="normal"))
|
||||||
|
|
||||||
|
def export_model(self):
|
||||||
|
"""Export trained model to C header."""
|
||||||
|
if not self.classifier or not self.classifier.is_trained:
|
||||||
|
messagebox.showerror("Error", "No trained model to export!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Default path in ESP32 project
|
||||||
|
default_path = Path("EMG_Arm/src/core/model_weights.h").absolute()
|
||||||
|
|
||||||
|
# Ask user for location, defaulting to the ESP32 project source
|
||||||
|
filename = tk.filedialog.asksaveasfilename(
|
||||||
|
title="Export Model Header",
|
||||||
|
initialdir=default_path.parent,
|
||||||
|
initialfile=default_path.name,
|
||||||
|
filetypes=[("C Header", "*.h")]
|
||||||
|
)
|
||||||
|
|
||||||
|
if filename:
|
||||||
|
try:
|
||||||
|
path = self.classifier.export_to_header(filename)
|
||||||
|
self._log(f"\nExported model to: {path}")
|
||||||
|
messagebox.showinfo("Export Success", f"Model exported to:\n{path}\n\nRecompile ESP32 firmware to apply.")
|
||||||
|
except Exception as e:
|
||||||
|
messagebox.showerror("Export Error", f"Failed to export: {e}")
|
||||||
|
|
||||||
def _log(self, text: str):
|
def _log(self, text: str):
|
||||||
"""Add text to results."""
|
"""Add text to results."""
|
||||||
@@ -1861,115 +1899,158 @@ class PredictionPage(BasePage):
|
|||||||
self._update_connection_status("gray", "Disconnected")
|
self._update_connection_status("gray", "Disconnected")
|
||||||
self.connect_button.configure(text="Connect")
|
self.connect_button.configure(text="Connect")
|
||||||
|
|
||||||
def _prediction_thread(self):
|
def toggle_prediction(self):
|
||||||
"""Background prediction thread."""
|
"""Start or stop prediction."""
|
||||||
# For simulated mode, create new stream
|
if self.is_predicting:
|
||||||
if not self.using_real_hardware:
|
self.stop_prediction()
|
||||||
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
|
else:
|
||||||
|
self.start_prediction()
|
||||||
|
|
||||||
# Stream is already started (either via handshake for real HW or will be started for simulated)
|
def start_prediction(self):
|
||||||
|
"""Start live prediction."""
|
||||||
|
# Determine mode
|
||||||
|
self.using_real_hardware = (self.source_var.get() == "real")
|
||||||
|
|
||||||
|
if self.using_real_hardware:
|
||||||
|
if not self.is_connected or not self.stream:
|
||||||
|
messagebox.showerror("Not Connected", "Please connect to ESP32 first.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("[DEBUG] Starting Edge Prediction (On-Device)...")
|
||||||
|
try:
|
||||||
|
# Send "start_predict" command to ESP32
|
||||||
|
if hasattr(self.stream, 'ser'):
|
||||||
|
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
|
||||||
|
self.stream.running = True
|
||||||
|
else:
|
||||||
|
# Fallback
|
||||||
|
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
|
||||||
|
self.stream.running = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
messagebox.showerror("Start Error", f"Failed to start: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Simulated - use PC-side inference
|
||||||
|
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
|
||||||
|
self.stream.start()
|
||||||
|
|
||||||
|
# Load model for PC-side (Simulated) OR for display (optional)
|
||||||
|
# Even for Edge, we might want the label list.
|
||||||
|
if not self.using_real_hardware:
|
||||||
|
if not self.classifier:
|
||||||
|
model_path = EMGClassifier.get_default_model_path()
|
||||||
|
if model_path.exists():
|
||||||
|
self.classifier = EMGClassifier.load(model_path)
|
||||||
|
self.model_label.configure(text="Model: Loaded", text_color="green")
|
||||||
|
else:
|
||||||
|
self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange")
|
||||||
|
|
||||||
|
# Reset smoother
|
||||||
|
self.smoother = PredictionSmoother(
|
||||||
|
label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"],
|
||||||
|
probability_smoothing=0.7,
|
||||||
|
majority_vote_window=5,
|
||||||
|
debounce_count=3
|
||||||
|
)
|
||||||
|
|
||||||
|
self.is_predicting = True
|
||||||
|
self.start_button.configure(text="Stop Prediction", fg_color="red")
|
||||||
|
|
||||||
|
# Start display loop
|
||||||
|
self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True)
|
||||||
|
self.prediction_thread.start()
|
||||||
|
|
||||||
|
self.update_prediction_ui()
|
||||||
|
|
||||||
|
def stop_prediction(self):
|
||||||
|
"""Stop prediction."""
|
||||||
|
self.is_predicting = False
|
||||||
|
if self.stream:
|
||||||
|
self.stream.stop() # Sends "stop" usually
|
||||||
|
if not self.using_real_hardware:
|
||||||
|
self.stream = None
|
||||||
|
|
||||||
|
self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"])
|
||||||
|
self.prediction_label.configure(text="---", text_color="gray")
|
||||||
|
self.confidence_label.configure(text="Confidence: ---%")
|
||||||
|
self.confidence_bar.set(0)
|
||||||
|
|
||||||
|
def prediction_loop(self):
|
||||||
|
"""Loop for reading data and (optionally) running inference."""
|
||||||
|
import json
|
||||||
|
|
||||||
parser = EMGParser(num_channels=NUM_CHANNELS)
|
parser = EMGParser(num_channels=NUM_CHANNELS)
|
||||||
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
|
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
|
||||||
|
|
||||||
# Simulated gesture cycling (only for simulated mode)
|
|
||||||
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
|
|
||||||
gesture_idx = 0
|
|
||||||
gesture_duration = 2.5
|
|
||||||
gesture_start = time.perf_counter()
|
|
||||||
current_gesture = gesture_cycle[0]
|
|
||||||
|
|
||||||
# Start simulated stream if needed
|
|
||||||
if not self.using_real_hardware:
|
|
||||||
try:
|
|
||||||
if hasattr(self.stream, 'set_gesture'):
|
|
||||||
self.stream.set_gesture(current_gesture)
|
|
||||||
self.stream.start()
|
|
||||||
except Exception as e:
|
|
||||||
self.data_queue.put(('error', f"Failed to start simulated stream: {e}"))
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# Real hardware is already streaming
|
|
||||||
self.data_queue.put(('connection_status', ('green', 'Streaming')))
|
|
||||||
|
|
||||||
while self.is_predicting:
|
while self.is_predicting:
|
||||||
# Change simulated gesture periodically (only for simulated mode)
|
|
||||||
if hasattr(self.stream, 'set_gesture'):
|
|
||||||
elapsed = time.perf_counter() - gesture_start
|
|
||||||
if elapsed > gesture_duration:
|
|
||||||
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
|
|
||||||
gesture_start = time.perf_counter()
|
|
||||||
current_gesture = gesture_cycle[gesture_idx]
|
|
||||||
self.stream.set_gesture(current_gesture)
|
|
||||||
self.data_queue.put(('sim_gesture', current_gesture))
|
|
||||||
|
|
||||||
# Read and process
|
|
||||||
try:
|
try:
|
||||||
line = self.stream.readline()
|
line = self.stream.readline()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if self.using_real_hardware:
|
||||||
|
# Edge Inference Mode: Expect JSON
|
||||||
|
try:
|
||||||
|
line = line.strip()
|
||||||
|
if line.startswith('{'):
|
||||||
|
data = json.loads(line)
|
||||||
|
|
||||||
|
if "gesture" in data:
|
||||||
|
# Update UI with Edge Prediction
|
||||||
|
gesture = data["gesture"]
|
||||||
|
conf = float(data.get("conf", 0.0))
|
||||||
|
|
||||||
|
self.data_queue.put(('prediction', (gesture, conf)))
|
||||||
|
|
||||||
|
elif "status" in data:
|
||||||
|
print(f"[ESP32] {data}")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
# PC Side Inference (Simulated)
|
||||||
|
sample = parser.parse_line(line)
|
||||||
|
if sample:
|
||||||
|
window = windower.add_sample(sample)
|
||||||
|
if window and self.classifier:
|
||||||
|
# Run Inference Local
|
||||||
|
raw_label, proba = self.classifier.predict(window.to_numpy())
|
||||||
|
label, conf, _ = self.smoother.update(raw_label, proba)
|
||||||
|
|
||||||
|
self.data_queue.put(('prediction', (label, conf)))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Only report error if we didn't intentionally stop
|
|
||||||
if self.is_predicting:
|
if self.is_predicting:
|
||||||
self.data_queue.put(('error', f"Serial read error: {e}"))
|
print(f"Prediction loop error: {e}")
|
||||||
break
|
break
|
||||||
|
|
||||||
if line:
|
|
||||||
sample = parser.parse_line(line)
|
|
||||||
if sample:
|
|
||||||
window = windower.add_sample(sample)
|
|
||||||
if window:
|
|
||||||
# Get raw prediction
|
|
||||||
window_data = window.to_numpy()
|
|
||||||
raw_label, proba = self.classifier.predict(window_data)
|
|
||||||
raw_confidence = max(proba) * 100
|
|
||||||
|
|
||||||
# Apply smoothing
|
|
||||||
smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba)
|
|
||||||
smoothed_confidence = smoothed_conf * 100
|
|
||||||
|
|
||||||
# Send both raw and smoothed to UI
|
|
||||||
self.data_queue.put(('prediction', (
|
|
||||||
smoothed_label, # The stable output
|
|
||||||
smoothed_confidence,
|
|
||||||
raw_label, # The raw (possibly twitchy) output
|
|
||||||
raw_confidence,
|
|
||||||
)))
|
|
||||||
|
|
||||||
# Safe cleanup - stream might already be stopped
|
|
||||||
try:
|
|
||||||
if self.stream:
|
|
||||||
self.stream.stop()
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore cleanup errors
|
|
||||||
|
|
||||||
def update_prediction_ui(self):
|
def update_prediction_ui(self):
|
||||||
"""Update UI from prediction thread."""
|
"""Update UI from queue."""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
msg_type, data = self.data_queue.get_nowait()
|
msg_type, data = self.data_queue.get_nowait()
|
||||||
|
|
||||||
if msg_type == 'prediction':
|
if msg_type == 'prediction':
|
||||||
smoothed_label, smoothed_conf, raw_label, raw_conf = data
|
label, conf = data
|
||||||
|
|
||||||
# Display smoothed (stable) prediction
|
# Update label
|
||||||
display_label = smoothed_label.upper().replace("_", " ")
|
self.prediction_label.configure(
|
||||||
color = get_gesture_color(smoothed_label)
|
text=label.upper(),
|
||||||
|
text_color=get_gesture_color(label)
|
||||||
self.prediction_label.configure(text=display_label, text_color=color)
|
)
|
||||||
self.confidence_bar.set(smoothed_conf / 100)
|
|
||||||
self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%")
|
# Update confidence
|
||||||
|
self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%")
|
||||||
# Show raw prediction for comparison (grayed out)
|
self.confidence_bar.set(conf)
|
||||||
raw_display = raw_label.upper().replace("_", " ")
|
|
||||||
if raw_label != smoothed_label:
|
# Clear raw label since we don't have raw vs smooth distinction in edge mode
|
||||||
# Raw differs from smoothed - show it was filtered
|
# (or we could expose it if we updated the C struct, but for now keep it simple)
|
||||||
self.raw_label.configure(
|
self.raw_label.configure(text="", text_color="gray")
|
||||||
text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered",
|
|
||||||
text_color="orange"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.raw_label.configure(
|
|
||||||
text=f"Raw: {raw_display} ({raw_conf:.0f}%)",
|
|
||||||
text_color="gray"
|
|
||||||
)
|
|
||||||
|
|
||||||
elif msg_type == 'sim_gesture':
|
elif msg_type == 'sim_gesture':
|
||||||
self.sim_label.configure(text=f"[Simulating: {data}]")
|
self.sim_label.configure(text=f"[Simulating: {data}]")
|
||||||
|
|||||||
@@ -1757,6 +1757,132 @@ class EMGClassifier:
|
|||||||
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
|
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
|
def export_to_header(self, filepath: Path) -> Path:
|
||||||
|
"""
|
||||||
|
Export trained model to a C header file for ESP32 inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filepath: Output .h file path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the saved header file
|
||||||
|
"""
|
||||||
|
if not self.is_trained:
|
||||||
|
raise ValueError("Cannot export untrained classifier!")
|
||||||
|
|
||||||
|
filepath = Path(filepath)
|
||||||
|
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
n_classes = len(self.label_names)
|
||||||
|
n_features = len(self.feature_names)
|
||||||
|
|
||||||
|
# Get LDA parameters
|
||||||
|
# coef_: (n_classes, n_features) - access as [class][feature]
|
||||||
|
# intercept_: (n_classes,)
|
||||||
|
coefs = self.lda.coef_
|
||||||
|
intercepts = self.lda.intercept_
|
||||||
|
|
||||||
|
# Add logic for binary classification (sklearn stores only 1 set of coefs)
|
||||||
|
# For >2 classes, it stores n_classes sets.
|
||||||
|
if n_classes == 2:
|
||||||
|
# Binary case: coef_ is (1, n_features), intercept_ is (1,)
|
||||||
|
# We need to expand this to 2 classes for the C inference engine to be generic.
|
||||||
|
# Class 1 decision = dot(w, x) + b
|
||||||
|
# Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function
|
||||||
|
# BUT: decision_function returns score. A generic 'argmax' approach usually expects
|
||||||
|
# one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial.
|
||||||
|
# Let's check sklearn docs or behavior.
|
||||||
|
# Actually, LDA in sklearn for binary case is special.
|
||||||
|
# To make C code simple (always argmax), let's explicitly store 2 rows.
|
||||||
|
# Row 1 (Index 1 in sklearn): coef, intercept
|
||||||
|
# Row 0 (Index 0): -coef, -intercept ?
|
||||||
|
# Wait, LDA is generative. The decision boundary is linear.
|
||||||
|
# Let's assume Multiclass for now or handle binary specifically.
|
||||||
|
# For simplicity in C, we prefer (n_classes, n_features).
|
||||||
|
# If coefs.shape[0] != n_classes, we need to handle it.
|
||||||
|
if coefs.shape[0] == 1:
|
||||||
|
print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.")
|
||||||
|
# Class 1 (positive)
|
||||||
|
c1_coef = coefs[0]
|
||||||
|
c1_int = intercepts[0]
|
||||||
|
# Class 0 (negative) - Effectively -score for decision boundary at 0
|
||||||
|
# But strictly speaking LDA is comparison of log-posteriors.
|
||||||
|
# Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0).
|
||||||
|
# The score S = coef.X + intercept. If S > 0 pred class 1, else 0.
|
||||||
|
# To map this to ArgMax(Score0, Score1):
|
||||||
|
# We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2.
|
||||||
|
# Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier).
|
||||||
|
# Safest: Let's trust that for our 5-gesture demo, it's multiclass.
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Generate C content
|
||||||
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
c_content = [
|
||||||
|
"/**",
|
||||||
|
f" * @file {filepath.name}",
|
||||||
|
" * @brief Trained LDA model weights exported from Python.",
|
||||||
|
f" * @date {timestamp}",
|
||||||
|
" */",
|
||||||
|
"",
|
||||||
|
"#ifndef MODEL_WEIGHTS_H",
|
||||||
|
"#define MODEL_WEIGHTS_H",
|
||||||
|
"",
|
||||||
|
"#include <stdint.h>",
|
||||||
|
"",
|
||||||
|
"/* Metadata */",
|
||||||
|
f"#define MODEL_NUM_CLASSES {n_classes}",
|
||||||
|
f"#define MODEL_NUM_FEATURES {n_features}",
|
||||||
|
"",
|
||||||
|
"/* Class Names */",
|
||||||
|
"static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in self.label_names:
|
||||||
|
c_content.append(f' "{name}",')
|
||||||
|
c_content.append("};")
|
||||||
|
c_content.append("")
|
||||||
|
|
||||||
|
c_content.append("/* Feature Extractor Parameters */")
|
||||||
|
c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f")
|
||||||
|
c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f")
|
||||||
|
c_content.append("")
|
||||||
|
|
||||||
|
c_content.append("/* LDA Intercepts/Biases */")
|
||||||
|
c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{")
|
||||||
|
line = " "
|
||||||
|
for val in intercepts:
|
||||||
|
line += f"{val:.6f}f, "
|
||||||
|
c_content.append(line.rstrip(", "))
|
||||||
|
c_content.append("};")
|
||||||
|
c_content.append("")
|
||||||
|
|
||||||
|
c_content.append("/* LDA Coefficients (Weights) */")
|
||||||
|
c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{")
|
||||||
|
|
||||||
|
for i, row in enumerate(coefs):
|
||||||
|
c_content.append(f" /* {self.label_names[i]} */")
|
||||||
|
c_content.append(" {")
|
||||||
|
line = " "
|
||||||
|
for j, val in enumerate(row):
|
||||||
|
line += f"{val:.6f}f, "
|
||||||
|
if (j + 1) % 8 == 0:
|
||||||
|
c_content.append(line)
|
||||||
|
line = " "
|
||||||
|
if line.strip():
|
||||||
|
c_content.append(line.rstrip(", "))
|
||||||
|
c_content.append(" },")
|
||||||
|
|
||||||
|
c_content.append("};")
|
||||||
|
c_content.append("")
|
||||||
|
c_content.append("#endif /* MODEL_WEIGHTS_H */")
|
||||||
|
|
||||||
|
with open(filepath, 'w') as f:
|
||||||
|
f.write('\n'.join(c_content))
|
||||||
|
|
||||||
|
print(f"[Classifier] Model weights exported to: {filepath}")
|
||||||
|
return filepath
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, filepath: Path) -> 'EMGClassifier':
|
def load(cls, filepath: Path) -> 'EMGClassifier':
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user