Skip to content

Commit

Permalink
Bug fix. ReduceL1, ReduceL2, ReduceLogSum, ReduceLogSumExp, ReduceMax…
Browse files Browse the repository at this point in the history
…, ReduceMean, ReduceMin, ReduceProd, ReduceSum, ReduceSumSquare
  • Loading branch information
PINTO0309 committed May 5, 2024
1 parent 2517eb9 commit 1537586
Show file tree
Hide file tree
Showing 11 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceL1.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceL2.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceLogSum.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceLogSumExp.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceMax.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceMean.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceMin.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceProd.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceSum.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/ops/ReduceSumSquare.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def make_node(
del onnx_tensor_infos_for_validation

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
if onnx_tensor_infos is not None and validation_data is not None and axes is not None:
# Shape Unmatch Error Mitigation Measures
# Search for and transpose shapes that do not cause shape unmatch errors
min_abs_err = sys.maxsize
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/ops/Reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,8 @@ def make_node(
block_size = 1
spase_to_depth_final_shape = []
if not tf_layers_dict[graph_node_output.name].get('unnecessary_reshape', False):
if len(final_shape) == 6:
if (isinstance(final_shape, np.ndarray) or isinstance(final_shape, list)) \
and len(final_shape) == 6:
channel_size = final_shape[1] if isinstance(final_shape[1], int) else None
block_size = final_shape[3] if isinstance(final_shape[3], int) else None

Expand Down

0 comments on commit 1537586

Please sign in to comment.