diff --git a/CodonTransformer/CodonPrediction.py b/CodonTransformer/CodonPrediction.py index 7639080..9154820 100644 --- a/CodonTransformer/CodonPrediction.py +++ b/CodonTransformer/CodonPrediction.py @@ -284,8 +284,8 @@ def sample_non_deterministic( raise ValueError("top_p must be a float between 0 and 1.") # Compute probabilities using temperature scaling - logits /= temperature - probs = torch.softmax(logits, dim=-1) + probs = torch.softmax(logits / temperature, dim=-1) + # Remove batch dimension if present if probs.dim() == 3: