Skip to content

Commit

Permalink
Skip casting model inputs to fp32 if weights and inputs are all fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
jeethu committed Sep 18, 2024
1 parent e6be416 commit 87f3b8c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
23 changes: 22 additions & 1 deletion coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,23 @@ def __init__(
self.opset_version = _target(opset_version) if opset_version is not None else None
self._prog = mil.Program()

self.src_model_has_all_fp16_weights = False

if isinstance(loaded_model, torch.jit.ScriptModule):
# src_model_has_all_fp16_weights will be True
# if there are more than one trainable layers in the model
# and if all those trainable layers have the fp16 dtype
# eg: if pytorch_model.half() has been explicitly used.
num_trainable_layers = 0
num_trainable_fp16_layers = 0
for param in loaded_model.parameters():
if param.requires_grad:
num_trainable_layers += 1
if param.dtype == torch.float16:
num_trainable_fp16_layers += 1
if num_trainable_layers > 0:
self.src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers

self.context = TranscriptionContext(frontend=TorchFrontend.TORCHSCRIPT)
self.graph = InternalTorchIRGraph.from_torchscript(
torchscript=loaded_model, inputs=self.inputs, cut_at_symbols=cut_at_symbols
Expand Down Expand Up @@ -1261,6 +1277,11 @@ def convert(self) -> Program:
user_names = list(ssa_func_inputs.keys())
internal_names = list(self.graph.inputs.keys())
internal_names.extend(user_names[len(internal_names) :])
input_dtypes = []
for torch_name, ssa_name in zip(internal_names, user_names):
input_var = ssa_func.inputs[ssa_name]
input_dtypes.append(input_var.dtype)
all_fp16_inputs = all(x == types.fp16 for x in input_dtypes)
for torch_name, ssa_name in zip(internal_names, user_names):
input_var = ssa_func.inputs[ssa_name]
if self.context.frontend == TorchFrontend.TORCHSCRIPT:
Expand All @@ -1272,7 +1293,7 @@ def convert(self) -> Program:
# So here we perform the "cast input to fp32" step
if (
types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type)
) and input_var.dtype == types.fp16:
) and input_var.dtype == types.fp16 and not (all_fp16_inputs and self.src_model_has_all_fp16_weights):
# This cast should have placeholder scope
with mb.scope(
ScopeInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,30 @@ def forward(self, x, y):
result[name], expected.detach().numpy(), rtol=rtol, atol=atol
)

@staticmethod
@pytest.mark.parametrize(
"backend",
backends,
)
def test_torch_fp16_model_with_fp16_inputs(torch_model, backend):
if backend[0] == "neuralnetwork":
pytest.skip(
"Input float16 needs target >= iOS16, which doesn't support neuralnetwork."
)
traced_torch_model = torch.jit.trace(torch_model.half(), torch.rand(1, 10).half())
ct.convert(
traced_torch_model,
source="pytorch",
inputs=[
ct.TensorType(
shape=(1, 10),
)
],
outputs=[ct.TensorType(dtype=np.float16)],
convert_to=backend[0],
minimum_deployment_target=ct.target.macOS13,
)


@pytest.fixture
def int32_input_model():
Expand Down

0 comments on commit 87f3b8c

Please sign in to comment.