diff --git a/changelog.md b/changelog.md index a0b1c9e28..a3a7dd75d 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,10 @@ ### Added +- `data.set_processing(...)` now expose an `autocast` parameter to disable or tweak the automatic casting of the tensor + during the processing. Autocasting should result in a slight speedup, but may lead to numerical instability. +- Use `torch.inference_mode` to disable view tracking and version counter bumps during inference. + ### Changed ### Fixed diff --git a/edsnlp/core/lazy_collection.py b/edsnlp/core/lazy_collection.py index 9b49f545d..07835a190 100644 --- a/edsnlp/core/lazy_collection.py +++ b/edsnlp/core/lazy_collection.py @@ -131,6 +131,10 @@ def gpu_worker_devices(self): def cpu_worker_devices(self): return self.config.get("cpu_worker_devices") + @property + def autocast(self): + return self.config.get("autocast") + @property def backend(self): backend = self.config.get("backend") @@ -156,8 +160,9 @@ def set_processing( num_gpu_workers: Optional[int] = INFER, disable_implicit_parallelism: bool = True, backend: Optional[Literal["simple", "multiprocessing", "mp", "spark"]] = INFER, - gpu_pipe_names: Optional[List[str]] = INFER, + autocast: Union[bool, Any] = True, show_progress: bool = False, + gpu_pipe_names: Optional[List[str]] = INFER, process_start_method: Optional[Literal["fork", "spawn"]] = INFER, gpu_worker_devices: Optional[List[str]] = INFER, cpu_worker_devices: Optional[List[str]] = INFER, @@ -203,10 +208,6 @@ def set_processing( disable_implicit_parallelism: bool Whether to disable OpenMP and Huggingface tokenizers implicit parallelism in multiprocessing mode. Defaults to True. - gpu_pipe_names: Optional[List[str]] - List of pipe names to accelerate on a GPUWorker, defaults to all pipes - that inherit from TorchComponent. Only used with "multiprocessing" backend. - Inferred from the pipeline if not set. backend: Optional[Literal["simple", "multiprocessing", "spark"]] The backend to use for parallel processing. If not set, the backend is automatically selected based on the input data and the number of workers. @@ -217,9 +218,20 @@ def set_processing( `num_gpu_workers` is greater than 0. - "spark" is used when the input data is a Spark dataframe and the output writer is a Spark writer. + autocast: Union[bool, Any] + Whether to use + [automatic mixed precision (AMP)](https://pytorch.org/docs/stable/amp.html) + for the forward pass of the deep-learning components. If True (by default), + AMP will be used with the default settings. If False, AMP will not be used. + If a dtype is provided, it will be passed to the `torch.autocast` context + manager. show_progress: Optional[bool] Whether to show progress bars (only applicable with "simple" and "multiprocessing" backends). + gpu_pipe_names: Optional[List[str]] + List of pipe names to accelerate on a GPUWorker, defaults to all pipes + that inherit from TorchComponent. Only used with "multiprocessing" backend. + Inferred from the pipeline if not set. process_start_method: Optional[Literal["fork", "spawn"]] Whether to use "fork" or "spawn" as the start method for the multiprocessing backend. The default is "fork" on Unix systems and "spawn" on Windows. diff --git a/edsnlp/processing/multiprocessing.py b/edsnlp/processing/multiprocessing.py index b173652ae..422b80fcd 100644 --- a/edsnlp/processing/multiprocessing.py +++ b/edsnlp/processing/multiprocessing.py @@ -656,13 +656,22 @@ def run(self): if name in self.gpu_pipe_names ] + autocast_ctx = ( + torch.autocast( + device_type=self.device, + dtype=lc.autocast, + ) + if lc.autocast is not None + else nullcontext() + ) + del lc logging.info(f"Starting {self} on {os.getpid()}") # Inform the main process that we are ready self.exchanger.put_results((None, 0, None, None)) - with torch.no_grad(): # , torch.cuda.amp.autocast(): + with torch.no_grad(), autocast_ctx, torch.inference_mode(): while True: stage, task = self.exchanger.get_gpu_task(self.gpu_idx) if task is None: diff --git a/edsnlp/processing/simple.py b/edsnlp/processing/simple.py index 8ec267a8c..1be0b700d 100644 --- a/edsnlp/processing/simple.py +++ b/edsnlp/processing/simple.py @@ -25,9 +25,25 @@ def execute_simple_backend( batch on the current process in a sequential manner. """ try: - no_grad = sys.modules["torch"].no_grad + torch = sys.modules["torch"] + no_grad_ctx = torch.no_grad() + autocast_ctx = ( + torch.autocast( + device_type=next( + p.device for pipe in lc.torch_components for p in lc.parameters + ), + dtype=lc.autocast, + ) + if lc.autocast is not None + else nullcontext() + ) + inference_mode_ctx = ( + torch.inference_mode() + if hasattr(torch, "inference_mode") + else nullcontext() + ) except (KeyError, AttributeError): - no_grad = nullcontext + no_grad_ctx = autocast_ctx = inference_mode_ctx = nullcontext() reader = lc.reader writer = lc.writer show_progress = lc.show_progress @@ -48,7 +64,7 @@ def process(): bar = tqdm(smoothing=0.1, mininterval=5.0) - with bar, lc.eval(): + with bar, lc.eval(), autocast_ctx, inference_mode_ctx: for docs in batchify( ( subtask @@ -64,7 +80,7 @@ def process(): for batch in batchify_fns[lc.batch_by](docs, lc.batch_size): count = len(batch) - with no_grad(), lc.cache(): + with no_grad_ctx, lc.cache(): batch = apply_basic_pipes(batch, batch_components) if writer is not None: