2026-04-21 13:01:49 -05:00
{
"cells": [
{
"cell_type": "markdown",
2026-04-21 21:18:33 -05:00
"id": "d8960e56",
2026-04-21 13:01:49 -05:00
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"# Motor Imagery Decoder — Train OFFLINE, Evaluate ONLINE (FES vs NOFES)\n",
"\n",
"For each subject × offline-session, train a CSP + LDA classifier on the OFFLINE recording, then apply it to the two matched ONLINE sessions.\n",
"\n",
"**Pair 1:** train on `S001 OFFLINE_FES` → test on `S002 ONLINE_FES` and `S003 ONLINE_NOFES`\n",
"**Pair 2:** train on `S004 OFFLINE_NOFES` → test on `S006 ONLINE_FES` and `S005 ONLINE_NOFES`\n",
"\n",
"**Metrics reported per (subject × pair × condition):**\n",
"1. **Classification accuracy** — fraction of cued trials correctly classified\n",
"2. **Classification amplitude** — mean |LDA decision-function value|\n",
"3. **SNR** — (a) Fisher ratio of the LDA projection on online data, and (b) mu-band power ratio REST / MI over motor channels C3/Cz/C4"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-22 16:20:39 -05:00
"execution_count": 2,
2026-04-21 13:01:49 -05:00
"id": "578c9128",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:26:50.521758Z",
"iopub.status.busy": "2026-04-22T19:26:50.521477Z",
"iopub.status.idle": "2026-04-22T19:26:50.528154Z",
"shell.execute_reply": "2026-04-22T19:26:50.527272Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
"source": [
"# Install dependencies if needed\n",
"# !pip install pyxdf mne scipy numpy matplotlib"
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 1,
2026-04-21 13:01:49 -05:00
"id": "857b22c0",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:26:50.530924Z",
"iopub.status.busy": "2026-04-22T19:26:50.530685Z",
"iopub.status.idle": "2026-04-22T19:26:51.596469Z",
"shell.execute_reply": "2026-04-22T19:26:51.595934Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
2026-04-21 21:18:33 -05:00
"source": [
"import os\n",
"import re\n",
"import glob\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Patch\n",
"import pyxdf\n",
"from scipy.signal import welch, butter, filtfilt, iirnotch\n",
"from scipy.linalg import eigh\n",
2026-04-22 16:20:39 -05:00
"from scipy.stats import wasserstein_distance\n",
2026-04-21 21:18:33 -05:00
"\n",
"plt.rcParams.update({'font.size': 11, 'figure.dpi': 120})"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "fe68bf0e",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 2,
2026-04-21 13:01:49 -05:00
"id": "dc4b2c55",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:26:51.597901Z",
"iopub.status.busy": "2026-04-22T19:26:51.597789Z",
"iopub.status.idle": "2026-04-22T19:26:51.600886Z",
"shell.execute_reply": "2026-04-22T19:26:51.600373Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
2026-04-21 21:18:33 -05:00
"source": [
"DATA_DIR = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'Group 2 - Glove')\n",
"\n",
2026-04-21 22:12:59 -05:00
"# Marker codes (from experiment trigger table)\n",
"MI_BEGIN = 200\n",
"MI_END = 220\n",
"MI_EARLYSTOP = 240 # online only: live classifier fired → successful MI detection\n",
"REST_BEGIN = 100\n",
"REST_END = 120\n",
"REST_EARLYSTOP = 140 # online only: live classifier fired → successful REST detection\n",
"ROBOT_BEGIN = 300\n",
"ROBOT_END = 320\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"TARGET_MARKERS = [100, 120, 140, 200, 220, 240]\n",
"\n",
"T_PRE = -1.0\n",
"T_POST = 5.0\n",
2026-04-21 21:18:33 -05:00
"\n",
"# ── Preprocessing ────────────────────────────────────────────────────────────\n",
2026-04-21 22:12:59 -05:00
"NOTCH_FREQ = 60.0\n",
"NOTCH_Q = 30\n",
"BP_LO, BP_HI = 8.0, 30.0\n",
"USE_CAR = True\n",
"PTP_REJECT_UV = 100.0\n",
2026-04-21 21:18:33 -05:00
"\n",
"N_CSP = 4\n",
"\n",
"NON_EEG = {'AUX1', 'AUX2', 'AUX3', 'AUX7', 'AUX8', 'AUX9', 'TRIGGER'}\n",
"RENAME = {'FP1':'Fp1','FPZ':'Fpz','FP2':'Fp2','FZ':'Fz','CZ':'Cz',\n",
" 'PZ':'Pz','POZ':'POz','OZ':'Oz'}\n",
"\n",
"MOTOR_CH = ['C3', 'Cz', 'C4']\n",
"MU_BAND = (8, 13)\n",
"\n",
"PAIRS = [\n",
" {'name': 'Pair1 (train=OFFLINE_FES)',\n",
2026-04-21 22:24:23 -05:00
" 'train': 'S001', 'online_fes': 'S002', 'online_nofes': 'S003'},\n",
2026-04-21 21:18:33 -05:00
" {'name': 'Pair2 (train=OFFLINE_NOFES)',\n",
2026-04-21 22:24:23 -05:00
" 'train': 'S004', 'online_fes': 'S006', 'online_nofes': 'S005'},\n",
2026-04-21 21:18:33 -05:00
"]"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "21a40df3",
"metadata": {},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 3,
2026-04-21 13:01:49 -05:00
"id": "e798b039",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:26:51.602206Z",
"iopub.status.busy": "2026-04-22T19:26:51.602130Z",
"iopub.status.idle": "2026-04-22T19:26:51.616577Z",
"shell.execute_reply": "2026-04-22T19:26:51.616072Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
2026-04-21 21:18:33 -05:00
"source": [
"# ── XDF loading + session parsing ─────────────────────────────────────────────\n",
"\n",
"def get_channel_names_from_xdf(eeg_stream):\n",
" ch_desc = eeg_stream['info']['desc'][0]\n",
" channels = ch_desc.get('channels', [{}])[0].get('channel', [])\n",
" return [ch['label'][0] for ch in channels]\n",
"\n",
"\n",
"_SESSION_RE = re.compile(r'ses-(S\\d+)(O[A-Z]*LINE)_(FES|NOFES)')\n",
"_SUBJ_RE = re.compile(r'SUBJ_(\\d+)')\n",
"\n",
"def parse_session(path):\n",
2026-04-21 22:12:59 -05:00
" base = os.path.basename(path)\n",
" m_subj = _SUBJ_RE.search(base)\n",
" m_ses = _SESSION_RE.search(base)\n",
2026-04-21 21:18:33 -05:00
" if not (m_subj and m_ses):\n",
" return None\n",
" ses_id, raw_kind, stim = m_ses.group(1), m_ses.group(2), m_ses.group(3)\n",
" kind = 'OFFLINE' if 'OFF' in raw_kind else 'ONLINE'\n",
" return m_subj.group(1), ses_id, kind, stim\n",
"\n",
"\n",
"def load_xdf_file(filepath):\n",
" streams, _ = pyxdf.load_xdf(filepath)\n",
"\n",
" eeg_stream = marker_stream = None\n",
" for s in streams:\n",
" stype = s['info']['type'][0].lower()\n",
" if stype == 'eeg': eeg_stream = s\n",
" elif stype == 'markers': marker_stream = s\n",
" if eeg_stream is None or marker_stream is None:\n",
" eeg_stream = streams[0]\n",
" marker_stream = streams[1] if len(streams) > 1 else None\n",
"\n",
" eeg_timestamps = np.array(eeg_stream['time_stamps'])\n",
" eeg_data = np.array(eeg_stream['time_series']).T\n",
" channel_names = get_channel_names_from_xdf(eeg_stream)\n",
" sfreq = float(eeg_stream['info']['nominal_srate'][0])\n",
"\n",
" valid_idx = [i for i, ch in enumerate(channel_names) if ch not in NON_EEG]\n",
" channel_names = [channel_names[i] for i in valid_idx]\n",
" eeg_data = eeg_data[valid_idx, :]\n",
" channel_names = [RENAME.get(ch, ch) for ch in channel_names]\n",
"\n",
" ts_arr = np.asarray(marker_stream['time_series'], dtype=float)\n",
" marker_data = ts_arr[:, 0].astype(int)\n",
" marker_ts = ts_arr[:, 1]\n",
" keep = np.isin(marker_data, TARGET_MARKERS)\n",
" return eeg_data, eeg_timestamps, marker_data[keep], marker_ts[keep], channel_names, sfreq\n",
"\n",
"\n",
"# ── Preprocessing primitives ─────────────────────────────────────────────────\n",
"\n",
"def notch_filter(data, freq, sfreq, Q=NOTCH_Q):\n",
" b, a = iirnotch(freq, Q, fs=sfreq)\n",
" return filtfilt(b, a, data, axis=-1)\n",
"\n",
"\n",
"def car(data):\n",
" return data - data.mean(axis=0, keepdims=True)\n",
"\n",
"\n",
"def bandpass(data, lo, hi, sfreq, order=4):\n",
" nyq = sfreq / 2.0\n",
" b, a = butter(order, [max(lo, 0.5) / nyq, min(hi, nyq - 0.1) / nyq], btype='band')\n",
" return filtfilt(b, a, data, axis=-1)\n",
"\n",
"\n",
"def reject_by_ptp(X, thresh_uv=PTP_REJECT_UV):\n",
" if X.size == 0:\n",
" return np.zeros(0, dtype=bool)\n",
2026-04-21 22:12:59 -05:00
" ptp = X.max(axis=-1) - X.min(axis=-1)\n",
2026-04-21 21:18:33 -05:00
" return ptp.max(axis=-1) < thresh_uv\n",
"\n",
"\n",
"def extract_epochs(eeg_data, eeg_ts, marker_data, marker_ts, sfreq, begin_code,\n",
" t_pre=T_PRE, t_post=T_POST):\n",
" epochs = []\n",
" n_pre = int(abs(t_pre) * sfreq)\n",
" for bi in np.where(marker_data == begin_code)[0]:\n",
" t_start = marker_ts[bi]\n",
" i0 = np.searchsorted(eeg_ts, t_start + t_pre)\n",
" i1 = np.searchsorted(eeg_ts, t_start + t_post)\n",
" if i0 < 0 or i1 > eeg_data.shape[1]:\n",
" continue\n",
" ep = eeg_data[:, i0:i1].copy()\n",
" if ep.shape[1] > n_pre:\n",
" ep -= ep[:, :n_pre].mean(axis=1, keepdims=True)\n",
" epochs.append(ep)\n",
" if not epochs:\n",
" return np.empty((0, eeg_data.shape[0], 0))\n",
" min_len = min(e.shape[-1] for e in epochs)\n",
" return np.stack([e[:, :min_len] for e in epochs])\n",
"\n",
"\n",
2026-04-22 16:20:39 -05:00
"# ── Trial-level parsing from markers ─────────────────────────────────────────\n",
"\n",
"def trial_events(marker_data, marker_ts):\n",
" \"\"\"Parse marker stream into ordered per-trial records.\n",
" Returns list of dicts with keys: cls ('MI'|'REST'), success (bool),\n",
" latency (seconds from BEGIN to EARLYSTOP, or None on failure), t_begin.\n",
2026-04-21 22:12:59 -05:00
" \"\"\"\n",
2026-04-22 16:20:39 -05:00
" begin_codes = {MI_BEGIN: 'MI', REST_BEGIN: 'REST'}\n",
" early_of = {'MI': MI_EARLYSTOP, 'REST': REST_EARLYSTOP}\n",
"\n",
" begin_idx = np.where(np.isin(marker_data, list(begin_codes.keys())))[0]\n",
" trials = []\n",
" for i, bi in enumerate(begin_idx):\n",
" cls = begin_codes[int(marker_data[bi])]\n",
" t_begin = float(marker_ts[bi])\n",
" # Search until next BEGIN (or end of stream) for EARLYSTOP of matching class\n",
" end = begin_idx[i + 1] if i + 1 < len(begin_idx) else len(marker_data)\n",
" window = marker_data[bi + 1:end]\n",
" wts = marker_ts[bi + 1:end]\n",
" hit = np.where(window == early_of[cls])[0]\n",
" if len(hit):\n",
" trials.append(dict(cls=cls, success=True,\n",
" latency=float(wts[hit[0]] - t_begin), t_begin=t_begin))\n",
" else:\n",
" trials.append(dict(cls=cls, success=False, latency=None, t_begin=t_begin))\n",
" return trials\n",
"\n",
"\n",
"def marker_stats(marker_data, marker_ts):\n",
" \"\"\"Online performance summary from markers. Returns None for offline sessions.\n",
" Fields: mk_acc, mi_acc, rest_acc, mi_latency, rest_latency, trials (ordered).\n",
" \"\"\"\n",
" trials = trial_events(marker_data, marker_ts)\n",
" if not trials or not any(t['success'] for t in trials):\n",
" return None # offline or no successful trials\n",
" mi = [t for t in trials if t['cls'] == 'MI']\n",
" rest = [t for t in trials if t['cls'] == 'REST']\n",
"\n",
" def _acc(ts): return sum(t['success'] for t in ts) / len(ts) if ts else None\n",
" def _lat(ts):\n",
" lats = [t['latency'] for t in ts if t['success']]\n",
" return float(np.mean(lats)) if lats else None\n",
"\n",
" n_correct = sum(t['success'] for t in trials)\n",
" return dict(mk_acc = n_correct / len(trials),\n",
" mi_acc = _acc(mi), rest_acc = _acc(rest),\n",
" mi_latency = _lat(mi), rest_latency = _lat(rest),\n",
" trials = trials)\n",
2026-04-21 22:12:59 -05:00
"\n",
"\n",
2026-04-21 21:18:33 -05:00
"def load_session_epochs(filepath):\n",
2026-04-21 22:12:59 -05:00
" \"\"\"Preprocessing pipeline: notch → CAR → bandpass → epoch → PTP-reject.\n",
2026-04-22 16:20:39 -05:00
" Returns X, y, ch_names, sfreq, n_rejected, stats\n",
" where `stats` is the marker_stats() dict (None for offline sessions).\n",
2026-04-21 21:18:33 -05:00
" \"\"\"\n",
" eeg, eeg_ts, mk, mk_ts, ch_names, sfreq = load_xdf_file(filepath)\n",
"\n",
2026-04-22 16:20:39 -05:00
" stats = marker_stats(mk, mk_ts)\n",
2026-04-21 22:12:59 -05:00
"\n",
2026-04-21 21:18:33 -05:00
" eeg = notch_filter(eeg, NOTCH_FREQ, sfreq)\n",
" if USE_CAR:\n",
" eeg = car(eeg)\n",
" eeg_bp = bandpass(eeg, BP_LO, BP_HI, sfreq)\n",
"\n",
" mi = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, MI_BEGIN)\n",
" rest = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, REST_BEGIN)\n",
"\n",
" n_pre = int(abs(T_PRE) * sfreq)\n",
" if mi.shape[-1] > n_pre: mi = mi[..., n_pre:]\n",
" if rest.shape[-1] > n_pre: rest = rest[..., n_pre:]\n",
"\n",
" n = min(mi.shape[-1], rest.shape[-1]) if (mi.size and rest.size) else 0\n",
" mi, rest = mi[..., :n], rest[..., :n]\n",
"\n",
" n0_mi, n0_rest = len(mi), len(rest)\n",
" mi = mi[reject_by_ptp(mi)]\n",
" rest = rest[reject_by_ptp(rest)]\n",
" n_rejected = (n0_mi - len(mi)) + (n0_rest - len(rest))\n",
"\n",
" X = np.concatenate([mi, rest], axis=0) if (len(mi) or len(rest)) else np.empty((0, len(ch_names), 0))\n",
" y = np.concatenate([np.ones(len(mi), int), np.zeros(len(rest), int)])\n",
2026-04-22 16:20:39 -05:00
" return X, y, ch_names, sfreq, n_rejected, stats\n",
2026-04-21 21:18:33 -05:00
"\n",
"\n",
"# ── CSP + LDA (2-class, numpy/scipy only) ────────────────────────────────────\n",
"\n",
"def _mean_cov(X):\n",
" covs = np.einsum('ijk,ilk->ijl', X, X)\n",
" covs /= np.trace(covs, axis1=1, axis2=2)[:, None, None]\n",
" return covs.mean(axis=0)\n",
"\n",
"\n",
"class CSPLDA:\n",
2026-04-21 22:12:59 -05:00
" \"\"\"CSP log-var features + LDA. Ramoser 2000; Blankertz 2008.\n",
" Ledoit-Wolf shrinkage keeps the generalized eigenproblem well-posed after CAR.\n",
2026-04-21 21:18:33 -05:00
" \"\"\"\n",
"\n",
" def __init__(self, n_csp=N_CSP, cov_shrink=0.05, lda_reg=1e-4):\n",
" self.n_csp = n_csp\n",
" self.cov_shrink = cov_shrink\n",
" self.lda_reg = lda_reg\n",
"\n",
" def fit(self, X, y):\n",
2026-04-21 22:12:59 -05:00
" assert set(np.unique(y)) == {0, 1}\n",
2026-04-21 21:18:33 -05:00
" C1 = _mean_cov(X[y == 1])\n",
" C0 = _mean_cov(X[y == 0])\n",
" n_ch = C1.shape[0]\n",
" s = self.cov_shrink\n",
" C1 = (1 - s) * C1 + s * (np.trace(C1) / n_ch) * np.eye(n_ch)\n",
" C0 = (1 - s) * C0 + s * (np.trace(C0) / n_ch) * np.eye(n_ch)\n",
" evals, evecs = eigh(C1, C0 + C1)\n",
" order = np.argsort(evals)\n",
" k = self.n_csp // 2\n",
" self.filters_ = np.concatenate([evecs[:, order[:k]],\n",
" evecs[:, order[-k:]]], axis=1).T\n",
" F = self._features(X)\n",
" mu1, mu0 = F[y == 1].mean(0), F[y == 0].mean(0)\n",
" Sw = np.cov(F[y == 1].T, ddof=1) + np.cov(F[y == 0].T, ddof=1)\n",
" Sw += self.lda_reg * np.eye(Sw.shape[0])\n",
" self.coef_ = np.linalg.solve(Sw, mu1 - mu0)\n",
" self.intercept_ = -self.coef_ @ ((mu1 + mu0) / 2)\n",
" return self\n",
"\n",
" def _features(self, X):\n",
" Z = np.einsum('fc,ncs->nfs', self.filters_, X)\n",
" var = Z.var(axis=-1, ddof=1)\n",
" return np.log(var / var.sum(axis=1, keepdims=True))\n",
"\n",
" def decision_function(self, X):\n",
" return self._features(X) @ self.coef_ + self.intercept_\n",
"\n",
" def predict(self, X):\n",
" return (self.decision_function(X) > 0).astype(int)\n",
"\n",
"\n",
"# ── Evaluation metrics ───────────────────────────────────────────────────────\n",
"\n",
"def evaluate(clf, X, y):\n",
" margin = clf.decision_function(X)\n",
" pred = (margin > 0).astype(int)\n",
" amp = np.abs(margin).mean()\n",
" m1, m0 = margin[y == 1], margin[y == 0]\n",
" fisher = (m1.mean() - m0.mean()) ** 2 / (m1.var(ddof=1) + m0.var(ddof=1) + 1e-30)\n",
2026-04-21 22:12:59 -05:00
" return dict(amp=amp, fisher=fisher, margin=margin, y=y, pred=pred)\n",
2026-04-21 21:18:33 -05:00
"\n",
"\n",
"def spectral_snr(X, y, ch_idx, sfreq, band=MU_BAND):\n",
" def band_pwr(sig):\n",
" f, p = welch(sig, fs=sfreq,\n",
" nperseg=min(int(sfreq * 2), sig.shape[-1]),\n",
" noverlap=int(sfreq), axis=-1)\n",
" m = (f >= band[0]) & (f < band[1])\n",
" return np.trapezoid(p[..., m], f[m], axis=-1).mean()\n",
" return band_pwr(X[y == 0][:, ch_idx, :]) / (band_pwr(X[y == 1][:, ch_idx, :]) + 1e-30)"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "98d225db",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 4,
2026-04-21 13:01:49 -05:00
"id": "d266216b",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:26:51.617805Z",
"iopub.status.busy": "2026-04-22T19:26:51.617727Z",
"iopub.status.idle": "2026-04-22T19:27:24.179747Z",
"shell.execute_reply": "2026-04-22T19:27:24.179038Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-21 22:12:59 -05:00
"Found 24 XDF file(s).\n",
2026-04-21 21:18:33 -05:00
"Preprocessing: notch 60 Hz → CAR → bandpass 8– 30 Hz → baseline-correct → PTP-reject @ 100 µV\n",
2026-04-21 22:24:23 -05:00
"\n",
" 002/S001 OFFLINE FES n= 85 (MI=43, REST=42) rej=5\n",
2026-04-22 16:20:39 -05:00
" 002/S002 ONLINE FES n= 53 (MI=27, REST=26) rej=7 acc=0.883 (MI=0.87, REST=0.90) lat(MI=2.74s, REST=2.46s)\n",
" 002/S003 ONLINE NOFES n= 52 (MI=26, REST=26) rej=8 acc=0.833 (MI=0.87, REST=0.80) lat(MI=2.26s, REST=2.29s)\n",
2026-04-21 22:24:23 -05:00
" 002/S004 OFFLINE NOFES n= 90 (MI=45, REST=45) rej=0\n",
2026-04-22 16:20:39 -05:00
" 002/S005 ONLINE NOFES n= 60 (MI=30, REST=30) rej=0 acc=0.850 (MI=0.83, REST=0.87) lat(MI=2.33s, REST=2.15s)\n",
" 002/S006 ONLINE FES n= 56 (MI=27, REST=29) rej=4 acc=0.917 (MI=1.00, REST=0.83) lat(MI=2.26s, REST=1.76s)\n",
2026-04-21 22:24:23 -05:00
" 003/S001 OFFLINE FES n= 89 (MI=44, REST=45) rej=1\n",
2026-04-22 16:20:39 -05:00
" 003/S002 ONLINE FES n= 59 (MI=29, REST=30) rej=1 acc=0.750 (MI=0.70, REST=0.80) lat(MI=2.59s, REST=2.12s)\n",
" 003/S003 ONLINE NOFES n= 38 (MI=17, REST=21) rej=0 acc=0.763 (MI=0.76, REST=0.76) lat(MI=2.15s, REST=2.74s)\n",
2026-04-21 22:24:23 -05:00
" 003/S004 OFFLINE NOFES n= 86 (MI=42, REST=44) rej=4\n",
2026-04-22 16:20:39 -05:00
" 003/S005 ONLINE NOFES n= 43 (MI=19, REST=24) rej=17 acc=0.717 (MI=0.67, REST=0.77) lat(MI=2.58s, REST=2.53s)\n",
" 003/S006 ONLINE FES n= 52 (MI=23, REST=29) rej=8 acc=0.767 (MI=0.77, REST=0.77) lat(MI=2.31s, REST=2.56s)\n",
2026-04-21 22:24:23 -05:00
" 005/S001 OFFLINE FES n= 90 (MI=45, REST=45) rej=0\n",
2026-04-22 16:20:39 -05:00
" 005/S002 ONLINE FES n= 60 (MI=30, REST=30) rej=0 acc=0.800 (MI=0.67, REST=0.93) lat(MI=2.77s, REST=2.13s)\n",
" 005/S003 ONLINE NOFES n= 59 (MI=29, REST=30) rej=1 acc=0.933 (MI=0.97, REST=0.90) lat(MI=1.86s, REST=2.23s)\n",
2026-04-21 22:24:23 -05:00
" 005/S004 OFFLINE NOFES n= 89 (MI=44, REST=45) rej=1\n",
2026-04-22 16:20:39 -05:00
" 005/S005 ONLINE NOFES n= 58 (MI=28, REST=30) rej=2 acc=0.783 (MI=0.67, REST=0.90) lat(MI=2.12s, REST=2.73s)\n",
" 005/S006 ONLINE FES n= 59 (MI=30, REST=29) rej=1 acc=0.917 (MI=0.90, REST=0.93) lat(MI=2.31s, REST=2.27s)\n",
2026-04-21 22:24:23 -05:00
" 009/S001 OFFLINE FES n= 57 (MI=33, REST=24) rej=33\n",
2026-04-22 16:20:39 -05:00
" 009/S002 ONLINE FES n= 42 (MI=21, REST=21) rej=18 acc=0.717 (MI=0.80, REST=0.63) lat(MI=2.56s, REST=2.09s)\n",
" 009/S003 ONLINE NOFES n= 1 (MI=1, REST=0) rej=59 acc=0.717 (MI=0.63, REST=0.80) lat(MI=2.36s, REST=2.47s)\n",
2026-04-21 22:24:23 -05:00
" 009/S004 OFFLINE NOFES n= 86 (MI=42, REST=44) rej=4\n",
2026-04-22 16:20:39 -05:00
" 009/S005 ONLINE NOFES n= 60 (MI=30, REST=30) rej=0 acc=0.850 (MI=0.80, REST=0.90) lat(MI=2.22s, REST=2.00s)\n",
" 009/S006 ONLINE FES n= 50 (MI=26, REST=24) rej=10 acc=0.817 (MI=0.83, REST=0.80) lat(MI=2.64s, REST=1.69s)\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"Loaded 4 subject(s): ['002', '003', '005', '009'] | total artifact-rejected epochs: 184\n"
2026-04-21 21:18:33 -05:00
]
}
],
"source": [
"xdf_files = sorted(glob.glob(os.path.join(DATA_DIR, '*.xdf')))\n",
"print(f'Found {len(xdf_files)} XDF file(s).')\n",
"print(f'Preprocessing: notch {NOTCH_FREQ:.0f} Hz → '\n",
" f'{\"CAR → \" if USE_CAR else \"\"}bandpass {BP_LO:.0f}– {BP_HI:.0f} Hz → '\n",
" f'baseline-correct → PTP-reject @ {PTP_REJECT_UV:.0f} µV\\n')\n",
"\n",
2026-04-21 22:12:59 -05:00
"sessions = {}\n",
2026-04-21 21:18:33 -05:00
"total_rej = 0\n",
"\n",
"for fp in xdf_files:\n",
" meta = parse_session(fp)\n",
" if meta is None:\n",
" print(f' SKIP (unparsed): {os.path.basename(fp)}')\n",
" continue\n",
" subj, ses_id, kind, stim = meta\n",
" try:\n",
2026-04-22 16:20:39 -05:00
" X, y, ch_names, sfreq, n_rej, stats = load_session_epochs(fp)\n",
2026-04-21 21:18:33 -05:00
" except Exception as e:\n",
" print(f' ERROR {os.path.basename(fp)}: {e}')\n",
" continue\n",
"\n",
2026-04-21 22:24:23 -05:00
" sessions.setdefault(subj, {})[ses_id] = dict(\n",
2026-04-21 21:18:33 -05:00
" X=X, y=y, kind=kind, stim=stim,\n",
2026-04-21 22:12:59 -05:00
" ch_names=ch_names, sfreq=sfreq,\n",
2026-04-22 16:20:39 -05:00
" stats=stats, file=os.path.basename(fp))\n",
2026-04-21 21:18:33 -05:00
" total_rej += n_rej\n",
"\n",
2026-04-22 16:20:39 -05:00
" if stats is not None:\n",
" info = (f' acc={stats[\"mk_acc\"]:.3f} '\n",
" f'(MI={stats[\"mi_acc\"]:.2f}, REST={stats[\"rest_acc\"]:.2f}) '\n",
" f'lat(MI={stats[\"mi_latency\"]:.2f}s, REST={stats[\"rest_latency\"]:.2f}s)'\n",
" if stats['mi_acc'] is not None and stats['rest_acc'] is not None else '')\n",
" else:\n",
" info = ''\n",
2026-04-21 21:18:33 -05:00
" print(f' {subj}/{ses_id} {kind:<7} {stim:<5} '\n",
" f'n={len(y):3d} (MI={int(y.sum())}, REST={int((1-y).sum())}) '\n",
2026-04-22 16:20:39 -05:00
" f'rej={n_rej}{info}')\n",
2026-04-21 21:18:33 -05:00
"\n",
"subjects = sorted(sessions.keys())\n",
"print(f'\\nLoaded {len(subjects)} subject(s): {subjects} | '\n",
2026-04-21 22:24:23 -05:00
" f'total artifact-rejected epochs: {total_rej}')"
2026-04-21 21:18:33 -05:00
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "7b8c8bea",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"## Verify Session Layout"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 5,
2026-04-21 13:01:49 -05:00
"id": "611baf23",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:27:24.182586Z",
"iopub.status.busy": "2026-04-22T19:27:24.182447Z",
"iopub.status.idle": "2026-04-22T19:27:24.188463Z",
"shell.execute_reply": "2026-04-22T19:27:24.188025Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Channels (32): ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5', 'FC1', 'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2', 'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1', 'Oz', 'O2']\n",
"Sampling rate: 512.0 Hz\n",
"Motor channels ['C3', 'Cz', 'C4'] → indices [14, 15, 16]\n"
]
}
],
"source": [
"# Verify channel layout is consistent across sessions, locate motor channels\n",
"ref_subj = subjects[0]\n",
"ref_ses = next(iter(sessions[ref_subj].values()))\n",
"channel_names_global = ref_ses['ch_names']\n",
"sfreq_global = ref_ses['sfreq']\n",
"\n",
"mismatches = [f'{subj}/{sid}' for subj in subjects for sid, s in sessions[subj].items()\n",
" if s['ch_names'] != channel_names_global]\n",
"if mismatches:\n",
" print('!! channel mismatch in:', mismatches)\n",
"\n",
"motor_idx_global = [channel_names_global.index(c) for c in MOTOR_CH\n",
" if c in channel_names_global]\n",
"\n",
"print(f'Channels ({len(channel_names_global)}): {channel_names_global}')\n",
"print(f'Sampling rate: {sfreq_global} Hz')\n",
"print(f'Motor channels {MOTOR_CH} → indices {motor_idx_global}')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "70922abb",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"## Train CSP + LDA on OFFLINE, Evaluate on ONLINE"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 6,
2026-04-21 13:01:49 -05:00
"id": "f5e80da3",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:27:24.190995Z",
"iopub.status.busy": "2026-04-22T19:27:24.190909Z",
"iopub.status.idle": "2026-04-22T19:27:26.129399Z",
"shell.execute_reply": "2026-04-22T19:27:26.128933Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 16:20:39 -05:00
"[009] Pair1 (train=OFFLINE_FES) / NOFES: only 1 clean epochs — acc=0.717, skipping EEG-derived metrics\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-22 16:20:39 -05:00
"Subj Pair Cond n trainAcc mkAcc MIacc REacc MIlat RElat |marg| Fisher muSNR\n",
"----------------------------------------------------------------------------------------------------------------------\n",
"002 Pair1 (train=OFFLINE_FES) FES 53 0.659 0.883 0.87 0.90 2.74 2.46 0.778 0.288 1.197\n",
"002 Pair1 (train=OFFLINE_FES) NOFES 52 0.659 0.833 0.87 0.80 2.26 2.29 1.162 0.073 1.452\n",
"002 Pair2 (train=OFFLINE_NOFES) FES 56 0.811 0.917 1.00 0.83 2.26 1.76 1.012 3.133 1.665\n",
"002 Pair2 (train=OFFLINE_NOFES) NOFES 60 0.811 0.850 0.83 0.87 2.33 2.15 0.756 0.695 1.576\n",
"003 Pair1 (train=OFFLINE_FES) FES 59 0.843 0.750 0.70 0.80 2.59 2.12 0.879 0.017 1.347\n",
"003 Pair1 (train=OFFLINE_FES) NOFES 38 0.843 0.763 0.76 0.76 2.15 2.74 0.910 0.683 1.219\n",
"003 Pair2 (train=OFFLINE_NOFES) FES 52 0.907 0.767 0.77 0.77 2.31 2.56 1.166 0.000 1.273\n",
"003 Pair2 (train=OFFLINE_NOFES) NOFES 43 0.907 0.717 0.67 0.77 2.58 2.53 1.071 0.029 1.296\n",
"005 Pair1 (train=OFFLINE_FES) FES 60 0.978 0.800 0.67 0.93 2.77 2.13 3.346 2.498 2.695\n",
"005 Pair1 (train=OFFLINE_FES) NOFES 59 0.978 0.933 0.97 0.90 1.86 2.23 3.620 3.029 2.399\n",
"005 Pair2 (train=OFFLINE_NOFES) FES 59 1.000 0.917 0.90 0.93 2.31 2.27 5.740 6.291 2.911\n",
"005 Pair2 (train=OFFLINE_NOFES) NOFES 58 1.000 0.783 0.67 0.90 2.12 2.73 5.261 4.325 2.287\n",
"009 Pair1 (train=OFFLINE_FES) FES 42 0.667 0.717 0.80 0.63 2.56 2.09 0.304 0.040 1.326\n",
"009 Pair1 (train=OFFLINE_FES) NOFES 1 0.667 0.717 0.63 0.80 2.36 2.47 -- -- --\n",
"009 Pair2 (train=OFFLINE_NOFES) FES 50 1.000 0.817 0.83 0.80 2.64 1.69 4.673 6.192 1.594\n",
"009 Pair2 (train=OFFLINE_NOFES) NOFES 60 1.000 0.850 0.80 0.90 2.22 2.00 3.754 8.959 2.000\n"
2026-04-21 21:18:33 -05:00
]
}
],
"source": [
2026-04-21 22:12:59 -05:00
"MIN_TEST_TRIALS = 10\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"results = []\n",
2026-04-21 21:18:33 -05:00
"\n",
"for subj in subjects:\n",
" subj_ses = sessions[subj]\n",
"\n",
" for pair in PAIRS:\n",
2026-04-21 22:24:23 -05:00
" needed = (pair['train'], pair['online_fes'], pair['online_nofes'])\n",
" missing = [k for k in needed if k not in subj_ses]\n",
2026-04-21 21:18:33 -05:00
" if missing:\n",
2026-04-21 22:24:23 -05:00
" print(f'[{subj}] {pair[\"name\"]}: missing {missing} — skipping')\n",
2026-04-21 21:18:33 -05:00
" continue\n",
"\n",
2026-04-21 22:24:23 -05:00
" train = subj_ses[pair['train']]\n",
2026-04-21 21:18:33 -05:00
" if set(np.unique(train['y'])) != {0, 1}:\n",
" print(f'[{subj}] {pair[\"name\"]}: training set lacks both classes — skipping')\n",
" continue\n",
2026-04-21 22:12:59 -05:00
"\n",
" clf = CSPLDA(n_csp=N_CSP).fit(train['X'], train['y'])\n",
2026-04-21 21:18:33 -05:00
" train_acc = (clf.predict(train['X']) == train['y']).mean()\n",
"\n",
2026-04-21 22:24:23 -05:00
" for cond_key, cond_label in [('online_fes', 'FES'), ('online_nofes', 'NOFES')]:\n",
2026-04-22 16:20:39 -05:00
" te = subj_ses[pair[cond_key]]\n",
" st = te['stats']\n",
" if st is None:\n",
2026-04-21 22:12:59 -05:00
" print(f'[{subj}] {pair[\"name\"]} / {cond_label}: no EARLYSTOP markers — skipping')\n",
" continue\n",
"\n",
2026-04-22 16:20:39 -05:00
" row = dict(\n",
2026-04-21 21:18:33 -05:00
" subject=subj, pair=pair['name'], condition=cond_label,\n",
" train_file=train['file'], test_file=te['file'],\n",
" train_acc=train_acc, n_test=len(te['y']),\n",
2026-04-22 16:20:39 -05:00
" acc = st['mk_acc'],\n",
" mi_acc = st['mi_acc'],\n",
" rest_acc = st['rest_acc'],\n",
" mi_latency = st['mi_latency'],\n",
" rest_latency = st['rest_latency'],\n",
" trials = st['trials'], # ordered per-trial records for trajectory analysis\n",
" )\n",
"\n",
" # EEG-based metrics require enough clean epochs of both classes\n",
" if len(te['y']) >= MIN_TEST_TRIALS and set(np.unique(te['y'])) == {0, 1}:\n",
" res = evaluate(clf, te['X'], te['y'])\n",
" snr_s = spectral_snr(te['X'], te['y'], motor_idx_global, te['sfreq'])\n",
" row.update(amp=res['amp'], fisher=res['fisher'], mu_snr=snr_s,\n",
" margin=res['margin'], y_test=res['y'], pred=res['pred'])\n",
" else:\n",
" print(f'[{subj}] {pair[\"name\"]} / {cond_label}: only {len(te[\"y\"])} clean epochs — '\n",
" f'acc={st[\"mk_acc\"]:.3f}, skipping EEG-derived metrics')\n",
" row.update(amp=np.nan, fisher=np.nan, mu_snr=np.nan,\n",
" margin=np.array([]), y_test=np.array([]), pred=np.array([]))\n",
"\n",
" results.append(row)\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"hdr = (f'{\"Subj\":<5} {\"Pair\":<28} {\"Cond\":<6} {\"n\":>4} '\n",
2026-04-22 16:20:39 -05:00
" f'{\"trainAcc\":>9} {\"mkAcc\":>7} {\"MIacc\":>6} {\"REacc\":>6} '\n",
" f'{\"MIlat\":>6} {\"RElat\":>6} {\"|marg|\":>8} {\"Fisher\":>8} {\"muSNR\":>7}')\n",
2026-04-21 21:18:33 -05:00
"print('\\n' + hdr)\n",
"print('-' * len(hdr))\n",
"for r in results:\n",
2026-04-22 16:20:39 -05:00
" fmt = lambda v, s: (f'{v:{s}}' if (v is not None and not (isinstance(v, float) and np.isnan(v))) else f'{\"--\":>{int(s.split(\".\")[0].lstrip(\">\"))}}')\n",
2026-04-21 21:18:33 -05:00
" print(f'{r[\"subject\"]:<5} {r[\"pair\"]:<28} {r[\"condition\"]:<6} {r[\"n_test\"]:>4} '\n",
2026-04-22 16:20:39 -05:00
" f'{r[\"train_acc\"]:>9.3f} {r[\"acc\"]:>7.3f} '\n",
" f'{r[\"mi_acc\"]:>6.2f} {r[\"rest_acc\"]:>6.2f} '\n",
" f'{r[\"mi_latency\"]:>6.2f} {r[\"rest_latency\"]:>6.2f} '\n",
" f'{fmt(r[\"amp\"], \">8.3f\")} {fmt(r[\"fisher\"], \">8.3f\")} {fmt(r[\"mu_snr\"], \">7.3f\")}')"
2026-04-21 21:18:33 -05:00
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "2ab81600",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"---\n",
"## Figure 1 — Per-metric comparison (FES vs NOFES)"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 7,
2026-04-21 13:01:49 -05:00
"id": "d53e63b9",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:27:26.130998Z",
"iopub.status.busy": "2026-04-22T19:27:26.130913Z",
"iopub.status.idle": "2026-04-22T19:27:26.627725Z",
"shell.execute_reply": "2026-04-22T19:27:26.627293Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"data": {
2026-04-22 16:20:39 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABoAAAAQ8CAYAAACyzFyVAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3Qm4XdPdOOB1k9yEEBEpSUhiaMxTKapUS4uWqmgNHSg1fSqltFqtKpqWKuqrVumglJYPNbTpoDM6qKKTqVXSIoYQQhNBkpvk/J/f9t+3+5577jyc4b7v8xw3zj7DOmuvvfZe67fXWk2lUqmUAAAAAAAAaBjDqp0AAAAAAAAA+pcAEAAAAAAAQIMRAAIAAAAAAGgwAkAAAAAAAAANRgAIAAAAAACgwQgAAQAAAAAANBgBIAAAAAAAgAYjAAQAAAAAANBgBIAAAAAAAAAajAAQAAAAAABAgxEAAgAAAAAAaDACQABVdtNNN6UPfOADaeONN05jx45No0aNSpMmTUq77rprOvPMM9O8efMG5HsfeeSR1NTU1PqINBTF/xe3x+vr3Wc+85k2v+nWW29NjWi99dZr/Y3x76LLL7+8TR7E/zeK8jLb2aP8d8fx1t33Vio3Tz75ZPrkJz+Ztt122+w4bm5uTq961avSRhttlPbYY4900kknpWuvvTY1yrETjxtvvLHd6yJvOqtXihYuXJguuOCC9La3vS1Nnjw5rbzyymnMmDHp1a9+dXrPe96TrrnmmrR8+fJulfOuHuX1V3ffF49K/vznP6cjjjgi27+rrLJKWmmllbJ6e4sttkjvfOc70xlnnJFuu+221FOlUilddNFF6XWve11abbXV0rBhw1rT8YMf/CDVqvLjp7MyEWWpVlW7foxyGmUpvjuuBR5//PFB/f5G09f92dm5lIHz4IMPpkMOOSRNmTIljRw5snUfvOY1r2l9zZIlS9LnP//57Lmog4v7+W9/+9uQvMathnq9rh5qx/ZAlffO9v9LL72U1lxzzdZtt99+e798JwC9M6KX7wOgjx577LGsk/MPf/hDu21PPfVU9vjNb36TNXDPPffcdNxxx8lzqEG//e1v0zve8Y4soFE0f/787PHQQw+lX/3qV2nChAnp3e9+d2oUp59+etpvv/2yIEVPXX/99el//ud/0vPPP99u26JFi9K///3vLGD22c9+Nn3ve9/LAiu14n//93/Txz/+8bRixYqK9fb999+fBWv++c9/pp133rnHeRqBf4amk08+OevYDocddlgWGIU8yBrXhMVgcSOKm55e//rXp+eee67T1x111FHpyiuvTENRBCweffTR7N/rrruu4BU1afTo0emjH/1o+tSnPpX9/0c+8pEsCNTRjTUADCwBIIAqeOKJJ7I7vOfOnfvfCnnEiOy51VdfPd19992td/6+/PLL6fjjj0/PPvvsoN45vf3222cdsbm4w5LG6DjYf//92/x/o9p0003TZpttVnFbV797u+22yzpWKok7GnMvvvhiFtQpBn823HDDNG3atCwwEsfx3//+99TS0pIaTQQ6rr766nTwwQf36H2XXXZZOvLII9s8F6Olos5ZvHhx1kEQf8M//vGPrDMwAuVbbrllp5/7xje+sc2+Keqq/tprr72yzoqu/OlPf0of+9jHWjtfYx/HqK8Y/RMd9//617+y4FVvO2cvueSSNv//pje9KcubsM466/TqM6kPd955Z7ruuuuyf0cHWQSDqK699967dRT2WmutZXcMghhZWgz+RN26ww47ZNfI66+/fut5N849udi22267ZaMmQ1xHd8U1bv+Ia6ziNWVH52CGprh58ayzzsqO2TvuuCO7+efAAw+sdrIAhiQBIIAqeP/7398m+LPJJpukH/3oR1mncYg7y7/whS+kU089tfU1cSd8NHCjQ3AwfOhDH8oeNN5dxPEYCg466KBeB02j7Hc2fVnuZz/7WTbqI/fFL34xm+6tKBq+v/jFL9KPf/zj1GgifyMAFh1w3RGjoY499th2eR2jamKqn/D0009nHUr5FGoRiD7ggAOygFNn3zNz5sxel+2LL764W8HQK664ojW4E5300aERwcKiKA+zZs3KRn/1VPz23I477lg30+nQd1/+8pdb/73TTju1Xg9QPVEvNKK77roruys/Ao4dBUvmzJmTDj300Cxgv8EGGwxa2op1YPj2t7+d3vrWt7Z5Lm6IKk4PGueHYkAodDXFlWvc/rvOigdUEtP6xkjxq666qvU8JwAEUB3WAAIYZL/73e/SLbfc0vr/w4cPTzfccEObzp64qzwa5+9973tbn4tOxwgCdTWn809/+tNszZFo1Mcd7XHn5Pe///1+nS+60tzqMZVT3LEc63fk6xjFFE+ddYL+/Oc/zzqPo+M11v9YddVVs6me4nNiTZXeiLvw426zWFMp1lKIu+aPOeaYdp0KHXnhhReyzuj8zvvolB4/fnx6y1veknVELFu2rMP3xt3/MTrgta99bRo3blz23okTJ2bTQMXUTpVGgUTgLzq7p06dmqU38iDSfvTRR6e//vWvnXakR/mIuy0j77beeuuss6qrkQddrYlQPi96/N5YkyRGOUR5inK1zz77pHvuuafi50fw8utf/3o2L3+kK/LwXe96V7r33nu7/O7y7bW8VkhxPxS9+c1vrjj6JNaFufTSS7v9uTH1YzEv4rguF+Wj+JoZM2a0botRKDFyMI6naIDHmkRRVjbffPNs6skvfelL6T//+U/qq9mzZ/doXY0IbC9durRNR/eFF17YGvwJMVVe1Fn53dz5mhDlHXzV3t+x1lMc6+XimI86J5/2pDvy467oj3/8Y4fr6vSlnooRadH5GWUhykbU1/H9MeXYX/7ylw7fFzctRJ2+9tprZ3VV3LgQdW1xf3ZXBMkij2KKs/j+OP/FukmxZkC5mE4vgoYxEixG5eXlOX73G97whuxY6awsx7YIzMYNFDGKI/JqjTXWSFtttVV2jHR3LYQYlbbnnnu22SeRrv6YiuuZZ57J7ozOve9972v3mkrn3RgpEdPqxP6LfIzzyAknnNDp9Fm9Oe8WvzeCrDHiMaZBjPN95Gd3A69RbqPMRHAz9kHsxzinxP5/+9vfngVxY+rEzr67r2t5xDEc+Rv1TJTj+O1RH1Y6Zrrz2b09nqLcxE0BsS8iwBLniTjH5qN0I4hcvBYrTv9Wni89Gcn7wAMPZNeIMS3p7rvvXnEazpheLOqW+M74G+WzN3pyfZOf/8vP+7FGXPGaoNLvjfXiepoX9XiNm6+1lk//FuLfHR0j5eU3vz6OUTuRnjyvot49++yzs0BOpC/OYfEb8/IY13AdtSO6WgOoPG0xq0GehigTUY/HtWxn9XCcC2MfxOjqOE4i7bH+XlzzlF+D9cd1clfKf1PUt3EuifZGfE9c/373u99tff1PfvKTtMsuu2T7P+q7mDL4vvvu6/DzYw2rKFtxjo33RD7FelixH/J6oZIFCxZk9XJ+Poi/8f9R73ZHHOdRB0f9nLdhop6Ma/5Kaz52V/F8Fjf2dNR+AGCAlQAYVB/5yEei5dH6eOtb39rha//whz+0ee3w4cNL//nPf1q3H3bYYW22v+9972vz/8XH1Vdf3eazH3744Tbb47OKyj87Xt/Re9/whjeUpk6dWvF7t95669KSJUvafHb8/0EHHdRhWuOx+uqrl371q1/1KG8XL15cetOb3lTx8yZOnNguf2655ZY27//rX/9aWnfddTtN1y677FJ6/vnn2333V77yldLIkSM7fW/xfS+99FLpHe94R6evHzZsWGnmzJntvuvOO+8srbbaahXfc+CBB5YmT57c+v/xe4q+/e1vt3l9/H9R8fdPmDCh9Ja3vKXi94wZM6b00EMPtUvb+9///oqvHzVqVOmQQw7p9LvL03bGGWf0YO+3L7M9eX95uSlPW0fOP//8Nu/bdtttS9dff33p2WefLfXFE088kR3v+edG3pU76aST2nz3n//85+z5v//976WxY8d2Wrbicdddd3U7PZGX5cdB/u8pU6Zkx16IY6qjemXFihWl8ePHd1ovFf3P//xPm9fuv//
2026-04-21 21:18:33 -05:00
"text/plain": [
"<Figure size 1680x1080 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: fes_vs_nofes_metrics.png\n"
]
}
],
"source": [
"METRICS = [\n",
" ('acc', 'Classification accuracy', '0– 1'),\n",
" ('amp', 'Classification amplitude (mean |decision fn|)', 'a.u.'),\n",
" ('fisher', 'Fisher ratio on LDA projection (test-set SNR)', 'a.u.'),\n",
" ('mu_snr', 'μ-band power ratio REST / MI @ C3/Cz/C4', 'ratio'),\n",
"]\n",
"\n",
"cond_color = {'FES': '#E05C2A', 'NOFES': '#2A7BE0'}\n",
"\n",
"fig, axes = plt.subplots(2, 2, figsize=(14, 9))\n",
"fig.suptitle('Online decoding: FES vs NOFES feedback (per subject × offline-trained model)',\n",
" fontsize=13, fontweight='bold', y=1.00)\n",
"\n",
"for ax, (key, title, unit) in zip(axes.ravel(), METRICS):\n",
" labels, vals, colors = [], [], []\n",
" for subj in subjects:\n",
" for pair in PAIRS:\n",
" tag = pair['name'].split()[0] # \"Pair1\" / \"Pair2\"\n",
" for cond in ('FES', 'NOFES'):\n",
" row = next((r for r in results\n",
" if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']==cond), None)\n",
" if row is None: continue\n",
" labels.append(f'{subj}\\n{tag}\\n{cond}')\n",
" vals.append(row[key])\n",
" colors.append(cond_color[cond])\n",
" x = np.arange(len(vals))\n",
" ax.bar(x, vals, color=colors, edgecolor='white', zorder=2)\n",
" ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=7.5)\n",
" ax.set_title(f'{title} ({unit})', fontsize=11, fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
" if key == 'acc':\n",
" ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n",
"\n",
"fig.legend(handles=[Patch(color=cond_color['FES'], label='ONLINE_FES'),\n",
" Patch(color=cond_color['NOFES'], label='ONLINE_NOFES')],\n",
" loc='upper right', ncol=2, bbox_to_anchor=(0.98, 1.0))\n",
"plt.tight_layout()\n",
"plt.savefig('fes_vs_nofes_metrics.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: fes_vs_nofes_metrics.png')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "248740bd",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"---\n",
"## Figure 2 — LDA decision-function distributions\n",
"\n",
"Visualizes classification amplitude and separability directly: wider FES vs NOFES spread between MI and REST curves = higher Fisher ratio and larger mean |margin|."
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 8,
2026-04-21 13:01:49 -05:00
"id": "393042a0",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:27:26.629009Z",
"iopub.status.busy": "2026-04-22T19:27:26.628916Z",
"iopub.status.idle": "2026-04-22T19:27:27.543034Z",
"shell.execute_reply": "2026-04-22T19:27:27.542546Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
2026-04-22 16:20:39 -05:00
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/numpy/lib/_histograms_impl.py:897: RuntimeWarning: invalid value encountered in divide\n",
" return n / db / n.sum(), bin_edges\n"
]
},
2026-04-21 21:18:33 -05:00
{
"data": {
2026-04-21 22:12:59 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABZAAAAYTCAYAAABt5R22AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3QeUG9X1x/Grsn3tdW9gm2aH3jsBbHoILWBqANMhEEJC78YECB1CMSHgYFpooZMAgVBDNQnNhGaasXEva6+3Spr/+T0z+kta1fX2/X7O0VmtNDN6mhmNRnfuuy/geZ5nAAAAAAAAAACkCKY+AAAAAAAAAAAAAWQAAAAAAAAAQEZkIAMAAAAAAAAA0iKADAAAAAAAAABIiwAyAAAAAAAAACAtAsgAAAAAAAAAgLQIIAMAAAAAAAAA0iKADAAAAAAAAABIiwAyAAAAAAAAACAtAsgAAAAAAAAAgLQIIAMAAAAAAAAA0iKADABADoFAIH675JJL2n19jRkzJv76ut+V30uhvv76azv00ENt6NChFg6H422fMmWKdRevvPJK0nbR/11Rpn2rM7+/rtbmztouoCf79ttvkz6X3en7CQAAHwFkAMBK/1g66qijcs6jH1SJ8+hWXFxsvXv3ttVWW80FRs855xz79NNP82rD448/3mx5Z5xxBluzG6mrq7M999zTHnzwQZszZ45Fo1HrahSUTNxHsXK6c6CGfaXzef/99+3kk0+2DTfc0Pr06eO+swYOHGjbbbedXXzxxfbDDz/kFejX7YUXXsh6cVDfg9m+M/PZ1/VdnDiPPi+ZPju63XHHHTmXket9pbulvhcAAND1hTu6AQCAnqupqcndli1bZt999529+uqrdvXVV9sxxxxjN910k1VUVGScd/Lkyc0eu+++++zKK6+0oqIi605+9atf2V577eXuDx8+fKWWdc0118Tvb7vtttaZTZ061T7//PP4/1oH22+/vQWDQdtiiy2su1hzzTWTtov+70468/vrSp+Hzr4uu5P6+nr7zW9+kzbAumDBAnd788033ffV9ddf74LMuZx99tn23//+t1NdSJowYYL98pe/tPLy8o5uSpfWr1+/pM9ld/p+AgDARwAZANAhDj74YNt8881t6dKl9tFHH9mzzz5rjY2N7rm//OUvNn36dPvnP/9pJSUlzeZV1tdzzz3X7PF58+bZ008/bfvvv791t3XVWs4880zrKhKz5+TGG2/slgEzXRToStulq7+/SCTiLlyVlZV1qnZ1xXXZHXmeZ0cccYT97W9/iz82ePBgO+igg2zIkCHuotZDDz1kDQ0N7nbKKae4efQ3mw8++MBd5NSyO4vZs2fbddddZxdddFGL5t91111tt912a/Z4VVWV9STqScXnEgDQ3VHCAgDQIfbYYw/3g+vSSy+1J554wr788kvbZJNN4s+/9tprdvnll6edV115/XIGqou7xhprxJ9T8Lklqqur7fTTT3cBmtLSUvvJT35iV111lQs05aKyG8oSXnvttV3WtAJTmv93v/udzZo1K+N8f//7323cuHE2cuRI95r6EaplHHvssfbVV1/lVQNZwQxNP2rUKPe66mKtesHKgFKb/vWvfxVUA3natGl2wgkn2OjRo11Wmpa51lprudf48MMPc3a7V0DlD3/4g3sfCv4r4HLSSSe5CwX58rtajx8/PulxtSOxa3auerDZ1lvqelBm4L777mt9+/Z173nLLbd02yedWCxmf/3rX11G9LBhw9z71HwbbLCBCyIpO9Fv28SJEzO+rl/6JZ+6tv/4xz/chZFVVlklXvpl4403tvPPP9/mzp3bbPrU965ptD9ofrVX+4sy5hT4auvPSbb3p3V522232Q477GADBgxwn2cFn7St99lnH7vsssts+fLlblp1i1999dWTln300Uen7W6vaRPX88cff+y2b//+/V0PhXfeeafgmuAKKG699dbuM66MQ312EzPk8ymxkVgewO/m31X3lUK2Xb46w37+6KOPJgWP119/fXeMV68YtePuu++2t99+O6mHjL7Lsh3rfQrU6hjZmWj96OJrSyhrX+899Xb88cfnvYzFixfbBRdc4Laztrf2I+1P6623nh1++OF25513NptHxxs9vssuu7iSItpXNI+C2Y888kja13n33XftkEMOiX/f6rbqqqu6ciS//e1v7b333mvx9PmU1lHGurK9dQzTsrT/rLvuui7TXbX+U6UeK/QdqjJfml/79ogRI+zcc8+NX3hPpLJPWhe68KHjXa9evdwydN6lfVAloQAAKJgHAECBvvnmG/0aj9/Gjx+fc5677roraR79n2rGjBleSUlJfJrevXt7jY2NSdPEYjFvzTXXjE+zxx57eFdccUX8/1Ao5M2aNaug97N06VJvww03TGqff9t7772T/p8wYULSvHfeeadXXFycdl7d+vbt6/373/9Omkfvady4cRnn0e3xxx+PT7/jjjvGH9d932effeZVVlZmXU7qtsn2Xm6//XavqKgo47LC4bA3adKkpHm0jMRptt9++7Tzjh07tsX7V7qbpnn55ZeTHtP/iTKtt9T1sNVWW6XdhsFg0HvppZeS5luyZEnG9+jf3n///WZty7Ztsr2PaDTqHXXUUVmXM2DAAO/NN9/M+N7XWGMNb9iwYWnnveSSS9r8c5Lt/R1//PF5bWsZOXJkzml9idNusskmXkVFRdo25Nvm1Pfn3/r06eN9+OGHGffd1OOctrn/nNqY7rW6yr5SyLbLpTPt5zvttFPSvM8++2za6c4555yk6S699NL4c6nbaejQofH7V199ddr2+/tDId+ZqRL3r9T1n7pvJrbp5JNPzriMRKnvK/U7pFD19fXe+uuvn3W7p66XhQsXeltssUXWeQ499FC3T/leeeUV9/2VbZ7E91Lo9Lk+9xdddJEXCAQyLkvHpyeeeCLjtuzfv7+37rrrpp1Xn5tEl19+ec7PZep3JQAA+aCEBQCg01BWozJknnzySfe/Mm6U5bPNNtvEp1Gd5MTsXGX0/PSnP3UZTIoHKTNZGWLnnXde3q+rwZBURsO30UYbuQw6vc4DDzyQcT5lMSpbV5l4ogxUZTmqHcoA0vzKrvrFL37hMqz9br1nnXVWUoabshnVPVqZw5pHZTjycdddd1lNTY27rwGelI2pLCxl4Wk5yuLOl7KjlLnnvxctR12tQ6GQ3XPPPS5DTV3/lWGr96l1ns7rr7/u3q8yq+6///54GYqXX37Zra+tttoq73qS2vbqKu5T9p8yff1pUktctJTapcwy7Uvff/+9yy4WrQvVOB07dmx8WmVF6z0m7rP77befa89nn30W33Z+rVqVYUkcPCuxTqYyG3PR9InZbJpH+5i2sfZzZeIp41mPJe5jiZTdpow3bV9lVytrVAMUiuq3ar3mUze8pZ+TTLTvJvYY2Gmnndy6VobmzJkzXQ3sTz75JP68PuPa5ldccUWzUji5BkLTfqztq4xplcfJVl89HW1XZdoqy1XZ6s8884x7fMmSJS5TUI+1VFfcVwrddp2p7dno++Pf//53/H8dV/WdlI6yU5V978t2vFXW/u9//3v3naYeGscdd1z8WNZRtN/qe+qbb76xP//5zy6rVhnbhdD3xrXXXps2MzmfmuL6XlCvF1Fte33n6DOq780ZM2YkbQvfkUce6fYv0fbWdlDWu/Y3fV/ouK3jkfYhbXPRvqDvL1F2ujKblZWrcljanxKP6S2ZPhu1Sdvep0xgHbdqa2vj3+HK1Nf70LpIV6Zp4cKFbp3ovavXi7Kv9XkQfT/rmKjzB1GmvE/HRn/8BH236fj9n//8J++2AwCQJK8wMwAA7ZCBLGeffXbSdA8//HDS84cffnj8ufLycm/ZsmXu8W233Tb++FprrZX39mpqavJ69eoVn3f06NEuK8qnrLJMWUcHHHBA/PGNNtrIa2hoSMqSKi0tjT9/ww03uMcXL16clOU7YsQIb968eUlt0nuaO3duzkza0047Lf74iSee2Oy9KdP522+/TXos03vZf//9k7K4P/300/hzX3zxRVI21j777JMxA/m3v/1t/LkPPvg
2026-04-21 21:18:33 -05:00
"text/plain": [
2026-04-21 22:12:59 -05:00
"<Figure size 1440x1536 with 8 Axes>"
2026-04-21 21:18:33 -05:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: decision_margin_distributions.png\n"
]
}
],
"source": [
"fig, axes = plt.subplots(len(subjects), len(PAIRS),\n",
" figsize=(6 * len(PAIRS), 3.2 * len(subjects)),\n",
" sharex=True, squeeze=False)\n",
"fig.suptitle('LDA decision-function distributions on ONLINE sessions\\n'\n",
" '(class separation ↔ classification amplitude & SNR)',\n",
" fontsize=12, fontweight='bold', y=1.01)\n",
"\n",
"for i, subj in enumerate(subjects):\n",
" for j, pair in enumerate(PAIRS):\n",
" ax = axes[i][j]\n",
" for cond, color in cond_color.items():\n",
" row = next((r for r in results\n",
" if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']==cond), None)\n",
" if row is None: continue\n",
" m_mi = row['margin'][row['y_test'] == 1]\n",
" m_rest = row['margin'][row['y_test'] == 0]\n",
" ax.hist(m_mi, bins=15, alpha=0.5, color=color,\n",
" label=f'{cond} MI', density=True)\n",
" ax.hist(m_rest, bins=15, alpha=0.25, color=color, hatch='///',\n",
" edgecolor=color, label=f'{cond} REST', density=True)\n",
" ax.axvline(0, color='k', lw=0.8)\n",
" ax.set_title(f'{subj} | {pair[\"name\"]}', fontsize=10, fontweight='bold')\n",
" if j == 0: ax.set_ylabel('Density')\n",
" if i == len(subjects) - 1: ax.set_xlabel('LDA decision function')\n",
" ax.spines[['top','right']].set_visible(False)\n",
" if i == 0 and j == 0:\n",
" ax.legend(fontsize=6.5, loc='upper left', ncol=2)\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('decision_margin_distributions.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: decision_margin_distributions.png')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "fcb6d19d",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"---\n",
"## Figure 3 — Paired Δ (FES − NOFES) per metric\n",
"\n",
"Within each (subject × pair), FES and NOFES sessions use the same offline-trained model. Positive bars mean FES > NOFES; negative means NOFES > FES. This removes the offline-model-quality confound and isolates the effect of feedback type."
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 9,
2026-04-21 13:01:49 -05:00
"id": "75df404b",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:27:27.545755Z",
"iopub.status.busy": "2026-04-22T19:27:27.545655Z",
"iopub.status.idle": "2026-04-22T19:27:27.826448Z",
"shell.execute_reply": "2026-04-22T19:27:27.826020Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"data": {
2026-04-22 16:20:39 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB3kAAAIwCAYAAACLCdFHAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAAtKxJREFUeJzs3Ql8FPX9//FPgIRLwBDlUA4RtIBcKlCEolBvqogQqFoRtSiIBwgWBYuIihYRRZGgRSwinggatVorBS9EwAuhLRWscigKIhIRAonu//H+/v6znd3sbjZhQ3aT1/PxmMcmszOzM7Ozc33m8/mmBQKBgAEAAAAAAAAAAAAAUkKV8p4BAAAAAAAAAAAAAED8CPICAAAAAAAAAAAAQAohyAsAAAAAAAAAAAAAKYQgLwAAAAAAAAAAAACkEIK8AAAAAAAAAAAAAJBCCPICAAAAAAAAAAAAQAohyAsAAAAAAAAAAAAAKYQgLwAAAAAAAAAAAACkEIK8AAAAAAAAAAAAAJBCCPICAAAAAAAAAAAAQAohyAsAAAAAAAAAAAAAKYQgLwAAAMrN3LlzLS0tLdh98cUXJZ6Gf/xbb7017vF69eoVHE9/V0SXXnppcBmPOuqo8p4dABEEAgE78cQT3e+0atWq9umnn6bMeirPfTiSg75z/3eobQL/Z/HixcH1cv7557NaAAAAkHAEeQEAABDT3/72t5AbuAsXLiwyTOPGjYPv16lTx3766aeoNzrVvfDCC8Wu9TfeeCNkHP2P5PX888+HfF/qxowZY8lIgaTweY3UKUjup0B5PONF2lYXLVpk55xzjvutZGRkuN9Js2bNrFu3bnbFFVfYww8/bMkkfJluvvnmYtdjtADf+vXr7YYbbrDOnTtbVlaWpaenu1f9r21E78cTPIrVJeL79eZ1xIgR1qZNG6tdu7ZVr17dGjVqZMcdd5xlZ2fb7bffbps3b7ZEevzxx+3DDz90fw8YMMCOPfZYqwjYh5f8YZxYXXgAPN7fRvjv8ueff7ZHH33UTjvtNGvQoIH7PdarV89atGhhv/rVr+zqq6+2Z5991pKZf18c6XecLNuf1rH2c6LznrfffvugzwMAAAAqtmrlPQMAAABIbj169HDZZV7g9s0333SBCI+yzr7++uvg/7t373YBiy5dugT7aRyPbrb27NnT/a1hpk6dGnyvfv36drBcddVVLugmTZs2tYroggsusHbt2rm/dRO/LM2ZM6dIv/nz59uf/vQnF0SozIYOHVpk/RQUFLjfigKGK1assAULFtiwYcMsWU2fPt0FP4888si4x1EwaeLEiXbnnXe6v/2+++47133wwQdu2uPHj7dJkyZZlSrl9xzy66+/bn379rX8/PyQ/t98843r/vWvf7mHXNq3b5+wfYb2qxMmTAj+f/3111sqScQ+3D9+9+7dEzZvKKqwsNAd91577bWQ/nl5ea5TQHjZsmWuGzRoEKswAUaPHm0XXXSR+1v7OQK9AAAASCSCvAAAAIhJGYcnnHCCrVq1qkjANtL/Xr9oQV4FHZXFJ8qOU1cefvvb31pFd9ZZZ7muNIEABSFr1qwZ1/BfffWVy/gOt23bNnvppZesf//+lsyGDx9uLVu2LNLfC5BHkpmZ6W7YR+Kf1t///veQAO/xxx9vZ555pgu6K8i5Zs0ae+eddyzZ7dmzx2655ZaIwfxoxo4da9OmTQv+X7duXffgQfPmzW3jxo329NNPu8CSAsB33HGHC676A37hlBFXmt9tPN+v5uHyyy8PBngVrBw4cKDLtt6/f79t2LDB3n33Xfv8888tkfT72LRpk/tbmZQnnXSSpZJE7MOV5Y3/0X5F+5dwsQLgRx99tHtwKRJ/4F0ZvP4A78knn+y6WrVq2fbt2+3jjz+25cuX83UcIO3XtL8TPTii9at9qPb12ufrQREAAAAgIQIAAABAMf7whz8EdOqoLi0tLbBjx47ge7/73e9c/1q1agXq1avn/j7nnHOC7+/ZsydQvXr14PjXXHNN8L2//OUvwf7qPv/8c9ff3y9S17x58+A0/P0nTpwYWLNmTeD8888PZGZmBmrUqBHo0qVL4OWXXy6yTKecckpwPP3tFz7NDz74INC3b9/AoYceGnOaxQlf3s8++ywwY8aMQPv27d06OvzwwwOXXnpp4Msvvywy7t133x3o169f4Nhjjw1kZWUFqlWrFjjkkEPcuNdff31g8+bNRcYZMmRIxHUm+t97T8N98sknbhnr16/v+i1dujTu5Zo8eXJwWhkZGYG2bdsG///Nb34TSDb6Tv3fQ7zL6l9n4eszGn033jgtW7YMFBYWFhlm//79gb/97W+BZBLpd1e1atXA2rVro65H7/crq1atCnnviCOOCHzxxRchn6H/1d8/3Pvvvx98X9Pzv6fttKy+39WrV8c1jvYv/uU8UGeffXbwM8eOHZvQfYZoXkeNGhU47rjjArVr13a/z2bNmgV++9vfBt56662I4zz11FOB008/PdCgQYPgfkbb+5lnnhn44x//GNi6dWuZ7sNlzpw5wX5VqlQJbNmypch8dujQIThM//79Q97717/+FRg+fHjgF7/4hTs2ab+tfafWRaRpJYK+F/+xMV7+/XT47ygW/zjhx7BodGz0xunVq1fEYX788cfA4sWL457/8N+ptoklS5YEevfu7bYddWeccUZgxYoVEcffvXt3YNq0aYEePXq443Z6errb9nQ8+vvf/x5zXUXq9NstyfYn3377beDWW28NdO7cOVC3bl03D0ceeWTgwgsvDKxcubLIPIdv9+vXrw9MmTIl0Lp1a/cbC/8+Bg4cGBz22muvjXvdAgAAAMUhyAsAAIBiKaDpv6H5/PPPB99r0qSJ63fqqae6gJ7+VjD0p59+cu/rZq9/3AULFpRZgEA3VmvWrFlkeAUJNB+lCfL+8pe/dDdt45lmccKXV+ss0vI1bdq0SCBCgd1Y60Q3x//5z3+WKsh7/PHHuwBQSQNj8vPPP7vgpT/Yct9994UEBqMFoGIpbhsI77zgULIFea+77rrgOPoOP/3000Aq8K+fxo0bB//u06dPXEHeyy+/POS9WbNmRfwc9fcPp/HKI8j74YcfhoyjbVjbdlnKz893wUfvM3NzcxO6z3jxxReL/K7Du/Hjx0d9YCNa51+fZRXkVeCvTp06wf56yCU82O4f75VXXgm+98gjj0TcZ/v3le+8804gkbx9nvalJQ30HswgrwKn3jgKevsD9qUV/jvVPkLHx/D1rocSwoO2emjhmGOOibmt+B9+KIsgrx5IadiwYdRhdQxTAN8vfLvv2bNnzO9j+vTpwfdatWp1wOscAAAA8FCuGQAAAMVSG7rh7fL269fP/vvf/9qWLVtcv1NOOcWV9/3rX/9q33//va1evdqVpg0v56zSkMVRydbPPvvMHnrooYglV6O1L6vPatKkif3ud79zbZ0++eSTwVKsd999t/Xu3bvE37baS030ND3/+Mc/XPuIKoe9dOnSYFt9+pxrr73WFi1aFBxW89CrVy9X6lalPNW2sdb9s88+68r+7ty505XHffnll0s8Hx999JH7frWMv/jFL1xp2tq1a8c1rta5vivP4MGDrVu3bq4Eq7YXdY899piNGzfOktUzzzxj77//fpH+Kg0cre1VleO85557ivTXtnnFFVcE/9d369mxY4dbvyrVqdLD+n3ot9WxY0dLZmeccYbbJtRO5yuvvGJvvPGG2xZjCf/dX3jhhRGHU/lmf5nZt956K+o0//nPf0Zc5yq7HKsseTzfb+vWrd3+a+/evcG2cdWetMon6/vRNu3t4xJFJfD97f927do1YfsMlZXW8nnLo3Kxl156qds+tT607xa1l6xyy16boQ888EDws7SNeu2Wa/qffPKJa0M5HqXdh3u0/9G2MXv27GD73n/4wx+C7z/xxBPBv1VSWyXQvf31lVdeGWwDWr+18847Tw+3u/LgmiftK88//3xbv359wtoqV/u1M2fOdPvS0047zRYvXlzqNua1zJHKNWu5vBLA4fT9RPptaPv2lzjXdvPiiy+6vz/99FP3vvZDJ554ontPx2f
2026-04-21 21:18:33 -05:00
"text/plain": [
"<Figure size 1920x540 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: fes_minus_nofes_delta.png\n"
]
}
],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\n",
"fig.suptitle('Within-pair Δ = FES − NOFES (positive → FES better)',\n",
" fontsize=12, fontweight='bold', y=1.03)\n",
"\n",
"for ax, (key, title, _) in zip(axes, METRICS):\n",
" labels, deltas = [], []\n",
" for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='FES'), None)\n",
" nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='NOFES'), None)\n",
" if fes is None or nof is None: continue\n",
" deltas.append(fes[key] - nof[key])\n",
" labels.append(f'{subj}\\n{pair[\"name\"].split()[0]}')\n",
"\n",
" colors = ['#E05C2A' if d > 0 else '#2A7BE0' for d in deltas]\n",
" ax.bar(np.arange(len(deltas)), deltas, color=colors, edgecolor='white', zorder=2)\n",
" ax.axhline(0, color='k', lw=0.8)\n",
" ax.set_xticks(np.arange(len(deltas)))\n",
" ax.set_xticklabels(labels, fontsize=8)\n",
" ax.set_title(f'Δ {title.split(\"(\")[0].strip()}', fontsize=10, fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('fes_minus_nofes_delta.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: fes_minus_nofes_delta.png')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
2026-04-22 16:20:39 -05:00
"id": "8f9533f8",
2026-04-21 13:01:49 -05:00
"metadata": {},
"source": [
"---\n",
2026-04-22 16:20:39 -05:00
"## Figure 5 — Per-class accuracy and EARLYSTOP latency\n",
"\n",
"FES stimulation fires only on MI trials, so if it helps the subject/system it should\n",
"show up in **MI-accuracy** (count_240 / count_200) specifically, not in REST-accuracy.\n",
"A lift in MI-acc without change in REST-acc would be a genuine sensitivity gain; an\n",
"equal-and-opposite shift would be a bias/threshold effect.\n",
"\n",
"**EARLYSTOP latency** (BEGIN → EARLYSTOP in seconds) is a continuous readout of how\n",
"confidently and how fast the live classifier committed — shorter latency means a\n",
"sharper MI signal crossing the detection threshold earlier."
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 10,
2026-04-22 16:20:39 -05:00
"id": "086ef172",
"metadata": {
"execution": {
"iopub.execute_input": "2026-04-22T19:27:27.828111Z",
"iopub.status.busy": "2026-04-22T19:27:27.828016Z",
"iopub.status.idle": "2026-04-22T19:27:28.197602Z",
"shell.execute_reply": "2026-04-22T19:27:28.197189Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/stats/_axis_nan_policy.py:430: RuntimeWarning: Precision loss occurred in moment calculation due to catastrophic cancellation. This occurs when the data are nearly identical. Results may be unreliable.\n",
" return hypotest_fun_in(*args, **kwds)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABoAAAATYCAYAAAA/GQUBAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3Qe4FNX5x/F3d2+hI1xpSlGKINgrVkBFY++aWDC2GDsmaNRo7H81ajQ27IUQS0xMjFFjBzUWsAsKCkhTpNxLv313/s/v4Kyze/f2snf3fj/Ps3fvzmw5O3Nmdmbe854T8jzPMwAAAAAAAAAAAGSNcLoLAAAAAAAAAAAAgKZFAAgAAAAAAAAAACDLEAACAAAAAAAAAADIMgSAAAAAAAAAAAAAsgwBIAAAAAAAAAAAgCxDAAgAAAAAAAAAACDLEAACAAAAAAAAAADIMgSAAAAAAAAAAAAAsgwBIAAAAAAAAAAAgCxDAAgAAAAAAAAAACDLEAACAAAAAAAAAADIMgSAAACoo1AoVO2tc+fOtuWWW9qvfvUr+/TTT1mmAb/85S8TltX8+fPrvHz03OBr9V5Aa/HYY49V2Rd07NjRVq9enfL5r776asr9R/I2sdlmmyXMnzJlimWi0aNHp/y+4XDYunTpYsOHD7fTTjvN3nvvvQbtd5NvqXz00UfuM7bYYgu3btq1a2d9+vSxrbbayo488ki76qqr7H//+1+1+6u63rTOgGRTp061I444wnr27Gm5ubm28cYb29Zbb22nnnqq/ec//2nR7S+dgvu0pthWrr766ozbR6qMwTLrOwAAALQEAkAAADSBdevW2axZs+zBBx+0nXbayW655RaWaxYKXrzRxTUgWXFxsQsMpXLPPfewwMzM8zxbu3atffXVV/boo4/a7rvv3iz7zD/96U+2yy67uM/45ptv3LopKyuzH374wWbOnGn/+te/7Nprr7W77rqL9YImd++999qYMWPsueees+XLl1tlZaUVFhbajBkz3D7ioYceYqmjVQbYAABAdslJdwEAAMhUBx54oHXo0MFKSkps+vTp7gKPRKNRu+SSS1wr35/97GfpLmZGU4v9o48+Ov545513Tmt5gLq477777MILL0yYtnDhwmZv8d/aKTg+YMAAW7NmjX3wwQfu3nf55ZfbscceW+vFS3+/W5sPP/zQJkyY4IJNoqyjHXbYwWX/KAg0d+5cmzdvXnx+cB+jgH7QggUL3Pv5lMUxatSohOcowwPwLV261H7zm98k1L8999zTZQt/+eWX9u2337aphXXQQQfZsmXLmmxbUfZg8NigR48ejX5PAACAbEUACACARrTu9S9WqssntfT95JNP4vPvuOMOAkCNpIs6f//736mjyCjKBnz99ddt3333TQgKKTjclp177rnxbhyXLFliI0aMsJUrV7rHyo5QF3lnnnlmnfe7NXn88cfjF9/VKl4BJwWggpQJpOwMZWUEy6hbkLI11GWXT+Vmv4SaqFtBBRp9f/jDH1x3gz51FavuCdsKbbdN6bjjjnM3AAAA1I4u4AAAaAJdu3a1Cy64IGGasoKSqSW5soO2335795q8vDzbdNNNXcv3N954o85j6Dz55JO2xx57uHE06juujlq9n3/++W4MDLVG1rgECrTooubPf/5zu/32223VqlV17re+IeP06EKvLo5vtNFG1qlTJ9t7771TZkfU9b1ffvllO/74492F4fbt27v31PfTsv7++++rLYeyt+6//34XqFNmQH5+viuTxnM644wz7PPPP0/oXiV5bIeGjE/08ccf28UXX2z77befDRo0yLp162Y5OTmuPmy33XYuc0TdVQW9++67CZ/1u9/9rsr7VlRUuMwE/zlan0GxWMz+9re/2WGHHebqnL6r6o8uil933XXxC/G1LX9dNP/1r39t/fr1c+X2v/ecOXPs97//vWvpPWTIEFcW1S3VMbXW1oV9ffeaWsyfffbZtskmm7hxWgYPHuzeb/369XUay6Ih21by92tst34qe6oLnroQHOzuSXWtOey///7x76JlGNyOfdq+g99ZdcL31ltvuX3AwIED3XakOqLvtOOOO7rxzarr2q4htAy03QcFAzGNFdyGVB/0HZL17t3bzjrrLJd9lE4N2ScEKXj2xBNP2OGHH+62S607bXfahk466SS3XpOpbtx6662u4YIyMrS9dO/e3bbZZhv3+xD8TaltvJXksbCS60ly91TaHm644Qa3X1BZ/YCeuui78cYb3YV97b+1flQHlfGl5xx11FH2z3/+s8Zl+dlnn9k555zjMnC1/PR67Qf22Wcf93395wTLe95551V5H+0vVff95wwbNszqQ8szaMWKFQmPtV5PP/10aywFUrVt+vtNlVPLtry8vE6vb8h+0/f222+7/f/QoUNdfdPnq/7pN0DdLtanizIFYv36q/fRrW/fvrbbbru5+pi83usyBpDq+M033+wyrwoKCtzvkeq4jp1uuummlPvHVMc73333Xfw3T8tHWYzaXnX80JTUJaV+A/Wd9Rn+8Zl+S/Ud/u///q9Kmf3fRq1Hn/6v6XfN/z064IADrFevXu47aZ+j5aKGS6m+V1MsFwXkdZynYzX9xijD29+2lc2lOiB//vOfEz4rVbBdjSyCz9F+HAAA1MADAAB1op/N4O3bb79NmP/CCy8kzM/Ly0uY/9e//tXr0KFDlfcJ3s4///wqn3vKKackPOeEE06otSzV+fLLL72uXbvWWAbdpk+fHn/Nm2++mTDvqquuSnhPfXZwvspbU/kvuuiiaj/3tttuq9d7l5WVeccdd1yN32WjjTbyXnvttSrLYsaMGd6QIUNqfO3tt9/unjtgwIBal1ly2apz3XXX1fpe7dq1c/UpaNiwYfH5/fv392KxWML8559/vtplWVRU5I0ePbrGz+zbt6/36aef1rj89R6bbLJJyu/9l7/8pdbvFYlEvAceeKDKMtHn9OvXL+VrdthhB2/77bdPmJasodtW8vcbNWqUVx+PPvpowut/85vfeJ07d3b/5+TkeIsXL66ybLQM9Tk1bb/J9U3bYF089dRTCa976KGHqjxnxx13jM8vKChw25BMnjzZC4VCNS7Djh071mv5JH9PLa+gww47LGH+448/XuU9GrqvO/TQQxNed+yxx3pvvPGGV1xc7NVX8nqubz1prn2CqI7ttNNONb72wgsvTHjN1KlTvV69etX4mn/+85/x52ufX1N9TF4+yes5WJ/79OnjjRkzJuH5mi+LFi2qdTnoNm7cuCr7P7n00ktrrMP67fMFy9ClSxdv3bp1Ce/1+uuvJ7z21ltvrdc6Xb58ecI+Sfu+//73v15Tmjt3bpX9sX/bc889vZ133rlZ9pvl5eVuHdT0um233bbaOuCvb98NN9xQ6zofMWJEwmtqq5MffPBBtcvGv2m+nheUfLxz4IEHet26dUv5es2rj9qOpfbdd99al4N+J+fPn1/tPjbVLbi/0mu32WabWpf1ggULmnS5rF692jvggANq/NzDDz/cPXfNmjUJx6n77LNPlfe78sorE1774Ycf1mtdAADQ1tAFHAAATURdugSp9XIwW2TcuHHxLqDUunvkyJEuAyM4fpAGI1cryt/+9rfVfo5aeuv12267rWu5Hex2rjZq/a/u6nzKElGLc2V/LF682LX6Th4To6mpDGqFqxb5ykbSWBw+tUTW2BqpWuunoq6aghkMWh56rcbwUNaMlrdazB555JGu1ffmm2/unldUVOSyJYLZQWrBqlbjykzQ+AwaJD55/IJ//OMf1Y4DUp/xidRidYsttnDlVctbtdZWBo1uUlpaaqeddporh1rIi1qLq3WtP56MWl8HMyj++te/xv9Xq+GTTz45/lgt6oMtpPv37+++q+rdtGnT3DSt/4MPPth9b7UET8V/D9Vt1T8t20gkkvActezVMtQ61vJXa2SNeaF6pcdqza3PCWbLqBX5okWL4o/VMnjXXXd12UY1ZQ019bbVWGqxreWu7B9lZTzwwAN2zTXX2D333JNQZ+++++5m+fwjjjjCLXfVb79OBLMMZs+endDtlMrqZyqonP6
"text/plain": [
"<Figure size 1680x1200 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: per_class_acc_latency.png\n",
"\n",
"Per-subject paired Δ = mean(FES) − mean(NOFES) & T-Test:\n",
"\n",
" MI accuracy:\n",
" subj 002: FES=0.93 NOFES=0.85 Δ=+0.08 | p = 0.5000 (t = 1.00, n = 2)\n",
" subj 003: FES=0.73 NOFES=0.72 Δ=+0.02 | p = 0.8656 (t = 0.21, n = 2)\n",
" subj 005: FES=0.78 NOFES=0.82 Δ=-0.03 | p = 0.9208 (t = -0.12, n = 2)\n",
" subj 009: FES=0.82 NOFES=0.72 Δ=+0.10 | p = 0.3743 (t = 1.50, n = 2)\n",
"\n",
" REST accuracy:\n",
" subj 002: FES=0.87 NOFES=0.83 Δ=+0.03 | p = 0.7048 (t = 0.50, n = 2)\n",
" subj 003: FES=0.78 NOFES=0.76 Δ=+0.02 | p = 0.5000 (t = 1.00, n = 2)\n",
" subj 005: FES=0.93 NOFES=0.90 Δ=+0.03 | p = 0.0000 (t = inf, n = 2)\n",
" subj 009: FES=0.72 NOFES=0.85 Δ=-0.13 | p = 0.1560 (t = -4.00, n = 2)\n",
"\n",
" MI EARLYSTOP latency:\n",
" subj 002: FES=2.50 NOFES=2.30 Δ=+0.20 | p = 0.6020 (t = 0.72, n = 2)\n",
" subj 003: FES=2.45 NOFES=2.36 Δ=+0.09 | p = 0.8457 (t = 0.25, n = 2)\n",
" subj 005: FES=2.54 NOFES=1.99 Δ=+0.56 | p = 0.3678 (t = 1.53, n = 2)\n",
" subj 009: FES=2.60 NOFES=2.29 Δ=+0.31 | p = 0.2213 (t = 2.76, n = 2)\n",
"\n",
" REST EARLYSTOP latency:\n",
" subj 002: FES=2.11 NOFES=2.22 Δ=-0.10 | p = 0.7758 (t = -0.37, n = 2)\n",
" subj 003: FES=2.34 NOFES=2.64 Δ=-0.30 | p = 0.5264 (t = -0.92, n = 2)\n",
" subj 005: FES=2.20 NOFES=2.48 Δ=-0.28 | p = 0.3581 (t = -1.59, n = 2)\n",
" subj 009: FES=1.89 NOFES=2.23 Δ=-0.34 | p = 0.0667 (t = -9.50, n = 2)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/stats/_axis_nan_policy.py:430: RuntimeWarning: Precision loss occurred in moment calculation due to catastrophic cancellation. This occurs when the data are nearly identical. Results may be unreliable.\n",
" return hypotest_fun_in(*args, **kwds)\n"
]
}
],
"source": [
"from scipy.stats import ttest_rel\n",
"\n",
"PER_CLASS_METRICS = [\n",
" ('mi_acc', 'MI accuracy', '0– 1'),\n",
" ('rest_acc', 'REST accuracy', '0– 1'),\n",
" ('mi_latency', 'MI EARLYSTOP latency', 'seconds'),\n",
" ('rest_latency', 'REST EARLYSTOP latency', 'seconds'),\n",
"]\n",
"\n",
"# Restructure data to average across multiple sessions (Pairs) for each subject\n",
"def subj_cond_metrics(key):\n",
" \"\"\"Return {subj: {'FES': [vals...], 'NOFES': [vals...]}} using non-null entries in `results`.\"\"\"\n",
" out = {s: {'FES': [], 'NOFES': []} for s in subjects}\n",
" for r in results:\n",
" v = r.get(key)\n",
" if v is None or (isinstance(v, float) and np.isnan(v)):\n",
" continue\n",
" out[r['subject']][r['condition']].append(v)\n",
" return out\n",
"\n",
"mi_acc_data = subj_cond_metrics('mi_acc')\n",
"rest_acc_data = subj_cond_metrics('rest_acc')\n",
"mi_lat_data = subj_cond_metrics('mi_latency')\n",
"rest_lat_data = subj_cond_metrics('rest_latency')\n",
"\n",
"metrics_data = [mi_acc_data, rest_acc_data, mi_lat_data, rest_lat_data]\n",
"fig, axes = plt.subplots(2, 2, figsize=(14, 10))\n",
"fig.suptitle('Per-subject average: MI vs REST, accuracy & decision latency\\n(mean across that subject\\'s sessions ± SEM)',\n",
" fontsize=13, fontweight='bold', y=1.03)\n",
"\n",
"width = 0.38\n",
"x = np.arange(len(subjects))\n",
"\n",
"for ax, data_dict, (key, title, unit) in zip(axes.ravel(), metrics_data, PER_CLASS_METRICS):\n",
" for i, cond in enumerate(('FES', 'NOFES')):\n",
" means = [np.mean(data_dict[s][cond]) if data_dict[s][cond] else np.nan for s in subjects]\n",
" sems = [np.std(data_dict[s][cond], ddof=1) / np.sqrt(len(data_dict[s][cond]))\n",
" if len(data_dict[s][cond]) > 1 else 0.0 for s in subjects]\n",
" offset = (i - 0.5) * width\n",
" ax.bar(x + offset, means, width, yerr=sems,\n",
" color=cond_color[cond], label=cond, edgecolor='white', capsize=4)\n",
" \n",
" # Overlay individual session values\n",
" for xi, s in zip(x, subjects):\n",
" if data_dict[s][cond]:\n",
" ax.scatter(np.full(len(data_dict[s][cond]), xi + offset), data_dict[s][cond],\n",
" color='k', alpha=0.5, s=14, zorder=3)\n",
" \n",
" # Calculate p-values for pairs and add them to custom labels\n",
" new_labels = []\n",
" for s in subjects:\n",
" label = s\n",
" if data_dict[s]['FES'] and data_dict[s]['NOFES'] and len(data_dict[s]['FES']) == len(data_dict[s]['NOFES']) and len(data_dict[s]['FES']) > 1:\n",
" try:\n",
" _, p_val = ttest_rel(data_dict[s]['FES'], data_dict[s]['NOFES'])\n",
" p_str = f'p={p_val:.3f}' if p_val >= 0.001 else 'p<0.001'\n",
" label = f'{s}\\n({p_str})'\n",
" except Exception:\n",
" pass\n",
" new_labels.append(label)\n",
"\n",
" ax.set_xticks(x)\n",
" ax.set_xticklabels(new_labels, fontsize=8)\n",
" ax.set_title(f'{title} ({unit})', fontsize=11, fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
" if 'acc' in key:\n",
" ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n",
" ax.set_ylim(0, 1.05)\n",
"\n",
"fig.legend(handles=[Patch(color=cond_color['FES'], label='ONLINE_FES'),\n",
" Patch(color=cond_color['NOFES'], label='ONLINE_NOFES')],\n",
" loc='upper right', ncol=2, bbox_to_anchor=(0.98, 1.0))\n",
"plt.tight_layout()\n",
"plt.savefig('per_class_acc_latency.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: per_class_acc_latency.png')\n",
"\n",
"# ── Paired Δ (FES − NOFES) summary per subject ─────────────────────\n",
"print('\\nPer-subject paired Δ = mean(FES) − mean(NOFES) & T-Test:')\n",
"for (key, title, _), data_dict in zip(PER_CLASS_METRICS, metrics_data):\n",
" print(f'\\n {title}:')\n",
" for s in subjects:\n",
" if data_dict[s]['FES'] and data_dict[s]['NOFES']:\n",
" delta = np.mean(data_dict[s]['FES']) - np.mean(data_dict[s]['NOFES'])\n",
" if len(data_dict[s]['FES']) == len(data_dict[s]['NOFES']) and len(data_dict[s]['FES']) > 1:\n",
" t_stat, p_val = ttest_rel(data_dict[s]['FES'], data_dict[s]['NOFES'])\n",
" stats_str = f'p = {p_val:.4f} (t = {t_stat:.2f}, n = {len(data_dict[s][\"FES\"])})'\n",
" else:\n",
" stats_str = f'n = {len(data_dict[s][\"FES\"])} (too few paired runs)'\n",
" \n",
" print(f' subj {s}: FES={np.mean(data_dict[s][\"FES\"]):.2f} '\n",
" f'NOFES={np.mean(data_dict[s][\"NOFES\"]):.2f} Δ={delta:+.2f} | {stats_str}')"
2026-04-21 13:01:49 -05:00
]
},
2026-04-22 04:21:48 -05:00
{
"cell_type": "markdown",
2026-04-22 16:20:39 -05:00
"id": "c74da761",
2026-04-22 04:21:48 -05:00
"metadata": {},
"source": [
2026-04-22 16:20:39 -05:00
"### Per-subject average latency\n",
2026-04-22 04:21:48 -05:00
"\n",
2026-04-22 16:20:39 -05:00
"Collapses the per-(subject × pair) bars into one FES value and one NOFES value per subject\n",
"by averaging across all that subject's ONLINE sessions of each type. Error bars are the\n",
"session-to-session SEM within that (subject, condition). Easier to read when you just want\n",
"the per-subject FES vs NOFES comparison without the pair split."
2026-04-22 04:21:48 -05:00
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 11,
2026-04-22 16:20:39 -05:00
"id": "4ea9951f",
"metadata": {
"execution": {
"iopub.execute_input": "2026-04-22T19:27:28.199503Z",
"iopub.status.busy": "2026-04-22T19:27:28.199411Z",
"iopub.status.idle": "2026-04-22T19:27:28.407015Z",
"shell.execute_reply": "2026-04-22T19:27:28.406594Z"
}
},
2026-04-22 04:21:48 -05:00
"outputs": [
{
"data": {
2026-04-22 16:20:39 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABZAAAAIrCAYAAABxg7xlAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAAsuVJREFUeJzs3QeYE+XWwPGTzRZYysIuvVfp2BEsFBWvBUQBCxYQLCheCxcsYO8NG6j4XUUQwYYNxYoKiIWmgqKC0qXDLixl2ZbN95wXJzfJZkuy2dT/73nyJJlMMpPJZPLOyXnPa3M6nU4BAAAAAAAAAMBLgvcEAAAAAAAAAAAIIAMAAAAAAAAASkQGMgAAAAAAAADAJwLIAAAAAAAAAACfCCADAAAAAAAAAHwigAwAAAAAAAAA8IkAMgAAAAAAAADAJwLIAAAAAAAAAACfCCADAAAAAAAAAHwigAwAAAAAAAAA8IkAMgAAAAAAAADAJwLIAKLKFVdcITabzVxatGhR7udNmzbN9Ty9bNiwoVLXE0B8uvfeez2ONQifrVu3SvXq1c3n0LJlSyksLOTjQNhp28U6PmibBhVrp+l87s/T10F043c0Pl122WWu7/Hs2bPDvToAfCCADEQQ70aw+yU5OVkaN24s/fv3l1mzZoV7VRGA3r17uz5PvY3YCACUdpk/f36JrzFgwIBi869cubJcf564X5KSkqRu3bpy8skny2OPPSb79+8v9lxdD/fn6MmZuzfffNPj8Z49e4rT6Sz2Ol988YXHfEceeaTk5+e7Hn/vvfekX79+0rBhQ3PMqlGjhjRr1ky6d+8uV199tfzf//1fiSeI5b34ChLec889ctJJJ5ntoMutVauWdO3aVUaNGiU//fRTidu0pGXoNq1fv7707dtXXn75ZXE4HBJOBJwCc+edd8rBgwfN7VtuuUUSExOD+rkA0RDIpO0ROQHPSNs3EHv++usv0/bp0KGDVKtWTVJSUqRBgwbSqVMnGTx4sDzwwAPy999/l/rnTUkX73MXX+3SZ555xud6vf/++8Xm9f4D7fbbb3d9f2677Tb+9AUiEC1pIEoUFBSYQIle5syZIwMHDpS33nqLE+JyOv744+WJJ55w3U9PT6+sjwoo0/bt2+WTTz4pNn3KlCny9NNP+7UFNaty9+7d5vLdd9/Ja6+9JosXLzYnDuV18cUXm8b922+/be4vXLhQJk6cKDfddJNrHg1MaxDYooFaXZZeq6uuusqsv/dx68CBA+ZkRddJ//waOXKkBMsLL7wg//nPfyQvL89jenZ2tvz666/mMnnyZLPe+n6qVKlS7m26c+dO+fLLL81FT64+++wzk82K6LB69Wp59dVXze20tDQZMWJEuFcJQCXQ9px7+07be4huZ5xxBr+3AZg7d66ce+65kpub6zF9x44d5vL777/Lu+++K126dJGmTZtKZdC21g033CB2u91j+pNPPlnmczt37mz+uNdkBes3/Morr6yU9QQQGALIQAQ77rjj5KKLLjKZgJrZosEaK7tQM/1efPFF+fe//11py9cGiDYANBsv2uk/73pB+Ozbt09q1qwZcx9B7dq1Zfz48T4fa926tc/p2ij21Z1+xowZJovYCsqWRpepmbYajNbnacBT/fbbb/LKK6+YBrw/NNCqgeNt27a5Xl+zia33MGbMGNm0aZNr/vvuu89k+Spt7LsHj48++mj517/+ZQJ3WVlZJpD77bfflnmCqH+KLVu2zOM96vb15fnnn/c4/mmWzYUXXijt27c320SD4XrCpF566SXZu3evK0DuS6tWreS6664zt3Ub6PF2165d5r4G5jXDbMKECeXYkogEuj8XFRWZ2+eff365/zxAeMXq7wQqj+4vY8eOZRPHkBNPPNFc4oGe32mJJTVv3ryAeyjq753+UWoFj/WPlQsuuMD0AtOeYmvWrJHvv/9e1q9fX+Zr6bmnnoN6K0/QWV//gw8+kEGDBrmmLVmyxLSjymPIkCGmTWklCRBABiKME0DEWL9+vfYZd12GDRvm8fjnn3/u8XjPnj09Hi8qKnK+9dZbznPOOcfZoEEDZ1JSkjMtLc15yimnOP/v//7PWVBQUGyZ7q93zz33OBcsWOA87bTTzPN0mq5TeWRlZTnHjx/vPPLII501atRw2u12Z0ZGhrNjx47OSy+91PnSSy95zN+8efMS3+fUqVM91st9HXRea7q+xoEDB5y33nqruZ2cnOxs1aqVeR+5ubnlfs1At53Kzs52Pv74486TTz7ZmZ6ebp5Xr149Z48ePZz33XefmUfXx33Zvi66fmXRdb755pvNOjVr1sxZvXp1s7y6des6+/TpY9azsLDQNf+aNWucNpvNtYzXX3+92GveeOONrsfr16/vzM/Pdz22e/du57333us87rjjnDVr1jTLaty4sXPIkCHOJUuWFHst7238119/OR977DFn+/btzWfTq1evgN6Huy+++MI8LzU11VmrVi1nv379nMuXLy+2jb3p+9J9UPftOnXqmOXp/tm3b1/n22+/7fSX+/6rt/11xBFHuJ7vflsvs2bN8vkc933fex/+9NNPPR679tprPZ47b968Yt91Xz7++ONixxj9bsydO9dj+oknnujxGY0ePdr1WOvWrX1+fvoZfPbZZ6Vul9Leo7u///7bWaVKFdd81apVc/7888/FjkmdO3f2eL133nnHYx73x6z907Jq1SqP74/uq+VR2r6on8OVV17pPPbYY50NGzY07yElJcXZtGlT5/nnn2+O8aVtD18Xfc1g/gb8+OOPznPPPdd8v3T9jj/+eOecOXN8vleHw+GcOXOmWZ6+H/2e6/N0u48aNcq5a9cuZ2ZmprNq1aquZTzzzDOlbjNd35ycHGdF6PFf18N6zU8++aTYPN6/Jdu3b3cOHz7cHB/0mHT66aebbaF0P7zooouctWvXNsce/V58//33Ppetv0lPPvmk86STTjLzW78Juk31+OVNP5M777zTefbZZ5vvjq63/n7qdtD9RB/Tfdmb7q/u+66uv37vGzVqZD6HNm3amN8m3SfKqyLH5vL8FlrL8P7te++998x8ujzv78yvv/7qvPrqq51t27Y1+5Huk7qdRowYYY79vvbJF154wbwHPcbrttTfL31O//79nQ888ID5jAKdvzy/CSVdfM2r++HGjRudQ4cONdtZP7tOnTo5p02bVmwZeoy77rrrnN27d3c2adLE7Is6v373zjrrLOebb77pMX+w2h7+biPv75Y7X59/SW2IdevWOV988UVn165dzeeu2+eKK65wbtmypdyvaVm0aJHz8ssvd7Zs2dK8lm67Ll26OO+66y5zjAr0+Ob92+rr4t7GLWkdFy9ebI4x+r3T3wS9aHtLf2tvuukm59KlS8v8nCryWps2bXKOHTvWbBP9Hup71W111VVXmd/CUH/PymrT5eXlOSdPnuzs3bu3eb3ExERz3NFj86RJk4qdA1Tkt87f8xt/ue+/7r/n/lqxYkWJbQPvY6p328r7u1ee44KvNopuG6uN6O7CCy8sNo/3d8Oyd+9e83la81i/wwAiAwFkIIoCyNq4cn9cT6gs2ljSE9DSGrHa0Dp48KDHa7o/ridw7j/spQVw3OmyvQM13hfvk4hgBJA14KkNPl/L08Cg+4luaa8Z6LbTxpoGfkp6jgYAgnkS99FHH5X5OnoSqQ11i570W4/pe3Sn20e3ofW4BuIteoLh/pj3RfcTbaSX9rnpiYKvAF0g70O99tprHgE966JBBf28SzrZ0JPDkvYT66JBce/lVVYA+ZtvvvFYtp70t2vXznX/zDPP9Du4+ssvv3g8pkGnQALISoM17vM+9NBDHu9XT7z1z4GS/ojQE6s///zTr21Snvfo7v777/eY77bbbvM5n3dg/dRTT/V43Nf+6U6DidbjGhArj9JOfMeMGVPmvv/www+XuD18XayTxGD8BpxwwgkmcOD9vISEBOfXX39d7CTP+zvufbGC+tdcc41rmp50e9M/mazHNUBWUfpHqPV6eszQdfXmvm018KDBEu/1131dj1e6T/s67vzxxx8er7l27Vrzu1zaNnE
2026-04-22 04:21:48 -05:00
"text/plain": [
"<Figure size 1440x540 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 16:20:39 -05:00
"Saved: per_subject_latency.png\n",
"\n",
"Per-subject paired Δ = mean(FES lat) − mean(NOFES lat) & T-Test:\n",
" MI latency:\n",
" subj 002: FES=2.50s NOFES=2.30s Δ=+0.20s | p = 0.6020 (t = 0.72, n = 2)\n",
" subj 003: FES=2.45s NOFES=2.36s Δ=+0.09s | p = 0.8457 (t = 0.25, n = 2)\n",
" subj 005: FES=2.54s NOFES=1.99s Δ=+0.56s | p = 0.3678 (t = 1.53, n = 2)\n",
" subj 009: FES=2.60s NOFES=2.29s Δ=+0.31s | p = 0.2213 (t = 2.76, n = 2)\n",
" REST latency:\n",
" subj 002: FES=2.11s NOFES=2.22s Δ=-0.10s | p = 0.7758 (t = -0.37, n = 2)\n",
" subj 003: FES=2.34s NOFES=2.64s Δ=-0.30s | p = 0.5264 (t = -0.92, n = 2)\n",
" subj 005: FES=2.20s NOFES=2.48s Δ=-0.28s | p = 0.3581 (t = -1.59, n = 2)\n",
" subj 009: FES=1.89s NOFES=2.23s Δ=-0.34s | p = 0.0667 (t = -9.50, n = 2)\n"
2026-04-22 04:21:48 -05:00
]
2026-04-22 16:20:39 -05:00
}
],
"source": [
"from scipy.stats import ttest_rel\n",
"\n",
"# Average latency per subject across sessions (collapses the pair dimension)\n",
"def subj_cond_latencies(key):\n",
" \"\"\"Return {subj: {'FES': [vals...], 'NOFES': [vals...]}} using all non-null entries in `results`.\"\"\"\n",
" out = {s: {'FES': [], 'NOFES': []} for s in subjects}\n",
" for r in results:\n",
" v = r.get(key)\n",
" if v is None or (isinstance(v, float) and np.isnan(v)):\n",
" continue\n",
" out[r['subject']][r['condition']].append(v)\n",
" return out\n",
"\n",
"mi_lat = subj_cond_latencies('mi_latency')\n",
"rest_lat = subj_cond_latencies('rest_latency')\n",
"\n",
"fig, axes = plt.subplots(1, 2, figsize=(12, 4.5), sharey=True)\n",
"fig.suptitle('Per-subject average EARLYSTOP latency (mean across that subject\\'s sessions ± SEM)',\n",
" fontsize=12, fontweight='bold', y=1.02)\n",
"\n",
"width = 0.38\n",
"x = np.arange(len(subjects))\n",
"\n",
"for ax, data, title in [(axes[0], mi_lat, 'MI trials'),\n",
" (axes[1], rest_lat, 'REST trials')]:\n",
" for i, cond in enumerate(('FES', 'NOFES')):\n",
" means = [np.mean(data[s][cond]) if data[s][cond] else np.nan for s in subjects]\n",
" sems = [np.std(data[s][cond], ddof=1) / np.sqrt(len(data[s][cond]))\n",
" if len(data[s][cond]) > 1 else 0.0 for s in subjects]\n",
" offset = (i - 0.5) * width\n",
" ax.bar(x + offset, means, width, yerr=sems,\n",
" color=cond_color[cond], label=cond, edgecolor='white', capsize=4)\n",
" # Overlay individual session values\n",
" for xi, s in zip(x, subjects):\n",
" if data[s][cond]:\n",
" ax.scatter(np.full(len(data[s][cond]), xi + offset), data[s][cond],\n",
" color='k', alpha=0.5, s=14, zorder=3)\n",
" \n",
" # Calculate and format p-values for x-axis labels\n",
" new_labels = []\n",
" for s in subjects:\n",
" label = s\n",
" if data[s]['FES'] and data[s]['NOFES'] and len(data[s]['FES']) == len(data[s]['NOFES']) and len(data[s]['FES']) > 1:\n",
" try:\n",
" _, p_val = ttest_rel(data[s]['FES'], data[s]['NOFES'])\n",
" p_str = f'p={p_val:.3f}' if p_val >= 0.001 else 'p<0.001'\n",
" label = f'{s}\\n({p_str})'\n",
" except Exception:\n",
" pass\n",
" new_labels.append(label)\n",
" \n",
" ax.set_xticks(x)\n",
" ax.set_xticklabels(new_labels, fontsize=8) # slightly smaller to fit the p-vals\n",
" ax.set_xlabel('Subject')\n",
" ax.set_title(title, fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
"\n",
"axes[0].set_ylabel('BEGIN → EARLYSTOP latency (s)')\n",
"axes[0].legend(loc='lower right')\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('per_subject_latency.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: per_subject_latency.png')\n",
"\n",
"# Per-subject paired Δ summary\n",
"print('\\nPer-subject paired Δ = mean(FES lat) − mean(NOFES lat) & T-Test:')\n",
"for key, title in [('mi_latency', 'MI'), ('rest_latency', 'REST')]:\n",
" data = subj_cond_latencies(key)\n",
" print(f' {title} latency:')\n",
" for s in subjects:\n",
" if data[s]['FES'] and data[s]['NOFES']:\n",
" delta = np.mean(data[s]['FES']) - np.mean(data[s]['NOFES'])\n",
" \n",
" # Since data arrays are appended in order (Pair 1, Pair 2), we can run a paired t-test\n",
" # if we have balanced matched pairs (e.g. n=2 for FES and n=2 for NOFES).\n",
" if len(data[s]['FES']) == len(data[s]['NOFES']) and len(data[s]['FES']) > 1:\n",
" t_stat, p_val = ttest_rel(data[s]['FES'], data[s]['NOFES'])\n",
" stats_str = f'p = {p_val:.4f} (t = {t_stat:.2f}, n = {len(data[s][\"FES\"])})'\n",
" else:\n",
" stats_str = f'n = {len(data[s][\"FES\"])} (too few or unbalanced runs for paired t-test)'\n",
" \n",
" print(f' subj {s}: FES={np.mean(data[s][\"FES\"]):.2f}s '\n",
" f'NOFES={np.mean(data[s][\"NOFES\"]):.2f}s Δ={delta:+.2f}s | {stats_str}')"
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 12,
2026-04-22 16:20:39 -05:00
"id": "2ba1928b",
"metadata": {
"execution": {
"iopub.execute_input": "2026-04-22T19:27:28.408334Z",
"iopub.status.busy": "2026-04-22T19:27:28.408256Z",
"iopub.status.idle": "2026-04-22T19:27:28.518108Z",
"shell.execute_reply": "2026-04-22T19:27:28.517764Z"
}
},
"outputs": [
2026-04-22 04:21:48 -05:00
{
"data": {
2026-04-22 16:20:39 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0EAAAIMCAYAAAAzTKffAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAAmixJREFUeJzt3QecE9Xax/FnC70JSFOKAhaKHRUbTfEqiFgAKyj23rCBvfcGdkUR7KhXFLFdaYpKsWBXQBQp0qUtZXfJ+/mfeyfvZDbZzdZskt/XT1ySTJLJZObMPKc8JyMUCoUMAAAAANJEZqJXAAAAAAAqEkEQAAAAgLRCEAQAAAAgrRAEAQAAAEgrBEEAAAAA0gpBEAAAAIC0QhAEAAAAIK0QBAEAAABIKwRBAAAAANIKQRAAAACAtEIQBAAAACCtEAQlyMKFC+3222+3ww8/3Fq0aGG1atWyqlWrWoMGDWyPPfawU045xZ588klbunSppaJRo0ZZRkZG+HbzzTcnepVQAfy/eVG3b7/9Nub7zJ07N+prfvzxxxJ9fnZ2tm2zzTa211572SWXXGI///xz1Nf/8ccfBV4bpOPW/3zNmjVt3rx5Ud9vypQplpmZGbH8Bx98ELHMu+++a8cff7y1bNnSqlev7t6vefPmrpwYMGCA3XbbbfbVV1+Flz/99NOLtZ39N32/oN9++82uueYa69y5szVu3DhcTnXo0MHOOecc++STT0q0zb1t07p1azvxxBMLfO94UI4gnQ0bNqzAMaVjIhn413mHHXYo1msnT54c8XqVeUCJhFChNm3aFLr88stD2dnZIW3+om6NGjVKyV/o+eefj/ieN910U6JXCRUgnn3eu33zzTcx3+eGG26I+pqrrrqqTD5fx+frr79e4PXz588vsGzQ1q1bQ927d49Y5tBDD41aFuy8884Ryw0ePDj8fF5eXujkk0+Oa30vvPDC8OtOO+20Ym1n/03fz7N582b3vpmZmUW+Tt938eLFpf7NTzrppFBubm6oMpQjwe04adKkMntvoLSmT58eysrKKnAM6ZhIBv51btWqVbFeq2PR/3odq4lU2dYH8csuWeiEkti4caNr+fnss88iHldtqGqgt912W1u/fr39+uuvrqVItm7dysZGyjryyCPd/h+NWmai0flzzJgxUZ978cUX7a677rKsrKxiff6KFSvsiy++sC1btrjH8/LyXCtHr169XCttcahm8tlnn7Xdd9/dNmzY4B5Ta8lzzz1nZ5xxRni5W2+91bWyeLbbbjt78MEHw/dHjBhhL7/8cvh+lSpVrFOnTtaoUSPLyclxr12wYEGBz993331dOeK3fPlymzp1avi+vrO+e5D3XfPz861v374FWmd22WUX23nnnV0L9axZs8Ll06RJk2z//fe3mTNnWpMmTeLa5trW3333nf3555/h51555RVr06aNa90CEN2mTZvstNNOc8dpOlIZqNZxf5kHlEgxAiaUkmp5g7XNd955ZygnJ6fAsr///nvotttuC7Vv3z4ltzstQempsJaHkta6ValSJeL++++/X6LP//bbbwu810cffVTsliDP8OHDI5arX79+6O+//3bPfffddwU+69133414/W677RZ+rm7duq5MCFqwYEHokUcecbfibLOial7vuOOOiOVr1KgR+ve//x2xzI8//hhq3bp1xHI9e/Ys1jZXq8+gQYMintd33bJlSygetAQhHV1xxRXh42WHHXZIu5agyoaWoORFEFRBvv/++1BGRkbEgT9y5MgiX6cuKUUVHupWo2CqQ4cO7mKlXr16ERdiAwcODO25556h7bbbLlS9evVQ1apVQ02bNnVddEaMGOFeH89BvXr16tC1114batu2rXsPddU75ZRT3IVYNLqQuffee0Pt2rULVatWLdS4cWPXvWfu3LllcvGii6nrr78+1Lt3b9etaNttt3WBZa1atUI77bSTW7dgFxYFnNtss034c3URF83QoUMj1u+tt96KeH7FihXuQvGggw4KNWjQwH1uw4YN3TbV7xrtIi7aNl2yZEnovPPOC7Vs2dK9R9++fcPvf+utt4aOPfbY0K677uq2nS6aa9as6U56xx13XIGL0uC2v++++1wQrW2vbdOvXz934Rrvtv/Pf/7jfq8dd9zR7Ve6abtqfX/55ZdQooKg008/PeI9br/99oj7J554Yok/f4899oh4/pVXXilxEKRucQcffHDEsv379w/l5+eH9ttvv4jHta8G6Xfznt99991DpVGcIGjdunURx4hujz76aNRlFTgGu8tNmzatWNtcwV1wGZWX8ShqX546dWro0ksvDXXt2tXtxyobdZzpr8rEyy67LPTbb79FvEbvEVyfaLdg2fLXX3+5cqNTp05u++l41XGr8umNN95w+0M8669uhRdccIErE1TOqtw+//zzQ6tWrYq5HXRcX3LJJe47eZ/dpEmT0P777+/WSb+pPl/lgT/YXL9+fYH3evzxxyPW6Yknnii38rgk3yPW76RtOWvWrNAxxxzjPlvn24ceeiji/b/44gtXGan103rpGNt+++1DRx99dOi1115zx2Y0qqDQ++o30TlUr9Pvss8++4TOOeecqIFHSV4Tj08//TR8zOl8c+ONN5Y6CNJ2veuuu0IHHHBA+HxWp04dd8yoYmPYsGGhmTNnRrxGZUhh5WBR3UmD5ZH2T+1re+21lzvXqNJI50OVMSUNOrTOZ555ZmiXXXYJ/946f+qaKPh9SrIvRjsfRLv510+v0fuqkkvbWF0atc21jjqv33333aGFCxcW+ZuhbBAEVZDrrrsu4qDQAVBS/vdp1qxZgfEH/iBIB35RB6gusBTgFFbI6GIuWOh5txYtWhQ4Qesi/PDDD4+6vA78c889N64L8cKMHTs2rgJI40f8NM6hsIs2FcY6cXnPK2D0j1PQtlEAWNhn6mSiQKawbdqtWzd3UvQ/5gVBKqDj+W6qRQ9eXGnb/+tf/4q6vE7GOgEUtu31+qLGouji7Lnnniv2b1baIEgXbbVr147Y/zV2xr9v6jv+888/Jfp8f+uLblOmTClxECS6wNYJ3b+8Loz893ViDe4rouPEv5wuhD///POoFSNlGQTp4i14vG7cuDHm8ocddljE8ldeeWWxtvmGDRsKLKPvWRZBUPBYj3VMvPfee6UKgl599dUiy1pdaAe3Y3D9jzzySHdBFO31e++9d9TKFa1vtLEh0bb5s88+G/H4k08+WeD9VLHjPa9jbe3ateVaHpfke0T7nQYMGFCgddULglRGXnzxxUWuW5cuXUIrV66MWC8FB0W9Tr99aV8Tb/nXpk2bcBD7559/Rg0Gi0P7ZLDyp6hxh2UdBDVv3jx0wgknRP1cBS4ffvhhsYMglUPBimf/Tc+pAq00+2Jxg6Cff/7ZXZ8VtbyOJVQMgqAKEgxU1KJSUtEOGp2sVICrZkhBiUcFrQpL1U6qRkcX2br4Dh6IOkEUVsh4N7U2qVY1mNhBXff8brnllgKvVQ2YXquLjuBzpQmCVBgfeOCBoaOOOsrVQqomKVg7/dVXX4Vfp5ql4MWl3+TJk2P+Vr/++mvERbh3caLP9k5O3q1Hjx5xbVNdBCtoUeB0/PHHRwRBCpI6d+7sLo769OnjWhAUgPhf/+abbxa57bVNtH/4WxdibXvVUvqf10WZ1k/fx//ZOklMnDixWL9Z8LP1vfSdgzfVbkYzevToiNerJl+uueaaiMeffvrpuD7ff1H19ddfR+zX2vbBVtLiBkFy//33F3rCC/5+nl69ekVdXr+Bfs+LLrrIXRzEqr0uaRAUrFlW2VWY4EWYjvF4t7lXs13UMqUJglQWqAVE+7/KvyOOOKLAsapj0OuWrBYB7YPBizy93r+P/vDDD+HWJv9+o8/TsayySK0MsRJfRFt/7+JMZaXKtOBzY8aMKXLfUuuTzgM6Zr3KGm97an/Wd43Vwqjl/BeOZ599dqi8y+OSfI/CglW1OunY0bnq4Ycfjlom6jvqnKh9Oxi8+sttVTj4y3sde9oPVBbr9d56+QOakrwmXv6g3utJUtog6KWXXop4vSqWtP1Uianjxjtfl2c
2026-04-22 04:21:48 -05:00
"text/plain": [
2026-04-22 16:20:39 -05:00
"<Figure size 720x540 with 1 Axes>"
2026-04-22 04:21:48 -05:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 16:20:39 -05:00
"Saved: grand_avg_latency.png\n",
"\n",
"Grand-average latency (across-subject mean ± SEM) & Paired T-Test:\n",
" MI FES = 2.523 ± 0.032 s NOFES = 2.235 ± 0.084 s Δ = +0.288 s | p = 0.0632 (t = 2.89, n = 4)\n",
" REST FES = 2.135 ± 0.094 s NOFES = 2.391 ± 0.102 s Δ = -0.256 s | p = 0.0163 (t = -4.89, n = 4)\n"
2026-04-22 04:21:48 -05:00
]
}
],
"source": [
2026-04-22 16:20:39 -05:00
"# Grand average across subjects: one pair of bars per class (MI, REST).\n",
"# Each subject contributes one mean (averaged across their sessions) per condition,\n",
"# so error bars = between-subject SEM.\n",
"from scipy.stats import ttest_rel\n",
"\n",
"def subject_means(data):\n",
" return {cond: np.array([np.mean(data[s][cond]) for s in subjects if data[s][cond]])\n",
" for cond in ('FES', 'NOFES')}\n",
"\n",
"mi_sub = subject_means(mi_lat)\n",
"rest_sub = subject_means(rest_lat)\n",
"\n",
"fig, ax = plt.subplots(figsize=(6, 4.5))\n",
"classes = ['MI', 'REST']\n",
"x = np.arange(len(classes))\n",
"width = 0.38\n",
"\n",
"p_values = {}\n",
"for cls, d in [('MI', mi_sub), ('REST', rest_sub)]:\n",
" subjects_with_both = [s for s in subjects if mi_lat[s]['FES'] and mi_lat[s]['NOFES']] if cls == 'MI' else [s for s in subjects if rest_lat[s]['FES'] and rest_lat[s]['NOFES']]\n",
" fes_paired = np.array([np.mean(mi_lat[s]['FES']) if cls == 'MI' else np.mean(rest_lat[s]['FES']) for s in subjects_with_both])\n",
" nofes_paired = np.array([np.mean(mi_lat[s]['NOFES']) if cls == 'MI' else np.mean(rest_lat[s]['NOFES']) for s in subjects_with_both])\n",
" \n",
" if len(fes_paired) > 1:\n",
" _, p_val = ttest_rel(fes_paired, nofes_paired)\n",
" p_values[cls] = p_val\n",
" else:\n",
" p_values[cls] = np.nan\n",
"\n",
"for i, cond in enumerate(('FES', 'NOFES')):\n",
" means = [mi_sub[cond].mean(), rest_sub[cond].mean()]\n",
" sems = [mi_sub[cond].std(ddof=1) / np.sqrt(len(mi_sub[cond])),\n",
" rest_sub[cond].std(ddof=1) / np.sqrt(len(rest_sub[cond]))]\n",
" offset = (i - 0.5) * width\n",
" ax.bar(x + offset, means, width, yerr=sems,\n",
" color=cond_color[cond], label=cond, edgecolor='white', capsize=5)\n",
" for j, arr in enumerate([mi_sub[cond], rest_sub[cond]]):\n",
" ax.scatter(np.full(len(arr), x[j] + offset), arr,\n",
" color='k', alpha=0.6, s=20, zorder=3)\n",
"\n",
"# Modify X-labels to contain the P-Values below each\n",
"new_labels = []\n",
"for cls in classes:\n",
" p_val = p_values[cls]\n",
" if not np.isnan(p_val):\n",
" p_str = f'p={p_val:.3f}' if p_val >= 0.001 else 'p<0.001'\n",
" new_labels.append(f'{cls}\\n({p_str})')\n",
" else:\n",
" new_labels.append(cls)\n",
"\n",
"ax.set_xticks(x); ax.set_xticklabels(new_labels)\n",
"ax.set_ylabel('BEGIN → EARLYSTOP latency (s)')\n",
"ax.set_title(f'Grand-average EARLYSTOP latency across {len(subjects)} subjects\\n'\n",
" '(points = per-subject means, bars = across-subject mean ± SEM)',\n",
" fontweight='bold')\n",
"ax.legend(); ax.grid(axis='y', alpha=0.3)\n",
"ax.spines[['top','right']].set_visible(False)\n",
"plt.tight_layout()\n",
"plt.savefig('grand_avg_latency.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: grand_avg_latency.png')\n",
2026-04-22 04:21:48 -05:00
"\n",
2026-04-22 16:20:39 -05:00
"print('\\nGrand-average latency (across-subject mean ± SEM) & Paired T-Test:')\n",
"sem = lambda a: a.std(ddof=1) / np.sqrt(len(a)) if len(a) > 1 else 0.0\n",
"for cls, d in [('MI', mi_sub), ('REST', rest_sub)]:\n",
" # Ensure arrays are aligned for t-test by taking paired subjects only\n",
" subjects_with_both = [s for s in subjects if mi_lat[s]['FES'] and mi_lat[s]['NOFES']] if cls == 'MI' else [s for s in subjects if rest_lat[s]['FES'] and rest_lat[s]['NOFES']]\n",
2026-04-22 04:21:48 -05:00
" \n",
2026-04-22 16:20:39 -05:00
" fes_paired = np.array([np.mean(mi_lat[s]['FES']) if cls == 'MI' else np.mean(rest_lat[s]['FES']) for s in subjects_with_both])\n",
" nofes_paired = np.array([np.mean(mi_lat[s]['NOFES']) if cls == 'MI' else np.mean(rest_lat[s]['NOFES']) for s in subjects_with_both])\n",
2026-04-22 04:21:48 -05:00
" \n",
2026-04-22 16:20:39 -05:00
" delta = d['FES'].mean() - d['NOFES'].mean()\n",
2026-04-22 04:21:48 -05:00
" \n",
2026-04-22 16:20:39 -05:00
" if len(fes_paired) > 1:\n",
" t_stat, p_val = ttest_rel(fes_paired, nofes_paired)\n",
" stats_str = f'p = {p_val:.4f} (t = {t_stat:.2f}, n = {len(fes_paired)})'\n",
" else:\n",
" stats_str = 'Not enough paired data for t-test'\n",
" \n",
" print(f' {cls:<5} FES = {d[\"FES\"].mean():.3f} ± {sem(d[\"FES\"]):.3f} s '\n",
" f'NOFES = {d[\"NOFES\"].mean():.3f} ± {sem(d[\"NOFES\"]):.3f} s '\n",
" f'Δ = {delta:+.3f} s | {stats_str}')"
]
},
{
"cell_type": "markdown",
"id": "21da9cda",
"metadata": {},
"source": [
"---\n",
"## Figure 6 — Decision-margin distribution shift (Wasserstein-1)\n",
2026-04-22 04:21:48 -05:00
"\n",
2026-04-22 16:20:39 -05:00
"Mean |margin| collapses the whole distribution to a scalar. If FES spreads the margin\n",
"distribution further from zero (or sharpens it into a narrow high-confidence mode) without\n",
"moving the mean, that shows up here but not in mean amplitude. For each (subject × pair),\n",
"we compute W₁ between the FES and NOFES margin distributions, separately for MI and REST\n",
"trials. Larger W₁ = greater distributional divergence between conditions."
]
},
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 13,
2026-04-22 16:20:39 -05:00
"id": "41bf28d0",
"metadata": {
"execution": {
"iopub.execute_input": "2026-04-22T19:27:28.519611Z",
"iopub.status.busy": "2026-04-22T19:27:28.519516Z",
"iopub.status.idle": "2026-04-22T19:27:28.702716Z",
"shell.execute_reply": "2026-04-22T19:27:28.702337Z"
}
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABggAAAIqCAYAAADxQcHcAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAA3udJREFUeJzs3QeUE9X7//EHlo6AFGkiCFJEmgURUQELooCKBUSUqqBiwa6oKNgAsRewIc2GiuDXigXsCtgRQURAUFGQjvQl//O5vzP5T7JJNslmd7O779c5OZtNJpMpdyYz97n3ucUCgUDAAAAAAAAAAABAkVI8vxcAAAAAAAAAAADkPQIEAAAAAAAAAAAUQQQIAAAAAAAAAAAogggQAAAAAAAAAABQBBEgAAAAAAAAAACgCCJAAAAAAAAAAABAEUSAAAAAAAAAAACAIogAAQAAAAAAAAAARRABAgAAAAAAAAAAiiACBAAAAAAAAAAAFEEECAAAAAAAAAAAKIIIEAAAipwDDzzQihUr5h79+/e3wspbRz1GjBgRfP2jjz4KeU//I7VUrrztq/IGFGYrVqwIOadMmjQpvxcJPulwzo/2e1QQ1wVZdezYMbhP9Dw/xSojKnv+99JFQVzmdF0uAEByCBAAQBTvvvtuyIXv9OnTs0xTq1at4PsVKlSwzMzMkPc/+OCDkHnMnDmT7V0AbzZT5dVXX7XLLrvM2rZta+XKlQspG6pgK4ioGEwPqpAtDOUpvwKFsR7+iprwSpxoj0hBoU2bNtkdd9xhRx55pFWqVMlKlixpVatWtUaNGtkpp5xiN954o3311VdWlISfP/R4+umnYwbcYlVEfffddzZkyBBr2bKl7bvvvlaqVCnbb7/97JhjjrHbbrvN/vrrr4ifS3a/hi9XtEekCvFvvvnG+vXrZwcddJCVLVvWypQpY7Vr17ZWrVpZ7969bcyYMbZhw4aktiuA3FeYf3cL4zU4ACC2Etm8DwBFlioUMjIygpX+H3/8sZ199tnB95csWWJ///138P+tW7fat99+6yp/PPqMRxfZxx13XJ4tP6K75ZZbXGWdNG/ePM821V133WU//PCD5TdVSI0dOzbkfwC5548//nDn//AKpPXr17vH0qVLbdasWbZz504XQCzKbr/9djv//PNdEDVeO3bssCuvvDJicOHff/91jy+++MLuvfdee+CBB1wQIT9NnjzZBg4caHv37g15ffXq1e7x448/2osvvmidO3e2ypUrW2Hg/81p165d0vPh9ys9XXrppdatWzf3/IADDrB0dfLJJ9s+++xj6aYglut03ZYAgOQQIACAKNQj4PDDD7f58+dnqeyP9L/3WrQAgSqi1Vq0KFGljYIsaimbTgYNGpQv3+u1QFW52r17t73xxhv5shy6eb/uuuvy5buBdKRK2Jtvvjnie7Eqajp16uQqScKph4DfDTfcEAwOlChRwgWbmzZt6p6vWrXK5s2bZ99//32O16MwUAX5/fffb8OHD49r+kAgYH369HE9tDw1atSwnj17Ws2aNe2XX36xadOmueCLHurFpc/ob073aziVoUgV+v4KcfUK0Hd7wYH999/flQct67Zt29zyfvbZZ247FCap+s3h9yv19uzZ465J1JMlWeeee64VBDoWcxKgSrXNmzdbxYoVC2S5TrdtCQDIoQAAIKrrr78+oFOlHsWKFQusW7cu+N7555/vXi9XrlygUqVK7nm3bt2C72/bti1QunTp4Ocvv/zy4Hv33ntvoHv37oHGjRsHqlatGihRokRgn332CbRo0SJw9dVXB1atWpVlWbZv3x4YPXp0oE2bNu77MjIyApUrV3bzOOeccwJjxozJ0fSyd+/ewLRp0wJdu3YN1KxZM1CyZEn32eOOOy7w5JNPBnbv3p3lM9766XH77bcHPv7448CJJ54Y3CbLly93082dOzdw7rnnBurWreu2ix77779/oF27doGhQ4cG5s+f76bTPPzzjPSYOHFiyDLMmjXLrVOdOnUCpUqVClSoUCFw5JFHBsaOHRv477//sixzvXr1gvPq169f8HUta/j3fPDBB4ETTjjBzVP7umPHjoEvv/wyqaPGvyyat/+7vO2UqI0bN7oyo3XXNtX+1X7ftWtXln3jmTNnTsh7+t+TmZkZGDdunNvnKpsqNxUrVgwcdNBBgdNOOy1w5513BrZu3ZplO0Z7eGbMmBHo06dPoGXLloEaNWq4/VS2bNlAgwYN3LH01VdfZVm38G3022+/BcaPHx9o1apVoEyZMm75zjvvvMCff/4Zcdv8/fffgeHDh7uysO+++7ryXKtWLbcPH3/88SzT//zzz4FLLrkk0KRJE7ev9R3anldddVXgjz/+SGi/qFx5y63tpG12ww03uOdad6239smOHTsifl7bQ9urfv36bjm0PDo/aH3856HwMhvp0aFDB3dsV69ePfja3XffHbKd/Oe5tWvXBt+76667gu9p2+X0fCH//vtvYMSIEYHWrVu7sqXP6VygfTlv3ryUl4No/OVXz+MRfuz4j6tYdO7N7jOrV68OngfjoXKgMnXSSScFDjzwQLct9VtSpUqVwNFHH+1+Z/Q7lN3579dffw307t07sN9++7my2axZs8CkSZMifudff/0VGDhwoCtL2vbaB0899VRg2bJlMc/R0UQrvzrf/vPPPxGPp/Dbl1deeSXkvebNmwfWr18fMs13330XKF++fHAaLbv/mE52v4YvVzzn8ddffz2uz+h3xn8sxuOXX35xv4UqbzpntG3bNjB9+vSY5/xkf0dF5y8di/rNV/nRsazj8fDDDw9cc801gZ07dwanjbV99d2nn356oHbt2m4e+m044IAD3Lla12FaL0886/LWW28FzjzzzOD8tC4qq8OGDXPnu3A6R/rPl5pGvwX6vLZFw4YN3fGkc14iIl1TvPDCC4HDDjvMlUFdD+lc6J0rdd7U8affcu0HHd/+bZjK4/7HH39021yfC9+O7733njuPqwzpt1PXtt9//32W67NY29AvfN9/88037rs1b20HlbM333wzkKhkynusddB54+abb3ZlRWVG1z8qz4cccoi7Tnn66acT+t2NVgZee+01t6903e8tQyLLrOu7UaNGuesT71pa19G6HvTL7joz0rVwItfgsbalqOzq/KDj2LvXUXlr37594NFHH414/ZNsWYl33wEAoiNAAAAx6CLUf7GqSk6Pbt70mm6MVUGm57qAVQWrzJ49O+SzqsTw6KI11sW3bnYWLlwYsiydOnXK9qI9J9PrQr1Lly4xp9dFfnhFgf993fDoojz8huSjjz5yNwax5u1VGCRyc6Kb9UGDBsWcVpWq/oqmRAIExxxzjKswDZ+nKi8WLVqUo2MnFQGCzZs3uwr3SOutyvxI2ze7G9Hstqd/WRMJEJx99tkxpytevHjgueeei7mNVGER6bMHH3xwlhtNBXb8lbLhD91E+j3zzDOuIijWMfnZZ5/FvW/8FYcKiOimNtJ8dZzu2bMn5LMjR46MWO68hyqVFi9enHBFhQJ03mudO3cOfp8q+f3Tq/Ii0nnkggsuyPH5QhXg2h7RPqPzhyoOUlUO0iVAoEo87zPaD5Eq8BK1YMGCbPe9Kmq9gF6k9db5w79s/kd4kECBa1XaxnO+STZAoCCU93zIkCFxBQgUwPW/984770T8rhtvvDFkujvuuCNfAgQ6vvyfUYVmKvzwww/uGiSe/eM/5yf7O6ryoMrsWJ/bsGFDcPpo21fn/ezKsb88ZRfg7t+/f8x5VatWLfDFF19ErdxW8FaBgUifVWV+IsLLt4KikeY7YMAAV8Eb6T2tT6qPewUo/AEz/3acOnVq1Oue8OvKZAIERx11VMTfWl0D6Lo5t8t7tEpt/XYouBhru3q/EzkJEBx77LFZpk00QKCATaTv1Pl8y5YtaREgUHBT5TDWfA499NDAmjVrclxWEtl3AIDoSDEEADEoZ3T4OATdu3e3ZcuWuZzS0qFDB9ct+6233rKNGze6HPOHHXZYlhRE7du3Dz6vU6eOG/SrXr16Lh2BUs9ofi+//LLLR60UBEpJ8eabb7rpFy9ebO+
"text/plain": [
"<Figure size 1560x540 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: margin_wasserstein.png\n",
"\n",
"Mean W₁ across 7 (subject × pair) comparisons:\n",
" MI trials: 1.207 ± 1.150\n",
" REST trials: 0.785 ± 0.623\n",
" W₁(MI) > W₁(REST) in 6/7 comparisons — FES perturbs MI margins more than REST margins\n"
]
}
],
"source": [
"w_results = [] # one row per (subject, pair) — Wasserstein between FES and NOFES margin dists\n",
"for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='FES'), None)\n",
" nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='NOFES'), None)\n",
" if fes is None or nof is None: continue\n",
" if fes['margin'].size == 0 or nof['margin'].size == 0: continue\n",
"\n",
" def w1(fe_mask, no_mask):\n",
" f = fes['margin'][fe_mask]; n = nof['margin'][no_mask]\n",
" return wasserstein_distance(f, n) if (f.size and n.size) else np.nan\n",
"\n",
" w_results.append(dict(\n",
" subject=subj, pair=pair['name'].split()[0],\n",
" w_mi = w1(fes['y_test'] == 1, nof['y_test'] == 1),\n",
" w_rest = w1(fes['y_test'] == 0, nof['y_test'] == 0),\n",
" w_all = wasserstein_distance(fes['margin'], nof['margin']),\n",
" fes_mi_mean = fes['margin'][fes['y_test']==1].mean() if (fes['y_test']==1).any() else np.nan,\n",
" nofes_mi_mean = nof['margin'][nof['y_test']==1].mean() if (nof['y_test']==1).any() else np.nan,\n",
" ))\n",
"\n",
"if w_results:\n",
" labels = [f'{r[\"subject\"]}\\n{r[\"pair\"]}' for r in w_results]\n",
" w_mi = np.array([r['w_mi'] for r in w_results])\n",
" w_rest = np.array([r['w_rest'] for r in w_results])\n",
" x = np.arange(len(labels))\n",
"\n",
" fig, axes = plt.subplots(1, 2, figsize=(13, 4.5))\n",
" fig.suptitle('Wasserstein-1 distance between FES and NOFES decision-margin distributions',\n",
" fontsize=12, fontweight='bold', y=1.02)\n",
"\n",
" width = 0.4\n",
" axes[0].bar(x - width/2, w_mi, width, color='#E05C2A', label='MI trials', edgecolor='white')\n",
" axes[0].bar(x + width/2, w_rest, width, color='#2A7BE0', label='REST trials', edgecolor='white')\n",
" axes[0].set_xticks(x); axes[0].set_xticklabels(labels, fontsize=8)\n",
" axes[0].set_ylabel('W₁(FES margin, NOFES margin)')\n",
" axes[0].set_title('Per-(subject × pair), by trial class', fontweight='bold')\n",
" axes[0].legend(); axes[0].grid(axis='y', alpha=0.3)\n",
" axes[0].spines[['top','right']].set_visible(False)\n",
"\n",
" # Direction of the shift (FES mean − NOFES mean) on MI margins\n",
" fes_mi = np.array([r['fes_mi_mean'] for r in w_results])\n",
" nofes_mi = np.array([r['nofes_mi_mean'] for r in w_results])\n",
" mi_shift = fes_mi - nofes_mi\n",
" colors_sh = ['#E05C2A' if d > 0 else '#2A7BE0' for d in mi_shift]\n",
" axes[1].bar(x, mi_shift, color=colors_sh, edgecolor='white')\n",
" axes[1].axhline(0, color='k', lw=0.8)\n",
" axes[1].set_xticks(x); axes[1].set_xticklabels(labels, fontsize=8)\n",
" axes[1].set_ylabel('mean(FES MI-margin) − mean(NOFES MI-margin)')\n",
" axes[1].set_title('Direction of MI-margin shift (positive = FES margin more negative is\\ncolumn below zero; see sign convention in code)', fontsize=9, fontweight='bold')\n",
" axes[1].grid(axis='y', alpha=0.3)\n",
" axes[1].spines[['top','right']].set_visible(False)\n",
2026-04-22 04:21:48 -05:00
"\n",
" plt.tight_layout()\n",
2026-04-22 16:20:39 -05:00
" plt.savefig('margin_wasserstein.png', dpi=150, bbox_inches='tight')\n",
2026-04-22 04:21:48 -05:00
" plt.show()\n",
2026-04-22 16:20:39 -05:00
" print('Saved: margin_wasserstein.png')\n",
2026-04-22 04:21:48 -05:00
"\n",
2026-04-22 16:20:39 -05:00
" print(f'\\nMean W₁ across {len(w_results)} (subject × pair) comparisons:')\n",
" print(f' MI trials: {np.nanmean(w_mi):.3f} ± {np.nanstd(w_mi, ddof=1):.3f}')\n",
" print(f' REST trials: {np.nanmean(w_rest):.3f} ± {np.nanstd(w_rest, ddof=1):.3f}')\n",
" print(f' W₁(MI) > W₁(REST) in {int((w_mi > w_rest).sum())}/{len(w_results)} comparisons '\n",
" f'— FES perturbs MI margins more than REST margins')\n",
2026-04-22 04:21:48 -05:00
"else:\n",
2026-04-22 16:20:39 -05:00
" print('No comparisons available for Wasserstein analysis.')"
]
},
{
"cell_type": "markdown",
"id": "b3db60ba",
"metadata": {},
"source": [
"---\n",
"## Summary Statistics"
]
},
{
"cell_type": "markdown",
"id": "9f27a80e",
"metadata": {},
"source": [
"---\n",
"## Figure 4 — Within-session MI-accuracy trajectory (sliding window)\n",
"\n",
"Refined from the original \"thirds\" analysis to address its noise floor. Each online session\n",
"is parsed into chronological MI-only trials, then a W-trial sliding window gives a smoother\n",
"accuracy trajectory. The slope of a linear fit to this trajectory is a per-session \"learning\n",
"rate.\" Restricting to MI trials isolates the class that FES actually perturbs (orthotic +\n",
"stimulation fires on MI, not REST) — slope differences between FES and NOFES sessions would\n",
"indicate whether proprioceptive feedback accelerates within-session adaptation.\n",
"\n",
"Caveat: even with sliding windows, ~25– 30 MI trials per session is low-N and the slope\n",
"estimator is noisy; read this as exploratory, not confirmatory."
2026-04-22 04:21:48 -05:00
]
},
2026-04-21 13:01:49 -05:00
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 14,
2026-04-22 16:20:39 -05:00
"id": "a9db789c",
"metadata": {
"execution": {
"iopub.execute_input": "2026-04-22T19:27:28.705422Z",
"iopub.status.busy": "2026-04-22T19:27:28.705328Z",
"iopub.status.idle": "2026-04-22T19:27:28.901376Z",
"shell.execute_reply": "2026-04-22T19:27:28.900874Z"
}
},
"outputs": [
{
"data": {
2026-05-01 12:28:17 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABggAAAIqCAYAAADxQcHcAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzsnQV4HNf19s+y0JZJMskgWWaSKXbiMHPaQJv8g23SNE3zFZJiqBxqmrZpoA1jm7RpsGF0yDEzyWxLtmWSZNHy97x3NOvZ1Ura1e6stKv393gfjXd26M6dmTsH3mMJBoNBIYQQQgghhBBCCCGEEEJIj8La1TtACCGEEEIIIYQQQgghhJDUQwcBIYQQQgghhBBCCCGEENIDoYOAEEIIIYQQQgghhBBCCOmB0EFACCGEEEIIIYQQQgghhPRA6CAghBBCCCGEEEIIIYQQQnogdBAQQgghhBBCCCGEEEIIIT0QOggIIYQQQgghhBBCCCGEkB4IHQSEEEIIIYQQQgghhBBCSA+EDgJCCCGEEEIIIYQQQgghpAdCBwEhhBBCCCGEEEIIIYQQ0gOhg4AQQgghhBBCCCGEEEII6YHQQUAIIYSQMH71q1+JxWIJfdKBjz/+OGyf8f94GTFiRGj5K6+8Mubl8Ft9OayDSMa1XWf7RrrRU44znVm6dKnYbDZ1jo4++uhuez+Nl+OOOy60PUyTnnkPf/LJJ8P63tatWzu9rkcffTS0nh/96EcJ7xshhBCSydBBQAghhESAF1LjCyo+AwcOFLfbHbWtpk6d2ur3xpfayPXR8JZaItsfBghCCElHR8mNN94ogUBATf/iF7/o6t0hKWD16tVy7bXXyrhx4yQvL0/sdrv07t1bysvL5ac//alUVlbGtb6e4oy5/PLLZciQIWr6wQcflE2bNnX1LhFCCCHdFntX7wAhhBCSDuzZs0eee+45+da3vhX2/fvvvy/Lly/vsv0iGqWlpXLPPfeE/T9VfPOb35SJEyeqaRhtCNsuXbn55pultrZWTet9mnQf3nvvPfnoo4/U9JgxY+SMM87IuPspCefNN9+Ur33ta+LxeMK+r6urk2XLlqkPIuU///xz5UBINd35+ed0OuX666+XX/7yl6r9kB35zDPPdPVuEUIIId0SOggIIYSQGPnzn//cykFw7733sv26AcXFxXLTTTd1ybZPO+009SFsu66kublZSc84HI5Or+Oaa66RTODQoUOSn58vmcYDDzwQmr744osz8n5KwvnJT34Scg7A4H3VVVfJ0KFD5cMPPww5iw4ePKgcOo8//njKmg8Oil69enX75x+uEzgIwL///W+57777pH///l29W4QQQki3gxJDhBBCSAfA6AZWrlypIjh11qxZI++8807Yb8yOJPz617+uUuZhKMDLOeSN8PKLDIeOZATwm+uuu04t73K5pKysTBkVgsFgh9v2+XzKaNSenq9ROxjtsXPnzg4NH/rvR48eHTbv9NNPD80zOmEgpWCUC3r33Xfb1cyGVMjIkSPD1g0DSyw1FrZv3y5XXHGFFBYWqvZClORTTz0VlwZzpFTJxo0b5f/+7/86XGcs4NjPPffcUH/IycmRYcOGyfHHH69kJzZs2NBqmbVr16o+MHbsWMnNzZXs7GwViYzzGU2mAkbnu+66S4444ggpKChQ0hZ9+/ZVy1x44YVy9913J/T7jvSrYejFMkceeaT06dNHGb/RdqeccoqKBNWlVtqTk/rggw/kxBNPVNcLjhntM3/+fEkmO3bsUP158uTJyjCNc1tSUqIM7uvXr2/1e+wn2vyYY46R4cOHq2VwDnFsJ5xwgvzjH/8Qv9/fajnjsSEadt68eXLSSSeptsa5xDlMpA3aktZJtF0XL16sot3xWxwrjhHXaGf1xiP7zb59++R73/ueMpyiz+n3DGzj6quvlhkzZsjgwYNVG2VlZanrBPdS/f4Rud5t27aFvsP12ZYeP+6dL774opx11lkyaNAgdQ5xLnBecQ5x30wWVVVV8sYbb4T+f9FFF6X8fhqtRg2k9+644w51T0G/hxzfd7/7XWVAjgTX6/333y8TJkxQ5wHnBPcjnL+OqK6ulltvvVWmT5+uItXR1lge98DXX3897LfYNvqBvo/PP/98aN5XX30V+h7nyngPwfWqz8M9pzuwefPmsP17+OGH5ZZbblFjEWPE/t69eztcl37uPvnkk9B3mI4mwRd5nhsaGpSkFe5ruA//v//3/zq8h7/yyitK5mfKlCmqX6B/4DmFjJRLL71UnYt4iPf5ArBPM2fOVNPoq08//XRc2ySEEEJ6DEFCCCGEhLFlyxZYzEOfr3/966Hp0047LfS7b3/726Hvzz///LBlsI621nfFFVfE1eJ+vz945ZVXhq0j8tO/f//gF198EbbcscceG5pfUlISHDx4cNRlf/WrX4Utd/vtt4fN1/nDH/4Q+q5v377BpqamsOVOPfXU0PzTTz+9w+N66623wraza9cu9b3P5wvm5+eHvj/nnHNCyzz33HOh751OZ7ChoUF9/9FHH4WtC/8Hw4cPb7fdjMdn/O3MmTOD/fr1i/r7J598Muw4cD71eViHEeM6J0+eHOzVq1dM6+yIZ599tsPjeuKJJ8KWefTRR1WbtfX7Pn36BD/77LOwZU4++eSY268zv2+v7SoqKlS/bW9dJ510UrCxsbHNa+2oo44KWiyWVstlZ2cH165dG3N7G89j5PX7v//9L6y/Rn6ysrKCL730Utgyr7/+eofthGsI174R4/w5c+YEbTZbq/tOIm3Q1nEmss4PPvgg6HK5Wv3WarUGzzrrrFb7HwvGfoN739ixY8PWg3sYuPHGGztsZ9zXoq23rY9+b2lubg6eccYZ7f72uOOOC92jEuWpp54Ku/92xf002vPh6KOPjnrsxx9/fKt9vOqqq6L+trS0NDh+/PjQ//HsMjJ//vzggAED2m3ryy67LOx6OeKII0Lzrr322tD3d911V9hyS5YsCc0rKysLfX/LLbfEdF4i2yqWj7E9O2LatGmh5crLy4ObNm1Sfe+1115T15A+729/+1uH64o8d+09Nzo6z/r9ob17eOS4KNo9AM8yI9h+W/eEeJ8vOj/84Q/DnhmEEEIIaQ0lhgghhJAOmDRpkopkRsQeMgYQhd2vXz959tln1XxEr5155pny0ksvmdKWiPI3FtZF1DmiJpERgOhWr9erIjDxXUVFRVQdYEQhImIT0ZqIon3ooYekqalJzfvTn/6kshA6kib5zne+I7/97W/VcgcOHFDp+pdddpmah+0jolgHUbsdcfTRR6ttYv8BoqERFbt06VLV3jqffvqpivK0Wq1hkY9z5sxR0YgdaaojKvkPf/hD6LtvfOMbKqK4PRYuXKgi1hHljeN95JFHQhHdd955p8osiJcVK1YkbZ2IwtXRoycRUYusDfTPL7/8Muz3iNTE+dOjZdGn0V9gc/7Xv/6lijdCpgJa13ofWrduXVjGDOah3XBusB1sw1j0Md7ftwfa5bzzzguLnsUxjh8/XvWzzz77LFQD5Ac/+IGK1o4GdLkR2YxocWh1IwsHoP3/8pe/qOsgERBpjv1qbGxU/0e2CvowrrVXX31VbRNRr8gaQaFRRN8CRL4iqhbtM2DAANXe2Cf0fUSJ47y89dZb8t///lcuuOCCqNtGe6L/X3LJJSoiHhlOxuvJzDaIdZ04dtwjjAXecf0he+m1114Li4jvLLj34INshrlz56p+rBcmRWYD7jPI7ECkMdoLNRbQbxYtWhSKlMa1h2h0XU8d9wusB+AcYZ8j9fhRLFg/btybcJ5wXaFPILsFx4zI+x/+8Idt9s94wP1RR4+ITvX9NBpYH651XJuo06NngUD+BvcdRHsDXA9PPPFEaLmioiIVXY52wvfGfYzMBjjnnHNCEfK4dtCnkC2CPqTXAEKbo0/qcjLoD3qEurHtjNMAbYBiv7t27VL3Ph0s3x3461//qjJUampq1LmMrAeBjBw8U/Bs7whkXqHIMa5P/d6Ke5Jx2Wh9Sz/Ps2bNkpNPPlmdMzzLOgJR/shwQt/A73FfxLX6v//9Tz0v0A9x/z7
2026-04-22 16:20:39 -05:00
"text/plain": [
"<Figure size 1560x540 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: within_session_mi_trajectory.png\n",
"\n",
"Condition-averaged slopes (FES − NOFES paired per session):\n",
" FES: mean slope = +0.0020 (n=8)\n",
" NOFES: mean slope = +0.0045 (n=8)\n"
]
}
],
"source": [
"WINDOW = 8 # trials per sliding window\n",
"\n",
"lr_results = []\n",
"for r in results:\n",
" mi_successes = np.array([int(t['success']) for t in r['trials'] if t['cls'] == 'MI'])\n",
" if len(mi_successes) < WINDOW + 2:\n",
" continue\n",
" traj = np.array([mi_successes[i:i+WINDOW].mean()\n",
" for i in range(len(mi_successes) - WINDOW + 1)])\n",
" slope = np.polyfit(np.arange(len(traj)), traj, 1)[0]\n",
" lr_results.append(dict(subj=r['subject'], pair=r['pair'], cond=r['condition'],\n",
" traj=traj, slope=slope, n_mi=len(mi_successes)))\n",
"\n",
"# ── Figure 4a: slope per condition (FES vs NOFES) ─────────────────────────────\n",
"fig, axes = plt.subplots(1, 2, figsize=(13, 4.5))\n",
"fig.suptitle(f'MI-only within-session learning rate (window = {WINDOW} trials)',\n",
" fontsize=12, fontweight='bold', y=1.02)\n",
"\n",
"fes_slopes = np.array([r['slope'] for r in lr_results if r['cond'] == 'FES'])\n",
"nofes_slopes = np.array([r['slope'] for r in lr_results if r['cond'] == 'NOFES'])\n",
"\n",
"ax = axes[0]\n",
"means = [fes_slopes.mean(), nofes_slopes.mean()]\n",
"sems = [fes_slopes.std(ddof=1) / np.sqrt(len(fes_slopes)),\n",
" nofes_slopes.std(ddof=1) / np.sqrt(len(nofes_slopes))]\n",
"ax.bar(['FES', 'NOFES'], means, yerr=sems,\n",
" color=[cond_color['FES'], cond_color['NOFES']],\n",
" capsize=6, edgecolor='white')\n",
"# Overlay individual points\n",
"for i, slopes in enumerate([fes_slopes, nofes_slopes]):\n",
" ax.scatter(np.full(len(slopes), i) + np.random.uniform(-0.08, 0.08, len(slopes)),\n",
" slopes, color='k', alpha=0.5, s=20, zorder=3)\n",
"ax.axhline(0, color='k', linestyle='--', lw=0.8)\n",
"ax.set_ylabel('Slope of MI-accuracy trajectory\\n(Δ fraction correct / trial-step)')\n",
"ax.set_title('Condition-averaged learning rate', fontweight='bold')\n",
"ax.spines[['top','right']].set_visible(False)\n",
"ax.grid(axis='y', alpha=0.3)\n",
"\n",
"# ── Figure 4b: trajectories per (subject × pair) ─────────────────────────────\n",
"ax = axes[1]\n",
"for r in lr_results:\n",
" ax.plot(r['traj'], color=cond_color[r['cond']], alpha=0.6, lw=1.5,\n",
" label=r['cond'])\n",
"ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n",
"ax.set_xlabel(f'Sliding window start (MI trial index, window = {WINDOW})')\n",
"ax.set_ylabel('MI accuracy within window')\n",
"ax.set_title('All session trajectories overlaid', fontweight='bold')\n",
"# Dedup legend\n",
"h, l = ax.get_legend_handles_labels()\n",
"seen = set()\n",
"handles = [(hh, ll) for hh, ll in zip(h, l) if not (ll in seen or seen.add(ll))]\n",
"ax.legend([h for h, _ in handles], [l for _, l in handles], loc='lower left')\n",
"ax.spines[['top','right']].set_visible(False)\n",
"ax.grid(axis='y', alpha=0.3)\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('within_session_mi_trajectory.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: within_session_mi_trajectory.png')\n",
"print(f'\\nCondition-averaged slopes (FES − NOFES paired per session):')\n",
"print(f' FES: mean slope = {fes_slopes.mean():+.4f} (n={len(fes_slopes)})')\n",
"print(f' NOFES: mean slope = {nofes_slopes.mean():+.4f} (n={len(nofes_slopes)})')"
]
},
2026-04-22 17:41:13 -05:00
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 16,
2026-04-22 17:41:13 -05:00
"id": "ddedb148",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABoAAAARRCAYAAAAGpkARAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3QWYlFXbB/B7Z7ZZdulGVMAA7FZAMD+7X+NVsTEwsVswsBs7sDteG+x4fcVCpUQRyQWW7Zqder7rf5ZnPM+zM7Mzs7OT/5/XXLKTT8e5z32fHMMwDCEiIiIiIiIiIiIiIqKM4Uj2BBAREREREREREREREVF8MQBERERERERERERERESUYRgAIiIiIiIiIiIiIiIiyjAMABEREREREREREREREWUYBoCIiIiIiIiIiIiIiIgyDANAREREREREREREREREGYYBICIiIiIiIiIiIiIiogzDABAREREREREREREREVGGYQCIiIiIiIiIiIiIiIgowzAARERERERERERERERElGEYACIiIiIiIiIiIiIiIsowDAARUdaaP3++XHrppbLzzjtL3759JT8/X0pKSmTTTTeVY489Vp5//nlxuVyS7WpqauTGG29Uy6l79+6Sl5cnPXr0kGHDhsm4cePkvPPOk8cff1z8fn9cfzcnJyfwwO9E4/rrr7d8/vPPP5dM0tzcLGVlZZZ5xOO///1vsict62244YZt1kskj5NOOqnTl93TTz9t+U38TZkFx0p9HetwHNRfw3GyM3Tk2J0Ic+fOlQsvvFC222476dWrlzqn4f/bbrutev63335LyPEB/860/fN///ufnHvuubLNNttI79691bLt2rWrjBo1Sk455RR5++23xev1SjoItx3H6xoj3P6aDjwej9x9992yww47qOvnwsJCGTRokIwdO1Yuu+wyWbp0acqeRzvrOBVuHw8F0x3L/Eb6/enm77//VtsS5rGgoEBWrFjRqb+V6GuxVPjtdFVdXS0333yz7LbbboFzTFFRkQwZMkQOO+wwee2118QwjDaf23///QPLedq0aUmZdiKiZGMAiIiyTkNDg5x44omqQeL222+X7777TtauXatuZBsbG2XRokXy0ksvyfHHHy9XXnmlZLMFCxbIyJEj5ZprrlHLCcEgNN7gAnzx4sXyxRdfyP333y+nn366uN3uZE9uSuqMRj00otXV1bV5/tlnn+3wdxMlQqKCERS5TA+co0MHzlVbbrml3HPPPfLTTz9JZWWlOqfh/z///LN6fquttpLTTjuNHUCisGbNGtXAtssuu8gDDzwgc+bMkXXr1qlli2uuefPmyVNPPSWHHnqoTJ8+XTJdugd2ItHS0iJ77bWXXHTRRfLDDz+o62c8t3LlSvnqq6/ktttuU/sUUbTQOQ/bEkyYMEEFFYn+/PNP2WKLLeSqq65SHd7McwzO7cuWLZO33npLjjrqKPWwB4FwH2tCAAnnLCKibJOb7AkgIkqk+vp6GTNmjPzyyy+W5/v3768ahZxOp+qRhcAHLh7jndWSTjD/xxxzjKxatSrw3ODBg2XzzTdXPfLKy8tVow6yUVLNiBEj5Igjjgj8jV5imeSZZ54J+vwrr7wi9957r8pmo+RAIygCyjo0juk9obEPYRvVoQd1Z0NvYX2/yNTewxQcjoP6+rdvg5kMjUQHHHCAfPrpp5bn0ZiEnsM47yMzyDz3PfHEE+q5Dz/8UHJzE3O7lK77Jxr8EfhZvny55XlMP7Yxn8+nGu7QaQQy4boqXtcYu+++u8o+S0cPP/ywfPnll4G/e/bsqbYDBPx+/PFHdb2dyudRff2ho1MyYbqx3HSYV8yzCdsJthddnz59JNPMnj1bXn31VfVvBE8RDOpMXbp0sWwLibgWo9ggQxfnG33djR49WnVORCdF0+uvv646cqKahwnHJtz/IziNY9OUKVPkwQcf5KogouxiEBFlkaOPPhpdggKPoqIiY8aMGYbf77e8b9myZcZFF11kXHLJJUa2+uGHHyzL6txzz23znpaWFuOjjz4yjj32WPXveNJ/e/fddzfS1VNPPWWZF/zdEatXrzacTmfg+/Ly8izf//rrr8dt2ik+JkyYYFlH1113XdYv2s8++4zLpBPgWKlva9HAdql/FusoU47dN9xwg2W6unTpos5dug8//FA9r79vypQpcZ2OIUOGBL4b/84EO++8s2WZ9ezZ03jvvffavG/BggXGSSedZNx7771GOujIdtyR/TBdHHTQQZZ5/PvvvwOv4XrwxRdfNL7//vusO4/Gax+3XzumyrG0sx133HGBed5tt92MTLZkyRLLOsY2TqGVlJQElhXug37//feQ1y9nn312m88/9thjlmuAmpoaLm4iyiosAUdEWQMlSV5++eU2mRQoB2cv0YFMlzvvvFOmTp0atlbz6tWr5cwzz1TvRy9he/1m/OYZZ5whm222WaA+Ot57+OGHqzJewaCUGsqqoaefXkN/o402UuU2UJZO7xUI6M100003qXF6MD4PPtOtWzc1Tg96Pd9www3y+++/R7W8/vjjD8vf48ePb/MeZJrss88+8sILL7TJOrHXNMfy60i5ISwXzCN6fWI59uvXT04++eQ2vY4j/W4ss7vuuivQAxfTjx6se+65pypVE26cAvRkvvjii9UYEhgXCZ/F9KAm9bXXXqvKCZolrjCNOvzdkTJLWNboUW26+uqrw5aBQ5kE/fdQl98O04tlYL7H3hsWPbaRXXTwwQfLwIEDVQZYaWmpbL/99mofQUlAu2j2F/QOR0kH9PodPnx4YFwObPfo4YuyTSjXFApKOZx11lkyYMAAtW1gu8f3oSRNJKV40NMWvUwxbgXGVsL6xHyijIQ9Y6Cz2KcTyxy9EzEuCXo56tOO7RYlKlGqypxnPDDNWIYzZswI2ss+knKEKLuCMb323XffwNho2MaxbaM8VriMP2QFYt/bdddd1b6EdYhjGLYTLN+qqqrAdmE/nuAYFW7asI2g9ycyNbHtYbqQuYn5xTYfbH8NdhyYOXOm7L333mqe8Bwy5vT3TJo0qc33YFliOZvvwfE8Usg2xdgnm2yyiVqPWCZYrlh3KG2DdayXzwxWGg/HuFNPPTWw72H7vu6666SpqSni6Qj13fp2gXWgwzoKdwyPVCTHbmx3eN78LWTmBHP00UcH3oNjSCRjQ5jHeh3+xrlLh23e/j5cB+g984PtQwsXLpR///vfar1i/WBdo8SMfpyORHv7p31sEWzz5jGiuLhYnfMPPPBA+fXXX8OOfXjOOeeoYzyOr5hefBe2xXDH2FBQcgfj/piQRf3ee++p/dIO+w3OrRMnTmzzWrz2b/QCP+SQQ9TxB9satqNHH3005PTjWuqggw5Syw7XaLiGQs/x9oS6xjCve1AeVxdq3JZIzk/JWjbtsV/zoRyT/hoyyHHsT6RozqPhxgDCvnDJJZeoa+6hQ4eq8wWON7g+2HrrreX8889vc42cSDivdOS8Fez6/J133lHHfMwj9gWM4/Tuu++GnIaKigp1zjDHB8U6xzEQx6A33ngjpvnCd2IMF9Nxxx1neR3ZH/p0o/ygDvuy+Rr+rcPYZPpnzeoG7Y3DE2xZffDBB+o6AscNHHt33HFHefPNN0POF647sGxxzMXy3WOPPWTWrFkRLxesG2QpbbDBBmrfNcerxbVxsDKLuI42pxfbvv2cZr6G6devFZE5o88r7jdSiX7MwbTjXGuyH2twrLQ78sgjA9+B+wNcKxMRZZVkR6CIiBLliiuusPQO2nbbbTvUU2vcuHHGgAEDQvbeQo9jh8Nhed3+QA/KpqamwGeQibTvvvuG/Qwe55xzTuAzzc3NxpZbbtnuZ26//fao5heZJPrnhw4dqrKlVq5cGVNvTSy/aHqb66/tsMMOxujRo4POV58+fVTP4mi+++eff7b00gz2GDNmjFFdXd1mvu677z4jPz8/7GfxOXuGQ6hHtL3st9lmm8BnCwoKjNraWmOnnXYKPIdpq6ystHxms802C7y+wQYbtMl4e+eddyzTdOeddwZeq6qqUtt6uHkYNGiQMWfOnJj3l2effbbd5YTefo8
"text/plain": [
"<Figure size 1680x1080 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: avg_within_session_trajectory_cross_overlay.png\n"
]
}
],
"source": [
"# Re-calculate trajectories for BOTH MI and REST trials to support the requested 2x2 overlay\n",
"WINDOW_BOTH = 8\n",
"\n",
"lr_results_all = []\n",
"for r in results:\n",
" for trial_type in ['MI', 'REST']:\n",
" successes = np.array([int(t['success']) for t in r['trials'] if t['cls'] == trial_type])\n",
" if len(successes) < WINDOW_BOTH + 2:\n",
" continue\n",
" traj = np.array([successes[i:i+WINDOW_BOTH].mean()\n",
" for i in range(len(successes) - WINDOW_BOTH + 1)])\n",
" lr_results_all.append(dict(\n",
" subj=r['subject'], pair=r['pair'], cond=r['condition'], \n",
" trial_type=trial_type, traj=traj\n",
" ))\n",
"\n",
"# Create a 2x2 grid: Rows = MI vs REST, Cols = ONLINE FES vs ONLINE NOFES\n",
"# In each subplot, overlay Pair 1 (FES CALIB) vs Pair 2 (NOFES CALIB)\n",
"fig, axes = plt.subplots(2, 2, figsize=(14, 9), sharey=True, sharex=True)\n",
"fig.suptitle(f'Cross-Subject Average Trajectory split by Online Condition & Trial Type (window = {WINDOW_BOTH})\\n'\n",
" 'Lines overlay Calibration Types: Pair 1 (FES) vs Pair 2 (NOFES)',\n",
" fontsize=13, fontweight='bold', y=1.02)\n",
"\n",
"TRIAL_TYPES = ['MI', 'REST']\n",
"ONLINE_CONDS = ['FES', 'NOFES']\n",
"\n",
"# Define distinct line colors / styles for the Calibration pairs to avoid confusion with Online Conds\n",
"# Pair 1 (FES_CALIB) -> dark orange-ish\n",
"# Pair 2 (NOFES_CALIB) -> dark blue-ish\n",
"calib_style = {\n",
" 'Pair1': {'color': '#d95f02', 'label': 'FES_CALIB (Pair 1)'},\n",
" 'Pair2': {'color': '#1f78b4', 'label': 'NOFES_CALIB (Pair 2)'}\n",
"}\n",
"\n",
"import warnings\n",
"\n",
"for row_idx, trial_type in enumerate(TRIAL_TYPES):\n",
" for col_idx, online_cond in enumerate(ONLINE_CONDS):\n",
" ax = axes[row_idx, col_idx]\n",
" \n",
" for pair_prefix, style_info in calib_style.items():\n",
" # Extract trajectories matching this specific combination:\n",
" # (trial_type, online_cond, pair starts with Pair1 or Pair2)\n",
" match_subset = [r['traj'] for r in lr_results_all \n",
" if r['trial_type'] == trial_type \n",
" and r['cond'] == online_cond \n",
" and r['pair'].startswith(pair_prefix)]\n",
" \n",
" if not match_subset:\n",
" continue\n",
" \n",
" max_len = max(len(t) for t in match_subset)\n",
" padded = np.full((len(match_subset), max_len), np.nan)\n",
" for j, t in enumerate(match_subset):\n",
" padded[j, :len(t)] = t\n",
" \n",
" with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\", category=RuntimeWarning)\n",
" mean_traj = np.nanmean(padded, axis=0)\n",
" n_present = np.sum(~np.isnan(padded), axis=0)\n",
" sem_traj = np.nanstd(padded, axis=0, ddof=1) / np.sqrt(n_present)\n",
" \n",
" x_vals = np.arange(max_len)\n",
" ax.plot(x_vals, mean_traj, color=style_info['color'], label=style_info['label'], lw=2)\n",
" ax.fill_between(x_vals, mean_traj - sem_traj, mean_traj + sem_traj, \n",
" color=style_info['color'], alpha=0.15, edgecolor='none')\n",
" \n",
" ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n",
" ax.set_title(f'ONLINE: {online_cond} | Trials: {trial_type}', fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
" \n",
" if row_idx == 1:\n",
" ax.set_xlabel('Sliding window start (trial index)')\n",
" if col_idx == 0:\n",
" ax.set_ylabel('Average Accuracy within window')\n",
" if row_idx == 0 and col_idx == 0:\n",
" ax.legend(loc='lower left')\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('avg_within_session_trajectory_cross_overlay.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: avg_within_session_trajectory_cross_overlay.png')"
]
},
2026-04-22 16:20:39 -05:00
{
"cell_type": "code",
2026-05-01 12:28:17 -05:00
"execution_count": 17,
2026-04-21 13:01:49 -05:00
"id": "cf55268e",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-22 16:20:39 -05:00
"iopub.execute_input": "2026-04-22T19:27:28.903912Z",
"iopub.status.busy": "2026-04-22T19:27:28.903829Z",
"iopub.status.idle": "2026-04-22T19:27:28.914870Z",
"shell.execute_reply": "2026-04-22T19:27:28.914546Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-22 16:20:39 -05:00
"=== Aggregate across 8 complete (subject × pair) comparisons ===\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-22 16:20:39 -05:00
"Metric FES NOFES paired Δ 95% CI (bootstrap) sign\n",
"----------------------------------------------------------------------------------------------------------------\n",
"Overall accuracy (markers) 0.821 ± 0.077 0.806 ± 0.075 +0.015 [ -0.040, +0.065] 4/7\n",
"MI accuracy 0.817 ± 0.108 0.775 ± 0.115 +0.042 [ -0.079, +0.142] 5/7\n",
"REST accuracy 0.825 ± 0.100 0.837 ± 0.061 -0.012 [ -0.070, +0.039] 4/7\n",
"MI EARLYSTOP latency (s) 2.523 ± 0.204 2.235 ± 0.209 +0.288 [ +0.058, +0.530] * 6/8\n",
"REST EARLYSTOP latency (s) 2.135 ± 0.303 2.391 ± 0.270 -0.256 [ -0.416, -0.077] * 2/8\n",
"Classification amplitude 2.237 ± 2.063 2.362 ± 1.813 +0.151 [ -0.142, +0.476] 4/7\n",
"Fisher ratio (test SNR) 2.307 ± 2.712 2.542 ± 3.263 +0.090 [ -1.065, +1.225] 3/7\n",
"μ-band SNR (REST/MI) 1.751 ± 0.671 1.747 ± 0.480 +0.065 [ -0.161, +0.299] 4/7\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-22 16:20:39 -05:00
"* = 95% bootstrap CI excludes zero (suggestive at this n).\n",
"Sign column: number of (subject × pair) comparisons where FES > NOFES.\n"
2026-04-21 21:18:33 -05:00
]
}
],
"source": [
2026-04-22 16:20:39 -05:00
"def boot_paired_ci(deltas, n_boot=10000, alpha=0.05, seed=0):\n",
" \"\"\"Bootstrap percentile CI on the mean of paired differences.\"\"\"\n",
" rng = np.random.default_rng(seed)\n",
" d = np.asarray([x for x in deltas if x is not None and not (isinstance(x, float) and np.isnan(x))])\n",
" if len(d) == 0:\n",
" return (np.nan, np.nan, np.nan)\n",
" boot_means = d[rng.integers(0, len(d), size=(n_boot, len(d)))].mean(axis=1)\n",
" lo, hi = np.percentile(boot_means, [100*alpha/2, 100*(1-alpha/2)])\n",
" return float(d.mean()), float(lo), float(hi)\n",
"\n",
"\n",
"# Collect all (subject × pair) where both FES and NOFES exist\n",
2026-04-21 21:18:33 -05:00
"paired = []\n",
"for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='FES'), None)\n",
" nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='NOFES'), None)\n",
2026-04-22 16:20:39 -05:00
" if fes and nof: paired.append((fes, nof))\n",
2026-04-21 21:18:33 -05:00
"\n",
"print(f'=== Aggregate across {len(paired)} complete (subject × pair) comparisons ===\\n')\n",
"\n",
2026-04-22 16:20:39 -05:00
"METRICS_SUMMARY = [\n",
" ('acc', 'Overall accuracy (markers)'),\n",
" ('mi_acc', 'MI accuracy'),\n",
" ('rest_acc', 'REST accuracy'),\n",
" ('mi_latency', 'MI EARLYSTOP latency (s)'),\n",
" ('rest_latency', 'REST EARLYSTOP latency (s)'),\n",
" ('amp', 'Classification amplitude'),\n",
" ('fisher', 'Fisher ratio (test SNR)'),\n",
" ('mu_snr', 'μ-band SNR (REST/MI)'),\n",
"]\n",
"\n",
"hdr = (f'{\"Metric\":<28} {\"FES\":>18} {\"NOFES\":>18} '\n",
" f'{\"paired Δ\":>9} {\"95% CI (bootstrap)\":>24} {\"sign\":>10}')\n",
2026-04-21 21:18:33 -05:00
"print(hdr); print('-' * len(hdr))\n",
2026-04-22 16:20:39 -05:00
"\n",
"def fmt_mean_sd(arr):\n",
" a = np.array([x for x in arr if x is not None and not (isinstance(x, float) and np.isnan(x))])\n",
" if len(a) == 0: return ' -- '\n",
" sd = a.std(ddof=1) if len(a) > 1 else 0.0\n",
" return f'{a.mean():>8.3f} ± {sd:6.3f}'\n",
"\n",
"for k, label in METRICS_SUMMARY:\n",
" fes_v = [f[k] for f, _ in paired]\n",
" nof_v = [n[k] for _, n in paired]\n",
" deltas = [f - n for f, n in zip(fes_v, nof_v)\n",
" if f is not None and n is not None\n",
" and not (isinstance(f, float) and np.isnan(f))\n",
" and not (isinstance(n, float) and np.isnan(n))]\n",
" if not deltas:\n",
" print(f'{label:<28} {\"(no data)\":<18} {\"\":<18}')\n",
" continue\n",
" mean_d, lo, hi = boot_paired_ci(deltas)\n",
" d_arr = np.array(deltas)\n",
" n_pos, n_neg = int((d_arr > 0).sum()), int((d_arr < 0).sum())\n",
" ci_crosses_zero = (lo < 0 < hi)\n",
" marker = ' ' if ci_crosses_zero else ' *'\n",
" print(f'{label:<28} {fmt_mean_sd(fes_v):>18} {fmt_mean_sd(nof_v):>18} '\n",
" f'{mean_d:>+9.3f} [{lo:>+7.3f},{hi:>+7.3f}]{marker} '\n",
" f'{n_pos}/{n_pos+n_neg}')\n",
"\n",
"print('\\n* = 95% bootstrap CI excludes zero (suggestive at this n).')\n",
"print('Sign column: number of (subject × pair) comparisons where FES > NOFES.')"
2026-04-21 21:18:33 -05:00
]
2026-05-01 12:28:17 -05:00
},
{
"cell_type": "markdown",
"id": "ee6e054e",
"metadata": {},
"source": [
"## Riemannian Generic Recentering & Discriminability\n",
"\n",
"To investigate whether discriminability changes resulting from FES feedback are hidden by pure accuracy levels, we replicate the robust transfer learning baseline (Riemannian Generic recentering, e.g., Zanini et al., 2018). \n",
"\n",
"Here we calculated a parameter-independent discriminability metric mathematically isolated from the fixed CSP projections. For each session, covariances are calculated and trace-normalized. We define the global reference point $C_{ref}$ as the Riemannian mean of the resting state trials and calculate discriminability as the native distance between class centers: $\\delta_R(C_{MI}, C_{REST})$ upon the Riemannian manifold of SPD matrices.\n",
"\n",
"By projecting each session's testing data to the identity $C_{ref}^{-1/2} X$, the identical offline decoder is evaluated on harmonized session distributions."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "4cb58a3b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<>:157: SyntaxWarning: invalid escape sequence '\\d'\n",
"<>:157: SyntaxWarning: invalid escape sequence '\\d'\n",
"/var/folders/98/qwcfyxcd0c12f9zp2wjlf1x80000gn/T/ipykernel_41885/1640400890.py:157: SyntaxWarning: invalid escape sequence '\\d'\n",
" ax.set_ylabel('Riemannian distance (MI vs REST)\\n$\\delta_R(C_{MI}, C_{REST})$')\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.9375691313508494e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.669963803608868e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.645524816095128e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.663373071280961e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.6914791949032847e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.6459381005574144e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.574740959007683e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.6370933350269555e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.5834058892971734e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.619226858195733e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.802433522130159e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.223125920536493e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.7162919724570517e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.9009604106934567e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.750509135686626e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.913175562960757e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.7225637711028584e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.0329918184869934e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.8006281048873087e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.0723573088052246e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.7691939412319846e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.007728608121999e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.422868575242497e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 3.0645254120286126e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.8596526698824483e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.2336790648652972e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.8526260428090153e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.2948400779817427e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.850686253077868e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.826465184871438e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.8727151431041065e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.793714914157437e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.3041236094301544e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.7244569473274707e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.2527220293745332e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.714964358671654e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.275182870324487e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.455653261686123e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.308158695654615e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.796099171627302e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.3019580157253367e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.711078152814985e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.3069898990797965e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.646265488710886e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.4338975954189277e-13\n",
" return f(*arrays, *other_args, **kwargs)\n",
"/Users/adipu/ECE374N/Final Project/venv/lib/python3.13/site-packages/scipy/_lib/_util.py:1181: RuntimeWarning: logm result may be inaccurate, approximate err = 2.5362700984456297e-13\n",
" return f(*arrays, *other_args, **kwargs)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Riemannian Alignment & Discriminability Metrics ===\n",
"\n",
"Metric FES NOFES paired Δ 95% CI (boot)\n",
"-----------------------------------------------------------------------------------------------------\n",
"Riemannian MI– REST Distance 1.702 ± 0.588 1.687 ± 0.422 +0.016 [ -0.395, +0.371] \n",
"Aligned CSPLDA test acc 0.690 ± 0.147 0.705 ± 0.197 -0.015 [ -0.101, +0.070] \n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAJICAYAAABvzFJaAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAAsTBJREFUeJzt3Qd4U/X6B/C3TdJ0lzJa9l4yXSAOQNx6VbxuUVHcW6/4v4p7oF739brluvfluq5e9xXBgYCoIKBs2VBaSnfTrP/zfdtTT9KkTZqkTZvv53kOoZknJ8nJm995f++b5PV6vUJERERElCCSW3sFiIiIiIhaEgNgIiIiIkooDICJiIiIKKEwACYiIiKihMIAmIiIiIgSCgNgIiIiIkooDICJiIiIKKEwACYiIiKihMIAmIiIiIgSCgNgIiIiIkooDICJiIiIKKEwAKa4lZSU5LP8/vvvUbvvr776yue+zz333KjdN/nCtvV/LY0lIyND+vfvLyeeeKLMnj1bvF5vwM3Xt29fn9tRbLXV7e2/3lhSUlIkJydHL5s4caJcffXV8u233zZ6PwcffHDM9j2xgPUzry/Wv6XcfvvtPo/94osvSjyIdB/f1PNq7DPC75e2gQEwEYUUuGKnHm2VlZWyfv16effdd+XUU0/VQNjj8fAVoahxOp1SWloqGzZskHnz5sk//vEPOeigg2Ts2LGyatUqbmlqcQimzftWBNvU8qyt8JhElMD22GMPGTZsmFRVVcmSJUtky5Yt9Ze999578sYbb8iZZ57pc5tjjjlGCgoKWmFtE1N72d4TJkyQLl26SFlZmaxYsUI2b95cf9miRYtkn332kU8//VQOOOAAn9thpLhz5871f+NIRTzD+p100kn1fw8fPrzFHhufZfNjY2S0PYjkeeE9Z77tmDFjor5+FDkGwETUojDSa4x41NTUyKGHHirffPNN/eUfffRRgwD4ySef5KvUgtrL9r7jjjt80gHmzp0rF110Uf3Ib3l5uUyePFmDYwQt5tu1JVj3f//73632ecbS3kTyvPADpLVeDwodUyCozTIfQsKvc7fbLU8//bT+2saISHZ2thx22GFN5vsZKioq5NZbb5XBgweL3W6XvLw8Oeuss2TTpk1Bb7Ny5UrNKRw1apQ+HvINe/bsKaeccop88cUXIeeWLV68WP785z/rF1lycrL8/e9/97nNDz/8IBdccIEMHTpUMjMzJTU1Vfr16ydTp07Vy0JNYcB1MTKBx8FzxP3df//9PmkHxvq99NJLPvc3adKkqKdEYHuZR0qgqKioWTmpuN0999yjh7c7deokNptNR/HwHnj++ef1ULi/QLl6eL2nTZsm+fn5kp6eLvvuu6+8+eab9bfBiCGCqqysLH3NDz/8cJk/f36D+3Y4HHLffffJ6aefLiNHjpRu3brp65aWlqbvkT/96U+6jQOlfAQ6RLpt2za5/PLLpU+fPvra9ejRQy677DIpLi6O2msfyvZuzecVKYzsfv3117rOhsLCQn0+4eQA//LLL/WfR+xr8F7D/gKBz5QpU+SRRx6R3bt3N3h8bLsXXnhBjj32WN1W2G54Hw0YMEC352effdboPg63v/fee2XEiBH63uzQoUNIOcCB9jnYLx511FGaH437Ofroo+v3JXjtHn/8cX198bris4B9jXkEvbH7buq99NZbb+noPD4/eB7777+/fPDBBwFfs8cee0wfe6+99tL3BtYH7xO8hvhsYz2xXUKBIwHXX3+9zjvAfeA1wHtv586dYT+vxgTLATbOx/7FDD+4/D8X+Gyaz3vllVcaPA7eY3gPGdcZN25cyOtIIph0QhSX8PY0L+vXrw96eX5+vvewww5rcBssKSkp3u+++87ntnPmzPG5Dm67xx57BLx9r169vLt27Wqwfo888ojXarUGvI2xXHzxxV6Px+Nzu9tuu83nOqeeeqrXZrP5nIf7Nlx33XXepKSkoI+By2bOnNlg/c455xyf651xxhne5OTkgPdxxRVXBF2/YAu2YSj81wP3b/bQQw/5XD5t2rQG99GnTx+f6/jDunTp0qXR9d1///29hYWFDW5nvs4BBxwQ9H7+/ve/6xLotbDb7d7vv//e57537twZ0nbEe8/hcPjc9oUXXvC5ztFHH+3t2LFjwNvvvffe3pqamqi89qFs79Z8Xk3xX+9g71F8vszX69Gjh8/lEydODLrv+eqrr/T1bur5L1q0yOc+f/vtt6D7GGPB62Zmvqxbt27eSZMm+ZyXk5Oj18P6mc/H+pv5f6aPPfbYgO+H1NRU3VeedNJJAdcP27e4uLjR+8Zr3NhrctZZZwW8b3yuZs+e3eC1ysjIaHJbjxo1qsF6+X+2jznmmKDbH+u4cePGiJ5XY49tvK7+5wdb8Nh4PpmZmfXnjRs3rsG2mTVrls/tnn/++QbXoeAYAFPcCicANn+RHX744frFYD7/0EMP9bltsB3R8OHD9cvDP7C96667fG7/+uuvNwiycTt8oXfu3NnnsjvuuMPntsECzEGDBulOGuuAQAvuuecen+vgywDP5YgjjvDZOWJ56aWXGg2CsKSlpXkPPvhg7+DBg33Ox5fhhg0b9HZvvfWWfgH67+AnTJig5xvLsmXLIg6Aq6urNeg0X/6f//ynwX009mWzcuXKBtsCwRO+5AcMGOBz/iGHHNLk+wDbAsHy6NGjfc5H0IPL8Fh4Dbp37+5zOd53gQJFvB/GjBnjPfLII73HH3+898ADD9TXwXxb/AhoLFA0AoR99tmnwfbC8sorr0TltQ9le7fm84pWAIz3jP9jmbdBYwEwXmf/95rx/M2Pbw6AEczgh7T5dhaLRd9jxx13nD5//N1YAGwseP/hs4j3IO6zOQGweV/iv17p6en1ATeeq38Aevfdd0cUKBrvHdx3Xl6ez/kDBw5s8Frh8bOzs7377ruv3mby5Mn6Pvbfx1955ZUh7ePxemH7INhvbN8QiwAY+0zsO/FczJcjMDfvW7EPhmuuucbnej/++KPP42A7GJdhe1RUVDTYfhQcA2CKW+EGwAh4EFDBqlWrNCg1LsP/zaNJgXaO5lHUF1980ecy7GgMbrfb27Nnz/rL8EVhXrfy8nLdyZq/UIqKihr9MnryySd9nhueB740zV8+I0aM8LmfHTt2+Hx5ISDDugULgjp06OD95Zdf9DKXy9VgxBzP2cz/9qGO+Przvx9jZ49g3z+IvPDCCwPeR2NfNlOmTPEJpj755JP6yzD6jlF4820//vjjRt8HRtCF2yIQ9g8+fv311/rtb/4S9X+PYfQT29v/CABs377dJ2gfO3Zsk4Hiq6++Wn/57bff7nPZueee2+g2D/e1b2x7t+bzilYAXFVV1WA9Fi5cGFIAjB+qxvnnn39+g/vesmWL97nnnvNu2rSp/rybb77Z5/6w/1i8eLHP7TZv3tzgx5//Ou611156/wZjfxduAIzXaMWKFXoZjm75/3DB45SWlurlWCfzZRiFjiRQRPBnjNbiM+QfBP/+++8+t//555/1PeuvrKzM269fP5+jgGaBPtvPPvts/eVLlixpENybf7TEIgAO9jnwPypmwLbAD6NA7ze8X8yj+IGO5FDjmANM7caDDz6oeV0waNAgGTJkSP1lmGyFXL9gkAuG3DDDcccd53P51q1b6///448/+uTC4TGvu+46Ofnkk3U555xzdHKNudTXl19+GfSxMQns0ksv9TkP94kcYuQlG5CXhwk8xuMgT9JcNxfriHUL5pJLLtHcQbBYLJrzF+w5xtKvv/4qb7/9tk52Mx4TeZSo/vDss8+GdV/YJh9++GH938gnnDVrVv02Qi72smXLfG5jvr4/vG+Q9w3Iqdtvv/18LkeeJvI+ATmfmCke7D2G/Gasz//93/9pHnHHjh01VxT327VrV5/3yG+//dbo80Run3liYGPvz1i/9vH0vJorklJ7yFU2fPzxx/LAAw/oe3nNmjU6D6F79+5y3nnn6T7FXN3E7KGHHpK9997b5zzkt/o/f38o4Yb7Nxj7u3DhfYxqLJCbm+uzr4Tp06drbjL45xObq7Y0x8yZM+tzl/EZ8v+M+b/
"text/plain": [
"<Figure size 720x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Saved: riemannian_discriminability.png\n"
]
}
],
"source": [
"from scipy.linalg import sqrtm, logm, eigh, inv, norm, expm\n",
"\n",
"def riemannian_mean_cov(covs, tol=1e-5, max_iter=30):\n",
" \"\"\"Compute Riemannian mean of SPD matrices using iterative fixed-point algorithm.\"\"\"\n",
" covs = [np.asarray(C, dtype=np.float64) for C in covs]\n",
" # Add small regularization to ensure SPD\n",
" for i in range(len(covs)):\n",
" covs[i] = covs[i] + 1e-5 * np.trace(covs[i]) / covs[i].shape[0] * np.eye(covs[i].shape[0])\n",
" C_mean = np.mean(covs, axis=0) # Init with Euclidean mean\n",
" for _ in range(max_iter):\n",
" C_sq_inv = inv(sqrtm(C_mean))\n",
" tangent_sum = np.zeros_like(C_mean)\n",
" for C in covs:\n",
" tangent_sum += np.real(logm(C_sq_inv @ C @ C_sq_inv))\n",
" tangent_sum /= len(covs)\n",
" \n",
" C_mean = sqrtm(C_mean) @ expm(tangent_sum) @ sqrtm(C_mean)\n",
" C_mean = np.real(C_mean)\n",
" if norm(tangent_sum) < tol:\n",
" break\n",
" return C_mean\n",
"\n",
"def riemannian_dist(A, B):\n",
" \"\"\"Riemannian distance between symmetric positive-definite matrices A and B.\"\"\"\n",
" A_reg = A + 1e-6 * np.trace(A) / A.shape[0] * np.eye(A.shape[0])\n",
" B_reg = B + 1e-6 * np.trace(B) / B.shape[0] * np.eye(B.shape[0])\n",
" evals = np.clip(np.real(eigh(A_reg, B_reg, eigvals_only=True)), 1e-12, np.inf)\n",
" return np.sqrt(np.sum(np.log(evals)**2))\n",
"\n",
"def recenter_data(X, y):\n",
" \"\"\"\n",
" Riemannian Generic Recentering (Zanini et al. 2018).\n",
" Calculates reference covariance C_ref from resting state (y=0)\n",
" and transforms X -> R @ X where R = C_ref^{-1/2}\n",
" \"\"\"\n",
" # compute trace-normalized covariances\n",
" covs = np.einsum('ijk,ilk->ijl', X, X)\n",
" covs /= np.trace(covs, axis1=1, axis2=2)[:, None, None]\n",
" \n",
" covs_rest = covs[y == 0]\n",
" if len(covs_rest) == 0:\n",
" return X, None, np.nan\n",
" \n",
" C_ref = riemannian_mean_cov(covs_rest)\n",
" \n",
" # Alignment transformation R = C_ref^{-1/2}\n",
" R = np.real(inv(sqrtm(C_ref)))\n",
" X_aligned = np.einsum('ij,kjl->kil', R, X)\n",
" \n",
" # Riemannian discriminability (Distance between MI center and REST center)\n",
" covs_mi = covs[y == 1]\n",
" if len(covs_mi) > 0:\n",
" C_mi = riemannian_mean_cov(covs_mi)\n",
" dist = riemannian_dist(C_mi, C_ref)\n",
" else:\n",
" dist = np.nan\n",
" \n",
" return X_aligned, C_ref, dist\n",
"\n",
"MIN_TEST_TRIALS = 10\n",
"ra_results = []\n",
"\n",
"for subj in subjects:\n",
" subj_ses = sessions[subj]\n",
" for pair in PAIRS:\n",
" needed = (pair['train'], pair['online_fes'], pair['online_nofes'])\n",
" if any(k not in subj_ses for k in needed): continue\n",
" \n",
" train = subj_ses[pair['train']]\n",
" if set(np.unique(train['y'])) != {0, 1}: continue\n",
" \n",
" # 1. Recenter Training data\n",
" X_train_r, _, _ = recenter_data(train['X'], train['y'])\n",
" \n",
" # 2. Fit CSPLDA on recentered offline train data\n",
" clf = CSPLDA(n_csp=N_CSP).fit(X_train_r, train['y'])\n",
" train_acc = (clf.predict(X_train_r) == train['y']).mean()\n",
" \n",
" for cond_key, cond_label in [('online_fes', 'FES'), ('online_nofes', 'NOFES')]:\n",
" te = subj_ses[pair[cond_key]]\n",
" if len(te['y']) < MIN_TEST_TRIALS or set(np.unique(te['y'])) != {0, 1}: continue\n",
" \n",
" # 3. Recenter Testing data (Online sets)\n",
" X_test_r, _, riemann_dist = recenter_data(te['X'], te['y'])\n",
" \n",
" # 4. Evaluate decoder on recentered test data\n",
" res = evaluate(clf, X_test_r, te['y'])\n",
" \n",
" y_true = res['y']\n",
" y_pred = res['pred']\n",
" acc = (y_true == y_pred).mean()\n",
" mi_acc = (y_pred[y_true == 1] == 1).mean() if (y_true == 1).any() else np.nan\n",
" rest_acc = (y_pred[y_true == 0] == 0).mean() if (y_true == 0).any() else np.nan\n",
" \n",
" ra_results.append({\n",
" 'subject': subj,\n",
" 'pair': pair['name'],\n",
" 'condition': cond_label,\n",
" 'train_acc': train_acc,\n",
" 'test_acc': acc,\n",
" 'mi_acc': mi_acc,\n",
" 'rest_acc': rest_acc,\n",
" 'riemann_dist': riemann_dist\n",
" })\n",
"\n",
"print(\"=== Riemannian Alignment & Discriminability Metrics ===\\n\")\n",
"hdr = f'{\"Metric\":<28} {\"FES\":>18} {\"NOFES\":>18} {\"paired Δ\":>9} {\"95% CI (boot)\":>24}'\n",
"print(hdr); print('-'*len(hdr))\n",
"\n",
"def fmt_ra(arr):\n",
" if len(arr) == 0: return ' -- '\n",
" return f'{np.mean(arr):>8.3f} ± {np.std(arr, ddof=1):6.3f}'\n",
"\n",
"metrics = [\n",
" ('riemann_dist', 'Riemannian MI– REST Distance'),\n",
" ('test_acc', 'Aligned CSPLDA test acc')\n",
"]\n",
"\n",
"for m_key, m_label in metrics:\n",
" fes_vals, nof_vals, deltas = [], [], []\n",
" for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in ra_results if r['subject']==subj and r['pair']==pair['name'] and r['condition']=='FES'), None)\n",
" nof = next((r for r in ra_results if r['subject']==subj and r['pair']==pair['name'] and r['condition']=='NOFES'), None)\n",
" if fes and nof and not np.isnan(fes[m_key]) and not np.isnan(nof[m_key]):\n",
" fes_vals.append(fes[m_key])\n",
" nof_vals.append(nof[m_key])\n",
" deltas.append(fes[m_key] - nof[m_key])\n",
" \n",
" if len(deltas) > 0:\n",
" mean_d, lo, hi = boot_paired_ci(deltas)\n",
" marker = ' ' if lo < 0 < hi else ' *'\n",
" print(f'{m_label:<28} {fmt_ra(fes_vals):>18} {fmt_ra(nof_vals):>18} '\n",
" f'{mean_d:>+9.3f} [{lo:>+7.3f},{hi:>+7.3f}]{marker}')\n",
"\n",
"# Plot Riemannian Distance (Discriminability) distribution\n",
"fes_dists = [r['riemann_dist'] for r in ra_results if r['condition'] == 'FES']\n",
"nof_dists = [r['riemann_dist'] for r in ra_results if r['condition'] == 'NOFES']\n",
"\n",
"if fes_dists and nof_dists:\n",
" fig, ax = plt.subplots(figsize=(6, 5))\n",
" means = [np.mean(fes_dists), np.mean(nof_dists)]\n",
" sems = [np.std(fes_dists, ddof=1)/np.sqrt(len(fes_dists)),\n",
" np.std(nof_dists, ddof=1)/np.sqrt(len(nof_dists))]\n",
" ax.bar(['ONLINE FES', 'ONLINE NOFES'], means, yerr=sems, \n",
" color=[cond_color['FES'], cond_color['NOFES']], capsize=6, edgecolor='white')\n",
"\n",
" # Paired lines\n",
" for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in ra_results if r['subject']==subj and r['pair']==pair['name'] and r['condition']=='FES'), None)\n",
" nof = next((r for r in ra_results if r['subject']==subj and r['pair']==pair['name'] and r['condition']=='NOFES'), None)\n",
" if fes and nof:\n",
" ax.plot([0, 1], [fes['riemann_dist'], nof['riemann_dist']], color='k', alpha=0.3, lw=1)\n",
" ax.scatter([0, 1], [fes['riemann_dist'], nof['riemann_dist']], color='k', alpha=0.5, s=20, zorder=3)\n",
"\n",
" ax.set_ylabel('Riemannian distance (MI vs REST)\\n$\\delta_R(C_{MI}, C_{REST})$')\n",
" ax.set_title('Inherent Riemannian Discriminability\\n(Independent of classifier projection)', fontweight='bold')\n",
" ax.spines[['top', 'right']].set_visible(False)\n",
" ax.grid(axis='y', alpha=0.3)\n",
" plt.tight_layout()\n",
" plt.savefig('riemannian_discriminability.png', dpi=150, bbox_inches='tight')\n",
" plt.show()\n",
" print('\\nSaved: riemannian_discriminability.png')"
]
},
{
"cell_type": "markdown",
"id": "09f20fd7",
"metadata": {},
"source": [
"## Nonstationarity: Drift from Offline Calibration to Online Session\n",
"\n",
"This analysis measures the covariate shift (drift) by calculating the Riemannian distance between the class manifold in the offline calibration session and the class manifold in the subsequent online session. \n",
"- **FES Drift:** Distance from OFFLINE_FES to ONLINE_FES\n",
"- **NOFES Drift:** Distance from OFFLINE_NOFES to ONLINE_NOFES\n",
"\n",
"This is done separately for the MI state and the REST state. High distances indicate that the brain state during the online task has drifted significantly from what the classifier learned during the offline phase."
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "17f73b20",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"<>:112: SyntaxWarning: invalid escape sequence '\\d'\n",
"<>:112: SyntaxWarning: invalid escape sequence '\\d'\n",
"/var/folders/98/qwcfyxcd0c12f9zp2wjlf1x80000gn/T/ipykernel_41885/2712111310.py:112: SyntaxWarning: invalid escape sequence '\\d'\n",
" axes[0].set_ylabel('Riemannian distance (Offline vs Online)\\n$\\delta_R(C_{train}, C_{test})$')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=== Offline to Online Nonstationarity (Riemannian Drift) ===\n",
"\n",
"Metric FES NOFES paired Δ 95% CI (boot)\n",
"-----------------------------------------------------------------------------------------------------\n",
"MI Drift (Offline->Online) 2.470 ± 1.075 2.248 ± 0.935 +0.222 [ -0.245, +0.869] \n",
"REST Drift (Offline->Online) 2.534 ± 1.087 2.354 ± 0.987 +0.180 [ -0.375, +0.836] \n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJoAAAJQCAYAAADGyg5uAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAA6nhJREFUeJzs3Qd8E/X7B/Cn6d6UvYdsRBCRpQjIcOACUXGCIKjgnihO3OBeuBVxr7/j554MRdlb9t6jlJbulf/r85SLlzRpk+bapM3n/Xpd0+QyLpfL5XvPPd/nG2a32+1CRERERERERETkJ5u/T0BERERERERERMRAExERERERERERWYYZTUREREREREREZAkGmoiIiIiIiIiIyBIMNBERERERERERkSUYaCIiIiIiIiIiIksw0ERERERERERERJZgoImIiIiIiIiIiCzBQBMREREREREREVmCgSYiIiIiIiIiIrIEA01ERERERERERGQJBpqIyC9hYWGlpqefftrj/a+//vpS92/ZsmXQfApXXnml07LNmjWr1H2WLVsmF110kTRt2lQiIyMtfR9bt251ev0BAwb49PgZM2Y4Pf7BBx/06fGun43NZpPo6GhJSUmR1q1by+DBg2XSpEmyYsUKH9+Zd6/pbh3a7XZ544035OSTT5ZatWrpMlX0/Vlt165dugz9+/eXRo0a6bpKTk6Wdu3ayahRo+TLL7/U5S+LN9uTN+sA24p5XWJbsmq7ChYbNmyQe+65R0455RTH+o6Li5NWrVrJeeedJ9OnT5e0tDRLXqu8dYb1bp6P754ZPj/z/Oqupr2fQMF3+auvvtL9A/YT2F9gO8b23K9fP92usF+pLDV5P5GRkSFTp07V/WSdOnV0f4r95THHHKO3XXPNNfLyyy/L4cOHpbpBW8T82aCtQkQUzCICvQBEVPOgIXfLLbfowbBZenq6vPvuu1KdrVmzRk466STJycmRUDkoys/P1wmN882bN8tvv/0m06ZNk9NPP13eeecdPUCqTFOmTNGpLDg4M98Hy1WZDXGsl0cffVQefvhhXTdmuI4DHgRF3nvvPenatat88skn0r59+wpvT96sg5osOztbbrrpJnn77beluLi41HwcIGP65ptv5PPPP5fff/89IMtZ3WCdIUhnQMDUXXC9OkMAcMyYMY7rDzzwQMAC1OvWrZORI0fK8uXLS83bu3evTnPnzpXHH39c7rvvPg2qMqjnnfXr1+uJkB07dpRqd2DasmWLzJs3T2/r1auXnHjiiZZ8pkRE5B4DTURkOTTo/ve//2mGgdlbb70lmZmZQb3Ge/To4bSM9erVc5o/c+ZMp6BAw4YNtdEaEREh9evXl5rmzDPPlNjYWG2oI4vpwIEDjnk//fSTdOvWTQ+M2rZtW+HXGDFihON/d+vwtddec7qOA4TmzZvrAVinTp0kEHBmHBlGZi1atJDOnTvrupo/f74UFBTo7TioxDby999/S8eOHSu0Pfm7DuLj453W87HHHivVRVZWlmZWLFq0yOn2unXrygknnCBRUVGyc+dOWblypRQVFbkNRFUGrHfzOg2mzMzKMHToUNm/f3+gF6PaQlC5T58+un8wIOOmd+/ekpSUJKtWrZJt27Y5gtUINCFo4vrdr0zVdT+BwD8CeOYgE/afXbp00YxH/G6tXr1aTwBUV2iLmD8btFWIiIIZA01EVCleeOEFp0ATDv5eeumloF/b1113nU6e4Iyz69lyZPbUVOiKZBxAozH/9ddfy4QJExzrYd++fXL22WfrQT4O+CsCGShlMa9zBGIWLlwogfThhx86BZmQuYdtG+vFgAOe888/3xEcwcElDhKwnsLDw33envxdBzhIKW89B6trr73WKciEbkbYv4wbN84pa/LQoUP6uRhZC5UN3R0xhQrsC6hiEADF998cZEKg4IsvvpBmzZo5bnvllVe0e7kRLH399dc1y+zSSy+tklVfXfcTS5cu1S7IhmHDhun7MO9rsU5xAuCDDz7Q4FN1g6BfdfxsiCiE2YmI/IDdiDGlpKTYY2NjHddXrlzpuN+XX37puL1JkyZOj2vRooXTc+bm5tqfeOIJ+8iRI+2dO3e2N2zY0B4dHW2PiYnRxw4dOtQ+Y8YMe1FRUanleeedd5ye+4EHHrDv3r3bPnHiRHvz5s3tUVFR9saNG9snTJhgP3ToUKnHjx492unxf/zxh96O5zHf7m7Cfcy2bt1qv/POO+0nnHCCPTk52R4REWGvW7euvX///vann37anpGRUer1t2zZ4vScuK+r/Px8+7Rp0+wdO3bU9VK/fn37pZdeat+4caPb91/RzxMTlsfVmjVr7PHx8U73e/HFF53ug8/UPB+f1fTp0+3du3d3PDYtLa3Ua5q3BdfncDe5vt+y7lfeZ+yttm3bOj3+9ttvd3s/bHcJCQlO9/3ggw982p68WQfGZ4RtxdNnV9525bo8WF/r1q2zjxo1Sr9/kZGR9pYtW9rvuusue05Ojsd18+uvv+q22KpVK90XYML6uvbaa+1r1661+2rZsmX2sLAwt+vQE+w/zN577z37VVddZe/Ro4e9adOmuv3h/eC72LdvX/tjjz1mP3z4cKnnqcg6M3P97OCTTz6xn3TSSbpdJCUl2QcPHux2+3P32unp6bo/adOmje7Hunbt6tf+0pvvjfk9u3s/rvbv32+fMmWKvU+fPvbatWvrPg+/C71799bb9+3b5/ZxrvuAwsJC+yuvvGI/8cQT7XFxcfbExET7oEGD7H/++afdW1iv3rxH130k9svPPPOMfcCAAbqNYFvB/rtbt272O+64w+0+sTzvv/++02vi/WD/4M5tt93mdF98f6z8rtbE/cRHH33ktFz4/Cri4MGD9kcffdR+8sknO7bfOnXq6Lb31ltv6W+vO9guL7nkEvsxxxyj7wXroUGDBvodvfLKK+0vv/yyPS8vz6/HuG7P+B1z5++//7aPGTPG3q5dO93XYV+A/cC5556r+5/KaDcREbnDQBMR+cX1AGHcuHGO6+PHj3fbuEVDzlNwAQ4cOODVAQIO0lwbb64NpjPPPFMbjO4ejwCQa8PRqkDTzJkznYJu7iY04BYvXuz0+uU19LG8p512mtvnw8HLNddcU+ZBlC+fp+tBiNlNN93kdD8cPJu5HpRedtllpZ67OgaaEDx1fe6dO3d6vL/5+4DpggsuqBaBphEjRnjcfs8555xS7xPbJQ4cy1pOHKy8/fbbdl9MnjzZ6TmM4Iovjj322HLXYbNmzTQwXJmBJtcAgjEhkOb6WNfXxvtGEMnduqjo/tLqQNNPP/3kcV9rTJiP+7ky3wcH21hWT9vQvHnzKi3QhP0x9stl3R/fC+zffYHvk/k5zL+Nrnbs2FHqNVetWmXZd7Um7ie++OILp+fAyRcEajZt2uT1c2B7qVevXpnLhgAqglFmCHy7BsPdTfie+vOY8gJNxcXF9htuuKHc5+zXr589NTXV0nYTEZE77DpHRJa68cYb5c0339T/kaKOEWC2b98us2fP1ttiYmLk6quv1iKn5UENFhSqrV27tnaXSU1NlSVLljhq2vz666/aZenWW2/1+Bw//PCD1rHp3r27Poe5Ww2eC0WaL7/8cq/rsaALj1FHAzBKkFHHyaiV88cff2jxWXSXMLRp00ZHvkG9HnQ3A6wX1EBCbQ7XWlCeoEjszz//7HQb3ltCQoJ2C6iqeh6o1/L88887ri9YsEDfr7mrghm2Bax/1HTCKEuLFy/2uiYMupcY0OUB68y1psi///6rNVDMNYxQM8nq+jmuXdawfTZp0sTj/THSkfF9MD/e2+3J23VgNbwe6kShUDkKcZu7paD+2p9//il9+/Z13IbuPuhSaMB3Fl2DUKcK9zUKyo8fP14/i1NPPdWr5UBdKzOsj4pAnbEOHTrocuG7gjps+C4ePHjQ0dXxhhtu0GLilQWjcaJw/nHHHaff+d27d+vtiLOgeyC2FU+1zozi0Rj9EXWpCgsLJTc316/9JbY/fLbYR5qfA121fK3Rg+/e8OHD9fkMGEURj8c87OuM7o24H7Z713plBuwfMeF7he8A9i1GlzO
"text/plain": [
"<Figure size 1200x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Saved: riemannian_drift_offline_to_online.png\n"
]
}
],
"source": [
"drift_results = []\n",
"\n",
"for subj in subjects:\n",
" subj_ses = sessions[subj]\n",
" \n",
" # We want to measure the drift from offline to the corresponding online session:\n",
" # 1. FES Drift: OFFLINE_FES -> ONLINE_FES\n",
" # 2. NOFES Drift: OFFLINE_NOFES -> ONLINE_NOFES\n",
" drift_pairs = [\n",
" ('FES', PAIRS[0]['train'], PAIRS[0]['online_fes']), \n",
" ('NOFES', PAIRS[1]['train'], PAIRS[1]['online_nofes']) \n",
" ]\n",
" \n",
" for cond_label, train_key, test_key in drift_pairs:\n",
" if train_key not in subj_ses or test_key not in subj_ses:\n",
" continue\n",
" \n",
" train = subj_ses[train_key]\n",
" test = subj_ses[test_key]\n",
" \n",
" # Internal helper to calculate regularized trace-normalized covariances\n",
" def get_covs(X):\n",
" c = np.einsum('ijk,ilk->ijl', X, X)\n",
" c /= np.trace(c, axis1=1, axis2=2)[:, None, None]\n",
" # Add small regularization to ensure they are SPD for the Riemannian distance\n",
" for i in range(len(c)):\n",
" c[i] = c[i] + 1e-5 * np.trace(c[i]) / c[i].shape[0] * np.eye(c[i].shape[0])\n",
" return c\n",
"\n",
" train_covs = get_covs(train['X'])\n",
" test_covs = get_covs(test['X'])\n",
" \n",
" # Pull out the MI and REST manifolds\n",
" train_mi, train_rest = train_covs[train['y'] == 1], train_covs[train['y'] == 0]\n",
" test_mi, test_rest = test_covs[test['y'] == 1], test_covs[test['y'] == 0]\n",
" \n",
" if len(train_mi) == 0 or len(train_rest) == 0 or len(test_mi) == 0 or len(test_rest) == 0:\n",
" continue\n",
" \n",
" # Riemannian mean for each phase/class\n",
" C_train_mi = riemannian_mean_cov(train_mi)\n",
" C_train_rest = riemannian_mean_cov(train_rest)\n",
" \n",
" C_test_mi = riemannian_mean_cov(test_mi)\n",
" C_test_rest = riemannian_mean_cov(test_rest)\n",
" \n",
" # Calculate Riemannian Drift (distance) between offline and online centers\n",
" mi_drift = riemannian_dist(C_train_mi, C_test_mi)\n",
" rest_drift = riemannian_dist(C_train_rest, C_test_rest)\n",
" \n",
" drift_results.append({\n",
" 'subject': subj,\n",
" 'condition': cond_label,\n",
" 'mi_drift': mi_drift,\n",
" 'rest_drift': rest_drift\n",
" })\n",
"\n",
"print(\"=== Offline to Online Nonstationarity (Riemannian Drift) ===\\n\")\n",
"hdr = f'{\"Metric\":<28} {\"FES\":>18} {\"NOFES\":>18} {\"paired Δ\":>9} {\"95% CI (boot)\":>24}'\n",
"print(hdr); print('-'*len(hdr))\n",
"\n",
"drift_metrics = [\n",
" ('mi_drift', 'MI Drift (Offline->Online)'),\n",
" ('rest_drift', 'REST Drift (Offline->Online)')\n",
"]\n",
"\n",
"for m_key, m_label in drift_metrics:\n",
" fes_vals, nof_vals, deltas = [], [], []\n",
" for subj in subjects:\n",
" fes = next((r for r in drift_results if r['subject']==subj and r['condition']=='FES'), None)\n",
" nof = next((r for r in drift_results if r['subject']==subj and r['condition']=='NOFES'), None)\n",
" if fes and nof and not np.isnan(fes[m_key]) and not np.isnan(nof[m_key]):\n",
" fes_vals.append(fes[m_key])\n",
" nof_vals.append(nof[m_key])\n",
" deltas.append(fes[m_key] - nof[m_key])\n",
" \n",
" if len(deltas) > 0:\n",
" mean_d, lo, hi = boot_paired_ci(deltas)\n",
" marker = ' ' if lo < 0 < hi else ' *'\n",
" print(f'{m_label:<28} {fmt_ra(fes_vals):>18} {fmt_ra(nof_vals):>18} '\n",
" f'{mean_d:>+9.3f} [{lo:>+7.3f},{hi:>+7.3f}]{marker}')\n",
"\n",
"\n",
"# ── Figure ──\n",
"fig, axes = plt.subplots(1, 2, figsize=(10, 5), sharey=True)\n",
"fig.suptitle('Manifold Drift: Offline Calibration to Online Session\\nMeasured by Riemannian Distance', fontweight='bold')\n",
"\n",
"for ax, (m_key, m_label) in zip(axes, drift_metrics):\n",
" fes_dists = [r[m_key] for r in drift_results if r['condition'] == 'FES']\n",
" nof_dists = [r[m_key] for r in drift_results if r['condition'] == 'NOFES']\n",
"\n",
" if fes_dists and nof_dists:\n",
" means = [np.mean(fes_dists), np.mean(nof_dists)]\n",
" sems = [np.std(fes_dists, ddof=1)/np.sqrt(len(fes_dists)),\n",
" np.std(nof_dists, ddof=1)/np.sqrt(len(nof_dists))]\n",
" \n",
" ax.bar(['FES Path\\n(FES Calib -> Online)', 'NOFES Path\\n(NOFES Calib -> Online)'], means, yerr=sems, \n",
" color=[cond_color['FES'], cond_color['NOFES']], capsize=6, edgecolor='white')\n",
"\n",
" # Paired lines\n",
" for subj in subjects:\n",
" fes = next((r for r in drift_results if r['subject']==subj and r['condition']=='FES'), None)\n",
" nof = next((r for r in drift_results if r['subject']==subj and r['condition']=='NOFES'), None)\n",
" if fes and nof:\n",
" ax.plot([0, 1], [fes[m_key], nof[m_key]], color='k', alpha=0.3, lw=1)\n",
" ax.scatter([0, 1], [fes[m_key], nof[m_key]], color='k', alpha=0.5, s=20, zorder=3)\n",
"\n",
" ax.set_title(m_label, fontweight='bold', fontsize=11)\n",
" ax.spines[['top', 'right']].set_visible(False)\n",
" ax.grid(axis='y', alpha=0.3)\n",
"\n",
"axes[0].set_ylabel('Riemannian distance (Offline vs Online)\\n$\\delta_R(C_{train}, C_{test})$')\n",
"plt.tight_layout()\n",
"plt.savefig('riemannian_drift_offline_to_online.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('\\nSaved: riemannian_drift_offline_to_online.png')"
]
2026-04-21 13:01:49 -05:00
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
2026-04-21 21:18:33 -05:00
}