diff --git a/config/data_module/emg_pretrain_data_module.yaml b/config/data_module/emg_pretrain_data_module.yaml index 55b293f..9bbc3f6 100644 --- a/config/data_module/emg_pretrain_data_module.yaml +++ b/config/data_module/emg_pretrain_data_module.yaml @@ -27,14 +27,36 @@ data_module: train_val_split_ratio: 0.8 datasets: demo_dataset: null - emg2pose: + emg2pose_train: _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' - data_dir: "${env:DATA_PATH}/emg2pose/h5/" - db6: + hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/train.h5" + emg2pose_val: _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' - data_dir: "${env:DATA_PATH}/ninapro/DB6/h5/" + hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/val.h5" + emg2pose_test: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/test.h5" + db6_train: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/train.h5" + pad_up_to_max_chans: 16 + db6_val: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/val.h5" + pad_up_to_max_chans: 16 + db6_test: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/test.h5" + pad_up_to_max_chans: 16 + db7_train: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/train.h5" + pad_up_to_max_chans: 16 + db7_val: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/val.h5" pad_up_to_max_chans: 16 - db7: + db7_test: _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' - data_dir: "${env:DATA_PATH}/ninapro/DB7/h5/" + hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/test.h5" pad_up_to_max_chans: 16 diff --git a/config/experiment/TinyMyo_finetune.yaml b/config/experiment/TinyMyo_finetune.yaml index 9532c5f..224e60f 100644 --- a/config/experiment/TinyMyo_finetune.yaml +++ b/config/experiment/TinyMyo_finetune.yaml @@ -21,7 +21,7 @@ tag: EMG_finetune gpus: -1 num_nodes: 1 -num_workers: 8 +num_workers: 4 batch_size: 32 max_epochs: 50 @@ -32,7 +32,6 @@ finetune_pretrained: True resume: False layerwise_lr_decay: 0.90 -scheduler_type: cosine pretrained_checkpoint_path: null pretrained_safetensors_path: null @@ -41,7 +40,7 @@ finetuning: freeze_layers: False io: - base_output_path: ${env:DATA_PATH} + base_output_path: ${env:LOG_DIR} checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints version: 0 @@ -52,23 +51,16 @@ defaults: - override /task: finetune_task_TinyMyo - override /criterion: finetune_criterion -masking: - patch_size: [1, 20] - masking_ratio: 0.50 - unmasked_loss_coeff: 0.1 - -input_normalization: - normalize: False - model: + n_layer: 8 num_classes: 6 - classification_type: "ml" + task: "classification" trainer: accelerator: gpu num_nodes: ${num_nodes} devices: ${gpus} - strategy: auto + strategy: ddp_find_unused_parameters_true max_epochs: ${max_epochs} model_checkpoint: @@ -89,7 +81,7 @@ optimizer: optim: 'AdamW' lr: 5e-4 betas: [0.9, 0.98] - weight_decay: 0.01 + weight_decay: 1e-2 scheduler: trainer: ${trainer} @@ -98,3 +90,10 @@ scheduler: warmup_epochs: 5 total_training_opt_steps: ${max_epochs} t_in_epochs: True + +wandb: + entity: "TinyMyo" + project: "TinyMyo" + save_dir: ${env:LOG_DIR} + run_name: "TinyMyo-Finetuning" + offline: True \ No newline at end of file diff --git a/config/experiment/TinyMyo_pretrain.yaml b/config/experiment/TinyMyo_pretrain.yaml index 3485edd..9c0f287 100644 --- a/config/experiment/TinyMyo_pretrain.yaml +++ b/config/experiment/TinyMyo_pretrain.yaml @@ -22,15 +22,15 @@ tag: EMG_pretrain gpus: -1 num_nodes: 1 num_workers: 8 -batch_size: 128 -max_epochs: 50 +batch_size: 512 +max_epochs: 30 final_validate: True final_test: False pretrained_checkpoint_path: null io: - base_output_path: ${env:DATA_PATH} + base_output_path: ${env:LOG_DIR} checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints version: 0 @@ -49,22 +49,23 @@ masking: input_normalization: normalize: True -scheduler: - trainer: ${trainer} - min_lr: 1e-6 - warmup_lr_init: 1e-6 - warmup_epochs: 10 - total_training_opt_steps: ${max_epochs} - t_in_epochs: True +model: + n_layer: 8 + drop_path: 0.0 # Stochastic depth disabled for pretraining + num_classes: 0 # No classification head for pretraining + task: pretraining + +criterion: + loss_type: 'smooth_l1' trainer: accelerator: gpu num_nodes: ${num_nodes} devices: ${gpus} strategy: auto + precision: "bf16-mixed" max_epochs: ${max_epochs} - gradient_clip_val: 3 - accumulate_grad_batches: 8 + gradient_clip_val: 1 model_checkpoint: save_last: True @@ -73,7 +74,21 @@ model_checkpoint: save_top_k: 1 optimizer: - optim: 'AdamW' - lr: 1e-4 + lr: 5e-4 betas: [0.9, 0.98] - weight_decay: 0.01 + weight_decay: 1e-2 + +scheduler: + trainer: ${trainer} + min_lr: 1e-6 + warmup_lr_init: 1e-6 + warmup_epochs: 3 + total_training_opt_steps: ${max_epochs} + t_in_epochs: True + +wandb: + entity: "TinyMyo" + project: "TinyMyo" + save_dir: ${env:LOG_DIR} + run_name: "TinyMyo-Pretraining" + offline: True diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..6dfe643 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1 @@ +"""Local dataset implementations for BioFoundation.""" diff --git a/datasets/emg_finetune_dataset.py b/datasets/emg_finetune_dataset.py index e6ad10f..56e4fc4 100644 --- a/datasets/emg_finetune_dataset.py +++ b/datasets/emg_finetune_dataset.py @@ -16,130 +16,62 @@ # * * # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -from collections import deque from typing import Tuple, Union import h5py +import numpy as np import torch class EMGDataset(torch.utils.data.Dataset): - """ - A PyTorch Dataset class for loading EMG (Electromyography) data from HDF5 files. - This dataset supports lazy loading of data from HDF5 files, with optional caching - to improve performance during training. It can be used for both fine-tuning (with labels) - and inference (without labels) modes. The class handles data preprocessing, such as - converting to tensors and optional unsqueezing. + """PyTorch Dataset for loading EMG data from HDF5 files. + Attributes: - hdf5_file (str): Path to the HDF5 file containing the dataset. - unsqueeze (bool): Whether to add an extra dimension to the input data (default: False). - finetune (bool): If True, loads both data and labels; if False, loads only data (default: True). - cache_size (int): Maximum number of samples to cache in memory (default: 1500). - use_cache (bool): Whether to use caching for faster access (default: True). - regression (bool): If True, treats labels as regression targets (float); else, classification (long) (default: False). - num_samples (int): Total number of samples in the dataset, determined from HDF5 file. - data (h5py.File or None): Handle to the opened HDF5 file (lazy-loaded). - X_ds (h5py.Dataset or None): Dataset handle for input data. - Y_ds (h5py.Dataset or None): Dataset handle for labels (if finetune is True). - cache (dict): Dictionary for caching data items (if use_cache is True). - cache_queue (deque): Queue to track the order of cached items for LRU eviction. - Note: - - The HDF5 file is expected to have 'data' and 'label' datasets. - - Caching uses an LRU (Least Recently Used) eviction policy. - - Suitable for use with PyTorch DataLoader for batched loading. + hdf5_file (str): Path to the HDF5 source file. + finetune (bool): If True, returns (data, label). If False, returns data only. + regression (bool): If True, labels are treated as floats. Else, longs. """ - def __init__( self, hdf5_file: str, - unsqueeze: bool = False, finetune: bool = True, - cache_size: int = 1500, - use_cache: bool = True, regression: bool = False, + verbose: bool = False, ): self.hdf5_file = hdf5_file - self.unsqueeze = unsqueeze - self.cache_size = cache_size self.finetune = finetune - self.use_cache = use_cache self.regression = regression - self.data = None - self.X_ds = None - self.Y_ds = None - - # Open once to get length, then close immediately with h5py.File(self.hdf5_file, "r") as f: - self.num_samples = f["data"].shape[0] - - if self.use_cache: - self.cache: dict[int, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} - self.cache_queue = deque() - - def _open_file(self) -> None: - # 'rdcc_nbytes' to increase the raw data chunk cache size - self.data = h5py.File(self.hdf5_file, "r", rdcc_nbytes=1024 * 1024 * 4) - if self.data is not None: - self.X_ds = self.data["data"] - self.Y_ds = self.data["label"] - - def __len__(self) -> int: - return self.num_samples - - def __getitem__(self, index): - # Check Cache - if self.use_cache and index in self.cache: - return self._process_data(self.cache[index]) - - # Open file (Lazy Loading for Multiprocessing) - if self.data is None: - self._open_file() - - # Read Data, HDF5 slicing returns numpy array - X_np = self.X_ds[index] - X = torch.from_numpy(X_np).float() + X_np = f["data"][:] + Y_np = f["label"][:] if self.finetune else None + self.X_tensor = torch.from_numpy(X_np).float().contiguous() if self.finetune: - Y_np = self.Y_ds[index] if self.regression: - Y = torch.from_numpy(Y_np).float() + self.Y_tensor = torch.from_numpy(Y_np).float().contiguous() else: - # Ensure scalar is converted properly - Y = torch.tensor(Y_np, dtype=torch.long) - - data_item = (X, Y) - else: - data_item = X + self.Y_tensor = torch.from_numpy(Y_np).long().contiguous() + if verbose: + uniq, cnt = np.unique(Y_np, return_counts=True) + print( + f"[EMGDataset] {self.hdf5_file}: label min={Y_np.min()}, max={Y_np.max()}, classes={len(uniq)}" + ) + print(f"[EMGDataset] {self.hdf5_file}: class hist={dict(zip(uniq.tolist(), cnt.tolist()))}") - # Update Cache - if self.use_cache: - # If cache is full, remove oldest item from dict AND queue - if len(self.cache) >= self.cache_size: - oldest_index = self.cache_queue.popleft() - del self.cache[oldest_index] + self.num_samples = self.X_tensor.shape[0] # [N, C, T] - self.cache[index] = data_item - self.cache_queue.append(index) - - return self._process_data(data_item) + def __len__(self) -> int: + """Returns the total number of samples in the dataset.""" + return self.num_samples - def _process_data(self, data_item): - """Helper to handle squeezing/returning uniformly.""" - if self.finetune: - X, Y = data_item - else: - X = data_item - Y = None + def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Retrieves the EMG data and optional label at the specified index.""" - if self.unsqueeze: - X = X.unsqueeze(0) + X = self.X_tensor[index] if self.finetune: + Y = self.Y_tensor[index] return X, Y - else: - return X - def __del__(self): - if self.data is not None: - self.data.close() + return X diff --git a/datasets/emg_pretrain_dataset.py b/datasets/emg_pretrain_dataset.py index 15a3248..39063b4 100644 --- a/datasets/emg_pretrain_dataset.py +++ b/datasets/emg_pretrain_dataset.py @@ -17,120 +17,186 @@ # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* import os -import threading -from collections import deque -from typing import Optional +import fcntl +from multiprocessing import shared_memory import h5py +from tqdm import tqdm +import numpy as np import torch +from joblib import Parallel, delayed import torch.nn.functional as F from torch.utils.data import Dataset -# thread-local storage for per-worker file handle -_thread_local = threading.local() - - -def _get_h5_handle(path): - h5f = getattr(_thread_local, "h5f", None) - if h5f is None or h5f.filename != path: - h5f = h5py.File(path, "r") - _thread_local.h5f = h5f - return h5f - - class EMGPretrainDataset(Dataset): - """ - A PyTorch Dataset class for loading EMG (electromyography) data from HDF5 files for pretraining purposes. - This dataset discovers all .h5 files in the specified directory, builds an index of samples across all files, - and provides access to individual samples. It supports optional caching to improve performance, channel padding, - and squeezing of the data tensor. - Args: - data_dir (str): Path to the directory containing .h5 files. - squeeze (bool, optional): Whether to squeeze the data tensor. Defaults to False. - cache_size (int, optional): Size of the cache. Defaults to 1500. - use_cache (bool, optional): Enable caching. Defaults to True. - pad_up_to_max_chans (int | None, optional): Number of channels to pad to. Defaults to None. - max_samples (int | None, optional): Limit the total number of samples. Defaults to None. - Raises: - RuntimeError: If no .h5 files are found in the data directory. - Note: - The .h5 files are expected to have a 'data' dataset with shape (N, C, T), where N is the number of samples, - C is the number of channels, and T is the number of time points. - Caching uses a simple LRU mechanism with a deque to track access order. - The __del__ method ensures that any open HDF5 file handles are closed upon deletion. + """Shared-memory optimized Dataset for large-scale EMG pretraining. + + This dataset loads HDF5 data into a shared RAM block (POSIX shared memory) + to allow fast access across multiple worker processes without + the serial overhead of HDF5 reads or the memory duplication of standard + multiprocessing. + + Attributes: + hdf5_file_path (str): Path to the single HDF5 source file. + minmax (bool): Whether to apply min-max scaling to [-1, 1]. + pad_up_to_max_chans (Optional[int]): If set, zero-pads channels to this count. + total_len (int): Total number of samples across all HDF5 groups. + ram_data (np.ndarray): View into the shared memory block. """ def __init__( self, - data_dir: str, - squeeze: bool = False, - cache_size: int = 1500, - use_cache: bool = True, - pad_up_to_max_chans: Optional[int] = None, - max_samples: Optional[int] = None, + hdf5_file: str, + minmax: bool = True, + pad_up_to_max_chans: int | None = None, + n_jobs: int = 16 ): + """Initializes the shared memory dataset and loads data if needed. + + Args: + hdf5_file (str): Path to the HDF5 file. + minmax (bool): Enable scaling. Defaults to True. + pad_up_to_max_chans (Optional[int]): Target channel count for padding. + n_jobs (int): Number of parallel threads for the initial load. + """ super().__init__() - self.squeeze = squeeze - self.cache_size = cache_size - self.use_cache = use_cache + self.minmax = minmax self.pad_up_to_max_chans = pad_up_to_max_chans - # discover all .h5 files - self.file_paths = sorted(os.path.join(data_dir, fn) for fn in os.listdir(data_dir) if fn.endswith(".h5")) - if not self.file_paths: - raise RuntimeError(f"No .h5 files in {data_dir!r}") - - # build index of (file_path, sample_idx) - self.index_map = [] - for fp in self.file_paths: - with h5py.File(fp, "r") as h5f: - n = h5f["data"].shape[0] - for i in range(n): - self.index_map.append((fp, i)) - if max_samples is not None: - self.index_map = self.index_map[:max_samples] - - # Cache to store recently accessed samples - if use_cache: - self.cache = {} - self.cache_queue = deque(maxlen=self.cache_size) - - def __len__(self): - return len(self.index_map) - - def __getitem__(self, index): - if self.use_cache and index in self.cache: - cached_data = self.cache[index] - X = cached_data - else: - fp, local_idx = self.index_map[index] - h5f = _get_h5_handle(fp) - np_x = h5f["data"][local_idx] # shape (C, T) - X = torch.from_numpy(np_x).float() - - if self.use_cache: - # If cache is full, remove oldest item from dict AND queue - if len(self.cache) >= self.cache_size: - oldest_index = self.cache_queue.popleft() - del self.cache[oldest_index] - - self.cache[index] = X - self.cache_queue.append(index) - - # squeeze if requested - if self.squeeze: - X = X.unsqueeze(0) - - # pad channels if requested + self.rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0))) + + # This class will be instantiated once per file (e.g., train.h5, val.h5) + self.hdf5_file_path = hdf5_file + if not os.path.exists(self.hdf5_file_path) or not self.hdf5_file_path.endswith(".h5"): + raise ValueError(f"Expected hdf5_file to be a path to a single HDF5 file, but got {self.hdf5_file_path}") + + self.ram_data = None + self.shm_block = None + + file_name = os.path.basename(self.hdf5_file_path) + + # Calculate total shape from all groups + with h5py.File(self.hdf5_file_path, "r") as hf: + group_keys = sorted(hf.keys()) + if not group_keys: raise ValueError(f"HDF5 file {file_name} contains no data groups.") + + self.group_offsets = [0] + total_samples = 0 + for key in group_keys: + num_in_group = hf[key]['X'].shape[0] + total_samples += num_in_group + self.group_offsets.append(total_samples) + + # Get other dimensions from the first group + _, C, T = hf[group_keys[0]]['X'].shape + final_shape = (total_samples, C, T) + + self.total_len = total_samples + + # Allocate shared memory and load in parallel + clean_name = f"{os.path.splitext(file_name)[0]}_{final_shape[0]}" + shm_name = f"emg_shm_{clean_name}" + target_dtype = np.float16 # cast to fp16 to fit in RAM + num_bytes = int(np.prod(final_shape)) * np.dtype(target_dtype).itemsize + + lock_path = f"/tmp/{shm_name}.lock" + ready_path = f"/tmp/{shm_name}.ready" + file_mtime = os.path.getmtime(self.hdf5_file_path) + ready_token = f"{os.path.abspath(self.hdf5_file_path)}|{file_mtime}|{num_bytes}" + + with open(lock_path, "w") as lockf: + fcntl.flock(lockf, fcntl.LOCK_EX) + + shm = None + shm_needs_load = True + + if os.path.exists(ready_path): + try: + with open(ready_path, "r") as rf: + token = rf.read().strip() + if token == ready_token: + existing = shared_memory.SharedMemory(name=shm_name) + if existing.size == num_bytes: + shm = existing + shm_needs_load = False + else: + existing.close() + try: + existing.unlink() + except FileNotFoundError: + pass + except Exception: + shm_needs_load = True + + if shm_needs_load: + try: + stale = shared_memory.SharedMemory(name=shm_name) + stale.close() + try: + stale.unlink() + except FileNotFoundError: + pass + except FileNotFoundError: + pass + + print(f"[PID {os.getpid()}] Allocating {num_bytes / 1e9:.2f} GB of Shared RAM for {file_name}...") + shm = shared_memory.SharedMemory(create=True, name=shm_name, size=num_bytes) + shm_arr = np.ndarray(final_shape, dtype=target_dtype, buffer=shm.buf) + + def load_group(group_idx): + key = group_keys[group_idx] + start_offset = self.group_offsets[group_idx] + end_offset = self.group_offsets[group_idx + 1] + with h5py.File(self.hdf5_file_path, "r") as local_hf: + data_chunk = local_hf[key]['X'][:].astype(target_dtype) + shm_arr[start_offset:end_offset] = data_chunk + + print(f"[PID {os.getpid()}] Parallel loading groups from {file_name} using {n_jobs} cores...") + Parallel(n_jobs=n_jobs, backend="threading")( + delayed(load_group)(i) for i in tqdm(range(len(group_keys)), desc=f"Loading {file_name}") + ) + with open(ready_path, "w") as wf: + wf.write(ready_token) + print(f"[PID {os.getpid()}] Finished loading {file_name}!") + + self.shm_block = shm + self.ram_data = np.ndarray(final_shape, dtype=target_dtype, buffer=shm.buf) + + def __len__(self) -> int: + """Returns the total number of samples.""" + return self.total_len + + def _minmax_scale(self, x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + """Scales EMG signal to [-1, 1] range.""" + maxv = x.amax(dim=-1, keepdim=True) + minv = x.amin(dim=-1, keepdim=True) + x = (x - minv) / (maxv - minv + eps) + return (x - 0.5) * 2 + + def __getitem__(self, idx: int) -> torch.Tensor: + """Retrieves a single sample as a float32 tensor. + + Args: + idx (int): Global index across all groups. + + Returns: + torch.Tensor: Normalized and padded EMG tensor of shape (C, T). + """ + if idx < 0 or idx >= self.total_len: + raise IndexError(f"Index {idx} out of range for dataset of size {self.total_len}") + + # Direct, instant access from the single RAM array + X = torch.tensor(self.ram_data[idx], dtype=torch.float32).contiguous() + + if self.minmax: X = self._minmax_scale(X) if self.pad_up_to_max_chans is not None: C = X.shape[0] to_pad = self.pad_up_to_max_chans - C - if to_pad > 0: - X = F.pad(X, (0, 0, 0, to_pad)) # (channels, time) -> pad channels - + if to_pad > 0: X = F.pad(X.T, (0, to_pad)).T return X + def __del__(self): - h5f = getattr(_thread_local, "h5f", None) - if h5f is not None: - h5f.close() + if hasattr(self, 'shm_block') and self.shm_block is not None: + try: self.shm_block.close() + except Exception: pass \ No newline at end of file diff --git a/docs/model/TinyMyo.md b/docs/model/TinyMyo.md index 35a332d..06a52bc 100644 --- a/docs/model/TinyMyo.md +++ b/docs/model/TinyMyo.md @@ -113,9 +113,9 @@ TinyMyo supports three major categories: Evaluated on: -* **Ninapro DB5** (52 classes, 10 subjects) -* **EPN-612** (5 classes, 612 subjects) -* **UCI EMG** (6 classes, 36 subjects) +* **Ninapro DB5** (52 classes, 10 subjects, 200 Hz) +* **EPN-612** (5 classes, 612 subjects, 200 Hz) +* **UCI EMG** (6 classes, 36 subjects, 200 Hz) * **Generic Neuromotor Interface** (Meta wristband; 9 gestures) * Repository: [MatteoFasulo/generic-neuromotor-interface](https://github.com/MatteoFasulo/generic-neuromotor-interface) @@ -126,8 +126,8 @@ Evaluated on: * EMG filtering: **20–90 Hz** bandpass + 50 Hz notch * Windows: - * **200 ms** (best for DB5) - * **1000 ms** (best for EPN & UCI) + * **1 sec** (best for DB5) + * **5 sec** (best for EPN & UCI) * Per-channel z-scoring * Linear classification head @@ -138,9 +138,9 @@ Evaluated on: | Dataset | Metric | Result | | ------------------------ | -------- | ----------------- | -| **Ninapro DB5 (200 ms)** | Accuracy | **89.41 ± 0.16%** | -| **EPN-612 (1000 ms)** | Accuracy | **96.74 ± 0.09%** | -| **UCI EMG (1000 ms)** | Accuracy | **97.56 ± 0.32%** | +| **Ninapro DB5 (1 sec)** | Accuracy | **89.41 ± 0.16%** | +| **EPN-612 (5 sec)** | Accuracy | **96.74 ± 0.09%** | +| **UCI EMG (5 sec)** | Accuracy | **97.56 ± 0.32%** | | **Neuromotor Interface** | CLER | **0.153 ± 0.006** | TinyMyo achieves **state-of-the-art** on DB5, EPN-612, and UCI. @@ -149,9 +149,9 @@ TinyMyo achieves **state-of-the-art** on DB5, EPN-612, and UCI. ### **4.2 Hand Kinematic Regression** -Dataset: **Ninapro DB8** +Dataset: **Ninapro DB8** (2000 Hz) Task: Regress **5 joint angles (DoA)** -Preprocessing: z-score only; windows of **200 ms** or **1000 ms** +Preprocessing: z-score only; windows of **100 ms** or **500 ms** **Regression head (788k params)** @@ -162,7 +162,7 @@ Preprocessing: z-score only; windows of **200 ms** or **1000 ms** **Performance (Fine-tuned)** -* **MAE = 8.77 ± 0.12°** (1000 ms window) +* **MAE = 8.77 ± 0.12°** (500 ms window) Although previous works achieve lower MAE (≈6.89°), those models are **subject-specific**, whereas TinyMyo trains **one model across all subjects**, a significantly harder problem. @@ -219,7 +219,7 @@ Key elements: * Integer softmax, integer LayerNorm, integer GELU * Static liveness-based memory arena -**Runtime (NinaPro EPN612 pipeline):** +**Runtime (EPN612 dataset):** * **0.785 s inference time** * **44.91 mJ energy** diff --git a/models/TinyMyo.py b/models/TinyMyo.py index 5e15687..5d61e22 100644 --- a/models/TinyMyo.py +++ b/models/TinyMyo.py @@ -1,6 +1,6 @@ -import math from dataclasses import dataclass -from typing import Literal, Optional, Tuple +import math +from typing import Literal, Optional import torch import torch.nn as nn @@ -113,7 +113,7 @@ def forward( # reshape the cache for broadcasting # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, # otherwise has shape [1, s, 1, h_d // 2, 2] - rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + rope_cache = rope_cache.reshape(-1, xshaped.size(1), 1, xshaped.size(3), 2) # tensor has shape [b, s, n_h, h_d // 2, 2] x_out = torch.stack( @@ -240,7 +240,12 @@ def __post_init__(self): self.proj = nn.Linear(self.dim, self.dim) self.p_drop = nn.Dropout(self.proj_drop) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + pos_ids: torch.Tensor = None, + attn_mask: torch.Tensor = None, + ) -> torch.Tensor: """Forward pass for rotary self-attention block.""" B, N, C = x.shape qkv = ( @@ -250,17 +255,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # (K, B, H, N, D) q, k, v = qkv.unbind(0) # each: (B, H, N, D) - q = self.rope(q) - k = self.rope(k) + # Transpose to [B, N, H, D] for RoPE + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + q = self.rope(q, input_pos=pos_ids) + k = self.rope(k, input_pos=pos_ids) + + # Transpose back to [B, H, N, D] for SDPA + q = q.transpose(1, 2) + k = k.transpose(1, 2) - # pylint: disable=not-callable x = F.scaled_dot_product_attention( q, k, v, + attn_mask=attn_mask, # Mask out padded channels dropout_p=self.attn_drop if self.training else 0.0, is_causal=False, - enable_gqa=False, ) x = x.transpose(2, 1).reshape(B, N, C) @@ -321,8 +333,15 @@ def __post_init__(self): drop=self.drop, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path1(self.attn(self.norm1(x))) + def forward( + self, + x: torch.Tensor, + pos_ids: torch.Tensor = None, + attn_mask: torch.Tensor = None, + ) -> torch.Tensor: + x = x + self.drop_path1( + self.attn(self.norm1(x), pos_ids=pos_ids, attn_mask=attn_mask) + ) x = x + self.drop_path2(self.mlp(self.norm2(x))) return x @@ -377,15 +396,12 @@ class EMGClassificationHead(nn.Module): """ A classification head for EMG (Electromyography) data processing, designed to classify token embeddings into a specified number of classes. - This module takes token embeddings as input, applies a reduction strategy (either mean or concatenation across channels), + This module takes token embeddings as input, applies a reduction strategy (concatenation across channels), averages across patches, and then uses a linear classifier to produce logits for classification. embed_dim (int): Dimensionality of the token embeddings. num_classes (int): Number of output classes for classification. in_chans (int): Number of input channels (e.g., EMG channels). - reduction (str): Reduction strategy for combining channel features. Options are "mean" or "concat". - - "mean": Averages across channels, resulting in feature dimension of embed_dim. - - "concat": Concatenates across channels, resulting in feature dimension of in_chans * embed_dim. Defaults to "concat". Attributes: classifier (nn.Linear): Linear layer for final classification, mapping from reduced feature dimension to num_classes. @@ -394,23 +410,17 @@ class EMGClassificationHead(nn.Module): embed_dim: int num_classes: int in_chans: int - reduction: Literal["mean", "concat"] = "concat" def __post_init__(self): super().__init__() - # after reduction, feature_dim to either embed_dim or in_chans*embed_dim - feat_dim = ( - self.embed_dim - if self.reduction == "mean" - else self.in_chans * self.embed_dim - ) + feat_dim = self.in_chans * self.embed_dim self.classifier = nn.Linear(feat_dim, self.num_classes) # init weights self.apply(self._init_weights) - def _init_weights(self, m: nn.Module): + def _init_weights(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: @@ -423,22 +433,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: logits: (B, num_classes) """ - _, N, _ = x.shape + B, N, D = x.shape num_patches = N // self.in_chans - if self.reduction == "mean": - # Reshape to (B, in_chans, num_patches, embed_dim) - x = rearrange(x, "b (c p) d -> b c p d", c=self.in_chans, p=num_patches) - # Take mean across the channels (in_chans) - x = x.mean(dim=1) # (B, num_patches, embed_dim) - elif self.reduction == "concat": - # Reshape to (B, num_patches, embed_dim * in_chans) - x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans, p=num_patches) - else: - raise ValueError(f"Unknown reduction type: {self.reduction}") - - # average across patches - x = x.mean(dim=1) # (B, feat_dim) + # Reshape to (B, num_patches, embed_dim * in_chans) + x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans, p=num_patches) + x = x.mean(dim=1) # apply projection to get logits logits = self.classifier(x) @@ -451,45 +451,28 @@ class EMGRegressionHead(nn.Module): A regression head for EMG (Electromyography) signals using convolutional layers. This module processes embedded features from a transformer model to perform - regression, predicting output signals of a specified dimension and length. It supports - different reduction methods for combining channel and patch features, followed by - convolutional layers for regression, and upsampling to a target sequence length. + regression, predicting output signals of a specified dimension and length upsampling to a target sequence length. Args: in_chans (int): Number of input channels (e.g., EMG channels). embed_dim (int): Dimension of the input embeddings. output_dim (int): Dimension of the output regression targets. - reduction (str): Method to reduce features across channels. - "mean" averages embeddings, "concat" concatenates them. Defaults to "concat". hidden_dim (int): Hidden dimension for the convolutional layers. Defaults to 256. dropout (float): Dropout probability applied after the first convolution. Defaults to 0.1. target_length (int): Desired length of the output sequence. If the input length differs, - linear interpolation is used to upsample. Defaults to 500. - - Attributes: - in_chans (int): Number of input channels. - embed_dim (int): Dimension of the embeddings. - output_dim (int): Dimension of the output. - reduction (str): Reduction method used. - dropout (float): Dropout rate. - target_length (int): Target output sequence length. + linear interpolation is used to upsample. Defaults to 1000. """ in_chans: int embed_dim: int output_dim: int - reduction: Literal["mean", "concat"] = "concat" hidden_dim: int = 256 dropout: float = 0.1 - target_length: int = 500 + target_length: int = 1000 def __post_init__(self): super().__init__() - feat_dim = ( - self.embed_dim - if self.reduction == "mean" - else self.in_chans * self.embed_dim - ) + feat_dim = self.in_chans * self.embed_dim self.regressor = nn.Sequential( nn.Conv1d(feat_dim, self.hidden_dim, kernel_size=1), @@ -530,12 +513,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: is the target sequence length and output_dim is the output dimension. """ # x: (B, num_tokens, token_dim) - if self.reduction == "mean": - x = rearrange(x, "b (c p) d -> b p d", c=self.in_chans) - elif self.reduction == "concat": - x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans) - else: - raise ValueError(f"Unknown reduction type: {self.reduction}") + x = rearrange( + x, "b (c p) d -> b p (c d)", c=self.in_chans, p=x.size(1) // self.in_chans + ) # conv head expects (B, C, L) x = x.transpose(1, 2) # (B, feat_dim, num_patches) @@ -575,10 +555,7 @@ class TinyMyo(nn.Module): drop_path (float, optional): Stochastic depth drop path rate. Defaults to 0.1. norm_layer (nn.Module, optional): Normalization layer class. Defaults to nn.LayerNorm. task (str, optional): Task type, one of "pretraining", "classification", or "regression". Defaults to "classification". - classification_type (str, optional): Type of classification (e.g., "ml" for multi-label). Defaults to "ml". - reduction_type (str, optional): Type of reduction to apply, either "mean" or "concat". Defaults to "concat". num_classes (int, optional): Number of classes for classification or output dimension for regression. Defaults to 53. - reg_target_len (int, optional): Target length for regression output. Defaults to 500. """ img_size: int = 1000 @@ -594,15 +571,16 @@ class TinyMyo(nn.Module): drop_path: float = 0.1 norm_layer = nn.LayerNorm task: Literal["pretraining", "classification", "regression"] = "classification" - classification_type: Literal["ml", "mc"] = "ml" - reduction_type: Literal["concat", "mean"] = "concat" num_classes: int = 53 - reg_target_len: int = 500 def __post_init__(self): super().__init__() + # Learnable mask token for pretraining (dimension: embed_dim) self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # Learnable channel IDs (always initialized to the maximum number of channels, but only the first C are used based on input) + self.channel_embed = nn.Parameter(torch.zeros(1, 16, 1, self.embed_dim)) + self.patch_embedding = PatchingModule( img_size=self.img_size, patch_size=self.patch_size, @@ -627,43 +605,47 @@ def __post_init__(self): ) self.norm = self.norm_layer(self.embed_dim) - if ( - self.task == "pretraining" or self.num_classes == 0 - ): # reconstruction (pre-training) + # reconstruction (pre-training) + if self.task == "pretraining": self.model_head = PatchReconstructionHead( img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim, ) - elif self.task == "classification" and self.num_classes > 0: + # classification or regression tasks + elif self.task == "classification": + assert self.num_classes > 0, ( + "num_classes must be > 0 for classification task" + ) self.model_head = EMGClassificationHead( embed_dim=self.embed_dim, num_classes=self.num_classes, in_chans=self.in_chans, - reduction=self.reduction_type, ) elif self.task == "regression": + assert self.num_classes > 0, ( + "num_classes must be > 0 for regression task (output dimension of regression targets)" + ) self.model_head = EMGRegressionHead( in_chans=self.in_chans, embed_dim=self.embed_dim, output_dim=self.num_classes, - reduction=self.reduction_type, - target_length=self.reg_target_len, + target_length=self.img_size, ) else: raise ValueError(f"Unknown task type {self.task}") self.initialize_weights() # Some checks - assert ( - self.img_size % self.patch_size == 0 - ), f"img_size ({self.img_size}) must be divisible by patch_size ({self.patch_size})" + assert self.img_size % self.patch_size == 0, ( + f"img_size ({self.img_size}) must be divisible by patch_size ({self.patch_size})" + ) def initialize_weights(self): """Initializes the model weights.""" - # Encodings Initializations code taken from the LaBraM paper trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.channel_embed, std=0.02) self.apply(self._init_weights) self.fix_init_weight() @@ -680,85 +662,85 @@ def _init_weights(self, m): nn.init.constant_(m.weight, 1.0) def fix_init_weight(self): - """ - Rescales the weights of attention and MLP layers to improve training stability. - - For each layer, weights are divided by sqrt(2 * layer_id). - """ + """Rescales the weights of attention and MLP layers to improve training stability.""" def rescale(param, layer_id): param.div_(math.sqrt(2.0 * layer_id)) - for layer_id, layer in enumerate(self.blocks, start=1): - attn_proj = getattr(getattr(layer, "attn", None), "proj", None) - if attn_proj is not None: - rescale(attn_proj.weight.data, layer_id) - - mlp_fc2 = getattr(getattr(layer, "mlp", None), "fc2", None) - if mlp_fc2 is not None: - rescale(mlp_fc2.weight.data, layer_id) + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) def prepare_tokens( self, x_signal: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: - """ - Prepares input tokens by embedding patches and applying masking if provided. - Args: - x_signal (torch.Tensor): Input signal tensor of shape (B, C, T). - mask (Optional[torch.Tensor]): Optional mask tensor of shape (B, C, T) indicating which patches to mask. - Returns: - torch.Tensor: Prepared token embeddings of shape (B, N, D) where N is number of patches and D is embed_dim. - """ - x_patched = self.patch_embedding(x_signal) # [B, N, D] - x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel + _, C, T = x_signal.shape + + x_patched = self.patch_embedding(x_signal) # [B, N, D] where N = C * P + P = T // self.patch_size + + # Unflatten to grid:[Batch, Channels, Patches, Dim] + x_patched = rearrange(x_patched, "B (C P) D -> B C P D", C=C, P=P) + if mask is not None: - mask_tokens = self.mask_token.repeat( - x_masked.shape[0], x_masked.shape[1], 1 - ) # (B, N, D) N = C * num_patches_per_channel - mask = rearrange( - mask, "B C (S P) -> B (C S) P", P=self.patch_size - ) # (B, C, T) -> (B, N, P) - mask = ( - (mask.sum(dim=-1) > 0).unsqueeze(-1).float() - ) # (B, N, 1), since a patch is either fully masked or not - x_masked = torch.where(mask.bool(), mask_tokens, x_masked) - return x_masked + # Reduce independent signal mask (B, C, T) to patch mask (B, C, P) + mask_p = rearrange(mask, "B C (P s) -> B C P s", s=self.patch_size).any( + dim=-1 + ) + mask_exp = mask_p.unsqueeze(-1) # (B, C, P, 1) - def forward( - self, x_signal: torch.Tensor, mask: Optional[torch.BoolTensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass of the TinyMyo model. + # Apply Masking + x_patched = torch.where( + mask_exp, self.mask_token.view(1, 1, 1, -1), x_patched + ) - This method processes the input signal tensor through the transformer blocks, - applies normalization, and then either reconstructs the signal or performs - classification/regression based on the model's configuration. + # Apply Spatial Identity + # Crop to C channels, fewer channels can be used + x = x_patched + self.channel_embed[:, :C, :, :] - Args: - x_signal (torch.Tensor): The input signal tensor of shape [B, C, T], - where B is batch size, C is number of channels, and T is the temporal dimension. - mask (Optional[torch.BoolTensor]): Optional boolean mask tensor for - masking certain tokens during processing. If None, no masking is applied. + return rearrange(x, "B C P D -> B (C P) D") - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - If num_classes == 0 (reconstruction mode): The reconstructed signal - tensor of shape [B, N, patch_size] and the original input signal tensor. - - Otherwise (classification/regression mode): The output tensor of shape - [B, Out] and the original input signal tensor. - """ - x_original = x_signal.clone() + def forward( + self, + x_signal: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + pad_mask_ch: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, C, T = x_signal.shape x = self.prepare_tokens(x_signal, mask=mask) - # forward pass through transformer blocks + _, N, _ = x.shape + # Number of patches per channel + P = T // self.patch_size + assert C * P == N, f"Token shape mismatch: expected {C * P} tokens, got {N}." + + # Cyclic Sequence for RoPE + # Sequence resets for each channel:[0..P-1, 0..P-1, ...] + pos_ids_single = torch.arange(P, device=x.device).repeat(C) + pos_ids = pos_ids_single.unsqueeze(0).expand(B, -1) # Shape (B, N) + + # Attention Masking for Padded Channels + attn_mask = None + if pad_mask_ch is not None: + token_pad_mask = pad_mask_ch.repeat_interleave(P, dim=1) + # Boolean mask for SDPA: True means keep, False means mask out + attn_mask = (~token_pad_mask).reshape(B, 1, 1, N) + + # Pass through RoPE-Attention Blocks for blk in self.blocks: - x = blk(x) - x_latent = self.norm(x) # [B, N, D] + x = blk(x, pos_ids=pos_ids, attn_mask=attn_mask) + + # Final normalization + x = self.norm(x) + + # Pass through task-specific head + return self.model_head(x) - if self.num_classes == 0: # reconstruction - x_reconstructed = self.model_head(x_latent) # [B, N, patch_size] - return x_reconstructed, x_original - else: # classification or regression - x_out = self.model_head(x_latent) # [B, Out] - return x_out, x_original +if __name__ == "__main__": + model = TinyMyo() + input_signal = torch.randn(1, 16, 1000) # (B, C, T) + with torch.no_grad(): + output = model(input_signal) + print("Input shape:", input_signal.shape) + print("Output shape:", output.shape) diff --git a/run_train.py b/run_train.py index 0431cfc..1d392fd 100644 --- a/run_train.py +++ b/run_train.py @@ -32,15 +32,19 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.strategies import DDPStrategy +import wandb + from util.train_utils import find_last_checkpoint_path -for env_var in ["DATA_PATH", "CHECKPOINT_DIR"]: +for env_var in ["DATA_PATH", "CHECKPOINT_DIR", "LOG_DIR"]: env_var_value = os.environ.get(env_var) if env_var_value is None or env_var_value == "#CHANGEME": - raise RuntimeError(f"Environment variable {env_var} is not set. Please set it before running the script.") + raise RuntimeError( + f"Environment variable {env_var} is not set. Please set it before running the script." + ) OmegaConf.register_new_resolver("env", lambda key: os.getenv(key)) OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) @@ -54,7 +58,7 @@ def train(cfg: DictConfig): seed_everything(cfg.seed) - date_format = "%d_%m_%H-%M" + date_format = "%d_%m_%H-%M-%S.%f" # Create your version_name version = f"{cfg.tag}_{datetime.now().strftime(date_format)}" @@ -64,6 +68,20 @@ def train(cfg: DictConfig): save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=version ) + loggers = [tb_logger] + + # Weights & Biases + wandb_logger = None + if cfg.wandb: + wandb_logger = WandbLogger( + entity=cfg.wandb.entity, + project=cfg.wandb.project, + save_dir=cfg.wandb.save_dir, + name=cfg.wandb.run_name if cfg.wandb.run_name else version, + offline=cfg.wandb.offline, + ) + loggers.append(wandb_logger) + # DataLoader print("===> Loading datasets") data_module = hydra.utils.instantiate(cfg.data_module) @@ -113,14 +131,14 @@ def train(cfg: DictConfig): del cfg.trainer.strategy trainer = Trainer( **cfg.trainer, - logger=tb_logger, + logger=loggers, callbacks=callbacks, strategy=DDPStrategy(find_unused_parameters=cfg.find_unused_parameters), ) else: trainer = Trainer( **cfg.trainer, - logger=tb_logger, + logger=loggers, callbacks=callbacks, ) @@ -158,12 +176,16 @@ def train(cfg: DictConfig): datamodule=data_module, results=results, accelerator=cfg.trainer.accelerator, - last_ckpt=best_ckpt, + ckpt=best_ckpt, + wandb_logger=wandb_logger, ) if not cfg.training: trainer.save_checkpoint(f"{checkpoint_dirpath}/last.ckpt") + if wandb.run is not None: + wandb.finish() + @pl.utilities.rank_zero_only def _run_test( @@ -171,14 +193,16 @@ def _run_test( datamodule: pl.LightningDataModule, results, accelerator, - last_ckpt, + ckpt, + wandb_logger=None, ): trainer = pl.Trainer( accelerator=accelerator, devices=1, + logger=wandb_logger if wandb_logger else [], ) print("===> Start testing") - test_results = trainer.test(module, datamodule=datamodule, ckpt_path=last_ckpt) + test_results = trainer.test(module, datamodule=datamodule, ckpt_path=ckpt) results["test_metrics"] = test_results return results, trainer diff --git a/tasks/finetune_task_EMG.py b/tasks/finetune_task_EMG.py index 8599443..b6d4991 100644 --- a/tasks/finetune_task_EMG.py +++ b/tasks/finetune_task_EMG.py @@ -16,13 +16,12 @@ # * * # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -from typing import Optional +from typing import Tuple import hydra import pytorch_lightning as pl import torch import torch.nn as nn -import torch_optimizer as torch_optim from omegaconf import DictConfig from safetensors.torch import load_file from torchmetrics import MetricCollection @@ -35,114 +34,135 @@ Precision, Recall, ) - -from util.train_utils import MinMaxNormalization +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError class FinetuneTask(pl.LightningModule): - """ - PyTorch Lightning module for fine-tuning a classification model, with support for: - - - Classification types: - - `bc`: Binary Classification - - `ml`: Multi-Label Classification - - - Metric logging during training, validation, and testing, including accuracy, precision, recall, F1 score, AUROC, and more - - Optional input normalization with configurable normalization functions - - Custom optimizer support including SGD, Adam, AdamW, and LAMB - - Learning rate schedulers with configurable scheduling strategies - - Layer-wise learning rate decay for fine-grained learning rate control across model blocks + """PyTorch LightningModule for EMG fine-tuning tasks. + + This module manages the training, validation, and testing for both + classification (gesture detection) and regression (kinematic tracking) tasks. + It supports modular model architectures, metric collections, and layer-wise learning rate decay. + + Attributes: + model (nn.Module): The instantiated neural network. + num_classes (int): Number of target classes or regression outputs. + task (str): The specific task type ('classification' or 'regression'). + normalize (bool): Whether input normalization is enabled. + criterion (nn.Module): The loss function (CrossEntropy or L1). """ def __init__(self, hparams: DictConfig): - """ - Initialize the FinetuneTask module. + """Initializes the FinetuneTask with Hydra configurations. + + Sets up the model, loss functions, and metric collections based on the + provided task type. Args: - hparams (DictConfig): Hyperparameters and configuration loaded via Hydra. + hparams (DictConfig): Configuration object containing 'model', + 'optimizer', 'scheduler', and 'finetuning' parameters. """ super().__init__() self.save_hyperparameters(hparams) self.model = hydra.utils.instantiate(self.hparams.model) self.num_classes = self.hparams.model.num_classes - self.classification_type = self.hparams.model.classification_type - - # Enable normalization if specified in parameters - self.normalize = False - if "input_normalization" in self.hparams and self.hparams.input_normalization.normalize: - self.normalize = True - self.normalize_fct = MinMaxNormalization() - - # Loss function - self.criterion = nn.CrossEntropyLoss(label_smoothing=0.10) - - # Classification mode detection - if not isinstance(self.num_classes, int): - raise TypeError("Number of classes must be an integer.") - elif self.num_classes < 2: - raise ValueError("Number of classes must be at least 2.") - elif self.num_classes == 2: - self.classification_task = "binary" + self.task = self.hparams.model.task + + if self.task == "regression": + self.criterion = nn.L1Loss() + + # Metric + mean_metrics = MetricCollection( + { + "rmse": MeanSquaredError(squared=False), + "mae": MeanAbsoluteError(), + } + ) + + self.train_mean_metrics = mean_metrics.clone(prefix="train/") + self.val_mean_metrics = mean_metrics.clone(prefix="val/") + self.test_mean_metrics = mean_metrics.clone(prefix="test/") + else: - self.classification_task = "multiclass" + # Loss function + self.criterion = hydra.utils.instantiate(self.hparams.criterion) + + # Classification mode detection + if not isinstance(self.num_classes, int): + raise TypeError("Number of classes must be an integer.") + elif self.num_classes < 2: + raise ValueError("Number of classes must be at least 2.") + else: + self.classification_task = "multiclass" + + # Metrics + label_metrics = MetricCollection( + { + "micro_acc": Accuracy( + task=self.classification_task, + num_classes=self.num_classes, + average="micro", + ), + "macro_acc": Accuracy( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "recall": Recall( + task="multiclass", num_classes=self.num_classes, average="macro" + ), + "precision": Precision( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "f1": F1Score( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "cohen_kappa": CohenKappa( + task=self.classification_task, num_classes=self.num_classes + ), + } + ) + logit_metrics = MetricCollection( + { + "auroc": AUROC( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "average_precision": AveragePrecision( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + } + ) + self.train_label_metrics = label_metrics.clone(prefix="train/") + self.val_label_metrics = label_metrics.clone(prefix="val/") + self.test_label_metrics = label_metrics.clone(prefix="test/") + self.train_logit_metrics = logit_metrics.clone(prefix="train/") + self.val_logit_metrics = logit_metrics.clone(prefix="val/") + self.test_logit_metrics = logit_metrics.clone(prefix="test/") - # Metrics - label_metrics = MetricCollection( - { - "micro_acc": Accuracy( - task=self.classification_task, - num_classes=self.num_classes, - average="micro", - ), - "macro_acc": Accuracy( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "recall": Recall(task="multiclass", num_classes=self.num_classes, average="macro"), - "precision": Precision( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "f1": F1Score( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "cohen_kappa": CohenKappa(task=self.classification_task, num_classes=self.num_classes), - } - ) - logit_metrics = MetricCollection( - { - "auroc": AUROC( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "average_precision": AveragePrecision( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - } - ) - self.train_label_metrics = label_metrics.clone(prefix="train/") - self.val_label_metrics = label_metrics.clone(prefix="val/") - self.test_label_metrics = label_metrics.clone(prefix="test/") - self.train_logit_metrics = logit_metrics.clone(prefix="train/") - self.val_logit_metrics = logit_metrics.clone(prefix="val/") - self.test_logit_metrics = logit_metrics.clone(prefix="test/") - - def load_pretrained_checkpoint(self, model_ckpt): - """ - Load a pretrained model checkpoint and unfreeze specific layers for fine-tuning. + def load_pretrained_checkpoint(self, model_ckpt: str) -> None: + """Loads a pretrained PyTorch Lightning checkpoint (.ckpt). + + This method loads the state dict, optionally freezes layers based on configuration, + and ensures the model head remains trainable for fine-tuning. + + Args: + model_ckpt (str): Path to the .ckpt file. """ assert self.model.model_head is not None print("Loading pretrained checkpoint from .ckpt file") checkpoint = torch.load(model_ckpt, map_location="cpu", weights_only=False) state_dict = checkpoint["state_dict"] - self.load_state_dict(state_dict, strict=False) + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + print(f"Missing keys when loading checkpoint: {missing_keys}") + print(f"Unexpected keys when loading checkpoint: {unexpected_keys}") for name, param in self.model.named_parameters(): if self.hparams.finetuning.freeze_layers: param.requires_grad = False @@ -151,14 +171,18 @@ def load_pretrained_checkpoint(self, model_ckpt): print("Pretrained model ready.") - def load_safetensors_checkpoint(self, model_ckpt): - """ - Load a pretrained model checkpoint in safetensors format and unfreeze specific layers for fine-tuning. + def load_safetensors_checkpoint(self, model_ckpt: str) -> None: + """Loads a pretrained model checkpoint in safetensors format. + + Args: + model_ckpt (str): Path to the .safetensors file. """ assert self.model.model_head is not None print("Loading pretrained safetensors checkpoint") state_dict = load_file(model_ckpt) - self.load_state_dict(state_dict, strict=False) + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + print(f"Missing keys when loading checkpoint: {missing_keys}") + print(f"Unexpected keys when loading checkpoint: {unexpected_keys}") for name, param in self.model.named_parameters(): if self.hparams.finetuning.freeze_layers: @@ -168,57 +192,56 @@ def load_safetensors_checkpoint(self, model_ckpt): print("Pretrained model ready.") - def generate_fake_mask(self, batch_size, C, T): - """ - Create a dummy mask tensor to simulate attention masking. - - Args: - batch_size (int): Number of samples. - C (int): Number of channels. - T (int): Temporal dimension. - - Returns: - torch.Tensor: Boolean mask tensor of shape (B, C, T). - """ - return torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device) - - def _step(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> dict: - """ - Perform forward pass and post-process predictions. + def _step(self, X: torch.Tensor) -> dict: + """Performs a forward pass and extracts probabilities and labels. Args: - X (torch.Tensor): Input tensor. + X (torch.Tensor): Input EMG tensor of shape (B, C, T). Returns: - dict: Dictionary containing predicted labels, probabilities, and logits. + dict: Dictionary with keys "label", and "logits". """ - y_pred_logits, _ = self.model(X, mask=mask) - - if self.classification_type in ("bc", "ml"): - y_pred_probs = torch.softmax(y_pred_logits, dim=1) - y_pred_label = torch.argmax(y_pred_probs, dim=1) + y_pred_logits = self.model(X) + if self.task != "regression": + y_pred_label = torch.argmax(y_pred_logits, dim=1) else: - raise NotImplementedError(f"No valid classification type: {self.classification_type}") + y_pred_label = None return { "label": y_pred_label, - "probs": y_pred_probs, "logits": y_pred_logits, } - def training_step(self, batch, batch_idx): + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Standard PyTorch Lightning training step. + + Args: + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (input, target). + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Computed loss. + """ X, y = batch - if self.normalize: - X = self.normalize_fct(X) - mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2]) - y_pred = self._step(X, mask=mask) + + y_pred = self._step(X) loss = self.criterion(y_pred["logits"], y) - self.train_label_metrics(y_pred["label"], y) - self.train_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False) - self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False) + if self.task == "regression": + logits_flat = y_pred["logits"].reshape( + -1, self.num_classes + ) # (B*T, num_classes) + y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) + self.train_mean_metrics(logits_flat, y_flat) + self.log_dict(self.train_mean_metrics, on_step=True, on_epoch=False) + else: + self.train_label_metrics(y_pred["label"], y) + self.train_logit_metrics(y_pred["logits"], y) + self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False) + self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False) self.log( "train_loss", loss, @@ -230,89 +253,140 @@ def training_step(self, batch, batch_idx): ) return loss - def validation_step(self, batch, batch_idx): + def validation_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Standard PyTorch Lightning validation step. + + Args: + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (input, target). + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Computed validation loss. + """ X, y = batch - if self.normalize: - X = self.normalize_fct(X) - mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2]) - y_pred = self._step(X, mask=mask) + + y_pred = self._step(X) loss = self.criterion(y_pred["logits"], y) - self.val_label_metrics(y_pred["label"], y) - self.val_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True) - self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True) + if self.task == "regression": + logits_flat = y_pred["logits"].reshape( + -1, self.num_classes + ) # (B*T, num_classes) + y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) + self.val_mean_metrics(logits_flat, y_flat) + self.log_dict(self.val_mean_metrics, on_step=False, on_epoch=True) + else: + self.val_label_metrics(y_pred["label"], y) + self.val_logit_metrics(y_pred["logits"], y) + self.log_dict( + self.val_label_metrics, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True) self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss - def test_step(self, batch, batch_idx): + def test_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: + """Standard PyTorch Lightning test step. + + Args: + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (input, target). + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Computed test loss. + """ X, y = batch - if self.normalize: - X = self.normalize_fct(X) - mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2]) - y_pred = self._step(X, mask=mask) + + y_pred = self._step(X) loss = self.criterion(y_pred["logits"], y) - self.test_label_metrics(y_pred["label"], y) - self.test_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True) - self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True) - self.log("test_loss", loss, prog_bar=True, logger=True, sync_dist=True) + if self.task == "regression": + logits_flat = y_pred["logits"].reshape( + -1, self.num_classes + ) # (B*T, num_classes) + y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) + self.test_mean_metrics(logits_flat, y_flat) + self.log_dict(self.test_mean_metrics, on_step=False, on_epoch=True) + else: + self.test_label_metrics(y_pred["label"], y) + self.test_logit_metrics(y_pred["logits"], y) + self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True) + self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True) return loss - def lr_scheduler_step(self, scheduler, metric): - """ - Custom scheduler step function for step-based LR schedulers - """ - scheduler.step(epoch=self.current_epoch) - - def configure_optimizers(self): + def configure_optimizers(self) -> dict: """ Configure the optimizer and learning rate scheduler. Returns: dict: Configuration dictionary with optimizer and LR scheduler. """ - num_blocks = self.hparams.model.n_layer - params_to_pass = [] base_lr = self.hparams.optimizer.lr + base_wd = getattr(self.hparams.optimizer, "weight_decay", 0.0) decay_factor = self.hparams.layerwise_lr_decay + num_blocks = self.hparams.model.n_layer + # 1) gather all model parameters, but tag which ones should get no weight decay + no_decay = {"positional_embedding", "channel_embed"} + params_to_pass = [] for name, param in self.model.named_parameters(): + if not param.requires_grad: + continue + + # Skip parameters that belong to the head, they will be handled separately + if "model_head" in name: + continue + + # find layer‐wise learning rate lr = base_lr - if "mamba_blocks" in name or "norm_layers" in name: - block_nr = int(name.split(".")[1]) - lr *= decay_factor ** (num_blocks - block_nr) - params_to_pass.append({"params": param, "lr": lr}) - - if self.hparams.optimizer.optim == "SGD": - optimizer = torch.optim.SGD(params_to_pass, lr=base_lr, momentum=self.hparams.optimizer.momentum) - elif self.hparams.optimizer.optim == "Adam": - optimizer = torch.optim.Adam( - params_to_pass, - lr=base_lr, - weight_decay=self.hparams.optimizer.weight_decay, - ) - elif self.hparams.optimizer.optim == "AdamW": - optimizer = torch.optim.AdamW( - params_to_pass, - lr=base_lr, - weight_decay=self.hparams.optimizer.weight_decay, - betas=self.hparams.optimizer.betas, + if name.startswith("blocks."): + block_id = int(name.split(".")[1]) + lr = base_lr * (decay_factor ** (num_blocks - block_id)) + + # choose weight_decay + wd = 0.0 if any(nd in name for nd in no_decay) else base_wd + + params_to_pass.append( + { + "params": [param], + "lr": lr, + "weight_decay": wd, + } ) - elif self.hparams.optimizer.optim == "LAMB": - optimizer = torch_optim.Lamb(params_to_pass, lr=base_lr) - else: - raise NotImplementedError("No valid optimizer name") - if self.hparams.scheduler_type == "multi_step_lr": - scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer) - else: - scheduler = hydra.utils.instantiate( - self.hparams.scheduler, - optimizer=optimizer, - total_training_opt_steps=self.trainer.estimated_stepping_batches, - ) + # 2) head params always get the base lr & base weight decay + head_params = [ + p for n, p in self.model.model_head.named_parameters() if p.requires_grad + ] + params_to_pass.append( + { + "params": head_params, + "lr": base_lr, + "weight_decay": base_wd, + } + ) + + optimizer = torch.optim.AdamW( + params_to_pass, + lr=self.hparams.optimizer.lr, + weight_decay=self.hparams.optimizer.weight_decay, + betas=self.hparams.optimizer.betas, + ) + + scheduler = hydra.utils.instantiate( + self.hparams.scheduler, + optimizer, + total_training_opt_steps=self.trainer.estimated_stepping_batches, + ) lr_scheduler_config = { "scheduler": scheduler, @@ -323,17 +397,5 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - def _handle_binary(self, preds): - """ - Special handling for binary classification probabilities. - - Args: - preds (torch.Tensor): Logit outputs. - - Returns: - torch.Tensor: Probabilities for the positive class. - """ - if self.classification_task == "binary" and self.classification_type != "mc": - return preds[:, 1].squeeze() - else: - return preds + def lr_scheduler_step(self, scheduler, metric): + scheduler.step(epoch=self.current_epoch) diff --git a/tasks/pretrain_task_EMG.py b/tasks/pretrain_task_EMG.py index c535075..b685113 100644 --- a/tasks/pretrain_task_EMG.py +++ b/tasks/pretrain_task_EMG.py @@ -14,27 +14,20 @@ # * See the License for the specific language governing permissions and * # * limitations under the License. * # * * -# * Author: Matteo Fasulo * +# * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -import hydra import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import torch -import torch_optimizer as torch_optim from omegaconf import DictConfig +import hydra +import torch +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +import wandb from util.train_utils import MinMaxNormalization class MaskTask(pl.LightningModule): - """ - PyTorch Lightning module for training a model with masked reconstruction. - - Args: - hparams (DictConfig): Parameters and configurations loaded via Hydra. - """ - def __init__(self, hparams: DictConfig): super().__init__() self.save_hyperparameters(hparams) @@ -53,6 +46,7 @@ def __init__(self, hparams: DictConfig): def generate_mask(self, batch_size, C, T): """ Generate per-sample patch-level boolean masks (MAE-style). + Fully independent random masking across both time and channels. Returns: mask_full (torch.BoolTensor): Shape (B, C, T) @@ -61,114 +55,104 @@ def generate_mask(self, batch_size, C, T): patch_H, patch_W = self.patch_size num_patches_H = C // patch_H num_patches_W = T // patch_W - N = num_patches_H * num_patches_W + + # Total number of patches per sample (e.g., 16 channels * 50 time patches = 800) + num_patches_total = num_patches_H * num_patches_W # Number of patches to mask per sample - num_to_mask = int(N * self.masking_ratio) + num_to_mask = int(num_patches_total * self.masking_ratio) - # Generate patch-level mask (B, N) - vectorized - mask_patches = torch.zeros(batch_size, N, dtype=torch.bool, device=self.device) + # Generate a flat mask over ALL patches (B, num_patches_total) + mask_flat = torch.zeros( + batch_size, num_patches_total, dtype=torch.bool, device=self.device + ) for b in range(batch_size): - selected = torch.randperm(N, device=self.device)[:num_to_mask] - mask_patches[b, selected] = True + selected = torch.randperm(num_patches_total, device=self.device)[ + :num_to_mask + ] + mask_flat[b, selected] = True - # unpatchify using reshape and repeat_interleave - # (B, N) -> (B, num_patches_H, num_patches_W) - mask_patches_2d = mask_patches.reshape(batch_size, num_patches_H, num_patches_W) + # Reshape the flat mask back into the 2D grid of patches (B, num_patches_H, num_patches_W) + mask_patches_2d = mask_flat.view(batch_size, num_patches_H, num_patches_W) - # Expand to full shape using repeat_interleave + # Expand to full signal shape using repeat_interleave # (B, num_patches_H, num_patches_W) -> (B, C, T) - mask_full = mask_patches_2d.repeat_interleave(patch_H, dim=1).repeat_interleave(patch_W, dim=2) + mask_full = mask_patches_2d.repeat_interleave(patch_H, dim=1).repeat_interleave( + patch_W, dim=2 + ) return mask_full def unpatchify(self, x_patches: torch.Tensor, in_chans: int) -> torch.Tensor: """ Convert patch embeddings (B, N, P) back to waveform (B, C, T) - - Args: - x_patches: (B, N, P) - in_chans: number of channels C - Returns: - x_reconstructed: (B, C, T) """ B, N, P = x_patches.shape num_patches_per_chan = N // in_chans x_recon = x_patches.reshape(B, in_chans, num_patches_per_chan * P) return x_recon - def training_step(self, batch, batch_idx): - """ - Training step: apply mask, normalize and compute loss. + def _step(self, X): + B, C, T = X.shape - Args: - batch (torch.Tensor): Input batch. - batch_idx (int): Batch index. + # Detect zero-padded channels (Shape: B, C. True if channel is 0.0 padded) + pad_mask_ch = (X.abs().max(dim=-1).values == 0) - Returns: - torch.Tensor: Loss value. - """ - X = batch - mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2]) + # Generate symmetrical time mask + mask = self.generate_mask(B, C, T) + + # Remove padded channels from the mask + # Broadcast pad_mask_ch (B, C) to (B, C, T) and set mask to False + mask[pad_mask_ch.unsqueeze(-1).expand(-1, -1, T)] = False if self.normalize: X = self.normalize_fct(X) - x_reconstructed, x_original = self.model(X, mask=mask) # x_reconstructed: (B, N, P) + # Pass pad_mask_ch to the model so the attention can ignore them. + x_reconstructed = self.model(X, mask=mask, pad_mask_ch=pad_mask_ch) - # unpatchify to original signal shape (B, C, T) - x_reconstructed_unpatched = self.unpatchify(x_reconstructed, self.hparams.model.in_chans) - - # Compute loss on masked parts and unmasked parts (with coefficient) - masked_loss, unmasked_loss = self.criterion(x_reconstructed_unpatched, x_original, mask) - loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss - - self.log( - "train_loss", - masked_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss + return { + "x_original": X, + "x_reconstructed": x_reconstructed, + "mask": mask, + } - def validation_step(self, batch, batch_idx): - """ - Validation step: apply mask, normalize, compute loss and log signals. - Args: - batch (torch.Tensor): Input batch. - batch_idx (int): Batch index. + def training_step(self, X, batch_idx): + out = self._step(X) + x_original = out["x_original"] + mask = out["mask"] - Returns: - torch.Tensor: Loss value. - """ - X = batch - mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2]) + x_reconstructed = self.unpatchify(out["x_reconstructed"], in_chans=self.model.in_chans) + masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask) + loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss - if self.normalize: - X = self.normalize_fct(X) + losses = { + "loss": loss, + "masked_loss": masked_loss, + "unmasked_loss": unmasked_loss, + } - x_reconstructed, x_original = self.model(X, mask=mask) # x_reconstructed: (B, N, P) + self.log_dict({f"train_{k}": v for k, v in losses.items()}, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True) + return loss - # unpatchify to original signal shape (B, C, T) - x_reconstructed_unpatched = self.unpatchify(x_reconstructed, self.hparams.model.in_chans) + def validation_step(self, X, batch_idx): + out = self._step(X) + x_original = out["x_original"] + x_reconstructed = self.unpatchify(out["x_reconstructed"], in_chans=self.model.in_chans) + mask = out["mask"] - # Compute loss on masked parts and unmasked parts (with coefficient) - masked_loss, unmasked_loss = self.criterion(x_reconstructed_unpatched, x_original, mask) + masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask) loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss - self.log( - "val_loss", - loss, - prog_bar=True, - on_step=False, - on_epoch=True, - logger=True, - sync_dist=True, - ) + losses = { + "loss": loss, + "masked_loss": masked_loss, + "unmasked_loss": unmasked_loss, + } + + self.log_dict({f"val_{k}": v for k, v in losses.items()}, prog_bar=True, on_step=False, on_epoch=True, logger=True, sync_dist=True) # Fixed indices for logging signals random_indices = [6, 16, 30] @@ -177,46 +161,22 @@ def validation_step(self, batch, batch_idx): if batch_idx == 0: self.log_signals_with_mask( x_original.float(), - x_reconstructed_unpatched.float(), + x_reconstructed.float(), mask, batch_indices=random_indices, - batch_idx=batch_idx, ) return loss def configure_optimizers(self): - """ - Configure optimizer and scheduler based on parameters. - - Returns: - dict: Dictionary with optimizer and scheduler for PyTorch Lightning. - """ - if self.hparams.optimizer.optim == "SGD": - optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9) - elif self.hparams.optimizer.optim == "Adam": - optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.hparams.optimizer.lr, - weight_decay=self.hparams.optimizer.weight_decay, - ) - elif self.hparams.optimizer.optim == "AdamW": - optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=self.hparams.optimizer.lr, - weight_decay=self.hparams.optimizer.weight_decay, - ) - elif self.hparams.optimizer.optim == "LAMB": - optimizer = torch_optim.Lamb( - self.model.parameters(), - lr=self.hparams.optimizer.lr, - ) - else: - raise NotImplementedError("No valid optim name") - - scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer) + optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=self.hparams.optimizer.lr, + weight_decay=self.hparams.optimizer.weight_decay, + ) + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer, total_training_opt_steps=self.trainer.estimated_stepping_batches) lr_scheduler_config = { "scheduler": scheduler, - "interval": "epoch", + "interval": "step", "frequency": 1, "monitor": "val_loss", } @@ -224,9 +184,9 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} def lr_scheduler_step(self, scheduler, metric): - scheduler.step(epoch=self.current_epoch) + scheduler.step(self.global_step) - def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indices=None, batch_idx=None): + def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indices=None): """ Log original and reconstructed signals highlighting masked regions. @@ -235,7 +195,6 @@ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indice reconstructed (torch.Tensor): Signals reconstructed by the model. mask (torch.BoolTensor, optional): Applied mask. batch_indices (list[int], optional): Batch indices to log. - batch_idx (int, optional): Current batch index. """ patch_H, patch_W = self.patch_size batch_size, C, T = original.shape @@ -250,18 +209,8 @@ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indice original_signal_c2 = original_signal[:patch_H, :] reconstructed_signal_c2 = reconstructed_signal[:patch_H, :] - ax.plot( - original_signal_c2[0].cpu().numpy(), - label="Original Channel 0", - color="blue", - alpha=0.7, - ) - ax.plot( - reconstructed_signal_c2[0].cpu().numpy(), - label="Reconstructed Channel 0", - color="orange", - alpha=0.7, - ) + ax.plot(original_signal_c2[0].cpu().numpy(), label='Original Channel 0', color='blue', alpha=0.7) + ax.plot(reconstructed_signal_c2[0].cpu().numpy(), label='Reconstructed Channel 0', color='orange', alpha=0.7) if mask is not None: mask_c2 = mask[batch_idx, :patch_H, :] @@ -270,26 +219,27 @@ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indice # Highlight masked regions with a light gray transparent band for i in range(patch_H): for j in range(T // patch_W): - if mask_c2[i, j * patch_W : (j + 1) * patch_W].all(): - ax.axvspan( - j * patch_W, - (j + 1) * patch_W, - color="lightgray", - alpha=0.1, - ) + if mask_c2[i, j * patch_W:(j + 1) * patch_W].all(): + ax.axvspan(j * patch_W, (j + 1) * patch_W, color='lightgray', alpha=0.3) indices.append(j) - # Remove duplicates and sort highlighted indices - indices_array = np.array(indices) - indices_array = np.unique(indices_array) - ax.set_title(f"Signal Reconstruction - batch_ {batch_idx}") ax.legend() - # Log the figure on TensorBoard with batch and index in the title - self.logger.experiment.add_figure( - f"Original and Reconstructed Signals with Mask (batch_0_ {batch_idx}, F1 = 0)", - fig, - self.current_epoch, - ) - plt.close(fig) + # Log the figure to WandB + if self.trainer.is_global_zero: + wandb_logger = None + for logger in self.trainer.loggers: + if isinstance(logger, WandbLogger): + wandb_logger = logger + break + + if wandb_logger: + wandb_logger.experiment.log({ + "reconstruction/channel0": wandb.Image( + fig, + caption=f"epoch={self.current_epoch}, batch={batch_idx}" + ) + }, step=self.global_step) + + plt.close(fig) \ No newline at end of file diff --git a/util/ckpt_to_safetensor.py b/util/ckpt_to_safetensor.py index 2441f74..04415e7 100644 --- a/util/ckpt_to_safetensor.py +++ b/util/ckpt_to_safetensor.py @@ -34,7 +34,7 @@ parser.add_argument( "--safetensor_path", type=str, - default="model.safetensors", + required=True, help="Path to save the converted safetensors file.", ) parser.add_argument(