forked from skypilot-org/skypilot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.yaml
118 lines (103 loc) · 3.41 KB
/
train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
envs:
MODEL_SIZE: 7
SEQ_LEN: 2048
GC_SCALE: 4
USE_FLASH_ATTN: 0
WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass.
resources:
accelerators: A100-80GB:8
disk_size: 1000
use_spot: true
disk_tier: best
num_nodes: 1
file_mounts:
/artifacts:
name: YOUR_OWN_BUCKET_NAME # Change to your own bucket name
mode: MOUNT
/data/mydata.json: ./dummy.json
workdir: .
setup: |
# Setup the environment
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.10 -y
conda activate chatbot
fi
# Install pytorch
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
# Install huggingface with the LLaMA commit
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout 41a2f3529c6b56866c317031375ffd3e7b8bea01
pip install .
cd -
# Install fastchat
git clone https://github.com/lm-sys/FastChat.git
cd FastChat
pip install -e .
if [ $USE_FLASH_ATTN -eq 1 ]; then
pip install flash-attn
fi
run: |
cd FastChat
conda activate chatbot
USE_FLASH_ATTN=${USE_FLASH_ATTN:-0}
if [ $USE_FLASH_ATTN -eq 1 ]; then
TRAIN_SCRIPT=fastchat/train/train_mem.py
USE_FLASH_SUFFIX="-flash"
else
TRAIN_SCRIPT=fastchat/train/train.py
USE_FLASH_SUFFIX=""
fi
PER_DEVICE_BATCH_SIZE=$((2048 * $GC_SCALE / $SEQ_LEN))
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Do the periodic syncing manually, to avoid the degradation of
# the training for saving checkpoints.
mkdir -p ~/.checkpoints
LOCAL_CKPT_PATH=~/.checkpoints
CKPT_PATH=/artifacts/chatbot/${MODEL_SIZE}b/skypilot-chatbot
mkdir -p $CKPT_PATH
last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
mkdir -p ~/.checkpoints/${last_ckpt}
gsutil -m rsync -r ${CKPT_PATH}/${last_ckpt}/ ~/.checkpoints/${last_ckpt}
bash ../scripts/sync_local_checkpoint.sh ${LOCAL_CKPT_PATH} ${CKPT_PATH} > sync.log 2>&1 &
# Turn off wandb if no api key is provided
if [ $WANDB_API_KEY == "" ]; then
WANDB_MODE="offline"
fi
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
$TRAIN_SCRIPT \
--model_name_or_path huggyllama/llama-${MODEL_SIZE}b \
--data_path /data/mydata.json \
--bf16 True \
--output_dir $LOCAL_CKPT_PATH \
--num_train_epochs 3 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 600 \
--save_total_limit 3 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length ${SEQ_LEN} \
--run_name $SKYPILOT_TASK_ID \
--gradient_checkpointing True \
--lazy_preprocess True
returncode=$?
# Sync any files not in the checkpoint-* folders
gsutil -m rsync -r -x 'checkpoint-*' $LOCAL_CKPT_PATH/ $CKPT_PATH/
exit $returncode