diff --git a/changelog.md b/changelog.md index 5ce164e72..46762d6ac 100644 --- a/changelog.md +++ b/changelog.md @@ -7,8 +7,8 @@ This assists us in knowing when to make the next release a breaking release and ### shotover rust API -* `Transform::transform` now takes `&mut Wrapper` instead of `Wrapper`. -* `Wrapper` is renamed to ChainState. +`Transform::transform` previously took a `Wrapper` type as an argument. +That has now been split into 2 separate types: `&mut ChainState` and `DownChainTransforms`. ## 0.4.0 diff --git a/custom-transforms-example/src/redis_get_rewrite.rs b/custom-transforms-example/src/redis_get_rewrite.rs index 735540794..9442817fe 100644 --- a/custom-transforms-example/src/redis_get_rewrite.rs +++ b/custom-transforms-example/src/redis_get_rewrite.rs @@ -4,7 +4,8 @@ use serde::{Deserialize, Serialize}; use shotover::frame::{Frame, MessageType, RedisFrame}; use shotover::message::{MessageIdSet, Messages}; use shotover::transforms::{ - ChainState, Transform, TransformBuilder, TransformConfig, TransformContextConfig, + ChainState, DownChainTransforms, Transform, TransformBuilder, TransformConfig, + TransformContextConfig, }; use shotover::transforms::{DownChainProtocol, TransformContextBuilder, UpChainProtocol}; @@ -64,9 +65,10 @@ impl Transform for RedisGetRewrite { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for message in chain_state.requests.iter_mut() { if let Some(frame) = message.frame() { @@ -75,7 +77,7 @@ impl Transform for RedisGetRewrite { } } } - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; for response in responses.iter_mut() { if response diff --git a/shotover/benches/benches/chain.rs b/shotover/benches/benches/chain.rs index 7bb0055fb..1b49ade27 100644 --- a/shotover/benches/benches/chain.rs +++ b/shotover/benches/benches/chain.rs @@ -341,14 +341,14 @@ fn cassandra_parsed_query(query: &str) -> ChainState { ) } -struct BenchInput<'a> { +struct BenchInput { chain: TransformChain, - chain_state: ChainState<'a>, + chain_state: ChainState, } -impl<'a> BenchInput<'a> { +impl BenchInput { // Setup the bench such that the chain is completely fresh - fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState<'a>) -> Self { + fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self { BenchInput { chain: chain.build(TransformContextBuilder::new_test()), chain_state: chain_state.clone(), @@ -358,7 +358,7 @@ impl<'a> BenchInput<'a> { // Setup the bench such that the chain has already had the test chain_state passed through it. // This ensures that any adhoc setup for that message type has been performed. // This is a more realistic bench for typical usage. - fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState<'a>) -> Self { + fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self { let mut chain = chain.build(TransformContextBuilder::new_test()); // Run the chain once so we are measuring the chain once each transform has been fully initialized diff --git a/shotover/src/server.rs b/shotover/src/server.rs index 1ed4ce45f..4f750470a 100644 --- a/shotover/src/server.rs +++ b/shotover/src/server.rs @@ -727,10 +727,11 @@ impl Handler { out_tx: &mpsc::UnboundedSender, requests: Messages, ) -> Result> { - let mut wrapper = ChainState::new_with_addr(requests, local_addr); + let mut chain_state = ChainState::new_with_addr(requests, local_addr); - self.pending_requests.process_requests(&wrapper.requests); - let responses = match self.chain.process_request(&mut wrapper).await { + self.pending_requests + .process_requests(&chain_state.requests); + let responses = match self.chain.process_request(&mut chain_state).await { Ok(x) => x, Err(err) => { let err = err.context("Chain failed to send and/or receive messages, the connection will now be closed."); @@ -752,7 +753,7 @@ impl Handler { } // if requested by a transform, close connection AFTER sending any responses back to the client - if wrapper.close_client_connection { + if chain_state.close_client_connection { return Ok(Some(CloseReason::TransformRequested)); } diff --git a/shotover/src/transforms/cassandra/peers_rewrite.rs b/shotover/src/transforms/cassandra/peers_rewrite.rs index 19f9c31b3..46b514e21 100644 --- a/shotover/src/transforms/cassandra/peers_rewrite.rs +++ b/shotover/src/transforms/cassandra/peers_rewrite.rs @@ -2,8 +2,8 @@ use crate::frame::MessageType; use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, UpChainProtocol, }; use crate::{ frame::{ @@ -79,9 +79,10 @@ impl Transform for CassandraPeersRewrite { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { // Find the indices of queries to system.peers & system.peers_v2 // we need to know which columns in which CQL queries in which messages have system peers @@ -90,7 +91,7 @@ impl Transform for CassandraPeersRewrite { self.column_names_to_rewrite.insert(request.id(), sys_peers); } - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; for response in &mut responses { if let Some(Frame::Cassandra(frame)) = response.frame() { diff --git a/shotover/src/transforms/cassandra/sink_cluster/mod.rs b/shotover/src/transforms/cassandra/sink_cluster/mod.rs index 1d3695fa5..659e209f9 100644 --- a/shotover/src/transforms/cassandra/sink_cluster/mod.rs +++ b/shotover/src/transforms/cassandra/sink_cluster/mod.rs @@ -6,8 +6,8 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, M use crate::message::{Message, MessageIdMap, Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, TransformContextConfig, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol, }; use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; @@ -761,9 +761,10 @@ impl Transform for CassandraSinkCluster { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { self.send_message(std::mem::take(&mut chain_state.requests)) .await diff --git a/shotover/src/transforms/cassandra/sink_single.rs b/shotover/src/transforms/cassandra/sink_single.rs index fe249de31..418adc1e9 100644 --- a/shotover/src/transforms/cassandra/sink_single.rs +++ b/shotover/src/transforms/cassandra/sink_single.rs @@ -5,8 +5,8 @@ use crate::frame::MessageType; use crate::message::{Messages, Metadata}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, TransformContextConfig, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol, }; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -212,9 +212,10 @@ impl Transform for CassandraSinkSingle { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { self.send_message(std::mem::take(&mut chain_state.requests)) .await diff --git a/shotover/src/transforms/chain.rs b/shotover/src/transforms/chain.rs index f27325568..152c0fdc7 100644 --- a/shotover/src/transforms/chain.rs +++ b/shotover/src/transforms/chain.rs @@ -1,4 +1,4 @@ -use super::TransformContextBuilder; +use super::{DownChainTransforms, TransformContextBuilder}; use crate::message::Messages; use crate::transforms::{ChainState, Transform, TransformBuilder}; use anyhow::{anyhow, Result}; @@ -72,7 +72,7 @@ pub struct BufferedChain { impl BufferedChain { pub async fn process_request( &mut self, - chain_state: ChainState<'_>, + chain_state: ChainState, buffer_timeout_micros: Option, ) -> Result { self.process_request_with_receiver(chain_state, buffer_timeout_micros) @@ -82,7 +82,7 @@ impl BufferedChain { async fn process_request_with_receiver( &mut self, - chain_state: ChainState<'_>, + chain_state: ChainState, buffer_timeout_micros: Option, ) -> Result>> { let (one_tx, one_rx) = oneshot::channel::>(); @@ -119,7 +119,7 @@ impl BufferedChain { pub async fn process_request_no_return( &mut self, - chain_state: ChainState<'_>, + chain_state: ChainState, buffer_timeout_micros: Option, ) -> Result<()> { if chain_state.flush { @@ -158,16 +158,12 @@ impl BufferedChain { } impl TransformChain { - pub async fn process_request<'shorter, 'longer: 'shorter>( - &'longer mut self, - chain_state: &'shorter mut ChainState<'longer>, - ) -> Result { + pub async fn process_request(&mut self, state: &mut ChainState) -> Result { let start = Instant::now(); - chain_state.reset(&mut self.chain); + let down_chain = DownChainTransforms::new(&mut self.chain); - self.chain_batch_size - .record(chain_state.requests.len() as f64); - let result = chain_state.call_next_transform().await; + self.chain_batch_size.record(state.requests.len() as f64); + let result = down_chain.call_next_transform(state).await; self.chain_total.increment(1); if result.is_err() { self.chain_failures.increment(1); @@ -322,9 +318,9 @@ impl TransformChainBuilder { count_clone.fetch_add(1, std::sync::atomic::Ordering::Relaxed); } - let mut chain_state = ChainState::new_with_addr(messages, local_addr); - chain_state.flush = flush; - let chain_response = chain.process_request(&mut chain_state).await; + let mut wrapper = ChainState::new_with_addr(messages, local_addr); + wrapper.flush = flush; + let chain_response = chain.process_request(&mut wrapper).await; if let Err(e) = &chain_response { error!("Internal error in buffered chain: {e:?}"); diff --git a/shotover/src/transforms/coalesce.rs b/shotover/src/transforms/coalesce.rs index 7399351df..e45d5fe8a 100644 --- a/shotover/src/transforms/coalesce.rs +++ b/shotover/src/transforms/coalesce.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::message::Messages; use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig}; use anyhow::Result; @@ -81,9 +84,10 @@ impl Transform for Coalesce { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { self.buffer.append(&mut chain_state.requests); @@ -102,7 +106,7 @@ impl Transform for Coalesce { self.last_write = Instant::now() } std::mem::swap(&mut self.buffer, &mut chain_state.requests); - chain_state.call_next_transform().await + down_chain.call_next_transform(chain_state).await } else { Ok(vec![]) } @@ -116,7 +120,7 @@ mod test { use crate::transforms::chain::TransformAndMetrics; use crate::transforms::coalesce::Coalesce; use crate::transforms::loopback::Loopback; - use crate::transforms::{ChainState, Transform}; + use crate::transforms::{ChainState, DownChainTransforms, Transform}; use pretty_assertions::assert_eq; use std::time::{Duration, Instant}; @@ -199,9 +203,13 @@ mod test { expected_len: usize, ) { let mut wrapper = ChainState::new_test(requests.to_vec()); - wrapper.reset(chain); + let transforms = DownChainTransforms::new(chain); assert_eq!( - coalesce.transform(&mut wrapper).await.unwrap().len(), + coalesce + .transform(&mut wrapper, transforms) + .await + .unwrap() + .len(), expected_len ); } diff --git a/shotover/src/transforms/debug/force_parse.rs b/shotover/src/transforms/debug/force_parse.rs index 4dee7c45d..27ee1cc16 100644 --- a/shotover/src/transforms/debug/force_parse.rs +++ b/shotover/src/transforms/debug/force_parse.rs @@ -1,4 +1,5 @@ use crate::message::Messages; +use crate::transforms::DownChainTransforms; /// This transform will by default parse requests and responses that pass through it. /// request and response parsing can be individually disabled if desired. /// @@ -105,9 +106,10 @@ impl Transform for DebugForceParse { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for message in &mut chain_state.requests { if self.parse_requests { @@ -118,7 +120,7 @@ impl Transform for DebugForceParse { } } - let mut response = chain_state.call_next_transform().await; + let mut response = down_chain.call_next_transform(chain_state).await; if let Ok(response) = response.as_mut() { for message in response { diff --git a/shotover/src/transforms/debug/log_to_file.rs b/shotover/src/transforms/debug/log_to_file.rs index 70711dff9..f31d4bfd6 100644 --- a/shotover/src/transforms/debug/log_to_file.rs +++ b/shotover/src/transforms/debug/log_to_file.rs @@ -1,5 +1,7 @@ use crate::message::{Encodable, Message}; -use crate::transforms::{ChainState, Transform, TransformBuilder, TransformContextBuilder}; +use crate::transforms::{ + ChainState, DownChainTransforms, Transform, TransformBuilder, TransformContextBuilder, +}; #[cfg(feature = "alpha-transforms")] use crate::transforms::{DownChainProtocol, UpChainProtocol}; use anyhow::{Context, Result}; @@ -89,9 +91,10 @@ impl Transform for DebugLogToFile { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result> { for message in &chain_state.requests { self.request_counter += 1; @@ -101,7 +104,7 @@ impl Transform for DebugLogToFile { log_message(message, path.as_path()).await?; } - let response = chain_state.call_next_transform().await?; + let response = down_chain.call_next_transform(chain_state).await?; for message in &response { self.response_counter += 1; diff --git a/shotover/src/transforms/debug/printer.rs b/shotover/src/transforms/debug/printer.rs index f36b89238..566ddf4c3 100644 --- a/shotover/src/transforms/debug/printer.rs +++ b/shotover/src/transforms/debug/printer.rs @@ -1,7 +1,7 @@ use crate::message::Messages; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, TransformContextConfig, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol, }; use anyhow::Result; use async_trait::async_trait; @@ -65,16 +65,17 @@ impl Transform for DebugPrinter { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for request in &mut chain_state.requests { info!("Request: {}", request.to_high_level_string()); } self.counter += 1; - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; for response in &mut responses { info!("Response: {}", response.to_high_level_string()); diff --git a/shotover/src/transforms/debug/returner.rs b/shotover/src/transforms/debug/returner.rs index 95c75c607..696e864ab 100644 --- a/shotover/src/transforms/debug/returner.rs +++ b/shotover/src/transforms/debug/returner.rs @@ -1,7 +1,7 @@ use crate::message::{Message, Messages}; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, TransformContextConfig, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol, }; use anyhow::{anyhow, Result}; use async_trait::async_trait; @@ -75,9 +75,10 @@ impl Transform for DebugReturner { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { chain_state .requests diff --git a/shotover/src/transforms/filter.rs b/shotover/src/transforms/filter.rs index c6cad8f7f..39ecbf104 100644 --- a/shotover/src/transforms/filter.rs +++ b/shotover/src/transforms/filter.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::message::{Message, MessageIdMap, Messages, QueryType}; use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig}; use anyhow::Result; @@ -64,9 +67,10 @@ impl Transform for QueryTypeFilter { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for request in chain_state.requests.iter_mut() { let filter_out = match &self.filter { @@ -87,7 +91,7 @@ impl Transform for QueryTypeFilter { } } - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; for response in responses.iter_mut() { if let Some(request_id) = response.request_id() { if let Some(error_response) = self.filtered_requests.remove(&request_id) { @@ -110,6 +114,7 @@ mod test { use crate::transforms::chain::TransformAndMetrics; use crate::transforms::filter::QueryTypeFilter; use crate::transforms::loopback::Loopback; + use crate::transforms::DownChainTransforms; use crate::transforms::{ChainState, Transform}; use pretty_assertions::assert_eq; @@ -140,8 +145,11 @@ mod test { .collect(); let mut chain_state = ChainState::new_test(messages); - chain_state.reset(&mut chain); - let result = filter_transform.transform(&mut chain_state).await.unwrap(); + let transforms = DownChainTransforms::new(&mut chain); + let result = filter_transform + .transform(&mut chain_state, transforms) + .await + .unwrap(); assert_eq!(result.len(), 26); @@ -195,8 +203,11 @@ mod test { .collect(); let mut chain_state = ChainState::new_test(messages); - chain_state.reset(&mut chain); - let result = filter_transform.transform(&mut chain_state).await.unwrap(); + let transforms = DownChainTransforms::new(&mut chain); + let result = filter_transform + .transform(&mut chain_state, transforms) + .await + .unwrap(); assert_eq!(result.len(), 26); diff --git a/shotover/src/transforms/kafka/sink_cluster/mod.rs b/shotover/src/transforms/kafka/sink_cluster/mod.rs index caf3f3000..7387a2431 100644 --- a/shotover/src/transforms/kafka/sink_cluster/mod.rs +++ b/shotover/src/transforms/kafka/sink_cluster/mod.rs @@ -3,8 +3,8 @@ use crate::frame::{Frame, MessageType}; use crate::message::{Message, MessageIdMap, Messages}; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformContextBuilder, - UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformContextBuilder, UpChainProtocol, }; use crate::transforms::{TransformConfig, TransformContextConfig}; use anyhow::{anyhow, Context, Result}; @@ -341,9 +341,10 @@ impl Transform for KafkaSinkCluster { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { let mut responses = if chain_state.requests.is_empty() { // there are no requests, so no point sending any, but we should check for any responses without awaiting diff --git a/shotover/src/transforms/kafka/sink_single.rs b/shotover/src/transforms/kafka/sink_single.rs index 738701abf..b3e2e37b6 100644 --- a/shotover/src/transforms/kafka/sink_single.rs +++ b/shotover/src/transforms/kafka/sink_single.rs @@ -5,7 +5,8 @@ use crate::frame::{Frame, MessageType}; use crate::message::Messages; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ - ChainState, Transform, TransformBuilder, TransformContextBuilder, TransformContextConfig, + ChainState, DownChainTransforms, Transform, TransformBuilder, TransformContextBuilder, + TransformContextConfig, }; use crate::transforms::{DownChainProtocol, TransformConfig, UpChainProtocol}; use anyhow::Result; @@ -117,9 +118,10 @@ impl Transform for KafkaSinkSingle { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { if self.connection.is_none() { let codec = KafkaCodecBuilder::new(Direction::Sink, "KafkaSinkSingle".to_owned()); diff --git a/shotover/src/transforms/load_balance.rs b/shotover/src/transforms/load_balance.rs index 5691e7c27..711c91e06 100644 --- a/shotover/src/transforms/load_balance.rs +++ b/shotover/src/transforms/load_balance.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{BufferedChain, TransformChainBuilder}; @@ -85,9 +88,10 @@ impl Transform for ConnectionBalanceAndPool { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { if self.active_connection.is_none() { let mut all_connections = self.all_connections.lock().await; diff --git a/shotover/src/transforms/loopback.rs b/shotover/src/transforms/loopback.rs index eae995f1a..d82d02258 100644 --- a/shotover/src/transforms/loopback.rs +++ b/shotover/src/transforms/loopback.rs @@ -1,4 +1,4 @@ -use super::TransformContextBuilder; +use super::{DownChainTransforms, TransformContextBuilder}; use crate::message::Messages; use crate::transforms::{ChainState, Transform, TransformBuilder}; use anyhow::Result; @@ -29,9 +29,10 @@ impl Transform for Loopback { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { // This transform ultimately doesnt make a lot of sense semantically // but make a vague attempt to follow transform invariants anyway. diff --git a/shotover/src/transforms/mod.rs b/shotover/src/transforms/mod.rs index 60ee70835..054c58c74 100644 --- a/shotover/src/transforms/mod.rs +++ b/shotover/src/transforms/mod.rs @@ -146,10 +146,10 @@ pub struct TransformContextConfig { /// The [`Wrapper`] struct is passed into each transform and contains a list of mutable references to the /// remaining transforms that will process the messages attached to this [`Wrapper`]. /// Most [`Transform`] authors will only be interested in [`wrapper.requests`]. -pub struct ChainState<'a> { +#[derive(Clone)] +pub struct ChainState { /// Requests received from the client pub requests: Messages, - transforms: IterMut<'a, TransformAndMetrics>, /// Contains the shotover source's ip address and port which the message was received on pub local_addr: SocketAddr, /// When true transforms must flush any buffered messages into the messages field. @@ -163,32 +163,15 @@ pub struct ChainState<'a> { pub close_client_connection: bool, } -/// [`Wrapper`] will not (cannot) bring the current list of transforms that it needs to traverse with it -/// This is purely to make it convenient to clone all the data within Wrapper rather than it's transform -/// state. -impl<'a> Clone for ChainState<'a> { - fn clone(&self) -> Self { - ChainState { - requests: self.requests.clone(), - transforms: [].iter_mut(), - local_addr: self.local_addr, - flush: self.flush, - close_client_connection: self.close_client_connection, - } - } -} +pub struct DownChainTransforms<'a>(IterMut<'a, TransformAndMetrics>); -impl<'shorter, 'longer: 'shorter> ChainState<'longer> { - fn take(&mut self) -> Self { - ChainState { - requests: std::mem::take(&mut self.requests), - transforms: std::mem::take(&mut self.transforms), - local_addr: self.local_addr, - flush: self.flush, - close_client_connection: self.close_client_connection, - } +impl<'a> DownChainTransforms<'a> { + fn new(transforms: &'a mut [TransformAndMetrics]) -> Self { + DownChainTransforms(transforms.iter_mut()) } +} +impl DownChainTransforms<'_> { /// This function will take a mutable reference to the next transform out of the [`Wrapper`] structs /// vector of transform references. It then sets up the chain name and transform name in the local /// thread scope for structured logging. @@ -197,14 +180,14 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { /// the execution time of the [Transform::transform] function as a metrics latency histogram. /// /// The result of calling the next transform is then provided as a response. - pub async fn call_next_transform(&'shorter mut self) -> Result { + pub async fn call_next_transform(mut self, chain_state: &mut ChainState) -> Result { let TransformAndMetrics { transform, transform_total, transform_failures, transform_latency, .. - } = match self.transforms.next() { + } = match self.0.next() { Some(transform) => transform, None => panic!("The transform chain does not end with a terminating transform. If you want to throw the messages away use a NullSink transform, otherwise use a terminating sink transform to send the messages somewhere.") }; @@ -213,7 +196,7 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { let start = Instant::now(); let result = transform - .transform(self) + .transform(chain_state, self) .await .map_err(|e| e.context(anyhow!("{transform_name} transform failed"))); transform_total.increment(1); @@ -223,6 +206,17 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { transform_latency.record(start.elapsed()); result } +} + +impl ChainState { + fn take(&mut self) -> Self { + ChainState { + requests: std::mem::take(&mut self.requests), + local_addr: self.local_addr, + flush: self.flush, + close_client_connection: self.close_client_connection, + } + } pub fn clone_requests_into_hashmap(&self, destination: &mut MessageIdMap) { for request in &self.requests { @@ -234,8 +228,7 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { pub fn new_test(requests: Messages) -> Self { ChainState { requests, - transforms: [].iter_mut(), - local_addr: DUMMY_ADDRESS, + local_addr: "127.0.0.1:8000".parse().unwrap(), flush: false, close_client_connection: false, } @@ -244,7 +237,6 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { pub fn new_with_addr(requests: Messages, local_addr: SocketAddr) -> Self { ChainState { requests, - transforms: [].iter_mut(), local_addr, flush: false, close_client_connection: false, @@ -254,7 +246,6 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { pub fn flush() -> Self { ChainState { requests: vec![], - transforms: [].iter_mut(), // The connection is closed so we need to just fake an address here local_addr: DUMMY_ADDRESS, flush: true, @@ -273,10 +264,6 @@ impl<'shorter, 'longer: 'shorter> ChainState<'longer> { .collect::>(); format!("{:?}", messages) } - - pub fn reset(&mut self, transforms: &'longer mut [TransformAndMetrics]) { - self.transforms = transforms.iter_mut(); - } } const DUMMY_ADDRESS: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)); @@ -346,9 +333,10 @@ pub trait Transform: Send { /// * Transform that do call subsquent chains via `chain_state.call_next_transform()` are non-terminating transforms. /// /// You can have have a transform that is both non-terminating and a sink. - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result; /// Name of the transform used in logs and displayed to the user diff --git a/shotover/src/transforms/null.rs b/shotover/src/transforms/null.rs index 4cedbefa0..3fca5caf7 100644 --- a/shotover/src/transforms/null.rs +++ b/shotover/src/transforms/null.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::message::Messages; use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig}; use anyhow::Result; @@ -52,9 +55,10 @@ impl Transform for NullSink { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { for request in &mut chain_state.requests { // reuse the requests to hold the responses to avoid an allocation diff --git a/shotover/src/transforms/opensearch/mod.rs b/shotover/src/transforms/opensearch/mod.rs index bee60bfe8..8846b064f 100644 --- a/shotover/src/transforms/opensearch/mod.rs +++ b/shotover/src/transforms/opensearch/mod.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::frame::MessageType; use crate::tcp; use crate::transforms::{ChainState, Messages, Transform, TransformBuilder, TransformConfig}; @@ -95,9 +98,10 @@ impl Transform for OpenSearchSinkSingle { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { // Return immediately if we have no messages. // If we tried to send no messages we would block forever waiting for a reply that will never come. diff --git a/shotover/src/transforms/parallel_map.rs b/shotover/src/transforms/parallel_map.rs index e131af6c7..29d75fbf2 100644 --- a/shotover/src/transforms/parallel_map.rs +++ b/shotover/src/transforms/parallel_map.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::config::chain::TransformChainConfig; use crate::message::Messages; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; @@ -108,9 +111,10 @@ impl Transform for ParallelMap { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { let mut results = Vec::with_capacity(chain_state.requests.len()); let mut message_iter = chain_state.requests.drain(..); diff --git a/shotover/src/transforms/protect/mod.rs b/shotover/src/transforms/protect/mod.rs index 1caaf800e..c6729b3cf 100644 --- a/shotover/src/transforms/protect/mod.rs +++ b/shotover/src/transforms/protect/mod.rs @@ -1,5 +1,5 @@ -use super::TransformContextBuilder; use super::{DownChainProtocol, UpChainProtocol}; +use super::{DownChainTransforms, TransformContextBuilder}; use crate::frame::MessageType; use crate::frame::{ value::GenericValue, CassandraFrame, CassandraOperation, CassandraResult, Frame, @@ -184,9 +184,10 @@ impl Transform for Protect { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { // encrypt the values included in any INSERT or UPDATE queries for message in chain_state.requests.iter_mut() { @@ -203,7 +204,7 @@ impl Transform for Protect { } chain_state.clone_requests_into_hashmap(&mut self.requests); - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; for response in &mut responses { if let Some(request_id) = response.request_id() { diff --git a/shotover/src/transforms/query_counter.rs b/shotover/src/transforms/query_counter.rs index 41e7efa58..c1b0832dd 100644 --- a/shotover/src/transforms/query_counter.rs +++ b/shotover/src/transforms/query_counter.rs @@ -12,6 +12,7 @@ use serde::Serialize; use std::collections::HashMap; use super::DownChainProtocol; +use super::DownChainTransforms; use super::TransformContextConfig; use super::UpChainProtocol; @@ -64,9 +65,10 @@ impl Transform for QueryCounter { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for m in &mut chain_state.requests { match m.frame() { @@ -101,7 +103,7 @@ impl Transform for QueryCounter { } } - chain_state.call_next_transform().await + down_chain.call_next_transform(chain_state).await } } diff --git a/shotover/src/transforms/redis/cache.rs b/shotover/src/transforms/redis/cache.rs index 24d63ecfd..b15449af2 100644 --- a/shotover/src/transforms/redis/cache.rs +++ b/shotover/src/transforms/redis/cache.rs @@ -3,8 +3,8 @@ use crate::frame::{CassandraFrame, CassandraOperation, Frame, MessageType, Redis use crate::message::{Message, MessageIdMap, Messages, Metadata}; use crate::transforms::chain::{TransformChain, TransformChainBuilder}; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, TransformContextConfig, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol, }; use anyhow::{bail, Result}; use async_trait::async_trait; @@ -376,7 +376,8 @@ impl SimpleRedisCache { /// calls the next transform and process the result for caching. async fn execute_upstream_and_write_to_cache( &mut self, - chain_state: &mut ChainState<'_>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { let local_addr = chain_state.local_addr; let mut request_messages: Vec<_> = chain_state @@ -384,7 +385,7 @@ impl SimpleRedisCache { .iter_mut() .map(|message| message.frame().cloned()) .collect(); - let mut response_messages = chain_state.call_next_transform().await?; + let mut response_messages = down_chain.call_next_transform(chain_state).await?; let mut cache_messages = vec![]; for (request, response) in request_messages @@ -618,9 +619,10 @@ impl Transform for SimpleRedisCache { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { self.read_from_cache(&mut chain_state.requests, chain_state.local_addr) .await @@ -634,7 +636,7 @@ impl Transform for SimpleRedisCache { &mut self.cache_miss_cassandra_requests, ); let mut responses = self - .execute_upstream_and_write_to_cache(chain_state) + .execute_upstream_and_write_to_cache(chain_state, down_chain) .await?; // add the cache hits to the final response diff --git a/shotover/src/transforms/redis/cluster_ports_rewrite.rs b/shotover/src/transforms/redis/cluster_ports_rewrite.rs index 8cf61ce31..1fa858402 100644 --- a/shotover/src/transforms/redis/cluster_ports_rewrite.rs +++ b/shotover/src/transforms/redis/cluster_ports_rewrite.rs @@ -3,6 +3,7 @@ use crate::frame::MessageType; use crate::frame::RedisFrame; use crate::message::{MessageIdMap, Messages}; use crate::transforms::DownChainProtocol; +use crate::transforms::DownChainTransforms; use crate::transforms::TransformContextBuilder; use crate::transforms::TransformContextConfig; use crate::transforms::UpChainProtocol; @@ -76,9 +77,10 @@ impl Transform for RedisClusterPortsRewrite { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for message in chain_state.requests.iter_mut() { let message_id = message.id(); @@ -95,7 +97,7 @@ impl Transform for RedisClusterPortsRewrite { } } - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; for response in &mut responses { if let Some(request_id) = response.request_id() { diff --git a/shotover/src/transforms/redis/sink_cluster.rs b/shotover/src/transforms/redis/sink_cluster.rs index 1b251bb8b..be086ebc7 100644 --- a/shotover/src/transforms/redis/sink_cluster.rs +++ b/shotover/src/transforms/redis/sink_cluster.rs @@ -8,8 +8,9 @@ use crate::transforms::redis::TransformError; use crate::transforms::util::cluster_connection_pool::{Authenticator, ConnectionPool}; use crate::transforms::util::{Request, Response}; use crate::transforms::{ - ChainState, DownChainProtocol, ResponseFuture, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, TransformContextConfig, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, ResponseFuture, Transform, + TransformBuilder, TransformConfig, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, }; use anyhow::{anyhow, bail, ensure, Context, Result}; use async_trait::async_trait; @@ -1017,9 +1018,10 @@ impl Transform for RedisSinkCluster { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { if !self.has_run_init { self.topology = (*self.shared_topology.read().await).clone(); diff --git a/shotover/src/transforms/redis/sink_single.rs b/shotover/src/transforms/redis/sink_single.rs index bc842a16b..e52bc01ef 100644 --- a/shotover/src/transforms/redis/sink_single.rs +++ b/shotover/src/transforms/redis/sink_single.rs @@ -4,8 +4,8 @@ use crate::frame::{Frame, MessageType, RedisFrame}; use crate::message::Messages; use crate::tls::{TlsConnector, TlsConnectorConfig}; use crate::transforms::{ - ChainState, DownChainProtocol, Transform, TransformBuilder, TransformConfig, - TransformContextBuilder, UpChainProtocol, + ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder, + TransformConfig, TransformContextBuilder, UpChainProtocol, }; use crate::{codec::redis::RedisCodecBuilder, transforms::TransformContextConfig}; use anyhow::Result; @@ -114,9 +114,10 @@ impl Transform for RedisSinkSingle { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + _down_chain: DownChainTransforms<'_>, ) -> Result { if self.connection.is_none() { let codec = RedisCodecBuilder::new(Direction::Sink, "RedisSinkSingle".to_owned()); diff --git a/shotover/src/transforms/tee.rs b/shotover/src/transforms/tee.rs index 8f2c81d80..d16f02da8 100644 --- a/shotover/src/transforms/tee.rs +++ b/shotover/src/transforms/tee.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::config::chain::TransformChainConfig; use crate::http::HttpServerError; use crate::message::{Message, MessageIdMap, Messages}; @@ -243,17 +246,18 @@ impl Transform for Tee { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { match &mut self.behavior { - ConsistencyBehavior::Ignore => self.ignore_behaviour(chain_state).await, + ConsistencyBehavior::Ignore => self.ignore_behaviour(chain_state, down_chain).await, ConsistencyBehavior::FailOnMismatch => { let (tee_result, chain_result) = tokio::join!( self.tx .process_request(chain_state.clone(), self.timeout_micros), - chain_state.call_next_transform() + down_chain.call_next_transform(chain_state) ); let keep: ResultSource = self.result_source.load(Ordering::Relaxed); @@ -283,7 +287,7 @@ impl Transform for Tee { let (tee_result, chain_result) = tokio::join!( self.tx .process_request(chain_state.clone(), self.timeout_micros), - chain_state.call_next_transform() + down_chain.call_next_transform(chain_state) ); let mut mismatched_requests = vec![]; @@ -311,7 +315,7 @@ impl Transform for Tee { let (tee_result, chain_result) = tokio::join!( self.tx .process_request(chain_state.clone(), self.timeout_micros), - chain_state.call_next_transform() + down_chain.call_next_transform(chain_state) ); let keep: ResultSource = self.result_source.load(Ordering::Relaxed); @@ -486,9 +490,10 @@ impl IncomingResponses { } impl Tee { - async fn ignore_behaviour<'shorter, 'longer: 'shorter>( + async fn ignore_behaviour( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { let result_source: ResultSource = self.result_source.load(Ordering::Relaxed); match result_source { @@ -496,7 +501,7 @@ impl Tee { let (tee_result, chain_result) = tokio::join!( self.tx .process_request_no_return(chain_state.clone(), self.timeout_micros), - chain_state.call_next_transform() + down_chain.call_next_transform(chain_state) ); if let Err(e) = tee_result { self.dropped_messages.increment(1); @@ -508,7 +513,7 @@ impl Tee { let (tee_result, chain_result) = tokio::join!( self.tx .process_request(chain_state.clone(), self.timeout_micros), - chain_state.call_next_transform() + down_chain.call_next_transform(chain_state) ); if let Err(e) = chain_result { self.dropped_messages.increment(1); diff --git a/shotover/src/transforms/throttling.rs b/shotover/src/transforms/throttling.rs index 79bf814fd..119cbfe0b 100644 --- a/shotover/src/transforms/throttling.rs +++ b/shotover/src/transforms/throttling.rs @@ -1,4 +1,7 @@ -use super::{DownChainProtocol, TransformContextBuilder, TransformContextConfig, UpChainProtocol}; +use super::{ + DownChainProtocol, DownChainTransforms, TransformContextBuilder, TransformContextConfig, + UpChainProtocol, +}; use crate::frame::MessageType; use crate::message::{Message, MessageIdMap, Messages}; use crate::transforms::{ChainState, Transform, TransformBuilder, TransformConfig}; @@ -81,9 +84,10 @@ impl Transform for RequestThrottling { NAME } - async fn transform<'shorter, 'longer: 'shorter>( + async fn transform( &mut self, - chain_state: &'shorter mut ChainState<'longer>, + chain_state: &mut ChainState, + down_chain: DownChainTransforms<'_>, ) -> Result { for request in &mut chain_state.requests { if let Ok(cell_count) = request.cell_count() { @@ -107,7 +111,7 @@ impl Transform for RequestThrottling { } // send allowed messages to Cassandra - let mut responses = chain_state.call_next_transform().await?; + let mut responses = down_chain.call_next_transform(chain_state).await?; // replace dummy responses with throttle messages for response in responses.iter_mut() {