-
Notifications
You must be signed in to change notification settings - Fork 537
Add native PyTorch nn.MultiheadAttention converter support #1457
Copy link
Copy link
Open
Labels
Description
Motivation
There is currently no way to convert a PyTorch model containing nn.MultiheadAttention to HLS. The existing HGQ2 path supports MHA via Keras v3, but requires rewriting the model with the HGQ2 API. Many users have existing PyTorch models they want to deploy directly without retraining.
Proposal
Decompose nn.MultiheadAttention into existing hls4ml layers (EinsumDense, Einsum, Softmax) at the converter level, requiring no new C++ backend code.
Supported:
- Self-attention and cross-attention (kdim/vdim != embed_dim)
- Vivado and Vitis backends, io_parallel
Not supported (future work): - batch_first=False
- Attention masks
- io_stream
I have a working implementation with 6 passing tests. Happy to open a PR if there is interest.
Related work
- HGQ2 QMultiHeadAttention (Keras v3 path)
- https://arxiv.org/abs/2409.05207
- https://arxiv.org/abs/2405.00645
Reactions are currently unavailable