This project implements a HAT (Hybrid Attention Transformer) model for the SISR (Single Image Super-Resolution) task. The primary goal is to upscale low-resolution (LR) images by a given factor (2x, 4x, 8x) to produce super-resolution (SR) images with high fidelity and perceptual quality. This project focuses on the Classical Image Super-Resolution task.
This implementation is based on papers Activating More Pixels in Image Super-Resolution Transformer and HAT: Hybrid Attention Transformer for Image Restoration.
The following images demonstrate the visual improvements of the HAT model compared to standard bicubic interpolation.
-
Hybrid Attention Block (HAB): The core building block that combines Window-based Multi-head Self-Attention (WMSA) with a Channel Attention Block (CAB). This hybrid approach allows the model to capture both long-range dependencies and local context simultaneously.
-
Overlapping Cross-Attention (OCA): Extends the receptive field by utilizing an overlapping window mechanism. This allows the model to aggregate features across window boundaries, effectively "activating more pixels" for better reconstruction quality.
-
Residual Hybrid Attention Group (RHAG): A hierarchical structure consisting of multiple HABs followed by an Overlapping Cross-Attention Block (OCAB) and a convolutional layer. This design facilitates deep feature extraction while maintaining stable gradient flow through residual connections.
-
Activating More Pixels: Specifically designed to overcome the limitations of standard Transformers in SISR, where only a small portion of input pixels are typically utilized. The combination of HAB and OCA significantly increases the number of pixels contributing to the final super-resolved output.
The model is pre-trained on the ImageNet dataset. During training, the DynamicPairDataset class in datasets.py processes these images by extracting random patches and applying augmentations like horizontal/vertical flips and rotations.
The model is fine-tuned on the DF2K (DIV2K + Flickr2K) dataset. Prior to training, the prepare_data.py script is used to crop HR images to ensure divisibility by the scaling factor and to generate corresponding LR images using MATLAB-like bicubic downsampling. During training, the StaticPairDataset class in datasets.py further processes these images by extracting random patches and applying augmentations like horizontal/vertical flips and rotations.
The Set14 dataset is used for validation.
The test.py script is configured to evaluate the trained model on standard benchmark datasets: Set5, Set14, BSDS100, Urban100, and Manga109.
.
├── checkpoints/ # Model weights (.safetensors) and training states (.pth)
├── images/ # Inference inputs and outputs
├── config.py # Hyperparameters and file paths
├── datasets.py # Dataset classes and image augmentations
├── prepare_data.py # Script for generating HR and LR pairs from raw datasets
├── inference.py # Inference pipeline
├── models.py # HAT model architecture definition
├── test.py # Testing pipeline
├── trainer.py # Trainer class for model training
├── train.py # Training pipeline
└── utils.py # Utility functions
All hyperparameters, paths, and training settings can be configured in the config.py file.
Explanation of some settings:
LOAD_CHECKPOINT: Set toTrueto resume training from the specified checkpoint (fortrain.py).LOAD_BEST_CHECKPOINT: Set toTrueto resume training from the best checkpoint (fortrain.py).TRAIN_DATASET_PATH: Path to the training directory containing HR and LR subfolders generated byprepare_data.py.VAL_DATASET_PATH: Path to the validation directory containing HR and LR subfolders generated byprepare_data.py.TEST_DATASET_PATHS: List of paths to test datasets prepared by theprepare_data.pyscript.DEV_MODE: Set toTrueto use a 10% subset of the training data for quick testing.
- Clone the repository:
git clone https://github.com/ash1ra/HAT.git
cd HAT- Create a
.venvand install dependencies:
uv sync- Activate the virtual environment:
# On Windows
.venv\Scripts\activate
# On Unix or MacOS
source .venv/bin/activate-
Download the ImageNet dataset.
-
Download the DIV2K datasets (
Train Data (HR images)andValidation Data (HR images)). -
Download the standard benchmark datasets (Set5, Set14, BSDS100, Urban100) and Manga109 dataset.
-
Organize your raw data containing original high-resolution images:
data/ ├── DF2K/ │ ├── 1.jpg │ └── ... ├── DIV2K_valid/ │ ├── 1.jpg │ └── ... ├── Set5/ │ ├── baboon.png │ └── ... ├── Set14/ │ └── ... ...or
data/ ├── DF2K.txt ├── DIV2K_valid.txt ├── Set5.txt ├── Set14.txt ... -
Run
prepare_data.pyto generate the training/validation pairs. This script will createHRandLR_x{scaling factor}directories within your dataset path, which are required for the training process. -
Update the paths (
TRAIN_DATASET_PATH, etc.) inconfig.pyto point to these newly created directories.
- Adjust parameters in
config.pyas needed. - To track your experiments, set
USE_WANDB = Trueinconfig.py. The trainer will log loss, learning rate, and visual samples automatically. - Run the training script:
python train.py
- Training progress will be logged to the console and to a file in the
logs/directory. - Checkpoints will be saved in
checkpoints/.
To evaluate the model's performance on the test datasets:
- Ensure the
BEST_CHECKPOINT_DIR_PATHinconfig.pypoints to your trained model (e.g.,checkpoints/best). - Run the test script:
python test.py
- The script will print the average PSNR and SSIM for each dataset.
The inference script supports command-line flags to specify parameters without editing the configuration file. To upscale a single image:
python inference.py -i images/input.png -o images/output.png -s 4Available flags:
-i,--input: Path to the input image.-o,--output: Path to save the result.-s,--scaling-factor: Upscaling factor.-ts,--tile-size: Size of the processing tiles for memory efficiency.-to,--tile-overlap: Number of overlapping pixels between adjacent tiles to prevent edge blending artifacts.-c,--comparison: Generate an additional image comparing the Bicubic baseline, HAT output, and original image.-v,--vertical: Stack comparison images vertically.-dt,--dtype: Floating-point precision used for Automatic Mixed Precision (AMP) during model inference.
The model was pre-trained for 100,000 iterations and fine-tuned for 50,000 iterations with a batch size of 32 on an NVIDIA RTX 4060 Ti (8 GB), which took approximately 216 and 106 hours, respectively. The pre-training dataset consisted of 1,152,197 filtered images from the ImageNet dataset. The fine-tuning dataset consisted of 3450 images from the DF2K dataset. The rest of the hyperparameters are specified in config.py file. The final model selected is the one with the highest PSNR on the validation set.
The final model (checkpoints/best) was evaluated on standard benchmark datasets. Metrics are calculated on the Y-channel after shaving 4px (the scaling factor) from the border.
PSNR (dB) / SSIM Comparison
| Dataset | HAT (this project) | HAT (paper) |
|---|---|---|
| Set5 | 32.84/0.9039 | 33.18/0.9073 |
| Set14 | 29.18/0.7956 | 29.38/0.8001 |
| BSDS100 | 27.93/0.7493 | 28.05/0.7534 |
| Urban100 | 27.70/0.8294 | 28.37/0.8447 |
| Manga109 | 32.29/0.9274 | 32.87/0.9319 |
Note: Differences in results are primarily due to training constraints; I pre-trained the model for 100,000 iterations compared to the original 800,000 iterations and fine-tuned the model for 50,000 iterations compared to the original 250,000. Additionally, the learning rate in this implementation was decayed five times more frequently.
Note 2: Additionally, during the first 78,000 pre-training iterations, one image was inadvertently omitted from the Set14 validation dataset. Correcting this omission caused a sudden but expected drop in validation metrics.
These examples include a diverse set of subjects (e.g., anime, real-world photography) to highlight the model's ability to reconstruct fine details and complex textures.
This implementation is based on the paper Activating More Pixels in Image Super-Resolution Transformer
@misc{chen2023activatingpixelsimagesuperresolution,
title={Activating More Pixels in Image Super-Resolution Transformer},
author={Xiangyu Chen and Xintao Wang and Jiantao Zhou and Yu Qiao and Chao Dong},
year={2023},
eprint={2205.04437},
archivePrefix={arXiv},
primaryClass={eess.IV},
url={https://arxiv.org/abs/2205.04437},
}and on the paper HAT: Hybrid Attention Transformer for Image Restoration.
@misc{chen2025hathybridattentiontransformer,
title={HAT: Hybrid Attention Transformer for Image Restoration},
author={Xiangyu Chen and Xintao Wang and Wenlong Zhang and Xiangtao Kong and Yu Qiao and Jiantao Zhou and Chao Dong},
year={2025},
eprint={2309.05239},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2309.05239},
}DIV2K dataset citation:
@InProceedings{Timofte_2018_CVPR_Workshops,
author = {Timofte, Radu and Gu, Shuhang and Wu, Jiqing and Van Gool, Luc and Zhang, Lei and Yang, Ming-Hsuan and Haris, Muhammad and others},
title = {NTIRE 2018 Challenge on Single Image Super-Resolution: Methods and Results},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2018}
}Manga109 dataset citation:
@article{mtap_matsui_2017,
author={Yusuke Matsui and Kota Ito and Yuji Aramaki and Azuma Fujimoto and Toru Ogawa and Toshihiko Yamasaki and Kiyoharu Aizawa},
title={Sketch-based Manga Retrieval using Manga109 Dataset},
journal={Multimedia Tools and Applications},
volume={76},
number={20},
pages={21811--21838},
doi={10.1007/s11042-016-4020-z},
year={2017}
}
@article{multimedia_aizawa_2020,
author={Kiyoharu Aizawa and Azuma Fujimoto and Atsushi Otsubo and Toru Ogawa and Yusuke Matsui and Koki Tsubota and Hikaru Ikuta},
title={Building a Manga Dataset ``Manga109'' with Annotations for Multimedia Applications},
journal={IEEE MultiMedia},
volume={27},
number={2},
pages={8--18},
doi={10.1109/mmul.2020.2987895},
year={2020}
}This project is licensed under the Apache License 2.0 - see the LICENSE file for details.










