From 6a2663fad3811a9101fe1f9e98f0648b2235ea77 Mon Sep 17 00:00:00 2001 From: adithya32 <163162210+KumarADITHYA123@users.noreply.github.com> Date: Sun, 15 Feb 2026 02:54:28 +0530 Subject: [PATCH] Fix division by zero in sampling when temperature is 0.0 --- gemma/gm/text/_sampling.py | 9 ++++++++- gemma/gm/text/_sampling_test.py | 20 ++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/gemma/gm/text/_sampling.py b/gemma/gm/text/_sampling.py index 76d46bc6..9efd363a 100644 --- a/gemma/gm/text/_sampling.py +++ b/gemma/gm/text/_sampling.py @@ -58,6 +58,8 @@ class RandomSampling(SamplingMethod): @typechecked def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']: + if self.temperature < 1e-6: + return Greedy().get_next_tokens(logits, rng) return jax.random.categorical(rng, logits / self.temperature, axis=-1) @@ -70,6 +72,9 @@ class TopkSampling(SamplingMethod): @typechecked def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']: + if self.temperature < 1e-6: + return Greedy().get_next_tokens(logits, rng) + logits, batch_shape = enp.flatten(logits, '... V') batch_size = logits.shape[0] @@ -91,6 +96,9 @@ class TopPSampling(SamplingMethod): @typechecked def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']: + if self.temperature < 1e-6: + return Greedy().get_next_tokens(logits, rng) + # temperature scaling logits = logits / self.temperature @@ -115,4 +123,3 @@ def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']: ) return jax.random.categorical(rng, logits, axis=-1) - diff --git a/gemma/gm/text/_sampling_test.py b/gemma/gm/text/_sampling_test.py index afbe1a0d..2ee77526 100644 --- a/gemma/gm/text/_sampling_test.py +++ b/gemma/gm/text/_sampling_test.py @@ -79,3 +79,23 @@ def test_top1_sampling_matches_greedy_sampling(): tokens_top1 = top1_sampling.get_next_tokens(logits, rng) np.testing.assert_array_equal(tokens_greedy, tokens_top1) + + +def test_zero_temperature_behavior(): + rng = jax.random.PRNGKey(0) + logits = jax.numpy.array([[10.0, 5.0]]) + + # Test RandomSampling + sampler = gm.text.RandomSampling(temperature=0.0) + tokens = sampler.get_next_tokens(logits, rng) + np.testing.assert_array_equal(tokens, [0]) + + # Test TopkSampling + sampler = gm.text.TopkSampling(k=5, temperature=0.0) + tokens = sampler.get_next_tokens(logits, rng) + np.testing.assert_array_equal(tokens, [0]) + + # Test TopPSampling + sampler = gm.text.TopPSampling(p=0.9, temperature=0.0) + tokens = sampler.get_next_tokens(logits, rng) + np.testing.assert_array_equal(tokens, [0])