Skip to content

Mixed-precision is not available in torch-xpu 2.10 #2306

@WhitePr

Description

@WhitePr

OS: Arch Linux x86_64
Kernel: Linux 6.19.10-arch1-1
python 3.13.12
Intel Arc B580

python package:

Package                 Version
----------------------- ----------
absl-py                 2.4.0
accelerate              1.13.0
annotated-doc           0.0.4
anyio                   4.13.0
bitsandbytes            0.49.2
certifi                 2026.2.25
charset-normalizer      3.4.7
click                   8.3.1
diffusers               0.32.1
dpcpp-cpp-rt            2025.3.1
einops                  0.7.0
filelock                3.25.2
fsspec                  2026.2.0
grpcio                  1.80.0
h11                     0.16.0
hf-xet                  1.4.3
httpcore                1.0.9
httpx                   0.28.1
huggingface-hub         0.36.2
idna                    3.11
imagesize               2.0.0
impi-rt                 2021.17.0
importlib-metadata      9.0.0
intel-cmplr-lib-rt      2025.3.1
intel-cmplr-lib-ur      2025.3.1
intel-cmplr-lic-rt      2025.3.1
intel-opencl-rt         2025.3.1
intel-openmp            2025.3.1
intel-pti               0.15.0
intel-sycl-rt           2025.3.1
jinja2                  3.1.6
markdown                3.10.2
markdown-it-py          4.0.0
markupsafe              3.0.3
mdurl                   0.1.2
mkl                     2025.3.0
mpmath                  1.3.0
networkx                3.6.1
numpy                   2.4.3
oneccl                  2021.17.1
oneccl-devel            2021.17.1
onemkl-license          2025.3.0
onemkl-sycl-blas        2025.3.0
onemkl-sycl-dft         2025.3.0
onemkl-sycl-lapack      2025.3.0
onemkl-sycl-rng         2025.3.0
onemkl-sycl-sparse      2025.3.0
opencv-python           4.10.0.84
packaging               26.0
pillow                  12.1.1
protobuf                7.34.1
psutil                  7.2.2
pyelftools              0.32
pygments                2.20.0
pyyaml                  6.0.3
regex                   2026.3.32
requests                2.33.1
rich                    14.3.3
safetensors             0.4.5
setuptools              70.2.0
shellingham             1.5.4
sympy                   1.14.0
tbb                     2022.3.0
tcmlib                  1.4.1
tensorboard             2.20.0
tensorboard-data-server 0.7.2
tokenizers              0.21.4
toml                    0.10.2
torch                   2.10.0+xpu
torchvision             0.25.0+xpu
tqdm                    4.67.3
transformers            4.54.1
triton-xpu              3.6.0
typer                   0.24.1
typing-extensions       4.15.0
umf                     1.0.2
urllib3                 2.6.3
voluptuous              0.15.2
werkzeug                3.1.8
zipp                    3.23.0

from shell

/sd-scripts/library/anima_models.py:237: UserWarning: In XPU autocast, but the target dtype is not supported. Disabling autocast.
XPU Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.
  with torch.autocast(device_type=x.device.type, dtype=torch.float32):
Traceback (most recent call last):
  File "/sd-scripts/anima_train_network.py", line 451, in <module>
    trainer.train(args)
    ~~~~~~~~~~~~~^^^^^^
  File "/sd-scripts/train_network.py", line 1430, in train
    accelerator.backward(loss)
    ~~~~~~~~~~~~~~~~~~~~^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/accelerate/accelerator.py", line 2838, in backward
    loss.backward(**kwargs)
    ~~~~~~~~~~~~~^^^^^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/_tensor.py", line 630, in backward
    torch.autograd.backward(
    ~~~~~~~~~~~~~~~~~~~~~~~^
        self, gradient, retain_graph, create_graph, inputs=inputs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/autograd/__init__.py", line 364, in backward
    _engine_run_backward(
    ~~~~~~~~~~~~~~~~~~~~^
        tensors,
        ^^^^^^^^
    ...<5 lines>...
        accumulate_grad=True,
        ^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        t_outputs, *args, **kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
    )  # Calls into the C++ engine to run the backward pass
    ^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/autograd/function.py", line 317, in apply
    return user_fn(self, *args)
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/amp/autocast_mode.py", line 523, in decorate_bwd
    return bwd(*args, **kwargs)  # pyrefly: ignore [not-callable]
  File "/sd-scripts/library/anima_models.py", line 107, in backward
    outputs = ctx.forward_function(*inputs)
  File "/sd-scripts/library/anima_models.py", line 920, in _forward
    self.self_attn(
    ~~~~~~~~~~~~~~^
        rearrange(normalized_x, "b t h w d -> b (t h w) d"),
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    ...<2 lines>...
        rope_emb=rope_emb_L_1_1_D,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
    ),
    ^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sd-scripts/library/anima_models.py", line 362, in forward
    q, k, v = self.compute_qkv(x, context, rope_emb=rope_emb)
              ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sd-scripts/library/anima_models.py", line 337, in compute_qkv
    q = self.q_proj(x)
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sd-scripts/networks/lora_anima.py", line 87, in forward
    lx = self.lora_down(x)
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
    return forward_call(*args, **kwargs)
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/torch/nn/modules/linear.py", line 134, in forward
    return F.linear(input, self.weight, self.bias)
           ~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::BFloat16 != float

Traceback (most recent call last):
  File "/sd-scripts/xpu2.10/bin/accelerate", line 10, in <module>
    sys.exit(main())
             ~~~~^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/accelerate/commands/accelerate_cli.py", line 50, in main
    args.func(args)
    ~~~~~~~~~^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/accelerate/commands/launch.py", line 1405, in launch_command
    simple_launcher(args)
    ~~~~~~~~~~~~~~~^^^^^^
  File "/sd-scripts/xpu2.10/lib/python3.13/site-packages/accelerate/commands/launch.py", line 993, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)

command:

accelerate launch --num_cpu_threads_per_process 1 \                                 
anima_train_network.py \
--pretrained_model_name_or_path="anima-preview2.safetensors" \
--qwen3="qwen_3_06b_base.safetensors" \
--vae="qwen_image_vae.safetensors" \
--dataset_config="dataset.toml" \
--output_dir="output" \
--output_name="ANIMA" \
--save_model_as=safetensors \
--network_module=networks.lora_anima \
--network_dim=32 \
--network_alpha=16 \
--timestep_sampling=sigmoid \
--sigmoid_scale=1.0 \
--learning_rate=0.0001 \
--max_train_epochs=20 \
--sdpa \
--gradient_checkpointing \
--mixed_precision="fp16" \
--lr_scheduler="cosine" \
--cache_latents \
--cache_latents_to_disk \
--blocks_to_swap=0 \
--optimizer_type="AdamW8bit" \
--lr_warmup_steps=0 \
--save_every_n_epochs=1 \
--lr_scheduler_num_cycles=0 \
--logging_dir="./log/" \
--unsloth_offload_checkpointing

I tested bf16 and fp16, neither of which worked, but using --full_bf16 worked fine (--full_fp16 did not).
Works fine on torch-xpu 2.9

By way:
torch-xpu 2.11 cannot be used; this is due to the issue described here: intel/torch-xpu-ops#2800

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions