diff --git a/coremltools/converters/mil/frontend/torch/converter.py b/coremltools/converters/mil/frontend/torch/converter.py index b5cd8277e..fdf5f5c33 100644 --- a/coremltools/converters/mil/frontend/torch/converter.py +++ b/coremltools/converters/mil/frontend/torch/converter.py @@ -585,7 +585,23 @@ def __init__( self.opset_version = _target(opset_version) if opset_version is not None else None self._prog = mil.Program() + self.src_model_has_all_fp16_weights = False + if isinstance(loaded_model, torch.jit.ScriptModule): + # src_model_has_all_fp16_weights will be True + # if there are more than one trainable layers in the model + # and if all those trainable layers have the fp16 dtype + # eg: if pytorch_model.half() has been explicitly used. + num_trainable_layers = 0 + num_trainable_fp16_layers = 0 + for param in loaded_model.parameters(): + if param.requires_grad: + num_trainable_layers += 1 + if param.dtype == torch.float16: + num_trainable_fp16_layers += 1 + if num_trainable_layers > 0: + self.src_model_has_all_fp16_weights = num_trainable_layers == num_trainable_fp16_layers + self.context = TranscriptionContext(frontend=TorchFrontend.TORCHSCRIPT) self.graph = InternalTorchIRGraph.from_torchscript( torchscript=loaded_model, inputs=self.inputs, cut_at_symbols=cut_at_symbols @@ -1261,6 +1277,11 @@ def convert(self) -> Program: user_names = list(ssa_func_inputs.keys()) internal_names = list(self.graph.inputs.keys()) internal_names.extend(user_names[len(internal_names) :]) + input_dtypes = [] + for torch_name, ssa_name in zip(internal_names, user_names): + input_var = ssa_func.inputs[ssa_name] + input_dtypes.append(input_var.dtype) + all_fp16_inputs = all(x == types.fp16 for x in input_dtypes) for torch_name, ssa_name in zip(internal_names, user_names): input_var = ssa_func.inputs[ssa_name] if self.context.frontend == TorchFrontend.TORCHSCRIPT: @@ -1272,7 +1293,7 @@ def convert(self) -> Program: # So here we perform the "cast input to fp32" step if ( types.is_tensor(input_var.sym_type) or types.is_scalar(input_var.sym_type) - ) and input_var.dtype == types.fp16: + ) and input_var.dtype == types.fp16 and not (all_fp16_inputs and self.src_model_has_all_fp16_weights): # This cast should have placeholder scope with mb.scope( ScopeInfo( diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py b/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py index f76d89734..5ba67e524 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_conversion_api.py @@ -1522,6 +1522,30 @@ def forward(self, x, y): result[name], expected.detach().numpy(), rtol=rtol, atol=atol ) + @staticmethod + @pytest.mark.parametrize( + "backend", + backends, + ) + def test_torch_fp16_model_with_fp16_inputs(torch_model, backend): + if backend[0] == "neuralnetwork": + pytest.skip( + "Input float16 needs target >= iOS16, which doesn't support neuralnetwork." + ) + traced_torch_model = torch.jit.trace(torch_model.half(), torch.rand(1, 10).half()) + ct.convert( + traced_torch_model, + source="pytorch", + inputs=[ + ct.TensorType( + shape=(1, 10), + ) + ], + outputs=[ct.TensorType(dtype=np.float16)], + convert_to=backend[0], + minimum_deployment_target=ct.target.macOS13, + ) + @pytest.fixture def int32_input_model():