Skip to content

Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models#1033

Merged
Edwardf0t1 merged 4 commits intomainfrom
zhiyu/handle-moe-w13-scales
Mar 16, 2026
Merged

Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models#1033
Edwardf0t1 merged 4 commits intomainfrom
zhiyu/handle-moe-w13-scales

Conversation

@Edwardf0t1
Copy link
Contributor

@Edwardf0t1 Edwardf0t1 commented Mar 13, 2026

What does this PR do?

Unify weight_scale_2 between gate_proj/up_proj (and w1/w3) in the HF export path for MOE models. Serving engines fuse these projections into a single gate_up_proj and require a shared scale; this takes the element-wise max of the two independent scales as a conservative choice that avoids overflow.

Type of change: ? Bug fix

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • New Features
    • Automatic synchronization of quantization scaling between Mixture-of-Experts gate and up projections during model export for non‑fused MoE setups (e.g., Qwen MoE, DeepSeek).
  • Bug Fixes / Improvements
    • Export now emits a brief notification when gate/up scaling values are adjusted to ensure consistent quantization.

…pe in export

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
…pe in export

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@Edwardf0t1 Edwardf0t1 marked this pull request as ready for review March 13, 2026 05:50
@Edwardf0t1 Edwardf0t1 requested a review from a team as a code owner March 13, 2026 05:50
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

Adds a new MoE gate/up amax synchronization function and invokes it during transformer checkpoint export to align weight-quantizer amax values for non-fused MoE expert gate/up pairs.

Changes

Cohort / File(s) Summary
MoE Synchronization Implementation
modelopt/torch/export/layer_utils.py
Adds _GATE_UP_PAIRS constant and sync_moe_gate_up_amax(model: nn.Module) -> int, which traverses MoE modules and experts, compares gate and up weight_quantizer.amax values, updates both to their element-wise maximum for non-fused gate/up pairs, and returns the number of synced pairs.
Export Integration
modelopt/torch/export/unified_export_hf.py
Imports sync_moe_gate_up_amax and calls it at the end of _export_transformers_checkpoint; if syncing occurred, prints a warning/status message indicating synced gate/up projection amaxes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title directly and accurately summarizes the main change: unifying weight_scale_2 between gate_proj/up_proj in HF export for MOE models, matching the core functionality added.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No critical security anti-patterns found. Code lacks torch.load with weights_only=False, numpy.load with allow_pickle=True, hardcoded trust_remote_code=True, eval/exec on untrusted input, nosec comments, or non-permissive license dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch zhiyu/handle-moe-w13-scales
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

Flake8 can be used to improve the quality of Python code reviews.

Flake8 is a Python linter that wraps PyFlakes, pycodestyle and Ned Batchelder's McCabe script.

To configure Flake8, add a '.flake8' or 'setup.cfg' file to your project root.

See Flake8 Documentation for more details.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the Hugging Face export path to ensure weight_scale_2 is shared between fused MoE MLP projection pairs (gate/up), matching serving-engine expectations when they fuse these projections into a single kernel.

Changes:

  • Add max_gate_up_scales() utility to replace gate_proj/up_proj (and w1/w3) weight_scale_2 pairs with their element-wise max in the exported state dict.
  • Invoke the new post-processing step at the end of _export_transformers_checkpoint() after postprocess_state_dict().

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
modelopt/torch/export/unified_export_hf.py Calls max_gate_up_scales() during HF checkpoint export and reports how many pairs were tied.
modelopt/torch/export/quant_utils.py Introduces max_gate_up_scales() to unify weight_scale_2 across gate/up projection pairs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +795 to +797
tied = max_gate_up_scales(quantized_state_dict)
if tied:
print(f"Tied weight_scale_2 for {tied} gate/up projection pair(s) in MoE experts.")
Comment on lines +1158 to +1168
"""Replace gate_proj and up_proj weight_scale_2 with their element-wise max.

For MOE models where gate_proj and up_proj are quantized independently,
serving engines typically fuse them into a single gate_up_proj and need
a single shared scale. Using max is conservative (avoids overflow at the
cost of slightly reduced dynamic range).
"""
suffix_pairs = {
".gate_proj.weight_scale_2": ".up_proj.weight_scale_2",
".w1.weight_scale_2": ".w3.weight_scale_2",
}
@codecov
Copy link

codecov bot commented Mar 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.09%. Comparing base (bc87981) to head (c521b8a).
⚠️ Report is 10 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1033      +/-   ##
==========================================
- Coverage   70.11%   70.09%   -0.03%     
==========================================
  Files         221      221              
  Lines       25459    25541      +82     
==========================================
+ Hits        17851    17902      +51     
- Misses       7608     7639      +31     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

return post_state_dict


def max_gate_up_scales(tensors: dict[str, torch.Tensor]) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may not work. We also need to requantize ws1 if the ws2 changes. I think we need to do it here instead: https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L230

Copy link
Contributor Author

@Edwardf0t1 Edwardf0t1 Mar 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it's safer to sync the _amax on the weight quantizers of gate/up pairs. In @lukealonso's implementation it's done in both amax level and the post-processing stage, and he verified the accuracy which seems good.

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L396-L418 this logic is used to align w1 and w3 scales.

We might want to debug why this fails for qwen3.5

@Edwardf0t1
Copy link
Contributor Author

https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/export/unified_export_hf.py#L396-L418 this logic is used to align w1 and w3 scales.

We might want to debug why this fails for qwen3.5

It failed for kimi-k2.5 as well. That's exactly why we needed this PR to fix it.

quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
)

tied = max_gate_up_scales(quantized_state_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do this before compressing the weights, i.e., syncing the amax of gate and up? Is my understanding correct that the block-wise scales are still based on the separate global scales?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right - updated to sync at the amax level.

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/layer_utils.py`:
- Around line 1206-1213: The loop that iterates over _GATE_UP_PAIRS incorrectly
uses break when a pair is missing a quantizer or amax, which stops checking
remaining pairs; update the logic in the loop that accesses
gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to use continue
instead of break so only the current pair is skipped and the next pair in
_GATE_UP_PAIRS is checked; keep the final break that exits after successfully
syncing unchanged.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c2d4c1ed-e8c9-4c53-b650-0b34276086cf

📥 Commits

Reviewing files that changed from the base of the PR and between 7398fea and 17b7f46.

📒 Files selected for processing (2)
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/export/unified_export_hf.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/unified_export_hf.py

Comment on lines +1206 to +1213
gate_wq = getattr(gate_linear, "weight_quantizer", None)
up_wq = getattr(up_linear, "weight_quantizer", None)
if gate_wq is None or up_wq is None:
break
gate_amax = getattr(gate_wq, "amax", None)
up_amax = getattr(up_wq, "amax", None)
if gate_amax is None or up_amax is None:
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

❓ Verification inconclusive

Script executed:

#!/bin/bash
set -euo pipefail

echo "== SequentialQuantizer API =="
fd 'tensor_quantizer.py' --exec rg -n -C2 'class SequentialQuantizer|def amax|@property'

echo
echo "== Where SequentialQuantizer is used for weight quantizers =="
rg -n -C3 --type=py 'SequentialQuantizer\(|weight_quantizer' modelopt/torch

echo
echo "== Existing unwrapping pattern in layer_utils.py =="
rg -n -C3 --type=py 'isinstance\(.*SequentialQuantizer\)|weight_quantizer\.amax|sync_moe_gate_up_amax' modelopt/torch/export/layer_utils.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

== SequentialQuantizer API ==
252-        return qtensor.dequantize(**kwarg)
253-
254:    `@property`
255-    def num_bits(self):
256-        """Return num_bits for quantization."""
--
262-        self._calibrator._num_bits = value
263-
264:    `@property`
265-    def maxbound(self):
266-        """Return maxbound for quantization."""
--
271-        return (1 << (self._num_bits - 1 + int(self._unsigned))) - 1
272-
273:    `@property`
274-    def unsigned(self):
275-        """Return True if unsigned quantization is used."""
--
281-        self._calibrator._unsigned = value
282-
283:    `@property`
284-    def pre_quant_scale(self):
285-        """Return pre_quant_scale used for smoothquant."""
--
305-            )
306-
307:    `@property`
308:    def amax(self):
309-        """Return amax for quantization."""
310-        if not hasattr(self, "_amax") or self.is_mx_format:
--
314-
315-    `@amax.setter`
316:    def amax(self, value):
317-        assert value is not None, "amax cannot be set to None."
318-
--
341-            self._bias_calibrator.reset()
342-
343:    `@property`
344-    def step_size(self):
345-        """Return step size for integer quantization."""
--
352-        return self._amax / (2.0 ** (self._num_bits - 1 + int(self._unsigned)) - 1.0)
353-
354:    `@property`
355-    def axis(self):
356-        """Return axis for quantization."""
--
362-        self._calibrator._axis = value
363-
364:    `@property`
365-    def block_sizes(self):
366-        """Return block_sizes for quantization."""
--
372-        self._block_sizes = value
373-
374:    `@property`
375-    def bias(self):
376-        """Return bias for quantization."""
--
379-        return self._bias
380-
381:    `@property`
382-    def bias_axis(self):
383-        """Return bias_axis for quantization."""
--
392-        self._bias_axis = value
393-
394:    `@property`
395-    def bias_method(self):
396-        """Return bias_method for quantization."""
--
399-        return self._bias.get("method", "mean")
400-
401:    `@property`
402-    def bias_type(self):
403-        """Return bias_type for quantization."""
--
414-        self._bias["type"] = value
415-
416:    `@property`
417-    def bias_value(self):
418-        """Return bias for quantization."""
--
435-            self._bias_value.data.copy_(value.clone().detach().to(self._bias_value.device))
436-
437:    `@property`
438-    def bias_calibrator(self):
439-        """Return bias_calibrator for quantization."""
--
450-        return self._bias_calibrator
451-
452:    `@property`
453-    def fake_quant(self):
454-        """Return True if fake quantization is used."""
455-        return self._fake_quant
456-
457:    `@property`
458-    def narrow_range(self):
459-        """Return True if symmetric integer range for signed quantization is used."""
--
464-        self._narrow_range = value
465-
466:    `@property`
467-    def is_enabled(self):
468-        """Return true if the modules is not disabled."""
--
480-        self._disabled = False
481-
482:    `@property`
483-    def trt_high_precision_dtype(self):
484-        """Return True if FP16 AMAX is used when exporting the model."""
--
489-        self._trt_high_precision_dtype = value
490-
491:    `@property`
492-    def is_mx_format(self):
493-        """Check if is MX formats."""
--
521-            raise NotImplementedError()
522-
523:    `@property`
524-    def is_static_block_quant(self):
525-        """Check if is static block quantization."""
--
530-        )
531-
532:    `@property`
533-    def rotate_is_enabled(self):
534-        """Check if rotate is enabled in quant config."""
535-        return self._rotate.get("enable", False) if isinstance(self._rotate, dict) else self._rotate
536-
537:    `@property`
538-    def rotate_is_fp32(self):
539-        """Check if rotation needs to be computed in float32."""
--
1286-        return tq
1287-
1288:    `@property`
1289-    def global_amax(self):
1290-        """Return global_amax for quantization."""
--
1320-
1321-
1322:class SequentialQuantizer(nn.Sequential):
1323-    """A sequential container for  :class:`TensorQuantizer` modules.
1324-

== Where SequentialQuantizer is used for weight quantizers ==
modelopt/torch/quantization/utils.py-228-
modelopt/torch/quantization/utils.py-229-    # the standard weight and quantizer case
modelopt/torch/quantization/utils.py-230-    weight = getattr(module, "weight", None)
modelopt/torch/quantization/utils.py:231:    weight_quantizer = getattr(module, "weight_quantizer", None)
modelopt/torch/quantization/utils.py:232:    if isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
modelopt/torch/quantization/utils.py-233-        yield "weight"
modelopt/torch/quantization/utils.py-234-
modelopt/torch/quantization/utils.py-235-    # other weight and quantizer case
modelopt/torch/quantization/utils.py-236-    for name, _ in module.named_parameters(recurse=False):
modelopt/torch/quantization/utils.py-237-        weight = getattr(module, name, None)
modelopt/torch/quantization/utils.py:238:        weight_quantizer = getattr(module, f"{name}_weight_quantizer", None)
modelopt/torch/quantization/utils.py-239-        if isinstance(weight, nn.Parameter) and isinstance(
modelopt/torch/quantization/utils.py:240:            weight_quantizer, (TensorQuantizer, SequentialQuantizer)
modelopt/torch/quantization/utils.py-241-        ):
modelopt/torch/quantization/utils.py-242-            yield name
modelopt/torch/quantization/utils.py-243-
--
modelopt/torch/quantization/utils.py-246-QuantizerAttrNames = namedtuple(
modelopt/torch/quantization/utils.py-247-    "QuantizerAttrNames",
modelopt/torch/quantization/utils.py-248-    (
modelopt/torch/quantization/utils.py:249:        "weight_quantizer",
modelopt/torch/quantization/utils.py-250-        "input_quantizer",
modelopt/torch/quantization/utils.py-251-        "output_quantizer",
modelopt/torch/quantization/utils.py-252-        "weight_scale",
--
modelopt/torch/quantization/utils.py-261-    """Get all the quantizer related attribute names for a given weight name."""
modelopt/torch/quantization/utils.py-262-    prefix = f"{weight_name}_" if weight_name != "weight" else ""
modelopt/torch/quantization/utils.py-263-    return QuantizerAttrNames(
modelopt/torch/quantization/utils.py:264:        weight_quantizer=f"{prefix}weight_quantizer",
modelopt/torch/quantization/utils.py-265-        input_quantizer=f"{prefix}input_quantizer",
modelopt/torch/quantization/utils.py-266-        output_quantizer=f"{prefix}output_quantizer",
modelopt/torch/quantization/utils.py-267-        weight_scale=f"{prefix}weight_scale",
--
modelopt/torch/quantization/utils.py-285-    return (
modelopt/torch/quantization/utils.py-286-        isinstance(module, QuantModule)
modelopt/torch/quantization/utils.py-287-        and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer)
modelopt/torch/quantization/utils.py:288:        and hasattr(module, "weight_quantizer")
modelopt/torch/quantization/utils.py-289-        and (
modelopt/torch/quantization/utils.py-290-            (getattr(module, "weight", None) is not None and module.weight.dim() == 2)
modelopt/torch/quantization/utils.py-291-            # module.weight0 check is required to support TEGroupedLinear
--
modelopt/torch/quantization/utils.py-329-    config["quant_cfg"]["*lora*"] = {"enable": False}
modelopt/torch/quantization/utils.py-330-    for layer in layers:
modelopt/torch/quantization/utils.py-331-        config["quant_cfg"][f"*{layer}.input_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py:332:        config["quant_cfg"][f"*{layer}.weight_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-333-        config["quant_cfg"][f"*{layer}.output_quantizer"] = {"enable": False}
modelopt/torch/quantization/utils.py-334-    return config
modelopt/torch/quantization/utils.py-335-
--
modelopt/torch/quantization/utils.py-537-
modelopt/torch/quantization/utils.py-538-    1. Takes the element-wise max of each ``input_quantizer`` amax across all experts
modelopt/torch/quantization/utils.py-539-       and writes it back, so every expert shares the same input amax.
modelopt/torch/quantization/utils.py:540:    2. For any ``weight_quantizer`` that is enabled but has ``amax is None`` (expert
modelopt/torch/quantization/utils.py-541-       received no tokens during calibration), runs a weight-only ``max_calibrate``
modelopt/torch/quantization/utils.py-542-       to populate the missing amax.
modelopt/torch/quantization/utils.py-543-    """
--
modelopt/torch/quantization/utils.py-566-
modelopt/torch/quantization/utils.py-567-    for expert in experts:
modelopt/torch/quantization/utils.py-568-        for name, module in expert.named_modules():
modelopt/torch/quantization/utils.py:569:            if name.endswith("weight_quantizer") and module.is_enabled and module.amax is None:
modelopt/torch/quantization/utils.py:570:                weight = expert.state_dict().get(name.replace("weight_quantizer", "weight"))
modelopt/torch/quantization/utils.py-571-                if weight is not None:
modelopt/torch/quantization/utils.py-572-                    max_calibrate(module, lambda m, w=weight: m(w), distributed_sync=False)
modelopt/torch/quantization/utils.py-573-
--
modelopt/torch/quantization/utils.py-687-    """
modelopt/torch/quantization/utils.py-688-    original_fake_quant = []
modelopt/torch/quantization/utils.py-689-    for m in module.modules():
modelopt/torch/quantization/utils.py:690:        if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:691:            original_fake_quant.append(m.weight_quantizer._fake_quant)
modelopt/torch/quantization/utils.py:692:            m.weight_quantizer._fake_quant = True
modelopt/torch/quantization/utils.py-693-    yield
modelopt/torch/quantization/utils.py-694-    for m in module.modules():
modelopt/torch/quantization/utils.py:695:        if hasattr(m, "weight_quantizer"):
modelopt/torch/quantization/utils.py:696:            m.weight_quantizer._fake_quant = original_fake_quant.pop(0)
modelopt/torch/quantization/utils.py-697-
modelopt/torch/quantization/utils.py-698-
modelopt/torch/quantization/utils.py-699-@contextmanager
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-97-    def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-98-        cls,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-99-        weight: torch.Tensor,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:100:        weight_quantizer,
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-101-    ) -> torch.Tensor:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-102-        """Returns E8M0 scale from quantizer or computes from weight.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-103-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-107-        Args:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-108-            weight: The weight tensor. Can be 2D (out_dim, in_dim) or
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-109-                3D for MoE (num_experts, out_dim, in_dim).
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:110:            weight_quantizer: The weight quantizer with block_sizes and optional _scale.
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-111-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-112-        Returns:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-113-            torch.Tensor: E8M0 scale as uint8 tensor with shape [..., out_dim, in_dim // 32].
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-114-        """
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:115:        assert hasattr(weight_quantizer, "block_sizes"), (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:116:            "weight_quantizer must have 'block_sizes' attribute"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-117-        )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:118:        assert weight_quantizer.block_sizes[-1] == cls.BLOCK_SIZE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:119:            f"MXFP8 requires block size {cls.BLOCK_SIZE}, got {weight_quantizer.block_sizes[-1]}"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-120-        )
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-121-        assert weight.dim() >= 2, f"Weight must be at least 2D, got {weight.dim()}D"
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-122-
--
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-126-        # For 3D MoE: (num_experts, out_dim, in_dim // 32)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-127-        expected_shape = (*weight.shape[:-1], in_dim // cls.BLOCK_SIZE)
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-128-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:129:        if hasattr(weight_quantizer, "_scale") and weight_quantizer._scale is not None:
modelopt/torch/quantization/qtensor/mxfp8_tensor.py:130:            scale = weight_quantizer._scale
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-131-
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-132-            assert scale.dtype == cls.SCALE_DTYPE, (
modelopt/torch/quantization/qtensor/mxfp8_tensor.py-133-                f"MXFP8 scale must be {cls.SCALE_DTYPE} (E8M0 format), got {scale.dtype}"
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-53-        return cls.e2m1_bounds_on_device[device]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-54-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-55-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:56:    def _is_static_quantizer(cls, weight_quantizer) -> bool:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-57-        """Check if the weight quantizer is a static NVFP4 quantizer with pre-computed amax."""
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:58:        return hasattr(weight_quantizer, "global_amax") and weight_quantizer.global_amax is not None
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-59-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-60-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:61:    def get_weights_scaling_factor_2_from_quantizer(cls, weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:62:        """Returns per tensor weight scaling factor from the weight_quantizer.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-63-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-64-        Handles both static NVFP4 quantizers (using global_amax) and
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-65-        dynamic quantizers (using _amax).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-66-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-67-        Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:68:            weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-69-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-70-        Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-71-            The global scaling factor as a float tensor.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-72-        """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:73:        if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:74:            return weight_quantizer.global_amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-75-        else:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:76:            assert hasattr(weight_quantizer, "_amax"), (
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-77-                "Weight quantizer does not have attribute amax"
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-78-            )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:79:            return weight_quantizer._amax.float() / (6.0 * 448.0)
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-80-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-81-    `@classmethod`
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-82-    def get_weights_scaling_factor_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-83-        cls,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:84:        weight_quantizer,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-85-        weight: torch.Tensor,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-86-        weights_scaling_factor_2: torch.Tensor | None = None,
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-87-        keep_high_precision: bool = False,
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-92-        and dynamic quantizers (computing from weight tensor).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-93-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-94-        Args:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:95:            weight_quantizer: The weight quantizer (static or dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-96-            weight: The weight tensor (used for shape in static, values in dynamic).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-97-            weights_scaling_factor_2: Optional pre-computed global scale.
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-98-            keep_high_precision: Whether to keep scales in high precision.
--
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-100-        Returns:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-101-            Tuple of (per_block_scale, weights_scaling_factor_2).
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-102-        """
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:103:        block_size = weight_quantizer.block_sizes[-1]
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-104-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-105-        if weights_scaling_factor_2 is None:
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-106-            weights_scaling_factor_2 = cls.get_weights_scaling_factor_2_from_quantizer(
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:107:                weight_quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-108-            )
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-109-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:110:        if cls._is_static_quantizer(weight_quantizer):
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-111-            # Static path: use pre-computed per-block amax values from quantizer
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:112:            global_amax = weight_quantizer.global_amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py:113:            per_block_amax = weight_quantizer._amax.float()
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-114-
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-115-            # Compute scales in float
modelopt/torch/quantization/qtensor/nvfp4_tensor.py-116-            per_block_scale_max = global_amax / 6.0
--
modelopt/torch/quantization/qtensor/base_qtensor.py-204-            # We dont compress meta tensors or None
modelopt/torch/quantization/qtensor/base_qtensor.py-205-            return False
modelopt/torch/quantization/qtensor/base_qtensor.py-206-        if (
modelopt/torch/quantization/qtensor/base_qtensor.py:207:            hasattr(module, "weight_quantizer")
modelopt/torch/quantization/qtensor/base_qtensor.py:208:            and module.weight_quantizer.is_enabled
modelopt/torch/quantization/qtensor/base_qtensor.py:209:            and not module.weight_quantizer._fake_quant
modelopt/torch/quantization/qtensor/base_qtensor.py-210-            and module.weight.element_size() > 1
modelopt/torch/quantization/qtensor/base_qtensor.py-211-        ):
modelopt/torch/quantization/qtensor/base_qtensor.py-212-            if force_quantize:
modelopt/torch/quantization/qtensor/base_qtensor.py:213:                module.weight_quantizer._dequantize = False
modelopt/torch/quantization/qtensor/base_qtensor.py-214-
modelopt/torch/quantization/qtensor/base_qtensor.py:215:            real_quant_tensor = module.weight_quantizer(module.weight)
modelopt/torch/quantization/qtensor/base_qtensor.py-216-            module.weight = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/qtensor/base_qtensor.py-217-            return True
modelopt/torch/quantization/qtensor/base_qtensor.py-218-
--
modelopt/torch/quantization/plugins/vllm.py-80-            torch.Tensor: The quantized output tensor.
modelopt/torch/quantization/plugins/vllm.py-81-        """
modelopt/torch/quantization/plugins/vllm.py-82-        x = layer.input_quantizer(x)
modelopt/torch/quantization/plugins/vllm.py:83:        if layer.weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-84-            original_weight = layer.weight
modelopt/torch/quantization/plugins/vllm.py:85:            quantized_tensor = layer.weight_quantizer(layer.weight)
modelopt/torch/quantization/plugins/vllm.py-86-            # parameterize the quantized weight
modelopt/torch/quantization/plugins/vllm.py-87-            if isinstance(original_weight, torch.nn.Parameter) and not isinstance(
modelopt/torch/quantization/plugins/vllm.py-88-                quantized_tensor, torch.nn.Parameter
--
modelopt/torch/quantization/plugins/vllm.py-110-class _VLLMParallelLinear(QuantModule):
modelopt/torch/quantization/plugins/vllm.py-111-    def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-112-        self.input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:113:        self.weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-114-        self.output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-115-        self.output_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-116-        assert type(self.quant_method) is vllm_linear.UnquantizedLinearMethod, (
--
modelopt/torch/quantization/plugins/vllm.py-159-    def _setup(self):
modelopt/torch/quantization/plugins/vllm.py-160-        self.w13_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py-161-        self.w2_input_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_input)
modelopt/torch/quantization/plugins/vllm.py:162:        self.w13_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py:163:        self.w2_weight_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_weight)
modelopt/torch/quantization/plugins/vllm.py-164-        self.w13_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-165-        self.w2_output_quantizer = TensorQuantizer(QuantLinearConvBase.default_quant_desc_output)
modelopt/torch/quantization/plugins/vllm.py-166-        self.w13_output_quantizer.disable()
--
modelopt/torch/quantization/plugins/vllm.py-181-        if B is self.w13_weight:
modelopt/torch/quantization/plugins/vllm.py-182-            # First layer of expert
modelopt/torch/quantization/plugins/vllm.py-183-            A = self.w13_input_quantizer(A)  # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:184:            if self.w13_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-185-                original_weight = self.w13_weight
modelopt/torch/quantization/plugins/vllm.py:186:                self.w13_weight = self.w13_weight_quantizer(self.w13_weight)
modelopt/torch/quantization/plugins/vllm.py-187-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-188-                self.w13_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-189-            else:
--
modelopt/torch/quantization/plugins/vllm.py-192-                C[:] = self.w13_output_quantizer(C)
modelopt/torch/quantization/plugins/vllm.py-193-        elif B is self.w2_weight:
modelopt/torch/quantization/plugins/vllm.py-194-            A = self.w2_input_quantizer(A)  # noqa: N806
modelopt/torch/quantization/plugins/vllm.py:195:            if self.w2_weight_quantizer.is_enabled:
modelopt/torch/quantization/plugins/vllm.py-196-                original_weight = self.w2_weight
modelopt/torch/quantization/plugins/vllm.py:197:                self.w2_weight = self.w2_weight_quantizer(self.w2_weight)
modelopt/torch/quantization/plugins/vllm.py-198-                vllm_fused_moe_package._invoke_fused_moe_kernel(A, B, C, *args, **kwargs)
modelopt/torch/quantization/plugins/vllm.py-199-                self.w2_weight = original_weight
modelopt/torch/quantization/plugins/vllm.py-200-            else:
--
modelopt/torch/quantization/plugins/vllm.py-232-        # the MoE weights can be super large, it consumes too much memory, so we need to fold the weight one by one
modelopt/torch/quantization/plugins/vllm.py-233-        for i in range(self.w13_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-234-            self.w13_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:235:                self.w13_weight_quantizer(self.w13_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-236-                    self.w13_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-237-                )
modelopt/torch/quantization/plugins/vllm.py-238-            )
modelopt/torch/quantization/plugins/vllm.py:239:        self.w13_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-240-        for i in range(self.w2_weight.shape[0]):
modelopt/torch/quantization/plugins/vllm.py-241-            self.w2_weight[i].copy_(
modelopt/torch/quantization/plugins/vllm.py:242:                self.w2_weight_quantizer(self.w2_weight[i].float().contiguous()).to(
modelopt/torch/quantization/plugins/vllm.py-243-                    self.w2_weight.dtype
modelopt/torch/quantization/plugins/vllm.py-244-                )
modelopt/torch/quantization/plugins/vllm.py-245-            )
modelopt/torch/quantization/plugins/vllm.py:246:        self.w2_weight_quantizer.disable()
modelopt/torch/quantization/plugins/vllm.py-247-
modelopt/torch/quantization/plugins/vllm.py-248-        torch.cuda.empty_cache()
modelopt/torch/quantization/plugins/vllm.py-249-
--
modelopt/torch/quantization/nn/modules/quant_module.py-121-
modelopt/torch/quantization/nn/modules/quant_module.py-122-    def fold_weight(self, keep_attrs: bool = False):
modelopt/torch/quantization/nn/modules/quant_module.py-123-        """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_module.py:124:        # Handle all attributes that end with _weight_quantizer
modelopt/torch/quantization/nn/modules/quant_module.py-125-        for name in dir(self):
modelopt/torch/quantization/nn/modules/quant_module.py-126-            attr = getattr(self, name)
modelopt/torch/quantization/nn/modules/quant_module.py-127-            if (
modelopt/torch/quantization/nn/modules/quant_module.py:128:                name.endswith("weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_module.py-129-                and isinstance(attr, TensorQuantizer)
modelopt/torch/quantization/nn/modules/quant_module.py-130-                and attr.fake_quant
modelopt/torch/quantization/nn/modules/quant_module.py-131-            ):
modelopt/torch/quantization/nn/modules/quant_module.py:132:                # Get the corresponding weight name by removing _weight_quantizer suffix
modelopt/torch/quantization/nn/modules/quant_module.py-133-                weight_name = name[:-10]
modelopt/torch/quantization/nn/modules/quant_module.py-134-
modelopt/torch/quantization/nn/modules/quant_module.py-135-                assert hasattr(self, weight_name), (
--
modelopt/torch/quantization/nn/modules/quant_module.py-203-    Quantized linear modules are modules where both the input and the weight are quantized.
modelopt/torch/quantization/nn/modules/quant_module.py-204-    """
modelopt/torch/quantization/nn/modules/quant_module.py-205-
modelopt/torch/quantization/nn/modules/quant_module.py:206:    weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_module.py-207-    _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_module.py-208-    default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_module.py-209-
--
modelopt/torch/quantization/nn/modules/quant_module.py-219-    `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_module.py-220-    def _get_quantized_weight(module: "QuantLinearConvBase", weight: torch.Tensor) -> torch.Tensor:
modelopt/torch/quantization/nn/modules/quant_module.py-221-        if module._enable_weight_quantization or is_torch_export_mode():
modelopt/torch/quantization/nn/modules/quant_module.py:222:            return module.weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_module.py-223-        return weight
modelopt/torch/quantization/nn/modules/quant_module.py-224-
modelopt/torch/quantization/nn/modules/quant_module.py-225-    def forward(self, input, *args, **kwargs):
--
modelopt/torch/quantization/nn/modules/quant_module.py-234-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_module.py-235-        super()._setup()
modelopt/torch/quantization/nn/modules/quant_module.py-236-        self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_module.py:237:            "weight_quantizer", TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_module.py-238-        )
modelopt/torch/quantization/nn/modules/quant_module.py-239-        self._register_temp_attribute("_enable_weight_quantization", False)
modelopt/torch/quantization/nn/modules/quant_module.py-240-        self._register_dynamic_attribute("weight", self._get_quantized_weight)
--
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1319-        return super()._fake_quantize(inputs)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1320-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1321-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py:1322:class SequentialQuantizer(nn.Sequential):
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1323-    """A sequential container for  :class:`TensorQuantizer` modules.
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1324-
modelopt/torch/quantization/nn/modules/tensor_quantizer.py-1325-    This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-47-class QuantRNNBase(QuantModule):
modelopt/torch/quantization/nn/modules/quant_rnn.py-48-    """Base class for quantized RNN modules."""
modelopt/torch/quantization/nn/modules/quant_rnn.py-49-
modelopt/torch/quantization/nn/modules/quant_rnn.py:50:    weight_quantizer: TensorQuantizer | SequentialQuantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py-51-    _enable_weight_quantization: bool
modelopt/torch/quantization/nn/modules/quant_rnn.py-52-    default_quant_desc_weight = QUANT_DESC_8BIT_PER_TENSOR
modelopt/torch/quantization/nn/modules/quant_rnn.py-53-    default_quant_desc_input = QUANT_DESC_8BIT_PER_TENSOR
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-75-        self._enable_weight_quantization = False
modelopt/torch/quantization/nn/modules/quant_rnn.py-76-
modelopt/torch/quantization/nn/modules/quant_rnn.py-77-    `@staticmethod`
modelopt/torch/quantization/nn/modules/quant_rnn.py:78:    def _get_quantized_weight_handler(weight_quantizer_name: str):
modelopt/torch/quantization/nn/modules/quant_rnn.py-79-        def _get_quantized_weight(module: "QuantRNNBase", weight: torch.Tensor):
modelopt/torch/quantization/nn/modules/quant_rnn.py-80-            if module._enable_weight_quantization:
modelopt/torch/quantization/nn/modules/quant_rnn.py:81:                weight_quantizer = getattr(module, weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py:82:                return weight_quantizer(weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-83-            return weight
modelopt/torch/quantization/nn/modules/quant_rnn.py-84-
modelopt/torch/quantization/nn/modules/quant_rnn.py-85-        return _get_quantized_weight
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-102-        for name, _ in self.named_parameters():
modelopt/torch/quantization/nn/modules/quant_rnn.py-103-            if name.startswith("weight"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-104-                # to be compatible with our current config, the name is some what weird
modelopt/torch/quantization/nn/modules/quant_rnn.py:105:                # it would be weight_xxx_weight_quantizer
modelopt/torch/quantization/nn/modules/quant_rnn.py:106:                weight_quantizer_name = name + "_weight_quantizer"
modelopt/torch/quantization/nn/modules/quant_rnn.py-107-                self._register_temp_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:108:                    weight_quantizer_name, TensorQuantizer(self.default_quant_desc_weight)
modelopt/torch/quantization/nn/modules/quant_rnn.py-109-                )
modelopt/torch/quantization/nn/modules/quant_rnn.py-110-                self._register_dynamic_attribute(
modelopt/torch/quantization/nn/modules/quant_rnn.py:111:                    name, self._get_quantized_weight_handler(weight_quantizer_name)
modelopt/torch/quantization/nn/modules/quant_rnn.py-112-                )
modelopt/torch/quantization/nn/modules/quant_rnn.py-113-        # for cells
modelopt/torch/quantization/nn/modules/quant_rnn.py-114-        self._register_temp_attribute("_input_quantizers", [])
--
modelopt/torch/quantization/nn/modules/quant_rnn.py-143-        for iq in self._input_quantizers + self._proj_input_quantizers:
modelopt/torch/quantization/nn/modules/quant_rnn.py-144-            iq.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-145-
modelopt/torch/quantization/nn/modules/quant_rnn.py:146:    def _disable_weight_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-147-        for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:148:            if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-149-                module.disable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-150-
modelopt/torch/quantization/nn/modules/quant_rnn.py:151:    def _enable_weight_quantizer(self):
modelopt/torch/quantization/nn/modules/quant_rnn.py-152-        for name, module in self.named_modules():
modelopt/torch/quantization/nn/modules/quant_rnn.py:153:            if name.endswith("weight_quantizer"):
modelopt/torch/quantization/nn/modules/quant_rnn.py-154-                module.enable()
modelopt/torch/quantization/nn/modules/quant_rnn.py-155-
modelopt/torch/quantization/nn/modules/quant_rnn.py-156-    def _setup(self):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-46-        """Quantized version of a generic linear functional."""
modelopt/torch/quantization/nn/modules/quant_linear.py-47-        output = getattr(package, func_name)(
modelopt/torch/quantization/nn/modules/quant_linear.py-48-            self.input_quantizer(input),
modelopt/torch/quantization/nn/modules/quant_linear.py:49:            self.weight_quantizer(weight),
modelopt/torch/quantization/nn/modules/quant_linear.py-50-            *args,
modelopt/torch/quantization/nn/modules/quant_linear.py-51-            **kwargs,
modelopt/torch/quantization/nn/modules/quant_linear.py-52-        )
--
modelopt/torch/quantization/nn/modules/quant_linear.py-119-
modelopt/torch/quantization/nn/modules/quant_linear.py-120-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-121-        """Overrides and bypass the _setup function."""
modelopt/torch/quantization/nn/modules/quant_linear.py:122:        if isinstance(self.weight_quantizer, SVDQuantTensorQuantizer):
modelopt/torch/quantization/nn/modules/quant_linear.py-123-            return
modelopt/torch/quantization/nn/modules/quant_linear.py:124:        self.weight_quantizer.__class__ = SVDQuantTensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-125-
modelopt/torch/quantization/nn/modules/quant_linear.py-126-    def _not_sequential_quantizers(self):
modelopt/torch/quantization/nn/modules/quant_linear.py:127:        return isinstance(self.weight_quantizer, TensorQuantizer) and isinstance(
modelopt/torch/quantization/nn/modules/quant_linear.py-128-            self.input_quantizer, TensorQuantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-129-        )
modelopt/torch/quantization/nn/modules/quant_linear.py-130-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-138-        """Compute the LoRA residual if present, otherwise return None."""
modelopt/torch/quantization/nn/modules/quant_linear.py-139-        if (
modelopt/torch/quantization/nn/modules/quant_linear.py-140-            self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:141:            and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:142:            and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-143-        ):
modelopt/torch/quantization/nn/modules/quant_linear.py:144:            lora_a = F.linear(input, weight=self.weight_quantizer.svdquant_lora_a)
modelopt/torch/quantization/nn/modules/quant_linear.py:145:            lora_b = F.linear(lora_a, weight=self.weight_quantizer.svdquant_lora_b)
modelopt/torch/quantization/nn/modules/quant_linear.py-146-            return lora_b
modelopt/torch/quantization/nn/modules/quant_linear.py-147-        return None
modelopt/torch/quantization/nn/modules/quant_linear.py-148-
--
modelopt/torch/quantization/nn/modules/quant_linear.py-150-        """SVDQuant layer forward function."""
modelopt/torch/quantization/nn/modules/quant_linear.py-151-        has_svdquant_lora = (
modelopt/torch/quantization/nn/modules/quant_linear.py-152-            self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:153:            and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:154:            and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-155-        )
modelopt/torch/quantization/nn/modules/quant_linear.py-156-        if has_svdquant_lora:
modelopt/torch/quantization/nn/modules/quant_linear.py-157-            input = self._apply_pre_quant_scale(input)
--
modelopt/torch/quantization/nn/modules/quant_linear.py-166-        """Fold the weight for faster eval."""
modelopt/torch/quantization/nn/modules/quant_linear.py-167-        super().fold_weight(keep_attrs)
modelopt/torch/quantization/nn/modules/quant_linear.py-168-        if (
modelopt/torch/quantization/nn/modules/quant_linear.py:169:            hasattr(self, "weight_quantizer")
modelopt/torch/quantization/nn/modules/quant_linear.py-170-            and hasattr(self, "weight")
modelopt/torch/quantization/nn/modules/quant_linear.py:171:            and self.weight_quantizer.fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-172-        ):
modelopt/torch/quantization/nn/modules/quant_linear.py-173-            if (
modelopt/torch/quantization/nn/modules/quant_linear.py-174-                self._not_sequential_quantizers()
modelopt/torch/quantization/nn/modules/quant_linear.py:175:                and self.weight_quantizer.svdquant_lora_a is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:176:                and self.weight_quantizer.svdquant_lora_b is not None
modelopt/torch/quantization/nn/modules/quant_linear.py-177-            ):
modelopt/torch/quantization/nn/modules/quant_linear.py-178-                self.weight.data.copy_(
modelopt/torch/quantization/nn/modules/quant_linear.py-179-                    self.weight
modelopt/torch/quantization/nn/modules/quant_linear.py:180:                    + self.weight_quantizer.svdquant_lora_b @ self.weight_quantizer.svdquant_lora_a
modelopt/torch/quantization/nn/modules/quant_linear.py-181-                )
modelopt/torch/quantization/nn/modules/quant_linear.py-182-            if not keep_attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py-183-                _attrs = [
--
modelopt/torch/quantization/nn/modules/quant_linear.py-185-                    "_svdquant_lora_b",
modelopt/torch/quantization/nn/modules/quant_linear.py-186-                ]
modelopt/torch/quantization/nn/modules/quant_linear.py-187-                for attr in _attrs:
modelopt/torch/quantization/nn/modules/quant_linear.py:188:                    if hasattr(self.weight_quantizer, attr):
modelopt/torch/quantization/nn/modules/quant_linear.py:189:                        delattr(self.weight_quantizer, attr)
modelopt/torch/quantization/nn/modules/quant_linear.py-190-
modelopt/torch/quantization/nn/modules/quant_linear.py-191-
modelopt/torch/quantization/nn/modules/quant_linear.py-192-class RealQuantLinear(QuantModule):
--
modelopt/torch/quantization/nn/modules/quant_linear.py-241-
modelopt/torch/quantization/nn/modules/quant_linear.py-242-    def _setup(self):
modelopt/torch/quantization/nn/modules/quant_linear.py-243-        class RealQuantParameterDict(dict):
modelopt/torch/quantization/nn/modules/quant_linear.py:244:            def __init__(self, weight_quantizer: TensorQuantizer, *args, **kwargs):
modelopt/torch/quantization/nn/modules/quant_linear.py-245-                super().__init__(*args, **kwargs)
modelopt/torch/quantization/nn/modules/quant_linear.py:246:                self.weight_quantizer = weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py-247-
modelopt/torch/quantization/nn/modules/quant_linear.py-248-            def __setitem__(self, key, value):
modelopt/torch/quantization/nn/modules/quant_linear.py-249-                if (
modelopt/torch/quantization/nn/modules/quant_linear.py-250-                    key == "weight"
modelopt/torch/quantization/nn/modules/quant_linear.py:251:                    and self.weight_quantizer
modelopt/torch/quantization/nn/modules/quant_linear.py:252:                    and self.weight_quantizer.is_enabled
modelopt/torch/quantization/nn/modules/quant_linear.py:253:                    and not self.weight_quantizer._fake_quant
modelopt/torch/quantization/nn/modules/quant_linear.py-254-                    and value.element_size() > 1
modelopt/torch/quantization/nn/modules/quant_linear.py-255-                ):
modelopt/torch/quantization/nn/modules/quant_linear.py-256-                    # reset the amax for later calibration
modelopt/torch/quantization/nn/modules/quant_linear.py-257-                    if (
modelopt/torch/quantization/nn/modules/quant_linear.py:258:                        self.weight_quantizer.amax is not None
modelopt/torch/quantization/nn/modules/quant_linear.py:259:                        and self.weight_quantizer.amax.is_meta
modelopt/torch/quantization/nn/modules/quant_linear.py-260-                    ):
modelopt/torch/quantization/nn/modules/quant_linear.py:261:                        delattr(self.weight_quantizer, "_amax")
modelopt/torch/quantization/nn/modules/quant_linear.py:262:                        self.weight_quantizer.amax = self.weight_quantizer._get_amax(value)
modelopt/torch/quantization/nn/modules/quant_linear.py:263:                        self.weight_quantizer._calibrator.reset()
modelopt/torch/quantization/nn/modules/quant_linear.py-264-                    # compress the weight
modelopt/torch/quantization/nn/modules/quant_linear.py:265:                    real_quant_tensor = self.weight_quantizer(value)
modelopt/torch/quantization/nn/modules/quant_linear.py-266-                    real_quant_value = QTensorWrapper(real_quant_tensor)
modelopt/torch/quantization/nn/modules/quant_linear.py-267-                    del value  # delete the original weight to save memory
modelopt/torch/quantization/nn/modules/quant_linear.py-268-                    value = real_quant_value
--
modelopt/torch/quantization/nn/modules/quant_linear.py-270-
modelopt/torch/quantization/nn/modules/quant_linear.py-271-        # Monkey patch the _parameters.__setitem__ to real quant the weight when loading
modelopt/torch/quantization/nn/modules/quant_linear.py-272-        # HF accelerate loads the weight by directly assigning the weight through the _parameters dict.
modelopt/torch/quantization/nn/modules/quant_linear.py:273:        self._parameters = RealQuantParameterDict(self.weight_quantizer, self._parameters)
modelopt/torch/quantization/nn/modules/quant_linear.py-274-
modelopt/torch/quantization/nn/modules/quant_linear.py-275-        # Function to dynamically override load_state_dict
modelopt/torch/quantization/nn/modules/quant_linear.py-276-        dynamically_update_state_methods(self)
--
modelopt/torch/quantization/plugins/transformer_engine.py-78-            idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-79-            weight, inputs = args[idx], args[idx + 1]
modelopt/torch/quantization/plugins/transformer_engine.py-80-            remaining_args = args[idx + 2 :]
modelopt/torch/quantization/plugins/transformer_engine.py:81:            weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-82-            inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-83-            new_args = (weight, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-84-            new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-90-            idx = 1 if func_name == "_forward" else 0
modelopt/torch/quantization/plugins/transformer_engine.py-91-            weight, weight_fp8, inputs = args[idx], args[idx + 1], args[idx + 2]
modelopt/torch/quantization/plugins/transformer_engine.py-92-            remaining_args = args[idx + 3 :]
modelopt/torch/quantization/plugins/transformer_engine.py:93:            weight = self.weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-94-            inputs = self.input_quantizer(inputs)
modelopt/torch/quantization/plugins/transformer_engine.py-95-            new_args = (weight, weight_fp8, inputs, *remaining_args)
modelopt/torch/quantization/plugins/transformer_engine.py-96-            new_args = (args[0], *new_args) if func_name == "_forward" else new_args
--
modelopt/torch/quantization/plugins/transformer_engine.py-170-        weights_and_biases = args[-2 * num_gemms :]
modelopt/torch/quantization/plugins/transformer_engine.py-171-        weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
modelopt/torch/quantization/plugins/transformer_engine.py-172-        quantized_inputs = self.input_quantizer(inp)
modelopt/torch/quantization/plugins/transformer_engine.py:173:        quantized_weights = [self.weight_quantizer(weight) for weight in weights]
modelopt/torch/quantization/plugins/transformer_engine.py-174-
modelopt/torch/quantization/plugins/transformer_engine.py-175-        output = getattr(package, func_name)(
modelopt/torch/quantization/plugins/transformer_engine.py-176-            *(
--
modelopt/torch/quantization/plugins/transformer_engine.py-208-
modelopt/torch/quantization/plugins/transformer_engine.py-209-    `@staticmethod`
modelopt/torch/quantization/plugins/transformer_engine.py-210-    def forward(ctx, inp, ln_weight, ln_bias, weight, *args, **kwargs):
modelopt/torch/quantization/plugins/transformer_engine.py:211:        input_quantizer, weight_quantizer = _QuantLayerNormLinearFunc.modelopt_quantizers
modelopt/torch/quantization/plugins/transformer_engine.py-212-
modelopt/torch/quantization/plugins/transformer_engine.py:213:        qweight = weight_quantizer(weight)
modelopt/torch/quantization/plugins/transformer_engine.py-214-        qweight.requires_grad = weight.requires_grad
modelopt/torch/quantization/plugins/transformer_engine.py-215-        if ctx is not None:
modelopt/torch/quantization/plugins/transformer_engine.py-216-            # We need to recompute the quantized input for the backward pass, so we save the input_quantizer
--
modelopt/torch/quantization/plugins/transformer_engine.py-291-        # This is multi-process safe (such as in torch distributed jobs), not multi-thread safe
modelopt/torch/quantization/plugins/transformer_engine.py-292-        _QuantLayerNormLinearFunc.modelopt_quantizers = (
modelopt/torch/quantization/plugins/transformer_engine.py-293-            self.input_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py:294:            self.weight_quantizer,
modelopt/torch/quantization/plugins/transformer_engine.py-295-        )
modelopt/torch/quantization/plugins/transformer_engine.py-296-        with replace_function(
modelopt/torch/quantization/plugins/transformer_engine.py-297-            te_layernorm_linear,
--
modelopt/torch/quantization/plugins/peft.py-34-class _QuantLoraLinear(QuantModule):
modelopt/torch/quantization/plugins/peft.py-35-    def _setup(self):
modelopt/torch/quantization/plugins/peft.py-36-        self.input_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py:37:        self.weight_quantizer = TensorQuantizer()
modelopt/torch/quantization/plugins/peft.py-38-        self.output_quantizer = Tenso

... [stdout truncated: output exceeded 50000 characters (209210 total). Please modify your command to reduce output size, for example by:
- Using grep/rg to filter results
- Using head/tail to limit lines
- Using --max-count flag with grep/rg
- Processing output with awk/sed to extract only needed information]

Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Full context of sync_moe_gate_up_amax function ==="
sed -n '1174,1221p' modelopt/torch/export/layer_utils.py

echo
echo "=== Check if SequentialQuantizer has amax property ==="
sed -n '1322,1400p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py | head -80

echo
echo "=== Check imports and _GATE_UP_PAIRS ==="
sed -n '1,50p' modelopt/torch/export/layer_utils.py
rg -n '_GATE_UP_PAIRS|from.*SequentialQuantizer|import.*SequentialQuantizer' modelopt/torch/export/layer_utils.py

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== Full context of sync_moe_gate_up_amax function ===
# Gate/up naming pairs for standard (unfused) MoE architectures.
# Fused variants (gate_up_proj, linear_fc1) already share a single quantizer and need no sync.
_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]


def sync_moe_gate_up_amax(model: nn.Module) -> int:
    """Take element-wise max of gate and up weight quantizer amaxes per expert.

    Serving engines fuse gate_proj and up_proj into a single gate_up_proj and
    require a single weight_scale_2. Since weight_scale_2 = amax / (6 * 448),
    syncing amaxes before quantization ensures the per-block weight_scale values
    are computed against a consistent global scale.

    Only affects standard MoE models with separate gate/up linear layers
    (e.g. Qwen MoE, DeepSeek). Models with already-fused gate_up_proj
    (e.g. Llama4, GptOss) are unaffected.

    Returns:
        Number of expert gate/up pairs whose amaxes were synced.
    """
    synced = 0
    for _, sub_module in model.named_modules():
        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
            continue
        if not hasattr(sub_module.experts, "__iter__"):
            continue
        for expert in sub_module.experts:
            for gate_name, up_name in _GATE_UP_PAIRS:
                gate_linear = getattr(expert, gate_name, None)
                up_linear = getattr(expert, up_name, None)
                if gate_linear is None or up_linear is None:
                    continue
                gate_wq = getattr(gate_linear, "weight_quantizer", None)
                up_wq = getattr(up_linear, "weight_quantizer", None)
                if gate_wq is None or up_wq is None:
                    break
                gate_amax = getattr(gate_wq, "amax", None)
                up_amax = getattr(up_wq, "amax", None)
                if gate_amax is None or up_amax is None:
                    break
                if not torch.equal(gate_amax, up_amax):
                    shared_amax = torch.max(gate_amax, up_amax)
                    gate_wq.amax = shared_amax
                    up_wq.amax = shared_amax.clone()
                    synced += 1
                break
    return synced


=== Check if SequentialQuantizer has amax property ===
class SequentialQuantizer(nn.Sequential):
    """A sequential container for  :class:`TensorQuantizer` modules.

    This modules is used to quantize a tensor in multiple formats sequentially. It takes as input
    :class:`TensorQuantizer` modules and containerize them similar to :class:`torch.nn.Sequential`.

    We delegate certain properties and methods to all contained quantizers.
    In the case of conflicts, the first quantizer's property or method takes priority.

    `SequentialQuantizer` is useful in cases like INT4 weights, FP8 activations where weight quantization is not the
    same as the gemm quantization. It allows for applying multiple quantization formats to the same tensor in sequence.

    Use `SequentialQuantizer` methods in lower level implementations for better code organization and readability.

    Args:
        quantizers (TensorQuantizer): :class:`TensorQuantizer` modules to be added to the container.

    """

    _delegated_properties = ["fake_quant", "is_enabled", "amax"]
    _delegated_methods = [
        "reset_amax",
        "disable",
        "enable",
        "load_calib_amax",
        "load_calib_bias",
    ]

    def __init__(self, *quantizers: TensorQuantizer):
        """Initialize SequentialQuantizer module."""
        super().__init__(*quantizers)
        assert all(isinstance(q, TensorQuantizer) for q in self), (
            "All quantizers must be a TensorQuantizer."
        )

    def __getattr__(self, name):
        """Delegate properties and methods to all contained quantizers."""
        if name in self._delegated_properties:
            # Return the property of the first quantizer
            return getattr(self[0], name)

        if name in self._delegated_methods:

            def method_wrapper(*args, **kwargs):
                outputs = getattr(self[0], name)(*args, **kwargs)
                for quantizer in self[1:]:
                    outputs = getattr(quantizer, name)(*args, **kwargs)
                return outputs

            return method_wrapper

        # Defer to super class for attributes not handled here
        return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in self._delegated_properties:
            for quantizer in self:
                setattr(quantizer, name, value)
        else:
            super().__setattr__(name, value)

    def get_modelopt_state(self) -> dict[str, Any]:
        """Get meta state to be saved in checkpoint."""
        return {"num_quantizers": len(self), "is_sequential_quantizer": True}

    def set_from_attribute_config(
        self,
        attributes: list[dict[str, Any] | QuantizerAttributeConfig]
        | dict[str, Any]
        | QuantizerAttributeConfig,
    ):
        """Set the attributes of contained quantizers from a list of attribute_dicts."""
        if not isinstance(attributes, (list, tuple)):
            assert isinstance(attributes, (dict, QuantizerAttributeConfig)), (
                "attributes must be a list or a dict."
            )
            attributes = [attributes] * len(self)

        for attribute, quantizer in zip(attributes, self):

=== Check imports and _GATE_UP_PAIRS ===
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils for model_config export.

Some of the logics in this file are empirical and needs constant update if exceptions occur.
"""

from warnings import warn

import torch
import torch.nn as nn

try:
    from transformers.activations import ACT2FN
except Exception:
    warn("Cannot find transformers package. Hugginface modules cannot be exported.")

from modelopt.torch.utils import distributed as dist
from modelopt.torch.utils import import_plugin

from ..quantization.nn import SequentialQuantizer, TensorQuantizer
from .hf_config_map import HF_CONFIG_MAP
from .mcore_config_map import MCORE_CONFIG_MAP
from .model_config import (
    LAYERNORM_DEFAULT,
    LAYERNORM_RMS,
    LINEAR_COLUMN,
    LINEAR_GROUP,
    LINEAR_ROW,
    QUANTIZATION_FP8,
    QUANTIZATION_NONE,
    QUANTIZATION_NVFP4,
    AttentionConfig,
    ConvConfig,
    DecoderLayerConfig,
    EmbeddingConfig,
    ExpertConfig,
34:from ..quantization.nn import SequentialQuantizer, TensorQuantizer
1176:_GATE_UP_PAIRS = [("gate_proj", "up_proj"), ("w1", "w3")]
1201:            for gate_name, up_name in _GATE_UP_PAIRS:

Script executed:

#!/bin/bash
set -euo pipefail

echo "=== Check build_mlp_config handling at lines 827-831 ==="
sed -n '820,840p' modelopt/torch/export/layer_utils.py

echo
echo "=== Check if SequentialQuantizer.amax delegation actually works ==="
sed -n '1342,1380p' modelopt/torch/quantization/nn/modules/tensor_quantizer.py

echo
echo "=== Verify the loop structure - what does break actually break? ==="
sed -n '1195,1225p' modelopt/torch/export/layer_utils.py | cat -n

Repository: NVIDIA/Model-Optimizer


Repository: NVIDIA/Model-Optimizer
Exit code: 0

stdout:

=== Check build_mlp_config handling at lines 827-831 ===
            or get_quantization_format(module) in [QUANTIZATION_FP8, QUANTIZATION_NVFP4]
        )
    ):
        preprocess_linear_fusion([fc_linear, gate_linear])

    if fc_linear is not None:
        weight_quantizer = None
        if hasattr(fc_linear, "weight_quantizer"):
            weight_quantizer = fc_linear.weight_quantizer
            if isinstance(weight_quantizer, SequentialQuantizer):
                weight_quantizer = weight_quantizer[0]

        # swap fused fc and gate
        if decoder_type in ["chatglm", "phi3"]:
            weights = torch.chunk(fc_linear.weight, 2, dim=0)
            weights = (weights[1], weights[0])
            fc_linear.weight.data = torch.cat(weights, dim=0)

            if (
                weight_quantizer is not None
                and weight_quantizer.is_enabled

=== Check if SequentialQuantizer.amax delegation actually works ===
    _delegated_methods = [
        "reset_amax",
        "disable",
        "enable",
        "load_calib_amax",
        "load_calib_bias",
    ]

    def __init__(self, *quantizers: TensorQuantizer):
        """Initialize SequentialQuantizer module."""
        super().__init__(*quantizers)
        assert all(isinstance(q, TensorQuantizer) for q in self), (
            "All quantizers must be a TensorQuantizer."
        )

    def __getattr__(self, name):
        """Delegate properties and methods to all contained quantizers."""
        if name in self._delegated_properties:
            # Return the property of the first quantizer
            return getattr(self[0], name)

        if name in self._delegated_methods:

            def method_wrapper(*args, **kwargs):
                outputs = getattr(self[0], name)(*args, **kwargs)
                for quantizer in self[1:]:
                    outputs = getattr(quantizer, name)(*args, **kwargs)
                return outputs

            return method_wrapper

        # Defer to super class for attributes not handled here
        return super().__getattr__(name)

    def __setattr__(self, name, value):
        if name in self._delegated_properties:
            for quantizer in self:
                setattr(quantizer, name, value)
        else:

=== Verify the loop structure - what does break actually break? ===
     1	    for _, sub_module in model.named_modules():
     2	        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
     3	            continue
     4	        if not hasattr(sub_module.experts, "__iter__"):
     5	            continue
     6	        for expert in sub_module.experts:
     7	            for gate_name, up_name in _GATE_UP_PAIRS:
     8	                gate_linear = getattr(expert, gate_name, None)
     9	                up_linear = getattr(expert, up_name, None)
    10	                if gate_linear is None or up_linear is None:
    11	                    continue
    12	                gate_wq = getattr(gate_linear, "weight_quantizer", None)
    13	                up_wq = getattr(up_linear, "weight_quantizer", None)
    14	                if gate_wq is None or up_wq is None:
    15	                    break
    16	                gate_amax = getattr(gate_wq, "amax", None)
    17	                up_amax = getattr(up_wq, "amax", None)
    18	                if gate_amax is None or up_amax is None:
    19	                    break
    20	                if not torch.equal(gate_amax, up_amax):
    21	                    shared_amax = torch.max(gate_amax, up_amax)
    22	                    gate_wq.amax = shared_amax
    23	                    up_wq.amax = shared_amax.clone()
    24	                    synced += 1
    25	                break
    26	    return synced
    27	
    28	
    29	def build_stacked_experts(
    30	    experts: nn.Module,
    31	    linear_names: list[str],

Use continue instead of break to check all gate/up pairs (not just the first) per expert.

On lines 1209 and 1213, break exits the pair-checking loop entirely if any pair lacks a quantizer or amax. Since _GATE_UP_PAIRS has two pairs (gate_proj/up_proj and w1/w3), a missing quantizer in the first pair prevents checking the second. This should be continue to skip only the current pair and attempt the next one. (The final break on line 1214 correctly exits after syncing.)

Proposed fix
                 gate_wq = getattr(gate_linear, "weight_quantizer", None)
                 up_wq = getattr(up_linear, "weight_quantizer", None)
                 if gate_wq is None or up_wq is None:
-                    break
+                    continue
                 gate_amax = getattr(gate_wq, "amax", None)
                 up_amax = getattr(up_wq, "amax", None)
                 if gate_amax is None or up_amax is None:
-                    break
+                    continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/layer_utils.py` around lines 1206 - 1213, The loop that
iterates over _GATE_UP_PAIRS incorrectly uses break when a pair is missing a
quantizer or amax, which stops checking remaining pairs; update the logic in the
loop that accesses gate_linear/up_linear, gate_wq/up_wq and gate_amax/up_amax to
use continue instead of break so only the current pair is skipped and the next
pair in _GATE_UP_PAIRS is checked; keep the final break that exits after
successfully syncing unchanged.

Copy link
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Edwardf0t1 Edwardf0t1 requested a review from cjluo-nv March 14, 2026 05:46
@cjluo-nv
Copy link
Collaborator

If the goal for this is to serve a safe guard, can we do a check (like assert) so when the bug reproduces, we can see it immediately?

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
@Edwardf0t1
Copy link
Contributor Author

If the goal for this is to serve a safe guard, can we do a check (like assert) so when the bug reproduces, we can see it immediately?

Added a warning.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_hf.py (1)

779-789: Implementation is correct; consider clarifying warning message.

The placement is appropriate—syncing amaxes before _process_quantized_modules ensures consistent weight_scale_2 computation downstream. The safety-net approach using warnings.warn() is reasonable for graceful handling of edge cases.

Minor nit: The warning mentions "mismatched weight_scale_2" but at this point the function syncs amax values (weight_scale_2 is computed later). Consider rewording for clarity:

📝 Optional: Clarify warning message
     synced = sync_moe_gate_up_amax(model)
     if synced:
         warnings.warn(
-            f"Found {synced} MoE expert gate/up projection pair(s) with mismatched "
-            f"weight_scale_2 after requantize_resmooth_fused_llm_layers. "
+            f"Found {synced} MoE expert gate/up projection pair(s) with mismatched amaxes "
+            f"after requantize_resmooth_fused_llm_layers. "
             f"This typically means the dummy forward did not activate these experts. "
-            f"Taking element-wise max of amaxes for serving-engine fusion."
+            f"Taking element-wise max of amaxes to ensure unified weight_scale_2 for serving-engine fusion."
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 779 - 789, The
warning text refers to "mismatched weight_scale_2" even though this block calls
sync_moe_gate_up_amax to reconcile amax values (weight_scale_2 is computed later
by requantize_resmooth_fused_llm_layers/_process_quantized_modules); update the
warnings.warn message in the sync_moe_gate_up_amax handling to state that
mismatched amax (or weight quantizer amaxes) were found and that the code is
taking the element-wise max of amaxes so downstream weight_scale_2 computation
will be consistent, keeping reference to the same context (e.g., mention
requantize_resmooth_fused_llm_layers dummy forward scenario).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 779-789: The warning text refers to "mismatched weight_scale_2"
even though this block calls sync_moe_gate_up_amax to reconcile amax values
(weight_scale_2 is computed later by
requantize_resmooth_fused_llm_layers/_process_quantized_modules); update the
warnings.warn message in the sync_moe_gate_up_amax handling to state that
mismatched amax (or weight quantizer amaxes) were found and that the code is
taking the element-wise max of amaxes so downstream weight_scale_2 computation
will be consistent, keeping reference to the same context (e.g., mention
requantize_resmooth_fused_llm_layers dummy forward scenario).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 078c1c03-41cb-4615-a802-45ec7cc83dff

📥 Commits

Reviewing files that changed from the base of the PR and between 17b7f46 and c521b8a.

📒 Files selected for processing (1)
  • modelopt/torch/export/unified_export_hf.py

cjluo-nv

This comment was marked as outdated.

cjluo-nv

This comment was marked as outdated.

@Edwardf0t1 Edwardf0t1 requested a review from cjluo-nv March 16, 2026 20:28
Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concrete proposal

Sorry for the confusion — here's exactly what I'd suggest:

What to change

Delete the new sync_moe_gate_up_amax function and _GATE_UP_PAIRS constant from layer_utils.py. Instead, add the sync logic inline in unified_export_hf.py inside the existing expert loop (lines 699-750), which already calls get_expert_linear_names and iterates over experts. This avoids any new traversal or naming constants.

Where to put it

In _export_transformers_checkpoint, right after the existing expert input-quantizer loop (after line 750), add something like:

    # Sync gate/up weight quantizer amaxes for unfused MoE experts.
    # Serving engines fuse gate+up into one projection and need a shared weight_scale_2.
    synced = 0
    for _, sub_module in model.named_modules():
        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
            continue
        if not isinstance(sub_module.experts, collections.abc.Iterable):
            continue
        linear_names = get_expert_linear_names(sub_module)
        # Only unfused 3-linear experts (e.g. gate_proj/down_proj/up_proj or w1/w2/w3)
        if len(linear_names) \!= 3:
            continue
        gate_name, up_name = linear_names[0], linear_names[2]
        for expert in sub_module.experts:
            gate_wq = getattr(getattr(expert, gate_name, None), "weight_quantizer", None)
            up_wq = getattr(getattr(expert, up_name, None), "weight_quantizer", None)
            if gate_wq is None or up_wq is None:
                continue
            gate_amax = getattr(gate_wq, "amax", None)
            up_amax = getattr(up_wq, "amax", None)
            if gate_amax is None or up_amax is None:
                continue
            if not torch.equal(gate_amax, up_amax):
                shared = torch.max(gate_amax, up_amax)
                gate_wq.amax = shared
                up_wq.amax = shared.clone()
                synced += 1
    if synced:
        warnings.warn(
            f"Synced {synced} MoE expert gate/up weight quantizer amax pair(s) "
            f"by taking element-wise max for serving-engine fusion."
        )

Why this is better

  • No new function or constant in layer_utils.py — zero diff there
  • Uses get_expert_linear_names to derive gate/up names — single source of truth
  • Same traversal pattern as the existing loop above it (lines 699-750) — consistent, easy to follow
  • Same fix logic you already wrote — just relocated

That's it — same behavior, less code, no duplication.

@Edwardf0t1
Copy link
Contributor Author

Concrete proposal

Sorry for the confusion — here's exactly what I'd suggest:

What to change

Delete the new sync_moe_gate_up_amax function and _GATE_UP_PAIRS constant from layer_utils.py. Instead, add the sync logic inline in unified_export_hf.py inside the existing expert loop (lines 699-750), which already calls get_expert_linear_names and iterates over experts. This avoids any new traversal or naming constants.

Where to put it

In _export_transformers_checkpoint, right after the existing expert input-quantizer loop (after line 750), add something like:

    # Sync gate/up weight quantizer amaxes for unfused MoE experts.
    # Serving engines fuse gate+up into one projection and need a shared weight_scale_2.
    synced = 0
    for _, sub_module in model.named_modules():
        if not (is_moe(sub_module) and hasattr(sub_module, "experts")):
            continue
        if not isinstance(sub_module.experts, collections.abc.Iterable):
            continue
        linear_names = get_expert_linear_names(sub_module)
        # Only unfused 3-linear experts (e.g. gate_proj/down_proj/up_proj or w1/w2/w3)
        if len(linear_names) \!= 3:
            continue
        gate_name, up_name = linear_names[0], linear_names[2]
        for expert in sub_module.experts:
            gate_wq = getattr(getattr(expert, gate_name, None), "weight_quantizer", None)
            up_wq = getattr(getattr(expert, up_name, None), "weight_quantizer", None)
            if gate_wq is None or up_wq is None:
                continue
            gate_amax = getattr(gate_wq, "amax", None)
            up_amax = getattr(up_wq, "amax", None)
            if gate_amax is None or up_amax is None:
                continue
            if not torch.equal(gate_amax, up_amax):
                shared = torch.max(gate_amax, up_amax)
                gate_wq.amax = shared
                up_wq.amax = shared.clone()
                synced += 1
    if synced:
        warnings.warn(
            f"Synced {synced} MoE expert gate/up weight quantizer amax pair(s) "
            f"by taking element-wise max for serving-engine fusion."
        )

Why this is better

  • No new function or constant in layer_utils.py — zero diff there
  • Uses get_expert_linear_names to derive gate/up names — single source of truth
  • Same traversal pattern as the existing loop above it (lines 699-750) — consistent, easy to follow
  • Same fix logic you already wrote — just relocated

That's it — same behavior, less code, no duplication.

I think the suggestion has a real problem. Here's why:
The current ordering in _export_transformers_checkpoint is:

Step 1 (lines 699-750):  Handle uncalibrated expert input quantizers
Step 2 (line 755):        requantize_resmooth_fused_llm_layers(model)
                            └─ preprocess_linear_fusion → syncs gate/up weight amaxes
                               (but may miss experts not activated by dummy forward)
Step 3 (line 782):        sync_moe_gate_up_amax(model)  ← safety net
Step 4 (line 792):        _process_quantized_modules     ← quantizes weights using amaxes

If we inline the sync into Step 1 (lines 699-750), it runs before requantize_resmooth_fused_llm_layers in Step 2. That's a problem because:

  1. preprocess_linear_fusion (called by Step 2) can resmooth weights and recalibrate amaxes via _update_pre_quant_scale, which resets amaxes and re-runs finish_stats_collection. So any sync done in Step 1 could be undone by Step 2.

  2. The current placement as a safety net after Step 2 is deliberate — it catches exactly the pairs that requantize_resmooth_fused_llm_layers missed, without interfering with what it successfully handled.

  3. The existing loop at 699-750 is specifically for input quantizer amaxes of uncalibrated experts. Mixing in weight quantizer amax sync conflates two independent concerns and would make the already-complex loop harder to reason about.

The one valid point in the suggestion is avoiding a redundant model traversal and the _GATE_UP_PAIRS constant. But the traversal cost is negligible (it's iterating named_modules(), not doing any computation), and _GATE_UP_PAIRS is a 2-line constant that maps cleanly to what get_expert_linear_names returns for unfused architectures.

My recommendation: keep sync_moe_gate_up_amax as a separate function in its current position (after requantize_resmooth_fused_llm_layers). The ordering is correct, the separation of concerns is clean, and the cost is trivial.

Copy link
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with one remaining suggestion: replace _GATE_UP_PAIRS with get_expert_linear_names inside sync_moe_gate_up_amax. Instead of hardcoding the gate/up pairs, derive them from get_expert_linear_names(sub_module) — when len(linear_names) == 3, use linear_names[0] and linear_names[2] as the gate/up pair. This keeps expert naming in a single source of truth and the helper function can stay as-is otherwise. The placement and logic are correct.

@Edwardf0t1 Edwardf0t1 merged commit 7b34de6 into main Mar 16, 2026
40 checks passed
@Edwardf0t1 Edwardf0t1 deleted the zhiyu/handle-moe-w13-scales branch March 16, 2026 22:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants