diff --git a/defuser/modeling/unfused_moe/qwen2_moe.py b/defuser/modeling/unfused_moe/qwen2_moe.py index 7368fa7..51b89a5 100644 --- a/defuser/modeling/unfused_moe/qwen2_moe.py +++ b/defuser/modeling/unfused_moe/qwen2_moe.py @@ -34,6 +34,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Route tokens exactly like HF Qwen2 MoE, then run explicit expert modules.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + shared_expert_output = self.shared_expert(hidden_states) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = run_routed_experts( @@ -44,7 +45,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.num_experts, ) - shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output diff --git a/defuser/modeling/unfused_moe/qwen3_next.py b/defuser/modeling/unfused_moe/qwen3_next.py index af58062..8d4510a 100644 --- a/defuser/modeling/unfused_moe/qwen3_next.py +++ b/defuser/modeling/unfused_moe/qwen3_next.py @@ -33,6 +33,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Route tokens exactly like HF Qwen3-Next MoE, then run explicit experts.""" batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) + shared_expert_output = self.shared_expert(hidden_states) _, routing_weights, selected_experts = self.gate(hidden_states) routing_weights = routing_weights.to(hidden_states.dtype) final_hidden_states = run_routed_experts( @@ -43,7 +44,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: self.num_experts, ) - shared_expert_output = self.shared_expert(hidden_states) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output final_hidden_states = final_hidden_states + shared_expert_output diff --git a/pyproject.toml b/pyproject.toml index 7b87412..50ae60b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "Defuser" -version = "0.0.20" +version = "0.0.21" description = "Model defuser helper for HF Transformers." readme = "README.md" requires-python = ">=3.9" diff --git a/tests/test_convert_model.py b/tests/test_convert_model.py index 475c779..3cfbdbf 100644 --- a/tests/test_convert_model.py +++ b/tests/test_convert_model.py @@ -402,6 +402,76 @@ def _assert_sparse_moe_defused_matches_fused_math( torch.testing.assert_close(actual, expected, **assert_close_kwargs) +def _force_route_all_experts(block: nn.Module) -> None: + """Set MoE routers to select all experts so execution-order hooks always fire.""" + + router = getattr(block, "gate", None) + num_experts = getattr(block, "num_experts", None) + if router is None or num_experts is None: + return + + for name in ("top_k", "num_experts_per_tok"): + if hasattr(router, name): + setattr(router, name, num_experts) + return + + +def _semantic_sparse_moe_execution_order(block: nn.Module, hidden_states: torch.Tensor) -> list[str]: + """Record semantic MoE execution order for shared expert, router, routed experts, and shared gate.""" + + _force_route_all_experts(block) + raw_events: list[str] = [] + handles = [] + + def _record(event_name: str): + def _hook(_module, _inputs): + raw_events.append(event_name) + return _hook + + if hasattr(block, "shared_expert"): + handles.append(block.shared_expert.register_forward_pre_hook(_record("shared_expert"))) + if hasattr(block, "gate"): + handles.append(block.gate.register_forward_pre_hook(_record("gate"))) + if hasattr(block, "shared_expert_gate"): + handles.append(block.shared_expert_gate.register_forward_pre_hook(_record("shared_expert_gate"))) + + experts = getattr(block, "experts", None) + if isinstance(experts, nn.ModuleList): + for idx, expert in enumerate(experts): + handles.append(expert.register_forward_pre_hook(_record(f"expert_{idx}"))) + elif isinstance(experts, nn.Module): + handles.append(experts.register_forward_pre_hook(_record("experts"))) + + try: + with torch.inference_mode(): + block.eval()(hidden_states) + finally: + for handle in handles: + handle.remove() + + semantic_events: list[str] = [] + for event in raw_events: + normalized = "routed_experts" if event.startswith("expert_") or event == "experts" else event + if not semantic_events or semantic_events[-1] != normalized: + semantic_events.append(normalized) + return semantic_events + + +def _assert_sparse_moe_defused_matches_fused_execution_order( + original_block: nn.Module, + defused_block: nn.Module, + hidden_states: torch.Tensor, +) -> None: + """Defused blocks must preserve the same semantic execution order as fused HF blocks.""" + + _seed_floating_tensors(original_block) + _copy_sparse_moe_weights(original_block, defused_block) + + expected = _semantic_sparse_moe_execution_order(original_block, hidden_states) + actual = _semantic_sparse_moe_execution_order(defused_block, hidden_states) + assert actual == expected + + def test_qwen2_moe(): model_type = "qwen2_moe" replace_fused_blocks(model_type) @@ -858,6 +928,17 @@ def test_qwen2_moe_defused_forward_matches_fused_math(): ) +def test_qwen2_moe_defused_forward_matches_fused_execution_order(): + config = _tiny_moe_config(Qwen2MoeConfig) + hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) + + _assert_sparse_moe_defused_matches_fused_execution_order( + Qwen2MoeSparseMoeBlock(config), + LinearQwen2MoeSparseMoeBlock(config), + hidden_states, + ) + + def test_qwen3_moe_defused_forward_matches_fused_math(): config = _tiny_moe_config(Qwen3MoeConfig) hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) @@ -880,6 +961,17 @@ def test_qwen3_next_defused_forward_matches_fused_math(): ) +def test_qwen3_next_defused_forward_matches_fused_execution_order(): + config = _tiny_moe_config(Qwen3NextConfig) + hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32) + + _assert_sparse_moe_defused_matches_fused_execution_order( + Qwen3NextSparseMoeBlock(config), + LinearQwen3NextSparseMoeBlock(config), + hidden_states, + ) + + def test_qwen3_omni_defused_forward_matches_fused_math(): config = _tiny_qwen3_omni_config().thinker_config.text_config hidden_states = torch.randn(2, 3, config.hidden_size, dtype=torch.float32)