Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

splitup ChainState type #1731

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ This assists us in knowing when to make the next release a breaking release and

### shotover rust API

* `Transform::transform` now takes `&mut Wrapper` instead of `Wrapper`.
* `Wrapper` is renamed to ChainState.
`Transform::transform` previously took a `Wrapper` type as an argument.
That has now been split into 2 separate types: `&mut ChainState` and `DownChainTransforms`.

## 0.4.0

Expand Down
10 changes: 6 additions & 4 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, MessageType, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{
ChainState, Transform, TransformBuilder, TransformConfig, TransformContextConfig,
ChainState, DownChainTransforms, Transform, TransformBuilder, TransformConfig,
TransformContextConfig,
};
use shotover::transforms::{DownChainProtocol, TransformContextBuilder, UpChainProtocol};

Expand Down Expand Up @@ -64,9 +65,10 @@ impl Transform for RedisGetRewrite {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for message in chain_state.requests.iter_mut() {
if let Some(frame) = message.frame() {
Expand All @@ -75,7 +77,7 @@ impl Transform for RedisGetRewrite {
}
}
}
let mut responses = chain_state.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in responses.iter_mut() {
if response
Expand Down
10 changes: 5 additions & 5 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,14 @@ fn cassandra_parsed_query(query: &str) -> ChainState {
)
}

struct BenchInput<'a> {
struct BenchInput {
chain: TransformChain,
chain_state: ChainState<'a>,
chain_state: ChainState,
}

impl<'a> BenchInput<'a> {
impl BenchInput {
// Setup the bench such that the chain is completely fresh
fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState<'a>) -> Self {
fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self {
BenchInput {
chain: chain.build(TransformContextBuilder::new_test()),
chain_state: chain_state.clone(),
Expand All @@ -358,7 +358,7 @@ impl<'a> BenchInput<'a> {
// Setup the bench such that the chain has already had the test chain_state passed through it.
// This ensures that any adhoc setup for that message type has been performed.
// This is a more realistic bench for typical usage.
fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState<'a>) -> Self {
fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self {
let mut chain = chain.build(TransformContextBuilder::new_test());

// Run the chain once so we are measuring the chain once each transform has been fully initialized
Expand Down
9 changes: 5 additions & 4 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,11 @@ impl<C: CodecBuilder + 'static> Handler<C> {
out_tx: &mpsc::UnboundedSender<Messages>,
requests: Messages,
) -> Result<Option<CloseReason>> {
let mut wrapper = ChainState::new_with_addr(requests, local_addr);
let mut chain_state = ChainState::new_with_addr(requests, local_addr);

self.pending_requests.process_requests(&wrapper.requests);
let responses = match self.chain.process_request(&mut wrapper).await {
self.pending_requests
.process_requests(&chain_state.requests);
let responses = match self.chain.process_request(&mut chain_state).await {
Ok(x) => x,
Err(err) => {
let err = err.context("Chain failed to send and/or receive messages, the connection will now be closed.");
Expand All @@ -752,7 +753,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}

// if requested by a transform, close connection AFTER sending any responses back to the client
if wrapper.close_client_connection {
if chain_state.close_client_connection {
return Ok(Some(CloseReason::TransformRequested));
}

Expand Down
11 changes: 6 additions & 5 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::frame::MessageType;
use crate::message::{Message, MessageIdMap, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, UpChainProtocol,
};
use crate::{
frame::{
Expand Down Expand Up @@ -79,9 +79,10 @@ impl Transform for CassandraPeersRewrite {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
// Find the indices of queries to system.peers & system.peers_v2
// we need to know which columns in which CQL queries in which messages have system peers
Expand All @@ -90,7 +91,7 @@ impl Transform for CassandraPeersRewrite {
self.column_names_to_rewrite.insert(request.id(), sys_peers);
}

let mut responses = chain_state.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in &mut responses {
if let Some(Frame::Cassandra(frame)) = response.frame() {
Expand Down
9 changes: 5 additions & 4 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, M
use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, TransformContextConfig, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -761,9 +761,10 @@ impl Transform for CassandraSinkCluster {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
_down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut chain_state.requests))
.await
Expand Down
9 changes: 5 additions & 4 deletions shotover/src/transforms/cassandra/sink_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::frame::MessageType;
use crate::message::{Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, TransformContextConfig, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -212,9 +212,10 @@ impl Transform for CassandraSinkSingle {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
_down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut chain_state.requests))
.await
Expand Down
26 changes: 11 additions & 15 deletions shotover/src/transforms/chain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::TransformContextBuilder;
use super::{DownChainTransforms, TransformContextBuilder};
use crate::message::Messages;
use crate::transforms::{ChainState, Transform, TransformBuilder};
use anyhow::{anyhow, Result};
Expand Down Expand Up @@ -72,7 +72,7 @@ pub struct BufferedChain {
impl BufferedChain {
pub async fn process_request(
&mut self,
chain_state: ChainState<'_>,
chain_state: ChainState,
buffer_timeout_micros: Option<u64>,
) -> Result<Messages> {
self.process_request_with_receiver(chain_state, buffer_timeout_micros)
Expand All @@ -82,7 +82,7 @@ impl BufferedChain {

async fn process_request_with_receiver(
&mut self,
chain_state: ChainState<'_>,
chain_state: ChainState,
buffer_timeout_micros: Option<u64>,
) -> Result<oneshot::Receiver<Result<Messages>>> {
let (one_tx, one_rx) = oneshot::channel::<Result<Messages>>();
Expand Down Expand Up @@ -119,7 +119,7 @@ impl BufferedChain {

pub async fn process_request_no_return(
&mut self,
chain_state: ChainState<'_>,
chain_state: ChainState,
buffer_timeout_micros: Option<u64>,
) -> Result<()> {
if chain_state.flush {
Expand Down Expand Up @@ -158,16 +158,12 @@ impl BufferedChain {
}

impl TransformChain {
pub async fn process_request<'shorter, 'longer: 'shorter>(
&'longer mut self,
chain_state: &'shorter mut ChainState<'longer>,
) -> Result<Messages> {
pub async fn process_request(&mut self, state: &mut ChainState) -> Result<Messages> {
let start = Instant::now();
chain_state.reset(&mut self.chain);
let down_chain = DownChainTransforms::new(&mut self.chain);

self.chain_batch_size
.record(chain_state.requests.len() as f64);
let result = chain_state.call_next_transform().await;
self.chain_batch_size.record(state.requests.len() as f64);
let result = down_chain.call_next_transform(state).await;
self.chain_total.increment(1);
if result.is_err() {
self.chain_failures.increment(1);
Expand Down Expand Up @@ -322,9 +318,9 @@ impl TransformChainBuilder {
count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}

let mut chain_state = ChainState::new_with_addr(messages, local_addr);
chain_state.flush = flush;
let chain_response = chain.process_request(&mut chain_state).await;
let mut wrapper = ChainState::new_with_addr(messages, local_addr);
wrapper.flush = flush;
let chain_response = chain.process_request(&mut wrapper).await;

if let Err(e) = &chain_response {
error!("Internal error in buffered chain: {e:?}");
Expand Down
22 changes: 15 additions & 7 deletions shotover/src/transforms/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol};
use super::{
DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig,
UpChainProtocol,
};
use crate::message::Messages;
use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig};
use anyhow::Result;
Expand Down Expand Up @@ -81,9 +84,10 @@ impl Transform for Coalesce {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.buffer.append(&mut chain_state.requests);

Expand All @@ -102,7 +106,7 @@ impl Transform for Coalesce {
self.last_write = Instant::now()
}
std::mem::swap(&mut self.buffer, &mut chain_state.requests);
chain_state.call_next_transform().await
down_chain.call_next_transform(chain_state).await
} else {
Ok(vec![])
}
Expand All @@ -116,7 +120,7 @@ mod test {
use crate::transforms::chain::TransformAndMetrics;
use crate::transforms::coalesce::Coalesce;
use crate::transforms::loopback::Loopback;
use crate::transforms::{ChainState, Transform};
use crate::transforms::{ChainState, DownChainTransforms, Transform};
use pretty_assertions::assert_eq;
use std::time::{Duration, Instant};

Expand Down Expand Up @@ -199,9 +203,13 @@ mod test {
expected_len: usize,
) {
let mut wrapper = ChainState::new_test(requests.to_vec());
wrapper.reset(chain);
let transforms = DownChainTransforms::new(chain);
assert_eq!(
coalesce.transform(&mut wrapper).await.unwrap().len(),
coalesce
.transform(&mut wrapper, transforms)
.await
.unwrap()
.len(),
expected_len
);
}
Expand Down
8 changes: 5 additions & 3 deletions shotover/src/transforms/debug/force_parse.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::message::Messages;
use crate::transforms::DownChainTransforms;
/// This transform will by default parse requests and responses that pass through it.
/// request and response parsing can be individually disabled if desired.
///
Expand Down Expand Up @@ -105,9 +106,10 @@ impl Transform for DebugForceParse {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for message in &mut chain_state.requests {
if self.parse_requests {
Expand All @@ -118,7 +120,7 @@ impl Transform for DebugForceParse {
}
}

let mut response = chain_state.call_next_transform().await;
let mut response = down_chain.call_next_transform(chain_state).await;

if let Ok(response) = response.as_mut() {
for message in response {
Expand Down
11 changes: 7 additions & 4 deletions shotover/src/transforms/debug/log_to_file.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::message::{Encodable, Message};
use crate::transforms::{ChainState, Transform, TransformBuilder, TransformContextBuilder};
use crate::transforms::{
ChainState, DownChainTransforms, Transform, TransformBuilder, TransformContextBuilder,
};
#[cfg(feature = "alpha-transforms")]
use crate::transforms::{DownChainProtocol, UpChainProtocol};
use anyhow::{Context, Result};
Expand Down Expand Up @@ -89,9 +91,10 @@ impl Transform for DebugLogToFile {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Vec<Message>> {
for message in &chain_state.requests {
self.request_counter += 1;
Expand All @@ -101,7 +104,7 @@ impl Transform for DebugLogToFile {
log_message(message, path.as_path()).await?;
}

let response = chain_state.call_next_transform().await?;
let response = down_chain.call_next_transform(chain_state).await?;

for message in &response {
self.response_counter += 1;
Expand Down
11 changes: 6 additions & 5 deletions shotover/src/transforms/debug/printer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::message::Messages;
use crate::transforms::{
ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig,
TransformContextBuilder, TransformContextConfig, UpChainProtocol,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::Result;
use async_trait::async_trait;
Expand Down Expand Up @@ -65,16 +65,17 @@ impl Transform for DebugPrinter {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for request in &mut chain_state.requests {
info!("Request: {}", request.to_high_level_string());
}

self.counter += 1;
let mut responses = chain_state.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in &mut responses {
info!("Response: {}", response.to_high_level_string());
Expand Down
Loading
Loading