Files
EMG_Arm/emg_gui.py

2184 lines
83 KiB
Python

"""
EMG Data Collection GUI
=======================
A modern GUI for the EMG data collection pipeline.
Features:
- Data collection with live EMG visualization and gesture prompts
- Session inspector with signal and feature plots
- Model training with progress and results
- Live prediction demo
- LDA visualization
Requirements:
pip install customtkinter matplotlib numpy
Run:
python emg_gui.py
"""
import customtkinter as ctk
import tkinter as tk
from tkinter import messagebox
import threading
import queue
import time
import numpy as np
from pathlib import Path
from datetime import datetime
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from matplotlib.figure import Figure
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
# Import from the existing pipeline
from learning_data_collection import (
# Configuration
NUM_CHANNELS, SAMPLING_RATE_HZ, WINDOW_SIZE_MS, WINDOW_OVERLAP,
GESTURE_HOLD_SEC, REST_BETWEEN_SEC, REPS_PER_GESTURE, DATA_DIR, USER_ID,
# Classes
EMGSample, EMGWindow, EMGParser, Windower,
GestureAwareEMGStream, SimulatedEMGStream,
PromptScheduler, SessionStorage, SessionMetadata,
EMGFeatureExtractor, EMGClassifier, PredictionSmoother,
)
# Import real serial stream for ESP32 hardware
from serial_stream import RealSerialStream
import serial.tools.list_ports
# =============================================================================
# APPEARANCE SETTINGS
# =============================================================================
ctk.set_appearance_mode("dark") # "dark", "light", or "system"
ctk.set_default_color_theme("blue") # "blue", "green", "dark-blue"
# Colors for gestures (names match ESP32 gesture definitions)
GESTURE_COLORS = {
"rest": "#6c757d", # Gray
"open": "#17a2b8", # Cyan
"fist": "#007bff", # Blue
"hook_em": "#fd7e14", # Orange (Hook 'em Horns)
"thumbs_up": "#28a745", # Green
}
def get_gesture_color(gesture_name: str) -> str:
"""Get color for a gesture name."""
for key, color in GESTURE_COLORS.items():
if key in gesture_name.lower():
return color
return "#dc3545" # Red for unknown
# =============================================================================
# MAIN APPLICATION
# =============================================================================
class EMGApp(ctk.CTk):
"""Main application window."""
def __init__(self):
super().__init__()
self.title("EMG Data Collection Pipeline")
self.geometry("1400x900")
self.minsize(1200, 700)
# Configure grid
self.grid_columnconfigure(1, weight=1)
self.grid_rowconfigure(0, weight=1)
# Create sidebar
self.sidebar = Sidebar(self, self.show_page)
self.sidebar.grid(row=0, column=0, sticky="nsew")
# Create container for pages
self.page_container = ctk.CTkFrame(self, fg_color="transparent")
self.page_container.grid(row=0, column=1, sticky="nsew", padx=20, pady=20)
self.page_container.grid_columnconfigure(0, weight=1)
self.page_container.grid_rowconfigure(0, weight=1)
# Create pages
self.pages = {}
self.pages["collection"] = CollectionPage(self.page_container)
self.pages["inspect"] = InspectPage(self.page_container)
self.pages["training"] = TrainingPage(self.page_container)
self.pages["prediction"] = PredictionPage(self.page_container)
self.pages["visualization"] = VisualizationPage(self.page_container)
# Show default page
self.current_page = None
self.show_page("collection")
# Handle window close
self.protocol("WM_DELETE_WINDOW", self.on_close)
def show_page(self, page_name: str):
"""Show a specific page."""
# Hide current page
if self.current_page:
self.pages[self.current_page].grid_forget()
self.pages[self.current_page].on_hide()
# Show new page
self.pages[page_name].grid(row=0, column=0, sticky="nsew")
self.pages[page_name].on_show()
self.current_page = page_name
# Update sidebar selection
self.sidebar.set_active(page_name)
def on_close(self):
"""Handle window close."""
# Stop any running processes in pages
for page in self.pages.values():
if hasattr(page, 'stop'):
page.stop()
self.destroy()
# =============================================================================
# SIDEBAR NAVIGATION
# =============================================================================
class Sidebar(ctk.CTkFrame):
"""Sidebar navigation panel."""
def __init__(self, parent, on_select_callback):
super().__init__(parent, width=200, corner_radius=0)
self.on_select = on_select_callback
# Logo/Title
self.logo_label = ctk.CTkLabel(
self, text="EMG Pipeline",
font=ctk.CTkFont(size=20, weight="bold")
)
self.logo_label.pack(pady=(20, 10))
self.subtitle = ctk.CTkLabel(
self, text="Data Collection & ML",
font=ctk.CTkFont(size=12),
text_color="gray"
)
self.subtitle.pack(pady=(0, 20))
# Navigation buttons
self.nav_buttons = {}
nav_items = [
("collection", "1. Collect Data"),
("inspect", "2. Inspect Sessions"),
("training", "3. Train Model"),
("prediction", "4. Live Prediction"),
("visualization", "5. Visualize LDA"),
]
for page_id, label in nav_items:
btn = ctk.CTkButton(
self, text=label,
font=ctk.CTkFont(size=14),
height=40,
corner_radius=8,
fg_color="transparent",
text_color=("gray10", "gray90"),
hover_color=("gray70", "gray30"),
anchor="w",
command=lambda p=page_id: self.on_select(p)
)
btn.pack(fill="x", padx=10, pady=5)
self.nav_buttons[page_id] = btn
# Spacer
spacer = ctk.CTkLabel(self, text="")
spacer.pack(expand=True)
# Status area
self.status_frame = ctk.CTkFrame(self, fg_color="transparent")
self.status_frame.pack(fill="x", padx=10, pady=10)
self.session_count_label = ctk.CTkLabel(
self.status_frame, text="Sessions: 0",
font=ctk.CTkFont(size=12)
)
self.session_count_label.pack()
self.model_status_label = ctk.CTkLabel(
self.status_frame, text="Model: Not saved",
font=ctk.CTkFont(size=12),
text_color="gray"
)
self.model_status_label.pack()
# Update status
self.update_status()
def set_active(self, page_id: str):
"""Set the active navigation button."""
for pid, btn in self.nav_buttons.items():
if pid == page_id:
btn.configure(fg_color=("gray75", "gray25"))
else:
btn.configure(fg_color="transparent")
def update_status(self):
"""Update the status display."""
storage = SessionStorage()
sessions = storage.list_sessions()
self.session_count_label.configure(text=f"Sessions: {len(sessions)}")
model_path = EMGClassifier.get_default_model_path()
if model_path.exists():
self.model_status_label.configure(text="Model: Saved", text_color="green")
else:
self.model_status_label.configure(text="Model: Not saved", text_color="gray")
# =============================================================================
# BASE PAGE CLASS
# =============================================================================
class BasePage(ctk.CTkFrame):
"""Base class for all pages."""
def __init__(self, parent):
super().__init__(parent, fg_color="transparent")
self.grid_columnconfigure(0, weight=1)
self.grid_rowconfigure(1, weight=1)
def on_show(self):
"""Called when page is shown."""
pass
def on_hide(self):
"""Called when page is hidden."""
pass
def create_header(self, title: str, subtitle: str = ""):
"""Create a page header."""
header_frame = ctk.CTkFrame(self, fg_color="transparent")
header_frame.grid(row=0, column=0, sticky="ew", pady=(0, 20))
title_label = ctk.CTkLabel(
header_frame, text=title,
font=ctk.CTkFont(size=28, weight="bold")
)
title_label.pack(anchor="w")
if subtitle:
subtitle_label = ctk.CTkLabel(
header_frame, text=subtitle,
font=ctk.CTkFont(size=14),
text_color="gray"
)
subtitle_label.pack(anchor="w")
return header_frame
# =============================================================================
# DATA COLLECTION PAGE
# =============================================================================
class CollectionPage(BasePage):
"""Data collection page with live EMG visualization."""
def __init__(self, parent):
super().__init__(parent)
self.create_header(
"Data Collection",
"Collect labeled EMG data with timed gesture prompts"
)
# Collection state (MUST be initialized BEFORE setup_controls)
self.is_collecting = False
self.is_connected = False
self.using_real_hardware = False
self.stream = None
self.parser = None
self.windower = None
self.scheduler = None
self.collected_windows = []
self.collected_labels = []
self.collected_raw_samples = [] # For label alignment
self.sample_buffer = []
self.collection_thread = None
self.data_queue = queue.Queue()
# Main content area
self.content = ctk.CTkFrame(self)
self.content.grid(row=1, column=0, sticky="nsew")
self.content.grid_columnconfigure(0, weight=1)
self.content.grid_columnconfigure(1, weight=2)
self.content.grid_rowconfigure(0, weight=1)
# Left panel - Controls
self.controls_panel = ctk.CTkFrame(self.content)
self.controls_panel.grid(row=0, column=0, sticky="nsew", padx=(0, 10), pady=0)
self.setup_controls()
# Right panel - Live plot and prompt
self.plot_panel = ctk.CTkFrame(self.content)
self.plot_panel.grid(row=0, column=1, sticky="nsew", padx=(10, 0), pady=0)
self.setup_plot()
def setup_controls(self):
"""Setup the control panel."""
# User ID
user_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent")
user_frame.pack(fill="x", padx=20, pady=(20, 10))
ctk.CTkLabel(user_frame, text="User ID:", font=ctk.CTkFont(size=14)).pack(anchor="w")
self.user_id_entry = ctk.CTkEntry(user_frame, placeholder_text="user_001")
self.user_id_entry.pack(fill="x", pady=(5, 0))
self.user_id_entry.insert(0, USER_ID)
# Data Source selection
source_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent")
source_frame.pack(fill="x", padx=20, pady=10)
ctk.CTkLabel(source_frame, text="Data Source:", font=ctk.CTkFont(size=14)).pack(anchor="w")
self.source_var = ctk.StringVar(value="simulated")
radio_frame = ctk.CTkFrame(source_frame, fg_color="transparent")
radio_frame.pack(fill="x", pady=(5, 0))
self.sim_radio = ctk.CTkRadioButton(
radio_frame, text="Simulated", variable=self.source_var, value="simulated",
command=self._on_source_change
)
self.sim_radio.pack(side="left", padx=(0, 20))
self.real_radio = ctk.CTkRadioButton(
radio_frame, text="Real ESP32", variable=self.source_var, value="real",
command=self._on_source_change
)
self.real_radio.pack(side="left")
# Port selection (initially hidden, shown when "Real ESP32" selected)
self.port_frame = ctk.CTkFrame(source_frame, fg_color="transparent")
# Don't pack yet - _on_source_change will handle visibility
port_select_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent")
port_select_frame.pack(fill="x", pady=(5, 0))
ctk.CTkLabel(port_select_frame, text="Port:").pack(side="left")
self.port_var = ctk.StringVar(value="Auto-detect")
self.port_dropdown = ctk.CTkOptionMenu(
port_select_frame, variable=self.port_var,
values=["Auto-detect"], width=150
)
self.port_dropdown.pack(side="left", padx=(10, 5))
self.refresh_ports_btn = ctk.CTkButton(
port_select_frame, text="", width=30,
command=self._refresh_ports
)
self.refresh_ports_btn.pack(side="left")
# Connection status and button
connect_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent")
connect_frame.pack(fill="x", pady=(5, 0))
self.connect_button = ctk.CTkButton(
connect_frame, text="Connect",
width=100, height=28,
command=self._toggle_connection,
state="disabled" # Disabled until "Real ESP32" selected
)
self.connect_button.pack(side="left", padx=(0, 10))
self.connection_status = ctk.CTkLabel(
connect_frame, text="● Disconnected",
font=ctk.CTkFont(size=11), text_color="gray"
)
self.connection_status.pack(side="left")
# Gesture selection
gesture_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent")
gesture_frame.pack(fill="x", padx=20, pady=10)
ctk.CTkLabel(gesture_frame, text="Gestures:", font=ctk.CTkFont(size=14)).pack(anchor="w")
self.gesture_vars = {}
available_gestures = ["open", "fist", "hook_em", "thumbs_up"]
for gesture in available_gestures:
var = ctk.BooleanVar(value=True) # All selected by default
cb = ctk.CTkCheckBox(gesture_frame, text=gesture.replace("_", " ").title(), variable=var)
cb.pack(anchor="w", pady=2)
self.gesture_vars[gesture] = var
# Settings
settings_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent")
settings_frame.pack(fill="x", padx=20, pady=10)
ctk.CTkLabel(settings_frame, text="Settings:", font=ctk.CTkFont(size=14)).pack(anchor="w")
# Hold duration
hold_frame = ctk.CTkFrame(settings_frame, fg_color="transparent")
hold_frame.pack(fill="x", pady=5)
ctk.CTkLabel(hold_frame, text="Hold (sec):").pack(side="left")
self.hold_slider = ctk.CTkSlider(hold_frame, from_=1, to=5, number_of_steps=8)
self.hold_slider.set(GESTURE_HOLD_SEC)
self.hold_slider.pack(side="left", fill="x", expand=True, padx=10)
self.hold_label = ctk.CTkLabel(hold_frame, text=f"{GESTURE_HOLD_SEC:.1f}")
self.hold_label.pack(side="right")
self.hold_slider.configure(command=lambda v: self.hold_label.configure(text=f"{v:.1f}"))
# Reps
reps_frame = ctk.CTkFrame(settings_frame, fg_color="transparent")
reps_frame.pack(fill="x", pady=5)
ctk.CTkLabel(reps_frame, text="Reps:").pack(side="left")
self.reps_slider = ctk.CTkSlider(reps_frame, from_=1, to=5, number_of_steps=4)
self.reps_slider.set(REPS_PER_GESTURE)
self.reps_slider.pack(side="left", fill="x", expand=True, padx=10)
self.reps_label = ctk.CTkLabel(reps_frame, text=f"{REPS_PER_GESTURE}")
self.reps_label.pack(side="right")
self.reps_slider.configure(command=lambda v: self.reps_label.configure(text=f"{int(v)}"))
# Buttons
button_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent")
button_frame.pack(fill="x", padx=20, pady=20)
self.start_button = ctk.CTkButton(
button_frame, text="Start Collection",
font=ctk.CTkFont(size=16, weight="bold"),
height=50,
command=self.toggle_collection
)
self.start_button.pack(fill="x", pady=5)
self.save_button = ctk.CTkButton(
button_frame, text="Save Session",
font=ctk.CTkFont(size=14),
height=40,
state="disabled",
command=self.save_session
)
self.save_button.pack(fill="x", pady=5)
# Progress
progress_frame = ctk.CTkFrame(self.controls_panel, fg_color="transparent")
progress_frame.pack(fill="x", padx=20, pady=10)
self.progress_bar = ctk.CTkProgressBar(progress_frame)
self.progress_bar.pack(fill="x", pady=5)
self.progress_bar.set(0)
self.status_label = ctk.CTkLabel(
progress_frame, text="Ready to collect",
font=ctk.CTkFont(size=12)
)
self.status_label.pack()
self.window_count_label = ctk.CTkLabel(
progress_frame, text="Windows: 0",
font=ctk.CTkFont(size=12),
text_color="gray"
)
self.window_count_label.pack()
def setup_plot(self):
"""Setup the live plot area."""
# Gesture prompt display
self.prompt_frame = ctk.CTkFrame(self.plot_panel)
self.prompt_frame.pack(fill="x", padx=20, pady=20)
self.prompt_label = ctk.CTkLabel(
self.prompt_frame, text="READY",
font=ctk.CTkFont(size=48, weight="bold"),
text_color="gray",
width=500, # Fixed width to prevent resizing glitches
)
self.prompt_label.pack(pady=30)
self.countdown_label = ctk.CTkLabel(
self.prompt_frame, text="",
font=ctk.CTkFont(size=18)
)
self.countdown_label.pack()
# Matplotlib figure for live EMG
self.fig = Figure(figsize=(8, 5), dpi=100, facecolor='#2b2b2b')
self.axes = []
for i in range(NUM_CHANNELS):
ax = self.fig.add_subplot(NUM_CHANNELS, 1, i + 1)
ax.set_facecolor('#2b2b2b')
ax.tick_params(colors='white')
ax.set_ylabel(f'Ch{i}', color='white', fontsize=10)
ax.set_xlim(0, 500)
ax.set_ylim(0, 3300) # ESP32 outputs millivolts (0-3100 mV)
ax.grid(True, alpha=0.3)
for spine in ax.spines.values():
spine.set_color('white')
self.axes.append(ax)
self.axes[-1].set_xlabel('Samples', color='white')
self.fig.tight_layout()
self.canvas = FigureCanvasTkAgg(self.fig, master=self.plot_panel)
self.canvas.draw()
self.canvas.get_tk_widget().pack(fill="both", expand=True, padx=20, pady=(0, 20))
# Initialize plot lines
self.plot_lines = []
self.plot_data = [np.zeros(500) for _ in range(NUM_CHANNELS)]
for i, ax in enumerate(self.axes):
line, = ax.plot(self.plot_data[i], color='#00ff88', linewidth=1)
self.plot_lines.append(line)
def toggle_collection(self):
"""Start or stop collection."""
print("\n" + "="*80)
print("[DEBUG] toggle_collection() called")
print(f"[DEBUG] Current state:")
print(f" - is_collecting: {self.is_collecting}")
print(f" - is_connected: {self.is_connected}")
print(f" - using_real_hardware: {self.using_real_hardware}")
print(f" - source_var: {self.source_var.get()}")
print(f" - stream exists: {self.stream is not None}")
if self.stream:
if hasattr(self.stream, 'state'):
print(f" - stream.state: {self.stream.state}")
print(f" - button text: {self.start_button.cget('text')}")
print(f" - button state: {self.start_button.cget('state')}")
# Prevent rapid double-clicks from interfering
if hasattr(self, '_toggling') and self._toggling:
print("[DEBUG] BLOCKED: Already toggling (debounce)")
print("="*80 + "\n")
return
self._toggling = True
try:
if self.is_collecting:
print("[DEBUG] Branch: STOPPING collection")
self.stop_collection()
else:
print("[DEBUG] Branch: STARTING collection")
self.start_collection()
finally:
# Reset flag after brief delay to prevent immediate re-trigger
self.after(100, lambda: setattr(self, '_toggling', False))
def start_collection(self):
"""Start data collection."""
print("[DEBUG] start_collection() entered")
# CRITICAL: Drain any stale messages from previous sessions FIRST
# This prevents old 'done' messages from stopping the new session
stale_count = 0
try:
while True:
msg = self.data_queue.get_nowait()
stale_count += 1
print(f"[DEBUG] Drained stale message: {msg[0]}")
except queue.Empty:
pass
if stale_count > 0:
print(f"[DEBUG] Cleared {stale_count} stale message(s) from queue")
# Get selected gestures
gestures = [g for g, var in self.gesture_vars.items() if var.get()]
print(f"[DEBUG] Selected gestures: {gestures}")
if not gestures:
print("[DEBUG] EXIT: No gestures selected")
messagebox.showwarning("No Gestures", "Please select at least one gesture.")
return
# Determine data source and create appropriate stream
self.using_real_hardware = (self.source_var.get() == "real")
print(f"[DEBUG] using_real_hardware set to: {self.using_real_hardware}")
if self.using_real_hardware:
print("[DEBUG] Real hardware path")
# Must be connected for real hardware
print(f"[DEBUG] Checking connection: is_connected={self.is_connected}, stream exists={self.stream is not None}")
if not self.is_connected or not self.stream:
print("[DEBUG] EXIT: Not connected to device")
messagebox.showerror("Not Connected", "Please connect to the ESP32 first.")
return
# Send start command to begin streaming
print("[DEBUG] Calling stream.start()...")
try:
self.stream.start()
print("[DEBUG] stream.start() succeeded")
except Exception as e:
print(f"[DEBUG] stream.start() FAILED: {e}")
# Reset stream state if start failed
if self.stream:
try:
print("[DEBUG] Attempting stream.stop() to reset state...")
self.stream.stop() # Try to return to CONNECTED state
print("[DEBUG] stream.stop() succeeded")
except Exception as e2:
print(f"[DEBUG] stream.stop() FAILED: {e2}")
messagebox.showerror("Start Error", f"Failed to start streaming:\n{e}")
print("[DEBUG] EXIT: Stream start error")
return
else:
print("[DEBUG] Simulated stream path")
# Simulated stream (gesture-aware for realistic testing)
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
print("[DEBUG] Created GestureAwareEMGStream")
self.stream.start() # Start the background data generation thread
print("[DEBUG] Started simulated stream")
# Initialize parser and windower
self.parser = EMGParser(num_channels=NUM_CHANNELS)
self.windower = Windower(
window_size_ms=WINDOW_SIZE_MS,
sample_rate=SAMPLING_RATE_HZ,
overlap=WINDOW_OVERLAP
)
self.scheduler = PromptScheduler(
gestures=gestures,
hold_sec=self.hold_slider.get(),
rest_sec=REST_BETWEEN_SEC,
reps=int(self.reps_slider.get())
)
# Reset state
self.collected_windows = []
self.collected_labels = []
self.collected_raw_samples = [] # Store raw samples for label alignment
self.sample_buffer = []
print("[DEBUG] Reset collection state")
# Mark as collecting
self.is_collecting = True
print("[DEBUG] Set is_collecting = True")
# Update UI
self.start_button.configure(text="Stop Collection", fg_color="red")
self.save_button.configure(state="disabled")
self.status_label.configure(text="Starting...")
print("[DEBUG] Updated UI - button now shows 'Stop Collection'")
# Disable source selection and connection during collection
self.sim_radio.configure(state="disabled")
self.real_radio.configure(state="disabled")
if self.using_real_hardware:
self.connect_button.configure(state="disabled")
print("[DEBUG] Disabled source/connection controls")
# Start collection thread
self.collection_thread = threading.Thread(target=self.collection_loop, daemon=True)
self.collection_thread.start()
print("[DEBUG] Started collection thread")
# Start UI update loop
self.update_collection_ui()
print("[DEBUG] start_collection() completed successfully")
print("="*80 + "\n")
def stop_collection(self):
"""Stop data collection."""
print("[DEBUG] stop_collection() called")
print(f"[DEBUG] Was collecting: {self.is_collecting}")
self.is_collecting = False
# Safe cleanup - stream might already be in error state
try:
if self.stream:
if self.using_real_hardware:
print("[DEBUG] Calling stream.stop() for real hardware")
# Send stop command (returns to CONNECTED state)
self.stream.stop()
print("[DEBUG] stream.stop() completed")
else:
print("[DEBUG] Stopping simulated stream")
# For simulated stream, just stop it
self.stream.stop()
self.stream = None
print("[DEBUG] Simulated stream stopped and cleared")
except Exception as e:
print(f"[DEBUG] Exception during stream cleanup: {e}")
pass # Ignore cleanup errors
# Drain any pending messages from queue to prevent stale data
try:
while True:
self.data_queue.get_nowait()
except queue.Empty:
pass
self.start_button.configure(text="Start Collection", fg_color=["#3B8ED0", "#1F6AA5"])
self.status_label.configure(text=f"Collected {len(self.collected_windows)} windows")
self.prompt_label.configure(text="DONE", text_color="green")
self.countdown_label.configure(text="")
print("[DEBUG] UI reset - button shows 'Start Collection'")
# Re-enable source selection and connection button
self.sim_radio.configure(state="normal")
self.real_radio.configure(state="normal")
if self.using_real_hardware:
self.connect_button.configure(state="normal")
# Still connected, just not streaming
if self.is_connected:
device_name = self.stream.device_info.get('device', 'ESP32') if self.stream and self.stream.device_info else 'ESP32'
self._update_connection_status("green", f"Connected ({device_name})")
if self.collected_windows:
self.save_button.configure(state="normal")
print("[DEBUG] stop_collection() completed")
print("="*80 + "\n")
def collection_loop(self):
"""Background collection loop."""
# Stream is already started (either via handshake for real HW or created for simulated)
# Just mark as ready
if self.using_real_hardware:
self.data_queue.put(('connection_status', ('green', 'Streaming')))
self.scheduler.start_session()
last_prompt = None
last_ui_update = time.perf_counter()
last_plot_update = time.perf_counter()
last_data_time = time.perf_counter() # Track last received data for timeout detection
sample_batch = [] # Batch samples for plotting
timeout_warning_sent = False
while self.is_collecting and not self.scheduler.is_session_complete():
# Get current prompt
prompt = self.scheduler.get_current_prompt()
current_time = time.perf_counter()
if prompt:
# Update simulated stream gesture (only for GestureAwareEMGStream)
if hasattr(self.stream, 'set_gesture'):
self.stream.set_gesture(prompt.gesture_name)
# Calculate time remaining in current gesture
elapsed_in_session = self.scheduler.get_elapsed_time()
elapsed_in_gesture = elapsed_in_session - prompt.start_time
time_remaining_in_gesture = prompt.duration_sec - elapsed_in_gesture
# Find the next gesture (for "upcoming" display)
current_prompt_idx = self.scheduler.schedule.prompts.index(prompt)
next_gesture = None
if current_prompt_idx + 1 < len(self.scheduler.schedule.prompts):
next_prompt = self.scheduler.schedule.prompts[current_prompt_idx + 1]
if next_prompt.gesture_name != "rest":
next_gesture = next_prompt.gesture_name
# Send prompt update to UI (throttled to every 200ms for smoother text)
if current_time - last_ui_update > 0.2:
# Send current gesture, countdown, and upcoming gesture
self.data_queue.put(('prompt_with_countdown', (
prompt.gesture_name,
time_remaining_in_gesture,
next_gesture
)))
# Send overall progress
progress = elapsed_in_session / self.scheduler.schedule.total_duration
self.data_queue.put(('progress', progress))
last_ui_update = current_time
last_prompt = prompt.gesture_name
# Read and process data
try:
line = self.stream.readline()
except Exception as e:
# Only report error if we didn't intentionally stop
if self.is_collecting:
self.data_queue.put(('error', f"Serial read error: {e}"))
break
if line:
last_data_time = current_time # Reset timeout counter
timeout_warning_sent = False
sample = self.parser.parse_line(line)
if sample:
# Store raw sample for label alignment
self.collected_raw_samples.append(sample)
# Batch samples for plotting (don't send every single one)
sample_batch.append(sample.channels)
# Send batched samples for plotting every 50ms (20 FPS)
if current_time - last_plot_update > 0.05:
if sample_batch:
self.data_queue.put(('samples_batch', sample_batch))
sample_batch = []
last_plot_update = current_time
# Try to form a window
window = self.windower.add_sample(sample)
if window:
label = self.scheduler.get_label_for_time(window.start_time)
self.collected_windows.append(window)
self.collected_labels.append(label)
self.data_queue.put(('window_count', len(self.collected_windows)))
else:
# Check for data timeout (only relevant for real hardware)
if self.using_real_hardware and (current_time - last_data_time > 3.0):
if not timeout_warning_sent:
self.data_queue.put(('warning', 'No data received - check ESP32 connection'))
self.data_queue.put(('connection_status', ('orange', 'No data')))
timeout_warning_sent = True
# Collection complete
self.data_queue.put(('done', None))
def update_collection_ui(self):
"""Update UI from collection thread data."""
needs_redraw = False
try:
# Process up to 10 messages per update cycle to prevent backlog
for _ in range(10):
msg_type, data = self.data_queue.get_nowait()
if msg_type == 'prompt_with_countdown':
gesture_name, time_remaining, next_gesture = data
countdown_int = int(np.ceil(time_remaining))
if gesture_name == "rest" and next_gesture:
# During rest, show upcoming gesture
next_display = next_gesture.upper().replace("_", " ")
color = get_gesture_color(next_gesture)
display_text = f"{next_display} in {countdown_int}"
else:
# During gesture, show current gesture (user is holding it)
gesture_display = gesture_name.upper().replace("_", " ")
color = get_gesture_color(gesture_name)
if countdown_int > 0:
display_text = f"{gesture_display} {countdown_int}"
else:
display_text = gesture_display
self.prompt_label.configure(text=display_text, text_color=color)
elif msg_type == 'progress':
self.progress_bar.set(data)
remaining = self.scheduler.schedule.total_duration * (1 - data)
self.countdown_label.configure(text=f"Total: {remaining:.1f}s remaining")
elif msg_type == 'samples_batch':
# Update plot data with batch of samples
for sample in data:
for i, val in enumerate(sample):
self.plot_data[i] = np.roll(self.plot_data[i], -1)
self.plot_data[i][-1] = val
# Update plot lines once per batch
for i in range(len(self.plot_lines)):
self.plot_lines[i].set_ydata(self.plot_data[i])
needs_redraw = True
elif msg_type == 'window_count':
self.window_count_label.configure(text=f"Windows: {data}")
elif msg_type == 'error':
# Show error and stop collection
self.status_label.configure(text=f"Error: {data}", text_color="red")
if self.using_real_hardware:
self._update_connection_status("red", "Disconnected")
messagebox.showerror("Collection Error", data)
self.stop_collection()
return
elif msg_type == 'warning':
# Show warning but continue
self.status_label.configure(text=f"Warning: {data}", text_color="orange")
elif msg_type == 'connection_status':
# Update connection indicator
color, text = data
self._update_connection_status(color, text)
elif msg_type == 'done':
self.stop_collection()
return
except queue.Empty:
pass
# Only redraw once per update cycle
if needs_redraw:
self.canvas.draw_idle()
if self.is_collecting:
self.after(50, self.update_collection_ui)
def save_session(self):
"""Save the collected session."""
if not self.collected_windows:
messagebox.showwarning("No Data", "No data to save!")
return
user_id = self.user_id_entry.get() or USER_ID
gestures = [g for g, var in self.gesture_vars.items() if var.get()]
storage = SessionStorage()
session_id = storage.generate_session_id(user_id)
metadata = SessionMetadata(
user_id=user_id,
session_id=session_id,
timestamp=datetime.now().isoformat(),
sampling_rate=SAMPLING_RATE_HZ,
window_size_ms=WINDOW_SIZE_MS,
num_channels=NUM_CHANNELS,
gestures=gestures,
notes=""
)
# Get session start time for label alignment
session_start_time = None
if self.scheduler and self.scheduler.session_start_time:
session_start_time = self.scheduler.session_start_time
filepath = storage.save_session(
windows=self.collected_windows,
labels=self.collected_labels,
metadata=metadata,
raw_samples=self.collected_raw_samples if self.collected_raw_samples else None,
session_start_time=session_start_time
)
# Check if alignment was performed
alignment_msg = ""
if session_start_time and self.collected_raw_samples:
alignment_msg = "\n\nLabel alignment: enabled"
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}{alignment_msg}")
# Update sidebar
app = self.winfo_toplevel()
if isinstance(app, EMGApp):
app.sidebar.update_status()
# Reset for next collection
self.collected_windows = []
self.collected_labels = []
self.collected_raw_samples = []
self.save_button.configure(state="disabled")
self.status_label.configure(text="Ready to collect")
self.window_count_label.configure(text="Windows: 0")
self.progress_bar.set(0)
self.prompt_label.configure(text="READY", text_color="gray")
def _on_source_change(self):
"""Show/hide port selection based on data source."""
print("\n" + "="*80)
print("[DEBUG] _on_source_change() called")
print(f"[DEBUG] Before cleanup:")
print(f" - is_connected: {self.is_connected}")
print(f" - is_collecting: {self.is_collecting}")
print(f" - stream exists: {self.stream is not None}")
print(f" - source_var changing to: {self.source_var.get()}")
# Clean up any existing connection/stream when switching modes
if self.is_connected and self.stream:
print("[DEBUG] Disconnecting existing stream...")
try:
self.stream.disconnect()
print("[DEBUG] Stream disconnected successfully")
except Exception as e:
print(f"[DEBUG] Stream disconnect failed: {e}")
self.is_connected = False
self.stream = None
print("[DEBUG] Cleared is_connected and stream")
print(f"[DEBUG] NOTE: is_collecting remains: {self.is_collecting}")
if self.source_var.get() == "real":
print("[DEBUG] Configuring for REAL hardware mode")
self.port_frame.pack(fill="x", pady=(5, 0))
self._refresh_ports()
self.connect_button.configure(text="Connect", state="normal")
self.start_button.configure(state="disabled") # Must connect first
self._update_connection_status("gray", "Disconnected")
print("[DEBUG] Start button DISABLED (must connect first)")
else:
print("[DEBUG] Configuring for SIMULATED mode")
self.port_frame.pack_forget()
self._update_connection_status("gray", "Not using hardware")
self.connect_button.configure(state="disabled")
self.start_button.configure(state="normal") # Simulated mode doesn't need connect
print("[DEBUG] Start button ENABLED (no connection needed)")
print("="*80 + "\n")
def _refresh_ports(self):
"""Scan and populate available serial ports."""
ports = serial.tools.list_ports.comports()
port_names = ["Auto-detect"] + [p.device for p in ports]
# Update dropdown values
self.port_dropdown.configure(values=port_names)
# Show port info
if ports:
self._update_connection_status("orange", f"Found {len(ports)} port(s)")
else:
self._update_connection_status("red", "No ports found")
def _get_serial_port(self):
"""Get selected port, or None for auto-detect."""
port = self.port_var.get()
return None if port == "Auto-detect" else port
def _update_connection_status(self, color: str, text: str):
"""Update the connection status indicator."""
self.connection_status.configure(text=f"{text}", text_color=color)
def _toggle_connection(self):
"""Connect or disconnect from ESP32."""
if self.is_connected:
self._disconnect_device()
else:
self._connect_device()
def _connect_device(self):
"""Connect to ESP32 with handshake."""
print("\n" + "="*80)
print("[DEBUG] _connect_device() called")
port = self._get_serial_port()
print(f"[DEBUG] Port: {port}")
try:
# Update UI to show connecting
self._update_connection_status("orange", "Connecting...")
self.connect_button.configure(state="disabled")
self.update() # Force UI update
print("[DEBUG] UI updated - showing 'Connecting...'")
# Create stream and connect
self.stream = RealSerialStream(port=port)
print("[DEBUG] Created RealSerialStream")
device_info = self.stream.connect(timeout=5.0)
print(f"[DEBUG] Connection successful: {device_info}")
# Success!
self.is_connected = True
print("[DEBUG] Set is_connected = True")
self._update_connection_status("green", f"Connected ({device_info.get('device', 'ESP32')})")
self.connect_button.configure(text="Disconnect", state="normal")
self.start_button.configure(state="normal")
print("[DEBUG] Start button ENABLED")
print(f"[DEBUG] Stream state: {self.stream.state}")
print("="*80 + "\n")
except TimeoutError as e:
messagebox.showerror(
"Connection Timeout",
f"Device did not respond within 5 seconds.\n\n"
f"Check that:\n"
f"• ESP32 is powered on\n"
f"• Correct firmware is flashed\n"
f"• USB cable is properly connected"
)
self._update_connection_status("red", "Timeout")
self.connect_button.configure(state="normal")
if self.stream:
try:
self.stream.disconnect()
except:
pass
self.stream = None
except Exception as e:
error_msg = f"Failed to connect:\n{e}"
if "Permission denied" in str(e) or "Resource busy" in str(e):
error_msg += "\n\nThe port may still be in use. Wait a few seconds and try again."
elif "FileNotFoundError" in str(type(e).__name__):
error_msg += f"\n\nPort not found. Try refreshing the port list."
messagebox.showerror("Connection Error", error_msg)
self._update_connection_status("red", "Failed")
self.connect_button.configure(state="normal")
if self.stream:
try:
self.stream.disconnect()
except:
pass
self.stream = None
def _disconnect_device(self):
"""Disconnect from ESP32."""
try:
if self.stream:
self.stream.disconnect()
# Give OS time to release the port
time.sleep(0.5)
self.is_connected = False
self.stream = None
self._update_connection_status("gray", "Disconnected")
self.connect_button.configure(text="Connect")
self.start_button.configure(state="disabled")
except Exception as e:
messagebox.showwarning("Disconnect Warning", f"Error during disconnect: {e}")
# Still mark as disconnected even if there was an error
self.is_connected = False
self.stream = None
self._update_connection_status("gray", "Disconnected")
self.connect_button.configure(text="Connect")
self.start_button.configure(state="disabled")
def on_hide(self):
"""Stop collection when leaving page."""
if self.is_collecting:
self.stop_collection()
def stop(self):
"""Stop everything."""
self.is_collecting = False
if self.stream:
self.stream.stop()
# =============================================================================
# INSPECT SESSIONS PAGE
# =============================================================================
class InspectPage(BasePage):
"""Page for inspecting saved sessions."""
def __init__(self, parent):
super().__init__(parent)
self.create_header(
"Inspect Sessions",
"View saved session data and features"
)
# Content
self.content = ctk.CTkFrame(self)
self.content.grid(row=1, column=0, sticky="nsew")
self.content.grid_columnconfigure(0, weight=1)
self.content.grid_columnconfigure(1, weight=3)
self.content.grid_rowconfigure(0, weight=1)
# Left panel - Session list
self.list_panel = ctk.CTkFrame(self.content)
self.list_panel.grid(row=0, column=0, sticky="nsew", padx=(0, 10))
ctk.CTkLabel(self.list_panel, text="Sessions", font=ctk.CTkFont(size=16, weight="bold")).pack(pady=10)
self.session_listbox = ctk.CTkScrollableFrame(self.list_panel)
self.session_listbox.pack(fill="both", expand=True, padx=10, pady=10)
self.refresh_button = ctk.CTkButton(self.list_panel, text="Refresh", command=self.load_sessions)
self.refresh_button.pack(pady=10)
# Right panel - Details
self.details_panel = ctk.CTkFrame(self.content)
self.details_panel.grid(row=0, column=1, sticky="nsew", padx=(10, 0))
self.details_label = ctk.CTkLabel(
self.details_panel,
text="Select a session to view details",
font=ctk.CTkFont(size=14)
)
self.details_label.pack(pady=20)
# Plot area
self.fig = None
self.canvas = None
self.session_buttons = []
def on_show(self):
"""Load sessions when page is shown."""
self.load_sessions()
def load_sessions(self):
"""Load and display available sessions."""
# Clear existing buttons
for btn in self.session_buttons:
btn.destroy()
self.session_buttons = []
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
label = ctk.CTkLabel(self.session_listbox, text="No sessions found")
label.pack(pady=10)
self.session_buttons.append(label)
return
for session_id in sessions:
info = storage.get_session_info(session_id)
btn_text = f"{session_id}\n{info['num_windows']} windows"
btn = ctk.CTkButton(
self.session_listbox,
text=btn_text,
font=ctk.CTkFont(size=12),
height=60,
anchor="w",
command=lambda s=session_id: self.show_session(s)
)
btn.pack(fill="x", pady=5)
self.session_buttons.append(btn)
def show_session(self, session_id: str):
"""Display session details and plot."""
storage = SessionStorage()
try:
X, y, label_names = storage.load_for_training(session_id)
except Exception as e:
messagebox.showerror("Error", f"Failed to load session: {e}")
return
# Clear previous plot
if self.canvas:
self.canvas.get_tk_widget().destroy()
# Create info text
info = storage.get_session_info(session_id)
info_text = f"""Session: {session_id}
User: {info['user_id']}
Time: {info['timestamp']}
Windows: {X.shape[0]}
Samples/window: {X.shape[1]}
Channels: {X.shape[2]}
Gestures: {', '.join(label_names)}"""
self.details_label.configure(text=info_text)
# Create plot
self.fig = Figure(figsize=(10, 6), dpi=100, facecolor='#2b2b2b')
# Plot raw signal for each channel
for ch in range(min(X.shape[2], 4)):
ax = self.fig.add_subplot(2, 2, ch + 1)
ax.set_facecolor('#2b2b2b')
ax.tick_params(colors='white')
signal = X[:, :, ch].flatten()
signal_centered = signal - signal.mean()
ax.plot(signal_centered[:2000], color='#00ff88', linewidth=0.5)
ax.set_title(f'Channel {ch}', color='white', fontsize=10)
ax.set_ylabel('Amplitude', color='white', fontsize=8)
ax.grid(True, alpha=0.3)
for spine in ax.spines.values():
spine.set_color('white')
self.fig.tight_layout()
self.canvas = FigureCanvasTkAgg(self.fig, master=self.details_panel)
self.canvas.draw()
self.canvas.get_tk_widget().pack(fill="both", expand=True, padx=20, pady=20)
# =============================================================================
# TRAINING PAGE
# =============================================================================
class TrainingPage(BasePage):
"""Page for training the classifier."""
def __init__(self, parent):
super().__init__(parent)
self.create_header(
"Train Classifier",
"Train LDA model on all collected sessions"
)
# Content
self.content = ctk.CTkFrame(self)
self.content.grid(row=1, column=0, sticky="nsew")
self.content.grid_columnconfigure(0, weight=1)
# Sessions info
self.info_frame = ctk.CTkFrame(self.content)
self.info_frame.pack(fill="x", padx=20, pady=20)
self.sessions_label = ctk.CTkLabel(
self.info_frame,
text="Loading sessions...",
font=ctk.CTkFont(size=14)
)
self.sessions_label.pack(pady=10)
# Train button
self.train_button = ctk.CTkButton(
self.content,
text="Train on All Sessions",
font=ctk.CTkFont(size=18, weight="bold"),
height=60,
command=self.train_model
)
self.train_button.pack(pady=20)
# Progress
self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
self.progress_bar.pack(pady=10)
self.progress_bar.set(0)
self.status_label = ctk.CTkLabel(self.content, text="", font=ctk.CTkFont(size=12))
self.status_label.pack()
# Results
self.results_frame = ctk.CTkFrame(self.content)
self.results_frame.pack(fill="both", expand=True, padx=20, pady=20)
self.results_text = ctk.CTkTextbox(self.results_frame, font=ctk.CTkFont(family="Courier", size=12))
self.results_text.pack(fill="both", expand=True, padx=10, pady=10)
self.classifier = None
def on_show(self):
"""Update session info when shown."""
self.update_session_info()
def update_session_info(self):
"""Update the sessions information display."""
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
self.sessions_label.configure(text="No sessions found. Collect data first!")
self.train_button.configure(state="disabled")
return
total_windows = 0
info_lines = [f"Found {len(sessions)} session(s):\n"]
for session_id in sessions:
info = storage.get_session_info(session_id)
info_lines.append(f" - {session_id}: {info['num_windows']} windows")
total_windows += info['num_windows']
info_lines.append(f"\nTotal: {total_windows} windows")
self.sessions_label.configure(text="\n".join(info_lines))
self.train_button.configure(state="normal")
def train_model(self):
"""Train the model on all sessions."""
self.train_button.configure(state="disabled")
self.results_text.delete("1.0", "end")
self.progress_bar.set(0)
self.status_label.configure(text="Loading data...")
# Run in thread to not block UI
thread = threading.Thread(target=self._train_thread, daemon=True)
thread.start()
def _train_thread(self):
"""Training thread."""
try:
storage = SessionStorage()
# Load data
self.after(0, lambda: self.status_label.configure(text="Loading all sessions..."))
self.after(0, lambda: self.progress_bar.set(0.2))
X, y, label_names, loaded_sessions = storage.load_all_for_training()
self.after(0, lambda: self._log(f"Loaded {X.shape[0]} windows from {len(loaded_sessions)} sessions"))
self.after(0, lambda: self._log(f"Labels: {label_names}\n"))
# Train
self.after(0, lambda: self.status_label.configure(text="Training classifier..."))
self.after(0, lambda: self.progress_bar.set(0.5))
self.classifier = EMGClassifier()
self.classifier.train(X, y, label_names)
self.after(0, lambda: self._log("Training complete!\n"))
# Cross-validation
self.after(0, lambda: self.status_label.configure(text="Running cross-validation..."))
self.after(0, lambda: self.progress_bar.set(0.7))
cv_scores = self.classifier.cross_validate(X, y, cv=5)
self.after(0, lambda: self._log(f"Cross-validation scores: {cv_scores.round(3)}"))
self.after(0, lambda: self._log(f"Mean accuracy: {cv_scores.mean()*100:.1f}% (+/- {cv_scores.std()*100:.1f}%)\n"))
# Feature importance
self.after(0, lambda: self._log("Feature importance (top 8):"))
importance = self.classifier.get_feature_importance()
for i, (name, score) in enumerate(list(importance.items())[:8]):
self.after(0, lambda n=name, s=score: self._log(f" {n}: {s:.3f}"))
# Save model
self.after(0, lambda: self.status_label.configure(text="Saving model..."))
self.after(0, lambda: self.progress_bar.set(0.9))
model_path = self.classifier.save(EMGClassifier.get_default_model_path())
self.after(0, lambda: self._log(f"\nModel saved to: {model_path}"))
self.after(0, lambda: self.progress_bar.set(1.0))
self.after(0, lambda: self.status_label.configure(text="Training complete!"))
# Update sidebar
self.after(0, lambda: self._update_sidebar())
except Exception as e:
self.after(0, lambda: self._log(f"\nError: {e}"))
self.after(0, lambda: self.status_label.configure(text="Training failed!"))
finally:
self.after(0, lambda: self.train_button.configure(state="normal"))
def _log(self, text: str):
"""Add text to results."""
self.results_text.insert("end", text + "\n")
self.results_text.see("end")
def _update_sidebar(self):
"""Safely update the sidebar."""
app = self.winfo_toplevel()
if isinstance(app, EMGApp):
app.sidebar.update_status()
# =============================================================================
# LIVE PREDICTION PAGE
# =============================================================================
class PredictionPage(BasePage):
"""Page for live prediction demo."""
def __init__(self, parent):
super().__init__(parent)
self.create_header(
"Live Prediction",
"Real-time gesture classification"
)
# State (MUST be initialized BEFORE creating UI elements)
self.is_predicting = False
self.is_connected = False
self.using_real_hardware = False
self.classifier = None
self.smoother = None
self.stream = None
self.data_queue = queue.Queue()
# Content
self.content = ctk.CTkFrame(self)
self.content.grid(row=1, column=0, sticky="nsew")
self.content.grid_columnconfigure(0, weight=1)
self.content.grid_rowconfigure(1, weight=1)
# Model status
self.status_frame = ctk.CTkFrame(self.content)
self.status_frame.pack(fill="x", padx=20, pady=20)
self.model_label = ctk.CTkLabel(
self.status_frame,
text="Checking for saved model...",
font=ctk.CTkFont(size=14)
)
self.model_label.pack(pady=10)
# Data Source selection
source_frame = ctk.CTkFrame(self.status_frame, fg_color="transparent")
source_frame.pack(fill="x", pady=(10, 0))
ctk.CTkLabel(source_frame, text="Data Source:", font=ctk.CTkFont(size=14)).pack(anchor="w")
self.source_var = ctk.StringVar(value="simulated")
radio_frame = ctk.CTkFrame(source_frame, fg_color="transparent")
radio_frame.pack(fill="x", pady=(5, 0))
self.sim_radio = ctk.CTkRadioButton(
radio_frame, text="Simulated", variable=self.source_var, value="simulated",
command=self._on_source_change
)
self.sim_radio.pack(side="left", padx=(0, 20))
self.real_radio = ctk.CTkRadioButton(
radio_frame, text="Real ESP32", variable=self.source_var, value="real",
command=self._on_source_change
)
self.real_radio.pack(side="left")
# Port selection (initially hidden)
self.port_frame = ctk.CTkFrame(source_frame, fg_color="transparent")
port_select_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent")
port_select_frame.pack(fill="x", pady=(5, 0))
ctk.CTkLabel(port_select_frame, text="Port:").pack(side="left")
self.port_var = ctk.StringVar(value="Auto-detect")
self.port_dropdown = ctk.CTkOptionMenu(
port_select_frame, variable=self.port_var,
values=["Auto-detect"], width=150
)
self.port_dropdown.pack(side="left", padx=(10, 5))
self.refresh_ports_btn = ctk.CTkButton(
port_select_frame, text="", width=30,
command=self._refresh_ports
)
self.refresh_ports_btn.pack(side="left")
# Connection status and button
connect_frame = ctk.CTkFrame(self.port_frame, fg_color="transparent")
connect_frame.pack(fill="x", pady=(5, 0))
self.connect_button = ctk.CTkButton(
connect_frame, text="Connect",
width=100, height=28,
command=self._toggle_connection,
state="disabled" # Disabled until "Real ESP32" selected
)
self.connect_button.pack(side="left", padx=(0, 10))
self.connection_status = ctk.CTkLabel(
connect_frame, text="● Disconnected",
font=ctk.CTkFont(size=11), text_color="gray"
)
self.connection_status.pack(side="left")
# Start button
self.start_button = ctk.CTkButton(
self.content,
text="Start Prediction",
font=ctk.CTkFont(size=18, weight="bold"),
height=60,
command=self.toggle_prediction
)
self.start_button.pack(pady=20)
# Prediction display
self.prediction_frame = ctk.CTkFrame(self.content)
self.prediction_frame.pack(fill="both", expand=True, padx=20, pady=20)
self.prediction_label = ctk.CTkLabel(
self.prediction_frame,
text="---",
font=ctk.CTkFont(size=72, weight="bold")
)
self.prediction_label.pack(pady=30)
self.confidence_bar = ctk.CTkProgressBar(self.prediction_frame, width=400, height=30)
self.confidence_bar.pack(pady=10)
self.confidence_bar.set(0)
self.confidence_label = ctk.CTkLabel(
self.prediction_frame,
text="Confidence: ---%",
font=ctk.CTkFont(size=18)
)
self.confidence_label.pack()
# Simulated gesture indicator
self.sim_label = ctk.CTkLabel(
self.prediction_frame,
text="",
font=ctk.CTkFont(size=14),
text_color="gray"
)
self.sim_label.pack(pady=10)
# Smoothing info display
self.smoothing_info_label = ctk.CTkLabel(
self.prediction_frame,
text="Smoothing: EMA(0.7) + Majority(5) + Debounce(3)",
font=ctk.CTkFont(size=11),
text_color="gray"
)
self.smoothing_info_label.pack()
self.raw_label = ctk.CTkLabel(
self.prediction_frame,
text="",
font=ctk.CTkFont(size=12),
text_color="gray"
)
self.raw_label.pack(pady=5)
def on_show(self):
"""Check model status when shown."""
self.check_model()
def check_model(self):
"""Check if a saved model exists."""
model_path = EMGClassifier.get_default_model_path()
if model_path.exists():
self.model_label.configure(
text=f"Saved model found: {model_path.name}",
text_color="green"
)
self.start_button.configure(state="normal")
else:
self.model_label.configure(
text="No saved model. Train a model first (Option 3).",
text_color="orange"
)
self.start_button.configure(state="disabled")
def toggle_prediction(self):
"""Start or stop prediction."""
# Prevent rapid double-clicks from interfering
if hasattr(self, '_toggling') and self._toggling:
return
self._toggling = True
try:
if self.is_predicting:
self.stop_prediction()
else:
self.start_prediction()
finally:
# Reset flag after brief delay to prevent immediate re-trigger
self.after(100, lambda: setattr(self, '_toggling', False))
def start_prediction(self):
"""Start live prediction."""
# CRITICAL: Drain any stale messages from previous sessions FIRST
# This prevents old 'done' messages from stopping the new session
try:
while True:
self.data_queue.get_nowait()
except queue.Empty:
pass
# Load model
try:
self.classifier = EMGClassifier.load(EMGClassifier.get_default_model_path())
except Exception as e:
messagebox.showerror("Error", f"Failed to load model: {e}")
return
# Determine data source
self.using_real_hardware = (self.source_var.get() == "real")
# For real hardware, must be connected
if self.using_real_hardware:
if not self.is_connected or not self.stream:
messagebox.showerror("Not Connected", "Please connect to the ESP32 first.")
return
# Send start command to begin streaming
try:
self.stream.start()
except Exception as e:
messagebox.showerror("Start Error", f"Failed to start streaming:\n{e}")
return
# Create prediction smoother
self.smoother = PredictionSmoother(
label_names=self.classifier.label_names,
probability_smoothing=0.7, # Higher = more smoothing
majority_vote_window=5, # Past predictions to consider
debounce_count=3, # Consecutive same predictions to change output
)
self.is_predicting = True
self.start_button.configure(text="Stop", fg_color="red")
# Disable source selection and connection during prediction
self.sim_radio.configure(state="disabled")
self.real_radio.configure(state="disabled")
if self.using_real_hardware:
self.connect_button.configure(state="disabled")
# Start prediction thread
thread = threading.Thread(target=self._prediction_thread, daemon=True)
thread.start()
# Start UI update
self.update_prediction_ui()
def stop_prediction(self):
"""Stop live prediction."""
self.is_predicting = False
# Safe cleanup - stream might already be in error state
try:
if self.stream:
if self.using_real_hardware:
# Send stop command (returns to CONNECTED state)
self.stream.stop()
else:
# For simulated stream, just stop it
self.stream.stop()
self.stream = None
except Exception:
pass # Ignore cleanup errors
self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"])
self.prediction_label.configure(text="---", text_color="white")
self.confidence_bar.set(0)
self.confidence_label.configure(text="Confidence: ---%")
self.sim_label.configure(text="")
self.raw_label.configure(text="", text_color="gray")
# Re-enable source selection and connection button
self.sim_radio.configure(state="normal")
self.real_radio.configure(state="normal")
if self.using_real_hardware:
self.connect_button.configure(state="normal")
# Still connected, just not streaming
if self.is_connected:
device_name = self.stream.device_info.get('device', 'ESP32') if self.stream and self.stream.device_info else 'ESP32'
self._update_connection_status("green", f"Connected ({device_name})")
def _on_source_change(self):
"""Show/hide port selection based on data source."""
# Clean up any existing connection/stream when switching modes
if self.is_connected and self.stream:
try:
self.stream.disconnect()
except:
pass
self.is_connected = False
self.stream = None
if self.source_var.get() == "real":
self.port_frame.pack(fill="x", pady=(5, 0))
self._refresh_ports()
self.connect_button.configure(text="Connect", state="normal")
self._update_connection_status("gray", "Disconnected")
# Start button will be enabled after connection
else:
self.port_frame.pack_forget()
self._update_connection_status("gray", "Not using hardware")
self.connect_button.configure(state="disabled")
def _refresh_ports(self):
"""Scan and populate available serial ports."""
ports = serial.tools.list_ports.comports()
port_names = ["Auto-detect"] + [p.device for p in ports]
self.port_dropdown.configure(values=port_names)
if ports:
self._update_connection_status("orange", f"Found {len(ports)} port(s)")
else:
self._update_connection_status("red", "No ports found")
def _get_serial_port(self):
"""Get selected port, or None for auto-detect."""
port = self.port_var.get()
return None if port == "Auto-detect" else port
def _update_connection_status(self, color: str, text: str):
"""Update the connection status indicator."""
self.connection_status.configure(text=f"{text}", text_color=color)
def _toggle_connection(self):
"""Connect or disconnect from ESP32."""
if self.is_connected:
self._disconnect_device()
else:
self._connect_device()
def _connect_device(self):
"""Connect to ESP32 with handshake."""
port = self._get_serial_port()
try:
# Update UI to show connecting
self._update_connection_status("orange", "Connecting...")
self.connect_button.configure(state="disabled")
self.update() # Force UI update
# Create stream and connect
self.stream = RealSerialStream(port=port)
device_info = self.stream.connect(timeout=5.0)
# Success!
self.is_connected = True
self._update_connection_status("green", f"Connected ({device_info.get('device', 'ESP32')})")
self.connect_button.configure(text="Disconnect", state="normal")
except TimeoutError as e:
messagebox.showerror(
"Connection Timeout",
f"Device did not respond within 5 seconds.\n\n"
f"Check that:\n"
f"• ESP32 is powered on\n"
f"• Correct firmware is flashed\n"
f"• USB cable is properly connected"
)
self._update_connection_status("red", "Timeout")
self.connect_button.configure(state="normal")
if self.stream:
try:
self.stream.disconnect()
except:
pass
self.stream = None
except Exception as e:
error_msg = f"Failed to connect:\n{e}"
if "Permission denied" in str(e) or "Resource busy" in str(e):
error_msg += "\n\nThe port may still be in use. Wait a few seconds and try again."
elif "FileNotFoundError" in str(type(e).__name__):
error_msg += f"\n\nPort not found. Try refreshing the port list."
messagebox.showerror("Connection Error", error_msg)
self._update_connection_status("red", "Failed")
self.connect_button.configure(state="normal")
if self.stream:
try:
self.stream.disconnect()
except:
pass
self.stream = None
def _disconnect_device(self):
"""Disconnect from ESP32."""
try:
if self.stream:
self.stream.disconnect()
# Give OS time to release the port
time.sleep(0.5)
self.is_connected = False
self.stream = None
self._update_connection_status("gray", "Disconnected")
self.connect_button.configure(text="Connect")
except Exception as e:
messagebox.showwarning("Disconnect Warning", f"Error during disconnect: {e}")
# Still mark as disconnected even if there was an error
self.is_connected = False
self.stream = None
self._update_connection_status("gray", "Disconnected")
self.connect_button.configure(text="Connect")
def _prediction_thread(self):
"""Background prediction thread."""
# For simulated mode, create new stream
if not self.using_real_hardware:
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
# Stream is already started (either via handshake for real HW or will be started for simulated)
parser = EMGParser(num_channels=NUM_CHANNELS)
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
# Simulated gesture cycling (only for simulated mode)
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
gesture_idx = 0
gesture_duration = 2.5
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[0]
# Start simulated stream if needed
if not self.using_real_hardware:
try:
if hasattr(self.stream, 'set_gesture'):
self.stream.set_gesture(current_gesture)
self.stream.start()
except Exception as e:
self.data_queue.put(('error', f"Failed to start simulated stream: {e}"))
return
else:
# Real hardware is already streaming
self.data_queue.put(('connection_status', ('green', 'Streaming')))
while self.is_predicting:
# Change simulated gesture periodically (only for simulated mode)
if hasattr(self.stream, 'set_gesture'):
elapsed = time.perf_counter() - gesture_start
if elapsed > gesture_duration:
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[gesture_idx]
self.stream.set_gesture(current_gesture)
self.data_queue.put(('sim_gesture', current_gesture))
# Read and process
try:
line = self.stream.readline()
except Exception as e:
# Only report error if we didn't intentionally stop
if self.is_predicting:
self.data_queue.put(('error', f"Serial read error: {e}"))
break
if line:
sample = parser.parse_line(line)
if sample:
window = windower.add_sample(sample)
if window:
# Get raw prediction
window_data = window.to_numpy()
raw_label, proba = self.classifier.predict(window_data)
raw_confidence = max(proba) * 100
# Apply smoothing
smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba)
smoothed_confidence = smoothed_conf * 100
# Send both raw and smoothed to UI
self.data_queue.put(('prediction', (
smoothed_label, # The stable output
smoothed_confidence,
raw_label, # The raw (possibly twitchy) output
raw_confidence,
)))
# Safe cleanup - stream might already be stopped
try:
if self.stream:
self.stream.stop()
except Exception:
pass # Ignore cleanup errors
def update_prediction_ui(self):
"""Update UI from prediction thread."""
try:
while True:
msg_type, data = self.data_queue.get_nowait()
if msg_type == 'prediction':
smoothed_label, smoothed_conf, raw_label, raw_conf = data
# Display smoothed (stable) prediction
display_label = smoothed_label.upper().replace("_", " ")
color = get_gesture_color(smoothed_label)
self.prediction_label.configure(text=display_label, text_color=color)
self.confidence_bar.set(smoothed_conf / 100)
self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%")
# Show raw prediction for comparison (grayed out)
raw_display = raw_label.upper().replace("_", " ")
if raw_label != smoothed_label:
# Raw differs from smoothed - show it was filtered
self.raw_label.configure(
text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered",
text_color="orange"
)
else:
self.raw_label.configure(
text=f"Raw: {raw_display} ({raw_conf:.0f}%)",
text_color="gray"
)
elif msg_type == 'sim_gesture':
self.sim_label.configure(text=f"[Simulating: {data}]")
elif msg_type == 'error':
# Show error and stop prediction
if self.using_real_hardware:
self._update_connection_status("red", "Disconnected")
messagebox.showerror("Prediction Error", data)
self.stop_prediction()
return
elif msg_type == 'connection_status':
# Update connection indicator
color, text = data
self._update_connection_status(color, text)
# Also update sim_label to indicate real hardware
if text == "Connected":
self.sim_label.configure(text="[Real ESP32 Hardware]")
except queue.Empty:
pass
if self.is_predicting:
self.after(50, self.update_prediction_ui)
def on_hide(self):
"""Stop when leaving page."""
if self.is_predicting:
self.stop_prediction()
def stop(self):
"""Stop everything."""
self.is_predicting = False
if self.stream:
self.stream.stop()
# =============================================================================
# VISUALIZATION PAGE
# =============================================================================
class VisualizationPage(BasePage):
"""Page for LDA visualization."""
def __init__(self, parent):
super().__init__(parent)
self.create_header(
"LDA Visualization",
"Visualize decision boundaries and feature space"
)
# Content
self.content = ctk.CTkFrame(self)
self.content.grid(row=1, column=0, sticky="nsew")
self.content.grid_columnconfigure(0, weight=1)
self.content.grid_rowconfigure(1, weight=1)
# Controls
self.controls = ctk.CTkFrame(self.content)
self.controls.pack(fill="x", padx=20, pady=20)
self.generate_button = ctk.CTkButton(
self.controls,
text="Generate Visualizations",
font=ctk.CTkFont(size=16, weight="bold"),
height=50,
command=self.generate_plots
)
self.generate_button.pack(side="left", padx=10)
self.status_label = ctk.CTkLabel(self.controls, text="", font=ctk.CTkFont(size=12))
self.status_label.pack(side="left", padx=20)
# Plot area
self.plot_frame = ctk.CTkFrame(self.content)
self.plot_frame.pack(fill="both", expand=True, padx=20, pady=(0, 20))
self.canvas = None
def generate_plots(self):
"""Generate LDA visualization plots."""
storage = SessionStorage()
sessions = storage.list_sessions()
if not sessions:
messagebox.showwarning("No Data", "No sessions found. Collect data first!")
return
self.status_label.configure(text="Loading data...")
self.generate_button.configure(state="disabled")
# Run in thread
thread = threading.Thread(target=self._generate_thread, daemon=True)
thread.start()
def _generate_thread(self):
"""Generate plots in background."""
try:
storage = SessionStorage()
X, y, label_names, _ = storage.load_all_for_training()
self.after(0, lambda: self.status_label.configure(text="Extracting features..."))
# Extract features and train LDA
extractor = EMGFeatureExtractor()
X_features = extractor.extract_features_batch(X)
lda = LinearDiscriminantAnalysis()
lda.fit(X_features, y)
X_lda = lda.transform(X_features)
self.after(0, lambda: self.status_label.configure(text="Creating plots..."))
n_classes = len(label_names)
colors = plt.cm.viridis(np.linspace(0.2, 0.8, n_classes))
# Create figure
fig = Figure(figsize=(12, 5), dpi=100, facecolor='#2b2b2b')
# Plot 1: LDA Feature Space with Decision Boundaries
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_facecolor('#2b2b2b')
ax1.tick_params(colors='white')
# Create mesh grid for decision boundaries
if X_lda.shape[1] >= 2:
x_min, x_max = X_lda[:, 0].min() - 1, X_lda[:, 0].max() + 1
y_min, y_max = X_lda[:, 1].min() - 1, X_lda[:, 1].max() + 1
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, 200),
np.linspace(y_min, y_max, 200)
)
# Train a classifier on the 2D LDA space for visualization
lda_2d = LinearDiscriminantAnalysis()
lda_2d.fit(X_lda[:, :2], y)
Z = lda_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# Plot decision regions (filled contours)
ax1.contourf(xx, yy, Z, alpha=0.3, levels=np.arange(-0.5, n_classes, 1),
colors=[colors[i] for i in range(n_classes)])
# Plot decision boundaries (lines)
ax1.contour(xx, yy, Z, colors='white', linewidths=1.5, alpha=0.8)
# Plot data points
for i, label in enumerate(label_names):
mask = y == i
ax1.scatter(
X_lda[mask, 0],
X_lda[mask, 1] if X_lda.shape[1] > 1 else np.zeros(mask.sum()),
c=[colors[i]], label=label, s=50, alpha=0.9,
edgecolors='white', linewidth=0.5
)
ax1.set_xlabel('LDA Component 1', color='white')
ax1.set_ylabel('LDA Component 2', color='white')
ax1.set_title('LDA Decision Boundaries', color='white', fontsize=14)
ax1.legend(facecolor='#2b2b2b', labelcolor='white', loc='upper right')
for spine in ax1.spines.values():
spine.set_color('white')
# Plot 2: Class distributions
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_facecolor('#2b2b2b')
ax2.tick_params(colors='white')
for i, label in enumerate(label_names):
mask = y == i
ax2.hist(X_lda[mask, 0], bins=20, alpha=0.6, label=label, color=colors[i])
ax2.set_xlabel('LDA Component 1', color='white')
ax2.set_ylabel('Count', color='white')
ax2.set_title('Class Distributions', color='white', fontsize=14)
ax2.legend(facecolor='#2b2b2b', labelcolor='white')
for spine in ax2.spines.values():
spine.set_color('white')
fig.tight_layout()
# Display in GUI
self.after(0, lambda: self._show_plot(fig))
self.after(0, lambda: self.status_label.configure(text="Done!"))
except Exception as e:
self.after(0, lambda: self.status_label.configure(text=f"Error: {e}"))
finally:
self.after(0, lambda: self.generate_button.configure(state="normal"))
def _show_plot(self, fig):
"""Show the plot in the GUI."""
if self.canvas:
self.canvas.get_tk_widget().destroy()
self.canvas = FigureCanvasTkAgg(fig, master=self.plot_frame)
self.canvas.draw()
self.canvas.get_tk_widget().pack(fill="both", expand=True, padx=10, pady=10)
# =============================================================================
# ENTRY POINT
# =============================================================================
if __name__ == "__main__":
app = EMGApp()
app.mainloop()