Files
EMG_Arm/train_mlp_tflite.py

107 lines
4.3 KiB
Python

"""
Train int8 MLP for ESP32-S3 deployment via TFLite Micro.
Run AFTER Change 0 (label shift) + Change 1 (expanded features).
Change E — priority Tier 3.
Outputs: EMG_Arm/src/core/emg_model_data.cc
"""
import numpy as np
from pathlib import Path
import sys
sys.path.insert(0, str(Path(__file__).parent))
from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
try:
import tensorflow as tf
except ImportError:
print("ERROR: TensorFlow not installed. Run: pip install tensorflow")
sys.exit(1)
# --- Load and extract features -----------------------------------------------
storage = SessionStorage()
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
X = extractor.extract_features_batch(X_raw).astype(np.float32)
# Per-session class-balanced normalization (must match EMGClassifier + train_ensemble.py)
for sid in np.unique(session_indices):
mask = session_indices == sid
X_sess = X[mask]
y_sess = y[mask]
class_means = [X_sess[y_sess == cls].mean(axis=0) for cls in np.unique(y_sess)]
balanced_mean = np.mean(class_means, axis=0)
std = X_sess.std(axis=0)
std[std < 1e-12] = 1.0
X[mask] = (X_sess - balanced_mean) / std
n_feat = X.shape[1]
n_cls = len(np.unique(y))
print(f"Dataset: {len(X)} samples, {n_feat} features, {n_cls} classes")
print(f"Classes: {label_names}")
# --- Build and train MLP -----------------------------------------------------
model = tf.keras.Sequential([
tf.keras.layers.Input(shape=(n_feat,)),
tf.keras.layers.Dense(32, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(n_cls, activation='softmax'),
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.summary()
model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1)
# --- Convert to int8 TFLite --------------------------------------------------
def representative_dataset():
for i in range(0, len(X), 10):
yield [X[i:i+1]]
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_model = converter.convert()
# --- Write emg_model_data.cc -------------------------------------------------
out_cc = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.cc'
out_h = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.h'
with open(out_cc, 'w') as f:
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
f.write('#include "emg_model_data.h"\n')
f.write(f'const int g_model_len = {len(tflite_model)};\n')
f.write('alignas(8) const unsigned char g_model[] = {\n ')
f.write(', '.join(f'0x{b:02x}' for b in tflite_model))
f.write('\n};\n')
with open(out_h, 'w') as f:
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
f.write('#pragma once\n\n')
f.write('#ifdef __cplusplus\nextern "C" {\n#endif\n\n')
f.write('extern const unsigned char g_model[];\n')
f.write('extern const int g_model_len;\n\n')
f.write('#ifdef __cplusplus\n}\n#endif\n')
print(f"Wrote {out_cc} ({len(tflite_model)} bytes)")
print(f"Wrote {out_h}")
# --- Save MLP weights as numpy for laptop-side inference (no TF needed) ------
layer_weights = [layer.get_weights() for layer in model.layers if layer.get_weights()]
mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz'
mlp_path.parent.mkdir(parents=True, exist_ok=True)
np.savez(mlp_path,
w0=layer_weights[0][0], b0=layer_weights[0][1], # Dense(32, relu)
w1=layer_weights[1][0], b1=layer_weights[1][1], # Dense(16, relu)
w2=layer_weights[2][0], b2=layer_weights[2][1], # Dense(5, softmax)
label_names=np.array(label_names))
print(f"Saved laptop MLP weights to {mlp_path}")
print(f"\nNext steps:")
print(f" 1. Set MODEL_USE_MLP 1 in EMG_Arm/src/core/model_weights.h")
print(f" 2. Run: pio run -t upload")