Skip to content

Commit

Permalink
add train/valid/resume test
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed Nov 1, 2023
1 parent 7be97ae commit a3fb6c0
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 3 deletions.
2 changes: 1 addition & 1 deletion test_status.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
effe33b961b2f61f5aa5131da4f9e5c680a1cc36: PASS
7be97ae21e5e6d5480c2ea678e9607ce509986c4: PASS
3 changes: 1 addition & 2 deletions tests/fsdp/test_model_consist.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,7 @@ def test_consist_of_model_output(self):
print(f"Running command: {command}")
out = subprocess.run(command, shell=True, capture_output=True)
if out.returncode != 0:
print(f"error: {out.stderr}")
raise Exception()
raise Exception(out.stderr.decode())


if __name__ == "__main__":
Expand Down
168 changes: 168 additions & 0 deletions tests/fsdp/test_model_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import subprocess
import unittest
import gc
import os
import sys
import shutil

import torch

from test_utils import create_train_command

test_directory = "/tmp/chemlactica_fsdp_precommit_test"


class TestModelTraining(unittest.TestCase):

def setUp(self):
# clean up
gc.collect()
torch.cuda.empty_cache()

if os.path.exists(test_directory):
print(f"Removing {test_directory}")
shutil.rmtree(test_directory)
os.mkdir(test_directory)
os.mkdir(f"{test_directory}/checkpoints")

def tearDown(self):
shutil.rmtree(test_directory)

# clean up
gc.collect()
torch.cuda.empty_cache()

def test_model_train(self):
# clean up
gc.collect()
torch.cuda.empty_cache()

command = create_train_command(
module="accelerate.commands.launch",
module_args={"config_file": "src/config/test_configs/fsdp_config.yaml"},
script="src/train.py",
script_args={
"from_pretrained": "facebook/galactica-125m",
"model_config": "125m",
"training_data_dir": ".small_data/train",
"valid_data_dir": ".small_data/valid",
"train_batch_size": 4,
"max_steps": 1000,
"eval_steps": 2000,
"save_steps": 2000,
"dataloader_num_workers": 1,
"checkpoints_root_dir": f"{test_directory}/checkpoints",
"experiment_name": "fsdp_model_train",
"gradient_accumulation_steps": 1,
"no_track": "",
"flash_attn": "",
}
)

print(f"Running command: {command}")
out = subprocess.run(command, shell=True, capture_output=True)
if out.returncode != 0:
raise Exception(out.stderr.decode())
else:
print(out.stdout.decode())

def test_model_valid(self):
# clean up
gc.collect()
torch.cuda.empty_cache()

command = create_train_command(
module="accelerate.commands.launch",
module_args={"config_file": "src/config/test_configs/fsdp_config.yaml"},
script="src/train.py",
script_args={
"from_pretrained": "facebook/galactica-125m",
"model_config": "125m",
"training_data_dir": ".small_data/train",
"valid_data_dir": ".small_data/valid",
"train_batch_size": 4,
"max_steps": 100,
"eval_steps": 10,
"save_steps": 2000,
"dataloader_num_workers": 1,
"checkpoints_root_dir": f"{test_directory}/checkpoints",
"experiment_name": "fsdp_model_valid",
"gradient_accumulation_steps": 1,
"no_track": "",
"flash_attn": "",
}
)

print(f"Running command: {command}")
out = subprocess.run(command, shell=True, capture_output=True)
if out.returncode != 0:
raise Exception(out.stderr.decode())
else:
print(out.stdout.decode())

def test_model_resume(self):
# clean up
gc.collect()
torch.cuda.empty_cache()

first_command = create_train_command(
module="accelerate.commands.launch",
module_args={"config_file": "src/config/test_configs/fsdp_config.yaml"},
script="src/train.py",
script_args={
"from_pretrained": "facebook/galactica-125m",
"model_config": "125m",
"training_data_dir": ".small_data/train",
"valid_data_dir": ".small_data/valid",
"train_batch_size": 4,
"max_steps": 20,
"eval_steps": 10,
"save_steps": 10,
"dataloader_num_workers": 1,
"checkpoints_root_dir": f"{test_directory}/checkpoints",
"experiment_name": "fsdp_model_resume",
"gradient_accumulation_steps": 1,
"no_track": "",
"flash_attn": "",
}
)

print(f"Running command: {first_command}")
out = subprocess.run(first_command, shell=True, capture_output=True)
if out.returncode != 0:
raise Exception(out.stderr.decode())
else:
print(out.stdout.decode())

second_command = create_train_command(
module="accelerate.commands.launch",
module_args={"config_file": "src/config/test_configs/fsdp_config.yaml"},
script="src/train.py",
script_args={
"from_pretrained": f"{test_directory}/checkpoints/facebook/galactica-125m/none/checkpoint-{20}",
"model_config": "125m",
"training_data_dir": ".small_data/train",
"valid_data_dir": ".small_data/valid",
"train_batch_size": 4,
"max_steps": 40,
"eval_steps": 10,
"save_steps": 10,
"dataloader_num_workers": 1,
"checkpoints_root_dir": f"{test_directory}/checkpoints",
"experiment_name": "fsdp_model_resume",
"gradient_accumulation_steps": 1,
"no_track": "",
"flash_attn": "",
}
)

print(f"Running command: {second_command}")
out = subprocess.run(second_command, shell=True, capture_output=True)
if out.returncode != 0:
raise Exception(out.stderr.decode())
else:
print(out.stdout.decode())


if __name__ == "__main__":
unittest.main(verbosity=2)

0 comments on commit a3fb6c0

Please sign in to comment.