Skip to content

Commit

Permalink
Made graphs more readable, normalized topic importances in dynamic GMM
Browse files Browse the repository at this point in the history
  • Loading branch information
x-tabdeveloping committed Feb 19, 2024
1 parent d88c72c commit 520f621
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
21 changes: 17 additions & 4 deletions turftopic/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,15 @@ def print_topics_over_time(
start_str = start_dt.strftime(date_format)
end_str = end_dt.strftime(date_format)
slice_names.append(f"{start_str} - {end_str}")
n_topics = self.temporal_components_.shape[1]
try:
topic_names = self.topic_names
except AttributeError:
topic_names = [f"Topic {i}" for i in range(n_topics)]
table = Table(show_lines=True)
table.add_column("Time Slice")
for i_topic in range(self.temporal_components_.shape[1]):
table.add_column(f"Topic {i_topic}")
for topic in topic_names:
table.add_column(topic)
for slice_name, components in zip(slice_names, temporal_components):
fields = []
fields.append(slice_name)
Expand Down Expand Up @@ -159,6 +164,11 @@ def plot_topics_over_time(self, top_k: int = 6):
) from e
fig = go.Figure()
vocab = self.get_vocab()
n_topics = self.temporal_components_.shape[1]
try:
topic_names = self.topic_names
except AttributeError:
topic_names = [f"Topic {i}" for i in range(n_topics)]
for i_topic, topic_imp_t in enumerate(self.temporal_importance_.T):
component_over_time = self.temporal_components_[:, i_topic, :]
name_over_time = []
Expand All @@ -177,9 +187,12 @@ def plot_topics_over_time(self, top_k: int = 6):
y=topic_imp_t,
mode="markers+lines",
text=name_over_time,
name=f"Topic {i_topic}",
name=topic_names[i_topic],
hovertemplate="<b>%{text}</b>",
marker=dict(line=dict(width=2, color="black"), size=14),
marker=dict(
line=dict(width=2, color="black"),
size=14,
),
line=dict(width=3),
)
)
Expand Down
2 changes: 2 additions & 0 deletions turftopic/models/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def fit_transform_dynamic(
topic_importances = doc_topic_matrix[time_labels == i_timebin].sum(
axis=0
)
# Normalizing
topic_importances = topic_importances / topic_importances.sum()
components = soft_ctf_idf(
doc_topic_matrix[time_labels == i_timebin],
document_term_matrix[time_labels == i_timebin], # type: ignore
Expand Down

0 comments on commit 520f621

Please sign in to comment.