diff --git a/doctr/file_utils.py b/doctr/file_utils.py index fc1129b0c..53263345b 100644 --- a/doctr/file_utils.py +++ b/doctr/file_utils.py @@ -80,10 +80,7 @@ def ensure_keras_v2() -> None: # pragma: no cover else: logging.info(f"TensorFlow version {_tf_version} available.") ensure_keras_v2() - import tensorflow as tf - # Enable eager execution - this is required for some models to work properly - tf.config.run_functions_eagerly(True) else: # pragma: no cover logging.info("Disabling Tensorflow because USE_TORCH is set") _tf_available = False diff --git a/doctr/models/recognition/parseq/tensorflow.py b/doctr/models/recognition/parseq/tensorflow.py index d8c54527b..0a72d33fc 100644 --- a/doctr/models/recognition/parseq/tensorflow.py +++ b/doctr/models/recognition/parseq/tensorflow.py @@ -250,7 +250,6 @@ def decode( target_query = self.dropout(target_query, **kwargs) return self.decoder(target_query, content, memory, target_mask, **kwargs) - @tf.function def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor: """Generate predictions for the given features.""" max_length = max_len if max_len is not None else self.max_length diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py index c04a4b289..792089bda 100644 --- a/doctr/models/utils/tensorflow.py +++ b/doctr/models/utils/tensorflow.py @@ -161,6 +161,10 @@ def export_model_to_onnx( ------- the path to the exported model and a list with the output layer names """ + # get the users eager mode + eager_mode = tf.executing_eagerly() + # set eager mode to true to avoid issues with tf2onnx + tf.config.run_functions_eagerly(True) large_model = kwargs.get("large_model", False) model_proto, _ = tf2onnx.convert.from_keras( model, @@ -171,6 +175,9 @@ def export_model_to_onnx( # Get the output layer names output = [n.name for n in model_proto.graph.output] + # reset the eager mode to the users mode + tf.config.run_functions_eagerly(eager_mode) + # models which are too large (weights > 2GB while converting to ONNX) needs to be handled # about an external tensor storage where the graph and weights are seperatly stored in a archive if large_model: