Skip to content

Commit

Permalink
Merge pull request #622 from PINTO0309/fix_if_reducemax
Browse files Browse the repository at this point in the history
Improved conversion stability of subgraphs of `If` operations.
  • Loading branch information
PINTO0309 authored May 5, 2024
2 parents 4207e99 + b63d080 commit bf8e894
Show file tree
Hide file tree
Showing 24 changed files with 71 additions and 36 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.20.7
ghcr.io/pinto0309/onnx2tf:1.20.8

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.20.7
docker.io/pinto0309/onnx2tf:1.20.8

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.20.7'
__version__ = '1.20.8'
3 changes: 2 additions & 1 deletion onnx2tf/ops/Concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,8 @@ def define_concat(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/Conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/ConvInteger.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/DepthToSpace.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down
1 change: 1 addition & 0 deletions onnx2tf/ops/Einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def make_node(
equation = graph_node.attrs['equation']
onnx_tensor_infos_for_validation: Dict[str: np.ndarray] = kwargs['onnx_tensor_infos_for_validation']
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None \
and graph_node_output.name in onnx_tensor_infos_for_validation:
onnx_output_shape = list(onnx_tensor_infos_for_validation[graph_node_output.name].shape)
graph_node_output.shape = onnx_output_shape
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/GroupNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/InstanceNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ def make_node(
and 'nhwc' in tf_layers_dict[graph_node_input.name].keys() else False
}

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(
Expand Down Expand Up @@ -164,7 +165,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/LayerNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output_1.name, None) is not None:
onnx_tensor_infos = {
graph_node_output_1.name:
onnx_tensor_infos_for_validation[graph_node_output_1.name]
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/MatMul.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ def make_node(
onnx_tensor_infos = None
validation_data_1 = None
validation_data_2 = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos, validation_data_1, validation_data_2 = \
acquisition_of_validation_data(
input_tensor_1=input_tensor_1,
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceL1.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -154,7 +155,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceL2.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -154,7 +155,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceLogSum.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -154,7 +155,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceLogSumExp.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -154,7 +155,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceMax.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -161,7 +162,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceMean.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -155,7 +156,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceMin.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -155,7 +156,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceProd.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -155,7 +156,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceSum.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -154,7 +155,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
6 changes: 4 additions & 2 deletions onnx2tf/ops/ReduceSumSquare.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ def make_node(
onnx_tensor_infos = None
validation_data = None

if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -155,7 +156,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/Reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down
7 changes: 5 additions & 2 deletions onnx2tf/ops/Softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ def make_node(
and onnx_input_shapes[pre_convert_axis] == tf_input_shapes[axis]:
acc_check_pass_flg = True

if onnx_tensor_infos_for_validation is not None and not acc_check_pass_flg:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None \
and not acc_check_pass_flg:
# Get the output tensor of one previous OP of TensorFlow only once
if not disable_strict_mode:
tf_model_inputs = get_tf_model_inputs(tf_layers_dict=tf_layers_dict)
Expand Down Expand Up @@ -180,7 +182,8 @@ def make_node(

# Get ONNX inference results
onnx_tensor_infos = None
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name:
onnx_tensor_infos_for_validation[graph_node_output.name]
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5660,7 +5660,8 @@ def acquisition_of_validation_data(

# Get ONNX inference results
onnx_tensor_infos = {}
if onnx_tensor_infos_for_validation is not None:
if onnx_tensor_infos_for_validation is not None \
and onnx_tensor_infos_for_validation.get(graph_node_output.name, None) is not None:
onnx_tensor_infos = {
graph_node_output.name: onnx_tensor_infos_for_validation[graph_node_output.name]
}
Expand Down

0 comments on commit bf8e894

Please sign in to comment.