Skip to content

Add native PyTorch nn.MultiheadAttention converter support #1457

@ChunzhengLab

Description

@ChunzhengLab

Motivation

There is currently no way to convert a PyTorch model containing nn.MultiheadAttention to HLS. The existing HGQ2 path supports MHA via Keras v3, but requires rewriting the model with the HGQ2 API. Many users have existing PyTorch models they want to deploy directly without retraining.

Proposal

Decompose nn.MultiheadAttention into existing hls4ml layers (EinsumDense, Einsum, Softmax) at the converter level, requiring no new C++ backend code.

Supported:

  • Self-attention and cross-attention (kdim/vdim != embed_dim)
  • Vivado and Vitis backends, io_parallel
    Not supported (future work):
  • batch_first=False
  • Attention masks
  • io_stream

I have a working implementation with 6 passing tests. Happy to open a PR if there is interest.

Related work

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions