Skip to content

Commit

Permalink
feat: enable (and expose) torch autocast and inference_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
percevalw committed Jun 25, 2024
1 parent b73336a commit 9d1a640
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
4 changes: 4 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

### Added

- `data.set_processing(...)` now expose an `autocast` parameter to disable or tweak the automatic casting of the tensor
during the processing. Autocasting should result in a slight speedup, but may lead to numerical instability.
- Use `torch.inference_mode` to disable view tracking and version counter bumps during inference.

### Changed

### Fixed
Expand Down
22 changes: 17 additions & 5 deletions edsnlp/core/lazy_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def gpu_worker_devices(self):
def cpu_worker_devices(self):
return self.config.get("cpu_worker_devices")

@property
def autocast(self):
return self.config.get("autocast")

@property
def backend(self):
backend = self.config.get("backend")
Expand All @@ -156,8 +160,9 @@ def set_processing(
num_gpu_workers: Optional[int] = INFER,
disable_implicit_parallelism: bool = True,
backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = INFER,
gpu_pipe_names: Optional[List[str]] = INFER,
autocast: Union[bool, Any] = True,
show_progress: bool = False,
gpu_pipe_names: Optional[List[str]] = INFER,
process_start_method: Optional[Literal["fork", "spawn"]] = INFER,
gpu_worker_devices: Optional[List[str]] = INFER,
cpu_worker_devices: Optional[List[str]] = INFER,
Expand Down Expand Up @@ -203,10 +208,6 @@ def set_processing(
disable_implicit_parallelism: bool
Whether to disable OpenMP and Huggingface tokenizers implicit parallelism in
multiprocessing mode. Defaults to True.
gpu_pipe_names: Optional[List[str]]
List of pipe names to accelerate on a GPUWorker, defaults to all pipes
that inherit from TorchComponent. Only used with "multiprocessing" backend.
Inferred from the pipeline if not set.
backend: Optional[Literal["simple", "multiprocessing", "spark"]]
The backend to use for parallel processing. If not set, the backend is
automatically selected based on the input data and the number of workers.
Expand All @@ -217,9 +218,20 @@ def set_processing(
`num_gpu_workers` is greater than 0.
- "spark" is used when the input data is a Spark dataframe and the output
writer is a Spark writer.
autocast: Union[bool, Any]
Whether to use
[automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html)
for the forward pass of the deep-learning components. If True (by default),
AMP will be used with the default settings. If False, AMP will not be used.
If a dtype is provided, it will be passed to the `torch.autocast` context
manager.
show_progress: Optional[bool]
Whether to show progress bars (only applicable with "simple" and
"multiprocessing" backends).
gpu_pipe_names: Optional[List[str]]
List of pipe names to accelerate on a GPUWorker, defaults to all pipes
that inherit from TorchComponent. Only used with "multiprocessing" backend.
Inferred from the pipeline if not set.
process_start_method: Optional[Literal["fork", "spawn"]]
Whether to use "fork" or "spawn" as the start method for the multiprocessing
backend. The default is "fork" on Unix systems and "spawn" on Windows.
Expand Down
11 changes: 10 additions & 1 deletion edsnlp/processing/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,13 +656,22 @@ def run(self):
if name in self.gpu_pipe_names
]

autocast_ctx = (
torch.autocast(
device_type=self.device,
dtype=lc.autocast,
)
if lc.autocast is not None
else nullcontext()
)

del lc
logging.info(f"Starting {self} on {os.getpid()}")

# Inform the main process that we are ready
self.exchanger.put_results((None, 0, None, None))

with torch.no_grad(): # , torch.cuda.amp.autocast():
with torch.no_grad(), autocast_ctx, torch.inference_mode():
while True:
stage, task = self.exchanger.get_gpu_task(self.gpu_idx)
if task is None:
Expand Down
24 changes: 20 additions & 4 deletions edsnlp/processing/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,25 @@ def execute_simple_backend(
batch on the current process in a sequential manner.
"""
try:
no_grad = sys.modules["torch"].no_grad
torch = sys.modules["torch"]
no_grad_ctx = torch.no_grad()
autocast_ctx = (
torch.autocast(
device_type=next(
p.device for pipe in lc.torch_components for p in lc.parameters
),
dtype=lc.autocast,
)
if lc.autocast is not None
else nullcontext()
)
inference_mode_ctx = (
torch.inference_mode()
if hasattr(torch, "inference_mode")
else nullcontext()
)
except (KeyError, AttributeError):
no_grad = nullcontext
no_grad_ctx = autocast_ctx = inference_mode_ctx = nullcontext()
reader = lc.reader
writer = lc.writer
show_progress = lc.show_progress
Expand All @@ -48,7 +64,7 @@ def process():

bar = tqdm(smoothing=0.1, mininterval=5.0)

with bar, lc.eval():
with bar, lc.eval(), autocast_ctx, inference_mode_ctx:
for docs in batchify(
(
subtask
Expand All @@ -64,7 +80,7 @@ def process():

for batch in batchify_fns[lc.batch_by](docs, lc.batch_size):
count = len(batch)
with no_grad(), lc.cache():
with no_grad_ctx, lc.cache():
batch = apply_basic_pipes(batch, batch_components)

if writer is not None:
Expand Down

0 comments on commit 9d1a640

Please sign in to comment.