overhaul
This commit is contained in:
267
emg_gui.py
267
emg_gui.py
@@ -1322,6 +1322,18 @@ class TrainingPage(BasePage):
|
||||
)
|
||||
self.train_button.pack(pady=20)
|
||||
|
||||
# Export button
|
||||
self.export_button = ctk.CTkButton(
|
||||
self.content,
|
||||
text="Export for ESP32",
|
||||
font=ctk.CTkFont(size=14),
|
||||
height=40,
|
||||
fg_color="green",
|
||||
state="disabled",
|
||||
command=self.export_model
|
||||
)
|
||||
self.export_button.pack(pady=5)
|
||||
|
||||
# Progress
|
||||
self.progress_bar = ctk.CTkProgressBar(self.content, width=400)
|
||||
self.progress_bar.pack(pady=10)
|
||||
@@ -1434,6 +1446,32 @@ class TrainingPage(BasePage):
|
||||
|
||||
finally:
|
||||
self.after(0, lambda: self.train_button.configure(state="normal"))
|
||||
self.after(0, lambda: self.export_button.configure(state="normal"))
|
||||
|
||||
def export_model(self):
|
||||
"""Export trained model to C header."""
|
||||
if not self.classifier or not self.classifier.is_trained:
|
||||
messagebox.showerror("Error", "No trained model to export!")
|
||||
return
|
||||
|
||||
# Default path in ESP32 project
|
||||
default_path = Path("EMG_Arm/src/core/model_weights.h").absolute()
|
||||
|
||||
# Ask user for location, defaulting to the ESP32 project source
|
||||
filename = tk.filedialog.asksaveasfilename(
|
||||
title="Export Model Header",
|
||||
initialdir=default_path.parent,
|
||||
initialfile=default_path.name,
|
||||
filetypes=[("C Header", "*.h")]
|
||||
)
|
||||
|
||||
if filename:
|
||||
try:
|
||||
path = self.classifier.export_to_header(filename)
|
||||
self._log(f"\nExported model to: {path}")
|
||||
messagebox.showinfo("Export Success", f"Model exported to:\n{path}\n\nRecompile ESP32 firmware to apply.")
|
||||
except Exception as e:
|
||||
messagebox.showerror("Export Error", f"Failed to export: {e}")
|
||||
|
||||
def _log(self, text: str):
|
||||
"""Add text to results."""
|
||||
@@ -1861,115 +1899,158 @@ class PredictionPage(BasePage):
|
||||
self._update_connection_status("gray", "Disconnected")
|
||||
self.connect_button.configure(text="Connect")
|
||||
|
||||
def _prediction_thread(self):
|
||||
"""Background prediction thread."""
|
||||
# For simulated mode, create new stream
|
||||
if not self.using_real_hardware:
|
||||
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
|
||||
def toggle_prediction(self):
|
||||
"""Start or stop prediction."""
|
||||
if self.is_predicting:
|
||||
self.stop_prediction()
|
||||
else:
|
||||
self.start_prediction()
|
||||
|
||||
# Stream is already started (either via handshake for real HW or will be started for simulated)
|
||||
def start_prediction(self):
|
||||
"""Start live prediction."""
|
||||
# Determine mode
|
||||
self.using_real_hardware = (self.source_var.get() == "real")
|
||||
|
||||
if self.using_real_hardware:
|
||||
if not self.is_connected or not self.stream:
|
||||
messagebox.showerror("Not Connected", "Please connect to ESP32 first.")
|
||||
return
|
||||
|
||||
print("[DEBUG] Starting Edge Prediction (On-Device)...")
|
||||
try:
|
||||
# Send "start_predict" command to ESP32
|
||||
if hasattr(self.stream, 'ser'):
|
||||
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
|
||||
self.stream.running = True
|
||||
else:
|
||||
# Fallback
|
||||
self.stream.ser.write(b'{"cmd": "start_predict"}\n')
|
||||
self.stream.running = True
|
||||
|
||||
except Exception as e:
|
||||
messagebox.showerror("Start Error", f"Failed to start: {e}")
|
||||
return
|
||||
|
||||
else:
|
||||
# Simulated - use PC-side inference
|
||||
self.stream = GestureAwareEMGStream(num_channels=NUM_CHANNELS, sample_rate=SAMPLING_RATE_HZ)
|
||||
self.stream.start()
|
||||
|
||||
# Load model for PC-side (Simulated) OR for display (optional)
|
||||
# Even for Edge, we might want the label list.
|
||||
if not self.using_real_hardware:
|
||||
if not self.classifier:
|
||||
model_path = EMGClassifier.get_default_model_path()
|
||||
if model_path.exists():
|
||||
self.classifier = EMGClassifier.load(model_path)
|
||||
self.model_label.configure(text="Model: Loaded", text_color="green")
|
||||
else:
|
||||
self.model_label.configure(text="Model: Not found (Simulating)", text_color="orange")
|
||||
|
||||
# Reset smoother
|
||||
self.smoother = PredictionSmoother(
|
||||
label_names=self.classifier.label_names if self.classifier else ["rest", "open", "fist", "hook_em", "thumbs_up"],
|
||||
probability_smoothing=0.7,
|
||||
majority_vote_window=5,
|
||||
debounce_count=3
|
||||
)
|
||||
|
||||
self.is_predicting = True
|
||||
self.start_button.configure(text="Stop Prediction", fg_color="red")
|
||||
|
||||
# Start display loop
|
||||
self.prediction_thread = threading.Thread(target=self.prediction_loop, daemon=True)
|
||||
self.prediction_thread.start()
|
||||
|
||||
self.update_prediction_ui()
|
||||
|
||||
def stop_prediction(self):
|
||||
"""Stop prediction."""
|
||||
self.is_predicting = False
|
||||
if self.stream:
|
||||
self.stream.stop() # Sends "stop" usually
|
||||
if not self.using_real_hardware:
|
||||
self.stream = None
|
||||
|
||||
self.start_button.configure(text="Start Prediction", fg_color=["#3B8ED0", "#1F6AA5"])
|
||||
self.prediction_label.configure(text="---", text_color="gray")
|
||||
self.confidence_label.configure(text="Confidence: ---%")
|
||||
self.confidence_bar.set(0)
|
||||
|
||||
def prediction_loop(self):
|
||||
"""Loop for reading data and (optionally) running inference."""
|
||||
import json
|
||||
|
||||
parser = EMGParser(num_channels=NUM_CHANNELS)
|
||||
windower = Windower(window_size_ms=WINDOW_SIZE_MS, sample_rate=SAMPLING_RATE_HZ, overlap=0.0)
|
||||
|
||||
# Simulated gesture cycling (only for simulated mode)
|
||||
gesture_cycle = ["rest", "open", "fist", "hook_em", "thumbs_up"]
|
||||
gesture_idx = 0
|
||||
gesture_duration = 2.5
|
||||
gesture_start = time.perf_counter()
|
||||
current_gesture = gesture_cycle[0]
|
||||
|
||||
# Start simulated stream if needed
|
||||
if not self.using_real_hardware:
|
||||
try:
|
||||
if hasattr(self.stream, 'set_gesture'):
|
||||
self.stream.set_gesture(current_gesture)
|
||||
self.stream.start()
|
||||
except Exception as e:
|
||||
self.data_queue.put(('error', f"Failed to start simulated stream: {e}"))
|
||||
return
|
||||
else:
|
||||
# Real hardware is already streaming
|
||||
self.data_queue.put(('connection_status', ('green', 'Streaming')))
|
||||
|
||||
while self.is_predicting:
|
||||
# Change simulated gesture periodically (only for simulated mode)
|
||||
if hasattr(self.stream, 'set_gesture'):
|
||||
elapsed = time.perf_counter() - gesture_start
|
||||
if elapsed > gesture_duration:
|
||||
gesture_idx = (gesture_idx + 1) % len(gesture_cycle)
|
||||
gesture_start = time.perf_counter()
|
||||
current_gesture = gesture_cycle[gesture_idx]
|
||||
self.stream.set_gesture(current_gesture)
|
||||
self.data_queue.put(('sim_gesture', current_gesture))
|
||||
|
||||
# Read and process
|
||||
try:
|
||||
line = self.stream.readline()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if self.using_real_hardware:
|
||||
# Edge Inference Mode: Expect JSON
|
||||
try:
|
||||
line = line.strip()
|
||||
if line.startswith('{'):
|
||||
data = json.loads(line)
|
||||
|
||||
if "gesture" in data:
|
||||
# Update UI with Edge Prediction
|
||||
gesture = data["gesture"]
|
||||
conf = float(data.get("conf", 0.0))
|
||||
|
||||
self.data_queue.put(('prediction', (gesture, conf)))
|
||||
|
||||
elif "status" in data:
|
||||
print(f"[ESP32] {data}")
|
||||
else:
|
||||
pass
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
else:
|
||||
# PC Side Inference (Simulated)
|
||||
sample = parser.parse_line(line)
|
||||
if sample:
|
||||
window = windower.add_sample(sample)
|
||||
if window and self.classifier:
|
||||
# Run Inference Local
|
||||
raw_label, proba = self.classifier.predict(window.to_numpy())
|
||||
label, conf, _ = self.smoother.update(raw_label, proba)
|
||||
|
||||
self.data_queue.put(('prediction', (label, conf)))
|
||||
|
||||
except Exception as e:
|
||||
# Only report error if we didn't intentionally stop
|
||||
if self.is_predicting:
|
||||
self.data_queue.put(('error', f"Serial read error: {e}"))
|
||||
print(f"Prediction loop error: {e}")
|
||||
break
|
||||
|
||||
if line:
|
||||
sample = parser.parse_line(line)
|
||||
if sample:
|
||||
window = windower.add_sample(sample)
|
||||
if window:
|
||||
# Get raw prediction
|
||||
window_data = window.to_numpy()
|
||||
raw_label, proba = self.classifier.predict(window_data)
|
||||
raw_confidence = max(proba) * 100
|
||||
|
||||
# Apply smoothing
|
||||
smoothed_label, smoothed_conf, debug = self.smoother.update(raw_label, proba)
|
||||
smoothed_confidence = smoothed_conf * 100
|
||||
|
||||
# Send both raw and smoothed to UI
|
||||
self.data_queue.put(('prediction', (
|
||||
smoothed_label, # The stable output
|
||||
smoothed_confidence,
|
||||
raw_label, # The raw (possibly twitchy) output
|
||||
raw_confidence,
|
||||
)))
|
||||
|
||||
# Safe cleanup - stream might already be stopped
|
||||
try:
|
||||
if self.stream:
|
||||
self.stream.stop()
|
||||
except Exception:
|
||||
pass # Ignore cleanup errors
|
||||
|
||||
def update_prediction_ui(self):
|
||||
"""Update UI from prediction thread."""
|
||||
"""Update UI from queue."""
|
||||
try:
|
||||
while True:
|
||||
msg_type, data = self.data_queue.get_nowait()
|
||||
|
||||
|
||||
if msg_type == 'prediction':
|
||||
smoothed_label, smoothed_conf, raw_label, raw_conf = data
|
||||
|
||||
# Display smoothed (stable) prediction
|
||||
display_label = smoothed_label.upper().replace("_", " ")
|
||||
color = get_gesture_color(smoothed_label)
|
||||
|
||||
self.prediction_label.configure(text=display_label, text_color=color)
|
||||
self.confidence_bar.set(smoothed_conf / 100)
|
||||
self.confidence_label.configure(text=f"Confidence: {smoothed_conf:.1f}%")
|
||||
|
||||
# Show raw prediction for comparison (grayed out)
|
||||
raw_display = raw_label.upper().replace("_", " ")
|
||||
if raw_label != smoothed_label:
|
||||
# Raw differs from smoothed - show it was filtered
|
||||
self.raw_label.configure(
|
||||
text=f"Raw: {raw_display} ({raw_conf:.0f}%) → filtered",
|
||||
text_color="orange"
|
||||
)
|
||||
else:
|
||||
self.raw_label.configure(
|
||||
text=f"Raw: {raw_display} ({raw_conf:.0f}%)",
|
||||
text_color="gray"
|
||||
)
|
||||
label, conf = data
|
||||
|
||||
# Update label
|
||||
self.prediction_label.configure(
|
||||
text=label.upper(),
|
||||
text_color=get_gesture_color(label)
|
||||
)
|
||||
|
||||
# Update confidence
|
||||
self.confidence_label.configure(text=f"Confidence: {conf*100:.1f}%")
|
||||
self.confidence_bar.set(conf)
|
||||
|
||||
# Clear raw label since we don't have raw vs smooth distinction in edge mode
|
||||
# (or we could expose it if we updated the C struct, but for now keep it simple)
|
||||
self.raw_label.configure(text="", text_color="gray")
|
||||
|
||||
elif msg_type == 'sim_gesture':
|
||||
self.sim_label.configure(text=f"[Simulating: {data}]")
|
||||
|
||||
Reference in New Issue
Block a user