post-LDA-Transfer

This commit is contained in:
pulipakaa24
2026-03-06 00:08:08 -06:00
parent 05505d97dd
commit 24019c18c0

View File

@@ -1667,11 +1667,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 53,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"── Cross-Subject LDA Transfer Decoding ─────────────────────────────────\n",
" Window: 350 ms (179 samples) | Overlap: 75%\n",
" Chance level : 33.3%\n",
"\n",
" Train S1 → Test S2 Run 6 : 0.5718 (57.2%) [+23.8 pp above chance]\n",
" Train S2 → Test S1 Run 6 : 0.5808 (58.1%) [+24.7 pp above chance]\n",
" Mean transfer accuracy : 0.5763 (57.6%)\n"
]
}
],
"source": [ "source": [
"# Test against run 6 of opposite subject.\n", "# ── Cross-subject transfer decoding: LDA trained on one subject, tested on the other ──\n",
"\n", "\n",
"WIN_TEST = 350 # ms\n", "WIN_TEST = 350 # ms\n",
"OVL_TEST = 0.75\n", "OVL_TEST = 0.75\n",
@@ -1690,26 +1704,31 @@
"X_test1, y_test1 = extract_run_Xy(s1_runs[5], WIN_TEST, OVL_TEST, raw=False)\n", "X_test1, y_test1 = extract_run_Xy(s1_runs[5], WIN_TEST, OVL_TEST, raw=False)\n",
"X_test2, y_test2 = extract_run_Xy(s2_runs[5], WIN_TEST, OVL_TEST, raw=False)\n", "X_test2, y_test2 = extract_run_Xy(s2_runs[5], WIN_TEST, OVL_TEST, raw=False)\n",
"\n", "\n",
"scaler = StandardScaler()\n", "scaler1 = StandardScaler()\n",
"X_tr_s1 = scaler.fit_transform(X_train1)\n", "X_tr_s1 = scaler1.fit_transform(X_train1)\n",
"X_te_s1 = scaler.transform(X_test1)\n", "X_te_s1 = scaler1.transform(X_test1)\n",
"\n", "\n",
"X_tr_s2 = scaler.fit_transform(X_train2)\n", "scaler2 = StandardScaler()\n",
"X_te_s2 = scaler.transform(X_test2)\n", "X_tr_s2 = scaler2.fit_transform(X_train2)\n",
"X_te_s2 = scaler2.transform(X_test2)\n",
"\n", "\n",
"lda1 = LinearDiscriminantAnalysis()\n", "lda1 = LinearDiscriminantAnalysis()\n",
"lda1.fit(X_tr_s1, y_train1)\n", "lda1.fit(X_tr_s1, y_train1)\n",
"acc = lda1.score(X_te_s2, y_test2)\n", "acc_s1_to_s2 = lda1.score(X_te_s2, y_test2)\n",
"\n", "\n",
"lda2 = LinearDiscriminantAnalysis()\n", "lda2 = LinearDiscriminantAnalysis()\n",
"lda2.fit(X_tr_s2, y_train2)\n", "lda2.fit(X_tr_s2, y_train2)\n",
"acc = lda2.score(X_te_s1, y_test1)\n", "acc_s2_to_s1 = lda2.score(X_te_s1, y_test1)\n",
"\n", "\n",
"print(f\"Subject 1 LDA held-out accuracy on Run 6\")\n", "chance = 1 / 3\n",
"\n",
"print(\"── Cross-Subject LDA Transfer Decoding ─────────────────────────────────\")\n",
"print(f\" Window: {WIN_TEST} ms ({ms_to_samples(WIN_TEST)} samples) | Overlap: {OVL_TEST:.0%}\")\n", "print(f\" Window: {WIN_TEST} ms ({ms_to_samples(WIN_TEST)} samples) | Overlap: {OVL_TEST:.0%}\")\n",
"print(f\" Train windows : {len(X_train)}\")\n", "print(f\" Chance level : {chance:.1%}\")\n",
"print(f\" Test windows : {len(X_test)}\")\n", "print()\n",
"print(f\" Accuracy : {acc:.4f} ({acc*100:.1f}%)\")" "print(f\" Train S1 → Test S2 Run 6 : {acc_s1_to_s2:.4f} ({acc_s1_to_s2*100:.1f}%) [+{(acc_s1_to_s2 - chance)*100:.1f} pp above chance]\")\n",
"print(f\" Train S2 → Test S1 Run 6 : {acc_s2_to_s1:.4f} ({acc_s2_to_s1*100:.1f}%) [+{(acc_s2_to_s1 - chance)*100:.1f} pp above chance]\")\n",
"print(f\" Mean transfer accuracy : {(acc_s1_to_s2 + acc_s2_to_s1) / 2:.4f} ({(acc_s1_to_s2 + acc_s2_to_s1) / 2 * 100:.1f}%)\")"
] ]
} }
], ],