-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Add MAGNUS: Multi-Attention Guided Network for Unified Segmentation #8717
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds a new MAGNUS CNN–ViT fusion segmentation implementation and its component classes in monai/networks/nets: CNNPath, TransformerPath, CrossModalAttentionFusion, ScaleAdaptiveConv, MagnusSEBlock, DecoderBlock, and MAGNUS. Re-exports these symbols from monai/networks/nets/init.py. Introduces comprehensive unit tests covering components, configurations, 2D/3D variants, deep supervision, gradient/CUDA paths, and error cases. Adds a MAGNUS autoclass entry in docs/source/networks.rst. Includes weight initialization and spatial-dimension and head-divisibility validations. Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Around line 147-166: The transformer path lacks positional embeddings: add a
learnable positional embedding parameter (e.g., self.pos_embed =
nn.Parameter(torch.zeros(1, num_patches, hidden_dim))) initialized (truncated
normal or normal) and sized to match the sequence length produced by
self.embedding (compute num_patches from input spatial dimensions divided by
patch_size or infer from the flattened embedding shape at runtime), then in the
forward pass add this positional embedding to the flattened patch tokens before
passing them into self.transformer; ensure the parameter is registered on the
correct device and that self.norm still applies after the transformer.
🧹 Nitpick comments (4)
monai/networks/nets/magnus.py (2)
37-37: Sort__all__alphabetically.Per Ruff RUF022: apply isort-style sorting to
__all__.Proposed fix
-__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] +__all__ = ["CNNPath", "CrossModalAttentionFusion", "MAGNUS", "ScaleAdaptiveConv", "TransformerPath"]
703-704: Addstrict=Truetozip().Ensures decoder_blocks and cnn_skips have matching lengths, catching bugs if construction changes.
Proposed fix
- for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips, strict=True)):tests/networks/nets/test_magnus.py (2)
29-35: Consider importing frommonai.networks.netsto verify public exports.Current imports bypass the public API. Testing via
from monai.networks.nets import MAGNUSwould also validate the__init__.pyexports.
190-194: Addstrict=Truetozip().Although length is asserted above, adding
strict=Truemakes intent explicit.Proposed fix
- for i, (feat, out) in enumerate(zip(features, outputs)): + for i, (feat, out) in enumerate(zip(features, outputs, strict=True)):
| aux_weights: Sequence[float] = (0.4, 0.3, 0.3), | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| if spatial_dims not in (2, 3): | ||
| raise ValueError(f"spatial_dims must be 2 or 3, got {spatial_dims}.") | ||
|
|
||
| self.spatial_dims = spatial_dims | ||
| self.in_channels = in_channels | ||
| self.out_channels = out_channels | ||
| self.features = list(features) | ||
| self.deep_supervision = deep_supervision | ||
| self.aux_weights = list(aux_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aux_weights is stored but never used.
The aux_weights parameter is documented and stored as an attribute but not applied anywhere in the model. Either apply them in the forward pass or remove from constructor and document that users should handle weighting externally.
🧰 Tools
🪛 Ruff (0.14.13)
558-558: Avoid specifying long messages outside the exception class
(TRY003)
- Add learnable positional embeddings to TransformerPath for proper spatial reasoning - Implement dynamic positional embedding interpolation for varying input sizes - Add positional dropout for regularization - Update aux_weights docstring to clarify it's for external use only Addresses CodeRabbit review comments on PR Project-MONAI#8717 Signed-off-by: Sefa Aras <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Line 37: The __all__ export list in magnus.py omits the public classes SEBlock
and DecoderBlock; update the __all__ variable to include "SEBlock" and
"DecoderBlock" (and reorder the entries alphabetically) so these symbols are
properly exported for external importers; ensure the string names exactly match
the class names SEBlock and DecoderBlock and preserve existing entries like
MAGNUS, CNNPath, TransformerPath, CrossModalAttentionFusion, ScaleAdaptiveConv.
🧹 Nitpick comments (5)
monai/networks/nets/magnus.py (5)
177-200: Unused parameterxin_interpolate_pos_encoding.Docstring says "for device reference" but it's not used. Either remove it or use
x.deviceif device placement is needed.Proposed fix (remove unused parameter)
- def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch.Tensor: + def _interpolate_pos_encoding(self, num_patches: int) -> torch.Tensor: """ Interpolate positional embeddings to match the number of patches. Args: - x: input tensor for device reference. num_patches: target number of patches. Returns: Interpolated positional embeddings of shape (1, num_patches, hidden_dim). """And update the call site at line 222:
- pos_embed = self._interpolate_pos_encoding(x_flat, num_patches) + pos_embed = self._interpolate_pos_encoding(num_patches)
746-746: Addstrict=Truetozip()for safety.Lengths should match by construction, but explicit strictness prevents silent bugs if refactored later.
Proposed fix
- for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips, strict=True)):
690-705: Minor: Kaiming initialization withnonlinearity="relu"applied to GELU layers.TransformerEncoderLayer uses GELU activation. While Kaiming init defaults assume ReLU, PyTorch lacks a GELU option. This is acceptable but worth noting.
580-601: MissingRaisessection in docstring.Per coding guidelines, raised exceptions should be documented.
ValueErroris raised at line 601 but not documented.Proposed docstring addition
Add to the docstring before
Example::Raises: ValueError: If spatial_dims is not 2 or 3.
250-261: MissingRaisessection in docstring.
ValueErroris raised when channels is not divisible by num_heads.Proposed docstring addition
Add to the class docstring:
Raises: ValueError: If channels is not divisible by num_heads.
monai/networks/nets/magnus.py
Outdated
| from monai.networks.blocks import Convolution, UpSample | ||
| from monai.networks.layers.utils import get_act_layer, get_norm_layer | ||
|
|
||
| __all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__all__ is missing SEBlock and DecoderBlock.
These are public classes that should be exported. Also, consider sorting alphabetically per static analysis.
Proposed fix
-__all__ = ["MAGNUS", "CNNPath", "TransformerPath", "CrossModalAttentionFusion", "ScaleAdaptiveConv"]
+__all__ = [
+ "CNNPath",
+ "CrossModalAttentionFusion",
+ "DecoderBlock",
+ "MAGNUS",
+ "ScaleAdaptiveConv",
+ "SEBlock",
+ "TransformerPath",
+]🧰 Tools
🪛 Ruff (0.14.13)
37-37: __all__ is not sorted
Apply an isort-style sorting to __all__
(RUF022)
🤖 Prompt for AI Agents
In `@monai/networks/nets/magnus.py` at line 37, The __all__ export list in
magnus.py omits the public classes SEBlock and DecoderBlock; update the __all__
variable to include "SEBlock" and "DecoderBlock" (and reorder the entries
alphabetically) so these symbols are properly exported for external importers;
ensure the string names exactly match the class names SEBlock and DecoderBlock
and preserve existing entries like MAGNUS, CNNPath, TransformerPath,
CrossModalAttentionFusion, ScaleAdaptiveConv.
dzenanz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for contributing this interesting new architecture!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Around line 635-641: The computed or provided attention head counts must be
validated so that vit_hidden_dim % vit_num_heads == 0 (and similarly for
fusion_num_heads if used by a TransformerEncoderLayer); update the logic around
vit_hidden_dim, vit_num_heads, and fusion_num_heads to check divisibility after
computing/defaulting vit_num_heads and before using them—if not divisible either
raise a clear ValueError explaining the mismatch and showing vit_hidden_dim and
vit_num_heads, or adjust by choosing the largest valid nhead ≤ computed value
that divides vit_hidden_dim (and apply the same validation/adjustment for
fusion_num_heads), ensuring any automated fallback still respects
TransformerEncoderLayer’s d_model % nhead == 0 requirement.
🧹 Nitpick comments (2)
monai/networks/nets/magnus.py (2)
243-272: Document the ValueError in the docstring.
Add aRaises:section for the channels/heads divisibility check. As per coding guidelines, ...
546-627: AddRaises:to MAGNUS docstring.
spatial_dimsvalidation throwsValueErrorbut isn’t documented. As per coding guidelines, ...
| # Compute number of attention heads | ||
| vit_hidden_dim = self.features[-1] | ||
| if vit_num_heads is None: | ||
| vit_num_heads = max(vit_hidden_dim // 32, 1) | ||
| if fusion_num_heads is None: | ||
| fusion_num_heads = vit_num_heads | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's examine the code in question
cat -n monai/networks/nets/magnus.py | sed -n '625,655p'Repository: Project-MONAI/MONAI
Length of output: 1325
🏁 Script executed:
# Check for TransformerEncoderLayer usage in the file
rg -n "TransformerEncoderLayer" monai/networks/nets/magnus.py -A 3 -B 3Repository: Project-MONAI/MONAI
Length of output: 337
🏁 Script executed:
# Check the broader context of vit_num_heads and fusion_num_heads usage
rg -n "vit_num_heads|fusion_num_heads" monai/networks/nets/magnus.pyRepository: Project-MONAI/MONAI
Length of output: 652
🌐 Web query:
PyTorch TransformerEncoderLayer d_model nhead divisibility requirement
💡 Result:
In torch.nn.TransformerEncoderLayer(d_model, nhead, ...), the self-attention module is multi-head attention, which splits the model/embedding dimension across heads. That implies a hard constraint:
- Requirement:
d_model % nhead == 0 - Per-head dimension:
head_dim = d_model // nhead(each head gets an equal slice of the embedding) [1], [2].
If d_model is not divisible by nhead, PyTorch raises the familiar error “embed_dim must be divisible by num_heads” (since d_model is used as embed_dim internally) [3].
Fix: pick an nhead that divides d_model (e.g., d_model=512 with nhead=8 or 16), or change d_model accordingly.
Sources: [1] [2] [3]
🏁 Script executed:
# Find TransformerPath class definition
rg -n "class TransformerPath" monai/networks/nets/magnus.py -A 30Repository: Project-MONAI/MONAI
Length of output: 1205
🏁 Script executed:
# Check if there's any validation of num_heads elsewhere
rg -n "% num_heads|% nhead|divisible" monai/networks/nets/magnus.pyRepository: Project-MONAI/MONAI
Length of output: 209
🏁 Script executed:
# Get more context around line 270 to understand where this validation is
cat -n monai/networks/nets/magnus.py | sed -n '260,280p'Repository: Project-MONAI/MONAI
Length of output: 882
🏁 Script executed:
# Check what class/method contains the line 270 validation
rg -n "def " monai/networks/nets/magnus.py | awk -F: '{print $1":"$2}' | while read line; do
linenum=$(echo $line | cut -d: -f1)
if [ $linenum -lt 270 ]; then
echo $line
else
break
fi
done | tail -1Repository: Project-MONAI/MONAI
Length of output: 81
🏁 Script executed:
# Get the class name for line 257
rg -n "class " monai/networks/nets/magnus.py | awk -F: '$1 < 257' | tail -1Repository: Project-MONAI/MONAI
Length of output: 110
🏁 Script executed:
# Check TransformerPath __init__ for any divisibility validation
cat -n monai/networks/nets/magnus.py | sed -n '136,180p'Repository: Project-MONAI/MONAI
Length of output: 1945
🏁 Script executed:
# Check where fusion_num_heads is used - what class instantiates it at line 667
cat -n monai/networks/nets/magnus.py | sed -n '655,675p'Repository: Project-MONAI/MONAI
Length of output: 880
🏁 Script executed:
# Check if vit_num_heads parameter is documented with any constraints
cat -n monai/networks/nets/magnus.py | sed -n '560,580p'Repository: Project-MONAI/MONAI
Length of output: 1503
🏁 Script executed:
# Check if there's any validation of user-provided vit_num_heads elsewhere in __init__
cat -n monai/networks/nets/magnus.py | sed -n '600,645p'Repository: Project-MONAI/MONAI
Length of output: 1933
Add divisibility validation for attention heads.
TransformerEncoderLayer requires d_model % nhead == 0. The auto-computed vit_num_heads can violate this (e.g., vit_hidden_dim=100 → vit_num_heads=3), and user-provided values have no validation. Validate both paths:
✅ Proposed fix
vit_hidden_dim = self.features[-1]
if vit_num_heads is None:
vit_num_heads = max(vit_hidden_dim // 32, 1)
+ if vit_hidden_dim % vit_num_heads != 0:
+ raise ValueError(
+ f"vit_hidden_dim ({vit_hidden_dim}) must be divisible by vit_num_heads ({vit_num_heads})."
+ )
if fusion_num_heads is None:
fusion_num_heads = vit_num_heads
+ if vit_hidden_dim % fusion_num_heads != 0:
+ raise ValueError(
+ f"vit_hidden_dim ({vit_hidden_dim}) must be divisible by fusion_num_heads ({fusion_num_heads})."
+ )🤖 Prompt for AI Agents
In `@monai/networks/nets/magnus.py` around lines 635 - 641, The computed or
provided attention head counts must be validated so that vit_hidden_dim %
vit_num_heads == 0 (and similarly for fusion_num_heads if used by a
TransformerEncoderLayer); update the logic around vit_hidden_dim, vit_num_heads,
and fusion_num_heads to check divisibility after computing/defaulting
vit_num_heads and before using them—if not divisible either raise a clear
ValueError explaining the mismatch and showing vit_hidden_dim and vit_num_heads,
or adjust by choosing the largest valid nhead ≤ computed value that divides
vit_hidden_dim (and apply the same validation/adjustment for fusion_num_heads),
ensuring any automated fallback still respects TransformerEncoderLayer’s d_model
% nhead == 0 requirement.
Thanks a lot for the encouraging feedback! It means a lot. I’m happy to refine the implementation further—if you have any preferences on design details, tests, or benchmarks, I’d love to align with them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Around line 192-215: The function _interpolate_pos_encoding contains an unused
parameter x; remove x from its signature and docstring (change def
_interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) to def
_interpolate_pos_encoding(self, num_patches: int)) and update all internal call
sites that pass x (e.g., any calls like self._interpolate_pos_encoding(x,
num_patches)) to call self._interpolate_pos_encoding(num_patches) so the
parameter list matches; also apply the same removal to the duplicate occurrence
referenced around lines 236-238 and ensure type hints and return description
remain correct.
🧹 Nitpick comments (1)
monai/networks/nets/magnus.py (1)
258-279: Add aRaisessection for theValueError.
Line 278 raises but the docstring doesn’t document it. As per coding guidelines, add aRaises:block.Docstring tweak
""" Cross-modal attention fusion between CNN and Transformer features. @@ Args: spatial_dims: number of spatial dimensions (2 or 3). channels: number of input/output channels. num_heads: number of attention heads. dropout: dropout rate for attention weights. + + Raises: + ValueError: if channels is not divisible by num_heads. """
| def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch.Tensor: | ||
| """ | ||
| Interpolate positional embeddings to match the number of patches. | ||
|
|
||
| Args: | ||
| x: input tensor for device reference. | ||
| num_patches: target number of patches. | ||
|
|
||
| Returns: | ||
| Interpolated positional embeddings of shape (1, num_patches, hidden_dim). | ||
| """ | ||
| if num_patches == self.pos_embed.shape[1]: | ||
| return self.pos_embed | ||
|
|
||
| # Interpolate positional embeddings | ||
| pos_embed = self.pos_embed.transpose(1, 2) # (1, hidden_dim, N) | ||
| pos_embed = F.interpolate( | ||
| pos_embed, | ||
| size=num_patches, | ||
| mode="linear", | ||
| align_corners=False, | ||
| ) | ||
| result: torch.Tensor = pos_embed.transpose(1, 2) # (1, num_patches, hidden_dim) | ||
| return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused _interpolate_pos_encoding arg (x).
Ruff reports Line 192 as unused.
Proposed fix
- def _interpolate_pos_encoding(self, x: torch.Tensor, num_patches: int) -> torch.Tensor:
+ def _interpolate_pos_encoding(self, num_patches: int) -> torch.Tensor:
"""
Interpolate positional embeddings to match the number of patches.
Args:
- x: input tensor for device reference.
num_patches: target number of patches.
@@
- pos_embed = self._interpolate_pos_encoding(x_flat, num_patches)
+ pos_embed = self._interpolate_pos_encoding(num_patches)Also applies to: 236-238
🧰 Tools
🪛 Ruff (0.14.14)
192-192: Unused method argument: x
(ARG002)
🤖 Prompt for AI Agents
In `@monai/networks/nets/magnus.py` around lines 192 - 215, The function
_interpolate_pos_encoding contains an unused parameter x; remove x from its
signature and docstring (change def _interpolate_pos_encoding(self, x:
torch.Tensor, num_patches: int) to def _interpolate_pos_encoding(self,
num_patches: int)) and update all internal call sites that pass x (e.g., any
calls like self._interpolate_pos_encoding(x, num_patches)) to call
self._interpolate_pos_encoding(num_patches) so the parameter list matches; also
apply the same removal to the duplicate occurrence referenced around lines
236-238 and ensure type hints and return description remain correct.
- Add learnable positional embeddings to TransformerPath for proper spatial reasoning - Implement dynamic positional embedding interpolation for varying input sizes - Add positional dropout for regularization - Update aux_weights docstring to clarify it's for external use only Addresses CodeRabbit review comments on PR Project-MONAI#8717 Signed-off-by: Sefa Aras <[email protected]>
606909c to
47b7247
Compare
- Add MAGNUS hybrid CNN-Transformer architecture for medical image segmentation - Implement CNNPath for hierarchical feature extraction - Implement TransformerPath for global context modeling - Add CrossModalAttentionFusion for bidirectional cross-attention - Add ScaleAdaptiveConv for multi-scale feature extraction - Add SEBlock for channel recalibration - Support both 2D and 3D medical images - Add deep supervision option - Add comprehensive unit tests Reference: Aras et al., IEEE Access 2026, DOI: 10.1109/ACCESS.2026.3656667 Signed-off-by: Sefa Aras <[email protected]>
for more information, see https://pre-commit.ci Signed-off-by: Sefa Aras <[email protected]>
- Add learnable positional embeddings to TransformerPath for proper spatial reasoning - Implement dynamic positional embedding interpolation for varying input sizes - Add positional dropout for regularization - Update aux_weights docstring to clarify it's for external use only Addresses CodeRabbit review comments on PR Project-MONAI#8717 Signed-off-by: Sefa Aras <[email protected]>
- Update type annotations to modern Python 3.10+ syntax (X | Y instead of Union) - Remove unused imports (Optional, Union) - Add docstrings to all __init__ methods for better coverage - Apply black and isort formatting - Fix ruff linting issues Improves docstring coverage from 72% to 80%+ Signed-off-by: Sefa Aras <[email protected]>
- Export SEBlock and DecoderBlock in __all__ and __init__.py - Add unit tests for SEBlock and DecoderBlock components - Fix TransformerEncoder warning with enable_nested_tensor=False - Add MAGNUS documentation to networks.rst Signed-off-by: Sefa Aras <[email protected]>
Signed-off-by: Sefa Aras <[email protected]>
Signed-off-by: Sefa Aras <[email protected]>
Signed-off-by: Sefa Aras <[email protected]>
Signed-off-by: Sefa Aras <[email protected]>
Signed-off-by: Sefa Aras <[email protected]>
5d1f0c7 to
7022314
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@monai/networks/nets/magnus.py`:
- Around line 704-719: The _init_weights method currently applies Kaiming normal
initialization to all nn.Linear layers (including Transformer attention/MLP
projections), which is inappropriate for GELU-based transformer modules like
TransformerEncoderLayer; change the initialization for nn.Linear inside
_init_weights to use truncated normal with std=0.02 (nn.init.trunc_normal_(...,
std=0.02)) for weights and keep biases zero, while leaving conv and norm
initializations unchanged so only linear layers used by attention/MLP get the
transformer-friendly init.
🧹 Nitpick comments (3)
tests/networks/nets/test_magnus.py (1)
155-160: Considerstrict=Trueforzip()iteration.Line 155 already validates lengths match, but adding
strict=Trueat line 156 provides an extra safeguard.Proposed fix
- for i, (feat, out) in enumerate(zip(features, outputs)): + for i, (feat, out) in enumerate(zip(features, outputs, strict=True)):monai/networks/nets/magnus.py (2)
36-44: Sort__all__alphabetically.Per static analysis (RUF022).
Proposed fix
__all__ = [ + "CNNPath", + "CrossModalAttentionFusion", + "DecoderBlock", "MAGNUS", - "CNNPath", - "TransformerPath", - "CrossModalAttentionFusion", + "MagnusSEBlock", "ScaleAdaptiveConv", - "MagnusSEBlock", - "DecoderBlock", + "TransformerPath", ]
758-758: Addstrict=Truetozip()for safety.Proposed fix
- for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips)): + for i, (decoder_block, skip) in enumerate(zip(self.decoder_blocks, cnn_skips, strict=True)):
| def _init_weights(self) -> None: | ||
| """Initialize model weights using Kaiming initialization.""" | ||
| for m in self.modules(): | ||
| if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d, nn.ConvTranspose3d)): | ||
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
| if m.bias is not None: | ||
| nn.init.constant_(m.bias, 0) | ||
| elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm, nn.InstanceNorm2d, nn.InstanceNorm3d)): | ||
| if m.weight is not None: | ||
| nn.init.constant_(m.weight, 1) | ||
| if m.bias is not None: | ||
| nn.init.constant_(m.bias, 0) | ||
| elif isinstance(m, nn.Linear): | ||
| nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") | ||
| if m.bias is not None: | ||
| nn.init.constant_(m.bias, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
find . -name "magnus.py" -type fRepository: Project-MONAI/MONAI
Length of output: 94
🏁 Script executed:
cat -n monai/networks/nets/magnus.py | head -750 | tail -100Repository: Project-MONAI/MONAI
Length of output: 4531
🏁 Script executed:
rg "TransformerPath|pos_embed" monai/networks/nets/magnus.py -A 5 -B 2Repository: Project-MONAI/MONAI
Length of output: 2077
🏁 Script executed:
rg "TransformerEncoderLayer|trunc_normal_" monai/networks/nets/magnus.py -A 3 -B 2Repository: Project-MONAI/MONAI
Length of output: 556
🌐 Web query:
Vision Transformer ViT weight initialization best practices attention layers GELU
💡 Result:
ViT weight-init best practices (attention layers + MLP/GELU)
1) Use the “Transformer-standard” small-std init for most linear weights (incl. QKV + proj + MLP).
Common, well-tested defaults are (truncated) normal with std ≈ 0.02 for nn.Linear weights (this covers Q, K, V, the attention output projection, and the MLP/FFN layers). Biases are typically zeros. This matches widespread ViT practice in timm-style implementations and HF configs (initializer range 0.02). [1][2][3]
2) LayerNorm (and similar norms): weight = 1, bias = 0.
This is the standard and what major implementations do. [2][3]
3) Learnable position embeddings and CLS token: initialize small, then learn.
A common choice is trunc_normal(std=0.02) for pos_embed, and a very small normal for cls_token (some implementations use std around 1e-6 for the class token). [3][4]
4) Patch embedding (Conv2d) init depends on “PyTorch/timm” vs “JAX-style” compatibility.
Two common choices:
- Treat it like other weights (trunc_normal std 0.02), OR
- JAX-style: LeCun normal for the Conv kernel (i.e., variance-scaling fan-in, truncated normal). [4][5]
5) GELU doesn’t require special initialization by itself.
GELU is typically paired with the same initializers above; stability is usually handled more by pre-LN architecture, residual scaling choices, and LR/warmup than by changing init specifically for GELU. HF ViT exposes "gelu" while keeping the same initializer_range. [1]
Practical PyTorch-style default (good starting point)
Linear/Conv: normal or trunc_normal,std=0.02; bias zeros. [2][4]LayerNorm: weight ones, bias zeros. [2][4]pos_embed: trunc_normalstd=0.02;cls_token: very small normal. [3][4]- If matching JAX ViT behavior for patch Conv: LeCun normal. [4][5]
Sources: Hugging Face ViT config & init, plus timm-like init patterns in open implementations, and LeCun-normal definition. [1][2][3][4][5]
References
[1] https://huggingface.co/docs/transformers/v4.15.0/en/model_doc/vit
[2] https://huggingface.co/transformers/v4.11.1/_modules/transformers/models/vit/modeling_vit.html
[3] https://vis4d.readthedocs.io/en/stable/_modules/vis4d/op/base/vit.html
[4] https://flowvision.readthedocs.io/en/latest/_modules/flowvision/models/vision_transformer.html
[5] https://docs.jax.dev/en/latest/_autosummary/jax.nn.initializers.lecun_normal.html
Use truncated normal initialization for transformer attention/MLP layers instead of Kaiming with ReLU.
The _init_weights method applies nn.init.kaiming_normal_(..., nonlinearity="relu") to all nn.Linear layers (line 717), including attention projections and MLP in TransformerEncoderLayer. Vision Transformer best practices initialize these layers with truncated normal (std=0.02), not Kaiming—Kaiming is designed for ReLU networks and is suboptimal for GELU-based transformers. Consider using nn.init.trunc_normal_(..., std=0.02) for nn.Linear instead.
🤖 Prompt for AI Agents
In `@monai/networks/nets/magnus.py` around lines 704 - 719, The _init_weights
method currently applies Kaiming normal initialization to all nn.Linear layers
(including Transformer attention/MLP projections), which is inappropriate for
GELU-based transformer modules like TransformerEncoderLayer; change the
initialization for nn.Linear inside _init_weights to use truncated normal with
std=0.02 (nn.init.trunc_normal_(..., std=0.02)) for weights and keep biases
zero, while leaving conv and norm initializations unchanged so only linear
layers used by attention/MLP get the transformer-friendly init.
Signed-off-by: Sefa Aras <[email protected]>
Description
This PR adds MAGNUS (Multi-Attention Guided Network for Unified Segmentation), a hybrid CNN-Transformer architecture for medical image segmentation.
Key Features
New Files
monai/networks/nets/magnus.py- Main implementationtests/networks/nets/test_magnus.py- Unit tests (17 tests)Modified Files
monai/networks/nets/__init__.py- Export MAGNUS and componentsUsage Example
from monai.networks.nets import MAGNUS
model = MAGNUS(
spatial_dims=3,
in_channels=1,
out_channels=2,
features=(64, 128, 256, 512),
)
Test Results
All 17 unit tests pass ✅
Reference
Aras, E., Kayikcioglu, T., Aras, S., & Merd, N. (2026). MAGNUS: Multi-Attention Guided Network for Unified Segmentation via CNN-ViT Fusion. IEEE Access. DOI: 10.1109/ACCESS.2026.3656667