Skip to content

Commit

Permalink
remove unused weight swapping functions from utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Nov 5, 2024
1 parent 81c0c96 commit aab943c
Showing 1 changed file with 0 additions and 185 deletions.
185 changes: 0 additions & 185 deletions library/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,6 @@ def setup_logging(args=None, log_level=None, reset=False):

# region PyTorch utils

# def swap_weights(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# # print(f"Swapping {layer_to_cpu.__class__.__name__}-{module_to_cpu.__class__.__name__}.")
# # cpu_tensor = module_to_cuda.weight.data
# # cuda_tensor = module_to_cpu.weight.data
# # assert cuda_tensor.device.type == "cuda"
# # temp_cpu_tensor = cuda_tensor.to("cpu", non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cuda_tensor.copy_(cpu_tensor, non_blocking=True)
# # torch.cuda.current_stream().synchronize()
# # cpu_tensor.copy_(temp_cpu_tensor, non_blocking=True)
# # module_to_cpu.weight.data, module_to_cuda.weight.data = cpu_tensor, cuda_tensor
# cuda_tensor_view = module_to_cpu.weight.data
# cpu_tensor_view = module_to_cuda.weight.data
# module_to_cpu.weight.data = module_to_cpu.weight.to("cpu", non_blocking=True).detach().clone()
# module_to_cuda.weight.data = cuda_tensor_view
# module_to_cuda.weight.data.copy_(cpu_tensor_view)


def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__
Expand Down Expand Up @@ -143,171 +123,6 @@ def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
torch.cuda.current_stream().synchronize() # this prevents the illegal loss value


def swap_weight_devices_2st(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))

stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()

events = []
with torch.cuda.stream(stream_to_cpu):
# cuda to offload
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))
event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)

with torch.cuda.stream(stream_to_cuda):
# cpu to cuda
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event in zip(weight_swap_jobs, events):
event.synchronize()
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view

# offload to cpu
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), offloaded_weight in zip(
weight_swap_jobs, offloaded_weights
):
module_to_cpu.weight.data = offloaded_weight

stream_to_cuda.synchronize()

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value


def swap_weight_devices_failed(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))

stream_to_cpu = torch.cuda.Stream()
stream_to_cuda = torch.cuda.Stream()

# cuda to offload
events = []
with torch.cuda.stream(stream_to_cpu):
offloaded_weights = []
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view in weight_swap_jobs:
cuda_data_view.record_stream(stream_to_cpu)
offloaded_weights.append(cuda_data_view.to("cpu", non_blocking=True))

event = torch.cuda.Event()
event.record(stream=stream_to_cpu)
events.append(event)

# cpu to cuda
with torch.cuda.stream(stream_to_cuda):
for (module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view), event, offloaded_weight in zip(
weight_swap_jobs, events, offloaded_weights
):
event.synchronize()
cuda_data_view.record_stream(stream_to_cuda)
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view

module_to_cpu.weight.data = offloaded_weight

stream_to_cuda.synchronize()

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value
# torch.cuda.current_stream().wait_stream(stream_to_cuda)
# for job in weight_swap_jobs:
# job[2].record_stream(torch.cuda.current_stream()) # record the ownership of the tensor


def swap_weight_devices_works_2(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "offloaded_weight") or hasattr(module_to_cuda, "offloaded_weight")):
# one of the modules must have the tensor to offload
module_to_cpu.offloaded_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.offloaded_weight.pin_memory()
offloaded_weight = (
module_to_cpu.offloaded_weight if hasattr(module_to_cpu, "offloaded_weight") else module_to_cuda.offloaded_weight
)
assert module_to_cpu.weight.device.type == "cuda" and module_to_cuda.weight.device.type == "cpu"
weight_swap_jobs.append(
(module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data, offloaded_weight)
)

stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# cuda to offload
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.record_stream(stream)
offloaded_weight.copy_(module_to_cpu.weight.data, non_blocking=True)

stream.synchronize()

# cpu to cuda
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
cuda_data_view.copy_(module_to_cuda.weight.data, non_blocking=True)
module_to_cuda.weight.data = cuda_data_view

# offload to cpu
for module_to_cpu, module_to_cuda, cuda_data_view, cpu_data_view, offloaded_weight in weight_swap_jobs:
module_to_cpu.weight.data = offloaded_weight
offloaded_weight = cpu_data_view
module_to_cpu.offloaded_weight = offloaded_weight
module_to_cuda.offloaded_weight = offloaded_weight

stream.synchronize()

torch.cuda.current_stream().synchronize() # this prevents the illegal loss value


def swap_weight_devices_safe_works(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
assert layer_to_cpu.__class__ == layer_to_cuda.__class__

weight_swap_jobs = []
for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
if not (hasattr(module_to_cpu, "__cached_cpu_weight") or hasattr(module_to_cuda, "__cached_cuda_weight")):
# one of the modules must have the tensor to cache
module_to_cpu.__cached_cpu_weight = torch.zeros_like(module_to_cpu.weight.data, device="cpu")
module_to_cpu.__cached_cpu_weight.pin_memory()

weight_swap_jobs.append((module_to_cpu, module_to_cuda, module_to_cpu.weight.data, module_to_cuda.weight.data))

for module_to_cpu, module_to_cuda, cuda_tensor_view, cpu_tensor_view in weight_swap_jobs:
module_to_cpu.weight.data = cuda_tensor_view.to("cpu", non_blocking=True)
module_to_cuda.weight.data = cpu_tensor_view.to("cuda", non_blocking=True)

torch.cuda.current_stream().synchronize() # wait for the copy from cache to cpu to finish
torch.cuda.empty_cache()


# def swap_weight_devices(layer_to_cpu: nn.Module, layer_to_cuda: nn.Module):
# assert layer_to_cpu.__class__ == layer_to_cuda.__class__
# for module_to_cpu, module_to_cuda in zip(layer_to_cpu.modules(), layer_to_cuda.modules()):
# if hasattr(module_to_cpu, "weight") and module_to_cpu.weight is not None:
# assert module_to_cuda.weight.device.type == "cpu" and module_to_cpu.weight.device.type == "cuda"
# weight_on_cuda = module_to_cpu.weight
# weight_on_cpu = module_to_cuda.weight
# cuda_to_cpu_data = weight_on_cuda.data.to("cpu", non_blocking=True)
# event = torch.cuda.current_stream().record_event()
# event.synchronize()
# weight_on_cuda.data.copy_(weight_on_cpu.data, non_blocking=True)
# weight_on_cpu.data = cuda_to_cpu_data
# weight_on_cpu.grad, weight_on_cuda.grad = weight_on_cuda.grad, weight_on_cpu.grad

# module_to_cpu.weight = weight_on_cpu
# module_to_cuda.weight = weight_on_cuda


def weighs_to_device(layer: nn.Module, device: torch.device):
for module in layer.modules():
if hasattr(module, "weight") and module.weight is not None:
Expand Down

0 comments on commit aab943c

Please sign in to comment.