Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,11 @@ num_frames: 81
guidance_scale: 5.0
flow_shift: 3.0

# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only)
# Skips the unconditional forward pass on ~35% of steps via residual compensation.
# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2
use_cfg_cache: False

# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
guidance_rescale: 0.0
num_inference_steps: 30
Expand Down
1 change: 1 addition & 0 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt):
num_frames=config.num_frames,
num_inference_steps=config.num_inference_steps,
guidance_scale=config.guidance_scale,
use_cfg_cache=config.use_cfg_cache,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might cause errors when trying to run WAN models other than WAN2.1 T2V as generate_wan.py is common script for all WAN models. Better to add use_cfg_cache parameter to all other WAN config files.

)
elif model_key == WAN2_2:
return pipeline(
Expand Down
105 changes: 105 additions & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,108 @@ def transformer_forward_pass(
latents = latents[:bsz]

return noise_pred, latents


@partial(jax.jit, static_argnames=("guidance_scale",))
def transformer_forward_pass_full_cfg(
graphdef,
sharded_state,
rest_of_state,
latents_doubled: jnp.array,
timestep: jnp.array,
prompt_embeds_combined: jnp.array,
guidance_scale: float,
encoder_hidden_states_image=None,
):
"""Full CFG forward pass.

Accepts pre-doubled latents and pre-concatenated [cond, uncond] prompt embeds.
Returns the merged noise_pred plus raw noise_cond and noise_uncond for
CFG cache storage. Keeping cond/uncond separate avoids a second forward
pass on cache steps.
"""
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
bsz = latents_doubled.shape[0] // 2
noise_pred = wan_transformer(
hidden_states=latents_doubled,
timestep=timestep,
encoder_hidden_states=prompt_embeds_combined,
encoder_hidden_states_image=encoder_hidden_states_image,
)
noise_cond = noise_pred[:bsz]
noise_uncond = noise_pred[bsz:]
noise_pred_merged = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
return noise_pred_merged, noise_cond, noise_uncond


@partial(jax.jit, static_argnames=("guidance_scale",))
def transformer_forward_pass_cfg_cache(
graphdef,
sharded_state,
rest_of_state,
latents_cond: jnp.array,
timestep_cond: jnp.array,
prompt_cond_embeds: jnp.array,
cached_noise_cond: jnp.array,
cached_noise_uncond: jnp.array,
guidance_scale: float,
w1: float = 1.0,
w2: float = 1.0,
encoder_hidden_states_image=None,
):
"""CFG-Cache forward pass with FFT frequency-domain compensation.

FasterCache (Lv et al., ICLR 2025) CFG-Cache:
1. Compute frequency-domain bias: ΔF = FFT(uncond) - FFT(cond)
2. Split into low-freq (ΔLF) and high-freq (ΔHF) via spectral mask
3. Apply phase-dependent weights:
F_low = FFT(new_cond)_low + w1 * ΔLF
F_high = FFT(new_cond)_high + w2 * ΔHF
4. Reconstruct: uncond_approx = IFFT(F_low + F_high)

w1/w2 encode the denoising phase:
Early (high noise): w1=1+α, w2=1 → boost low-freq correction
Late (low noise): w1=1, w2=1+α → boost high-freq correction
where α=0.2 (FasterCache default).

On TPU this compiles to a single static XLA graph with half the batch size
of a full CFG pass.
"""
wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state)
noise_cond = wan_transformer(
hidden_states=latents_cond,
timestep=timestep_cond,
encoder_hidden_states=prompt_cond_embeds,
encoder_hidden_states_image=encoder_hidden_states_image,
)

# FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W]
fft_cond_cached = jnp.fft.rfft2(cached_noise_cond.astype(jnp.float32))
fft_uncond_cached = jnp.fft.rfft2(cached_noise_uncond.astype(jnp.float32))
fft_bias = fft_uncond_cached - fft_cond_cached

# Build low/high frequency mask (25% cutoff)
h = fft_bias.shape[-2]
w_rfft = fft_bias.shape[-1]
ch = jnp.maximum(1, h // 4)
cw = jnp.maximum(1, w_rfft // 4)
freq_h = jnp.arange(h)
freq_w = jnp.arange(w_rfft)
# Low-freq: indices near DC (0) in both dims; account for wrap-around in dim H
low_h = (freq_h < ch) | (freq_h >= h - ch + 1)
low_w = freq_w < cw
low_mask = (low_h[:, None] & low_w[None, :]).astype(jnp.float32)
high_mask = 1.0 - low_mask

# Apply phase-dependent weights to frequency bias
fft_bias_weighted = fft_bias * (low_mask * w1 + high_mask * w2)

# Reconstruct unconditional output
fft_cond_new = jnp.fft.rfft2(noise_cond.astype(jnp.float32))
fft_uncond_approx = fft_cond_new + fft_bias_weighted
noise_uncond_approx = jnp.fft.irfft2(
fft_uncond_approx, s=noise_cond.shape[-2:]
).astype(noise_cond.dtype)

noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx)
return noise_pred_merged, noise_cond
141 changes: 123 additions & 18 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .wan_pipeline import WanPipeline, transformer_forward_pass
from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache
from ...models.wan.transformers.transformer_wan import WanModel
from typing import List, Union, Optional
from ...pyconfig import HyperParameters
Expand Down Expand Up @@ -90,6 +90,7 @@ def __call__(
prompt_embeds: Optional[jax.Array] = None,
negative_prompt_embeds: Optional[jax.Array] = None,
vae_only: bool = False,
use_cfg_cache: bool = False,
):
latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs(
prompt,
Expand All @@ -114,6 +115,8 @@ def __call__(
num_inference_steps=num_inference_steps,
scheduler=self.scheduler,
scheduler_state=scheduler_state,
use_cfg_cache=use_cfg_cache,
height=height,
)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
Expand All @@ -140,26 +143,128 @@ def run_inference_2_1(
num_inference_steps: int,
scheduler: FlaxUniPCMultistepScheduler,
scheduler_state,
use_cfg_cache: bool = False,
height: int = 480,
):
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)
"""Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.

CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True):
- Full CFG steps : run transformer on [cond, uncond] batch (batch×2).
Cache raw noise_cond and noise_uncond for FFT bias.
- Cache steps : run transformer on cond batch only (batch×1).
Estimate uncond via FFT frequency-domain compensation:
ΔF = FFT(cached_uncond) - FFT(cached_cond)
Split ΔF into low-freq (ΔLF) and high-freq (ΔHF).
uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF)
Phase-dependent weights (α=0.2):
Early (high noise): w1=1.2, w2=1.0 (boost low-freq)
Late (low noise): w1=1.0, w2=1.2 (boost high-freq)
- Schedule : full CFG for the first 1/3 of steps, then
full CFG every 5 steps, cache the rest.

Two separately-compiled JAX-jitted functions handle full and cache steps so
XLA sees static shapes throughout — the key requirement for TPU efficiency.
"""
do_cfg = guidance_scale > 1.0
bsz = latents.shape[0]

# Resolution-dependent CFG cache config (FasterCache / MixCache guidance)
if height >= 720:
# 720p: conservative — protect last 40%, interval=5
cfg_cache_interval = 5
cfg_cache_start_step = int(num_inference_steps / 3)
cfg_cache_end_step = int(num_inference_steps * 0.9)
cfg_cache_alpha = 0.2
else:
# 480p: moderate — protect last 2 steps, interval=5
cfg_cache_interval = 5
cfg_cache_start_step = int(num_inference_steps / 3)
cfg_cache_end_step = num_inference_steps - 2
cfg_cache_alpha = 0.2

# Pre-split embeds once, outside the loop.
prompt_cond_embeds = prompt_embeds
prompt_embeds_combined = None
if do_cfg:
prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0)

# Pre-compute cache schedule and phase-dependent weights.
# t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq.
t0_step = num_inference_steps // 2
first_full_step_seen = False
step_is_cache = []
step_w1w2 = []
for s in range(num_inference_steps):
is_cache = (
use_cfg_cache
and do_cfg
and first_full_step_seen
and s >= cfg_cache_start_step
and s < cfg_cache_end_step
and (s - cfg_cache_start_step) % cfg_cache_interval != 0
)
step_is_cache.append(is_cache)
if not is_cache:
first_full_step_seen = True
# Phase-dependent weights: w = 1 + α·I(condition)
if s < t0_step:
step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # early: boost low-freq
else:
step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # late: boost high-freq

# Cache tensors (on-device JAX arrays, initialised to None).
cached_noise_cond = None
cached_noise_uncond = None

for step in range(num_inference_steps):
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
if do_classifier_free_guidance:
latents = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, latents.shape[0])

noise_pred, latents = transformer_forward_pass(
graphdef,
sharded_state,
rest_of_state,
latents,
timestep,
prompt_embeds,
do_classifier_free_guidance=do_classifier_free_guidance,
guidance_scale=guidance_scale,
)
is_cache_step = step_is_cache[step]

if is_cache_step:
# ── Cache step: cond-only forward + FFT frequency compensation ──
w1, w2 = step_w1w2[step]
timestep = jnp.broadcast_to(t, bsz)
noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache(
graphdef,
sharded_state,
rest_of_state,
latents,
timestep,
prompt_cond_embeds,
cached_noise_cond,
cached_noise_uncond,
guidance_scale=guidance_scale,
w1=jnp.float32(w1),
w2=jnp.float32(w2),
)

elif do_cfg:
# ── Full CFG step: doubled batch, store raw cond/uncond for cache ──
latents_doubled = jnp.concatenate([latents] * 2)
timestep = jnp.broadcast_to(t, bsz * 2)
noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg(
graphdef,
sharded_state,
rest_of_state,
latents_doubled,
timestep,
prompt_embeds_combined,
guidance_scale=guidance_scale,
)

else:
# ── No CFG (guidance_scale <= 1.0) ──
timestep = jnp.broadcast_to(t, bsz)
noise_pred, latents = transformer_forward_pass(
graphdef,
sharded_state,
rest_of_state,
latents,
timestep,
prompt_cond_embeds,
do_classifier_free_guidance=False,
guidance_scale=guidance_scale,
)

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents
Loading