184 lines
22 KiB
Plaintext
184 lines
22 KiB
Plaintext
|
|
{
|
|||
|
|
"cells": [
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": "# Motor Imagery Decoder — Train OFFLINE, Evaluate ONLINE (FES vs NOFES)\n\nFor each subject × offline-session, train a CSP + LDA classifier on the OFFLINE recording, then apply it to the two matched ONLINE sessions.\n\n**Pair 1:** train on `S001 OFFLINE_FES` → test on `S002 ONLINE_FES` and `S003 ONLINE_NOFES`\n**Pair 2:** train on `S004 OFFLINE_NOFES` → test on `S006 ONLINE_FES` and `S005 ONLINE_NOFES`\n\n**Metrics reported per (subject × pair × condition):**\n1. **Classification accuracy** — fraction of cued trials correctly classified\n2. **Classification amplitude** — mean |LDA decision-function value|\n3. **SNR** — (a) Fisher ratio of the LDA projection on online data, and (b) mu-band power ratio REST / MI over motor channels C3/Cz/C4"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": 13,
|
|||
|
|
"id": "578c9128",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": [
|
|||
|
|
"# Install dependencies if needed\n",
|
|||
|
|
"# !pip install pyxdf mne scipy numpy matplotlib"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "857b22c0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "import os\nimport re\nimport glob\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom matplotlib.patches import Patch\nimport pyxdf\nfrom scipy.signal import welch, butter, filtfilt\nfrom scipy.linalg import eigh\n\nplt.rcParams.update({'font.size': 11, 'figure.dpi': 120})"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fe68bf0e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## Configuration"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "dc4b2c55",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "DATA_DIR = os.path.join(os.path.dirname(os.path.abspath('__file__')), 'Group 2 - Glove')\n\n# Marker codes (same cue-encoding in offline and online)\nMI_BEGIN, MI_END = 200, 220\nREST_BEGIN, REST_END = 100, 120\nTARGET_MARKERS = [100, 120, 200, 220]\n\n# Epoch window (t=0 at cue marker)\nT_PRE = -1.0 # baseline start\nT_POST = 5.0 # epoch end\n\n# Bandpass for MI decoding (mu + beta)\nBP_LO, BP_HI = 8.0, 30.0\n\n# CSP spatial filters (top N/2 + bottom N/2)\nN_CSP = 4\n\nNON_EEG = {'AUX1', 'AUX2', 'AUX3', 'AUX7', 'AUX8', 'AUX9', 'TRIGGER'}\nRENAME = {'FP1':'Fp1','FPZ':'Fpz','FP2':'Fp2','FZ':'Fz','CZ':'Cz',\n 'PZ':'Pz','POZ':'POz','OZ':'Oz'}\n\nMOTOR_CH = ['C3', 'Cz', 'C4']\nMU_BAND = (8, 13)\n\n# Design: per subject, each OFFLINE session trains a model tested on two ONLINE sessions\nPAIRS = [\n {'name': 'Pair1 (train=OFFLINE_FES)',\n 'train': 'S001', 'online_fes': 'S002', 'online_nofes': 'S003'},\n {'name': 'Pair2 (train=OFFLINE_NOFES)',\n 'train': 'S004', 'online_fes': 'S006', 'online_nofes': 'S005'},\n]"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "21a40df3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## Helper Functions"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "e798b039",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "# ── XDF loading + session parsing ─────────────────────────────────────────────\n\ndef get_channel_names_from_xdf(eeg_stream):\n ch_desc = eeg_stream['info']['desc'][0]\n channels = ch_desc.get('channels', [{}])[0].get('channel', [])\n return [ch['label'][0] for ch in channels]\n\n\n# Tolerates the 'ONOLINE' typo in subj 003 / S005\n_SESSION_RE = re.compile(r'ses-(S\\d+)(O[A-Z]*LINE)_(FES|NOFES)')\n_SUBJ_RE = re.compile(r'SUBJ_(\\d+)')\n\ndef parse_session(path):\n \"\"\"Return (subject, session_id, kind, stim) or None.\"\"\"\n base = os.path.basename(path)\n m_subj = _SUBJ_RE.search(base)\n m_ses = _SESSION_RE.search(base)\n if not (m_subj and m_ses):\n return None\n ses_id, raw_kind, stim = m_ses.group(1), m_ses.group(2), m_ses.group(3)\n kind = 'OFFLINE' if 'OFF' in raw_kind else 'ONLINE'\n return m_subj.group(1), ses_id, kind, stim\n\n\ndef load_xdf_file(filepath):\n streams, _ = pyxdf.load_xdf(filepath)\n\n eeg_stream = marker_stream = None\n for s in streams:\n stype = s['info']['type'][0].lower()\n if stype == 'eeg': eeg_stream = s\n elif stype == 'markers': marker_stream = s\n if eeg_stream is None or marker_stream is None:\n eeg_stream = streams[0]\n marker_stream = streams[1] if len(streams) > 1 else None\n\n eeg_timestamps = np.array(eeg_stream['time_stamps'])\n eeg_data = np.array(eeg_stream['time_series']).T\n channel_names = get_channel_names_from_xdf(eeg_stream)\n sfreq = float(eeg_stream['info']['nominal_srate'][0])\n\n valid_idx = [i for i, ch in enumerate(channel_names) if ch not in NON_EEG]\n channel_names = [channel_names[i] for i in valid_idx]\n eeg_data = eeg_data[valid_idx, :]\n channel_names = [RENAME.get(ch, ch) for ch in channel_names]\n\n ts_arr = np.asarray(marker_stream['time_series'], dtype=float)\n marker_data = ts_arr[:, 0].astype(int)\n marker_ts = ts_arr[:, 1]\n keep = np.isin(marker_data, TARGET_MARKERS)\n return eeg_data, eeg_timestamps, marker_data[keep], marker_ts[keep], channel_names, sfreq\n\n\ndef bandpass(data, lo, hi, sfreq, order=4):\n nyq = sfreq / 2.0\n b, a = butter(order, [max(lo, 0.5) / nyq, min(hi, nyq - 0.1) / nyq], btype='band')\n return filtfilt(b, a, data, axis=-1)\n\n\ndef extract_epochs(eeg_data, eeg_ts, marker_data, marker_ts, sfreq, begin_code,\n t_pre=T_PRE, t_post=T_POST):\n \"\"\"Returns (n_epochs, n_ch, n_samp) — all epochs trimmed to the same length.\"\"\"\n epochs = []\n n_pre = int(abs(t_pre) * sfreq)\n\n for bi in np.where(marker_data == begin_code)[0]:\n t_start = marker_ts[bi]\n i0 = np.searchsorted(eeg_ts, t_start + t_pre)\n i1 = np.searchsorted(eeg_ts, t_start + t_post)\n if i0 < 0 or i1 > eeg_data.shape[1]:\n continue\n ep = eeg_data[:, i0:i1].copy()\n if ep.shape[1] > n_pre:\n ep -= ep[:, :n_pre].mean(axis=1, keepdims=True)\n epochs.append(ep)\n\n if not epochs:\n return np.empty((0, eeg_data.shape[0], 0))\n min_len = min(e.shape[-1] for e in epochs)\n return np.stack([e[:, :min_len] for e in epochs])\n\n\ndef load_session_epochs(filepath):\n \"\"\"Bandpass continuous EEG, epoch on MI/REST cues, keep ACTIVE window [0, T_POST].\n Returns X (n_trials, n_ch, n_samp), y (1=MI, 0=REST), ch_names, sfreq.\n \"\"\"\n eeg, eeg_ts, mk, mk_ts, ch_names, sfreq = load_xdf_file(filepath)\n eeg_bp = bandpass(eeg, BP_LO, BP_HI, sfreq)\n\n mi = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, MI_BEGIN)\n rest = extract_epochs(eeg_bp, eeg_ts, mk, mk_ts, sfreq, REST_BEGIN)\n\n n_pre = int(abs(T_PRE) * sfreq)\n if mi.shape[-1] > n_pre: mi = mi[..., n_pre:]\n if rest.shape[-1] > n_pre: rest = rest[..., n_pre:]\n\n n = min(mi.shape[-1], rest.shape[-1]) if (mi.size and rest.size) else 0\n mi, rest = mi[..., :n],
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "98d225db",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"## Load Data"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "d266216b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "xdf_files = sorted(glob.glob(os.path.join(DATA_DIR, '*.xdf')))\nprint(f'Found {len(xdf_files)} XDF file(s).\\n')\n\n# sessions[subject][session_id] = dict(X, y, kind, stim, ch_names, sfreq, file)\nsessions = {}\n\nfor fp in xdf_files:\n meta = parse_session(fp)\n if meta is None:\n print(f' SKIP (unparsed): {os.path.basename(fp)}')\n continue\n subj, ses_id, kind, stim = meta\n try:\n X, y, ch_names, sfreq = load_session_epochs(fp)\n except Exception as e:\n print(f' ERROR {os.path.basename(fp)}: {e}')\n continue\n\n sessions.setdefault(subj, {})[ses_id] = dict(\n X=X, y=y, kind=kind, stim=stim,\n ch_names=ch_names, sfreq=sfreq, file=os.path.basename(fp))\n\n print(f' {subj}/{ses_id} {kind:<7} {stim:<5} '\n f'n={len(y):3d} (MI={int(y.sum())}, REST={int((1-y).sum())})')\n\nsubjects = sorted(sessions.keys())\nprint(f'\\nLoaded {len(subjects)} subject(s): {subjects}')"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "7b8c8bea",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": "## Verify Session Layout"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "611baf23",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "# Verify channel layout is consistent across sessions, locate motor channels\nref_subj = subjects[0]\nref_ses = next(iter(sessions[ref_subj].values()))\nchannel_names_global = ref_ses['ch_names']\nsfreq_global = ref_ses['sfreq']\n\nmismatches = [f'{subj}/{sid}' for subj in subjects for sid, s in sessions[subj].items()\n if s['ch_names'] != channel_names_global]\nif mismatches:\n print('!! channel mismatch in:', mismatches)\n\nmotor_idx_global = [channel_names_global.index(c) for c in MOTOR_CH\n if c in channel_names_global]\n\nprint(f'Channels ({len(channel_names_global)}): {channel_names_global}')\nprint(f'Sampling rate: {sfreq_global} Hz')\nprint(f'Motor channels {MOTOR_CH} → indices {motor_idx_global}')"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "70922abb",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": "## Train CSP + LDA on OFFLINE, Evaluate on ONLINE"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "f5e80da3",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "results = [] # one row per (subject, pair, condition)\n\nfor subj in subjects:\n subj_ses = sessions[subj]\n\n for pair in PAIRS:\n needed = (pair['train'], pair['online_fes'], pair['online_nofes'])\n missing = [k for k in needed if k not in subj_ses]\n if missing:\n print(f'[{subj}] {pair[\"name\"]}: missing {missing} — skipping')\n continue\n\n train = subj_ses[pair['train']]\n clf = CSPLDA(n_csp=N_CSP).fit(train['X'], train['y'])\n train_acc = (clf.predict(train['X']) == train['y']).mean()\n\n for cond_key, cond_label in [('online_fes', 'FES'), ('online_nofes', 'NOFES')]:\n te = subj_ses[pair[cond_key]]\n res = evaluate(clf, te['X'], te['y'])\n snr_s = spectral_snr(te['X'], te['y'], motor_idx_global, te['sfreq'])\n\n results.append(dict(\n subject=subj, pair=pair['name'], condition=cond_label,\n train_file=train['file'], test_file=te['file'],\n train_acc=train_acc, n_test=len(te['y']),\n acc=res['acc'], amp=res['amp'], fisher=res['fisher'], mu_snr=snr_s,\n margin=res['margin'], y_test=res['y'], pred=res['pred'],\n ))\n\n# Results table\nhdr = f'{\"Subj\":<5} {\"Pair\":<28} {\"Cond\":<6} {\"n\":>4} {\"trainAcc\":>9} {\"acc\":>7} {\"|marg|\":>8} {\"Fisher\":>8} {\"muSNR\":>8}'\nprint('\\n' + hdr)\nprint('-' * len(hdr))\nfor r in results:\n print(f'{r[\"subject\"]:<5} {r[\"pair\"]:<28} {r[\"condition\"]:<6} {r[\"n_test\"]:>4} '\n f'{r[\"train_acc\"]:>9.3f} {r[\"acc\"]:>7.3f} {r[\"amp\"]:>8.3f} '\n f'{r[\"fisher\"]:>8.3f} {r[\"mu_snr\"]:>8.3f}')"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "2ab81600",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": "---\n## Figure 1 — Per-metric comparison (FES vs NOFES)"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "d53e63b9",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "METRICS = [\n ('acc', 'Classification accuracy', '0–1'),\n ('amp', 'Classification amplitude (mean |decision fn|)', 'a.u.'),\n ('fisher', 'Fisher ratio on LDA projection (test-set SNR)', 'a.u.'),\n ('mu_snr', 'μ-band power ratio REST / MI @ C3/Cz/C4', 'ratio'),\n]\n\ncond_color = {'FES': '#E05C2A', 'NOFES': '#2A7BE0'}\n\nfig, axes = plt.subplots(2, 2, figsize=(14, 9))\nfig.suptitle('Online decoding: FES vs NOFES feedback (per subject × offline-trained model)',\n fontsize=13, fontweight='bold', y=1.00)\n\nfor ax, (key, title, unit) in zip(axes.ravel(), METRICS):\n labels, vals, colors = [], [], []\n for subj in subjects:\n for pair in PAIRS:\n tag = pair['name'].split()[0] # \"Pair1\" / \"Pair2\"\n for cond in ('FES', 'NOFES'):\n row = next((r for r in results\n if r['subject']==subj and r['pair']==pair['name']\n and r['condition']==cond), None)\n if row is None: continue\n labels.append(f'{subj}\\n{tag}\\n{cond}')\n vals.append(row[key])\n colors.append(cond_color[cond])\n x = np.arange(len(vals))\n ax.bar(x, vals, color=colors, edgecolor='white', zorder=2)\n ax.set_xticks(x); ax.set_xticklabels(labels, fontsize=7.5)\n ax.set_title(f'{title} ({unit})', fontsize=11, fontweight='bold')\n ax.grid(axis='y', alpha=0.3)\n ax.spines[['top','right']].set_visible(False)\n if key == 'acc':\n ax.axhline(0.5, color='gray', linestyle='--', lw=0.8, alpha=0.6)\n\nfig.legend(handles=[Patch(color=cond_color['FES'], label='ONLINE_FES'),\n Patch(color=cond_color['NOFES'], label='ONLINE_NOFES')],\n loc='upper right', ncol=2, bbox_to_anchor=(0.98, 1.0))\nplt.tight_layout()\nplt.savefig('fes_vs_nofes_metrics.png', dpi=150, bbox_inches='tight')\nplt.show()\nprint('Saved: fes_vs_nofes_metrics.png')"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "248740bd",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": "---\n## Figure 2 — LDA decision-function distributions\n\nVisualizes classification amplitude and separability directly: wider FES vs NOFES spread between MI and REST curves = higher Fisher ratio and larger mean |margin|."
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "393042a0",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "fig, axes = plt.subplots(len(subjects), len(PAIRS),\n figsize=(6 * len(PAIRS), 3.2 * len(subjects)),\n sharex=True, squeeze=False)\nfig.suptitle('LDA decision-function distributions on ONLINE sessions\\n'\n '(class separation ↔ classification amplitude & SNR)',\n fontsize=12, fontweight='bold', y=1.01)\n\nfor i, subj in enumerate(subjects):\n for j, pair in enumerate(PAIRS):\n ax = axes[i][j]\n for cond, color in cond_color.items():\n row = next((r for r in results\n if r['subject']==subj and r['pair']==pair['name']\n and r['condition']==cond), None)\n if row is None: continue\n m_mi = row['margin'][row['y_test'] == 1]\n m_rest = row['margin'][row['y_test'] == 0]\n ax.hist(m_mi, bins=15, alpha=0.5, color=color,\n label=f'{cond} MI', density=True)\n ax.hist(m_rest, bins=15, alpha=0.25, color=color, hatch='///',\n edgecolor=color, label=f'{cond} REST', density=True)\n ax.axvline(0, color='k', lw=0.8)\n ax.set_title(f'{subj} | {pair[\"name\"]}', fontsize=10, fontweight='bold')\n if j == 0: ax.set_ylabel('Density')\n if i == len(subjects) - 1: ax.set_xlabel('LDA decision function')\n ax.spines[['top','right']].set_visible(False)\n if i == 0 and j == 0:\n ax.legend(fontsize=6.5, loc='upper left', ncol=2)\n\nplt.tight_layout()\nplt.savefig('decision_margin_distributions.png', dpi=150, bbox_inches='tight')\nplt.show()\nprint('Saved: decision_margin_distributions.png')"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "fcb6d19d",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": "---\n## Figure 3 — Paired Δ (FES − NOFES) per metric\n\nWithin each (subject × pair), FES and NOFES sessions use the same offline-trained model. Positive bars mean FES > NOFES; negative means NOFES > FES. This removes the offline-model-quality confound and isolates the effect of feedback type."
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "75df404b",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "fig, axes = plt.subplots(1, 4, figsize=(16, 4.5))\nfig.suptitle('Within-pair Δ = FES − NOFES (positive → FES better)',\n fontsize=12, fontweight='bold', y=1.03)\n\nfor ax, (key, title, _) in zip(axes, METRICS):\n labels, deltas = [], []\n for subj in subjects:\n for pair in PAIRS:\n fes = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n and r['condition']=='FES'), None)\n nof = next((r for r in results if r['subject']==subj and r['pair']==pair['name']\n and r['condition']=='NOFES'), None)\n if fes is None or nof is None: continue\n deltas.append(fes[key] - nof[key])\n labels.append(f'{subj}\\n{pair[\"name\"].split()[0]}')\n\n colors = ['#E05C2A' if d > 0 else '#2A7BE0' for d in deltas]\n ax.bar(np.arange(len(deltas)), deltas, color=colors, edgecolor='white', zorder=2)\n ax.axhline(0, color='k', lw=0.8)\n ax.set_xticks(np.arange(len(deltas)))\n ax.set_xticklabels(labels, fontsize=8)\n ax.set_title(f'Δ {title.split(\"(\")[0].strip()}', fontsize=10, fontweight='bold')\n ax.grid(axis='y', alpha=0.3)\n ax.spines[['top','right']].set_visible(False)\n\nplt.tight_layout()\nplt.savefig('fes_minus_nofes_delta.png', dpi=150, bbox_inches='tight')\nplt.show()\nprint('Saved: fes_minus_nofes_delta.png')"
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "markdown",
|
|||
|
|
"id": "b3db60ba",
|
|||
|
|
"metadata": {},
|
|||
|
|
"source": [
|
|||
|
|
"---\n",
|
|||
|
|
"## Summary Statistics"
|
|||
|
|
]
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"cell_type": "code",
|
|||
|
|
"execution_count": null,
|
|||
|
|
"id": "cf55268e",
|
|||
|
|
"metadata": {},
|
|||
|
|
"outputs": [],
|
|||
|
|
"source": "agg = {c: {k: [] for k in ('acc','amp','fisher','mu_snr')} for c in ('FES','NOFES')}\nfor r in results:\n for k in agg[r['condition']]:\n agg[r['condition']][k].append(r[k])\n\nn_pairs = len(agg['FES']['acc'])\nprint(f'=== Aggregate across {n_pairs} (subject × pair) comparisons ===\\n')\n\nhdr = f'{\"Metric\":<28} {\"FES (mean ± sd)\":>22} {\"NOFES (mean ± sd)\":>22} {\"paired Δ\":>12}'\nprint(hdr); print('-' * len(hdr))\nfor k, label in [('acc', 'Classification accuracy'),\n ('amp', 'Classification amplitude'),\n ('fisher', 'Fisher ratio (test SNR)'),\n ('mu_snr', 'μ-band SNR (REST/MI)')]:\n fes, nof = np.array(agg['FES'][k]), np.array(agg['NOFES'][k])\n delta = fes - nof\n print(f'{label:<28} {fes.mean():>10.3f} ± {fes.std(ddof=1):>6.3f} '\n f'{nof.mean():>10.3f} ± {nof.std(ddof=1):>6.3f} {delta.mean():>+12.3f}')\n\n# Sign test (simple)\nprint()\nfor k, label in [('acc','acc'), ('amp','|margin|'), ('fisher','Fisher'), ('mu_snr','μ-SNR')]:\n d = np.array(agg['FES'][k]) - np.array(agg['NOFES'][k])\n n_pos = int((d > 0).sum()); n_neg = int((d < 0).sum())\n print(f' {label:<10} FES > NOFES in {n_pos}/{len(d)} comparisons (NOFES > FES in {n_neg})')"
|
|||
|
|
}
|
|||
|
|
],
|
|||
|
|
"metadata": {
|
|||
|
|
"kernelspec": {
|
|||
|
|
"display_name": "venv",
|
|||
|
|
"language": "python",
|
|||
|
|
"name": "python3"
|
|||
|
|
},
|
|||
|
|
"language_info": {
|
|||
|
|
"codemirror_mode": {
|
|||
|
|
"name": "ipython",
|
|||
|
|
"version": 3
|
|||
|
|
},
|
|||
|
|
"file_extension": ".py",
|
|||
|
|
"mimetype": "text/x-python",
|
|||
|
|
"name": "python",
|
|||
|
|
"nbconvert_exporter": "python",
|
|||
|
|
"pygments_lexer": "ipython3",
|
|||
|
|
"version": "3.13.7"
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
"nbformat": 4,
|
|||
|
|
"nbformat_minor": 5
|
|||
|
|
}
|