Files
ECE374N/Final Project/overall_analysis.ipynb

1146 lines
630 KiB
Plaintext
Raw Normal View History

2026-04-21 13:01:49 -05:00
{
"cells": [
{
"cell_type": "markdown",
2026-04-21 21:18:33 -05:00
"id": "d8960e56",
2026-04-21 13:01:49 -05:00
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"# Motor Imagery Decoder — Train OFFLINE, Evaluate ONLINE (FES vs NOFES)\n",
"\n",
"For each subject × offline-session, train a CSP + LDA classifier on the OFFLINE recording, then apply it to the two matched ONLINE sessions.\n",
"\n",
"**Pair 1:** train on `S001 OFFLINE_FES` → test on `S002 ONLINE_FES` and `S003 ONLINE_NOFES`\n",
"**Pair 2:** train on `S004 OFFLINE_NOFES` → test on `S006 ONLINE_FES` and `S005 ONLINE_NOFES`\n",
"\n",
"**Metrics reported per (subject × pair × condition):**\n",
"1. **Classification accuracy** — fraction of cued trials correctly classified\n",
"2. **Classification amplitude** — mean |LDA decision-function value|\n",
"3. **SNR** — (a) Fisher ratio of the LDA projection on online data, and (b) mu-band power ratio REST / MI over motor channels C3/Cz/C4"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 1,
2026-04-21 13:01:49 -05:00
"id": "578c9128",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:20:38.290986Z",
"iopub.status.busy": "2026-04-22T03:20:38.290709Z",
"iopub.status.idle": "2026-04-22T03:20:38.297801Z",
"shell.execute_reply": "2026-04-22T03:20:38.296447Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
"source": [
"# Install dependencies if needed\n",
"# !pip install pyxdf mne scipy numpy matplotlib"
]
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 2,
2026-04-21 13:01:49 -05:00
"id": "857b22c0",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:20:38.300669Z",
"iopub.status.busy": "2026-04-22T03:20:38.300420Z",
"iopub.status.idle": "2026-04-22T03:20:39.360556Z",
"shell.execute_reply": "2026-04-22T03:20:39.360152Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
2026-04-21 21:18:33 -05:00
"source": [
"import os\n",
"import re\n",
"import glob\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.patches import Patch\n",
"import pyxdf\n",
"from scipy.signal import welch, butter, filtfilt, iirnotch\n",
"from scipy.linalg import eigh\n",
"\n",
"plt.rcParams.update({'font.size': 11, 'figure.dpi': 120})"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "fe68bf0e",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 3,
2026-04-21 13:01:49 -05:00
"id": "dc4b2c55",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:20:39.362226Z",
"iopub.status.busy": "2026-04-22T03:20:39.362118Z",
"iopub.status.idle": "2026-04-22T03:20:39.364971Z",
"shell.execute_reply": "2026-04-22T03:20:39.364522Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
2026-04-21 21:18:33 -05:00
"source": [
"DATA_DIR = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'Group 2 - Glove')\n",
"\n",
2026-04-21 22:12:59 -05:00
"# Marker codes (from experiment trigger table)\n",
"MI_BEGIN = 200\n",
"MI_END = 220\n",
"MI_EARLYSTOP = 240 # online only: live classifier fired → successful MI detection\n",
"REST_BEGIN = 100\n",
"REST_END = 120\n",
"REST_EARLYSTOP = 140 # online only: live classifier fired → successful REST detection\n",
"ROBOT_BEGIN = 300\n",
"ROBOT_END = 320\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"TARGET_MARKERS = [100, 120, 140, 200, 220, 240]\n",
"\n",
"T_PRE = -1.0\n",
"T_POST = 5.0\n",
2026-04-21 21:18:33 -05:00
"\n",
"# ── Preprocessing ────────────────────────────────────────────────────────────\n",
2026-04-21 22:12:59 -05:00
"NOTCH_FREQ = 60.0\n",
"NOTCH_Q = 30\n",
"BP_LO, BP_HI = 8.0, 30.0\n",
"USE_CAR = True\n",
"PTP_REJECT_UV = 100.0\n",
2026-04-21 21:18:33 -05:00
"\n",
"N_CSP = 4\n",
"\n",
"NON_EEG = {'AUX1', 'AUX2', 'AUX3', 'AUX7', 'AUX8', 'AUX9', 'TRIGGER'}\n",
"RENAME = {'FP1':'Fp1','FPZ':'Fpz','FP2':'Fp2','FZ':'Fz','CZ':'Cz',\n",
" 'PZ':'Pz','POZ':'POz','OZ':'Oz'}\n",
"\n",
"MOTOR_CH = ['C3', 'Cz', 'C4']\n",
"MU_BAND = (8, 13)\n",
"\n",
"PAIRS = [\n",
" {'name': 'Pair1 (train=OFFLINE_FES)',\n",
2026-04-21 22:24:23 -05:00
" 'train': 'S001', 'online_fes': 'S002', 'online_nofes': 'S003'},\n",
2026-04-21 21:18:33 -05:00
" {'name': 'Pair2 (train=OFFLINE_NOFES)',\n",
2026-04-21 22:24:23 -05:00
" 'train': 'S004', 'online_fes': 'S006', 'online_nofes': 'S005'},\n",
2026-04-21 21:18:33 -05:00
"]"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "21a40df3",
"metadata": {},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 4,
2026-04-21 13:01:49 -05:00
"id": "e798b039",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:20:39.366319Z",
"iopub.status.busy": "2026-04-22T03:20:39.366243Z",
"iopub.status.idle": "2026-04-22T03:20:39.379175Z",
"shell.execute_reply": "2026-04-22T03:20:39.378716Z"
2026-04-21 21:18:33 -05:00
}
},
2026-04-21 13:01:49 -05:00
"outputs": [],
2026-04-21 21:18:33 -05:00
"source": [
"# ── XDF loading + session parsing ─────────────────────────────────────────────\n",
"\n",
"def get_channel_names_from_xdf(eeg_stream):\n",
" ch_desc = eeg_stream['info']['desc'][0]\n",
" channels = ch_desc.get('channels', [{}])[0].get('channel', [])\n",
" return [ch['label'][0] for ch in channels]\n",
"\n",
"\n",
"_SESSION_RE = re.compile(r'ses-(S\\d+)(O[A-Z]*LINE)_(FES|NOFES)')\n",
"_SUBJ_RE = re.compile(r'SUBJ_(\\d+)')\n",
"\n",
"def parse_session(path):\n",
2026-04-21 22:12:59 -05:00
" base = os.path.basename(path)\n",
" m_subj = _SUBJ_RE.search(base)\n",
" m_ses = _SESSION_RE.search(base)\n",
2026-04-21 21:18:33 -05:00
" if not (m_subj and m_ses):\n",
" return None\n",
" ses_id, raw_kind, stim = m_ses.group(1), m_ses.group(2), m_ses.group(3)\n",
" kind = 'OFFLINE' if 'OFF' in raw_kind else 'ONLINE'\n",
" return m_subj.group(1), ses_id, kind, stim\n",
"\n",
"\n",
"def load_xdf_file(filepath):\n",
" streams, _ = pyxdf.load_xdf(filepath)\n",
"\n",
" eeg_stream = marker_stream = None\n",
" for s in streams:\n",
" stype = s['info']['type'][0].lower()\n",
" if stype == 'eeg': eeg_stream = s\n",
" elif stype == 'markers': marker_stream = s\n",
" if eeg_stream is None or marker_stream is None:\n",
" eeg_stream = streams[0]\n",
" marker_stream = streams[1] if len(streams) > 1 else None\n",
"\n",
" eeg_timestamps = np.array(eeg_stream['time_stamps'])\n",
" eeg_data = np.array(eeg_stream['time_series']).T\n",
" channel_names = get_channel_names_from_xdf(eeg_stream)\n",
" sfreq = float(eeg_stream['info']['nominal_srate'][0])\n",
"\n",
" valid_idx = [i for i, ch in enumerate(channel_names) if ch not in NON_EEG]\n",
" channel_names = [channel_names[i] for i in valid_idx]\n",
" eeg_data = eeg_data[valid_idx, :]\n",
" channel_names = [RENAME.get(ch, ch) for ch in channel_names]\n",
"\n",
" ts_arr = np.asarray(marker_stream['time_series'], dtype=float)\n",
" marker_data = ts_arr[:, 0].astype(int)\n",
" marker_ts = ts_arr[:, 1]\n",
" keep = np.isin(marker_data, TARGET_MARKERS)\n",
" return eeg_data, eeg_timestamps, marker_data[keep], marker_ts[keep], channel_names, sfreq\n",
"\n",
"\n",
"# ── Preprocessing primitives ─────────────────────────────────────────────────\n",
"\n",
"def notch_filter(data, freq, sfreq, Q=NOTCH_Q):\n",
" b, a = iirnotch(freq, Q, fs=sfreq)\n",
" return filtfilt(b, a, data, axis=-1)\n",
"\n",
"\n",
"def car(data):\n",
" return data - data.mean(axis=0, keepdims=True)\n",
"\n",
"\n",
"def bandpass(data, lo, hi, sfreq, order=4):\n",
" nyq = sfreq / 2.0\n",
" b, a = butter(order, [max(lo, 0.5) / nyq, min(hi, nyq - 0.1) / nyq], btype='band')\n",
" return filtfilt(b, a, data, axis=-1)\n",
"\n",
"\n",
"def reject_by_ptp(X, thresh_uv=PTP_REJECT_UV):\n",
" if X.size == 0:\n",
" return np.zeros(0, dtype=bool)\n",
2026-04-21 22:12:59 -05:00
" ptp = X.max(axis=-1) - X.min(axis=-1)\n",
2026-04-21 21:18:33 -05:00
" return ptp.max(axis=-1) < thresh_uv\n",
"\n",
"\n",
"def extract_epochs(eeg_data, eeg_ts, marker_data, marker_ts, sfreq, begin_code,\n",
" t_pre=T_PRE, t_post=T_POST):\n",
" epochs = []\n",
" n_pre = int(abs(t_pre) * sfreq)\n",
" for bi in np.where(marker_data == begin_code)[0]:\n",
" t_start = marker_ts[bi]\n",
" i0 = np.searchsorted(eeg_ts, t_start + t_pre)\n",
" i1 = np.searchsorted(eeg_ts, t_start + t_post)\n",
" if i0 < 0 or i1 > eeg_data.shape[1]:\n",
" continue\n",
" ep = eeg_data[:, i0:i1].copy()\n",
" if ep.shape[1] > n_pre:\n",
" ep -= ep[:, :n_pre].mean(axis=1, keepdims=True)\n",
" epochs.append(ep)\n",
" if not epochs:\n",
" return np.empty((0, eeg_data.shape[0], 0))\n",
" min_len = min(e.shape[-1] for e in epochs)\n",
" return np.stack([e[:, :min_len] for e in epochs])\n",
"\n",
"\n",
2026-04-21 22:12:59 -05:00
"def marker_accuracy(marker_data):\n",
" \"\"\"Online accuracy from EARLYSTOP markers: (MI_EARLYSTOP + REST_EARLYSTOP) / total trials.\n",
" Returns None for offline sessions (no EARLYSTOP markers present).\n",
" \"\"\"\n",
" n_mi_trials = int((marker_data == MI_BEGIN).sum())\n",
" n_rest_trials = int((marker_data == REST_BEGIN).sum())\n",
" n_total = n_mi_trials + n_rest_trials\n",
" if n_total == 0:\n",
" return None\n",
" n_mi_correct = int((marker_data == MI_EARLYSTOP).sum())\n",
" n_rest_correct = int((marker_data == REST_EARLYSTOP).sum())\n",
" # If no EARLYSTOP markers present at all, this is an offline session\n",
" if n_mi_correct + n_rest_correct == 0:\n",
" return None\n",
" return (n_mi_correct + n_rest_correct) / n_total\n",
"\n",
"\n",
2026-04-21 21:18:33 -05:00
"def load_session_epochs(filepath):\n",
2026-04-21 22:12:59 -05:00
" \"\"\"Preprocessing pipeline: notch → CAR → bandpass → epoch → PTP-reject.\n",
" Returns X, y, ch_names, sfreq, n_rejected, mk_acc\n",
" where mk_acc = marker-based online accuracy (None for offline sessions).\n",
2026-04-21 21:18:33 -05:00
" \"\"\"\n",
" eeg, eeg_ts, mk, mk_ts, ch_names, sfreq = load_xdf_file(filepath)\n",
"\n",
2026-04-21 22:12:59 -05:00
" # Marker-based accuracy before any EEG processing (counts all trials as seen by the live system)\n",
" mk_acc = marker_accuracy(mk)\n",
"\n",
2026-04-21 21:18:33 -05:00
" eeg = notch_filter(eeg, NOTCH_FREQ, sfreq)\n",
" if USE_CAR:\n",
" eeg = car(eeg)\n",
" eeg_bp = bandpass(eeg, BP_LO, BP_HI, sfreq)\n",
"\n",
" mi = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, MI_BEGIN)\n",
" rest = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, REST_BEGIN)\n",
"\n",
" n_pre = int(abs(T_PRE) * sfreq)\n",
" if mi.shape[-1] > n_pre: mi = mi[..., n_pre:]\n",
" if rest.shape[-1] > n_pre: rest = rest[..., n_pre:]\n",
"\n",
" n = min(mi.shape[-1], rest.shape[-1]) if (mi.size and rest.size) else 0\n",
" mi, rest = mi[..., :n], rest[..., :n]\n",
"\n",
" n0_mi, n0_rest = len(mi), len(rest)\n",
" mi = mi[reject_by_ptp(mi)]\n",
" rest = rest[reject_by_ptp(rest)]\n",
" n_rejected = (n0_mi - len(mi)) + (n0_rest - len(rest))\n",
"\n",
" X = np.concatenate([mi, rest], axis=0) if (len(mi) or len(rest)) else np.empty((0, len(ch_names), 0))\n",
" y = np.concatenate([np.ones(len(mi), int), np.zeros(len(rest), int)])\n",
2026-04-21 22:12:59 -05:00
" return X, y, ch_names, sfreq, n_rejected, mk_acc\n",
2026-04-21 21:18:33 -05:00
"\n",
"\n",
"# ── CSP + LDA (2-class, numpy/scipy only) ────────────────────────────────────\n",
"\n",
"def _mean_cov(X):\n",
" covs = np.einsum('ijk,ilk->ijl', X, X)\n",
" covs /= np.trace(covs, axis1=1, axis2=2)[:, None, None]\n",
" return covs.mean(axis=0)\n",
"\n",
"\n",
"class CSPLDA:\n",
2026-04-21 22:12:59 -05:00
" \"\"\"CSP log-var features + LDA. Ramoser 2000; Blankertz 2008.\n",
" Ledoit-Wolf shrinkage keeps the generalized eigenproblem well-posed after CAR.\n",
2026-04-21 21:18:33 -05:00
" \"\"\"\n",
"\n",
" def __init__(self, n_csp=N_CSP, cov_shrink=0.05, lda_reg=1e-4):\n",
" self.n_csp = n_csp\n",
" self.cov_shrink = cov_shrink\n",
" self.lda_reg = lda_reg\n",
"\n",
" def fit(self, X, y):\n",
2026-04-21 22:12:59 -05:00
" assert set(np.unique(y)) == {0, 1}\n",
2026-04-21 21:18:33 -05:00
" C1 = _mean_cov(X[y == 1])\n",
" C0 = _mean_cov(X[y == 0])\n",
" n_ch = C1.shape[0]\n",
" s = self.cov_shrink\n",
" C1 = (1 - s) * C1 + s * (np.trace(C1) / n_ch) * np.eye(n_ch)\n",
" C0 = (1 - s) * C0 + s * (np.trace(C0) / n_ch) * np.eye(n_ch)\n",
" evals, evecs = eigh(C1, C0 + C1)\n",
" order = np.argsort(evals)\n",
" k = self.n_csp // 2\n",
" self.filters_ = np.concatenate([evecs[:, order[:k]],\n",
" evecs[:, order[-k:]]], axis=1).T\n",
" F = self._features(X)\n",
" mu1, mu0 = F[y == 1].mean(0), F[y == 0].mean(0)\n",
" Sw = np.cov(F[y == 1].T, ddof=1) + np.cov(F[y == 0].T, ddof=1)\n",
" Sw += self.lda_reg * np.eye(Sw.shape[0])\n",
" self.coef_ = np.linalg.solve(Sw, mu1 - mu0)\n",
" self.intercept_ = -self.coef_ @ ((mu1 + mu0) / 2)\n",
" return self\n",
"\n",
" def _features(self, X):\n",
" Z = np.einsum('fc,ncs->nfs', self.filters_, X)\n",
" var = Z.var(axis=-1, ddof=1)\n",
" return np.log(var / var.sum(axis=1, keepdims=True))\n",
"\n",
" def decision_function(self, X):\n",
" return self._features(X) @ self.coef_ + self.intercept_\n",
"\n",
" def predict(self, X):\n",
" return (self.decision_function(X) > 0).astype(int)\n",
"\n",
"\n",
"# ── Evaluation metrics ───────────────────────────────────────────────────────\n",
"\n",
"def evaluate(clf, X, y):\n",
" margin = clf.decision_function(X)\n",
" pred = (margin > 0).astype(int)\n",
" amp = np.abs(margin).mean()\n",
" m1, m0 = margin[y == 1], margin[y == 0]\n",
" fisher = (m1.mean() - m0.mean()) ** 2 / (m1.var(ddof=1) + m0.var(ddof=1) + 1e-30)\n",
2026-04-21 22:12:59 -05:00
" return dict(amp=amp, fisher=fisher, margin=margin, y=y, pred=pred)\n",
2026-04-21 21:18:33 -05:00
"\n",
"\n",
"def spectral_snr(X, y, ch_idx, sfreq, band=MU_BAND):\n",
" def band_pwr(sig):\n",
" f, p = welch(sig, fs=sfreq,\n",
" nperseg=min(int(sfreq * 2), sig.shape[-1]),\n",
" noverlap=int(sfreq), axis=-1)\n",
" m = (f >= band[0]) & (f < band[1])\n",
" return np.trapezoid(p[..., m], f[m], axis=-1).mean()\n",
" return band_pwr(X[y == 0][:, ch_idx, :]) / (band_pwr(X[y == 1][:, ch_idx, :]) + 1e-30)"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "98d225db",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 5,
2026-04-21 13:01:49 -05:00
"id": "d266216b",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:20:39.380444Z",
"iopub.status.busy": "2026-04-22T03:20:39.380365Z",
"iopub.status.idle": "2026-04-22T03:21:10.042926Z",
"shell.execute_reply": "2026-04-22T03:21:10.041519Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-21 22:12:59 -05:00
"Found 24 XDF file(s).\n",
2026-04-21 21:18:33 -05:00
"Preprocessing: notch 60 Hz → CAR → bandpass 830 Hz → baseline-correct → PTP-reject @ 100 µV\n",
2026-04-21 22:24:23 -05:00
"\n",
" 002/S001 OFFLINE FES n= 85 (MI=43, REST=42) rej=5\n",
" 002/S002 ONLINE FES n= 53 (MI=27, REST=26) rej=7 mk_acc=0.883\n",
" 002/S003 ONLINE NOFES n= 52 (MI=26, REST=26) rej=8 mk_acc=0.833\n",
" 002/S004 OFFLINE NOFES n= 90 (MI=45, REST=45) rej=0\n",
" 002/S005 ONLINE NOFES n= 60 (MI=30, REST=30) rej=0 mk_acc=0.850\n",
" 002/S006 ONLINE FES n= 56 (MI=27, REST=29) rej=4 mk_acc=0.917\n",
" 003/S001 OFFLINE FES n= 89 (MI=44, REST=45) rej=1\n",
" 003/S002 ONLINE FES n= 59 (MI=29, REST=30) rej=1 mk_acc=0.750\n",
" 003/S003 ONLINE NOFES n= 38 (MI=17, REST=21) rej=0 mk_acc=0.763\n",
" 003/S004 OFFLINE NOFES n= 86 (MI=42, REST=44) rej=4\n",
" 003/S005 ONLINE NOFES n= 43 (MI=19, REST=24) rej=17 mk_acc=0.717\n",
" 003/S006 ONLINE FES n= 52 (MI=23, REST=29) rej=8 mk_acc=0.767\n",
" 005/S001 OFFLINE FES n= 90 (MI=45, REST=45) rej=0\n",
" 005/S002 ONLINE FES n= 60 (MI=30, REST=30) rej=0 mk_acc=0.800\n",
" 005/S003 ONLINE NOFES n= 59 (MI=29, REST=30) rej=1 mk_acc=0.933\n",
" 005/S004 OFFLINE NOFES n= 89 (MI=44, REST=45) rej=1\n",
" 005/S005 ONLINE NOFES n= 58 (MI=28, REST=30) rej=2 mk_acc=0.783\n",
" 005/S006 ONLINE FES n= 59 (MI=30, REST=29) rej=1 mk_acc=0.917\n",
" 009/S001 OFFLINE FES n= 57 (MI=33, REST=24) rej=33\n",
" 009/S002 ONLINE FES n= 42 (MI=21, REST=21) rej=18 mk_acc=0.717\n",
" 009/S003 ONLINE NOFES n= 1 (MI=1, REST=0) rej=59 mk_acc=0.717\n",
" 009/S004 OFFLINE NOFES n= 86 (MI=42, REST=44) rej=4\n",
" 009/S005 ONLINE NOFES n= 60 (MI=30, REST=30) rej=0 mk_acc=0.850\n",
2026-04-21 22:12:59 -05:00
" 009/S006 ONLINE FES n= 50 (MI=26, REST=24) rej=10 mk_acc=0.817\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"Loaded 4 subject(s): ['002', '003', '005', '009'] | total artifact-rejected epochs: 184\n"
2026-04-21 21:18:33 -05:00
]
}
],
"source": [
"xdf_files = sorted(glob.glob(os.path.join(DATA_DIR, '*.xdf')))\n",
"print(f'Found {len(xdf_files)} XDF file(s).')\n",
"print(f'Preprocessing: notch {NOTCH_FREQ:.0f} Hz → '\n",
" f'{\"CAR → \" if USE_CAR else \"\"}bandpass {BP_LO:.0f}{BP_HI:.0f} Hz → '\n",
" f'baseline-correct → PTP-reject @ {PTP_REJECT_UV:.0f} µV\\n')\n",
"\n",
2026-04-21 22:12:59 -05:00
"sessions = {}\n",
2026-04-21 21:18:33 -05:00
"total_rej = 0\n",
"\n",
"for fp in xdf_files:\n",
" meta = parse_session(fp)\n",
" if meta is None:\n",
" print(f' SKIP (unparsed): {os.path.basename(fp)}')\n",
" continue\n",
" subj, ses_id, kind, stim = meta\n",
" try:\n",
2026-04-21 22:12:59 -05:00
" X, y, ch_names, sfreq, n_rej, mk_acc = load_session_epochs(fp)\n",
2026-04-21 21:18:33 -05:00
" except Exception as e:\n",
" print(f' ERROR {os.path.basename(fp)}: {e}')\n",
" continue\n",
"\n",
2026-04-21 22:24:23 -05:00
" sessions.setdefault(subj, {})[ses_id] = dict(\n",
2026-04-21 21:18:33 -05:00
" X=X, y=y, kind=kind, stim=stim,\n",
2026-04-21 22:12:59 -05:00
" ch_names=ch_names, sfreq=sfreq,\n",
" mk_acc=mk_acc, file=os.path.basename(fp))\n",
2026-04-21 21:18:33 -05:00
" total_rej += n_rej\n",
"\n",
2026-04-21 22:12:59 -05:00
" acc_str = f' mk_acc={mk_acc:.3f}' if mk_acc is not None else ''\n",
2026-04-21 21:18:33 -05:00
" print(f' {subj}/{ses_id} {kind:<7} {stim:<5} '\n",
" f'n={len(y):3d} (MI={int(y.sum())}, REST={int((1-y).sum())}) '\n",
2026-04-21 22:12:59 -05:00
" f'rej={n_rej}{acc_str}')\n",
2026-04-21 21:18:33 -05:00
"\n",
"subjects = sorted(sessions.keys())\n",
"print(f'\\nLoaded {len(subjects)} subject(s): {subjects} | '\n",
2026-04-21 22:24:23 -05:00
" f'total artifact-rejected epochs: {total_rej}')"
2026-04-21 21:18:33 -05:00
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "7b8c8bea",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"## Verify Session Layout"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 6,
2026-04-21 13:01:49 -05:00
"id": "611baf23",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:21:10.051665Z",
"iopub.status.busy": "2026-04-22T03:21:10.051368Z",
"iopub.status.idle": "2026-04-22T03:21:10.057208Z",
"shell.execute_reply": "2026-04-22T03:21:10.056834Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Channels (32): ['Fp1', 'Fpz', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'FC5', 'FC1', 'FC2', 'FC6', 'M1', 'T7', 'C3', 'Cz', 'C4', 'T8', 'M2', 'CP5', 'CP1', 'CP2', 'CP6', 'P7', 'P3', 'Pz', 'P4', 'P8', 'POz', 'O1', 'Oz', 'O2']\n",
"Sampling rate: 512.0 Hz\n",
"Motor channels ['C3', 'Cz', 'C4'] → indices [14, 15, 16]\n"
]
}
],
"source": [
"# Verify channel layout is consistent across sessions, locate motor channels\n",
"ref_subj = subjects[0]\n",
"ref_ses = next(iter(sessions[ref_subj].values()))\n",
"channel_names_global = ref_ses['ch_names']\n",
"sfreq_global = ref_ses['sfreq']\n",
"\n",
"mismatches = [f'{subj}/{sid}' for subj in subjects for sid, s in sessions[subj].items()\n",
" if s['ch_names'] != channel_names_global]\n",
"if mismatches:\n",
" print('!! channel mismatch in:', mismatches)\n",
"\n",
"motor_idx_global = [channel_names_global.index(c) for c in MOTOR_CH\n",
" if c in channel_names_global]\n",
"\n",
"print(f'Channels ({len(channel_names_global)}): {channel_names_global}')\n",
"print(f'Sampling rate: {sfreq_global} Hz')\n",
"print(f'Motor channels {MOTOR_CH} → indices {motor_idx_global}')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "70922abb",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"## Train CSP + LDA on OFFLINE, Evaluate on ONLINE"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 7,
2026-04-21 13:01:49 -05:00
"id": "f5e80da3",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:21:10.059575Z",
"iopub.status.busy": "2026-04-22T03:21:10.059490Z",
"iopub.status.idle": "2026-04-22T03:21:11.605957Z",
"shell.execute_reply": "2026-04-22T03:21:11.605570Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-21 22:12:59 -05:00
"[009] Pair1 (train=OFFLINE_FES) / NOFES: only 1 clean epochs — acc=0.717 (marker-based), skipping EEG metrics\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"Subj Pair Cond n trainAcc mkAcc |marg| Fisher muSNR\n",
2026-04-21 21:18:33 -05:00
"-------------------------------------------------------------------------------------------\n",
2026-04-21 22:12:59 -05:00
"002 Pair1 (train=OFFLINE_FES) FES 53 0.659 0.883 0.778 0.288 1.197\n",
"002 Pair1 (train=OFFLINE_FES) NOFES 52 0.659 0.833 1.162 0.073 1.452\n",
"002 Pair2 (train=OFFLINE_NOFES) FES 56 0.811 0.917 1.012 3.133 1.665\n",
"002 Pair2 (train=OFFLINE_NOFES) NOFES 60 0.811 0.850 0.756 0.695 1.576\n",
"003 Pair1 (train=OFFLINE_FES) FES 59 0.843 0.750 0.879 0.017 1.347\n",
"003 Pair1 (train=OFFLINE_FES) NOFES 38 0.843 0.763 0.910 0.683 1.219\n",
"003 Pair2 (train=OFFLINE_NOFES) FES 52 0.907 0.767 1.166 0.000 1.273\n",
"003 Pair2 (train=OFFLINE_NOFES) NOFES 43 0.907 0.717 1.071 0.029 1.296\n",
"005 Pair1 (train=OFFLINE_FES) FES 60 0.978 0.800 3.346 2.498 2.695\n",
"005 Pair1 (train=OFFLINE_FES) NOFES 59 0.978 0.933 3.620 3.029 2.399\n",
"005 Pair2 (train=OFFLINE_NOFES) FES 59 1.000 0.917 5.740 6.291 2.911\n",
"005 Pair2 (train=OFFLINE_NOFES) NOFES 58 1.000 0.783 5.261 4.325 2.287\n",
"009 Pair1 (train=OFFLINE_FES) FES 42 0.667 0.717 0.304 0.040 1.326\n",
"009 Pair2 (train=OFFLINE_NOFES) FES 50 1.000 0.817 4.673 6.192 1.594\n",
"009 Pair2 (train=OFFLINE_NOFES) NOFES 60 1.000 0.850 3.754 8.959 2.000\n"
2026-04-21 21:18:33 -05:00
]
}
],
"source": [
2026-04-21 22:12:59 -05:00
"MIN_TEST_TRIALS = 10\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
"results = []\n",
2026-04-21 21:18:33 -05:00
"\n",
"for subj in subjects:\n",
" subj_ses = sessions[subj]\n",
"\n",
" for pair in PAIRS:\n",
2026-04-21 22:24:23 -05:00
" needed = (pair['train'], pair['online_fes'], pair['online_nofes'])\n",
" missing = [k for k in needed if k not in subj_ses]\n",
2026-04-21 21:18:33 -05:00
" if missing:\n",
2026-04-21 22:24:23 -05:00
" print(f'[{subj}] {pair[\"name\"]}: missing {missing} — skipping')\n",
2026-04-21 21:18:33 -05:00
" continue\n",
"\n",
2026-04-21 22:24:23 -05:00
" train = subj_ses[pair['train']]\n",
2026-04-21 21:18:33 -05:00
" if set(np.unique(train['y'])) != {0, 1}:\n",
" print(f'[{subj}] {pair[\"name\"]}: training set lacks both classes — skipping')\n",
" continue\n",
2026-04-21 22:12:59 -05:00
"\n",
" clf = CSPLDA(n_csp=N_CSP).fit(train['X'], train['y'])\n",
2026-04-21 21:18:33 -05:00
" train_acc = (clf.predict(train['X']) == train['y']).mean()\n",
"\n",
2026-04-21 22:24:23 -05:00
" for cond_key, cond_label in [('online_fes', 'FES'), ('online_nofes', 'NOFES')]:\n",
" te = subj_ses[pair[cond_key]]\n",
2026-04-21 22:12:59 -05:00
" acc = te['mk_acc']\n",
" if acc is None:\n",
" print(f'[{subj}] {pair[\"name\"]} / {cond_label}: no EARLYSTOP markers — skipping')\n",
" continue\n",
"\n",
2026-04-21 21:18:33 -05:00
" if len(te['y']) < MIN_TEST_TRIALS or set(np.unique(te['y'])) != {0, 1}:\n",
2026-04-21 22:12:59 -05:00
" print(f'[{subj}] {pair[\"name\"]} / {cond_label}: only {len(te[\"y\"])} clean epochs — '\n",
" f'acc={acc:.3f} (marker-based), skipping EEG metrics')\n",
2026-04-21 21:18:33 -05:00
" continue\n",
2026-04-21 22:12:59 -05:00
"\n",
2026-04-21 21:18:33 -05:00
" res = evaluate(clf, te['X'], te['y'])\n",
" snr_s = spectral_snr(te['X'], te['y'], motor_idx_global, te['sfreq'])\n",
"\n",
" results.append(dict(\n",
" subject=subj, pair=pair['name'], condition=cond_label,\n",
" train_file=train['file'], test_file=te['file'],\n",
" train_acc=train_acc, n_test=len(te['y']),\n",
2026-04-21 22:12:59 -05:00
" acc=acc,\n",
" amp=res['amp'], fisher=res['fisher'], mu_snr=snr_s,\n",
2026-04-21 21:18:33 -05:00
" margin=res['margin'], y_test=res['y'], pred=res['pred'],\n",
" ))\n",
"\n",
2026-04-21 22:12:59 -05:00
"hdr = (f'{\"Subj\":<5} {\"Pair\":<28} {\"Cond\":<6} {\"n\":>4} '\n",
" f'{\"trainAcc\":>9} {\"mkAcc\":>7} {\"|marg|\":>8} {\"Fisher\":>8} {\"muSNR\":>8}')\n",
2026-04-21 21:18:33 -05:00
"print('\\n' + hdr)\n",
"print('-' * len(hdr))\n",
"for r in results:\n",
" print(f'{r[\"subject\"]:<5} {r[\"pair\"]:<28} {r[\"condition\"]:<6} {r[\"n_test\"]:>4} '\n",
" f'{r[\"train_acc\"]:>9.3f} {r[\"acc\"]:>7.3f} {r[\"amp\"]:>8.3f} '\n",
" f'{r[\"fisher\"]:>8.3f} {r[\"mu_snr\"]:>8.3f}')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "2ab81600",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"---\n",
"## Figure 1 — Per-metric comparison (FES vs NOFES)"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 8,
2026-04-21 13:01:49 -05:00
"id": "d53e63b9",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:21:11.609669Z",
"iopub.status.busy": "2026-04-22T03:21:11.609563Z",
"iopub.status.idle": "2026-04-22T03:21:12.097953Z",
"shell.execute_reply": "2026-04-22T03:21:12.097507Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"data": {
2026-04-21 22:12:59 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABoAAAAQ8CAYAAACyzFyVAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3Qm85WP9OPDnjrkztjHGxMwwM7axbxESKQpFMsrSQiT8ZCJKKQlNkZBfSrSIKH6SqGnRjhYJlWwlpjCWyTAYBjNzZ+b8X59v/+/te8895+73nuW+36/Xccf5nuU5z/f5Ls/zeZaWUqlUSgAAAAAAADSNEbVOAAAAAAAAAANLAAgAAAAAAKDJCAABAAAAAAA0GQEgAAAAAACAJiMABAAAAAAA0GQEgAAAAAAAAJqMABAAAAAAAECTEQACAAAAAABoMgJAAAAAAAAATUYACAAAAAAAoMkIAAEAAAAAADQZASCAGrvhhhvSe9/73rTJJpuksWPHptGjR6dJkyal3XbbLZ155plp3rx5g/K9Dz/8cGppaWl/RBqK4v+L2+P1je5Tn/pUh9908803p2a03nrrtf/G+HfR5Zdf3iEP4v+bRXmZ7epR/rvjeOvpeyuVmyeeeCJ9/OMfT9ttt112HLe2tqZXvOIVaeONN0577rlnOumkk9I111yTmuXYicf111/f6XWRN12dV4qef/75dMEFF6Q3v/nNafLkyWmllVZKY8aMSRtuuGF65zvfmb7zne+kZcuW9aicd/coP3/19H3xqOTPf/5zet/73pft31VWWSWtuOKK2Xl7yy23TG9729vSGWeckW655ZbUW6VSKV100UXp1a9+dVpttdXSiBEj2tPxgx/8INWr8uOnqzIRZale1fr8GOU0ylJ8d9wLPPbYY0P6/c2mv/uzq2spg+eBBx5Ihx56aJoyZUoaNWpU+z545Stf2f6axYsXp89+9rPZc3EOLu7nv/71r8PyHrcWGvW+ergd24NV3rva/y+99FJac80127fdeuutA/KdAPTNyD6+D4B+evTRR7NGzj/84Q+dtv373//OHr/5zW+yCu65556bjjvuOHkOdei3v/1teutb35oFNIrmz5+fPR588MH0q1/9Kk2YMCG94x3vSM3i9NNPT/vvv38WpOit733ve+l//ud/0rPPPttp28KFC9O//vWvLGD26U9/On33u9/NAiv14n//93/TRz/60bR8+fKK5+377rsvC9b84x//SLvsskuv8zQC/wxPJ598ctawHQ4//PAsMAp5kDXuCYvB4mYUnZ5e85rXpGeeeabL1x111FHpyiuvTMNRBCweeeSR7N/rrruu4BV1aeWVV04f/vCH0yc+8Yns/z/0oQ9lQaBqHWsAGFwCQAA18Pjjj2c9vOfOnfvfE/LIkdlzq6++errrrrvae/6+/PLL6fjjj09PP/30kPac3mGHHbKG2Fz0sKQ5Gg4OOOCADv/frDbbbLO0+eabV9zW3e/efvvts4aVSqJHY+7FF1/MgjrF4M9GG22Upk2blgVG4jj+29/+ltra2lKziUDH1VdfnQ455JBeve+yyy5LRx55ZIfnYrRUnHMWLVqUNRDE3/D3v/89awyMQPlWW23V5ee+7nWv67Bviro7f+29995ZY0V3/vSnP6WPfOQj7Y2vsY9j1FeM/omG+3/+859Z8KqvjbOXXHJJh/9//etfn+VNWGeddfr0mTSG22+/PV177bXZv6OBLIJB1NY+++zTPgp7rbXWsjuGQIwsLQZ/4ty64447ZvfI66+/fvt1N649udi2++67Z6MmQ9xHd8c97sCIe6ziPWW1azDDU3RePOuss7Jj9rbbbss6/xx00EG1ThbAsCQABFAD73nPezoEfzbddNP0ox/9KGs0DtGz/HOf+1w69dRT218TPeGjghsNgkPhAx/4QPag+XoRx2M4OPjgg/scNI2y39X0Zbmf/exn2aiP3Oc///lsureiqPj+4he/SD/+8Y9Ts4n8jQBYNMD1RIyGOvbYYzvldYyqial+wpNPPpk1KOVTqEUg+sADD8wCTl19z8yZM/tcti+++OIeBUOvuOKK9uBONNJHg0YEC4uiPMyaNSsb/dVb8dtzO+20U8NMp0P/ffGLX2z/984779x+P0DtxHmhGd1xxx1Zr/wIOFYLlsyZMycddthhWcB+gw02GLK0Fc+B4Zvf/GZ605ve1OG56BBVnB40rg/FgFDoboor97gDd58VD6gkpvWNkeJXXXVV+3VOAAigNqwBBDDEfve736Wbbrqp/f9XWGGFdN1113Vo7Ile5VE5f9e73tX+XDQ6RhCouzmdf/rTn2ZrjkSlPnq0R8/J73//+wM6X3SludVjKqfosRzrd+TrGMUUT101gv785z/PGo+j4TXW/1h11VWzqZ7ic2JNlb6IXvjR2yzWVIq1FKLX/DHHHNOpUaGaF154IWuMznveR6P0+PHj0xvf+MasIWLp0qVV3xu9/2N0wKte9ao0bty47L0TJ07MpoGKqZ0qjQKJwF80dk+dOjVLb+RBpP3oo49Od955Z5cN6VE+ordl5N0222yTNVZ1N/KguzURyudFj98ba5LEKIcoT1Gu9t1333T33XdX/PwIXn71q1/N5uWPdEUevv3tb0/33HNPt99dvr2e1wop7oeiN7zhDRVHn8S6MJdeemmPPzemfizmRRzX5aJ8FF8zY8aM9m0xCiVGDsbxFBXwWJMoysoWW2yRTT35hS98IT333HOpv2bPnt2rdTUisL1kyZIODd0XXnhhe/AnxFR5cc7Ke3Pna0KUN/DVen/HWk9xrJeLYz7OOfm0Jz2RH3dFf/zjH6uuq9Of81SMSIvGzygLUTbifB3fH1OO/eUvf6n6vui0EOf0tddeOztXRceFONcW92dPRZAs8iimOIvvj+tfrJsUawaUi+n0ImgYI8FiVF5enuN3v/a1r82Ola7KcmyLwGx0oIhRHJFXa6yxRtp6662zY6SnayHEqLS99tqrwz6JdA3EVFxPPfVU1jM69+53v7vTaypdd2OkREyrE/sv8jGuIyeccEKX02f15bpb/N4IssaIx5gGMa73kZ89DbxGuY0yE8HN2AexH+OaEvv/LW95SxbEjakTu/ru/q7lEcdw5G+cZ6Icx2+P82GlY6Ynn93X4ynKTXQKiH0RAZa4TsQ1Nh+lG0Hk4r1Ycfq38nzpzUje+++/P7tHjGlJ99hjj4rTcMb0YnFuie+Mv1E++6I39zf59b/8uh9rxBXvCSr93lgvrrd50Yj3uPlaa/n0byH+Xe0YKS+/+f1xjNqJ9OR5Fefds88+OwvkRPriGha/MS+PcQ9XrR7R3RpA5WmLWQ3yNESZiPN43Mt2dR6Oa2HsgxhdHcdJpD3W34t7nvJ7sIG4T+5O+W+K821cS6K+Ed8T97/f/va321//k5/8JO26667Z/o/zXUwZfO+991b9/FjDKspWXGPjPZFPsR5W7If8vFDJggULsvNyfj2Iv/H/cd7tiTjO4xwc5+e8DhPnybjnr7TmY08Vr2fRsada/QGAQVYCYEh96EMfippH++NNb3pT1df+4Q9/6PDaFVZYofTcc8+1bz/88MM7bH/3u9/d4f+Lj6uvvrrDZz/00EMdtsdnFZV/dry+2ntf+9rXlqZOnVrxe7fZZpvS4sWLO3x2/P/BBx9cNa3xWH311Uu/+tWvepW3ixYtKr3+9a+v+HkTJ07slD833XRTh/ffeeedpXXXXbfLdO26666lZ599ttN3f+lLXyqNGjWqy/cW3/fSSy+V3vrWt3b5+hEjRpRmzpzZ6btuv/320mqrrVbxPQcddFBp8uTJ7f8fv6fom9/8ZofXx/8XFX//hAkTSm984xsrfs+YMWNKDz74YKe0vec976n4+tGjR5cOPfTQLr+7PG1nnHFGL/Z+5zLbm/eXl5vytFVz/vnnd3jfdtttV/re975Xevrpp0v98fjjj2fHe/65kXflTjrppA7f/ec//zl7/m9/+1tp7NixXZateNxxxx09Tk/kZflxkP97ypQp2bEX4piqdl5Zvnx5afz48V2el4r+53/+p8NrDzjggA7
2026-04-21 21:18:33 -05:00
"text/plain": [
"<Figure size 1680x1080 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: fes_vs_nofes_metrics.png\n"
]
}
],
"source": [
"METRICS = [\n",
" ('acc', 'Classification accuracy', '01'),\n",
" ('amp', 'Classification amplitude (mean |decision fn|)', 'a.u.'),\n",
" ('fisher', 'Fisher ratio on LDA projection (test-set SNR)', 'a.u.'),\n",
" ('mu_snr', 'μ-band power ratio REST / MI @ C3/Cz/C4', 'ratio'),\n",
"]\n",
"\n",
"cond_color = {'FES': '#E05C2A', 'NOFES': '#2A7BE0'}\n",
"\n",
"fig, axes = plt.subplots(2, 2, figsize=(14, 9))\n",
"fig.suptitle('Online decoding: FES vs NOFES feedback (per subject × offline-trained model)',\n",
" fontsize=13, fontweight='bold', y=1.00)\n",
"\n",
"for ax, (key, title, unit) in zip(axes.ravel(), METRICS):\n",
" labels, vals, colors = [], [], []\n",
" for subj in subjects:\n",
" for pair in PAIRS:\n",
" tag = pair['name'].split()[0] # \"Pair1\" / \"Pair2\"\n",
" for cond in ('FES', 'NOFES'):\n",
" row = next((r for r in results\n",
" if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']==cond), None)\n",
" if row is None: continue\n",
" labels.append(f'{subj}\\n{tag}\\n{cond}')\n",
" vals.append(row[key])\n",
" colors.append(cond_color[cond])\n",
" x = np.arange(len(vals))\n",
" ax.bar(x, vals, color=colors, edgecolor='white', zorder=2)\n",
" ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=7.5)\n",
" ax.set_title(f'{title} ({unit})', fontsize=11, fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
" if key == 'acc':\n",
" ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n",
"\n",
"fig.legend(handles=[Patch(color=cond_color['FES'], label='ONLINE_FES'),\n",
" Patch(color=cond_color['NOFES'], label='ONLINE_NOFES')],\n",
" loc='upper right', ncol=2, bbox_to_anchor=(0.98, 1.0))\n",
"plt.tight_layout()\n",
"plt.savefig('fes_vs_nofes_metrics.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: fes_vs_nofes_metrics.png')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "248740bd",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"---\n",
"## Figure 2 — LDA decision-function distributions\n",
"\n",
"Visualizes classification amplitude and separability directly: wider FES vs NOFES spread between MI and REST curves = higher Fisher ratio and larger mean |margin|."
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 9,
2026-04-21 13:01:49 -05:00
"id": "393042a0",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:21:12.099369Z",
"iopub.status.busy": "2026-04-22T03:21:12.099281Z",
"iopub.status.idle": "2026-04-22T03:21:12.990027Z",
"shell.execute_reply": "2026-04-22T03:21:12.989595Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"data": {
2026-04-21 22:12:59 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAABZAAAAYTCAYAAABt5R22AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3QeUG9X1x/Grsn3tdW9gm2aH3jsBbHoILWBqANMhEEJC78YECB1CMSHgYFpooZMAgVBDNQnNhGaasXEva6+3Spr/+T0z+kta1fX2/X7O0VmtNDN6mhmNRnfuuy/geZ5nAAAAAAAAAACkCKY+AAAAAAAAAAAAAWQAAAAAAAAAQEZkIAMAAAAAAAAA0iKADAAAAAAAAABIiwAyAAAAAAAAACAtAsgAAAAAAAAAgLQIIAMAAAAAAAAA0iKADAAAAAAAAABIiwAyAAAAAAAAACAtAsgAAAAAAAAAgLQIIAMAAAAAAAAA0iKADABADoFAIH675JJL2n19jRkzJv76ut+V30uhvv76azv00ENt6NChFg6H422fMmWKdRevvPJK0nbR/11Rpn2rM7+/rtbmztouoCf79ttvkz6X3en7CQAAHwFkAMBK/1g66qijcs6jH1SJ8+hWXFxsvXv3ttVWW80FRs855xz79NNP82rD448/3mx5Z5xxBluzG6mrq7M999zTHnzwQZszZ45Fo1HrahSUTNxHsXK6c6CGfaXzef/99+3kk0+2DTfc0Pr06eO+swYOHGjbbbedXXzxxfbDDz/kFejX7YUXXsh6cVDfg9m+M/PZ1/VdnDiPPi+ZPju63XHHHTmXket9pbulvhcAAND1hTu6AQCAnqupqcndli1bZt999529+uqrdvXVV9sxxxxjN910k1VUVGScd/Lkyc0eu+++++zKK6+0oqIi605+9atf2V577eXuDx8+fKWWdc0118Tvb7vtttaZTZ061T7//PP4/1oH22+/vQWDQdtiiy2su1hzzTWTtov+70468/vrSp+Hzr4uu5P6+nr7zW9+kzbAumDBAnd788033ffV9ddf74LMuZx99tn23//+t1NdSJowYYL98pe/tPLy8o5uSpfWr1+/pM9ld/p+AgDARwAZANAhDj74YNt8881t6dKl9tFHH9mzzz5rjY2N7rm//OUvNn36dPvnP/9pJSUlzeZV1tdzzz3X7PF58+bZ008/bfvvv791t3XVWs4880zrKhKz5+TGG2/slgEzXRToStulq7+/SCTiLlyVlZV1qnZ1xXXZHXmeZ0cccYT97W9/iz82ePBgO+igg2zIkCHuotZDDz1kDQ0N7nbKKae4efQ3mw8++MBd5NSyO4vZs2fbddddZxdddFGL5t91111tt912a/Z4VVWV9STqScXnEgDQ3VHCAgDQIfbYYw/3g+vSSy+1J554wr788kvbZJNN4s+/9tprdvnll6edV115/XIGqou7xhprxJ9T8Lklqqur7fTTT3cBmtLSUvvJT35iV111lQs05aKyG8oSXnvttV3WtAJTmv93v/udzZo1K+N8f//7323cuHE2cuRI95r6EaplHHvssfbVV1/lVQNZwQxNP2rUKPe66mKtesHKgFKb/vWvfxVUA3natGl2wgkn2OjRo11Wmpa51lprudf48MMPc3a7V0DlD3/4g3sfCv4r4HLSSSe5CwX58rtajx8/PulxtSOxa3auerDZ1lvqelBm4L777mt9+/Z173nLLbd02yedWCxmf/3rX11G9LBhw9z71HwbbLCBCyIpO9Fv28SJEzO+rl/6JZ+6tv/4xz/chZFVVlklXvpl4403tvPPP9/mzp3bbPrU965ptD9ofrVX+4sy5hT4auvPSbb3p3V522232Q477GADBgxwn2cFn7St99lnH7vsssts+fLlblp1i1999dWTln300Uen7W6vaRPX88cff+y2b//+/V0PhXfeeafgmuAKKG699dbuM66MQ312EzPk8ymxkVgewO/m31X3lUK2Xb46w37+6KOPJgWP119/fXeMV68YtePuu++2t99+O6mHjL7Lsh3rfQrU6hjZmWj96OJrSyhrX+899Xb88cfnvYzFixfbBRdc4Laztrf2I+1P6623nh1++OF25513NptHxxs9vssuu7iSItpXNI+C2Y888kja13n33XftkEMOiX/f6rbqqqu6ciS//e1v7b333mvx9PmU1lHGurK9dQzTsrT/rLvuui7TXbX+U6UeK/QdqjJfml/79ogRI+zcc8+NX3hPpLJPWhe68KHjXa9evdwydN6lfVAloQAAKJgHAECBvvnmG/0aj9/Gjx+fc5677roraR79n2rGjBleSUlJfJrevXt7jY2NSdPEYjFvzTXXjE+zxx57eFdccUX8/1Ao5M2aNaug97N06VJvww03TGqff9t7772T/p8wYULSvHfeeadXXFycdl7d+vbt6/373/9Omkfvady4cRnn0e3xxx+PT7/jjjvGH9d932effeZVVlZmXU7qtsn2Xm6//XavqKgo47LC4bA3adKkpHm0jMRptt9++7Tzjh07tsX7V7qbpnn55ZeTHtP/iTKtt9T1sNVWW6XdhsFg0HvppZeS5luyZEnG9+jf3n///WZty7Ztsr2PaDTqHXXUUVmXM2DAAO/NN9/M+N7XWGMNb9iwYWnnveSSS9r8c5Lt/R1//PF5bWsZOXJkzml9idNusskmXkVFRdo25Nvm1Pfn3/r06eN9+OGHGffd1OOctrn/nNqY7rW6yr5SyLbLpTPt5zvttFPSvM8++2za6c4555yk6S699NL4c6nbaejQofH7V199ddr2+/tDId+ZqRL3r9T1n7pvJrbp5JNPzriMRKnvK/U7pFD19fXe+uuvn3W7p66XhQsXeltssUXWeQ499FC3T/leeeUV9/2VbZ7E91Lo9Lk+9xdddJEXCAQyLkvHpyeeeCLjtuzfv7+37rrrpp1Xn5tEl19+ec7PZep3JQAA+aCEBQCg01BWozJknnzySfe/Mm6U5bPNNtvEp1Gd5MTsXGX0/PSnP3UZTIoHKTNZGWLnnXde3q+rwZBURsO30UYbuQw6vc4DDzyQcT5lMSpbV5l4ogxUZTmqHcoA0vzKrvrFL37hMqz9br1nnXVWUoabshnVPVqZw5pHZTjycdddd1lNTY27rwGelI2pLCxl4Wk5yuLOl7KjlLnnvxctR12tQ6GQ3XPPPS5DTV3/lWGr96l1ns7rr7/u3q8yq+6///54GYqXX37Zra+tttoq73qS2vbqKu5T9p8yff1pUktctJTapcwy7Uvff/+9yy4WrQvVOB07dmx8WmVF6z0m7rP77befa89nn30W33Z+rVqVYUkcPCuxTqYyG3PR9InZbJpH+5i2sfZzZeIp41mPJe5jiZTdpow3bV9lVytrVAMUiuq3ar3mUze8pZ+TTLTvJvYY2Gmnndy6VobmzJkzXQ3sTz75JP68PuPa5ldccUWzUji5BkLTfqztq4xplcfJVl89HW1XZdoqy1XZ6s8884x7fMmSJS5TUI+1VFfcVwrddp2p7dno++Pf//53/H8dV/WdlI6yU5V978t2vFXW/u9//3v3naYeGscdd1z8WNZRtN/qe+qbb76xP//5zy6rVhnbhdD3xrXXXps2MzmfmuL6XlCvF1Fte33n6DOq780ZM2YkbQvfkUce6fYv0fbWdlDWu/Y3fV/ouK3jkfYhbXPRvqDvL1F2ujKblZWrcljanxKP6S2ZPhu1Sdvep0xgHbdqa2vj3+HK1Nf70LpIV6Zp4cKFbp3ovavXi7Kv9XkQfT/rmKjzB1GmvE/HRn/8BH236fj9n//8J++2AwCQJK8wMwAA7ZCBLGeffXbSdA8//HDS84cffnj8ufLycm/ZsmXu8W233Tb++FprrZX39mpqavJ69eoVn3f06NEuK8qnrLJMWUcHHHBA/PGNNtrIa2hoSMqSKi0tjT9/ww03uMcXL16clOU7YsQIb968eUlt0nuaO3duzkza0047Lf74iSee2Oy9KdP522+/TXos03vZf//9k7K4P/300/hzX3zxRVI21j777JMxA/m3v/1t/LkPPvg
2026-04-21 21:18:33 -05:00
"text/plain": [
2026-04-21 22:12:59 -05:00
"<Figure size 1440x1536 with 8 Axes>"
2026-04-21 21:18:33 -05:00
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: decision_margin_distributions.png\n"
]
}
],
"source": [
"fig, axes = plt.subplots(len(subjects), len(PAIRS),\n",
" figsize=(6 * len(PAIRS), 3.2 * len(subjects)),\n",
" sharex=True, squeeze=False)\n",
"fig.suptitle('LDA decision-function distributions on ONLINE sessions\\n'\n",
" '(class separation ↔ classification amplitude & SNR)',\n",
" fontsize=12, fontweight='bold', y=1.01)\n",
"\n",
"for i, subj in enumerate(subjects):\n",
" for j, pair in enumerate(PAIRS):\n",
" ax = axes[i][j]\n",
" for cond, color in cond_color.items():\n",
" row = next((r for r in results\n",
" if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']==cond), None)\n",
" if row is None: continue\n",
" m_mi = row['margin'][row['y_test'] == 1]\n",
" m_rest = row['margin'][row['y_test'] == 0]\n",
" ax.hist(m_mi, bins=15, alpha=0.5, color=color,\n",
" label=f'{cond} MI', density=True)\n",
" ax.hist(m_rest, bins=15, alpha=0.25, color=color, hatch='///',\n",
" edgecolor=color, label=f'{cond} REST', density=True)\n",
" ax.axvline(0, color='k', lw=0.8)\n",
" ax.set_title(f'{subj} | {pair[\"name\"]}', fontsize=10, fontweight='bold')\n",
" if j == 0: ax.set_ylabel('Density')\n",
" if i == len(subjects) - 1: ax.set_xlabel('LDA decision function')\n",
" ax.spines[['top','right']].set_visible(False)\n",
" if i == 0 and j == 0:\n",
" ax.legend(fontsize=6.5, loc='upper left', ncol=2)\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('decision_margin_distributions.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: decision_margin_distributions.png')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "fcb6d19d",
"metadata": {},
2026-04-21 21:18:33 -05:00
"source": [
"---\n",
"## Figure 3 — Paired Δ (FES NOFES) per metric\n",
"\n",
"Within each (subject × pair), FES and NOFES sessions use the same offline-trained model. Positive bars mean FES > NOFES; negative means NOFES > FES. This removes the offline-model-quality confound and isolates the effect of feedback type."
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 10,
2026-04-21 13:01:49 -05:00
"id": "75df404b",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:21:12.992553Z",
"iopub.status.busy": "2026-04-22T03:21:12.992447Z",
"iopub.status.idle": "2026-04-22T03:21:13.261572Z",
"shell.execute_reply": "2026-04-22T03:21:13.261249Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"data": {
2026-04-21 22:12:59 -05:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB3kAAAIwCAYAAACLCdFHAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAAtU1JREFUeJzs3Ql8FPX9//FPgIRLwBjlUA4RtICAF1CEolBvq4gYqFoRtSiIBwgWBYuIihYRRdGgRSwinigatVqrBS9EwAuhLRWscigKIhIRAonu//H+/v6znd3sbjZhQ3aS1/Px2McmszOzc+1cn/l8vhmhUChkAAAAAAAAAAAAAIBAqFHZEwAAAAAAAAAAAAAASB5BXgAAAAAAAAAAAAAIEIK8AAAAAAAAAAAAABAgBHkBAAAAAAAAAAAAIEAI8gIAAAAAAAAAAABAgBDkBQAAAAAAAAAAAIAAIcgLAAAAAAAAAAAAAAFCkBcAAAAAAAAAAAAAAoQgLwAAAAAAAAAAAAAECEFeAAAAAAAAAAAAAAgQgrwAAAAAAAAAAAAAECAEeQEAAFBpZs+ebRkZGeHXF198UeZx+Ie/6aabkh6ud+/e4eH0d1V00UUXhefx4IMPruzJARBDKBSyY445xv1Oa9asaZ9++mlgllNl7sORHrTO/etQ2wT+z+uvvx5eLmeffTaLBQAAAClHkBcAAAAJ/e1vf4u4gfvss8+W6KdZs2bhzxs0aGA//fRT3Budej3//POlLvU33ngjYhj9j/T13HPPRawvvUaPHm3pSIGk6GmN9VKQ3E+B8mSGi7Wtzp8/38444wz3W8nKynK/k5YtW1r37t3t0ksvtQcffNDSSfQ83XDDDaUux3gBvtWrV9u1115rXbp0sZycHMvMzHTv+l/biD5PJniU6JWK9etN6/Dhw619+/ZWv359q127tjVt2tQOP/xwy83NtVtuucXWr19vqfToo4/ahx9+6P4+55xz7LDDDrOqgH142R/GSfSKDoAn+9uI/l3+/PPP9vDDD9uJJ55ojRs3dr/HRo0aWevWre1Xv/qVXXHFFfb0009bOvPvi2P9jtNl+9My1n5OdN7z9ttv7/VpAAAAQNVWq7InAAAAAOmtZ8+eLrvMC9y++eabLhDhUdbZ119/Hf5/+/btLmDRtWvXcDcN49HN1l69erm/1c+UKVPCn+233362t1x++eUu6CYtWrSwqujcc8+1jh07ur91E78izZo1q0S3uXPn2p/+9CcXRKjOhgwZUmL5FBUVud+KAoZLliyxefPm2dChQy1dTZs2zQU/DzrooKSHUTBpwoQJdtttt7m//b777jv3+uCDD9y4x40bZxMnTrQaNSrvOeTXXnvN+vbta4WFhRHdv/nmG/f617/+5R5y6dSpU8r2Gdqvjh8/Pvz/NddcY0GSin24f/gePXqkbNpQUnFxsTvuvfrqqxHdCwoK3EsB4UWLFrnXwIEDWYQpMGrUKDv//PPd39rPEegFAABAKhHkBQAAQELKODz66KNt2bJlJQK2sf73usUL8iroqCw+UXacXpXht7/9rVV1p556qnuVJxCgIGTdunWT6v+rr75yGd/RNm3aZC+++KL179/f0tmwYcOsTZs2Jbp7AfJYsrOz3Q37WPzj+vvf/x4R4D3qqKPslFNOcUF3BTlXrFhh77zzjqW7HTt22I033hgzmB/PmDFjbOrUqeH/GzZs6B48aNWqla1du9aefPJJF1hSAPjWW291wVV/wC+aMuLK87tNZv1qGi655JJwgFfBygEDBrhs6927d9uaNWvs3Xfftc8//9xSSb+PdevWub+VSXnsscdakKRiH64sb/yP9ivav0RLFAA/5JBD3INLsfgD78rg9Qd4jzvuOPeqV6+ebd682T7++GNbvHgxq2MPab+m/Z3owREtX+1Dta/XPl8PigAAAAApEQIAAABK8Yc//CGkU0e9MjIyQlu2bAl/9rvf/c51r1evXqhRo0bu7zPOOCP8+Y4dO0K1a9cOD3/llVeGP/vLX/4S7q7X559/7rr7u8V6tWrVKjwOf/cJEyaEVqxYETr77LND2dnZoTp16oS6du0aeumll0rM0/HHHx8eTn/7RY/zgw8+CPXt2ze07777JhxnaaLn97PPPgtNnz491KlTJ7eMDjjggNBFF10U+vLLL0sMe8cdd4T69esXOuyww0I5OTmhWrVqhfbZZx837DXXXBNav359iWEGDx4cc5mJ/vc+U3+ffPKJm8f99tvPdVu4cGHS8zVp0qTwuLKyskIdOnQI//+b3/wmlG60Tv3rIdl59S+z6OUZj9aNN0ybNm1CxcXFJfrZvXt36G9/+1soncT63dWsWTO0cuXKuMvR+/3KsmXLIj478MADQ1988UXEd+h/dff39/7774c/1/j8n2k7raj1u3z58qSG0f7FP5976rTTTgt/55gxY1K6zxBN68iRI0OHH354qH79+u732bJly9Bvf/vb0FtvvRVzmCeeeCJ00kknhRo3bhzez2h7P+WUU0J//OMfQxs3bqzQfbjMmjUr3K1GjRqhDRs2lJjOzp07h/vp379/xGf/+te/QsOGDQv94he/cMcm7be179SyiDWuVNB68R8bk+XfT0f/jhLxDxN9DItHx0ZvmN69e8fs58cffwy9/vrrSU9/9O9U28SCBQtCffr0cduOXieffHJoyZIlMYffvn17aOrUqaGePXu643ZmZqbb9nQ8+vvf/55wWcV66bdblu1Pvv3229BNN90U6tKlS6hhw4ZuGg466KDQeeedF1q6dGmJaY7e7levXh2aPHlyqF27du43Fr0+BgwYEO73qquuSnrZAgAAAKUhyAsAAIBSKaDpv6H53HPPhT9r3ry563bCCSe4gJ7+VjD0p59+cp/rZq9/2Hnz5lVYgEA3VuvWrVuifwUJNB3lCfL+8pe/dDdtkxlnaaLnV8ss1vy1aNGiRCBCgd1Ey0Q3x//5z3+WK8h71FFHuQBQWQNj8vPPP7vgpT/Ycvfdd0cEBuMFoBIpbRuIfnnBoXQL8l599dXhYbQOP/3001AQ+JdPs2bNwn+ffvrpSQV5L7nkkojPZsyYEfN71N3fn4arjCDvhx9+GDGMtmFt2xWpsLDQBR+978zPz0/pPuOFF14o8buOfo0bNy7uAxvxXv7lWVFBXgX+GjRoEO6uh1yig+3+4V5++eXwZw899FDMfbZ/X/nOO++EUsnb52lfWtZA794M8ipw6g2joLc/YF9e0b9T7SN0fIxe7nooITpoq4cWDj300ITbiv/hh4oI8uqBlCZNmsTtV8cwBfD9orf7Xr16JVwf06ZNC3/Wtm3bPV7mAAAAgIdyzQAAACiV2tCNbpe3X79+9t///tc2bNjguh1//PGuvO9f//pX+/7772358uWuNG10OWeVhiyNSrZ+9tln9sADD8QsuRqvfVl9V/Pmze13v/uda+v08ccfD5diveOOO6xPnz5lXttqLzXV4/T84x//cO0jqhz2woULw2316Xuuuuoqmz9/frhfTUPv3r1dqVuV8lTbxlr2Tz/9tCv7u3XrVlce96WXXirzdHz00Udu/Woef/GLX7jStPXr109qWC1zrSvPoEGDrHv37q4Eq7YXvR555BEbO3aspaunnnrK3n///RLdVRo4XturKsd55513luiubfPSSy8N/69169myZYtbvirVqdLD+n3ot3XEEUdYOjv55JPdNqF2Ol9++WV744033LaYSPTv/rzzzovZn8o3+8vMvvXWW3HH+c9//jPmMlfZ5URlyZNZv+3atXP7r507d4bbxlV70iqfrPWjbdrbx6WKSuD72//t1q1byvYZKiut+fPmR+ViL7roIrd9anlo3y1qL1nllr02Q++9997wd2kb9dot1/g/+eQT14ZyMsq7D/do/6NtY+bMmeH2vf/whz+EP3/sscfCf6uktkqge/vryy67LNwGtH5rZ511lh5ud+XBNU3aV5599tm2evXqlLVVrvZr77//frcvPfHEE+31118vdxvzmudY5Zo1X14J4GhaP7F+G9q+/SXOtd288MIL7u9PP/3Ufa790DHHHOM+0/F
2026-04-21 21:18:33 -05:00
"text/plain": [
"<Figure size 1920x540 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: fes_minus_nofes_delta.png\n"
]
}
],
"source": [
"fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\n",
"fig.suptitle('Within-pair Δ = FES NOFES (positive → FES better)',\n",
" fontsize=12, fontweight='bold', y=1.03)\n",
"\n",
"for ax, (key, title, _) in zip(axes, METRICS):\n",
" labels, deltas = [], []\n",
" for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='FES'), None)\n",
" nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='NOFES'), None)\n",
" if fes is None or nof is None: continue\n",
" deltas.append(fes[key] - nof[key])\n",
" labels.append(f'{subj}\\n{pair[\"name\"].split()[0]}')\n",
"\n",
" colors = ['#E05C2A' if d > 0 else '#2A7BE0' for d in deltas]\n",
" ax.bar(np.arange(len(deltas)), deltas, color=colors, edgecolor='white', zorder=2)\n",
" ax.axhline(0, color='k', lw=0.8)\n",
" ax.set_xticks(np.arange(len(deltas)))\n",
" ax.set_xticklabels(labels, fontsize=8)\n",
" ax.set_title(f'Δ {title.split(\"(\")[0].strip()}', fontsize=10, fontweight='bold')\n",
" ax.grid(axis='y', alpha=0.3)\n",
" ax.spines[['top','right']].set_visible(False)\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig('fes_minus_nofes_delta.png', dpi=150, bbox_inches='tight')\n",
"plt.show()\n",
"print('Saved: fes_minus_nofes_delta.png')"
]
2026-04-21 13:01:49 -05:00
},
{
"cell_type": "markdown",
"id": "b3db60ba",
"metadata": {},
"source": [
"---\n",
"## Summary Statistics"
]
},
2026-04-22 04:21:48 -05:00
{
"cell_type": "markdown",
"id": "9f27a80e",
"metadata": {},
"source": [
"---\n",
"## Figure 4 — Within-Session Learning Rate \n",
"\n",
"This analysis attempts to measure the \"learning rate\" by looking at the accuracy trend over the course of a single run (dividing chronological trials into thirds). \n",
"\n",
"**Is this a reliable gauge?** \n",
"In short, it is **highly noisy and generally unreliable for single runs**. A standard BCI run might only contain 30-40 trials. Dividing this into thirds yields ~10-13 trials per bin. Accuracy estimates on 10 trials have very high variance, making the slope (trend) highly susceptible to chance. While this is a standard *exploratory* step, true BCI learning curves are usually evaluated **across multiple sessions or days**, rather than within a single 15-minute run. However, aggregated across many subjects and runs, it might show a macro-level effect of FES vs. NOFES."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "a9db789c",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABZAAAAI7CAYAAAByVb7+AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAABAABJREFUeJzs3Qd4JHX5B/B3+242vV/L9aPcHUUp0kFBEATpCtJFQASpUgUEFaUJAlKkI4qI4h+xoCBVQECk3XEHV5Or6ckm28v8n+9vM5vZzeymJ5vk+7lnnly2ZXZ2dnbmu++8P4umaZoQEREREREREREREWWwZl5ARERERERERERERMQAmYiIiIiIiIiIiIiyYgUyEREREREREREREZligExEREREREREREREphggExEREREREREREZEpBshEREREREREREREZIoBMhERERERERERERGZYoBMRERERERERERERKYYIBMRERERERERERGRKQbIRERERERERERERGSKATIRERERERERERERmWKATEREk8qjjz4qFoslNa1fv37Qj2G8/w9/+MMB32///fdP3Q//p/F5/Wj0nXbaaanXaM6cOVzkeeYXv/hF6vW55pprJtz6kO/bgVzzN5rvjVdeeSXt7+L3qQbL2rgM8FoMFF4P/X54nUbaSO4D7L333upxbDabfPDBByM2j0REREPFAJmIiMbc888/n3YA+Mc//rHPbaZNm5a6vqioSOLxeNr1L774Ytpj/N///d+UPvh+5pln5Ktf/apabk6nUy2zuro6+cIXviDf/va35f777x/vWZxUMgOkwYQYNH5hkz5ZrVbxer2ycOFCOemkk+Tf//73qP3NsV432tvb5YYbblD/LygokAsuuED9H1+GmS2L/iYaOZmvQb4Y7WC1v+3nQKd8+yJhNF155ZXqZyKRkEsvvXS8Z4eIiEjsXAZERDTW9tprL1VVo4fCr776qhxzzDGp6z/77DPZunVr6vfu7m753//+J7vuumvqMtxHhwPLffbZR/0ft7nllltS15WXl8tY+c53vqNCXJg1a9aY/d0zzzxTHnroobTLotGoWm4bNmyQt99+W55++mk5++yzJd+N5+tHA/eNb3xDlixZov5fUlIy4RadpmkSCARk9erVavrtb3+r3kOnn366THS33nqrtLW1qf8jHK+srBz1vznR14eJsizmz5+ftn3E7zRwV199tXR2dqr/669RvjrssMNkm222kU8//VT+9a9/qelLX/rSeM8WERFNYQyQiYhozKE69nOf+5y8++67fcJgs9/1y7IFyDgQrKioUP9fvHixmsbD17/+9TH/m//85z/TwuOdd95ZDj74YBVcIET6+OOPR7S6crSN5+s3VSA89fv9UlhYOOTHOOSQQ9Q00eyyyy7qfRoKheStt96Sv/3tb6llcvnll8upp56qqpMnqkgkIg8++GDq9xNPPDH1/y9/+ct9XvOnnnpK/vvf/6Z+v+qqq6SsrGzAf8/n80lxcfGEXR9Gw2guC3wxORmqUTO/KNQ/y1544YXU7+ecc06fgBxfKOpfjgwFzsYZCn09H48vI66//nr1/3vuuYcBMhERjS+NiIhoHHz/+9/X8DGEyWKxaK2tranrvvnNb6rLCwoKtJKSEvX/r371q6nrA4GA5nK5Uvc/77zzUtc98sgjqcsxrVu3Tl1uvMxsmj17duoxjJdfd9112scff6wdddRRWllZmeZ2u7Vdd91V+8tf/tLnOe23336p++H/RpmP+d5772lHHHGEVlpamvMx+3PRRRelHnf+/PlaLBbrc5tIJKI9//zzpvdvaWnRfvjDH2q77LKLVlxcrDkcDm3GjBnaCSecoL3zzjum93nyySe1gw46SKuurtbsdrtWWFiolt/BBx+s/eAHP9C2bNky5Ntne/10iURC++1vf6sdcsgh6vEwv1hHdtttN+3GG2/UOjs7+8wv/pb+eKeeeqq2atUq7cQTT9Sqqqo0p9OpLV68WHv00UcHvMzN5hO/D0R3d7d22223aXvttZdanzD/eB5YF/75z3/2uX00GlXL6NBDD1WvL9YXm82mnvPnP/95dV1bW1u/6+L69evV+wp/C+83fX6HumxwW7P3znAeE7AufOtb31LziffFDjvsoN1///3a2rVrh7S8sf4Y74f5Mdp9993Trt+6dWva9Q899JD29a9/Xdt+++3V/OP18nq92rbbbqudddZZ2ieffJL1uWebhrM+9Oepp55K/Z3a2lotHo/nvL3xdTR7v2FbZbwe83vFFVdoc+fOVe9lfXnmWh/+9Kc/aSeffLJ6LWtqatQ64PF4tHnz5ql18j//+U+f+epvO5BNMBjUfvazn6ntAd4jeK9guS5atEg79thjtZtuuinr88d8Y/tx8cUXa3V1dWo+MY9YBqFQaMDzZ7YsXn755X7Xi8x100zm4+D3bK8V5hnbxG222UY9Fyz7s88+O20bmfn6m03GvzGU7e9AZc6/8e/mek/jtXj99de1Aw88UCsqKlL7Dfvvv7/21ltv9blv5rYp12M+88wz2h577KE+r4zvW7yn7rzzTrVNwH7ItGnTtHPOOUdrbm7OuQ+wcuVK7YwzztAWLFigtm1YdniP4rMX93/xxRf7zO/y5ctTj4f3W+b2iYiIaCwxQCYionGBsNR4wIaQQTdz5kx12Ze+9CXtsMMOU/9HcKaHIS+99FLafZ9++ulRC5BxEIiwI/P2VqtVzcdQAmSEVjigH8hj9ud73/te6v4VFRXaZ599NuD7vvvuuypUyLZMEL7cddddaff5yU9+MqjAYbC3zxXM4IsDhM65HgvBVuYyMIYGCLEQlJvddzAh8lAC5DVr1mgLFy7MOf+XXXZZ2n26uroGtO5mhvbGdRF/M/N1NguQB7NsBhogD+YxN23alDWARaA62OU9kAD56KOPTnv/hcPhtOsR0uda9giQjOvvYALkoawP/Tn99NNT98WXXv0ZbIC8zz77mC7PXOvDMccck/M5Yrk/8cQTIxIg44uqgS7/zPnGFwRYX83ug8c1fjk3EQLkzNdKnw444ADTec026X9jqNvf0Q6Qv/zlL6t1KHNe8Lm9YsWKIQXIe++9d9b1xvgeM074gg+hstk+AMJjPYge7OuPL0D022S+T4iIiMYSW1gQEdG4QM/izD7IRx55pKxdu1Y2btyoLttvv/3E4/HIX//6V+no6JAPP/xQtWjIbHGx77779vv3cLrsmjVr5L777jM9RTZbr0r8rZkzZ8o3v/lN1U8YvVL1gW1uvvlmOeCAAwb93NGTeKQeE61AdK2trapn4tKlS9Wp+lhWWM477rhjn/t1dXXJ4YcfLo2Njer3mpoaOeGEE9Qpwhig8LXXXlOvDQbg2mmnndSI8HDnnXemHgN/Q+/5jOfx0UcfyXvvvZf2dwZ7+1wuvvhi+cc//pH6fY899pCDDjpI9cz+3e9+py5bt26dfO1rX1OPbbf33c3B5ThF/6KLLpJgMCgPPPBAah382c9+ploYjAb8jaOOOkpWrVqlfsfp0Hj9p0+fLu+8844899xz6nK8/ni99NYD6O89d+5cNRjijBkz1LzjsfA80X4AfXzr6+vlxz/+sdx9992mf1v/m3h/YZ3YtGlTquXLaC+bwTzm+eefr56Lbs8995QDDzxQtbr585//LCMJLSzefPPNtFPmjz32WDUApVFVVZVaZxcsWKCeh8PhUP3Z//SnP6l1OBwOy3nnnSfLli1L9VjFQF833nhj6jHQMgPr/kisD/3B+1a32267yUh7/fXX1ePifYfnPpB2F6Wlpep13H777dXt3W63tLS0qO36ypUr1XYP2xn0wcd1Q4XHMr6eWL5Y7tjW4TMFLUvwGZBNc3Oz+pxBmwP0jf7973+fuj0e95e//KV873vfG1bv4sw2DcY2DiPdkxevFZYBlvtvfvOb1AB0L7/8svoM2n333VP9mrG+YvBFY5sX47yP1PZ3NGCZbrvttnL00UfLBx98kGpLg+3NL37xC7n33nsH/Zho+4RtJJZDdXV1qt3Ws88+K4888kjqdvjcPOWUU9R7AZdjXTOD6zA
"text/plain": [
"<Figure size 1440x540 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: within_session_learning_rate.png\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJ8AAAJ6CAYAAACG1ON6AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjgsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvwVt1zgAAAAlwSFlzAAASdAAAEnQB3mYfeAAAkBJJREFUeJzt3Qd8U2X7//Gre0DZq8oGUdkgKKIIqCgKAgKiqOD4ubeoiLjAAeJ+xPE4mKK4QHAh4gAXAoIi41GmLMuG0tJN839d9/Ok/zRNm7TJaU6Sz/v1Ci3JyZ2Tk5ycnG/v+7qjHA6HQwAAAAAAAAALRFvRKAAAAAAAAED4BAAAAAAAAEvR8wkAAAAAAACWIXwCAAAAAACAZQifAAAAAAAAYBnCJwAAAAAAAFiG8AkAAAAAAACWIXwCAAAAAACAZQifAAAAAAAAYBnCJwAAAAAAAFiG8AkAAAAAAACWIXwCgHL6+++/JSoqqugyffr0otv0d9fbdFl/XH311UVtNW3a1Of7BXo9/OG6HuPGjQvaeiA49DV3fQ8gtM2fP7/otRwxYkRA2oy090hBQYE0a9bMPNeqVavKP//8E+xVCkvhcPwEgHBC+AQgrIIg/bKJ8OD6urpekpOTzYnbkCFD5LPPPhO7W7x4MQFciAjGey6Uwtn8/HwZPXq0+V3Xd8yYMSWW2b17t9x3333Svn17SUlJkbi4OKlbt66cdNJJMmDAAHn44Ydl/fr1EsliY2Pl3nvvNb8fPXrUbJNIsXLlSrnqqqukRYsWkpSUJImJiXLcccdJhw4d5PLLL5dJkybJoUOHJNL16tWr6HNBfweAcBAb7BUAgHDStWtXeeaZZ4r+X6tWLb/au+yyy6Rt27bm9+rVq/u9fuEgOzvbhI56mTt3rjz66KO2P2mPZOedd57p3RHKeM/9/x4hGzZsML/36dNH2rRpU2w7rV69Ws4++2w5ePBgsev3799vLn/99Zd8+umnUrt2bWndurVEsmuvvVYefPBBSU9PN9v1/vvvl1atWkk4mzFjhnnehYWFxa5PS0szlz/++ENmz54t559/vtSsWTNsjuMAgP8ifAKAANKTMfcTMn/07dvXXCJd8+bN5eabb5a8vDxzgvLBBx+Iw+Ewtz355JNy5513BvVkJZxlZGSYHiwV1b17d3MJNbznSnrllVeKftdeKu50H3UGT9pbbNiwYaaHi+6rW7ZskeXLl0d8rycn7fUzaNAgE8hoGPP666/Lc889J6EoJydHYmJiTC+30mhvpltvvbUoeDr++ONNT8IGDRpIVlaWCSZ//PFHE0KF23EcAPA/DgAIYVu3btUEouhy1VVXFbu9SZMmxW7buHGj4/LLL3fUrVvXER8f72jTpo1j+vTpHtv+559/HNdee62jXr16jsTEREeHDh0cb7zxhmPLli3FHnPatGlF99HfXW/T9VM9evQouu68884r8VirVq0qdr9PPvnEXK/r7LxOn4u75cuXO84//3xHSkqKo2rVqo6zzz7bsXjx4lLXQ/Xs2bPoev3d1XfffVfsfvp/12191113mefSuHFj83hxcXFmW/bu3dvx+uuvOwoKCkqso2t7jz76aJmvZ2n3c1/PSy+9tNjtv/zyi8+Pqf93vd2V+7bZvXu346abbnIcd9xx5v3SsmVLx9NPP+0oLCz0+Xm4b1Nft0FeXp7jzTffdJxzzjmOOnXqmG1du3ZtR58+fRwffPCBx/voug0aNMjRqlUrs2xsbKx5ndq1a+e4++67HTt27ChxH/f32L59+xw333yz4/jjj3fExMQUrW9Ft41V2zs9Pd0xatQoR8OGDR0JCQnmOU+cONGRn58flPfcxx9/7BgxYoSjffv2jvr165v1T0pKcjRv3txxxRVXlFje9bmXdnHdbyvyfli2bJlZb91fdRvpRV/X7t27O+68807HihUrfN422pZzvfR9pdvflf6/tM9FV5s3b3asXbvW5/eIys3Ndbz22muOXr16Fb2va9Wq5TjrrLMckydPduTk5JS4j/t7QJ/rBRdc4KhRo4YjOTnZccYZZzgWLFjgcR0ra1t//vnnReuo7evj+sr92LZhwwbHsGHDzPrq8apz586Ot99+2+N9dX96//33Hf369XM0aNDAPL/q1aubz3b9HNd9yNv2XLJkidk+ej/396on8+fPL/W97Wrp0qXmM6is5+qqrGOd+2dbZmamY/To0eZ33T9139Tn4v7+KavNim4/5z6in2dnnnmmef/q/fT7xemnn+4YP368x33B06W0fQsA7I7wCUDEhE96UlitWjWPX+bcAyg9SW/UqJHHZS+66KJyh08zZswouk5P6NPS0oo93j333FN0u554O0OcssKnr776ynyBdl+/6Oho86U40OHTp59+6vVLsZ7cHTt2rFibVgQBGjq43r5p0yafH9PXMERPTPS18PQ8x40bZ2n4dODAAUfXrl3L3NbDhw8vsa31BLas+9SsWdOxbt26YvdxfY/pietJJ53kcX0rum2s2N56EtmxY0ePyw4YMCAo77khQ4aUue11v5w1a5bH517axbnfVuT9oCG0hjRl3ac820ZPjp3369SpU4nbDx48WKxt3V6lnYS7K+s9okGEBillPQ99L+zdu7fY/Vxv11DA02dlVFRUic/+ytzWhw4dMuvgXObHH390+Mr12NalS5eiEMj94gw1nDRoufDCC8tcVw35jh49Wur21LBEj2Oe3qulmTt3brHl58yZU6HnWtHwSQPh0l5XDRVd/3BSVpsV3X6rV68u9TuFXvT1U4RPAMIZw+4ARAwdrqVDs+6++25Tw+XNN9+UY8eOmdueeuopUwTV6fbbb5cdO3YU/b9Hjx7Su3dvWbVqlalZUl6XXHKJ3HHHHaa+hz6m1rXQ9VA6DEH/73TNNdeYIQzehjmMHDnSDENTWpRU60O1bNnSFET+/PPPxYoiuVoUtkuXLqaAsNag0u3422+/mcfU85MFCxaYOkxDhw4Vqwoe6+v44YcfFl136qmnmqE9gabDhLQYrg4l0iEyr732mnm+6vnnn5exY8eWOczEH/rarlixwvyu6+B8bdetWyfvv/9+0XtG64Hpejg1bNjQFKdt0qSJea/r+2Lnzp1mmKIOh9KhL1owurSi2c7aPOecc46ceeaZZnkdHlMZ26Y8bT7yyCPy+++/F923Xbt2MnDgQNm4caPZPsF4z9WoUUPOPfdcU8tIt70+F92Wui/++eef5jXT4aE61Mj5PPv372+KcztpHSWtkeVea6Yi7wfdfjqrmtLX8MorrzTDJ3VmNd1OP/zwQ7m2w/fff1/s+bvT56wzijlnBtPXbNq0adKtWzfp2LGjuY9+hpa3dp3OqKefu05aD0jb1O3xxRdfmOv0vXDFFVfIV1995bENfa66vfRzWF8TrbGkr6t+Zt1yyy1mm6emplb6ttb3zAknnFBUR2vJkiVyxhlnSHn9+uuvZvjaDTfcILm5uWa763BZNX78eLnoooukU6dO5v/33HNP0XaLjo42n9W6/2zbtk3efvttc3+dJOGuu+6SN954w+PjLV261Ayr1KGXjRs3ljVr1njd3/U9oJ9HzuHSuh/o59Rpp51mbjv99NPNc7fqM3XPnj3mM/D666+XOnXqmM/EzZs3m9sWLVpkhpTqMdqbimw/LSp/4YUXyq5du4rVldL6aErf3/oautbI0/eUfia6DgF2vS8AhKRgp18AUFk9n/QvzDq8zUmHkLne98iRI+Z67ZXk+tdo/Sum61+5r7nmmnL3fFI6lMl5vf4l3+mbb74pto46rM+ptJ5P7733XrHHeeSRR4r9Zfbkk08OeM8nJ+3xoUNPXn75Zcezzz7reOaZZ8zwEud9dKiiq4r2tCjrL8uuPRp27txZ5n0r2vNJL/PmzSu67cUXXyx22x9//GFJz6c1a9YUW16Hd7jSYSPO23TohvtQR30fL1y40Az/eP75583rM3DgwKL76HAg1+E9ru8xvegQIU8qum0Cvb21N41rD0btMZWVlVV0v4ceeigo7znnuv3000/mc0DXX7e9e4+p77//vtTH9LSuFX0/uL7mOhzRnW6
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saved: individual_learning_rates.png\n"
]
}
],
"source": [
"# Calculate within-session learning rate using chronological online EARLYSTOP markers\n",
"\n",
"lr_results = []\n",
"for subj in subjects:\n",
" subj_ses = sessions[subj]\n",
" for pair in PAIRS:\n",
" # Check both conditions are available to construct pair\n",
" if pair['online_fes'] not in subj_ses or pair['online_nofes'] not in subj_ses:\n",
" continue\n",
" \n",
" for cond_key, cond_label in [('online_fes', 'FES'), ('online_nofes', 'NOFES')]:\n",
" te = subj_ses[pair[cond_key]]\n",
" \n",
" # Re-parse raw markers for chronological accuracy (bypassing class-level epoch sorting)\n",
" fp = os.path.join(DATA_DIR, te['file'])\n",
" _, _, mk, _, _, _ = load_xdf_file(fp)\n",
" \n",
" # We assume a trial begins at MI_BEGIN (200) or REST_BEGIN (100)\n",
" target_starts = (mk == MI_BEGIN) | (mk == REST_BEGIN)\n",
" start_indices = np.where(target_starts)[0]\n",
" \n",
" trials = []\n",
" for i, idx in enumerate(start_indices):\n",
" # Search ahead to next start marker\n",
" next_start = start_indices[i+1] if i + 1 < len(start_indices) else len(mk)\n",
" interval = mk[idx:next_start]\n",
" # Check if it was successfully typed\n",
" correct = int(MI_EARLYSTOP in interval or REST_EARLYSTOP in interval)\n",
" trials.append(correct)\n",
" \n",
" if len(trials) < 9:\n",
" continue # Need at least a minimal number to split into thirds meaningfully\n",
" \n",
" # Split chronologically into thirds\n",
" splits = np.array_split(trials, 3)\n",
" accs = [np.mean(s) for s in splits]\n",
" # Calculate slope (rate of learning over the 3 portions of the run)\n",
" slope = np.polyfit([0, 1, 2], accs, 1)[0]\n",
" \n",
" lr_results.append({\n",
" 'subj': subj,\n",
" 'pair': pair['name'],\n",
" 'cond': cond_label,\n",
" 'accs': accs,\n",
" 'slope': slope\n",
" })\n",
"\n",
"# Plot learning rates\n",
"if lr_results:\n",
" fig, axes = plt.subplots(1, 2, figsize=(12, 4.5))\n",
" fig.suptitle('Within-Session Learning Rate (Trial split into Thirds)', fontsize=12, fontweight='bold', y=1.05)\n",
"\n",
" # 1. Plot mean slopes by condition\n",
" fes_slopes = [r['slope'] for r in lr_results if r['cond'] == 'FES']\n",
" nofes_slopes = [r['slope'] for r in lr_results if r['cond'] == 'NOFES']\n",
" \n",
" axes[0].bar(['ONLINE_FES', 'ONLINE_NOFES'], [np.mean(fes_slopes), np.mean(nofes_slopes)], \n",
" yerr=[np.std(fes_slopes)/np.sqrt(len(fes_slopes)), np.std(nofes_slopes)/np.sqrt(len(nofes_slopes))],\n",
" color=[cond_color['FES'], cond_color['NOFES']], capsize=5)\n",
" axes[0].set_title('Average Performance Trend (Slope)', fontweight='bold')\n",
" axes[0].axhline(0, color='k', linestyle='--', lw=1)\n",
" axes[0].set_ylabel('Slope (Δ Acc per Third)')\n",
"\n",
" # 2. Plot lines for each run overlaid\n",
" for idx, r in enumerate(lr_results):\n",
" jitter = np.random.uniform(-0.02, 0.02)\n",
" axes[1].plot(np.array([1, 2, 3]) + jitter, r['accs'], color=cond_color[r['cond']], \n",
" alpha=0.3, marker='o', label=r['cond'] if idx < 2 else \"\")\n",
" \n",
" axes[1].set_xticks([1, 2, 3])\n",
" axes[1].set_xticklabels(['Early', 'Middle', 'Late'])\n",
" axes[1].set_title('Accuracy per third (over individual runs)', fontweight='bold')\n",
" axes[1].set_ylabel('Marker Accuracy')\n",
" \n",
" # Custom legend\n",
" axes[1].legend(handles=[Patch(color=cond_color['FES'], label='FES'),\n",
" Patch(color=cond_color['NOFES'], label='NOFES')])\n",
" \n",
" for ax in axes:\n",
" ax.spines[['top','right']].set_visible(False)\n",
" ax.grid(axis='y', alpha=0.3)\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig('within_session_learning_rate.png', dpi=150, bbox_inches='tight')\n",
" plt.show()\n",
" print('Saved: within_session_learning_rate.png')\n",
"\n",
" # --- 3. Plot individual learning rates grouped by subject and pair ---\n",
" subjects_with_lr = sorted(list(set(r['subj'] for r in lr_results)))\n",
" fig2, ax2 = plt.subplots(figsize=(max(8, len(subjects_with_lr) * 2.5), 5))\n",
" fig2.suptitle('Individual Run Learning Rates (Slopes) per Subject', fontsize=12, fontweight='bold', y=1.05)\n",
" \n",
" x = np.arange(len(subjects_with_lr))\n",
" n_pairs = len(PAIRS)\n",
" total_bars_per_subj = n_pairs * 2\n",
" width = 0.8 / total_bars_per_subj\n",
" \n",
" offsets = np.linspace(-0.8/2 + width/2, 0.8/2 - width/2, total_bars_per_subj)\n",
" \n",
" bar_idx = 0\n",
" for p_idx, pair in enumerate(PAIRS):\n",
" for c_idx, cond in enumerate(['FES', 'NOFES']):\n",
" slopes = []\n",
" for subj in subjects_with_lr:\n",
" val = next((r['slope'] for r in lr_results if r['subj'] == subj and r['pair'] == pair['name'] and r['cond'] == cond), np.nan)\n",
" slopes.append(val)\n",
" \n",
" color = cond_color[cond]\n",
" # Use styling to differentiate pairs\n",
" alpha = 1.0 if p_idx == 0 else 0.6\n",
" hatch = '' if p_idx == 0 else '//'\n",
" label = f\"{pair['name'].split()[0]} {cond}\"\n",
" \n",
" ax2.bar(x + offsets[bar_idx], slopes, width, label=label, color=color, alpha=alpha, hatch=hatch, edgecolor='white')\n",
" bar_idx += 1\n",
"\n",
" ax2.set_xticks(x)\n",
" ax2.set_xticklabels(subjects_with_lr)\n",
" ax2.axhline(0, color='k', linestyle='-', lw=0.8)\n",
" ax2.set_ylabel('Slope (Δ Acc per Third)')\n",
" ax2.legend(bbox_to_anchor=(1.01, 1), loc='upper left')\n",
" ax2.grid(axis='y', alpha=0.3)\n",
" ax2.spines[['top','right']].set_visible(False)\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig('individual_learning_rates.png', dpi=150, bbox_inches='tight')\n",
" plt.show()\n",
" print('Saved: individual_learning_rates.png')\n",
"\n",
"else:\n",
" print('No sufficient data for learning rate calculation.')"
]
},
2026-04-21 13:01:49 -05:00
{
"cell_type": "code",
2026-04-21 21:18:33 -05:00
"execution_count": 11,
2026-04-21 13:01:49 -05:00
"id": "cf55268e",
2026-04-21 21:18:33 -05:00
"metadata": {
"execution": {
2026-04-21 22:24:23 -05:00
"iopub.execute_input": "2026-04-22T03:21:13.263110Z",
"iopub.status.busy": "2026-04-22T03:21:13.263012Z",
"iopub.status.idle": "2026-04-22T03:21:13.268334Z",
"shell.execute_reply": "2026-04-22T03:21:13.267954Z"
2026-04-21 21:18:33 -05:00
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2026-04-21 22:12:59 -05:00
"=== Aggregate across 7 complete (subject × pair) comparisons ===\n",
2026-04-21 21:18:33 -05:00
"\n",
"Metric FES (mean ± sd) NOFES (mean ± sd) paired Δ\n",
"---------------------------------------------------------------------------------------\n",
2026-04-21 22:12:59 -05:00
"Classification accuracy 0.836 ± 0.070 0.819 ± 0.071 +0.017\n",
"Classification amplitude 2.513 ± 2.062 2.362 ± 1.813 +0.151\n",
"Fisher ratio (test SNR) 2.631 ± 2.757 2.542 ± 3.263 +0.090\n",
"μ-band SNR (REST/MI) 1.812 ± 0.700 1.747 ± 0.480 +0.065\n",
2026-04-21 21:18:33 -05:00
"\n",
2026-04-21 22:12:59 -05:00
" acc FES > NOFES in 4/7 comparisons (NOFES > FES in 3)\n",
" |margin| FES > NOFES in 4/7 comparisons (NOFES > FES in 3)\n",
" Fisher FES > NOFES in 3/7 comparisons (NOFES > FES in 4)\n",
" μ-SNR FES > NOFES in 4/7 comparisons (NOFES > FES in 3)\n"
2026-04-21 21:18:33 -05:00
]
}
],
"source": [
"# Build only pairs where BOTH FES and NOFES survived evaluation\n",
"paired = []\n",
"for subj in subjects:\n",
" for pair in PAIRS:\n",
" fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='FES'), None)\n",
" nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n",
" and r['condition']=='NOFES'), None)\n",
" if fes and nof:\n",
" paired.append((fes, nof))\n",
"\n",
"print(f'=== Aggregate across {len(paired)} complete (subject × pair) comparisons ===\\n')\n",
"\n",
"hdr = f'{\"Metric\":<28} {\"FES (mean ± sd)\":>22} {\"NOFES (mean ± sd)\":>22} {\"paired Δ\":>12}'\n",
"print(hdr); print('-' * len(hdr))\n",
"for k, label in [('acc', 'Classification accuracy'),\n",
" ('amp', 'Classification amplitude'),\n",
" ('fisher', 'Fisher ratio (test SNR)'),\n",
" ('mu_snr', 'μ-band SNR (REST/MI)')]:\n",
" fes_v = np.array([f[k] for f,_ in paired])\n",
" nof_v = np.array([n[k] for _,n in paired])\n",
" delta = fes_v - nof_v\n",
" sd = lambda a: a.std(ddof=1) if len(a) > 1 else 0.0\n",
" print(f'{label:<28} {fes_v.mean():>10.3f} ± {sd(fes_v):>6.3f} '\n",
" f'{nof_v.mean():>10.3f} ± {sd(nof_v):>6.3f} {delta.mean():>+12.3f}')\n",
"\n",
"# Sign test (simple)\n",
"print()\n",
"for k, label in [('acc','acc'), ('amp','|margin|'), ('fisher','Fisher'), ('mu_snr','μ-SNR')]:\n",
" d = np.array([f[k] - n[k] for f,n in paired])\n",
" n_pos = int((d > 0).sum()); n_neg = int((d < 0).sum())\n",
" print(f' {label:<10} FES > NOFES in {n_pos}/{len(d)} comparisons (NOFES > FES in {n_neg})')"
]
2026-04-21 13:01:49 -05:00
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
2026-04-21 21:18:33 -05:00
}