Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ def approximate_distribution(
return topic_distributions, topic_token_distributions

def find_topics(
self, search_term: str | None = None, image: str | None = None, top_n: int = 5
self, search_term: str | List[str] | None = None, image: str | None = None, top_n: int = 5
) -> Tuple[List[int], List[float]]:
"""Find topics most similar to a search_term.

Expand Down Expand Up @@ -1462,6 +1462,10 @@ def find_topics(

Note that the search query is typically more accurate if the
search_term consists of a phrase or multiple words.

When ``search_term`` is a list of strings, each term is embedded
independently and the resulting vectors are averaged into a single
query embedding before computing similarity against topic embeddings.
"""
if self.embedding_model is None:
raise Exception("This method can only be used if you did not use custom embeddings.")
Expand All @@ -1471,7 +1475,11 @@ def find_topics(

# Extract search_term embeddings and compare with topic embeddings
if search_term is not None:
search_embedding = self._extract_embeddings([search_term], method="word", verbose=False).flatten()
search_terms = [search_term] if isinstance(search_term, str) else list(search_term)
if not search_terms:
raise ValueError("search_term must be a non-empty string or list of strings.")
search_embeddings = self._extract_embeddings(search_terms, method="word", verbose=False)
search_embedding = np.mean(search_embeddings, axis=0).flatten()
elif image is not None:
search_embedding = self._extract_embeddings(
[None], images=[image], method="document", verbose=False
Expand Down
25 changes: 25 additions & 0 deletions tests/test_representation/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,28 @@ def test_find_topics(model, request):

assert np.mean(similarity) > 0.1
assert len(similar_topics) > 0


@pytest.mark.parametrize(
"model",
[
("kmeans_pca_topic_model"),
("base_topic_model"),
],
)
@pytest.mark.parametrize("search_term", [["car"], ["car", "vehicle"]])
def test_find_topics_with_list(model, search_term, request):
"""Regression test for #2392 / #2475: ``find_topics`` must accept a list
of search terms (including a single-element list) without crashing."""
topic_model = copy.deepcopy(request.getfixturevalue(model))
similar_topics, similarity = topic_model.find_topics(search_term)

assert len(similar_topics) > 0
assert len(similar_topics) == len(similarity)


def test_find_topics_empty_list_raises(base_topic_model):
"""An empty ``search_term`` list must raise instead of silently returning NaN similarities."""
topic_model = copy.deepcopy(base_topic_model)
with pytest.raises(ValueError, match="non-empty"):
topic_model.find_topics([])
Loading