diff --git a/tests/test_integration.py b/tests/test_integration.py index ee31f30..6d58f7b 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,3 +1,4 @@ +from datetime import datetime import tempfile from pathlib import Path @@ -12,9 +13,24 @@ AutoEncodingTopicModel, ClusteringTopicModel, KeyNMF, - SemanticSignalSeparation, + SemanticSignalSeparation ) + +def generate_dates( + n_dates: int, +) -> list[datetime]: + """ Generate random dates to test dynamic models """ + dates = [] + for n in range(n_dates): + d = np.random.randint(low=1, high=29) + m = np.random.randint(low=1, high=13) + y = np.random.randint(low=2000, high=2020) + date = datetime(year=y, month=m, day=d) + dates.append(date) + return dates + + newsgroups = fetch_20newsgroups( subset="all", categories=[ @@ -25,6 +41,7 @@ texts = newsgroups.data trf = SentenceTransformer("all-MiniLM-L6-v2") embeddings = np.asarray(trf.encode(texts)) +timestamps = generate_dates(n_dates=len(texts)) models = [ GMM(5, encoder=trf), @@ -46,6 +63,28 @@ AutoEncodingTopicModel(5, combined=True), ] +dynamic_models = [ + GMM(5, encoder=trf), + ClusteringTopicModel( + n_reduce_to=5, + feature_importance="centroid", + encoder=trf, + reduction_method="smallest" + ), + ClusteringTopicModel( + n_reduce_to=5, + feature_importance="soft-c-tf-idf", + encoder=trf, + reduction_method="smallest" + ), + ClusteringTopicModel( + n_reduce_to=5, + feature_importance="c-tf-idf", + encoder=trf, + reduction_method="smallest" + ), +] + @pytest.mark.parametrize("model", models) def test_fit_export_table(model): @@ -56,3 +95,16 @@ def test_fit_export_table(model): with out_path.open("w") as out_file: out_file.write(table) df = pd.read_csv(out_path) + + +@pytest.mark.parametrize("model", dynamic_models) +def test_fit_dynamic(model): + doc_topic_matrix = model.fit_transform_dynamic( + texts, embeddings=embeddings, timestamps=timestamps, + ) + table = model.export_topics(format="csv") + with tempfile.TemporaryDirectory() as tmpdirname: + out_path = Path(tmpdirname).joinpath("topics.csv") + with out_path.open("w") as out_file: + out_file.write(table) + df = pd.read_csv(out_path)