""" Train int8 MLP for ESP32-S3 deployment via TFLite Micro. Run AFTER Change 0 (label shift) + Change 1 (expanded features). Change E — priority Tier 3. Outputs: EMG_Arm/src/core/emg_model_data.cc """ import numpy as np from pathlib import Path import sys sys.path.insert(0, str(Path(__file__).parent)) from learning_data_collection import SessionStorage, EMGFeatureExtractor, HAND_CHANNELS try: import tensorflow as tf except ImportError: print("ERROR: TensorFlow not installed. Run: pip install tensorflow") sys.exit(1) # --- Load and extract features ----------------------------------------------- storage = SessionStorage() X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training() extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True) X = extractor.extract_features_batch(X_raw).astype(np.float32) # Per-session class-balanced normalization (must match EMGClassifier + train_ensemble.py) for sid in np.unique(session_indices): mask = session_indices == sid X_sess = X[mask] y_sess = y[mask] class_means = [X_sess[y_sess == cls].mean(axis=0) for cls in np.unique(y_sess)] balanced_mean = np.mean(class_means, axis=0) std = X_sess.std(axis=0) std[std < 1e-12] = 1.0 X[mask] = (X_sess - balanced_mean) / std n_feat = X.shape[1] n_cls = len(np.unique(y)) print(f"Dataset: {len(X)} samples, {n_feat} features, {n_cls} classes") print(f"Classes: {label_names}") # --- Build and train MLP ----------------------------------------------------- model = tf.keras.Sequential([ tf.keras.layers.Input(shape=(n_feat,)), tf.keras.layers.Dense(32, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(16, activation='relu'), tf.keras.layers.Dense(n_cls, activation='softmax'), ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.summary() model.fit(X, y, epochs=150, batch_size=64, validation_split=0.1, verbose=1) # --- Convert to int8 TFLite -------------------------------------------------- def representative_dataset(): for i in range(0, len(X), 10): yield [X[i:i+1]] converter = tf.lite.TFLiteConverter.from_keras_model(model) converter.optimizations = [tf.lite.Optimize.DEFAULT] converter.representative_dataset = representative_dataset converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] converter.inference_input_type = tf.int8 converter.inference_output_type = tf.int8 tflite_model = converter.convert() # --- Write emg_model_data.cc ------------------------------------------------- out_cc = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.cc' out_h = Path(__file__).parent / 'EMG_Arm/src/core/emg_model_data.h' with open(out_cc, 'w') as f: f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n') f.write('#include "emg_model_data.h"\n') f.write(f'const int g_model_len = {len(tflite_model)};\n') f.write('alignas(8) const unsigned char g_model[] = {\n ') f.write(', '.join(f'0x{b:02x}' for b in tflite_model)) f.write('\n};\n') with open(out_h, 'w') as f: f.write('// Auto-generated by train_mlp_tflite.py — do not edit\n') f.write('#pragma once\n\n') f.write('#ifdef __cplusplus\nextern "C" {\n#endif\n\n') f.write('extern const unsigned char g_model[];\n') f.write('extern const int g_model_len;\n\n') f.write('#ifdef __cplusplus\n}\n#endif\n') print(f"Wrote {out_cc} ({len(tflite_model)} bytes)") print(f"Wrote {out_h}") # --- Save MLP weights as numpy for laptop-side inference (no TF needed) ------ layer_weights = [layer.get_weights() for layer in model.layers if layer.get_weights()] mlp_path = Path(__file__).parent / 'models' / 'emg_mlp_weights.npz' mlp_path.parent.mkdir(parents=True, exist_ok=True) np.savez(mlp_path, w0=layer_weights[0][0], b0=layer_weights[0][1], # Dense(32, relu) w1=layer_weights[1][0], b1=layer_weights[1][1], # Dense(16, relu) w2=layer_weights[2][0], b2=layer_weights[2][1], # Dense(5, softmax) label_names=np.array(label_names)) print(f"Saved laptop MLP weights to {mlp_path}") print(f"\nNext steps:") print(f" 1. Set MODEL_USE_MLP 1 in EMG_Arm/src/core/model_weights.h") print(f" 2. Run: pio run -t upload")