Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion gemma/gm/nn/_modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,57 @@ def test_embedder_decode():
np.testing.assert_array_equal(output, jnp.array(expected))


# TODO(mblondel): Add tests for `encode_vision` here.
def test_encode_vision_output_shape():
"""encode_vision should project from vision_proj_dim to embed_dim."""
vocab_size = 10
embed_dim = 8
vision_proj_dim = 16
embedder = gm.nn.Embedder(
vocab_size=vocab_size,
embed_dim=embed_dim,
vision_proj_dim=vision_proj_dim,
)

rng = jax.random.PRNGKey(0)
dummy_vision_input = jnp.ones((1, 4, vision_proj_dim))
params = embedder.init(rng, dummy_vision_input, method=embedder.encode_vision)

# [B, num_patches, vision_proj_dim] -> [B, num_patches, embed_dim]
vision_input = jax.random.normal(rng, (2, 4, vision_proj_dim))
output = embedder.apply(
params, vision_input, method=embedder.encode_vision
)
assert output.shape == (2, 4, embed_dim)


def test_encode_vision_different_batch_shapes():
"""encode_vision should handle various batch/patch dimensions."""
vocab_size = 10
embed_dim = 8
vision_proj_dim = 16
embedder = gm.nn.Embedder(
vocab_size=vocab_size,
embed_dim=embed_dim,
vision_proj_dim=vision_proj_dim,
)

rng = jax.random.PRNGKey(0)
dummy_input = jnp.ones((1, 1, vision_proj_dim))
params = embedder.init(rng, dummy_input, method=embedder.encode_vision)

output_1 = embedder.apply(
params,
jax.random.normal(rng, (1, 1, vision_proj_dim)),
method=embedder.encode_vision,
)
assert output_1.shape == (1, 1, embed_dim)

output_8 = embedder.apply(
params,
jax.random.normal(rng, (1, 8, vision_proj_dim)),
method=embedder.encode_vision,
)
assert output_8.shape == (1, 8, embed_dim)


def test_sliding_mask():
Expand Down