Files
ECE374N/Final Project/overall_analysis.ipynb
2026-04-21 13:01:49 -05:00

184 lines
22 KiB
Plaintext
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": "# Motor Imagery Decoder — Train OFFLINE, Evaluate ONLINE (FES vs NOFES)\n\nFor each subject × offline-session, train a CSP + LDA classifier on the OFFLINE recording, then apply it to the two matched ONLINE sessions.\n\n**Pair 1:** train on `S001 OFFLINE_FES` → test on `S002 ONLINE_FES` and `S003 ONLINE_NOFES`\n**Pair 2:** train on `S004 OFFLINE_NOFES` → test on `S006 ONLINE_FES` and `S005 ONLINE_NOFES`\n\n**Metrics reported per (subject × pair × condition):**\n1. **Classification accuracy** — fraction of cued trials correctly classified\n2. **Classification amplitude** — mean |LDA decision-function value|\n3. **SNR** — (a) Fisher ratio of the LDA projection on online data, and (b) mu-band power ratio REST / MI over motor channels C3/Cz/C4"
},
{
"cell_type": "code",
"execution_count": 13,
"id": "578c9128",
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies if needed\n",
"# !pip install pyxdf mne scipy numpy matplotlib"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "857b22c0",
"metadata": {},
"outputs": [],
"source": "import os\nimport re\nimport glob\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Patch\nimport pyxdf\nfrom scipy.signal import welch, butter, filtfilt\nfrom scipy.linalg import eigh\n\nplt.rcParams.update({'font.size': 11, 'figure.dpi': 120})"
},
{
"cell_type": "markdown",
"id": "fe68bf0e",
"metadata": {},
"source": [
"## Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc4b2c55",
"metadata": {},
"outputs": [],
"source": "DATA_DIR = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'Group 2 - Glove')\n\n# Marker codes (same cue-encoding in offline and online)\nMI_BEGIN, MI_END = 200, 220\nREST_BEGIN, REST_END = 100, 120\nTARGET_MARKERS = [100, 120, 200, 220]\n\n# Epoch window (t=0 at cue marker)\nT_PRE = -1.0 # baseline start\nT_POST = 5.0 # epoch end\n\n# Bandpass for MI decoding (mu + beta)\nBP_LO, BP_HI = 8.0, 30.0\n\n# CSP spatial filters (top N/2 + bottom N/2)\nN_CSP = 4\n\nNON_EEG = {'AUX1', 'AUX2', 'AUX3', 'AUX7', 'AUX8', 'AUX9', 'TRIGGER'}\nRENAME = {'FP1':'Fp1','FPZ':'Fpz','FP2':'Fp2','FZ':'Fz','CZ':'Cz',\n 'PZ':'Pz','POZ':'POz','OZ':'Oz'}\n\nMOTOR_CH = ['C3', 'Cz', 'C4']\nMU_BAND = (8, 13)\n\n# Design: per subject, each OFFLINE session trains a model tested on two ONLINE sessions\nPAIRS = [\n {'name': 'Pair1 (train=OFFLINE_FES)',\n 'train': 'S001', 'online_fes': 'S002', 'online_nofes': 'S003'},\n {'name': 'Pair2 (train=OFFLINE_NOFES)',\n 'train': 'S004', 'online_fes': 'S006', 'online_nofes': 'S005'},\n]"
},
{
"cell_type": "markdown",
"id": "21a40df3",
"metadata": {},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e798b039",
"metadata": {},
"outputs": [],
"source": "# ── XDF loading + session parsing ─────────────────────────────────────────────\n\ndef get_channel_names_from_xdf(eeg_stream):\n ch_desc = eeg_stream['info']['desc'][0]\n channels = ch_desc.get('channels', [{}])[0].get('channel', [])\n return [ch['label'][0] for ch in channels]\n\n\n# Tolerates the 'ONOLINE' typo in subj 003 / S005\n_SESSION_RE = re.compile(r'ses-(S\\d+)(O[A-Z]*LINE)_(FES|NOFES)')\n_SUBJ_RE = re.compile(r'SUBJ_(\\d+)')\n\ndef parse_session(path):\n \"\"\"Return (subject, session_id, kind, stim) or None.\"\"\"\n base = os.path.basename(path)\n m_subj = _SUBJ_RE.search(base)\n m_ses = _SESSION_RE.search(base)\n if not (m_subj and m_ses):\n return None\n ses_id, raw_kind, stim = m_ses.group(1), m_ses.group(2), m_ses.group(3)\n kind = 'OFFLINE' if 'OFF' in raw_kind else 'ONLINE'\n return m_subj.group(1), ses_id, kind, stim\n\n\ndef load_xdf_file(filepath):\n streams, _ = pyxdf.load_xdf(filepath)\n\n eeg_stream = marker_stream = None\n for s in streams:\n stype = s['info']['type'][0].lower()\n if stype == 'eeg': eeg_stream = s\n elif stype == 'markers': marker_stream = s\n if eeg_stream is None or marker_stream is None:\n eeg_stream = streams[0]\n marker_stream = streams[1] if len(streams) > 1 else None\n\n eeg_timestamps = np.array(eeg_stream['time_stamps'])\n eeg_data = np.array(eeg_stream['time_series']).T\n channel_names = get_channel_names_from_xdf(eeg_stream)\n sfreq = float(eeg_stream['info']['nominal_srate'][0])\n\n valid_idx = [i for i, ch in enumerate(channel_names) if ch not in NON_EEG]\n channel_names = [channel_names[i] for i in valid_idx]\n eeg_data = eeg_data[valid_idx, :]\n channel_names = [RENAME.get(ch, ch) for ch in channel_names]\n\n ts_arr = np.asarray(marker_stream['time_series'], dtype=float)\n marker_data = ts_arr[:, 0].astype(int)\n marker_ts = ts_arr[:, 1]\n keep = np.isin(marker_data, TARGET_MARKERS)\n return eeg_data, eeg_timestamps, marker_data[keep], marker_ts[keep], channel_names, sfreq\n\n\ndef bandpass(data, lo, hi, sfreq, order=4):\n nyq = sfreq / 2.0\n b, a = butter(order, [max(lo, 0.5) / nyq, min(hi, nyq - 0.1) / nyq], btype='band')\n return filtfilt(b, a, data, axis=-1)\n\n\ndef extract_epochs(eeg_data, eeg_ts, marker_data, marker_ts, sfreq, begin_code,\n t_pre=T_PRE, t_post=T_POST):\n \"\"\"Returns (n_epochs, n_ch, n_samp) — all epochs trimmed to the same length.\"\"\"\n epochs = []\n n_pre = int(abs(t_pre) * sfreq)\n\n for bi in np.where(marker_data == begin_code)[0]:\n t_start = marker_ts[bi]\n i0 = np.searchsorted(eeg_ts, t_start + t_pre)\n i1 = np.searchsorted(eeg_ts, t_start + t_post)\n if i0 < 0 or i1 > eeg_data.shape[1]:\n continue\n ep = eeg_data[:, i0:i1].copy()\n if ep.shape[1] > n_pre:\n ep -= ep[:, :n_pre].mean(axis=1, keepdims=True)\n epochs.append(ep)\n\n if not epochs:\n return np.empty((0, eeg_data.shape[0], 0))\n min_len = min(e.shape[-1] for e in epochs)\n return np.stack([e[:, :min_len] for e in epochs])\n\n\ndef load_session_epochs(filepath):\n \"\"\"Bandpass continuous EEG, epoch on MI/REST cues, keep ACTIVE window [0, T_POST].\n Returns X (n_trials, n_ch, n_samp), y (1=MI, 0=REST), ch_names, sfreq.\n \"\"\"\n eeg, eeg_ts, mk, mk_ts, ch_names, sfreq = load_xdf_file(filepath)\n eeg_bp = bandpass(eeg, BP_LO, BP_HI, sfreq)\n\n mi = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, MI_BEGIN)\n rest = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, REST_BEGIN)\n\n n_pre = int(abs(T_PRE) * sfreq)\n if mi.shape[-1] > n_pre: mi = mi[..., n_pre:]\n if rest.shape[-1] > n_pre: rest = rest[..., n_pre:]\n\n n = min(mi.shape[-1], rest.shape[-1]) if (mi.size and rest.size) else 0\n mi, rest = mi[..., :n], rest[..., :n]\n\n X = np.concatenate([mi, rest], axis=0)\n y = np.concatenate([np.ones(len(mi), int), np.zeros(len(rest), int)])\n return X, y, ch_names, sfreq\n\n\n# ── CSP + LDA (2-class, numpy/scipy only) ────────────────────────────────────\n\ndef _mean_cov(X):\n \"\"\"Average trace-normalized trial covariance. X: (n_trials, n_ch, n_samp).\"\"\"\n covs = np.einsum('ijk,ilk->ijl', X, X)\n covs /= np.trace(covs, axis1=1, axis2=2)[:, None, None]\n return covs.mean(axis=0)\n\n\nclass CSPLDA:\n \"\"\"Common Spatial Patterns (log-var features) + Linear Discriminant Analysis.\"\"\"\n\n def __init__(self, n_csp=N_CSP, reg=1e-6):\n self.n_csp = n_csp\n self.reg = reg\n\n def fit(self, X, y):\n assert set(np.unique(y)) == {0, 1}, 'CSPLDA requires both classes in training set'\n C1 = _mean_cov(X[y == 1])\n C0 = _mean_cov(X[y == 0])\n # Generalized eigenproblem: C1 w = λ (C0 + C1) w\n evals, evecs = eigh(C1, C0 + C1)\n order = np.argsort(evals)\n k = self.n_csp // 2\n # Stack bottom-k (max class-0 variance) and top-k (max class-1 variance) filters\n self.filters_ = np.concatenate([evecs[:, order[:k]],\n evecs[:, order[-k:]]], axis=1).T # (n_csp, n_ch)\n\n F = self._features(X)\n mu1, mu0 = F[y == 1].mean(0), F[y == 0].mean(0)\n Sw = np.cov(F[y == 1].T, ddof=1) + np.cov(F[y == 0].T, ddof=1)\n Sw += self.reg * np.eye(Sw.shape[0])\n self.coef_ = np.linalg.solve(Sw, mu1 - mu0)\n self.intercept_ = -self.coef_ @ ((mu1 + mu0) / 2)\n return self\n\n def _features(self, X):\n Z = np.einsum('fc,ncs->nfs', self.filters_, X) # spatial filter\n var = Z.var(axis=-1, ddof=1)\n return np.log(var / var.sum(axis=1, keepdims=True)) # normalized log-var\n\n def decision_function(self, X):\n return self._features(X) @ self.coef_ + self.intercept_\n\n def predict(self, X):\n return (self.decision_function(X) > 0).astype(int)\n\n\n# ── Evaluation metrics ───────────────────────────────────────────────────────\n\ndef evaluate(clf, X, y):\n margin = clf.decision_function(X)\n pred = (margin > 0).astype(int)\n acc = (pred == y).mean()\n amp = np.abs(margin).mean()\n m1, m0 = margin[y == 1], margin[y == 0]\n fisher = (m1.mean() - m0.mean()) ** 2 / (m1.var(ddof=1) + m0.var(ddof=1) + 1e-30)\n return dict(acc=acc, amp=amp, fisher=fisher, margin=margin, y=y, pred=pred)\n\n\ndef spectral_snr(X, y, ch_idx, sfreq, band=MU_BAND):\n \"\"\"Ratio of REST mu-power to MI mu-power averaged over selected channels.\"\"\"\n def band_pwr(sig):\n f, p = welch(sig, fs=sfreq,\n nperseg=min(int(sfreq * 2), sig.shape[-1]),\n noverlap=int(sfreq), axis=-1)\n m = (f >= band[0]) & (f < band[1])\n return np.trapezoid(p[..., m], f[m], axis=-1).mean()\n\n return band_pwr(X[y == 0][:, ch_idx, :]) / (band_pwr(X[y == 1][:, ch_idx, :]) + 1e-30)"
},
{
"cell_type": "markdown",
"id": "98d225db",
"metadata": {},
"source": [
"## Load Data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d266216b",
"metadata": {},
"outputs": [],
"source": "xdf_files = sorted(glob.glob(os.path.join(DATA_DIR, '*.xdf')))\nprint(f'Found {len(xdf_files)} XDF file(s).\\n')\n\n# sessions[subject][session_id] = dict(X, y, kind, stim, ch_names, sfreq, file)\nsessions = {}\n\nfor fp in xdf_files:\n meta = parse_session(fp)\n if meta is None:\n print(f' SKIP (unparsed): {os.path.basename(fp)}')\n continue\n subj, ses_id, kind, stim = meta\n try:\n X, y, ch_names, sfreq = load_session_epochs(fp)\n except Exception as e:\n print(f' ERROR {os.path.basename(fp)}: {e}')\n continue\n\n sessions.setdefault(subj, {})[ses_id] = dict(\n X=X, y=y, kind=kind, stim=stim,\n ch_names=ch_names, sfreq=sfreq, file=os.path.basename(fp))\n\n print(f' {subj}/{ses_id} {kind:<7} {stim:<5} '\n f'n={len(y):3d} (MI={int(y.sum())}, REST={int((1-y).sum())})')\n\nsubjects = sorted(sessions.keys())\nprint(f'\\nLoaded {len(subjects)} subject(s): {subjects}')"
},
{
"cell_type": "markdown",
"id": "7b8c8bea",
"metadata": {},
"source": "## Verify Session Layout"
},
{
"cell_type": "code",
"execution_count": null,
"id": "611baf23",
"metadata": {},
"outputs": [],
"source": "# Verify channel layout is consistent across sessions, locate motor channels\nref_subj = subjects[0]\nref_ses = next(iter(sessions[ref_subj].values()))\nchannel_names_global = ref_ses['ch_names']\nsfreq_global = ref_ses['sfreq']\n\nmismatches = [f'{subj}/{sid}' for subj in subjects for sid, s in sessions[subj].items()\n if s['ch_names'] != channel_names_global]\nif mismatches:\n print('!! channel mismatch in:', mismatches)\n\nmotor_idx_global = [channel_names_global.index(c) for c in MOTOR_CH\n if c in channel_names_global]\n\nprint(f'Channels ({len(channel_names_global)}): {channel_names_global}')\nprint(f'Sampling rate: {sfreq_global} Hz')\nprint(f'Motor channels {MOTOR_CH} → indices {motor_idx_global}')"
},
{
"cell_type": "markdown",
"id": "70922abb",
"metadata": {},
"source": "## Train CSP + LDA on OFFLINE, Evaluate on ONLINE"
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5e80da3",
"metadata": {},
"outputs": [],
"source": "results = [] # one row per (subject, pair, condition)\n\nfor subj in subjects:\n subj_ses = sessions[subj]\n\n for pair in PAIRS:\n needed = (pair['train'], pair['online_fes'], pair['online_nofes'])\n missing = [k for k in needed if k not in subj_ses]\n if missing:\n print(f'[{subj}] {pair[\"name\"]}: missing {missing} — skipping')\n continue\n\n train = subj_ses[pair['train']]\n clf = CSPLDA(n_csp=N_CSP).fit(train['X'], train['y'])\n train_acc = (clf.predict(train['X']) == train['y']).mean()\n\n for cond_key, cond_label in [('online_fes', 'FES'), ('online_nofes', 'NOFES')]:\n te = subj_ses[pair[cond_key]]\n res = evaluate(clf, te['X'], te['y'])\n snr_s = spectral_snr(te['X'], te['y'], motor_idx_global, te['sfreq'])\n\n results.append(dict(\n subject=subj, pair=pair['name'], condition=cond_label,\n train_file=train['file'], test_file=te['file'],\n train_acc=train_acc, n_test=len(te['y']),\n acc=res['acc'], amp=res['amp'], fisher=res['fisher'], mu_snr=snr_s,\n margin=res['margin'], y_test=res['y'], pred=res['pred'],\n ))\n\n# Results table\nhdr = f'{\"Subj\":<5} {\"Pair\":<28} {\"Cond\":<6} {\"n\":>4} {\"trainAcc\":>9} {\"acc\":>7} {\"|marg|\":>8} {\"Fisher\":>8} {\"muSNR\":>8}'\nprint('\\n' + hdr)\nprint('-' * len(hdr))\nfor r in results:\n print(f'{r[\"subject\"]:<5} {r[\"pair\"]:<28} {r[\"condition\"]:<6} {r[\"n_test\"]:>4} '\n f'{r[\"train_acc\"]:>9.3f} {r[\"acc\"]:>7.3f} {r[\"amp\"]:>8.3f} '\n f'{r[\"fisher\"]:>8.3f} {r[\"mu_snr\"]:>8.3f}')"
},
{
"cell_type": "markdown",
"id": "2ab81600",
"metadata": {},
"source": "---\n## Figure 1 — Per-metric comparison (FES vs NOFES)"
},
{
"cell_type": "code",
"execution_count": null,
"id": "d53e63b9",
"metadata": {},
"outputs": [],
"source": "METRICS = [\n ('acc', 'Classification accuracy', '01'),\n ('amp', 'Classification amplitude (mean |decision fn|)', 'a.u.'),\n ('fisher', 'Fisher ratio on LDA projection (test-set SNR)', 'a.u.'),\n ('mu_snr', 'μ-band power ratio REST / MI @ C3/Cz/C4', 'ratio'),\n]\n\ncond_color = {'FES': '#E05C2A', 'NOFES': '#2A7BE0'}\n\nfig, axes = plt.subplots(2, 2, figsize=(14, 9))\nfig.suptitle('Online decoding: FES vs NOFES feedback (per subject × offline-trained model)',\n fontsize=13, fontweight='bold', y=1.00)\n\nfor ax, (key, title, unit) in zip(axes.ravel(), METRICS):\n labels, vals, colors = [], [], []\n for subj in subjects:\n for pair in PAIRS:\n tag = pair['name'].split()[0] # \"Pair1\" / \"Pair2\"\n for cond in ('FES', 'NOFES'):\n row = next((r for r in results\n if r['subject']==subj and r['pair']==pair['name']\n and r['condition']==cond), None)\n if row is None: continue\n labels.append(f'{subj}\\n{tag}\\n{cond}')\n vals.append(row[key])\n colors.append(cond_color[cond])\n x = np.arange(len(vals))\n ax.bar(x, vals, color=colors, edgecolor='white', zorder=2)\n ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=7.5)\n ax.set_title(f'{title} ({unit})', fontsize=11, fontweight='bold')\n ax.grid(axis='y', alpha=0.3)\n ax.spines[['top','right']].set_visible(False)\n if key == 'acc':\n ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n\nfig.legend(handles=[Patch(color=cond_color['FES'], label='ONLINE_FES'),\n Patch(color=cond_color['NOFES'], label='ONLINE_NOFES')],\n loc='upper right', ncol=2, bbox_to_anchor=(0.98, 1.0))\nplt.tight_layout()\nplt.savefig('fes_vs_nofes_metrics.png', dpi=150, bbox_inches='tight')\nplt.show()\nprint('Saved: fes_vs_nofes_metrics.png')"
},
{
"cell_type": "markdown",
"id": "248740bd",
"metadata": {},
"source": "---\n## Figure 2 — LDA decision-function distributions\n\nVisualizes classification amplitude and separability directly: wider FES vs NOFES spread between MI and REST curves = higher Fisher ratio and larger mean |margin|."
},
{
"cell_type": "code",
"execution_count": null,
"id": "393042a0",
"metadata": {},
"outputs": [],
"source": "fig, axes = plt.subplots(len(subjects), len(PAIRS),\n figsize=(6 * len(PAIRS), 3.2 * len(subjects)),\n sharex=True, squeeze=False)\nfig.suptitle('LDA decision-function distributions on ONLINE sessions\\n'\n '(class separation ↔ classification amplitude & SNR)',\n fontsize=12, fontweight='bold', y=1.01)\n\nfor i, subj in enumerate(subjects):\n for j, pair in enumerate(PAIRS):\n ax = axes[i][j]\n for cond, color in cond_color.items():\n row = next((r for r in results\n if r['subject']==subj and r['pair']==pair['name']\n and r['condition']==cond), None)\n if row is None: continue\n m_mi = row['margin'][row['y_test'] == 1]\n m_rest = row['margin'][row['y_test'] == 0]\n ax.hist(m_mi, bins=15, alpha=0.5, color=color,\n label=f'{cond} MI', density=True)\n ax.hist(m_rest, bins=15, alpha=0.25, color=color, hatch='///',\n edgecolor=color, label=f'{cond} REST', density=True)\n ax.axvline(0, color='k', lw=0.8)\n ax.set_title(f'{subj} | {pair[\"name\"]}', fontsize=10, fontweight='bold')\n if j == 0: ax.set_ylabel('Density')\n if i == len(subjects) - 1: ax.set_xlabel('LDA decision function')\n ax.spines[['top','right']].set_visible(False)\n if i == 0 and j == 0:\n ax.legend(fontsize=6.5, loc='upper left', ncol=2)\n\nplt.tight_layout()\nplt.savefig('decision_margin_distributions.png', dpi=150, bbox_inches='tight')\nplt.show()\nprint('Saved: decision_margin_distributions.png')"
},
{
"cell_type": "markdown",
"id": "fcb6d19d",
"metadata": {},
"source": "---\n## Figure 3 — Paired Δ (FES NOFES) per metric\n\nWithin each (subject × pair), FES and NOFES sessions use the same offline-trained model. Positive bars mean FES > NOFES; negative means NOFES > FES. This removes the offline-model-quality confound and isolates the effect of feedback type."
},
{
"cell_type": "code",
"execution_count": null,
"id": "75df404b",
"metadata": {},
"outputs": [],
"source": "fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\nfig.suptitle('Within-pair Δ = FES NOFES (positive → FES better)',\n fontsize=12, fontweight='bold', y=1.03)\n\nfor ax, (key, title, _) in zip(axes, METRICS):\n labels, deltas = [], []\n for subj in subjects:\n for pair in PAIRS:\n fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n and r['condition']=='FES'), None)\n nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n and r['condition']=='NOFES'), None)\n if fes is None or nof is None: continue\n deltas.append(fes[key] - nof[key])\n labels.append(f'{subj}\\n{pair[\"name\"].split()[0]}')\n\n colors = ['#E05C2A' if d > 0 else '#2A7BE0' for d in deltas]\n ax.bar(np.arange(len(deltas)), deltas, color=colors, edgecolor='white', zorder=2)\n ax.axhline(0, color='k', lw=0.8)\n ax.set_xticks(np.arange(len(deltas)))\n ax.set_xticklabels(labels, fontsize=8)\n ax.set_title(f'Δ {title.split(\"(\")[0].strip()}', fontsize=10, fontweight='bold')\n ax.grid(axis='y', alpha=0.3)\n ax.spines[['top','right']].set_visible(False)\n\nplt.tight_layout()\nplt.savefig('fes_minus_nofes_delta.png', dpi=150, bbox_inches='tight')\nplt.show()\nprint('Saved: fes_minus_nofes_delta.png')"
},
{
"cell_type": "markdown",
"id": "b3db60ba",
"metadata": {},
"source": [
"---\n",
"## Summary Statistics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cf55268e",
"metadata": {},
"outputs": [],
"source": "agg = {c: {k: [] for k in ('acc','amp','fisher','mu_snr')} for c in ('FES','NOFES')}\nfor r in results:\n for k in agg[r['condition']]:\n agg[r['condition']][k].append(r[k])\n\nn_pairs = len(agg['FES']['acc'])\nprint(f'=== Aggregate across {n_pairs} (subject × pair) comparisons ===\\n')\n\nhdr = f'{\"Metric\":<28} {\"FES (mean ± sd)\":>22} {\"NOFES (mean ± sd)\":>22} {\"paired Δ\":>12}'\nprint(hdr); print('-' * len(hdr))\nfor k, label in [('acc', 'Classification accuracy'),\n ('amp', 'Classification amplitude'),\n ('fisher', 'Fisher ratio (test SNR)'),\n ('mu_snr', 'μ-band SNR (REST/MI)')]:\n fes, nof = np.array(agg['FES'][k]), np.array(agg['NOFES'][k])\n delta = fes - nof\n print(f'{label:<28} {fes.mean():>10.3f} ± {fes.std(ddof=1):>6.3f} '\n f'{nof.mean():>10.3f} ± {nof.std(ddof=1):>6.3f} {delta.mean():>+12.3f}')\n\n# Sign test (simple)\nprint()\nfor k, label in [('acc','acc'), ('amp','|margin|'), ('fisher','Fisher'), ('mu_snr','μ-SNR')]:\n d = np.array(agg['FES'][k]) - np.array(agg['NOFES'][k])\n n_pos = int((d > 0).sum()); n_neg = int((d < 0).sum())\n print(f' {label:<10} FES > NOFES in {n_pos}/{len(d)} comparisons (NOFES > FES in {n_neg})')"
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}