Files
EMG_Arm/train_ensemble.py

201 lines
8.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Train the full 3-specialist-LDA + meta-LDA ensemble.
Requires Change 1 (expanded features) to be implemented first.
Exports model_weights_ensemble.h for firmware Change F.
Architecture:
LDA_TD (36 time-domain feat) ─┐
LDA_FD (24 freq-domain feat) ├─ 15 probs ─► Meta-LDA ─► final class
LDA_CC (9 cross-ch feat) ─┘
Change 7 — priority Tier 3.
"""
import numpy as np
from pathlib import Path
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_predict, GroupKFold, cross_val_score
import sys
sys.path.insert(0, str(Path(__file__).parent))
from learning_data_collection import (
SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
)
# --- Load and extract features -----------------------------------------------
storage = SessionStorage()
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
X = extractor.extract_features_batch(X_raw).astype(np.float64)
# Per-session class-balanced normalization
# Must match EMGClassifier._apply_session_normalization():
# mean = average of per-class means (not overall mean), std = overall std.
# StandardScaler uses overall mean, which biases toward the majority class.
for sid in np.unique(session_indices):
mask = session_indices == sid
X_sess = X[mask]
y_sess = y[mask]
# Class-balanced mean: average of per-class centroids
class_means = []
for cls in np.unique(y_sess):
class_means.append(X_sess[y_sess == cls].mean(axis=0))
balanced_mean = np.mean(class_means, axis=0)
# Overall std (same as StandardScaler)
std = X_sess.std(axis=0)
std[std < 1e-12] = 1.0 # avoid division by zero
X[mask] = (X_sess - balanced_mean) / std
feat_names = extractor.get_feature_names(n_channels=len(HAND_CHANNELS))
n_cls = len(np.unique(y))
# --- Feature subset indices ---------------------------------------------------
# Per-channel layout (20 features/channel): indices 0-11 TD, 12-19 FD
# Cross-channel features start at index 60 (3 channels × 20 features each)
TD_FEAT = ['rms', 'wl', 'zc', 'ssc', 'mav', 'var', 'iemg', 'wamp', 'ar1', 'ar2', 'ar3', 'ar4']
FD_FEAT = ['mnf', 'mdf', 'pkf', 'mnp', 'bp0', 'bp1', 'bp2', 'bp3']
td_idx = [i for i, n in enumerate(feat_names)
if any(n.endswith(f'_{f}') for f in TD_FEAT) and n.startswith('ch')]
fd_idx = [i for i, n in enumerate(feat_names)
if any(n.endswith(f'_{f}') for f in FD_FEAT) and n.startswith('ch')]
cc_idx = [i for i, n in enumerate(feat_names) if n.startswith('cc_')]
print(f"Feature subsets — TD: {len(td_idx)}, FD: {len(fd_idx)}, CC: {len(cc_idx)}")
assert len(td_idx) == 36, f"Expected 36 TD features, got {len(td_idx)}"
assert len(fd_idx) == 24, f"Expected 24 FD features, got {len(fd_idx)}"
assert len(cc_idx) == 9, f"Expected 9 CC features, got {len(cc_idx)}"
X_td = X[:, td_idx]
X_fd = X[:, fd_idx]
X_cc = X[:, cc_idx]
# --- Train specialist LDAs with out-of-fold stacking -------------------------
gkf = GroupKFold(n_splits=min(5, len(np.unique(trial_ids))))
print("Training specialist LDAs (out-of-fold for stacking)...")
lda_td = LinearDiscriminantAnalysis()
lda_fd = LinearDiscriminantAnalysis()
lda_cc = LinearDiscriminantAnalysis()
oof_td = cross_val_predict(lda_td, X_td, y, cv=gkf, groups=trial_ids, method='predict_proba')
oof_fd = cross_val_predict(lda_fd, X_fd, y, cv=gkf, groups=trial_ids, method='predict_proba')
oof_cc = cross_val_predict(lda_cc, X_cc, y, cv=gkf, groups=trial_ids, method='predict_proba')
# Specialist CV accuracy (for diagnostics)
for name, mdl, Xs in [('LDA_TD', lda_td, X_td),
('LDA_FD', lda_fd, X_fd),
('LDA_CC', lda_cc, X_cc)]:
sc = cross_val_score(mdl, Xs, y, cv=gkf, groups=trial_ids)
print(f" {name}: {sc.mean()*100:.1f}% ± {sc.std()*100:.1f}%")
# --- Train meta-LDA on out-of-fold outputs ------------------------------------
X_meta = np.hstack([oof_td, oof_fd, oof_cc]) # (n_samples, 3*n_cls = 15)
meta_lda = LinearDiscriminantAnalysis()
meta_sc = cross_val_score(meta_lda, X_meta, y, cv=gkf, groups=trial_ids)
print(f" Meta-LDA: {meta_sc.mean()*100:.1f}% ± {meta_sc.std()*100:.1f}%")
# Fit all models on full dataset for deployment
lda_td.fit(X_td, y)
lda_fd.fit(X_fd, y)
lda_cc.fit(X_cc, y)
meta_lda.fit(X_meta, y)
# --- Export all weights to C header ------------------------------------------
def lda_to_c_arrays(lda, name, feat_dim, n_cls, label_names, class_order):
"""Generate C array strings for LDA weights and intercepts.
NOTE: sklearn LDA.coef_ for multi-class has shape (n_classes-1, n_features)
when using SVD solver. If so, we use decision_function and re-derive weights.
"""
coef = lda.coef_
intercept = lda.intercept_
if coef.shape[0] != n_cls:
# SVD solver returns (n_cls-1, n_feat); sklearn handles this internally
# via scalings_. We refit with 'lsqr' solver to get full coef matrix.
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA2
lda2 = LDA2(solver='lsqr')
# We can't refit here (no data) so just warn and pad with zeros
print(f" WARNING: {name} coef_ shape {coef.shape} != ({n_cls}, {feat_dim}). "
f"Padding with zeros. Refit with solver='lsqr' for full matrix.")
padded = np.zeros((n_cls, feat_dim))
padded[:coef.shape[0]] = coef
coef = padded
padded_i = np.zeros(n_cls)
padded_i[:intercept.shape[0]] = intercept
intercept = padded_i
lines = []
lines.append(f"const float {name}_WEIGHTS[{n_cls}][{feat_dim}] = {{")
for c in class_order:
row = ', '.join(f'{v:.8f}f' for v in coef[c])
lines.append(f" {{{row}}}, // {label_names[c]}")
lines.append("};")
lines.append(f"const float {name}_INTERCEPTS[{n_cls}] = {{")
intercept_str = ', '.join(f'{intercept[c]:.8f}f' for c in class_order)
lines.append(f" {intercept_str}")
lines.append("};")
return '\n'.join(lines)
class_order = list(range(n_cls))
out_path = Path(__file__).parent / 'EMG_Arm/src/core/model_weights_ensemble.h'
out_path.parent.mkdir(parents=True, exist_ok=True)
td_offset = min(td_idx) if td_idx else 0
fd_offset = min(fd_idx) if fd_idx else 0
cc_offset = min(cc_idx) if cc_idx else 0
with open(out_path, 'w') as f:
f.write("// Auto-generated by train_ensemble.py — do not edit\n")
f.write("#pragma once\n\n")
f.write("// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from\n")
f.write("// model_weights.h to avoid redefinition conflicts.\n")
f.write('#include "model_weights.h"\n\n')
f.write(f"#define ENSEMBLE_PER_CH_FEATURES 20\n\n")
f.write(f"#define TD_FEAT_OFFSET {td_offset}\n")
f.write(f"#define TD_NUM_FEATURES {len(td_idx)}\n")
f.write(f"#define FD_FEAT_OFFSET {fd_offset}\n")
f.write(f"#define FD_NUM_FEATURES {len(fd_idx)}\n")
f.write(f"#define CC_FEAT_OFFSET {cc_offset}\n")
f.write(f"#define CC_NUM_FEATURES {len(cc_idx)}\n")
f.write(f"#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)\n\n")
f.write("// Feature index arrays for gather operations (TD and FD are non-contiguous)\n")
f.write(f"// TD indices: {td_idx}\n")
f.write(f"// FD indices: {fd_idx}\n")
f.write(f"// CC indices: {cc_idx}\n\n")
f.write(lda_to_c_arrays(lda_td, 'LDA_TD', len(td_idx), n_cls, label_names, class_order))
f.write('\n\n')
f.write(lda_to_c_arrays(lda_fd, 'LDA_FD', len(fd_idx), n_cls, label_names, class_order))
f.write('\n\n')
f.write(lda_to_c_arrays(lda_cc, 'LDA_CC', len(cc_idx), n_cls, label_names, class_order))
f.write('\n\n')
f.write(lda_to_c_arrays(meta_lda, 'META_LDA', 3 * n_cls, n_cls, label_names, class_order))
f.write('\n')
print(f"Exported ensemble weights to {out_path}")
print(f"Total weight storage: "
f"{(len(td_idx) + len(fd_idx) + len(cc_idx) + 3*n_cls) * n_cls * 4} bytes float32")
# --- Also save sklearn models for laptop-side inference ----------------------
import joblib
ensemble_bundle = {
'lda_td': lda_td,
'lda_fd': lda_fd,
'lda_cc': lda_cc,
'meta_lda': meta_lda,
'td_idx': td_idx,
'fd_idx': fd_idx,
'cc_idx': cc_idx,
'label_names': label_names,
}
ensemble_joblib = Path(__file__).parent / 'models' / 'emg_ensemble.joblib'
ensemble_joblib.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(ensemble_bundle, ensemble_joblib)
print(f"Saved laptop ensemble model to {ensemble_joblib}")