Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic limiting of local batchsize bounds after OOM #90

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
64 changes: 64 additions & 0 deletions adaptdl/adaptdl/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2020 Petuum, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import functools
import logging

import adaptdl.checkpoint
import adaptdl.env
from adaptdl.torch._metrics import _report_sched_hints
from adaptdl.torch.data import current_dataloader

logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)


# Percentage of current_local_bsz used to decide upper bound on
# local_bsz_bounds after OOM
LOCAL_BSZ_CUTOFF_PCT = 0.1


def cudaoom(e):
return "CUDA out of memory" in str(e)


def retry(func):
@functools.wraps(func)
def inner(*args, **kwargs):
try:
func(*args, **kwargs)
except RuntimeError as e:
LOG.info(f"{e}")
dataloader = current_dataloader()
if (dataloader is not None and
dataloader.local_bsz_bounds is not None and
cudaoom(e)):
current_local_bsz = dataloader.current_local_bsz
low, high = dataloader.local_bsz_bounds
assert current_local_bsz <= high
new_high = int((1. - LOCAL_BSZ_CUTOFF_PCT) * current_local_bsz)
if new_high < low:
raise e
dataloader.local_bsz_bounds = (low, new_high)
LOG.info(f"Local batch size bounds changed to "
f"{dataloader.local_bsz_bounds}")
if adaptdl.env.replica_rank() == 0:
_report_sched_hints()
adaptdl.checkpoint.save_all_states()
exit(143)
else:
raise e
return inner
38 changes: 25 additions & 13 deletions adaptdl/adaptdl/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ class AdaptiveDataLoaderHelper(object):
def __init__(self, batch_size=1):
# Autoscale batch size fields.
self._max_batch_size = None
self._local_bsz_bounds = None
# Create and load state.
self._state = _AdaptiveDataLoaderState()
adaptdl.checkpoint.load_state(self._state)
Expand Down Expand Up @@ -198,7 +197,11 @@ def local_bsz_bounds(self):
The local batch size bounds on each replica. A pair of integers,
(min_local_bsz, max_local_bsz).
"""
return self._local_bsz_bounds
return self._state.local_bsz_bounds

@local_bsz_bounds.setter
def local_bsz_bounds(self, bounds):
self._state.local_bsz_bounds = bounds

@property
def current_local_bsz(self):
Expand Down Expand Up @@ -263,7 +266,8 @@ def autoscale_batch_size(self, max_batch_size, local_bsz_bounds=None,
local_bsz_bounds[1] < self.batch_size):
raise ValueError("invalid local_bsz_bounds")
self._max_batch_size = max_batch_size
self._local_bsz_bounds = local_bsz_bounds
if self.local_bsz_bounds is None:
self.local_bsz_bounds = local_bsz_bounds
self._gradient_accumulation = gradient_accumulation
self.train()

Expand All @@ -279,7 +283,7 @@ def _sync_local_bsz(self):
_, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
atomic_bsz_range=self.local_bsz_bounds,
accumulation=self._gradient_accumulation)
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps
Expand All @@ -288,7 +292,7 @@ def _sync_local_bsz(self):
suggest_goodput, atomic_bsz, accum_steps = goodput_fn.optimize(
adaptdl.env.num_nodes(), adaptdl.env.num_replicas(),
max_batch_size=self._max_batch_size,
atomic_bsz_range=self._local_bsz_bounds,
atomic_bsz_range=self.local_bsz_bounds,
accumulation=self._gradient_accumulation)
# get current goodput
current_goodput = goodput_fn(
Expand All @@ -299,6 +303,7 @@ def _sync_local_bsz(self):
if speedup > self._speedup_threshold:
self._state.current_local_bsz = atomic_bsz
self._state.accumulation_steps = accum_steps

self._state.current_local_bsz, self._state.accumulation_steps = \
adaptdl.collective.broadcast((self._state.current_local_bsz,
self._state.accumulation_steps))
Expand Down Expand Up @@ -340,18 +345,23 @@ def context(self):
proper cleanup of elastic context at the end of each epoch.
"""
epoch = current_epoch()
exception = False
try:
if AdaptiveDataLoaderHelper._current is not None:
raise RuntimeError("overlapping dataloader \
iterations detected")
AdaptiveDataLoaderHelper._current = self
yield
except GeneratorExit:
# Generic Exception outside of the dataloader
exception = True
finally:
self._state.current_index = 0
self._state.end_index = 0
self._state.last_position[epoch] = self._position[epoch]
self._position[epoch] += 1
AdaptiveDataLoaderHelper._current = None
if not exception:
self._state.current_index = 0
self._state.end_index = 0
self._state.last_position[epoch] = self._position[epoch]
self._position[epoch] += 1
AdaptiveDataLoaderHelper._current = None

@property
def current_batch_size(self):
Expand Down Expand Up @@ -490,6 +500,7 @@ class AdaptiveDataLoader(DataLoader, AdaptiveDataLoaderMixin):

.. automethod:: __iter__
"""

def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):
if kwargs.get("batch_sampler") is not None \
or kwargs.get("sampler") is not None:
Expand Down Expand Up @@ -563,13 +574,14 @@ def __init__(self):
self.current_index = 0 # Index within the current dataloader loop.
self.end_index = 0 # End index of the current DataLoader loop.
self.last_position = {} # Epoch -> position of last completed loop.
self.local_bsz_bounds = None
self.current_local_bsz = 0
self.accumulation_steps = 0

def save(self, fileobj):
pickle.dump((self.current_index, self.end_index,
self.last_position), fileobj)
self.last_position, self.local_bsz_bounds), fileobj)

def load(self, fileobj):
self.current_index, self.end_index, self.last_position = \
pickle.load(fileobj)
self.current_index, self.end_index, self.last_position, \
self.local_bsz_bounds = pickle.load(fileobj)
2 changes: 2 additions & 0 deletions examples/pytorch-cifar/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import adaptdl
import adaptdl.torch as adl
from adaptdl.retry import retry

from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
Expand Down Expand Up @@ -99,6 +100,7 @@


# Training
@retry
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
Expand Down
7 changes: 6 additions & 1 deletion sched/adaptdl_sched/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ async def _sync_job(self, namespace, job_name):
replicas = job["status"].get("replicas", 0)
preemptible = job["spec"].get("preemptible", True)
if (completion_status := self._detect_completion(pods, preemptible)):
# Job is already completed.
job["status"].update(completion_status)
phase = job["status"]["phase"]
if phase in ("Succeeded", "Failed"):
# Job is already completed.
job["status"].setdefault("completionTimestamp", current_ts)
job["status"]["allocation"] = allocation = []
await self._delete_pods( # Keep failed pods for debug purposes.
Expand Down Expand Up @@ -294,14 +296,17 @@ def any143(pod):
# resources before this pod could bind to that node.
LOG.warning("UnexpectedAdmissionError for pod %s: %s",
pod.metadata.name, pod.status.message)
return {"phase": "Stopping"}
elif str(pod.status.reason).startswith("Outof"):
# we might be temporarily out of pods on this node
LOG.warning(f"Pod {pod.metadata.name} is {pod.status.reason} "
f"on {pod.spec.node_name}")
return {"phase": "Stopping"}
elif preemptible and (pod.metadata.deletion_timestamp is not None
or any143(pod)):
# This pod was intentionally terminated.
LOG.warning(f"Pod {pod.metadata.name} terminated")
return {"phase": "Stopping"}
else:
return {"phase": "Failed", "reason": "PodFailure",
"message": f"{pod.metadata.name} {pod.status.phase}"}
Expand Down