diff --git a/README.md b/README.md index c9a309e..e112b32 100644 --- a/README.md +++ b/README.md @@ -9,9 +9,10 @@ - Semantic Signal Separation - SΒ³ 🧭 - KeyNMF πŸ”‘ (paper in progress ⏳) - GMM :gem: (paper soon) - - Implementations of existing transformer-based topic models + - Implementations of other transformer-based topic models - Clustering Topic Models: BERTopic and Top2Vec - Autoencoding Topic Models: CombinedTM and ZeroShotTM + - FASTopic - Streamlined scikit-learn compatible API πŸ› οΈ - Easy topic interpretation πŸ” - Dynamic Topic Modeling πŸ“ˆ (GMM, ClusteringTopicModel and KeyNMF) @@ -19,43 +20,45 @@ > This package is still work in progress and scientific papers on some of the novel methods are currently undergoing peer-review. If you use this package and you encounter any problem, let us know by opening relevant issues. -### New in version 0.4.0 +### New in version 0.5.0 -#### Online KeyNMF +#### Hierarchical KeyNMF -You can now online fit and finetune KeyNMF as you wish! +You can now subdivide topics in KeyNMF at will. ```python -from itertools import batched from turftopic import KeyNMF -model = KeyNMF(10, top_n=5) - -corpus = ["some string", "etc", ...] -for batch in batched(corpus, 200): - batch = list(batch) - model.partial_fit(batch) +model = KeyNMF(2, top_n=15, random_state=42).fit(corpus) +model.hierarchy.divide_children(n_subtopics=3) +print(model.hierarchy) ``` -#### $S^3$ Concept Compasses +
+ +Root
+β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+β”‚ β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
+β”‚ β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
+β”‚ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
+└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
+. β”œβ”€β”€ 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
+. β”œβ”€β”€ 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
+. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
+
+
-You can now produce a compass of concepts along two semantic axes using $S^3$. - - - - - -
- -```python -model = SemanticSignalSeparation(10).fit(corpus) -fig = model.concept_compass(topic_x=1, topic_y=4) -fig.show() -``` +#### FASTopic *(Experimental)* -
+You can now use [FASTopic](https://github.com/BobXWu/FASTopic) inside Turftopic. +```python +from turftopic import FASTopic + +model = FASTopic(10).fit(corpus) +model.print_topics() +``` ## Basics [(Documentation)](https://x-tabdeveloping.github.io/turftopic/) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/x-tabdeveloping/turftopic/blob/main/examples/basic_example_20newsgroups.ipynb) @@ -180,6 +183,7 @@ Alternatively you can use the [Figures API](https://x-tabdeveloping.github.io/to ## References - Kardos, M., Kostkan, J., Vermillet, A., Nielbo, K., Enevoldsen, K., & Rocca, R. (2024, June 13). $S^3$ - Semantic Signal separation. arXiv.org. https://arxiv.org/abs/2406.09556 +- Wu, X., Nguyen, T., Zhang, D. C., Wang, W. Y., & Luu, A. T. (2024). FASTopic: A Fast, Adaptive, Stable, and Transferable Topic Modeling Paradigm. ArXiv Preprint ArXiv:2405.17978. - Grootendorst, M. (2022, March 11). BERTopic: Neural topic modeling with a class-based TF-IDF procedure. arXiv.org. https://arxiv.org/abs/2203.05794 - Angelov, D. (2020, August 19). Top2VEC: Distributed representations of topics. arXiv.org. https://arxiv.org/abs/2008.09470 - Bianchi, F., Terragni, S., & Hovy, D. (2020, April 8). Pre-training is a Hot Topic: Contextualized Document Embeddings Improve Topic Coherence. arXiv.org. https://arxiv.org/abs/2004.03974 diff --git a/docs/FASTopic.md b/docs/FASTopic.md new file mode 100644 index 0000000..9338469 --- /dev/null +++ b/docs/FASTopic.md @@ -0,0 +1,15 @@ +# FASTopic + +FASTopic is a neural topic model based on Dual Semantic-relation Reconstruction. + +> Turftopic contains an implementation repurposed for our API, but the implementation is mostly from the [original FASTopic package](https://github.com/BobXWu/FASTopic). + +:warning: This part of the documentation is still under construction :warning: + +## References + +Wu, X., Nguyen, T., Zhang, D. C., Wang, W. Y., & Luu, A. T. (2024). FASTopic: A Fast, Adaptive, Stable, and Transferable Topic Modeling Paradigm. ArXiv Preprint ArXiv:2405.17978. + +## API Reference + +::: turftopic.models.fastopic.FASTopic diff --git a/docs/KeyNMF.md b/docs/KeyNMF.md index c785f62..85742b7 100644 --- a/docs/KeyNMF.md +++ b/docs/KeyNMF.md @@ -309,6 +309,47 @@ for batch in batched(zip(corpus, timestamps)): model.partial_fit_dynamic(text_batch, timestamps=ts_batch, bins=bins) ``` +## Hierarchical Topic Modeling + +When you suspect that subtopics might be present in the topics you find with the model, KeyNMF can be used to discover topics further down the hierarchy. + +This is done by utilising a special case of **weighted NMF**, where documents are weighted by how high they score on the parent topic. +In other words: + +1. Decompose keyword matrix $M \approx WH$ +2. To find subtopics in topic $j$, define document weights $w$ as the $j$th column of $W$. +3. Estimate subcomponents with **wNMF** $M \approx \mathring{W} \mathring{H}$ with document weight $w$ + 1. Initialise $\mathring{H}$ and $\mathring{W}$ randomly. + 2. Perform multiplicative updates until convergence.
+ $\mathring{W}^T = \mathring{W}^T \odot \frac{\mathring{H} \cdot (M^T \odot w)}{\mathring{H} \cdot \mathring{H}^T \cdot (\mathring{W}^T \odot w)}$
+ $\mathring{H}^T = \mathring{H}^T \odot \frac{ (M^T \odot w)\cdot \mathring{W}}{\mathring{H}^T \cdot (\mathring{W}^T \odot w) \cdot \mathring{W}}$ +4. To sufficiently differentiate the subcomponents from each other a pseudo-c-tf-idf weighting scheme is applied to $\mathring{H}$: + 1. $\mathring{H} = \mathring{H}_{ij} \odot ln(1 + \frac{A}{1+\sum_k \mathring{H}_{kj}})$, where $A$ is the average of all elements in $\mathring{H}$ + +To create a hierarchical model, you can use the `hierarchy` property of the model. + +```python +# This divides each of the topics in the model to 3 subtopics. +model.hierarchy.divide_children(n_subtopics=3) +print(model.hierarchy) +``` + +
+ +Root
+β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+β”‚ β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
+β”‚ β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
+β”‚ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
+└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
+. β”œβ”€β”€ 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
+. β”œβ”€β”€ 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
+. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
+
+
+ +For a detailed tutorial on hierarchical modeling click [here](hierarchical.md). + ## Considerations ### Strengths diff --git a/docs/dynamic.md b/docs/dynamic.md index 772a60b..6693cf5 100644 --- a/docs/dynamic.md +++ b/docs/dynamic.md @@ -77,7 +77,7 @@ model.plot_topics_over_time(top_k=5)
Topics over time on a Figure
-## Interface +## API reference All dynamic topic models have a `temporal_components_` attribute, which contains the topic-term matrices for each time slice, along with a `temporal_importance_` attribute, which contains the importance of each topic in each time slice. diff --git a/docs/hierarchical.md b/docs/hierarchical.md new file mode 100644 index 0000000..de5696e --- /dev/null +++ b/docs/hierarchical.md @@ -0,0 +1,152 @@ +# Hierarchical Topic Modeling + +> Note: Hierarchical topic modeling in Turftopic is still in its early stages, you can expect more visualization utilities, tools and models in the future :sparkles: + +You might expect some topics in your corpus to belong to a hierarchy of topics. +Some models in Turftopic (currently only [KeyNMF](KeyNMF.md)) allow you to investigate hierarchical relations and build a taxonomy of topics in a corpus. + +## Divisive Hierarchical Modeling + +Currently Turftopic, in contrast with other topic modeling libraries only allows for hierarchical modeling in a divisive context. +This means that topics can be divided into subtopics in a **top-down** manner. +[KeyNMF](KeyNMF.md) does not discover a topic hierarchy automatically, + but you can manually instruct the model to find subtopics in larger topics. + +As a demonstration, let's load a corpus, that we know to have hierarchical themes. + +```python +from sklearn.datasets import fetch_20newsgroups + +corpus = fetch_20newsgroups( + subset="all", + remove=("headers", "footers", "quotes"), + categories=[ + "comp.os.ms-windows.misc", + "comp.sys.ibm.pc.hardware", + "talk.religion.misc", + "alt.atheism", + ], +).data +``` + +In this case, we have two base themes, which are **computers**, and **religion**. +Let us fit a KeyNMF model with two topics to see if the model finds these. + +```python +from turftopic import KeyNMF + +model = KeyNMF(2, top_n=15, random_state=42).fit(corpus) +model.print_topics() +``` + +| Topic ID | Highest Ranking | +| - | - | +| 0 | windows, dos, os, disk, card, drivers, file, pc, files, microsoft | +| 1 | atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs | + +The results conform our intuition. Topic 0 seems to revolve around IT, while Topic 1 around atheism and religion. +We can already suspect, however that more granular topics could be discovered in this corpus. +For instance Topic 0 contains terms related to operating systems, like *windows* and *dos*, but also components, like *disk* and *card*. + +We can access the hierarchy of topics in the model at the current stage, with the model's `hierarchy` property. + +```python +print(model.hierarchy) +``` + +
+ +Root
+β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
+
+
+ +There isn't much to see yet, the model contains a flat hierarchy of the two topics we discovered and we are at root level. +We can dissect these topics, by adding a level to the hierarchy. + +Let us add 3 subtopics to each topic on the root level. + +```python +model.hierarchy.divide_children(n_subtopics=3) +``` + +
+ +Root
+β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+β”‚ β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
+β”‚ β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
+β”‚ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
+└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
+. β”œβ”€β”€ 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
+. β”œβ”€β”€ 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
+. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
+
+
+ +As you can see, the model managed to identify meaningful subtopics of the two larger topics we found earlier. +Topic 0 got divided into a topic mostly concerned with dos and windows, a topic on operating systems in general, and one about hardware, +while Topic 1 contains a topic about newsgroups, one about atheism, and one about morality and christianity. + +You can also easily access nodes of the hierarchy by indexing it: +```python +model.hierarchy[0] +``` + +
+ +0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
+β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
+└── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
+
+
+ +You can also divide individual topics to a number of subtopics, by using the `divide()` method. +Let us divide Topic 0.0 to 5 subtopics. + +```python +model.hierarchy[0][0].divide(5) +model.hierarchy +``` + +
+ +Root
+β”œβ”€β”€ 0: windows, dos, os, disk, card, drivers, file, pc, files, microsoft
+β”‚ β”œβ”€β”€ 0.0: dos, file, disk, files, program, windows, disks, shareware, norton, memory
+β”‚ β”‚ β”œβ”€β”€ 0.0.1: file, files, ftp, bmp, program, windows, shareware, directory, bitmap, zip
+β”‚ β”‚ β”œβ”€β”€ 0.0.2: os, windows, unix, microsoft, crash, apps, crashes, nt, pc, operating
+β”‚ β”‚ β”œβ”€β”€ 0.0.3: disk, disks, floppy, drive, drives, scsi, boot, hd, norton, ide
+β”‚ β”‚ β”œβ”€β”€ 0.0.4: dos, modem, command, ms, emm386, serial, commands, 386, drivers, batch
+β”‚ β”‚ └── 0.0.5: printer, print, printing, fonts, font, postscript, hp, printers, output, driver
+β”‚ β”œβ”€β”€ 0.1: os, unix, windows, microsoft, apps, nt, ibm, ms, os2, platform
+β”‚ └── 0.2: card, drivers, monitor, driver, vga, ram, motherboard, cards, graphics, ati
+└── 1: atheism, atheist, atheists, religion, christians, religious, belief, christian, god, beliefs
+. β”œβ”€β”€ 1.0: atheism, alt, newsgroup, reading, faq, islam, questions, read, newsgroups, readers
+. β”œβ”€β”€ 1.1: atheists, atheist, belief, theists, beliefs, religious, religion, agnostic, gods, religions
+. └── 1.2: morality, bible, christian, christians, moral, christianity, biblical, immoral, god, religion
+
+
+ +## Visualization +You can visualize hierarchies in Turftopic by using the `plot_tree()` method of a topic hierarchy. +The plot is interactive and you can zoom in or hover on individual topics to get an overview of the most important words. + +```python +model.hierarchy.plot_tree() +``` + +
+ +
Tree plot of the hierarchy.
+
+ + +## API reference + +::: turftopic.hierarchical.TopicNode + + + diff --git a/docs/images/hierarchy_tree.png b/docs/images/hierarchy_tree.png new file mode 100644 index 0000000..9696d28 Binary files /dev/null and b/docs/images/hierarchy_tree.png differ diff --git a/docs/index.md b/docs/index.md index 0aca15a..3c9b874 100644 --- a/docs/index.md +++ b/docs/index.md @@ -23,7 +23,7 @@ pip install turftopic[pyro-ppl] You can use most transformer-based topic models in Turftopic, these include: - [Semantic Signal Separation - $S^3$](s3.md) :compass: -- [KeyNMF](KeyNMF.md) :key: + - [KeyNMF](KeyNMF.md) :key: - [Gaussian Mixture Models (GMM)](gmm.md) - [Clustering Topic Models](clustering.md): - [BERTopic](clustering.md#bertopic_and_top2vec) @@ -31,6 +31,8 @@ You can use most transformer-based topic models in Turftopic, these include: - [Auto-encoding Topic Models](ctm.md): - CombinedTM - ZeroShotTM + - [FASTopic](fastopic.md) :zap: + ## Basic Usage diff --git a/mkdocs.yml b/mkdocs.yml index 3a4e8f9..ce5048b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -7,6 +7,7 @@ nav: - Using Turftopic: basics.md - Dynamic Topic Modeling: dynamic.md - Online Topic Modeling: online.md + - Hierarchical Topic Modeling: hierarchical.md - Model Persistence: persistence.md - Models: - Model Overview: model_overview.md @@ -15,6 +16,7 @@ nav: - GMM: GMM.md - Clustering Models: clustering.md - Autoencoding Models: ctm.md + - FASTopic: fastopic.md - Encoders: encoders.md theme: name: material diff --git a/pyproject.toml b/pyproject.toml index 5587b32..384d349 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ line-length=79 [tool.poetry] name = "turftopic" -version = "0.4.5" +version = "0.5.0" description = "Topic modeling with contextual representations from sentence transformers." authors = ["MΓ‘rton Kardos "] license = "MIT" diff --git a/tests/test_integration.py b/tests/test_integration.py index 33e3361..e73c7fd 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -11,8 +11,14 @@ from sklearn.datasets import fetch_20newsgroups from sklearn.decomposition import PCA -from turftopic import (GMM, AutoEncodingTopicModel, ClusteringTopicModel, - FASTopic, KeyNMF, SemanticSignalSeparation) +from turftopic import ( + GMM, + AutoEncodingTopicModel, + ClusteringTopicModel, + FASTopic, + KeyNMF, + SemanticSignalSeparation, +) def batched(iterable, n: int): @@ -95,17 +101,6 @@ def generate_dates( online_models = [KeyNMF(3, encoder=trf)] -@pytest.mark.parametrize("model", models) -def test_fit_export_table(model): - doc_topic_matrix = model.fit_transform(texts, embeddings=embeddings) - table = model.export_topics(format="csv") - with tempfile.TemporaryDirectory() as tmpdirname: - out_path = Path(tmpdirname).joinpath("topics.csv") - with out_path.open("w") as out_file: - out_file.write(table) - df = pd.read_csv(out_path) - - @pytest.mark.parametrize("model", dynamic_models) def test_fit_dynamic(model): doc_topic_matrix = model.fit_transform_dynamic( @@ -138,7 +133,7 @@ def test_fit_online(model): @pytest.mark.parametrize("model", models) -def test_prepare_topic_data(model): +def test_prepare_topic_data_export_table(model): topic_data = model.prepare_topic_data(texts, embeddings=embeddings) for key, value in topic_data.items(): # We allow transform() to be None for transductive models @@ -146,3 +141,16 @@ def test_prepare_topic_data(model): continue if value is None: raise TypeError(f"Field {key} is None in topic_data.") + table = model.export_topics(format="csv") + with tempfile.TemporaryDirectory() as tmpdirname: + out_path = Path(tmpdirname).joinpath("topics.csv") + with out_path.open("w") as out_file: + out_file.write(table) + df = pd.read_csv(out_path) + + +def test_hierarchical(): + model = KeyNMF(2).fit(texts, embeddings=embeddings) + model.hierarchy.divide_children(3) + model.hierarchy[0][0].divide(3) + repr = str(model.hierarchy) diff --git a/turftopic/__init__.py b/turftopic/__init__.py index 541a9cb..ecfed28 100644 --- a/turftopic/__init__.py +++ b/turftopic/__init__.py @@ -2,6 +2,7 @@ from turftopic.error import NotInstalled from turftopic.models.cluster import ClusteringTopicModel from turftopic.models.decomp import SemanticSignalSeparation +from turftopic.models.fastopic import FASTopic from turftopic.models.gmm import GMM from turftopic.models.keynmf import KeyNMF @@ -10,10 +11,6 @@ except ModuleNotFoundError: AutoEncodingTopicModel = NotInstalled("AutoEncodingTopicModel", "pyro-ppl") -try: - from turftopic.models.fastopic import FASTopic -except ModuleNotFoundError: - FASTopic = NotInstalled("FASTopic", "torch") __all__ = [ "ClusteringTopicModel", diff --git a/turftopic/hierarchical.py b/turftopic/hierarchical.py new file mode 100644 index 0000000..d0a5144 --- /dev/null +++ b/turftopic/hierarchical.py @@ -0,0 +1,277 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +from rich.console import Console +from rich.tree import Tree + +from turftopic.base import ContextualModel + +COLOR_PER_LEVEL = [ + "bright_blue", + "bright_magenta", + "bright_cyan", + "bright_green", + "bright_red", + "bright_yellow", + "cyan", + "magenta", + "blue", + "white", +] + + +def _tree_plot(hierarchy: TopicNode): + """Plots hierarchy with Plotly as a Tree""" + try: + import igraph as ig + import plotly.express as px + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "You will need to install plotly and igraph to use hierarchical plotting functionality." + ) from e + + def traverse(h, nodes, edges, parent=None): + nodes.append(h) + if parent is not None: + edges.append([parent._simple_desc, h._simple_desc]) + if h.children is not None: + for child in h.children: + traverse(child, nodes, edges, parent=h) + + def word_table(h): + entries = [] + words = h.get_words(top_k=10) + for word, imp in words: + entries.append(f"{word}: {imp:.2f}") + return "
".join(entries) + + nodes = [] + edges = [] + for child in hierarchy.children: + traverse(child, nodes, edges) + node_names = [node._simple_desc for node in nodes] + node_to_idx = {node_name: idx for idx, node_name in enumerate(node_names)} + edges_idx = [ + [node_to_idx[start], node_to_idx[end]] for start, end in edges + ] + tables = [word_table(node) for node in nodes] + graph = ig.Graph(len(nodes), edges=edges_idx, directed=True) + layout = graph.layout("rt") + layout.rotate(-90) + x, y = np.array(layout.coords).T + xmin, xmax = np.min(x), np.max(x) + xpad = (xmax - xmin) * 0.35 + fig = px.scatter(x=x, y=y, text=node_names, template="plotly_white") + fig = fig.update_traces( + customdata=[[table] for table in tables], + hovertemplate="%{text}

%{customdata[0]}", + ) + fig = fig.update_traces(marker=dict(size=20, color="rgba(0,0,0,0.2)")) + fig = fig.update_layout(margin=dict(l=0, r=0, t=0, b=0)) + fig = fig.update_yaxes(showgrid=False, visible=False, zeroline=False) + fig = fig.update_xaxes( + showgrid=False, + visible=False, + zeroline=False, + range=(xmin - xpad, xmax + xpad), + ) + for start, end in edges_idx: + fig.add_shape( + type="line", + xref="x", + yref="y", + x0=x[start], + y0=y[start], + x1=x[end], + y1=y[end], + opacity=0.2, + ) + return fig + + +@dataclass +class TopicNode: + """Node for a topic in a topic hierarchy. + + Parameters + ---------- + model: ContextualModel + Underlying topic model, which the hierarchy is based on. + path: tuple[int], default () + Path that leads to this node from the root of the tree. + word_importance: ndarray of shape (n_vocab), default None + Importance of each word in the vocabulary for given topic. + document_topic_vector: ndarray of shape (n_documents), default None + Importance of the topic in all documents in the corpus. + children: list[TopicNode], default None + List of subtopics within this topic. + """ + + model: ContextualModel + path: tuple[int] = () + word_importance: Optional[np.ndarray] = None + document_topic_vector: Optional[np.ndarray] = None + children: Optional[list[TopicNode]] = None + + @classmethod + def create_root( + cls, + model: ContextualModel, + components: np.ndarray, + document_topic_matrix: np.ndarray, + ) -> TopicNode: + """Creates root node from a topic models' components and topic importances in documents.""" + children = [] + n_components = components.shape[0] + for i, comp, doc_top in zip( + range(n_components), components, document_topic_matrix.T + ): + children.append( + cls( + model, + path=(i,), + word_importance=comp, + document_topic_vector=doc_top, + children=None, + ) + ) + return TopicNode( + model, + path=(), + word_importance=None, + document_topic_vector=None, + children=children, + ) + + @property + def level(self) -> int: + """Indicates how deep down the hierarchy the topic is.""" + return len(self.path) + + def get_words(self, top_k: int = 10) -> list[tuple[str, float]]: + """Returns top words and words importances for the topic. + + Parameters + ---------- + top_k: int, default 10 + Number of top words to return. + + Returns + ------- + list[tuple[str, float]] + List of word, importance pairs. + """ + if (self.word_importance is None) or ( + self.document_topic_vector + ) is None: + return [] + idx = np.argpartition(-self.word_importance, top_k)[:top_k] + order = np.argsort(-self.word_importance[idx]) + idx = idx[order] + imp = self.word_importance[idx] + words = self.model.get_vocab()[idx] + return list(zip(words, imp)) + + @property + def description(self) -> str: + """Returns a high level description of the topic with its path in the tree + and top words.""" + if not len(self.path): + path = "Root" + else: + path = ".".join([str(idx) for idx in self.path]) + words = [] + for word, imp in self.get_words(top_k=10): + words.append(word) + concat_words = ", ".join(words) + color = COLOR_PER_LEVEL[min(self.level, len(COLOR_PER_LEVEL) - 1)] + stylized = f"[{color} bold]{path}[/]: [italic]{concat_words}[/]" + console = Console() + with console.capture() as capture: + console.print(stylized, end="") + return capture.get() + + @property + def _simple_desc(self) -> str: + if not len(self.path): + path = "Root" + else: + path = ".".join([str(idx) for idx in self.path]) + words = [] + for word, imp in self.get_words(top_k=5): + words.append(word) + concat_words = ", ".join(words) + return f"{path}: {concat_words}" + + def _build_tree(self, tree: Tree = None, top_k: int = 10) -> Tree: + if tree is None: + tree = Tree(self.description) + else: + tree = tree.add(self.description) + if self.children is not None: + for child in self.children: + child._build_tree(tree) + return tree + + def __str__(self): + tree = self._build_tree(top_k=10) + console = Console() + with console.capture() as capture: + console.print(tree) + return capture.get() + + def __repr__(self): + return str(self) + + def clear(self): + """Deletes children of the given node.""" + self.children = None + return self + + def __getitem__(self, index: int): + if self.children is None: + raise IndexError("Current node is a leaf and has not children.") + return self.children[index] + + def divide(self, n_subtopics: int, **kwargs): + """Divides current node into smaller subtopics. + Only works when the underlying model is a divisive hierarchical model. + + Parameters + ---------- + n_subtopics: int + Number of topics to divide the topic into. + """ + try: + self.children = self.model.divide_topic( + node=self, n_subtopics=n_subtopics, **kwargs + ) + except AttributeError as e: + raise AttributeError( + "Looks like your model is not a divisive hierarchical model." + ) from e + return self + + def divide_children(self, n_subtopics: int, **kwargs): + """Divides all children of the current node to smaller topics. + Only works when the underlying model is a divisive hierarchical model. + + Parameters + ---------- + n_subtopics: int + Number of topics to divide the topics into. + """ + if self.children is None: + raise ValueError( + "Current Node is a leaf, children can't be subdivided." + ) + for child in self.children: + child.divide(n_subtopics, **kwargs) + return self + + def plot_tree(self): + """Plots hierarchy as an interactive tree in Plotly.""" + return _tree_plot(self) diff --git a/turftopic/models/_keynmf.py b/turftopic/models/_keynmf.py index d3ca8b7..e48e585 100644 --- a/turftopic/models/_keynmf.py +++ b/turftopic/models/_keynmf.py @@ -1,16 +1,13 @@ import itertools +import warnings from datetime import datetime from typing import Iterable, Optional import numpy as np import scipy.sparse as spr from sklearn.base import clone -from sklearn.decomposition._nmf import ( - NMF, - MiniBatchNMF, - _initialize_nmf, - _update_coordinate_descent, -) +from sklearn.decomposition._nmf import (NMF, MiniBatchNMF, _initialize_nmf, + _update_coordinate_descent) from sklearn.exceptions import NotFittedError from sklearn.feature_extraction.text import CountVectorizer from sklearn.metrics.pairwise import cosine_similarity @@ -123,7 +120,7 @@ def batch_extract_keywords( if not np.any(mask): keywords.append(dict()) continue - important_terms = np.squeeze(np.asarray(mask)) + important_terms = np.ravel(np.asarray(mask)) word_embeddings = [ self.term_embeddings[self.key_to_index[term]] for term in batch_vocab[important_terms] @@ -241,13 +238,19 @@ def transform(self, keywords: list[dict[str, float]]): def partial_fit(self, keyword_batch: list[dict[str, float]]): X = self.vectorize(keyword_batch, fitting=True) - check_non_negative(X, "NMF (input X)") - self._add_word_components(X) - W, _ = _initialize_nmf(X, self.n_components, random_state=self.seed) - _minibatchnmf = MiniBatchNMF( - self.n_components, init="custom", random_state=self.seed - ).partial_fit(X, W=W, H=self.components) - self.components = _minibatchnmf.components_.astype(X.dtype) + try: + check_non_negative(X, "NMF (input X)") + self._add_word_components(X) + W, _ = _initialize_nmf( + X, self.n_components, random_state=self.seed + ) + _minibatchnmf = MiniBatchNMF( + self.n_components, init="custom", random_state=self.seed + ).partial_fit(X, W=W, H=self.components) + self.components = _minibatchnmf.components_.astype(X.dtype) + except ValueError as e: + warnings.warn(f"Batch failed with error: {e}, skipping.") + return self return self def fit_transform_dynamic( diff --git a/turftopic/models/keynmf.py b/turftopic/models/keynmf.py index 8d6bebc..b083cdb 100644 --- a/turftopic/models/keynmf.py +++ b/turftopic/models/keynmf.py @@ -1,16 +1,22 @@ +import warnings from datetime import datetime from typing import Optional, Union import numpy as np +import scipy.sparse as spr from rich.console import Console from sentence_transformers import SentenceTransformer from sklearn.exceptions import NotFittedError from sklearn.feature_extraction.text import CountVectorizer +from sklearn.preprocessing import normalize from turftopic.base import ContextualModel, Encoder from turftopic.data import TopicData from turftopic.dynamic import DynamicTopicModel +from turftopic.hierarchical import TopicNode from turftopic.models._keynmf import KeywordExtractor, KeywordNMF +from turftopic.models.wnmf import weighted_nmf +from turftopic.vectorizer import default_vectorizer class KeyNMF(ContextualModel, DynamicTopicModel): @@ -57,12 +63,13 @@ def __init__( self.n_components = n_components self.top_n = top_n self.encoder = encoder + self._has_custom_vectorizer = vectorizer is not None if isinstance(encoder, str): self.encoder_ = SentenceTransformer(encoder) else: self.encoder_ = encoder if vectorizer is None: - self.vectorizer = CountVectorizer() + self.vectorizer = default_vectorizer() else: self.vectorizer = vectorizer self.model = KeywordNMF( @@ -92,6 +99,52 @@ def extract_keywords( batch_or_document, embeddings=embeddings ) + def vectorize( + self, + raw_documents=None, + embeddings: Optional[np.ndarray] = None, + keywords: Optional[list[dict[str, float]]] = None, + ) -> spr.csr_array: + """Creates document-term-matrix from documents.""" + if keywords is None: + keywords = self.extract_keywords( + raw_documents, embeddings=embeddings + ) + return self.model.vectorize(keywords) + + def divide_topic( + self, + node: TopicNode, + n_subtopics: int, + ) -> list[TopicNode]: + document_term_matrix = getattr(self, "document_term_matrix", None) + if document_term_matrix is None: + raise ValueError( + "document_term_matrix is needed for computing hierarchies. Perhaps you fitted the model online?" + ) + dtm = document_term_matrix + subtopics = [] + weight = node.document_topic_vector + subcomponents, sub_doc_topic = weighted_nmf( + dtm, weight, n_subtopics, self.random_state, max_iter=200 + ) + subcomponents = subcomponents * np.log( + 1 + subcomponents.mean() / (subcomponents.sum(axis=0) + 1) + ) + subcomponents = normalize(subcomponents, axis=1, norm="l2") + for i, component, doc_topic_vector in zip( + range(n_subtopics), subcomponents, sub_doc_topic.T + ): + sub = TopicNode( + self, + path=(*node.path, i), + word_importance=component, + document_topic_vector=doc_topic_vector, + children=None, + ) + subtopics.append(sub) + return subtopics + def fit_transform( self, raw_documents=None, @@ -130,6 +183,11 @@ def fit_transform( doc_topic_matrix = self.model.fit_transform(keywords) self.components_ = self.model.components console.log("Model fitting done.") + self.document_topic_matrix = doc_topic_matrix + self.document_term_matrix = self.model.vectorize(keywords) + self.hierarchy = TopicNode.create_root( + self, self.components_, self.document_topic_matrix + ) return doc_topic_matrix def fit( @@ -205,6 +263,18 @@ def partial_fit( keywords: list[dict[str, float]], optional Precomputed keyword dictionaries. """ + if not self._has_custom_vectorizer: + self.vectorizer = CountVectorizer(stop_words="english") + self._has_custom_vectorizer = True + min_df = self.vectorizer.min_df + max_df = self.vectorizer.max_df + if (min_df != 1) or (max_df != 1.0): + warnings.warn(f"""When applying partial fitting, the vectorizer is fitted batch-wise in KeyNMF. + You have a vectorizer with min_df={min_df}, and max_df={max_df}. + If you continue with these settings, all tokens might get filtered out. + We recommend setting min_df=1 and max_df=1.0 for online fitting. + `model = KeyNMF(10, vectorizer=CountVectorizer(min_df=1, max_df=1.0)` + """) if keywords is None and raw_documents is None: raise ValueError( "You have to pass either keywords or raw_documents." @@ -247,6 +317,11 @@ def prepare_topic_data( self.components_ = self.model.components console.log("Model fitting done.") document_term_matrix = self.model.vectorize(keywords) + self.document_topic_matrix = doc_topic_matrix + self.document_term_matrix = document_term_matrix + self.hierarchy = TopicNode.create_root( + self, self.components_, self.document_topic_matrix + ) res: TopicData = { "corpus": corpus, "document_term_matrix": document_term_matrix, @@ -291,6 +366,11 @@ def fit_transform_dynamic( ).T self.temporal_components_ = self.model.temporal_components self.components_ = self.model.components + self.document_topic_matrix = doc_topic_matrix + self.document_term_matrix = self.model.vectorize(keywords) + self.hierarchy = TopicNode.create_root( + self, self.components_, self.document_topic_matrix + ) return doc_topic_matrix def partial_fit_dynamic( diff --git a/turftopic/models/wnmf.py b/turftopic/models/wnmf.py new file mode 100644 index 0000000..8a4304c --- /dev/null +++ b/turftopic/models/wnmf.py @@ -0,0 +1,48 @@ +import numpy as np +from sklearn.decomposition._nmf import _beta_divergence, _initialize_nmf +from sklearn.utils.extmath import safe_sparse_dot + +EPSILON = np.finfo(np.float32).eps + + +def weighted_nmf( + dtm: np.ndarray, + weight: np.ndarray, + n_components: int, + seed: int, + max_iter: int = 200, + tol: float = 1e-4, +) -> tuple[np.ndarray, np.ndarray]: + """Multiplicative Update algorithm for a special case of weighted NMF, where + only the rows are weighted, but not the individual elements in the data matrix.""" + doc_topic_matrix, components = _initialize_nmf( + dtm, n_components, random_state=seed + ) + U = components.T + V = doc_topic_matrix.T + weighted_A = dtm.T.multiply(weight) # .T + prev_error = np.inf + for i in range(0, max_iter): + # Update V + numerator = safe_sparse_dot(U.T, weighted_A) + denominator = np.linalg.multi_dot((U.T, U, V * weight)) + denominator[denominator <= 0] = EPSILON + delta = numerator + delta /= denominator + delta[np.isinf(delta) & (V == 0)] = 0 + V *= delta + # Update U + numerator = safe_sparse_dot(weighted_A, V.T) + denominator = np.linalg.multi_dot((U, V * weight, V.T)) + denominator[denominator <= 0] = EPSILON + delta = numerator + delta /= denominator + delta[np.isinf(delta) & (U == 0)] = 0 + U *= delta + if (tol > 0) and (i % 10 == 0): + error = _beta_divergence(dtm, V.T, U.T, 2) + if (error - prev_error) > tol: + break + prev_error = error + components, doc_topic_matrix = U.T, V.T + return components, doc_topic_matrix