From 4ea7dcd2e387a16fc6552aed7cf6d3c0dbd028d8 Mon Sep 17 00:00:00 2001 From: ChiragTrivedi06 Date: Wed, 18 Feb 2026 17:27:16 +0530 Subject: [PATCH] fix: prevent division by zero in sampling when temperature is 0.0 --- .gitignore | 3 +++ gemma/gm/text/_sampling.py | 24 +++++++++++------------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index f0a15d26..275f49bc 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,6 @@ poetry.lock # Ignore generated docs docs/_build docs/api + +# virtual environments +.venv/ diff --git a/gemma/gm/text/_sampling.py b/gemma/gm/text/_sampling.py index 76d46bc6..fe177dd1 100644 --- a/gemma/gm/text/_sampling.py +++ b/gemma/gm/text/_sampling.py @@ -58,7 +58,8 @@ class RandomSampling(SamplingMethod): @typechecked def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']: - return jax.random.categorical(rng, logits / self.temperature, axis=-1) + scaled_logits = logits if self.temperature < 1e-6 else logits / self.temperature + return jax.random.categorical(rng, scaled_logits, axis=-1) @dataclasses.dataclass(frozen=True, kw_only=True) @@ -74,9 +75,8 @@ def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']: batch_size = logits.shape[0] topk_values, topk_indices = jax.lax.top_k(logits, self.k) - sampled_topk_indices = jax.random.categorical( - rng, topk_values / self.temperature, axis=-1 - ) + scaled_topk_values = topk_values if self.temperature < 1e-6 else topk_values / self.temperature + sampled_topk_indices = jax.random.categorical(rng, scaled_topk_values, axis=-1) batch_indices = jnp.arange(batch_size) topk_indices = topk_indices[batch_indices, sampled_topk_indices] return enp.unflatten(topk_indices, batch_shape, '...') @@ -91,11 +91,10 @@ class TopPSampling(SamplingMethod): @typechecked def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']: - # temperature scaling - logits = logits / self.temperature + scaled_logits = logits if self.temperature < 1e-6 else logits / self.temperature if self.p < 1.0: - sorted_logits = jnp.sort(logits, axis=-1, descending=True) + sorted_logits = jnp.sort(scaled_logits, axis=-1, descending=True) cumulative_probs = jnp.cumsum( jax.nn.softmax(sorted_logits, axis=-1), axis=-1 @@ -108,11 +107,10 @@ def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']: cutoff_logit = jnp.take_along_axis(sorted_logits, cutoff_index, axis=-1) # select logit values that are smaller than the cutoff logit. - logits = jnp.where( - logits < cutoff_logit, - jnp.finfo(logits.dtype).min, - logits, + scaled_logits = jnp.where( + scaled_logits < cutoff_logit, + jnp.finfo(scaled_logits.dtype).min, + scaled_logits, ) - return jax.random.categorical(rng, logits, axis=-1) - + return jax.random.categorical(rng, scaled_logits, axis=-1)