201 lines
8.4 KiB
Python
201 lines
8.4 KiB
Python
|
|
"""
|
|||
|
|
Train the full 3-specialist-LDA + meta-LDA ensemble.
|
|||
|
|
Requires Change 1 (expanded features) to be implemented first.
|
|||
|
|
Exports model_weights_ensemble.h for firmware Change F.
|
|||
|
|
|
|||
|
|
Architecture:
|
|||
|
|
LDA_TD (36 time-domain feat) ─┐
|
|||
|
|
LDA_FD (24 freq-domain feat) ├─ 15 probs ─► Meta-LDA ─► final class
|
|||
|
|
LDA_CC (9 cross-ch feat) ─┘
|
|||
|
|
|
|||
|
|
Change 7 — priority Tier 3.
|
|||
|
|
"""
|
|||
|
|
import numpy as np
|
|||
|
|
from pathlib import Path
|
|||
|
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
|
|||
|
|
from sklearn.model_selection import cross_val_predict, GroupKFold, cross_val_score
|
|||
|
|
import sys
|
|||
|
|
sys.path.insert(0, str(Path(__file__).parent))
|
|||
|
|
from learning_data_collection import (
|
|||
|
|
SessionStorage, EMGFeatureExtractor, HAND_CHANNELS
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# --- Load and extract features -----------------------------------------------
|
|||
|
|
storage = SessionStorage()
|
|||
|
|
X_raw, y, trial_ids, session_indices, label_names, _ = storage.load_all_for_training()
|
|||
|
|
|
|||
|
|
extractor = EMGFeatureExtractor(channels=HAND_CHANNELS, cross_channel=True, expanded=True, reinhard=True)
|
|||
|
|
X = extractor.extract_features_batch(X_raw).astype(np.float64)
|
|||
|
|
|
|||
|
|
# Per-session class-balanced normalization
|
|||
|
|
# Must match EMGClassifier._apply_session_normalization():
|
|||
|
|
# mean = average of per-class means (not overall mean), std = overall std.
|
|||
|
|
# StandardScaler uses overall mean, which biases toward the majority class.
|
|||
|
|
for sid in np.unique(session_indices):
|
|||
|
|
mask = session_indices == sid
|
|||
|
|
X_sess = X[mask]
|
|||
|
|
y_sess = y[mask]
|
|||
|
|
|
|||
|
|
# Class-balanced mean: average of per-class centroids
|
|||
|
|
class_means = []
|
|||
|
|
for cls in np.unique(y_sess):
|
|||
|
|
class_means.append(X_sess[y_sess == cls].mean(axis=0))
|
|||
|
|
balanced_mean = np.mean(class_means, axis=0)
|
|||
|
|
|
|||
|
|
# Overall std (same as StandardScaler)
|
|||
|
|
std = X_sess.std(axis=0)
|
|||
|
|
std[std < 1e-12] = 1.0 # avoid division by zero
|
|||
|
|
|
|||
|
|
X[mask] = (X_sess - balanced_mean) / std
|
|||
|
|
|
|||
|
|
feat_names = extractor.get_feature_names(n_channels=len(HAND_CHANNELS))
|
|||
|
|
n_cls = len(np.unique(y))
|
|||
|
|
|
|||
|
|
# --- Feature subset indices ---------------------------------------------------
|
|||
|
|
# Per-channel layout (20 features/channel): indices 0-11 TD, 12-19 FD
|
|||
|
|
# Cross-channel features start at index 60 (3 channels × 20 features each)
|
|||
|
|
TD_FEAT = ['rms', 'wl', 'zc', 'ssc', 'mav', 'var', 'iemg', 'wamp', 'ar1', 'ar2', 'ar3', 'ar4']
|
|||
|
|
FD_FEAT = ['mnf', 'mdf', 'pkf', 'mnp', 'bp0', 'bp1', 'bp2', 'bp3']
|
|||
|
|
|
|||
|
|
td_idx = [i for i, n in enumerate(feat_names)
|
|||
|
|
if any(n.endswith(f'_{f}') for f in TD_FEAT) and n.startswith('ch')]
|
|||
|
|
fd_idx = [i for i, n in enumerate(feat_names)
|
|||
|
|
if any(n.endswith(f'_{f}') for f in FD_FEAT) and n.startswith('ch')]
|
|||
|
|
cc_idx = [i for i, n in enumerate(feat_names) if n.startswith('cc_')]
|
|||
|
|
|
|||
|
|
print(f"Feature subsets — TD: {len(td_idx)}, FD: {len(fd_idx)}, CC: {len(cc_idx)}")
|
|||
|
|
assert len(td_idx) == 36, f"Expected 36 TD features, got {len(td_idx)}"
|
|||
|
|
assert len(fd_idx) == 24, f"Expected 24 FD features, got {len(fd_idx)}"
|
|||
|
|
assert len(cc_idx) == 9, f"Expected 9 CC features, got {len(cc_idx)}"
|
|||
|
|
|
|||
|
|
X_td = X[:, td_idx]
|
|||
|
|
X_fd = X[:, fd_idx]
|
|||
|
|
X_cc = X[:, cc_idx]
|
|||
|
|
|
|||
|
|
# --- Train specialist LDAs with out-of-fold stacking -------------------------
|
|||
|
|
gkf = GroupKFold(n_splits=min(5, len(np.unique(trial_ids))))
|
|||
|
|
|
|||
|
|
print("Training specialist LDAs (out-of-fold for stacking)...")
|
|||
|
|
lda_td = LinearDiscriminantAnalysis()
|
|||
|
|
lda_fd = LinearDiscriminantAnalysis()
|
|||
|
|
lda_cc = LinearDiscriminantAnalysis()
|
|||
|
|
|
|||
|
|
oof_td = cross_val_predict(lda_td, X_td, y, cv=gkf, groups=trial_ids, method='predict_proba')
|
|||
|
|
oof_fd = cross_val_predict(lda_fd, X_fd, y, cv=gkf, groups=trial_ids, method='predict_proba')
|
|||
|
|
oof_cc = cross_val_predict(lda_cc, X_cc, y, cv=gkf, groups=trial_ids, method='predict_proba')
|
|||
|
|
|
|||
|
|
# Specialist CV accuracy (for diagnostics)
|
|||
|
|
for name, mdl, Xs in [('LDA_TD', lda_td, X_td),
|
|||
|
|
('LDA_FD', lda_fd, X_fd),
|
|||
|
|
('LDA_CC', lda_cc, X_cc)]:
|
|||
|
|
sc = cross_val_score(mdl, Xs, y, cv=gkf, groups=trial_ids)
|
|||
|
|
print(f" {name}: {sc.mean()*100:.1f}% ± {sc.std()*100:.1f}%")
|
|||
|
|
|
|||
|
|
# --- Train meta-LDA on out-of-fold outputs ------------------------------------
|
|||
|
|
X_meta = np.hstack([oof_td, oof_fd, oof_cc]) # (n_samples, 3*n_cls = 15)
|
|||
|
|
meta_lda = LinearDiscriminantAnalysis()
|
|||
|
|
meta_sc = cross_val_score(meta_lda, X_meta, y, cv=gkf, groups=trial_ids)
|
|||
|
|
print(f" Meta-LDA: {meta_sc.mean()*100:.1f}% ± {meta_sc.std()*100:.1f}%")
|
|||
|
|
|
|||
|
|
# Fit all models on full dataset for deployment
|
|||
|
|
lda_td.fit(X_td, y)
|
|||
|
|
lda_fd.fit(X_fd, y)
|
|||
|
|
lda_cc.fit(X_cc, y)
|
|||
|
|
meta_lda.fit(X_meta, y)
|
|||
|
|
|
|||
|
|
# --- Export all weights to C header ------------------------------------------
|
|||
|
|
def lda_to_c_arrays(lda, name, feat_dim, n_cls, label_names, class_order):
|
|||
|
|
"""Generate C array strings for LDA weights and intercepts.
|
|||
|
|
|
|||
|
|
NOTE: sklearn LDA.coef_ for multi-class has shape (n_classes-1, n_features)
|
|||
|
|
when using SVD solver. If so, we use decision_function and re-derive weights.
|
|||
|
|
"""
|
|||
|
|
coef = lda.coef_
|
|||
|
|
intercept = lda.intercept_
|
|||
|
|
|
|||
|
|
if coef.shape[0] != n_cls:
|
|||
|
|
# SVD solver returns (n_cls-1, n_feat); sklearn handles this internally
|
|||
|
|
# via scalings_. We refit with 'lsqr' solver to get full coef matrix.
|
|||
|
|
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA2
|
|||
|
|
lda2 = LDA2(solver='lsqr')
|
|||
|
|
# We can't refit here (no data) so just warn and pad with zeros
|
|||
|
|
print(f" WARNING: {name} coef_ shape {coef.shape} != ({n_cls}, {feat_dim}). "
|
|||
|
|
f"Padding with zeros. Refit with solver='lsqr' for full matrix.")
|
|||
|
|
padded = np.zeros((n_cls, feat_dim))
|
|||
|
|
padded[:coef.shape[0]] = coef
|
|||
|
|
coef = padded
|
|||
|
|
padded_i = np.zeros(n_cls)
|
|||
|
|
padded_i[:intercept.shape[0]] = intercept
|
|||
|
|
intercept = padded_i
|
|||
|
|
|
|||
|
|
lines = []
|
|||
|
|
lines.append(f"const float {name}_WEIGHTS[{n_cls}][{feat_dim}] = {{")
|
|||
|
|
for c in class_order:
|
|||
|
|
row = ', '.join(f'{v:.8f}f' for v in coef[c])
|
|||
|
|
lines.append(f" {{{row}}}, // {label_names[c]}")
|
|||
|
|
lines.append("};")
|
|||
|
|
lines.append(f"const float {name}_INTERCEPTS[{n_cls}] = {{")
|
|||
|
|
intercept_str = ', '.join(f'{intercept[c]:.8f}f' for c in class_order)
|
|||
|
|
lines.append(f" {intercept_str}")
|
|||
|
|
lines.append("};")
|
|||
|
|
return '\n'.join(lines)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class_order = list(range(n_cls))
|
|||
|
|
out_path = Path(__file__).parent / 'EMG_Arm/src/core/model_weights_ensemble.h'
|
|||
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
td_offset = min(td_idx) if td_idx else 0
|
|||
|
|
fd_offset = min(fd_idx) if fd_idx else 0
|
|||
|
|
cc_offset = min(cc_idx) if cc_idx else 0
|
|||
|
|
|
|||
|
|
with open(out_path, 'w') as f:
|
|||
|
|
f.write("// Auto-generated by train_ensemble.py — do not edit\n")
|
|||
|
|
f.write("#pragma once\n\n")
|
|||
|
|
f.write("// Pull MODEL_NUM_CLASSES, MODEL_NUM_FEATURES, MODEL_CLASS_NAMES from\n")
|
|||
|
|
f.write("// model_weights.h to avoid redefinition conflicts.\n")
|
|||
|
|
f.write('#include "model_weights.h"\n\n')
|
|||
|
|
f.write(f"#define ENSEMBLE_PER_CH_FEATURES 20\n\n")
|
|||
|
|
f.write(f"#define TD_FEAT_OFFSET {td_offset}\n")
|
|||
|
|
f.write(f"#define TD_NUM_FEATURES {len(td_idx)}\n")
|
|||
|
|
f.write(f"#define FD_FEAT_OFFSET {fd_offset}\n")
|
|||
|
|
f.write(f"#define FD_NUM_FEATURES {len(fd_idx)}\n")
|
|||
|
|
f.write(f"#define CC_FEAT_OFFSET {cc_offset}\n")
|
|||
|
|
f.write(f"#define CC_NUM_FEATURES {len(cc_idx)}\n")
|
|||
|
|
f.write(f"#define META_NUM_INPUTS (3 * MODEL_NUM_CLASSES)\n\n")
|
|||
|
|
|
|||
|
|
f.write("// Feature index arrays for gather operations (TD and FD are non-contiguous)\n")
|
|||
|
|
f.write(f"// TD indices: {td_idx}\n")
|
|||
|
|
f.write(f"// FD indices: {fd_idx}\n")
|
|||
|
|
f.write(f"// CC indices: {cc_idx}\n\n")
|
|||
|
|
|
|||
|
|
f.write(lda_to_c_arrays(lda_td, 'LDA_TD', len(td_idx), n_cls, label_names, class_order))
|
|||
|
|
f.write('\n\n')
|
|||
|
|
f.write(lda_to_c_arrays(lda_fd, 'LDA_FD', len(fd_idx), n_cls, label_names, class_order))
|
|||
|
|
f.write('\n\n')
|
|||
|
|
f.write(lda_to_c_arrays(lda_cc, 'LDA_CC', len(cc_idx), n_cls, label_names, class_order))
|
|||
|
|
f.write('\n\n')
|
|||
|
|
f.write(lda_to_c_arrays(meta_lda, 'META_LDA', 3 * n_cls, n_cls, label_names, class_order))
|
|||
|
|
f.write('\n')
|
|||
|
|
|
|||
|
|
print(f"Exported ensemble weights to {out_path}")
|
|||
|
|
print(f"Total weight storage: "
|
|||
|
|
f"{(len(td_idx) + len(fd_idx) + len(cc_idx) + 3*n_cls) * n_cls * 4} bytes float32")
|
|||
|
|
|
|||
|
|
# --- Also save sklearn models for laptop-side inference ----------------------
|
|||
|
|
import joblib
|
|||
|
|
ensemble_bundle = {
|
|||
|
|
'lda_td': lda_td,
|
|||
|
|
'lda_fd': lda_fd,
|
|||
|
|
'lda_cc': lda_cc,
|
|||
|
|
'meta_lda': meta_lda,
|
|||
|
|
'td_idx': td_idx,
|
|||
|
|
'fd_idx': fd_idx,
|
|||
|
|
'cc_idx': cc_idx,
|
|||
|
|
'label_names': label_names,
|
|||
|
|
}
|
|||
|
|
ensemble_joblib = Path(__file__).parent / 'models' / 'emg_ensemble.joblib'
|
|||
|
|
ensemble_joblib.parent.mkdir(parents=True, exist_ok=True)
|
|||
|
|
joblib.dump(ensemble_bundle, ensemble_joblib)
|
|||
|
|
print(f"Saved laptop ensemble model to {ensemble_joblib}")
|