Skip to content

Commit

Permalink
Merge pull request #57 from x-tabdeveloping/fastopic
Browse files Browse the repository at this point in the history
Added FASTopic implementation with Turftopic API
  • Loading branch information
x-tabdeveloping authored Jul 30, 2024
2 parents 75babba + 452f44c commit 4ae4047
Show file tree
Hide file tree
Showing 4 changed files with 381 additions and 12 deletions.
31 changes: 19 additions & 12 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import pandas as pd
import pytest
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.datasets import fetch_20newsgroups
from sklearn.decomposition import PCA

from turftopic import (GMM, AutoEncodingTopicModel, ClusteringTopicModel,
KeyNMF, SemanticSignalSeparation)
FASTopic, KeyNMF, SemanticSignalSeparation)


def batched(iterable, n: int):
Expand Down Expand Up @@ -50,42 +52,47 @@ def generate_dates(
timestamps = generate_dates(n_dates=len(texts))

models = [
GMM(5, encoder=trf),
SemanticSignalSeparation(5, encoder=trf),
KeyNMF(5, encoder=trf),
GMM(3, encoder=trf),
SemanticSignalSeparation(3, encoder=trf),
KeyNMF(3, encoder=trf),
ClusteringTopicModel(
n_reduce_to=5,
dimensionality_reduction=PCA(10),
clustering=KMeans(3),
feature_importance="c-tf-idf",
encoder=trf,
reduction_method="agglomerative",
),
ClusteringTopicModel(
n_reduce_to=5,
dimensionality_reduction=PCA(10),
clustering=KMeans(3),
feature_importance="centroid",
encoder=trf,
reduction_method="smallest",
),
AutoEncodingTopicModel(5, combined=True),
AutoEncodingTopicModel(3, combined=True),
FASTopic(3, batch_size=None),
]

dynamic_models = [
GMM(5, encoder=trf),
GMM(3, encoder=trf),
ClusteringTopicModel(
n_reduce_to=5,
dimensionality_reduction=PCA(10),
clustering=KMeans(3),
feature_importance="centroid",
encoder=trf,
reduction_method="smallest",
),
ClusteringTopicModel(
n_reduce_to=5,
dimensionality_reduction=PCA(10),
clustering=KMeans(3),
feature_importance="soft-c-tf-idf",
encoder=trf,
reduction_method="smallest",
),
KeyNMF(5, encoder=trf),
KeyNMF(3, encoder=trf),
]

online_models = [KeyNMF(5, encoder=trf)]
online_models = [KeyNMF(3, encoder=trf)]


@pytest.mark.parametrize("model", models)
Expand Down
6 changes: 6 additions & 0 deletions turftopic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,17 @@
except ModuleNotFoundError:
AutoEncodingTopicModel = NotInstalled("AutoEncodingTopicModel", "pyro-ppl")

try:
from turftopic.models.fastopic import FASTopic
except ModuleNotFoundError:
FASTopic = NotInstalled("FASTopic", "torch")

__all__ = [
"ClusteringTopicModel",
"SemanticSignalSeparation",
"GMM",
"KeyNMF",
"AutoEncodingTopicModel",
"ContextualModel",
"FASTopic",
]
160 changes: 160 additions & 0 deletions turftopic/models/_fastopic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import random
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn


def pairwise_euclidean_distance(x, y):
cost = (
torch.sum(x**2, axis=1, keepdim=True)
+ torch.sum(y**2, dim=1)
- 2 * torch.matmul(x, y.t())
)
return cost


class ETP(nn.Module):
def __init__(
self,
sinkhorn_alpha,
init_a_dist=None,
init_b_dist=None,
OT_max_iter=5000,
stopThr=0.5e-2,
):
super().__init__()
self.sinkhorn_alpha = sinkhorn_alpha
self.OT_max_iter = OT_max_iter
self.stopThr = stopThr
self.epsilon = 1e-16
self.init_a_dist = init_a_dist
self.init_b_dist = init_b_dist
if init_a_dist is not None:
self.a_dist = init_a_dist
if init_b_dist is not None:
self.b_dist = init_b_dist

def forward(self, x, y):
# Sinkhorn's algorithm
M = pairwise_euclidean_distance(x, y)
device = M.device
if self.init_a_dist is None:
a = (torch.ones(M.shape[0]) / M.shape[0]).unsqueeze(1).to(device)
else:
a = F.softmax(self.a_dist, dim=0).to(device)
if self.init_b_dist is None:
b = (torch.ones(M.shape[1]) / M.shape[1]).unsqueeze(1).to(device)
else:
b = F.softmax(self.b_dist, dim=0).to(device)
u = (torch.ones_like(a) / a.size()[0]).to(device) # Kx1
K = torch.exp(-M * self.sinkhorn_alpha)
err = 1
cpt = 0
while err > self.stopThr and cpt < self.OT_max_iter:
v = torch.div(b, torch.matmul(K.t(), u) + self.epsilon)
u = torch.div(a, torch.matmul(K, v) + self.epsilon)
cpt += 1
if cpt % 50 == 1:
bb = torch.mul(v, torch.matmul(K.t(), u))
err = torch.norm(
torch.sum(torch.abs(bb - b), dim=0), p=float("inf")
)
transp = u * (K * v.T)
loss_ETP = torch.sum(transp * M)
return loss_ETP, transp


class fastopic(nn.Module):
def __init__(
self,
num_topics: int,
theta_temp: float = 1.0,
DT_alpha: float = 3.0,
TW_alpha: float = 2.0,
random_state: Optional[int] = None,
):
super().__init__()

self.num_topics = num_topics
self.DT_alpha = DT_alpha
self.TW_alpha = TW_alpha
self.theta_temp = theta_temp
self.seed = random_state or random.randint(0, 10_000)
self.epsilon = 1e-12

def init(self, vocab_size: int, embed_size: int):
torch.manual_seed(self.seed)
self.word_embeddings = nn.init.trunc_normal_(
torch.empty(vocab_size, embed_size)
)
self.word_embeddings = nn.Parameter(F.normalize(self.word_embeddings))
self.topic_embeddings = torch.empty((self.num_topics, embed_size))
nn.init.trunc_normal_(self.topic_embeddings, std=0.1)
self.topic_embeddings = nn.Parameter(
F.normalize(self.topic_embeddings)
)
self.word_weights = nn.Parameter(
(torch.ones(vocab_size) / vocab_size).unsqueeze(1)
)
self.topic_weights = nn.Parameter(
(torch.ones(self.num_topics) / self.num_topics).unsqueeze(1)
)
self.DT_ETP = ETP(self.DT_alpha, init_b_dist=self.topic_weights)
self.TW_ETP = ETP(self.TW_alpha, init_b_dist=self.word_weights)

def get_transp_DT(
self,
doc_embeddings,
):
torch.manual_seed(self.seed)
topic_embeddings = self.topic_embeddings.detach().to(
doc_embeddings.device
)
_, transp = self.DT_ETP(doc_embeddings, topic_embeddings)
return transp.detach().cpu().numpy()

# only for testing
def get_beta(self):
torch.manual_seed(self.seed)
_, transp_TW = self.TW_ETP(self.topic_embeddings, self.word_embeddings)
# use transport plan as beta
beta = transp_TW * transp_TW.shape[0]
return beta

# only for testing
def get_theta(self, doc_embeddings, train_doc_embeddings):
torch.manual_seed(self.seed)
topic_embeddings = self.topic_embeddings.detach().to(
doc_embeddings.device
)
dist = pairwise_euclidean_distance(doc_embeddings, topic_embeddings)
train_dist = pairwise_euclidean_distance(
train_doc_embeddings, topic_embeddings
)
exp_dist = torch.exp(-dist / self.theta_temp)
exp_train_dist = torch.exp(-train_dist / self.theta_temp)
theta = exp_dist / (exp_train_dist.sum(0))
theta = theta / theta.sum(1, keepdim=True)
return theta

def forward(self, train_bow, doc_embeddings):
torch.manual_seed(self.seed)
loss_DT, transp_DT = self.DT_ETP(doc_embeddings, self.topic_embeddings)
loss_TW, transp_TW = self.TW_ETP(
self.topic_embeddings, self.word_embeddings
)
loss_ETP = loss_DT + loss_TW
theta = transp_DT * transp_DT.shape[0]
beta = transp_TW * transp_TW.shape[0]
# Dual Semantic-relation Reconstruction
recon = torch.matmul(theta, beta)
loss_DSR = (
-(train_bow * (recon + self.epsilon).log()).sum(axis=1).mean()
)
loss = loss_DSR + loss_ETP
rst_dict = {
"loss": loss,
}
return rst_dict
Loading

0 comments on commit 4ae4047

Please sign in to comment.