bunch of training and label latency fixes, also trained 70% accurate model.
This commit is contained in:
@@ -12,6 +12,6 @@ board_upload.maximum_size = 33554432
|
|||||||
; and I can give you the content for it.
|
; and I can give you the content for it.
|
||||||
board_build.partitions = partitions.csv
|
board_build.partitions = partitions.csv
|
||||||
|
|
||||||
monitor_speed = 115200
|
monitor_speed = 921600
|
||||||
monitor_dtr = 1
|
monitor_dtr = 1
|
||||||
monitor_rts = 1
|
monitor_rts = 1
|
||||||
@@ -232,9 +232,8 @@ static void stream_emg_data(void)
|
|||||||
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
|
/* Read EMG (fake or real depending on FEATURE_FAKE_EMG) */
|
||||||
emg_sensor_read(&sample);
|
emg_sensor_read(&sample);
|
||||||
|
|
||||||
/* Output in CSV format matching Python expectation */
|
/* Output in CSV format - channels only, Python handles timestamps */
|
||||||
printf("%lu,%u,%u,%u,%u\n",
|
printf("%u,%u,%u,%u\n",
|
||||||
(unsigned long)sample.timestamp_ms,
|
|
||||||
sample.channels[0],
|
sample.channels[0],
|
||||||
sample.channels[1],
|
sample.channels[1],
|
||||||
sample.channels[2],
|
sample.channels[2],
|
||||||
@@ -283,7 +282,7 @@ void emgPrinter() {
|
|||||||
if (i != EMG_NUM_CHANNELS - 1) printf(" | ");
|
if (i != EMG_NUM_CHANNELS - 1) printf(" | ");
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
|
// vTaskDelayUntil(&previousWake, pdMS_TO_TICKS(100));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -339,6 +338,6 @@ void app_main(void)
|
|||||||
|
|
||||||
printf("[INIT] Done!\n\n");
|
printf("[INIT] Done!\n\n");
|
||||||
|
|
||||||
emgPrinter();
|
// emgPrinter();
|
||||||
// appConnector();
|
appConnector();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@
|
|||||||
* EMG Configuration
|
* EMG Configuration
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
#define EMG_NUM_CHANNELS 1 /**< Number of EMG sensor channels */
|
#define EMG_NUM_CHANNELS 4 /**< Number of EMG sensor channels */
|
||||||
#define EMG_SAMPLE_RATE_HZ 1000 /**< Samples per second per channel */
|
#define EMG_SAMPLE_RATE_HZ 1000 /**< Samples per second per channel */
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
|
|||||||
@@ -19,7 +19,12 @@
|
|||||||
adc_oneshot_unit_handle_t adc1_handle;
|
adc_oneshot_unit_handle_t adc1_handle;
|
||||||
adc_cali_handle_t cali_handle = NULL;
|
adc_cali_handle_t cali_handle = NULL;
|
||||||
|
|
||||||
const uint8_t emg_channels[EMG_NUM_CHANNELS] = {ADC_CHANNEL_1};
|
const uint8_t emg_channels[EMG_NUM_CHANNELS] = {
|
||||||
|
ADC_CHANNEL_1, // GPIO 2 - EMG Channel 0
|
||||||
|
ADC_CHANNEL_2, // GPIO 3 - EMG Channel 1
|
||||||
|
ADC_CHANNEL_8, // GPIO 9 - EMG Channel 2
|
||||||
|
ADC_CHANNEL_9 // GPIO 10 - EMG Channel 3
|
||||||
|
};
|
||||||
|
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Public Functions
|
* Public Functions
|
||||||
@@ -94,7 +99,6 @@ void emg_sensor_read(emg_sample_t *sample)
|
|||||||
ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
|
ESP_ERROR_CHECK(adc_cali_raw_to_voltage(cali_handle, raw_val, &voltage_mv));
|
||||||
sample->channels[i] = (uint16_t) voltage_mv;
|
sample->channels[i] = (uint16_t) voltage_mv;
|
||||||
}
|
}
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
Binary file not shown.
BIN
collected_data/latency_fix_100_20260127_184804.hdf5
Normal file
BIN
collected_data/latency_fix_100_20260127_184804.hdf5
Normal file
Binary file not shown.
BIN
collected_data/latency_fix_101_20260127_185344.hdf5
Normal file
BIN
collected_data/latency_fix_101_20260127_185344.hdf5
Normal file
Binary file not shown.
BIN
collected_data/latency_fix_102_20260127_190022.hdf5
Normal file
BIN
collected_data/latency_fix_102_20260127_190022.hdf5
Normal file
Binary file not shown.
BIN
collected_data/latency_fix_103_20260127_191249.hdf5
Normal file
BIN
collected_data/latency_fix_103_20260127_191249.hdf5
Normal file
Binary file not shown.
BIN
collected_data/latency_fix_104_20260127_195150.hdf5
Normal file
BIN
collected_data/latency_fix_104_20260127_195150.hdf5
Normal file
Binary file not shown.
BIN
collected_data/latency_fix_105_20260127_195503.hdf5
Normal file
BIN
collected_data/latency_fix_105_20260127_195503.hdf5
Normal file
Binary file not shown.
BIN
collected_data/new_placements_000_20260127_174231.hdf5
Normal file
BIN
collected_data/new_placements_000_20260127_174231.hdf5
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
28
emg_gui.py
28
emg_gui.py
@@ -301,6 +301,7 @@ class CollectionPage(BasePage):
|
|||||||
self.scheduler = None
|
self.scheduler = None
|
||||||
self.collected_windows = []
|
self.collected_windows = []
|
||||||
self.collected_labels = []
|
self.collected_labels = []
|
||||||
|
self.collected_raw_samples = [] # For label alignment
|
||||||
self.sample_buffer = []
|
self.sample_buffer = []
|
||||||
self.collection_thread = None
|
self.collection_thread = None
|
||||||
self.data_queue = queue.Queue()
|
self.data_queue = queue.Queue()
|
||||||
@@ -511,7 +512,7 @@ class CollectionPage(BasePage):
|
|||||||
ax.tick_params(colors='white')
|
ax.tick_params(colors='white')
|
||||||
ax.set_ylabel(f'Ch{i}', color='white', fontsize=10)
|
ax.set_ylabel(f'Ch{i}', color='white', fontsize=10)
|
||||||
ax.set_xlim(0, 500)
|
ax.set_xlim(0, 500)
|
||||||
ax.set_ylim(0, 1024)
|
ax.set_ylim(0, 3300) # ESP32 outputs millivolts (0-3100 mV)
|
||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
for spine in ax.spines.values():
|
for spine in ax.spines.values():
|
||||||
spine.set_color('white')
|
spine.set_color('white')
|
||||||
@@ -648,6 +649,7 @@ class CollectionPage(BasePage):
|
|||||||
# Reset state
|
# Reset state
|
||||||
self.collected_windows = []
|
self.collected_windows = []
|
||||||
self.collected_labels = []
|
self.collected_labels = []
|
||||||
|
self.collected_raw_samples = [] # Store raw samples for label alignment
|
||||||
self.sample_buffer = []
|
self.sample_buffer = []
|
||||||
print("[DEBUG] Reset collection state")
|
print("[DEBUG] Reset collection state")
|
||||||
|
|
||||||
@@ -800,6 +802,9 @@ class CollectionPage(BasePage):
|
|||||||
timeout_warning_sent = False
|
timeout_warning_sent = False
|
||||||
sample = self.parser.parse_line(line)
|
sample = self.parser.parse_line(line)
|
||||||
if sample:
|
if sample:
|
||||||
|
# Store raw sample for label alignment
|
||||||
|
self.collected_raw_samples.append(sample)
|
||||||
|
|
||||||
# Batch samples for plotting (don't send every single one)
|
# Batch samples for plotting (don't send every single one)
|
||||||
sample_batch.append(sample.channels)
|
sample_batch.append(sample.channels)
|
||||||
|
|
||||||
@@ -932,9 +937,25 @@ class CollectionPage(BasePage):
|
|||||||
notes=""
|
notes=""
|
||||||
)
|
)
|
||||||
|
|
||||||
filepath = storage.save_session(self.collected_windows, self.collected_labels, metadata)
|
# Get session start time for label alignment
|
||||||
|
session_start_time = None
|
||||||
|
if self.scheduler and self.scheduler.session_start_time:
|
||||||
|
session_start_time = self.scheduler.session_start_time
|
||||||
|
|
||||||
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}")
|
filepath = storage.save_session(
|
||||||
|
windows=self.collected_windows,
|
||||||
|
labels=self.collected_labels,
|
||||||
|
metadata=metadata,
|
||||||
|
raw_samples=self.collected_raw_samples if self.collected_raw_samples else None,
|
||||||
|
session_start_time=session_start_time
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if alignment was performed
|
||||||
|
alignment_msg = ""
|
||||||
|
if session_start_time and self.collected_raw_samples:
|
||||||
|
alignment_msg = "\n\nLabel alignment: enabled"
|
||||||
|
|
||||||
|
messagebox.showinfo("Saved", f"Session saved!\n\nID: {session_id}\nWindows: {len(self.collected_windows)}{alignment_msg}")
|
||||||
|
|
||||||
# Update sidebar
|
# Update sidebar
|
||||||
app = self.winfo_toplevel()
|
app = self.winfo_toplevel()
|
||||||
@@ -944,6 +965,7 @@ class CollectionPage(BasePage):
|
|||||||
# Reset for next collection
|
# Reset for next collection
|
||||||
self.collected_windows = []
|
self.collected_windows = []
|
||||||
self.collected_labels = []
|
self.collected_labels = []
|
||||||
|
self.collected_raw_samples = []
|
||||||
self.save_button.configure(state="disabled")
|
self.save_button.configure(state="disabled")
|
||||||
self.status_label.configure(text="Ready to collect")
|
self.status_label.configure(text="Ready to collect")
|
||||||
self.window_count_label.configure(text="Windows: 0")
|
self.window_count_label.configure(text="Windows: 0")
|
||||||
|
|||||||
BIN
extra_data/emg_real_001_20260125_132316.hdf5
Normal file
BIN
extra_data/emg_real_001_20260125_132316.hdf5
Normal file
Binary file not shown.
BIN
extra_data/emg_real_001_20260125_133141.hdf5
Normal file
BIN
extra_data/emg_real_001_20260125_133141.hdf5
Normal file
Binary file not shown.
BIN
extra_data/emg_real_001_20260125_133545.hdf5
Normal file
BIN
extra_data/emg_real_001_20260125_133545.hdf5
Normal file
Binary file not shown.
BIN
extra_data/emg_real_001_20260125_144247.hdf5
Normal file
BIN
extra_data/emg_real_001_20260125_144247.hdf5
Normal file
Binary file not shown.
BIN
extra_data/label_test_002_20260125_183955.hdf5
Normal file
BIN
extra_data/label_test_002_20260125_183955.hdf5
Normal file
Binary file not shown.
BIN
extra_data/label_test_003_20260125_185807.hdf5
Normal file
BIN
extra_data/label_test_003_20260125_185807.hdf5
Normal file
Binary file not shown.
BIN
extra_data/latency_debug_20260125_181559.hdf5
Normal file
BIN
extra_data/latency_debug_20260125_181559.hdf5
Normal file
Binary file not shown.
BIN
extra_data/latency_fix_000_20260126_230849.hdf5
Normal file
BIN
extra_data/latency_fix_000_20260126_230849.hdf5
Normal file
Binary file not shown.
BIN
extra_data/latency_fix_002_20260126_233357.hdf5
Normal file
BIN
extra_data/latency_fix_002_20260126_233357.hdf5
Normal file
Binary file not shown.
BIN
extra_data/latency_fix_003_20260126_234717.hdf5
Normal file
BIN
extra_data/latency_fix_003_20260126_234717.hdf5
Normal file
Binary file not shown.
BIN
extra_data/latency_fix_004_20260127_002041.hdf5
Normal file
BIN
extra_data/latency_fix_004_20260127_002041.hdf5
Normal file
Binary file not shown.
BIN
extra_data/latency_fix_005_20260127_002618.hdf5
Normal file
BIN
extra_data/latency_fix_005_20260127_002618.hdf5
Normal file
Binary file not shown.
@@ -39,7 +39,7 @@ import matplotlib.pyplot as plt
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
|
NUM_CHANNELS = 4 # Number of EMG channels (MyoWare sensors)
|
||||||
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
|
SAMPLING_RATE_HZ = 1000 # Must match ESP32's EMG_SAMPLE_RATE_HZ
|
||||||
SERIAL_BAUD = 115200 # Typical baud rate for ESP32
|
SERIAL_BAUD = 921600 # High baud rate to prevent serial buffer backlog
|
||||||
|
|
||||||
# Windowing configuration
|
# Windowing configuration
|
||||||
WINDOW_SIZE_MS = 150 # Window size in milliseconds
|
WINDOW_SIZE_MS = 150 # Window size in milliseconds
|
||||||
@@ -55,6 +55,27 @@ DATA_DIR = Path("collected_data") # Directory to store session files
|
|||||||
MODEL_DIR = Path("models") # Directory to store trained models
|
MODEL_DIR = Path("models") # Directory to store trained models
|
||||||
USER_ID = "user_001" # Current user ID (change per user)
|
USER_ID = "user_001" # Current user ID (change per user)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LABEL ALIGNMENT CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
# Human reaction time causes EMG activity to lag behind label prompts.
|
||||||
|
# We detect when EMG actually rises and shift labels to match.
|
||||||
|
|
||||||
|
ENABLE_LABEL_ALIGNMENT = True # Enable/disable automatic label alignment
|
||||||
|
ONSET_THRESHOLD = 2 # Signal must exceed baseline + threshold * std
|
||||||
|
ONSET_SEARCH_MS = 2000 # Search window after prompt (ms)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# TRANSITION WINDOW FILTERING
|
||||||
|
# =============================================================================
|
||||||
|
# Windows near gesture transitions contain ambiguous data (reaction time at start,
|
||||||
|
# muscle relaxation at end). Discard these during training for cleaner labels.
|
||||||
|
# This is standard practice in EMG research (see Frontiers Neurorobotics 2023).
|
||||||
|
|
||||||
|
DISCARD_TRANSITION_WINDOWS = True # Enable/disable transition filtering during training
|
||||||
|
TRANSITION_START_MS = 300 # Discard windows within this time AFTER gesture starts
|
||||||
|
TRANSITION_END_MS = 150 # Discard windows within this time BEFORE gesture ends
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# DATA STRUCTURES
|
# DATA STRUCTURES
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -187,30 +208,27 @@ class EMGParser:
|
|||||||
"""
|
"""
|
||||||
Parse a line from ESP32 into an EMGSample.
|
Parse a line from ESP32 into an EMGSample.
|
||||||
|
|
||||||
Expected format: "timestamp_ms,ch0,ch1,ch2,ch3\n"
|
Expected format: "ch0,ch1,ch2,ch3\n" (channels only, no ESP32 timestamp)
|
||||||
|
Python assigns timestamp on receipt for label alignment.
|
||||||
Returns None if parsing fails.
|
Returns None if parsing fails.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Strip whitespace and split
|
# Strip whitespace and split
|
||||||
parts = line.strip().split(',')
|
parts = line.strip().split(',')
|
||||||
|
|
||||||
# Validate we have correct number of fields
|
# Validate we have correct number of fields (channels only)
|
||||||
expected_fields = 1 + self.num_channels # timestamp + channels
|
if len(parts) != self.num_channels:
|
||||||
if len(parts) != expected_fields:
|
|
||||||
self.parse_errors += 1
|
self.parse_errors += 1
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Parse ESP32 timestamp
|
|
||||||
esp_timestamp_ms = int(parts[0])
|
|
||||||
|
|
||||||
# Parse channel values
|
# Parse channel values
|
||||||
channels = [float(parts[i + 1]) for i in range(self.num_channels)]
|
channels = [float(parts[i]) for i in range(self.num_channels)]
|
||||||
|
|
||||||
# Create sample with Python-side timestamp
|
# Create sample with Python-side timestamp (aligned with label clock)
|
||||||
sample = EMGSample(
|
sample = EMGSample(
|
||||||
timestamp=time.perf_counter(), # High-resolution monotonic clock
|
timestamp=time.perf_counter(), # High-resolution monotonic clock
|
||||||
channels=channels,
|
channels=channels,
|
||||||
esp_timestamp_ms=esp_timestamp_ms
|
esp_timestamp_ms=None # No longer using ESP32 timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
self.samples_parsed += 1
|
self.samples_parsed += 1
|
||||||
@@ -491,6 +509,179 @@ class GestureAwareEMGStream(SimulatedEMGStream):
|
|||||||
time.sleep(interval)
|
time.sleep(interval)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LABEL ALIGNMENT (Simple Onset Detection)
|
||||||
|
# =============================================================================
|
||||||
|
from scipy.signal import butter, sosfiltfilt
|
||||||
|
|
||||||
|
|
||||||
|
def align_labels_with_onset(
|
||||||
|
labels: list[str],
|
||||||
|
window_start_times: np.ndarray,
|
||||||
|
raw_timestamps: np.ndarray,
|
||||||
|
raw_channels: np.ndarray,
|
||||||
|
sampling_rate: int,
|
||||||
|
threshold_factor: float = 2.0,
|
||||||
|
search_ms: float = 800
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
Align labels to EMG onset by detecting when signal rises above baseline.
|
||||||
|
|
||||||
|
Simple algorithm:
|
||||||
|
1. High-pass filter to remove DC offset
|
||||||
|
2. Compute RMS envelope across channels
|
||||||
|
3. At each label transition, find where envelope exceeds baseline + threshold
|
||||||
|
4. Move label boundary to that point
|
||||||
|
"""
|
||||||
|
if len(labels) == 0:
|
||||||
|
return labels.copy()
|
||||||
|
|
||||||
|
# High-pass filter to remove DC (raw ADC has ~2340mV offset)
|
||||||
|
nyquist = sampling_rate / 2
|
||||||
|
sos = butter(2, 20.0 / nyquist, btype='high', output='sos')
|
||||||
|
|
||||||
|
# Filter and compute envelope (RMS across channels)
|
||||||
|
filtered = np.zeros_like(raw_channels)
|
||||||
|
for ch in range(raw_channels.shape[1]):
|
||||||
|
filtered[:, ch] = sosfiltfilt(sos, raw_channels[:, ch])
|
||||||
|
envelope = np.sqrt(np.mean(filtered ** 2, axis=1))
|
||||||
|
|
||||||
|
# Smooth envelope
|
||||||
|
sos_lp = butter(2, 10.0 / nyquist, btype='low', output='sos')
|
||||||
|
envelope = sosfiltfilt(sos_lp, envelope)
|
||||||
|
|
||||||
|
# Find transitions and detect onsets
|
||||||
|
search_samples = int(search_ms / 1000 * sampling_rate)
|
||||||
|
baseline_samples = int(200 / 1000 * sampling_rate)
|
||||||
|
|
||||||
|
boundaries = [] # (time, new_label)
|
||||||
|
|
||||||
|
for i in range(1, len(labels)):
|
||||||
|
if labels[i] != labels[i - 1]:
|
||||||
|
prompt_time = window_start_times[i]
|
||||||
|
|
||||||
|
# Find index in raw signal closest to prompt time
|
||||||
|
prompt_idx = np.searchsorted(raw_timestamps, prompt_time)
|
||||||
|
|
||||||
|
# Get baseline (before transition)
|
||||||
|
base_start = max(0, prompt_idx - baseline_samples)
|
||||||
|
baseline = envelope[base_start:prompt_idx]
|
||||||
|
if len(baseline) == 0:
|
||||||
|
boundaries.append((prompt_time + 0.3, labels[i]))
|
||||||
|
continue
|
||||||
|
|
||||||
|
threshold = np.mean(baseline) + threshold_factor * np.std(baseline)
|
||||||
|
|
||||||
|
# Search forward for onset
|
||||||
|
search_end = min(len(envelope), prompt_idx + search_samples)
|
||||||
|
onset_idx = None
|
||||||
|
|
||||||
|
for j in range(prompt_idx, search_end):
|
||||||
|
if envelope[j] > threshold:
|
||||||
|
onset_idx = j
|
||||||
|
break
|
||||||
|
|
||||||
|
if onset_idx is not None:
|
||||||
|
onset_time = raw_timestamps[onset_idx]
|
||||||
|
else:
|
||||||
|
onset_time = prompt_time + 0.3 # fallback
|
||||||
|
|
||||||
|
boundaries.append((onset_time, labels[i]))
|
||||||
|
|
||||||
|
# Assign labels based on detected boundaries
|
||||||
|
aligned = []
|
||||||
|
boundary_idx = 0
|
||||||
|
current_label = labels[0]
|
||||||
|
|
||||||
|
for t in window_start_times:
|
||||||
|
while boundary_idx < len(boundaries) and t >= boundaries[boundary_idx][0]:
|
||||||
|
current_label = boundaries[boundary_idx][1]
|
||||||
|
boundary_idx += 1
|
||||||
|
aligned.append(current_label)
|
||||||
|
|
||||||
|
return aligned
|
||||||
|
|
||||||
|
|
||||||
|
def filter_transition_windows(
|
||||||
|
X: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
labels: list[str],
|
||||||
|
start_times: np.ndarray,
|
||||||
|
end_times: np.ndarray,
|
||||||
|
transition_start_ms: float = TRANSITION_START_MS,
|
||||||
|
transition_end_ms: float = TRANSITION_END_MS
|
||||||
|
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||||
|
"""
|
||||||
|
Filter out windows that fall within transition zones at gesture boundaries.
|
||||||
|
|
||||||
|
This removes ambiguous data where:
|
||||||
|
- User is still reacting to prompt (start of gesture)
|
||||||
|
- User is anticipating next gesture (end of gesture)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X: EMG data array (n_windows, samples, channels)
|
||||||
|
y: Label indices (n_windows,)
|
||||||
|
labels: String labels (n_windows,)
|
||||||
|
start_times: Window start times in seconds (n_windows,)
|
||||||
|
end_times: Window end times in seconds (n_windows,)
|
||||||
|
transition_start_ms: Discard windows within this time after gesture start
|
||||||
|
transition_end_ms: Discard windows within this time before gesture end
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered (X, y, labels) with transition windows removed
|
||||||
|
"""
|
||||||
|
if len(X) == 0:
|
||||||
|
return X, y, labels
|
||||||
|
|
||||||
|
transition_start_sec = transition_start_ms / 1000.0
|
||||||
|
transition_end_sec = transition_end_ms / 1000.0
|
||||||
|
|
||||||
|
# Find gesture boundaries (where label changes)
|
||||||
|
# Each boundary is the START of a new gesture segment
|
||||||
|
boundaries = [0] # First window starts a segment
|
||||||
|
for i in range(1, len(labels)):
|
||||||
|
if labels[i] != labels[i-1]:
|
||||||
|
boundaries.append(i)
|
||||||
|
boundaries.append(len(labels)) # End marker
|
||||||
|
|
||||||
|
# For each segment, find start_time and end_time of the gesture
|
||||||
|
# Then mark windows that are within transition zones
|
||||||
|
keep_mask = np.ones(len(X), dtype=bool)
|
||||||
|
|
||||||
|
for seg_idx in range(len(boundaries) - 1):
|
||||||
|
seg_start_idx = boundaries[seg_idx]
|
||||||
|
seg_end_idx = boundaries[seg_idx + 1]
|
||||||
|
|
||||||
|
# Get the time boundaries of this gesture segment
|
||||||
|
gesture_start_time = start_times[seg_start_idx]
|
||||||
|
gesture_end_time = end_times[seg_end_idx - 1] # Last window's end time
|
||||||
|
|
||||||
|
# Mark windows in transition zones
|
||||||
|
for i in range(seg_start_idx, seg_end_idx):
|
||||||
|
window_start = start_times[i]
|
||||||
|
window_end = end_times[i]
|
||||||
|
|
||||||
|
# Check if window is too close to gesture START (reaction time zone)
|
||||||
|
if window_start < gesture_start_time + transition_start_sec:
|
||||||
|
keep_mask[i] = False
|
||||||
|
|
||||||
|
# Check if window is too close to gesture END (anticipation zone)
|
||||||
|
if window_end > gesture_end_time - transition_end_sec:
|
||||||
|
keep_mask[i] = False
|
||||||
|
|
||||||
|
# Apply filter
|
||||||
|
X_filtered = X[keep_mask]
|
||||||
|
y_filtered = y[keep_mask]
|
||||||
|
labels_filtered = [l for l, keep in zip(labels, keep_mask) if keep]
|
||||||
|
|
||||||
|
n_removed = len(X) - len(X_filtered)
|
||||||
|
if n_removed > 0:
|
||||||
|
print(f"[Filter] Removed {n_removed} transition windows ({n_removed/len(X)*100:.1f}%)")
|
||||||
|
print(f"[Filter] Kept {len(X_filtered)} windows for training")
|
||||||
|
|
||||||
|
return X_filtered, y_filtered, labels_filtered
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# SESSION STORAGE (Save/Load labeled data to HDF5)
|
# SESSION STORAGE (Save/Load labeled data to HDF5)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -527,16 +718,24 @@ class SessionStorage:
|
|||||||
windows: list[EMGWindow],
|
windows: list[EMGWindow],
|
||||||
labels: list[str],
|
labels: list[str],
|
||||||
metadata: SessionMetadata,
|
metadata: SessionMetadata,
|
||||||
raw_samples: Optional[list[EMGSample]] = None
|
raw_samples: Optional[list[EMGSample]] = None,
|
||||||
|
session_start_time: Optional[float] = None,
|
||||||
|
enable_alignment: bool = ENABLE_LABEL_ALIGNMENT
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Save a collection session to HDF5.
|
Save a collection session to HDF5 with optional label alignment.
|
||||||
|
|
||||||
|
When raw_samples and session_start_time are provided and enable_alignment
|
||||||
|
is True, automatically detects EMG onset and corrects labels for human
|
||||||
|
reaction time delay.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
windows: List of EMGWindow objects (no label info)
|
windows: List of EMGWindow objects (no label info)
|
||||||
labels: List of gesture labels, parallel to windows
|
labels: List of gesture labels, parallel to windows
|
||||||
metadata: Session metadata
|
metadata: Session metadata
|
||||||
raw_samples: Optional raw samples for debugging
|
raw_samples: Raw samples (required for alignment)
|
||||||
|
session_start_time: When session started (required for alignment)
|
||||||
|
enable_alignment: Whether to perform automatic label alignment
|
||||||
"""
|
"""
|
||||||
filepath = self.get_session_filepath(metadata.session_id)
|
filepath = self.get_session_filepath(metadata.session_id)
|
||||||
|
|
||||||
@@ -549,6 +748,35 @@ class SessionStorage:
|
|||||||
window_samples = len(windows[0].samples)
|
window_samples = len(windows[0].samples)
|
||||||
num_channels = len(windows[0].samples[0].channels)
|
num_channels = len(windows[0].samples[0].channels)
|
||||||
|
|
||||||
|
# Prepare timing arrays
|
||||||
|
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
|
||||||
|
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
|
||||||
|
|
||||||
|
# Label alignment using onset detection
|
||||||
|
aligned_labels = labels
|
||||||
|
original_labels = labels
|
||||||
|
|
||||||
|
if enable_alignment and raw_samples and len(raw_samples) > 0:
|
||||||
|
print("[Storage] Aligning labels to EMG onset...")
|
||||||
|
|
||||||
|
raw_timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
||||||
|
raw_channels = np.array([s.channels for s in raw_samples], dtype=np.float32)
|
||||||
|
|
||||||
|
aligned_labels = align_labels_with_onset(
|
||||||
|
labels=labels,
|
||||||
|
window_start_times=start_times,
|
||||||
|
raw_timestamps=raw_timestamps,
|
||||||
|
raw_channels=raw_channels,
|
||||||
|
sampling_rate=metadata.sampling_rate,
|
||||||
|
threshold_factor=ONSET_THRESHOLD,
|
||||||
|
search_ms=ONSET_SEARCH_MS
|
||||||
|
)
|
||||||
|
|
||||||
|
changed = sum(1 for a, b in zip(labels, aligned_labels) if a != b)
|
||||||
|
print(f"[Storage] Labels aligned: {changed}/{len(labels)} windows shifted")
|
||||||
|
elif enable_alignment:
|
||||||
|
print("[Storage] Warning: No raw samples, skipping alignment")
|
||||||
|
|
||||||
with h5py.File(filepath, 'w') as f:
|
with h5py.File(filepath, 'w') as f:
|
||||||
# Metadata as attributes
|
# Metadata as attributes
|
||||||
f.attrs['user_id'] = metadata.user_id
|
f.attrs['user_id'] = metadata.user_id
|
||||||
@@ -568,19 +796,24 @@ class SessionStorage:
|
|||||||
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
|
emg_data = np.array([w.to_numpy() for w in windows], dtype=np.float32)
|
||||||
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
|
windows_grp.create_dataset('emg_data', data=emg_data, compression='gzip', compression_opts=4)
|
||||||
|
|
||||||
# Labels stored separately from window data
|
# Store ALIGNED labels as primary (what training will use)
|
||||||
max_label_len = max(len(l) for l in labels)
|
max_label_len = max(len(l) for l in aligned_labels)
|
||||||
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
|
dt = h5py.string_dtype(encoding='utf-8', length=max_label_len + 1)
|
||||||
windows_grp.create_dataset('labels', data=labels, dtype=dt)
|
windows_grp.create_dataset('labels', data=aligned_labels, dtype=dt)
|
||||||
|
|
||||||
|
# Also store original labels for reference/debugging
|
||||||
|
windows_grp.create_dataset('labels_original', data=original_labels, dtype=dt)
|
||||||
|
|
||||||
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
|
window_ids = np.array([w.window_id for w in windows], dtype=np.int32)
|
||||||
windows_grp.create_dataset('window_ids', data=window_ids)
|
windows_grp.create_dataset('window_ids', data=window_ids)
|
||||||
|
|
||||||
start_times = np.array([w.start_time for w in windows], dtype=np.float64)
|
|
||||||
end_times = np.array([w.end_time for w in windows], dtype=np.float64)
|
|
||||||
windows_grp.create_dataset('start_times', data=start_times)
|
windows_grp.create_dataset('start_times', data=start_times)
|
||||||
windows_grp.create_dataset('end_times', data=end_times)
|
windows_grp.create_dataset('end_times', data=end_times)
|
||||||
|
|
||||||
|
# Store alignment metadata
|
||||||
|
f.attrs['alignment_enabled'] = enable_alignment
|
||||||
|
f.attrs['alignment_method'] = 'onset_detection' if (enable_alignment and raw_samples) else 'none'
|
||||||
|
|
||||||
if raw_samples:
|
if raw_samples:
|
||||||
raw_grp = f.create_group('raw_samples')
|
raw_grp = f.create_group('raw_samples')
|
||||||
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
timestamps = np.array([s.timestamp for s in raw_samples], dtype=np.float64)
|
||||||
@@ -655,13 +888,21 @@ class SessionStorage:
|
|||||||
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
|
print(f"[Storage] {len(windows)} windows, {len(metadata.gestures)} gesture types")
|
||||||
return windows, labels_out, metadata
|
return windows, labels_out, metadata
|
||||||
|
|
||||||
def load_for_training(self, session_id: str) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
def load_for_training(self, session_id: str, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||||
"""Load a single session in ML-ready format: X, y, label_names."""
|
"""
|
||||||
|
Load a single session in ML-ready format: X, y, label_names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: The session to load
|
||||||
|
filter_transitions: If True, remove windows in transition zones (default from config)
|
||||||
|
"""
|
||||||
filepath = self.get_session_filepath(session_id)
|
filepath = self.get_session_filepath(session_id)
|
||||||
|
|
||||||
with h5py.File(filepath, 'r') as f:
|
with h5py.File(filepath, 'r') as f:
|
||||||
X = f['windows/emg_data'][:]
|
X = f['windows/emg_data'][:]
|
||||||
labels_raw = f['windows/labels'][:]
|
labels_raw = f['windows/labels'][:]
|
||||||
|
start_times = f['windows/start_times'][:]
|
||||||
|
end_times = f['windows/end_times'][:]
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
for l in labels_raw:
|
for l in labels_raw:
|
||||||
@@ -670,18 +911,33 @@ class SessionStorage:
|
|||||||
else:
|
else:
|
||||||
labels.append(l)
|
labels.append(l)
|
||||||
|
|
||||||
|
print(f"[Storage] Loaded session: {session_id} ({X.shape[0]} windows)")
|
||||||
|
|
||||||
|
# Apply transition filtering if enabled
|
||||||
|
if filter_transitions:
|
||||||
|
label_names_pre = sorted(set(labels))
|
||||||
|
label_to_idx_pre = {name: idx for idx, name in enumerate(label_names_pre)}
|
||||||
|
y_pre = np.array([label_to_idx_pre[l] for l in labels], dtype=np.int32)
|
||||||
|
|
||||||
|
X, y_pre, labels = filter_transition_windows(
|
||||||
|
X, y_pre, labels, start_times, end_times
|
||||||
|
)
|
||||||
|
|
||||||
label_names = sorted(set(labels))
|
label_names = sorted(set(labels))
|
||||||
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
|
label_to_idx = {name: idx for idx, name in enumerate(label_names)}
|
||||||
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
|
y = np.array([label_to_idx[l] for l in labels], dtype=np.int32)
|
||||||
|
|
||||||
print(f"[Storage] Loaded for training: X{X.shape}, y{y.shape}")
|
print(f"[Storage] Ready for training: X{X.shape}, y{y.shape}")
|
||||||
print(f"[Storage] Labels: {label_names}")
|
print(f"[Storage] Labels: {label_names}")
|
||||||
return X, y, label_names
|
return X, y, label_names
|
||||||
|
|
||||||
def load_all_for_training(self) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
|
def load_all_for_training(self, filter_transitions: bool = DISCARD_TRANSITION_WINDOWS) -> tuple[np.ndarray, np.ndarray, list[str], list[str]]:
|
||||||
"""
|
"""
|
||||||
Load ALL sessions combined into a single training dataset.
|
Load ALL sessions combined into a single training dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_transitions: If True, remove windows in transition zones (default from config)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
|
X: Combined EMG windows from all sessions (n_total_windows, samples, channels)
|
||||||
y: Combined labels as integers (n_total_windows,)
|
y: Combined labels as integers (n_total_windows,)
|
||||||
@@ -697,11 +953,15 @@ class SessionStorage:
|
|||||||
raise ValueError("No sessions found to load!")
|
raise ValueError("No sessions found to load!")
|
||||||
|
|
||||||
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
|
print(f"[Storage] Loading {len(sessions)} session(s) for combined training...")
|
||||||
|
if filter_transitions:
|
||||||
|
print(f"[Storage] Transition filtering: START={TRANSITION_START_MS}ms, END={TRANSITION_END_MS}ms")
|
||||||
|
|
||||||
all_X = []
|
all_X = []
|
||||||
all_labels = []
|
all_labels = []
|
||||||
loaded_sessions = []
|
loaded_sessions = []
|
||||||
reference_shape = None
|
reference_shape = None
|
||||||
|
total_removed = 0
|
||||||
|
total_original = 0
|
||||||
|
|
||||||
for session_id in sessions:
|
for session_id in sessions:
|
||||||
filepath = self.get_session_filepath(session_id)
|
filepath = self.get_session_filepath(session_id)
|
||||||
@@ -709,6 +969,8 @@ class SessionStorage:
|
|||||||
with h5py.File(filepath, 'r') as f:
|
with h5py.File(filepath, 'r') as f:
|
||||||
X = f['windows/emg_data'][:]
|
X = f['windows/emg_data'][:]
|
||||||
labels_raw = f['windows/labels'][:]
|
labels_raw = f['windows/labels'][:]
|
||||||
|
start_times = f['windows/start_times'][:]
|
||||||
|
end_times = f['windows/end_times'][:]
|
||||||
|
|
||||||
# Validate shape compatibility
|
# Validate shape compatibility
|
||||||
if reference_shape is None:
|
if reference_shape is None:
|
||||||
@@ -725,10 +987,26 @@ class SessionStorage:
|
|||||||
else:
|
else:
|
||||||
labels.append(l)
|
labels.append(l)
|
||||||
|
|
||||||
|
original_count = len(X)
|
||||||
|
total_original += original_count
|
||||||
|
|
||||||
|
# Apply transition filtering per session (each has its own gesture boundaries)
|
||||||
|
if filter_transitions:
|
||||||
|
# Need temporary y for filtering function
|
||||||
|
temp_label_names = sorted(set(labels))
|
||||||
|
temp_label_to_idx = {name: idx for idx, name in enumerate(temp_label_names)}
|
||||||
|
temp_y = np.array([temp_label_to_idx[l] for l in labels], dtype=np.int32)
|
||||||
|
|
||||||
|
X, temp_y, labels = filter_transition_windows(
|
||||||
|
X, temp_y, labels, start_times, end_times
|
||||||
|
)
|
||||||
|
total_removed += original_count - len(X)
|
||||||
|
|
||||||
all_X.append(X)
|
all_X.append(X)
|
||||||
all_labels.extend(labels)
|
all_labels.extend(labels)
|
||||||
loaded_sessions.append(session_id)
|
loaded_sessions.append(session_id)
|
||||||
print(f"[Storage] - {session_id}: {X.shape[0]} windows")
|
print(f"[Storage] - {session_id}: {len(X)} windows" +
|
||||||
|
(f" (was {original_count})" if filter_transitions and len(X) != original_count else ""))
|
||||||
|
|
||||||
if not all_X:
|
if not all_X:
|
||||||
raise ValueError("No compatible sessions found!")
|
raise ValueError("No compatible sessions found!")
|
||||||
@@ -742,6 +1020,8 @@ class SessionStorage:
|
|||||||
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
|
y_combined = np.array([label_to_idx[l] for l in all_labels], dtype=np.int32)
|
||||||
|
|
||||||
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
|
print(f"[Storage] Combined dataset: X{X_combined.shape}, y{y_combined.shape}")
|
||||||
|
if filter_transitions and total_removed > 0:
|
||||||
|
print(f"[Storage] Total removed: {total_removed}/{total_original} windows ({total_removed/total_original*100:.1f}%)")
|
||||||
print(f"[Storage] Labels: {label_names}")
|
print(f"[Storage] Labels: {label_names}")
|
||||||
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")
|
print(f"[Storage] Sessions loaded: {len(loaded_sessions)}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,41 +1,152 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import h5py
|
import h5py
|
||||||
import pandas as pd
|
|
||||||
from scipy.signal import butter, sosfiltfilt
|
from scipy.signal import butter, sosfiltfilt
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# CONFIGURABLE PARAMETERS
|
# CONFIGURABLE PARAMETERS
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
ZC_THRESHOLD_PERCENT = 0.6 # Zero Crossing threshold as fraction of RMS
|
ZC_THRESHOLD_PERCENT = 0.7 # Zero Crossing threshold as fraction of RMS
|
||||||
SSC_THRESHOLD_PERCENT = 0.3 # Slope Sign Change threshold as fraction of RMS
|
SSC_THRESHOLD_PERCENT = 0.6 # Slope Sign Change threshold as fraction of RMS
|
||||||
|
|
||||||
file = h5py.File("assets/discrete_gestures_user_000_dataset_000.hdf5", "r")
|
# =============================================================================
|
||||||
|
# LOAD DATA FROM GUI's HDF5 FORMAT
|
||||||
|
# =============================================================================
|
||||||
|
# Update this path to your collected session file
|
||||||
|
HDF5_PATH = "collected_data\latency_fix_106_20260127_200043.hdf5"
|
||||||
|
|
||||||
|
file = h5py.File(HDF5_PATH, "r")
|
||||||
|
|
||||||
|
# Print HDF5 structure for debugging
|
||||||
|
print("HDF5 Structure:")
|
||||||
def print_tree(name, obj):
|
def print_tree(name, obj):
|
||||||
print(name)
|
print(f" {name}")
|
||||||
|
|
||||||
file.visititems(print_tree)
|
file.visititems(print_tree)
|
||||||
|
print()
|
||||||
|
|
||||||
print(list(file.keys()))
|
# Load metadata
|
||||||
|
fs = file.attrs['sampling_rate'] # Sampling rate in Hz
|
||||||
|
n_channels = file.attrs['num_channels']
|
||||||
|
window_size_ms = file.attrs['window_size_ms']
|
||||||
|
print(f"Sampling rate: {fs} Hz")
|
||||||
|
print(f"Channels: {n_channels}")
|
||||||
|
print(f"Window size: {window_size_ms} ms")
|
||||||
|
|
||||||
data = file["data"]
|
# Load windowed EMG data: shape (n_windows, samples_per_window, channels)
|
||||||
print(type(data))
|
emg_windows = file['windows/emg_data'][:]
|
||||||
print(data)
|
labels = file['windows/labels'][:]
|
||||||
print(data.dtype)
|
start_times = file['windows/start_times'][:]
|
||||||
|
end_times = file['windows/end_times'][:]
|
||||||
|
|
||||||
raw = data[:]
|
print(f"Windows shape: {emg_windows.shape}")
|
||||||
print(raw.shape)
|
print(f"Labels: {len(labels)} (unique: {np.unique(labels)})")
|
||||||
print(raw.dtype)
|
|
||||||
|
|
||||||
emg = raw['emg']
|
# Flatten windows to continuous signal for filtering analysis
|
||||||
time = raw['time']
|
# Shape: (n_windows * samples_per_window, channels)
|
||||||
print(emg.shape)
|
n_windows, samples_per_window, _ = emg_windows.shape
|
||||||
|
emg = emg_windows.reshape(-1, n_channels)
|
||||||
|
print(f"Flattened EMG shape: {emg.shape}")
|
||||||
|
|
||||||
dt = np.diff(time)
|
# Reconstruct time vector (assumes continuous recording)
|
||||||
fs = 1.0 / np.median(dt)
|
total_samples = emg.shape[0]
|
||||||
print("fs =", fs)
|
time = np.arange(total_samples) / fs
|
||||||
print("dt min/median/max =", dt.min(), np.median(dt), dt.max())
|
print(f"Time range: {time[0]:.2f}s to {time[-1]:.2f}s")
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PLOT RAW EMG DATA (before filtering)
|
||||||
|
# =============================================================================
|
||||||
|
print("\nPlotting raw EMG data...")
|
||||||
|
|
||||||
|
# Color map for channels
|
||||||
|
channel_colors = ['#00ff88', '#ff6b6b', '#4ecdc4', '#ffe66d']
|
||||||
|
|
||||||
|
# Map window indices to actual time in the flattened data
|
||||||
|
# Window i starts at sample (i * samples_per_window), which is time (i * samples_per_window / fs)
|
||||||
|
window_time_in_data = np.arange(len(labels)) * samples_per_window / fs
|
||||||
|
|
||||||
|
# Compute global Y range across all channels for consistent comparison
|
||||||
|
emg_global_min = emg.min()
|
||||||
|
emg_global_max = emg.max()
|
||||||
|
emg_margin = (emg_global_max - emg_global_min) * 0.05 # 5% margin
|
||||||
|
emg_ylim = (emg_global_min - emg_margin, emg_global_max + emg_margin)
|
||||||
|
print(f"Global EMG range: {emg_global_min:.1f} to {emg_global_max:.1f} mV")
|
||||||
|
|
||||||
|
# Full session plot with shared Y-axis
|
||||||
|
fig_raw, axes_raw = plt.subplots(n_channels, 1, figsize=(14, 2.5 * n_channels), sharex=True, sharey=True)
|
||||||
|
if n_channels == 1:
|
||||||
|
axes_raw = [axes_raw]
|
||||||
|
|
||||||
|
for ch in range(n_channels):
|
||||||
|
ax = axes_raw[ch]
|
||||||
|
|
||||||
|
# Plot raw EMG signal
|
||||||
|
ax.plot(time, emg[:, ch], linewidth=0.3, color=channel_colors[ch % len(channel_colors)], alpha=0.8)
|
||||||
|
ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10)
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.set_xlim(time[0], time[-1])
|
||||||
|
ax.set_ylim(emg_ylim) # Same Y range for all channels
|
||||||
|
|
||||||
|
# Add gesture markers AFTER plotting
|
||||||
|
y_min, y_max = emg_ylim
|
||||||
|
y_text = y_min + (y_max - y_min) * 0.85 # Position text at 85% height
|
||||||
|
for ch in range(n_channels):
|
||||||
|
ax = axes_raw[ch]
|
||||||
|
|
||||||
|
# Add gesture transition markers using window-aligned times
|
||||||
|
prev_label = None
|
||||||
|
for i, lbl in enumerate(labels):
|
||||||
|
lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl
|
||||||
|
t = window_time_in_data[i] # Time in the flattened data
|
||||||
|
|
||||||
|
if lbl_str != prev_label:
|
||||||
|
if 'open' in lbl_str:
|
||||||
|
color = 'cyan'
|
||||||
|
elif 'fist' in lbl_str:
|
||||||
|
color = 'blue'
|
||||||
|
elif 'hook' in lbl_str:
|
||||||
|
color = 'orange'
|
||||||
|
elif 'thumb' in lbl_str:
|
||||||
|
color = 'green'
|
||||||
|
elif 'rest' in lbl_str:
|
||||||
|
color = 'gray'
|
||||||
|
else:
|
||||||
|
color = 'red'
|
||||||
|
|
||||||
|
ax.axvline(t, color=color, linestyle='--', alpha=0.7, linewidth=1)
|
||||||
|
|
||||||
|
# Only add text label on first channel to avoid clutter
|
||||||
|
if ch == 0 and lbl_str != 'rest':
|
||||||
|
ax.text(t + 0.2, y_text, lbl_str, fontsize=8,
|
||||||
|
color=color, rotation=0, ha='left', va='top',
|
||||||
|
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7))
|
||||||
|
prev_label = lbl_str
|
||||||
|
|
||||||
|
axes_raw[-1].set_xlabel('Time (s)', fontsize=11)
|
||||||
|
fig_raw.suptitle('Raw EMG Signal (All Channels)', fontsize=14, fontweight='bold')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# Zoomed view of first 2 seconds to see waveform detail
|
||||||
|
zoom_duration = 2.0
|
||||||
|
zoom_samples = int(zoom_duration * fs)
|
||||||
|
|
||||||
|
if total_samples > zoom_samples:
|
||||||
|
fig_zoom, axes_zoom = plt.subplots(n_channels, 1, figsize=(14, 2 * n_channels), sharex=True, sharey=True)
|
||||||
|
if n_channels == 1:
|
||||||
|
axes_zoom = [axes_zoom]
|
||||||
|
|
||||||
|
for ch in range(n_channels):
|
||||||
|
ax = axes_zoom[ch]
|
||||||
|
ax.plot(time[:zoom_samples], emg[:zoom_samples, ch],
|
||||||
|
linewidth=0.5, color=channel_colors[ch % len(channel_colors)])
|
||||||
|
ax.set_ylabel(f'Ch {ch}\n(mV)', fontsize=10)
|
||||||
|
ax.set_ylim(emg_ylim) # Same Y range as full plot
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
axes_zoom[-1].set_xlabel('Time (s)', fontsize=11)
|
||||||
|
fig_zoom.suptitle(f'Raw EMG Signal (First {zoom_duration}s - Zoomed)', fontsize=14, fontweight='bold')
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
# 1) pick one channel
|
# 1) pick one channel
|
||||||
emg_ch = emg[:, 0].astype(np.float32)
|
emg_ch = emg[:, 0].astype(np.float32)
|
||||||
@@ -122,10 +233,9 @@ def compute_all_features_windowed(x, window_len, threshold_zc, threshold_ssc):
|
|||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# COMPUTE FEATURES FOR ALL 16 CHANNELS
|
# COMPUTE FEATURES FOR ALL CHANNELS
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
n_channels = 16
|
|
||||||
all_features = {} # Non-overlapping windows - for ML
|
all_features = {} # Non-overlapping windows - for ML
|
||||||
|
|
||||||
for ch in range(n_channels):
|
for ch in range(n_channels):
|
||||||
@@ -143,89 +253,106 @@ for ch in range(n_channels):
|
|||||||
all_features[ch] = {'rms': rms_i, 'wl': wl_i, 'zc': zc_i, 'ssc': ssc_i}
|
all_features[ch] = {'rms': rms_i, 'wl': wl_i, 'zc': zc_i, 'ssc': ssc_i}
|
||||||
|
|
||||||
# Time vector for windowed features
|
# Time vector for windowed features
|
||||||
n_windows = len(all_features[0]['rms'])
|
n_feat_windows = len(all_features[0]['rms'])
|
||||||
window_centers = np.arange(n_windows) * window_samples + window_samples // 2
|
window_centers = np.arange(n_feat_windows) * window_samples + window_samples // 2
|
||||||
time_windows = time_ch[window_centers]
|
time_windows = time_ch[window_centers]
|
||||||
|
|
||||||
print(f"\nComputed features for {n_channels} channels")
|
print(f"\nComputed features for {n_channels} channels")
|
||||||
print(f"Windows per channel (non-overlapping): {n_windows}")
|
print(f"Windows per channel (non-overlapping): {n_feat_windows}")
|
||||||
print(f"Window size: {window_samples} samples ({window_ms} ms)")
|
print(f"Window size: {window_samples} samples ({window_ms} ms)")
|
||||||
|
|
||||||
# 5) Load gesture labels (prompts)
|
# 5) Use embedded gesture labels from HDF5
|
||||||
prompts = pd.read_hdf("assets/discrete_gestures_user_000_dataset_000.hdf5", key="prompts")
|
# Labels are per-window, aligned with emg_windows
|
||||||
print("\nUnique gestures:", prompts['name'].unique())
|
unique_labels = np.unique(labels)
|
||||||
|
print(f"\nUnique gestures in session: {unique_labels}")
|
||||||
|
|
||||||
t_abs = time_ch # absolute timestamps
|
# Map labels to time in the flattened data (same as raw plot)
|
||||||
|
# Each original window i maps to time (i * samples_per_window / fs) in the flattened data
|
||||||
# Find first occurrence of each gesture type
|
# But we dropped `drop` samples, so adjust accordingly
|
||||||
index_gestures = prompts[prompts['name'].str.contains('index')]
|
label_times_in_data = np.arange(len(labels)) * samples_per_window / fs
|
||||||
middle_gestures = prompts[prompts['name'].str.contains('middle')]
|
print(f"Data duration: {time[-1]:.2f}s ({len(labels)} windows)")
|
||||||
thumb_gestures = prompts[prompts['name'].str.contains('thumb')]
|
|
||||||
|
|
||||||
# Define plot configurations: 1) Index+Middle combined, 2) Thumb separate
|
|
||||||
plot_configs = [
|
|
||||||
{'name': 'Index & Middle Finger', 'start_time': index_gestures['time'].iloc[0], 'filter': 'index|middle'},
|
|
||||||
{'name': 'Thumb', 'start_time': thumb_gestures['time'].iloc[0], 'filter': 'thumb'},
|
|
||||||
]
|
|
||||||
|
|
||||||
# Color function for markers
|
# Color function for markers
|
||||||
def get_gesture_color(name):
|
def get_gesture_color(name):
|
||||||
if 'index' in name:
|
"""Assign colors to gesture types."""
|
||||||
return 'green'
|
name_str = name.decode('utf-8') if isinstance(name, bytes) else name
|
||||||
elif 'middle' in name:
|
if 'open' in name_str:
|
||||||
|
return 'cyan'
|
||||||
|
elif 'fist' in name_str:
|
||||||
return 'blue'
|
return 'blue'
|
||||||
elif 'thumb' in name:
|
elif 'hook' in name_str:
|
||||||
return 'orange'
|
return 'orange'
|
||||||
|
elif 'thumb' in name_str:
|
||||||
|
return 'green'
|
||||||
|
elif 'rest' in name_str:
|
||||||
return 'gray'
|
return 'gray'
|
||||||
|
return 'red'
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL 16 CHANNELS
|
# 6) PLOT ALL FEATURES (RMS, WL, ZC, SSC) FOR ALL CHANNELS
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
feature_names = ['rms', 'wl', 'zc', 'ssc']
|
feature_names = ['rms', 'wl', 'zc', 'ssc']
|
||||||
feature_titles = ['RMS Envelope', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)']
|
feature_titles = ['RMS Envelope', 'Waveform Length (WL)', 'Zero Crossings (ZC)', 'Slope Sign Changes (SSC)']
|
||||||
feature_colors = ['red', 'blue', 'green', 'purple']
|
feature_colors = ['red', 'blue', 'green', 'purple']
|
||||||
feature_ylabels = ['Amplitude', 'WL (a.u.)', 'Count', 'Count']
|
feature_ylabels = ['Amplitude (mV)', 'WL (a.u.)', 'Count', 'Count']
|
||||||
|
|
||||||
for config in plot_configs:
|
# Determine subplot grid based on channel count
|
||||||
t_start = config['start_time'] - 0.5
|
if n_channels <= 4:
|
||||||
t_end = t_start + 10.0
|
n_rows, n_cols = 2, 2
|
||||||
|
elif n_channels <= 9:
|
||||||
|
n_rows, n_cols = 3, 3
|
||||||
|
else:
|
||||||
|
n_rows, n_cols = 4, 4
|
||||||
|
|
||||||
# Mask for windowed time vector
|
# Plot entire session for each feature
|
||||||
mask_win = (time_windows >= t_start) & (time_windows <= t_end)
|
for feat_idx, feat_name in enumerate(feature_names):
|
||||||
t_win_rel = time_windows[mask_win] - t_start
|
fig, axes = plt.subplots(n_rows, n_cols, figsize=(12, 8), sharex=True, sharey=True)
|
||||||
|
|
||||||
# Get gestures in window
|
|
||||||
gesture_mask = (prompts['time'] >= t_start) & (prompts['time'] <= t_end) & \
|
|
||||||
(prompts['name'].str.contains(config['filter']))
|
|
||||||
gestures_in_window = prompts[gesture_mask]
|
|
||||||
|
|
||||||
# Create one figure per feature
|
|
||||||
for feat_idx, feat_name in enumerate(feature_names):
|
|
||||||
fig, axes = plt.subplots(4, 4, figsize=(10, 8), sharex=True, sharey=True)
|
|
||||||
axes = axes.flatten()
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
# Compute global Y range for this feature across all channels
|
||||||
|
all_feat_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
|
||||||
|
feat_min, feat_max = all_feat_vals.min(), all_feat_vals.max()
|
||||||
|
feat_margin = (feat_max - feat_min) * 0.05 if feat_max > feat_min else 1
|
||||||
|
feat_ylim = (feat_min - feat_margin, feat_max + feat_margin)
|
||||||
|
|
||||||
for ch in range(n_channels):
|
for ch in range(n_channels):
|
||||||
ax = axes[ch]
|
ax = axes[ch]
|
||||||
|
|
||||||
# Plot windowed feature
|
# Plot windowed feature over time
|
||||||
feat_data = all_features[ch][feat_name][mask_win]
|
feat_data = all_features[ch][feat_name]
|
||||||
ax.plot(t_win_rel, feat_data, linewidth=1, color=feature_colors[feat_idx])
|
ax.plot(time_windows, feat_data, linewidth=0.8, color=feature_colors[feat_idx])
|
||||||
ax.set_title(f"Ch {ch}", fontsize=9)
|
ax.set_title(f"Channel {ch}", fontsize=10)
|
||||||
|
ax.set_ylim(feat_ylim) # Same Y range for all channels
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
# Add gesture markers
|
# Add gesture transition markers from labels
|
||||||
for _, row in gestures_in_window.iterrows():
|
prev_label = None
|
||||||
t_g = row['time'] - t_start
|
for i, lbl in enumerate(labels):
|
||||||
color = get_gesture_color(row['name'])
|
lbl_str = lbl.decode('utf-8') if isinstance(lbl, bytes) else lbl
|
||||||
ax.axvline(t_g, color=color, linestyle='--', alpha=0.5, linewidth=0.5)
|
t = label_times_in_data[i] # Time in flattened data
|
||||||
|
if lbl_str != prev_label and lbl_str != 'rest':
|
||||||
|
# Only show markers within the time_windows range
|
||||||
|
if t <= time_windows[-1]:
|
||||||
|
color = get_gesture_color(lbl)
|
||||||
|
ax.axvline(t, color=color, linestyle='--', alpha=0.6, linewidth=1)
|
||||||
|
prev_label = lbl_str
|
||||||
|
|
||||||
# Set subtitle based on gesture type
|
# Hide unused subplots
|
||||||
if 'Index' in config['name']:
|
for ch in range(n_channels, len(axes)):
|
||||||
subtitle = "(Green=index, Blue=middle)"
|
axes[ch].set_visible(False)
|
||||||
else:
|
|
||||||
subtitle = "(Orange=thumb)"
|
|
||||||
|
|
||||||
fig.suptitle(f"{feature_titles[feat_idx]} - {config['name']} Gestures\n{subtitle}", fontsize=12)
|
# Legend for gesture colors
|
||||||
|
from matplotlib.lines import Line2D
|
||||||
|
legend_elements = [
|
||||||
|
Line2D([0], [0], color='cyan', linestyle='--', label='Open'),
|
||||||
|
Line2D([0], [0], color='blue', linestyle='--', label='Fist'),
|
||||||
|
Line2D([0], [0], color='orange', linestyle='--', label='Hook Em'),
|
||||||
|
Line2D([0], [0], color='green', linestyle='--', label='Thumbs Up'),
|
||||||
|
]
|
||||||
|
fig.legend(handles=legend_elements, loc='upper right', fontsize=9)
|
||||||
|
|
||||||
|
fig.suptitle(f"{feature_titles[feat_idx]} - All Channels", fontsize=14)
|
||||||
fig.supxlabel("Time (s)")
|
fig.supxlabel("Time (s)")
|
||||||
fig.supylabel(feature_ylabels[feat_idx])
|
fig.supylabel(feature_ylabels[feat_idx])
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
@@ -235,9 +362,13 @@ for config in plot_configs:
|
|||||||
# SUMMARY: Feature statistics across all channels
|
# SUMMARY: Feature statistics across all channels
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("FEATURE SUMMARY (all channels, all windows)")
|
print(f"FEATURE SUMMARY ({n_channels} channels, {n_feat_windows} windows each)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
for feat_name in ['rms', 'wl', 'zc', 'ssc']:
|
for feat_name in ['rms', 'wl', 'zc', 'ssc']:
|
||||||
all_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
|
all_vals = np.concatenate([all_features[ch][feat_name] for ch in range(n_channels)])
|
||||||
print(f"{feat_name.upper():4s} | min: {all_vals.min():10.4f} | max: {all_vals.max():10.4f} | "
|
print(f"{feat_name.upper():4s} | min: {all_vals.min():10.4f} | max: {all_vals.max():10.4f} | "
|
||||||
f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}")
|
f"mean: {all_vals.mean():10.4f} | std: {all_vals.std():10.4f}")
|
||||||
|
|
||||||
|
# Close the HDF5 file
|
||||||
|
file.close()
|
||||||
|
print("\n[Done] HDF5 file closed.")
|
||||||
Binary file not shown.
@@ -99,13 +99,13 @@ class RealSerialStream:
|
|||||||
@note Requires pyserial: pip install pyserial
|
@note Requires pyserial: pip install pyserial
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, port: str = None, baud_rate: int = 115200, timeout: float = 1.0):
|
def __init__(self, port: str = None, baud_rate: int = 921600, timeout: float = 0.05):
|
||||||
"""
|
"""
|
||||||
@brief Initialize the serial stream.
|
@brief Initialize the serial stream.
|
||||||
|
|
||||||
@param port Serial port name (e.g., 'COM3' on Windows, '/dev/ttyUSB0' on Linux).
|
@param port Serial port name (e.g., 'COM3' on Windows, '/dev/ttyUSB0' on Linux).
|
||||||
If None, will attempt to auto-detect the ESP32.
|
If None, will attempt to auto-detect the ESP32.
|
||||||
@param baud_rate Communication speed in bits per second. Default 115200 matches ESP32.
|
@param baud_rate Communication speed in bits per second. Default 921600 for high-throughput streaming.
|
||||||
@param timeout Read timeout in seconds for readline().
|
@param timeout Read timeout in seconds for readline().
|
||||||
"""
|
"""
|
||||||
self.port = port
|
self.port = port
|
||||||
@@ -233,6 +233,9 @@ class RealSerialStream:
|
|||||||
"Must call connect() first."
|
"Must call connect() first."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Flush any stale data before starting fresh stream
|
||||||
|
self.serial.reset_input_buffer()
|
||||||
|
|
||||||
# Send start command
|
# Send start command
|
||||||
start_cmd = {"cmd": "start"}
|
start_cmd = {"cmd": "start"}
|
||||||
self._send_json(start_cmd)
|
self._send_json(start_cmd)
|
||||||
|
|||||||
Reference in New Issue
Block a user