Skip to content

Commit

Permalink
fix the gaps in blocking commands within transaction and fix paring o… (
Browse files Browse the repository at this point in the history
#36)

* fix the gaps in blocking commands within transaction and fix paring of block option for streams command

* fix keyindex for transaction

* fix key index for eval commands within transaction

* fix wrong number of arguments for subscription commands

* clean up the if elase for handling commands with no arguments

* add client remote ip in log when errenous command executed

* Add support for copy command

---------

Co-authored-by: suryadeepr <[email protected]>
  • Loading branch information
dinesh-murugiah and suryadeepr authored Oct 8, 2024
1 parent 69ece62 commit a851190
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 27 deletions.
12 changes: 12 additions & 0 deletions source/extensions/filters/network/common/redis/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,18 +280,28 @@ struct Transaction {

void setBlockingCommand() {is_blocking_command_ = true;}

void setDiscardTransaction() { discard_ = true; }

void setSendDiscardError() { send_discard_error_ = true; }

bool isTransactionMode() { return is_transaction_mode_; }

bool isSubscribedMode() { return is_subscribed_mode_; }

bool isBlockingCommand() { return is_blocking_command_; }

bool isDiscardTransaction() { return discard_; }

bool shouldSendDiscardError() { return send_discard_error_; }


void close() {
active_ = false;
is_transaction_mode_ = false;
is_subscribed_mode_ = false;
is_blocking_command_ = false;
discard_ = false;
send_discard_error_ = false;
key_.clear();
if (connection_established_) {
for (auto& client : clients_) {
Expand Down Expand Up @@ -320,6 +330,8 @@ struct Transaction {
}

bool active_{false};
bool discard_{false};
bool send_discard_error_{false};
bool connection_established_{false};
bool should_close_{false};
bool is_blocking_command_{false};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct SupportedCommands {
*/
static const absl::flat_hash_set<std::string>& simpleCommands() {
CONSTRUCT_ON_FIRST_USE(
absl::flat_hash_set<std::string>, "append", "bitcount", "bitfield", "bitpos", "decr",
absl::flat_hash_set<std::string>, "append", "bitcount", "bitfield", "bitpos", "copy", "decr",
"decrby", "dump", "expire", "expireat", "geoadd", "geodist", "geohash", "geopos",
"georadius_ro", "georadiusbymember_ro", "get", "getbit", "getdel", "getrange", "getset",
"hdel", "hexists", "hget", "hgetall", "hincrby", "hincrbyfloat", "hkeys", "hlen", "hmget",
Expand Down Expand Up @@ -64,7 +64,7 @@ struct SupportedCommands {
* @return commands which handle Redis transactions allowed non simple commands.
*/
static const absl::flat_hash_set<std::string>& transactionAllowedNonSimpleCommands() {
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "del", "exists", "touch", "unlink", "xack", "xadd", "xautoclaim", "xclaim", "xdel", "xgroup", "xinfo", "xlen", "xpending", "xrange", "xrevrange","xtrim");
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "del", "exists", "touch", "unlink", "xack", "xadd", "xautoclaim", "xclaim", "xdel", "xgroup", "xinfo", "xlen", "xpending", "xrange", "xrevrange","xtrim","eval", "evalsha","xread","xreadgroup","blpop", "brpop", "brpoplpush", "bzpopmax", "bzpopmin","blmove","mset","mget");
}

/**
Expand All @@ -78,7 +78,7 @@ struct SupportedCommands {
* @return commands allowed when a client is in subscribed state.
*/
static const absl::flat_hash_set<std::string>& subcrStateallowedCommands() {
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "subscribe", "psubscribe", "unsubscribe", "punsubscribe","quit", "ping","reset");
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "subscribe", "psubscribe", "unsubscribe", "punsubscribe","quit", "ping");
}

/**
Expand Down Expand Up @@ -113,7 +113,7 @@ struct SupportedCommands {
* @return client sub commands thats supported
*/
static const absl::flat_hash_set<std::string>& clientSubCommands() {
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "getname", "list","setname");
CONSTRUCT_ON_FIRST_USE(absl::flat_hash_set<std::string>, "getname","setname");
}

/**
Expand Down
137 changes: 114 additions & 23 deletions source/extensions/filters/network/redis_proxy/command_splitter_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ SingleServerRequest::~SingleServerRequest() { ASSERT(!handle_); }
void SingleServerRequest::onResponse(Common::Redis::RespValuePtr&& response) {
handle_ = nullptr;
updateStats(true);
if (callbacks_.transaction().isTransactionMode() && callbacks_.transaction().shouldSendDiscardError() && callbacks_.transaction().should_close_) {
response.reset();
callbacks_.onResponse(Common::Redis::Utility::makeError("EXECABORT Transaction discarded because of previous errors"));
return;
}
callbacks_.onResponse(std::move(response));
}

Expand Down Expand Up @@ -930,6 +935,13 @@ SplitRequestPtr PubSubRequest::create(Router& router, Common::Redis::RespValuePt
std::string key = std::string();
int32_t redisShardsCount=0;

if (Common::Redis::SupportedCommands::subscriptionCommands().contains(command_name) && incoming_request->asArray().size() < 2) {
ENVOY_LOG(debug, "Invalid request: '{}'", incoming_request->toString());
callbacks.onResponse(Common::Redis::Utility::makeError(Response::get().InvalidRequest));
return nullptr;
}


bool singleShardRequest = false;
bool allShardsRequest = false;
bool allShardwithSingleShardRequest = false;
Expand Down Expand Up @@ -984,8 +996,8 @@ SplitRequestPtr PubSubRequest::create(Router& router, Common::Redis::RespValuePt
if (transaction.active_) {
if (transaction.isSubscribedMode()) {
if (!Common::Redis::SupportedCommands::subcrStateallowedCommands().contains(command_name)) {
callbacks.onResponse(
Common::Redis::Utility::makeError("Not supported command in subscribe state"));
callbacks.onResponse(Common::Redis::Utility::makeError("Not supported command in subscribe state"));
return nullptr;
} else if (command_name == "quit") {
transaction.should_close_ = true;
transaction.subscribed_client_shard_index_ = -1;
Expand Down Expand Up @@ -1631,6 +1643,36 @@ void SplitKeysSumResultRequest::onChildResponse(Common::Redis::RespValuePtr&& va
}
}

int32_t TransactionRequest::getShardingKeyIndex(const std::string& command_name, const Common::Redis::RespValue& request) {
if (command_name == "xread" || command_name == "xreadgroup") {
int32_t count = request.asArray().size();
for (int32_t index = 0; index < count; ++index) {
if (absl::AsciiStrToLower(request.asArray()[index].asString()) == "streams") {
if (index + 1 < count) {
return index + 1; // Return the index of the key after "streams"
} else {
return -1; // "streams" is the last element
}
}
}
return -1; // "streams" not found
} else if (command_name == "xgroup" || command_name == "xinfo") {
if (request.asArray().size() > 2) {
return 2; // Return index 2 if there are more than 2 elements
} else {
return -1; // Not enough elements
}
} else if(Common::Redis::SupportedCommands::evalCommands().contains(command_name)) {
if (!(request.asArray().size() < 4)) {
return 3; // Return index 3 for eval commands
} else {
return -1; // Not enough arguments to process in transaction
}
}else {
return 1; // Default case for other commands
}
}

SplitRequestPtr TransactionRequest::create(Router& router,
Common::Redis::RespValuePtr&& incoming_request,
SplitCallbacks& callbacks, CommandStats& command_stats,
Expand All @@ -1646,6 +1688,8 @@ SplitRequestPtr TransactionRequest::create(Router& router,
callbacks.onResponse(Common::Redis::Utility::makeError(
fmt::format("'{}' command is not supported within transaction",
incoming_request->asArray()[0].asString())));
transaction.setDiscardTransaction();
ENVOY_LOG(debug, "Transaction command not supported: '{}', Setting Discard flag", incoming_request->asArray()[0].asString());
return nullptr;
}

Expand Down Expand Up @@ -1674,15 +1718,28 @@ SplitRequestPtr TransactionRequest::create(Router& router,
// Handle the case where the transaction is empty.
if (transaction.key_.empty()) {
if (command_name == "exec") {
Common::Redis::RespValuePtr empty_array{new Common::Redis::Client::EmptyArray{}};
callbacks.onResponse(std::move(empty_array));
if (transaction.isDiscardTransaction()){
callbacks.onResponse(Common::Redis::Utility::makeError("EXECABORT Transaction discarded because of previous errors"));
}else{
Common::Redis::RespValuePtr empty_array{new Common::Redis::Client::EmptyArray{}};
callbacks.onResponse(std::move(empty_array));
}
} else {
localResponse(callbacks, "OK");
}
transaction.close();
return nullptr;
}
//If Disard transaction flag is set, we will send discard in the place of exec
if (transaction.isDiscardTransaction() && command_name == "exec"){
auto new_request = std::make_unique<Common::Redis::RespValue>();
new_request->type(Common::Redis::RespType::Array);
addBulkString(*new_request, "DISCARD");
incoming_request = std::move(new_request);
ENVOY_LOG(debug, "Transaction command is set to discard, changing exec to discard");
transaction.setSendDiscardError();

}
// In all other cases we will close the transaction connection after sending the last command.
transaction.should_close_ = true;
}
Expand All @@ -1694,7 +1751,14 @@ SplitRequestPtr TransactionRequest::create(Router& router,

RouteSharedPtr route;
if (transaction.key_.empty()) {
transaction.key_ = incoming_request->asArray()[1].asString();
int32_t shardKeyIndex = getShardingKeyIndex(command_name,*incoming_request);
if (shardKeyIndex < 0) {
ENVOY_LOG(error, "unexpected command : '{}'", incoming_request->toString());
callbacks.onResponse(Common::Redis::Utility::makeError(fmt::format("unexpected command format")));
transaction.setDiscardTransaction();
return nullptr;
}
transaction.key_ = incoming_request->asArray()[shardKeyIndex].asString();
route = router.upstreamPool(transaction.key_, stream_info);
Common::Redis::RespValueSharedPtr multi_request =
std::make_shared<Common::Redis::Client::MultiRequest>();
Expand Down Expand Up @@ -1778,19 +1842,19 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request,
const StreamInfo::StreamInfo& stream_info) {
// Validate request type and contents.
if ((request->type() != Common::Redis::RespType::Array) || request->asArray().empty()) {
ENVOY_LOG(error,"invalid request - not an array or empty");
ENVOY_LOG(error,"invalid request - not an array or empty,sent from remote client ip '{}'",stream_info.downstreamAddressProvider().remoteAddress()->asString());
onInvalidRequest(callbacks);
return nullptr;
}

for (const Common::Redis::RespValue& value : request->asArray()) {
if (value.type() != Common::Redis::RespType::BulkString) {
ENVOY_LOG(error,"invalid request - not an array of bulk strings");
ENVOY_LOG(error,"invalid request - not an array of bulk strings,sent from remote client ip '{}' ",stream_info.downstreamAddressProvider().remoteAddress()->asString());
onInvalidRequest(callbacks);
return nullptr;
}
}
ENVOY_LOG(info, "command to process '{}'", request->toString());
ENVOY_LOG(info, "command to process '{}'- sent from remote client ip '{}'", request->toString(),stream_info.downstreamAddressProvider().remoteAddress()->asString());
// Extract command name
std::string command_name = absl::AsciiStrToLower(request->asArray()[0].asString());

Expand All @@ -1803,7 +1867,7 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request,
// Handle AUTH command
if (command_name == Common::Redis::SupportedCommands::auth()) {
if (request->asArray().size() < 2) {
ENVOY_LOG(error,"invalid request - not enough arguments for auth command");
ENVOY_LOG(error,"invalid request - not enough arguments for auth command - sent from remote client ip '{}'",stream_info.downstreamAddressProvider().remoteAddress()->asString());
onInvalidRequest(callbacks);
return nullptr;
}
Expand Down Expand Up @@ -1865,21 +1929,33 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request,
return nullptr;
}

if (request->asArray().size() < 2 &&(Common::Redis::SupportedCommands::transactionCommands().count(command_name) == 0)
&& (Common::Redis::SupportedCommands::subcrStateallowedCommands().count(command_name) == 0)
&& (Common::Redis::SupportedCommands::noArgCommands().count(command_name) == 0)) {
// Commands other than PING, TIME and transaction commands all have at least two arguments.
ENVOY_LOG(error,"invalid request - not enough arguments for command: '{}'", command_name);
onInvalidRequest(callbacks);
return nullptr;
if (request->asArray().size() < 2) {
if (Common::Redis::SupportedCommands::transactionCommands().count(command_name) > 0){

ENVOY_LOG(debug,"as this is a Transaction command lesser arguments are allowed, further argument count check would be done in transaction handler: '{}'", command_name);

} else if((Common::Redis::SupportedCommands::subcrStateallowedCommands().count(command_name) > 0) && callbacks.transaction().active_ && callbacks.transaction().isSubscribedMode()){

ENVOY_LOG(debug,"as we are in subscribed state we allow commands with less arguments, further argument count check would be done in pubsub handler: '{}'", command_name);

}else if (Common::Redis::SupportedCommands::noArgCommands().count(command_name) > 0) {

ENVOY_LOG(debug,"as it is no argument command lesser arguments allowed for: '{}'", command_name);

} else {
// Commands other than PING, TIME and transaction commands all have at least two arguments.
ENVOY_LOG(error,"invalid request - not enough arguments for command: '{}' - sent from remote client ip '{}'", command_name,stream_info.downstreamAddressProvider().remoteAddress()->asString());
onInvalidRequest(callbacks);
return nullptr;
}
}

// Handle CLIENT command locally
if (command_name == "client") {
std::string sub_command = absl::AsciiStrToLower(request->asArray()[1].asString());
if (Common::Redis::SupportedCommands::clientSubCommands().count(sub_command) == 0) {
stats_.unsupported_command_.inc();
ENVOY_LOG(error, "unsupported command '{}' '{}'",command_name, sub_command);
ENVOY_LOG(error, "unsupported command '{}' '{}'- sent from remote client ip '{}'",command_name, sub_command,stream_info.downstreamAddressProvider().remoteAddress()->asString());
callbacks.onResponse(Common::Redis::Utility::makeError(
fmt::format("unsupported command '{}' '{}'",command_name, sub_command)));
return nullptr;
Expand Down Expand Up @@ -1917,9 +1993,12 @@ SplitRequestPtr InstanceImpl::makeRequest(Common::Redis::RespValuePtr&& request,
//Stream commands are not listed directly under any handler , we need to check if it is a blocking or simple command and choose appropriate handler
handler=getHandlerForStreamsCommand(command_name,request);

}else if (callbacks.transaction().active_ && callbacks.transaction().isSubscribedMode() && Common::Redis::SupportedCommands::subcrStateallowedCommands().contains(command_name)) {
handler = handler_lookup_table_.find("subscribe");

}else{
stats_.unsupported_command_.inc();
ENVOY_LOG(error, "unsupported command '{}'", request->asArray()[0].asString());
ENVOY_LOG(error, "unsupported command '{}', sent from remote client ip '{}'", request->asArray()[0].asString(), stream_info.downstreamAddressProvider().remoteAddress()->asString());
callbacks.onResponse(Common::Redis::Utility::makeError(
fmt::format("unsupported command '{}'", request->asArray()[0].asString())));
return nullptr;
Expand Down Expand Up @@ -2002,12 +2081,24 @@ void InstanceImpl::addHandler(Stats::Scope& scope, const std::string& stat_prefi
}

InstanceImpl::HandlerDataPtr InstanceImpl::getHandlerForStreamsCommand(const std::string& command_name, const Common::Redis::RespValuePtr& request) {
// Check if the command is a stream blocking command.
// Check if the command is a stream blocking command.
if (Common::Redis::SupportedCommands::streamBlockingCommands().contains(command_name)) {
// Check for "block" keyword in the appropriate positions of the request array.
if ((request->asArray().size() > 1 && absl::AsciiStrToLower(request->asArray()[1].asString()) == "block") ||
(request->asArray().size() > 3 && absl::AsciiStrToLower(request->asArray()[3].asString()) == "block")) {
return handler_lookup_table_.find("xread_blocking_command");
const auto& args = request->asArray();
size_t streams_pos = args.size();

// Find STREAMS keyword
for (size_t i = 1; i < args.size(); ++i) {
if (absl::AsciiStrToLower(args[i].asString()) == "streams") {
streams_pos = i;
break;
}
}

// Check for BLOCK keyword before STREAMS
for (size_t i = 1; i < streams_pos; ++i) {
if (absl::AsciiStrToLower(args[i].asString()) == "block") {
return handler_lookup_table_.find("xread_blocking_command");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,7 @@ class TransactionRequest : public SingleServerRequest {
TransactionRequest(SplitCallbacks& callbacks, CommandStats& command_stats,
TimeSource& time_source, bool delay_command_latency)
: SingleServerRequest(callbacks, command_stats, time_source, delay_command_latency) {}
static int32_t getShardingKeyIndex(const std::string& command_name,const Common::Redis::RespValue& request);
};

/**
Expand Down

0 comments on commit a851190

Please sign in to comment.