diff --git a/benchmark/README.md b/benchmark/README.md index 3264dff5..afaee57b 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -209,3 +209,15 @@ Experimental setting: 20 Optuna search trials for XGBoost, CatBoost and LightGBM | FTTransformer | 0.872±0.005 (7004s) | 0.540±0.068 (3355s) | 0.908±0.004 (7514s) | | TabNet | **0.912±0.004 (219s)** | 0.995±0.001 (301s) | 0.919±0.003 (187s) | | TabTransformer | 0.843±0.003 (2810s) | 0.657±0.187 (2843s) | 0.854±0.001 (284s) | + +## Benchmarking pytorch-frame and pytorch-tabular + +`pytorch_tabular_benchmark` compares the performance of `pytorch-frame` to `pytorch-tabular`. `pytorch-tabular` excels in providing an accessible approach for standard tabular tasks, allowing users to quickly implement and experiment with existing tabular learning models. It also excels with its training loop modifications and explainability feature. On the other hand, `ptroch-frame` offers enhanced flexibility for exploring and building novel tabular learning approaches while still providing access to established models. It distinguishes itself through support for a wider array of data types, more sophisticated encoding schemas, and streamlined integration with LLMs. +The following table shows the speed comparison of `pytorch-frame` to `pytorch-tabular` on implementations of `TabNet` and `FTTransformer`. + +| Package | Model | Num iters/sec | +| :-------------- | :------------ | :------------ | +| PyTorch Tabular | TabNet | 41.7 | +| PyTorch Frame | TabNet | 45.0 | +| PyTorch Tabular | FTTransformer | 40.1 | +| PyTorch Frame | FTTransformer | 43.7 | diff --git a/benchmark/pytorch_tabular_benchmark.py b/benchmark/pytorch_tabular_benchmark.py index 5631af21..9540424d 100644 --- a/benchmark/pytorch_tabular_benchmark.py +++ b/benchmark/pytorch_tabular_benchmark.py @@ -1,5 +1,17 @@ """This script benchmarks the training time of TabTransformer using PyTorch Frame and PyTorch Tabular. + +Results form comparing Pytorch Tabular and Frame. Specifically the iteration +speed while trainig. + +------------------------------------- +Package | Model | Num iters/sec| +------------------------------------- +Tabular | TabNet | 41.7 +Frame | TabNet | 45.0 +Tabular | FTTrans | 40.1 +Frame | FTTrans | 43.7 +-------------------------------------- """ import argparse import os.path as osp