diff --git a/chemlactica/config/create_train_config.py b/chemlactica/config/create_train_config.py index a30a24f..0877590 100644 --- a/chemlactica/config/create_train_config.py +++ b/chemlactica/config/create_train_config.py @@ -20,6 +20,9 @@ model_train_configs["125m"][ "tokenizer_path" ] = "chemlactica/tokenizer/ChemLacticaTokenizer66" +model_train_configs["small_opt"][ + "tokenizer_path" +] = "chemlactica/tokenizer/ChemLacticaTokenizer66" model_train_configs["1.3b"][ "tokenizer_path" ] = "chemlactica/tokenizer/ChemLacticaTokenizer66" diff --git a/chemlactica/jsonl_dataset.py b/chemlactica/jsonl_dataset.py index 9fb4eb2..0610885 100644 --- a/chemlactica/jsonl_dataset.py +++ b/chemlactica/jsonl_dataset.py @@ -1,8 +1,9 @@ from typing import List -import torch -# from io import StringIO import os +from accelerate.state import PartialState + +distributed_state = PartialState() def generator_init_print(shared_jsonl_files, files): @@ -23,51 +24,39 @@ def setup_generator(shared_jsonl_files, files): return file_states -def get_batch(file, state, chunk_size): - with open(file) as f: - f.seek(state["position"]) - batch = f.read(chunk_size) - if not batch: - raise StopIteration - - batch += f.readline() - batch = batch.splitlines() +def should_yield_on_current_rank(i, num_processes, process_index): + return i % num_processes == process_index - # batch = [line.rstrip("\n") for line in batch] - state["position"] = f.tell() - batch_len = len(batch) - state["line_number"] += batch_len - return batch, batch_len, state - -def format_sample(sample, return_line_info, batch_len, file, state, i): +def format_sample(line): + sample = line.strip() ret = {"text": sample} - if return_line_info: - ret["line_info"] = { - "file": file, - "line_number": state["line_number"] - batch_len + i, - } return ret def samples_generator( files: List[str], shared_jsonl_files, chunk_size=25000, return_line_info=False ): - if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - file_states = setup_generator(shared_jsonl_files, files) - - returned = True - while returned: - returned = False - for file, state in file_states.items(): - try: - batch, batch_len, state = get_batch(file, state, chunk_size) - except StopIteration: - break - for i, sample in enumerate(batch, start=1): - returned = True - ret = format_sample( - sample, return_line_info, batch_len, file, state, i - ) - yield ret - shared_jsonl_files[file] = state + file_states = setup_generator(shared_jsonl_files, files) + + returned = True + while returned: + returned = False + for file, state in file_states.items(): + with open(file) as f: + f.seek(state["position"]) + line = f.readline() + counter = 0 + while line: + state["position"] = f.tell() + if should_yield_on_current_rank( + counter, + distributed_state.num_processes, + distributed_state.process_index, + ): + returned = True + ret = format_sample(line) + yield ret + counter = counter + 1 + shared_jsonl_files[file] = state + line = f.readline() diff --git a/chemlactica/train.py b/chemlactica/train.py index 8e7ec43..54fb044 100644 --- a/chemlactica/train.py +++ b/chemlactica/train.py @@ -50,7 +50,6 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "caching_allocator" # os.environ["TOKENIZERS_PARALLELISM"] = "false" - # signal.signal(signal.SIGINT, signal_handler) # signal.signal(signal.SIGTERM, signal_handler) @@ -192,9 +191,9 @@ def train( accelerator.wait_for_everyone() - with multiprocessing.Manager() if accelerator.is_main_process else nullcontext() as manager: + with multiprocessing.Manager() as manager: shared_jsonl_files = None - if accelerator.is_main_process and train_type == "pretrain": + if train_type == "pretrain": shared_jsonl_files = manager.dict() trainer_callback_dict[ "json_dataset_resume_callback" diff --git a/chemlactica/utils/model_utils.py b/chemlactica/utils/model_utils.py index 09cfebe..cad599f 100644 --- a/chemlactica/utils/model_utils.py +++ b/chemlactica/utils/model_utils.py @@ -105,7 +105,7 @@ def load_model( ffn_dim=model_config["ffn_dim"], max_position_embeddings=model_config["max_position_embeddings"], num_attention_heads=model_config["num_attention_heads"], - word_embed_proj_dim=model_config["word_sembed_proj_dim"], + word_embed_proj_dim=model_config["word_embed_proj_dim"], ) ) if "galactica" in from_pretrained.lower(): diff --git a/environment.yml b/environment.yml index 15a58e5..31a5e1c 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: cl11.8_t_4.37 +name: cl11.8_t_4.39 channels: - pytorch - nvidia @@ -7,12 +7,10 @@ channels: dependencies: - _libgcc_mutex=0.1=conda_forge - _openmp_mutex=4.5=2_kmp_llvm - - asttokens=2.4.1=pyhd8ed1ab_0 - blas=2.121=openblas - blas-devel=3.9.0=21_linux64_openblas - bzip2=1.0.8=hd590300_5 - ca-certificates=2024.2.2=hbcca054_0 - - comm=0.2.1=pyhd8ed1ab_0 - cuda=11.8.0=0 - cuda-cccl=11.8.89=0 - cuda-command-line-tools=11.8.0=0 @@ -45,27 +43,16 @@ dependencies: - cuda-sanitizer-api=11.8.86=0 - cuda-toolkit=11.8.0=0 - cuda-tools=11.8.0=0 - - cuda-version=11.8=h70ddcb2_2 + - cuda-version=11.8=h70ddcb2_3 - cuda-visual-tools=11.8.0=0 - cudatoolkit=11.8.0=h4ba93d1_13 - - cudnn=8.8.0.121=hcdd5f01_4 - - debugpy=1.8.0=py310hc6cd4ac_1 - - decorator=5.1.1=pyhd8ed1ab_0 - - exceptiongroup=1.2.0=pyhd8ed1ab_2 - - executing=2.0.1=pyhd8ed1ab_0 + - cudnn=8.9.7.29=hbc23b4c_3 - filelock=3.13.1=pyhd8ed1ab_0 - gds-tools=1.4.0.31=0 - - gmp=6.3.0=h59595ed_0 + - gmp=6.3.0=h59595ed_1 - gmpy2=2.1.2=py310h3ec546c_1 - icu=73.2=h59595ed_0 - - importlib-metadata=7.0.1=pyha770c72_0 - - importlib_metadata=7.0.1=hd8ed1ab_0 - - ipykernel=6.29.2=pyhd33586a_0 - - ipython=8.21.0=pyh707e725_0 - - jedi=0.19.1=pyhd8ed1ab_0 - jinja2=3.1.3=pyhd8ed1ab_0 - - jupyter_client=8.6.0=pyhd8ed1ab_0 - - jupyter_core=5.7.1=py310hff52083_0 - ld_impl_linux-64=2.40=h41732ed_0 - libabseil=20230802.1=cxx17_h59595ed_0 - libblas=3.9.0=21_linux64_openblas @@ -74,9 +61,9 @@ dependencies: - libcublas-dev=11.11.3.6=0 - libcufft=10.9.0.58=0 - libcufft-dev=10.9.0.58=0 - - libcufile=1.8.1.2=0 + - libcufile=1.9.0.20=0 - libcufile-dev=1.4.0.31=0 - - libcurand=10.3.4.107=0 + - libcurand=10.3.5.119=0 - libcurand-dev=10.3.0.86=0 - libcusolver=11.4.1.48=0 - libcusolver-dev=11.4.1.48=0 @@ -92,81 +79,59 @@ dependencies: - liblapack=3.9.0=21_linux64_openblas - liblapacke=3.9.0=21_linux64_openblas - libmagma=2.7.2=h09b5827_2 - - libmagma_sparse=2.7.2=h09b5827_2 + - libmagma_sparse=2.7.2=h09b5827_3 - libnpp=11.8.0.86=0 - libnpp-dev=11.8.0.86=0 - libnsl=2.0.1=hd590300_0 - libnvjpeg=11.9.0.86=0 - libnvjpeg-dev=11.9.0.86=0 - libopenblas=0.3.26=pthreads_h413a1c8_0 - - libprotobuf=4.25.1=hf27288f_1 - - libsodium=1.0.18=h36c2ea0_1 - - libsqlite=3.44.2=h2797004_0 + - libprotobuf=4.25.1=hf27288f_2 + - libsqlite=3.45.2=h2797004_0 - libstdcxx-ng=13.2.0=h7e041cc_5 - libtorch=2.1.2=cuda118_h12fe058_301 - libuuid=2.38.1=h0b41bf4_0 - - libuv=1.46.0=hd590300_0 + - libuv=1.48.0=hd590300_0 - libxcrypt=4.4.36=hd590300_1 - - libxml2=2.12.5=h232c23b_0 + - libxml2=2.12.6=h232c23b_0 - libzlib=1.2.13=hd590300_5 - - llvm-openmp=17.0.6=h4dfa4b3_0 - - magma=2.7.2=h4aca40b_2 + - llvm-openmp=18.1.2=h4dfa4b3_0 + - magma=2.7.2=h4aca40b_3 - markupsafe=2.1.5=py310h2372a71_0 - - matplotlib-inline=0.1.6=pyhd8ed1ab_0 - mkl=2023.2.0=h84fe81f_50496 - mkl-devel=2023.2.0=ha770c72_50496 - mkl-include=2023.2.0=h84fe81f_50496 - mpc=1.3.1=hfe3b2da_0 - mpfr=4.2.1=h9458935_0 - mpmath=1.3.0=pyhd8ed1ab_0 - - nccl=2.19.4.1=h6103f9b_0 - - ncurses=6.4=h59595ed_2 - - nest-asyncio=1.6.0=pyhd8ed1ab_0 + - nccl=2.20.5.1=h6103f9b_0 + - ncurses=6.4.20240210=h59595ed_0 - networkx=3.2.1=pyhd8ed1ab_0 - nsight-compute=2022.3.0.22=0 - numpy=1.26.4=py310hb13e2d6_0 - openblas=0.3.26=pthreads_h7a3da1a_0 - - openssl=3.2.1=hd590300_0 - - packaging=23.2=pyhd8ed1ab_0 - - parso=0.8.3=pyhd8ed1ab_0 - - pexpect=4.9.0=pyhd8ed1ab_0 - - pickleshare=0.7.5=py_1003 + - openssl=3.2.1=hd590300_1 - pip=24.0=pyhd8ed1ab_0 - - platformdirs=4.2.0=pyhd8ed1ab_0 - - prompt-toolkit=3.0.42=pyha770c72_0 - - psutil=5.9.8=py310h2372a71_0 - - ptyprocess=0.7.0=pyhd3deb0d_0 - - pure_eval=0.2.2=pyhd8ed1ab_0 - - pygments=2.17.2=pyhd8ed1ab_0 - - python=3.10.13=hd12c33a_1_cpython - - python-dateutil=2.8.2=pyhd8ed1ab_0 + - python=3.10.14=hd12c33a_0_cpython - python_abi=3.10=4_cp310 - pytorch=2.1.2=cuda118_py310h59774e7_301 - pytorch-cuda=11.8=h7e8668a_5 - pytorch-mutex=1.0=cuda - pyyaml=6.0.1=py310h2372a71_1 - - pyzmq=25.1.2=py310h795f18f_0 - readline=8.2=h8228510_1 - - setuptools=69.0.3=pyhd8ed1ab_0 - - six=1.16.0=pyh6c4a22f_0 + - setuptools=69.2.0=pyhd8ed1ab_0 - sleef=3.5.1=h9b69904_2 - - stack_data=0.6.2=pyhd8ed1ab_0 - sympy=1.12=pypyh9d50eac_103 - tbb=2021.11.0=h00ab1b0_1 - tk=8.6.13=noxft_h4845f30_101 - torchtriton=2.2.0=py310 - - tornado=6.3.3=py310h2372a71_1 - - traitlets=5.14.1=pyhd8ed1ab_0 - - typing_extensions=4.9.0=pyha770c72_0 - - wcwidth=0.2.13=pyhd8ed1ab_0 - - wheel=0.42.0=pyhd8ed1ab_0 + - typing_extensions=4.10.0=pyha770c72_0 + - wheel=0.43.0=pyhd8ed1ab_0 - xz=5.2.6=h166bdaf_0 - yaml=0.2.5=h7f98852_2 - - zeromq=4.3.5=h59595ed_0 - - zipp=3.17.0=pyhd8ed1ab_0 - zstd=1.5.5=hfc55251_0 - pip: - - accelerate==0.26.1 + - accelerate==0.28.0 - aim==3.17.0 - aim-ui==3.17.0 - aimrecords==0.0.7 @@ -176,10 +141,11 @@ dependencies: - aiosignal==1.3.1 - alembic==1.13.1 - annotated-types==0.6.0 - - anyio==4.2.0 + - anyio==4.3.0 - argon2-cffi==23.1.0 - argon2-cffi-bindings==21.2.0 - arrow==1.3.0 + - asttokens==2.4.1 - async-lru==2.0.4 - async-timeout==4.0.3 - attrs==23.2.0 @@ -189,7 +155,7 @@ dependencies: - beautifulsoup4==4.12.3 - bitsandbytes==0.43.0 - bleach==6.1.0 - - cachetools==5.3.2 + - cachetools==5.3.3 - certifi==2024.2.2 - cffi==1.16.0 - cfgv==3.4.0 @@ -197,120 +163,140 @@ dependencies: - chemlactica==0.0.1 - click==8.1.7 - cloudpickle==3.0.0 - - conda-pack==0.7.1+23.g60c249a - - contourpy==1.2.0 - - cryptography==42.0.2 - - cycler==0.12.1 - - datasets==2.14.7 + - comm==0.2.2 + - conda-pack==0.7.1 + - cryptography==42.0.5 + - datasets==2.18.0 + - debugpy==1.8.1 + - decorator==5.1.1 - defusedxml==0.7.1 - - dill==0.3.7 + - dill==0.3.8 - distlib==0.3.8 - - docstring-parser==0.15 + - docstring-parser==0.16 - einops==0.7.0 - - fastapi==0.109.2 + - exceptiongroup==1.2.0 + - executing==2.0.1 + - fastapi==0.110.0 - fastjsonschema==2.19.1 - - flash-attn==2.4.3.post1 - - fonttools==4.48.1 + - flash-attn==2.5.6 - fqdn==1.5.1 - frozenlist==1.4.1 - - fsspec==2023.10.0 + - fsspec==2024.2.0 - greenlet==3.0.3 - - grpcio==1.60.1 + - grpcio==1.62.1 - h11==0.14.0 - - httpcore==1.0.3 - - httpx==0.26.0 - - huggingface-hub==0.20.3 - - identify==2.5.34 + - httpcore==1.0.4 + - httpx==0.27.0 + - huggingface-hub==0.21.4 + - identify==2.5.35 - idna==3.6 + - ipykernel==6.29.3 + - ipython==8.22.2 - ipywidgets==8.1.2 - isoduration==20.11.0 + - jedi==0.19.1 - joblib==1.3.2 - - json5==0.9.17 + - json5==0.9.24 - jsonpointer==2.4 - jsonschema==4.21.1 - jsonschema-specifications==2023.12.1 - jupyter==1.0.0 + - jupyter-client==8.6.1 - jupyter-console==6.6.3 - - jupyter-events==0.9.0 - - jupyter-lsp==2.2.2 - - jupyter-server==2.12.5 - - jupyter-server-terminals==0.5.2 - - jupyterlab==4.1.2 + - jupyter-core==5.7.2 + - jupyter-events==0.10.0 + - jupyter-lsp==2.2.4 + - jupyter-server==2.13.0 + - jupyter-server-terminals==0.5.3 + - jupyterlab==4.1.5 - jupyterlab-pygments==0.3.0 - - jupyterlab-server==2.25.3 + - jupyterlab-server==2.25.4 - jupyterlab-widgets==3.0.10 - - kiwisolver==1.4.5 - mako==1.3.2 - markdown-it-py==3.0.0 - - matplotlib==3.8.2 + - matplotlib-inline==0.1.6 - mdurl==0.1.2 - mistune==3.0.2 - multidict==6.0.5 - - multiprocess==0.70.15 - - nbclient==0.9.0 - - nbconvert==7.16.1 - - nbformat==5.9.2 + - multiprocess==0.70.16 + - nbclient==0.10.0 + - nbconvert==7.16.3 + - nbformat==5.10.3 + - nest-asyncio==1.6.0 - ninja==1.11.1.1 - nodeenv==1.8.0 - - notebook==7.1.0 + - notebook==7.1.2 - notebook-shim==0.2.4 - orjson==3.9.15 - overrides==7.7.0 - - pandas==2.2.0 + - packaging==23.2 + - pandas==2.2.1 - pandocfilters==1.5.1 + - parso==0.8.3 + - pexpect==4.9.0 - pillow==10.2.0 - - pre-commit==3.6.1 + - platformdirs==4.2.0 + - pre-commit==3.6.2 + - prettyprint==0.1.5 - prometheus-client==0.20.0 - - protobuf==4.25.2 + - prompt-toolkit==3.0.43 + - protobuf==4.25.3 + - psutil==5.9.8 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 - py3nvml==0.2.7 - - pyarrow==15.0.0 + - pyarrow==15.0.2 - pyarrow-hotfix==0.6 - pycparser==2.21 - - pydantic==2.6.1 - - pydantic-core==2.16.2 + - pydantic==2.6.4 + - pydantic-core==2.16.3 + - pygments==2.17.2 - pyjwt==2.8.0 - - pyparsing==3.1.1 + - python-dateutil==2.9.0.post0 - python-json-logger==2.0.7 - pytz==2024.1 + - pyzmq==25.1.2 - qtconsole==5.5.1 - qtpy==2.4.1 - - rdkit==2023.9.4 - - referencing==0.33.0 + - referencing==0.34.0 - regex==2023.12.25 - requests==2.31.0 - - restrictedpython==7.0 + - restrictedpython==7.1 - rfc3339-validator==0.1.4 - rfc3986-validator==0.1.1 - - rich==13.7.0 + - rich==13.7.1 - rpds-py==0.18.0 - safetensors==0.4.2 - scikit-learn==1.4.1.post1 - scipy==1.12.0 - - seaborn==0.13.2 - - segment-analytics-python==2.3.1 + - segment-analytics-python==2.3.2 - send2trash==1.8.2 - - shtab==1.6.5 - - sniffio==1.3.0 + - shtab==1.7.1 + - six==1.16.0 + - sniffio==1.3.1 - soupsieve==2.5 - - sqlalchemy==1.4.51 + - sqlalchemy==1.4.52 + - stack-data==0.6.3 - starlette==0.36.3 - - stringzilla==3.7.0 - submitit==1.5.1 - - terminado==0.18.0 - - threadpoolctl==3.3.0 + - terminado==0.18.1 + - threadpoolctl==3.4.0 - tinycss2==1.2.1 - - tokenizers==0.15.1 + - tokenizers==0.15.2 - tomli==2.0.1 - - tqdm==4.66.1 - - transformers==4.37.0 - - trl==0.7.11 - - types-python-dateutil==2.8.19.20240106 - - tyro==0.7.2 - - tzdata==2023.4 + - tornado==6.4 + - tqdm==4.66.2 + - traitlets==5.14.2 + - transformers==4.39.0 + - trl==0.8.1 + - types-python-dateutil==2.9.0.20240316 + - tyro==0.7.3 + - tzdata==2024.1 - uri-template==1.3.0 - - urllib3==2.2.0 - - uvicorn==0.27.0.post1 - - virtualenv==20.25.0 + - urllib3==2.2.1 + - uvicorn==0.29.0 + - virtualenv==20.25.1 + - wcwidth==0.2.13 - webcolors==1.13 - webencodings==0.5.1 - websocket-client==1.7.0 @@ -318,4 +304,4 @@ dependencies: - xmltodict==0.13.0 - xxhash==3.4.1 - yarl==1.9.4 -prefix: /auto/home/menuab/miniforge3/envs/cl11.8_t_4.37 +prefix: /home/philipp/miniforge3/envs/cl11.8_t_4.39 diff --git a/test_status.yaml b/test_status.yaml index b9d553d..188d21f 100644 --- a/test_status.yaml +++ b/test_status.yaml @@ -1 +1 @@ -51f0d0f146c9c7dfaebbcb53722a25dde8c89534: PASS +e82662850b8b2cec825f73168a9d53b13af95500: PASS diff --git a/tests/_test_dataloader_speed.py b/tests/_test_dataloader_speed.py index 37873df..9646174 100644 --- a/tests/_test_dataloader_speed.py +++ b/tests/_test_dataloader_speed.py @@ -5,7 +5,7 @@ from datasets import load_dataset from torch.utils.data import DataLoader -from dataset_utils import process_dataset +from chemlactica.utils.dataset_utils import process_dataset from transformers import TrainingArguments from transformers.data.data_collator import default_data_collator from transformers.trainer_pt_utils import IterableDatasetShard diff --git a/tests/dataset/_test_resumable_dataset.py b/tests/dataset/_test_resumable_dataset.py index 78280a6..9f3bc81 100644 --- a/tests/dataset/_test_resumable_dataset.py +++ b/tests/dataset/_test_resumable_dataset.py @@ -3,7 +3,6 @@ import glob import multiprocessing import os -import sys import torch from torch.utils.data import DataLoader @@ -11,12 +10,12 @@ from datasets.dataset_dict import IterableDatasetDict from transformers.trainer_utils import seed_worker -from jsonl_dataset import samples_generator +from chemlactica.jsonl_dataset import samples_generator +from chemlactica.utils.dataset_utils import process_dataset from test_utils import TD_PATH class TestDataloader(unittest.TestCase): - def setUp(self): # clean up gc.collect() @@ -28,13 +27,13 @@ def tearDown(self): torch.cuda.empty_cache() def get_train_dataloader( - self, train_dataset, batch_size, - num_workers, pin_memory - ) -> DataLoader: + self, train_dataset, batch_size, num_workers, pin_memory + ) -> DataLoader: """ Returns the training [`~torch.utils.data.DataLoader`]. - Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + Will use no sampler if `train_dataset` does not implement `__len__`, + a random sampler (adapted to distributed training if necessary) otherwise. Subclass and override this method if you want to inject some custom behavior. @@ -56,41 +55,49 @@ def get_train_dataloader( def test_dataloader(self): """ - The following code replicates the resumed training + The following code replicates the resumed training - The test works the following way - 1. We return line information along with each read line by the dataset - and check if the read line corresponds to the line in the file - 2. We do 2 stages, the first one is a dataset starting from scratch - and the second one a dataset (fast) resumed + The test works the following way + 1. We return line information along with each read line by the dataset + and check if the read line corresponds to the line in the file + 2. We do 2 stages, the first one is a dataset starting from scratch + and the second one a dataset (fast) resumed """ + default_tokenizer_path = "chemlactica/tokenizer/ChemLacticaTokenizer66" + train_config = {"tokenizer_path": default_tokenizer_path, "block_size": 2048} with multiprocessing.Manager() as manager: shared_jsonl_files = manager.dict() - training_data_dirs = [os.path.join(TD_PATH, "comp_train"), os.path.join(TD_PATH, "assay_train")] + training_data_dirs = [ + os.path.join(TD_PATH, "comp_train"), + os.path.join(TD_PATH, "assay_train"), + ] + dir_data_types = ["computed", "assay"] valid_data_dir = os.path.join(TD_PATH, "comp_valid") shuffle_buffer_size = 4 train_dataset_dict = {} print("---Training dataset names---") - for i, (training_data_dir, dir_data_type) in enumerate(zip(training_data_dirs, dir_data_types)): + for i, (training_data_dir, dir_data_type) in enumerate( + zip(training_data_dirs, dir_data_types) + ): training_data_files = glob.glob(training_data_dir + "/*.jsonl") ds_name = f"{dir_data_type}_1" is_assay_split = "assay" in dir_data_type dataset = IterableDataset.from_generator( samples_generator, - gen_kwargs = { - "files" : training_data_files, - "shared_jsonl_files" : shared_jsonl_files - } + gen_kwargs={ + "files": training_data_files, + "shared_jsonl_files": shared_jsonl_files, + }, ) dataset = process_dataset( dataset=dataset, train_config=train_config, process_batch_sizes=(50, 50), is_eval=False, - assay=is_assay_split + assay=is_assay_split, ) if is_assay_split: dataset.shuffle(buffer_size=shuffle_buffer_size) @@ -101,90 +108,97 @@ def test_dataloader(self): # combine small train and valid to have 2 files to test training_data_files.extend(glob.glob(valid_data_dir + "/*.jsonl")) - initial_train_dataset = IterableDatasetDict({ - "train": IterableDataset.from_generator( - samples_generator, - gen_kwargs={ - "files": training_data_files, - "shared_jsonl_files": shared_jsonl_files, - "return_line_info": True, - } - ) - }) + initial_train_dataset = IterableDatasetDict( + { + "train": IterableDataset.from_generator( + samples_generator, + gen_kwargs={ + "files": training_data_files, + "shared_jsonl_files": shared_jsonl_files, + "return_line_info": True, + }, + ) + } + ) """ get_train_dataloader is a function similar to (Trainer.get_train_dataloader) - what we do in training (we exclude accelerate.prepare operation, because that is for distributed training) + what we do in training. We exclude accelerate.prepare operation + because that is for distributed training. Use get_train_dataloader to replicate our dataloader during the training and use and test our dataset with multiprocessing """ initial_train_dataloader = self.get_train_dataloader( - initial_train_dataset["train"], 16, - num_workers=2, pin_memory=True + initial_train_dataset["train"], 16, num_workers=2, pin_memory=True ) """ - keep all the lines in the memory to check if the read line by dataset matches the actual line in the file + keep all the lines in the memory to check if the + read line by dataset matches the actual line in the file (this is kind of the "ground truth") """ loaded_files = {} for file in training_data_files: with open(file, "r") as _f: - loaded_files[file] = [{ - "text": line.rstrip("\n"), - "is_read": False - } for line in _f.readlines()] + loaded_files[file] = [ + {"text": line.rstrip("\n"), "is_read": False} + for line in _f.readlines() + ] sample_to_pass = 10 for i, samples in enumerate(initial_train_dataloader): for text, file, line_number in zip( - samples["text"], - samples["line_info"]["file"], - samples["line_info"]["line_number"].tolist() - ): + samples["text"], + samples["line_info"]["file"], + samples["line_info"]["line_number"].tolist(), + ): # check if the line matches with what is actually in the file assert loaded_files[file][line_number - 1]["text"] == text assert not loaded_files[file][line_number - 1]["is_read"] loaded_files[file][line_number - 1]["is_read"] = True - print(f'{file} {line_number} passed') + print(f"{file} {line_number} passed") if i == sample_to_pass: break fixed_shared_jsonl_files = {k: v for k, v in shared_jsonl_files.items()} - resumed_train_dataset = IterableDatasetDict({ - "train": IterableDataset.from_generator( - samples_generator, - gen_kwargs={ - "files": training_data_files, - "shared_jsonl_files": shared_jsonl_files, - "return_line_info": True, - } - ) - }) + resumed_train_dataset = IterableDatasetDict( + { + "train": IterableDataset.from_generator( + samples_generator, + gen_kwargs={ + "files": training_data_files, + "shared_jsonl_files": shared_jsonl_files, + "return_line_info": True, + }, + ) + } + ) resumed_train_dataloader = self.get_train_dataloader( - resumed_train_dataset["train"], 16, - num_workers=2, pin_memory=True + resumed_train_dataset["train"], 16, num_workers=2, pin_memory=True ) for samples in resumed_train_dataloader: for text, file, line_number in zip( - samples["text"], - samples["line_info"]["file"], - samples["line_info"]["line_number"].tolist() - ): + samples["text"], + samples["line_info"]["file"], + samples["line_info"]["line_number"].tolist(), + ): # check if the line matches with what is actually in the file assert loaded_files[file][line_number - 1]["text"] == text assert fixed_shared_jsonl_files[file]["line_number"] < line_number # assert not loaded_files[file][line_number - 1]["is_read"] loaded_files[file][line_number - 1]["is_read"] = True - print(f'{file} {line_number} passed') + print(f"{file} {line_number} passed") for file_name, lines in loaded_files.items(): number_of_read: int = 0 for i, line in enumerate(lines, start=1): # assert line["is_read"], f"'{file_name}' line {i} is not read." number_of_read += int(line["is_read"]) - print(f"File: {file_name}: number of read line {number_of_read}, number of not read {len(lines) - number_of_read}.") + print( + f"File: {file_name}: number of read line {number_of_read}, \ + number of not read {len(lines) - number_of_read}." + ) if __name__ == "__main__": diff --git a/tests/dataset/distributed_dataset_iter.py b/tests/dataset/distributed_dataset_iter.py new file mode 100644 index 0000000..8dfa709 --- /dev/null +++ b/tests/dataset/distributed_dataset_iter.py @@ -0,0 +1,114 @@ +import os +import multiprocessing +import glob +import shutil +from datetime import timedelta +import torch +import signal +import random +import numpy +from accelerate import Accelerator, logging, InitProcessGroupKwargs +import hashlib +from datasets.iterable_dataset import IterableDataset +from chemlactica.jsonl_dataset import samples_generator +from chemlactica.utils.utils import ( + signal_handler, +) + +from accelerate.state import PartialState + +distributed_state = PartialState() +torch.manual_seed(42) +random.seed(42) +numpy.random.seed(42) +logger = logging.get_logger("transformers") + +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + + +def run(): + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200)) + accelerator = Accelerator( + kwargs_handlers=[kwargs], log_with="all", project_dir=None + ) + + directory_name = ".tmp" + if distributed_state.is_main_process: + if os.path.exists(directory_name): + print(f"test directory '{directory_name}' already exists. Clearing it now") + shutil.rmtree(directory_name) + os.makedirs(directory_name) + print(f"test directory '{directory_name}' created successfully.") + num_files = 5 + num_lines = 100 + + for i in range(num_files): + file_name = os.path.join(directory_name, f"test_file_{i}.jsonl") + with open(file_name, "w") as file: + for j in range(num_lines): + sha3_hash = hashlib.sha3_256( + str.encode(f"test_file_{i}.jsonl - Line {j}") + ).hexdigest() + file.write(f"{sha3_hash}\n") + accelerator.wait_for_everyone() + with multiprocessing.Manager() as manager: + shared_jsonl_files = manager.dict() + + training_data_files = glob.glob(directory_name + "/*.jsonl") + training_data_files = [os.path.abspath(path) for path in training_data_files] + + print(training_data_files) + + accelerator.wait_for_everyone() + dataset = IterableDataset.from_generator( + samples_generator, + gen_kwargs={ + "files": training_data_files, + "shared_jsonl_files": shared_jsonl_files, + }, + ) + + file_name_mapping = {} + for process_index in range(distributed_state.num_processes): + file_name_mapping[process_index] = f"dataload_proc{process_index}.jsonl" + + for example in dataset: + file_to_write = file_name_mapping[distributed_state.process_index] + with open(file_to_write, "a") as f: + f.write(example["text"] + "\n") + accelerator.wait_for_everyone() + file_line_sets = [] + + # for process_index in range(distributed_state.num_processes): + if distributed_state.is_main_process: + for process_index, file_to_check in file_name_mapping.items(): + file_lines = load_file_contents(file_to_check) + file_line_set = set(file_lines) + file_line_sets.append(file_line_set) + print(f"file line set length {len(file_line_set)}") + print(f"file line length {len(file_lines)}") + assert len(file_lines) == len(file_line_set) + + num_sets = len(file_line_sets) + for i in range(num_sets): + for j in range(i + 1, num_sets): + set1 = file_line_sets[i] + set2 = file_line_sets[j] + assert set1.isdisjoint(set2) + + accelerator.wait_for_everyone() + if distributed_state.is_main_process: + for process_index in file_name_mapping: + file_to_check = file_name_mapping[process_index] + os.remove(file_to_check) + + +def load_file_contents(file): + with open(file, "r") as f: + lines = [line.strip() for line in f.readlines()] + return lines + + +if __name__ == "__main__": + run() diff --git a/tests/dataset/test_line_by_line_dataset.py b/tests/dataset/test_line_by_line_dataset.py new file mode 100644 index 0000000..0755754 --- /dev/null +++ b/tests/dataset/test_line_by_line_dataset.py @@ -0,0 +1,20 @@ +import unittest +import subprocess +from test_utils import create_train_command + + +class TestLineByLineDataloader(unittest.TestCase): + def test_line_by_line_dataloader(self): + command = create_train_command( + module="accelerate.commands.launch", + module_args={ + "config_file": "chemlactica/config/test_configs/fsdp_config.yaml" + }, + script="tests/dataset/distributed_dataset_iter.py", + script_args={}, + ) + + print(f"Running command: {command}") + out = subprocess.run(command, shell=True, capture_output=False) + if out.returncode != 0: + raise Exception(out.stderr.decode()) diff --git a/tests/fsdp/test_model_training.py b/tests/fsdp/test_model_training.py index b95170b..f0f00e1 100644 --- a/tests/fsdp/test_model_training.py +++ b/tests/fsdp/test_model_training.py @@ -53,7 +53,7 @@ def test_model_train(self): "max_steps": 300, "eval_steps": 2000, "save_steps": 2000, - "dataloader_num_workers": 8, + "dataloader_num_workers": 1, "checkpoints_root_dir": os.path.join(TEST_DIR, "checkpoints"), "experiment_name": "fsdp_model_train", "gradient_accumulation_steps": 1, @@ -88,7 +88,7 @@ def test_model_train_interleaved(self): "max_steps": 300, "eval_steps": 2000, "save_steps": 2000, - "dataloader_num_workers": 8, + "dataloader_num_workers": 1, "checkpoints_root_dir": os.path.join(TEST_DIR, "checkpoints"), "experiment_name": "fsdp_model_train", "gradient_accumulation_steps": 1, @@ -122,7 +122,7 @@ def test_model_valid(self): "max_steps": 100, "eval_steps": 10, "save_steps": 2000, - "dataloader_num_workers": 8, + "dataloader_num_workers": 1, "checkpoints_root_dir": os.path.join(TEST_DIR, "checkpoints"), "experiment_name": "fsdp_model_valid", "gradient_accumulation_steps": 1, @@ -156,7 +156,7 @@ def test_model_resume(self): "max_steps": 20, "eval_steps": 10, "save_steps": 10, - "dataloader_num_workers": 8, + "dataloader_num_workers": 1, "checkpoints_root_dir": os.path.join(TEST_DIR, "checkpoints"), "experiment_name": "fsdp_model_resume", "gradient_accumulation_steps": 1, @@ -187,7 +187,7 @@ def test_model_resume(self): "max_steps": 40, "eval_steps": 10, "save_steps": 10, - "dataloader_num_workers": 8, + "dataloader_num_workers": 1, "checkpoints_root_dir": os.path.join(TEST_DIR, "checkpoints"), "experiment_name": "fsdp_model_resume", "gradient_accumulation_steps": 1,