diff --git a/README.md b/README.md
index 7a8bbe2..406fb93 100644
--- a/README.md
+++ b/README.md
@@ -95,9 +95,9 @@ sudo cmake --install build --prefix /usr/local/onnxruntime-server
# Install via a package manager
-| OS | Method | Command |
-|-------------------------------|------------|--------------------------------------------|
-| Arch Linux | AUR | `yay -S onnxruntime-server` |
+| OS | Method | Command |
+|------------|--------|-----------------------------|
+| Arch Linux | AUR | `yay -S onnxruntime-server` |
----
@@ -127,11 +127,12 @@ sudo cmake --install build --prefix /usr/local/onnxruntime-server
## Options
-| Option | Environment | Description |
-|-------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
-| `--workers` | `ONNX_SERVER_WORKERS` | Worker thread pool size.
Default: `4` |
-| `--model-dir` | `ONNX_SERVER_MODEL_DIR` | Model directory path
The onnx model files must be located in the following path:
`${model_dir}/${model_name}/${model_version}/model.onnx`
Default: `models` |
-| `--prepare-model` | `ONNX_SERVER_PREPARE_MODEL` | Pre-create some model sessions at server startup.
Format as a space-separated list of `model_name:model_version` or `model_name:model_version(session_options, ...)`.
Available session_options are
- cuda=device_id`[ or true or false]`
eg) `model1:v1 model2:v9`
`model1:v1(cuda=true) model2:v9(cuda=1)` |
+| Option | Environment | Description |
+|---------------------------|-------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| `--workers` | `ONNX_SERVER_WORKERS` | Worker thread pool size.
Default: `4` |
+| `--request-payload-limit` | `ONNX_SERVER_REQUEST_PAYLOAD_LIMIT` | HTTP/HTTPS request payload size limit.
Default: 1024 * 1024 * 10(10MB)` |
+| `--model-dir` | `ONNX_SERVER_MODEL_DIR` | Model directory path
The onnx model files must be located in the following path:
`${model_dir}/${model_name}/${model_version}/model.onnx`
Default: `models` |
+| `--prepare-model` | `ONNX_SERVER_PREPARE_MODEL` | Pre-create some model sessions at server startup.
Format as a space-separated list of `model_name:model_version` or `model_name:model_version(session_options, ...)`.
Available session_options are
- cuda=device_id`[ or true or false]`
eg) `model1:v1 model2:v9`
`model1:v1(cuda=true) model2:v9(cuda=1)` |
### Backend options
diff --git a/src/onnxruntime_server.hpp b/src/onnxruntime_server.hpp
index d0e8743..0a85e39 100644
--- a/src/onnxruntime_server.hpp
+++ b/src/onnxruntime_server.hpp
@@ -311,6 +311,7 @@ namespace onnxruntime_server {
std::string model_dir;
std::string prepare_model;
model_bin_getter_t model_bin_getter{};
+ long request_payload_limit = 1024 * 1024 * 10;
};
namespace transport {
@@ -322,6 +323,7 @@ namespace onnxruntime_server {
asio::socket socket;
asio::acceptor acceptor;
uint_least16_t assigned_port = 0;
+ long request_payload_limit_;
onnx::session_manager *onnx_session_manager;
@@ -331,12 +333,13 @@ namespace onnxruntime_server {
public:
server(
boost::asio::io_context &io_context, onnx::session_manager *onnx_session_manager,
- builtin_thread_pool *worker_pool, int port
+ builtin_thread_pool *worker_pool, int port, long request_payload_limit
);
~server();
builtin_thread_pool *get_worker_pool();
onnx::session_manager *get_onnx_session_manager();
+ [[nodiscard]] long request_payload_limit() const;
[[nodiscard]] uint_least16_t port() const;
};
diff --git a/src/standalone/standalone.cpp b/src/standalone/standalone.cpp
index e41abbc..df7db1c 100644
--- a/src/standalone/standalone.cpp
+++ b/src/standalone/standalone.cpp
@@ -15,9 +15,13 @@ int onnxruntime_server::standalone::init_config(int argc, char **argv) {
po_desc.add_options()("help,h", "Produce help message\n");
// env: ONNX_WORKERS
po_desc.add_options()(
- "workers", po::value()->default_value(4),
+ "workers", po::value()->default_value(4),
"env: ONNX_SERVER_WORKERS\nWorker thread pool size.\nDefault: 4"
);
+ po_desc.add_options()(
+ "request-payload-limit", po::value()->default_value(1024 * 1024 * 10),
+ "env: ONNX_SERVER_REQUEST_PAYLOAD_LIMIT\nHTTP/HTTPS request payload size limit.\nDefault: 1024 * 1024 * 10(10MB)"
+ );
po_desc.add_options()(
"model-dir", po::value()->default_value("models"),
"env: ONNX_SERVER_MODEL_DIR\nModel directory path.\nThe onnx model files must be located in the "
@@ -156,7 +160,10 @@ int onnxruntime_server::standalone::init_config(int argc, char **argv) {
AixLog::Log::init({log_file, log_access_file});
if (vm.count("workers"))
- config.num_threads = vm["workers"].as();
+ config.num_threads = vm["workers"].as();
+
+ if (vm.count("request-payload-limit"))
+ config.request_payload_limit = vm["request-payload-limit"].as();
if (vm.count("model-dir"))
config.model_dir = vm["model-dir"].as();
diff --git a/src/test/e2e/e2e_test_http_server.cpp b/src/test/e2e/e2e_test_http_server.cpp
index c0f84a5..6a4aa58 100644
--- a/src/test/e2e/e2e_test_http_server.cpp
+++ b/src/test/e2e/e2e_test_http_server.cpp
@@ -87,6 +87,137 @@ TEST(test_onnxruntime_server_http, HttpServerTest) {
ASSERT_GT(res_json["output"][0], 0);
}
+ { // API: Execute session large request
+ auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
+ int size = 1000000;
+ for (int i = 0; i < size; i++) {
+ input["x"].push_back(input["x"][0]);
+ input["y"].push_back(input["y"][0]);
+ input["z"].push_back(input["z"][0]);
+ }
+ std::cout << input.dump().length() << " bytes\n";
+
+ bool exception = false;
+ try {
+ TIME_MEASURE_START
+ auto res =
+ http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
+ TIME_MEASURE_STOP
+ } catch (std::exception &e) {
+ exception = true;
+ std::cout << e.what() << std::endl;
+ }
+ ASSERT_TRUE(exception);
+ }
+
+ { // API: Destroy session
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), "");
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Destroy sessions\n" << res_json.dump(2) << "\n";
+ ASSERT_TRUE(res_json);
+ }
+
+ { // API: List session
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), "");
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json.size(), 0);
+ }
+
+ running = false;
+ server_thread.join();
+}
+
+TEST(test_onnxruntime_server_http, HttpServerLargeRequestTest) {
+ Orts::config config;
+ config.http_port = 0;
+ config.model_bin_getter = test_model_bin_getter;
+ config.request_payload_limit = 1024 * 1024 * 1024;
+
+ boost::asio::io_context io_context;
+ Orts::onnx::session_manager manager(config.model_bin_getter);
+ Orts::builtin_thread_pool worker_pool(config.num_threads);
+ Orts::transport::http::http_server server(io_context, config, &manager, &worker_pool);
+
+ bool running = true;
+ std::thread server_thread([&io_context, &running]() { test_server_run(io_context, &running); });
+
+ TIME_MEASURE_INIT
+
+ { // API: Create session
+ json body = json::parse(R"({"model":"sample","version":"1"})");
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::post, "/api/sessions", server.port(), body.dump());
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Create session\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json["model"], "sample");
+ ASSERT_EQ(res_json["version"], "1");
+ }
+
+ { // API: Get session
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::get, "/api/sessions/sample/1", server.port(), "");
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Get session\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json["model"], "sample");
+ ASSERT_EQ(res_json["version"], "1");
+ }
+
+ { // API: List session
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), "");
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json.size(), 1);
+ ASSERT_EQ(res_json[0]["model"], "sample");
+ ASSERT_EQ(res_json[0]["version"], "1");
+ }
+
+ { // API: Execute session
+ auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Execute sessions\n" << res_json.dump(2) << "\n";
+ ASSERT_TRUE(res_json.contains("output"));
+ ASSERT_EQ(res_json["output"].size(), 1);
+ ASSERT_GT(res_json["output"][0], 0);
+ }
+
+ { // API: Execute session large request
+ auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
+ int size = 1000000;
+ for (int i = 0; i < size; i++) {
+ input["x"].push_back(input["x"][0]);
+ input["y"].push_back(input["y"][0]);
+ input["z"].push_back(input["z"][0]);
+ }
+ std::cout << input.dump().length() << " bytes\n";
+
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ ASSERT_TRUE(res_json.contains("output"));
+ ASSERT_EQ(res_json["output"].size(), size + 1);
+ ASSERT_GT(res_json["output"][0], 0);
+ }
+
{ // API: Destroy session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), "");
@@ -132,6 +263,5 @@ http_request(beast::http::verb method, const std::string &target, short port, st
beast::http::read(socket, buffer, res);
- socket.close();
return res;
}
diff --git a/src/test/e2e/e2e_test_https_server.cpp b/src/test/e2e/e2e_test_https_server.cpp
index 28d886f..4f0b1b4 100644
--- a/src/test/e2e/e2e_test_https_server.cpp
+++ b/src/test/e2e/e2e_test_https_server.cpp
@@ -89,6 +89,30 @@ TEST(test_onnxruntime_server_http, HttpsServerTest) {
ASSERT_GT(res_json["output"][0], 0);
}
+ { // API: Execute session large request
+ auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
+ int size = 1000000;
+ for (int i = 0; i < size; i++) {
+ input["x"].push_back(input["x"][0]);
+ input["y"].push_back(input["y"][0]);
+ input["z"].push_back(input["z"][0]);
+ }
+ std::cout << input.dump().length() << " bytes\n";
+
+ bool exception = false;
+ try {
+
+ TIME_MEASURE_START
+ auto res =
+ http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
+ TIME_MEASURE_STOP
+ } catch (std::exception &e) {
+ exception = true;
+ std::cout << e.what() << std::endl;
+ }
+ ASSERT_TRUE(exception);
+ }
+
{ // API: Destroy session
TIME_MEASURE_START
auto res = http_request(boost::beast::http::verb::delete_, "/api/sessions/sample/1", server.port(), "");
@@ -113,6 +137,96 @@ TEST(test_onnxruntime_server_http, HttpsServerTest) {
server_thread.join();
}
+TEST(test_onnxruntime_server_http, HttpsServerLargeRequestTest) {
+ Orts::config config;
+ config.https_port = 0;
+ config.https_cert = (test_dir / "ssl" / "server-cert.pem").string();
+ config.https_key = (test_dir / "ssl" / "server-key.pem").string();
+ config.model_bin_getter = test_model_bin_getter;
+ config.request_payload_limit = 1024 * 1024 * 1024;
+
+ boost::asio::io_context io_context;
+ Orts::onnx::session_manager manager(config.model_bin_getter);
+ Orts::builtin_thread_pool worker_pool(config.num_threads);
+ Orts::transport::http::https_server server(io_context, config, &manager, &worker_pool);
+
+ bool running = true;
+ std::thread server_thread([&io_context, &running]() { test_server_run(io_context, &running); });
+
+ TIME_MEASURE_INIT
+
+ { // API: Create session
+ json body = json::parse(R"({"model":"sample","version":"1"})");
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::post, "/api/sessions", server.port(), body.dump());
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Create session\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json["model"], "sample");
+ ASSERT_EQ(res_json["version"], "1");
+ }
+
+ { // API: Get session
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::get, "/api/sessions/sample/1", server.port(), "");
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Get session\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json["model"], "sample");
+ ASSERT_EQ(res_json["version"], "1");
+ }
+
+ { // API: List session
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::get, "/api/sessions", server.port(), "");
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: List sessions\n" << res_json.dump(2) << "\n";
+ ASSERT_EQ(res_json.size(), 1);
+ ASSERT_EQ(res_json[0]["model"], "sample");
+ ASSERT_EQ(res_json[0]["version"], "1");
+ }
+
+ { // API: Execute session
+ auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ std::cout << "API: Execute sessions\n" << res_json.dump(2) << "\n";
+ ASSERT_TRUE(res_json.contains("output"));
+ ASSERT_EQ(res_json["output"].size(), 1);
+ ASSERT_GT(res_json["output"][0], 0);
+ }
+
+ { // API: Execute session large request
+ auto input = json::parse(R"({"x":[[1]],"y":[[2]],"z":[[3]]})");
+ int size = 1000000;
+ for (int i = 0; i < size; i++) {
+ input["x"].push_back(input["x"][0]);
+ input["y"].push_back(input["y"][0]);
+ input["z"].push_back(input["z"][0]);
+ }
+ std::cout << input.dump().length() << " bytes\n";
+
+ TIME_MEASURE_START
+ auto res = http_request(boost::beast::http::verb::post, "/api/sessions/sample/1", server.port(), input.dump());
+ TIME_MEASURE_STOP
+ ASSERT_EQ(res.result(), boost::beast::http::status::ok);
+ json res_json = json::parse(boost::beast::buffers_to_string(res.body().data()));
+ ASSERT_TRUE(res_json.contains("output"));
+ ASSERT_EQ(res_json["output"].size(), size + 1);
+ ASSERT_GT(res_json["output"][0], 0);
+ }
+
+ running = false;
+ server_thread.join();
+}
+
beast::http::response
http_request(beast::http::verb method, const std::string &target, short port, std::string body) {
boost::asio::io_context ioc;
@@ -143,7 +257,5 @@ http_request(beast::http::verb method, const std::string &target, short port, st
beast::http::read(stream, buffer, res);
- stream.shutdown();
- stream.lowest_layer().close();
return res;
}
diff --git a/src/transport/http/http_server.cpp b/src/transport/http/http_server.cpp
index c0d2efb..606ea4f 100644
--- a/src/transport/http/http_server.cpp
+++ b/src/transport/http/http_server.cpp
@@ -9,7 +9,7 @@ onnxruntime_server::transport::http::http_server::http_server(
onnxruntime_server::onnx::session_manager *onnx_session_manager,
onnxruntime_server::builtin_thread_pool *worker_pool
)
- : server(io_context, onnx_session_manager, worker_pool, config.http_port), swagger(config.swagger_url_path) {
+ : server(io_context, onnx_session_manager, worker_pool, config.http_port, config.request_payload_limit), swagger(config.swagger_url_path) {
acceptor.set_option(boost::asio::socket_base::reuse_address(true));
}
diff --git a/src/transport/http/http_server.hpp b/src/transport/http/http_server.hpp
index e3e1f72..03ef00d 100644
--- a/src/transport/http/http_server.hpp
+++ b/src/transport/http/http_server.hpp
@@ -36,7 +36,7 @@ namespace onnxruntime_server::transport::http {
template class http_session_base : public std::enable_shared_from_this {
protected:
beast::flat_buffer buffer;
- beast::http::request req;
+ std::shared_ptr> req_parser;
virtual onnx::session_manager *get_onnx_session_manager() = 0;
std::shared_ptr> handle_request();
diff --git a/src/transport/http/http_session.cpp b/src/transport/http/http_session.cpp
index 5281eb4..1a75fc7 100644
--- a/src/transport/http/http_session.cpp
+++ b/src/transport/http/http_session.cpp
@@ -30,10 +30,12 @@ void onnxruntime_server::transport::http::http_session::do_read() {
_remote_endpoint = stream.socket().remote_endpoint().address().to_string() + ":" +
std::to_string(stream.socket().remote_endpoint().port());
- req = {};
- stream.expires_after(std::chrono::seconds(30));
+ // stream.expires_after(std::chrono::seconds(300));
+ req_parser->body_limit(server->request_payload_limit());
- beast::http::async_read(stream, buffer, req, beast::bind_front_handler(&http_session::on_read, shared_from_this()));
+ beast::http::async_read(
+ stream, buffer, *req_parser, beast::bind_front_handler(&http_session::on_read, shared_from_this())
+ );
}
void onnxruntime_server::transport::http::http_session::on_read(beast::error_code ec, std::size_t bytes_transferred) {
@@ -55,6 +57,7 @@ void onnxruntime_server::transport::http::http_session::on_read(beast::error_cod
void onnxruntime_server::transport::http::http_session::do_write(
std::shared_ptr> msg
) {
+ auto req = req_parser->get();
PLOG(L_INFO, "ACCESS") << get_remote_endpoint() << " task: " << req.method_string() << " " << req.target()
<< " status: " << msg->result_int() << " duration: " << request_time.get_duration()
<< std::endl;
@@ -68,7 +71,7 @@ void onnxruntime_server::transport::http::http_session::do_write(
return self->close();
}
- if (!self->req.keep_alive())
+ if (!self->req_parser->get().keep_alive())
return self->close();
self->do_read();
diff --git a/src/transport/http/http_session_base.cpp b/src/transport/http/http_session_base.cpp
index 8000249..6cab8f1 100644
--- a/src/transport/http/http_session_base.cpp
+++ b/src/transport/http/http_session_base.cpp
@@ -1,7 +1,8 @@
#include "http_server.hpp"
template
-onnxruntime_server::transport::http::http_session_base::http_session_base() : buffer(), req() {
+onnxruntime_server::transport::http::http_session_base::http_session_base() : buffer() {
+ req_parser = std::make_shared>();
}
#define CONTENT_TYPE_PLAIN_TEXT "text/plain"
@@ -10,8 +11,11 @@ onnxruntime_server::transport::http::http_session_base::http_session_ba
template
std::shared_ptr>
onnxruntime_server::transport::http::http_session_base::handle_request() {
+ auto req = req_parser->get();
+
auto const simple_response =
[this](beast::http::status method, beast::string_view content_type, beast::string_view body) {
+ auto req = req_parser->get();
auto res = std::make_shared>(method, req.version());
res->set(beast::http::field::content_type, content_type);
res->keep_alive(req.keep_alive());
diff --git a/src/transport/http/https_server.cpp b/src/transport/http/https_server.cpp
index 34ac785..f45e718 100644
--- a/src/transport/http/https_server.cpp
+++ b/src/transport/http/https_server.cpp
@@ -9,7 +9,7 @@ onnxruntime_server::transport::http::https_server::https_server(
onnxruntime_server::onnx::session_manager *onnx_session_manager,
onnxruntime_server::builtin_thread_pool *worker_pool
)
- : server(io_context, onnx_session_manager, worker_pool, config.https_port), ctx(boost::asio::ssl::context::sslv23),
+ : server(io_context, onnx_session_manager, worker_pool, config.https_port, config.request_payload_limit), ctx(boost::asio::ssl::context::sslv23),
swagger(config.swagger_url_path) {
boost::system::error_code ec;
ctx.set_options(
diff --git a/src/transport/http/https_session.cpp b/src/transport/http/https_session.cpp
index 2eb33ee..6f39227 100644
--- a/src/transport/http/https_session.cpp
+++ b/src/transport/http/https_session.cpp
@@ -35,10 +35,10 @@ void onnxruntime_server::transport::http::https_session::do_read() {
_remote_endpoint = stream.lowest_layer().remote_endpoint().address().to_string() + ":" +
std::to_string(stream.lowest_layer().remote_endpoint().port());
- req = {};
+ req_parser->body_limit(server->request_payload_limit());
beast::http::async_read(
- stream, buffer, req, beast::bind_front_handler(&https_session::on_read, shared_from_this())
+ stream, buffer, *req_parser, beast::bind_front_handler(&https_session::on_read, shared_from_this())
);
}
@@ -59,6 +59,7 @@ void onnxruntime_server::transport::http::https_session::on_read(beast::error_co
void onnxruntime_server::transport::http::https_session::do_write(
std::shared_ptr> msg
) {
+ auto req = req_parser->get();
PLOG(L_INFO, "ACCESS") << get_remote_endpoint() << " task: " << req.method_string() << " " << req.target()
<< " status: " << msg->result_int() << " duration: " << request_time.get_duration()
<< std::endl;
@@ -72,7 +73,7 @@ void onnxruntime_server::transport::http::https_session::do_write(
return self->close();
}
- if (!self->req.keep_alive())
+ if (!self->req_parser->get().keep_alive())
return self->close();
self->do_read();
diff --git a/src/transport/server.cpp b/src/transport/server.cpp
index cd97aef..6f95f2c 100644
--- a/src/transport/server.cpp
+++ b/src/transport/server.cpp
@@ -6,10 +6,10 @@
Orts::transport::server::server(
boost::asio::io_context &io_context, Orts::onnx::session_manager *onnx_session_manager,
- Orts::builtin_thread_pool *worker_pool, int port
+ Orts::builtin_thread_pool *worker_pool, int port, long request_payload_limit
)
: io_context(io_context), acceptor(io_context, asio::endpoint(asio::v4(), port)), socket(io_context),
- onnx_session_manager(onnx_session_manager), worker_pool(worker_pool) {
+ onnx_session_manager(onnx_session_manager), worker_pool(worker_pool), request_payload_limit_(request_payload_limit) {
assigned_port = acceptor.local_endpoint().port();
@@ -37,6 +37,10 @@ Orts::onnx::session_manager *Orts::transport::server::get_onnx_session_manager()
return onnx_session_manager;
}
+long Orts::transport::server::request_payload_limit() const {
+ return request_payload_limit_;
+}
+
uint_least16_t Orts::transport::server::port() const {
return assigned_port;
}
diff --git a/src/transport/tcp/tcp_server.cpp b/src/transport/tcp/tcp_server.cpp
index 9cb3fee..42baef8 100644
--- a/src/transport/tcp/tcp_server.cpp
+++ b/src/transport/tcp/tcp_server.cpp
@@ -8,7 +8,7 @@ onnxruntime_server::transport::tcp::tcp_server::tcp_server(
onnxruntime_server::onnx::session_manager *onnx_session_manager,
onnxruntime_server::builtin_thread_pool *worker_pool
)
- : server(io_context, onnx_session_manager, worker_pool, config.tcp_port) {
+ : server(io_context, onnx_session_manager, worker_pool, config.tcp_port, config.request_payload_limit) {
acceptor.set_option(boost::asio::socket_base::reuse_address(true));
}