Skip to content

Commit

Permalink
Enable serialization of predict response as tensor content.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689691887
  • Loading branch information
tensorflower-gardener authored and tensorflow-copybara committed Oct 25, 2024
1 parent 36aa4a6 commit 1e16551
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions tensorflow_serving/model_servers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ cc_library(
"//tensorflow_serving/servables/tensorflow:get_model_metadata_impl",
"//tensorflow_serving/servables/tensorflow:multi_inference",
"//tensorflow_serving/servables/tensorflow:predict_impl",
"//tensorflow_serving/servables/tensorflow:predict_response_tensor_serialization_option",
"//tensorflow_serving/servables/tensorflow:regression_service",
"//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter",
"//tensorflow_serving/servables/tensorflow:session_bundle_config_cc_proto",
Expand Down
6 changes: 5 additions & 1 deletion tensorflow_serving/model_servers/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,11 @@ int main(int argc, char** argv) {
"Whether to skip auto initializing TPU."),
tensorflow::Flag("enable_grpc_healthcheck_service",
&options.enable_grpc_healthcheck_service,
"Enable the standard gRPC healthcheck service.")};
"Enable the standard gRPC healthcheck service."),
tensorflow::Flag(
"enable_serialization_as_tensor_content",
&options.enable_serialization_as_tensor_content,
"Enable serialization of predict response as tensor content.")};

const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list);
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {
Expand Down
5 changes: 5 additions & 0 deletions tensorflow_serving/model_servers/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ limitations under the License.
#include "tensorflow_serving/model_servers/model_platform_types.h"
#include "tensorflow_serving/model_servers/server_core.h"
#include "tensorflow_serving/model_servers/server_init.h"
#include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
#include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
#include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
#include "tensorflow_serving/servables/tensorflow/util.h"
Expand Down Expand Up @@ -321,6 +322,10 @@ Status Server::BuildAndStart(const Options& server_options) {
options.force_allow_any_version_labels_for_unavailable_models =
server_options.force_allow_any_version_labels_for_unavailable_models;
options.enable_cors_support = server_options.enable_cors_support;
if (server_options.enable_serialization_as_tensor_content) {
options.predict_response_tensor_serialization_option =
internal::PredictResponseTensorSerializationOption::kAsProtoContent;
}

TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_));

Expand Down
2 changes: 2 additions & 0 deletions tensorflow_serving/model_servers/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class Server {
bool skip_initialize_tpu = false;
// Misc GRPC features
bool enable_grpc_healthcheck_service = false;
// Control whether to serialize predict response as tensor content.
bool enable_serialization_as_tensor_content = false;
Options();
};

Expand Down
3 changes: 3 additions & 0 deletions tensorflow_serving/servables/tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ cc_test(
cc_library(
name = "predict_response_tensor_serialization_option",
hdrs = ["predict_response_tensor_serialization_option.h"],
visibility = [
"//visibility:public",
],
)

cc_library(
Expand Down

0 comments on commit 1e16551

Please sign in to comment.