From 9d3334c78b382d08a8afe4a7889e98dfa620e10b Mon Sep 17 00:00:00 2001 From: Matteo Fasulo <74818541+MatteoFasulo@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:29:55 +0100 Subject: [PATCH 1/9] Update hand gesture classification experiment details --- docs/model/TinyMyo.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/model/TinyMyo.md b/docs/model/TinyMyo.md index 35a332d..281759e 100644 --- a/docs/model/TinyMyo.md +++ b/docs/model/TinyMyo.md @@ -113,9 +113,9 @@ TinyMyo supports three major categories: Evaluated on: -* **Ninapro DB5** (52 classes, 10 subjects) -* **EPN-612** (5 classes, 612 subjects) -* **UCI EMG** (6 classes, 36 subjects) +* **Ninapro DB5** (52 classes, 10 subjects, 200 Hz) +* **EPN-612** (5 classes, 612 subjects, 200 Hz) +* **UCI EMG** (6 classes, 36 subjects, 200 Hz) * **Generic Neuromotor Interface** (Meta wristband; 9 gestures) * Repository: [MatteoFasulo/generic-neuromotor-interface](https://github.com/MatteoFasulo/generic-neuromotor-interface) @@ -126,8 +126,8 @@ Evaluated on: * EMG filtering: **20–90 Hz** bandpass + 50 Hz notch * Windows: - * **200 ms** (best for DB5) - * **1000 ms** (best for EPN & UCI) + * **1 sec** (best for DB5) + * **5 sec** (best for EPN & UCI) * Per-channel z-scoring * Linear classification head @@ -138,9 +138,9 @@ Evaluated on: | Dataset | Metric | Result | | ------------------------ | -------- | ----------------- | -| **Ninapro DB5 (200 ms)** | Accuracy | **89.41 ± 0.16%** | -| **EPN-612 (1000 ms)** | Accuracy | **96.74 ± 0.09%** | -| **UCI EMG (1000 ms)** | Accuracy | **97.56 ± 0.32%** | +| **Ninapro DB5 (1 sec)** | Accuracy | **89.41 ± 0.16%** | +| **EPN-612 (5 sec)** | Accuracy | **96.74 ± 0.09%** | +| **UCI EMG (5 sec)** | Accuracy | **97.56 ± 0.32%** | | **Neuromotor Interface** | CLER | **0.153 ± 0.006** | TinyMyo achieves **state-of-the-art** on DB5, EPN-612, and UCI. @@ -151,7 +151,7 @@ TinyMyo achieves **state-of-the-art** on DB5, EPN-612, and UCI. Dataset: **Ninapro DB8** Task: Regress **5 joint angles (DoA)** -Preprocessing: z-score only; windows of **200 ms** or **1000 ms** +Preprocessing: z-score only; windows of **100 ms** or **500 ms** **Regression head (788k params)** @@ -219,7 +219,7 @@ Key elements: * Integer softmax, integer LayerNorm, integer GELU * Static liveness-based memory arena -**Runtime (NinaPro EPN612 pipeline):** +**Runtime (EPN612 dataset):** * **0.785 s inference time** * **44.91 mJ energy** From 28a362e3383ececca257f72b582040cbf67672b9 Mon Sep 17 00:00:00 2001 From: Matteo Fasulo <74818541+MatteoFasulo@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:29:08 +0100 Subject: [PATCH 2/9] Clarify dataset and performance metrics in TinyMyo.md Updated dataset information and performance metrics for hand kinematic regression. --- docs/model/TinyMyo.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/model/TinyMyo.md b/docs/model/TinyMyo.md index 281759e..06a52bc 100644 --- a/docs/model/TinyMyo.md +++ b/docs/model/TinyMyo.md @@ -149,7 +149,7 @@ TinyMyo achieves **state-of-the-art** on DB5, EPN-612, and UCI. ### **4.2 Hand Kinematic Regression** -Dataset: **Ninapro DB8** +Dataset: **Ninapro DB8** (2000 Hz) Task: Regress **5 joint angles (DoA)** Preprocessing: z-score only; windows of **100 ms** or **500 ms** @@ -162,7 +162,7 @@ Preprocessing: z-score only; windows of **100 ms** or **500 ms** **Performance (Fine-tuned)** -* **MAE = 8.77 ± 0.12°** (1000 ms window) +* **MAE = 8.77 ± 0.12°** (500 ms window) Although previous works achieve lower MAE (≈6.89°), those models are **subject-specific**, whereas TinyMyo trains **one model across all subjects**, a significantly harder problem. From 99393ed8de67d263840d438ea34fc75fe7bc01e1 Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Tue, 31 Mar 2026 11:19:36 +0200 Subject: [PATCH 3/9] feat: Add optional Weights & Biases logger integration to training script --- run_train.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/run_train.py b/run_train.py index 0431cfc..2ee0cd8 100644 --- a/run_train.py +++ b/run_train.py @@ -32,12 +32,12 @@ from omegaconf import DictConfig, OmegaConf from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.strategies import DDPStrategy from util.train_utils import find_last_checkpoint_path -for env_var in ["DATA_PATH", "CHECKPOINT_DIR"]: +for env_var in ["DATA_PATH", "CHECKPOINT_DIR", "LOG_DIR"]: env_var_value = os.environ.get(env_var) if env_var_value is None or env_var_value == "#CHANGEME": raise RuntimeError(f"Environment variable {env_var} is not set. Please set it before running the script.") @@ -64,6 +64,18 @@ def train(cfg: DictConfig): save_dir=osp.expanduser(cfg.io.base_output_path), name=cfg.tag, version=version ) + loggers = [tb_logger] + + # Weights & Biases + if cfg.wandb: + wandb_logger = WandbLogger( + name=version, + project=cfg.wandb.project, + log_model="all", + save_dir=cfg.wandb.save_dir, + ) + loggers.append(wandb_logger) + # DataLoader print("===> Loading datasets") data_module = hydra.utils.instantiate(cfg.data_module) @@ -113,14 +125,14 @@ def train(cfg: DictConfig): del cfg.trainer.strategy trainer = Trainer( **cfg.trainer, - logger=tb_logger, + logger=loggers, callbacks=callbacks, strategy=DDPStrategy(find_unused_parameters=cfg.find_unused_parameters), ) else: trainer = Trainer( **cfg.trainer, - logger=tb_logger, + logger=loggers, callbacks=callbacks, ) From ca5a85aa6a296e4c3a25e7f0c187bc13aca23ed7 Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Tue, 31 Mar 2026 11:20:50 +0200 Subject: [PATCH 4/9] feat: Add local dataset implementations module for BioFoundation. This avoids namespace issues with third party libraries (e.g., datasets from HuggingFace) --- datasets/__init__.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 datasets/__init__.py diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000..6dfe643 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1 @@ +"""Local dataset implementations for BioFoundation.""" From 8f6a7d1feaf9de9ae07bf5d580a19aa69b76dca1 Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Tue, 31 Mar 2026 11:22:29 +0200 Subject: [PATCH 5/9] fix: Comment out Grouped Query Attention (GQA) for backward compatibility with old pytorch versions (<2.5) --- models/TinyMyo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/TinyMyo.py b/models/TinyMyo.py index 5e15687..203d1f0 100644 --- a/models/TinyMyo.py +++ b/models/TinyMyo.py @@ -260,7 +260,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: v, dropout_p=self.attn_drop if self.training else 0.0, is_causal=False, - enable_gqa=False, + #enable_gqa=False, ) x = x.transpose(2, 1).reshape(B, N, C) From c767e4be2a504878f53b886326b84c12905164fc Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Tue, 31 Mar 2026 11:38:03 +0200 Subject: [PATCH 6/9] feat: EMG dataset handling - large scale pretraining using shared memory optimization for DDP - finetuning dataset preload in RAM option and documentation - regression task support for EMG finetuning - additional dataset configurations (log dir, individual h5 file) --- .../data_module/emg_pretrain_data_module.yaml | 34 +- config/experiment/TinyMyo_finetune.yaml | 6 +- config/experiment/TinyMyo_pretrain.yaml | 6 +- datasets/emg_finetune_dataset.py | 150 ++++---- datasets/emg_pretrain_dataset.py | 258 +++++++++----- tasks/finetune_task_EMG.py | 333 ++++++++++++------ 6 files changed, 502 insertions(+), 285 deletions(-) diff --git a/config/data_module/emg_pretrain_data_module.yaml b/config/data_module/emg_pretrain_data_module.yaml index 55b293f..9bbc3f6 100644 --- a/config/data_module/emg_pretrain_data_module.yaml +++ b/config/data_module/emg_pretrain_data_module.yaml @@ -27,14 +27,36 @@ data_module: train_val_split_ratio: 0.8 datasets: demo_dataset: null - emg2pose: + emg2pose_train: _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' - data_dir: "${env:DATA_PATH}/emg2pose/h5/" - db6: + hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/train.h5" + emg2pose_val: _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' - data_dir: "${env:DATA_PATH}/ninapro/DB6/h5/" + hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/val.h5" + emg2pose_test: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/test.h5" + db6_train: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/train.h5" + pad_up_to_max_chans: 16 + db6_val: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/val.h5" + pad_up_to_max_chans: 16 + db6_test: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/test.h5" + pad_up_to_max_chans: 16 + db7_train: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/train.h5" + pad_up_to_max_chans: 16 + db7_val: + _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' + hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/val.h5" pad_up_to_max_chans: 16 - db7: + db7_test: _target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset' - data_dir: "${env:DATA_PATH}/ninapro/DB7/h5/" + hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/test.h5" pad_up_to_max_chans: 16 diff --git a/config/experiment/TinyMyo_finetune.yaml b/config/experiment/TinyMyo_finetune.yaml index 9532c5f..b4e6c20 100644 --- a/config/experiment/TinyMyo_finetune.yaml +++ b/config/experiment/TinyMyo_finetune.yaml @@ -41,7 +41,7 @@ finetuning: freeze_layers: False io: - base_output_path: ${env:DATA_PATH} + base_output_path: ${env:LOG_DIR} checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints version: 0 @@ -98,3 +98,7 @@ scheduler: warmup_epochs: 5 total_training_opt_steps: ${max_epochs} t_in_epochs: True + +wandb: + project: "TinyMyo" + save_dir: ${env:LOG_DIR} \ No newline at end of file diff --git a/config/experiment/TinyMyo_pretrain.yaml b/config/experiment/TinyMyo_pretrain.yaml index 3485edd..0c0c74c 100644 --- a/config/experiment/TinyMyo_pretrain.yaml +++ b/config/experiment/TinyMyo_pretrain.yaml @@ -30,7 +30,7 @@ final_test: False pretrained_checkpoint_path: null io: - base_output_path: ${env:DATA_PATH} + base_output_path: ${env:LOG_DIR} checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints version: 0 @@ -77,3 +77,7 @@ optimizer: lr: 1e-4 betas: [0.9, 0.98] weight_decay: 0.01 + +wandb: + project: "TinyMyo" + save_dir: ${env:LOG_DIR} diff --git a/datasets/emg_finetune_dataset.py b/datasets/emg_finetune_dataset.py index e6ad10f..103bbf6 100644 --- a/datasets/emg_finetune_dataset.py +++ b/datasets/emg_finetune_dataset.py @@ -24,93 +24,110 @@ class EMGDataset(torch.utils.data.Dataset): - """ - A PyTorch Dataset class for loading EMG (Electromyography) data from HDF5 files. - This dataset supports lazy loading of data from HDF5 files, with optional caching - to improve performance during training. It can be used for both fine-tuning (with labels) - and inference (without labels) modes. The class handles data preprocessing, such as - converting to tensors and optional unsqueezing. + """PyTorch Dataset for loading EMG data from HDF5 files. + + This dataset supports both classification and regression tasks. It provides + two loading modes: pre-loading the entire dataset into RAM for speed, or + lazy-loading from disk with an LRU cache for large datasets that don't fit + in memory. + Attributes: - hdf5_file (str): Path to the HDF5 file containing the dataset. - unsqueeze (bool): Whether to add an extra dimension to the input data (default: False). - finetune (bool): If True, loads both data and labels; if False, loads only data (default: True). - cache_size (int): Maximum number of samples to cache in memory (default: 1500). - use_cache (bool): Whether to use caching for faster access (default: True). - regression (bool): If True, treats labels as regression targets (float); else, classification (long) (default: False). - num_samples (int): Total number of samples in the dataset, determined from HDF5 file. - data (h5py.File or None): Handle to the opened HDF5 file (lazy-loaded). - X_ds (h5py.Dataset or None): Dataset handle for input data. - Y_ds (h5py.Dataset or None): Dataset handle for labels (if finetune is True). - cache (dict): Dictionary for caching data items (if use_cache is True). - cache_queue (deque): Queue to track the order of cached items for LRU eviction. - Note: - - The HDF5 file is expected to have 'data' and 'label' datasets. - - Caching uses an LRU (Least Recently Used) eviction policy. - - Suitable for use with PyTorch DataLoader for batched loading. + hdf5_file (str): Path to the HDF5 source file. + finetune (bool): If True, returns (data, label). If False, returns data only. + unsqueeze (bool): If True, adds a channel dimension to the input. + cache_size (int): Max number of samples to keep in the LRU cache. + use_cache (bool): Enables LRU caching for lazy loading. + regression (bool): If True, labels are treated as floats. Else, longs. + preload_in_memory (bool): If True, loads the full HDF5 content into RAM on init. + num_samples (int): Total number of samples in the dataset. """ - def __init__( self, hdf5_file: str, - unsqueeze: bool = False, finetune: bool = True, + unsqueeze: bool = False, cache_size: int = 1500, - use_cache: bool = True, + use_cache: bool = False, regression: bool = False, + preload_in_memory: bool = True, ): self.hdf5_file = hdf5_file + self.finetune = finetune self.unsqueeze = unsqueeze self.cache_size = cache_size - self.finetune = finetune self.use_cache = use_cache self.regression = regression + self.preload_in_memory = preload_in_memory self.data = None self.X_ds = None self.Y_ds = None - - # Open once to get length, then close immediately - with h5py.File(self.hdf5_file, "r") as f: - self.num_samples = f["data"].shape[0] + self.X_tensor = None + self.Y_tensor = None + + if self.preload_in_memory: + with h5py.File(self.hdf5_file, "r") as f: + X_np = f["data"][:] + Y_np = f["label"][:] if self.finetune else None + + self.X_tensor = torch.from_numpy(X_np).float().contiguous() + if self.finetune: + if self.regression: + self.Y_tensor = torch.from_numpy(Y_np).float().contiguous() + else: + self.Y_tensor = torch.from_numpy(Y_np).long().contiguous() + + self.num_samples = self.X_tensor.shape[0] + else: + self.data = h5py.File(self.hdf5_file, "r") + self.X_ds = self.data["data"] + self.Y_ds = self.data["label"] if self.finetune else None + self.num_samples = self.X_ds.shape[0] if self.use_cache: self.cache: dict[int, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} self.cache_queue = deque() - def _open_file(self) -> None: - # 'rdcc_nbytes' to increase the raw data chunk cache size - self.data = h5py.File(self.hdf5_file, "r", rdcc_nbytes=1024 * 1024 * 4) - if self.data is not None: - self.X_ds = self.data["data"] - self.Y_ds = self.data["label"] - def __len__(self) -> int: + """Returns the total number of samples in the dataset.""" return self.num_samples - def __getitem__(self, index): + def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Retrieves the EMG data and optional label at the specified index. + + Args: + index (int): Index of the sample to retrieve. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + - If finetune=True: (EMG tensor, Label tensor) + - If finetune=False: EMG tensor + """ # Check Cache if self.use_cache and index in self.cache: return self._process_data(self.cache[index]) - # Open file (Lazy Loading for Multiprocessing) - if self.data is None: - self._open_file() - - # Read Data, HDF5 slicing returns numpy array - X_np = self.X_ds[index] - X = torch.from_numpy(X_np).float() - - if self.finetune: - Y_np = self.Y_ds[index] - if self.regression: - Y = torch.from_numpy(Y_np).float() + if self.preload_in_memory: + X = self.X_tensor[index] + if self.finetune: + Y = self.Y_tensor[index] + data_item = (X, Y) else: - # Ensure scalar is converted properly - Y = torch.tensor(Y_np, dtype=torch.long) - - data_item = (X, Y) + data_item = X else: - data_item = X + # Read Data, HDF5 slicing returns numpy array + X_np = self.X_ds[index] + X = torch.from_numpy(X_np).float() + + if self.finetune: + Y_np = self.Y_ds[index] + if self.regression: + Y = torch.from_numpy(Y_np).float() + else: + Y = torch.tensor(Y_np, dtype=torch.long) + data_item = (X, Y) + else: + data_item = X # Update Cache if self.use_cache: @@ -124,21 +141,26 @@ def __getitem__(self, index): return self._process_data(data_item) - def _process_data(self, data_item): - """Helper to handle squeezing/returning uniformly.""" + def _process_data(self, data_item: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Applies final transformations like unsqueezing. + + Args: + data_item (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): Raw data/label tuple. + + Returns: + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Processed data. + """ if self.finetune: X, Y = data_item - else: - X = data_item - Y = None + if self.unsqueeze: + X = X.unsqueeze(0) + return X, Y + X = data_item if self.unsqueeze: X = X.unsqueeze(0) + return X - if self.finetune: - return X, Y - else: - return X def __del__(self): if self.data is not None: diff --git a/datasets/emg_pretrain_dataset.py b/datasets/emg_pretrain_dataset.py index 15a3248..39063b4 100644 --- a/datasets/emg_pretrain_dataset.py +++ b/datasets/emg_pretrain_dataset.py @@ -17,120 +17,186 @@ # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* import os -import threading -from collections import deque -from typing import Optional +import fcntl +from multiprocessing import shared_memory import h5py +from tqdm import tqdm +import numpy as np import torch +from joblib import Parallel, delayed import torch.nn.functional as F from torch.utils.data import Dataset -# thread-local storage for per-worker file handle -_thread_local = threading.local() - - -def _get_h5_handle(path): - h5f = getattr(_thread_local, "h5f", None) - if h5f is None or h5f.filename != path: - h5f = h5py.File(path, "r") - _thread_local.h5f = h5f - return h5f - - class EMGPretrainDataset(Dataset): - """ - A PyTorch Dataset class for loading EMG (electromyography) data from HDF5 files for pretraining purposes. - This dataset discovers all .h5 files in the specified directory, builds an index of samples across all files, - and provides access to individual samples. It supports optional caching to improve performance, channel padding, - and squeezing of the data tensor. - Args: - data_dir (str): Path to the directory containing .h5 files. - squeeze (bool, optional): Whether to squeeze the data tensor. Defaults to False. - cache_size (int, optional): Size of the cache. Defaults to 1500. - use_cache (bool, optional): Enable caching. Defaults to True. - pad_up_to_max_chans (int | None, optional): Number of channels to pad to. Defaults to None. - max_samples (int | None, optional): Limit the total number of samples. Defaults to None. - Raises: - RuntimeError: If no .h5 files are found in the data directory. - Note: - The .h5 files are expected to have a 'data' dataset with shape (N, C, T), where N is the number of samples, - C is the number of channels, and T is the number of time points. - Caching uses a simple LRU mechanism with a deque to track access order. - The __del__ method ensures that any open HDF5 file handles are closed upon deletion. + """Shared-memory optimized Dataset for large-scale EMG pretraining. + + This dataset loads HDF5 data into a shared RAM block (POSIX shared memory) + to allow fast access across multiple worker processes without + the serial overhead of HDF5 reads or the memory duplication of standard + multiprocessing. + + Attributes: + hdf5_file_path (str): Path to the single HDF5 source file. + minmax (bool): Whether to apply min-max scaling to [-1, 1]. + pad_up_to_max_chans (Optional[int]): If set, zero-pads channels to this count. + total_len (int): Total number of samples across all HDF5 groups. + ram_data (np.ndarray): View into the shared memory block. """ def __init__( self, - data_dir: str, - squeeze: bool = False, - cache_size: int = 1500, - use_cache: bool = True, - pad_up_to_max_chans: Optional[int] = None, - max_samples: Optional[int] = None, + hdf5_file: str, + minmax: bool = True, + pad_up_to_max_chans: int | None = None, + n_jobs: int = 16 ): + """Initializes the shared memory dataset and loads data if needed. + + Args: + hdf5_file (str): Path to the HDF5 file. + minmax (bool): Enable scaling. Defaults to True. + pad_up_to_max_chans (Optional[int]): Target channel count for padding. + n_jobs (int): Number of parallel threads for the initial load. + """ super().__init__() - self.squeeze = squeeze - self.cache_size = cache_size - self.use_cache = use_cache + self.minmax = minmax self.pad_up_to_max_chans = pad_up_to_max_chans - # discover all .h5 files - self.file_paths = sorted(os.path.join(data_dir, fn) for fn in os.listdir(data_dir) if fn.endswith(".h5")) - if not self.file_paths: - raise RuntimeError(f"No .h5 files in {data_dir!r}") - - # build index of (file_path, sample_idx) - self.index_map = [] - for fp in self.file_paths: - with h5py.File(fp, "r") as h5f: - n = h5f["data"].shape[0] - for i in range(n): - self.index_map.append((fp, i)) - if max_samples is not None: - self.index_map = self.index_map[:max_samples] - - # Cache to store recently accessed samples - if use_cache: - self.cache = {} - self.cache_queue = deque(maxlen=self.cache_size) - - def __len__(self): - return len(self.index_map) - - def __getitem__(self, index): - if self.use_cache and index in self.cache: - cached_data = self.cache[index] - X = cached_data - else: - fp, local_idx = self.index_map[index] - h5f = _get_h5_handle(fp) - np_x = h5f["data"][local_idx] # shape (C, T) - X = torch.from_numpy(np_x).float() - - if self.use_cache: - # If cache is full, remove oldest item from dict AND queue - if len(self.cache) >= self.cache_size: - oldest_index = self.cache_queue.popleft() - del self.cache[oldest_index] - - self.cache[index] = X - self.cache_queue.append(index) - - # squeeze if requested - if self.squeeze: - X = X.unsqueeze(0) - - # pad channels if requested + self.rank = int(os.environ.get("LOCAL_RANK", os.environ.get("RANK", 0))) + + # This class will be instantiated once per file (e.g., train.h5, val.h5) + self.hdf5_file_path = hdf5_file + if not os.path.exists(self.hdf5_file_path) or not self.hdf5_file_path.endswith(".h5"): + raise ValueError(f"Expected hdf5_file to be a path to a single HDF5 file, but got {self.hdf5_file_path}") + + self.ram_data = None + self.shm_block = None + + file_name = os.path.basename(self.hdf5_file_path) + + # Calculate total shape from all groups + with h5py.File(self.hdf5_file_path, "r") as hf: + group_keys = sorted(hf.keys()) + if not group_keys: raise ValueError(f"HDF5 file {file_name} contains no data groups.") + + self.group_offsets = [0] + total_samples = 0 + for key in group_keys: + num_in_group = hf[key]['X'].shape[0] + total_samples += num_in_group + self.group_offsets.append(total_samples) + + # Get other dimensions from the first group + _, C, T = hf[group_keys[0]]['X'].shape + final_shape = (total_samples, C, T) + + self.total_len = total_samples + + # Allocate shared memory and load in parallel + clean_name = f"{os.path.splitext(file_name)[0]}_{final_shape[0]}" + shm_name = f"emg_shm_{clean_name}" + target_dtype = np.float16 # cast to fp16 to fit in RAM + num_bytes = int(np.prod(final_shape)) * np.dtype(target_dtype).itemsize + + lock_path = f"/tmp/{shm_name}.lock" + ready_path = f"/tmp/{shm_name}.ready" + file_mtime = os.path.getmtime(self.hdf5_file_path) + ready_token = f"{os.path.abspath(self.hdf5_file_path)}|{file_mtime}|{num_bytes}" + + with open(lock_path, "w") as lockf: + fcntl.flock(lockf, fcntl.LOCK_EX) + + shm = None + shm_needs_load = True + + if os.path.exists(ready_path): + try: + with open(ready_path, "r") as rf: + token = rf.read().strip() + if token == ready_token: + existing = shared_memory.SharedMemory(name=shm_name) + if existing.size == num_bytes: + shm = existing + shm_needs_load = False + else: + existing.close() + try: + existing.unlink() + except FileNotFoundError: + pass + except Exception: + shm_needs_load = True + + if shm_needs_load: + try: + stale = shared_memory.SharedMemory(name=shm_name) + stale.close() + try: + stale.unlink() + except FileNotFoundError: + pass + except FileNotFoundError: + pass + + print(f"[PID {os.getpid()}] Allocating {num_bytes / 1e9:.2f} GB of Shared RAM for {file_name}...") + shm = shared_memory.SharedMemory(create=True, name=shm_name, size=num_bytes) + shm_arr = np.ndarray(final_shape, dtype=target_dtype, buffer=shm.buf) + + def load_group(group_idx): + key = group_keys[group_idx] + start_offset = self.group_offsets[group_idx] + end_offset = self.group_offsets[group_idx + 1] + with h5py.File(self.hdf5_file_path, "r") as local_hf: + data_chunk = local_hf[key]['X'][:].astype(target_dtype) + shm_arr[start_offset:end_offset] = data_chunk + + print(f"[PID {os.getpid()}] Parallel loading groups from {file_name} using {n_jobs} cores...") + Parallel(n_jobs=n_jobs, backend="threading")( + delayed(load_group)(i) for i in tqdm(range(len(group_keys)), desc=f"Loading {file_name}") + ) + with open(ready_path, "w") as wf: + wf.write(ready_token) + print(f"[PID {os.getpid()}] Finished loading {file_name}!") + + self.shm_block = shm + self.ram_data = np.ndarray(final_shape, dtype=target_dtype, buffer=shm.buf) + + def __len__(self) -> int: + """Returns the total number of samples.""" + return self.total_len + + def _minmax_scale(self, x: torch.Tensor, eps: float = 1e-10) -> torch.Tensor: + """Scales EMG signal to [-1, 1] range.""" + maxv = x.amax(dim=-1, keepdim=True) + minv = x.amin(dim=-1, keepdim=True) + x = (x - minv) / (maxv - minv + eps) + return (x - 0.5) * 2 + + def __getitem__(self, idx: int) -> torch.Tensor: + """Retrieves a single sample as a float32 tensor. + + Args: + idx (int): Global index across all groups. + + Returns: + torch.Tensor: Normalized and padded EMG tensor of shape (C, T). + """ + if idx < 0 or idx >= self.total_len: + raise IndexError(f"Index {idx} out of range for dataset of size {self.total_len}") + + # Direct, instant access from the single RAM array + X = torch.tensor(self.ram_data[idx], dtype=torch.float32).contiguous() + + if self.minmax: X = self._minmax_scale(X) if self.pad_up_to_max_chans is not None: C = X.shape[0] to_pad = self.pad_up_to_max_chans - C - if to_pad > 0: - X = F.pad(X, (0, 0, 0, to_pad)) # (channels, time) -> pad channels - + if to_pad > 0: X = F.pad(X.T, (0, to_pad)).T return X + def __del__(self): - h5f = getattr(_thread_local, "h5f", None) - if h5f is not None: - h5f.close() + if hasattr(self, 'shm_block') and self.shm_block is not None: + try: self.shm_block.close() + except Exception: pass \ No newline at end of file diff --git a/tasks/finetune_task_EMG.py b/tasks/finetune_task_EMG.py index 8599443..7a59d55 100644 --- a/tasks/finetune_task_EMG.py +++ b/tasks/finetune_task_EMG.py @@ -16,7 +16,7 @@ # * * # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -from typing import Optional +from typing import Optional, Tuple import hydra import pytorch_lightning as pl @@ -35,37 +35,43 @@ Precision, Recall, ) +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, R2Score from util.train_utils import MinMaxNormalization class FinetuneTask(pl.LightningModule): - """ - PyTorch Lightning module for fine-tuning a classification model, with support for: - - - Classification types: - - `bc`: Binary Classification - - `ml`: Multi-Label Classification - - - Metric logging during training, validation, and testing, including accuracy, precision, recall, F1 score, AUROC, and more - - Optional input normalization with configurable normalization functions - - Custom optimizer support including SGD, Adam, AdamW, and LAMB - - Learning rate schedulers with configurable scheduling strategies - - Layer-wise learning rate decay for fine-grained learning rate control across model blocks + """PyTorch LightningModule for EMG fine-tuning tasks. + + This module manages the training, validation, and testing for both + classification (gesture detection) and regression (kinematic tracking) tasks. + It supports modular model architectures, metric collections, and layer-wise learning rate decay. + + Attributes: + model (nn.Module): The instantiated neural network. + num_classes (int): Number of target classes or regression outputs. + classification_type (str): Format of classification (e.g., 'bc', 'ml'). + task (str): The specific task type ('classification' or 'regression'). + normalize (bool): Whether input normalization is enabled. + criterion (nn.Module): The loss function (CrossEntropy or L1). """ def __init__(self, hparams: DictConfig): - """ - Initialize the FinetuneTask module. + """Initializes the FinetuneTask with Hydra configurations. + + Sets up the model, loss functions, and metric collections based on the + provided task type. Args: - hparams (DictConfig): Hyperparameters and configuration loaded via Hydra. + hparams (DictConfig): Configuration object containing 'model', + 'optimizer', 'scheduler', and 'finetuning' parameters. """ super().__init__() self.save_hyperparameters(hparams) self.model = hydra.utils.instantiate(self.hparams.model) self.num_classes = self.hparams.model.num_classes self.classification_type = self.hparams.model.classification_type + self.task = self.hparams.model.task # Enable normalization if specified in parameters self.normalize = False @@ -73,70 +79,100 @@ def __init__(self, hparams: DictConfig): self.normalize = True self.normalize_fct = MinMaxNormalization() - # Loss function - self.criterion = nn.CrossEntropyLoss(label_smoothing=0.10) - - # Classification mode detection - if not isinstance(self.num_classes, int): - raise TypeError("Number of classes must be an integer.") - elif self.num_classes < 2: - raise ValueError("Number of classes must be at least 2.") - elif self.num_classes == 2: - self.classification_task = "binary" - else: - self.classification_task = "multiclass" + if self.task == "regression": + self.criterion = nn.L1Loss() - # Metrics - label_metrics = MetricCollection( + # Metric + mean_metrics = MetricCollection( { - "micro_acc": Accuracy( - task=self.classification_task, - num_classes=self.num_classes, - average="micro", - ), - "macro_acc": Accuracy( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "recall": Recall(task="multiclass", num_classes=self.num_classes, average="macro"), - "precision": Precision( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "f1": F1Score( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "cohen_kappa": CohenKappa(task=self.classification_task, num_classes=self.num_classes), + "rmse": MeanSquaredError(squared=False), + "mae": MeanAbsoluteError(), } - ) - logit_metrics = MetricCollection( - { - "auroc": AUROC( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - "average_precision": AveragePrecision( - task=self.classification_task, - num_classes=self.num_classes, - average="macro", - ), - } - ) - self.train_label_metrics = label_metrics.clone(prefix="train/") - self.val_label_metrics = label_metrics.clone(prefix="val/") - self.test_label_metrics = label_metrics.clone(prefix="test/") - self.train_logit_metrics = logit_metrics.clone(prefix="train/") - self.val_logit_metrics = logit_metrics.clone(prefix="val/") - self.test_logit_metrics = logit_metrics.clone(prefix="test/") - - def load_pretrained_checkpoint(self, model_ckpt): - """ - Load a pretrained model checkpoint and unfreeze specific layers for fine-tuning. + ) + scalar_metrics = MetricCollection( + { + "r2": R2Score(num_outputs=self.num_classes, multioutput="uniform_average"), + } + ) + + self.train_mean_metrics = mean_metrics.clone(prefix="train/") + self.train_scalar_metrics = scalar_metrics.clone(prefix="train/") + self.val_mean_metrics = mean_metrics.clone(prefix="val/") + self.val_scalar_metrics = scalar_metrics.clone(prefix="val/") + self.test_mean_metrics = mean_metrics.clone(prefix="test/") + self.test_scalar_metrics = scalar_metrics.clone(prefix="test/") + + else: + # Loss function + self.criterion = nn.CrossEntropyLoss(label_smoothing=0.10) + + # Classification mode detection + if not isinstance(self.num_classes, int): + raise TypeError("Number of classes must be an integer.") + elif self.num_classes < 2: + raise ValueError("Number of classes must be at least 2.") + elif self.num_classes == 2: + self.classification_task = "binary" + else: + self.classification_task = "multiclass" + + # Metrics + label_metrics = MetricCollection( + { + "micro_acc": Accuracy( + task=self.classification_task, + num_classes=self.num_classes, + average="micro", + ), + "macro_acc": Accuracy( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "recall": Recall(task="multiclass", num_classes=self.num_classes, average="macro"), + "precision": Precision( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "f1": F1Score( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "cohen_kappa": CohenKappa(task=self.classification_task, num_classes=self.num_classes), + } + ) + logit_metrics = MetricCollection( + { + "auroc": AUROC( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + "average_precision": AveragePrecision( + task=self.classification_task, + num_classes=self.num_classes, + average="macro", + ), + } + ) + self.train_label_metrics = label_metrics.clone(prefix="train/") + self.val_label_metrics = label_metrics.clone(prefix="val/") + self.test_label_metrics = label_metrics.clone(prefix="test/") + self.train_logit_metrics = logit_metrics.clone(prefix="train/") + self.val_logit_metrics = logit_metrics.clone(prefix="val/") + self.test_logit_metrics = logit_metrics.clone(prefix="test/") + + + def load_pretrained_checkpoint(self, model_ckpt: str) -> None: + """Loads a pretrained PyTorch Lightning checkpoint (.ckpt). + + This method loads the state dict, optionally freezes layers based on configuration, + and ensures the model head remains trainable for fine-tuning. + + Args: + model_ckpt (str): Path to the .ckpt file. """ assert self.model.model_head is not None print("Loading pretrained checkpoint from .ckpt file") @@ -151,9 +187,11 @@ def load_pretrained_checkpoint(self, model_ckpt): print("Pretrained model ready.") - def load_safetensors_checkpoint(self, model_ckpt): - """ - Load a pretrained model checkpoint in safetensors format and unfreeze specific layers for fine-tuning. + def load_safetensors_checkpoint(self, model_ckpt: str) -> None: + """Loads a pretrained model checkpoint in safetensors format. + + Args: + model_ckpt (str): Path to the .safetensors file. """ assert self.model.model_head is not None print("Loading pretrained safetensors checkpoint") @@ -168,29 +206,31 @@ def load_safetensors_checkpoint(self, model_ckpt): print("Pretrained model ready.") - def generate_fake_mask(self, batch_size, C, T): - """ - Create a dummy mask tensor to simulate attention masking. + def generate_fake_mask(self, batch_size: int, C: int, T: int) -> torch.Tensor: + """Creates a dummy boolean mask tensor to simulate attention masking. Args: - batch_size (int): Number of samples. + batch_size (int): Batch size (B). C (int): Number of channels. - T (int): Temporal dimension. + T (int): Sequence length (tokens). Returns: - torch.Tensor: Boolean mask tensor of shape (B, C, T). + torch.Tensor: Boolean mask of shape (B, C, T) initialized to False. """ return torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device) def _step(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> dict: - """ - Perform forward pass and post-process predictions. + """Performs a forward pass and extracts probabilities and labels. Args: - X (torch.Tensor): Input tensor. + X (torch.Tensor): Input EMG tensor of shape (B, C, T). + mask (Optional[torch.Tensor]): Attention mask. Defaults to None. Returns: - dict: Dictionary containing predicted labels, probabilities, and logits. + dict: Dictionary with keys "label", "probs", and "logits". + + Raises: + NotImplementedError: If classification_type is not 'bc' or 'ml'. """ y_pred_logits, _ = self.model(X, mask=mask) @@ -207,7 +247,16 @@ def _step(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> dict: "logits": y_pred_logits, } - def training_step(self, batch, batch_idx): + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Standard PyTorch Lightning training step. + + Args: + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (input, target). + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Computed loss. + """ X, y = batch if self.normalize: X = self.normalize_fct(X) @@ -215,10 +264,18 @@ def training_step(self, batch, batch_idx): y_pred = self._step(X, mask=mask) loss = self.criterion(y_pred["logits"], y) - self.train_label_metrics(y_pred["label"], y) - self.train_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False) - self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False) + if self.task == "regression": + logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) + y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) + self.train_mean_metrics(logits_flat, y_flat) + self.train_scalar_metrics(logits_flat, y_flat) + self.log_dict(self.train_mean_metrics, on_step=True, on_epoch=False) + self.log_dict(self.train_scalar_metrics, on_step=True, on_epoch=False) + else: + self.train_label_metrics(y_pred["label"], y) + self.train_logit_metrics(self._handle_binary(y_pred["logits"]), y) + self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False) + self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False) self.log( "train_loss", loss, @@ -230,7 +287,16 @@ def training_step(self, batch, batch_idx): ) return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Standard PyTorch Lightning validation step. + + Args: + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (input, target). + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Computed validation loss. + """ X, y = batch if self.normalize: X = self.normalize_fct(X) @@ -238,14 +304,31 @@ def validation_step(self, batch, batch_idx): y_pred = self._step(X, mask=mask) loss = self.criterion(y_pred["logits"], y) - self.val_label_metrics(y_pred["label"], y) - self.val_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True) - self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True) + if self.task == "regression": + logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) + y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) + self.val_mean_metrics(logits_flat, y_flat) + self.val_scalar_metrics(logits_flat, y_flat) + self.log_dict(self.val_mean_metrics, on_step=False, on_epoch=True) + self.log_dict(self.val_scalar_metrics, on_step=False, on_epoch=True) + else: + self.val_label_metrics(y_pred["label"], y) + self.val_logit_metrics(self._handle_binary(y_pred["logits"]), y) + self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True) + self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True) self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss - def test_step(self, batch, batch_idx): + def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + """Standard PyTorch Lightning test step. + + Args: + batch (Tuple[torch.Tensor, torch.Tensor]): Tuple containing (input, target). + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Computed test loss. + """ X, y = batch if self.normalize: X = self.normalize_fct(X) @@ -253,25 +336,41 @@ def test_step(self, batch, batch_idx): y_pred = self._step(X, mask=mask) loss = self.criterion(y_pred["logits"], y) - self.test_label_metrics(y_pred["label"], y) - self.test_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True) - self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True) + if self.task == "regression": + logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) + y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) + self.test_mean_metrics(logits_flat, y_flat) + self.test_scalar_metrics(logits_flat, y_flat) + self.log_dict(self.test_mean_metrics, on_step=False, on_epoch=True) + self.log_dict(self.test_scalar_metrics, on_step=False, on_epoch=True) + else: + self.test_label_metrics(y_pred["label"], y) + self.test_logit_metrics(self._handle_binary(y_pred["logits"]), y) + self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True) + self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True) self.log("test_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss - def lr_scheduler_step(self, scheduler, metric): - """ - Custom scheduler step function for step-based LR schedulers + def lr_scheduler_step(self, scheduler: torch.optim.lr_scheduler._LRScheduler, metric: Optional[torch.Tensor]) -> None: + """Custom scheduler step logic for step-based schedulers. + + Args: + scheduler (torch.optim.lr_scheduler._LRScheduler): The optimizer scheduler. + metric (Optional[torch.Tensor]): Optional metric for ReduceLROnPlateau. """ scheduler.step(epoch=self.current_epoch) - def configure_optimizers(self): - """ - Configure the optimizer and learning rate scheduler. + def configure_optimizers(self) -> dict: + """Configures optimizers and learning rate schedulers. + + Implements layer-wise learning rate decay for the Transformer encoder/Mamba blocks, + ensuring lower layers decay more than the head. Returns: - dict: Configuration dictionary with optimizer and LR scheduler. + dict: Configuration for the PyTorch Lightning trainer. + + Raises: + NotImplementedError: If the optimizer name is not supported. """ num_blocks = self.hparams.model.n_layer params_to_pass = [] @@ -323,17 +422,17 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - def _handle_binary(self, preds): - """ - Special handling for binary classification probabilities. + def _handle_binary(self, preds: torch.Tensor) -> torch.Tensor: + """Slices logits for binary classification task. Args: - preds (torch.Tensor): Logit outputs. + preds (torch.Tensor): Logit outputs from the model. Returns: - torch.Tensor: Probabilities for the positive class. + torch.Tensor: Logits/probabilities for the positive class if binary, else full preds. """ if self.classification_task == "binary" and self.classification_type != "mc": return preds[:, 1].squeeze() else: return preds + From fc62df3b8c904cb575dc17779b54a744ef3f5d97 Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Thu, 2 Apr 2026 17:34:16 +0200 Subject: [PATCH 7/9] feat: add TinyMyo finetuning and pretraining configurations for 4 layers, EMG finetuning dataset handling, and improved training script logging --- config/experiment/TinyMyo_finetune.yaml | 1 + config/experiment/TinyssimoMyo_finetune.yaml | 111 ++++++++++++++++ config/experiment/TinyssimoMyo_pretrain.yaml | 87 ++++++++++++ datasets/emg_finetune_dataset.py | 132 +++---------------- models/TinyMyo.py | 39 ------ run_train.py | 1 + tasks/finetune_task_EMG.py | 56 +++----- 7 files changed, 237 insertions(+), 190 deletions(-) create mode 100644 config/experiment/TinyssimoMyo_finetune.yaml create mode 100644 config/experiment/TinyssimoMyo_pretrain.yaml diff --git a/config/experiment/TinyMyo_finetune.yaml b/config/experiment/TinyMyo_finetune.yaml index b4e6c20..0fe2690 100644 --- a/config/experiment/TinyMyo_finetune.yaml +++ b/config/experiment/TinyMyo_finetune.yaml @@ -100,5 +100,6 @@ scheduler: t_in_epochs: True wandb: + entity: "TinyMyo" project: "TinyMyo" save_dir: ${env:LOG_DIR} \ No newline at end of file diff --git a/config/experiment/TinyssimoMyo_finetune.yaml b/config/experiment/TinyssimoMyo_finetune.yaml new file mode 100644 index 0000000..e7d5ee9 --- /dev/null +++ b/config/experiment/TinyssimoMyo_finetune.yaml @@ -0,0 +1,111 @@ +# @package _global_ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2025 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Matteo Fasulo * +#*----------------------------------------------------------------------------* +tag: EMG_finetune + +gpus: -1 +num_nodes: 1 +num_workers: 8 +batch_size: 32 +max_epochs: 50 + +training: True +final_validate: True +final_test: True +finetune_pretrained: True +resume: False + +layerwise_lr_decay: 0.90 +scheduler_type: cosine + +pretrained_checkpoint_path: null +pretrained_safetensors_path: null + +finetuning: + freeze_layers: False + +io: + base_output_path: ${env:LOG_DIR} + checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints + version: 0 + +defaults: + - override /data_module: emg_finetune_data_module + - override /model: TinyMyo_finetune + - override /scheduler: cosine + - override /task: finetune_task_TinyMyo + - override /criterion: finetune_criterion + +masking: + patch_size: [1, 20] + masking_ratio: 0.50 + unmasked_loss_coeff: 0.1 + +input_normalization: + normalize: False + +model: + n_layer: 4 + attn_drop: 0.1 + proj_drop: 0.1 + drop_path: 0.1 + num_classes: 6 + task: "classification" + classification_type: "ml" + reduction_type: "concat" + +trainer: + accelerator: gpu + num_nodes: ${num_nodes} + devices: ${gpus} + strategy: auto + max_epochs: ${max_epochs} + +model_checkpoint: + save_last: True + monitor: "val_loss" + mode: "min" + save_top_k: 1 + +callbacks: + early_stopping: + _target_: 'pytorch_lightning.callbacks.EarlyStopping' + monitor: "val_loss" + patience: 7 + mode: "min" + verbose: True + +optimizer: + optim: 'AdamW' + lr: 5e-4 + betas: [0.9, 0.98] + weight_decay: 1e-2 + +scheduler: + trainer: ${trainer} + min_lr: 1e-5 + warmup_lr_init: 1e-6 + warmup_epochs: 5 + total_training_opt_steps: ${max_epochs} + t_in_epochs: True + +wandb: + entity: "TinyMyo" + project: "TinyMyo" + save_dir: ${env:LOG_DIR} diff --git a/config/experiment/TinyssimoMyo_pretrain.yaml b/config/experiment/TinyssimoMyo_pretrain.yaml new file mode 100644 index 0000000..6b26959 --- /dev/null +++ b/config/experiment/TinyssimoMyo_pretrain.yaml @@ -0,0 +1,87 @@ +# @package _global_ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2025 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Matteo Fasulo * +#*----------------------------------------------------------------------------* +tag: EMG_pretrain + +gpus: -1 +num_nodes: 1 +num_workers: 8 +batch_size: 128 +max_epochs: 50 + +final_validate: True +final_test: False + +pretrained_checkpoint_path: null +io: + base_output_path: ${env:LOG_DIR} + checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints + version: 0 + +defaults: + - override /data_module: emg_pretrain_data_module + - override /model: TinyMyo_pretrain + - override /scheduler: cosine + - override /task: pretrain_task_TinyMyo + - override /criterion: pretrain_criterion + +masking: + patch_size: [1, 20] + masking_ratio: 0.50 + unmasked_loss_coeff: 0.1 + +input_normalization: + normalize: True + +model: + n_layer: 4 + +scheduler: + trainer: ${trainer} + min_lr: 1e-6 + warmup_lr_init: 1e-6 + warmup_epochs: 10 + total_training_opt_steps: ${max_epochs} + t_in_epochs: True + +trainer: + accelerator: gpu + num_nodes: ${num_nodes} + devices: ${gpus} + strategy: auto + max_epochs: ${max_epochs} + gradient_clip_val: 3 + accumulate_grad_batches: 8 + +model_checkpoint: + save_last: True + monitor: "val_loss" + mode: "min" + save_top_k: 1 + +optimizer: + optim: 'AdamW' + lr: 1e-4 + betas: [0.9, 0.98] + weight_decay: 0.01 + +wandb: + entity: "TinyMyo" + project: "TinyMyo" + save_dir: ${env:LOG_DIR} diff --git a/datasets/emg_finetune_dataset.py b/datasets/emg_finetune_dataset.py index 103bbf6..56e4fc4 100644 --- a/datasets/emg_finetune_dataset.py +++ b/datasets/emg_finetune_dataset.py @@ -16,152 +16,62 @@ # * * # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -from collections import deque from typing import Tuple, Union import h5py +import numpy as np import torch class EMGDataset(torch.utils.data.Dataset): """PyTorch Dataset for loading EMG data from HDF5 files. - This dataset supports both classification and regression tasks. It provides - two loading modes: pre-loading the entire dataset into RAM for speed, or - lazy-loading from disk with an LRU cache for large datasets that don't fit - in memory. - Attributes: hdf5_file (str): Path to the HDF5 source file. finetune (bool): If True, returns (data, label). If False, returns data only. - unsqueeze (bool): If True, adds a channel dimension to the input. - cache_size (int): Max number of samples to keep in the LRU cache. - use_cache (bool): Enables LRU caching for lazy loading. regression (bool): If True, labels are treated as floats. Else, longs. - preload_in_memory (bool): If True, loads the full HDF5 content into RAM on init. - num_samples (int): Total number of samples in the dataset. """ def __init__( self, hdf5_file: str, finetune: bool = True, - unsqueeze: bool = False, - cache_size: int = 1500, - use_cache: bool = False, regression: bool = False, - preload_in_memory: bool = True, + verbose: bool = False, ): self.hdf5_file = hdf5_file self.finetune = finetune - self.unsqueeze = unsqueeze - self.cache_size = cache_size - self.use_cache = use_cache self.regression = regression - self.preload_in_memory = preload_in_memory - - self.data = None - self.X_ds = None - self.Y_ds = None - self.X_tensor = None - self.Y_tensor = None - - if self.preload_in_memory: - with h5py.File(self.hdf5_file, "r") as f: - X_np = f["data"][:] - Y_np = f["label"][:] if self.finetune else None - self.X_tensor = torch.from_numpy(X_np).float().contiguous() - if self.finetune: - if self.regression: - self.Y_tensor = torch.from_numpy(Y_np).float().contiguous() - else: - self.Y_tensor = torch.from_numpy(Y_np).long().contiguous() + with h5py.File(self.hdf5_file, "r") as f: + X_np = f["data"][:] + Y_np = f["label"][:] if self.finetune else None - self.num_samples = self.X_tensor.shape[0] - else: - self.data = h5py.File(self.hdf5_file, "r") - self.X_ds = self.data["data"] - self.Y_ds = self.data["label"] if self.finetune else None - self.num_samples = self.X_ds.shape[0] + self.X_tensor = torch.from_numpy(X_np).float().contiguous() + if self.finetune: + if self.regression: + self.Y_tensor = torch.from_numpy(Y_np).float().contiguous() + else: + self.Y_tensor = torch.from_numpy(Y_np).long().contiguous() + if verbose: + uniq, cnt = np.unique(Y_np, return_counts=True) + print( + f"[EMGDataset] {self.hdf5_file}: label min={Y_np.min()}, max={Y_np.max()}, classes={len(uniq)}" + ) + print(f"[EMGDataset] {self.hdf5_file}: class hist={dict(zip(uniq.tolist(), cnt.tolist()))}") - if self.use_cache: - self.cache: dict[int, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} - self.cache_queue = deque() + self.num_samples = self.X_tensor.shape[0] # [N, C, T] def __len__(self) -> int: """Returns the total number of samples in the dataset.""" return self.num_samples def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Retrieves the EMG data and optional label at the specified index. + """Retrieves the EMG data and optional label at the specified index.""" - Args: - index (int): Index of the sample to retrieve. + X = self.X_tensor[index] - Returns: - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - - If finetune=True: (EMG tensor, Label tensor) - - If finetune=False: EMG tensor - """ - # Check Cache - if self.use_cache and index in self.cache: - return self._process_data(self.cache[index]) - - if self.preload_in_memory: - X = self.X_tensor[index] - if self.finetune: - Y = self.Y_tensor[index] - data_item = (X, Y) - else: - data_item = X - else: - # Read Data, HDF5 slicing returns numpy array - X_np = self.X_ds[index] - X = torch.from_numpy(X_np).float() - - if self.finetune: - Y_np = self.Y_ds[index] - if self.regression: - Y = torch.from_numpy(Y_np).float() - else: - Y = torch.tensor(Y_np, dtype=torch.long) - data_item = (X, Y) - else: - data_item = X - - # Update Cache - if self.use_cache: - # If cache is full, remove oldest item from dict AND queue - if len(self.cache) >= self.cache_size: - oldest_index = self.cache_queue.popleft() - del self.cache[oldest_index] - - self.cache[index] = data_item - self.cache_queue.append(index) - - return self._process_data(data_item) - - def _process_data(self, data_item: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """Applies final transformations like unsqueezing. - - Args: - data_item (Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): Raw data/label tuple. - - Returns: - Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: Processed data. - """ if self.finetune: - X, Y = data_item - if self.unsqueeze: - X = X.unsqueeze(0) + Y = self.Y_tensor[index] return X, Y - X = data_item - if self.unsqueeze: - X = X.unsqueeze(0) return X - - - def __del__(self): - if self.data is not None: - self.data.close() diff --git a/models/TinyMyo.py b/models/TinyMyo.py index 203d1f0..2fadd02 100644 --- a/models/TinyMyo.py +++ b/models/TinyMyo.py @@ -1,4 +1,3 @@ -import math from dataclasses import dataclass from typing import Literal, Optional, Tuple @@ -407,15 +406,6 @@ def __post_init__(self): self.classifier = nn.Linear(feat_dim, self.num_classes) - # init weights - self.apply(self._init_weights) - - def _init_weights(self, m: nn.Module): - if isinstance(m, nn.Linear): - torch.nn.init.xavier_uniform_(m.weight) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: @@ -508,15 +498,6 @@ def __post_init__(self): nn.Conv1d(self.hidden_dim, self.output_dim, kernel_size=1), ) - # Initialize weights - self.apply(self._init_weights) - - def _init_weights(self, m): - if isinstance(m, nn.Conv1d): - nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") - if m.bias is not None: - nn.init.zeros_(m.bias) - def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the model. @@ -666,7 +647,6 @@ def initialize_weights(self): trunc_normal_(self.mask_token, std=0.02) self.apply(self._init_weights) - self.fix_init_weight() def _init_weights(self, m): """Initializes the model weights.""" @@ -679,25 +659,6 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - def fix_init_weight(self): - """ - Rescales the weights of attention and MLP layers to improve training stability. - - For each layer, weights are divided by sqrt(2 * layer_id). - """ - - def rescale(param, layer_id): - param.div_(math.sqrt(2.0 * layer_id)) - - for layer_id, layer in enumerate(self.blocks, start=1): - attn_proj = getattr(getattr(layer, "attn", None), "proj", None) - if attn_proj is not None: - rescale(attn_proj.weight.data, layer_id) - - mlp_fc2 = getattr(getattr(layer, "mlp", None), "fc2", None) - if mlp_fc2 is not None: - rescale(mlp_fc2.weight.data, layer_id) - def prepare_tokens( self, x_signal: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: diff --git a/run_train.py b/run_train.py index 2ee0cd8..304a0b0 100644 --- a/run_train.py +++ b/run_train.py @@ -70,6 +70,7 @@ def train(cfg: DictConfig): if cfg.wandb: wandb_logger = WandbLogger( name=version, + entity=cfg.wandb.entity, project=cfg.wandb.project, log_model="all", save_dir=cfg.wandb.save_dir, diff --git a/tasks/finetune_task_EMG.py b/tasks/finetune_task_EMG.py index 7a59d55..fa51849 100644 --- a/tasks/finetune_task_EMG.py +++ b/tasks/finetune_task_EMG.py @@ -22,7 +22,6 @@ import pytorch_lightning as pl import torch import torch.nn as nn -import torch_optimizer as torch_optim from omegaconf import DictConfig from safetensors.torch import load_file from torchmetrics import MetricCollection @@ -84,27 +83,19 @@ def __init__(self, hparams: DictConfig): # Metric mean_metrics = MetricCollection( - { - "rmse": MeanSquaredError(squared=False), - "mae": MeanAbsoluteError(), - } - ) - scalar_metrics = MetricCollection( { - "r2": R2Score(num_outputs=self.num_classes, multioutput="uniform_average"), + "rmse": MeanSquaredError(squared=False), + "mae": MeanAbsoluteError(), } ) self.train_mean_metrics = mean_metrics.clone(prefix="train/") - self.train_scalar_metrics = scalar_metrics.clone(prefix="train/") self.val_mean_metrics = mean_metrics.clone(prefix="val/") - self.val_scalar_metrics = scalar_metrics.clone(prefix="val/") self.test_mean_metrics = mean_metrics.clone(prefix="test/") - self.test_scalar_metrics = scalar_metrics.clone(prefix="test/") else: # Loss function - self.criterion = nn.CrossEntropyLoss(label_smoothing=0.10) + self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1) # Classification mode detection if not isinstance(self.num_classes, int): @@ -178,7 +169,9 @@ def load_pretrained_checkpoint(self, model_ckpt: str) -> None: print("Loading pretrained checkpoint from .ckpt file") checkpoint = torch.load(model_ckpt, map_location="cpu", weights_only=False) state_dict = checkpoint["state_dict"] - self.load_state_dict(state_dict, strict=False) + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + print(f"Missing keys when loading checkpoint: {missing_keys}") + print(f"Unexpected keys when loading checkpoint: {unexpected_keys}") for name, param in self.model.named_parameters(): if self.hparams.finetuning.freeze_layers: param.requires_grad = False @@ -196,7 +189,9 @@ def load_safetensors_checkpoint(self, model_ckpt: str) -> None: assert self.model.model_head is not None print("Loading pretrained safetensors checkpoint") state_dict = load_file(model_ckpt) - self.load_state_dict(state_dict, strict=False) + missing_keys, unexpected_keys = self.load_state_dict(state_dict, strict=False) + print(f"Missing keys when loading checkpoint: {missing_keys}") + print(f"Unexpected keys when loading checkpoint: {unexpected_keys}") for name, param in self.model.named_parameters(): if self.hparams.finetuning.freeze_layers: @@ -268,9 +263,7 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) self.train_mean_metrics(logits_flat, y_flat) - self.train_scalar_metrics(logits_flat, y_flat) self.log_dict(self.train_mean_metrics, on_step=True, on_epoch=False) - self.log_dict(self.train_scalar_metrics, on_step=True, on_epoch=False) else: self.train_label_metrics(y_pred["label"], y) self.train_logit_metrics(self._handle_binary(y_pred["logits"]), y) @@ -308,9 +301,7 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) self.val_mean_metrics(logits_flat, y_flat) - self.val_scalar_metrics(logits_flat, y_flat) self.log_dict(self.val_mean_metrics, on_step=False, on_epoch=True) - self.log_dict(self.val_scalar_metrics, on_step=False, on_epoch=True) else: self.val_label_metrics(y_pred["label"], y) self.val_logit_metrics(self._handle_binary(y_pred["logits"]), y) @@ -340,9 +331,7 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) self.test_mean_metrics(logits_flat, y_flat) - self.test_scalar_metrics(logits_flat, y_flat) self.log_dict(self.test_mean_metrics, on_step=False, on_epoch=True) - self.log_dict(self.test_scalar_metrics, on_step=False, on_epoch=True) else: self.test_label_metrics(y_pred["label"], y) self.test_logit_metrics(self._handle_binary(y_pred["logits"]), y) @@ -379,39 +368,26 @@ def configure_optimizers(self) -> dict: for name, param in self.model.named_parameters(): lr = base_lr - if "mamba_blocks" in name or "norm_layers" in name: + if "norm_layers" in name: block_nr = int(name.split(".")[1]) lr *= decay_factor ** (num_blocks - block_nr) params_to_pass.append({"params": param, "lr": lr}) - if self.hparams.optimizer.optim == "SGD": - optimizer = torch.optim.SGD(params_to_pass, lr=base_lr, momentum=self.hparams.optimizer.momentum) - elif self.hparams.optimizer.optim == "Adam": - optimizer = torch.optim.Adam( - params_to_pass, - lr=base_lr, - weight_decay=self.hparams.optimizer.weight_decay, - ) - elif self.hparams.optimizer.optim == "AdamW": + if self.hparams.optimizer.optim == "AdamW": optimizer = torch.optim.AdamW( params_to_pass, lr=base_lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas, ) - elif self.hparams.optimizer.optim == "LAMB": - optimizer = torch_optim.Lamb(params_to_pass, lr=base_lr) else: raise NotImplementedError("No valid optimizer name") - if self.hparams.scheduler_type == "multi_step_lr": - scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer) - else: - scheduler = hydra.utils.instantiate( - self.hparams.scheduler, - optimizer=optimizer, - total_training_opt_steps=self.trainer.estimated_stepping_batches, - ) + scheduler = hydra.utils.instantiate( + self.hparams.scheduler, + optimizer=optimizer, + total_training_opt_steps=self.trainer.estimated_stepping_batches, + ) lr_scheduler_config = { "scheduler": scheduler, From f9ddfcb61f12b063c47c7afa67e1ad658645a507 Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Fri, 10 Apr 2026 16:16:03 +0200 Subject: [PATCH 8/9] refactor: training script with wandb offline parameter and finish closure --- run_train.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/run_train.py b/run_train.py index 304a0b0..650073b 100644 --- a/run_train.py +++ b/run_train.py @@ -34,6 +34,8 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.strategies import DDPStrategy +import wandb + from util.train_utils import find_last_checkpoint_path @@ -54,7 +56,7 @@ def train(cfg: DictConfig): seed_everything(cfg.seed) - date_format = "%d_%m_%H-%M" + date_format = "%d_%m_%H-%M-%S.%f" # Create your version_name version = f"{cfg.tag}_{datetime.now().strftime(date_format)}" @@ -69,11 +71,11 @@ def train(cfg: DictConfig): # Weights & Biases if cfg.wandb: wandb_logger = WandbLogger( - name=version, entity=cfg.wandb.entity, project=cfg.wandb.project, - log_model="all", save_dir=cfg.wandb.save_dir, + name=version, + offline=cfg.wandb.offline, ) loggers.append(wandb_logger) @@ -177,6 +179,9 @@ def train(cfg: DictConfig): if not cfg.training: trainer.save_checkpoint(f"{checkpoint_dirpath}/last.ckpt") + if wandb.run is not None: + wandb.finish() + @pl.utilities.rank_zero_only def _run_test( @@ -195,7 +200,6 @@ def _run_test( results["test_metrics"] = test_results return results, trainer - @hydra.main(config_path="./config", config_name="defaults", version_base="1.1") def run(cfg: DictConfig): print(f"PyTorch-Lightning Version: {pl.__version__}") From 636835b01b88c65eeb3ebb7a94f9650c2dc7a1e4 Mon Sep 17 00:00:00 2001 From: MatteoFasulo Date: Fri, 17 Apr 2026 18:03:26 +0200 Subject: [PATCH 9/9] refactor: TinyMyo configuration and model for improved training efficiency and logging --- config/experiment/TinyMyo_finetune.yaml | 22 +- config/experiment/TinyMyo_pretrain.yaml | 39 ++- config/experiment/TinyssimoMyo_finetune.yaml | 111 -------- config/experiment/TinyssimoMyo_pretrain.yaml | 87 ------ models/TinyMyo.py | 279 ++++++++++--------- run_train.py | 17 +- tasks/finetune_task_EMG.py | 215 +++++++------- tasks/pretrain_task_EMG.py | 250 +++++++---------- util/ckpt_to_safetensor.py | 2 +- 9 files changed, 397 insertions(+), 625 deletions(-) delete mode 100644 config/experiment/TinyssimoMyo_finetune.yaml delete mode 100644 config/experiment/TinyssimoMyo_pretrain.yaml diff --git a/config/experiment/TinyMyo_finetune.yaml b/config/experiment/TinyMyo_finetune.yaml index 0fe2690..224e60f 100644 --- a/config/experiment/TinyMyo_finetune.yaml +++ b/config/experiment/TinyMyo_finetune.yaml @@ -21,7 +21,7 @@ tag: EMG_finetune gpus: -1 num_nodes: 1 -num_workers: 8 +num_workers: 4 batch_size: 32 max_epochs: 50 @@ -32,7 +32,6 @@ finetune_pretrained: True resume: False layerwise_lr_decay: 0.90 -scheduler_type: cosine pretrained_checkpoint_path: null pretrained_safetensors_path: null @@ -52,23 +51,16 @@ defaults: - override /task: finetune_task_TinyMyo - override /criterion: finetune_criterion -masking: - patch_size: [1, 20] - masking_ratio: 0.50 - unmasked_loss_coeff: 0.1 - -input_normalization: - normalize: False - model: + n_layer: 8 num_classes: 6 - classification_type: "ml" + task: "classification" trainer: accelerator: gpu num_nodes: ${num_nodes} devices: ${gpus} - strategy: auto + strategy: ddp_find_unused_parameters_true max_epochs: ${max_epochs} model_checkpoint: @@ -89,7 +81,7 @@ optimizer: optim: 'AdamW' lr: 5e-4 betas: [0.9, 0.98] - weight_decay: 0.01 + weight_decay: 1e-2 scheduler: trainer: ${trainer} @@ -102,4 +94,6 @@ scheduler: wandb: entity: "TinyMyo" project: "TinyMyo" - save_dir: ${env:LOG_DIR} \ No newline at end of file + save_dir: ${env:LOG_DIR} + run_name: "TinyMyo-Finetuning" + offline: True \ No newline at end of file diff --git a/config/experiment/TinyMyo_pretrain.yaml b/config/experiment/TinyMyo_pretrain.yaml index 0c0c74c..9c0f287 100644 --- a/config/experiment/TinyMyo_pretrain.yaml +++ b/config/experiment/TinyMyo_pretrain.yaml @@ -22,8 +22,8 @@ tag: EMG_pretrain gpus: -1 num_nodes: 1 num_workers: 8 -batch_size: 128 -max_epochs: 50 +batch_size: 512 +max_epochs: 30 final_validate: True final_test: False @@ -49,22 +49,23 @@ masking: input_normalization: normalize: True -scheduler: - trainer: ${trainer} - min_lr: 1e-6 - warmup_lr_init: 1e-6 - warmup_epochs: 10 - total_training_opt_steps: ${max_epochs} - t_in_epochs: True +model: + n_layer: 8 + drop_path: 0.0 # Stochastic depth disabled for pretraining + num_classes: 0 # No classification head for pretraining + task: pretraining + +criterion: + loss_type: 'smooth_l1' trainer: accelerator: gpu num_nodes: ${num_nodes} devices: ${gpus} strategy: auto + precision: "bf16-mixed" max_epochs: ${max_epochs} - gradient_clip_val: 3 - accumulate_grad_batches: 8 + gradient_clip_val: 1 model_checkpoint: save_last: True @@ -73,11 +74,21 @@ model_checkpoint: save_top_k: 1 optimizer: - optim: 'AdamW' - lr: 1e-4 + lr: 5e-4 betas: [0.9, 0.98] - weight_decay: 0.01 + weight_decay: 1e-2 + +scheduler: + trainer: ${trainer} + min_lr: 1e-6 + warmup_lr_init: 1e-6 + warmup_epochs: 3 + total_training_opt_steps: ${max_epochs} + t_in_epochs: True wandb: + entity: "TinyMyo" project: "TinyMyo" save_dir: ${env:LOG_DIR} + run_name: "TinyMyo-Pretraining" + offline: True diff --git a/config/experiment/TinyssimoMyo_finetune.yaml b/config/experiment/TinyssimoMyo_finetune.yaml deleted file mode 100644 index e7d5ee9..0000000 --- a/config/experiment/TinyssimoMyo_finetune.yaml +++ /dev/null @@ -1,111 +0,0 @@ -# @package _global_ -#*----------------------------------------------------------------------------* -#* Copyright (C) 2025 ETH Zurich, Switzerland * -#* SPDX-License-Identifier: Apache-2.0 * -#* * -#* Licensed under the Apache License, Version 2.0 (the "License"); * -#* you may not use this file except in compliance with the License. * -#* You may obtain a copy of the License at * -#* * -#* http://www.apache.org/licenses/LICENSE-2.0 * -#* * -#* Unless required by applicable law or agreed to in writing, software * -#* distributed under the License is distributed on an "AS IS" BASIS, * -#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * -#* See the License for the specific language governing permissions and * -#* limitations under the License. * -#* * -#* Author: Matteo Fasulo * -#*----------------------------------------------------------------------------* -tag: EMG_finetune - -gpus: -1 -num_nodes: 1 -num_workers: 8 -batch_size: 32 -max_epochs: 50 - -training: True -final_validate: True -final_test: True -finetune_pretrained: True -resume: False - -layerwise_lr_decay: 0.90 -scheduler_type: cosine - -pretrained_checkpoint_path: null -pretrained_safetensors_path: null - -finetuning: - freeze_layers: False - -io: - base_output_path: ${env:LOG_DIR} - checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints - version: 0 - -defaults: - - override /data_module: emg_finetune_data_module - - override /model: TinyMyo_finetune - - override /scheduler: cosine - - override /task: finetune_task_TinyMyo - - override /criterion: finetune_criterion - -masking: - patch_size: [1, 20] - masking_ratio: 0.50 - unmasked_loss_coeff: 0.1 - -input_normalization: - normalize: False - -model: - n_layer: 4 - attn_drop: 0.1 - proj_drop: 0.1 - drop_path: 0.1 - num_classes: 6 - task: "classification" - classification_type: "ml" - reduction_type: "concat" - -trainer: - accelerator: gpu - num_nodes: ${num_nodes} - devices: ${gpus} - strategy: auto - max_epochs: ${max_epochs} - -model_checkpoint: - save_last: True - monitor: "val_loss" - mode: "min" - save_top_k: 1 - -callbacks: - early_stopping: - _target_: 'pytorch_lightning.callbacks.EarlyStopping' - monitor: "val_loss" - patience: 7 - mode: "min" - verbose: True - -optimizer: - optim: 'AdamW' - lr: 5e-4 - betas: [0.9, 0.98] - weight_decay: 1e-2 - -scheduler: - trainer: ${trainer} - min_lr: 1e-5 - warmup_lr_init: 1e-6 - warmup_epochs: 5 - total_training_opt_steps: ${max_epochs} - t_in_epochs: True - -wandb: - entity: "TinyMyo" - project: "TinyMyo" - save_dir: ${env:LOG_DIR} diff --git a/config/experiment/TinyssimoMyo_pretrain.yaml b/config/experiment/TinyssimoMyo_pretrain.yaml deleted file mode 100644 index 6b26959..0000000 --- a/config/experiment/TinyssimoMyo_pretrain.yaml +++ /dev/null @@ -1,87 +0,0 @@ -# @package _global_ -#*----------------------------------------------------------------------------* -#* Copyright (C) 2025 ETH Zurich, Switzerland * -#* SPDX-License-Identifier: Apache-2.0 * -#* * -#* Licensed under the Apache License, Version 2.0 (the "License"); * -#* you may not use this file except in compliance with the License. * -#* You may obtain a copy of the License at * -#* * -#* http://www.apache.org/licenses/LICENSE-2.0 * -#* * -#* Unless required by applicable law or agreed to in writing, software * -#* distributed under the License is distributed on an "AS IS" BASIS, * -#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * -#* See the License for the specific language governing permissions and * -#* limitations under the License. * -#* * -#* Author: Matteo Fasulo * -#*----------------------------------------------------------------------------* -tag: EMG_pretrain - -gpus: -1 -num_nodes: 1 -num_workers: 8 -batch_size: 128 -max_epochs: 50 - -final_validate: True -final_test: False - -pretrained_checkpoint_path: null -io: - base_output_path: ${env:LOG_DIR} - checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints - version: 0 - -defaults: - - override /data_module: emg_pretrain_data_module - - override /model: TinyMyo_pretrain - - override /scheduler: cosine - - override /task: pretrain_task_TinyMyo - - override /criterion: pretrain_criterion - -masking: - patch_size: [1, 20] - masking_ratio: 0.50 - unmasked_loss_coeff: 0.1 - -input_normalization: - normalize: True - -model: - n_layer: 4 - -scheduler: - trainer: ${trainer} - min_lr: 1e-6 - warmup_lr_init: 1e-6 - warmup_epochs: 10 - total_training_opt_steps: ${max_epochs} - t_in_epochs: True - -trainer: - accelerator: gpu - num_nodes: ${num_nodes} - devices: ${gpus} - strategy: auto - max_epochs: ${max_epochs} - gradient_clip_val: 3 - accumulate_grad_batches: 8 - -model_checkpoint: - save_last: True - monitor: "val_loss" - mode: "min" - save_top_k: 1 - -optimizer: - optim: 'AdamW' - lr: 1e-4 - betas: [0.9, 0.98] - weight_decay: 0.01 - -wandb: - entity: "TinyMyo" - project: "TinyMyo" - save_dir: ${env:LOG_DIR} diff --git a/models/TinyMyo.py b/models/TinyMyo.py index 2fadd02..5d61e22 100644 --- a/models/TinyMyo.py +++ b/models/TinyMyo.py @@ -1,5 +1,6 @@ from dataclasses import dataclass -from typing import Literal, Optional, Tuple +import math +from typing import Literal, Optional import torch import torch.nn as nn @@ -112,7 +113,7 @@ def forward( # reshape the cache for broadcasting # tensor has shape [b, s, 1, h_d // 2, 2] if packed samples, # otherwise has shape [1, s, 1, h_d // 2, 2] - rope_cache = rope_cache.view(-1, xshaped.size(1), 1, xshaped.size(3), 2) + rope_cache = rope_cache.reshape(-1, xshaped.size(1), 1, xshaped.size(3), 2) # tensor has shape [b, s, n_h, h_d // 2, 2] x_out = torch.stack( @@ -239,7 +240,12 @@ def __post_init__(self): self.proj = nn.Linear(self.dim, self.dim) self.p_drop = nn.Dropout(self.proj_drop) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + pos_ids: torch.Tensor = None, + attn_mask: torch.Tensor = None, + ) -> torch.Tensor: """Forward pass for rotary self-attention block.""" B, N, C = x.shape qkv = ( @@ -249,17 +255,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # (K, B, H, N, D) q, k, v = qkv.unbind(0) # each: (B, H, N, D) - q = self.rope(q) - k = self.rope(k) + # Transpose to [B, N, H, D] for RoPE + q = q.transpose(1, 2) + k = k.transpose(1, 2) + + q = self.rope(q, input_pos=pos_ids) + k = self.rope(k, input_pos=pos_ids) + + # Transpose back to [B, H, N, D] for SDPA + q = q.transpose(1, 2) + k = k.transpose(1, 2) - # pylint: disable=not-callable x = F.scaled_dot_product_attention( q, k, v, + attn_mask=attn_mask, # Mask out padded channels dropout_p=self.attn_drop if self.training else 0.0, is_causal=False, - #enable_gqa=False, ) x = x.transpose(2, 1).reshape(B, N, C) @@ -320,8 +333,15 @@ def __post_init__(self): drop=self.drop, ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = x + self.drop_path1(self.attn(self.norm1(x))) + def forward( + self, + x: torch.Tensor, + pos_ids: torch.Tensor = None, + attn_mask: torch.Tensor = None, + ) -> torch.Tensor: + x = x + self.drop_path1( + self.attn(self.norm1(x), pos_ids=pos_ids, attn_mask=attn_mask) + ) x = x + self.drop_path2(self.mlp(self.norm2(x))) return x @@ -376,15 +396,12 @@ class EMGClassificationHead(nn.Module): """ A classification head for EMG (Electromyography) data processing, designed to classify token embeddings into a specified number of classes. - This module takes token embeddings as input, applies a reduction strategy (either mean or concatenation across channels), + This module takes token embeddings as input, applies a reduction strategy (concatenation across channels), averages across patches, and then uses a linear classifier to produce logits for classification. embed_dim (int): Dimensionality of the token embeddings. num_classes (int): Number of output classes for classification. in_chans (int): Number of input channels (e.g., EMG channels). - reduction (str): Reduction strategy for combining channel features. Options are "mean" or "concat". - - "mean": Averages across channels, resulting in feature dimension of embed_dim. - - "concat": Concatenates across channels, resulting in feature dimension of in_chans * embed_dim. Defaults to "concat". Attributes: classifier (nn.Linear): Linear layer for final classification, mapping from reduced feature dimension to num_classes. @@ -393,19 +410,22 @@ class EMGClassificationHead(nn.Module): embed_dim: int num_classes: int in_chans: int - reduction: Literal["mean", "concat"] = "concat" def __post_init__(self): super().__init__() - # after reduction, feature_dim to either embed_dim or in_chans*embed_dim - feat_dim = ( - self.embed_dim - if self.reduction == "mean" - else self.in_chans * self.embed_dim - ) + feat_dim = self.in_chans * self.embed_dim self.classifier = nn.Linear(feat_dim, self.num_classes) + # init weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: @@ -413,22 +433,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: logits: (B, num_classes) """ - _, N, _ = x.shape + B, N, D = x.shape num_patches = N // self.in_chans - if self.reduction == "mean": - # Reshape to (B, in_chans, num_patches, embed_dim) - x = rearrange(x, "b (c p) d -> b c p d", c=self.in_chans, p=num_patches) - # Take mean across the channels (in_chans) - x = x.mean(dim=1) # (B, num_patches, embed_dim) - elif self.reduction == "concat": - # Reshape to (B, num_patches, embed_dim * in_chans) - x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans, p=num_patches) - else: - raise ValueError(f"Unknown reduction type: {self.reduction}") - - # average across patches - x = x.mean(dim=1) # (B, feat_dim) + # Reshape to (B, num_patches, embed_dim * in_chans) + x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans, p=num_patches) + x = x.mean(dim=1) # apply projection to get logits logits = self.classifier(x) @@ -441,45 +451,28 @@ class EMGRegressionHead(nn.Module): A regression head for EMG (Electromyography) signals using convolutional layers. This module processes embedded features from a transformer model to perform - regression, predicting output signals of a specified dimension and length. It supports - different reduction methods for combining channel and patch features, followed by - convolutional layers for regression, and upsampling to a target sequence length. + regression, predicting output signals of a specified dimension and length upsampling to a target sequence length. Args: in_chans (int): Number of input channels (e.g., EMG channels). embed_dim (int): Dimension of the input embeddings. output_dim (int): Dimension of the output regression targets. - reduction (str): Method to reduce features across channels. - "mean" averages embeddings, "concat" concatenates them. Defaults to "concat". hidden_dim (int): Hidden dimension for the convolutional layers. Defaults to 256. dropout (float): Dropout probability applied after the first convolution. Defaults to 0.1. target_length (int): Desired length of the output sequence. If the input length differs, - linear interpolation is used to upsample. Defaults to 500. - - Attributes: - in_chans (int): Number of input channels. - embed_dim (int): Dimension of the embeddings. - output_dim (int): Dimension of the output. - reduction (str): Reduction method used. - dropout (float): Dropout rate. - target_length (int): Target output sequence length. + linear interpolation is used to upsample. Defaults to 1000. """ in_chans: int embed_dim: int output_dim: int - reduction: Literal["mean", "concat"] = "concat" hidden_dim: int = 256 dropout: float = 0.1 - target_length: int = 500 + target_length: int = 1000 def __post_init__(self): super().__init__() - feat_dim = ( - self.embed_dim - if self.reduction == "mean" - else self.in_chans * self.embed_dim - ) + feat_dim = self.in_chans * self.embed_dim self.regressor = nn.Sequential( nn.Conv1d(feat_dim, self.hidden_dim, kernel_size=1), @@ -498,6 +491,15 @@ def __post_init__(self): nn.Conv1d(self.hidden_dim, self.output_dim, kernel_size=1), ) + # Initialize weights + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_uniform_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass of the model. @@ -511,12 +513,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: is the target sequence length and output_dim is the output dimension. """ # x: (B, num_tokens, token_dim) - if self.reduction == "mean": - x = rearrange(x, "b (c p) d -> b p d", c=self.in_chans) - elif self.reduction == "concat": - x = rearrange(x, "b (c p) d -> b p (c d)", c=self.in_chans) - else: - raise ValueError(f"Unknown reduction type: {self.reduction}") + x = rearrange( + x, "b (c p) d -> b p (c d)", c=self.in_chans, p=x.size(1) // self.in_chans + ) # conv head expects (B, C, L) x = x.transpose(1, 2) # (B, feat_dim, num_patches) @@ -556,10 +555,7 @@ class TinyMyo(nn.Module): drop_path (float, optional): Stochastic depth drop path rate. Defaults to 0.1. norm_layer (nn.Module, optional): Normalization layer class. Defaults to nn.LayerNorm. task (str, optional): Task type, one of "pretraining", "classification", or "regression". Defaults to "classification". - classification_type (str, optional): Type of classification (e.g., "ml" for multi-label). Defaults to "ml". - reduction_type (str, optional): Type of reduction to apply, either "mean" or "concat". Defaults to "concat". num_classes (int, optional): Number of classes for classification or output dimension for regression. Defaults to 53. - reg_target_len (int, optional): Target length for regression output. Defaults to 500. """ img_size: int = 1000 @@ -575,15 +571,16 @@ class TinyMyo(nn.Module): drop_path: float = 0.1 norm_layer = nn.LayerNorm task: Literal["pretraining", "classification", "regression"] = "classification" - classification_type: Literal["ml", "mc"] = "ml" - reduction_type: Literal["concat", "mean"] = "concat" num_classes: int = 53 - reg_target_len: int = 500 def __post_init__(self): super().__init__() + # Learnable mask token for pretraining (dimension: embed_dim) self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # Learnable channel IDs (always initialized to the maximum number of channels, but only the first C are used based on input) + self.channel_embed = nn.Parameter(torch.zeros(1, 16, 1, self.embed_dim)) + self.patch_embedding = PatchingModule( img_size=self.img_size, patch_size=self.patch_size, @@ -608,45 +605,50 @@ def __post_init__(self): ) self.norm = self.norm_layer(self.embed_dim) - if ( - self.task == "pretraining" or self.num_classes == 0 - ): # reconstruction (pre-training) + # reconstruction (pre-training) + if self.task == "pretraining": self.model_head = PatchReconstructionHead( img_size=self.img_size, patch_size=self.patch_size, in_chans=self.in_chans, embed_dim=self.embed_dim, ) - elif self.task == "classification" and self.num_classes > 0: + # classification or regression tasks + elif self.task == "classification": + assert self.num_classes > 0, ( + "num_classes must be > 0 for classification task" + ) self.model_head = EMGClassificationHead( embed_dim=self.embed_dim, num_classes=self.num_classes, in_chans=self.in_chans, - reduction=self.reduction_type, ) elif self.task == "regression": + assert self.num_classes > 0, ( + "num_classes must be > 0 for regression task (output dimension of regression targets)" + ) self.model_head = EMGRegressionHead( in_chans=self.in_chans, embed_dim=self.embed_dim, output_dim=self.num_classes, - reduction=self.reduction_type, - target_length=self.reg_target_len, + target_length=self.img_size, ) else: raise ValueError(f"Unknown task type {self.task}") self.initialize_weights() # Some checks - assert ( - self.img_size % self.patch_size == 0 - ), f"img_size ({self.img_size}) must be divisible by patch_size ({self.patch_size})" + assert self.img_size % self.patch_size == 0, ( + f"img_size ({self.img_size}) must be divisible by patch_size ({self.patch_size})" + ) def initialize_weights(self): """Initializes the model weights.""" - # Encodings Initializations code taken from the LaBraM paper trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.channel_embed, std=0.02) self.apply(self._init_weights) + self.fix_init_weight() def _init_weights(self, m): """Initializes the model weights.""" @@ -659,67 +661,86 @@ def _init_weights(self, m): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) + def fix_init_weight(self): + """Rescales the weights of attention and MLP layers to improve training stability.""" + + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + def prepare_tokens( self, x_signal: torch.Tensor, mask: Optional[torch.Tensor] = None ) -> torch.Tensor: - """ - Prepares input tokens by embedding patches and applying masking if provided. - Args: - x_signal (torch.Tensor): Input signal tensor of shape (B, C, T). - mask (Optional[torch.Tensor]): Optional mask tensor of shape (B, C, T) indicating which patches to mask. - Returns: - torch.Tensor: Prepared token embeddings of shape (B, N, D) where N is number of patches and D is embed_dim. - """ - x_patched = self.patch_embedding(x_signal) # [B, N, D] - x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel + _, C, T = x_signal.shape + + x_patched = self.patch_embedding(x_signal) # [B, N, D] where N = C * P + P = T // self.patch_size + + # Unflatten to grid:[Batch, Channels, Patches, Dim] + x_patched = rearrange(x_patched, "B (C P) D -> B C P D", C=C, P=P) + if mask is not None: - mask_tokens = self.mask_token.repeat( - x_masked.shape[0], x_masked.shape[1], 1 - ) # (B, N, D) N = C * num_patches_per_channel - mask = rearrange( - mask, "B C (S P) -> B (C S) P", P=self.patch_size - ) # (B, C, T) -> (B, N, P) - mask = ( - (mask.sum(dim=-1) > 0).unsqueeze(-1).float() - ) # (B, N, 1), since a patch is either fully masked or not - x_masked = torch.where(mask.bool(), mask_tokens, x_masked) - return x_masked + # Reduce independent signal mask (B, C, T) to patch mask (B, C, P) + mask_p = rearrange(mask, "B C (P s) -> B C P s", s=self.patch_size).any( + dim=-1 + ) + mask_exp = mask_p.unsqueeze(-1) # (B, C, P, 1) - def forward( - self, x_signal: torch.Tensor, mask: Optional[torch.BoolTensor] = None - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass of the TinyMyo model. + # Apply Masking + x_patched = torch.where( + mask_exp, self.mask_token.view(1, 1, 1, -1), x_patched + ) - This method processes the input signal tensor through the transformer blocks, - applies normalization, and then either reconstructs the signal or performs - classification/regression based on the model's configuration. + # Apply Spatial Identity + # Crop to C channels, fewer channels can be used + x = x_patched + self.channel_embed[:, :C, :, :] - Args: - x_signal (torch.Tensor): The input signal tensor of shape [B, C, T], - where B is batch size, C is number of channels, and T is the temporal dimension. - mask (Optional[torch.BoolTensor]): Optional boolean mask tensor for - masking certain tokens during processing. If None, no masking is applied. + return rearrange(x, "B C P D -> B (C P) D") - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - If num_classes == 0 (reconstruction mode): The reconstructed signal - tensor of shape [B, N, patch_size] and the original input signal tensor. - - Otherwise (classification/regression mode): The output tensor of shape - [B, Out] and the original input signal tensor. - """ - x_original = x_signal.clone() + def forward( + self, + x_signal: torch.Tensor, + mask: Optional[torch.BoolTensor] = None, + pad_mask_ch: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + B, C, T = x_signal.shape x = self.prepare_tokens(x_signal, mask=mask) - # forward pass through transformer blocks + _, N, _ = x.shape + # Number of patches per channel + P = T // self.patch_size + assert C * P == N, f"Token shape mismatch: expected {C * P} tokens, got {N}." + + # Cyclic Sequence for RoPE + # Sequence resets for each channel:[0..P-1, 0..P-1, ...] + pos_ids_single = torch.arange(P, device=x.device).repeat(C) + pos_ids = pos_ids_single.unsqueeze(0).expand(B, -1) # Shape (B, N) + + # Attention Masking for Padded Channels + attn_mask = None + if pad_mask_ch is not None: + token_pad_mask = pad_mask_ch.repeat_interleave(P, dim=1) + # Boolean mask for SDPA: True means keep, False means mask out + attn_mask = (~token_pad_mask).reshape(B, 1, 1, N) + + # Pass through RoPE-Attention Blocks for blk in self.blocks: - x = blk(x) - x_latent = self.norm(x) # [B, N, D] + x = blk(x, pos_ids=pos_ids, attn_mask=attn_mask) + + # Final normalization + x = self.norm(x) + + # Pass through task-specific head + return self.model_head(x) - if self.num_classes == 0: # reconstruction - x_reconstructed = self.model_head(x_latent) # [B, N, patch_size] - return x_reconstructed, x_original - else: # classification or regression - x_out = self.model_head(x_latent) # [B, Out] - return x_out, x_original +if __name__ == "__main__": + model = TinyMyo() + input_signal = torch.randn(1, 16, 1000) # (B, C, T) + with torch.no_grad(): + output = model(input_signal) + print("Input shape:", input_signal.shape) + print("Output shape:", output.shape) diff --git a/run_train.py b/run_train.py index 650073b..1d392fd 100644 --- a/run_train.py +++ b/run_train.py @@ -42,7 +42,9 @@ for env_var in ["DATA_PATH", "CHECKPOINT_DIR", "LOG_DIR"]: env_var_value = os.environ.get(env_var) if env_var_value is None or env_var_value == "#CHANGEME": - raise RuntimeError(f"Environment variable {env_var} is not set. Please set it before running the script.") + raise RuntimeError( + f"Environment variable {env_var} is not set. Please set it before running the script." + ) OmegaConf.register_new_resolver("env", lambda key: os.getenv(key)) OmegaConf.register_new_resolver("get_method", hydra.utils.get_method) @@ -69,12 +71,13 @@ def train(cfg: DictConfig): loggers = [tb_logger] # Weights & Biases + wandb_logger = None if cfg.wandb: wandb_logger = WandbLogger( entity=cfg.wandb.entity, project=cfg.wandb.project, save_dir=cfg.wandb.save_dir, - name=version, + name=cfg.wandb.run_name if cfg.wandb.run_name else version, offline=cfg.wandb.offline, ) loggers.append(wandb_logger) @@ -173,7 +176,8 @@ def train(cfg: DictConfig): datamodule=data_module, results=results, accelerator=cfg.trainer.accelerator, - last_ckpt=best_ckpt, + ckpt=best_ckpt, + wandb_logger=wandb_logger, ) if not cfg.training: @@ -189,17 +193,20 @@ def _run_test( datamodule: pl.LightningDataModule, results, accelerator, - last_ckpt, + ckpt, + wandb_logger=None, ): trainer = pl.Trainer( accelerator=accelerator, devices=1, + logger=wandb_logger if wandb_logger else [], ) print("===> Start testing") - test_results = trainer.test(module, datamodule=datamodule, ckpt_path=last_ckpt) + test_results = trainer.test(module, datamodule=datamodule, ckpt_path=ckpt) results["test_metrics"] = test_results return results, trainer + @hydra.main(config_path="./config", config_name="defaults", version_base="1.1") def run(cfg: DictConfig): print(f"PyTorch-Lightning Version: {pl.__version__}") diff --git a/tasks/finetune_task_EMG.py b/tasks/finetune_task_EMG.py index fa51849..b6d4991 100644 --- a/tasks/finetune_task_EMG.py +++ b/tasks/finetune_task_EMG.py @@ -16,7 +16,7 @@ # * * # * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -from typing import Optional, Tuple +from typing import Tuple import hydra import pytorch_lightning as pl @@ -34,9 +34,7 @@ Precision, Recall, ) -from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError, R2Score - -from util.train_utils import MinMaxNormalization +from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError class FinetuneTask(pl.LightningModule): @@ -49,7 +47,6 @@ class FinetuneTask(pl.LightningModule): Attributes: model (nn.Module): The instantiated neural network. num_classes (int): Number of target classes or regression outputs. - classification_type (str): Format of classification (e.g., 'bc', 'ml'). task (str): The specific task type ('classification' or 'regression'). normalize (bool): Whether input normalization is enabled. criterion (nn.Module): The loss function (CrossEntropy or L1). @@ -69,15 +66,8 @@ def __init__(self, hparams: DictConfig): self.save_hyperparameters(hparams) self.model = hydra.utils.instantiate(self.hparams.model) self.num_classes = self.hparams.model.num_classes - self.classification_type = self.hparams.model.classification_type self.task = self.hparams.model.task - # Enable normalization if specified in parameters - self.normalize = False - if "input_normalization" in self.hparams and self.hparams.input_normalization.normalize: - self.normalize = True - self.normalize_fct = MinMaxNormalization() - if self.task == "regression": self.criterion = nn.L1Loss() @@ -95,15 +85,13 @@ def __init__(self, hparams: DictConfig): else: # Loss function - self.criterion = nn.CrossEntropyLoss(label_smoothing=0.1) + self.criterion = hydra.utils.instantiate(self.hparams.criterion) # Classification mode detection if not isinstance(self.num_classes, int): raise TypeError("Number of classes must be an integer.") elif self.num_classes < 2: raise ValueError("Number of classes must be at least 2.") - elif self.num_classes == 2: - self.classification_task = "binary" else: self.classification_task = "multiclass" @@ -120,7 +108,9 @@ def __init__(self, hparams: DictConfig): num_classes=self.num_classes, average="macro", ), - "recall": Recall(task="multiclass", num_classes=self.num_classes, average="macro"), + "recall": Recall( + task="multiclass", num_classes=self.num_classes, average="macro" + ), "precision": Precision( task=self.classification_task, num_classes=self.num_classes, @@ -131,7 +121,9 @@ def __init__(self, hparams: DictConfig): num_classes=self.num_classes, average="macro", ), - "cohen_kappa": CohenKappa(task=self.classification_task, num_classes=self.num_classes), + "cohen_kappa": CohenKappa( + task=self.classification_task, num_classes=self.num_classes + ), } ) logit_metrics = MetricCollection( @@ -155,7 +147,6 @@ def __init__(self, hparams: DictConfig): self.val_logit_metrics = logit_metrics.clone(prefix="val/") self.test_logit_metrics = logit_metrics.clone(prefix="test/") - def load_pretrained_checkpoint(self, model_ckpt: str) -> None: """Loads a pretrained PyTorch Lightning checkpoint (.ckpt). @@ -201,48 +192,30 @@ def load_safetensors_checkpoint(self, model_ckpt: str) -> None: print("Pretrained model ready.") - def generate_fake_mask(self, batch_size: int, C: int, T: int) -> torch.Tensor: - """Creates a dummy boolean mask tensor to simulate attention masking. - - Args: - batch_size (int): Batch size (B). - C (int): Number of channels. - T (int): Sequence length (tokens). - - Returns: - torch.Tensor: Boolean mask of shape (B, C, T) initialized to False. - """ - return torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device) - - def _step(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> dict: + def _step(self, X: torch.Tensor) -> dict: """Performs a forward pass and extracts probabilities and labels. Args: X (torch.Tensor): Input EMG tensor of shape (B, C, T). - mask (Optional[torch.Tensor]): Attention mask. Defaults to None. Returns: - dict: Dictionary with keys "label", "probs", and "logits". - - Raises: - NotImplementedError: If classification_type is not 'bc' or 'ml'. + dict: Dictionary with keys "label", and "logits". """ - y_pred_logits, _ = self.model(X, mask=mask) - - if self.classification_type in ("bc", "ml"): - y_pred_probs = torch.softmax(y_pred_logits, dim=1) - y_pred_label = torch.argmax(y_pred_probs, dim=1) + y_pred_logits = self.model(X) + if self.task != "regression": + y_pred_label = torch.argmax(y_pred_logits, dim=1) else: - raise NotImplementedError(f"No valid classification type: {self.classification_type}") + y_pred_label = None return { "label": y_pred_label, - "probs": y_pred_probs, "logits": y_pred_logits, } - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + def training_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: """Standard PyTorch Lightning training step. Args: @@ -253,20 +226,20 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int torch.Tensor: Computed loss. """ X, y = batch - if self.normalize: - X = self.normalize_fct(X) - mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2]) - y_pred = self._step(X, mask=mask) + + y_pred = self._step(X) loss = self.criterion(y_pred["logits"], y) if self.task == "regression": - logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) + logits_flat = y_pred["logits"].reshape( + -1, self.num_classes + ) # (B*T, num_classes) y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) self.train_mean_metrics(logits_flat, y_flat) self.log_dict(self.train_mean_metrics, on_step=True, on_epoch=False) else: self.train_label_metrics(y_pred["label"], y) - self.train_logit_metrics(self._handle_binary(y_pred["logits"]), y) + self.train_logit_metrics(y_pred["logits"], y) self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False) self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False) self.log( @@ -280,7 +253,9 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int ) return loss - def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + def validation_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: """Standard PyTorch Lightning validation step. Args: @@ -291,26 +266,35 @@ def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: i torch.Tensor: Computed validation loss. """ X, y = batch - if self.normalize: - X = self.normalize_fct(X) - mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2]) - y_pred = self._step(X, mask=mask) + + y_pred = self._step(X) loss = self.criterion(y_pred["logits"], y) if self.task == "regression": - logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) + logits_flat = y_pred["logits"].reshape( + -1, self.num_classes + ) # (B*T, num_classes) y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) self.val_mean_metrics(logits_flat, y_flat) self.log_dict(self.val_mean_metrics, on_step=False, on_epoch=True) else: self.val_label_metrics(y_pred["label"], y) - self.val_logit_metrics(self._handle_binary(y_pred["logits"]), y) - self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True) + self.val_logit_metrics(y_pred["logits"], y) + self.log_dict( + self.val_label_metrics, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True) self.log("val_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss - def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: + def test_step( + self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int + ) -> torch.Tensor: """Standard PyTorch Lightning test step. Args: @@ -321,71 +305,86 @@ def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: Computed test loss. """ X, y = batch - if self.normalize: - X = self.normalize_fct(X) - mask = self.generate_fake_mask(X.shape[0], X.shape[1], X.shape[2]) - y_pred = self._step(X, mask=mask) + + y_pred = self._step(X) loss = self.criterion(y_pred["logits"], y) if self.task == "regression": - logits_flat = y_pred["logits"].reshape(-1, self.num_classes) # (B*T, num_classes) + logits_flat = y_pred["logits"].reshape( + -1, self.num_classes + ) # (B*T, num_classes) y_flat = y.reshape(-1, self.num_classes) # (B*T, num_classes) self.test_mean_metrics(logits_flat, y_flat) self.log_dict(self.test_mean_metrics, on_step=False, on_epoch=True) else: self.test_label_metrics(y_pred["label"], y) - self.test_logit_metrics(self._handle_binary(y_pred["logits"]), y) + self.test_logit_metrics(y_pred["logits"], y) self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True) self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True) - self.log("test_loss", loss, prog_bar=True, logger=True, sync_dist=True) return loss - def lr_scheduler_step(self, scheduler: torch.optim.lr_scheduler._LRScheduler, metric: Optional[torch.Tensor]) -> None: - """Custom scheduler step logic for step-based schedulers. - - Args: - scheduler (torch.optim.lr_scheduler._LRScheduler): The optimizer scheduler. - metric (Optional[torch.Tensor]): Optional metric for ReduceLROnPlateau. - """ - scheduler.step(epoch=self.current_epoch) - def configure_optimizers(self) -> dict: - """Configures optimizers and learning rate schedulers. - - Implements layer-wise learning rate decay for the Transformer encoder/Mamba blocks, - ensuring lower layers decay more than the head. + """ + Configure the optimizer and learning rate scheduler. Returns: - dict: Configuration for the PyTorch Lightning trainer. - - Raises: - NotImplementedError: If the optimizer name is not supported. + dict: Configuration dictionary with optimizer and LR scheduler. """ - num_blocks = self.hparams.model.n_layer - params_to_pass = [] base_lr = self.hparams.optimizer.lr + base_wd = getattr(self.hparams.optimizer, "weight_decay", 0.0) decay_factor = self.hparams.layerwise_lr_decay + num_blocks = self.hparams.model.n_layer + # 1) gather all model parameters, but tag which ones should get no weight decay + no_decay = {"positional_embedding", "channel_embed"} + params_to_pass = [] for name, param in self.model.named_parameters(): + if not param.requires_grad: + continue + + # Skip parameters that belong to the head, they will be handled separately + if "model_head" in name: + continue + + # find layer‐wise learning rate lr = base_lr - if "norm_layers" in name: - block_nr = int(name.split(".")[1]) - lr *= decay_factor ** (num_blocks - block_nr) - params_to_pass.append({"params": param, "lr": lr}) - - if self.hparams.optimizer.optim == "AdamW": - optimizer = torch.optim.AdamW( - params_to_pass, - lr=base_lr, - weight_decay=self.hparams.optimizer.weight_decay, - betas=self.hparams.optimizer.betas, + if name.startswith("blocks."): + block_id = int(name.split(".")[1]) + lr = base_lr * (decay_factor ** (num_blocks - block_id)) + + # choose weight_decay + wd = 0.0 if any(nd in name for nd in no_decay) else base_wd + + params_to_pass.append( + { + "params": [param], + "lr": lr, + "weight_decay": wd, + } ) - else: - raise NotImplementedError("No valid optimizer name") + + # 2) head params always get the base lr & base weight decay + head_params = [ + p for n, p in self.model.model_head.named_parameters() if p.requires_grad + ] + params_to_pass.append( + { + "params": head_params, + "lr": base_lr, + "weight_decay": base_wd, + } + ) + + optimizer = torch.optim.AdamW( + params_to_pass, + lr=self.hparams.optimizer.lr, + weight_decay=self.hparams.optimizer.weight_decay, + betas=self.hparams.optimizer.betas, + ) scheduler = hydra.utils.instantiate( self.hparams.scheduler, - optimizer=optimizer, + optimizer, total_training_opt_steps=self.trainer.estimated_stepping_batches, ) @@ -398,17 +397,5 @@ def configure_optimizers(self) -> dict: return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} - def _handle_binary(self, preds: torch.Tensor) -> torch.Tensor: - """Slices logits for binary classification task. - - Args: - preds (torch.Tensor): Logit outputs from the model. - - Returns: - torch.Tensor: Logits/probabilities for the positive class if binary, else full preds. - """ - if self.classification_task == "binary" and self.classification_type != "mc": - return preds[:, 1].squeeze() - else: - return preds - + def lr_scheduler_step(self, scheduler, metric): + scheduler.step(epoch=self.current_epoch) diff --git a/tasks/pretrain_task_EMG.py b/tasks/pretrain_task_EMG.py index c535075..b685113 100644 --- a/tasks/pretrain_task_EMG.py +++ b/tasks/pretrain_task_EMG.py @@ -14,27 +14,20 @@ # * See the License for the specific language governing permissions and * # * limitations under the License. * # * * -# * Author: Matteo Fasulo * +# * Author: Matteo Fasulo * # *----------------------------------------------------------------------------* -import hydra import matplotlib.pyplot as plt -import numpy as np -import pytorch_lightning as pl -import torch -import torch_optimizer as torch_optim from omegaconf import DictConfig +import hydra +import torch +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +import wandb from util.train_utils import MinMaxNormalization class MaskTask(pl.LightningModule): - """ - PyTorch Lightning module for training a model with masked reconstruction. - - Args: - hparams (DictConfig): Parameters and configurations loaded via Hydra. - """ - def __init__(self, hparams: DictConfig): super().__init__() self.save_hyperparameters(hparams) @@ -53,6 +46,7 @@ def __init__(self, hparams: DictConfig): def generate_mask(self, batch_size, C, T): """ Generate per-sample patch-level boolean masks (MAE-style). + Fully independent random masking across both time and channels. Returns: mask_full (torch.BoolTensor): Shape (B, C, T) @@ -61,114 +55,104 @@ def generate_mask(self, batch_size, C, T): patch_H, patch_W = self.patch_size num_patches_H = C // patch_H num_patches_W = T // patch_W - N = num_patches_H * num_patches_W + + # Total number of patches per sample (e.g., 16 channels * 50 time patches = 800) + num_patches_total = num_patches_H * num_patches_W # Number of patches to mask per sample - num_to_mask = int(N * self.masking_ratio) + num_to_mask = int(num_patches_total * self.masking_ratio) - # Generate patch-level mask (B, N) - vectorized - mask_patches = torch.zeros(batch_size, N, dtype=torch.bool, device=self.device) + # Generate a flat mask over ALL patches (B, num_patches_total) + mask_flat = torch.zeros( + batch_size, num_patches_total, dtype=torch.bool, device=self.device + ) for b in range(batch_size): - selected = torch.randperm(N, device=self.device)[:num_to_mask] - mask_patches[b, selected] = True + selected = torch.randperm(num_patches_total, device=self.device)[ + :num_to_mask + ] + mask_flat[b, selected] = True - # unpatchify using reshape and repeat_interleave - # (B, N) -> (B, num_patches_H, num_patches_W) - mask_patches_2d = mask_patches.reshape(batch_size, num_patches_H, num_patches_W) + # Reshape the flat mask back into the 2D grid of patches (B, num_patches_H, num_patches_W) + mask_patches_2d = mask_flat.view(batch_size, num_patches_H, num_patches_W) - # Expand to full shape using repeat_interleave + # Expand to full signal shape using repeat_interleave # (B, num_patches_H, num_patches_W) -> (B, C, T) - mask_full = mask_patches_2d.repeat_interleave(patch_H, dim=1).repeat_interleave(patch_W, dim=2) + mask_full = mask_patches_2d.repeat_interleave(patch_H, dim=1).repeat_interleave( + patch_W, dim=2 + ) return mask_full def unpatchify(self, x_patches: torch.Tensor, in_chans: int) -> torch.Tensor: """ Convert patch embeddings (B, N, P) back to waveform (B, C, T) - - Args: - x_patches: (B, N, P) - in_chans: number of channels C - Returns: - x_reconstructed: (B, C, T) """ B, N, P = x_patches.shape num_patches_per_chan = N // in_chans x_recon = x_patches.reshape(B, in_chans, num_patches_per_chan * P) return x_recon - def training_step(self, batch, batch_idx): - """ - Training step: apply mask, normalize and compute loss. + def _step(self, X): + B, C, T = X.shape - Args: - batch (torch.Tensor): Input batch. - batch_idx (int): Batch index. + # Detect zero-padded channels (Shape: B, C. True if channel is 0.0 padded) + pad_mask_ch = (X.abs().max(dim=-1).values == 0) - Returns: - torch.Tensor: Loss value. - """ - X = batch - mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2]) + # Generate symmetrical time mask + mask = self.generate_mask(B, C, T) + + # Remove padded channels from the mask + # Broadcast pad_mask_ch (B, C) to (B, C, T) and set mask to False + mask[pad_mask_ch.unsqueeze(-1).expand(-1, -1, T)] = False if self.normalize: X = self.normalize_fct(X) - x_reconstructed, x_original = self.model(X, mask=mask) # x_reconstructed: (B, N, P) + # Pass pad_mask_ch to the model so the attention can ignore them. + x_reconstructed = self.model(X, mask=mask, pad_mask_ch=pad_mask_ch) - # unpatchify to original signal shape (B, C, T) - x_reconstructed_unpatched = self.unpatchify(x_reconstructed, self.hparams.model.in_chans) - - # Compute loss on masked parts and unmasked parts (with coefficient) - masked_loss, unmasked_loss = self.criterion(x_reconstructed_unpatched, x_original, mask) - loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss - - self.log( - "train_loss", - masked_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss + return { + "x_original": X, + "x_reconstructed": x_reconstructed, + "mask": mask, + } - def validation_step(self, batch, batch_idx): - """ - Validation step: apply mask, normalize, compute loss and log signals. - Args: - batch (torch.Tensor): Input batch. - batch_idx (int): Batch index. + def training_step(self, X, batch_idx): + out = self._step(X) + x_original = out["x_original"] + mask = out["mask"] - Returns: - torch.Tensor: Loss value. - """ - X = batch - mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2]) + x_reconstructed = self.unpatchify(out["x_reconstructed"], in_chans=self.model.in_chans) + masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask) + loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss - if self.normalize: - X = self.normalize_fct(X) + losses = { + "loss": loss, + "masked_loss": masked_loss, + "unmasked_loss": unmasked_loss, + } - x_reconstructed, x_original = self.model(X, mask=mask) # x_reconstructed: (B, N, P) + self.log_dict({f"train_{k}": v for k, v in losses.items()}, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True) + return loss - # unpatchify to original signal shape (B, C, T) - x_reconstructed_unpatched = self.unpatchify(x_reconstructed, self.hparams.model.in_chans) + def validation_step(self, X, batch_idx): + out = self._step(X) + x_original = out["x_original"] + x_reconstructed = self.unpatchify(out["x_reconstructed"], in_chans=self.model.in_chans) + mask = out["mask"] - # Compute loss on masked parts and unmasked parts (with coefficient) - masked_loss, unmasked_loss = self.criterion(x_reconstructed_unpatched, x_original, mask) + masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask) loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss - self.log( - "val_loss", - loss, - prog_bar=True, - on_step=False, - on_epoch=True, - logger=True, - sync_dist=True, - ) + losses = { + "loss": loss, + "masked_loss": masked_loss, + "unmasked_loss": unmasked_loss, + } + + self.log_dict({f"val_{k}": v for k, v in losses.items()}, prog_bar=True, on_step=False, on_epoch=True, logger=True, sync_dist=True) # Fixed indices for logging signals random_indices = [6, 16, 30] @@ -177,46 +161,22 @@ def validation_step(self, batch, batch_idx): if batch_idx == 0: self.log_signals_with_mask( x_original.float(), - x_reconstructed_unpatched.float(), + x_reconstructed.float(), mask, batch_indices=random_indices, - batch_idx=batch_idx, ) return loss def configure_optimizers(self): - """ - Configure optimizer and scheduler based on parameters. - - Returns: - dict: Dictionary with optimizer and scheduler for PyTorch Lightning. - """ - if self.hparams.optimizer.optim == "SGD": - optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9) - elif self.hparams.optimizer.optim == "Adam": - optimizer = torch.optim.Adam( - self.model.parameters(), - lr=self.hparams.optimizer.lr, - weight_decay=self.hparams.optimizer.weight_decay, - ) - elif self.hparams.optimizer.optim == "AdamW": - optimizer = torch.optim.AdamW( - self.model.parameters(), - lr=self.hparams.optimizer.lr, - weight_decay=self.hparams.optimizer.weight_decay, - ) - elif self.hparams.optimizer.optim == "LAMB": - optimizer = torch_optim.Lamb( - self.model.parameters(), - lr=self.hparams.optimizer.lr, - ) - else: - raise NotImplementedError("No valid optim name") - - scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer) + optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=self.hparams.optimizer.lr, + weight_decay=self.hparams.optimizer.weight_decay, + ) + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer, total_training_opt_steps=self.trainer.estimated_stepping_batches) lr_scheduler_config = { "scheduler": scheduler, - "interval": "epoch", + "interval": "step", "frequency": 1, "monitor": "val_loss", } @@ -224,9 +184,9 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} def lr_scheduler_step(self, scheduler, metric): - scheduler.step(epoch=self.current_epoch) + scheduler.step(self.global_step) - def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indices=None, batch_idx=None): + def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indices=None): """ Log original and reconstructed signals highlighting masked regions. @@ -235,7 +195,6 @@ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indice reconstructed (torch.Tensor): Signals reconstructed by the model. mask (torch.BoolTensor, optional): Applied mask. batch_indices (list[int], optional): Batch indices to log. - batch_idx (int, optional): Current batch index. """ patch_H, patch_W = self.patch_size batch_size, C, T = original.shape @@ -250,18 +209,8 @@ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indice original_signal_c2 = original_signal[:patch_H, :] reconstructed_signal_c2 = reconstructed_signal[:patch_H, :] - ax.plot( - original_signal_c2[0].cpu().numpy(), - label="Original Channel 0", - color="blue", - alpha=0.7, - ) - ax.plot( - reconstructed_signal_c2[0].cpu().numpy(), - label="Reconstructed Channel 0", - color="orange", - alpha=0.7, - ) + ax.plot(original_signal_c2[0].cpu().numpy(), label='Original Channel 0', color='blue', alpha=0.7) + ax.plot(reconstructed_signal_c2[0].cpu().numpy(), label='Reconstructed Channel 0', color='orange', alpha=0.7) if mask is not None: mask_c2 = mask[batch_idx, :patch_H, :] @@ -270,26 +219,27 @@ def log_signals_with_mask(self, original, reconstructed, mask=None, batch_indice # Highlight masked regions with a light gray transparent band for i in range(patch_H): for j in range(T // patch_W): - if mask_c2[i, j * patch_W : (j + 1) * patch_W].all(): - ax.axvspan( - j * patch_W, - (j + 1) * patch_W, - color="lightgray", - alpha=0.1, - ) + if mask_c2[i, j * patch_W:(j + 1) * patch_W].all(): + ax.axvspan(j * patch_W, (j + 1) * patch_W, color='lightgray', alpha=0.3) indices.append(j) - # Remove duplicates and sort highlighted indices - indices_array = np.array(indices) - indices_array = np.unique(indices_array) - ax.set_title(f"Signal Reconstruction - batch_ {batch_idx}") ax.legend() - # Log the figure on TensorBoard with batch and index in the title - self.logger.experiment.add_figure( - f"Original and Reconstructed Signals with Mask (batch_0_ {batch_idx}, F1 = 0)", - fig, - self.current_epoch, - ) - plt.close(fig) + # Log the figure to WandB + if self.trainer.is_global_zero: + wandb_logger = None + for logger in self.trainer.loggers: + if isinstance(logger, WandbLogger): + wandb_logger = logger + break + + if wandb_logger: + wandb_logger.experiment.log({ + "reconstruction/channel0": wandb.Image( + fig, + caption=f"epoch={self.current_epoch}, batch={batch_idx}" + ) + }, step=self.global_step) + + plt.close(fig) \ No newline at end of file diff --git a/util/ckpt_to_safetensor.py b/util/ckpt_to_safetensor.py index 2441f74..04415e7 100644 --- a/util/ckpt_to_safetensor.py +++ b/util/ckpt_to_safetensor.py @@ -34,7 +34,7 @@ parser.add_argument( "--safetensor_path", type=str, - default="model.safetensors", + required=True, help="Path to save the converted safetensors file.", ) parser.add_argument(