Skip to content

Commit

Permalink
[PT FE] Save args on meta device when patching
Browse files Browse the repository at this point in the history
  • Loading branch information
mvafin committed Oct 21, 2024
1 parent d34cdda commit 163ae03
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions src/bindings/python/src/openvino/frontend/pytorch/patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __exit__(self, *args):
self.state = None


def patch_model(model, module_extensions, orig_forward_name):
def patch_model(model, module_extensions, orig_forward_name, use_meta=False):
def module_patcher(m, name):
extension = None
if m in module_extensions:
Expand All @@ -32,13 +32,14 @@ def module_patcher(m, name):
extension = module_extensions[name]

if extension:
log.debug("Patching module %s", m)
# The Trampoline class is instantiated for every module replacement, so we can use class members individually for each module.

class Trampoline(torch.autograd.Function):
target_extension = extension
original_module = m
stashed_args = None
stashed_kwargs = None
stashed_args = tuple()
stashed_kwargs = {}

@staticmethod
@torch.jit.ignore
Expand All @@ -53,26 +54,34 @@ def forward(*args, **kwargs):
# set original forward for the module
m.forward = getattr(m, orig_forward_name)
# call user code
results = extension.evaluate(
m, *Trampoline.stashed_args, **Trampoline.stashed_kwargs) # call user code
results = extension.evaluate(m, *Trampoline.stashed_args,
**Trampoline.stashed_kwargs)
m.forward = patched_forward # return patched forward back
return results

def new_forward(*args, **kwargs):
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
# use meta device to store args, to save memory
if use_meta:
d = torch.device("meta")
Trampoline.stashed_args = tuple(a.to(d) for a in args)
Trampoline.stashed_kwargs = dict((k, v.to(d)) for k, v in kwargs.items())
else:
Trampoline.stashed_args = args
Trampoline.stashed_kwargs = kwargs
return extension.convert(m, Trampoline.apply, *args, **kwargs)

setattr(m, orig_forward_name, m.forward)
m.forward = new_forward

for name, m in model.named_modules():
if hasattr(m, orig_forward_name):
# already patched, skipping with a warning because it is unexpected
log.warning("Unexpectedly found already patched module %s while applying "
"ModuleExtension during PyTorch model conversion. "
"Result of the conversion maybe broken. Depending on the exact issue "
"it may lead to broken original model.", name)
# already patched, skipping. It may happen when patching applied for same module twice
log.debug("Unexpectedly found already patched module %s while applying "
"ModuleExtension during PyTorch model conversion. "
"Result of the conversion maybe broken. Depending on the exact issue "
"it may lead to broken original model.", name)
continue

module_patcher(m, name)


Expand All @@ -99,25 +108,34 @@ def __make_16bit_traceable(model: torch.nn.Module):
torch.nn.Linear, "ov_ext::linear",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.out_features], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias)),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias)),
torch.nn.Embedding: ModuleExtension(
torch.nn.Embedding, "ov_ext::embedding",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape) + [module.embedding_dim], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight, args[0], module.padding_idx, module.scale_grad_by_freq, module.sparse)),
convert=lambda module, target_op, *args, **kwargs: target_op(module.weight,
args[0],
module.padding_idx,
module.scale_grad_by_freq,
module.sparse)),
}
try:
from transformers.pytorch_utils import Conv1D
extensions[Conv1D] = ModuleExtension(
Conv1D, "ov_ext::conv1d",
evaluate=lambda module, *args, **kwargs: torch.full(
list(args[0].shape[:-1]) + [module.nf], 0.5, dtype=torch.float32),
convert=lambda module, target_op, *args, **kwargs: target_op(args[0], module.weight, module.bias))
convert=lambda module, target_op, *args, **kwargs: target_op(args[0],
module.weight,
module.bias))
except:
pass
patch_model(model, extensions,
"_openvino_module_extension_patch_orig_forward")
"_openvino_module_extension_patch_orig_forward", use_meta=True)
for _, module in model.named_modules():
if module.__class__ not in extensions and (any([p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False)])
or any([b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False)])):
if module.__class__ not in extensions and (any(p.dtype in [torch.float16, torch.bfloat16] for p in module.parameters(False))
or any(b.dtype in [torch.float16, torch.bfloat16] for b in module.buffers(False))):
log.debug("Casting module %s to float32", module)
module.float()

0 comments on commit 163ae03

Please sign in to comment.