diff --git a/ouroboros-network-api/src/Ouroboros/Network/SizeInBytes.hs b/ouroboros-network-api/src/Ouroboros/Network/SizeInBytes.hs index e92259d8de0..437b6865348 100644 --- a/ouroboros-network-api/src/Ouroboros/Network/SizeInBytes.hs +++ b/ouroboros-network-api/src/Ouroboros/Network/SizeInBytes.hs @@ -1,11 +1,15 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE DerivingVia #-} {-# LANGUAGE GeneralisedNewtypeDeriving #-} -module Ouroboros.Network.SizeInBytes (SizeInBytes (..)) where +module Ouroboros.Network.SizeInBytes ( + SizeInBytes (..) + , WithBytes (..)) where import Control.DeepSeq (NFData (..)) +import Data.ByteString.Short (ShortByteString) import Data.Monoid (Sum (..)) import Data.Word (Word32) import GHC.Generics @@ -14,6 +18,13 @@ import Data.Measure qualified as Measure import NoThunks.Class (NoThunks (..)) import Quiet (Quiet (..)) +data WithBytes a = WithBytes { wbValue :: !a, + unannotate :: ShortByteString } + deriving (Eq, Show) + +instance NFData a => NFData (WithBytes a) where + rnf (WithBytes a b) = rnf a `seq` rnf b + newtype SizeInBytes = SizeInBytes { getSizeInBytes :: Word32 } deriving (Eq, Ord) deriving Show via Quiet SizeInBytes diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs index 372c333bbb8..21957c04a01 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Limits.hs @@ -46,25 +46,30 @@ import Ouroboros.Network.Protocol.Limits import Ouroboros.Network.Util.ShowProxy -driverWithLimits :: forall ps failure bytes m. +driverWithLimits :: forall ps failure bytes m f annotator. ( MonadThrow m , Show failure , ShowProxy ps , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') + , Monoid bytes ) - => Tracer m (TraceSendRecv ps) - -> TimeoutFn m - -> Codec ps failure m bytes - -> ProtocolSizeLimits ps bytes - -> ProtocolTimeLimits ps - -> Channel m bytes - -> Driver ps (Maybe bytes) m + => Tracer m (TraceSendRecv ps) + -> TimeoutFn m + -> Codec ps failure m annotator bytes + -> ProtocolSizeLimits ps bytes + -> ProtocolTimeLimits ps + -> Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) + -- ^ project out the byte consuming function which produces + -- the decoded message + -> Driver ps (Maybe bytes) m + driverWithLimits tracer timeoutFn Codec{encode, decode} ProtocolSizeLimits{sizeLimitForState, dataSize} ProtocolTimeLimits{timeLimitForState} - channel@Channel{send} = + channel@Channel{send} runAnnotator = Driver { sendMessage, recvMessage, startDState = Nothing } where sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps). @@ -84,8 +89,8 @@ driverWithLimits tracer timeoutFn let sizeLimit = sizeLimitForState stok timeLimit = fromMaybe (-1) (timeLimitForState stok) result <- timeoutFn timeLimit $ - runDecoderWithLimit sizeLimit dataSize - channel trailing decoder + runDecoderWithLimit sizeLimit dataSize channel + runAnnotator trailing decoder case result of Just (Right x@(SomeMessage msg, _trailing')) -> do traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg)) @@ -95,17 +100,19 @@ driverWithLimits tracer timeoutFn Nothing -> throwIO (ExceededTimeLimit stok) runDecoderWithLimit - :: forall m bytes failure a. Monad m + :: forall m bytes failure annotator st. + (Monad m, Monoid bytes) => Word -- ^ message size limit -> (bytes -> Word) -- ^ byte size -> Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> Maybe bytes - -> DecodeStep bytes failure m a - -> m (Either (Maybe failure) (a, Maybe bytes)) -runDecoderWithLimit limit size Channel{recv} = - go 0 + -> DecodeStep bytes failure m (annotator st) + -> m (Either (Maybe failure) (SomeMessage st, Maybe bytes)) +runDecoderWithLimit limit size Channel{recv} runAnnotator = + go 0 <*> fromMaybe mempty where -- Our strategy here is as follows... -- @@ -121,29 +128,30 @@ runDecoderWithLimit limit size Channel{recv} = -- final chunk, we must check if it consumed too much of the final chunk. -- go :: Word -- ^ size of consumed input so far - -> Maybe bytes -- ^ any trailing data - -> DecodeStep bytes failure m a - -> m (Either (Maybe failure) (a, Maybe bytes)) + -> Maybe bytes -- ^ any queued data + -> bytes -- ^ consumed so far + -> DecodeStep bytes failure m (annotator st) + -> m (Either (Maybe failure) (SomeMessage st, Maybe bytes)) - go !sz _ (DecodeDone x trailing) + go !sz _ consumed (DecodeDone annotator trailing) | let sz' = sz - maybe 0 size trailing , sz' > limit = return (Left Nothing) - | otherwise = return (Right (x, trailing)) + | otherwise = return (Right (runAnnotator annotator consumed, trailing)) - go !_ _ (DecodeFail failure) = return (Left (Just failure)) + go !_ _ _ (DecodeFail failure) = return (Left (Just failure)) - go !sz trailing (DecodePartial k) + go !sz queued consumed (DecodePartial k) | sz > limit = return (Left Nothing) - | otherwise = case trailing of + | otherwise = case queued of Nothing -> do mbs <- recv let !sz' = sz + maybe 0 size mbs - go sz' Nothing =<< k mbs - Just bs -> do let sz' = sz + size bs - go sz' Nothing =<< k (Just bs) + go sz' Nothing (consumed <> fromMaybe mempty mbs) =<< k mbs + Just queued' -> do let sz' = sz + size queued' + go sz' Nothing (consumed <> queued') =<< k queued runPeerWithLimits - :: forall ps (st :: ps) pr failure bytes m a . + :: forall ps (st :: ps) pr failure bytes annotator m a. ( MonadAsync m , MonadFork m , MonadMask m @@ -153,17 +161,20 @@ runPeerWithLimits , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps , Show failure + , Monoid bytes ) => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes + -> Codec ps failure m annotator bytes -> ProtocolSizeLimits ps bytes -> ProtocolTimeLimits ps -> Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> Peer ps pr st m a -> m (a, Maybe bytes) -runPeerWithLimits tracer codec slimits tlimits channel peer = +runPeerWithLimits tracer codec slimits tlimits channel runAnnotator peer = withTimeoutSerial $ \timeoutFn -> - let driver = driverWithLimits tracer timeoutFn codec slimits tlimits channel + let driver = driverWithLimits tracer timeoutFn codec slimits + tlimits channel runAnnotator in runPeerWithDriver driver peer (startDState driver) @@ -175,7 +186,7 @@ runPeerWithLimits tracer codec slimits tlimits channel peer = -- 'MonadAsync' constraint. -- runPipelinedPeerWithLimits - :: forall ps (st :: ps) pr failure bytes m a. + :: forall ps (st :: ps) pr failure bytes annotator m a. ( MonadAsync m , MonadFork m , MonadMask m @@ -185,15 +196,18 @@ runPipelinedPeerWithLimits , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps , Show failure + , Monoid bytes ) => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes + -> Codec ps failure m annotator bytes -> ProtocolSizeLimits ps bytes -> ProtocolTimeLimits ps -> Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> PeerPipelined ps pr st m a -> m (a, Maybe bytes) -runPipelinedPeerWithLimits tracer codec slimits tlimits channel peer = +runPipelinedPeerWithLimits tracer codec slimits tlimits channel runAnnotator peer = withTimeoutSerial $ \timeoutFn -> - let driver = driverWithLimits tracer timeoutFn codec slimits tlimits channel + let driver = driverWithLimits tracer timeoutFn codec + slimits tlimits channel runAnnotator in runPipelinedPeerWithDriver driver peer (startDState driver) diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs index 868d93dcee8..b0925e601ec 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Driver/Simple.hs @@ -18,12 +18,10 @@ module Ouroboros.Network.Driver.Simple -- $intro -- * Normal peers runPeer - , runAnnotatedPeer , TraceSendRecv (..) , DecoderFailure (..) -- * Pipelined peers , runPipelinedPeer - , runPipelinedAnnotatedPeer -- * Connected peers -- TODO: move these to a test lib , Role (..) @@ -111,31 +109,23 @@ instance Show DecoderFailure where instance Exception DecoderFailure where -mkSimpleDriver :: forall ps failure bytes m f annotator. +mkSimpleDriver :: forall ps failure bytes m annotator. ( MonadThrow m , Show failure , forall (st :: ps). Show (ClientHasAgency st) , forall (st :: ps). Show (ServerHasAgency st) , ShowProxy ps + , Monoid bytes ) - => (forall a. - Channel m bytes - -> Maybe bytes - -> DecodeStep bytes failure m (f a) - -> m (Either failure (a, Maybe bytes)) - ) - -- ^ run incremental decoder against a channel - - -> (forall st. annotator st -> f (SomeMessage st)) - -- ^ transform annotator to a container holding the decoded - -- message - - -> Tracer m (TraceSendRecv ps) - -> Codec' ps failure m annotator bytes - -> Channel m bytes - -> Driver ps (Maybe bytes) m - -mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{send} = + => (forall st. annotator st -> bytes -> SomeMessage st) + -- ^ project out the byte consuming function which produces + -- the decoded message + -> Tracer m (TraceSendRecv ps) + -> Codec ps failure m annotator bytes + -> Channel m bytes + -> Driver ps (Maybe bytes) m + +mkSimpleDriver runAnnotator tracer Codec{encode, decode} channel@Channel{send} = Driver { sendMessage, recvMessage, startDState = Nothing } where sendMessage :: forall (pr :: PeerRole) (st :: ps) (st' :: ps). @@ -152,7 +142,7 @@ mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{s -> m (SomeMessage st, Maybe bytes) recvMessage stok trailing = do decoder <- decode stok - result <- runDecodeSteps channel trailing (nat <$> decoder) + result <- runDecoderWithChannel channel runAnnotator trailing decoder case result of Right x@(SomeMessage msg, _trailing') -> do traceWith tracer (TraceRecvMsg (AnyMessageAndAgency stok msg)) @@ -160,82 +150,29 @@ mkSimpleDriver runDecodeSteps nat tracer Codec{encode, decode} channel@Channel{s Left failure -> throwIO (DecoderFailure stok failure) - -simpleDriver :: forall ps failure bytes m. - ( MonadThrow m - , Show failure - , forall (st :: ps). Show (ClientHasAgency st) - , forall (st :: ps). Show (ServerHasAgency st) - , ShowProxy ps - ) - => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes - -> Channel m bytes - -> Driver ps (Maybe bytes) m -simpleDriver = mkSimpleDriver runDecoderWithChannel Identity - - -annotatedSimpleDriver - :: forall ps failure bytes m. - ( MonadThrow m - , Monoid bytes - , Show failure - , forall (st :: ps). Show (ClientHasAgency st) - , forall (st :: ps). Show (ServerHasAgency st) - , ShowProxy ps - ) - => Tracer m (TraceSendRecv ps) - -> AnnotatedCodec ps failure m bytes - -> Channel m bytes - -> Driver ps (Maybe bytes) m -annotatedSimpleDriver = mkSimpleDriver runAnnotatedDecoderWithChannel runAnnotator - - -- | Run a peer with the given channel via the given codec. -- -- This runs the peer to completion (if the protocol allows for termination). -- runPeer - :: forall ps (st :: ps) pr failure bytes m a . + :: forall ps (st :: ps) pr failure bytes annotator m a. ( MonadThrow m , Show failure , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps - ) - => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes - -> Channel m bytes - -> Peer ps pr st m a - -> m (a, Maybe bytes) -runPeer tracer codec channel peer = - runPeerWithDriver driver peer (startDState driver) - where - driver = simpleDriver tracer codec channel - - --- | Run a peer with the given channel via the given annotated codec. --- --- This runs the peer to completion (if the protocol allows for termination). --- -runAnnotatedPeer - :: forall ps (st :: ps) pr failure bytes m a . - ( MonadThrow m , Monoid bytes - , Show failure - , forall (st' :: ps). Show (ClientHasAgency st') - , forall (st' :: ps). Show (ServerHasAgency st') - , ShowProxy ps ) => Tracer m (TraceSendRecv ps) - -> AnnotatedCodec ps failure m bytes + -> Codec ps failure m annotator bytes -> Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> Peer ps pr st m a -> m (a, Maybe bytes) -runAnnotatedPeer tracer codec channel peer = +runPeer tracer codec channel runAnnotator peer = runPeerWithDriver driver peer (startDState driver) where - driver = annotatedSimpleDriver tracer codec channel + driver = mkSimpleDriver runAnnotator tracer codec channel -- | Run a pipelined peer with the given channel via the given codec. @@ -246,52 +183,25 @@ runAnnotatedPeer tracer codec channel peer = -- 'MonadAsync' constraint. -- runPipelinedPeer - :: forall ps (st :: ps) pr failure bytes m a. + :: forall ps (st :: ps) pr failure annotator bytes m a. ( MonadAsync m , MonadThrow m , Show failure , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps - ) - => Tracer m (TraceSendRecv ps) - -> Codec ps failure m bytes - -> Channel m bytes - -> PeerPipelined ps pr st m a - -> m (a, Maybe bytes) -runPipelinedPeer tracer codec channel peer = - runPipelinedPeerWithDriver driver peer (startDState driver) - where - driver = simpleDriver tracer codec channel - - --- | Run a pipelined peer with the given channel via the given annotated codec. --- --- This runs the peer to completion (if the protocol allows for termination). --- --- Unlike normal peers, running pipelined peers rely on concurrency, hence the --- 'MonadAsync' constraint. --- -runPipelinedAnnotatedPeer - :: forall ps (st :: ps) pr failure bytes m a. - ( MonadAsync m - , MonadThrow m , Monoid bytes - , Show failure - , forall (st' :: ps). Show (ClientHasAgency st') - , forall (st' :: ps). Show (ServerHasAgency st') - , ShowProxy ps ) => Tracer m (TraceSendRecv ps) - -> AnnotatedCodec ps failure m bytes + -> Codec ps failure m annotator bytes -> Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> PeerPipelined ps pr st m a -> m (a, Maybe bytes) -runPipelinedAnnotatedPeer tracer codec channel peer = +runPipelinedPeer tracer codec channel runAnnotator peer = runPipelinedPeerWithDriver driver peer (startDState driver) where - driver = annotatedSimpleDriver tracer codec channel - + driver = mkSimpleDriver runAnnotator tracer codec channel -- -- Utils @@ -300,38 +210,19 @@ runPipelinedAnnotatedPeer tracer codec channel peer = -- | Run a codec incremental decoder 'DecodeStep' against a channel. It also -- takes any extra input data and returns any unused trailing data. -- -runDecoderWithChannel :: Monad m +runDecoderWithChannel :: (Monad m, Monoid bytes) => Channel m bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> Maybe bytes - -> DecodeStep bytes failure m (Identity a) - -> m (Either failure (a, Maybe bytes)) + -> DecodeStep bytes failure m (annotator st) + -> m (Either failure (SomeMessage st, Maybe bytes)) -runDecoderWithChannel Channel{recv} = go +runDecoderWithChannel Channel{recv} runAnnotator = go <*> fromMaybe mempty where - go _ (DecodeDone (Identity x) trailing) = return (Right (x, trailing)) - go _ (DecodeFail failure) = return (Left failure) - go Nothing (DecodePartial k) = recv >>= k >>= go Nothing - go (Just trailing) (DecodePartial k) = k (Just trailing) >>= go Nothing - - -runAnnotatedDecoderWithChannel - :: forall m bytes failure a. - ( Monad m - , Monoid bytes - ) - => Channel m bytes - -> Maybe bytes - -> DecodeStep bytes failure m (bytes -> a) - -> m (Either failure (a, Maybe bytes)) - -runAnnotatedDecoderWithChannel Channel{recv} bs0 = go (fromMaybe mempty bs0) bs0 - where - go :: bytes -> Maybe bytes -> DecodeStep bytes failure m (bytes -> a) -> m (Either failure (a, Maybe bytes)) - go bytes _ (DecodeDone f trailing) = return $ Right (f bytes, trailing) - go _bytes _ (DecodeFail failure) = return (Left failure) - go bytes Nothing (DecodePartial k) = recv >>= \bs -> k bs >>= go (bytes <> fromMaybe mempty bs) Nothing - go bytes (Just trailing) (DecodePartial k) = k (Just trailing) >>= go (bytes <> trailing) Nothing - + go _ consumed (DecodeDone ann trailing) = return (Right (runAnnotator ann consumed, trailing)) + go _ _ (DecodeFail failure) = return (Left failure) + go Nothing consumed (DecodePartial k) = recv >>= \bs -> k bs >>= go Nothing (consumed <> fromMaybe mempty bs) + go (Just queued) consumed (DecodePartial k) = k (Just queued) >>= go Nothing (consumed <> queued) data Role = Client | Server @@ -348,19 +239,21 @@ runConnectedPeers :: ( MonadAsync m , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps + , Monoid bytes ) => m (Channel m bytes, Channel m bytes) -> Tracer m (Role, TraceSendRecv ps) - -> Codec ps failure m bytes + -> Codec ps failure m annotator bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> Peer ps pr st m a -> Peer ps (FlipAgency pr) st m b -> m (a, b) -runConnectedPeers createChannels tracer codec client server = +runConnectedPeers createChannels tracer codec runAnnotator client server = createChannels >>= \(clientChannel, serverChannel) -> - (fst <$> runPeer tracerClient codec clientChannel client) + (fst <$> runPeer tracerClient codec clientChannel runAnnotator client) `concurrently` - (fst <$> runPeer tracerServer codec serverChannel server) + (fst <$> runPeer tracerServer codec serverChannel runAnnotator server) where tracerClient = contramap ((,) Client) tracer tracerServer = contramap ((,) Server) tracer @@ -376,20 +269,22 @@ runConnectedPeersAsymmetric , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps + , Monoid bytes ) => m (Channel m bytes, Channel m bytes) -> Tracer m (Role, TraceSendRecv ps) - -> Codec ps failure m bytes - -> Codec ps failure m bytes + -> Codec ps failure m annotator bytes + -> Codec ps failure m annotator bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> Peer ps pr st m a -> Peer ps (FlipAgency pr) st m b -> m (a, b) -runConnectedPeersAsymmetric createChannels tracer codec codec' client server = +runConnectedPeersAsymmetric createChannels tracer codec codec' runAnnotator client server = createChannels >>= \(clientChannel, serverChannel) -> - (fst <$> runPeer tracerClient codec clientChannel client) + (fst <$> runPeer tracerClient codec clientChannel runAnnotator client) `concurrently` - (fst <$> runPeer tracerServer codec' serverChannel server) + (fst <$> runPeer tracerServer codec' serverChannel runAnnotator server) where tracerClient = contramap ((,) Client) tracer tracerServer = contramap ((,) Server) tracer @@ -401,20 +296,21 @@ runConnectedPeersPipelined :: ( MonadAsync m , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps + , Monoid bytes ) => m (Channel m bytes, Channel m bytes) -> Tracer m (Role, TraceSendRecv ps) - -> Codec ps failure m bytes + -> Codec ps failure m annotator bytes + -> (forall st. annotator st -> bytes -> SomeMessage st) -> PeerPipelined ps pr st m a -> Peer ps (FlipAgency pr) st m b -> m (a, b) -runConnectedPeersPipelined createChannels tracer codec client server = +runConnectedPeersPipelined createChannels tracer codec runAnnotator client server = createChannels >>= \(clientChannel, serverChannel) -> - (fst <$> runPipelinedPeer tracerClient codec clientChannel client) + (fst <$> runPipelinedPeer tracerClient codec clientChannel runAnnotator client) `concurrently` - (fst <$> runPeer tracerServer codec serverChannel server) + (fst <$> runPeer tracerServer codec serverChannel runAnnotator server) where tracerClient = contramap ((,) Client) tracer tracerServer = contramap ((,) Server) tracer - diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs b/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs index d6a9caba51d..3bd93e9b90e 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Mux.hs @@ -5,11 +5,13 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} @@ -309,27 +311,29 @@ data MiniProtocolCb ctx bytes m a where -> MiniProtocolCb ctx bytes m a MuxPeer - :: forall (pr :: PeerRole) ps (st :: ps) failure ctx bytes m a. + :: forall (pr :: PeerRole) ps (st :: ps) failure ctx bytes m annotator a. ( Show failure , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps ) => (ctx -> ( Tracer m (TraceSendRecv ps) - , Codec ps failure m bytes + , Codec ps failure m annotator bytes + , forall st. annotator st -> bytes -> SomeMessage st , Peer ps pr st m a )) -> MiniProtocolCb ctx bytes m a MuxPeerPipelined - :: forall (pr :: PeerRole) ps (st :: ps) failure ctx bytes m a. + :: forall (pr :: PeerRole) ps (st :: ps) failure ctx bytes m annotator a. ( Show failure , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps ) => (ctx -> ( Tracer m (TraceSendRecv ps) - , Codec ps failure m bytes + , Codec ps failure m annotator bytes + , forall st. annotator st -> bytes -> SomeMessage st , PeerPipelined ps pr st m a )) -> MiniProtocolCb ctx bytes m a @@ -357,15 +361,17 @@ pattern MuxPeerRaw { runMuxPeer } = MiniProtocolCb runMuxPeer -- | Create a 'MuxPeer' from a tracer, codec and 'Peer'. -- mkMiniProtocolCbFromPeer - :: forall (pr :: PeerRole) ps (st :: ps) failure bytes ctx m a. + :: forall (pr :: PeerRole) ps (st :: ps) failure bytes ctx m annotator a. ( MonadThrow m , Show failure , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps + , Monoid bytes ) => (ctx -> ( Tracer m (TraceSendRecv ps) - , Codec ps failure m bytes + , Codec ps failure m annotator bytes + , forall st. annotator st -> bytes -> SomeMessage st , Peer ps pr st m a ) ) @@ -373,23 +379,25 @@ mkMiniProtocolCbFromPeer mkMiniProtocolCbFromPeer fn = MiniProtocolCb $ \ctx channel -> case fn ctx of - (tracer, codec, peer) -> - runPeer tracer codec channel peer + (tracer, codec, runAnnotator, peer) -> + runPeer tracer codec channel runAnnotator peer -- | Create a 'MuxPeer' from a tracer, codec and 'PeerPipelined'. -- mkMiniProtocolCbFromPeerPipelined - :: forall (pr :: PeerRole) ps (st :: ps) failure ctx bytes m a. + :: forall (pr :: PeerRole) ps (st :: ps) failure ctx bytes m annotator a. ( MonadAsync m , MonadThrow m , Show failure , forall (st' :: ps). Show (ClientHasAgency st') , forall (st' :: ps). Show (ServerHasAgency st') , ShowProxy ps + , Monoid bytes ) => (ctx -> ( Tracer m (TraceSendRecv ps) - , Codec ps failure m bytes + , Codec ps failure m annotator bytes + , forall st. annotator st -> bytes -> SomeMessage st , PeerPipelined ps pr st m a ) ) @@ -397,8 +405,8 @@ mkMiniProtocolCbFromPeerPipelined mkMiniProtocolCbFromPeerPipelined fn = MiniProtocolCb $ \ctx channel -> case fn ctx of - (tracer, codec, peer) -> - runPipelinedPeer tracer codec channel peer + (tracer, codec, runAnnotator, peer) -> + runPipelinedPeer tracer codec channel runAnnotator peer -- | Run a 'MuxPeer' using supplied 'ctx' and 'Mux.Channel' diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs index 2160cfe90af..31c16b55801 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake.hs @@ -88,7 +88,7 @@ data HandshakeArguments connectionId vNumber vData m = HandshakeArguments { -- | Codec for protocol messages. -- haHandshakeCodec - :: Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m BL.ByteString, + :: Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m SomeMessage BL.ByteString, -- | A codec for protocol parameters. -- @@ -144,6 +144,7 @@ runHandshakeClient bearer byteLimitsHandshake haTimeLimits (fromChannel (muxBearerAsChannel bearer handshakeProtocolNum InitiatorDir)) + const (handshakeClientPeer haVersionDataCodec haAcceptVersion versions)) @@ -182,6 +183,7 @@ runHandshakeServer bearer byteLimitsHandshake haTimeLimits (fromChannel (muxBearerAsChannel bearer handshakeProtocolNum ResponderDir)) + const (handshakeServerPeer haVersionDataCodec haAcceptVersion haQueryVersion versions)) -- | A 20s delay after query result was send back, before we close the diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs index 8862485433f..c50305420c9 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Codec.hs @@ -127,7 +127,7 @@ codecHandshake , Show failure ) => CodecCBORTerm (failure, Maybe Int) vNumber - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure m SomeMessage ByteString codecHandshake versionNumberCodec = mkCodecCborLazyBS encodeMsg decodeMsg where encodeMsg @@ -301,7 +301,7 @@ decodeRefuseReason versionNumberCodec = do -- nodeToNodeHandshakeCodec :: MonadST m => Codec (Handshake NodeToNodeVersion CBOR.Term) - CBOR.DeserialiseFailure m BL.ByteString + CBOR.DeserialiseFailure m SomeMessage BL.ByteString nodeToNodeHandshakeCodec = codecHandshake nodeToNodeVersionCodec @@ -309,5 +309,5 @@ nodeToNodeHandshakeCodec = codecHandshake nodeToNodeVersionCodec -- nodeToClientHandshakeCodec :: MonadST m => Codec (Handshake NodeToClientVersion CBOR.Term) - CBOR.DeserialiseFailure m BL.ByteString + CBOR.DeserialiseFailure m SomeMessage BL.ByteString nodeToClientHandshakeCodec = codecHandshake nodeToClientVersionCodec diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Unversioned.hs b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Unversioned.hs index c74ec555c28..8e3b873470e 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Unversioned.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Protocol/Handshake/Unversioned.hs @@ -126,7 +126,7 @@ dataFlowProtocol dataFlow = -- unversionedHandshakeCodec :: MonadST m => Codec (Handshake UnversionedProtocol CBOR.Term) - CBOR.DeserialiseFailure m ByteString + CBOR.DeserialiseFailure m SomeMessage ByteString unversionedHandshakeCodec = codecHandshake unversionedProtocolCodec where unversionedProtocolCodec :: CodecCBORTerm (String, Maybe Int) UnversionedProtocol diff --git a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs index e90ad9b3314..cd8dc02b49a 100644 --- a/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs +++ b/ouroboros-network-framework/src/Ouroboros/Network/Socket.hs @@ -265,7 +265,7 @@ connectToNode => Snocket IO fd addr -> Mx.MakeBearer IO fd -> (fd -> IO ()) -- ^ configure a socket - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> NetworkConnectTracers addr vNumber @@ -307,7 +307,7 @@ connectToNode' ) => Snocket IO fd addr -> Mx.MakeBearer IO fd - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> NetworkConnectTracers addr vNumber @@ -372,7 +372,7 @@ connectToNodeSocket , Mx.HasInitiator appType ~ True ) => IOManager - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> NetworkConnectTracers Socket.SockAddr vNumber @@ -442,7 +442,7 @@ beginConnection => Mx.MakeBearer IO fd -> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) Mx.MuxTrace) -> Tracer IO (Mx.WithMuxBearer (ConnectionId addr) (TraceSendRecv (Handshake vNumber CBOR.Term))) - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> HandshakeCallbacks vData @@ -618,7 +618,7 @@ runServerThread -> Mx.MakeBearer IO fd -> fd -> AcceptedConnectionsLimit - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> HandshakeCallbacks vData @@ -736,7 +736,7 @@ withServerNode -> NetworkMutableState addr -> AcceptedConnectionsLimit -> addr - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> HandshakeCallbacks vData @@ -811,7 +811,7 @@ withServerNode' -- ^ a configured socket to be used be the server. The server will call -- `bind` and `listen` methods but it will not set any socket or tcp options -- on it. - -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO BL.ByteString + -> Codec (Handshake vNumber CBOR.Term) CBOR.DeserialiseFailure IO SomeMessage BL.ByteString -> ProtocolTimeLimits (Handshake vNumber CBOR.Term) -> VersionDataCodec CBOR.Term vNumber vData -> HandshakeCallbacks vData diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs index af76098adfc..76eec5ba800 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/BlockFetch/Codec.hs @@ -67,7 +67,7 @@ codecBlockFetch -> (forall s. CBOR.Decoder s block) -> (point -> CBOR.Encoding) -> (forall s. CBOR.Decoder s point) - -> Codec (BlockFetch block point) CBOR.DeserialiseFailure m LBS.ByteString + -> Codec (BlockFetch block point) CBOR.DeserialiseFailure m SomeMessage LBS.ByteString codecBlockFetch encodeBlock decodeBlock encodePoint decodePoint = mkCodecCborLazyBS encode decode @@ -124,6 +124,7 @@ codecBlockFetchId :: forall block point m. Monad m => Codec (BlockFetch block point) CodecFailure m + SomeMessage (AnyMessage (BlockFetch block point)) codecBlockFetchId = Codec { encode, decode } where diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs index 43a976f7fff..741fef6bb85 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/ChainSync/Codec.hs @@ -98,7 +98,7 @@ codecChainSync -> (tip -> CBOR.Encoding) -> (forall s. CBOR.Decoder s tip) -> Codec (ChainSync header point tip) - CBOR.DeserialiseFailure m LBS.ByteString + CBOR.DeserialiseFailure m SomeMessage LBS.ByteString codecChainSync encodeHeader decodeHeader encodePoint decodePoint encodeTip decodeTip = @@ -215,7 +215,9 @@ decodeList dec = do -- codecChainSyncId :: forall header point tip m. Monad m => Codec (ChainSync header point tip) - CodecFailure m (AnyMessage (ChainSync header point tip)) + CodecFailure m + SomeMessage + (AnyMessage (ChainSync header point tip)) codecChainSyncId = Codec { encode, decode } where encode :: forall (pr :: PeerRole) st st'. diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs index ae73bc13913..a4f40e13c39 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/KeepAlive/Codec.hs @@ -34,7 +34,7 @@ import Ouroboros.Network.Protocol.Limits codecKeepAlive_v2 :: forall m. MonadST m - => Codec KeepAlive CBOR.DeserialiseFailure m ByteString + => Codec KeepAlive CBOR.DeserialiseFailure m SomeMessage ByteString codecKeepAlive_v2 = mkCodecCborLazyBS encodeMsg decodeMsg where encodeMsg :: forall (pr :: PeerRole) st st'. @@ -96,7 +96,7 @@ codecKeepAliveId :: forall m. ( Monad m ) - => Codec KeepAlive CodecFailure m (AnyMessage KeepAlive) + => Codec KeepAlive CodecFailure m SomeMessage (AnyMessage KeepAlive) codecKeepAliveId = Codec encodeMsg decodeMsg where encodeMsg :: forall (pr :: PeerRole) st st'. diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalStateQuery/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalStateQuery/Codec.hs index 4d1bd32fb12..d9b086ffa2f 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalStateQuery/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalStateQuery/Codec.hs @@ -44,7 +44,7 @@ codecLocalStateQuery -> (forall s . CBOR.Decoder s (Some query)) -> (forall result . query result -> result -> CBOR.Encoding) -> (forall result . query result -> forall s . CBOR.Decoder s result) - -> Codec (LocalStateQuery block point query) CBOR.DeserialiseFailure m ByteString + -> Codec (LocalStateQuery block point query) CBOR.DeserialiseFailure m SomeMessage ByteString codecLocalStateQuery version encodePoint decodePoint encodeQuery decodeQuery @@ -209,6 +209,7 @@ codecLocalStateQueryId ) -> Codec (LocalStateQuery block point query) CodecFailure m + SomeMessage (AnyMessage (LocalStateQuery block point query)) codecLocalStateQueryId eqQuery = Codec { encode, decode } diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxMonitor/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxMonitor/Codec.hs index 67974580bec..f27facb3e02 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxMonitor/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxMonitor/Codec.hs @@ -37,7 +37,7 @@ codecLocalTxMonitor :: -> (forall s. CBOR.Decoder s tx) -> (slot -> CBOR.Encoding) -> (forall s. CBOR.Decoder s slot) - -> Codec (LocalTxMonitor txid tx slot) CBOR.DeserialiseFailure m ByteString + -> Codec (LocalTxMonitor txid tx slot) CBOR.DeserialiseFailure m SomeMessage ByteString codecLocalTxMonitor encodeTxId decodeTxId encodeTx decodeTx encodeSlot decodeSlot = @@ -153,7 +153,7 @@ codecLocalTxMonitorId :: ( Monad m , ptcl ~ LocalTxMonitor txid tx slot ) - => Codec ptcl CodecFailure m (AnyMessage ptcl) + => Codec ptcl CodecFailure m SomeMessage (AnyMessage ptcl) codecLocalTxMonitorId = Codec { encode, decode } where diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxSubmission/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxSubmission/Codec.hs index 536b1162787..c168640d348 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxSubmission/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/LocalTxSubmission/Codec.hs @@ -31,7 +31,7 @@ codecLocalTxSubmission -> (forall s . CBOR.Decoder s tx) -> (reject -> CBOR.Encoding) -> (forall s . CBOR.Decoder s reject) - -> Codec (LocalTxSubmission tx reject) CBOR.DeserialiseFailure m ByteString + -> Codec (LocalTxSubmission tx reject) CBOR.DeserialiseFailure m SomeMessage ByteString codecLocalTxSubmission encodeTx decodeTx encodeReject decodeReject = mkCodecCborLazyBS encode decode where @@ -85,7 +85,8 @@ codecLocalTxSubmissionId :: forall tx reject m. Monad m => Codec (LocalTxSubmission tx reject) - CodecFailure m + CodecFailure m + SomeMessage (AnyMessage (LocalTxSubmission tx reject)) codecLocalTxSubmissionId = Codec { encode, decode } diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs index c20f8c212be..41a24dfde48 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/PeerSharing/Codec.hs @@ -31,6 +31,7 @@ codecPeerSharing :: forall m peerAddress. -> Codec (PeerSharing peerAddress) CBOR.DeserialiseFailure m + SomeMessage ByteString codecPeerSharing encodeAddress decodeAddress = mkCodecCborLazyBS encodeMsg decodeMsg where @@ -86,7 +87,7 @@ codecPeerSharing encodeAddress decodeAddress = mkCodecCborLazyBS encodeMsg decod codecPeerSharingId :: forall peerAddress m. Monad m - => Codec (PeerSharing peerAddress) CodecFailure m (AnyMessage (PeerSharing peerAddress)) + => Codec (PeerSharing peerAddress) CodecFailure m SomeMessage (AnyMessage (PeerSharing peerAddress)) codecPeerSharingId = Codec encodeMsg decodeMsg where encodeMsg :: forall (pr :: PeerRole) st st'. diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Client.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Client.hs index f7a6a1f2c80..ff4b55463f3 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Client.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Client.hs @@ -29,6 +29,7 @@ module Ouroboros.Network.Protocol.TxSubmission2.Client import Network.TypedProtocol.Core +import Ouroboros.Network.SizeInBytes import Ouroboros.Network.Protocol.TxSubmission2.Type @@ -74,7 +75,7 @@ data ClientStTxIds blocking txid tx m a where data ClientStTxs txid tx m a where - SendMsgReplyTxs :: [tx] + SendMsgReplyTxs :: [WithBytes tx] -> ClientStIdle txid tx m a -> ClientStTxs txid tx m a diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs index 4a91658de5c..08e71902915 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Codec.hs @@ -17,17 +17,21 @@ module Ouroboros.Network.Protocol.TxSubmission2.Codec import Control.Monad.Class.MonadST import Control.Monad.Class.MonadTime.SI +import Data.Functor import Data.List.NonEmpty qualified as NonEmpty import Codec.CBOR.Decoding qualified as CBOR import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Data.ByteString.Lazy (ByteString) +import Data.ByteString.Lazy qualified as BSL +import Data.ByteString.Short (toShort) import Text.Printf import Network.TypedProtocol.Codec.CBOR import Ouroboros.Network.Protocol.Limits +import Ouroboros.Network.SizeInBytes import Ouroboros.Network.Protocol.TxSubmission2.Type -- | Byte Limits. @@ -70,7 +74,7 @@ codecTxSubmission2 -> (forall s . CBOR.Decoder s txid) -> (tx -> CBOR.Encoding) -> (forall s . CBOR.Decoder s tx) - -> AnnotatedCodec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m ByteString + -> Codec (TxSubmission2 txid tx) CBOR.DeserialiseFailure m (Annotator ByteString) ByteString codecTxSubmission2 encodeTxId decodeTxId encodeTx decodeTx = mkCodecCborLazyBS @@ -141,7 +145,7 @@ encodeTxSubmission2 encodeTxId encodeTx = encode CBOR.encodeListLen 2 <> CBOR.encodeWord 3 <> CBOR.encodeListLenIndef - <> foldr (\txid r -> encodeTx txid <> r) CBOR.encodeBreak txs + <> foldr (\(WithBytes { wbValue = tx }) r -> encodeTx tx <> r) CBOR.encodeBreak txs encode (ClientAgency (TokTxIds TokBlocking)) MsgDone = CBOR.encodeListLen 1 @@ -167,15 +171,15 @@ decodeTxSubmission2 decodeTxId decodeTx = decode decode stok len key = do case (stok, len, key) of (ClientAgency TokInit, 1, 6) -> - return (Annotator $ \_ -> SomeMessage MsgInit) + return . Annotator . const $ SomeMessage MsgInit (ServerAgency TokIdle, 4, 0) -> do blocking <- CBOR.decodeBool ackNo <- NumTxIdsToAck <$> CBOR.decodeWord16 reqNo <- NumTxIdsToReq <$> CBOR.decodeWord16 - return $! + return . Annotator . const $ if blocking - then Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokBlocking ackNo reqNo) - else Annotator $ \_ -> SomeMessage (MsgRequestTxIds TokNonBlocking ackNo reqNo) + then SomeMessage $ MsgRequestTxIds TokBlocking ackNo reqNo + else SomeMessage $ MsgRequestTxIds TokNonBlocking ackNo reqNo (ClientAgency (TokTxIds b), 2, 1) -> do CBOR.decodeListLenIndef @@ -185,14 +189,13 @@ decodeTxSubmission2 decodeTxId decodeTx = decode txid <- decodeTxId sz <- CBOR.decodeWord32 return (txid, SizeInBytes sz)) + case (b, txids) of (TokBlocking, t:ts) -> - return $ Annotator $ \_ -> - SomeMessage (MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts))) + return . Annotator . const . SomeMessage $ MsgReplyTxIds (BlockingReply (t NonEmpty.:| ts)) (TokNonBlocking, ts) -> - return $ Annotator $ \_ -> - SomeMessage (MsgReplyTxIds (NonBlockingReply ts)) + return . Annotator . const . SomeMessage $ MsgReplyTxIds (NonBlockingReply ts) (TokBlocking, []) -> fail "codecTxSubmission: MsgReplyTxIds: empty list not permitted" @@ -201,26 +204,21 @@ decodeTxSubmission2 decodeTxId decodeTx = decode (ServerAgency TokIdle, 2, 2) -> do CBOR.decodeListLenIndef txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTxId - return (Annotator $ \_ -> SomeMessage (MsgRequestTxs txids)) + return . Annotator . const . SomeMessage $ MsgRequestTxs txids (ClientAgency TokTxs, 2, 3) -> do CBOR.decodeListLenIndef - txids <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse decodeTx - -- ^ TODO: `txids -> txs` :grin: - return (Annotator $ - -- TODO: here we have access to bytes from which the message was decoded. - -- we can use `Codec.CBOR.Decoding.decodeWithByteSpan` - -- around each `tx` and wrap each `tx` in `WithBytes`. - -- - -- `decodeTxSubmission2` can be polymorphic by adding an - -- extra argument of type - -- `ByteString -> ByteOffSet -> ByteOffset -> tx -> a` - -- this way we could wrap `tx` in `WithBytes` or just - -- return `tx`. - \_bytes -> SomeMessage (MsgReplyTxs txids)) + txs <- CBOR.decodeSequenceLenIndef (flip (:)) [] reverse (CBOR.decodeWithByteSpan decodeTx) + return $ Annotator (\bytes -> + SomeMessage . MsgReplyTxs $ + txs <&> \(tx, start, end) -> + let slice = BSL.take (end - start ) $ BSL.drop start bytes + in WithBytes { + wbValue = tx, + unannotate = toShort . BSL.toStrict $ slice }) (ClientAgency (TokTxIds TokBlocking), 1, 4) -> - return (Annotator $ \_ -> SomeMessage MsgDone) + return . Annotator . const $ SomeMessage MsgDone -- -- failures per protocol state @@ -239,7 +237,7 @@ decodeTxSubmission2 decodeTxId decodeTx = decode codecTxSubmission2Id :: forall txid tx m. Monad m - => Codec (TxSubmission2 txid tx) CodecFailure m (AnyMessage (TxSubmission2 txid tx)) + => Codec (TxSubmission2 txid tx) CodecFailure m SomeMessage (AnyMessage (TxSubmission2 txid tx)) codecTxSubmission2Id = Codec { encode, decode } where encode :: forall (pr :: PeerRole) st st'. diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Server.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Server.hs index c29517ca208..a5d71d87f4c 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Server.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Server.hs @@ -28,6 +28,7 @@ import Data.List.NonEmpty (NonEmpty) import Network.TypedProtocol.Core import Network.TypedProtocol.Pipelined +import Ouroboros.Network.SizeInBytes import Ouroboros.Network.Protocol.TxSubmission2.Type @@ -50,7 +51,7 @@ data Collect txid tx = -- contains the transactions sent, but this pairs them up with the -- transactions requested. This is because the peer can determine that -- some transactions are no longer needed. - | CollectTxs [txid] [tx] + | CollectTxs [txid] [WithBytes tx] data ServerStIdle (n :: N) txid tx m a where diff --git a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Type.hs b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Type.hs index 220f1f2ab31..44d596ad7b2 100644 --- a/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Type.hs +++ b/ouroboros-network-protocols/src/Ouroboros/Network/Protocol/TxSubmission2/Type.hs @@ -43,7 +43,7 @@ import Quiet (Quiet (..)) import Network.TypedProtocol.Core -import Ouroboros.Network.SizeInBytes (SizeInBytes (..)) +import Ouroboros.Network.SizeInBytes import Ouroboros.Network.Util.ShowProxy -- | Transactions are typically not big, but in principle in future we could @@ -262,7 +262,7 @@ instance Protocol (TxSubmission2 txid tx) where -- be valid and available from another peer). -- MsgReplyTxs - :: [tx] + :: [WithBytes tx] -> Message (TxSubmission2 txid tx) StTxs StIdle -- | Termination message, initiated by the client when the server is diff --git a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs index cf216548a02..cd69c49df09 100644 --- a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs +++ b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Inbound.hs @@ -41,6 +41,7 @@ import Network.TypedProtocol.Pipelined (N, Nat (..), natToInt) import Ouroboros.Network.NodeToNode.Version (NodeToNodeVersion) import Ouroboros.Network.Protocol.TxSubmission2.Server import Ouroboros.Network.Protocol.TxSubmission2.Type +import Ouroboros.Network.SizeInBytes import Ouroboros.Network.TxSubmission.Mempool.Reader (MempoolSnapshot (..), TxSubmissionMempoolReader (..)) @@ -304,7 +305,7 @@ txSubmissionInbound tracer (NumTxIdsToAck maxUnacked) mpReader mpWriter _version -- approach to this and check it. -- let txsMap :: Map txid tx - txsMap = Map.fromList [ (txId tx, tx) | tx <- txs ] + txsMap = Map.fromList [(txId tx, tx) | WithBytes { wbValue = tx } <- txs] txidsReceived = Map.keysSet txsMap txidsRequested = Set.fromList txids diff --git a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Reader.hs b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Reader.hs index 771bbbfe6b1..54b7db67fe2 100644 --- a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Reader.hs +++ b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Mempool/Reader.hs @@ -6,7 +6,7 @@ module Ouroboros.Network.TxSubmission.Mempool.Reader ) where import Control.Monad.Class.MonadSTM (MonadSTM, STM) -import Ouroboros.Network.SizeInBytes (SizeInBytes) +import Ouroboros.Network.SizeInBytes -- | The consensus layer functionality that the inbound and outbound side of -- the tx submission logic requires. @@ -28,7 +28,7 @@ data TxSubmissionMempoolReader txid tx idx m = mapTxSubmissionMempoolReader :: MonadSTM m - => (tx -> tx') + => (WithBytes tx -> WithBytes tx') -> TxSubmissionMempoolReader txid tx idx m -> TxSubmissionMempoolReader txid tx' idx m mapTxSubmissionMempoolReader f rdr = @@ -53,12 +53,12 @@ mapTxSubmissionMempoolReader f rdr = data MempoolSnapshot txid tx idx = MempoolSnapshot { mempoolTxIdsAfter :: idx -> [(txid, idx, SizeInBytes)], - mempoolLookupTx :: idx -> Maybe tx, + mempoolLookupTx :: idx -> Maybe (WithBytes tx), mempoolHasTx :: txid -> Bool } mapMempoolSnapshot :: - (tx -> tx') + (WithBytes tx -> WithBytes tx') -> MempoolSnapshot txid tx idx -> MempoolSnapshot txid tx' idx mapMempoolSnapshot f snap = diff --git a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Outbound.hs b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Outbound.hs index d5cac825788..79da9dd3e3c 100644 --- a/ouroboros-network/src/Ouroboros/Network/TxSubmission/Outbound.hs +++ b/ouroboros-network/src/Ouroboros/Network/TxSubmission/Outbound.hs @@ -29,6 +29,7 @@ import Ouroboros.Network.Protocol.TxSubmission2.Client import Ouroboros.Network.Protocol.TxSubmission2.Type import Ouroboros.Network.TxSubmission.Mempool.Reader (MempoolSnapshot (..), TxSubmissionMempoolReader (..)) +import Ouroboros.Network.SizeInBytes (WithBytes(..)) data TraceTxSubmissionOutbound txid tx @@ -187,6 +188,6 @@ txSubmissionOutbound tracer maxUnacked TxSubmissionMempoolReader{..} _version co client' = client unackedSeq lastIdx -- Trace the transactions to be sent in the response. - traceWith tracer (TraceTxSubmissionOutboundSendMsgReplyTxs txs) + traceWith tracer (TraceTxSubmissionOutboundSendMsgReplyTxs $ wbValue <$> txs) - return $ SendMsgReplyTxs txs client' + return $ SendMsgReplyTxs undefined client'