Skip to content

Commit

Permalink
Merge pull request #624 from PINTO0309/restore_metadata
Browse files Browse the repository at this point in the history
Fixed to restore metadata
  • Loading branch information
PINTO0309 authored May 7, 2024
2 parents 30f25f7 + b624190 commit a851283
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.20.9
ghcr.io/pinto0309/onnx2tf:1.20.10

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.20.9
docker.io/pinto0309/onnx2tf:1.20.10

or

Expand All @@ -288,6 +288,8 @@ Video speed is adjusted approximately 50 times slower than actual speed.
&& pip install -U onnxruntime==1.17.1 \
&& pip install -U onnxsim==0.4.33 \
&& pip install -U simple_onnx_processing_tools \
&& pip install -U sne4onnx>=1.0.13 \
&& pip install -U sng4onnx>=1.0.4 \
&& pip install -U tensorflow==2.16.1 \
&& pip install -U protobuf==3.20.3 \
&& pip install -U onnx2tf \
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.20.9'
__version__ = '1.20.10'
22 changes: 19 additions & 3 deletions onnx2tf/onnx2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,10 @@ def convert(
tmp_onnx_graph = onnx.load(input_onnx_file_path)
domain: str = tmp_onnx_graph.domain
ir_version: int = tmp_onnx_graph.ir_version
meta_data = {'domain': domain, 'ir_version': ir_version}
metadata_props = None
if hasattr(tmp_onnx_graph, 'metadata_props'):
metadata_props = tmp_onnx_graph.metadata_props
tmp_graph = gs.import_onnx(tmp_onnx_graph)
output_clear = False
for graph_output in tmp_graph.outputs:
Expand All @@ -603,7 +607,10 @@ def convert(
graph_output.shape = None
output_clear = True
if output_clear:
estimated_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(tmp_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
if metadata_props is not None:
exported_onnx_graph.metadata_props.extend(metadata_props)
estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
onnx.save(estimated_graph, f=input_onnx_file_path)
del estimated_graph
except:
Expand Down Expand Up @@ -669,6 +676,10 @@ def convert(

domain: str = onnx_graph.domain
ir_version: int = onnx_graph.ir_version
meta_data = {'domain': domain, 'ir_version': ir_version}
metadata_props = None
if hasattr(onnx_graph, 'metadata_props'):
metadata_props = onnx_graph.metadata_props
graph = gs.import_onnx(onnx_graph)

# List Output
Expand Down Expand Up @@ -758,7 +769,9 @@ def sanitizing(node):
new_output_names.append(output_name)
output_names = new_output_names
try:
onnx_graph = gs.export_onnx(graph=graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})
onnx_graph = gs.export_onnx(graph=graph, do_type_check=False, **meta_data)
if metadata_props is not None:
onnx_graph.metadata_props.extend(metadata_props)
except Exception as ex:
# Workaround for SequenceConstruct terminating abnormally with onnx_graphsurgeon
pass
Expand Down Expand Up @@ -984,7 +997,10 @@ def sanitizing(node):
onnx_output_shape = list(onnx_tensor_infos_for_validation[correction_op_output.name].shape)
correction_op_output.shape = onnx_output_shape
try:
estimated_graph = onnx.shape_inference.infer_shapes(gs.export_onnx(graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version}))
exported_onnx_graph = gs.export_onnx(graph, do_type_check=False, **meta_data)
if metadata_props is not None:
exported_onnx_graph.metadata_props.extend(metadata_props)
estimated_graph = onnx.shape_inference.infer_shapes(exported_onnx_graph)
if input_onnx_file_path is not None:
onnx.save(estimated_graph, input_onnx_file_path)
if not not_use_onnxsim:
Expand Down
8 changes: 7 additions & 1 deletion onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3657,6 +3657,10 @@ def dummy_onnx_inference(
# Separate onnx at specified output_names position
domain: str = onnx_graph.domain
ir_version: int = onnx_graph.ir_version
meta_data = {'domain': domain, 'ir_version': ir_version}
metadata_props = None
if hasattr(onnx_graph, 'metadata_props'):
metadata_props = onnx_graph.metadata_props
gs_graph = gs.import_onnx(onnx_graph)

# reduce all axes except batch axis
Expand Down Expand Up @@ -3709,7 +3713,9 @@ def dummy_onnx_inference(
if node_output.dtype is not None:
gs_graph.outputs.append(node_output)

new_onnx_graph = gs.export_onnx(graph=gs_graph, do_type_check=False, **{'domain': domain, 'ir_version': ir_version})
new_onnx_graph = gs.export_onnx(graph=gs_graph, do_type_check=False, **meta_data)
if metadata_props is not None:
new_onnx_graph.metadata_props.extend(metadata_props)
tmp_onnx_path = ''
tmp_onnx_external_weights_path =''
try:
Expand Down

0 comments on commit a851283

Please sign in to comment.