diff --git a/gemma/gm/nn/_modules_test.py b/gemma/gm/nn/_modules_test.py index afde3dcf..c81c0eed 100644 --- a/gemma/gm/nn/_modules_test.py +++ b/gemma/gm/nn/_modules_test.py @@ -71,7 +71,57 @@ def test_embedder_decode(): np.testing.assert_array_equal(output, jnp.array(expected)) -# TODO(mblondel): Add tests for `encode_vision` here. +def test_encode_vision_output_shape(): + """encode_vision should project from vision_proj_dim to embed_dim.""" + vocab_size = 10 + embed_dim = 8 + vision_proj_dim = 16 + embedder = gm.nn.Embedder( + vocab_size=vocab_size, + embed_dim=embed_dim, + vision_proj_dim=vision_proj_dim, + ) + + rng = jax.random.PRNGKey(0) + dummy_vision_input = jnp.ones((1, 4, vision_proj_dim)) + params = embedder.init(rng, dummy_vision_input, method=embedder.encode_vision) + + # [B, num_patches, vision_proj_dim] -> [B, num_patches, embed_dim] + vision_input = jax.random.normal(rng, (2, 4, vision_proj_dim)) + output = embedder.apply( + params, vision_input, method=embedder.encode_vision + ) + assert output.shape == (2, 4, embed_dim) + + +def test_encode_vision_different_batch_shapes(): + """encode_vision should handle various batch/patch dimensions.""" + vocab_size = 10 + embed_dim = 8 + vision_proj_dim = 16 + embedder = gm.nn.Embedder( + vocab_size=vocab_size, + embed_dim=embed_dim, + vision_proj_dim=vision_proj_dim, + ) + + rng = jax.random.PRNGKey(0) + dummy_input = jnp.ones((1, 1, vision_proj_dim)) + params = embedder.init(rng, dummy_input, method=embedder.encode_vision) + + output_1 = embedder.apply( + params, + jax.random.normal(rng, (1, 1, vision_proj_dim)), + method=embedder.encode_vision, + ) + assert output_1.shape == (1, 1, embed_dim) + + output_8 = embedder.apply( + params, + jax.random.normal(rng, (1, 8, vision_proj_dim)), + method=embedder.encode_vision, + ) + assert output_8.shape == (1, 8, embed_dim) def test_sliding_mask():