diff --git a/HW_2_sp23_EE374N-385J_Neural_Eng_EMGanalysis/HW2_EMG_Analysis.ipynb b/HW_2_sp23_EE374N-385J_Neural_Eng_EMGanalysis/HW2_EMG_Analysis.ipynb index fc08760..3d5f560 100644 --- a/HW_2_sp23_EE374N-385J_Neural_Eng_EMGanalysis/HW2_EMG_Analysis.ipynb +++ b/HW_2_sp23_EE374N-385J_Neural_Eng_EMGanalysis/HW2_EMG_Analysis.ipynb @@ -19,7 +19,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2026-03-04T06:31:57.995449Z", @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2026-03-04T06:31:58.477974Z", @@ -154,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2026-03-04T06:31:58.661192Z", @@ -281,7 +281,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -360,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -428,7 +428,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -489,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -580,7 +580,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -640,7 +640,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 9, "metadata": {}, "outputs": [ { @@ -709,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -744,7 +744,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -794,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -855,7 +855,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -905,7 +905,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -983,17 +983,448 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], - "source": "# ── 2.4.1 Setup + Part A: LDA vs KNN on extracted features ──────────────────\nimport torch\nimport torch.nn as nn\nfrom torch.amp import autocast, GradScaler\nfrom sklearn.discriminant_analysis import LinearDiscriminantAnalysis\nfrom sklearn.neighbors import KNeighborsClassifier\nfrom sklearn.preprocessing import StandardScaler\nfrom tqdm.notebook import tqdm\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nprint(f\"Using device: {device}\")\nif device.type == 'cuda':\n print(f\" GPU: {torch.cuda.get_device_name(0)}\")\n torch.backends.cudnn.benchmark = True\n\nuse_amp = device.type == 'cuda'\n\nwin_sizes = [50, 75, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350]\noverlap_fracs = [0, 0.25, 0.5, 0.75]\nk_values = [1, 3, 5, 7, 11, 15]\n\n\ndef extract_run_Xy(run, win_len, ovl_frac, raw=False):\n \"\"\"Extract windowed data from one run.\n\n raw=False : X is (n_windows, 8) — [MAV×4, SSC×4] feature vectors\n raw=True : X is (n_windows, 4, win_len) — raw EMG, channels-first\n \"\"\"\n typ_r, pos_r, emg_r = run['typ'], run['pos'], run['emg']\n task_code = {101: 0, 201: 1, 301: 2}\n step = max(1, int(win_len * (1 - ovl_frac)))\n\n s_mask = np.isin(typ_r, [101, 201, 301])\n e_mask = np.isin(typ_r, [102, 202, 302])\n starts, ends, codes = pos_r[s_mask], pos_r[e_mask], typ_r[s_mask]\n\n X_list, y_list = [], []\n for s_pos, e_pos, s_code in zip(starts, ends, codes):\n seg = emg_r[s_pos:e_pos, :] # (n_samples, 4)\n n_win = max(0, (len(seg) - win_len) // step + 1)\n if n_win == 0:\n continue\n\n if raw:\n for w in range(n_win):\n s = w * step\n X_list.append(seg[s:s + win_len, :].T.astype(np.float32)) # (4, win_len)\n y_list.append(task_code[s_code])\n else:\n ch_feats, n_win_min = [], n_win\n for ch in range(4):\n m, r = extract_features(seg[:, ch], win_len, ovl_frac)\n ch_feats.append((m, r))\n n_win_min = min(n_win_min, len(m))\n for w in range(n_win_min):\n row = ([ch_feats[c][0][w] for c in range(4)] +\n [ch_feats[c][1][w] for c in range(4)])\n X_list.append(row)\n y_list.append(task_code[s_code])\n\n if X_list:\n return np.array(X_list), np.array(y_list)\n shape = (0, 4, win_len) if raw else (0, 8)\n return np.empty(shape, dtype=np.float32), np.array([], dtype=int)\n\n\n# ── Part A grid search ────────────────────────────────────────────────────────\nresults_feat = {}\n\nfor subj_idx, subj_runs in enumerate(filtered_data):\n n_cv = min(5, len(subj_runs))\n cv_runs = subj_runs[:n_cv]\n results = []\n n_total = len(win_sizes) * len(overlap_fracs)\n\n pbar = tqdm(total=n_total, desc=f'[Part A] Subject {subj_idx + 1}', unit='combo')\n\n for win_len in win_sizes:\n for ovl in overlap_fracs:\n pbar.set_postfix(win=win_len, ovl=f'{ovl:.0%}')\n run_Xy = [extract_run_Xy(r, win_len, ovl, raw=False) for r in cv_runs]\n\n clf_accs = {'LDA': [], **{f'KNN-{k}': [] for k in k_values}}\n\n for test_r in range(n_cv):\n train_idx = [i for i in range(n_cv) if i != test_r]\n Xtr_parts = [run_Xy[i][0] for i in train_idx if len(run_Xy[i][0]) > 0]\n ytr_parts = [run_Xy[i][1] for i in train_idx if len(run_Xy[i][1]) > 0]\n if not Xtr_parts:\n continue\n X_train = np.vstack(Xtr_parts)\n y_train = np.concatenate(ytr_parts)\n X_test, y_test = run_Xy[test_r]\n if len(X_test) == 0:\n continue\n\n scaler = StandardScaler()\n X_tr_s = scaler.fit_transform(X_train)\n X_te_s = scaler.transform(X_test)\n\n try:\n lda = LinearDiscriminantAnalysis()\n lda.fit(X_tr_s, y_train)\n clf_accs['LDA'].append(lda.score(X_te_s, y_test))\n except Exception:\n pass\n\n for k in k_values:\n try:\n knn = KNeighborsClassifier(n_neighbors=k)\n knn.fit(X_tr_s, y_train)\n clf_accs[f'KNN-{k}'].append(knn.score(X_te_s, y_test))\n except Exception:\n pass\n\n for clf_name, accs in clf_accs.items():\n if accs:\n results.append((win_len, ovl, clf_name,\n np.mean(accs), np.std(accs)))\n pbar.update(1)\n\n pbar.close()\n results_feat[subj_idx] = results\n best = max(results, key=lambda x: x[3])\n print(f\"Subject {subj_idx + 1} – Best: win={best[0]}, overlap={best[1]:.0%}, \"\n f\"{best[2]}, accuracy = {best[3]:.3f} ± {best[4]:.3f}\")\n\n\n# ── Part A heatmaps (LDA + representative K values) ──────────────────────────\nclf_show = ['LDA', 'KNN-1', 'KNN-5', 'KNN-11']\n\nfig, axes = plt.subplots(2, 4, figsize=(22, 10))\nfig.suptitle('Part A – Feature CV Accuracy: Window × Overlap\\n'\n '(LDA & KNN on MAV+SSC features, run-wise CV)',\n fontsize=13, fontweight='bold')\n\nfor subj_idx in range(2):\n for col, clf_name in enumerate(clf_show):\n ax = axes[subj_idx, col]\n acc_grid = np.full((len(win_sizes), len(overlap_fracs)), np.nan)\n for w, o, c, a, s in results_feat[subj_idx]:\n if c == clf_name:\n acc_grid[win_sizes.index(w), overlap_fracs.index(o)] = a\n\n im = ax.imshow(acc_grid, aspect='auto', cmap='RdYlGn',\n vmin=0.3, vmax=1.0, origin='lower')\n ax.set_xticks(range(len(overlap_fracs)))\n ax.set_xticklabels([f'{o:.0%}' for o in overlap_fracs])\n ax.set_yticks(range(len(win_sizes)))\n ax.set_yticklabels(win_sizes)\n ax.set_xlabel('Overlap')\n ax.set_ylabel('Window Size (samples)')\n ax.set_title(f'Subject {subj_idx + 1} – {clf_name}', fontsize=10)\n\n for wi in range(len(win_sizes)):\n for oi in range(len(overlap_fracs)):\n v = acc_grid[wi, oi]\n if not np.isnan(v):\n ax.text(oi, wi, f'{v:.2f}', ha='center', va='center',\n fontsize=5.5, color='black')\n fig.colorbar(im, ax=ax, shrink=0.8, label='Accuracy')\n\nplt.tight_layout()\nplt.savefig('2_4a_cv_heatmap_features.png', bbox_inches='tight')\nplt.show()" + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cuda\n", + " GPU: NVIDIA GeForce RTX 3060 Ti\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3f9c2dbf5b71469d931f6d54c0aa499d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "[Part A] Subject 1: 0%| | 0/52 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# ── 2.4.1 Setup + Part A: LDA vs KNN on extracted features ──────────────────\n", + "import torch\n", + "import torch.nn as nn\n", + "from torch.amp import autocast, GradScaler\n", + "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n", + "from sklearn.neighbors import KNeighborsClassifier\n", + "from sklearn.preprocessing import StandardScaler\n", + "from tqdm.notebook import tqdm\n", + "\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "print(f\"Using device: {device}\")\n", + "if device.type == 'cuda':\n", + " print(f\" GPU: {torch.cuda.get_device_name(0)}\")\n", + " torch.backends.cudnn.benchmark = True\n", + "\n", + "use_amp = device.type == 'cuda'\n", + "\n", + "win_sizes = [50, 75, 100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350]\n", + "overlap_fracs = [0, 0.25, 0.5, 0.75]\n", + "k_values = [1, 3, 5, 7, 11, 15]\n", + "\n", + "\n", + "def extract_run_Xy(run, win_len, ovl_frac, raw=False):\n", + " \"\"\"Extract windowed data from one run.\n", + "\n", + " raw=False : X is (n_windows, 8) — [MAV×4, SSC×4] feature vectors\n", + " raw=True : X is (n_windows, 4, win_len) — raw EMG, channels-first\n", + " \"\"\"\n", + " typ_r, pos_r, emg_r = run['typ'], run['pos'], run['emg']\n", + " task_code = {101: 0, 201: 1, 301: 2}\n", + " step = max(1, int(win_len * (1 - ovl_frac)))\n", + "\n", + " s_mask = np.isin(typ_r, [101, 201, 301])\n", + " e_mask = np.isin(typ_r, [102, 202, 302])\n", + " starts, ends, codes = pos_r[s_mask], pos_r[e_mask], typ_r[s_mask]\n", + "\n", + " X_list, y_list = [], []\n", + " for s_pos, e_pos, s_code in zip(starts, ends, codes):\n", + " seg = emg_r[s_pos:e_pos, :] # (n_samples, 4)\n", + " n_win = max(0, (len(seg) - win_len) // step + 1)\n", + " if n_win == 0:\n", + " continue\n", + "\n", + " if raw:\n", + " for w in range(n_win):\n", + " s = w * step\n", + " X_list.append(seg[s:s + win_len, :].T.astype(np.float32)) # (4, win_len)\n", + " y_list.append(task_code[s_code])\n", + " else:\n", + " ch_feats, n_win_min = [], n_win\n", + " for ch in range(4):\n", + " m, r = extract_features(seg[:, ch], win_len, ovl_frac)\n", + " ch_feats.append((m, r))\n", + " n_win_min = min(n_win_min, len(m))\n", + " for w in range(n_win_min):\n", + " row = ([ch_feats[c][0][w] for c in range(4)] +\n", + " [ch_feats[c][1][w] for c in range(4)])\n", + " X_list.append(row)\n", + " y_list.append(task_code[s_code])\n", + "\n", + " if X_list:\n", + " return np.array(X_list), np.array(y_list)\n", + " shape = (0, 4, win_len) if raw else (0, 8)\n", + " return np.empty(shape, dtype=np.float32), np.array([], dtype=int)\n", + "\n", + "\n", + "# ── Part A grid search ────────────────────────────────────────────────────────\n", + "results_feat = {}\n", + "\n", + "for subj_idx, subj_runs in enumerate(filtered_data):\n", + " n_cv = min(5, len(subj_runs))\n", + " cv_runs = subj_runs[:n_cv]\n", + " results = []\n", + " n_total = len(win_sizes) * len(overlap_fracs)\n", + "\n", + " pbar = tqdm(total=n_total, desc=f'[Part A] Subject {subj_idx + 1}', unit='combo')\n", + "\n", + " for win_len in win_sizes:\n", + " for ovl in overlap_fracs:\n", + " pbar.set_postfix(win=win_len, ovl=f'{ovl:.0%}')\n", + " run_Xy = [extract_run_Xy(r, win_len, ovl, raw=False) for r in cv_runs]\n", + "\n", + " clf_accs = {'LDA': [], **{f'KNN-{k}': [] for k in k_values}}\n", + "\n", + " for test_r in range(n_cv):\n", + " train_idx = [i for i in range(n_cv) if i != test_r]\n", + " Xtr_parts = [run_Xy[i][0] for i in train_idx if len(run_Xy[i][0]) > 0]\n", + " ytr_parts = [run_Xy[i][1] for i in train_idx if len(run_Xy[i][1]) > 0]\n", + " if not Xtr_parts:\n", + " continue\n", + " X_train = np.vstack(Xtr_parts)\n", + " y_train = np.concatenate(ytr_parts)\n", + " X_test, y_test = run_Xy[test_r]\n", + " if len(X_test) == 0:\n", + " continue\n", + "\n", + " scaler = StandardScaler()\n", + " X_tr_s = scaler.fit_transform(X_train)\n", + " X_te_s = scaler.transform(X_test)\n", + "\n", + " try:\n", + " lda = LinearDiscriminantAnalysis()\n", + " lda.fit(X_tr_s, y_train)\n", + " clf_accs['LDA'].append(lda.score(X_te_s, y_test))\n", + " except Exception:\n", + " pass\n", + "\n", + " for k in k_values:\n", + " try:\n", + " knn = KNeighborsClassifier(n_neighbors=k)\n", + " knn.fit(X_tr_s, y_train)\n", + " clf_accs[f'KNN-{k}'].append(knn.score(X_te_s, y_test))\n", + " except Exception:\n", + " pass\n", + "\n", + " for clf_name, accs in clf_accs.items():\n", + " if accs:\n", + " results.append((win_len, ovl, clf_name,\n", + " np.mean(accs), np.std(accs)))\n", + " pbar.update(1)\n", + "\n", + " pbar.close()\n", + " results_feat[subj_idx] = results\n", + " best = max(results, key=lambda x: x[3])\n", + " print(f\"Subject {subj_idx + 1} – Best: win={best[0]}, overlap={best[1]:.0%}, \"\n", + " f\"{best[2]}, accuracy = {best[3]:.3f} ± {best[4]:.3f}\")\n", + "\n", + "\n", + "# ── Part A heatmaps (LDA + representative K values) ──────────────────────────\n", + "clf_show = ['LDA', 'KNN-1', 'KNN-5', 'KNN-11']\n", + "\n", + "fig, axes = plt.subplots(2, 4, figsize=(22, 10))\n", + "fig.suptitle('Part A – Feature CV Accuracy: Window × Overlap\\n'\n", + " '(LDA & KNN on MAV+SSC features, run-wise CV)',\n", + " fontsize=13, fontweight='bold')\n", + "\n", + "for subj_idx in range(2):\n", + " for col, clf_name in enumerate(clf_show):\n", + " ax = axes[subj_idx, col]\n", + " acc_grid = np.full((len(win_sizes), len(overlap_fracs)), np.nan)\n", + " for w, o, c, a, s in results_feat[subj_idx]:\n", + " if c == clf_name:\n", + " acc_grid[win_sizes.index(w), overlap_fracs.index(o)] = a\n", + "\n", + " im = ax.imshow(acc_grid, aspect='auto', cmap='RdYlGn',\n", + " vmin=0.3, vmax=1.0, origin='lower')\n", + " ax.set_xticks(range(len(overlap_fracs)))\n", + " ax.set_xticklabels([f'{o:.0%}' for o in overlap_fracs])\n", + " ax.set_yticks(range(len(win_sizes)))\n", + " ax.set_yticklabels(win_sizes)\n", + " ax.set_xlabel('Overlap')\n", + " ax.set_ylabel('Window Size (samples)')\n", + " ax.set_title(f'Subject {subj_idx + 1} – {clf_name}', fontsize=10)\n", + "\n", + " for wi in range(len(win_sizes)):\n", + " for oi in range(len(overlap_fracs)):\n", + " v = acc_grid[wi, oi]\n", + " if not np.isnan(v):\n", + " ax.text(oi, wi, f'{v:.2f}', ha='center', va='center',\n", + " fontsize=5.5, color='black')\n", + " fig.colorbar(im, ax=ax, shrink=0.8, label='Accuracy')\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig('2_4a_cv_heatmap_features.png', bbox_inches='tight')\n", + "plt.show()" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], - "source": "# ── 2.4.2 Part B: CNN on raw EMG windows ────────────────────────────────────\n\nclass EMG_CNN_Raw(nn.Module):\n \"\"\"1-D CNN for multi-channel raw EMG.\n Input: (batch, 4, win_len) — 4 EMG channels, temporal dim last.\n \"\"\"\n def __init__(self, n_classes=3, channels=(16,), kernel_size=7):\n super().__init__()\n layers = []\n in_ch = 4\n for out_ch in channels:\n layers += [\n nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size,\n padding=kernel_size // 2),\n nn.BatchNorm1d(out_ch),\n nn.ReLU(),\n ]\n in_ch = out_ch\n layers += [nn.AdaptiveAvgPool1d(1), nn.Flatten(),\n nn.Linear(in_ch, n_classes)]\n self.net = nn.Sequential(*layers)\n\n def forward(self, x):\n return self.net(x)\n\n\nCNN_CONFIGS_RAW = {\n 'CNN-S': dict(channels=(16,), kernel_size=7),\n 'CNN-M': dict(channels=(16, 32), kernel_size=7),\n 'CNN-L': dict(channels=(16, 32, 64), kernel_size=7),\n}\n\n\ndef train_cnn_raw(X_train, y_train, X_test, y_test, cnn_cfg,\n n_epochs=80, lr=1e-3):\n \"\"\"Train raw-EMG CNN. X: (n, 4, win_len), y: (n,).\"\"\"\n # Normalise per channel across training windows: shape (1, 4, 1)\n mu = X_train.mean(axis=(0, 2), keepdims=True)\n sigma = X_train.std(axis=(0, 2), keepdims=True) + 1e-8\n Xtr = torch.tensor((X_train - mu) / sigma, dtype=torch.float32, device=device)\n ytr = torch.tensor(y_train, dtype=torch.long, device=device)\n Xte = torch.tensor((X_test - mu) / sigma, dtype=torch.float32, device=device)\n yte = torch.tensor(y_test, dtype=torch.long, device=device)\n\n model = EMG_CNN_Raw(**cnn_cfg).to(device)\n opt = torch.optim.Adam(model.parameters(), lr=lr)\n loss_fn = nn.CrossEntropyLoss()\n scaler = GradScaler('cuda', enabled=use_amp)\n\n model.train()\n for _ in range(n_epochs):\n opt.zero_grad(set_to_none=True)\n with autocast(device_type=device.type, enabled=use_amp):\n loss = loss_fn(model(Xtr), ytr)\n scaler.scale(loss).backward()\n scaler.step(opt)\n scaler.update()\n\n model.eval()\n with torch.no_grad(), autocast(device_type=device.type, enabled=use_amp):\n preds = model(Xte).argmax(dim=1)\n return (preds == yte).float().mean().item()\n\n\n# ── Part B grid search ────────────────────────────────────────────────────────\nresults_cnn = {}\n\nfor subj_idx, subj_runs in enumerate(filtered_data):\n n_cv = min(5, len(subj_runs))\n cv_runs = subj_runs[:n_cv]\n results = []\n n_total = len(win_sizes) * len(overlap_fracs)\n\n pbar = tqdm(total=n_total, desc=f'[Part B] Subject {subj_idx + 1}', unit='combo')\n\n for win_len in win_sizes:\n for ovl in overlap_fracs:\n pbar.set_postfix(win=win_len, ovl=f'{ovl:.0%}')\n run_Xy = [extract_run_Xy(r, win_len, ovl, raw=True) for r in cv_runs]\n\n for cnn_name, cnn_cfg in CNN_CONFIGS_RAW.items():\n accs = []\n for test_r in range(n_cv):\n train_idx = [i for i in range(n_cv) if i != test_r]\n Xtr_parts = [run_Xy[i][0] for i in train_idx if len(run_Xy[i][0]) > 0]\n ytr_parts = [run_Xy[i][1] for i in train_idx if len(run_Xy[i][1]) > 0]\n if not Xtr_parts:\n continue\n X_train = np.vstack(Xtr_parts)\n y_train = np.concatenate(ytr_parts)\n X_test, y_test = run_Xy[test_r]\n if len(X_test) == 0:\n continue\n try:\n acc = train_cnn_raw(X_train, y_train, X_test, y_test, cnn_cfg)\n accs.append(acc)\n except Exception:\n pass\n if accs:\n results.append((win_len, ovl, cnn_name,\n np.mean(accs), np.std(accs)))\n pbar.update(1)\n\n pbar.close()\n if device.type == 'cuda':\n torch.cuda.empty_cache()\n\n results_cnn[subj_idx] = results\n best = max(results, key=lambda x: x[3])\n print(f\"Subject {subj_idx + 1} – Best: win={best[0]}, overlap={best[1]:.0%}, \"\n f\"{best[2]}, accuracy = {best[3]:.3f} ± {best[4]:.3f}\")\n\n\n# ── Part B heatmaps ───────────────────────────────────────────────────────────\ncnn_names_plot = list(CNN_CONFIGS_RAW.keys())\n\nfig, axes = plt.subplots(2, 3, figsize=(18, 10))\nfig.suptitle('Part B – Raw EMG CNN CV Accuracy: Window × Overlap\\n'\n '(CNN on raw 4-channel windows, run-wise CV)',\n fontsize=13, fontweight='bold')\n\nfor subj_idx in range(2):\n for col, cnn_name in enumerate(cnn_names_plot):\n ax = axes[subj_idx, col]\n acc_grid = np.full((len(win_sizes), len(overlap_fracs)), np.nan)\n for w, o, c, a, s in results_cnn[subj_idx]:\n if c == cnn_name:\n acc_grid[win_sizes.index(w), overlap_fracs.index(o)] = a\n\n im = ax.imshow(acc_grid, aspect='auto', cmap='RdYlGn',\n vmin=0.3, vmax=1.0, origin='lower')\n ax.set_xticks(range(len(overlap_fracs)))\n ax.set_xticklabels([f'{o:.0%}' for o in overlap_fracs])\n ax.set_yticks(range(len(win_sizes)))\n ax.set_yticklabels(win_sizes)\n ax.set_xlabel('Overlap')\n ax.set_ylabel('Window Size (samples)')\n ax.set_title(f'Subject {subj_idx + 1} – {cnn_name}', fontsize=10)\n\n for wi in range(len(win_sizes)):\n for oi in range(len(overlap_fracs)):\n v = acc_grid[wi, oi]\n if not np.isnan(v):\n ax.text(oi, wi, f'{v:.2f}', ha='center', va='center',\n fontsize=5.5, color='black')\n fig.colorbar(im, ax=ax, shrink=0.8, label='Accuracy')\n\nplt.tight_layout()\nplt.savefig('2_4b_cv_heatmap_cnn_raw.png', bbox_inches='tight')\nplt.show()" + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "89babf4e1c194b91a3ca72d1ad41c270", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "[Part B] Subject 1: 0%| | 0/52 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# ── 2.4.2 Part B: CNN on raw EMG windows ────────────────────────────────────\n", + "\n", + "class EMG_CNN_Raw(nn.Module):\n", + " \"\"\"1-D CNN for multi-channel raw EMG.\n", + " Input: (batch, 4, win_len) — 4 EMG channels, temporal dim last.\n", + " \"\"\"\n", + " def __init__(self, n_classes=3, channels=(16,), kernel_size=7):\n", + " super().__init__()\n", + " layers = []\n", + " in_ch = 4\n", + " for out_ch in channels:\n", + " layers += [\n", + " nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size,\n", + " padding=kernel_size // 2),\n", + " nn.BatchNorm1d(out_ch),\n", + " nn.ReLU(),\n", + " ]\n", + " in_ch = out_ch\n", + " layers += [nn.AdaptiveAvgPool1d(1), nn.Flatten(),\n", + " nn.Linear(in_ch, n_classes)]\n", + " self.net = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " return self.net(x)\n", + "\n", + "\n", + "CNN_CONFIGS_RAW = {\n", + " 'CNN-S': dict(channels=(16,), kernel_size=7),\n", + " 'CNN-M': dict(channels=(16, 32), kernel_size=7),\n", + " 'CNN-L': dict(channels=(16, 32, 64), kernel_size=7),\n", + "}\n", + "\n", + "\n", + "def train_cnn_raw(X_train, y_train, X_test, y_test, cnn_cfg,\n", + " n_epochs=80, lr=1e-3):\n", + " \"\"\"Train raw-EMG CNN. X: (n, 4, win_len), y: (n,).\"\"\"\n", + " # Normalise per channel across training windows: shape (1, 4, 1)\n", + " mu = X_train.mean(axis=(0, 2), keepdims=True)\n", + " sigma = X_train.std(axis=(0, 2), keepdims=True) + 1e-8\n", + " Xtr = torch.tensor((X_train - mu) / sigma, dtype=torch.float32, device=device)\n", + " ytr = torch.tensor(y_train, dtype=torch.long, device=device)\n", + " Xte = torch.tensor((X_test - mu) / sigma, dtype=torch.float32, device=device)\n", + " yte = torch.tensor(y_test, dtype=torch.long, device=device)\n", + "\n", + " model = EMG_CNN_Raw(**cnn_cfg).to(device)\n", + " opt = torch.optim.Adam(model.parameters(), lr=lr)\n", + " loss_fn = nn.CrossEntropyLoss()\n", + " scaler = GradScaler('cuda', enabled=use_amp)\n", + "\n", + " model.train()\n", + " for _ in range(n_epochs):\n", + " opt.zero_grad(set_to_none=True)\n", + " with autocast(device_type=device.type, enabled=use_amp):\n", + " loss = loss_fn(model(Xtr), ytr)\n", + " scaler.scale(loss).backward()\n", + " scaler.step(opt)\n", + " scaler.update()\n", + "\n", + " model.eval()\n", + " with torch.no_grad(), autocast(device_type=device.type, enabled=use_amp):\n", + " preds = model(Xte).argmax(dim=1)\n", + " return (preds == yte).float().mean().item()\n", + "\n", + "\n", + "# ── Part B grid search ────────────────────────────────────────────────────────\n", + "results_cnn = {}\n", + "\n", + "for subj_idx, subj_runs in enumerate(filtered_data):\n", + " n_cv = min(5, len(subj_runs))\n", + " cv_runs = subj_runs[:n_cv]\n", + " results = []\n", + " n_total = len(win_sizes) * len(overlap_fracs)\n", + "\n", + " pbar = tqdm(total=n_total, desc=f'[Part B] Subject {subj_idx + 1}', unit='combo')\n", + "\n", + " for win_len in win_sizes:\n", + " for ovl in overlap_fracs:\n", + " pbar.set_postfix(win=win_len, ovl=f'{ovl:.0%}')\n", + " run_Xy = [extract_run_Xy(r, win_len, ovl, raw=True) for r in cv_runs]\n", + "\n", + " for cnn_name, cnn_cfg in CNN_CONFIGS_RAW.items():\n", + " accs = []\n", + " for test_r in range(n_cv):\n", + " train_idx = [i for i in range(n_cv) if i != test_r]\n", + " Xtr_parts = [run_Xy[i][0] for i in train_idx if len(run_Xy[i][0]) > 0]\n", + " ytr_parts = [run_Xy[i][1] for i in train_idx if len(run_Xy[i][1]) > 0]\n", + " if not Xtr_parts:\n", + " continue\n", + " X_train = np.vstack(Xtr_parts)\n", + " y_train = np.concatenate(ytr_parts)\n", + " X_test, y_test = run_Xy[test_r]\n", + " if len(X_test) == 0:\n", + " continue\n", + " try:\n", + " acc = train_cnn_raw(X_train, y_train, X_test, y_test, cnn_cfg)\n", + " accs.append(acc)\n", + " except Exception:\n", + " pass\n", + " if accs:\n", + " results.append((win_len, ovl, cnn_name,\n", + " np.mean(accs), np.std(accs)))\n", + " pbar.update(1)\n", + "\n", + " pbar.close()\n", + " if device.type == 'cuda':\n", + " torch.cuda.empty_cache()\n", + "\n", + " results_cnn[subj_idx] = results\n", + " best = max(results, key=lambda x: x[3])\n", + " print(f\"Subject {subj_idx + 1} – Best: win={best[0]}, overlap={best[1]:.0%}, \"\n", + " f\"{best[2]}, accuracy = {best[3]:.3f} ± {best[4]:.3f}\")\n", + "\n", + "\n", + "# ── Part B heatmaps ───────────────────────────────────────────────────────────\n", + "cnn_names_plot = list(CNN_CONFIGS_RAW.keys())\n", + "\n", + "fig, axes = plt.subplots(2, 3, figsize=(18, 10))\n", + "fig.suptitle('Part B – Raw EMG CNN CV Accuracy: Window × Overlap\\n'\n", + " '(CNN on raw 4-channel windows, run-wise CV)',\n", + " fontsize=13, fontweight='bold')\n", + "\n", + "for subj_idx in range(2):\n", + " for col, cnn_name in enumerate(cnn_names_plot):\n", + " ax = axes[subj_idx, col]\n", + " acc_grid = np.full((len(win_sizes), len(overlap_fracs)), np.nan)\n", + " for w, o, c, a, s in results_cnn[subj_idx]:\n", + " if c == cnn_name:\n", + " acc_grid[win_sizes.index(w), overlap_fracs.index(o)] = a\n", + "\n", + " im = ax.imshow(acc_grid, aspect='auto', cmap='RdYlGn',\n", + " vmin=0.3, vmax=1.0, origin='lower')\n", + " ax.set_xticks(range(len(overlap_fracs)))\n", + " ax.set_xticklabels([f'{o:.0%}' for o in overlap_fracs])\n", + " ax.set_yticks(range(len(win_sizes)))\n", + " ax.set_yticklabels(win_sizes)\n", + " ax.set_xlabel('Overlap')\n", + " ax.set_ylabel('Window Size (samples)')\n", + " ax.set_title(f'Subject {subj_idx + 1} – {cnn_name}', fontsize=10)\n", + "\n", + " for wi in range(len(win_sizes)):\n", + " for oi in range(len(overlap_fracs)):\n", + " v = acc_grid[wi, oi]\n", + " if not np.isnan(v):\n", + " ax.text(oi, wi, f'{v:.2f}', ha='center', va='center',\n", + " fontsize=5.5, color='black')\n", + " fig.colorbar(im, ax=ax, shrink=0.8, label='Accuracy')\n", + "\n", + "plt.tight_layout()\n", + "plt.savefig('2_4b_cv_heatmap_cnn_raw.png', bbox_inches='tight')\n", + "plt.show()" + ] } ], "metadata": { @@ -1017,4 +1448,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}