Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add client close timeout on term #175

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 65 additions & 24 deletions src/amqproxy/cli.cr
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ require "ini"
require "log"

class AMQProxy::CLI
Log = ::Log.for(self)

@listen_address = ENV["LISTEN_ADDRESS"]? || "localhost"
@listen_port = ENV["LISTEN_PORT"]? || 5673
@log_level : Log::Severity = Log::Severity::Info
@log_level : ::Log::Severity = ::Log::Severity::Info
@idle_connection_timeout : Int32 = ENV.fetch("IDLE_CONNECTION_TIMEOUT", "5").to_i
@term_timeout = -1
@term_client_close_timeout = 0
@upstream = ENV["AMQP_URL"]?

def parse_config(path) # ameba:disable Metrics/CyclomaticComplexity
Expand All @@ -19,19 +22,20 @@ class AMQProxy::CLI
when "main", ""
section.each do |key, value|
case key
when "upstream" then @upstream = value
when "log_level" then @log_level = Log::Severity.parse(value)
when "idle_connection_timeout" then @idle_connection_timeout = value.to_i
when "term_timeout" then @term_timeout = value.to_i
else raise "Unsupported config #{name}/#{key}"
when "upstream" then @upstream = value
when "log_level" then @log_level = ::Log::Severity.parse(value)
when "idle_connection_timeout" then @idle_connection_timeout = value.to_i
when "term_timeout" then @term_timeout = value.to_i
when "term_client_close_timeout" then @term_client_close_timeout = value.to_i
else raise "Unsupported config #{name}/#{key}"
end
end
when "listen"
section.each do |key, value|
case key
when "port" then @listen_port = value
when "bind", "address" then @listen_address = value
when "log_level" then @log_level = Log::Severity.parse(value)
when "log_level" then @log_level = ::Log::Severity.parse(value)
else raise "Unsupported config #{name}/#{key}"
end
end
Expand All @@ -52,10 +56,13 @@ class AMQProxy::CLI
parser.on("-t IDLE_CONNECTION_TIMEOUT", "--idle-connection-timeout=SECONDS", "Maxiumum time in seconds an unused pooled connection stays open (default 5s)") do |v|
@idle_connection_timeout = v.to_i
end
parser.on("--term-timeout=SECONDS", "At TERM the server will wait this many seconds for clients to gracefully close their sockets (default: infinite)") do |v|
parser.on("--term-timeout=SECONDS", "At TERM the server waits SECONDS seconds for clients to gracefully close their sockets after Close has been sent (default: infinite)") do |v|
@term_timeout = v.to_i
end
parser.on("-d", "--debug", "Verbose logging") { @log_level = Log::Severity::Debug }
parser.on("--term-client-close-timeout=SECONDS", "At TERM the server waits SECONDS seconds for clients to send Close beforing sending Close to clients (default: 0s)") do |v|
@term_client_close_timeout = v.to_i
end
parser.on("-d", "--debug", "Verbose logging") { @log_level = ::Log::Severity::Debug }
parser.on("-c FILE", "--config=FILE", "Load config file") { |v| parse_config(v) }
parser.on("-h", "--help", "Show this help") { puts parser.to_s; exit 0 }
parser.on("-v", "--version", "Display version") { puts AMQProxy::VERSION.to_s; exit 0 }
Expand All @@ -77,42 +84,76 @@ class AMQProxy::CLI
tls = u.scheme == "amqps"

log_backend = if ENV.has_key?("JOURNAL_STREAM")
Log::IOBackend.new(formatter: JournalLogFormat, dispatcher: ::Log::DirectDispatcher)
::Log::IOBackend.new(formatter: Journal::LogFormat, dispatcher: ::Log::DirectDispatcher)
else
Log::IOBackend.new(formatter: StdoutLogFormat, dispatcher: ::Log::DirectDispatcher)
::Log::IOBackend.new(formatter: Stdout::LogFormat, dispatcher: ::Log::DirectDispatcher)
end
Log.setup_from_env(default_level: @log_level, backend: log_backend)
::Log.setup_from_env(default_level: @log_level, backend: log_backend)

server = AMQProxy::Server.new(u.hostname || "", port, tls, @idle_connection_timeout)

first_shutdown = true
shutdown = ->(_s : Signal) do
initiate_shutdown = ->(_s : Signal) do
if first_shutdown
first_shutdown = false
server.stop_accepting_clients
server.disconnect_clients
if @term_timeout >= 0
spawn do
sleep @term_timeout
abort "Exiting with #{server.client_connections} client connections still open"
end
end
else
abort "Exiting with #{server.client_connections} client connections still open"
end
end
Signal::INT.trap &shutdown
Signal::TERM.trap &shutdown
Signal::INT.trap &initiate_shutdown
Signal::TERM.trap &initiate_shutdown

server.listen(@listen_address, @listen_port.to_i)

shutdown server

# wait until all client connections are closed
until server.client_connections.zero?
sleep 0.2
end
Log.info { "No clients left. Exiting." }
end

def shutdown(server)
if server.client_connections > 0
if @term_client_close_timeout > 0
wait_for_clients_to_close(server, @term_client_close_timeout.seconds)
end
server.disconnect_clients
end

if server.client_connections > 0
if @term_timeout >= 0
spawn do
sleep @term_timeout
abort "Exiting with #{server.client_connections} client connections still open"
end
end
end
end

def wait_for_clients_to_close(server, close_timeout)
Log.info { "Waiting for clients to close their connections." }
ch = Channel(Bool).new
spawn do
loop do
ch.send true if server.client_connections.zero?
sleep 0.1.seconds
end
rescue Channel::ClosedError
end

select
when ch.receive?
Log.info { "All clients has closed their connections." }
when timeout close_timeout
ch.close
Log.info { "Timeout waiting for clients to close their connections." }
end
end

struct JournalLogFormat < Log::StaticFormatter
struct Journal::LogFormat < ::Log::StaticFormatter
def run
source
context(before: '[', after: ']')
Expand All @@ -122,7 +163,7 @@ class AMQProxy::CLI
end
end

struct StdoutLogFormat < Log::StaticFormatter
struct Stdout::LogFormat < ::Log::StaticFormatter
def run
timestamp
severity
Expand Down