Skip to content

Commit

Permalink
Merge pull request #689 from PINTO0309/fix_flatten_2dim
Browse files Browse the repository at this point in the history
Addressed the issue of missing conversions when multi-dimensional flattening is performed and the batch size of the first dimension is an undefined dimension.
  • Loading branch information
PINTO0309 authored Sep 11, 2024
2 parents c8ace3b + 6b5b1f4 commit 5baf18d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.25.9
ghcr.io/pinto0309/onnx2tf:1.25.10

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.25.9
docker.io/pinto0309/onnx2tf:1.25.10

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.25.9'
__version__ = '1.25.10'
16 changes: 14 additions & 2 deletions onnx2tf/ops/Flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,20 @@ def make_node(
cal_shape = (1, -1)
elif axis >= input_tensor_rank:
cal_shape = (-1, 1)
elif graph_node_output.shape is not None and len(graph_node_output.shape) == 2 and axis == input_tensor_rank - 1:
cal_shape = (1, -1)
elif graph_node_output.shape is not None \
and len(graph_node_output.shape) == 2 \
and axis == input_tensor_rank - 1 \
and not isinstance(graph_node_output.shape[0], str):
cal_shape = (graph_node_output.shape[0], -1)
elif graph_node_output.shape is not None \
and len(graph_node_output.shape) == 2 \
and axis == input_tensor_rank - 1 \
and isinstance(graph_node_output.shape[0], str):
try:
dim_prod = int(np.prod(graph_node_output.shape[1:]))
cal_shape = (-1, dim_prod)
except:
cal_shape = (1, -1)
elif input_tensor_rank >= 2 \
and input_tensor_shape[0] is None \
and len([idx for idx in input_tensor_shape[1:] if idx is not None]) == input_tensor_rank - 1 \
Expand Down

0 comments on commit 5baf18d

Please sign in to comment.