From 520f621b5ce2bb0f8e6672ba1e42a298575fbb0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Kardos?= Date: Mon, 19 Feb 2024 08:26:32 +0100 Subject: [PATCH] Made graphs more readable, normalized topic importances in dynamic GMM --- turftopic/dynamic.py | 21 +++++++++++++++++---- turftopic/models/gmm.py | 2 ++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/turftopic/dynamic.py b/turftopic/dynamic.py index 82d5a49..f70912a 100644 --- a/turftopic/dynamic.py +++ b/turftopic/dynamic.py @@ -117,10 +117,15 @@ def print_topics_over_time( start_str = start_dt.strftime(date_format) end_str = end_dt.strftime(date_format) slice_names.append(f"{start_str} - {end_str}") + n_topics = self.temporal_components_.shape[1] + try: + topic_names = self.topic_names + except AttributeError: + topic_names = [f"Topic {i}" for i in range(n_topics)] table = Table(show_lines=True) table.add_column("Time Slice") - for i_topic in range(self.temporal_components_.shape[1]): - table.add_column(f"Topic {i_topic}") + for topic in topic_names: + table.add_column(topic) for slice_name, components in zip(slice_names, temporal_components): fields = [] fields.append(slice_name) @@ -159,6 +164,11 @@ def plot_topics_over_time(self, top_k: int = 6): ) from e fig = go.Figure() vocab = self.get_vocab() + n_topics = self.temporal_components_.shape[1] + try: + topic_names = self.topic_names + except AttributeError: + topic_names = [f"Topic {i}" for i in range(n_topics)] for i_topic, topic_imp_t in enumerate(self.temporal_importance_.T): component_over_time = self.temporal_components_[:, i_topic, :] name_over_time = [] @@ -177,9 +187,12 @@ def plot_topics_over_time(self, top_k: int = 6): y=topic_imp_t, mode="markers+lines", text=name_over_time, - name=f"Topic {i_topic}", + name=topic_names[i_topic], hovertemplate="%{text}", - marker=dict(line=dict(width=2, color="black"), size=14), + marker=dict( + line=dict(width=2, color="black"), + size=14, + ), line=dict(width=3), ) ) diff --git a/turftopic/models/gmm.py b/turftopic/models/gmm.py index 7f67967..cc1d2ef 100644 --- a/turftopic/models/gmm.py +++ b/turftopic/models/gmm.py @@ -153,6 +153,8 @@ def fit_transform_dynamic( topic_importances = doc_topic_matrix[time_labels == i_timebin].sum( axis=0 ) + # Normalizing + topic_importances = topic_importances / topic_importances.sum() components = soft_ctf_idf( doc_topic_matrix[time_labels == i_timebin], document_term_matrix[time_labels == i_timebin], # type: ignore