diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 90379b86..2622c86c 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -477,24 +477,34 @@ def fit_transform( # Create documents from images if we have images only if documents.Document.values[0] is None: custom_documents = self._images_to_text(documents, embeddings) - - # Extract topics by calculating c-TF-IDF - self._extract_topics(custom_documents, embeddings=embeddings) - self._create_topic_vectors(documents=documents, embeddings=embeddings) - - # Reduce topics + # Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations. + self._extract_topics(custom_documents, embeddings=embeddings, calculate_representation=not self.nr_topics) + if self.nr_topics: - custom_documents = self._reduce_topics(custom_documents) + if (isinstance(self.nr_topics, (int)) and self.nr_topics < len(self.get_topics())) or isinstance(self.nr_topics, str): + custom_documents = self._reduce_topics(custom_documents) + else: + logger.info(f"Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())}).") + documents = self._sort_mappings_by_frequency(documents) + self._extract_topics(custom_documents, embeddings=embeddings) + + self._create_topic_vectors(documents=documents, embeddings=embeddings) # Save the top 3 most representative documents per topic self._save_representative_docs(custom_documents) + else: - # Extract topics by calculating c-TF-IDF - self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose) + # Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations. + self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose, calculate_representation=not self.nr_topics) - # Reduce topics if self.nr_topics: - documents = self._reduce_topics(documents) + if (isinstance(self.nr_topics, (int)) and self.nr_topics < len(self.get_topics())) or isinstance(self.nr_topics, str): + # Reduce topics + documents = self._reduce_topics(documents) + else: + logger.info(f"Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())}).") + documents = self._sort_mappings_by_frequency(documents) + self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose) # Save the top 3 most representative documents per topic self._save_representative_docs(documents) @@ -3972,6 +3982,7 @@ def _extract_topics( embeddings: np.ndarray = None, mappings=None, verbose: bool = False, + calculate_representation: bool = True, ): """Extract topics from the clusters using a class-based TF-IDF. @@ -3980,18 +3991,25 @@ def _extract_topics( embeddings: The document embeddings mappings: The mappings from topic to word verbose: Whether to log the process of extracting topics + calculate_representation: Whether to extract the topic representations Returns: c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic """ if verbose: - logger.info("Representation - Extracting topics from clusters using representation models.") + action = "Representation" if calculate_representation else "Topics" + logger.info(f"{action} - Extracting topics from clusters{'using representation models' if calculate_representation else ''}.") + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join}) self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic) - self.topic_representations_ = self._extract_words_per_topic(words, documents) + self.topic_representations_ = self._extract_words_per_topic( + words, documents, + calculate_representation=calculate_representation, + calculate_aspects=calculate_representation) self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings) + if verbose: - logger.info("Representation - Completed \u2713") + logger.info(f"{action} - Completed \u2713") def _save_representative_docs(self, documents: pd.DataFrame): """Save the 3 most representative docs per topic. @@ -4245,6 +4263,7 @@ def _extract_words_per_topic( words: List[str], documents: pd.DataFrame, c_tf_idf: csr_matrix = None, + calculate_representation: bool = True, calculate_aspects: bool = True, ) -> Mapping[str, List[Tuple[str, float]]]: """Based on tf_idf scores per topic, extract the top n words per topic. @@ -4258,6 +4277,7 @@ def _extract_words_per_topic( words: List of all words (sorted according to tf_idf matrix position) documents: DataFrame with documents and their topic IDs c_tf_idf: A c-TF-IDF matrix from which to calculate the top words + calculate_representation: Whether to calculate the topic representations calculate_aspects: Whether to calculate additional topic aspects Returns: @@ -4288,15 +4308,15 @@ def _extract_words_per_topic( # Fine-tune the topic representations topics = base_topics.copy() - if not self.representation_model: + if not self.representation_model or not calculate_representation: # Default representation: c_tf_idf + top_n_words topics = {label: values[: self.top_n_words] for label, values in topics.items()} - elif isinstance(self.representation_model, list): + elif calculate_representation and isinstance(self.representation_model, list): for tuner in self.representation_model: topics = tuner.extract_topics(self, documents, c_tf_idf, topics) - elif isinstance(self.representation_model, BaseRepresentation): + elif calculate_representation and isinstance(self.representation_model, BaseRepresentation): topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics) - elif isinstance(self.representation_model, dict): + elif calculate_representation and isinstance(self.representation_model, dict): if self.representation_model.get("Main"): main_model = self.representation_model["Main"] if isinstance(main_model, BaseRepresentation): @@ -4412,7 +4432,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) # Update representations documents = self._sort_mappings_by_frequency(documents) - self._extract_topics(documents, mappings=mappings) + self._extract_topics(documents, mappings=mappings, verbose=self.verbose) self._update_topic_size(documents) return documents @@ -4468,7 +4488,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) # Update documents and topics self.topic_mapper_.add_mappings(mapped_topics, topic_model=self) documents = self._sort_mappings_by_frequency(documents) - self._extract_topics(documents, mappings=mappings) + self._extract_topics(documents, mappings=mappings, verbose=self.verbose) self._update_topic_size(documents) return documents