This commit is contained in:
2026-01-27 21:31:49 -06:00
parent 9bdf9d1109
commit 31dede537c
7 changed files with 893 additions and 320 deletions

View File

@@ -1322,6 +1322,18 @@ class TrainingPage(BasePage):
)
self.train_button.pack(pady=20)
# Export button
self.export_button = ctk.CTkButton(
self.content,
text="Export for ESP32",
font=ctk.CTkFont(size=14),
height=40,
fg_color="green",
state="disabled",
command=self.export_model
)
self.export_button.pack(pady=5)
# Progress
self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
self.progress_bar.pack(pady=10)
@@ -1434,6 +1446,32 @@ class TrainingPage(BasePage):
finally:
self.after(0, lambda: self.train_button.configure(state="normal"))
self.after(0, lambda: self.export_button.configure(state="normal"))
def export_model(self):
"""Export trained model to C header."""
if not self.classifier or not self.classifier.is_trained:
messagebox.showerror("Error", "No trained model to export!")
return
# Default path in ESP32 project
default_path = Path("EMG_Arm/src/core/model_weights.h").absolute()
# Ask user for location, defaulting to the ESP32 project source
filename = tk.filedialog.asksaveasfilename(
title="Export Model Header",
initialdir=default_path.parent,
initialfile=default_path.name,
filetypes=[("C Header", "*.h")]
)
if filename:
try:
path = self.classifier.export_to_header(filename)
self._log(f"\nExported model to: {path}")
messagebox.showinfo("Export Success", f"Model exported to:\n{path}\n\nRecompile ESP32 firmware to apply.")
except Exception as e:
messagebox.showerror("Export Error", f"Failed to export: {e}")
def _log(self, text: str):
"""Add text to results."""
@@ -1861,115 +1899,158 @@ class PredictionPage(BasePage):
self._update_connection_status("gray", "Disconnected")
self.connect_button.configure(text="Connect")
def _prediction_thread(self):
"""Background prediction thread."""
# For simulated mode, create new stream
if not self.using_real_hardware:
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
def toggle_prediction(self):
"""Start or stop prediction."""
if self.is_predicting:
self.stop_prediction()
else:
self.start_prediction()
# Stream is already started (either via handshake for real HW or will be started for simulated)
def start_prediction(self):
"""Start live prediction."""
# Determine mode
self.using_real_hardware = (self.source_var.get() == "real")
if self.using_real_hardware:
if not self.is_connected or not self.stream:
messagebox.showerror("Not Connected", "Please connect to ESP32 first.")
return
print("[DEBUG] Starting Edge Prediction (On-Device)...")
try:
# Send "start_predict" command to ESP32
if hasattr(self.stream, 'ser'):
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
self.stream.running = True
else:
# Fallback
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
self.stream.running = True
except Exception as e:
messagebox.showerror("Start Error", f"Failed to start: {e}")
return
else:
# Simulated - use PC-side inference
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
self.stream.start()
# Load model for PC-side (Simulated) OR for display (optional)
# Even for Edge, we might want the label list.
if not self.using_real_hardware:
if not self.classifier:
model_path = EMGClassifier.get_default_model_path()
if model_path.exists():
self.classifier = EMGClassifier.load(model_path)
self.model_label.configure(text="Model: Loaded", text_color="green")
else:
self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange")
# Reset smoother
self.smoother = PredictionSmoother(
label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"],
probability_smoothing=0.7,
majority_vote_window=5,
debounce_count=3
)
self.is_predicting = True
self.start_button.configure(text="Stop Prediction", fg_color="red")
# Start display loop
self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True)
self.prediction_thread.start()
self.update_prediction_ui()
def stop_prediction(self):
"""Stop prediction."""
self.is_predicting = False
if self.stream:
self.stream.stop() # Sends "stop" usually
if not self.using_real_hardware:
self.stream = None
self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"])
self.prediction_label.configure(text="---", text_color="gray")
self.confidence_label.configure(text="Confidence: ---%")
self.confidence_bar.set(0)
def prediction_loop(self):
"""Loop for reading data and (optionally) running inference."""
import json
parser = EMGParser(num_channels=NUM_CHANNELS)
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
# Simulated gesture cycling (only for simulated mode)
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
gesture_idx = 0
gesture_duration = 2.5
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[0]
# Start simulated stream if needed
if not self.using_real_hardware:
try:
if hasattr(self.stream, 'set_gesture'):
self.stream.set_gesture(current_gesture)
self.stream.start()
except Exception as e:
self.data_queue.put(('error', f"Failed to start simulated stream: {e}"))
return
else:
# Real hardware is already streaming
self.data_queue.put(('connection_status', ('green', 'Streaming')))
while self.is_predicting:
# Change simulated gesture periodically (only for simulated mode)
if hasattr(self.stream, 'set_gesture'):
elapsed = time.perf_counter() - gesture_start
if elapsed > gesture_duration:
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
gesture_start = time.perf_counter()
current_gesture = gesture_cycle[gesture_idx]
self.stream.set_gesture(current_gesture)
self.data_queue.put(('sim_gesture', current_gesture))
# Read and process
try:
line = self.stream.readline()
if not line:
continue
if self.using_real_hardware:
# Edge Inference Mode: Expect JSON
try:
line = line.strip()
if line.startswith('{'):
data = json.loads(line)
if "gesture" in data:
# Update UI with Edge Prediction
gesture = data["gesture"]
conf = float(data.get("conf", 0.0))
self.data_queue.put(('prediction', (gesture, conf)))
elif "status" in data:
print(f"[ESP32] {data}")
else:
pass
except json.JSONDecodeError:
pass
else:
# PC Side Inference (Simulated)
sample = parser.parse_line(line)
if sample:
window = windower.add_sample(sample)
if window and self.classifier:
# Run Inference Local
raw_label, proba = self.classifier.predict(window.to_numpy())
label, conf, _ = self.smoother.update(raw_label, proba)
self.data_queue.put(('prediction', (label, conf)))
except Exception as e:
# Only report error if we didn't intentionally stop
if self.is_predicting:
self.data_queue.put(('error', f"Serial read error: {e}"))
print(f"Prediction loop error: {e}")
break
if line:
sample = parser.parse_line(line)
if sample:
window = windower.add_sample(sample)
if window:
# Get raw prediction
window_data = window.to_numpy()
raw_label, proba = self.classifier.predict(window_data)
raw_confidence = max(proba) * 100
# Apply smoothing
smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba)
smoothed_confidence = smoothed_conf * 100
# Send both raw and smoothed to UI
self.data_queue.put(('prediction', (
smoothed_label, # The stable output
smoothed_confidence,
raw_label, # The raw (possibly twitchy) output
raw_confidence,
)))
# Safe cleanup - stream might already be stopped
try:
if self.stream:
self.stream.stop()
except Exception:
pass # Ignore cleanup errors
def update_prediction_ui(self):
"""Update UI from prediction thread."""
"""Update UI from queue."""
try:
while True:
msg_type, data = self.data_queue.get_nowait()
if msg_type == 'prediction':
smoothed_label, smoothed_conf, raw_label, raw_conf = data
# Display smoothed (stable) prediction
display_label = smoothed_label.upper().replace("_", " ")
color = get_gesture_color(smoothed_label)
self.prediction_label.configure(text=display_label, text_color=color)
self.confidence_bar.set(smoothed_conf / 100)
self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%")
# Show raw prediction for comparison (grayed out)
raw_display = raw_label.upper().replace("_", " ")
if raw_label != smoothed_label:
# Raw differs from smoothed - show it was filtered
self.raw_label.configure(
text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered",
text_color="orange"
)
else:
self.raw_label.configure(
text=f"Raw: {raw_display} ({raw_conf:.0f}%)",
text_color="gray"
)
label, conf = data
# Update label
self.prediction_label.configure(
text=label.upper(),
text_color=get_gesture_color(label)
)
# Update confidence
self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%")
self.confidence_bar.set(conf)
# Clear raw label since we don't have raw vs smooth distinction in edge mode
# (or we could expose it if we updated the C struct, but for now keep it simple)
self.raw_label.configure(text="", text_color="gray")
elif msg_type == 'sim_gesture':
self.sim_label.configure(text=f"[Simulating: {data}]")