This is the official code repository for the paper:
Practical Computational Power of Linear Transformers and Their Recurrent and Self-Referential Extensions (EMNLP 2023, short paper) (Link coming soon)
This repository was originally forked from IDSIA/recurrent-fwp /algorithmic, and contains code from IDSIA/modern-srwm (SRWM implementation) and IDSIA/rtrl-elstm (e-LSTM implementation).
fast_transformers
directory contains code taken and modified from idiap/fast-transformers (a specific license file is included).
- PyTorch. We used PyTorch
2.0.1+cu117
with Python 3.9 or 3.11 - Ninja to compile custom CUDA kernels (
pip install ninja
) - Optionally: wandb for monitoring jobs (you can enable it by adding the
--use_wandb
flag) (We did not use/need this option for our experiments)
We obtained the datasets from satwik77/Transformer-Formal-Languages (for all datasets except reset Dyck-1
we downloaded their pre-generated datasets; for reset Dyck-1
we generated a dataset using their code).
The exact set of datasets we used can also be downloaded from this Google Drive link.
A generic script to train a model is provided below.
TORCH_EXTENSIONS_DIR
is where the compiled custom CUDA kernel will be stored (this way, you will not need to re-compile them every time).- the script below assumes the dataset to be available under
data/${task}
. The expected files are:train_50.src
,train_50.tgt
,valid_50.src
,valid_50.tgt
,test_50.src
, andtest_50.tgt
(*.src
/*.tgt
files containing the input/output sequences; "test" corresponds to "valid2"; see examples in our files provided above, and seemain.py
if further details are needed).
modeltype
specifies the model type by its ID.
The IDs of the models used in the paper are as follows:
0
: LSTM1
: (regular) Transformer2
: DeltaNet7
: Recurrent Delta8
: Linear Transformer10
: e-LSTM (called "Quasi-LSTM" in the code)12
: SRWM
The corresponding CUDA implementations can be found in the following directories:
fast_transformers
: vanilla Linear Transformerfast_weight
: DeltaNetrec_update_fwm_tanh
: Recurrent Deltaself_ref_v0
: SRWM
When using our data files, we used task
to specify the language to learn, which has to be consistent with the directory name:
parity
aa-star
abab-star
counter-2
counter-3
shuffle-2
dyck-1
reset_dyck
Otherwise, the task is implicitly specified through specification of the data directory --data_dir
(see below)
--level
flag can be ignored/removed. We initially thought about having various "levels" for each tasks, but in the end, we did not implement such an option, and we always have level=50
.
Nevertheless, we should not set it to other values than 50, as the dataset file names have to contain the level
value (see file names above).
Other flags/parameters should be self-explanatory (see hyper-parameter tables in the appendix of the paper).
#!/bin/bash
export CUDA_VISIBLE_DEVICES=0
export TORCH_EXTENSIONS_DIR="arbitrary_path/torch_extensions/formal_lang"
task=
modeltype=
lay=
size=
head=
ff=
batch=
lr=
DATA_DIR='data/'${task}
python main.py \
--data_dir ${DATA_DIR} \
--level 50 \
--model_type ${modeltype} \
--num_layer ${lay} \
--hidden_size ${size} \
--n_head ${head} \
--ff_factor ${ff} \
--dropout 0.0 \
--batch_size ${batch} \
--learning_rate ${lr} \
--seed 1 \
--grad_cummulate 1 \
--num_epoch 300 \
--report_every 50 \
--project_name "2023--formal-lang" \
--remove_pos_enc
NB:
- In training logs,
valid
corresponds to the validation set with the same distribution as the training set ("Bin0" in the paper) whilevalid2
is the validation set with longer sequences ("Bin1" in the paper). - A training run will stop when either 100% accuracy is achieved on the valid2 dataset or the maximum number of training epoch
num_epoch
is reached. - In logs,
no-op acc
should be ignored (this used to be originally relevant for other tasks used in IDSIA/recurrent-fwp /algorithmic) - Our final results are reported by consistently testing a single seed (
seed=1
) for each configuration described in the appendix. However, it is not impossible that certain good configurations still "fail" for some seeds; if some of the good configurations we report happen to fail in your setting (i.e., not achieving 100% accuracy on valid2 using the script above), we recommend trying other seeds.
@inproceedings{irie2023practical,
title={Practical Computational Power of Linear Transformers and Their Recurrent and Self-Referential Extensions},
author={Kazuki Irie and R\'obert Csord\'as and J\"urgen Schmidhuber},
booktitle={Proc. Conf. on Empirical Methods in Natural Language Processing (EMNLP)},
address={Sentosa, Singapore},
year={2023}
}