From cc2d046fb09baa3e469c936268b6ef6dabd64f78 Mon Sep 17 00:00:00 2001 From: sebastianbreguel Date: Tue, 7 Apr 2026 16:40:16 -0400 Subject: [PATCH] Fix find_topics crash when search_term is a list of strings Closes #2475, closes #2392 find_topics() always wrapped search_term in an extra list before passing it to the embedder, so a list input became a nested list. sentence-transformers then interpreted the inner list as a text-pair, raising IndexError for a single-element list and silently producing wrong embeddings for length-2 lists. Resolve search_term to a flat list of strings (string -> single-element list, list -> as-is), embed all terms, and average the resulting vectors into a single query embedding. Empty lists now raise ValueError explicitly instead of silently returning NaN similarities. Adds a regression test covering single- and multi-element list inputs across fixture models, plus a test for the empty-list ValueError. --- bertopic/_bertopic.py | 12 +++++++-- .../test_representations.py | 25 +++++++++++++++++++ 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index cfafb58a..68921f5f 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1429,7 +1429,7 @@ def approximate_distribution( return topic_distributions, topic_token_distributions def find_topics( - self, search_term: str | None = None, image: str | None = None, top_n: int = 5 + self, search_term: str | List[str] | None = None, image: str | None = None, top_n: int = 5 ) -> Tuple[List[int], List[float]]: """Find topics most similar to a search_term. @@ -1462,6 +1462,10 @@ def find_topics( Note that the search query is typically more accurate if the search_term consists of a phrase or multiple words. + + When ``search_term`` is a list of strings, each term is embedded + independently and the resulting vectors are averaged into a single + query embedding before computing similarity against topic embeddings. """ if self.embedding_model is None: raise Exception("This method can only be used if you did not use custom embeddings.") @@ -1471,7 +1475,11 @@ def find_topics( # Extract search_term embeddings and compare with topic embeddings if search_term is not None: - search_embedding = self._extract_embeddings([search_term], method="word", verbose=False).flatten() + search_terms = [search_term] if isinstance(search_term, str) else list(search_term) + if not search_terms: + raise ValueError("search_term must be a non-empty string or list of strings.") + search_embeddings = self._extract_embeddings(search_terms, method="word", verbose=False) + search_embedding = np.mean(search_embeddings, axis=0).flatten() elif image is not None: search_embedding = self._extract_embeddings( [None], images=[image], method="document", verbose=False diff --git a/tests/test_representation/test_representations.py b/tests/test_representation/test_representations.py index fa756625..b109f0c9 100644 --- a/tests/test_representation/test_representations.py +++ b/tests/test_representation/test_representations.py @@ -182,3 +182,28 @@ def test_find_topics(model, request): assert np.mean(similarity) > 0.1 assert len(similar_topics) > 0 + + +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ], +) +@pytest.mark.parametrize("search_term", [["car"], ["car", "vehicle"]]) +def test_find_topics_with_list(model, search_term, request): + """Regression test for #2392 / #2475: ``find_topics`` must accept a list + of search terms (including a single-element list) without crashing.""" + topic_model = copy.deepcopy(request.getfixturevalue(model)) + similar_topics, similarity = topic_model.find_topics(search_term) + + assert len(similar_topics) > 0 + assert len(similar_topics) == len(similarity) + + +def test_find_topics_empty_list_raises(base_topic_model): + """An empty ``search_term`` list must raise instead of silently returning NaN similarities.""" + topic_model = copy.deepcopy(base_topic_model) + with pytest.raises(ValueError, match="non-empty"): + topic_model.find_topics([])