Skip to content

Commit

Permalink
[benchmarks] Fix AMP setup for torchbench models. (#7067)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored May 16, 2024
1 parent 9e18935 commit aeed89e
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions benchmarks/torchbench_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _cleanup(self):
gc.collect()

# If we are using CUDA, clean-up its cache left-over.
if self.benchmark_experiment.accelerator == "cuda":
if self.is_accelerator_cuda():
torch.cuda.empty_cache()

def set_up(self):
Expand Down Expand Up @@ -253,7 +253,7 @@ def set_up(self):
if self.benchmark_experiment.xla:
# First, move the model and the inputs to CPU.
# This avoids having dupplicated data on CUDA.
if self.benchmark_experiment.accelerator == "cuda":
if self.is_accelerator_cuda():
self.module = self.module.to("cpu")
self.example_inputs = move_to_device(self.example_inputs, "cpu")
self._cleanup()
Expand Down Expand Up @@ -305,8 +305,9 @@ def load_benchmark(self):
# torch.backends.__allow_nonbracketed_mutation_flag = True

# torchbench uses `xla` as device instead of `tpu`
if (device := self.benchmark_experiment.accelerator) == 'tpu':
device = str(self.benchmark_experiment.get_device())
device = (
str(self.benchmark_experiment.get_device())
if self.is_accelerator_tpu() else self.benchmark_experiment.accelerator)

return self.benchmark_cls()(
test=self.benchmark_experiment.test,
Expand All @@ -330,6 +331,12 @@ def is_inference(self):
def is_training(self):
return self.benchmark_experiment.test == "train"

def is_accelerator_cuda(self):
return self.benchmark_experiment.accelerator == "cuda"

def is_accelerator_tpu(self):
return self.benchmark_experiment.accelerator == "tpu"

def use_amp(self):
return self.is_training(
) or self.model_name in FORCE_AMP_FOR_FP16_BF16_MODELS
Expand All @@ -350,23 +357,27 @@ def conversion_dtype(self):

def _get_autocast_with_kwargs(self):
kwargs = {}

if self.use_amp():
# Set the default data-type based on the accelerator.
if self.benchmark_experiment.accelerator == "cuda":
# TODO: Should call device specific autocast implementations.
# Specifically, we should be using:
# - torch.cuda.amp.autocast for inductor
# - torch_xla.amp.autocast for PyTorch/XLA experiments.
# PyTorch/XLA autocast does not run with dynamo, though:
# https://github.com/pytorch/xla/issues/6511
if self.is_accelerator_cuda():
# For inductor and XLA:CUDA, we use CUDA autocast.
autocast = torch.cuda.amp.autocast
kwargs["dtype"] = torch.float16
else:
# Both CPU and TPU autocast mode defaults to bfloat16.
kwargs["dtype"] = torch.bfloat16

if self.benchmark_experiment.xla:
# Should call device specific autocast implementations.
# PyTorch/XLA autocast does not run with dynamo, though:
# https://github.com/pytorch/xla/issues/6511
elif self.is_accelerator_tpu():
autocast = torch.amp.autocast
kwargs["device_type"] = "xla"
kwargs["dtype"] = torch.bfloat16
else:
autocast = torch.cuda.amp.autocast
# Error: AMP is only supported on XLA:CUDA and XLA:TPU.
name = self.model_name
accelerator = self.benchmark_experiment.accelerator
raise RuntimeError(f"Tried to run {name} with AMP on {accelerator}. "
"However, AMP is only supported on cuda and tpu.")
else:
autocast = contextlib.nullcontext
return (autocast, kwargs)
Expand Down

0 comments on commit aeed89e

Please sign in to comment.