diff --git a/tensorflow_serving/model_servers/BUILD b/tensorflow_serving/model_servers/BUILD index 63871980ad0..bc71895c6c7 100644 --- a/tensorflow_serving/model_servers/BUILD +++ b/tensorflow_serving/model_servers/BUILD @@ -497,6 +497,7 @@ cc_library( "//tensorflow_serving/servables/tensorflow:get_model_metadata_impl", "//tensorflow_serving/servables/tensorflow:multi_inference", "//tensorflow_serving/servables/tensorflow:predict_impl", + "//tensorflow_serving/servables/tensorflow:predict_response_tensor_serialization_option", "//tensorflow_serving/servables/tensorflow:regression_service", "//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter", "//tensorflow_serving/servables/tensorflow:session_bundle_config_cc_proto", diff --git a/tensorflow_serving/model_servers/main.cc b/tensorflow_serving/model_servers/main.cc index 16f327a82f6..32454b86b27 100644 --- a/tensorflow_serving/model_servers/main.cc +++ b/tensorflow_serving/model_servers/main.cc @@ -307,7 +307,11 @@ int main(int argc, char** argv) { "Whether to skip auto initializing TPU."), tensorflow::Flag("enable_grpc_healthcheck_service", &options.enable_grpc_healthcheck_service, - "Enable the standard gRPC healthcheck service.")}; + "Enable the standard gRPC healthcheck service."), + tensorflow::Flag( + "enable_serialization_as_tensor_content", + &options.enable_serialization_as_tensor_content, + "Enable serialization of predict response as tensor content.")}; const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list); if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { diff --git a/tensorflow_serving/model_servers/server.cc b/tensorflow_serving/model_servers/server.cc index 0c9147cacd9..0655b3f33b5 100644 --- a/tensorflow_serving/model_servers/server.cc +++ b/tensorflow_serving/model_servers/server.cc @@ -52,6 +52,7 @@ limitations under the License. #include "tensorflow_serving/model_servers/model_platform_types.h" #include "tensorflow_serving/model_servers/server_core.h" #include "tensorflow_serving/model_servers/server_init.h" +#include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h" #include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h" #include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h" #include "tensorflow_serving/servables/tensorflow/util.h" @@ -321,6 +322,10 @@ Status Server::BuildAndStart(const Options& server_options) { options.force_allow_any_version_labels_for_unavailable_models = server_options.force_allow_any_version_labels_for_unavailable_models; options.enable_cors_support = server_options.enable_cors_support; + if (server_options.enable_serialization_as_tensor_content) { + options.predict_response_tensor_serialization_option = + internal::PredictResponseTensorSerializationOption::kAsProtoContent; + } TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_)); diff --git a/tensorflow_serving/model_servers/server.h b/tensorflow_serving/model_servers/server.h index babd34ee766..ac7828fd4d8 100644 --- a/tensorflow_serving/model_servers/server.h +++ b/tensorflow_serving/model_servers/server.h @@ -105,6 +105,8 @@ class Server { bool skip_initialize_tpu = false; // Misc GRPC features bool enable_grpc_healthcheck_service = false; + // Control whether to serialize predict response as tensor content. + bool enable_serialization_as_tensor_content = false; Options(); }; diff --git a/tensorflow_serving/servables/tensorflow/BUILD b/tensorflow_serving/servables/tensorflow/BUILD index e21397955ae..12f785d41d9 100644 --- a/tensorflow_serving/servables/tensorflow/BUILD +++ b/tensorflow_serving/servables/tensorflow/BUILD @@ -501,6 +501,9 @@ cc_test( cc_library( name = "predict_response_tensor_serialization_option", hdrs = ["predict_response_tensor_serialization_option.h"], + visibility = [ + "//visibility:public", + ], ) cc_library(