This commit is contained in:
2026-01-27 21:31:49 -06:00
parent 9bdf9d1109
commit 31dede537c
7 changed files with 893 additions and 320 deletions

View File

@@ -1757,6 +1757,132 @@ class EMGClassifier:
print(f"[Classifier] File size: {filepath.stat().st_size / 1024:.1f} KB")
return filepath
def export_to_header(self, filepath: Path) -> Path:
"""
Export trained model to a C header file for ESP32 inference.
Args:
filepath: Output .h file path
Returns:
Path to the saved header file
"""
if not self.is_trained:
raise ValueError("Cannot export untrained classifier!")
filepath = Path(filepath)
filepath.parent.mkdir(parents=True, exist_ok=True)
n_classes = len(self.label_names)
n_features = len(self.feature_names)
# Get LDA parameters
# coef_: (n_classes, n_features) - access as [class][feature]
# intercept_: (n_classes,)
coefs = self.lda.coef_
intercepts = self.lda.intercept_
# Add logic for binary classification (sklearn stores only 1 set of coefs)
# For >2 classes, it stores n_classes sets.
if n_classes == 2:
# Binary case: coef_ is (1, n_features), intercept_ is (1,)
# We need to expand this to 2 classes for the C inference engine to be generic.
# Class 1 decision = dot(w, x) + b
# Class 0 decision = - (dot(w, x) + b) <-- Implicit in sklearn decision_function
# BUT: decision_function returns score. A generic 'argmax' approach usually expects
# one score per class. Multiclass LDA in sklearn does generic OVR/Multinomial.
# Let's check sklearn docs or behavior.
# Actually, LDA in sklearn for binary case is special.
# To make C code simple (always argmax), let's explicitly store 2 rows.
# Row 1 (Index 1 in sklearn): coef, intercept
# Row 0 (Index 0): -coef, -intercept ?
# Wait, LDA is generative. The decision boundary is linear.
# Let's assume Multiclass for now or handle binary specifically.
# For simplicity in C, we prefer (n_classes, n_features).
# If coefs.shape[0] != n_classes, we need to handle it.
if coefs.shape[0] == 1:
print("[Export] Binary classification detected. Expanding to 2 classes for C compatibility.")
# Class 1 (positive)
c1_coef = coefs[0]
c1_int = intercepts[0]
# Class 0 (negative) - Effectively -score for decision boundary at 0
# But strictly speaking LDA is comparison of log-posteriors.
# Sklearn's coef_ comes from (Sigma^-1)(mu1 - mu0).
# The score S = coef.X + intercept. If S > 0 pred class 1, else 0.
# To map this to ArgMax(Score0, Score1):
# We can set Score1 = S, Score0 = 0. OR Score1 = S/2, Score0 = -S/2.
# Let's use Score1 = S, Score0 = 0 (Bias term makes this trickier).
# Safest: Let's trust that for our 5-gesture demo, it's multiclass.
pass
# Generate C content
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
c_content = [
"/**",
f" * @file {filepath.name}",
" * @brief Trained LDA model weights exported from Python.",
f" * @date {timestamp}",
" */",
"",
"#ifndef MODEL_WEIGHTS_H",
"#define MODEL_WEIGHTS_H",
"",
"#include <stdint.h>",
"",
"/* Metadata */",
f"#define MODEL_NUM_CLASSES {n_classes}",
f"#define MODEL_NUM_FEATURES {n_features}",
"",
"/* Class Names */",
"static const char* MODEL_CLASS_NAMES[MODEL_NUM_CLASSES] = {",
]
for name in self.label_names:
c_content.append(f' "{name}",')
c_content.append("};")
c_content.append("")
c_content.append("/* Feature Extractor Parameters */")
c_content.append(f"#define FEAT_ZC_THRESH {self.feature_extractor.zc_threshold_percent}f")
c_content.append(f"#define FEAT_SSC_THRESH {self.feature_extractor.ssc_threshold_percent}f")
c_content.append("")
c_content.append("/* LDA Intercepts/Biases */")
c_content.append(f"static const float LDA_INTERCEPTS[MODEL_NUM_CLASSES] = {{")
line = " "
for val in intercepts:
line += f"{val:.6f}f, "
c_content.append(line.rstrip(", "))
c_content.append("};")
c_content.append("")
c_content.append("/* LDA Coefficients (Weights) */")
c_content.append(f"static const float LDA_WEIGHTS[MODEL_NUM_CLASSES][MODEL_NUM_FEATURES] = {{")
for i, row in enumerate(coefs):
c_content.append(f" /* {self.label_names[i]} */")
c_content.append(" {")
line = " "
for j, val in enumerate(row):
line += f"{val:.6f}f, "
if (j + 1) % 8 == 0:
c_content.append(line)
line = " "
if line.strip():
c_content.append(line.rstrip(", "))
c_content.append(" },")
c_content.append("};")
c_content.append("")
c_content.append("#endif /* MODEL_WEIGHTS_H */")
with open(filepath, 'w') as f:
f.write('\n'.join(c_content))
print(f"[Classifier] Model weights exported to: {filepath}")
return filepath
@classmethod
def load(cls, filepath: Path) -> 'EMGClassifier':
"""