Skip to content

Commit

Permalink
Fixed issue MaartenGr#2144
Browse files Browse the repository at this point in the history
  • Loading branch information
pipa666 committed Sep 16, 2024
1 parent 0b4265a commit 4ce0b1a
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,24 +477,34 @@ def fit_transform(
# Create documents from images if we have images only
if documents.Document.values[0] is None:
custom_documents = self._images_to_text(documents, embeddings)

# Extract topics by calculating c-TF-IDF
self._extract_topics(custom_documents, embeddings=embeddings)
self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Reduce topics
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(custom_documents, embeddings=embeddings, calculate_representation=not self.nr_topics)

if self.nr_topics:
custom_documents = self._reduce_topics(custom_documents)
if (isinstance(self.nr_topics, (int)) and self.nr_topics < len(self.get_topics())) or isinstance(self.nr_topics, str):
custom_documents = self._reduce_topics(custom_documents)
else:
logger.info(f"Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())}).")
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(custom_documents, embeddings=embeddings)

self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Save the top 3 most representative documents per topic
self._save_representative_docs(custom_documents)

else:
# Extract topics by calculating c-TF-IDF
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose)
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose, calculate_representation=not self.nr_topics)

# Reduce topics
if self.nr_topics:
documents = self._reduce_topics(documents)
if (isinstance(self.nr_topics, (int)) and self.nr_topics < len(self.get_topics())) or isinstance(self.nr_topics, str):
# Reduce topics
documents = self._reduce_topics(documents)
else:
logger.info(f"Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())}).")
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose)

# Save the top 3 most representative documents per topic
self._save_representative_docs(documents)
Expand Down Expand Up @@ -3972,6 +3982,7 @@ def _extract_topics(
embeddings: np.ndarray = None,
mappings=None,
verbose: bool = False,
calculate_representation: bool = True,
):
"""Extract topics from the clusters using a class-based TF-IDF.
Expand All @@ -3980,18 +3991,25 @@ def _extract_topics(
embeddings: The document embeddings
mappings: The mappings from topic to word
verbose: Whether to log the process of extracting topics
calculate_representation: Whether to extract the topic representations
Returns:
c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic
"""
if verbose:
logger.info("Representation - Extracting topics from clusters using representation models.")
action = "Representation" if calculate_representation else "Topics"
logger.info(f"{action} - Extracting topics from clusters{'using representation models' if calculate_representation else ''}.")

documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join})
self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic)
self.topic_representations_ = self._extract_words_per_topic(words, documents)
self.topic_representations_ = self._extract_words_per_topic(
words, documents,
calculate_representation=calculate_representation,
calculate_aspects=calculate_representation)
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)

if verbose:
logger.info("Representation - Completed \u2713")
logger.info(f"{action} - Completed \u2713")

def _save_representative_docs(self, documents: pd.DataFrame):
"""Save the 3 most representative docs per topic.
Expand Down Expand Up @@ -4245,6 +4263,7 @@ def _extract_words_per_topic(
words: List[str],
documents: pd.DataFrame,
c_tf_idf: csr_matrix = None,
calculate_representation: bool = True,
calculate_aspects: bool = True,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Based on tf_idf scores per topic, extract the top n words per topic.
Expand All @@ -4258,6 +4277,7 @@ def _extract_words_per_topic(
words: List of all words (sorted according to tf_idf matrix position)
documents: DataFrame with documents and their topic IDs
c_tf_idf: A c-TF-IDF matrix from which to calculate the top words
calculate_representation: Whether to calculate the topic representations
calculate_aspects: Whether to calculate additional topic aspects
Returns:
Expand Down Expand Up @@ -4288,15 +4308,15 @@ def _extract_words_per_topic(

# Fine-tune the topic representations
topics = base_topics.copy()
if not self.representation_model:
if not self.representation_model or not calculate_representation:
# Default representation: c_tf_idf + top_n_words
topics = {label: values[: self.top_n_words] for label, values in topics.items()}
elif isinstance(self.representation_model, list):
elif calculate_representation and isinstance(self.representation_model, list):
for tuner in self.representation_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, BaseRepresentation):
elif calculate_representation and isinstance(self.representation_model, BaseRepresentation):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, dict):
elif calculate_representation and isinstance(self.representation_model, dict):
if self.representation_model.get("Main"):
main_model = self.representation_model["Main"]
if isinstance(main_model, BaseRepresentation):
Expand Down Expand Up @@ -4412,7 +4432,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)

# Update representations
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._extract_topics(documents, mappings=mappings, verbose=self.verbose)

self._update_topic_size(documents)
return documents
Expand Down Expand Up @@ -4468,7 +4488,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
# Update documents and topics
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._extract_topics(documents, mappings=mappings, verbose=self.verbose)
self._update_topic_size(documents)
return documents

Expand Down

0 comments on commit 4ce0b1a

Please sign in to comment.