Multi-model prediction on laptop | need to move to on-board prediction
This commit is contained in:
106
train_mlp_tflite.py
Normal file
106
train_mlp_tflite.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Train int8 MLP for ESP32-S3 deployment via TFLite Micro.
|
||||
Run AFTER Change 0 (label shift) + Change 1 (expanded features).
|
||||
|
||||
Change E — priority Tier 3.
|
||||
Outputs: EMG_Arm/src/core/emg_model_data.cc
|
||||
"""
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
print("ERROR: TensorFlow not installed. Run: pip install tensorflow")
|
||||
sys.exit(1)
|
||||
|
||||
# --- Load and extract features -----------------------------------------------
|
||||
storage = SessionStorage()
|
||||
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
|
||||
|
||||
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
|
||||
X = extractor.extract_features_batch(X_raw).astype(np.float32)
|
||||
|
||||
# Per-session class-balanced normalization (must match EMGClassifier + train_ensemble.py)
|
||||
for sid in np.unique(session_indices):
|
||||
mask = session_indices == sid
|
||||
X_sess = X[mask]
|
||||
y_sess = y[mask]
|
||||
class_means = [X_sess[y_sess == cls].mean(axis=0) for cls in np.unique(y_sess)]
|
||||
balanced_mean = np.mean(class_means, axis=0)
|
||||
std = X_sess.std(axis=0)
|
||||
std[std < 1e-12] = 1.0
|
||||
X[mask] = (X_sess - balanced_mean) / std
|
||||
|
||||
n_feat = X.shape[1]
|
||||
n_cls = len(np.unique(y))
|
||||
print(f"Dataset: {len(X)} samples, {n_feat} features, {n_cls} classes")
|
||||
print(f"Classes: {label_names}")
|
||||
|
||||
# --- Build and train MLP -----------------------------------------------------
|
||||
model = tf.keras.Sequential([
|
||||
tf.keras.layers.Input(shape=(n_feat,)),
|
||||
tf.keras.layers.Dense(32, activation='relu'),
|
||||
tf.keras.layers.Dropout(0.2),
|
||||
tf.keras.layers.Dense(16, activation='relu'),
|
||||
tf.keras.layers.Dense(n_cls, activation='softmax'),
|
||||
])
|
||||
model.compile(optimizer='adam',
|
||||
loss='sparse_categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
model.summary()
|
||||
model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1)
|
||||
|
||||
# --- Convert to int8 TFLite --------------------------------------------------
|
||||
def representative_dataset():
|
||||
for i in range(0, len(X), 10):
|
||||
yield [X[i:i+1]]
|
||||
|
||||
converter = tf.lite.TFLiteConverter.from_keras_model(model)
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.representative_dataset = representative_dataset
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
|
||||
converter.inference_input_type = tf.int8
|
||||
converter.inference_output_type = tf.int8
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# --- Write emg_model_data.cc -------------------------------------------------
|
||||
out_cc = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.cc'
|
||||
out_h = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.h'
|
||||
|
||||
with open(out_cc, 'w') as f:
|
||||
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
|
||||
f.write('#include "emg_model_data.h"\n')
|
||||
f.write(f'const int g_model_len = {len(tflite_model)};\n')
|
||||
f.write('alignas(8) const unsigned char g_model[] = {\n ')
|
||||
f.write(', '.join(f'0x{b:02x}' for b in tflite_model))
|
||||
f.write('\n};\n')
|
||||
|
||||
with open(out_h, 'w') as f:
|
||||
f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n')
|
||||
f.write('#pragma once\n\n')
|
||||
f.write('#ifdef __cplusplus\nextern "C" {\n#endif\n\n')
|
||||
f.write('extern const unsigned char g_model[];\n')
|
||||
f.write('extern const int g_model_len;\n\n')
|
||||
f.write('#ifdef __cplusplus\n}\n#endif\n')
|
||||
|
||||
print(f"Wrote {out_cc} ({len(tflite_model)} bytes)")
|
||||
print(f"Wrote {out_h}")
|
||||
|
||||
# --- Save MLP weights as numpy for laptop-side inference (no TF needed) ------
|
||||
layer_weights = [layer.get_weights() for layer in model.layers if layer.get_weights()]
|
||||
mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz'
|
||||
mlp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
np.savez(mlp_path,
|
||||
w0=layer_weights[0][0], b0=layer_weights[0][1], # Dense(32, relu)
|
||||
w1=layer_weights[1][0], b1=layer_weights[1][1], # Dense(16, relu)
|
||||
w2=layer_weights[2][0], b2=layer_weights[2][1], # Dense(5, softmax)
|
||||
label_names=np.array(label_names))
|
||||
print(f"Saved laptop MLP weights to {mlp_path}")
|
||||
|
||||
print(f"\nNext steps:")
|
||||
print(f" 1. Set MODEL_USE_MLP 1 in EMG_Arm/src/core/model_weights.h")
|
||||
print(f" 2. Run: pio run -t upload")
|
||||
Reference in New Issue
Block a user