Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to convert a model with 3d input shape of dynamic length into tflite int8 format #673

Open
gurudatta-patil opened this issue Jul 18, 2024 · 8 comments
Labels
Bug bug Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape OP:AveragePool OP:AveragePool OP:BatchNormalization OP:BatchNormalization OP:Expand OP:Expand

Comments

@gurudatta-patil
Copy link

gurudatta-patil commented Jul 18, 2024

Issue Type

Others

OS

Linux

onnx2tf version number

1.25.6

onnx version number

1.16.1

onnxruntime version number

1.18.1

onnxsim (onnx_simplifier) version number

0.4.33

tensorflow version number

2.17.0

Download URL for ONNX

https://github.com/gurudatta-patil/ML-Campp/blob/main/cam%2B%2B_vin.onnx

Parameter Replacement JSON

~

Description

  1. Research
  2. Command: onnx2tf -i cam++_vin.onnx -osd -coion

  File "/usr/local/bin/onnx2tf", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/onnx2tf/onnx2tf.py", line 2574, in main
    model = convert(
  File "/usr/local/lib/python3.10/dist-packages/onnx2tf/onnx2tf.py", line 1295, in convert
    concrete_func = run_model.get_concrete_function()
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1251, in get_concrete_function
    concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 1221, in _get_concrete_function_garbage_collected
    self._initialize(args, kwargs, add_initializers_to=initializers)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 696, in _initialize
    self._concrete_variable_creation_fn = tracing_compilation.trace_function(
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 178, in trace_function
    concrete_function = _maybe_define_function(
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 283, in _maybe_define_function
    concrete_function = _create_concrete_function(
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/tracing_compilation.py", line 310, in _create_concrete_function
    traced_func_graph = func_graph_module.func_graph_from_py_func(
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/func_graph.py", line 1059, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py", line 599, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 52, in autograph_handler
    raise e.ag_error_metadata.to_exception(e)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/eager/polymorphic_function/autograph_util.py", line 41, in autograph_handler
    return api.converted_call(
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
    result = converted_f(*effective_args, **kwargs)
  File "/tmp/__autograph_generated_filej12tz1am.py", line 6, in <lambda>
    tf__lam = lambda *inputs: ag__.with_function_scope(lambda lscope: ag__.converted_call(model, (inputs,), None, lscope), 'lscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True))
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/core/function_wrappers.py", line 113, in with_function_scope
    return thunk(scope)
  File "/tmp/__autograph_generated_filej12tz1am.py", line 6, in <lambda>
    tf__lam = lambda *inputs: ag__.with_function_scope(lambda lscope: ag__.converted_call(model, (inputs,), None, lscope), 'lscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True))
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
    return _call_unconverted(f, args, kwargs, options)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/autograph/impl/api.py", line 460, in _call_unconverted
    return f(*args)
  File "/usr/local/lib/python3.10/dist-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py", line 1037, in _create_c_op
    raise ValueError(e.message)
ValueError: in user code:

    File "/usr/local/lib/python3.10/dist-packages/onnx2tf/onnx2tf.py", line 1292, in None  *
        lambda *inputs : model(inputs)
    File "/usr/local/lib/python3.10/dist-packages/tf_keras/src/utils/traceback_utils.py", line 70, in error_handler  **
        raise e.with_traceback(filtered_tb) from None

    ValueError: Exception encountered when calling layer 'tf.math.multiply_9' (type TFOpLambda).
    
    Dimensions must be equal, but are 2 and 128 for '{{node model_59/tf.math.multiply_9/Mul}} = Mul[T=DT_FLOAT](model_59/tf.expand_dims_4/ExpandDims, model_59/tf.ones/ones)' with input shapes: [1,2,1,128], [1,128,128,100].
    
    Call arguments received by layer 'tf.math.multiply_9' (type TFOpLambda):
      • x=tf.Tensor(shape=(1, 2, 1, 128), dtype=float32)
      • y=tf.Tensor(shape=(1, 128, 128, 100), dtype=float32)
      • name=None

Input size: [1,-1,80]
name: input
tensor: float32[1,time_frames,80]

  1. I also tried a few other commands including passing a npy file as input.
    I am trying to get a int8 output for the model.

  2. I am trying to get this model into lightweight format with minimal quantization error to deploy on embedded device.

@PINTO0309 PINTO0309 added the TODO TODO label Jul 18, 2024
@PINTO0309
Copy link
Owner

PINTO0309 commented Jul 18, 2024

I will keep notes on the material to research again when I have enough time to work on it.

  • log
    convert_log.txt.zip
  • Problem Location
    INFO: 50 / 1719
    INFO: onnx_op_type: Expand onnx_op_name: wa/xvector/block1/tdnnd1/cam_layer/Expand
    INFO:  input_name.1: wa/xvector/block1/tdnnd1/cam_layer/Unsqueeze_output_0 shape: [1, 128, 'unk__77', 1] dtype: float32
    INFO:  input_name.2: wa/xvector/block1/tdnnd1/cam_layer/Where_output_0 shape: [4] dtype: int64
    INFO:  output_name.1: wa/xvector/block1/tdnnd1/cam_layer/Expand_output_0 shape: ['unk__80', 128, 'unk__83', 'unk__86'] dtype: float32
    INFO: tf_op_type: Expand
    INFO:  input.1.input_tensor: name: tf.expand_dims_4/ExpandDims:0 shape: (1, 2, 1, 128) dtype: <dtype: 'float32'>
    INFO:  input.2.input_tensor_shape: name: tf.where/SelectV2:0 shape: (4,) dtype: <dtype: 'int64'>
    INFO:  output.1.output: name: tf.math.multiply_9/Mul:0 shape: (None, 2, None, 128) dtype: <dtype: 'float32'>
    
  • Logic that needs to be modified (Expand bug)
    • Do not perform optimization using tf.ones when undefined dimensions are present.
      # tf.math.multiply does not support bool therefore use int32
      expanded_tensor = None
      if input_tensor.dtype is tf.bool:
      ones = tf.ones(input_tensor_shape, dtype=tf.int32)
      r = tf.cast(input_tensor, tf.int32) * ones
      expanded_tensor = tf.cast(r, tf.bool)
      else:
      ones = tf.ones(input_tensor_shape, dtype=input_tensor.dtype)
      expanded_tensor = input_tensor * ones
      tf_layers_dict[graph_node_output.name]['tf_node'] = expanded_tensor
      tf_type = 'Expand'

@PINTO0309 PINTO0309 added Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape OP:Expand OP:Expand Bug bug labels Jul 18, 2024
@gurudatta-patil
Copy link
Author

gurudatta-patil commented Jul 19, 2024

Thank you!
Can I follow any particular steps in a different way to solve it, as of now?
Shape: [1,-1,80] LSTM Model, Final conversion: int8 tflite model

@PINTO0309
Copy link
Owner

There is a JSON behavior correction function, but it is difficult to understand and takes a very long time to comprehend.

I'm concentrating on other tasks for a while, so if you're in a hurry, try these. The conversion success rate is said to be 100%.

https://github.com/google-ai-edge/ai-edge-torch

https://github.com/AlexanderLutsenko/nobuco

@PINTO0309
Copy link
Owner

PINTO0309 commented Jul 23, 2024

It turned out to be an AveragePool1D problem, not an Expand problem. This is a rather tricky problem. This issue has nothing to do with LSTM or INT8 quantization, but rather with specification differences between frameworks for the Pooling process.

image

Unfortunately, this AveragePool is not compatible with TensorFlow's AveragePool.

image

The padding size is calculated by a rather complicated logic and is forced to conform to TensorFlow, so I have to investigate how to reduce the padding size to zero. Essentially, the output tensor of AveragePool must be TF: [1, 100, 128] ONNX: [1, 128, 100].

# Generation of TF OP
tf_op_type = None
if len(kernel_shape) == 1:
pooled_tensor = AveragePooling1D(
pool_size=kernel_shape,
strides=strides,
padding=tf_pad_mode.upper(),
)(padded_tensor)
tf_op_type = AveragePooling1D

image

@PINTO0309 PINTO0309 added the OP:AveragePool OP:AveragePool label Jul 23, 2024
@gurudatta-patil
Copy link
Author

We did try it by using ai-edge-torch, but saw that if failed to modify the code, but after some meddling we could convert the code. But it still does not transform the average pooling layer correctly.

@PINTO0309
Copy link
Owner

Thanks for sharing your valuable experience. This is quite a difficult issue.

@PINTO0309
Copy link
Owner

PINTO0309 commented Jul 25, 2024

I would add debugging resources.

Dynamic Static128 Static1
avgpool1d_dynamic.onnx.zip
image
avgpool1d_static.onnx.zip
image
avgpool1d_static1.onnx.zip
image
onnx2tf -i avgpool1d_static1.onnx -cotof

INFO: validation_conditions: np.allclose(onnx_outputs, tf_outputs, rtol=0.0, atol=0.0001, equal_nan=True)
INFO: onnx_output_name: wa/xvector/block1/tdnnd1/cam_layer/AveragePool_output_0 tf_output_name: tf.compat.v1.squeeze/Squeeze:0 shape: (1, 128, 1) dtype: float32 validate_result:  Unmatched  max_abs_error: 0.9900000095367432

image

image

WIP: main...fix_undef_expand

image

image

  • Bug: BatchNormalization 1D
    • INPUT: TensorShape([1, None, 128])
    • OUTPUT: TensorShape([1, 128, 128])
      image

@PINTO0309
Copy link
Owner

PINTO0309 commented Jul 25, 2024

I have fixed and released the critical problems except for AveragePool, but AveragePool (with ceil_mode=1) with dynamic tensor as input is extremely difficult to fix due to compatibility issues with TensorFlow.

The problem is that the error was not occurring in the AveragePool where the conversion error should have occurred, and the latest onnx2tf should now generate a conversion error in the AveragePool. This is because of the difficulty in calculating the ExtraPadding needed to resolve the differences between PyTorch and TensorFlow's Pooling specifications.

```
INFO: 39 / 1464
INFO: onnx_op_type: AveragePool onnx_op_name: wa/xvector/block1/tdnnd1/cam_layer/AveragePool
INFO:  input_name.1: wa/xvector/block1/tdnnd1/nonlinear2/relu/Relu_output_0 shape: [1, 128, 'unk__71'] dtype: float32
INFO:  output_name.1: wa/xvector/block1/tdnnd1/cam_layer/AveragePool_output_0 shape: [1, 128, 'unk__77'] dtype: float32
ERROR: The trace log is below.
Traceback (most recent call last):
  File "/home/xxxxx/git/onnx2tf/onnx2tf/utils/common_functions.py", line 312, in print_wrapper_func
    result = func(*args, **kwargs)
  File "/home/xxxxx/git/onnx2tf/onnx2tf/utils/common_functions.py", line 385, in inverted_operation_enable_disable_wrapper_func
    result = func(*args, **kwargs)
  File "/home/xxxxx/git/onnx2tf/onnx2tf/utils/common_functions.py", line 55, in get_replacement_parameter_wrapper_func
    func(*args, **kwargs)
  File "/home/xxxxx/git/onnx2tf/onnx2tf/ops/AveragePool.py", line 171, in make_node
    output_spatial_shape = [
  File "/home/xxxxx/git/onnx2tf/onnx2tf/ops/AveragePool.py", line 172, in <listcomp>
    func((i + pb + pe - d * (k - 1) - 1) / s + 1)
TypeError: unsupported operand type(s) for +: 'NoneType' and 'int'

ERROR: input_onnx_file_path: ../cam++_vin.onnx
ERROR: onnx_op_name: wa/xvector/block1/tdnnd1/cam_layer/AveragePool
ERROR: Read this and deal with it. https://github.com/PINTO0309/onnx2tf#parameter-replacement
ERROR: Alternatively, if the input OP has a dynamic dimension, use the -b or -ois option to rewrite it to a static shape and try again.
ERROR: If the input OP of ONNX before conversion is NHWC or an irregular channel arrangement other than NCHW, use the -kt or -kat option.
ERROR: Also, for models that include NonMaxSuppression in the post-processing, try the -onwdt option.
``` 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug bug Dynamic batch / Dynamic shape Dynamic batch / Dynamic shape OP:AveragePool OP:AveragePool OP:BatchNormalization OP:BatchNormalization OP:Expand OP:Expand
Projects
None yet
Development

No branches or pull requests

2 participants