diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 91a3e092..fa630961 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -324,6 +324,11 @@ num_frames: 81 guidance_scale: 5.0 flow_shift: 3.0 +# Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) +# Skips the unconditional forward pass on ~35% of steps via residual compensation. +# See: FasterCache (Lv et al. 2024), WAN 2.1 paper §4.4.2 +use_cfg_cache: False + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 num_inference_steps: 30 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index f53cc59b..828bc1a2 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -125,6 +125,7 @@ def call_pipeline(config, pipeline, prompt, negative_prompt): num_frames=config.num_frames, num_inference_steps=config.num_inference_steps, guidance_scale=config.guidance_scale, + use_cfg_cache=config.use_cfg_cache, ) elif model_key == WAN2_2: return pipeline( diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7c0314b4..31a224dc 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -778,3 +778,108 @@ def transformer_forward_pass( latents = latents[:bsz] return noise_pred, latents + + +@partial(jax.jit, static_argnames=("guidance_scale",)) +def transformer_forward_pass_full_cfg( + graphdef, + sharded_state, + rest_of_state, + latents_doubled: jnp.array, + timestep: jnp.array, + prompt_embeds_combined: jnp.array, + guidance_scale: float, + encoder_hidden_states_image=None, +): + """Full CFG forward pass. + + Accepts pre-doubled latents and pre-concatenated [cond, uncond] prompt embeds. + Returns the merged noise_pred plus raw noise_cond and noise_uncond for + CFG cache storage. Keeping cond/uncond separate avoids a second forward + pass on cache steps. + """ + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + bsz = latents_doubled.shape[0] // 2 + noise_pred = wan_transformer( + hidden_states=latents_doubled, + timestep=timestep, + encoder_hidden_states=prompt_embeds_combined, + encoder_hidden_states_image=encoder_hidden_states_image, + ) + noise_cond = noise_pred[:bsz] + noise_uncond = noise_pred[bsz:] + noise_pred_merged = noise_uncond + guidance_scale * (noise_cond - noise_uncond) + return noise_pred_merged, noise_cond, noise_uncond + + +@partial(jax.jit, static_argnames=("guidance_scale",)) +def transformer_forward_pass_cfg_cache( + graphdef, + sharded_state, + rest_of_state, + latents_cond: jnp.array, + timestep_cond: jnp.array, + prompt_cond_embeds: jnp.array, + cached_noise_cond: jnp.array, + cached_noise_uncond: jnp.array, + guidance_scale: float, + w1: float = 1.0, + w2: float = 1.0, + encoder_hidden_states_image=None, +): + """CFG-Cache forward pass with FFT frequency-domain compensation. + + FasterCache (Lv et al., ICLR 2025) CFG-Cache: + 1. Compute frequency-domain bias: ΔF = FFT(uncond) - FFT(cond) + 2. Split into low-freq (ΔLF) and high-freq (ΔHF) via spectral mask + 3. Apply phase-dependent weights: + F_low = FFT(new_cond)_low + w1 * ΔLF + F_high = FFT(new_cond)_high + w2 * ΔHF + 4. Reconstruct: uncond_approx = IFFT(F_low + F_high) + + w1/w2 encode the denoising phase: + Early (high noise): w1=1+α, w2=1 → boost low-freq correction + Late (low noise): w1=1, w2=1+α → boost high-freq correction + where α=0.2 (FasterCache default). + + On TPU this compiles to a single static XLA graph with half the batch size + of a full CFG pass. + """ + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) + noise_cond = wan_transformer( + hidden_states=latents_cond, + timestep=timestep_cond, + encoder_hidden_states=prompt_cond_embeds, + encoder_hidden_states_image=encoder_hidden_states_image, + ) + + # FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W] + fft_cond_cached = jnp.fft.rfft2(cached_noise_cond.astype(jnp.float32)) + fft_uncond_cached = jnp.fft.rfft2(cached_noise_uncond.astype(jnp.float32)) + fft_bias = fft_uncond_cached - fft_cond_cached + + # Build low/high frequency mask (25% cutoff) + h = fft_bias.shape[-2] + w_rfft = fft_bias.shape[-1] + ch = jnp.maximum(1, h // 4) + cw = jnp.maximum(1, w_rfft // 4) + freq_h = jnp.arange(h) + freq_w = jnp.arange(w_rfft) + # Low-freq: indices near DC (0) in both dims; account for wrap-around in dim H + low_h = (freq_h < ch) | (freq_h >= h - ch + 1) + low_w = freq_w < cw + low_mask = (low_h[:, None] & low_w[None, :]).astype(jnp.float32) + high_mask = 1.0 - low_mask + + # Apply phase-dependent weights to frequency bias + fft_bias_weighted = fft_bias * (low_mask * w1 + high_mask * w2) + + # Reconstruct unconditional output + fft_cond_new = jnp.fft.rfft2(noise_cond.astype(jnp.float32)) + fft_uncond_approx = fft_cond_new + fft_bias_weighted + noise_uncond_approx = jnp.fft.irfft2( + fft_uncond_approx, s=noise_cond.shape[-2:] + ).astype(noise_cond.dtype) + + noise_pred_merged = noise_uncond_approx + guidance_scale * (noise_cond - noise_uncond_approx) + return noise_pred_merged, noise_cond diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index c247facb..b50c4be6 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .wan_pipeline import WanPipeline, transformer_forward_pass +from .wan_pipeline import WanPipeline, transformer_forward_pass, transformer_forward_pass_full_cfg, transformer_forward_pass_cfg_cache from ...models.wan.transformers.transformer_wan import WanModel from typing import List, Union, Optional from ...pyconfig import HyperParameters @@ -90,6 +90,7 @@ def __call__( prompt_embeds: Optional[jax.Array] = None, negative_prompt_embeds: Optional[jax.Array] = None, vae_only: bool = False, + use_cfg_cache: bool = False, ): latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, @@ -114,6 +115,8 @@ def __call__( num_inference_steps=num_inference_steps, scheduler=self.scheduler, scheduler_state=scheduler_state, + use_cfg_cache=use_cfg_cache, + height=height, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -140,26 +143,128 @@ def run_inference_2_1( num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, scheduler_state, + use_cfg_cache: bool = False, + height: int = 480, ): - do_classifier_free_guidance = guidance_scale > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. + + CFG-Cache strategy (Lv et al., ICLR 2025, enabled via use_cfg_cache=True): + - Full CFG steps : run transformer on [cond, uncond] batch (batch×2). + Cache raw noise_cond and noise_uncond for FFT bias. + - Cache steps : run transformer on cond batch only (batch×1). + Estimate uncond via FFT frequency-domain compensation: + ΔF = FFT(cached_uncond) - FFT(cached_cond) + Split ΔF into low-freq (ΔLF) and high-freq (ΔHF). + uncond_approx = IFFT(FFT(new_cond) + w1*ΔLF + w2*ΔHF) + Phase-dependent weights (α=0.2): + Early (high noise): w1=1.2, w2=1.0 (boost low-freq) + Late (low noise): w1=1.0, w2=1.2 (boost high-freq) + - Schedule : full CFG for the first 1/3 of steps, then + full CFG every 5 steps, cache the rest. + + Two separately-compiled JAX-jitted functions handle full and cache steps so + XLA sees static shapes throughout — the key requirement for TPU efficiency. + """ + do_cfg = guidance_scale > 1.0 + bsz = latents.shape[0] + + # Resolution-dependent CFG cache config (FasterCache / MixCache guidance) + if height >= 720: + # 720p: conservative — protect last 40%, interval=5 + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = int(num_inference_steps * 0.9) + cfg_cache_alpha = 0.2 + else: + # 480p: moderate — protect last 2 steps, interval=5 + cfg_cache_interval = 5 + cfg_cache_start_step = int(num_inference_steps / 3) + cfg_cache_end_step = num_inference_steps - 2 + cfg_cache_alpha = 0.2 + + # Pre-split embeds once, outside the loop. + prompt_cond_embeds = prompt_embeds + prompt_embeds_combined = None + if do_cfg: + prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + # Pre-compute cache schedule and phase-dependent weights. + # t₀ = midpoint step; before t₀ boost low-freq, after boost high-freq. + t0_step = num_inference_steps // 2 + first_full_step_seen = False + step_is_cache = [] + step_w1w2 = [] + for s in range(num_inference_steps): + is_cache = ( + use_cfg_cache + and do_cfg + and first_full_step_seen + and s >= cfg_cache_start_step + and s < cfg_cache_end_step + and (s - cfg_cache_start_step) % cfg_cache_interval != 0 + ) + step_is_cache.append(is_cache) + if not is_cache: + first_full_step_seen = True + # Phase-dependent weights: w = 1 + α·I(condition) + if s < t0_step: + step_w1w2.append((1.0 + cfg_cache_alpha, 1.0)) # early: boost low-freq + else: + step_w1w2.append((1.0, 1.0 + cfg_cache_alpha)) # late: boost high-freq + + # Cache tensors (on-device JAX arrays, initialised to None). + cached_noise_cond = None + cached_noise_uncond = None + for step in range(num_inference_steps): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - - noise_pred, latents = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance=do_classifier_free_guidance, - guidance_scale=guidance_scale, - ) + is_cache_step = step_is_cache[step] + + if is_cache_step: + # ── Cache step: cond-only forward + FFT frequency compensation ── + w1, w2 = step_w1w2[step] + timestep = jnp.broadcast_to(t, bsz) + noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_cond_embeds, + cached_noise_cond, + cached_noise_uncond, + guidance_scale=guidance_scale, + w1=jnp.float32(w1), + w2=jnp.float32(w2), + ) + + elif do_cfg: + # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── + latents_doubled = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, bsz * 2) + noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + graphdef, + sharded_state, + rest_of_state, + latents_doubled, + timestep, + prompt_embeds_combined, + guidance_scale=guidance_scale, + ) + + else: + # ── No CFG (guidance_scale <= 1.0) ── + timestep = jnp.broadcast_to(t, bsz) + noise_pred, latents = transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_cond_embeds, + do_classifier_free_guidance=False, + guidance_scale=guidance_scale, + ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents