Skip to content

Commit

Permalink
add tests for dynamic models
Browse files Browse the repository at this point in the history
  • Loading branch information
rbroc committed Mar 19, 2024
1 parent 39ad5ee commit 46944fa
Showing 1 changed file with 53 additions and 1 deletion.
54 changes: 53 additions & 1 deletion tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
import tempfile
from pathlib import Path

Expand All @@ -12,9 +13,24 @@
AutoEncodingTopicModel,
ClusteringTopicModel,
KeyNMF,
SemanticSignalSeparation,
SemanticSignalSeparation
)


def generate_dates(
n_dates: int,
) -> list[datetime]:
""" Generate random dates to test dynamic models """
dates = []
for n in range(n_dates):
d = np.random.randint(low=1, high=29)
m = np.random.randint(low=1, high=13)
y = np.random.randint(low=2000, high=2020)
date = datetime(year=y, month=m, day=d)
dates.append(date)
return dates


newsgroups = fetch_20newsgroups(
subset="all",
categories=[
Expand All @@ -25,6 +41,7 @@
texts = newsgroups.data
trf = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = np.asarray(trf.encode(texts))
timestamps = generate_dates(n_dates=len(texts))

models = [
GMM(5, encoder=trf),
Expand All @@ -46,6 +63,28 @@
AutoEncodingTopicModel(5, combined=True),
]

dynamic_models = [
GMM(5, encoder=trf),
ClusteringTopicModel(
n_reduce_to=5,
feature_importance="centroid",
encoder=trf,
reduction_method="smallest"
),
ClusteringTopicModel(
n_reduce_to=5,
feature_importance="soft-c-tf-idf",
encoder=trf,
reduction_method="smallest"
),
ClusteringTopicModel(
n_reduce_to=5,
feature_importance="c-tf-idf",
encoder=trf,
reduction_method="smallest"
),
]


@pytest.mark.parametrize("model", models)
def test_fit_export_table(model):
Expand All @@ -56,3 +95,16 @@ def test_fit_export_table(model):
with out_path.open("w") as out_file:
out_file.write(table)
df = pd.read_csv(out_path)


@pytest.mark.parametrize("model", dynamic_models)
def test_fit_dynamic(model):
doc_topic_matrix = model.fit_transform_dynamic(
texts, embeddings=embeddings, timestamps=timestamps,
)
table = model.export_topics(format="csv")
with tempfile.TemporaryDirectory() as tmpdirname:
out_path = Path(tmpdirname).joinpath("topics.csv")
with out_path.open("w") as out_file:
out_file.write(table)
df = pd.read_csv(out_path)

0 comments on commit 46944fa

Please sign in to comment.