Skip to content

Commit

Permalink
Load checkpoint using S3DPReader
Browse files Browse the repository at this point in the history
  • Loading branch information
rajdchak committed Sep 26, 2024
1 parent 21d822d commit 13cca76
Showing 1 changed file with 64 additions and 18 deletions.
82 changes: 64 additions & 18 deletions s3torchconnector/src/s3torchconnector/dcp/toy_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@
import torch.distributed.checkpoint as DCP
from torch import nn
import argparse

import torch.distributed as dist

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

from s3torchconnector import S3StorageWriter, S3StorageReader, S3DPWriter
from s3torchconnector import S3StorageWriter, S3StorageReader, S3DPWriter, S3DPReader

CHECKPOINT_DIR = "checkpoint"


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
Expand Down Expand Up @@ -56,15 +52,7 @@ def run_fsdp_checkpoint_save_example(rank, backend):
model(torch.rand(8, 16, device=torch.device("cpu"))).sum().backward()
optimizer.step()

loaded_state_dict = {}
# DCP.load(
# loaded_state_dict,
# # storage_reader=DCP.FileSystemReader(CHECKPOINT_DIR)
# storage_reader=S3StorageReader(region="eu-north-1", s3_uri="s3://dcp-poc-test/", thread_count=world_size)
# )

# set FSDP StateDictType to SHARDED_STATE_DICT so we can use DCP to checkpoint sharded model state dict
# note that we do not support FSDP StateDictType.LOCAL_STATE_DICT
# Set FSDP StateDictType to SHARDED_STATE_DICT to checkpoint the model state dict
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
Expand All @@ -74,25 +62,74 @@ def run_fsdp_checkpoint_save_example(rank, backend):
}

thread_count = 4
bucket = "dcp-poc-test-2"
path = f"s3://{bucket}/epoech_1/"
bucket = "dcp-poc-test-3"
path = f"s3://{bucket}/epoch_1/"
region = "eu-west-2"
# writer_to_use = "local"
writer_to_use = "s3_fs"
writer = get_writer(region, path, thread_count, writer_to_use)

DCP.save(state_dict=state_dict, storage_writer=writer)

print("Checkpoint saved for epoch 1.")

# Save for another epoch
state_dict = {
"model": model.state_dict(),
"prefix": "bla",
}
optimizer.step()

path = f"s3://{bucket}/epoech_2/"
path = f"s3://{bucket}/epoch_2/"
writer = get_writer(region, path, thread_count, writer_to_use)
DCP.save(state_dict=state_dict, storage_writer=writer)

print("Checkpoint saved for epoch 2.")

return state_dict # Return the state dict for verification

def run_fsdp_checkpoint_load_example(rank, backend, state_dict):
print(f"Running basic FSDP checkpoint loading example on rank {rank}.")

# Need to put tensor on a GPU device for nccl backend
if backend == "nccl":
device_id = rank % torch.cuda.device_count()
model = ToyModel().to(device_id)
model = FSDP(model, device_id=device_id)
elif backend == "gloo":
model = ToyModel().to(device=torch.device("cpu"))
model = FSDP(model, device_id=torch.cpu.current_device())
else:
raise Exception(f"Unknown backend type: {backend}")

# Set FSDP StateDictType to SHARDED_STATE_DICT to load the sharded model state dict
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
)

# Prepare state_dict to load into
loaded_state_dict = {}

thread_count = 1
bucket = "dcp-poc-test-3"
path = f"s3://{bucket}/epoch_2/"
region = "eu-west-2"
reader_to_use = "s3_fs"
reader = get_reader(region, path, thread_count, reader_to_use)

# Load the checkpoint
DCP.load(state_dict=loaded_state_dict, storage_reader=reader)

# Load the model state dict from the checkpoint
model.load_state_dict(loaded_state_dict["model"])
print("Checkpoint loaded and model state dict restored.")

# Verify that saved and loaded state dicts are similar
if torch.allclose(state_dict["model"], loaded_state_dict["model"]):
print("The saved and loaded model state dicts are similar.")
else:
print("The saved and loaded model state dicts differ.")

def get_writer(region, path, thread_count, writer_to_use):
if writer_to_use == "local":
Expand All @@ -103,6 +140,14 @@ def get_writer(region, path, thread_count, writer_to_use):
writer = S3StorageWriter(region=region, s3_uri=path, thread_count=thread_count)
return writer

def get_reader(region, path, thread_count, reader_to_use):
if reader_to_use == "local":
reader = DCP.FileSystemReader(CHECKPOINT_DIR)
elif reader_to_use == "s3_fs":
reader = S3DPReader(region=region, path=path)
else:
reader = S3StorageReader(region=region, s3_uri=path, thread_count=thread_count)
return reader

if __name__ == "__main__":
"""
Expand Down Expand Up @@ -152,6 +197,7 @@ def get_writer(region, path, thread_count, writer_to_use):
world_size = dist.get_world_size()
print(f"Starting for rank {rank}, world_size is {world_size}")

run_fsdp_checkpoint_save_example(rank, args.backend)
state_dict = run_fsdp_checkpoint_save_example(rank, args.backend)
run_fsdp_checkpoint_load_example(rank, args.backend, state_dict)

cleanup()

0 comments on commit 13cca76

Please sign in to comment.