overhaul
This commit is contained in:
@@ -20,6 +20,7 @@ set(DRIVER_SOURCES
|
||||
|
||||
set(CORE_SOURCES
|
||||
core/gestures.c
|
||||
core/inference.c
|
||||
)
|
||||
|
||||
set(APP_SOURCES
|
||||
|
||||
@@ -12,17 +12,18 @@
|
||||
* @note This is Layer 4 (Application).
|
||||
*/
|
||||
|
||||
#include "esp_timer.h"
|
||||
#include <freertos/FreeRTOS.h>
|
||||
#include <freertos/queue.h>
|
||||
#include <freertos/task.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
#include <freertos/FreeRTOS.h>
|
||||
#include <freertos/task.h>
|
||||
#include <freertos/queue.h>
|
||||
#include "esp_timer.h"
|
||||
|
||||
#include "config/config.h"
|
||||
#include "drivers/hand.h"
|
||||
#include "drivers/emg_sensor.h"
|
||||
#include "core/gestures.h"
|
||||
#include "core/inference.h" // [NEW]
|
||||
#include "drivers/emg_sensor.h"
|
||||
#include "drivers/hand.h"
|
||||
|
||||
/*******************************************************************************
|
||||
* Constants
|
||||
@@ -41,7 +42,8 @@
|
||||
typedef enum {
|
||||
STATE_IDLE = 0, /**< Waiting for connect command */
|
||||
STATE_CONNECTED, /**< Connected, waiting for start command */
|
||||
STATE_STREAMING, /**< Actively streaming EMG data */
|
||||
STATE_STREAMING, /**< Actively streaming raw EMG data (for training) */
|
||||
STATE_PREDICTING, /**< [NEW] On-device inference and control */
|
||||
} device_state_t;
|
||||
|
||||
/**
|
||||
@@ -50,7 +52,8 @@ typedef enum {
|
||||
typedef enum {
|
||||
CMD_NONE = 0,
|
||||
CMD_CONNECT,
|
||||
CMD_START,
|
||||
CMD_START, /**< Start raw streaming */
|
||||
CMD_START_PREDICT, /**< [NEW] Start on-device prediction */
|
||||
CMD_STOP,
|
||||
CMD_DISCONNECT,
|
||||
} command_t;
|
||||
@@ -80,16 +83,15 @@ static void send_ack_connect(void);
|
||||
* @param line Input line from serial
|
||||
* @return Parsed command
|
||||
*/
|
||||
static command_t parse_command(const char* line)
|
||||
{
|
||||
static command_t parse_command(const char *line) {
|
||||
/* Simple JSON parsing - look for "cmd" field */
|
||||
const char* cmd_start = strstr(line, "\"cmd\"");
|
||||
const char *cmd_start = strstr(line, "\"cmd\"");
|
||||
if (!cmd_start) {
|
||||
return CMD_NONE;
|
||||
}
|
||||
|
||||
/* Find the value after "cmd": */
|
||||
const char* value_start = strchr(cmd_start, ':');
|
||||
const char *value_start = strchr(cmd_start, ':');
|
||||
if (!value_start) {
|
||||
return CMD_NONE;
|
||||
}
|
||||
@@ -103,6 +105,8 @@ static command_t parse_command(const char* line)
|
||||
/* Match command strings */
|
||||
if (strncmp(value_start, "connect", 7) == 0) {
|
||||
return CMD_CONNECT;
|
||||
} else if (strncmp(value_start, "start_predict", 13) == 0) {
|
||||
return CMD_START_PREDICT;
|
||||
} else if (strncmp(value_start, "start", 5) == 0) {
|
||||
return CMD_START;
|
||||
} else if (strncmp(value_start, "stop", 4) == 0) {
|
||||
@@ -120,55 +124,42 @@ static command_t parse_command(const char* line)
|
||||
|
||||
/**
|
||||
* @brief FreeRTOS task to read serial input and parse commands.
|
||||
*
|
||||
* This task runs continuously, reading lines from stdin (USB serial)
|
||||
* and updating device state directly. This allows commands to interrupt
|
||||
* streaming immediately via the volatile state variable.
|
||||
*
|
||||
* @param pvParameters Unused
|
||||
*/
|
||||
static void serial_input_task(void* pvParameters)
|
||||
{
|
||||
static void serial_input_task(void *pvParameters) {
|
||||
char line_buffer[CMD_BUFFER_SIZE];
|
||||
int line_idx = 0;
|
||||
|
||||
while (1) {
|
||||
/* Read one character at a time */
|
||||
int c = getchar();
|
||||
|
||||
if (c == EOF || c == 0xFF) {
|
||||
/* No data available, yield to other tasks */
|
||||
vTaskDelay(pdMS_TO_TICKS(10));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (c == '\n' || c == '\r') {
|
||||
/* End of line - process command */
|
||||
if (line_idx > 0) {
|
||||
line_buffer[line_idx] = '\0';
|
||||
|
||||
command_t cmd = parse_command(line_buffer);
|
||||
|
||||
if (cmd != CMD_NONE) {
|
||||
/* Handle CONNECT command from ANY state (for reconnection/recovery) */
|
||||
if (cmd == CMD_CONNECT) {
|
||||
/* Stop streaming if active, reset to CONNECTED state */
|
||||
g_device_state = STATE_CONNECTED;
|
||||
send_ack_connect();
|
||||
printf("[STATE] ANY -> CONNECTED (reconnect)\n");
|
||||
}
|
||||
/* Handle other state transitions */
|
||||
else {
|
||||
} else {
|
||||
switch (g_device_state) {
|
||||
case STATE_IDLE:
|
||||
/* Only CONNECT allowed from IDLE (handled above) */
|
||||
break;
|
||||
|
||||
case STATE_CONNECTED:
|
||||
if (cmd == CMD_START) {
|
||||
g_device_state = STATE_STREAMING;
|
||||
printf("[STATE] CONNECTED -> STREAMING\n");
|
||||
/* Signal state machine to start streaming */
|
||||
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||
} else if (cmd == CMD_START_PREDICT) {
|
||||
g_device_state = STATE_PREDICTING;
|
||||
printf("[STATE] CONNECTED -> PREDICTING\n");
|
||||
xQueueSend(g_cmd_queue, &cmd, 0);
|
||||
} else if (cmd == CMD_DISCONNECT) {
|
||||
g_device_state = STATE_IDLE;
|
||||
@@ -177,27 +168,23 @@ static void serial_input_task(void* pvParameters)
|
||||
break;
|
||||
|
||||
case STATE_STREAMING:
|
||||
case STATE_PREDICTING:
|
||||
if (cmd == CMD_STOP) {
|
||||
g_device_state = STATE_CONNECTED;
|
||||
printf("[STATE] STREAMING -> CONNECTED\n");
|
||||
/* Streaming loop will exit when it sees state change */
|
||||
printf("[STATE] ACTIVE -> CONNECTED\n");
|
||||
} else if (cmd == CMD_DISCONNECT) {
|
||||
g_device_state = STATE_IDLE;
|
||||
printf("[STATE] STREAMING -> IDLE\n");
|
||||
/* Streaming loop will exit when it sees state change */
|
||||
printf("[STATE] ACTIVE -> IDLE\n");
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
line_idx = 0;
|
||||
}
|
||||
} else if (line_idx < CMD_BUFFER_SIZE - 1) {
|
||||
/* Add character to buffer */
|
||||
line_buffer[line_idx++] = (char)c;
|
||||
} else {
|
||||
/* Buffer overflow - reset */
|
||||
line_idx = 0;
|
||||
}
|
||||
}
|
||||
@@ -207,107 +194,113 @@ static void serial_input_task(void* pvParameters)
|
||||
* State Machine
|
||||
******************************************************************************/
|
||||
|
||||
/**
|
||||
* @brief Send JSON acknowledgment for connection.
|
||||
*/
|
||||
static void send_ack_connect(void)
|
||||
{
|
||||
printf("{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n",
|
||||
static void send_ack_connect(void) {
|
||||
printf(
|
||||
"{\"status\":\"ack_connect\",\"device\":\"ESP32-EMG\",\"channels\":%d}\n",
|
||||
EMG_NUM_CHANNELS);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stream EMG data continuously until stopped.
|
||||
*
|
||||
* This function blocks and streams data at the configured sample rate.
|
||||
* Returns when state changes from STREAMING.
|
||||
* @brief Stream raw EMG data (Training Mode).
|
||||
*/
|
||||
static void stream_emg_data(void)
|
||||
{
|
||||
static void stream_emg_data(void) {
|
||||
emg_sample_t sample;
|
||||
const TickType_t delay_ticks = 1; /* 1 tick = 1ms at 1000 Hz tick rate */
|
||||
const TickType_t delay_ticks = 1;
|
||||
|
||||
while (g_device_state == STATE_STREAMING) {
|
||||
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
|
||||
emg_sensor_read(&sample);
|
||||
|
||||
/* Output in CSV format - channels only, Python handles timestamps */
|
||||
printf("%u,%u,%u,%u\n",
|
||||
sample.channels[0],
|
||||
sample.channels[1],
|
||||
sample.channels[2],
|
||||
sample.channels[3]);
|
||||
|
||||
/* Yield to FreeRTOS scheduler - prevents watchdog timeout */
|
||||
printf("%u,%u,%u,%u\n", sample.channels[0], sample.channels[1],
|
||||
sample.channels[2], sample.channels[3]);
|
||||
vTaskDelay(delay_ticks);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Main state machine loop.
|
||||
*
|
||||
* Monitors device state and starts streaming when requested.
|
||||
* Serial input task handles all state transitions directly.
|
||||
* @brief Run on-device inference (Prediction Mode).
|
||||
*/
|
||||
static void state_machine_loop(void)
|
||||
{
|
||||
static void run_inference_loop(void) {
|
||||
emg_sample_t sample;
|
||||
const TickType_t delay_ticks = 1; // 1ms @ 1kHz
|
||||
int last_gesture = -1;
|
||||
|
||||
// Reset inference state
|
||||
inference_init();
|
||||
printf("{\"status\":\"info\",\"msg\":\"Inference started\"}\n");
|
||||
|
||||
while (g_device_state == STATE_PREDICTING) {
|
||||
emg_sensor_read(&sample);
|
||||
|
||||
// Add to buffer
|
||||
// Note: sample.channels is uint16_t, matching inference engine expectation
|
||||
if (inference_add_sample(sample.channels)) {
|
||||
// Buffer full (sliding window), run prediction
|
||||
// We can optimize stride here (e.g. valid prediction only every N
|
||||
// samples) For now, let's predict every sample (sliding window) or
|
||||
// throttle if too slow. ESP32S3 is fast enough for 4ch features @ 1kHz?
|
||||
// maybe. Let's degrade to 50Hz updates (20ms stride) to be safe and avoid
|
||||
// UART spam.
|
||||
|
||||
static int stride_counter = 0;
|
||||
stride_counter++;
|
||||
|
||||
if (stride_counter >= 20) { // 20ms stride
|
||||
float confidence = 0;
|
||||
int gesture_idx = inference_predict(&confidence);
|
||||
stride_counter = 0;
|
||||
|
||||
if (gesture_idx >= 0) {
|
||||
// Map class index (0-N) to gesture enum (correct hardware action)
|
||||
int gesture_enum = inference_get_gesture_enum(gesture_idx);
|
||||
|
||||
// Execute gesture on hand
|
||||
gestures_execute((gesture_t)gesture_enum);
|
||||
|
||||
// Send telemetry if changed or periodically?
|
||||
// "Live prediction flow should change to only have each new output...
|
||||
// sent"
|
||||
if (gesture_idx != last_gesture) {
|
||||
printf("{\"gesture\":\"%s\",\"conf\":%.2f}\n",
|
||||
inference_get_class_name(gesture_idx), confidence);
|
||||
last_gesture = gesture_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
vTaskDelay(delay_ticks);
|
||||
}
|
||||
}
|
||||
|
||||
static void state_machine_loop(void) {
|
||||
command_t cmd;
|
||||
const TickType_t poll_interval = pdMS_TO_TICKS(50);
|
||||
|
||||
while (1) {
|
||||
/* Check if we should start streaming */
|
||||
if (g_device_state == STATE_STREAMING) {
|
||||
/* Stream until state changes (via serial input task) */
|
||||
stream_emg_data();
|
||||
/* Returns when state is no longer STREAMING */
|
||||
} else if (g_device_state == STATE_PREDICTING) {
|
||||
run_inference_loop();
|
||||
}
|
||||
|
||||
/* Wait for start command or just poll state */
|
||||
/* Timeout allows checking state even if queue is empty */
|
||||
xQueueReceive(g_cmd_queue, &cmd, poll_interval);
|
||||
|
||||
/* Note: State transitions are handled by serial_input_task */
|
||||
/* This loop only triggers streaming when state becomes STREAMING */
|
||||
}
|
||||
}
|
||||
|
||||
void emgPrinter() {
|
||||
TickType_t previousWake = xTaskGetTickCount();
|
||||
while (1) {
|
||||
emg_sample_t sample;
|
||||
emg_sensor_read(&sample);
|
||||
for (uint8_t i = 0; i < EMG_NUM_CHANNELS; i++) {
|
||||
printf("%d", sample.channels[i]);
|
||||
if (i != EMG_NUM_CHANNELS - 1) printf(" | ");
|
||||
}
|
||||
printf("\n");
|
||||
// vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
|
||||
}
|
||||
}
|
||||
|
||||
void appConnector() {
|
||||
/* Create command queue */
|
||||
g_cmd_queue = xQueueCreate(10, sizeof(command_t));
|
||||
if (g_cmd_queue == NULL) {
|
||||
printf("[ERROR] Failed to create command queue!\n");
|
||||
return;
|
||||
}
|
||||
|
||||
/* Launch serial input task */
|
||||
xTaskCreate(
|
||||
serial_input_task,
|
||||
"serial_input",
|
||||
4096, /* Stack size */
|
||||
NULL, /* Parameters */
|
||||
5, /* Priority */
|
||||
NULL /* Task handle */
|
||||
);
|
||||
xTaskCreate(serial_input_task, "serial_input", 4096, NULL, 5, NULL);
|
||||
|
||||
printf("[PROTOCOL] Waiting for host to connect...\n");
|
||||
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n\n");
|
||||
printf("[PROTOCOL] Send: {\"cmd\": \"connect\"}\n");
|
||||
printf("[PROTOCOL] Send: {\"cmd\": \"start_predict\"} for on-device "
|
||||
"inference\n\n");
|
||||
|
||||
/* Run main state machine */
|
||||
state_machine_loop();
|
||||
}
|
||||
|
||||
@@ -315,21 +308,22 @@ void appConnector() {
|
||||
* Application Entry Point
|
||||
******************************************************************************/
|
||||
|
||||
void app_main(void)
|
||||
{
|
||||
void app_main(void) {
|
||||
printf("\n");
|
||||
printf("========================================\n");
|
||||
printf(" Bucky Arm - EMG Robotic Hand\n");
|
||||
printf(" Firmware v2.0.0 (Handshake Protocol)\n");
|
||||
printf(" Firmware v2.1.0 (Inference Enabled)\n");
|
||||
printf("========================================\n\n");
|
||||
|
||||
/* Initialize subsystems */
|
||||
printf("[INIT] Initializing hand (servos)...\n");
|
||||
hand_init();
|
||||
|
||||
printf("[INIT] Initializing EMG sensor...\n");
|
||||
emg_sensor_init();
|
||||
|
||||
printf("[INIT] Initializing Inference Engine...\n");
|
||||
inference_init();
|
||||
|
||||
#if FEATURE_FAKE_EMG
|
||||
printf("[INIT] Using FAKE EMG data (sensors not connected)\n");
|
||||
#else
|
||||
@@ -338,6 +332,5 @@ void app_main(void)
|
||||
|
||||
printf("[INIT] Done!\n\n");
|
||||
|
||||
// emgPrinter();
|
||||
appConnector();
|
||||
}
|
||||
|
||||
281
EMG_Arm/src/core/inference.c
Normal file
281
EMG_Arm/src/core/inference.c
Normal file
@@ -0,0 +1,281 @@
|
||||
/**
|
||||
* @file inference.c
|
||||
* @brief Implementation of EMG inference engine.
|
||||
*/
|
||||
|
||||
#include "inference.h"
|
||||
#include "config/config.h"
|
||||
#include "model_weights.h"
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
// --- Constants ---
|
||||
#define SMOOTHING_FACTOR 0.7f // EMA factor for probability (matches Python)
|
||||
#define VOTE_WINDOW 5 // Majority vote window size
|
||||
#define DEBOUNCE_COUNT 3 // Confirmations needed to change output
|
||||
|
||||
// --- State ---
|
||||
static uint16_t window_buffer[INFERENCE_WINDOW_SIZE][NUM_CHANNELS];
|
||||
static int buffer_head = 0;
|
||||
static int samples_collected = 0;
|
||||
|
||||
// Smoothing State
|
||||
static float smoothed_probs[MODEL_NUM_CLASSES];
|
||||
static int vote_history[VOTE_WINDOW];
|
||||
static int vote_head = 0;
|
||||
static int current_output = -1;
|
||||
static int pending_output = -1;
|
||||
static int pending_count = 0;
|
||||
|
||||
void inference_init(void) {
|
||||
memset(window_buffer, 0, sizeof(window_buffer));
|
||||
buffer_head = 0;
|
||||
samples_collected = 0;
|
||||
|
||||
// Initialize smoothing
|
||||
for (int i = 0; i < MODEL_NUM_CLASSES; i++) {
|
||||
smoothed_probs[i] = 1.0f / MODEL_NUM_CLASSES;
|
||||
}
|
||||
for (int i = 0; i < VOTE_WINDOW; i++) {
|
||||
vote_history[i] = -1;
|
||||
}
|
||||
vote_head = 0;
|
||||
current_output = -1;
|
||||
pending_output = -1;
|
||||
pending_count = 0;
|
||||
}
|
||||
|
||||
bool inference_add_sample(uint16_t *channels) {
|
||||
// Add to circular buffer
|
||||
for (int i = 0; i < NUM_CHANNELS; i++) {
|
||||
window_buffer[buffer_head][i] = channels[i];
|
||||
}
|
||||
|
||||
buffer_head = (buffer_head + 1) % INFERENCE_WINDOW_SIZE;
|
||||
|
||||
if (samples_collected < INFERENCE_WINDOW_SIZE) {
|
||||
samples_collected++;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true; // Buffer is full (always ready in sliding window, but caller
|
||||
// controls stride)
|
||||
}
|
||||
|
||||
// --- Feature Extraction ---
|
||||
|
||||
static void compute_features(float *features_out) {
|
||||
// Process each channel
|
||||
// We need to iterate over the logical window (unrolling circular buffer)
|
||||
|
||||
for (int ch = 0; ch < NUM_CHANNELS; ch++) {
|
||||
float sum = 0;
|
||||
float sq_sum = 0;
|
||||
|
||||
// Pass 1: Mean (for centering) and raw values collection
|
||||
// We could optimize by not copying, but accessing logically is safer
|
||||
float signal[INFERENCE_WINDOW_SIZE];
|
||||
|
||||
int idx = buffer_head; // Oldest sample
|
||||
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||
signal[i] = (float)window_buffer[idx][ch];
|
||||
sum += signal[i];
|
||||
idx = (idx + 1) % INFERENCE_WINDOW_SIZE;
|
||||
}
|
||||
|
||||
float mean = sum / INFERENCE_WINDOW_SIZE;
|
||||
|
||||
// Pass 2: Centering and Features
|
||||
float wl = 0;
|
||||
int zc = 0;
|
||||
int ssc = 0;
|
||||
|
||||
// Center the signal
|
||||
for (int i = 0; i < INFERENCE_WINDOW_SIZE; i++) {
|
||||
signal[i] -= mean;
|
||||
sq_sum += signal[i] * signal[i];
|
||||
}
|
||||
|
||||
float rms = sqrtf(sq_sum / INFERENCE_WINDOW_SIZE);
|
||||
|
||||
// Thresholds
|
||||
float zc_thresh = FEAT_ZC_THRESH * rms;
|
||||
float ssc_thresh = (FEAT_SSC_THRESH * rms) *
|
||||
(FEAT_SSC_THRESH * rms); // threshold is on diff product
|
||||
|
||||
for (int i = 0; i < INFERENCE_WINDOW_SIZE - 1; i++) {
|
||||
// WL
|
||||
wl += fabsf(signal[i + 1] - signal[i]);
|
||||
|
||||
// ZC
|
||||
if ((signal[i] > 0 && signal[i + 1] < 0) ||
|
||||
(signal[i] < 0 && signal[i + 1] > 0)) {
|
||||
if (fabsf(signal[i] - signal[i + 1]) > zc_thresh) {
|
||||
zc++;
|
||||
}
|
||||
}
|
||||
|
||||
// SSC (needs 3 points, so loop to N-2)
|
||||
if (i < INFERENCE_WINDOW_SIZE - 2) {
|
||||
float diff1 = signal[i + 1] - signal[i];
|
||||
float diff2 = signal[i + 1] - signal[i + 2];
|
||||
if ((diff1 * diff2) > ssc_thresh) {
|
||||
ssc++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store features: [RMS, WL, ZC, SSC] per channel
|
||||
int base = ch * 4;
|
||||
features_out[base + 0] = rms;
|
||||
features_out[base + 1] = wl;
|
||||
features_out[base + 2] = (float)zc;
|
||||
features_out[base + 3] = (float)ssc;
|
||||
}
|
||||
}
|
||||
|
||||
// --- Prediction ---
|
||||
|
||||
int inference_predict(float *confidence) {
|
||||
if (samples_collected < INFERENCE_WINDOW_SIZE) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// 1. Extract Features
|
||||
float features[MODEL_NUM_FEATURES];
|
||||
compute_features(features);
|
||||
|
||||
// 2. LDA Inference (Linear Score)
|
||||
float raw_scores[MODEL_NUM_CLASSES];
|
||||
float max_score = -1e9;
|
||||
int max_idx = 0;
|
||||
|
||||
// Calculate raw discriminative scores
|
||||
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||
float score = LDA_INTERCEPTS[c];
|
||||
for (int f = 0; f < MODEL_NUM_FEATURES; f++) {
|
||||
score += features[f] * LDA_WEIGHTS[c][f];
|
||||
}
|
||||
raw_scores[c] = score;
|
||||
}
|
||||
|
||||
// Convert scores to probabilities (Softmax)
|
||||
// LDA scores are log-likelihoods + const. Softmax is appropriate.
|
||||
float sum_exp = 0;
|
||||
float probas[MODEL_NUM_CLASSES];
|
||||
|
||||
// Numerical stability: subtract max
|
||||
// Create temp copy for max finding
|
||||
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||
if (raw_scores[c] > max_score)
|
||||
max_score = raw_scores[c];
|
||||
}
|
||||
|
||||
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||
probas[c] = expf(raw_scores[c] - max_score);
|
||||
sum_exp += probas[c];
|
||||
}
|
||||
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||
probas[c] /= sum_exp;
|
||||
}
|
||||
|
||||
// 3. Smoothing
|
||||
// 3a. Probability EMA
|
||||
float max_smoothed_prob = 0;
|
||||
int smoothed_winner = 0;
|
||||
|
||||
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||
smoothed_probs[c] = (SMOOTHING_FACTOR * smoothed_probs[c]) +
|
||||
((1.0f - SMOOTHING_FACTOR) * probas[c]);
|
||||
|
||||
if (smoothed_probs[c] > max_smoothed_prob) {
|
||||
max_smoothed_prob = smoothed_probs[c];
|
||||
smoothed_winner = c;
|
||||
}
|
||||
}
|
||||
|
||||
// 3b. Majority Vote
|
||||
vote_history[vote_head] = smoothed_winner;
|
||||
vote_head = (vote_head + 1) % VOTE_WINDOW;
|
||||
|
||||
int counts[MODEL_NUM_CLASSES];
|
||||
memset(counts, 0, sizeof(counts));
|
||||
|
||||
for (int i = 0; i < VOTE_WINDOW; i++) {
|
||||
if (vote_history[i] != -1) {
|
||||
counts[vote_history[i]]++;
|
||||
}
|
||||
}
|
||||
|
||||
int majority_winner = 0;
|
||||
int majority_count = 0;
|
||||
for (int c = 0; c < MODEL_NUM_CLASSES; c++) {
|
||||
if (counts[c] > majority_count) {
|
||||
majority_count = counts[c];
|
||||
majority_winner = c;
|
||||
}
|
||||
}
|
||||
|
||||
// 3c. Debounce
|
||||
int final_result = current_output;
|
||||
|
||||
if (current_output == -1) {
|
||||
current_output = majority_winner;
|
||||
pending_output = majority_winner;
|
||||
pending_count = 1;
|
||||
final_result = majority_winner;
|
||||
} else if (majority_winner == current_output) {
|
||||
pending_output = majority_winner;
|
||||
pending_count = 1;
|
||||
} else if (majority_winner == pending_output) {
|
||||
pending_count++;
|
||||
if (pending_count >= DEBOUNCE_COUNT) {
|
||||
current_output = majority_winner;
|
||||
final_result = majority_winner;
|
||||
}
|
||||
} else {
|
||||
pending_output = majority_winner;
|
||||
pending_count = 1;
|
||||
}
|
||||
|
||||
// Use smoothed probability of the final winner as confidence
|
||||
// Or simpler: use fraction of votes
|
||||
*confidence = (float)majority_count / VOTE_WINDOW;
|
||||
|
||||
return final_result;
|
||||
}
|
||||
|
||||
const char *inference_get_class_name(int class_idx) {
|
||||
if (class_idx >= 0 && class_idx < MODEL_NUM_CLASSES) {
|
||||
return MODEL_CLASS_NAMES[class_idx];
|
||||
}
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
||||
int inference_get_gesture_enum(int class_idx) {
|
||||
const char *name = inference_get_class_name(class_idx);
|
||||
|
||||
// Map string name to gesture_t enum
|
||||
// Strings must match those in Python list: ["fist", "hook_em", "open",
|
||||
// "rest", "thumbs_up"] Note: Python strings are lowercase, config.h enums
|
||||
// are: GESTURE_NONE=0, REST=1, FIST=2, OPEN=3, HOOK_EM=4, THUMBS_UP=5
|
||||
|
||||
// Case-insensitive check would be safer, but let's assume Python output is
|
||||
// lowercase as seen in scripts or uppercase if specified. In
|
||||
// learning_data_collection.py, they seem to be "rest", "open", "fist", etc.
|
||||
|
||||
// Simple string matching
|
||||
if (strcmp(name, "rest") == 0 || strcmp(name, "REST") == 0)
|
||||
return GESTURE_REST;
|
||||
if (strcmp(name, "fist") == 0 || strcmp(name, "FIST") == 0)
|
||||
return GESTURE_FIST;
|
||||
if (strcmp(name, "open") == 0 || strcmp(name, "OPEN") == 0)
|
||||
return GESTURE_OPEN;
|
||||
if (strcmp(name, "hook_em") == 0 || strcmp(name, "HOOK_EM") == 0)
|
||||
return GESTURE_HOOK_EM;
|
||||
if (strcmp(name, "thumbs_up") == 0 || strcmp(name, "THUMBS_UP") == 0)
|
||||
return GESTURE_THUMBS_UP;
|
||||
|
||||
return GESTURE_NONE;
|
||||
}
|
||||
47
EMG_Arm/src/core/inference.h
Normal file
47
EMG_Arm/src/core/inference.h
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* @file inference.h
|
||||
* @brief On-device inference engine for EMG gesture recognition.
|
||||
*/
|
||||
|
||||
#ifndef INFERENCE_H
|
||||
#define INFERENCE_H
|
||||
|
||||
#include <stdbool.h>
|
||||
#include <stdint.h>
|
||||
|
||||
// --- Configuration ---
|
||||
#define INFERENCE_WINDOW_SIZE 150 // Window size in samples (must match Python)
|
||||
#define NUM_CHANNELS 4 // Number of EMG channels
|
||||
|
||||
/**
|
||||
* @brief Initialize the inference engine.
|
||||
*/
|
||||
void inference_init(void);
|
||||
|
||||
/**
|
||||
* @brief Add a sample to the inference buffer.
|
||||
*
|
||||
* @param channels Array of 4 channel values (raw ADC)
|
||||
* @return true if a full window is ready for processing
|
||||
*/
|
||||
bool inference_add_sample(uint16_t *channels);
|
||||
|
||||
/**
|
||||
* @brief Run inference on the current window.
|
||||
*
|
||||
* @param confidence Output pointer for confidence score (0.0 - 1.0)
|
||||
* @return Detected class index (-1 if error)
|
||||
*/
|
||||
int inference_predict(float *confidence);
|
||||
|
||||
/**
|
||||
* @brief Get the name of a class index.
|
||||
*/
|
||||
const char *inference_get_class_name(int class_idx);
|
||||
|
||||
/**
|
||||
* @brief Map class index to gesture_t enum.
|
||||
*/
|
||||
int inference_get_gesture_enum(int class_idx);
|
||||
|
||||
#endif /* INFERENCE_H */
|
||||
44
EMG_Arm/src/core/model_weights.h
Normal file
44
EMG_Arm/src/core/model_weights.h
Normal file
@@ -0,0 +1,44 @@
|
||||
/**
|
||||
* @file model_weights.h
|
||||
* @brief Placeholder for trained model weights.
|
||||
*
|
||||
* This file should be generated by the Python script (Training Page -> Export for ESP32).
|
||||
*/
|
||||
|
||||
#ifndef MODEL_WEIGHTS_H
|
||||
#define MODEL_WEIGHTS_H
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
/* Metadata */
|
||||
#define MODEL_NUM_CLASSES 5
|
||||
#define MODEL_NUM_FEATURES 16
|
||||
|
||||
/* Class Names */
|
||||
static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {
|
||||
"REST",
|
||||
"OPEN",
|
||||
"FIST",
|
||||
"HOOK_EM",
|
||||
"THUMBS_UP",
|
||||
};
|
||||
|
||||
/* Feature Extractor Parameters */
|
||||
#define FEAT_ZC_THRESH 0.1f
|
||||
#define FEAT_SSC_THRESH 0.1f
|
||||
|
||||
/* LDA Intercepts/Biases */
|
||||
static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {
|
||||
0.0f, 0.0f, 0.0f, 0.0f, 0.0f
|
||||
};
|
||||
|
||||
/* LDA Coefficients (Weights) */
|
||||
static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {
|
||||
{0.0f},
|
||||
{0.0f},
|
||||
{0.0f},
|
||||
{0.0f},
|
||||
{0.0f}
|
||||
};
|
||||
|
||||
#endif /* MODEL_WEIGHTS_H */
|
||||
255
emg_gui.py
255
emg_gui.py
@@ -1322,6 +1322,18 @@ class TrainingPage(BasePage):
|
||||
)
|
||||
self.train_button.pack(pady=20)
|
||||
|
||||
# Export button
|
||||
self.export_button = ctk.CTkButton(
|
||||
self.content,
|
||||
text="Export for ESP32",
|
||||
font=ctk.CTkFont(size=14),
|
||||
height=40,
|
||||
fg_color="green",
|
||||
state="disabled",
|
||||
command=self.export_model
|
||||
)
|
||||
self.export_button.pack(pady=5)
|
||||
|
||||
# Progress
|
||||
self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
|
||||
self.progress_bar.pack(pady=10)
|
||||
@@ -1434,6 +1446,32 @@ class TrainingPage(BasePage):
|
||||
|
||||
finally:
|
||||
self.after(0, lambda: self.train_button.configure(state="normal"))
|
||||
self.after(0, lambda: self.export_button.configure(state="normal"))
|
||||
|
||||
def export_model(self):
|
||||
"""Export trained model to C header."""
|
||||
if not self.classifier or not self.classifier.is_trained:
|
||||
messagebox.showerror("Error", "No trained model to export!")
|
||||
return
|
||||
|
||||
# Default path in ESP32 project
|
||||
default_path = Path("EMG_Arm/src/core/model_weights.h").absolute()
|
||||
|
||||
# Ask user for location, defaulting to the ESP32 project source
|
||||
filename = tk.filedialog.asksaveasfilename(
|
||||
title="Export Model Header",
|
||||
initialdir=default_path.parent,
|
||||
initialfile=default_path.name,
|
||||
filetypes=[("C Header", "*.h")]
|
||||
)
|
||||
|
||||
if filename:
|
||||
try:
|
||||
path = self.classifier.export_to_header(filename)
|
||||
self._log(f"\nExported model to: {path}")
|
||||
messagebox.showinfo("Export Success", f"Model exported to:\n{path}\n\nRecompile ESP32 firmware to apply.")
|
||||
except Exception as e:
|
||||
messagebox.showerror("Export Error", f"Failed to export: {e}")
|
||||
|
||||
def _log(self, text: str):
|
||||
"""Add text to results."""
|
||||
@@ -1861,116 +1899,159 @@ class PredictionPage(BasePage):
|
||||
self._update_connection_status("gray", "Disconnected")
|
||||
self.connect_button.configure(text="Connect")
|
||||
|
||||
def _prediction_thread(self):
|
||||
"""Background prediction thread."""
|
||||
# For simulated mode, create new stream
|
||||
if not self.using_real_hardware:
|
||||
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
|
||||
def toggle_prediction(self):
|
||||
"""Start or stop prediction."""
|
||||
if self.is_predicting:
|
||||
self.stop_prediction()
|
||||
else:
|
||||
self.start_prediction()
|
||||
|
||||
def start_prediction(self):
|
||||
"""Start live prediction."""
|
||||
# Determine mode
|
||||
self.using_real_hardware = (self.source_var.get() == "real")
|
||||
|
||||
if self.using_real_hardware:
|
||||
if not self.is_connected or not self.stream:
|
||||
messagebox.showerror("Not Connected", "Please connect to ESP32 first.")
|
||||
return
|
||||
|
||||
print("[DEBUG] Starting Edge Prediction (On-Device)...")
|
||||
try:
|
||||
# Send "start_predict" command to ESP32
|
||||
if hasattr(self.stream, 'ser'):
|
||||
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
|
||||
self.stream.running = True
|
||||
else:
|
||||
# Fallback
|
||||
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
|
||||
self.stream.running = True
|
||||
|
||||
except Exception as e:
|
||||
messagebox.showerror("Start Error", f"Failed to start: {e}")
|
||||
return
|
||||
|
||||
else:
|
||||
# Simulated - use PC-side inference
|
||||
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
|
||||
self.stream.start()
|
||||
|
||||
# Load model for PC-side (Simulated) OR for display (optional)
|
||||
# Even for Edge, we might want the label list.
|
||||
if not self.using_real_hardware:
|
||||
if not self.classifier:
|
||||
model_path = EMGClassifier.get_default_model_path()
|
||||
if model_path.exists():
|
||||
self.classifier = EMGClassifier.load(model_path)
|
||||
self.model_label.configure(text="Model: Loaded", text_color="green")
|
||||
else:
|
||||
self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange")
|
||||
|
||||
# Reset smoother
|
||||
self.smoother = PredictionSmoother(
|
||||
label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"],
|
||||
probability_smoothing=0.7,
|
||||
majority_vote_window=5,
|
||||
debounce_count=3
|
||||
)
|
||||
|
||||
self.is_predicting = True
|
||||
self.start_button.configure(text="Stop Prediction", fg_color="red")
|
||||
|
||||
# Start display loop
|
||||
self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True)
|
||||
self.prediction_thread.start()
|
||||
|
||||
self.update_prediction_ui()
|
||||
|
||||
def stop_prediction(self):
|
||||
"""Stop prediction."""
|
||||
self.is_predicting = False
|
||||
if self.stream:
|
||||
self.stream.stop() # Sends "stop" usually
|
||||
if not self.using_real_hardware:
|
||||
self.stream = None
|
||||
|
||||
self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"])
|
||||
self.prediction_label.configure(text="---", text_color="gray")
|
||||
self.confidence_label.configure(text="Confidence: ---%")
|
||||
self.confidence_bar.set(0)
|
||||
|
||||
def prediction_loop(self):
|
||||
"""Loop for reading data and (optionally) running inference."""
|
||||
import json
|
||||
|
||||
# Stream is already started (either via handshake for real HW or will be started for simulated)
|
||||
parser = EMGParser(num_channels=NUM_CHANNELS)
|
||||
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
|
||||
|
||||
# Simulated gesture cycling (only for simulated mode)
|
||||
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
|
||||
gesture_idx = 0
|
||||
gesture_duration = 2.5
|
||||
gesture_start = time.perf_counter()
|
||||
current_gesture = gesture_cycle[0]
|
||||
|
||||
# Start simulated stream if needed
|
||||
if not self.using_real_hardware:
|
||||
try:
|
||||
if hasattr(self.stream, 'set_gesture'):
|
||||
self.stream.set_gesture(current_gesture)
|
||||
self.stream.start()
|
||||
except Exception as e:
|
||||
self.data_queue.put(('error', f"Failed to start simulated stream: {e}"))
|
||||
return
|
||||
else:
|
||||
# Real hardware is already streaming
|
||||
self.data_queue.put(('connection_status', ('green', 'Streaming')))
|
||||
|
||||
while self.is_predicting:
|
||||
# Change simulated gesture periodically (only for simulated mode)
|
||||
if hasattr(self.stream, 'set_gesture'):
|
||||
elapsed = time.perf_counter() - gesture_start
|
||||
if elapsed > gesture_duration:
|
||||
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
|
||||
gesture_start = time.perf_counter()
|
||||
current_gesture = gesture_cycle[gesture_idx]
|
||||
self.stream.set_gesture(current_gesture)
|
||||
self.data_queue.put(('sim_gesture', current_gesture))
|
||||
|
||||
# Read and process
|
||||
try:
|
||||
line = self.stream.readline()
|
||||
except Exception as e:
|
||||
# Only report error if we didn't intentionally stop
|
||||
if self.is_predicting:
|
||||
self.data_queue.put(('error', f"Serial read error: {e}"))
|
||||
break
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line:
|
||||
if self.using_real_hardware:
|
||||
# Edge Inference Mode: Expect JSON
|
||||
try:
|
||||
line = line.strip()
|
||||
if line.startswith('{'):
|
||||
data = json.loads(line)
|
||||
|
||||
if "gesture" in data:
|
||||
# Update UI with Edge Prediction
|
||||
gesture = data["gesture"]
|
||||
conf = float(data.get("conf", 0.0))
|
||||
|
||||
self.data_queue.put(('prediction', (gesture, conf)))
|
||||
|
||||
elif "status" in data:
|
||||
print(f"[ESP32] {data}")
|
||||
else:
|
||||
pass
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
else:
|
||||
# PC Side Inference (Simulated)
|
||||
sample = parser.parse_line(line)
|
||||
if sample:
|
||||
window = windower.add_sample(sample)
|
||||
if window:
|
||||
# Get raw prediction
|
||||
window_data = window.to_numpy()
|
||||
raw_label, proba = self.classifier.predict(window_data)
|
||||
raw_confidence = max(proba) * 100
|
||||
if window and self.classifier:
|
||||
# Run Inference Local
|
||||
raw_label, proba = self.classifier.predict(window.to_numpy())
|
||||
label, conf, _ = self.smoother.update(raw_label, proba)
|
||||
|
||||
# Apply smoothing
|
||||
smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba)
|
||||
smoothed_confidence = smoothed_conf * 100
|
||||
self.data_queue.put(('prediction', (label, conf)))
|
||||
|
||||
# Send both raw and smoothed to UI
|
||||
self.data_queue.put(('prediction', (
|
||||
smoothed_label, # The stable output
|
||||
smoothed_confidence,
|
||||
raw_label, # The raw (possibly twitchy) output
|
||||
raw_confidence,
|
||||
)))
|
||||
|
||||
# Safe cleanup - stream might already be stopped
|
||||
try:
|
||||
if self.stream:
|
||||
self.stream.stop()
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
except Exception as e:
|
||||
if self.is_predicting:
|
||||
print(f"Prediction loop error: {e}")
|
||||
break
|
||||
|
||||
def update_prediction_ui(self):
|
||||
"""Update UI from prediction thread."""
|
||||
"""Update UI from queue."""
|
||||
try:
|
||||
while True:
|
||||
msg_type, data = self.data_queue.get_nowait()
|
||||
|
||||
if msg_type == 'prediction':
|
||||
smoothed_label, smoothed_conf, raw_label, raw_conf = data
|
||||
label, conf = data
|
||||
|
||||
# Display smoothed (stable) prediction
|
||||
display_label = smoothed_label.upper().replace("_", " ")
|
||||
color = get_gesture_color(smoothed_label)
|
||||
|
||||
self.prediction_label.configure(text=display_label, text_color=color)
|
||||
self.confidence_bar.set(smoothed_conf / 100)
|
||||
self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%")
|
||||
|
||||
# Show raw prediction for comparison (grayed out)
|
||||
raw_display = raw_label.upper().replace("_", " ")
|
||||
if raw_label != smoothed_label:
|
||||
# Raw differs from smoothed - show it was filtered
|
||||
self.raw_label.configure(
|
||||
text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered",
|
||||
text_color="orange"
|
||||
)
|
||||
else:
|
||||
self.raw_label.configure(
|
||||
text=f"Raw: {raw_display} ({raw_conf:.0f}%)",
|
||||
text_color="gray"
|
||||
# Update label
|
||||
self.prediction_label.configure(
|
||||
text=label.upper(),
|
||||
text_color=get_gesture_color(label)
|
||||
)
|
||||
|
||||
# Update confidence
|
||||
self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%")
|
||||
self.confidence_bar.set(conf)
|
||||
|
||||
# Clear raw label since we don't have raw vs smooth distinction in edge mode
|
||||
# (or we could expose it if we updated the C struct, but for now keep it simple)
|
||||
self.raw_label.configure(text="", text_color="gray")
|
||||
|
||||
elif msg_type == 'sim_gesture':
|
||||
self.sim_label.configure(text=f"[Simulating: {data}]")
|
||||
|
||||
|
||||
@@ -1757,6 +1757,132 @@ class EMGClassifier:
|
||||
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
|
||||
return filepath
|
||||
|
||||
def export_to_header(self, filepath: Path) -> Path:
|
||||
"""
|
||||
Export trained model to a C header file for ESP32 inference.
|
||||
|
||||
Args:
|
||||
filepath: Output .h file path
|
||||
|
||||
Returns:
|
||||
Path to the saved header file
|
||||
"""
|
||||
if not self.is_trained:
|
||||
raise ValueError("Cannot export untrained classifier!")
|
||||
|
||||
filepath = Path(filepath)
|
||||
filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
n_classes = len(self.label_names)
|
||||
n_features = len(self.feature_names)
|
||||
|
||||
# Get LDA parameters
|
||||
# coef_: (n_classes, n_features) - access as [class][feature]
|
||||
# intercept_: (n_classes,)
|
||||
coefs = self.lda.coef_
|
||||
intercepts = self.lda.intercept_
|
||||
|
||||
# Add logic for binary classification (sklearn stores only 1 set of coefs)
|
||||
# For >2 classes, it stores n_classes sets.
|
||||
if n_classes == 2:
|
||||
# Binary case: coef_ is (1, n_features), intercept_ is (1,)
|
||||
# We need to expand this to 2 classes for the C inference engine to be generic.
|
||||
# Class 1 decision = dot(w, x) + b
|
||||
# Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function
|
||||
# BUT: decision_function returns score. A generic 'argmax' approach usually expects
|
||||
# one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial.
|
||||
# Let's check sklearn docs or behavior.
|
||||
# Actually, LDA in sklearn for binary case is special.
|
||||
# To make C code simple (always argmax), let's explicitly store 2 rows.
|
||||
# Row 1 (Index 1 in sklearn): coef, intercept
|
||||
# Row 0 (Index 0): -coef, -intercept ?
|
||||
# Wait, LDA is generative. The decision boundary is linear.
|
||||
# Let's assume Multiclass for now or handle binary specifically.
|
||||
# For simplicity in C, we prefer (n_classes, n_features).
|
||||
# If coefs.shape[0] != n_classes, we need to handle it.
|
||||
if coefs.shape[0] == 1:
|
||||
print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.")
|
||||
# Class 1 (positive)
|
||||
c1_coef = coefs[0]
|
||||
c1_int = intercepts[0]
|
||||
# Class 0 (negative) - Effectively -score for decision boundary at 0
|
||||
# But strictly speaking LDA is comparison of log-posteriors.
|
||||
# Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0).
|
||||
# The score S = coef.X + intercept. If S > 0 pred class 1, else 0.
|
||||
# To map this to ArgMax(Score0, Score1):
|
||||
# We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2.
|
||||
# Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier).
|
||||
# Safest: Let's trust that for our 5-gesture demo, it's multiclass.
|
||||
pass
|
||||
|
||||
# Generate C content
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
c_content = [
|
||||
"/**",
|
||||
f" * @file {filepath.name}",
|
||||
" * @brief Trained LDA model weights exported from Python.",
|
||||
f" * @date {timestamp}",
|
||||
" */",
|
||||
"",
|
||||
"#ifndef MODEL_WEIGHTS_H",
|
||||
"#define MODEL_WEIGHTS_H",
|
||||
"",
|
||||
"#include <stdint.h>",
|
||||
"",
|
||||
"/* Metadata */",
|
||||
f"#define MODEL_NUM_CLASSES {n_classes}",
|
||||
f"#define MODEL_NUM_FEATURES {n_features}",
|
||||
"",
|
||||
"/* Class Names */",
|
||||
"static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {",
|
||||
]
|
||||
|
||||
for name in self.label_names:
|
||||
c_content.append(f' "{name}",')
|
||||
c_content.append("};")
|
||||
c_content.append("")
|
||||
|
||||
c_content.append("/* Feature Extractor Parameters */")
|
||||
c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f")
|
||||
c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f")
|
||||
c_content.append("")
|
||||
|
||||
c_content.append("/* LDA Intercepts/Biases */")
|
||||
c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{")
|
||||
line = " "
|
||||
for val in intercepts:
|
||||
line += f"{val:.6f}f, "
|
||||
c_content.append(line.rstrip(", "))
|
||||
c_content.append("};")
|
||||
c_content.append("")
|
||||
|
||||
c_content.append("/* LDA Coefficients (Weights) */")
|
||||
c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{")
|
||||
|
||||
for i, row in enumerate(coefs):
|
||||
c_content.append(f" /* {self.label_names[i]} */")
|
||||
c_content.append(" {")
|
||||
line = " "
|
||||
for j, val in enumerate(row):
|
||||
line += f"{val:.6f}f, "
|
||||
if (j + 1) % 8 == 0:
|
||||
c_content.append(line)
|
||||
line = " "
|
||||
if line.strip():
|
||||
c_content.append(line.rstrip(", "))
|
||||
c_content.append(" },")
|
||||
|
||||
c_content.append("};")
|
||||
c_content.append("")
|
||||
c_content.append("#endif /* MODEL_WEIGHTS_H */")
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('\n'.join(c_content))
|
||||
|
||||
print(f"[Classifier] Model weights exported to: {filepath}")
|
||||
return filepath
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath: Path) -> 'EMGClassifier':
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user