diff --git a/cabal.project b/cabal.project index 94906ef8..7fa242e0 100644 --- a/cabal.project +++ b/cabal.project @@ -6,5 +6,8 @@ package grapesy source-repository-package type: git - location: https://github.com/kazu-yamamoto/http2.git - tag: 059b24427ef33e7a0f8cb1a06dcf229590bd2d48 + -- https://github.com/kazu-yamamoto/http2/pull/81 + location: https://github.com/edsko/http2.git + tag: 9e7713bedc4788c0d117c3abea9d3dc5e046c377 + -- location: https://github.com/kazu-yamamoto/http2.git + -- tag: 059b24427ef33e7a0f8cb1a06dcf229590bd2d48 diff --git a/demo-client/Demo/Client/API/Core/NoFinal/Greeter.hs b/demo-client/Demo/Client/API/Core/NoFinal/Greeter.hs index d33ba307..96848786 100644 --- a/demo-client/Demo/Client/API/Core/NoFinal/Greeter.hs +++ b/demo-client/Demo/Client/API/Core/NoFinal/Greeter.hs @@ -2,12 +2,11 @@ module Demo.Client.API.Core.NoFinal.Greeter ( sayHello ) where -import Control.Concurrent.STM import Data.Default import Data.Proxy import Network.GRPC.Client -import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Common import Proto.Helloworld @@ -18,8 +17,8 @@ import Proto.Helloworld sayHello :: Connection -> HelloRequest -> IO () sayHello conn n = withRPC conn def (Proxy @(Protobuf Greeter "sayHello")) $ \call -> do - atomically $ sendInput call $ StreamElem n - out <- atomically $ recvOutput call - trailers <- atomically $ recvOutput call + sendInput call $ StreamElem n + out <- recvOutput call + trailers <- recvOutput call print (out, trailers) diff --git a/demo-client/Demo/Client/API/Protobuf/Pipes/RouteGuide.hs b/demo-client/Demo/Client/API/Protobuf/Pipes/RouteGuide.hs index 2397b7eb..2f17a82d 100644 --- a/demo-client/Demo/Client/API/Protobuf/Pipes/RouteGuide.hs +++ b/demo-client/Demo/Client/API/Protobuf/Pipes/RouteGuide.hs @@ -13,7 +13,7 @@ import Pipes.Safe import Network.GRPC.Client import Network.GRPC.Client.StreamType.Pipes -import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Common import Proto.RouteGuide @@ -34,23 +34,23 @@ listFeatures conn r = runSafeT . runEffect $ recordRoute :: Connection - -> Producer' (StreamElem () Point) (SafeT IO) () + -> Producer' (StreamElem NoMetadata Point) (SafeT IO) () -> IO () recordRoute conn ps = runSafeT . runEffect $ ps >-> (cons >>= logMsg) where - cons :: Consumer' (StreamElem () Point) (SafeT IO) RouteSummary + cons :: Consumer' (StreamElem NoMetadata Point) (SafeT IO) RouteSummary cons = clientStreaming conn def (Proxy @(Protobuf RouteGuide "recordRoute")) routeChat :: Connection - -> Producer' (StreamElem () RouteNote) IO () + -> Producer' (StreamElem NoMetadata RouteNote) IO () -> IO () routeChat conn ns = biDiStreaming conn def (Proxy @(Protobuf RouteGuide "routeChat")) aux where aux :: - Consumer' (StreamElem () RouteNote) IO () + Consumer' (StreamElem NoMetadata RouteNote) IO () -> Producer' RouteNote IO () -> IO () aux cons prod = diff --git a/demo-client/Demo/Client/API/Protobuf/RouteGuide.hs b/demo-client/Demo/Client/API/Protobuf/RouteGuide.hs index 85bc9fe5..dff79444 100644 --- a/demo-client/Demo/Client/API/Protobuf/RouteGuide.hs +++ b/demo-client/Demo/Client/API/Protobuf/RouteGuide.hs @@ -7,7 +7,7 @@ module Demo.Client.API.Protobuf.RouteGuide ( import Network.GRPC.Client import Network.GRPC.Client.StreamType -import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Common import Network.GRPC.Common.StreamType import Proto.RouteGuide @@ -29,13 +29,13 @@ listFeatures conn rect = do serverStreaming (rpc @(Protobuf RouteGuide "listFeatures") conn) rect $ logMsg -recordRoute :: Connection -> IO (StreamElem () Point) -> IO () +recordRoute :: Connection -> IO (StreamElem NoMetadata Point) -> IO () recordRoute conn getPoint = do summary <- clientStreaming (rpc @(Protobuf RouteGuide "recordRoute") conn) getPoint logMsg summary -routeChat :: Connection -> IO (StreamElem () RouteNote) -> IO () +routeChat :: Connection -> IO (StreamElem NoMetadata RouteNote) -> IO () routeChat conn getNote = do biDiStreaming (rpc @(Protobuf RouteGuide "routeChat") conn) getNote $ logMsg diff --git a/demo-client/Demo/Client/Util/DelayOr.hs b/demo-client/Demo/Client/Util/DelayOr.hs index 58dd69b4..93e26b86 100644 --- a/demo-client/Demo/Client/Util/DelayOr.hs +++ b/demo-client/Demo/Client/Util/DelayOr.hs @@ -11,7 +11,7 @@ import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty qualified as NE import Pipes -import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Common import Demo.Common.Logging @@ -24,13 +24,15 @@ isDelay :: DelayOr a -> Either Double a isDelay (Delay d) = Left d isDelay (Exec a) = Right a -execAll :: forall a. Show a => [DelayOr a] -> IO (IO (StreamElem () a)) +execAll :: forall a. Show a => [DelayOr a] -> IO (IO (StreamElem NoMetadata a)) execAll = fmap (flip modifyMVar getNext) . newMVar . alternating . map isDelay where - getNext :: AltLists Double a -> IO (AltLists Double a, (StreamElem () a)) + getNext :: + AltLists Double a + -> IO (AltLists Double a, (StreamElem NoMetadata a)) getNext (Alternating Nil) = - return (Alternating Nil, NoMoreElems ()) + return (Alternating Nil, NoMoreElems NoMetadata) getNext (Alternating (Lft ds xss)) = do let d = sum ds traceWith threadSafeTracer $ "Delay " ++ show d ++ "s" @@ -48,14 +50,14 @@ execAll = yieldAll :: forall a m. (MonadIO m, Show a) - => [DelayOr a] -> Producer' (StreamElem () a) m () + => [DelayOr a] -> Producer' (StreamElem NoMetadata a) m () yieldAll = withAlternating go . alternating . map isDelay where go :: Alt d (NonEmpty Double) (NonEmpty a) - -> Producer' (StreamElem () a) m () + -> Producer' (StreamElem NoMetadata a) m () go Nil = - yield $ NoMoreElems () + yield $ NoMoreElems NoMetadata go (Lft ds xss) = do let d = sum ds liftIO $ do @@ -73,9 +75,9 @@ yieldAll = withAlternating go . alternating . map isDelay checkIsFinal :: NonEmpty a -> Alt L (NonEmpty Double) (NonEmpty a) - -> StreamElem () a -checkIsFinal (a :| []) Nil = FinalElem a () -checkIsFinal (a :| []) (Lft _ Nil) = FinalElem a () + -> StreamElem NoMetadata a +checkIsFinal (a :| []) Nil = FinalElem a NoMetadata +checkIsFinal (a :| []) (Lft _ Nil) = FinalElem a NoMetadata checkIsFinal (a :| []) (Lft _ (Rgt _ _)) = StreamElem a checkIsFinal (a :| (_ : _)) _ = StreamElem a diff --git a/demo-server/Demo/Server/Service/Greeter.hs b/demo-server/Demo/Server/Service/Greeter.hs index a2cce348..43646ff7 100644 --- a/demo-server/Demo/Server/Service/Greeter.hs +++ b/demo-server/Demo/Server/Service/Greeter.hs @@ -12,7 +12,7 @@ import Data.ProtoLens.Labels () import Data.Proxy import Data.Text (Text) -import Network.GRPC.Common.CustomMetadata (CustomMetadata(..)) +import Network.GRPC.Common import Network.GRPC.Common.StreamType import Network.GRPC.Server import Network.GRPC.Server.Protobuf diff --git a/demo-server/Demo/Server/Service/RouteGuide.hs b/demo-server/Demo/Server/Service/RouteGuide.hs index 54d5b3cd..f5151d5c 100644 --- a/demo-server/Demo/Server/Service/RouteGuide.hs +++ b/demo-server/Demo/Server/Service/RouteGuide.hs @@ -15,7 +15,7 @@ import Data.ProtoLens.Labels () import Data.Proxy import Data.Time -import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Common import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Common.StreamType import Network.GRPC.Server @@ -54,7 +54,7 @@ getFeature db p = return $ fromMaybe defMessage $ featureAt db p listFeatures :: [Feature] -> Rectangle -> (Feature -> IO ()) -> IO () listFeatures db r send = mapM_ send $ filter (inRectangle r . view #location) db -recordRoute :: [Feature] -> IO (StreamElem () Point) -> IO RouteSummary +recordRoute :: [Feature] -> IO (StreamElem NoMetadata Point) -> IO RouteSummary recordRoute db recv = do start <- getCurrentTime ps <- StreamElem.collect recv @@ -63,7 +63,7 @@ recordRoute db recv = do routeChat :: [Feature] - -> IO (StreamElem () RouteNote) + -> IO (StreamElem NoMetadata RouteNote) -> (RouteNote -> IO ()) -> IO () routeChat _db recv send = do diff --git a/demo-server/Main.hs b/demo-server/Main.hs index c16dedbb..0b31b2d2 100644 --- a/demo-server/Main.hs +++ b/demo-server/Main.hs @@ -68,13 +68,10 @@ getRouteGuideDb = do Nothing -> error "Could not parse the route guide DB" serverParams :: Cmdline -> ServerParams -serverParams cmd = ServerParams { +serverParams cmd = def { serverDebugTracer = if cmdDebug cmd then contramap show threadSafeTracer else serverDebugTracer def - - , serverCompression = - def } diff --git a/grapesy.cabal b/grapesy.cabal index a5bc8037..ea944485 100644 --- a/grapesy.cabal +++ b/grapesy.cabal @@ -36,6 +36,7 @@ common lang DeriveGeneric DeriveTraversable DerivingStrategies + DerivingVia DisambiguateRecordFields EmptyCase FlexibleContexts @@ -77,10 +78,9 @@ library Network.GRPC.Client.Binary Network.GRPC.Client.StreamType Network.GRPC.Client.StreamType.Pipes + Network.GRPC.Common Network.GRPC.Common.Binary Network.GRPC.Common.Compression - Network.GRPC.Common.CustomMetadata - Network.GRPC.Common.Exceptions Network.GRPC.Common.StreamElem Network.GRPC.Common.StreamType Network.GRPC.Server @@ -88,6 +88,7 @@ library Network.GRPC.Server.Protobuf Network.GRPC.Server.Run Network.GRPC.Server.StreamType + Network.GRPC.Spec other-modules: Network.GRPC.Client.Call Network.GRPC.Client.Connection @@ -98,7 +99,7 @@ library Network.GRPC.Server.Context Network.GRPC.Server.Handler Network.GRPC.Server.Session - Network.GRPC.Spec + Network.GRPC.Spec.Call Network.GRPC.Spec.Common Network.GRPC.Spec.Compression Network.GRPC.Spec.CustomMetadata @@ -110,6 +111,9 @@ library Network.GRPC.Spec.RPC Network.GRPC.Spec.RPC.Binary Network.GRPC.Spec.RPC.Protobuf + Network.GRPC.Spec.RPC.StreamType + Network.GRPC.Spec.Status + Network.GRPC.Spec.Timeout Network.GRPC.Util.AccumulatedByteString Network.GRPC.Util.ByteString Network.GRPC.Util.HTTP2 @@ -117,6 +121,7 @@ library Network.GRPC.Util.HTTP2.TLS Network.GRPC.Util.Parser Network.GRPC.Util.Partial + Network.GRPC.Util.PrettyVal Network.GRPC.Util.RedundantConstraint Network.GRPC.Util.Session Network.GRPC.Util.Session.API @@ -156,6 +161,7 @@ library , network-run >= 0.2 && < 0.3 , pipes >= 4.3 && < 4.4 , pipes-safe >= 2.3 && < 2.4 + , pretty-show , proto-lens >= 0.7 && < 0.8 , sop-core >= 0.5 && < 0.6 , stm >= 2.5 && < 2.6 @@ -205,8 +211,15 @@ test-suite test-grapesy Paths_grapesy Test.Driver.ClientServer Test.Driver.Dialogue + Test.Driver.Dialogue.Definition + Test.Driver.Dialogue.Execution + Test.Driver.Dialogue.Generation + Test.Driver.Dialogue.TestClock + Test.Prop.Dialogue + Test.Prop.Serialization Test.Sanity.StreamingType.NonStreaming Test.Util.ClientServer + Test.Util.PrettyVal build-depends: -- Internal dependencies , grapesy @@ -218,12 +231,14 @@ test-suite test-grapesy , contra-tracer >= 0.2 && < 0.3 , data-default >= 0.7 && < 0.8 , exceptions >= 0.10 && < 0.11 + , http2 , mtl >= 2.2 && < 2.4 + , pretty-show , QuickCheck >= 2.14 && < 2.15 , stm >= 2.5 && < 2.6 , tasty >= 1.4 && < 1.5 , tasty-hunit >= 0.10 && < 0.11 --- , tasty-quickcheck >= 0.10 && < 0.11 + , tasty-quickcheck >= 0.10 && < 0.11 , text >= 1.2 && < 2.1 , tls >= 1.5 && < 1.8 @@ -341,6 +356,7 @@ executable test-stress , containers >= 0.6 && < 0.7 , contra-tracer >= 0.2 && < 0.3 , data-default >= 0.7 && < 0.8 + , http2 , optparse-applicative >= 0.16 && < 0.19 , text >= 1.2 && < 2.1 , tls >= 1.5 && < 1.8 diff --git a/src/Network/GRPC/Client.hs b/src/Network/GRPC/Client.hs index 750f60c7..0cb5fad4 100644 --- a/src/Network/GRPC/Client.hs +++ b/src/Network/GRPC/Client.hs @@ -21,8 +21,6 @@ module Network.GRPC.Client ( -- * Make RPCs , Call -- opaque , withRPC - , startRPC - , abortRPC -- ** Call parameters , CallParams(..) @@ -46,88 +44,21 @@ module Network.GRPC.Client ( , recvFinalOutput , recvAllOutputs + -- ** Low-level API + , sendInputSTM + , recvOutputSTM + , startRPC + , closeRPC + -- * Common serialization formats , Protobuf ) where -import Control.Monad.Catch -import Control.Monad.IO.Class -import Data.Proxy -import GHC.Stack - import Network.GRPC.Client.Call import Network.GRPC.Client.Connection import Network.GRPC.Spec -import Network.GRPC.Spec.PseudoHeaders (Scheme(..), Authority(..)) -import Network.GRPC.Spec.RPC -import Network.GRPC.Spec.RPC.Protobuf (Protobuf) import Network.GRPC.Util.TLS qualified as Util.TLS -{------------------------------------------------------------------------------- - Make RPCs --------------------------------------------------------------------------------} - --- | Start RPC call --- --- This is a non-blocking call; the connection will be set up in a background --- thread; if this takes time, then the first call to 'sendInput' or --- 'recvOutput' will block, but the call to 'startRPC' itself will not block. --- This non-blocking nature makes this safe to use in 'bracket' patterns. --- --- This is a low-level API. Consider using 'withRPC' instead. -startRPC :: - IsRPC rpc - => Connection -> CallParams -> Proxy rpc -> IO (Call rpc) -startRPC = initiateCall - --- | Stop RPC call --- --- This is a low-level API. Consider using 'withRPC' instead. --- --- TODO: Say say something about why this is called abort. --- --- NOTE: When an 'Call' is closed, it remembers the /reason/ why it was closed. --- For example, if it is closed due to a network exception, then this network --- exception is recorded as part of this reason. When the call is closed due to --- call to 'stopRPC', we use the 'CallAborted' exception. /If/ the call to --- 'stopRPC' /itself/ was due to an exception, this exception will not be --- recorded; if that is undesirable, consider using 'withRPC' instead. -abortRPC :: HasCallStack => Call rpc -> IO () -abortRPC call = abortCall call $ CallAborted callStack - --- | Scoped RPC call --- --- May throw --- --- * 'RpcImmediateError' when we fail to establish the call --- * 'CallClosed' when attempting to send or receive data an a closed call. -withRPC :: forall m rpc a. - (MonadMask m, MonadIO m, IsRPC rpc) - => Connection -> CallParams -> Proxy rpc -> (Call rpc -> m a) -> m a -withRPC conn params proxy = fmap aux . - generalBracket - (liftIO $ startRPC conn params proxy) - (\call -> liftIO . \case - ExitCaseSuccess _ -> abortCall call $ CallAborted callStack - ExitCaseException e -> abortCall call e - ExitCaseAbort -> abortCall call UnknownException - ) - where - aux :: (a, ()) -> a - aux = fst - --- | Exception corresponding to 'stopRPC' --- --- We record the callstack to 'stopRPC'. -data CallAborted = CallAborted CallStack - deriving stock (Show) - deriving anyclass (Exception) - --- | Exception corresponding to the 'ExitCaseAbort' case of 'ExitCase' -data UnknownException = UnknownException - deriving stock (Show) - deriving anyclass (Exception) - {------------------------------------------------------------------------------- Ongoing calls -------------------------------------------------------------------------------} diff --git a/src/Network/GRPC/Client/Binary.hs b/src/Network/GRPC/Client/Binary.hs index a1522113..25abeb1c 100644 --- a/src/Network/GRPC/Client/Binary.hs +++ b/src/Network/GRPC/Client/Binary.hs @@ -11,14 +11,13 @@ module Network.GRPC.Client.Binary ( , recvFinalOutput ) where -import Control.Concurrent.STM +import Control.Monad.IO.Class import Data.Binary import Network.GRPC.Client (Call) import Network.GRPC.Client qualified as Client import Network.GRPC.Common.Binary -import Network.GRPC.Common.CustomMetadata -import Network.GRPC.Common.StreamElem +import Network.GRPC.Common {------------------------------------------------------------------------------- Convenience wrappers using @binary@ for serialization/deserialization @@ -33,33 +32,32 @@ import Network.GRPC.Common.StreamElem argument, to facilitate the use of type arguments. -------------------------------------------------------------------------------} -sendInput :: forall inp serv meth. - Binary inp +sendInput :: forall inp serv meth m. + (Binary inp, MonadIO m) => Call (BinaryRpc serv meth) - -> StreamElem () inp - -> IO () -sendInput call inp = atomically $ - Client.sendInput call (encode <$> inp) + -> StreamElem NoMetadata inp + -> m () +sendInput call inp = Client.sendInput call (encode <$> inp) -sendFinalInput :: forall inp serv meth. - Binary inp +sendFinalInput :: forall inp serv meth m. + (Binary inp, MonadIO m) => Call (BinaryRpc serv meth) -> inp - -> IO () + -> m () sendFinalInput call inp = Client.sendFinalInput call (encode inp) -recvOutput :: forall out serv meth. - Binary out +recvOutput :: forall out serv meth m. + (Binary out, MonadIO m) => Call (BinaryRpc serv meth) - -> IO (StreamElem [CustomMetadata] out) -recvOutput call = - atomically $ Client.recvOutput call >>= traverse decodeOrThrow + -> m (StreamElem [CustomMetadata] out) +recvOutput call = liftIO $ + Client.recvOutput call >>= traverse decodeOrThrow -recvFinalOutput :: forall out serv meth. - Binary out +recvFinalOutput :: forall out serv meth m. + (Binary out, MonadIO m) => Call (BinaryRpc serv meth) - -> IO (out, [CustomMetadata]) -recvFinalOutput call = do + -> m (out, [CustomMetadata]) +recvFinalOutput call = liftIO $ do (out, md) <- Client.recvFinalOutput call (, md) <$> decodeOrThrow out diff --git a/src/Network/GRPC/Client/Call.hs b/src/Network/GRPC/Client/Call.hs index fd75e62a..226b8a5c 100644 --- a/src/Network/GRPC/Client/Call.hs +++ b/src/Network/GRPC/Client/Call.hs @@ -6,8 +6,7 @@ module Network.GRPC.Client.Call ( Call -- opaque -- * Construction - , initiateCall - , abortCall + , withRPC -- * Open (ongoing) call , sendInput @@ -19,10 +18,19 @@ module Network.GRPC.Client.Call ( , sendAllInputs , recvFinalOutput , recvAllOutputs + + -- ** Low-level API + , sendInputSTM + , recvOutputSTM + , waitForOutbound + , startRPC + , closeRPC ) where import Control.Concurrent.STM import Control.Exception +import Control.Monad +import Control.Monad.Catch import Control.Monad.IO.Class import Control.Tracer import Data.Bifunctor @@ -35,14 +43,10 @@ import Network.GRPC.Client.Connection (Connection, ConnParams (..)) import Network.GRPC.Client.Connection qualified as Connection import Network.GRPC.Client.Meta qualified as Meta import Network.GRPC.Client.Session -import Network.GRPC.Common.Compression (Compression(..)) +import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compression -import Network.GRPC.Common.Exceptions -import Network.GRPC.Common.StreamElem (StreamElem(..)) import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Spec -import Network.GRPC.Spec.CustomMetadata -import Network.GRPC.Spec.RPC import Network.GRPC.Util.Session qualified as Session {------------------------------------------------------------------------------- @@ -61,10 +65,33 @@ data Call rpc = IsRPC rpc => Call { Open a call -------------------------------------------------------------------------------} -initiateCall :: forall rpc. +-- | Scoped RPC call +withRPC :: forall m rpc a. + (MonadMask m, MonadIO m, IsRPC rpc) + => Connection -> CallParams -> Proxy rpc -> (Call rpc -> m a) -> m a +withRPC conn params proxy k = + (throwUnclean =<<) $ + generalBracket + (liftIO $ startRPC conn params proxy) + (\call -> liftIO . closeRPC call) + k + where + throwUnclean :: (a, Maybe ChannelUncleanClose) -> m a + throwUnclean (_, Just err) = throwM err + throwUnclean (x, Nothing) = return x + +-- | Start RPC call +-- +-- This is a non-blocking call; the connection will be set up in a background +-- thread; if this takes time, then the first call to 'sendInput' or +-- 'recvOutput' will block, but the call to 'startRPC' itself will not block. +-- This non-blocking nature makes this safe to use in 'bracket' patterns. +-- +-- This is a low-level API. Consider using 'withRPC' instead. +startRPC :: forall rpc. IsRPC rpc => Connection -> CallParams -> Proxy rpc -> IO (Call rpc) -initiateCall conn callParams _proxy = do +startRPC conn callParams _proxy = do cOut <- Meta.outboundCompression <$> Connection.currentMeta conn callChannel <- @@ -74,7 +101,7 @@ initiateCall conn callParams _proxy = do (Connection.connectionToServer conn) (Session.FlowStartRegular $ OutboundHeaders { outHeaders = requestHeaders cOut - , outCompression = fromMaybe Compression.identity cOut + , outCompression = fromMaybe noCompression cOut }) return Call{callSession, callChannel} where @@ -105,15 +132,15 @@ initiateCall conn callParams _proxy = do contramap (Connection.PeerDebugMsg @rpc) $ connDebugTracer (Connection.params conn) -{------------------------------------------------------------------------------- - Closing an open RPC --------------------------------------------------------------------------------} - --- | Abort an open RPC call +-- | Close an open RPC call +-- +-- This is a low-level API; most users should use 'withRPC' instead. -- --- TODO: Docs. -abortCall :: Exception e => Call rpc -> HasCallStack => e -> IO () -abortCall = Session.close . callChannel +-- See 'Session.close' for detailed discussion. +closeRPC :: + HasCallStack + => Call rpc -> ExitCase a -> IO (Maybe ChannelUncleanClose) +closeRPC = Session.close . callChannel {------------------------------------------------------------------------------- Open (ongoing) call @@ -121,37 +148,29 @@ abortCall = Session.close . callChannel -- | Send an input to the peer -- --- This lives in @STM@ for improved composability. For example, if the peer is --- currently busy then 'sendInput' will block, but you can then use 'orElse' to --- provide an alternative codepath. --- -- Calling 'sendInput' again after sending the final message is a bug. --- --- WARNING: Sending multiple messages on the same call within the same STM --- transaction will deadlock; you should enqueue them separately. --- --- TODO: We should have a way to detect this and throw an exception. -sendInput :: HasCallStack => Call rpc -> StreamElem () (Input rpc) -> STM () -sendInput = Session.send . callChannel +sendInput :: + (HasCallStack, MonadIO m) + => Call rpc + -> StreamElem NoMetadata (Input rpc) + -> m () +sendInput call msg = liftIO $ do + atomically $ sendInputSTM call msg + StreamElem.whenDefinitelyFinal msg $ \_ -> waitForOutbound call -- | Receive an output from the peer -- --- This lives in @STM@ for improved compositionality. For example, you can wait --- on multiple clients and see which one responds first. --- -- After the final 'Output', you will receive any 'CustomMetadata' (application -- defined trailers) that the server returns. We do /NOT/ include the -- 'GrpcStatus' here: a status of 'GrpcOk' carries no information, and any other -- status will result in a 'GrpcException'. Calling 'recvOutput' again after -- receiving the trailers is a bug and results in a 'RecvAfterFinal' exception. -recvOutput :: Call rpc -> STM (StreamElem [CustomMetadata] (Output rpc)) -recvOutput = fmap (first collapseTrailers) . Session.recv . callChannel - where - -- No difference between 'ProperTrailers' and 'TrailersOnly' - collapseTrailers :: - Either [CustomMetadata] [CustomMetadata] - -> [CustomMetadata] - collapseTrailers = either id id +recvOutput :: + MonadIO m + => Call rpc + -> m (StreamElem [CustomMetadata] (Output rpc)) +recvOutput call = + liftIO $ atomically $ recvOutputSTM call -- | The initial metadata that was included in the response headers -- @@ -176,6 +195,60 @@ recvResponseMetadata Call{callChannel} = aux (Left trailersOnly) = trailersOnly aux (Right headers) = responseMetadata $ inbHeaders headers +{------------------------------------------------------------------------------- + Low-level API +-------------------------------------------------------------------------------} + +-- | Send input +-- +-- This is a low-level API; most users should use 'sendInput' instead. +-- +-- The advantage of 'sendInputSTM' over 'sendInput' is the improved +-- compositionality of @STM@. For example, if the peer is currently busy then +-- 'sendInput' will block, but you can then use 'orElse' to provide an +-- alternative codepath. +-- +-- If you choose to use 'sendInputSTM' over 'sendInput', you have some +-- responsibilities: +-- +-- * You must call 'waitForOutbound' after sending the final message (and before +-- exiting the scope of 'withRPC'). +-- * You should not enqueue multiple messages on the same call within the same +-- STM transaction; doing so will deadlock. +sendInputSTM :: + HasCallStack + => Call rpc + -> StreamElem NoMetadata (Input rpc) + -> STM () +sendInputSTM = Session.send . callChannel + +-- | Receive output +-- +-- This is a low-level API; most users should use 'recvOutput' instead. +-- +-- The improved compositionality of STM can for example be used to wait on +-- multiple clients and see which one responds first. +-- +-- If you choose to use 'recvOutputSTM' over 'recvOutput', you should call +-- 'waitForInbound' after receiving the final message (and before exiting the +-- scope of 'withRPC'). +recvOutputSTM :: Call rpc -> STM (StreamElem [CustomMetadata] (Output rpc)) +recvOutputSTM = fmap (first collapseTrailers) . Session.recv . callChannel + where + -- No difference between 'ProperTrailers' and 'TrailersOnly' + collapseTrailers :: + Either [CustomMetadata] [CustomMetadata] + -> [CustomMetadata] + collapseTrailers = either id id + +-- | Wait for all outbound messages to have been processed +-- +-- This should be called before exiting the scope of 'withRPC'. However, +-- 'sendOutput' will call this function when sending the final messages, so +-- you only need to worry about this if using 'sendOutputSTM'. +waitForOutbound :: Call rpc -> IO () +waitForOutbound = void . Session.waitForOutbound . callChannel + {------------------------------------------------------------------------------- Protocol specific wrappers -------------------------------------------------------------------------------} @@ -185,23 +258,24 @@ sendFinalInput :: => Call rpc -> Input rpc -> m () -sendFinalInput call input = liftIO $ - atomically $ sendInput call (FinalElem input ()) +sendFinalInput call input = + sendInput call (FinalElem input NoMetadata) sendAllInputs :: forall m rpc. MonadIO m => Call rpc - -> m (StreamElem () (Input rpc)) + -> m (StreamElem NoMetadata (Input rpc)) -> m () sendAllInputs call produceInput = loop where loop :: m () loop = do inp <- produceInput - liftIO $ atomically $ sendInput call inp - case StreamElem.definitelyFinal inp of - Nothing -> loop - Just _ -> return () + sendInput call inp + case inp of + StreamElem{} -> loop + FinalElem{} -> return () + NoMoreElems{} -> return () -- | Receive output, which we expect to be the /final/ output -- @@ -214,12 +288,12 @@ recvFinalOutput :: forall m rpc. => Call rpc -> m (Output rpc, [CustomMetadata]) recvFinalOutput call@Call{} = liftIO $ do - out1 <- atomically $ recvOutput call + out1 <- recvOutput call case out1 of - NoMoreElems ts -> throwIO $ TooFewOutputs @rpc ts + NoMoreElems ts -> throwM $ TooFewOutputs @rpc ts FinalElem out ts -> return (out, ts) StreamElem out -> do - out2 <- atomically $ recvOutput call + out2 <- recvOutput call case out2 of NoMoreElems ts -> return (out, ts) FinalElem out' _ -> throwIO $ TooManyOutputs @rpc out' @@ -234,7 +308,7 @@ recvAllOutputs call processOutput = loop where loop :: m [CustomMetadata] loop = do - mOut <- liftIO $ atomically $ recvOutput call + mOut <- recvOutput call case mOut of StreamElem out -> do processOutput out diff --git a/src/Network/GRPC/Client/Connection.hs b/src/Network/GRPC/Client/Connection.hs index 450d5447..8349f1cb 100644 --- a/src/Network/GRPC/Client/Connection.hs +++ b/src/Network/GRPC/Client/Connection.hs @@ -36,8 +36,6 @@ import Network.GRPC.Client.Meta qualified as Meta import Network.GRPC.Client.Session import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Spec -import Network.GRPC.Spec.PseudoHeaders -import Network.GRPC.Spec.RPC import Network.GRPC.Util.HTTP2.TLS import Network.GRPC.Util.Session qualified as Session import Network.GRPC.Util.TLS (ServerValidation(..)) diff --git a/src/Network/GRPC/Client/Meta.hs b/src/Network/GRPC/Client/Meta.hs index f5a85c00..2c43d0fc 100644 --- a/src/Network/GRPC/Client/Meta.hs +++ b/src/Network/GRPC/Client/Meta.hs @@ -16,7 +16,6 @@ import Prelude hiding (init) import Control.Monad.Catch import Data.List.NonEmpty (NonEmpty) -import Network.GRPC.Common.Compression (Compression, CompressionId) import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Spec diff --git a/src/Network/GRPC/Client/Session.hs b/src/Network/GRPC/Client/Session.hs index d2f254dd..7c543e95 100644 --- a/src/Network/GRPC/Client/Session.hs +++ b/src/Network/GRPC/Client/Session.hs @@ -10,16 +10,9 @@ import Control.Monad import Data.Proxy import Network.HTTP.Types qualified as HTTP +import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr -import Network.GRPC.Common.Exceptions -import Network.GRPC.Spec qualified as GRPC -import Network.GRPC.Spec.Compression (Compression) -import Network.GRPC.Spec.CustomMetadata -import Network.GRPC.Spec.LengthPrefixed qualified as LP -import Network.GRPC.Spec.PseudoHeaders -import Network.GRPC.Spec.Request qualified as Req -import Network.GRPC.Spec.Response qualified as Resp -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec import Network.GRPC.Util.Session {------------------------------------------------------------------------------- @@ -28,7 +21,7 @@ import Network.GRPC.Util.Session data ClientSession rpc = ClientSession { clientCompression :: Compr.Negotation - , clientUpdateMeta :: GRPC.ResponseHeaders -> IO () + , clientUpdateMeta :: ResponseHeaders -> IO () } {------------------------------------------------------------------------------- @@ -40,56 +33,55 @@ data ClientOutbound rpc instance IsRPC rpc => DataFlow (ClientInbound rpc) where data Headers (ClientInbound rpc) = InboundHeaders { - inbHeaders :: GRPC.ResponseHeaders + inbHeaders :: ResponseHeaders , inbCompression :: Compression } deriving (Show) - type Message (ClientInbound rpc) = Output rpc - type ProperTrailers (ClientInbound rpc) = [CustomMetadata] - type TrailersOnly (ClientInbound rpc) = [CustomMetadata] + type Message (ClientInbound rpc) = Output rpc + type Trailers (ClientInbound rpc) = [CustomMetadata] + type NoMessages (ClientInbound rpc) = [CustomMetadata] instance IsRPC rpc => DataFlow (ClientOutbound rpc) where data Headers (ClientOutbound rpc) = OutboundHeaders { - outHeaders :: GRPC.RequestHeaders + outHeaders :: RequestHeaders , outCompression :: Compression } deriving (Show) - type Message (ClientOutbound rpc) = Input rpc - type ProperTrailers (ClientOutbound rpc) = () + type Message (ClientOutbound rpc) = Input rpc + type Trailers (ClientOutbound rpc) = NoMetadata -- gRPC does not support request trailers, but does not require that a request -- has any messages at all. When it does not, the request can be terminated -- immediately after the initial set of headers; this makes it essentially a -- 'Trailers-Only' case, but the headers are just the normal headers. - type TrailersOnly (ClientOutbound rpc) = GRPC.RequestHeaders + type NoMessages (ClientOutbound rpc) = RequestHeaders instance IsRPC rpc => IsSession (ClientSession rpc) where type Inbound (ClientSession rpc) = ClientInbound rpc type Outbound (ClientSession rpc) = ClientOutbound rpc - buildProperTrailers _client = \() -> - [] -- Request trailers are not supported by gRPC - parseProperTrailers _client = - processResponseTrailers $ Resp.parseProperTrailers (Proxy @rpc) + buildOutboundTrailers _client = \NoMetadata -> [] + parseInboundTrailers _client = + processResponseTrailers $ parseProperTrailers (Proxy @rpc) - parseMsg _ = LP.parseOutput (Proxy @rpc) . inbCompression - buildMsg _ = LP.buildInput (Proxy @rpc) . outCompression + parseMsg _ = parseOutput (Proxy @rpc) . inbCompression + buildMsg _ = buildInput (Proxy @rpc) . outCompression instance IsRPC rpc => InitiateSession (ClientSession rpc) where parseResponseRegular client info = do unless (HTTP.statusCode (responseStatus info) == 200) $ throwIO $ ResponseInvalidStatus (responseStatus info) - responseHeaders :: GRPC.ResponseHeaders <- - case Resp.parseHeaders (Proxy @rpc) (responseHeaders info) of + responseHeaders :: ResponseHeaders <- + case parseResponseHeaders (Proxy @rpc) (responseHeaders info) of Left err -> throwIO $ ResponseInvalidHeaders err Right parsed -> return parsed cIn :: Compression <- Compr.getSupported (clientCompression client) $ - GRPC.responseCompression responseHeaders + responseCompression responseHeaders clientUpdateMeta client responseHeaders @@ -98,20 +90,20 @@ instance IsRPC rpc => InitiateSession (ClientSession rpc) where , inbCompression = cIn } - parseResponseTrailersOnly _ info = do + parseResponseNoMessages _ info = do unless (HTTP.statusCode (responseStatus info) == 200) $ throwIO $ ResponseInvalidStatus (responseStatus info) processResponseTrailers - (fmap GRPC.getTrailersOnly . Resp.parseTrailersOnly (Proxy @rpc)) + (fmap getTrailersOnly . parseTrailersOnly (Proxy @rpc)) (responseHeaders info) buildRequestInfo _ start = RequestInfo { requestMethod = rawMethod resourceHeaders , requestPath = rawPath resourceHeaders - , requestHeaders = Req.buildHeaders (Proxy @rpc) $ + , requestHeaders = buildRequestHeaders (Proxy @rpc) $ case start of - FlowStartRegular headers -> outHeaders headers - FlowStartTrailersOnly headers -> headers + FlowStartRegular headers -> outHeaders headers + FlowStartNoMessages headers -> headers } where resourceHeaders :: RawResourceHeaders @@ -128,19 +120,20 @@ instance IsRPC rpc => InitiateSession (ClientSession rpc) where -- -- However, in practice gRPC servers can also respond with @Trailers-Only@ in -- non-error cases, simply indicating that the server considers the --- conversation over. To distinguish, we look at 'trailerGrpcStatus': in case --- of 'GrpcOk' we return the 'Trailers', and in the case of 'GrpcError' we --- throw 'RpcImmediateError'. +-- conversation over. To distinguish, we look at 'trailerGrpcStatus'. +-- +-- TODO: We are throwing away 'trailerGrpcMessage' in case of 'GrpcOk'; not +-- sure if that is problematic. processResponseTrailers :: - ([HTTP.Header] -> Either String GRPC.ProperTrailers) + ([HTTP.Header] -> Either String ProperTrailers) -> [HTTP.Header] -> IO [CustomMetadata] processResponseTrailers parse raw = case parse raw of Left err -> throwIO $ ResponseInvalidTrailers err Right trailers -> case grpcExceptionFromTrailers trailers of - Nothing -> return $ GRPC.trailerMetadata trailers - Just err -> throwIO err + Left (_msg, metadata) -> return metadata + Right err -> throwIO err {------------------------------------------------------------------------------- Exceptions diff --git a/src/Network/GRPC/Client/StreamType.hs b/src/Network/GRPC/Client/StreamType.hs index c17bfdfd..6f942edd 100644 --- a/src/Network/GRPC/Client/StreamType.hs +++ b/src/Network/GRPC/Client/StreamType.hs @@ -12,7 +12,7 @@ import Data.Proxy import Network.GRPC.Client import Network.GRPC.Common.StreamType -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec {------------------------------------------------------------------------------- Obtain handler for specific RPC call diff --git a/src/Network/GRPC/Client/StreamType/Pipes.hs b/src/Network/GRPC/Client/StreamType/Pipes.hs index be97bbaa..70b0b0d7 100644 --- a/src/Network/GRPC/Client/StreamType/Pipes.hs +++ b/src/Network/GRPC/Client/StreamType/Pipes.hs @@ -15,10 +15,9 @@ import Pipes.Safe import Network.GRPC.Client import Network.GRPC.Client.StreamType -import Network.GRPC.Common.StreamElem (StreamElem) +import Network.GRPC.Common import Network.GRPC.Common.StreamType (SupportsStreamingType) import Network.GRPC.Common.StreamType qualified as StreamType -import Network.GRPC.Spec.RPC import Network.GRPC.Util.RedundantConstraint {------------------------------------------------------------------------------- @@ -31,7 +30,7 @@ clientStreaming :: forall rpc. => Connection -> CallParams -> Proxy rpc - -> Consumer' (StreamElem () (Input rpc)) (SafeT IO) (Output rpc) + -> Consumer' (StreamElem NoMetadata (Input rpc)) (SafeT IO) (Output rpc) clientStreaming conn params proxy = StreamType.clientStreaming (rpcWith conn params proxy) @@ -68,7 +67,7 @@ biDiStreaming :: forall rpc a. => Connection -> CallParams -> Proxy rpc - -> ( Consumer' (StreamElem () (Input rpc)) IO () + -> ( Consumer' (StreamElem NoMetadata (Input rpc)) IO () -> Producer' (Output rpc) IO () -> IO a ) diff --git a/src/Network/GRPC/Common.hs b/src/Network/GRPC/Common.hs new file mode 100644 index 00000000..3cf00dbb --- /dev/null +++ b/src/Network/GRPC/Common.hs @@ -0,0 +1,68 @@ +-- | General infrastructure used by both the client and the server +-- +-- Intended for unqualified import. +module Network.GRPC.Common ( + -- * Abstraction over different serialization formats + IsRPC(..) + + -- * Stream elements + -- + -- We export only the main type here; for operations on 'StreamElem', see + -- "Network.GRPC.Common.StreamElem" (intended for qualified import). + , StreamElem(..) + + -- * Custom metadata + -- + -- Clients can include custom metadata in the initial request to the server, + -- and servers can include custom metadata boh in the initial response to + -- the client as well as in the response trailers. + , CustomMetadata(..) + , HeaderName(HeaderName) + , AsciiValue(AsciiValue) + , BinaryValue(..) + , NoMetadata(..) + , customHeaderName + + -- * Exceptions + , GrpcException(..) + , GrpcError(..) + , ProtocolException(..) + -- ** Low-level + , ThreadCancelled(..) + , Session.ChannelClosed(..) + , Session.ChannelUncleanClose(..) + + ) where + +import Control.Exception + +import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Spec +import Network.GRPC.Util.Session qualified as Session +import Network.GRPC.Util.Thread (ThreadCancelled(..)) + +{------------------------------------------------------------------------------- + Exceptions +-------------------------------------------------------------------------------} + +-- | Protocol exception +-- +-- A protocol exception arises when the client and the server disagree on the +-- sequence of inputs and outputs exchanged. This agreement might be part of a +-- formal specification such as Protobuf, or it might be implicit in the +-- implementation of a specific RPC. +data ProtocolException rpc = + -- | We expected an input but got none + TooFewInputs + + -- | We received an input when we expected no more inputs + | TooManyInputs (Input rpc) + + -- | We expected an output, but got trailers instead + | TooFewOutputs [CustomMetadata] + + -- | We expected trailers, but got an output instead + | TooManyOutputs (Output rpc) + +deriving instance IsRPC rpc => Show (ProtocolException rpc) +deriving instance IsRPC rpc => Exception (ProtocolException rpc) diff --git a/src/Network/GRPC/Common/Binary.hs b/src/Network/GRPC/Common/Binary.hs index 1ac06220..1defe404 100644 --- a/src/Network/GRPC/Common/Binary.hs +++ b/src/Network/GRPC/Common/Binary.hs @@ -15,7 +15,7 @@ import Data.Binary.Get qualified as Binary import Data.ByteString.Lazy qualified as BS.Lazy import Data.ByteString.Lazy qualified as Lazy (ByteString) -import Network.GRPC.Spec.RPC.Binary (BinaryRpc) +import Network.GRPC.Spec {------------------------------------------------------------------------------- Decoding diff --git a/src/Network/GRPC/Common/Compression.hs b/src/Network/GRPC/Common/Compression.hs index 787c7c4e..3dff056f 100644 --- a/src/Network/GRPC/Common/Compression.hs +++ b/src/Network/GRPC/Common/Compression.hs @@ -1,6 +1,6 @@ -- | Public 'Compression' API -- --- Intended for qualified import. +-- Intended for unqualified import. -- -- > import Network.GRPC.Common.Compression (Compression(..)) -- > import Network.GRPC.Common.Compression qualified as Compr @@ -9,9 +9,9 @@ module Network.GRPC.Common.Compression ( Compression(..) , CompressionId(..) -- * Standard compresion schemes - , identity + , noCompression , gzip - , allSupported + , allSupportedCompression -- * Negotation , Negotation(..) , getSupported @@ -33,7 +33,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map (Map) import Data.Map qualified as Map -import Network.GRPC.Spec.Compression +import Network.GRPC.Spec {------------------------------------------------------------------------------- Negotation @@ -63,18 +63,18 @@ data Negotation = Negotation { getSupported :: MonadThrow m => Negotation -> Maybe CompressionId -> m Compression -getSupported _ Nothing = return identity +getSupported _ Nothing = return noCompression getSupported compr (Just cid) = case Map.lookup cid (supported compr) of Nothing -> throwM $ UnsupportedCompression cid Just c -> return c instance Default Negotation where - def = chooseFirst allSupported + def = chooseFirst allSupportedCompression -- | Disable all compression none :: Negotation -none = require identity +none = require noCompression -- | Insist on the specified algorithm require :: Compression -> Negotation diff --git a/src/Network/GRPC/Common/CustomMetadata.hs b/src/Network/GRPC/Common/CustomMetadata.hs deleted file mode 100644 index 9238656c..00000000 --- a/src/Network/GRPC/Common/CustomMetadata.hs +++ /dev/null @@ -1,17 +0,0 @@ --- | Custom metadata --- --- Clients can include custom metadata in the initial request to the server, and --- servers can include custom metadata boh in the initial response to the client --- as well as in the response trailers. --- --- Intended for qualified import. --- --- > import Network.GRPC.Common.CustomMetadata (CustomMetadata(..), HeaderName(..)) --- > import Network.GRPC.Common.CustomMetadata qualified as Metadata -module Network.GRPC.Common.CustomMetadata ( - CustomMetadata(..) - , HeaderName(HeaderName) - , AsciiValue(AsciiValue) - ) where - -import Network.GRPC.Spec.CustomMetadata diff --git a/src/Network/GRPC/Common/Exceptions.hs b/src/Network/GRPC/Common/Exceptions.hs deleted file mode 100644 index bf2553d0..00000000 --- a/src/Network/GRPC/Common/Exceptions.hs +++ /dev/null @@ -1,80 +0,0 @@ --- | Exceptions --- --- Intended for unqualified import -module Network.GRPC.Common.Exceptions ( - -- * gRPC exception - GrpcException(..) - , GrpcError(..) - , grpcExceptionFromTrailers - , grpcExceptionToTrailers - -- * Protocol exception - , ProtocolException(..) - -- * Low-level exceptions - , ThreadCancelled(..) - , ChannelClosed(..) - ) where - -import Control.Exception -import Data.Text (Text) - -import Network.GRPC.Spec -import Network.GRPC.Spec.CustomMetadata -import Network.GRPC.Spec.RPC -import Network.GRPC.Util.Session.Channel (ChannelClosed(..)) -import Network.GRPC.Util.Thread (ThreadCancelled(..)) - -{------------------------------------------------------------------------------- - gRPC exception --------------------------------------------------------------------------------} - --- | Server indicated a gRPC error -data GrpcException = GrpcException { - grpcError :: GrpcError - , grpcErrorMessage :: Maybe Text - , grpcErrorMetadata :: [CustomMetadata] - } - deriving stock (Show, Eq) - deriving anyclass (Exception) - -grpcExceptionFromTrailers :: ProperTrailers -> Maybe GrpcException -grpcExceptionFromTrailers trailers = - case trailerGrpcStatus trailers of - GrpcOk -> Nothing - GrpcError err -> Just GrpcException{ - grpcError = err - , grpcErrorMessage = trailerGrpcMessage trailers - , grpcErrorMetadata = trailerMetadata trailers - } - -grpcExceptionToTrailers :: GrpcException -> ProperTrailers -grpcExceptionToTrailers err = ProperTrailers{ - trailerGrpcStatus = GrpcError (grpcError err) - , trailerGrpcMessage = grpcErrorMessage err - , trailerMetadata = grpcErrorMetadata err - } - -{------------------------------------------------------------------------------- - Protocol exception --------------------------------------------------------------------------------} - --- | Protocol exception --- --- A protocol exception arises when the client and the server disagree on the --- sequence of inputs and outputs exchanged. This agreement might be part of a --- formal specification such as Protobuf, or it might be implicit in the --- implementation of a specific RPC. -data ProtocolException rpc = - -- | We expected an input but got none - TooFewInputs - - -- | We received an input when we expected no more inputs - | TooManyInputs (Input rpc) - - -- | We expected an output, but got trailers instead - | TooFewOutputs [CustomMetadata] - - -- | We expected trailers, but got an output instead - | TooManyOutputs (Output rpc) - -deriving instance IsRPC rpc => Show (ProtocolException rpc) -deriving instance IsRPC rpc => Exception (ProtocolException rpc) diff --git a/src/Network/GRPC/Common/StreamElem.hs b/src/Network/GRPC/Common/StreamElem.hs index b762bb24..d8063c3b 100644 --- a/src/Network/GRPC/Common/StreamElem.hs +++ b/src/Network/GRPC/Common/StreamElem.hs @@ -2,12 +2,14 @@ -- -- Intended for qualified import. -- --- > import Network.GRPC.Common.StreamElem (StreamElem(..)) -- > import Network.GRPC.Common.StreamElem qualified as StreamElem +-- +-- "Network.GRPC.Common" (intended for unqualified import) exports 'StreamElem', +-- but none of the operations on 'StreamElem'. module Network.GRPC.Common.StreamElem ( StreamElem(..) , value - , definitelyFinal + , whenDefinitelyFinal , mapM_ , collect ) where @@ -19,6 +21,8 @@ import Control.Monad.Trans.Class import Data.Bifoldable import Data.Bifunctor import Data.Bitraversable +import GHC.Generics qualified as GHC +import Text.Show.Pretty -- | An element positioned in a stream data StreamElem b a = @@ -57,7 +61,8 @@ data StreamElem b a = -- * The final element was not marked as final. -- See 'StreamElem' for detailed additional discussion. | NoMoreElems b - deriving stock (Show, Eq, Functor, Foldable, Traversable) + deriving stock (Show, Eq, Functor, Foldable, Traversable, GHC.Generic) + deriving anyclass (PrettyVal) instance Bifunctor StreamElem where bimap g f (FinalElem a b) = FinalElem (f a) (g b) @@ -90,28 +95,29 @@ value = \case -- -- A 'False' result does not mean the element is not final; see 'StreamElem' for -- detailed discussion. -definitelyFinal :: StreamElem b a -> Maybe b -definitelyFinal = \case - StreamElem _ -> Nothing - FinalElem _ b -> Just b - NoMoreElems b -> Just b +whenDefinitelyFinal :: Applicative m => StreamElem b a -> (b -> m ()) -> m () +whenDefinitelyFinal msg k = + case msg of + StreamElem _ -> pure () + FinalElem _ b -> k b + NoMoreElems b -> k b -- | Map over all elements -mapM_ :: forall m a. Monad m => m (StreamElem () a) -> (a -> m ()) -> m () +mapM_ :: forall m a b. Monad m => m (StreamElem b a) -> (a -> m ()) -> m () mapM_ recv f = loop where loop :: m () loop = do x <- recv case x of - StreamElem a -> f a >> loop - FinalElem a () -> f a - NoMoreElems () -> return () + StreamElem a -> f a >> loop + FinalElem a _ -> f a + NoMoreElems _ -> return () -- | Collect all elements -- -- Returns the elements in the order they were received. -collect :: forall m a. Monad m => m (StreamElem () a) -> m [a] +collect :: forall m a b. Monad m => m (StreamElem b a) -> m [a] collect recv = reverse <$> execStateT go [] where diff --git a/src/Network/GRPC/Common/StreamType.hs b/src/Network/GRPC/Common/StreamType.hs index 782cbf44..63575b02 100644 --- a/src/Network/GRPC/Common/StreamType.hs +++ b/src/Network/GRPC/Common/StreamType.hs @@ -8,215 +8,26 @@ -- using the more general interface from "Network.GRPC.Client" or -- "Network.GRPC.Server" is not much more difficult. module Network.GRPC.Common.StreamType ( - -- * Communication patterns + -- * Abstraction StreamingType(..) , SupportsStreamingType , HasStreamingType(..) - -- * Handler types + -- * Handlers + , HandlerFor , NonStreamingHandler(..) , ClientStreamingHandler(..) , ServerStreamingHandler(..) , BiDiStreamingHandler(..) - , HandlerFor - -- ** Destructors + -- * Execution , nonStreaming , clientStreaming , serverStreaming , biDiStreaming - -- ** Constructors + -- * Construction , mkNonStreaming , mkClientStreaming , mkServerStreaming , mkBiDiStreaming ) where -import Data.Kind -import Network.GRPC.Common.StreamElem (StreamElem) -import Network.GRPC.Spec.RPC -import Network.GRPC.Util.RedundantConstraint - --- Borrow protolens 'StreamingType' (but this module is not Protobuf specific) -import Data.ProtoLens.Service.Types (StreamingType(..)) - -{------------------------------------------------------------------------------- - Communication patterns --------------------------------------------------------------------------------} - --- | This RPC supports the given streaming type --- --- This is a weaker condition than 'HasStreamingType', which maps each RPC to --- a /specific/ streaming type. Some (non-Protobuf) RPCs however may support --- more than one streaming type. -class SupportsStreamingType rpc (styp :: StreamingType) - -class SupportsStreamingType rpc (RpcStreamingType rpc) - => HasStreamingType rpc where - -- | Streaming type supported by this RPC - -- - -- The - -- [gRPC specification](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md) - -- does not distinguish between different kinds of communication patterns; - -- this is done in the - -- [Protobuf specification](https://protobuf.dev/reference/protobuf/proto3-spec/#service_definition). - -- Nonetheless, these streaming types can be applied to other serialization - -- formats also. - type RpcStreamingType rpc :: StreamingType - -{------------------------------------------------------------------------------- - Handler types - - Users don't typically work with these types directly; however, they serve - as as specification of the commonality between the client and the server. --------------------------------------------------------------------------------} - -newtype NonStreamingHandler m rpc = UnsafeNonStreamingHandler ( - Input rpc - -> m (Output rpc) - ) - -newtype ClientStreamingHandler m rpc = UnsafeClientStreamingHandler ( - m (StreamElem () (Input rpc)) - -> m (Output rpc) - ) - -newtype ServerStreamingHandler m rpc = UnsafeServerStreamingHandler ( - Input rpc - -> (Output rpc -> m ()) - -> m () - ) - -newtype BiDiStreamingHandler m rpc = UnsafeBiDiStreamingHandler ( - m (StreamElem () (Input rpc)) - -> (Output rpc -> m ()) - -> m () - ) - --- | Match 'StreamingType' to handler type --- --- This is occassionally useful to improve type inference. Users are not --- expected to need to use this in their own code. -type family HandlerFor (typ :: StreamingType) :: (Type -> Type) -> Type -> Type where - HandlerFor NonStreaming = NonStreamingHandler - HandlerFor ClientStreaming = ClientStreamingHandler - HandlerFor ServerStreaming = ServerStreamingHandler - HandlerFor BiDiStreaming = BiDiStreamingHandler - -{------------------------------------------------------------------------------- - Destructors - - These are just the record field accessors, but with an additional constraint - on them that ensures that the right type of handler is used with the right - type of RPC. --------------------------------------------------------------------------------} - --- | Make a non-streaming RPC --- --- Example usage: --- --- > features <- --- > nonStreaming (rpc @(Protobuf RouteGuide "getFeature") conn) point --- > logMsg features -nonStreaming :: forall rpc m. - SupportsStreamingType rpc NonStreaming - => NonStreamingHandler m rpc - -> Input rpc - -> m (Output rpc) -nonStreaming (UnsafeNonStreamingHandler h) = h - where - _ = addConstraint @(SupportsStreamingType rpc NonStreaming) - --- | Make a client-side streaming RPC --- --- Example usage: --- --- > summary <- --- > clientStreaming (rpc @(Protobuf RouteGuide "recordRoute") conn) getPoint -clientStreaming :: forall rpc m. - SupportsStreamingType rpc ClientStreaming - => ClientStreamingHandler m rpc - -> m (StreamElem () (Input rpc)) - -> m (Output rpc) -clientStreaming (UnsafeClientStreamingHandler h) = h - where - _ = addConstraint @(SupportsStreamingType rpc ClientStreaming) - --- | Make a server-side streaming RPC --- --- Example usage: --- --- > serverStreaming (rpc @(Protobuf RouteGuide "listFeatures") conn) rect $ --- > logMsg -serverStreaming :: forall rpc m. - SupportsStreamingType rpc ServerStreaming - => ServerStreamingHandler m rpc - -> Input rpc - -> (Output rpc -> m ()) - -> m () -serverStreaming (UnsafeServerStreamingHandler h) = h - where - _ = addConstraint @(SupportsStreamingType rpc ServerStreaming) - --- | Make a bidirectional RPC --- --- Example usage: --- --- > biDiStreaming (rpc @(Protobuf RouteGuide "routeChat") conn) getNote $ --- > logMsg -biDiStreaming :: forall rpc m. - SupportsStreamingType rpc BiDiStreaming - => BiDiStreamingHandler m rpc - -> m (StreamElem () (Input rpc)) - -> (Output rpc -> m ()) - -> m () -biDiStreaming (UnsafeBiDiStreamingHandler h) = h - where - _ = addConstraint @(SupportsStreamingType rpc BiDiStreaming) - -{------------------------------------------------------------------------------- - Constructors - - These are just the newtype constructors, but with an additional constraint, - similar to the destructors above. --------------------------------------------------------------------------------} - -mkNonStreaming :: forall m rpc. - SupportsStreamingType rpc NonStreaming - => ( Input rpc - -> m (Output rpc) - ) - -> NonStreamingHandler m rpc -mkNonStreaming = UnsafeNonStreamingHandler - where - _ = addConstraint @(SupportsStreamingType rpc NonStreaming) - -mkClientStreaming :: forall m rpc. - SupportsStreamingType rpc ClientStreaming - => ( m (StreamElem () (Input rpc)) - -> m (Output rpc) - ) - -> ClientStreamingHandler m rpc -mkClientStreaming = UnsafeClientStreamingHandler - where - _ = addConstraint @(SupportsStreamingType rpc ClientStreaming) - -mkServerStreaming :: forall m rpc. - SupportsStreamingType rpc ServerStreaming - => ( Input rpc - -> (Output rpc -> m ()) - -> m () - ) - -> ServerStreamingHandler m rpc -mkServerStreaming = UnsafeServerStreamingHandler - where - _ = addConstraint @(SupportsStreamingType rpc ServerStreaming) - -mkBiDiStreaming :: forall m rpc. - SupportsStreamingType rpc BiDiStreaming - => ( m (StreamElem () (Input rpc)) - -> (Output rpc -> m ()) - -> m () - ) - -> BiDiStreamingHandler m rpc -mkBiDiStreaming = UnsafeBiDiStreamingHandler - where - _ = addConstraint @(SupportsStreamingType rpc BiDiStreaming) +import Network.GRPC.Spec diff --git a/src/Network/GRPC/Server.hs b/src/Network/GRPC/Server.hs index ba8fb1a4..ea561182 100644 --- a/src/Network/GRPC/Server.hs +++ b/src/Network/GRPC/Server.hs @@ -34,10 +34,10 @@ module Network.GRPC.Server ( import Control.Exception import Control.Tracer -import Data.Text qualified as Text +import Network.HTTP.Types qualified as HTTP import Network.HTTP2.Server qualified as HTTP2 -import Network.GRPC.Common.Exceptions +import Network.GRPC.Common import Network.GRPC.Server.Call import Network.GRPC.Server.Connection (Connection, withConnection) import Network.GRPC.Server.Connection qualified as Connection @@ -46,8 +46,7 @@ import Network.GRPC.Server.Context qualified as Context import Network.GRPC.Server.Handler (RpcHandler(..)) import Network.GRPC.Server.Handler qualified as Handler import Network.GRPC.Spec -import Network.GRPC.Spec.PseudoHeaders -import Network.GRPC.Spec.RPC.Protobuf (Protobuf) +import Network.GRPC.Util.Session.Server qualified as Session.Server {------------------------------------------------------------------------------- Server proper @@ -60,24 +59,22 @@ import Network.GRPC.Spec.RPC.Protobuf (Protobuf) -- functions. withServer :: ServerParams -> [RpcHandler IO] -> (HTTP2.Server -> IO a) -> IO a withServer params handlers k = do - Context.withContext params $ \ctxt -> - k $ withConnection ctxt $ - handleRequest (Handler.constructMap handlers) - -handleRequest :: Handler.Map IO -> Connection -> IO () -handleRequest handlers conn = do + Context.withContext params $ \ctxt -> do + let server :: HTTP2.Server + server = withConnection ctxt $ + handleRequest params (Handler.constructMap handlers) + k server + +handleRequest :: ServerParams -> Handler.Map IO -> Connection -> IO () +handleRequest params handlers conn = do -- TODO: Proper "Apache style" logging (in addition to the debug logging) traceWith tracer $ Context.NewRequest path - - RpcHandler handler <- getHandler handlers path - call <- acceptCall conn - - -- TODO: Timeouts - -- - -- Wait-for-ready semantics makes this more complicated, maybe. - -- See example in the grpc-repo (python/wait_for_ready). - - handle (forwardException call) $ handler call + withHandler params handlers path conn $ \(RpcHandler h) -> do + -- TODO: Timeouts + -- + -- Wait-for-ready semantics makes this more complicated, maybe. + -- See example in the grpc-repo (python/wait_for_ready). + acceptCall params conn h where path :: Path path = Connection.path conn @@ -88,49 +85,34 @@ handleRequest handlers conn = do $ Context.params $ Connection.context conn --- | Forward exception to the client --- --- If the handler throws an exception, attempt to forward it to the client so --- that it is notified something went wrong. This is a best-effort only: --- --- * The nature of the exception might mean that we we cannot send anything to --- the client at all. --- * It is possible the exception was thrown /after/ the handler already send --- the trailers to the client. --- --- We therefore catch and suppress all exceptions here. -forwardException :: Call rpc -> SomeException -> IO () -forwardException call err = - handle ignoreExceptions $ - sendProperTrailers call trailers - where - trailers :: ProperTrailers - trailers - | Just (err' :: GrpcException) <- fromException err - = grpcExceptionToTrailers err' - - -- TODO: There might be a security concern here (server-side exceptions - -- could potentially leak some sensitive data). - | otherwise - = ProperTrailers { - trailerGrpcStatus = GrpcError GrpcUnknown - , trailerGrpcMessage = Just $ Text.pack $ show err - , trailerMetadata = [] - } - - -- See discussion above. - ignoreExceptions :: SomeException -> IO () - ignoreExceptions _ = return () - {------------------------------------------------------------------------------- Get handler for the request -------------------------------------------------------------------------------} -getHandler :: Handler.Map m -> Path -> IO (RpcHandler m) -getHandler handlers path = do +withHandler :: + ServerParams + -> Handler.Map m + -> Path + -> Connection + -> (RpcHandler m -> IO ()) -> IO () +withHandler params handlers path conn k = case Handler.lookup path handlers of - Nothing -> throwIO $ grpcUnimplemented path - Just h -> return h + Just h -> k h + Nothing -> do + traceWith (Context.serverExceptionTracer params) (toException err) + Session.Server.respond conn' $ + HTTP2.responseNoBody + HTTP.ok200 + (buildTrailersOnly trailers) + where + conn' :: Session.Server.ConnectionToClient + conn' = Connection.connectionToClient conn + + err :: GrpcException + err = grpcUnimplemented path + + trailers :: TrailersOnly + trailers = TrailersOnly $ grpcExceptionToTrailers err grpcUnimplemented :: Path -> GrpcException grpcUnimplemented path = GrpcException { diff --git a/src/Network/GRPC/Server/Binary.hs b/src/Network/GRPC/Server/Binary.hs index f66263a5..f3314804 100644 --- a/src/Network/GRPC/Server/Binary.hs +++ b/src/Network/GRPC/Server/Binary.hs @@ -20,8 +20,7 @@ import Control.Monad.Catch import Data.Binary import Network.GRPC.Common.Binary -import Network.GRPC.Common.CustomMetadata (CustomMetadata) -import Network.GRPC.Common.StreamElem +import Network.GRPC.Common import Network.GRPC.Common.StreamType qualified as StreamType import Network.GRPC.Server (Call) import Network.GRPC.Server qualified as Server @@ -49,7 +48,7 @@ sendFinalOutput call (out, trailers) = recvInput :: Binary a => Call (BinaryRpc serv meth) - -> IO (StreamElem () a) + -> IO (StreamElem NoMetadata a) recvInput call = do Server.recvInput call >>= traverse decodeOrThrow @@ -76,7 +75,7 @@ mkNonStreaming f = StreamType.mkNonStreaming $ \inp -> do mkClientStreaming :: forall m out serv meth. (MonadThrow m, Binary out) - => ( (forall inp. Binary inp => m (StreamElem () inp)) + => ( (forall inp. Binary inp => m (StreamElem NoMetadata inp)) -> m out ) -> StreamType.ClientStreamingHandler m (BinaryRpc serv meth) @@ -97,7 +96,7 @@ mkServerStreaming f = StreamType.mkServerStreaming $ \inp send -> do mkBiDiStreaming :: forall m serv meth. MonadThrow m - => ( (forall inp. Binary inp => m (StreamElem () inp)) + => ( (forall inp. Binary inp => m (StreamElem NoMetadata inp)) -> (forall out. Binary out => out -> m ()) -> m () ) diff --git a/src/Network/GRPC/Server/Call.hs b/src/Network/GRPC/Server/Call.hs index e14569e9..f21b8d63 100644 --- a/src/Network/GRPC/Server/Call.hs +++ b/src/Network/GRPC/Server/Call.hs @@ -26,33 +26,29 @@ module Network.GRPC.Server.Call ( , sendOutputSTM , initiateResponse , sendTrailersOnly + , waitForOutbound -- ** Internal API , sendProperTrailers ) where import Control.Concurrent.STM -import Control.Exception import Control.Monad +import Control.Monad.Catch import Control.Tracer import Data.Bifunctor import Data.Text qualified as Text import Network.HTTP.Types qualified as HTTP import Network.HTTP2.Server qualified as HTTP2 -import Network.GRPC.Common.Compression (Compression) +import Network.GRPC.Common import Network.GRPC.Common.Compression qualified as Compr -import Network.GRPC.Common.Compression qualified as Compression -import Network.GRPC.Common.Exceptions -import Network.GRPC.Common.StreamElem (StreamElem(..)) +import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Server.Connection (Connection) import Network.GRPC.Server.Connection qualified as Connection import Network.GRPC.Server.Context qualified as Context import Network.GRPC.Server.Session import Network.GRPC.Spec -import Network.GRPC.Spec.CustomMetadata -import Network.GRPC.Spec.Response qualified as Resp -import Network.GRPC.Spec.RPC import Network.GRPC.Util.Session qualified as Session import Network.GRPC.Util.Session.Server qualified as Server @@ -113,29 +109,96 @@ data Kickoff = -- -- If an exception is thrown during call setup, we will send an error response -- to the client (and then rethrow the exception). -acceptCall :: forall rpc. IsRPC rpc => Connection -> IO (Call rpc) -acceptCall conn = do +acceptCall :: forall rpc. + IsRPC rpc + => Context.ServerParams + -> Connection + -> (Call rpc -> IO ()) + -> IO () +acceptCall params conn k = do callRequestMetadata <- newEmptyTMVarIO callResponseMetadata <- newTVarIO [] callResponseKickoff <- newEmptyTMVarIO - callChannel <- - Session.initiateResponse - callSession - (contramap (Context.PeerDebugMsg @rpc) tracer) - (Connection.connectionToClient conn) - ( handle sendErrorResponse - . mkOutboundHeaders - callRequestMetadata - callResponseMetadata - callResponseKickoff - ) - return Call{ - callSession - , callChannel - , callRequestMetadata - , callResponseMetadata - , callResponseKickoff - } + + let createChannel :: IO (Session.Channel (ServerSession rpc)) + createChannel = + Session.initiateResponse + callSession + (contramap (Context.PeerDebugMsg @rpc) tracer) + (Connection.connectionToClient conn) + ( handle sendErrorResponse + . mkOutboundHeaders + callRequestMetadata + callResponseMetadata + callResponseKickoff + ) + + closeChannel :: + Call rpc + -> Either SomeException () + -> IO () + closeChannel call@Call{callChannel} mRes = do + mUnclean <- + case mRes of + Right () -> + -- Handler terminated successfully + -- (but it might not have sent the final message, see below) + Session.close callChannel $ ExitCaseSuccess () + Left err -> do + -- The handler threw an exception. /Try/ to tell the client + -- (but see discussion of 'forwardException') + traceWith (Context.serverExceptionTracer params) err + forwardException call err + Session.close callChannel $ ExitCaseException err + + case mUnclean of + Nothing -> return () -- Channel was closed cleanly + Just err -> + -- An unclean shutdown can have 3 causes: + -- + -- 1. We failed to set up the call. + -- + -- If the failure is due to network comms, we definitely have + -- no way of communicating with the client. If the failure + -- was due to a setup failure in 'mkOutboundHeaders', then + -- the client has /already/ been notified of the problem. + -- + -- 2. We lost communication during the call. + -- + -- In this case we also have no way of telling the client + -- that something went wrong. + -- + -- 3. The handler failed to properly terminate the communication + -- (send the final message and call 'waitForOutbound'). + -- + -- This is the trickiest case. We don't really know what + -- state the handler left the channel in; for example, we + -- might have killed the thread halfway through sending a + -- message. + -- + -- So the only thing we really can do here is log the error, + -- and nothing else. (We should not rethrow it, as doing so will + -- cause http2 to reset the stream, which is not always the + -- right thing to do (for example, in case (1)). + traceWith (Context.serverExceptionTracer params) $ + toException err + + -- We don't use bracket here, we don't want to rethrow any exceptions + -- (see discussion of 'closeChannel') + mask $ \unmask -> do + callChannel <- createChannel + + let call :: Call rpc + call = Call{ + callSession + , callChannel + , callRequestMetadata + , callResponseMetadata + , callResponseKickoff + } + + mRes <- try $ unmask $ k call + closeChannel call mRes where callSession :: ServerSession rpc callSession = ServerSession { @@ -174,16 +237,16 @@ acceptCall conn = do KickoffRegular -> do cOut :: Compression <- case requestAcceptCompression inboundHeaders of - Nothing -> return Compression.identity + Nothing -> return noCompression Just cids -> -- If the requests explicitly lists compression algorithms, -- and that list does /not/ include @identity@, then we - -- should not default to 'Compression.identity', even if all - -- other algorithms are unsupported. This gives the client - -- the option to /insist/ on compression. - case Compression.choose compr cids of + -- should not default to 'Compr.identity', even if all other + -- algorithms are unsupported. This gives the client the + -- option to /insist/ on compression. + case Compr.choose compr cids of Right c -> return c - Left err -> throwIO err + Left err -> throwM err return $ Session.FlowStartRegular $ OutboundHeaders { outHeaders = ResponseHeaders { responseCompression = Just $ Compr.compressionId cOut @@ -193,29 +256,68 @@ acceptCall conn = do , outCompression = cOut } KickoffTrailersOnly trailers -> - return $ Session.FlowStartTrailersOnly trailers + return $ Session.FlowStartNoMessages trailers where inboundHeaders :: RequestHeaders inboundHeaders = case requestStart of - Session.FlowStartRegular headers -> inbHeaders headers - Session.FlowStartTrailersOnly headers -> headers + Session.FlowStartRegular headers -> inbHeaders headers + Session.FlowStartNoMessages trailers -> trailers -- | Send response when 'mkOutboundHeaders' fails - sendErrorResponse :: SomeException -> IO a + -- + -- TODO: Some duplication with 'forwardException' (and 'withHandler'). + sendErrorResponse :: forall x. SomeException -> IO x sendErrorResponse err = do + putStrLn $ "sendErrorResponse: " ++ show err traceWith tracer $ Context.AcceptCallFailed err Server.respond (Connection.connectionToClient conn) $ HTTP2.responseNoBody HTTP.ok200 -- gRPC uses HTTP 200 even when there are gRPC errors - (Resp.buildTrailersOnly $ TrailersOnly $ ProperTrailers { + (buildTrailersOnly $ TrailersOnly $ ProperTrailers { trailerGrpcStatus = GrpcError GrpcUnknown -- TODO: Potential security concern here -- (showing the exception)? - , trailerGrpcMessage = Just $ Text.pack $ show err + , trailerGrpcMessage = Just $ Text.pack (show err) , trailerMetadata = [] }) - throwIO err + throwM err + +-- | Process exception thrown by a handler +-- +-- Trace the exception and forward it to the client. +-- +-- The attempt to forward it to the client is a best-effort only: +-- +-- * The nature of the exception might mean that we we cannot send anything to +-- the client at all. +-- * It is possible the exception was thrown /after/ the handler already send +-- the trailers to the client. +-- +-- We therefore catch and suppress all exceptions here. +forwardException :: Call rpc -> SomeException -> IO () +forwardException call err = do + handle ignoreExceptions $ + sendProperTrailers call trailers + where + trailers :: ProperTrailers + trailers + | Just (err' :: GrpcException) <- fromException err + = grpcExceptionToTrailers err' + + -- TODO: There might be a security concern here (server-side exceptions + -- could potentially leak some sensitive data). + | otherwise + = ProperTrailers { + trailerGrpcStatus = GrpcError GrpcUnknown + , trailerGrpcMessage = Just $ Text.pack (show err) + , trailerMetadata = [] + } + + -- See discussion above. + ignoreExceptions :: SomeException -> IO () + ignoreExceptions _ = return () + {------------------------------------------------------------------------------- Open (ongoing) call @@ -225,8 +327,9 @@ acceptCall conn = do -- -- We do not return trailers, since gRPC does not support sending trailers from -- the client to the server (only from the server to the client). -recvInput :: forall rpc. Call rpc -> IO (StreamElem () (Input rpc)) -recvInput = atomically . recvInputSTM +recvInput :: Call rpc -> IO (StreamElem NoMetadata (Input rpc)) +recvInput call = + atomically $ recvInputSTM call -- | Send RPC output to the client -- @@ -234,10 +337,13 @@ recvInput = atomically . recvInputSTM -- to indicate something went wrong), the server handler should throw a -- 'GrpcException' (the @grapesy@ client API treats this the same way: a -- @grpc-status@ other than @0@ will be raised as a 'GrpcException'). +-- +-- Will block if this the final output. sendOutput :: Call rpc -> StreamElem [CustomMetadata] (Output rpc) -> IO () sendOutput call msg = do _updated <- initiateResponse call atomically $ sendOutputSTM call msg + StreamElem.whenDefinitelyFinal msg $ \_ -> waitForOutbound call -- | Get request metadata -- @@ -279,21 +385,39 @@ setResponseMetadata Call{ callResponseMetadata -- | STM version of 'recvInput' -- +-- This is a low-level API; most users can use 'recvInput' instead. +-- -- Most server handlers will deal with single clients, but in principle --- 'recvInputSTM' could be used to wait for the first message from any number --- of clients. -recvInputSTM :: forall rpc. Call rpc -> STM (StreamElem () (Input rpc)) +-- 'recvInputSTM' could be used to wait for the first message from any number of +-- clients. If you choose to use this function, you should call 'waitForInbound' +-- after receiving the final message (and before exiting the scope of +-- 'acceptCall'). +recvInputSTM :: forall rpc. + Call rpc + -> STM (StreamElem NoMetadata (Input rpc)) recvInputSTM Call{callChannel} = first ignoreTrailersOnly <$> Session.recv callChannel where - ignoreTrailersOnly :: Either RequestHeaders () -> () - ignoreTrailersOnly _ = () + ignoreTrailersOnly :: + Either RequestHeaders NoMetadata + -> NoMetadata + ignoreTrailersOnly _ = NoMetadata -- | STM version of 'sendOutput' -- --- You /MUST/ call 'initiateResponse' before calling 'sendOutputSTM'; throws --- 'ResponseNotInitiated' otherwise. This is a low-level API; most users can use --- 'sendOutput' instead. +-- This is a low-level API; most users can use 'sendOutput' instead. +-- +-- If you choose to use 'sendOutputSTM' instead, you have two responsibilities: +-- +-- * You must call 'initiateResponse' before calling 'sendOutputSTM'; +-- 'sendOutputSTM' will throw 'ResponseNotInitiated' otherwise. +-- * You must call 'waitForOutbound' after sending the final message (and before +-- exiting the scope of 'acceptCall'); if you don't, you risk that the HTTP2 +-- stream is cancelled. +-- +-- Implementation note: we cannot call 'waitForOutbound' /in/ 'acceptCall': if +-- the handler for whatever reason never writes the final message, then such a +-- call would block indefinitely. sendOutputSTM :: Call rpc -> StreamElem [CustomMetadata] (Output rpc) -> STM () sendOutputSTM Call{callChannel, callResponseKickoff} msg = do mKickoff <- tryReadTMVar callResponseKickoff @@ -343,7 +467,7 @@ sendTrailersOnly :: Call rpc -> [CustomMetadata] -> IO () sendTrailersOnly Call{callResponseKickoff} metadata = do updated <- atomically $ tryPutTMVar callResponseKickoff $ KickoffTrailersOnly trailers - unless updated $ throwIO ResponseAlreadyInitiated + unless updated $ throwM ResponseAlreadyInitiated where trailers :: TrailersOnly trailers = TrailersOnly $ ProperTrailers { @@ -352,6 +476,14 @@ sendTrailersOnly Call{callResponseKickoff} metadata = do , trailerMetadata = metadata } +-- | Wait for all outbound messages to have been processed +-- +-- This function /must/ be called before leaving the scope of 'acceptCall'. +-- However, 'sendOutput' will call 'waitForOutbound' when the last message is +-- sent, so you only need to worry about this if you are using 'sendOutputSTM'. +waitForOutbound :: Call rpc -> IO () +waitForOutbound = void . Session.waitForOutbound . callChannel + data ResponseKickoffException = ResponseAlreadyInitiated | ResponseNotInitiated @@ -372,14 +504,14 @@ recvFinalInput :: forall rpc. Call rpc -> IO (Input rpc) recvFinalInput call@Call{} = do inp1 <- recvInput call case inp1 of - NoMoreElems () -> throwIO $ TooFewInputs @rpc - FinalElem inp () -> return inp - StreamElem inp -> do + NoMoreElems NoMetadata -> throwM $ TooFewInputs @rpc + FinalElem inp NoMetadata -> return inp + StreamElem inp -> do inp2 <- recvInput call case inp2 of - NoMoreElems () -> return inp - FinalElem inp' _ -> throwIO $ TooManyInputs @rpc inp' - StreamElem inp' -> throwIO $ TooManyInputs @rpc inp' + NoMoreElems NoMetadata -> return inp + FinalElem inp' NoMetadata -> throwM $ TooManyInputs @rpc inp' + StreamElem inp' -> throwM $ TooManyInputs @rpc inp' -- | Send final output -- diff --git a/src/Network/GRPC/Server/Connection.hs b/src/Network/GRPC/Server/Connection.hs index 77b08775..56116d2c 100644 --- a/src/Network/GRPC/Server/Connection.hs +++ b/src/Network/GRPC/Server/Connection.hs @@ -23,7 +23,7 @@ import Network.HTTP.Types qualified as HTTP import Network.HTTP2.Server qualified as HTTP2 import Network.GRPC.Server.Context -import Network.GRPC.Spec.PseudoHeaders +import Network.GRPC.Spec import Network.GRPC.Util.Session (ConnectionToClient(..)) {------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/Server/Context.hs b/src/Network/GRPC/Server/Context.hs index f2daeb46..8b7fbcbe 100644 --- a/src/Network/GRPC/Server/Context.hs +++ b/src/Network/GRPC/Server/Context.hs @@ -20,8 +20,7 @@ import Data.Default import Network.GRPC.Common.Compression qualified as Compr import Network.GRPC.Server.Session (ServerSession) -import Network.GRPC.Spec.PseudoHeaders -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec import Network.GRPC.Util.Session qualified as Session {------------------------------------------------------------------------------- @@ -42,14 +41,26 @@ withContext params k = k $ ServerContext{params} -------------------------------------------------------------------------------} data ServerParams = ServerParams { - serverDebugTracer :: Tracer IO ServerDebugMsg - , serverCompression :: Compr.Negotation + -- | Server compression preferences + serverCompression :: Compr.Negotation + + -- | Tracer for exceptions thrown by handlers + -- + -- The default uses 'stdoutTracer'. + , serverExceptionTracer :: Tracer IO SomeException + + -- | Tracer for debug messages + -- + -- This is prmarily for debugging @grapesy@ itself; most client code will + -- probably want to use 'nullTracer' here. + , serverDebugTracer :: Tracer IO ServerDebugMsg } instance Default ServerParams where def = ServerParams { - serverDebugTracer = nullTracer - , serverCompression = def + serverCompression = def + , serverExceptionTracer = contramap show stdoutTracer + , serverDebugTracer = nullTracer } {------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/Server/Handler.hs b/src/Network/GRPC/Server/Handler.hs index 2a52c67a..094f54de 100644 --- a/src/Network/GRPC/Server/Handler.hs +++ b/src/Network/GRPC/Server/Handler.hs @@ -28,8 +28,7 @@ import Data.Proxy import Data.Typeable import Network.GRPC.Server.Call -import Network.GRPC.Spec.PseudoHeaders -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec {------------------------------------------------------------------------------- Handlers diff --git a/src/Network/GRPC/Server/Protobuf.hs b/src/Network/GRPC/Server/Protobuf.hs index 0c87fb8d..38244ac2 100644 --- a/src/Network/GRPC/Server/Protobuf.hs +++ b/src/Network/GRPC/Server/Protobuf.hs @@ -12,7 +12,7 @@ import Data.Kind import Data.ProtoLens.Service.Types import GHC.TypeLits -import Network.GRPC.Spec.RPC.Protobuf +import Network.GRPC.Spec {------------------------------------------------------------------------------- Compute full Protobuf API diff --git a/src/Network/GRPC/Server/Session.hs b/src/Network/GRPC/Server/Session.hs index 1ab548f4..5da468a0 100644 --- a/src/Network/GRPC/Server/Session.hs +++ b/src/Network/GRPC/Server/Session.hs @@ -10,13 +10,8 @@ import Data.List.NonEmpty (NonEmpty) import Data.Proxy import Network.HTTP.Types qualified as HTTP -import Network.GRPC.Common.Compression (Compression, CompressionId) import Network.GRPC.Common.Compression qualified as Compr -import Network.GRPC.Spec qualified as GRPC -import Network.GRPC.Spec.LengthPrefixed qualified as LP -import Network.GRPC.Spec.Request qualified as Req -import Network.GRPC.Spec.Response qualified as Resp -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec import Network.GRPC.Util.Session {------------------------------------------------------------------------------- @@ -36,56 +31,56 @@ data ServerOutbound rpc instance IsRPC rpc => DataFlow (ServerInbound rpc) where data Headers (ServerInbound rpc) = InboundHeaders { - inbHeaders :: GRPC.RequestHeaders + inbHeaders :: RequestHeaders , inbCompression :: Compression } deriving (Show) - type Message (ServerInbound rpc) = Input rpc - type ProperTrailers (ServerInbound rpc) = () + type Message (ServerInbound rpc) = Input rpc + type Trailers (ServerInbound rpc) = NoMetadata -- See discussion of 'TrailersOnly' in 'ClientOutbound' - type TrailersOnly (ServerInbound rpc) = GRPC.RequestHeaders + type NoMessages (ServerInbound rpc) = RequestHeaders instance IsRPC rpc => DataFlow (ServerOutbound rpc) where data Headers (ServerOutbound rpc) = OutboundHeaders { - outHeaders :: GRPC.ResponseHeaders + outHeaders :: ResponseHeaders , outCompression :: Compression } deriving (Show) - type Message (ServerOutbound rpc) = Output rpc - type ProperTrailers (ServerOutbound rpc) = GRPC.ProperTrailers - type TrailersOnly (ServerOutbound rpc) = GRPC.TrailersOnly + type Message (ServerOutbound rpc) = Output rpc + type Trailers (ServerOutbound rpc) = ProperTrailers + type NoMessages (ServerOutbound rpc) = TrailersOnly instance IsRPC rpc => IsSession (ServerSession rpc) where type Inbound (ServerSession rpc) = ServerInbound rpc type Outbound (ServerSession rpc) = ServerOutbound rpc - parseProperTrailers _ = \_ -> return () - buildProperTrailers _ = Resp.buildProperTrailers + parseInboundTrailers _ = \_ -> return NoMetadata + buildOutboundTrailers _ = buildProperTrailers - parseMsg _ = LP.parseInput (Proxy @rpc) . inbCompression - buildMsg _ = LP.buildOutput (Proxy @rpc) . outCompression + parseMsg _ = parseInput (Proxy @rpc) . inbCompression + buildMsg _ = buildOutput (Proxy @rpc) . outCompression instance IsRPC rpc => AcceptSession (ServerSession rpc) where parseRequestRegular server info = do - requestHeaders :: GRPC.RequestHeaders <- - case Req.parseHeaders (Proxy @rpc) (requestHeaders info) of + requestHeaders :: RequestHeaders <- + case parseRequestHeaders (Proxy @rpc) (requestHeaders info) of Left err -> throwIO $ RequestInvalidHeaders err Right hdrs -> return hdrs cIn :: Compression <- Compr.getSupported (serverCompression server) $ - GRPC.requestCompression requestHeaders + requestCompression requestHeaders return InboundHeaders { inbHeaders = requestHeaders , inbCompression = cIn } - parseRequestTrailersOnly _ info = - case Req.parseHeaders (Proxy @rpc) (requestHeaders info) of + parseRequestNoMessages _ info = + case parseRequestHeaders (Proxy @rpc) (requestHeaders info) of Left err -> throwIO $ RequestInvalidHeaders err Right hdrs -> return hdrs @@ -94,9 +89,9 @@ instance IsRPC rpc => AcceptSession (ServerSession rpc) where , responseHeaders = case start of FlowStartRegular headers -> - Resp.buildHeaders (Proxy @rpc) (outHeaders headers) - FlowStartTrailersOnly trailers -> - Resp.buildTrailersOnly trailers + buildResponseHeaders (Proxy @rpc) (outHeaders headers) + FlowStartNoMessages trailers -> + buildTrailersOnly trailers } {------------------------------------------------------------------------------- diff --git a/src/Network/GRPC/Server/StreamType.hs b/src/Network/GRPC/Server/StreamType.hs index 1f9e97fe..e982d3e9 100644 --- a/src/Network/GRPC/Server/StreamType.hs +++ b/src/Network/GRPC/Server/StreamType.hs @@ -16,7 +16,7 @@ import Data.Proxy import Network.GRPC.Common.StreamType import Network.GRPC.Server -import Network.GRPC.Spec.RPC +import Network.GRPC.Spec {------------------------------------------------------------------------------- Construct 'RpcHandler' diff --git a/src/Network/GRPC/Spec.hs b/src/Network/GRPC/Spec.hs index 43b19684..c689defb 100644 --- a/src/Network/GRPC/Spec.hs +++ b/src/Network/GRPC/Spec.hs @@ -1,398 +1,117 @@ --- | Parts of the gRPC specification that are not HTTP2 specific. +-- | Pure implementation of the gRPC spec +-- +-- Most code will not need to use this module directly. -- -- Intended for unqualified import. module Network.GRPC.Spec ( - -- * Call parameters - CallParams(..) + -- * RPC + IsRPC(..) + -- ** Instances + , Protobuf + , BinaryRpc + -- ** Serialization + , parseInput + , parseOutput + , buildInput + , buildOutput + -- * Streaming types + , StreamingType(..) + , SupportsStreamingType + , HasStreamingType(..) + -- ** Handlers + , HandlerFor + , NonStreamingHandler(..) + , ClientStreamingHandler(..) + , ServerStreamingHandler(..) + , BiDiStreamingHandler(..) + -- ** Execution + , nonStreaming + , clientStreaming + , serverStreaming + , biDiStreaming + -- ** Construction + , mkNonStreaming + , mkClientStreaming + , mkServerStreaming + , mkBiDiStreaming + -- * Compression + , CompressionId(..) + , Compression(..) + -- ** Compression algorithms + , noCompression + , gzip + , allSupportedCompression + -- * Requests + , RequestHeaders(..) + -- ** Parameters + , CallParams(..) + -- ** Pseudo-headers + , PseudoHeaders(..) + , ServerHeaders(..) + , ResourceHeaders(..) + , Path(..) + , Authority(..) + , Scheme(..) + , Method(..) + , rpcPath + -- ** Serialization + , RawPseudoHeaders(..) + , RawServerHeaders(..) + , RawResourceHeaders(..) + , InvalidPseudoHeaders(..) + , buildResourceHeaders + , buildServerHeaders + , parsePseudoHeaders + -- ** Headers + , buildRequestHeaders + , parseRequestHeaders -- ** Timeouts , Timeout(..) - , TimeoutValue(TimeoutValue, getTimeoutValue) + , TimeoutValue(..) , TimeoutUnit(..) , timeoutToMicro - -- * Inputs (message sent to the peer) - , RequestHeaders(..) - , IsFinal(..) - -- * Outputs (messages received from the peer) + -- * Responses + -- ** Headers , ResponseHeaders(..) + , buildResponseHeaders + , parseResponseHeaders + -- ** Trailers , ProperTrailers(..) , TrailersOnly(..) - -- * GRPC status + , parseProperTrailers + , parseTrailersOnly + , buildProperTrailers + , buildTrailersOnly + -- *** Status , GrpcStatus(..) , GrpcError(..) - , fromGrpcStatus - , toGrpcStatus + -- *** Classificaiton + , GrpcException(..) + , grpcExceptionToTrailers + , grpcExceptionFromTrailers + -- * Metadata + , CustomMetadata(..) + , HeaderName(..) + , BinaryValue(..) + , AsciiValue(..) + , NoMetadata(..) + , customHeaderName + , parseCustomMetadata + , buildCustomMetadata + , safeHeaderName + , safeAsciiValue ) where -import Control.Exception -import Data.Default -import Data.List.NonEmpty (NonEmpty) -import Data.Text (Text) -import Generics.SOP qualified as SOP -import GHC.Generics qualified as GHC -import GHC.Show - -import Network.GRPC.Spec.Compression (CompressionId) +import Network.GRPC.Spec.Call +import Network.GRPC.Spec.Compression import Network.GRPC.Spec.CustomMetadata - -{------------------------------------------------------------------------------- - Requests --------------------------------------------------------------------------------} - --- | RPC parameters that can be chosen on a per-call basis -data CallParams = CallParams { - -- | Timeout - -- - -- If Timeout is omitted a server should assume an infinite timeout. - -- Client implementations are free to send a default minimum timeout based - -- on their deployment requirements. - callTimeout :: Maybe Timeout - - -- | Custom metadata - -- - -- This is the metadata included in the request. (The server can include - -- its own metadata in the response: see 'responseMetadata' and - -- 'trailerMetadatda'.) - , callRequestMetadata :: [CustomMetadata] - } - deriving stock (Show, Eq) - --- | Default 'CallParams' -instance Default CallParams where - def = CallParams { - callTimeout = Nothing - , callRequestMetadata = [] - } - -{------------------------------------------------------------------------------- - Timeouts --------------------------------------------------------------------------------} - -data Timeout = Timeout TimeoutUnit TimeoutValue - deriving stock (Show, Eq) - --- | Positive integer with ASCII representation of at most 8 digits -newtype TimeoutValue = UnsafeTimeoutValue { - getTimeoutValue :: Word - } - deriving newtype (Eq) - --- | 'Show' instance relies on the 'TimeoutValue' pattern synonym -instance Show TimeoutValue where - showsPrec p (UnsafeTimeoutValue val) = showParen (p >= appPrec1) $ - showString "TimeoutValue " - . showsPrec appPrec1 val - -pattern TimeoutValue :: Word -> TimeoutValue -pattern TimeoutValue t <- UnsafeTimeoutValue t - where - TimeoutValue t - | isValidTimeoutValue t = UnsafeTimeoutValue t - | otherwise = error $ "invalid TimeoutValue: " ++ show t - -{-# COMPLETE TimeoutValue #-} - -isValidTimeoutValue :: Word -> Bool -isValidTimeoutValue t = length (show t) <= 8 - -data TimeoutUnit = - Hour - | Minute - | Second - | Millisecond - | Microsecond - | Nanosecond - deriving stock (Show, Eq) - --- | Translate 'Timeout' to microseconds --- --- For 'Nanosecond' timeout we round up. -timeoutToMicro :: Timeout -> Integer -timeoutToMicro = \case - Timeout Hour (TimeoutValue n) -> mult n $ 1 * 1_000 * 1_000 * 60 * 24 - Timeout Minute (TimeoutValue n) -> mult n $ 1 * 1_000 * 1_000 * 60 - Timeout Second (TimeoutValue n) -> mult n $ 1 * 1_000 * 1_000 - Timeout Millisecond (TimeoutValue n) -> mult n $ 1 * 1_000 - Timeout Microsecond (TimeoutValue n) -> mult n $ 1 - Timeout Nanosecond (TimeoutValue n) -> nano n - where - mult :: Word -> Integer -> Integer - mult n m = fromIntegral n * m - - nano :: Word -> Integer - nano n = fromIntegral $ - mu + if n' == 0 then 0 else 1 - where - (mu, n') = divMod n 1_000 - -{------------------------------------------------------------------------------- - Inputs (message sent to the peer) --------------------------------------------------------------------------------} - --- | Full set of call parameters required to construct the RPC call --- --- This is constructed internally; it is not part of the public API. -data RequestHeaders = RequestHeaders { - -- | Timeout - requestTimeout :: Maybe Timeout - - -- | Custom metadata - , requestMetadata :: [CustomMetadata] - - -- | Compression used for outgoing messages - , requestCompression :: Maybe CompressionId - - -- | Accepted compression algorithms for incoming messages - -- - -- @Maybe (NonEmpty ..)@ is perhaps a bit strange (why not just @[]@), but - -- it emphasizes the specification: /if/ the header is present, it must be - -- a non-empty list. - , requestAcceptCompression :: Maybe (NonEmpty CompressionId) - } - deriving stock (Show, Eq) - deriving stock (GHC.Generic) - deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) - --- | Mark a input sent as final -data IsFinal = Final | NotFinal - deriving stock (Show, Eq) - -{------------------------------------------------------------------------------- - Outputs (messages received from the peer) --------------------------------------------------------------------------------} - --- | Response headers -data ResponseHeaders = ResponseHeaders { - responseCompression :: Maybe CompressionId - , responseAcceptCompression :: Maybe (NonEmpty CompressionId) - , responseMetadata :: [CustomMetadata] - } - deriving stock (Show, Eq) - deriving stock (GHC.Generic) - deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) - --- | Information sent by the peer after the final output --- --- Response trailers are a --- [HTTP2 concept](https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.3): --- they are HTTP headers that are sent /after/ the content body. For example, --- imagine the server is streaming a file that it's reading from disk; it could --- use trailers to give the client an MD5 checksum when streaming is complete. -data ProperTrailers = ProperTrailers { - trailerGrpcStatus :: GrpcStatus - , trailerGrpcMessage :: Maybe Text - , trailerMetadata :: [CustomMetadata] - } - deriving stock (Show, Eq) - deriving stock (GHC.Generic) - deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) - --- | Trailers sent in the gRPC Trailers-Only case --- --- In the current version of the spec, the information in 'TrailersOnly' is --- identical to the 'ProperTrailers' case (but they do require a slightly --- different function to parse/unparse). -newtype TrailersOnly = TrailersOnly { - getTrailersOnly :: ProperTrailers - } - deriving stock (Show, Eq) - -{------------------------------------------------------------------------------- - gRPC status --------------------------------------------------------------------------------} - --- | gRPC status --- --- Defined in . -data GrpcStatus = - GrpcOk - | GrpcError GrpcError - deriving stock (Show, Eq) - --- | gRPC error code --- --- This is a subset of the gRPC status codes. See 'GrpcStatus'. -data GrpcError = - -- | Cancelled - -- - -- The operation was cancelled, typically by the caller. - GrpcCancelled - - -- | Unknown error - -- - -- For example, this error may be returned when a @Status@ value received - -- from another address space belongs to an error space that is not known in - -- this address space. Also errors raised by APIs that do not return enough - -- error information may be converted to this error. - | GrpcUnknown - - -- | Invalid argument - -- - -- The client specified an invalid argument. Note that this differs from - -- 'GrpcFailedPrecondition': 'GrpcInvalidArgumen'` indicates arguments that - -- are problematic regardless of the state of the system (e.g., a malformed - -- file name). - | GrpcInvalidArgument - - -- | Deadline exceeded - -- - -- The deadline expired before the operation could complete. For operations - -- that change the state of the system, this error may be returned even if - -- the operation has completed successfully. For example, a successful - -- response from a server could have been delayed long. - | GrpcDeadlineExceeded - - -- | Not found - -- - -- Some requested entity (e.g., file or directory) was not found. - -- - -- Note to server developers: if a request is denied for an entire class of - -- users, such as gradual feature rollout or undocumented allowlist, - -- 'GrpcNotFound' may be used. - -- - -- If a request is denied for some users within a class of users, such as - -- user-based access control, 'GrpcPermissionDenied' must be used. - | GrpcNotFound - - -- | Already exists - -- - -- The entity that a client attempted to create (e.g., file or directory) - -- already exists. - | GrpcAlreadyExists - - -- | Permission denied - -- - -- The caller does not have permission to execute the specified operation. - -- - -- * 'GrpcPermissionDenied' must not be used for rejections caused by - -- exhausting some resource (use 'GrpcResourceExhausted' instead for those - -- errors). - -- * 'GrpcPermissionDenoed' must not be used if the caller can not be - -- identified (use 'GrpcUnauthenticated' instead for those errors). - -- - -- This error code does not imply the request is valid or the requested - -- entity exists or satisfies other pre-conditions. - | GrpcPermissionDenied - - -- | Resource exhausted - -- - -- Some resource has been exhausted, perhaps a per-user quota, or perhaps - -- the entire file system is out of space. - | GrpcResourceExhausted - - -- | Failed precondition - -- - -- The operation was rejected because the system is not in a state required - -- for the operation's execution. For example, the directory to be deleted - -- is non-empty, an rmdir operation is applied to a non-directory, etc. - -- - -- Service implementors can use the following guidelines to decide between - -- 'GrpcFailedPrecondition', 'GrpcAborted', and 'GrpcUnvailable': - -- - -- (a) Use 'GrpcUnavailable' if the client can retry just the failing call. - -- (b) Use 'GrpcAborted' if the client should retry at a higher level (e.g., - -- when a client-specified test-and-set fails, indicating the client - -- should restart a read-modify-write sequence). - -- (c) Use `GrpcFailedPrecondition` if the client should not retry until the - -- system state has been explicitly fixed. E.g., if an @rmdir@ fails - -- because the directory is non-empty, 'GrpcFailedPrecondition' should - -- be returned since the client should not retry unless the files are - -- deleted from the directory. - | GrpcFailedPrecondition - - -- | Aborted - -- - -- The operation was aborted, typically due to a concurrency issue such as a - -- sequencer check failure or transaction abort. See the guidelines above - -- for deciding between 'GrpcFailedPrecondition', 'GrpcAborted', and - -- 'GrpcUnavailable'. - | GrpcAborted - - -- | Out of range - -- - -- The operation was attempted past the valid range. E.g., seeking or - -- reading past end-of-file. - -- - -- Unlike 'GrpcInvalidArgument', this error indicates a problem that may be - -- fixed if the system state changes. For example, a 32-bit file system will - -- generate 'GrpcInvalidArgument' if asked to read at an offset that is not - -- in the range @[0, 2^32-1]@, but it will generate 'GrpcOutOfRange' if - -- asked to read from an offset past the current file size. - -- - -- There is a fair bit of overlap between 'GrpcFailedPrecondition' and - -- 'GrpcOutOfRange'. We recommend using 'GrpcOutOfRange' (the more specific - -- error) when it applies so that callers who are iterating through a space - -- can easily look for an 'GrpcOutOfRange' error to detect when they are - -- done. - | GrpcOutOfRange - - -- | Unimplemented - -- - -- The operation is not implemented or is not supported/enabled in this - -- service. - | GrpcUnimplemented - - -- | Internal errors - -- - -- This means that some invariants expected by the underlying system have - -- been broken. This error code is reserved for serious errors. - | GrpcInternal - - -- | Unavailable - -- - -- The service is currently unavailable. This is most likely a transient - -- condition, which can be corrected by retrying with a backoff. Note that - -- it is not always safe to retry non-idempotent operations. - | GrpcUnavailable - - -- | Data loss - -- - -- Unrecoverable data loss or corruption. - | GrpcDataLoss - - -- | Unauthenticated - -- - -- The request does not have valid authentication credentials for the - -- operation. - | GrpcUnauthenticated - deriving stock (Show, Eq) - deriving anyclass (Exception) - -fromGrpcStatus :: GrpcStatus -> Word -fromGrpcStatus GrpcOk = 0 -fromGrpcStatus (GrpcError GrpcCancelled) = 1 -fromGrpcStatus (GrpcError GrpcUnknown) = 2 -fromGrpcStatus (GrpcError GrpcInvalidArgument) = 3 -fromGrpcStatus (GrpcError GrpcDeadlineExceeded) = 4 -fromGrpcStatus (GrpcError GrpcNotFound) = 5 -fromGrpcStatus (GrpcError GrpcAlreadyExists) = 6 -fromGrpcStatus (GrpcError GrpcPermissionDenied) = 7 -fromGrpcStatus (GrpcError GrpcResourceExhausted) = 8 -fromGrpcStatus (GrpcError GrpcFailedPrecondition) = 9 -fromGrpcStatus (GrpcError GrpcAborted) = 10 -fromGrpcStatus (GrpcError GrpcOutOfRange) = 11 -fromGrpcStatus (GrpcError GrpcUnimplemented) = 12 -fromGrpcStatus (GrpcError GrpcInternal) = 13 -fromGrpcStatus (GrpcError GrpcUnavailable) = 14 -fromGrpcStatus (GrpcError GrpcDataLoss) = 15 -fromGrpcStatus (GrpcError GrpcUnauthenticated) = 16 - -toGrpcStatus :: Word -> Maybe GrpcStatus -toGrpcStatus 0 = Just $ GrpcOk -toGrpcStatus 1 = Just $ GrpcError $ GrpcCancelled -toGrpcStatus 2 = Just $ GrpcError $ GrpcUnknown -toGrpcStatus 3 = Just $ GrpcError $ GrpcInvalidArgument -toGrpcStatus 4 = Just $ GrpcError $ GrpcDeadlineExceeded -toGrpcStatus 5 = Just $ GrpcError $ GrpcNotFound -toGrpcStatus 6 = Just $ GrpcError $ GrpcAlreadyExists -toGrpcStatus 7 = Just $ GrpcError $ GrpcPermissionDenied -toGrpcStatus 8 = Just $ GrpcError $ GrpcResourceExhausted -toGrpcStatus 9 = Just $ GrpcError $ GrpcFailedPrecondition -toGrpcStatus 10 = Just $ GrpcError $ GrpcAborted -toGrpcStatus 11 = Just $ GrpcError $ GrpcOutOfRange -toGrpcStatus 12 = Just $ GrpcError $ GrpcUnimplemented -toGrpcStatus 13 = Just $ GrpcError $ GrpcInternal -toGrpcStatus 14 = Just $ GrpcError $ GrpcUnavailable -toGrpcStatus 15 = Just $ GrpcError $ GrpcDataLoss -toGrpcStatus 16 = Just $ GrpcError $ GrpcUnauthenticated -toGrpcStatus _ = Nothing - +import Network.GRPC.Spec.LengthPrefixed +import Network.GRPC.Spec.PseudoHeaders +import Network.GRPC.Spec.Request +import Network.GRPC.Spec.Response +import Network.GRPC.Spec.RPC +import Network.GRPC.Spec.RPC.Binary +import Network.GRPC.Spec.RPC.Protobuf +import Network.GRPC.Spec.RPC.StreamType +import Network.GRPC.Spec.Status +import Network.GRPC.Spec.Timeout diff --git a/src/Network/GRPC/Spec/Call.hs b/src/Network/GRPC/Spec/Call.hs new file mode 100644 index 00000000..1b715916 --- /dev/null +++ b/src/Network/GRPC/Spec/Call.hs @@ -0,0 +1,39 @@ +module Network.GRPC.Spec.Call ( + -- * Parameters + CallParams(..) + ) where + +import Data.Default + +import Network.GRPC.Spec.CustomMetadata +import Network.GRPC.Spec.Timeout + +{------------------------------------------------------------------------------- + Parameteres +-------------------------------------------------------------------------------} + +-- | RPC parameters that can be chosen on a per-call basis +data CallParams = CallParams { + -- | Timeout + -- + -- If Timeout is omitted a server should assume an infinite timeout. + -- Client implementations are free to send a default minimum timeout based + -- on their deployment requirements. + callTimeout :: Maybe Timeout + + -- | Custom metadata + -- + -- This is the metadata included in the request. (The server can include + -- its own metadata in the response: see 'responseMetadata' and + -- 'trailerMetadatda'.) + , callRequestMetadata :: [CustomMetadata] + } + deriving stock (Show, Eq) + +-- | Default 'CallParams' +instance Default CallParams where + def = CallParams { + callTimeout = Nothing + , callRequestMetadata = [] + } + diff --git a/src/Network/GRPC/Spec/Common.hs b/src/Network/GRPC/Spec/Common.hs index d124086f..03363861 100644 --- a/src/Network/GRPC/Spec/Common.hs +++ b/src/Network/GRPC/Spec/Common.hs @@ -29,8 +29,7 @@ import Data.List.NonEmpty (NonEmpty(..)) import Data.Proxy import Network.HTTP.Types qualified as HTTP -import Network.GRPC.Spec.Compression (CompressionId) -import Network.GRPC.Spec.Compression qualified as Compr +import Network.GRPC.Spec.Compression import Network.GRPC.Spec.RPC import Network.GRPC.Util.ByteString import Network.GRPC.Util.Partial @@ -67,7 +66,7 @@ parseContentType proxy hdr = buildMessageEncoding :: CompressionId -> HTTP.Header buildMessageEncoding compr = ( "grpc-encoding" - , Compr.serializeId compr + , serializeCompressionId compr ) parseMessageEncoding :: @@ -75,7 +74,7 @@ parseMessageEncoding :: => HTTP.Header -> m CompressionId parseMessageEncoding (_name, value) = - return $ Compr.deserializeId value + return $ deserializeCompressionId value {------------------------------------------------------------------------------- > Message-Accept-Encoding → @@ -85,7 +84,7 @@ parseMessageEncoding (_name, value) = buildMessageAcceptEncoding :: NonEmpty CompressionId -> HTTP.Header buildMessageAcceptEncoding compr = ( "grpc-accept-encoding" - , mconcat . intersperse "," . map Compr.serializeId $ toList compr + , mconcat . intersperse "," . map serializeCompressionId $ toList compr ) parseMessageAcceptEncoding :: @@ -94,6 +93,6 @@ parseMessageAcceptEncoding :: -> m (NonEmpty CompressionId) parseMessageAcceptEncoding (_name, value) = expectAtLeastOne - . map (Compr.deserializeId . strip) + . map (deserializeCompressionId . strip) . BS.Strict.splitWith (== ascii ',') $ value \ No newline at end of file diff --git a/src/Network/GRPC/Spec/Compression.hs b/src/Network/GRPC/Spec/Compression.hs index 2c19895e..a67b156a 100644 --- a/src/Network/GRPC/Spec/Compression.hs +++ b/src/Network/GRPC/Spec/Compression.hs @@ -2,21 +2,18 @@ -- | Compression -- --- Intended for qualified import. --- --- > import Network.GRPC.Spec.Compression (Compression) --- > import Network.GRPC.Spec.Compression qualified as Compr +-- Intended for unqualified import. module Network.GRPC.Spec.Compression ( -- * Definition Compression(..) - , allSupported - , isIdentity + , allSupportedCompression + , compressionIsIdentity -- ** ID , CompressionId(..) - , serializeId - , deserializeId + , serializeCompressionId + , deserializeCompressionId -- * Specific coders - , identity + , noCompression , gzip ) where @@ -50,8 +47,8 @@ instance Show Compression where -- -- The order of this list is important: algorithms listed earlier are preferred -- over algorithms listed later. -allSupported :: NonEmpty Compression -allSupported = gzip :| [identity] +allSupportedCompression :: NonEmpty Compression +allSupportedCompression = gzip :| [noCompression] {------------------------------------------------------------------------------- Compression ID @@ -70,35 +67,35 @@ data CompressionId = | Custom Strict.ByteString deriving (Eq, Ord) -serializeId :: CompressionId -> Strict.ByteString -serializeId Identity = "identity" -serializeId GZip = "gzip" -serializeId Deflate = "deflate" -serializeId Snappy = "snappy" -serializeId (Custom i) = i +serializeCompressionId :: CompressionId -> Strict.ByteString +serializeCompressionId Identity = "identity" +serializeCompressionId GZip = "gzip" +serializeCompressionId Deflate = "deflate" +serializeCompressionId Snappy = "snappy" +serializeCompressionId (Custom i) = i -deserializeId :: Strict.ByteString -> CompressionId -deserializeId "identity" = Identity -deserializeId "gzip" = GZip -deserializeId "deflate" = Deflate -deserializeId "snappy" = Snappy -deserializeId i = Custom i +deserializeCompressionId :: Strict.ByteString -> CompressionId +deserializeCompressionId "identity" = Identity +deserializeCompressionId "gzip" = GZip +deserializeCompressionId "deflate" = Deflate +deserializeCompressionId "snappy" = Snappy +deserializeCompressionId i = Custom i instance Show CompressionId where - show = BS.UTF8.toString . serializeId + show = BS.UTF8.toString . serializeCompressionId instance IsString CompressionId where - fromString = deserializeId . BS.UTF8.fromString + fromString = deserializeCompressionId . BS.UTF8.fromString -isIdentity :: Compression -> Bool -isIdentity = (== Identity) . compressionId +compressionIsIdentity :: Compression -> Bool +compressionIsIdentity = (== Identity) . compressionId {------------------------------------------------------------------------------- Compression algorithms -------------------------------------------------------------------------------} -identity :: Compression -identity = Compression { +noCompression :: Compression +noCompression = Compression { compressionId = Identity , compress = id , decompress = id diff --git a/src/Network/GRPC/Spec/CustomMetadata.hs b/src/Network/GRPC/Spec/CustomMetadata.hs index ffc27db2..8ac4e561 100644 --- a/src/Network/GRPC/Spec/CustomMetadata.hs +++ b/src/Network/GRPC/Spec/CustomMetadata.hs @@ -8,14 +8,17 @@ module Network.GRPC.Spec.CustomMetadata ( -- * Definition CustomMetadata(..) + , customHeaderName -- * Header-Name , HeaderName(HeaderName) , getHeaderName , safeHeaderName - -- * ASCII-Value + -- * ASCII value , AsciiValue(AsciiValue) , getAsciiValue , safeAsciiValue + -- * Binary value + , BinaryValue(..) -- * To and from HTTP headers , buildCustomMetadata , parseCustomMetadata @@ -28,10 +31,13 @@ import Data.ByteString.Base64 qualified as BS.Strict.B64 import Data.CaseInsensitive qualified as CI import Data.String import Data.Word +import GHC.Generics qualified as GHC import GHC.Show import Network.HTTP.Types qualified as HTTP +import Text.Show.Pretty import Network.GRPC.Util.ByteString (strip, ascii, dropEnd) +import Network.GRPC.Util.PrettyVal {------------------------------------------------------------------------------- Definition @@ -58,7 +64,7 @@ data CustomMetadata = -- decoding as headers are sent and received). -- -- Since these considered binary data, padding considerations do not apply. - BinaryHeader HeaderName Strict.ByteString + BinaryHeader HeaderName BinaryValue -- | ASCII header -- @@ -71,7 +77,12 @@ data CustomMetadata = -- as "space and horizontal tab" -- . | AsciiHeader HeaderName AsciiValue - deriving stock (Show, Eq) + deriving stock (Show, Eq, Ord, GHC.Generic) + deriving anyclass (PrettyVal) + +customHeaderName :: CustomMetadata -> HeaderName +customHeaderName (BinaryHeader n _) = n +customHeaderName (AsciiHeader n _) = n {------------------------------------------------------------------------------- Header-Name @@ -90,6 +101,8 @@ newtype HeaderName = UnsafeHeaderName { getHeaderName :: Strict.ByteString } deriving stock (Eq, Ord) + deriving newtype (IsString) + deriving (PrettyVal) via StrictByteString_IsString HeaderName -- | 'Show' instance relies on the 'HeaderName' pattern synonym instance Show HeaderName where @@ -97,9 +110,6 @@ instance Show HeaderName where showString "HeaderName " . showsPrec appPrec1 name -instance IsString HeaderName where - fromString = HeaderName . fromString - pattern HeaderName :: Strict.ByteString -> HeaderName pattern HeaderName n <- UnsafeHeaderName n where @@ -146,6 +156,7 @@ newtype AsciiValue = UnsafeAsciiValue { getAsciiValue :: Strict.ByteString } deriving stock (Eq, Ord) + deriving PrettyVal via StrictByteString_IsString AsciiValue -- | 'Show' instance relies on the 'AsciiValue' pattern synonym instance Show AsciiValue where @@ -177,6 +188,16 @@ isValidAsciiValue bs = and [ , BS.Strict.all (\c -> 0x20 <= c && c <= 0x7E) bs ] +{------------------------------------------------------------------------------- + Binary value +-------------------------------------------------------------------------------} + +newtype BinaryValue = BinaryValue { + getBinaryValue :: Strict.ByteString + } + deriving stock (Show, Eq, Ord) + deriving PrettyVal via StrictByteString_Binary "BinaryValue" BinaryValue + {------------------------------------------------------------------------------- To/from HTTP2 -------------------------------------------------------------------------------} @@ -184,31 +205,37 @@ isValidAsciiValue bs = and [ buildCustomMetadata :: CustomMetadata -> HTTP.Header buildCustomMetadata (BinaryHeader name value) = ( CI.mk $ getHeaderName name <> "-bin" - , BS.Strict.B64.encode value + , BS.Strict.B64.encode (getBinaryValue value) ) buildCustomMetadata (AsciiHeader name value) = ( CI.mk $ getHeaderName name , getAsciiValue value ) -parseCustomMetadata :: MonadError String m => HTTP.Header-> m CustomMetadata +parseCustomMetadata :: MonadError String m => HTTP.Header -> m CustomMetadata parseCustomMetadata (name, value) | "grpc-" `BS.Strict.isPrefixOf` CI.foldedCase name = throwError $ "Reserved header: " ++ show (name, value) | "-bin" `BS.Strict.isSuffixOf` CI.foldedCase name - = case safeHeaderName (dropEnd 4 $ CI.foldedCase name) of - Just name' -> - return $ BinaryHeader name' value - _otherwise -> - throwError $ "Invalid custom binary header: " ++ show (name, value) + = case ( safeHeaderName (dropEnd 4 $ CI.foldedCase name) + , BS.Strict.B64.decode value + ) of + (Nothing, _) -> + throwError $ "Invalid header name: " ++ show (name, value) + (_, Left err) -> + throwError $ "Cannot decode binary header: " ++ err + (Just name', Right value') -> + return $ BinaryHeader name' (BinaryValue value') | otherwise = case ( safeHeaderName (CI.foldedCase name) , safeAsciiValue value ) of + (Nothing, _) -> + throwError $ "Invalid header name: " ++ show name + (_, Nothing) -> + throwError $ "Invalid ASCII header value: " ++ show value (Just name', Just value') -> return $ AsciiHeader name' value' - _otherwise -> - throwError $ "Invalid custom ASCII header: " ++ show (name, value) diff --git a/src/Network/GRPC/Spec/LengthPrefixed.hs b/src/Network/GRPC/Spec/LengthPrefixed.hs index c58fa95d..565c6906 100644 --- a/src/Network/GRPC/Spec/LengthPrefixed.hs +++ b/src/Network/GRPC/Spec/LengthPrefixed.hs @@ -1,11 +1,6 @@ -- | Length-prefixed messages -- -- These are used both for inputs and outputs. --- --- Intended for qualified import. --- --- > import Network.GRPC.Spec.LengthPrefixed (MessagePrefix) --- > import Network.GRPC.Spec.LengthPrefixed qualified as LP module Network.GRPC.Spec.LengthPrefixed ( -- * Message prefix MessagePrefix(..) @@ -23,13 +18,12 @@ import Data.ByteString.Builder (Builder) import Data.ByteString.Builder qualified as Builder import Data.ByteString.Lazy qualified as BS.Lazy import Data.ByteString.Lazy qualified as Lazy (ByteString) +import Data.Proxy import Data.Word -import Network.GRPC.Spec.Compression (Compression) -import Network.GRPC.Spec.Compression qualified as Compr +import Network.GRPC.Spec.Compression import Network.GRPC.Spec.RPC import Network.GRPC.Util.Parser (Parser(..)) -import Data.Proxy {------------------------------------------------------------------------------- Message prefix @@ -88,11 +82,11 @@ buildMsg build compr x = mconcat [ ] where compressed :: Lazy.ByteString - compressed = Compr.compress compr $ build x + compressed = compress compr $ build x prefix :: MessagePrefix prefix = MessagePrefix { - msgIsCompressed = not $ Compr.isIdentity compr + msgIsCompressed = not $ compressionIsIdentity compr , msgLength = fromIntegral $ BS.Lazy.length compressed } @@ -137,7 +131,7 @@ parseMsg parse compr = | otherwise = let (msg, rest) = BS.Lazy.splitAt (fromIntegral $ msgLength prefix) acc serialized = if msgIsCompressed prefix - then Compr.decompress compr msg + then decompress compr msg else msg in case parse serialized of Left err -> ParserError err diff --git a/src/Network/GRPC/Spec/RPC/Binary.hs b/src/Network/GRPC/Spec/RPC/Binary.hs index 99b2b236..3f88be8e 100644 --- a/src/Network/GRPC/Spec/RPC/Binary.hs +++ b/src/Network/GRPC/Spec/RPC/Binary.hs @@ -7,8 +7,8 @@ import Data.Proxy import Data.Text qualified as Text import GHC.TypeLits -import Network.GRPC.Common.StreamType import Network.GRPC.Spec.RPC +import Network.GRPC.Spec.RPC.StreamType {------------------------------------------------------------------------------- Binary format diff --git a/src/Network/GRPC/Spec/RPC/Protobuf.hs b/src/Network/GRPC/Spec/RPC/Protobuf.hs index 204a4f9c..869d6185 100644 --- a/src/Network/GRPC/Spec/RPC/Protobuf.hs +++ b/src/Network/GRPC/Spec/RPC/Protobuf.hs @@ -15,8 +15,8 @@ import Data.Text qualified as Text import Data.Typeable import GHC.TypeLits -import Network.GRPC.Common.StreamType import Network.GRPC.Spec.RPC +import Network.GRPC.Spec.RPC.StreamType {------------------------------------------------------------------------------- The spec defines the following in Appendix A, "GRPC for Protobuf": diff --git a/src/Network/GRPC/Spec/RPC/StreamType.hs b/src/Network/GRPC/Spec/RPC/StreamType.hs new file mode 100644 index 00000000..d0fc1e2b --- /dev/null +++ b/src/Network/GRPC/Spec/RPC/StreamType.hs @@ -0,0 +1,216 @@ +-- | Streaming types +-- +-- The +-- [gRPC specification](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md) +-- does not distinguish between different kinds of communication patterns; +-- this is done in the +-- [Protobuf specification](https://protobuf.dev/reference/protobuf/proto3-spec/#service_definition). +-- Nonetheless, these streaming types can be applied to other serialization +-- formats also. +module Network.GRPC.Spec.RPC.StreamType ( + -- * Communication patterns + StreamingType(..) + , SupportsStreamingType + , HasStreamingType(..) + -- ** Handlers + , NonStreamingHandler(..) + , ClientStreamingHandler(..) + , ServerStreamingHandler(..) + , BiDiStreamingHandler(..) + , HandlerFor + -- ** Execution + , nonStreaming + , clientStreaming + , serverStreaming + , biDiStreaming + -- ** Constructors + , mkNonStreaming + , mkClientStreaming + , mkServerStreaming + , mkBiDiStreaming + ) where + +import Data.Kind + +-- Borrow protolens 'StreamingType' (but this module is not Protobuf specific) +import Data.ProtoLens.Service.Types (StreamingType(..)) + +import Network.GRPC.Common.StreamElem +import Network.GRPC.Spec.Request +import Network.GRPC.Spec.RPC +import Network.GRPC.Util.RedundantConstraint + +{------------------------------------------------------------------------------- + Communication patterns +-------------------------------------------------------------------------------} + +-- | This RPC supports the given streaming type +-- +-- This is a weaker condition than 'HasStreamingType', which maps each RPC to +-- a /specific/ streaming type. Some (non-Protobuf) RPCs however may support +-- more than one streaming type. +class SupportsStreamingType rpc (styp :: StreamingType) + +class SupportsStreamingType rpc (RpcStreamingType rpc) + => HasStreamingType rpc where + -- | Streaming type supported by this RPC + type RpcStreamingType rpc :: StreamingType + +{------------------------------------------------------------------------------- + Handler types + + Users don't typically work with these types directly; however, they serve + as as specification of the commonality between the client and the server. +-------------------------------------------------------------------------------} + +newtype NonStreamingHandler m rpc = UnsafeNonStreamingHandler ( + Input rpc + -> m (Output rpc) + ) + +newtype ClientStreamingHandler m rpc = UnsafeClientStreamingHandler ( + m (StreamElem NoMetadata (Input rpc)) + -> m (Output rpc) + ) + +newtype ServerStreamingHandler m rpc = UnsafeServerStreamingHandler ( + Input rpc + -> (Output rpc -> m ()) + -> m () + ) + +newtype BiDiStreamingHandler m rpc = UnsafeBiDiStreamingHandler ( + m (StreamElem NoMetadata (Input rpc)) + -> (Output rpc -> m ()) + -> m () + ) + +-- | Match 'StreamingType' to handler type +-- +-- This is occassionally useful to improve type inference. Users are not +-- expected to need to use this in their own code. +type family HandlerFor (typ :: StreamingType) :: (Type -> Type) -> Type -> Type where + HandlerFor NonStreaming = NonStreamingHandler + HandlerFor ClientStreaming = ClientStreamingHandler + HandlerFor ServerStreaming = ServerStreamingHandler + HandlerFor BiDiStreaming = BiDiStreamingHandler + +{------------------------------------------------------------------------------- + Destructors + + These are just the record field accessors, but with an additional constraint + on them that ensures that the right type of handler is used with the right + type of RPC. +-------------------------------------------------------------------------------} + +-- | Make a non-streaming RPC +-- +-- Example usage: +-- +-- > features <- +-- > nonStreaming (rpc @(Protobuf RouteGuide "getFeature") conn) point +-- > logMsg features +nonStreaming :: forall rpc m. + SupportsStreamingType rpc NonStreaming + => NonStreamingHandler m rpc + -> Input rpc + -> m (Output rpc) +nonStreaming (UnsafeNonStreamingHandler h) = h + where + _ = addConstraint @(SupportsStreamingType rpc NonStreaming) + +-- | Make a client-side streaming RPC +-- +-- Example usage: +-- +-- > summary <- +-- > clientStreaming (rpc @(Protobuf RouteGuide "recordRoute") conn) getPoint +clientStreaming :: forall rpc m. + SupportsStreamingType rpc ClientStreaming + => ClientStreamingHandler m rpc + -> m (StreamElem NoMetadata (Input rpc)) + -> m (Output rpc) +clientStreaming (UnsafeClientStreamingHandler h) = h + where + _ = addConstraint @(SupportsStreamingType rpc ClientStreaming) + +-- | Make a server-side streaming RPC +-- +-- Example usage: +-- +-- > serverStreaming (rpc @(Protobuf RouteGuide "listFeatures") conn) rect $ +-- > logMsg +serverStreaming :: forall rpc m. + SupportsStreamingType rpc ServerStreaming + => ServerStreamingHandler m rpc + -> Input rpc + -> (Output rpc -> m ()) + -> m () +serverStreaming (UnsafeServerStreamingHandler h) = h + where + _ = addConstraint @(SupportsStreamingType rpc ServerStreaming) + +-- | Make a bidirectional RPC +-- +-- Example usage: +-- +-- > biDiStreaming (rpc @(Protobuf RouteGuide "routeChat") conn) getNote $ +-- > logMsg +biDiStreaming :: forall rpc m. + SupportsStreamingType rpc BiDiStreaming + => BiDiStreamingHandler m rpc + -> m (StreamElem NoMetadata (Input rpc)) + -> (Output rpc -> m ()) + -> m () +biDiStreaming (UnsafeBiDiStreamingHandler h) = h + where + _ = addConstraint @(SupportsStreamingType rpc BiDiStreaming) + +{------------------------------------------------------------------------------- + Constructors + + These are just the newtype constructors, but with an additional constraint, + similar to the destructors above. +-------------------------------------------------------------------------------} + +mkNonStreaming :: forall m rpc. + SupportsStreamingType rpc NonStreaming + => ( Input rpc + -> m (Output rpc) + ) + -> NonStreamingHandler m rpc +mkNonStreaming = UnsafeNonStreamingHandler + where + _ = addConstraint @(SupportsStreamingType rpc NonStreaming) + +mkClientStreaming :: forall m rpc. + SupportsStreamingType rpc ClientStreaming + => ( m (StreamElem NoMetadata (Input rpc)) + -> m (Output rpc) + ) + -> ClientStreamingHandler m rpc +mkClientStreaming = UnsafeClientStreamingHandler + where + _ = addConstraint @(SupportsStreamingType rpc ClientStreaming) + +mkServerStreaming :: forall m rpc. + SupportsStreamingType rpc ServerStreaming + => ( Input rpc + -> (Output rpc -> m ()) + -> m () + ) + -> ServerStreamingHandler m rpc +mkServerStreaming = UnsafeServerStreamingHandler + where + _ = addConstraint @(SupportsStreamingType rpc ServerStreaming) + +mkBiDiStreaming :: forall m rpc. + SupportsStreamingType rpc BiDiStreaming + => ( m (StreamElem NoMetadata (Input rpc)) + -> (Output rpc -> m ()) + -> m () + ) + -> BiDiStreamingHandler m rpc +mkBiDiStreaming = UnsafeBiDiStreamingHandler + where + _ = addConstraint @(SupportsStreamingType rpc BiDiStreaming) diff --git a/src/Network/GRPC/Spec/Request.hs b/src/Network/GRPC/Spec/Request.hs index 3d8a317a..fdf64cff 100644 --- a/src/Network/GRPC/Spec/Request.hs +++ b/src/Network/GRPC/Spec/Request.hs @@ -6,8 +6,13 @@ -- -- > import Network.GRPC.Spec.Request qualified as Req module Network.GRPC.Spec.Request ( - buildHeaders - , parseHeaders + -- * Inputs (message sent to the peer) + RequestHeaders(..) + , IsFinal(..) + , NoMetadata(..) + -- * Serialization + , buildRequestHeaders + , parseRequestHeaders ) where import Data.ByteString.Char8 qualified as BS.Strict.C8 @@ -16,18 +21,58 @@ import Data.List.NonEmpty (NonEmpty) import Data.Maybe (catMaybes) import Data.SOP import Data.Version +import Generics.SOP qualified as SOP +import GHC.Generics qualified as GHC import Network.HTTP.Types qualified as HTTP +import Text.Show.Pretty -import Network.GRPC.Spec import Network.GRPC.Spec.Common import Network.GRPC.Spec.Compression (CompressionId) import Network.GRPC.Spec.CustomMetadata import Network.GRPC.Spec.PercentEncoding qualified as PercentEncoding import Network.GRPC.Spec.RPC +import Network.GRPC.Spec.Timeout import Network.GRPC.Util.Partial import Paths_grapesy qualified as Grapesy +{------------------------------------------------------------------------------- + Inputs (message sent to the peer) +-------------------------------------------------------------------------------} + +-- | Full set of call parameters required to construct the RPC call +-- +-- This is constructed internally; it is not part of the public API. +data RequestHeaders = RequestHeaders { + -- | Timeout + requestTimeout :: Maybe Timeout + + -- | Custom metadata + , requestMetadata :: [CustomMetadata] + + -- | Compression used for outgoing messages + , requestCompression :: Maybe CompressionId + + -- | Accepted compression algorithms for incoming messages + -- + -- @Maybe (NonEmpty ..)@ is perhaps a bit strange (why not just @[]@), but + -- it emphasizes the specification: /if/ the header is present, it must be + -- a non-empty list. + , requestAcceptCompression :: Maybe (NonEmpty CompressionId) + } + deriving stock (Show, Eq) + deriving stock (GHC.Generic) + deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) + +-- | Mark a input sent as final +data IsFinal = Final | NotFinal + deriving stock (Show, Eq) + +-- | gRPC does not support request trailers (only response trailers) +data NoMetadata = NoMetadata + deriving stock (Show, Eq, GHC.Generic) + deriving anyclass (PrettyVal) + {------------------------------------------------------------------------------- Construction -------------------------------------------------------------------------------} @@ -37,8 +82,8 @@ import Paths_grapesy qualified as Grapesy -- > Request-Headers → -- > Call-Definition -- > *Custom-Metadata -buildHeaders :: IsRPC rpc => Proxy rpc -> RequestHeaders -> [HTTP.Header] -buildHeaders proxy callParams@RequestHeaders{requestMetadata} = concat [ +buildRequestHeaders :: IsRPC rpc => Proxy rpc -> RequestHeaders -> [HTTP.Header] +buildRequestHeaders proxy callParams@RequestHeaders{requestMetadata} = concat [ callDefinition proxy callParams , map buildCustomMetadata requestMetadata ] @@ -150,11 +195,11 @@ callDefinition proxy = \hdrs -> catMaybes [ -------------------------------------------------------------------------------} -- | Parse request headers -parseHeaders :: +parseRequestHeaders :: IsRPC rpc => Proxy rpc -> [HTTP.Header] -> Either String RequestHeaders -parseHeaders proxy = +parseRequestHeaders proxy = runPartialParser uninitRequestHeaders . mapM_ parseHeader where diff --git a/src/Network/GRPC/Spec/Response.hs b/src/Network/GRPC/Spec/Response.hs index c966ae45..01bb74dc 100644 --- a/src/Network/GRPC/Spec/Response.hs +++ b/src/Network/GRPC/Spec/Response.hs @@ -6,31 +6,117 @@ -- -- > import Network.GRPC.Spec.Response qualified as Resp module Network.GRPC.Spec.Response ( - -- * Construction - buildHeaders + -- * Headers + ResponseHeaders(..) + , ProperTrailers(..) + , TrailersOnly(..) + -- * Distinguish betwen 'GrpcOk' and 'GrpcError' + , GrpcException(..) + , grpcExceptionFromTrailers + , grpcExceptionToTrailers + -- * Serialization + -- ** Construction + , buildResponseHeaders , buildProperTrailers , buildTrailersOnly - -- * Parsing - , parseHeaders + -- ** Parsing + , parseResponseHeaders , parseProperTrailers , parseTrailersOnly ) where +import Control.Exception import Control.Monad.Except import Data.ByteString.Char8 qualified as BS.Strict.C8 import Data.List.NonEmpty (NonEmpty) import Data.SOP +import Data.Text (Text) +import Generics.SOP qualified as SOP +import GHC.Generics qualified as GHC import Network.HTTP.Types qualified as HTTP import Text.Read (readMaybe) +import Text.Show.Pretty -import Network.GRPC.Spec import Network.GRPC.Spec.Common import Network.GRPC.Spec.Compression (CompressionId) import Network.GRPC.Spec.CustomMetadata import Network.GRPC.Spec.PercentEncoding qualified as PercentEncoding import Network.GRPC.Spec.RPC +import Network.GRPC.Spec.Status import Network.GRPC.Util.Partial +{------------------------------------------------------------------------------- + Outputs (messages received from the peer) +-------------------------------------------------------------------------------} + +-- | Response headers +data ResponseHeaders = ResponseHeaders { + responseCompression :: Maybe CompressionId + , responseAcceptCompression :: Maybe (NonEmpty CompressionId) + , responseMetadata :: [CustomMetadata] + } + deriving stock (Show, Eq) + deriving stock (GHC.Generic) + deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) + +-- | Information sent by the peer after the final output +-- +-- Response trailers are a +-- [HTTP2 concept](https://datatracker.ietf.org/doc/html/rfc7540#section-8.1.3): +-- they are HTTP headers that are sent /after/ the content body. For example, +-- imagine the server is streaming a file that it's reading from disk; it could +-- use trailers to give the client an MD5 checksum when streaming is complete. +data ProperTrailers = ProperTrailers { + trailerGrpcStatus :: GrpcStatus + , trailerGrpcMessage :: Maybe Text + , trailerMetadata :: [CustomMetadata] + } + deriving stock (Show, Eq) + deriving stock (GHC.Generic) + deriving anyclass (SOP.Generic, SOP.HasDatatypeInfo) + +-- | Trailers sent in the gRPC Trailers-Only case +-- +-- In the current version of the spec, the information in 'TrailersOnly' is +-- identical to the 'ProperTrailers' case (but they do require a slightly +-- different function to parse/unparse). +newtype TrailersOnly = TrailersOnly { + getTrailersOnly :: ProperTrailers + } + deriving stock (Show, Eq) + +{------------------------------------------------------------------------------- + Distinguish betwen 'GrpcOk' and 'GrpcError' +-------------------------------------------------------------------------------} + +-- | Server indicated a gRPC error +data GrpcException = GrpcException { + grpcError :: GrpcError + , grpcErrorMessage :: Maybe Text + , grpcErrorMetadata :: [CustomMetadata] + } + deriving stock (Show, Eq, GHC.Generic) + deriving anyclass (Exception, PrettyVal) + +grpcExceptionFromTrailers :: + ProperTrailers + -> Either (Maybe Text, [CustomMetadata]) GrpcException +grpcExceptionFromTrailers (ProperTrailers status msg metadata) = + case status of + GrpcOk -> Left (msg, metadata) + GrpcError err -> Right GrpcException{ + grpcError = err + , grpcErrorMessage = msg + , grpcErrorMetadata = metadata + } + +grpcExceptionToTrailers :: GrpcException -> ProperTrailers +grpcExceptionToTrailers err = ProperTrailers{ + trailerGrpcStatus = GrpcError (grpcError err) + , trailerGrpcMessage = grpcErrorMessage err + , trailerMetadata = grpcErrorMetadata err + } + {------------------------------------------------------------------------------- > Response-Headers → > HTTP-Status @@ -52,8 +138,8 @@ import Network.GRPC.Util.Partial -------------------------------------------------------------------------------} -- | Build response headers -buildHeaders :: IsRPC rpc => Proxy rpc -> ResponseHeaders -> [HTTP.Header] -buildHeaders proxy +buildResponseHeaders :: IsRPC rpc => Proxy rpc -> ResponseHeaders -> [HTTP.Header] +buildResponseHeaders proxy ResponseHeaders{ responseCompression , responseAcceptCompression , responseMetadata @@ -71,10 +157,10 @@ buildHeaders proxy ] -- | Parse response headers -parseHeaders :: +parseResponseHeaders :: IsRPC rpc => Proxy rpc -> [HTTP.Header] -> Either String ResponseHeaders -parseHeaders proxy = +parseResponseHeaders proxy = runPartialParser uninitResponseHeaders . mapM_ parseHeader where diff --git a/src/Network/GRPC/Spec/Status.hs b/src/Network/GRPC/Spec/Status.hs new file mode 100644 index 00000000..4f492aa6 --- /dev/null +++ b/src/Network/GRPC/Spec/Status.hs @@ -0,0 +1,212 @@ +module Network.GRPC.Spec.Status ( + -- * GRPC status + GrpcStatus(..) + , GrpcError(..) + , fromGrpcStatus + , toGrpcStatus + ) where + +import Control.Exception +import GHC.Generics qualified as GHC +import Text.Show.Pretty + +{------------------------------------------------------------------------------- + gRPC status +-------------------------------------------------------------------------------} + +-- | gRPC status +-- +-- Defined in . +data GrpcStatus = + GrpcOk + | GrpcError GrpcError + deriving stock (Show, Eq) + +-- | gRPC error code +-- +-- This is a subset of the gRPC status codes. See 'GrpcStatus'. +data GrpcError = + -- | Cancelled + -- + -- The operation was cancelled, typically by the caller. + GrpcCancelled + + -- | Unknown error + -- + -- For example, this error may be returned when a @Status@ value received + -- from another address space belongs to an error space that is not known in + -- this address space. Also errors raised by APIs that do not return enough + -- error information may be converted to this error. + | GrpcUnknown + + -- | Invalid argument + -- + -- The client specified an invalid argument. Note that this differs from + -- 'GrpcFailedPrecondition': 'GrpcInvalidArgumen'` indicates arguments that + -- are problematic regardless of the state of the system (e.g., a malformed + -- file name). + | GrpcInvalidArgument + + -- | Deadline exceeded + -- + -- The deadline expired before the operation could complete. For operations + -- that change the state of the system, this error may be returned even if + -- the operation has completed successfully. For example, a successful + -- response from a server could have been delayed long. + | GrpcDeadlineExceeded + + -- | Not found + -- + -- Some requested entity (e.g., file or directory) was not found. + -- + -- Note to server developers: if a request is denied for an entire class of + -- users, such as gradual feature rollout or undocumented allowlist, + -- 'GrpcNotFound' may be used. + -- + -- If a request is denied for some users within a class of users, such as + -- user-based access control, 'GrpcPermissionDenied' must be used. + | GrpcNotFound + + -- | Already exists + -- + -- The entity that a client attempted to create (e.g., file or directory) + -- already exists. + | GrpcAlreadyExists + + -- | Permission denied + -- + -- The caller does not have permission to execute the specified operation. + -- + -- * 'GrpcPermissionDenied' must not be used for rejections caused by + -- exhausting some resource (use 'GrpcResourceExhausted' instead for those + -- errors). + -- * 'GrpcPermissionDenoed' must not be used if the caller can not be + -- identified (use 'GrpcUnauthenticated' instead for those errors). + -- + -- This error code does not imply the request is valid or the requested + -- entity exists or satisfies other pre-conditions. + | GrpcPermissionDenied + + -- | Resource exhausted + -- + -- Some resource has been exhausted, perhaps a per-user quota, or perhaps + -- the entire file system is out of space. + | GrpcResourceExhausted + + -- | Failed precondition + -- + -- The operation was rejected because the system is not in a state required + -- for the operation's execution. For example, the directory to be deleted + -- is non-empty, an rmdir operation is applied to a non-directory, etc. + -- + -- Service implementors can use the following guidelines to decide between + -- 'GrpcFailedPrecondition', 'GrpcAborted', and 'GrpcUnvailable': + -- + -- (a) Use 'GrpcUnavailable' if the client can retry just the failing call. + -- (b) Use 'GrpcAborted' if the client should retry at a higher level (e.g., + -- when a client-specified test-and-set fails, indicating the client + -- should restart a read-modify-write sequence). + -- (c) Use `GrpcFailedPrecondition` if the client should not retry until the + -- system state has been explicitly fixed. E.g., if an @rmdir@ fails + -- because the directory is non-empty, 'GrpcFailedPrecondition' should + -- be returned since the client should not retry unless the files are + -- deleted from the directory. + | GrpcFailedPrecondition + + -- | Aborted + -- + -- The operation was aborted, typically due to a concurrency issue such as a + -- sequencer check failure or transaction abort. See the guidelines above + -- for deciding between 'GrpcFailedPrecondition', 'GrpcAborted', and + -- 'GrpcUnavailable'. + | GrpcAborted + + -- | Out of range + -- + -- The operation was attempted past the valid range. E.g., seeking or + -- reading past end-of-file. + -- + -- Unlike 'GrpcInvalidArgument', this error indicates a problem that may be + -- fixed if the system state changes. For example, a 32-bit file system will + -- generate 'GrpcInvalidArgument' if asked to read at an offset that is not + -- in the range @[0, 2^32-1]@, but it will generate 'GrpcOutOfRange' if + -- asked to read from an offset past the current file size. + -- + -- There is a fair bit of overlap between 'GrpcFailedPrecondition' and + -- 'GrpcOutOfRange'. We recommend using 'GrpcOutOfRange' (the more specific + -- error) when it applies so that callers who are iterating through a space + -- can easily look for an 'GrpcOutOfRange' error to detect when they are + -- done. + | GrpcOutOfRange + + -- | Unimplemented + -- + -- The operation is not implemented or is not supported/enabled in this + -- service. + | GrpcUnimplemented + + -- | Internal errors + -- + -- This means that some invariants expected by the underlying system have + -- been broken. This error code is reserved for serious errors. + | GrpcInternal + + -- | Unavailable + -- + -- The service is currently unavailable. This is most likely a transient + -- condition, which can be corrected by retrying with a backoff. Note that + -- it is not always safe to retry non-idempotent operations. + | GrpcUnavailable + + -- | Data loss + -- + -- Unrecoverable data loss or corruption. + | GrpcDataLoss + + -- | Unauthenticated + -- + -- The request does not have valid authentication credentials for the + -- operation. + | GrpcUnauthenticated + deriving stock (Show, Eq, GHC.Generic) + deriving anyclass (Exception, PrettyVal) + +fromGrpcStatus :: GrpcStatus -> Word +fromGrpcStatus GrpcOk = 0 +fromGrpcStatus (GrpcError GrpcCancelled) = 1 +fromGrpcStatus (GrpcError GrpcUnknown) = 2 +fromGrpcStatus (GrpcError GrpcInvalidArgument) = 3 +fromGrpcStatus (GrpcError GrpcDeadlineExceeded) = 4 +fromGrpcStatus (GrpcError GrpcNotFound) = 5 +fromGrpcStatus (GrpcError GrpcAlreadyExists) = 6 +fromGrpcStatus (GrpcError GrpcPermissionDenied) = 7 +fromGrpcStatus (GrpcError GrpcResourceExhausted) = 8 +fromGrpcStatus (GrpcError GrpcFailedPrecondition) = 9 +fromGrpcStatus (GrpcError GrpcAborted) = 10 +fromGrpcStatus (GrpcError GrpcOutOfRange) = 11 +fromGrpcStatus (GrpcError GrpcUnimplemented) = 12 +fromGrpcStatus (GrpcError GrpcInternal) = 13 +fromGrpcStatus (GrpcError GrpcUnavailable) = 14 +fromGrpcStatus (GrpcError GrpcDataLoss) = 15 +fromGrpcStatus (GrpcError GrpcUnauthenticated) = 16 + +toGrpcStatus :: Word -> Maybe GrpcStatus +toGrpcStatus 0 = Just $ GrpcOk +toGrpcStatus 1 = Just $ GrpcError $ GrpcCancelled +toGrpcStatus 2 = Just $ GrpcError $ GrpcUnknown +toGrpcStatus 3 = Just $ GrpcError $ GrpcInvalidArgument +toGrpcStatus 4 = Just $ GrpcError $ GrpcDeadlineExceeded +toGrpcStatus 5 = Just $ GrpcError $ GrpcNotFound +toGrpcStatus 6 = Just $ GrpcError $ GrpcAlreadyExists +toGrpcStatus 7 = Just $ GrpcError $ GrpcPermissionDenied +toGrpcStatus 8 = Just $ GrpcError $ GrpcResourceExhausted +toGrpcStatus 9 = Just $ GrpcError $ GrpcFailedPrecondition +toGrpcStatus 10 = Just $ GrpcError $ GrpcAborted +toGrpcStatus 11 = Just $ GrpcError $ GrpcOutOfRange +toGrpcStatus 12 = Just $ GrpcError $ GrpcUnimplemented +toGrpcStatus 13 = Just $ GrpcError $ GrpcInternal +toGrpcStatus 14 = Just $ GrpcError $ GrpcUnavailable +toGrpcStatus 15 = Just $ GrpcError $ GrpcDataLoss +toGrpcStatus 16 = Just $ GrpcError $ GrpcUnauthenticated +toGrpcStatus _ = Nothing + diff --git a/src/Network/GRPC/Spec/Timeout.hs b/src/Network/GRPC/Spec/Timeout.hs new file mode 100644 index 00000000..b008ec06 --- /dev/null +++ b/src/Network/GRPC/Spec/Timeout.hs @@ -0,0 +1,76 @@ +module Network.GRPC.Spec.Timeout ( + -- * Timeouts + Timeout(..) + , TimeoutValue(TimeoutValue, getTimeoutValue) + , TimeoutUnit(..) + -- * Translation + , timeoutToMicro + ) where + +import GHC.Show + +{------------------------------------------------------------------------------- + Timeouts +-------------------------------------------------------------------------------} + +data Timeout = Timeout TimeoutUnit TimeoutValue + deriving stock (Show, Eq) + +-- | Positive integer with ASCII representation of at most 8 digits +newtype TimeoutValue = UnsafeTimeoutValue { + getTimeoutValue :: Word + } + deriving newtype (Eq) + +-- | 'Show' instance relies on the 'TimeoutValue' pattern synonym +instance Show TimeoutValue where + showsPrec p (UnsafeTimeoutValue val) = showParen (p >= appPrec1) $ + showString "TimeoutValue " + . showsPrec appPrec1 val + +pattern TimeoutValue :: Word -> TimeoutValue +pattern TimeoutValue t <- UnsafeTimeoutValue t + where + TimeoutValue t + | isValidTimeoutValue t = UnsafeTimeoutValue t + | otherwise = error $ "invalid TimeoutValue: " ++ show t + +{-# COMPLETE TimeoutValue #-} + +isValidTimeoutValue :: Word -> Bool +isValidTimeoutValue t = length (show t) <= 8 + +data TimeoutUnit = + Hour + | Minute + | Second + | Millisecond + | Microsecond + | Nanosecond + deriving stock (Show, Eq) + +{------------------------------------------------------------------------------- + Translation +-------------------------------------------------------------------------------} + +-- | Translate 'Timeout' to microseconds +-- +-- For 'Nanosecond' timeout we round up. +timeoutToMicro :: Timeout -> Integer +timeoutToMicro = \case + Timeout Hour (TimeoutValue n) -> mult n $ 1 * 1_000 * 1_000 * 60 * 24 + Timeout Minute (TimeoutValue n) -> mult n $ 1 * 1_000 * 1_000 * 60 + Timeout Second (TimeoutValue n) -> mult n $ 1 * 1_000 * 1_000 + Timeout Millisecond (TimeoutValue n) -> mult n $ 1 * 1_000 + Timeout Microsecond (TimeoutValue n) -> mult n $ 1 + Timeout Nanosecond (TimeoutValue n) -> nano n + where + mult :: Word -> Integer -> Integer + mult n m = fromIntegral n * m + + nano :: Word -> Integer + nano n = fromIntegral $ + mu + if n' == 0 then 0 else 1 + where + (mu, n') = divMod n 1_000 + diff --git a/src/Network/GRPC/Util/HTTP2/Stream.hs b/src/Network/GRPC/Util/HTTP2/Stream.hs index 408957e5..f8e55a13 100644 --- a/src/Network/GRPC/Util/HTTP2/Stream.hs +++ b/src/Network/GRPC/Util/HTTP2/Stream.hs @@ -10,10 +10,8 @@ module Network.GRPC.Util.HTTP2.Stream ( , clientOutputStream ) where -import Control.Monad import Data.Binary.Builder (Builder) import Data.ByteString qualified as Strict -import Data.IORef import Network.HTTP.Types qualified as HTTP import Network.HTTP2.Client qualified as Client import Network.HTTP2.Server qualified as Server @@ -30,9 +28,6 @@ data OutputStream = OutputStream { -- | Flush the stream (send frames to the peer) , flush :: IO () - - -- | Indicate that there will be no more data - , closeOutputStream :: IO () } data InputStream = InputStream { @@ -65,41 +60,35 @@ serverInputStream req = do -- > Most responses are expected to have both headers and trailers but -- > Trailers-Only is permitted for calls that produce an immediate error. -- --- If we compare this to the example server, however, we see that all headers --- might delivered in a /single/ frame (i.e., the proper Trailers-Only case). An --- example is @RouteGuide.listFeatures@: when there /are/ no features in the --- specified rectangle, the server will send no messages back to the client. The --- example Python server will use the gRPC Trailers-Only case here (and so we --- must be able to deal with that in our client implementation). --- --- Fortunately, it seems that if we do /not/ make use of the Trailers-Only case --- (our current server implementation) then the example Python client will still --- interpret the results correctly, so we don't have to follow suit. Indeed, --- technically what the Python server does is not conform spec, as mentioned --- above. --- --- If we /do/ want to do what the example Python server does, it's a bit tricky, --- as it leads to two requirements that seem in opposition: --- --- 1. We should wait until we create the request (and its headers) until we know --- for sure some data is written. If no data is written, we are in the --- @Trailers-Only@ case, and we should generate /different/ headers (this --- probably means we should introduce Trailers-Only data types after all; see --- discussion of 'parseTrailers'). --- 2. We should be able to send the response headers /before/ the first message --- (e.g. @wait_for_ready_with_client_timeout_example_client.py@ tests that it --- receives the initial metadata before the first response). +-- If we compare this to the official Python example @RouteGuide@ server, +-- however, we see that the @Trailers-Only@ case is sometimes also used in +-- non-error cases. An example is @RouteGuide.listFeatures@: when there /are/ no +-- features in the specified rectangle, the server will send no messages back to +-- the client. The example Python server will use the gRPC Trailers-Only case +-- here (and so we must be able to deal with that in our client implementation). -- --- Perhaps a way out here is that we can create the response as soon as we have --- a /promise/ of the first message, even if that first message may still need --- time to compute. +-- We do provide this functionality, but only through a specific API (see +-- 'sendTrailersOnly'); when that API is used, we do not make use of this +-- 'OutputStream' abstraction (indeed, we do not stream at all). In streaming +-- cases (the default) we do not make use of @Trailers-Only@. serverOutputStream :: (Builder -> IO ()) -> IO () -> IO OutputStream serverOutputStream writeChunk flush = do - return OutputStream { - writeChunk - , flush - , closeOutputStream = return () - } + -- Make sure that http2 does not wait for the first message before sending + -- the response headers. This is important: the client might want the + -- initial response metadata before the first message. + -- + -- This does require some justification; if any of the reasons below is + -- no longer true, we might need to reconsider: + -- + -- * The extra cost of this flush is that we might need an additional TCP + -- packet; no big deal. + -- * We only create the 'OutputStream' once the user actually initiates the + -- response, at which point the headers are fixed. + -- * We do not use an 'OutputStream' at all when we are in the Trailers-Only + -- case (see discussion above). + flush + + return OutputStream {writeChunk, flush} {------------------------------------------------------------------------------- Client API @@ -115,25 +104,17 @@ clientInputStream resp = do clientOutputStream :: (Builder -> IO ()) -> IO () -> IO OutputStream clientOutputStream writeChunk flush = do - wroteSomethingRef <- newIORef False - return OutputStream { - writeChunk = \chunk -> do - atomicModifyIORef wroteSomethingRef $ \_ -> (True, ()) - writeChunk chunk - , flush - - -- The http2 client implementation has an explicit check that means the - -- response is not sent until /something/ has been written - -- ; to - -- workaround this limitation we write an empty chunk. This results in - -- an empty data frame in the stream, but this does not matter: gRPC - -- does not support request headers, which means that, unlike in the - -- server, we do not need to give the empty case special treatment. - -- TODO: Nonetheless, it might be better to patch http2 to avoid this - -- workaround. - , closeOutputStream = do - wroteSomething <- readIORef wroteSomethingRef - unless wroteSomething $ writeChunk mempty - } - + -- The http2 client implementation has an explicit check that means the + -- request is not initiated until /something/ has been written + -- ; to workaround + -- this limitation we write an empty chunk. This results in an empty data + -- frame in the stream, but this does not matter: gRPC does not support + -- request headers, which means that, unlike in the server, we do not need + -- to give the empty case special treatment. + -- + -- TODO: A better alternative might be to offer the user an explicit + -- 'initiateRequest' API, just like we offer for responses. This would also + -- improve the "duality" between the server/client API. + writeChunk mempty + return OutputStream {writeChunk, flush} diff --git a/src/Network/GRPC/Util/PrettyVal.hs b/src/Network/GRPC/Util/PrettyVal.hs new file mode 100644 index 00000000..fd2c84fd --- /dev/null +++ b/src/Network/GRPC/Util/PrettyVal.hs @@ -0,0 +1,43 @@ +module Network.GRPC.Util.PrettyVal ( + -- * Deriving-via support + StrictByteString_IsString(..) + , StrictByteString_Binary(..) + ) where + +import Data.ByteString qualified as Strict +import Data.ByteString.Internal qualified as Strict +import Data.Coerce +import Data.Proxy +import Data.String +import GHC.TypeLits +import Text.Show.Pretty + +{------------------------------------------------------------------------------- + Deriving-via support +-------------------------------------------------------------------------------} + +newtype StrictByteString_IsString a = StrictByteString_IsString a + +instance ( Coercible a Strict.ByteString + , IsString a + ) => PrettyVal (StrictByteString_IsString a) where + prettyVal (StrictByteString_IsString x) = + -- The 'ByteString' 'IsString' instance is defined using 'packChars', + -- so here we do the opposite + String . show $ Strict.unpackChars (co x) + where + co :: a -> Strict.ByteString + co = coerce + +newtype StrictByteString_Binary (constr :: Symbol) a = StrictByteString_Binary a + +instance ( Coercible a Strict.ByteString + , KnownSymbol constr + ) => PrettyVal (StrictByteString_Binary constr a) where + prettyVal (StrictByteString_Binary x) = + Con (symbolVal (Proxy @constr)) [ + Con "pack" [prettyVal $ Strict.unpack (co x)] + ] + where + co :: a -> Strict.ByteString + co = coerce diff --git a/src/Network/GRPC/Util/Session.hs b/src/Network/GRPC/Util/Session.hs index 442077d8..a15eed38 100644 --- a/src/Network/GRPC/Util/Session.hs +++ b/src/Network/GRPC/Util/Session.hs @@ -26,10 +26,17 @@ module Network.GRPC.Util.Session ( , PeerException(..) -- * Channel , Channel -- opaque + -- ** Working with an open channel , getInboundHeaders - , recv , send + , recv + , RecvAfterFinal(..) + , SendAfterFinal(..) + -- ** Closing + , waitForOutbound , close + , ChannelUncleanClose(..) + , ChannelClosed(..) -- ** Construction -- *** Client , ConnectionToServer(..) diff --git a/src/Network/GRPC/Util/Session/API.hs b/src/Network/GRPC/Util/Session/API.hs index 90ddd144..beeb619b 100644 --- a/src/Network/GRPC/Util/Session/API.hs +++ b/src/Network/GRPC/Util/Session/API.hs @@ -52,31 +52,29 @@ data ResponseInfo = ResponseInfo { -- 3. Trailers -- -- However, in the case that there /are/ no messages, this whole thing collapses --- into a single 'TrailersOnly'. We need to treat this case separately, because +-- and we just have headers (in gRPC this is referred to as the Trailers-Only +-- case, but we avoid that terminology here). -- -- * It looks different on the wire: in the regular case, we will have /two/ --- HTTP @Headers@ frames, but in the Trailers-Only case, we only have one. --- * Applications may in turn treat the Trailers-Only case special, using a --- different set of headers (specifically, this is the case for gRPC). --- --- To avoid confusion, we refer to the trailers in the non-Trailers-Only case --- as "proper" trailers. -class ( Show (Headers flow) - , Show (Message flow) - , Show (ProperTrailers flow) - , Show (TrailersOnly flow) +-- HTTP @Headers@ frames, but in the absence of messages we only have one. +-- * Applications may in turn treat this case special, using a different set of +-- headers (specifically, this is the case for gRPC). +class ( Show (Headers flow) + , Show (Message flow) + , Show (Trailers flow) + , Show (NoMessages flow) ) => DataFlow flow where - data Headers flow :: Type - type Message flow :: Type - type ProperTrailers flow :: Type - type TrailersOnly flow :: Type + data Headers flow :: Type + type Message flow :: Type + type Trailers flow :: Type + type NoMessages flow :: Type -- | Start of data flow -- -- See 'DataFlow' for discussion. data FlowStart flow = - FlowStartRegular (Headers flow) - | FlowStartTrailersOnly (TrailersOnly flow) + FlowStartRegular (Headers flow) + | FlowStartNoMessages (NoMessages flow) deriving instance DataFlow flow => Show (FlowStart flow) @@ -98,14 +96,16 @@ class ( DataFlow (Inbound sess) type Outbound sess :: Type -- | Parse proper trailers - parseProperTrailers :: + parseInboundTrailers :: sess - -> [HTTP.Header] -> IO (ProperTrailers (Inbound sess)) + -> [HTTP.Header] + -> IO (Trailers (Inbound sess)) -- | Build proper trailers - buildProperTrailers :: + buildOutboundTrailers :: sess - -> ProperTrailers (Outbound sess) -> [HTTP.Header] + -> Trailers (Outbound sess) + -> [HTTP.Header] -- | Parse message parseMsg :: @@ -117,7 +117,8 @@ class ( DataFlow (Inbound sess) buildMsg :: sess -> Headers (Outbound sess) - -> Message (Outbound sess) -> Builder + -> Message (Outbound sess) + -> Builder -- | Initiate new session -- @@ -136,9 +137,9 @@ class IsSession sess => InitiateSession sess where -> ResponseInfo -> IO (Headers (Inbound sess)) -- | Parse 'ResponseInfo' from the server, Trailers-Only case - parseResponseTrailersOnly :: + parseResponseNoMessages :: sess - -> ResponseInfo -> IO (TrailersOnly (Inbound sess)) + -> ResponseInfo -> IO (NoMessages (Inbound sess)) -- | Accept session -- @@ -152,9 +153,9 @@ class IsSession sess => AcceptSession sess where -> RequestInfo -> IO (Headers (Inbound sess)) -- | Parse 'RequestInfo' from the client, Trailers-Only case - parseRequestTrailersOnly :: + parseRequestNoMessages :: sess - -> RequestInfo -> IO (TrailersOnly (Inbound sess)) + -> RequestInfo -> IO (NoMessages (Inbound sess)) -- | Build 'ResponseInfo' for the client buildResponseInfo :: diff --git a/src/Network/GRPC/Util/Session/Channel.hs b/src/Network/GRPC/Util/Session/Channel.hs index 4201e77b..2d8e8ac1 100644 --- a/src/Network/GRPC/Util/Session/Channel.hs +++ b/src/Network/GRPC/Util/Session/Channel.hs @@ -1,3 +1,7 @@ +-- | Channel +-- +-- You should not have to import this module directly; instead import +-- "Network.GRPC.Util.Session". module Network.GRPC.Util.Session.Channel ( -- * Main definition Channel(..) @@ -10,23 +14,26 @@ module Network.GRPC.Util.Session.Channel ( , getInboundHeaders , send , recv - , close - -- ** Exceptions , RecvAfterFinal(..) , SendAfterFinal(..) + -- * Closing + , waitForOutbound + , close + , ChannelUncleanClose(..) , ChannelClosed(..) + -- ** Exceptions -- * Constructing channels , sendMessageLoop , recvMessageLoop - , processInboundTrailers - , processOutboundTrailers + , outboundTrailersMaker -- ** Logging , DebugMsg(..) ) where import Control.Concurrent.STM import Control.Exception -import Control.Monad +import Control.Monad.Catch +import Control.Monad.IO.Class import Control.Tracer import Data.Bifunctor import Data.ByteString qualified as BS.Strict @@ -36,7 +43,7 @@ import GHC.Stack import Network.HTTP.Types qualified as HTTP import Network.HTTP2.Internal qualified as HTTP2 -import Network.GRPC.Common.StreamElem (StreamElem (..)) +import Network.GRPC.Common.StreamElem (StreamElem(..)) import Network.GRPC.Common.StreamElem qualified as StreamElem import Network.GRPC.Util.Parser import Network.GRPC.Util.Session.API @@ -70,10 +77,10 @@ import Network.GRPC.Util.HTTP2.Stream -- -- Each channel is constructed for a /single/ session (request/response). data Channel sess = Channel { - -- | Thread state of the thread receiving messages from the sess's peer + -- | Thread state of the thread receiving messages from the peer channelInbound :: TVar (ThreadState (TMVar (FlowState (Inbound sess)))) - -- | Thread state of the thread sending messages to the sess's peer + -- | Thread state of the thread sending messages to the peer , channelOutbound :: TVar (ThreadState (TMVar (FlowState (Outbound sess)))) -- | 'CallStack' of the final call to 'send' @@ -90,39 +97,66 @@ data Channel sess = Channel { } -- | Data flow state --- --- We maintain three separate pieces of state: --- --- 1. 'flowHeaders' is written to once, when we receive the inbound headers. It --- is intended to support client code that wants to inspects these headers; a --- 'readTMVar' of this variable will block until the headers are received, --- but will return /before/ the first message is received. --- --- 2. 'flowMsg' is written to for every incoming message. This can be regarded --- as a 1-place buffer, and it provides backpressure: we cannot receive --- messages from the peer faster than the user is reading them. --- --- TODO: It might make sense to generalize this to an @N@-place buffer, for --- configurable @N@. This might provide for better latency masking. --- --- 3. 'flowTrailers' is written to once, when we receive the trailers. --- --- Invariant: when the trailers are written to 'flowTrailers', they will also --- (atomically) be written to 'flowMsg'. We maintain both to allow the --- 'HTTP2.TrailersMaker' to wait on 'flowTrailers'. --- --- NOTE: 'flowTrailers' and 'channelRecvFinal' are set at different times: --- 'flowTrailers' is set immediately upon receiving the trailers from the peer; --- 'channelRecvFinal' is set when the user does the final call to 'recv' to --- actually receive those trailers. data FlowState flow = FlowStateRegular (RegularFlowState flow) - | FlowStateTrailersOnly (TrailersOnly flow) + | FlowStateNoMessages (NoMessages flow) +-- | Regular (streaming) flow state data RegularFlowState flow = RegularFlowState { - flowHeaders :: Headers flow - , flowMsg :: TMVar (StreamElem (ProperTrailers flow) (Message flow)) - , flowTrailers :: TMVar (ProperTrailers flow) + -- | Headers + -- + -- On the client side, the outbound headers are specified when the request + -- is made ('callRequestMetadata'), and the inbound headers are recorded + -- once the responds starts to come in; clients can block-and-wait for + -- these headers ('getInboundHeaders'). + -- + -- On the server side, the inbound headers are recorded when the request + -- comes in, and the outbound headers are specified + -- ('setResponseMetadata') before the response is initiated + -- ('initiateResponse'/'sendTrailersOnly'). + flowHeaders :: Headers flow + + -- | Messages + -- + -- This TMVar is written to for incoming messages ('recvMessageLoop') and + -- read from for outgoing messages ('sendMessageLoop'). It acts as a + -- one-place buffer, providing backpressure in both directions. + -- + -- TODO: It might make sense to generalize this to an @N@-place buffer, + -- for configurable @N@. This might result in better latency masking. + , flowMsg :: TMVar (StreamElem (Trailers flow) (Message flow)) + + -- | Trailers + -- + -- Unlike 'flowMsg', which is /written/ to in 'recvMessageLoop' and /read/ + -- from in 'sendMessageLoop', both loops /set/ 'flowTerminated', once, + -- just before they terminate. + -- + -- * For 'sendMessageLoop', this means that the last message has been + -- written (that is, the last call to 'writeChunk' has happened). + -- This has two consequences: + -- + -- 1. @http2@ can now construct the trailers ('outboundTrailersMaker') + -- 2. Higher layers can wait on 'flowTerminated' to be /sure/ that the + -- last message has been written. + -- + -- * For 'recvMessageLoop', this means that the trailers have been + -- received from the peer. Higher layers can use this to check for, or + -- block-and-wait, to receive those trailers. + -- + -- == Relation to 'channelSentFinal'/'channelRecvFinal' + -- + -- 'flowTerminated' is set at different times than 'channelSentFinal' and + -- 'channelRecvFinal' are: + -- + -- * 'channelSentFinal' is set on the last call to 'send', but /before/ + -- the message is processed by 'sendMessageLoop'. + -- * 'channelRecvFinal', dually, is set on the last call to 'recv, which + -- must (necessarily) happen /before/ that message is actually made + -- available by 'recvMessageLoop'. + -- + -- /Their/ sole purpose is to catch user errors, not capture data flow. + , flowTerminated :: TMVar (Trailers flow) } {------------------------------------------------------------------------------- @@ -132,8 +166,8 @@ data RegularFlowState flow = RegularFlowState { initChannel :: IO (Channel sess) initChannel = Channel - <$> newTVarIO ThreadNotStarted - <*> newTVarIO ThreadNotStarted + <$> newThreadState + <*> newThreadState <*> newEmptyTMVarIO <*> newEmptyTMVarIO @@ -152,12 +186,12 @@ initFlowStateRegular headers = do -- Will block if the inbound headers have not yet been received. getInboundHeaders :: Channel sess - -> STM (Either (TrailersOnly (Inbound sess)) (Headers (Inbound sess))) + -> STM (Either (NoMessages (Inbound sess)) (Headers (Inbound sess))) getInboundHeaders Channel{channelInbound} = do st <- readTMVar =<< getThreadInterface channelInbound return $ case st of - FlowStateRegular regular -> Right $ flowHeaders regular - FlowStateTrailersOnly trailers -> Left trailers + FlowStateRegular regular -> Right $ flowHeaders regular + FlowStateNoMessages trailers -> Left trailers -- | Send a message to the node's peer -- @@ -167,7 +201,7 @@ getInboundHeaders Channel{channelInbound} = do send :: HasCallStack => Channel sess - -> StreamElem (ProperTrailers (Outbound sess)) (Message (Outbound sess)) + -> StreamElem (Trailers (Outbound sess)) (Message (Outbound sess)) -> STM () send Channel{channelOutbound, channelSentFinal} msg = do -- By checking that we haven't sent the final message yet, we know that this @@ -182,10 +216,10 @@ send Channel{channelOutbound, channelSentFinal} msg = do st <- readTMVar =<< getThreadInterface channelOutbound case st of FlowStateRegular regular -> do - forM_ (StreamElem.definitelyFinal msg) $ \_trailers -> + StreamElem.whenDefinitelyFinal msg $ \_trailers -> putTMVar channelSentFinal callStack putTMVar (flowMsg regular) msg - FlowStateTrailersOnly _ -> + FlowStateNoMessages _ -> -- For outgoing messages, the caller decides to use Trailers-Only, -- so if they then subsequently call 'send', we throw an exception. -- This is different for /inbound/ messages; see 'recv', below. @@ -199,7 +233,7 @@ recv :: HasCallStack => Channel sess -> STM ( StreamElem - (Either (TrailersOnly (Inbound sess)) (ProperTrailers (Inbound sess))) + (Either (NoMessages (Inbound sess)) (Trailers (Inbound sess))) (Message (Inbound sess)) ) recv Channel{channelInbound, channelRecvFinal} = do @@ -218,32 +252,13 @@ recv Channel{channelInbound, channelRecvFinal} = do msg <- takeTMVar (flowMsg regular) -- We update 'channelRecvFinal' in the same tx as the read, to -- atomically change from "there is a value" to "all values read". - forM_ (StreamElem.definitelyFinal msg) $ \_trailers -> + StreamElem.whenDefinitelyFinal msg $ \_trailers -> putTMVar channelRecvFinal callStack return $ first Right msg - FlowStateTrailersOnly trailers -> do + FlowStateNoMessages trailers -> do putTMVar channelRecvFinal callStack return $ NoMoreElems (Left trailers) --- | Close the channel --- --- This should only be used under exceptional circumstances; the normal --- way to terminate a connection is to --- --- * Read until receiving the final message --- * Write until sending the final message --- --- At this point a call to 'close' is unnecessary though harmless: calling --- 'close' on a channel that is terminated is a no-op. --- --- TODO: @http2@ does not offer an API for indicating that we want to ignore --- further output. We should check that this does not result in memory leak if --- the server keeps sending data and we're not listening.) -close :: (HasCallStack, Exception e) => Channel sess -> e -> IO () -close Channel{channelInbound, channelOutbound} e = do - cancelThread channelInbound $ ChannelClosed callStack (toException e) - cancelThread channelOutbound $ ChannelClosed callStack (toException e) - -- | Thrown by 'send' -- -- The 'CallStack' is the callstack of the final call to 'send'. @@ -269,11 +284,82 @@ data RecvAfterFinal = deriving stock (Show) deriving anyclass (Exception) --- | Channel was closed -data ChannelClosed = ChannelClosed { - channelClosedAt :: CallStack - , channelClosedReason :: SomeException - } +{------------------------------------------------------------------------------- + Closing +-------------------------------------------------------------------------------} + +-- | Wait for the outbound thread to terminate +-- +-- See 'forceClose' for discussion. +waitForOutbound :: Channel sess -> IO (FlowState (Outbound sess)) +waitForOutbound Channel{channelOutbound} = atomically $ + readTMVar =<< waitForThread channelOutbound + +-- | Close the channel +-- +-- Before a channel can be closed, you should 'send' the final outbound message +-- and then 'waitForOutbound' until all outbound messages have been processed. +-- Not doing so is considered a bug (it is not possible to do this implicitly, +-- because the final call to 'send' involves a choice of trailers, and calling +-- 'waitForOutbound' /without/ a final close to 'send' will result in deadlock). +-- Typically code will also process all /incoming/ messages, but doing so of +-- course not mandatory. +-- +-- Calling 'close' will kill the inbound thread (executing 'recvMessageLoop') +-- and outbound thread ('sendMessageLoop'), /if/ they are still running. If the +-- outbound thread was terminated with an exception, this could mean one of two +-- things: +-- +-- * The connection to the peer was lost +-- * Proper procedure for outbound messages was not followed (see above) +-- +-- In this case, 'close' will return a 'ChannelUncleanClose' exception. +-- +-- TODO: @http2@ does not offer an API for indicating that we want to ignore +-- further output. We should check that this does not result in memory leak (if +-- the server keeps sending data and we're not listening.) +close :: + HasCallStack + => Channel sess + -> ExitCase a -- ^ The reason why the channel is being closed + -> IO (Maybe ChannelUncleanClose) +close Channel{channelInbound, channelOutbound} reason = liftIO $ do + _inbound <- cancelThread channelInbound $ toException channelClosed + outbound <- cancelThread channelOutbound $ toException channelClosed + case outbound of + Right _ -> return $ Nothing + Left err -> return $ Just (ChannelUncleanClose err) + where + channelClosed :: ChannelClosed + channelClosed = + case reason of + ExitCaseSuccess _ -> ChannelDiscarded callStack + ExitCaseException e -> ChannelException callStack e + ExitCaseAbort -> ChannelAborted callStack + +-- | Thrown by 'close' if not all outbound messages have been processed +-- +-- See 'close' for discussion. +data ChannelUncleanClose = ChannelUncleanClose SomeException + deriving stock (Show) + deriving anyclass (Exception) + +-- | Thrown to the inbound and outbound threads by 'close' +data ChannelClosed = + -- | Channel was closed because it was discarded + -- + -- This typically corresponds to leaving the scope of 'acceptCall' or + -- 'withRPC' (without throwing an exception). + ChannelDiscarded CallStack + + -- | Channel was closed with an exception + | ChannelException CallStack SomeException + + -- | Channel was closed for an unknown reason + -- + -- This will only be used in monad stacks that have error mechanisms other + -- than exceptions. + | ChannelAborted CallStack deriving stock (Show) deriving anyclass (Exception) @@ -293,9 +379,11 @@ sendMessageLoop sess tracer st stream = go $ buildMsg sess (flowHeaders st) where go :: (Message (Outbound sess) -> Builder) -> IO () - go build = loop + go build = do + trailers <- loop + atomically $ putTMVar (flowTerminated st) trailers where - loop :: IO () + loop :: IO (Trailers (Outbound sess)) loop = do traceWith tracer $ NodeSendAwaitMsg msg <- atomically $ takeTMVar (flowMsg st) @@ -311,13 +399,14 @@ sendMessageLoop sess tracer st stream = -- stream as END_STREAM (rather than having to send a separate -- empty data frame). writeChunk stream $ build x - atomically $ putTMVar (flowTrailers st) trailers + return trailers NoMoreElems trailers -> do - closeOutputStream stream - atomically $ putTMVar (flowTrailers st) trailers + return trailers -- | Receive all messages sent by the node's peer -- +-- Should be called with exceptions masked. +-- -- TODO: This is wrong, we are never marking the final element as final. -- (But fixing this requires a patch to http2.) recvMessageLoop :: forall sess. @@ -331,9 +420,13 @@ recvMessageLoop sess tracer st stream = go $ parseMsg sess (flowHeaders st) where go :: Parser (Message (Inbound sess)) -> IO () - go = loop + go = \parser -> do + trailers <- loop parser >>= parseInboundTrailers sess + traceWith tracer $ NodeRecvFinal trailers + atomically $ putTMVar (flowTerminated st) $ trailers + atomically $ putTMVar (flowMsg st) $ NoMoreElems trailers where - loop :: Parser (Message (Inbound sess)) -> IO () + loop :: Parser (Message (Inbound sess)) -> IO [HTTP.Header] loop (ParserError err) = throwIO $ PeerSentMalformedMessage err loop (ParserDone x p') = do @@ -350,34 +443,30 @@ recvMessageLoop sess tracer st stream = | not (BS.Lazy.null acc) -> throwIO PeerSentIncompleteMessage - | otherwise -> do - processInboundTrailers sess tracer st =<< getTrailers stream + | otherwise -> + getTrailers stream -processInboundTrailers :: forall sess. - IsSession sess - => sess - -> Tracer IO (DebugMsg sess) - -> RegularFlowState (Inbound sess) - -> [HTTP.Header] -> IO () -processInboundTrailers sess tracer st trailers = do - inboundTrailers <- parseProperTrailers sess trailers - traceWith tracer $ NodeRecvFinal inboundTrailers - atomically $ do - putTMVar (flowMsg st) $ NoMoreElems inboundTrailers - putTMVar (flowTrailers st) $ inboundTrailers - -processOutboundTrailers :: forall sess. +outboundTrailersMaker :: forall sess. IsSession sess => sess - -> RegularFlowState (Outbound sess) + -> Channel sess -> HTTP2.TrailersMaker -processOutboundTrailers sess st = go +outboundTrailersMaker sess channel = go where go :: HTTP2.TrailersMaker go (Just _) = return $ HTTP2.NextTrailersMaker go go Nothing = do - outboundTrailers <- atomically $ readTMVar (flowTrailers st) - return $ HTTP2.Trailers $ buildProperTrailers sess outboundTrailers + -- Wait for the thread to terminate + -- + -- If the thread was killed, this will throw an exception (which will + -- then result in @http2@ cancelling the corresponding stream). + flowState <- waitForOutbound channel + trailers <- case flowState of + FlowStateRegular regular -> + atomically $ readTMVar $ flowTerminated regular + FlowStateNoMessages _ -> + error "unexpected FlowStateNoMessages" + return $ HTTP2.Trailers $ buildOutboundTrailers sess trailers data DebugMsg sess = -- | Thread sending messages is awaiting a message @@ -386,8 +475,8 @@ data DebugMsg sess = -- | Thread sending message will send a message | NodeSendMsg ( StreamElem - (ProperTrailers (Outbound sess)) - (Message (Outbound sess)) + (Trailers (Outbound sess)) + (Message (Outbound sess)) ) -- | Receive thread requires data @@ -398,11 +487,16 @@ data DebugMsg sess = -- | Receive thread received a message | NodeRecvMsg ( StreamElem - (Either (TrailersOnly (Inbound sess)) (ProperTrailers (Inbound sess))) + (Either (NoMessages (Inbound sess)) (Trailers (Inbound sess))) (Message (Inbound sess)) ) -- | Receive thread received the trailers - | NodeRecvFinal (ProperTrailers (Inbound sess)) + | NodeRecvFinal (Trailers (Inbound sess)) deriving instance IsSession sess => Show (DebugMsg sess) + +-- | Thrown in 'outboundTrailersMaker' +data FailedToMakeTrailers = FailedToMakeTrailers SomeException + deriving stock (Show) + deriving anyclass (Exception) \ No newline at end of file diff --git a/src/Network/GRPC/Util/Session/Client.hs b/src/Network/GRPC/Util/Session/Client.hs index 65727e32..6f0e74c9 100644 --- a/src/Network/GRPC/Util/Session/Client.hs +++ b/src/Network/GRPC/Util/Session/Client.hs @@ -4,10 +4,9 @@ module Network.GRPC.Util.Session.Client ( , initiateRequest ) where -import Control.Concurrent import Control.Concurrent.STM -import Control.Exception import Control.Monad +import Control.Monad.Catch import Control.Tracer import Network.HTTP2.Client qualified as Client @@ -83,20 +82,21 @@ initiateRequest sess tracer ConnectionToServer{sendRequest} outboundStart = do FlowStartRegular headers -> do regular <- initFlowStateRegular headers let req :: Client.Request - req = setRequestTrailers sess regular - $ Client.requestStreaming + req = setRequestTrailers sess channel + $ Client.requestStreamingUnmask (requestMethod requestInfo) (requestPath requestInfo) (requestHeaders requestInfo) - $ \write flush -> do - stream <- clientOutputStream write flush + $ \unmask write flush -> unmask $ do threadBody (channelOutbound channel) + unmask (newTMVarIO (FlowStateRegular regular)) $ \_stVar -> do + stream <- clientOutputStream write flush sendMessageLoop sess tracer regular stream forkRequest channel req - FlowStartTrailersOnly trailers -> do - stVar <- newTMVarIO $ FlowStateTrailersOnly trailers + FlowStartNoMessages trailers -> do + stVar <- newTMVarIO $ FlowStateNoMessages trailers atomically $ writeTVar (channelOutbound channel) $ ThreadDone stVar let req :: Client.Request req = Client.requestNoBody @@ -108,19 +108,19 @@ initiateRequest sess tracer ConnectionToServer{sendRequest} outboundStart = do return channel where forkRequest :: Channel sess -> Client.Request -> IO () - forkRequest channel req = void $ forkIO $ - threadBody (channelInbound channel) newEmptyTMVarIO $ \stVar -> do + forkRequest channel req = + forkThread (channelInbound channel) newEmptyTMVarIO $ \stVar -> do setup <- try $ sendRequest req $ \resp -> do responseStatus <- case Client.responseStatus resp of Just x -> return x - Nothing -> throwIO PeerMissingPseudoHeaderStatus + Nothing -> throwM PeerMissingPseudoHeaderStatus let responseHeaders = fromHeaderTable $ Client.responseHeaders resp responseInfo = ResponseInfo {responseHeaders, responseStatus} inboundStart <- if Client.responseBodySize resp == Just 0 - then FlowStartTrailersOnly <$> - parseResponseTrailersOnly sess responseInfo + then FlowStartNoMessages <$> + parseResponseNoMessages sess responseInfo else FlowStartRegular <$> parseResponseRegular sess responseInfo @@ -130,12 +130,12 @@ initiateRequest sess tracer ConnectionToServer{sendRequest} outboundStart = do stream <- clientInputStream resp atomically $ putTMVar stVar $ FlowStateRegular regular recvMessageLoop sess tracer regular stream - FlowStartTrailersOnly trailers -> - atomically $ putTMVar stVar $ FlowStateTrailersOnly trailers + FlowStartNoMessages trailers -> + atomically $ putTMVar stVar $ FlowStateNoMessages trailers case setup of - Right () -> return () - Left (e :: SomeException)-> close channel e + Right () -> return () + Left e -> void $ close channel $ ExitCaseException e {------------------------------------------------------------------------------- Auxiliary http2 @@ -144,8 +144,8 @@ initiateRequest sess tracer ConnectionToServer{sendRequest} outboundStart = do setRequestTrailers :: IsSession sess => sess - -> RegularFlowState (Outbound sess) + -> Channel sess -> Client.Request -> Client.Request -setRequestTrailers sess st req = +setRequestTrailers sess channel req = Client.setRequestTrailersMaker req $ - processOutboundTrailers sess st + outboundTrailersMaker sess channel diff --git a/src/Network/GRPC/Util/Session/Server.hs b/src/Network/GRPC/Util/Session/Server.hs index 13166f3c..c98daff0 100644 --- a/src/Network/GRPC/Util/Session/Server.hs +++ b/src/Network/GRPC/Util/Session/Server.hs @@ -4,10 +4,8 @@ module Network.GRPC.Util.Session.Server ( , initiateResponse ) where -import Control.Concurrent import Control.Concurrent.STM import Control.Exception -import Control.Monad import Control.Tracer import Network.HTTP2.Server qualified as Server @@ -62,45 +60,43 @@ initiateResponse sess tracer conn startOutbound = do inboundStart <- if Server.requestBodySize (request conn) == Just 0 then - FlowStartTrailersOnly <$> parseRequestTrailersOnly sess requestInfo + FlowStartNoMessages <$> parseRequestNoMessages sess requestInfo else FlowStartRegular <$> parseRequestRegular sess requestInfo - void $ forkIO $ - threadBody (channelInbound channel) newEmptyTMVarIO $ \stVar -> do - case inboundStart of - FlowStartRegular headers -> do - regular <- initFlowStateRegular headers - stream <- serverInputStream (request conn) - atomically $ putTMVar stVar $ FlowStateRegular regular - recvMessageLoop sess tracer regular stream - FlowStartTrailersOnly trailers -> - atomically $ putTMVar stVar $ FlowStateTrailersOnly trailers + forkThread (channelInbound channel) newEmptyTMVarIO $ \stVar -> + case inboundStart of + FlowStartRegular headers -> do + regular <- initFlowStateRegular headers + stream <- serverInputStream (request conn) + atomically $ putTMVar stVar $ FlowStateRegular regular + recvMessageLoop sess tracer regular stream + FlowStartNoMessages trailers -> + atomically $ putTMVar stVar $ FlowStateNoMessages trailers - void $ forkIO $ - threadBody (channelOutbound channel) newEmptyTMVarIO $ \stVar -> do - outboundStart <- startOutbound inboundStart - let responseInfo = buildResponseInfo sess outboundStart - case outboundStart of - FlowStartRegular headers -> do - regular <- initFlowStateRegular headers - atomically $ putTMVar stVar $ FlowStateRegular regular - let resp :: Server.Response - resp = setResponseTrailers sess regular - $ Server.responseStreaming - (responseStatus responseInfo) - (responseHeaders responseInfo) - $ \write flush -> do - stream <- serverOutputStream write flush - sendMessageLoop sess tracer regular stream - respond conn resp - FlowStartTrailersOnly trailers -> do - atomically $ putTMVar stVar $ FlowStateTrailersOnly trailers - let resp :: Server.Response - resp = Server.responseNoBody - (responseStatus responseInfo) - (responseHeaders responseInfo) - respond conn $ resp + forkThread (channelOutbound channel) newEmptyTMVarIO $ \stVar -> do + outboundStart <- startOutbound inboundStart + let responseInfo = buildResponseInfo sess outboundStart + case outboundStart of + FlowStartRegular headers -> do + regular <- initFlowStateRegular headers + atomically $ putTMVar stVar $ FlowStateRegular regular + let resp :: Server.Response + resp = setResponseTrailers sess channel + $ Server.responseStreaming + (responseStatus responseInfo) + (responseHeaders responseInfo) + $ \write flush -> do + stream <- serverOutputStream write flush + sendMessageLoop sess tracer regular stream + respond conn resp + FlowStartNoMessages trailers -> do + atomically $ putTMVar stVar $ FlowStateNoMessages trailers + let resp :: Server.Response + resp = Server.responseNoBody + (responseStatus responseInfo) + (responseHeaders responseInfo) + respond conn $ resp return channel @@ -111,8 +107,8 @@ initiateResponse sess tracer conn startOutbound = do setResponseTrailers :: IsSession sess => sess - -> RegularFlowState (Outbound sess) + -> Channel sess -> Server.Response -> Server.Response -setResponseTrailers sess st resp = +setResponseTrailers sess channel resp = Server.setResponseTrailersMaker resp $ - processOutboundTrailers sess st + outboundTrailersMaker sess channel diff --git a/src/Network/GRPC/Util/Thread.hs b/src/Network/GRPC/Util/Thread.hs index 7ef803d1..15a916ee 100644 --- a/src/Network/GRPC/Util/Thread.hs +++ b/src/Network/GRPC/Util/Thread.hs @@ -3,8 +3,13 @@ -- Intended for unqualified import. module Network.GRPC.Util.Thread ( ThreadState(..) + -- * Creating threads + , newThreadState + , forkThread , threadBody + -- * Access thread state , cancelThread + , waitForThread , getThreadInterface , ThreadCancelled(..) ) where @@ -42,39 +47,38 @@ data ThreadState a = | ThreadException SomeException {------------------------------------------------------------------------------- - Running + Creating threads -------------------------------------------------------------------------------} +newThreadState :: IO (TVar (ThreadState a)) +newThreadState = newTVarIO ThreadNotStarted + +forkThread :: + TVar (ThreadState a) + -> IO a -- ^ Initialize the thread (runs with exceptions masked) + -> (a -> IO ()) -- ^ Main thread body + -> IO () +forkThread state initThread body = + void $ mask_ $ forkIOWithUnmask $ \unmask -> + threadBody state unmask initThread body + -- | Wrap the thread body -- --- If the 'ThreadState' is anything other than 'ThreadNotStarted' on entry, --- this function terminates immediately. --- --- == Discussion +-- This should be wrapped around the body of the thread, and should be called +-- with exceptions masked. -- -- This is intended for integration with existing libraries (such as @http2@), --- which might do the forking under the hood. For this reason, this does /NOT/ --- do the actual fork, but should instead be wrapped around the body of the --- thread. --- --- It would be better if the thread would be called with exceptions --- masked and passed an @unmask@ function, but we cannot require this for the --- same reason. This leaves a small window of time in between the thread being --- forked by the external code and entering the scope of the 'mask' inside --- 'threadBody'; if the thread is now killed using 'killThread', a call to --- 'getThreadInterface' could block indefinitely. +-- which might do the forking under the hood. -- --- However, a call to 'cancelThread' is unproblematic, as it will itself --- correctly update the 'ThreadState' if the thread is in 'ThreadNotStarted' --- state. --- --- TL;DR: The thread should be killed with 'cancelThread', not 'killThread'. +-- If the 'ThreadState' is anything other than 'ThreadNotStarted' on entry, +-- this function terminates immediately. threadBody :: TVar (ThreadState a) + -> (forall x. IO x -> IO x) -- ^ Unmask exceptions -> IO a -- ^ Initialize the thread (runs with exceptions masked) -> (a -> IO ()) -- ^ Main thread body -> IO () -threadBody state initThread body = mask $ \unmask -> do +threadBody state unmask initThread body = do tid <- myThreadId shouldStart <- atomically $ do st <- readTVar state @@ -109,22 +113,33 @@ threadBody state initThread body = mask $ \unmask -> do -- 'ThreadNotStarted'); instead, we rely on the exception handler inside -- 'threadBody' to do so (we are guaranteed that this thread handler is -- installed if the thread state is anything other than 'ThreadNotStarted'). -cancelThread :: Exception e => TVar (ThreadState a) -> e -> IO () -cancelThread state e = do - mTid <- atomically $ do - st <- readTVar state - case st of - ThreadNotStarted -> do - writeTVar state $ ThreadException (toException e) - return Nothing - ThreadInitializing tid -> return $ Just tid - ThreadRunning tid _ -> return $ Just tid - ThreadDone _ -> return Nothing - ThreadException _ -> return Nothing - - case mTid of - Nothing -> return () - Just tid -> throwTo tid $ ThreadCancelled (toException e) +-- +-- Returns the reason the thread was killed (either the exception passed as an +-- argument to 'cancelThread', or an earlier exception if the thread was already +-- killed), or the thread interface if the thread had already terminated. +cancelThread :: + TVar (ThreadState a) + -> SomeException + -> IO (Either SomeException a) +cancelThread state e = join . atomically $ do + st <- readTVar state + case st of + ThreadNotStarted -> do + writeTVar state $ ThreadException e + return $ return (Left e) + ThreadInitializing tid -> + return $ kill tid + ThreadRunning tid _ -> + return $ kill tid + ThreadException e' -> + return $ return (Left e') + ThreadDone a -> + return $ return (Right a) + where + kill :: ThreadId -> IO (Either SomeException a) + kill tid = do + throwTo tid $ ThreadCancelled e + return $ Left e -- | Exception thrown by 'cancelThread' to the thread to the cancelled newtype ThreadCancelled = ThreadCancelled { @@ -158,3 +173,18 @@ getThreadInterface state = do ThreadRunning _ a -> return a ThreadDone a -> return a ThreadException e -> throwSTM e + +-- | Wait for the thread to terminate +-- +-- This is similar to 'getThreadInterface', but retries when the thread is still +-- running. +waitForThread :: TVar (ThreadState a) -> STM a +waitForThread state = do + st <- readTVar state + case st of + ThreadNotStarted -> retry + ThreadInitializing _ -> retry + ThreadRunning _ _ -> retry + ThreadDone a -> return a + ThreadException e -> throwSTM e + diff --git a/test-common/Test/Util/ClientServer.hs b/test-common/Test/Util/ClientServer.hs index bd3d51ef..a817f879 100644 --- a/test-common/Test/Util/ClientServer.hs +++ b/test-common/Test/Util/ClientServer.hs @@ -10,6 +10,7 @@ module Test.Util.ClientServer ( -- ** Evaluation , isExpectedException -- * Run + , UnexpectedHandlerException(..) , runTestClientServer -- ** Lower-level functionality , runTestClient @@ -24,12 +25,13 @@ import Data.Default import Data.Map qualified as Map import Data.Set qualified as Set import Data.Text qualified as Text +import Network.HTTP2.Internal qualified as HTTP2 import Network.TLS import Network.GRPC.Client qualified as Client +import Network.GRPC.Common +import Network.GRPC.Common.Compression (CompressionNegotationFailed) import Network.GRPC.Common.Compression qualified as Compr -import Network.GRPC.Common.Compression qualified as Compression -import Network.GRPC.Common.Exceptions import Network.GRPC.Server qualified as Server import Network.GRPC.Server.Run qualified as Server @@ -40,15 +42,15 @@ import Paths_grapesy -------------------------------------------------------------------------------} data ClientServerConfig = ClientServerConfig { - clientCompr :: Compression.Negotation - , serverCompr :: Compression.Negotation + clientCompr :: Compr.Negotation + , serverCompr :: Compr.Negotation , useTLS :: Maybe TlsSetup } instance Default ClientServerConfig where def = ClientServerConfig { - clientCompr = Compression.none - , serverCompr = Compression.none + clientCompr = Compr.none + , serverCompr = Compr.none , useTLS = Nothing } @@ -113,8 +115,7 @@ isExpectedException cfg err False | Just (grpcException :: GrpcException) <- fromException err - = if | Set.disjoint (Map.keysSet (Compr.supported (clientCompr cfg))) - (Map.keysSet (Compr.supported (serverCompr cfg))) + = if | compressionNegotationFailure , GrpcUnknown <- grpcError grpcException , Just msg <- grpcErrorMessage grpcException , "CompressionNegotationFailed" `Text.isInfixOf` msg @@ -123,14 +124,26 @@ isExpectedException cfg err | otherwise -> False + | Just (_ :: CompressionNegotationFailed) <- fromException err + , compressionNegotationFailure + = True + | Just (threadCancelled :: ThreadCancelled) <- fromException err = isExpectedException cfg (threadCancelledReason threadCancelled) - | Just (channelClosed :: ChannelClosed) <- fromException err - = isExpectedException cfg (channelClosedReason channelClosed) + | Just (ChannelException _ err') <- fromException err + = isExpectedException cfg err' + + | Just (ChannelUncleanClose err') <- fromException err + = isExpectedException cfg err' | otherwise = False + where + compressionNegotationFailure :: Bool + compressionNegotationFailure = + Set.disjoint (Map.keysSet (Compr.supported (clientCompr cfg))) + (Map.keysSet (Compr.supported (serverCompr cfg))) {------------------------------------------------------------------------------- Server @@ -138,9 +151,10 @@ isExpectedException cfg err runTestServer :: ClientServerConfig + -> Tracer IO SomeException -> [Server.RpcHandler IO] -> IO () -runTestServer cfg serverHandlers = do +runTestServer cfg serverExceptions serverHandlers = do pubCert <- getDataFileName "grpc-demo.cert" privKey <- getDataFileName "grpc-demo.priv" @@ -178,8 +192,9 @@ runTestServer cfg serverHandlers = do serverParams :: Server.ServerParams serverParams = Server.ServerParams { - serverDebugTracer = nullTracer - , serverCompression = serverCompr cfg + serverCompression = serverCompr cfg + , serverExceptionTracer = serverExceptions + , serverDebugTracer = nullTracer } Server.withServer serverParams serverHandlers $ @@ -253,13 +268,48 @@ runTestClient cfg clientRun = do Main entry point: run server and client together -------------------------------------------------------------------------------} +data UnexpectedHandlerException = UnexpectedHandlerException SomeException + deriving stock (Show) + deriving anyclass (Exception) + runTestClientServer :: forall a. ClientServerConfig + -> (SomeException -> Maybe UnexpectedHandlerException) -> (Client.Connection -> IO a) -> [Server.RpcHandler IO] -> IO a -runTestClientServer cfg clientRun serverHandlers = do - withAsync (runTestServer cfg serverHandlers) $ \_serverThread -> do +runTestClientServer cfg classifyHandlerException clientRun serverHandlers = do + testThread <- myThreadId + + -- Normally, when a handler throws an exception, that request is simply + -- aborted, but the server does not shut down. However, for the sake of + -- testing, if a handler throws an unexpected exception, the test should + -- fail. We therefore monitor for these exceptions and rethrow them to the + let serverExceptions :: Tracer IO SomeException + serverExceptions = arrow $ emit $ \err -> + if -- Exceptions due to invalid config are always expected + | isExpectedException cfg err + -> return () + + -- If the client shuts down the server, don't record that as an + -- unexpected exception + | Just AsyncCancelled <- fromException err + -> return () + + -- HTTP2 might terminate a handler when the client disappears + -- + -- TODO: Is this really expected behaviour..? Do we want to + -- unceremoniously kill a handler when a client disappears..? + | Just HTTP2.KilledByHttp2ThreadPoolManager <- fromException err + -> return () + + | Just err' <- classifyHandlerException err + -> throwTo testThread err' + + | otherwise + -> return () + + withAsync (runTestServer cfg serverExceptions serverHandlers) $ \_ -> do -- TODO: This threadDelay is obviously awful. When we fix proper -- wait-for-ready semantics, we can avoid it. threadDelay 100_000 diff --git a/test-common/Test/Util/PrettyVal.hs b/test-common/Test/Util/PrettyVal.hs new file mode 100644 index 00000000..293c9fdf --- /dev/null +++ b/test-common/Test/Util/PrettyVal.hs @@ -0,0 +1,52 @@ +{-# OPTIONS_GHC -Wno-orphans #-} + +module Test.Util.PrettyVal ( + -- * Deriving-via support + ShowAsPretty(..) + -- * CallStack + , PrettyCallStack(..) + ) where + +import Data.Bifunctor +import Data.Set (Set) +import Data.Set qualified as Set +import GHC.Stack +import Text.Show.Pretty + +{------------------------------------------------------------------------------- + Deriving-via support +-------------------------------------------------------------------------------} + +newtype ShowAsPretty a = ShowAsPretty a + +instance PrettyVal a => Show (ShowAsPretty a) where + show (ShowAsPretty x) = dumpStr x + +{------------------------------------------------------------------------------- + CallStack +-------------------------------------------------------------------------------} + +newtype PrettyCallStack = PrettyCallStack CallStack + deriving Show via ShowAsPretty PrettyCallStack + +newtype PrettySrcLoc = PrettySrcLoc SrcLoc + deriving Show via ShowAsPretty PrettySrcLoc + +instance PrettyVal PrettySrcLoc where + prettyVal (PrettySrcLoc x) = String (prettySrcLoc x) + +instance PrettyVal PrettyCallStack where + prettyVal (PrettyCallStack x) = + prettyVal $ map (second PrettySrcLoc) (getCallStack x) + +{------------------------------------------------------------------------------- + Orphans + + These live only in the testsuite, not the main lib. +-------------------------------------------------------------------------------} + +instance PrettyVal () where + prettyVal () = Con "()" [] + +instance PrettyVal a => PrettyVal (Set a) where + prettyVal xs = Con "Set.fromList" [prettyVal $ Set.toList xs] \ No newline at end of file diff --git a/test-grapesy/Main.hs b/test-grapesy/Main.hs index 333d903b..db2e120c 100644 --- a/test-grapesy/Main.hs +++ b/test-grapesy/Main.hs @@ -2,6 +2,8 @@ module Main (main) where import Test.Tasty +import Test.Prop.Dialogue qualified as Dialogue +import Test.Prop.Serialization qualified as Serialization import Test.Sanity.StreamingType.NonStreaming qualified as StreamingType.NonStreaming main :: IO () @@ -11,4 +13,8 @@ main = defaultMain $ testGroup "grapesy" [ StreamingType.NonStreaming.tests ] ] + , testGroup "Prop" [ + Serialization.tests + , Dialogue.tests + ] ] diff --git a/test-grapesy/Test/Driver/ClientServer.hs b/test-grapesy/Test/Driver/ClientServer.hs index d232cacd..05761a39 100644 --- a/test-grapesy/Test/Driver/ClientServer.hs +++ b/test-grapesy/Test/Driver/ClientServer.hs @@ -4,14 +4,18 @@ module Test.Driver.ClientServer ( -- * Basic client-server test ClientServerTest(..) , testClientServer + , propClientServer -- * Re-exports , module Test.Util.ClientServer ) where import Control.Exception +import Control.Monad.IO.Class import Data.Default import Data.Typeable -import Test.Tasty.HUnit +import Test.QuickCheck.Monadic qualified as QuickCheck +import Test.Tasty.HUnit qualified as HUnit +import Test.Tasty.QuickCheck qualified as QuickCheck import Network.GRPC.Client qualified as Client import Network.GRPC.Server qualified as Server @@ -35,20 +39,24 @@ instance Default ClientServerTest where , server = [] } -testClientServer :: (forall a. (ClientServerTest -> IO a) -> IO a) -> IO String -testClientServer withTest = +-- | Run client server test, and check for expected failures +testClientServer :: + (SomeException -> Maybe UnexpectedHandlerException) + -> (forall a. Show a => (ClientServerTest -> IO a) -> IO a) + -> IO String +testClientServer classifyHandlerException withTest = withTest $ \ClientServerTest{config, client, server} -> do - mRes <- try $ runTestClientServer config client server + mRes <- try $ runTestClientServer config classifyHandlerException client server case mRes of Left err - | Just (testFailure :: HUnitFailure) <- fromException err + | Just (testFailure :: HUnit.HUnitFailure) <- fromException err -> throwIO testFailure | isExpectedException config err -> return $ "Got expected error: " ++ show err | otherwise - -> assertFailure $ concat [ + -> HUnit.assertFailure $ concat [ "Unexpected exception of type " , case err of SomeException e -> show (typeOf e) @@ -58,3 +66,14 @@ testClientServer withTest = Right () -> return "" +-- | Turn client server test into property +-- +-- This does /not/ test for expected failures: we're not testing invalid +-- configurations when doing property based testing. +propClientServer :: + (SomeException -> Maybe UnexpectedHandlerException) + -> (forall a. Show a => (ClientServerTest -> IO a) -> IO a) + -> QuickCheck.Property +propClientServer classifyHandlerException withTest = QuickCheck.monadicIO $ + liftIO $ withTest $ \ClientServerTest{config, client, server} -> + runTestClientServer config classifyHandlerException client server \ No newline at end of file diff --git a/test-grapesy/Test/Driver/Dialogue.hs b/test-grapesy/Test/Driver/Dialogue.hs index 51249102..5cb9faa1 100644 --- a/test-grapesy/Test/Driver/Dialogue.hs +++ b/test-grapesy/Test/Driver/Dialogue.hs @@ -1,725 +1,10 @@ {-# LANGUAGE OverloadedStrings #-} module Test.Driver.Dialogue ( - execGlobalSteps + module X ) where -import Control.Concurrent -import Control.Concurrent.Async -import Control.Concurrent.STM -import Control.Monad -import Control.Monad.Catch -import Control.Monad.State -import Data.ByteString qualified as Strict -import Data.Default -import Data.List (sortBy) -import Data.List.NonEmpty (NonEmpty(..)) -import Data.Maybe (mapMaybe) -import Data.Ord (comparing) -import Data.Proxy -import Data.Text qualified as Text -import GHC.Stack -import Test.QuickCheck - -import Network.GRPC.Client (Timeout) -import Network.GRPC.Client qualified as Client -import Network.GRPC.Client.Binary qualified as Client.Binary -import Network.GRPC.Common.Binary -import Network.GRPC.Common.CustomMetadata -import Network.GRPC.Common.Exceptions -import Network.GRPC.Common.StreamElem (StreamElem(..)) -import Network.GRPC.Server qualified as Server -import Network.GRPC.Server.Binary qualified as Server.Binary - -import Test.Driver.ClientServer - -{------------------------------------------------------------------------------- - RPC --------------------------------------------------------------------------------} - -type TestRpc = BinaryRpc "binary" "test" - -{------------------------------------------------------------------------------- - Test clock - - These are concurrent tests, looking at the behaviour of @n@ concurrent - connections between a client and a server. To make the interleaving of actions - easier to see, we introduce a global "test clock". Every test action is - annotated with a particular "test tick" on this clock. If multiple actions are - annotated with the same tick, their interleaving is essentially - non-deterministic. --------------------------------------------------------------------------------} - -newtype TestClock = TestClock (TVar TestClockTick) - -newtype TestClockTick = TestClockTick Word - deriving stock (Show) - deriving newtype (Eq, Ord, Enum) - -withTestClock :: (TestClock -> IO a) -> IO a -withTestClock k = do - tvar <- newTVarIO (TestClockTick 0) - withAsync (tickClock tvar) $ \_ -> - k $ TestClock tvar - where - tickClock :: TVar TestClockTick -> IO () - tickClock tvar = forever $ do - threadDelay $ fromIntegral (tickDuration * 1_000) - atomically $ modifyTVar tvar succ - --- | Tick duration, in millseconds --- --- Setting this is finicky: too long and the test take a long time to run, --- too short and the illusion that individual actions take no time fails. --- --- TODO: For now we set this to 1second, which is almost certainly too long. -tickDuration :: Word -tickDuration = 1_000 - -waitForTestClock :: MonadIO m => TestClock -> TestClockTick -> m () -waitForTestClock (TestClock tvar) tick = liftIO $ atomically $ do - clock <- readTVar tvar - unless (clock >= tick) retry - -{------------------------------------------------------------------------------- - Timeouts - - We use the test clock also to test timeouts: a slow server is modelled by one - that waits for a particular test clock tick. To avoid ambiguity, we use - /even/ clockticks for actions and /odd/ clockticks for timeouts. --------------------------------------------------------------------------------} - --- | Does the current 'TestClock' time exceed an (optional) timeout? -exceedsTimeout :: TestClockTick -> Maybe TestClockTick -> Bool -now `exceedsTimeout` Just timeout | now == timeout = error "impossible" -now `exceedsTimeout` Just timeout = now > timeout -_ `exceedsTimeout` Nothing = False - -tickToTimeout :: TestClockTick -> Timeout -tickToTimeout (TestClockTick n) - | even n = error "tickToTimeout: expected odd timeout" - | otherwise = Client.Timeout Client.Millisecond $ - Client.TimeoutValue (n * tickDuration) - -{------------------------------------------------------------------------------- - Test failures --------------------------------------------------------------------------------} - -data TestFailure = TestFailure CallStack Failure - deriving stock (Show) - deriving anyclass (Exception) - -data Failure = - -- | Thrown by the server when an unexpected new RPC is initiated - UnexpectedRequest - - -- | Received an unexpected value - | Unexpected ReceivedUnexpected - deriving stock (Show) - -data ReceivedUnexpected = forall a. Show a => ReceivedUnexpected { - expected :: a - , received :: a - } - -deriving stock instance Show ReceivedUnexpected - -expect :: - (MonadThrow m, Eq a, Show a, HasCallStack) - => a -- ^ Expected - -> a -- ^ Actually received - -> m () -expect expected received - | expected == received - = return () - - | otherwise - = throwM $ TestFailure callStack $ - Unexpected $ ReceivedUnexpected{expected, received} - -{------------------------------------------------------------------------------- - Single channel - - TODO: We should test that the Trailers-Only case gets triggered if no messages - were exchanged before the exception (not sure this is observable without - Wireshark..?). --------------------------------------------------------------------------------} - -data LocalStep = - -- | Server initiates response to the client - -- - -- If there is no explicit 'InitiateResponse', there will be an implicit - -- one, with empty metadata. - InitiateResponse [CustomMetadata] - - -- | Client sends a message the server - | ClientToServer (StreamElem () Int) - - -- | Server sends a message to the client - | ServerToClient (StreamElem [CustomMetadata] Int) - - -- | Client throws an exception - | ClientException - - -- | Server throws an exception - | ServerException - deriving stock (Show) - -newtype LocalSteps = LocalSteps { - getLocalSteps :: [(TestClockTick, LocalStep)] - } - deriving stock (Show) - -data SomeServerException = SomeServerException TestClockTick - deriving stock (Show) - deriving anyclass (Exception) - -data SomeClientException = SomeClientException TestClockTick - deriving stock (Show) - deriving anyclass (Exception) - -{------------------------------------------------------------------------------- - Many channels (birds-eye view) --------------------------------------------------------------------------------} - -data GlobalStep = - -- | Spawn another channel - -- - -- The timeout (if given) is defined in terms of 'TimeClockTick'. - Spawn (Maybe TestClockTick) [CustomMetadata] LocalSteps - deriving stock (Show) - -newtype GlobalSteps = GlobalSteps { - getGlobalSteps :: [(TestClockTick, GlobalStep)] - } - deriving stock (Show) - -{------------------------------------------------------------------------------- - Health --------------------------------------------------------------------------------} - --- | Health --- --- When the client is expecting a response from the server, it needs to know the --- "health" of the server, that is, is the server still alive, or did it fail --- with some kind exception? The same is true for the server when it expects a --- response from the client. Therefore, the client interpretation keeps track of --- the health of the server, and vice versa. -data Health e = Alive | Failed e - -type ServerHealth = Health SomeServerException -type ClientHealth = Health SomeClientException - -{------------------------------------------------------------------------------- - Client-side interpretation --------------------------------------------------------------------------------} - -clientGlobal :: TestClock -> Client.Connection -> GlobalSteps -> IO () -clientGlobal testClock conn = - mapM_ (uncurry go) . getGlobalSteps - where - go :: TestClockTick -> GlobalStep -> IO () - go tick = \case - Spawn timeout metadata localSteps -> do - waitForTestClock testClock tick - let params :: Client.CallParams - params = Client.CallParams { - callTimeout = tickToTimeout <$> timeout - , callRequestMetadata = metadata - } - void $ forkIO $ - Client.withRPC conn params (Proxy @TestRpc) $ \call -> - clientLocal testClock call timeout localSteps - -clientLocal :: - HasCallStack - => TestClock - -> Client.Call TestRpc - -> Maybe TestClockTick -- ^ Timeout specified to the server - -> LocalSteps - -> IO () -clientLocal testClock call serverTimeout = - flip evalStateT Alive . mapM_ (uncurry go) . getLocalSteps - where - go :: TestClockTick -> LocalStep -> StateT ServerHealth IO () - go tick = \case - InitiateResponse expectedMetadata -> liftIO $ do - receivedMetadata <- atomically $ Client.recvResponseMetadata call - expect expectedMetadata $ receivedMetadata - ClientToServer x -> do - waitForTestClock testClock tick - expected <- adjustExpectation tick () - expect expected =<< liftIO (try $ Client.Binary.sendInput call x) - ServerToClient expectedElem -> do - expected <- adjustExpectation tick expectedElem - expect expected =<< liftIO (try $ Client.Binary.recvOutput call) - ClientException -> do - waitForTestClock testClock tick - throwM $ SomeClientException tick - ServerException -> do - previousServerFailure <- adjustExpectation tick () - case previousServerFailure of - Left _ -> return () - Right () -> put $ Failed (SomeServerException tick) - - -- Adjust expectation when communicating with the server - -- - -- If the server handler died for some reason, we won't get the regular - -- result, but should instead see the exception reported to the client. - adjustExpectation :: forall a. - TestClockTick - -> a - -> StateT ServerHealth IO (Either GrpcException a) - adjustExpectation now x = - return . aux =<< get - where - aux :: ServerHealth -> Either GrpcException a - aux serverHealth - -- The gRPC specification does not actually say anywhere which gRPC - -- error code should be used for timeouts, but - -- suggests @CANCELLED@. - | now `exceedsTimeout` serverTimeout - = Left GrpcException { - grpcError = GrpcCancelled - , grpcErrorMessage = Nothing - , grpcErrorMetadata = [] - } - - | Failed err <- serverHealth - = Left GrpcException { - grpcError = GrpcUnknown - , grpcErrorMessage = Just $ Text.pack $ show err - , grpcErrorMetadata = [] - } - - | Alive <- serverHealth - = Right x - -{------------------------------------------------------------------------------- - Server-side interpretation - - The server-side is slightly different, since the infrastructure spawns - threads on our behalf (one for each incoming RPC). --------------------------------------------------------------------------------} - -serverGlobal :: - HasCallStack - => TestClock - -> MVar GlobalSteps - -> Server.Call TestRpc - -> IO () -serverGlobal testClock globalStepsVar call = do - serverLocal testClock call - =<< modifyMVar globalStepsVar (getNextSteps . getGlobalSteps) - where - -- The client starts new calls from a /single/ thread, so the order that - -- these come in is determinstic. Here we peel off the next 'LocalSteps' - -- whilst holding the lock to ensure the same determinism server-side. - getNextSteps :: [(TestClockTick, GlobalStep)] -> IO (GlobalSteps, LocalSteps) - getNextSteps [] = - throwM $ TestFailure callStack $ UnexpectedRequest - -- We don't care about the timeout the client sets; if the server takes too - -- long, the /client/ will check that it gets the expected exception. - getNextSteps ((tick, Spawn _timeout metadata localSteps) : global') = do - receivedMetadata <- Server.getRequestMetadata call - expect metadata $ receivedMetadata - waitForTestClock testClock tick - return (GlobalSteps global', localSteps) - -serverLocal :: - TestClock - -> Server.Call TestRpc - -> LocalSteps -> IO () -serverLocal testClock call = - flip evalStateT Alive . mapM_ (uncurry go) . getLocalSteps - where - go :: TestClockTick -> LocalStep -> StateT ClientHealth IO () - go tick = \case - InitiateResponse metadata -> liftIO $ do - waitForTestClock testClock tick - Server.setResponseMetadata call metadata - void $ Server.initiateResponse call - ClientToServer expectedElem -> do - expected <- adjustExpectation expectedElem - expect expected =<< liftIO (try $ Server.Binary.recvInput call) - ServerToClient x -> do - waitForTestClock testClock tick - expected <- adjustExpectation () - expect expected =<< liftIO (try $ Server.Binary.sendOutput call x) - ServerException -> do - waitForTestClock testClock tick - throwM $ SomeServerException tick - ClientException -> do - previousClientFailure <- adjustExpectation () - case previousClientFailure of - Left _ -> return () - Right () -> put $ Failed (SomeClientException tick) - - -- Adjust expectation when communicating with the server - -- - -- If the server handler died for some reason, we won't get the regular - -- result, but should instead see the exception reported to the client. - adjustExpectation :: forall a. - a - -> StateT ClientHealth IO (Either ClientDisconnected a) - adjustExpectation x = - return . aux =<< get - where - aux :: ClientHealth -> Either ClientDisconnected a - aux clientHealth - | Failed _err <- clientHealth - = Left ClientDisconnected - - | Alive <- clientHealth - = Right x - --- | Exception raised in a handler when the client disappeared --- --- TODO: This is merely wishful thinking at the moment; this needs to move to --- the main library and be implemented there. -data ClientDisconnected = ClientDisconnected - deriving stock (Show, Eq) - deriving anyclass (Exception) - -{------------------------------------------------------------------------------- - Top-level --------------------------------------------------------------------------------} - -execGlobalSteps :: GlobalSteps -> (ClientServerTest -> IO a) -> IO a -execGlobalSteps steps k = do - globalStepsVar <- newMVar steps - withTestClock $ \testClock -> - k $ def { - client = \conn -> clientGlobal testClock conn steps - , server = [ Server.mkRpcHandler (Proxy @TestRpc) $ - serverGlobal testClock globalStepsVar - ] - } - -{------------------------------------------------------------------------------- - QuickCheck: simple values - - We don't (usually) want to have to look at tests failures with unnecessarily - complicated names or values, this simply distracts. The only case where we - /do/ want that is if we are specifically testing parsers/builders. --------------------------------------------------------------------------------} - -simpleHeaderName :: [HeaderName] -simpleHeaderName = ["md1", "md2", "md3"] - -simpleAsciiValue :: [AsciiValue] -simpleAsciiValue = ["a", "b", "c"] - -{------------------------------------------------------------------------------- - QuickCheck --------------------------------------------------------------------------------} - -data LocalGenState = LocalGenState { - generatedInitiateResponse :: Bool - , clientSentFinalMessage :: Bool - , serverSentFinalMessage :: Bool - } - -initLocalGenState :: LocalGenState -initLocalGenState = LocalGenState { - generatedInitiateResponse = False - , clientSentFinalMessage = False - , serverSentFinalMessage = False - } - --- | Generate 'LocalStep's --- --- We are careful which actions we generate: the goal of these tests is to --- ensure correct behaviour of the library, /provided/ the library is used --- correctly. --- --- We do not insert timing here; we do this globally. -genLocalSteps :: Gen [LocalStep] -genLocalSteps = sized $ \sz -> do - n <- choose (0, sz) - flip evalStateT initLocalGenState $ - replicateM n (StateT go) - where - go :: LocalGenState -> Gen (LocalStep, LocalGenState) - go st = oneof $ concat [ - [ (,st) <$> goStateless ] - , [ clientSendMessage st - | not (clientSentFinalMessage st) - ] - , [ serverSendMessage st - | not (serverSentFinalMessage st) - ] - , [ do metadata <- genMetadata - return ( - InitiateResponse metadata - , st { generatedInitiateResponse = True } - ) - | not (generatedInitiateResponse st) - ] - ] - - -- Steps that do not depend on the 'LocalGenState' - goStateless :: Gen LocalStep - goStateless = oneof [ - return ClientException - , return ServerException - ] - - -- Precondition: @not clientSentFinalMessage@ - clientSendMessage :: LocalGenState -> Gen (LocalStep, LocalGenState) - clientSendMessage st = oneof [ - do msg <- genMsg - return ( - ClientToServer $ StreamElem msg - , st - ) - , do msg <- genMsg - return ( - ClientToServer $ FinalElem msg () - , st { clientSentFinalMessage = True } - ) - , do return ( - ClientToServer $ NoMoreElems () - , st { clientSentFinalMessage = True } - ) - ] - - -- Precondition: @not serverSentFinalMessage@ - serverSendMessage :: LocalGenState -> Gen (LocalStep, LocalGenState) - serverSendMessage st = oneof [ - do msg <- genMsg - return ( - ServerToClient (StreamElem msg) - , st - ) - , do msg <- genMsg - metadata <- genMetadata - return ( - ServerToClient $ FinalElem msg metadata - , st { serverSentFinalMessage = True } - ) - , do metadata <- genMetadata - return ( - ServerToClient $ NoMoreElems metadata - , st { serverSentFinalMessage = True } - ) - ] - - genMsg :: Gen Int - genMsg = choose (0, 99) - --- | Generate 'GlobalSteps' by assigning timing to local threads -genGlobalSteps :: [[LocalStep]] -> Gen GlobalSteps -genGlobalSteps = - -- start with clocktick 2, so that we have space for a timeout before - go [] (TestClockTick 2) . mapMaybe initAcc - where - initAcc :: [a] -> Maybe (Maybe b, NonEmpty a) - initAcc [] = Nothing -- filter out empty threads - initAcc (s:ss) = Just (Nothing, s :| ss) - - go :: - -- Fully completed threads - [(TestClockTick, LocalSteps)] - -- Next available clock tick (incremented by two each step) - -> TestClockTick - -- Not yet completed threads, with an accumulator per thread - -> [(Maybe (TestClockTick, LocalSteps), NonEmpty LocalStep)] - -> Gen GlobalSteps - go completed _ [] = do - GlobalSteps <$> mapM mkGlobalStep (sortBy (comparing fst) completed) - go completed now todo = do - (t1, (mSpawned, left@(s :| ss)), t2) <- isolate todo - case mSpawned of - Nothing -> - go completed - (succ (succ now)) - (t1 ++ [(Just (now, LocalSteps []), left)] ++ t2) - Just (spawned, LocalSteps acc) -> do - let acc' = ((now, s) : acc) - case ss of - [] -> - go ((spawned, LocalSteps (reverse acc')) : completed) - (succ (succ now)) - (t1 ++ t2) - s':ss' -> - go completed - (succ (succ now)) - (t1 ++ [(Just (spawned, LocalSteps acc'), s' :| ss')] ++ t2) - - -- Once we have assigned timestamps to all actions, the only thing left to - -- do is to assign timeouts and initial metadata - mkGlobalStep :: - (TestClockTick, LocalSteps) - -> Gen (TestClockTick, GlobalStep) - mkGlobalStep (spawned, steps) = do - metadata <- genMetadata - -- Too many timeouts might reduce test coverage - genTimeout <- frequency [(3, return False), (1, return True)] - timeout <- if genTimeout - then Just <$> elements (pickTimeout steps) - else return Nothing - return (spawned, Spawn timeout metadata steps) - -pickTimeout :: LocalSteps -> [TestClockTick] -pickTimeout = concatMap (\t -> [pred t, succ t]) . map fst . getLocalSteps - -genMetadata :: Gen [CustomMetadata] -genMetadata = do - n <- choose (0, 2) - replicateM n $ oneof [ - BinaryHeader <$> genHeaderName <*> genBinaryValue - , AsciiHeader <$> genHeaderName <*> genAsciiValue - ] - where - -- TODO: Test invalid names - -- TODO: Test "awkward" names - -- TODO: Does it matter if there are duplicates..? - genHeaderName :: Gen HeaderName - genHeaderName = elements simpleHeaderName - - genBinaryValue :: Gen Strict.ByteString - genBinaryValue = sized $ \sz -> do - n <- choose (0, sz) - Strict.pack <$> replicateM n arbitrary - - -- TODO: Test invalid values - -- TODO: Test "awkward" values - genAsciiValue :: Gen AsciiValue - genAsciiValue = elements simpleAsciiValue - -{------------------------------------------------------------------------------- - Shrinking - - Shrinking is simpler than generation: we need 'LocalGenState' to prevent - generating some actions /after/ certain other actions, but it's fine to - /remove/ arbitrary actions from the list. --------------------------------------------------------------------------------} - -shrinkLocalSteps :: LocalSteps -> [LocalSteps] -shrinkLocalSteps = - map LocalSteps - . shrinkList (shrinkTimed shrinkLocalStep) - . getLocalSteps - -shrinkLocalStep :: LocalStep -> [LocalStep] -shrinkLocalStep = \case - InitiateResponse metadata -> - map InitiateResponse $ shrinkMetadataList metadata - ClientToServer x -> - map ClientToServer $ shrinkElem (const []) x - ServerToClient x -> - map ServerToClient $ shrinkElem shrinkMetadataList x - ClientException -> - [] - ServerException -> - [] - -shrinkGlobalStep :: GlobalStep -> [GlobalStep] -shrinkGlobalStep = \case - Spawn mTimeout metadata steps -> concat [ - [ Spawn mTimeout metadata steps' - | steps' <- shrinkLocalSteps steps - ] - , [ Spawn mTimeout metadata' steps - | metadata' <- shrinkMetadataList metadata - ] - - -- Disable timeout - , [ Spawn Nothing metadata steps - | Just _timeout <- [mTimeout] - ] - - -- Timeouts complicate the interpretation of a test, so moving the - -- timeout later in the test reduces how many we need to consider the - -- timeout, and hence results in a simpler test - , [ Spawn (Just timeout') metadata steps - | Just timeout <- [mTimeout] - , timeout' <- pickTimeout steps - , timeout' > timeout - ] - ] - --- TODO: If duplicate header names are a problem (see generation), we need to be --- more careful there too -shrinkMetadataList :: [CustomMetadata] -> [[CustomMetadata]] -shrinkMetadataList = shrinkList shrinkMetadata - -shrinkMetadata :: CustomMetadata -> [CustomMetadata] -shrinkMetadata (BinaryHeader nm val) = concat [ - [ BinaryHeader nm' val - | nm' <- filter (< nm) simpleHeaderName - ] - , [ BinaryHeader nm (Strict.pack val') - | val' <- shrink (Strict.unpack val) - ] - , [ AsciiHeader nm val' - | val' <- simpleAsciiValue - ] - ] -shrinkMetadata (AsciiHeader nm val) = concat [ - [ AsciiHeader nm' val - | nm' <- filter (< nm) simpleHeaderName - ] - , [ AsciiHeader nm val' - | val' <- filter (< val) simpleAsciiValue - ] - ] - --- TODO: We currently don't change the nature of the elem. Not sure what the --- right definition of "simpler" is here -shrinkElem :: (a -> [a]) -> StreamElem a Int -> [StreamElem a Int] -shrinkElem _ (StreamElem x) = concat [ - [ StreamElem x' - | x' <- shrink x - ] - ] -shrinkElem f (FinalElem x y) = concat [ - [ FinalElem x' y - | x' <- shrink x - ] - , [ FinalElem x y' - | y' <- f y - ] - ] -shrinkElem f (NoMoreElems y) = concat [ - [ NoMoreElems y' - | y' <- f y - ] - ] - -shrinkTimed :: (a -> [a]) -> (TestClockTick, a) -> [(TestClockTick, a)] -shrinkTimed f (tick, a) = (tick, ) <$> f a - -{------------------------------------------------------------------------------- - 'Arbitrary' instance - - TODO: When we shrink, we might drop steps, but we don't change timings. - This shouldn't really matter, except that it might mean that the test takes - longer than it really needs to. --------------------------------------------------------------------------------} - -instance Arbitrary GlobalSteps where - arbitrary = do - concurrency <- choose (1, 3) - threads <- replicateM concurrency genLocalSteps - genGlobalSteps threads - - shrink (GlobalSteps steps) = - GlobalSteps <$> shrinkList (shrinkTimed shrinkGlobalStep) steps - -{------------------------------------------------------------------------------- - Auxiliary: QuickCheck --------------------------------------------------------------------------------} - -isolate :: [a] -> Gen ([a], a, [a]) -isolate = \case - [] -> error "isolate: empty list" - xs -> do n <- choose (0, length xs - 1) - return $ go [] xs n - where - go :: [a] -> [a] -> Int -> ([a], a, [a]) - go _ [] _ = error "isolate: impossible" - go prev (x:xs) 0 = (reverse prev, x, xs) - go prev (x:xs) n = go (x:prev) xs (pred n) - +import Test.Driver.Dialogue.Definition as X +import Test.Driver.Dialogue.Execution as X +import Test.Driver.Dialogue.Generation as X +import Test.Driver.Dialogue.TestClock as X diff --git a/test-grapesy/Test/Driver/Dialogue/Definition.hs b/test-grapesy/Test/Driver/Dialogue/Definition.hs new file mode 100644 index 00000000..2f364c87 --- /dev/null +++ b/test-grapesy/Test/Driver/Dialogue/Definition.hs @@ -0,0 +1,83 @@ +module Test.Driver.Dialogue.Definition ( + -- * Local + LocalStep(..) + , RPC(..) + -- * Bird's-eye view + , GlobalSteps(..) + , LocalSteps(..) + ) where + +import Data.Set (Set) +import GHC.Generics qualified as GHC +import Text.Show.Pretty + +import Network.GRPC.Common + +import Test.Driver.Dialogue.TestClock +import Test.Util.PrettyVal () + +{------------------------------------------------------------------------------- + Single channel + + TODO: We should test that the Trailers-Only case gets triggered if no messages + were exchanged before the exception (not sure this is observable without + Wireshark..?). + + TODO: Test what happens when either peer simply disappears. +-------------------------------------------------------------------------------} + +data LocalStep = + -- | Client initiates request to the server + -- + -- When the client initiates a request, they can specify a timeout, initial + -- metadata for the request, as well as which endpoint to connect to. + -- + -- This must happen before anything else. + ClientInitiateRequest (Maybe TestClockTick) Metadata RPC + + -- | Server initiates response to the client + -- + -- If there is no explicit 'ServerInitiate', there will be an implicit + -- one, with empty metadata. + | ServerInitiateResponse Metadata + + -- | Client sends a message the server + | ClientToServer (StreamElem NoMetadata Int) + + -- | Server sends a message to the client + | ServerToClient (StreamElem Metadata Int) + + -- | Client throws an exception + | ClientException + + -- | Server throws an exception + | ServerException + deriving stock (Show, Eq, GHC.Generic) + deriving anyclass (PrettyVal) + +data RPC = RPC1 | RPC2 | RPC3 + deriving stock (Show, Eq, GHC.Generic) + deriving anyclass (PrettyVal) + +-- | Metadata +-- +-- We use 'Set' for 'CustomMetadata' rather than a list, because we do not +-- want to test that the /order/ of the metadata is matched. +type Metadata = Set CustomMetadata + +{------------------------------------------------------------------------------- + Many channels (bird's-eye view) +-------------------------------------------------------------------------------} + +newtype LocalSteps = LocalSteps { + getLocalSteps :: [(TestClockTick, LocalStep)] + } + deriving stock (Show, GHC.Generic) + deriving anyclass (PrettyVal) + +newtype GlobalSteps = GlobalSteps { + getGlobalSteps :: [LocalSteps] + } + deriving stock (Show, GHC.Generic) + deriving anyclass (PrettyVal) + diff --git a/test-grapesy/Test/Driver/Dialogue/Execution.hs b/test-grapesy/Test/Driver/Dialogue/Execution.hs new file mode 100644 index 00000000..92838bf5 --- /dev/null +++ b/test-grapesy/Test/Driver/Dialogue/Execution.hs @@ -0,0 +1,372 @@ +module Test.Driver.Dialogue.Execution ( + execGlobalSteps + ) where + +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.Catch +import Control.Monad.State +import Data.Bifunctor +import Data.Default +import Data.Proxy +import Data.Set qualified as Set +import Data.Text qualified as Text +import GHC.Generics qualified as GHC +import GHC.Stack +import Text.Show.Pretty + +import Network.GRPC.Client qualified as Client +import Network.GRPC.Client.Binary qualified as Client.Binary +import Network.GRPC.Common +import Network.GRPC.Common.Binary +import Network.GRPC.Server qualified as Server +import Network.GRPC.Server.Binary qualified as Server.Binary + +import Test.Driver.ClientServer +import Test.Driver.Dialogue.Definition +import Test.Driver.Dialogue.TestClock +import Test.Util.PrettyVal + +{------------------------------------------------------------------------------- + Endpoints +-------------------------------------------------------------------------------} + +type TestRpc1 = BinaryRpc "dialogue" "test1" +type TestRpc2 = BinaryRpc "dialogue" "test2" +type TestRpc3 = BinaryRpc "dialogue" "test3" + +withProxy :: + RPC + -> (forall serv meth. + IsRPC (BinaryRpc serv meth) + => Proxy (BinaryRpc serv meth) + -> a) + -> a +withProxy RPC1 k = k (Proxy @TestRpc1) +withProxy RPC2 k = k (Proxy @TestRpc2) +withProxy RPC3 k = k (Proxy @TestRpc3) + +{------------------------------------------------------------------------------- + Test failures +-------------------------------------------------------------------------------} + +data TestFailure = TestFailure PrettyCallStack Failure + deriving stock (GHC.Generic) + deriving anyclass (Exception, PrettyVal) + deriving Show via ShowAsPretty TestFailure + +data Failure = + -- | Thrown by the server when an unexpected new RPC is initiated + UnexpectedRequest + + -- | Received an unexpected value + | Unexpected ReceivedUnexpected + deriving stock (Show, GHC.Generic) + deriving anyclass (Exception, PrettyVal) + +data ReceivedUnexpected = forall a. (Show a, PrettyVal a) => ReceivedUnexpected { + expected :: a + , received :: a + } + +deriving stock instance Show ReceivedUnexpected + +instance PrettyVal ReceivedUnexpected where + prettyVal (ReceivedUnexpected{expected, received}) = + Rec "ReceivedUnexpected" [ + ("expected", prettyVal expected) + , ("received", prettyVal received) + ] + +expect :: + (MonadThrow m, Eq a, Show a, PrettyVal a, HasCallStack) + => a -- ^ Expected + -> a -- ^ Actually received + -> m () +expect expected received + | expected == received + = return () + + | otherwise + = throwM $ TestFailure (PrettyCallStack callStack) $ + Unexpected $ ReceivedUnexpected{expected, received} + +{------------------------------------------------------------------------------- + Health +-------------------------------------------------------------------------------} + +-- | Health +-- +-- When the client is expecting a response from the server, it needs to know the +-- "health" of the server, that is, is the server still alive, or did it fail +-- with some kind exception? The same is true for the server when it expects a +-- response from the client. Therefore, the client interpretation keeps track of +-- the health of the server, and vice versa. +data Health e = Alive | Failed e + +type ServerHealth = Health SomeServerException +type ClientHealth = Health SomeClientException + +{------------------------------------------------------------------------------- + Exceptions + + When the dialogue calls for the client or the server to throw an exception, + we throw one of these. Their sole purpose is to be "any" kind of exception + (not a specific one). The 'TestClockTick' argument just makes it easier to + relate the exception back to the dialogue. +-------------------------------------------------------------------------------} + +data SomeServerException = SomeServerException TestClockTick + deriving stock (Show) + deriving anyclass (Exception) + +data SomeClientException = SomeClientException TestClockTick + deriving stock (Show) + deriving anyclass (Exception) + +{------------------------------------------------------------------------------- + Client-side interpretation +-------------------------------------------------------------------------------} + +clientLocal :: + HasCallStack + => TestClock + -> Client.Call (BinaryRpc meth srv) + -> Maybe TestClockTick -- ^ Timeout specified to the server + -> LocalSteps + -> IO () +clientLocal testClock call _serverTimeout = + flip evalStateT Alive . mapM_ (uncurry go) . getLocalSteps + where + go :: TestClockTick -> LocalStep -> StateT ServerHealth IO () + go tick = \case + ClientInitiateRequest{} -> + error "clientLocal: unexpected ClientInitiateRequest" + ServerInitiateResponse expectedMetadata -> liftIO $ do + receivedMetadata <- atomically $ Client.recvResponseMetadata call + expect expectedMetadata $ Set.fromList receivedMetadata + ClientToServer x -> do + waitForTestClock testClock tick + expected <- adjustExpectation tick () + expect expected =<< liftIO (try $ Client.Binary.sendInput call x) + + ServerToClient (FinalElem a b) -> do + -- Known bug (limitation in http2). See recvMessageLoop. + go tick $ ServerToClient (StreamElem a) + go tick $ ServerToClient (NoMoreElems b) + + ServerToClient expectedElem -> do + expected <- adjustExpectation tick expectedElem + received <- liftIO . try $ + fmap (first Set.fromList) $ + Client.Binary.recvOutput call + expect expected received + ClientException -> do + waitForTestClock testClock tick + throwM $ SomeClientException tick + ServerException -> do + previousServerFailure <- adjustExpectation tick () + case previousServerFailure of + Left _ -> return () + Right () -> put $ Failed (SomeServerException tick) + + -- Adjust expectation when communicating with the server + -- + -- If the server handler died for some reason, we won't get the regular + -- result, but should instead see the exception reported to the client. + adjustExpectation :: forall a. + TestClockTick + -> a + -> StateT ServerHealth IO (Either GrpcException a) + adjustExpectation _now x = + return . aux =<< get + where + aux :: ServerHealth -> Either GrpcException a + aux serverHealth + -- The gRPC specification does not actually say anywhere which gRPC + -- error code should be used for timeouts, but + -- suggests @CANCELLED@. + -- + -- TODO: Enable this once we have support timeouts. + -- | now `exceedsTimeout` serverTimeout + -- = Left GrpcException { + -- grpcError = GrpcCancelled + -- , grpcErrorMessage = Nothing + -- , grpcErrorMetadata = [] + -- } + + | Failed err <- serverHealth + = Left GrpcException { + grpcError = GrpcUnknown + , grpcErrorMessage = Just $ Text.pack $ show err + , grpcErrorMetadata = [] + } + + | Alive <- serverHealth + = Right x + +clientGlobal :: TestClock -> Client.Connection -> GlobalSteps -> IO () +clientGlobal testClock conn = \(GlobalSteps globalSteps) -> + go [] globalSteps + where + go :: [Async ()] -> [LocalSteps] -> IO () + go threads [] = + -- Wait for all threads to finish + -- + -- This also ensures that if any of these threads threw an exception, + -- that is now rethrown here in the main test. + mapM_ wait threads + go threads (c:cs) = + withAsync (runLocalSteps c) $ \newThread -> + go (newThread:threads) cs + + runLocalSteps :: LocalSteps -> IO () + runLocalSteps (LocalSteps steps) = do + case steps of + (tick, ClientInitiateRequest timeout metadata rpc) : steps' -> do + waitForTestClock testClock tick + + -- TODO: We are setting the timeout here, but actually the grpc + -- server infrastructure doesn't respect timeouts yet, so this is + -- not actually testing anything (because we're also not testing + -- for expected timeouts). + + let params :: Client.CallParams + params = Client.CallParams { + callTimeout = tickToTimeout <$> timeout + , callRequestMetadata = Set.toList metadata + } + + withProxy rpc $ \proxy -> + Client.withRPC conn params proxy $ \call -> + clientLocal testClock call timeout (LocalSteps steps') + + _otherwise -> + error "clientGlobal: expected ClientInitiateRequest" + +{------------------------------------------------------------------------------- + Server-side interpretation + + The server-side is slightly different, since the infrastructure spawns + threads on our behalf (one for each incoming RPC). +-------------------------------------------------------------------------------} + +serverLocal :: + TestClock + -> Server.Call (BinaryRpc serv meth) + -> LocalSteps -> IO () +serverLocal testClock call = \(LocalSteps steps) -> do + flip evalStateT Alive $ mapM_ (uncurry go) steps + where + go :: TestClockTick -> LocalStep -> StateT ClientHealth IO () + go tick = \case + ClientInitiateRequest{} -> + error "serverLocal: unexpected ClientInitiateRequest" + ServerInitiateResponse metadata -> liftIO $ do + waitForTestClock testClock tick + Server.setResponseMetadata call (Set.toList metadata) + void $ Server.initiateResponse call + + ClientToServer (FinalElem a b) -> do + -- Known bug (limitation in http2). See recvMessageLoop. + go tick $ ClientToServer (StreamElem a) + go tick $ ClientToServer (NoMoreElems b) + + ClientToServer expectedElem -> do + expected <- adjustExpectation expectedElem + expect expected =<< liftIO (try $ Server.Binary.recvInput call) + ServerToClient x -> do + waitForTestClock testClock tick + expected <- adjustExpectation () + received <- try . liftIO $ + Server.Binary.sendOutput call (first Set.toList x) + expect expected received + ServerException -> do + waitForTestClock testClock tick + throwM $ SomeServerException tick + ClientException -> do + previousClientFailure <- adjustExpectation () + case previousClientFailure of + Left _ -> return () + Right () -> put $ Failed (SomeClientException tick) + + -- Adjust expectation when communicating with the server + -- + -- If the server handler died for some reason, we won't get the regular + -- result, but should instead see the exception reported to the client. + adjustExpectation :: forall a. + a + -> StateT ClientHealth IO (Either ClientDisconnected a) + adjustExpectation x = + return . aux =<< get + where + aux :: ClientHealth -> Either ClientDisconnected a + aux clientHealth + | Failed _err <- clientHealth + = Left ClientDisconnected + + | Alive <- clientHealth + = Right x + +serverGlobal :: + HasCallStack + => TestClock + -> MVar GlobalSteps + -- ^ Unlike in the client case, the grapesy infrastructure spawns a new + -- thread for each incoming connection. To know which part of the test this + -- particular handler corresponds to, we take the next 'LocalSteps' from + -- this @MVar@. Since all requests are started by the client from /one/ + -- thread, the order of these incoming requests is deterministic. + -> Server.Call (BinaryRpc serv meth) + -> IO () +serverGlobal testClock globalStepsVar call = do + localSteps <- modifyMVar globalStepsVar (getNextSteps . getGlobalSteps) + serverLocal testClock call localSteps + where + getNextSteps :: [LocalSteps] -> IO (GlobalSteps, LocalSteps) + getNextSteps [] = + throwM $ TestFailure (PrettyCallStack callStack) $ UnexpectedRequest + getNextSteps (LocalSteps steps:global') = + case steps of + (tick, ClientInitiateRequest _timeout metadata _rpc) : steps' -> do + -- We don't care about the timeout the client sets; if the server + -- takes too long, the /client/ will check that it gets the expected + -- exception. + receivedMetadata <- Server.getRequestMetadata call + expect metadata $ Set.fromList receivedMetadata + waitForTestClock testClock tick + return (GlobalSteps global', LocalSteps steps') + _otherwise -> + error "serverGlobal: expected ClientInitiateRequest" + +-- | Exception raised in a handler when the client disappeared +-- +-- TODO: This is merely wishful thinking at the moment; this needs to move to +-- the main library and be implemented there. +data ClientDisconnected = ClientDisconnected + deriving stock (Show, Eq, GHC.Generic) + deriving anyclass (Exception, PrettyVal) + +{------------------------------------------------------------------------------- + Top-level +-------------------------------------------------------------------------------} + +execGlobalSteps :: GlobalSteps -> (ClientServerTest -> IO a) -> IO a +execGlobalSteps steps k = do + globalStepsVar <- newMVar steps + withTestClock $ \testClock -> do + mRes :: Either SomeException a <- try $ k $ def { + client = \conn -> clientGlobal testClock conn steps + , server = [ Server.mkRpcHandler (Proxy @TestRpc1) $ + serverGlobal testClock globalStepsVar + , Server.mkRpcHandler (Proxy @TestRpc2) $ + serverGlobal testClock globalStepsVar + , Server.mkRpcHandler (Proxy @TestRpc3) $ + serverGlobal testClock globalStepsVar + ] + } + case mRes of + Left err -> throwM err + Right a -> return a diff --git a/test-grapesy/Test/Driver/Dialogue/Generation.hs b/test-grapesy/Test/Driver/Dialogue/Generation.hs new file mode 100644 index 00000000..d3753d44 --- /dev/null +++ b/test-grapesy/Test/Driver/Dialogue/Generation.hs @@ -0,0 +1,620 @@ +{-# LANGUAGE OverloadedStrings #-} + +module Test.Driver.Dialogue.Generation ( + Dialogue -- opaque + , dialogueGlobalSteps + , DialogueWithoutExceptions(..) + , DialogueWithExceptions(..) + ) where + +import Control.Monad +import Data.ByteString qualified as Strict +import Data.List.NonEmpty (NonEmpty(..)) +import Data.List.NonEmpty qualified as NE +import Data.Map (Map) +import Data.Map qualified as Map +import Data.Maybe (mapMaybe) +import Data.Set (Set) +import Data.Set qualified as Set +import GHC.Generics qualified as GHC +import Test.QuickCheck +import Text.Show.Pretty + +import Network.GRPC.Common + +import Test.Util.PrettyVal +import Test.Driver.Dialogue.TestClock +import Test.Driver.Dialogue.Definition + +{------------------------------------------------------------------------------- + Metadata + + We don't (usually) want to have to look at tests failures with unnecessarily + complicated names or values, this simply distracts. The only case where we + /do/ want that is if we are specifically testing parsers/builders. +-------------------------------------------------------------------------------} + +simpleHeaderName :: [HeaderName] +simpleHeaderName = ["md1", "md2", "md3"] + +simpleAsciiValue :: [AsciiValue] +simpleAsciiValue = ["a", "b", "c"] + +genMetadata :: Gen (Set CustomMetadata) +genMetadata = do + n <- choose (0, 2) + + -- TODO: For now we avoid generating duplicates; it's not entirely clear + -- what the semantics of that is. + names <- replicateM n genHeaderName `suchThat` allDisjoint + fmap Set.fromList $ forM names $ \nm -> oneof [ + BinaryHeader nm <$> genBinaryValue + , AsciiHeader nm <$> genAsciiValue + ] + where + -- TODO: Test invalid names + -- TODO: Test "awkward" names + -- (Perhaps invalid/awkward names should be separate property tests) + genHeaderName :: Gen HeaderName + genHeaderName = elements simpleHeaderName + + genBinaryValue :: Gen BinaryValue + genBinaryValue = sized $ \sz -> do + n <- choose (0, sz) + BinaryValue . Strict.pack <$> replicateM n arbitrary + + -- TODO: Test invalid values + -- TODO: Test "awkward" values + genAsciiValue :: Gen AsciiValue + genAsciiValue = elements simpleAsciiValue + +{------------------------------------------------------------------------------- + Local steps +-------------------------------------------------------------------------------} + +-- | Generate 'LocalStep's +-- +-- We are careful which actions we generate: the goal of these tests is to +-- ensure correct behaviour of the library, /provided/ the library is used +-- correctly. +-- +-- We do not insert timing here; we do this globally. +genLocalSteps :: + Bool -- ^ Should we generate exceptions? + -> Gen [LocalStep] +genLocalSteps genExceptions = sized $ \sz -> do + n <- choose (0, sz) + (:) <$> genInitStep <*> replicateM n genStep + where + genInitStep :: Gen LocalStep + genInitStep = do + mTimeout <- frequency [ + (3, return Nothing) + , (1, Just . mkTimeout <$> choose (0, 50)) + ] + metadata <- genMetadata + rpc <- elements [RPC1, RPC2, RPC3] + return $ ClientInitiateRequest mTimeout metadata rpc + where + -- Timeouts are always odd numbers (see 'TestClock') + mkTimeout :: Int -> TestClockTick + mkTimeout x = TestClockTick (x * 2 + 1) + + genStep :: Gen LocalStep + genStep = frequency [ + (if genExceptions then 1 else 0, genException) + , (3, ServerInitiateResponse <$> genMetadata) + , (3, ClientToServer <$> genElem (pure NoMetadata)) + , (3, ServerToClient <$> genElem genMetadata) + ] + + genException :: Gen LocalStep + genException = elements [ClientException, ServerException] + + genElem :: Gen b -> Gen (StreamElem b Int) + genElem genTrailers = oneof [ + StreamElem <$> genMsg + , FinalElem <$> genMsg <*> genTrailers + , NoMoreElems <$> genTrailers + ] + + genMsg :: Gen Int + genMsg = choose (0, 99) + +{------------------------------------------------------------------------------- + Ensure correct library usage + + We have two essentially different kinds of tests: correctness of the library + given correct library usage, and reasonable error reporting given incorrect + library usage. We focus on the former here. +-------------------------------------------------------------------------------} + +data LocalGenState = LocalGenState { + clientInitiatedRequest :: Bool + , serverInitiatedResponse :: Bool + , clientSentFinalMessage :: Bool + , serverSentFinalMessage :: Bool + , clientThrewException :: Bool + , serverThrewException :: Bool + } + +initLocalGenState :: LocalGenState +initLocalGenState = LocalGenState { + clientInitiatedRequest = False + , serverInitiatedResponse = False + , clientSentFinalMessage = False + , serverSentFinalMessage = False + , clientThrewException = False + , serverThrewException = False + } + +ensureCorrectUsage :: [(Int, LocalStep)] -> [(Int, LocalStep)] +ensureCorrectUsage = go Map.empty [] + where + go :: + Map Int LocalGenState -- State of each concurrent conversation + -> [(Int, LocalStep)] -- Accumulator (reverse order) + -> [(Int, LocalStep)] -- Still to consider + -> [(Int, LocalStep)] -- Result + go sts acc [] = concat [ + reverse acc + , concatMap (\(i, st) -> (i,) <$> ensureCleanClose st) $ Map.toList sts + ] + go sts acc ((i, s):ss) = case s of + -- Request must be initiated before any messages can be sent + -- + -- During generation we will generate this 'ClientInitiate' step as the + -- first step for each channel, which is then interleaved with all the + -- other steps. If however during shrinking that 'ClientInitiate' got + -- removed, we have to insert it before the first action. + -- + -- We don't have to worry about /multiple/ ClientInitiate messages (they + -- are not generated). + + ClientInitiateRequest{} -> + go (upd st{clientInitiatedRequest = True}) ((i, s) : acc) ss + + _anyStep | not (clientInitiatedRequest st) -> + go sts acc $ (i, ClientInitiateRequest Nothing Set.empty RPC1) + : (i, s) + : ss + + -- Response can only be initiated once, and is initiated implicitly + -- on the first response if not initiated explicitly + + ServerInitiateResponse{} | serverInitiatedResponse st -> + go sts acc ss + + ServerInitiateResponse{} -> + go (upd st{serverInitiatedResponse = True}) ((i, s) : acc) ss + + ServerToClient{} | not (serverInitiatedResponse st) -> + go (upd st{serverInitiatedResponse = True}) acc ((i, s) : ss) + + -- Multiple exceptions arent't /really/ about " correct " usage, but + -- it's simply impossible: if a client throws an exception (and doens't + -- catch it), it just cannot throw another one. + + ClientException{} | clientThrewException st -> + go sts acc ss + + ClientException{} -> + go (upd st{clientThrewException = True}) ((i, s) : acc) ss + + ServerException{} | serverThrewException st -> + go sts acc ss + + ServerException{} -> + go (upd st{serverThrewException = True}) ((i, s) : acc) ss + + -- Make sure no messages are sent after the final one + + ClientToServer{} | clientSentFinalMessage st -> + go sts acc ss + + ClientToServer (StreamElem{}) -> + go sts ((i, s) : acc) ss + + ClientToServer{} -> + go (upd st{clientSentFinalMessage = True}) ((i, s) : acc) ss + + ServerToClient{} | serverSentFinalMessage st -> + go sts acc ss + + ServerToClient (StreamElem{}) -> + go sts ((i, s) : acc) ss + + ServerToClient{} -> + go (upd st{serverSentFinalMessage = True}) ((i, s) : acc) ss + + where + st :: LocalGenState + st = Map.findWithDefault initLocalGenState i sts + + upd :: LocalGenState -> Map Int LocalGenState + upd st' = Map.insert i st' sts + + -- Make sure all channels are closed cleanly + ensureCleanClose :: LocalGenState -> [LocalStep] + ensureCleanClose st = concat [ + [ ClientToServer $ NoMoreElems NoMetadata + | clientInitiatedRequest st + , not $ clientSentFinalMessage st + , not $ clientThrewException st + ] + + , [ ServerToClient $ NoMoreElems Set.empty + | clientInitiatedRequest st + , not $ serverSentFinalMessage st + , not $ serverThrewException st + ] + ] + +{------------------------------------------------------------------------------- + Dialogue + + We generate the steps for each channel separately first, and then choose a + particular interleaving ('interleave'). Before execution, we then assign + timings ('assignTimings'), separating the channels again. The advantage of + this representation is that it is easier to shrink: we can shrink the /global/ + step of steps, whilst keeping the choice of interleaving; that is much harder + to do when the channels are kept separate. + + In addition, we apply 'ensureCorrectUsage' just before execution. We do /not/ + do this as part of generation/shrinking, as doing so could shrinking + non-wellfounded (shrinking might remove some action which gets reinserted by + 'ensureCorrectUsage', and shrinking will loop). +-------------------------------------------------------------------------------} + +newtype Dialogue = Dialogue { + getDialogue :: [(Int, LocalStep)] + } + deriving stock (Eq, GHC.Generic) + deriving anyclass (PrettyVal) + deriving Show via ShowAsPretty Dialogue + +dialogueGlobalSteps :: Dialogue -> GlobalSteps +dialogueGlobalSteps = + GlobalSteps + . map LocalSteps + . assignTimings + . ensureCorrectUsage + . getDialogue + +-- | Shrink dialogue +-- +-- We ignore 'dialogueCleanup', and simply reconstruct it. See 'Dialogue' for +-- further discussion. +shrinkDialogue :: Dialogue -> [Dialogue] +shrinkDialogue = + map Dialogue + . shrinkList (shrinkInterleaved shrinkLocalStep) + . getDialogue + +{------------------------------------------------------------------------------- + Shrinking +-------------------------------------------------------------------------------} + +shrinkLocalStep :: LocalStep -> [LocalStep] +shrinkLocalStep = \case + ClientInitiateRequest mTimeout metadata rpc -> concat [ + -- Disable the timeout altogether + -- + -- Apart from disabling the timeout, it's not entirely clear what + -- constitutes a "simpler" timeout, so for now we don't attempt any + -- other shrinking of the timeout + [ ClientInitiateRequest Nothing metadata rpc + | Just _ <- [mTimeout] + ] + , [ ClientInitiateRequest mTimeout metadata' rpc + | metadata' <- shrinkMetadataSet metadata + ] + , [ ClientInitiateRequest mTimeout metadata rpc' + | rpc' <- shrinkRPC rpc + ] + ] + ServerInitiateResponse metadata -> + map ServerInitiateResponse $ shrinkMetadataSet metadata + ClientToServer x -> + map ClientToServer $ shrinkElem (const []) x + ServerToClient x -> + map ServerToClient $ shrinkElem shrinkMetadataSet x + ClientException -> + [] + ServerException -> + [] + +shrinkRPC :: RPC -> [RPC] +shrinkRPC RPC1 = [] +shrinkRPC RPC2 = [RPC1] +shrinkRPC RPC3 = [RPC1, RPC2] + +shrinkMetadataSet :: Set CustomMetadata -> [Set CustomMetadata] +shrinkMetadataSet = + map Set.fromList + . filter (allDisjoint . map customHeaderName) + . shrinkList shrinkMetadata + . Set.toList + +shrinkMetadata :: CustomMetadata -> [CustomMetadata] +shrinkMetadata (BinaryHeader nm (BinaryValue val)) = concat [ + -- prefer ASCII headers over binary headers + [ AsciiHeader nm val' + | val' <- simpleAsciiValue + ] + -- shrink the name + , [ BinaryHeader nm' (BinaryValue val) + | nm' <- filter (< nm) simpleHeaderName + ] + -- aggressively try to shrink to a single byte + , [ BinaryHeader nm (BinaryValue (Strict.pack [x])) + | x:_:_ <- [Strict.unpack val] + ] + -- normal shrinking of binary values + , [ BinaryHeader nm (BinaryValue (Strict.pack val')) + | val' <- shrink (Strict.unpack val) + ] + ] +shrinkMetadata (AsciiHeader nm val) = concat [ + [ AsciiHeader nm' val + | nm' <- filter (< nm) simpleHeaderName + ] + , [ AsciiHeader nm val' + | val' <- filter (< val) simpleAsciiValue + ] + ] + +-- TODO: We currently don't change the nature of the elem. Not sure what the +-- right definition of "simpler" is here +shrinkElem :: (a -> [a]) -> StreamElem a Int -> [StreamElem a Int] +shrinkElem _ (StreamElem x) = concat [ + [ StreamElem x' + | x' <- shrink x + ] + ] +shrinkElem f (FinalElem x y) = concat [ + [ FinalElem x' y + | x' <- shrink x + ] + , [ FinalElem x y' + | y' <- f y + ] + ] +shrinkElem f (NoMoreElems y) = concat [ + [ NoMoreElems y' + | y' <- f y + ] + ] + +shrinkInterleaved :: (a -> [a]) -> (Int, a) -> [(Int, a)] +shrinkInterleaved f (i, a) = (i, ) <$> f a + +{------------------------------------------------------------------------------- + Arbitrary instance +-------------------------------------------------------------------------------} + +newtype DialogueWithoutExceptions = DialogueWithoutExceptions Dialogue + deriving stock (Eq, GHC.Generic) + deriving anyclass (PrettyVal) + deriving Show via ShowAsPretty DialogueWithoutExceptions + +newtype DialogueWithExceptions = DialogueWithExceptions Dialogue + deriving stock (Eq, GHC.Generic) + deriving anyclass (PrettyVal) + deriving Show via ShowAsPretty DialogueWithExceptions + +instance Arbitrary DialogueWithoutExceptions where + arbitrary = do + concurrency <- choose (1, 1) -- TODO + threads <- replicateM concurrency $ genLocalSteps False + DialogueWithoutExceptions . Dialogue <$> interleave threads + + shrink (DialogueWithoutExceptions dialogue) = + DialogueWithoutExceptions <$> shrinkDialogue dialogue + +instance Arbitrary DialogueWithExceptions where + arbitrary = do + concurrency <- choose (1, 3) + threads <- replicateM concurrency $ genLocalSteps True + DialogueWithExceptions . Dialogue <$> interleave threads + + shrink (DialogueWithExceptions dialogue) = + DialogueWithExceptions <$> shrinkDialogue dialogue + +{------------------------------------------------------------------------------- + Auxiliary: Interleavings + + We assign timings /from/ interleavings. This is more suitable to shrinking; + during generation we pick an arbitrary interleaving, then as we shrink, we + leave the interleaving (mostly) alone, but re-assign timings. +-------------------------------------------------------------------------------} + +-- | Pick an arbitrary interleaving +-- +-- This flattens the list; each item in the result is an element from /one/ +-- of the input lists, annotated with the index of that particular list. +-- +-- > ghci> sample (interleave ["abc", "de"]) +-- > [(0,'a'),(1,'d'),(0,'b'),(0,'c'),(1,'e')] +-- > [(1,'d'),(0,'a'),(0,'b'),(1,'e'),(0,'c')] +-- > [(1,'d'),(1,'e'),(0,'a'),(0,'b'),(0,'c')] +interleave :: forall a. [[a]] -> Gen [(Int, a)] +interleave = + go . mapMaybe (\(i, xs) -> (i,) <$> NE.nonEmpty xs) . zip [0..] + where + go :: [(Int, NonEmpty a)] -> Gen [(Int, a)] + go [] = return [] + go xs = do + (as, (i, b :| bs), cs) <- isolate xs + fmap ((i, b) :) $ + case bs of + [] -> go (as ++ cs) + b' : bs' -> go (as ++ [(i, b' :| bs')] ++ cs) + +-- | Assign timings, given an interleaving +-- +-- In some sense this is an inverse to 'interleave': +-- +-- > assignTimings [(0,'a'),(0,'b'),(1,'d'),(1,'e'),(0,'c')] +-- > == [ [ (TestClockTick 0,'a') +-- > , (TestClockTick 2,'b') +-- > , (TestClockTick 8,'c') +-- > ] +-- > , [ (TestClockTick 4,'d') +-- > ,( TestClockTick 6,'e') +-- > ] +-- > ] +-- +-- Put another way +-- +-- > map (map snd) . assignTimings <$> interleave xs +-- +-- will just generate @xs@. +assignTimings :: [(Int, a)] -> [[(TestClockTick, a)]] +assignTimings = go (TestClockTick 0) + where + go :: TestClockTick -> [(Int, a)] -> [[(TestClockTick, a)]] + go _ [] = [] + go t ((i, x):xs) = insert i (t, x) (go (succ (succ t)) xs) + + insert :: Int -> a -> [[a]] -> [[a]] + insert i x [] = replicate i [] ++ [[x]] + insert 0 x (xs:xss) = (x:xs) : xss + insert i x (xs:xss) = xs : insert (i - 1) x xss + +{------------------------------------------------------------------------------- + Auxiliary: QuickCheck +-------------------------------------------------------------------------------} + +isolate :: [a] -> Gen ([a], a, [a]) +isolate = \case + [] -> error "isolate: empty list" + xs -> do n <- choose (0, length xs - 1) + return $ go [] xs n + where + go :: [a] -> [a] -> Int -> ([a], a, [a]) + go _ [] _ = error "isolate: impossible" + go prev (x:xs) 0 = (reverse prev, x, xs) + go prev (x:xs) n = go (x:prev) xs (pred n) + +allDisjoint :: Ord a => [a] -> Bool +allDisjoint xs = Set.size (Set.fromList xs) == length xs + +{------------------------------------------------------------------------------- + TODO: Temp +-------------------------------------------------------------------------------} + + +{- +-- this shrinks to the same thing? +-- actually, we cannot see this here, we need the dialogue /before/ the +-- calling dialogueGlobalSteps. +example :: Dialogue +example = Dialogue { + dialogueMain = + [ ( 0 + , ClientInitiateRequest Nothing (Set.fromList []) RPC1 + ) + , ( 0 + , ServerInitiateResponse + (Set.fromList + [ BinaryHeader + "md1" + (BinaryValue + (Strict.pack + [ 120 + , 115 + , 227 + , 99 + , 152 + , 159 + , 235 + , 50 + , 193 + , 167 + , 26 + , 229 + , 209 + , 66 + , 78 + , 236 + , 49 + , 191 + , 62 + , 82 + , 138 + , 34 + , 221 + , 158 + , 39 + , 86 + , 229 + , 220 + , 133 + , 84 + , 196 + , 47 + , 175 + , 80 + , 72 + , 100 + , 209 + , 231 + , 14 + , 213 + , 141 + ])) + , BinaryHeader + "md3" + (BinaryValue + (Strict.pack + [ 187 + , 0 + , 148 + , 178 + , 43 + , 106 + , 45 + , 102 + , 73 + , 92 + , 140 + , 59 + , 206 + , 72 + , 18 + , 234 + , 59 + , 134 + , 29 + , 32 + , 225 + , 1 + , 102 + , 195 + , 235 + , 228 + , 34 + , 165 + , 45 + , 67 + , 220 + , 177 + , 39 + , 19 + , 12 + , 138 + , 190 + , 199 + , 183 + , 59 + ])) + ]) + ) + ] + , dialogueCleanup = [ + ( 0 , ClientToServer (NoMoreElems NoMetadata) ) + , ( 0 , ServerToClient (NoMoreElems (Set.fromList []))) + ] + } +-} \ No newline at end of file diff --git a/test-grapesy/Test/Driver/Dialogue/TestClock.hs b/test-grapesy/Test/Driver/Dialogue/TestClock.hs new file mode 100644 index 00000000..aed97dc0 --- /dev/null +++ b/test-grapesy/Test/Driver/Dialogue/TestClock.hs @@ -0,0 +1,82 @@ +module Test.Driver.Dialogue.TestClock ( + TestClock -- opaque + , TestClockTick(..) + , withTestClock + , waitForTestClock + -- * Timeouts + , exceedsTimeout + , tickToTimeout + ) where + +import Control.Concurrent +import Control.Concurrent.Async +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.State +import GHC.Generics qualified as GHC +import Text.Show.Pretty + +import Network.GRPC.Client (Timeout) +import Network.GRPC.Client qualified as Client + +{------------------------------------------------------------------------------- + Test clock + + These are concurrent tests, looking at the behaviour of @n@ concurrent + connections between a client and a server. To make the interleaving of actions + easier to see, we introduce a global "test clock". Every test action is + annotated with a particular "test tick" on this clock. If multiple actions are + annotated with the same tick, their interleaving is essentially + non-deterministic. +-------------------------------------------------------------------------------} + +newtype TestClock = TestClock (TVar TestClockTick) + +newtype TestClockTick = TestClockTick Int + deriving stock (Show, GHC.Generic) + deriving newtype (Eq, Ord, Enum) + deriving anyclass (PrettyVal) + +withTestClock :: (TestClock -> IO a) -> IO a +withTestClock k = do + tvar <- newTVarIO (TestClockTick 0) + withAsync (tickClock tvar) $ \_ -> + k $ TestClock tvar + where + tickClock :: TVar TestClockTick -> IO () + tickClock tvar = forever $ do + threadDelay $ fromIntegral (tickDuration * 1_000) + atomically $ modifyTVar tvar succ + +-- | Tick duration, in millseconds +-- +-- Setting this is finicky: too long and the test take a long time to run, +-- too short and the illusion that individual actions take no time fails. +tickDuration :: Word +tickDuration = 50 + +waitForTestClock :: MonadIO m => TestClock -> TestClockTick -> m () +waitForTestClock (TestClock tvar) tick = liftIO $ atomically $ do + clock <- readTVar tvar + unless (clock >= tick) retry + +{------------------------------------------------------------------------------- + Timeouts + + We use the test clock also to test timeouts: a slow server is modelled by one + that waits for a particular test clock tick. To avoid ambiguity, we use + /even/ clockticks for actions and /odd/ clockticks for timeouts. +-------------------------------------------------------------------------------} + +-- | Does the current 'TestClock' time exceed an (optional) timeout? +exceedsTimeout :: TestClockTick -> Maybe TestClockTick -> Bool +now `exceedsTimeout` Just timeout | now == timeout = error "impossible" +now `exceedsTimeout` Just timeout = now > timeout +_ `exceedsTimeout` Nothing = False + +tickToTimeout :: TestClockTick -> Timeout +tickToTimeout (TestClockTick n) + | even n = error "tickToTimeout: expected odd timeout" + | otherwise = Client.Timeout Client.Millisecond $ + Client.TimeoutValue (fromIntegral n * tickDuration) + diff --git a/test-grapesy/Test/Prop/Dialogue.hs b/test-grapesy/Test/Prop/Dialogue.hs new file mode 100644 index 00000000..954ea0c9 --- /dev/null +++ b/test-grapesy/Test/Prop/Dialogue.hs @@ -0,0 +1,53 @@ +module Test.Prop.Dialogue (tests) where + +import Control.Exception +import Test.Tasty +import Test.Tasty.QuickCheck + +import Test.Driver.Dialogue +import Test.Driver.ClientServer + +tests :: TestTree +tests = testGroup "Test.Prop.Dialogue" [ + testGroup "Regression" [ + ] + , testGroup "Setup" [ + testProperty "shrinkingWellFounded" prop_shrinkingWellFounded + ] + , testGroup "Arbitrary" [ + testProperty "withoutExceptions" arbitraryWithoutExceptions +-- , testProperty "withExceptions" arbitraryWithExceptions + ] + ] + +{------------------------------------------------------------------------------- + Verify setup is correct +-------------------------------------------------------------------------------} + +prop_shrinkingWellFounded :: Property +prop_shrinkingWellFounded = + -- Explicit 'forAll' so that we can disable shrinking here + -- (If shrinking is /not/ well-founded, then we should not try to shrink!) + forAll arbitrary $ \(d :: DialogueWithoutExceptions) -> + d `notElem` shrink d + +{------------------------------------------------------------------------------- + Running the tests +-------------------------------------------------------------------------------} + +arbitraryWithoutExceptions :: DialogueWithoutExceptions -> Property +arbitraryWithoutExceptions (DialogueWithoutExceptions dialogue) = + propClientServer classifyHandlerException $ + execGlobalSteps $ dialogueGlobalSteps dialogue + +_arbitraryWithExceptions :: DialogueWithExceptions -> Property +_arbitraryWithExceptions (DialogueWithExceptions dialogue) = + propClientServer classifyHandlerException $ + execGlobalSteps $ dialogueGlobalSteps dialogue + +_regression :: GlobalSteps -> IO String +_regression steps = + testClientServer classifyHandlerException $ execGlobalSteps steps + +classifyHandlerException :: SomeException -> Maybe UnexpectedHandlerException +classifyHandlerException = Just . UnexpectedHandlerException diff --git a/test-grapesy/Test/Prop/Serialization.hs b/test-grapesy/Test/Prop/Serialization.hs new file mode 100644 index 00000000..cf626e06 --- /dev/null +++ b/test-grapesy/Test/Prop/Serialization.hs @@ -0,0 +1,66 @@ +module Test.Prop.Serialization (tests) where + +import Control.Monad.Except +import Test.Tasty +import Test.Tasty.QuickCheck + +import Network.GRPC.Spec +import Data.ByteString qualified as Strict + +tests :: TestTree +tests = testGroup "Test.Prop.Serialization" [ + testGroup "Roundtrip" [ + testProperty "CustomMetadata" $ + roundtrip buildCustomMetadata parseCustomMetadata + ] + ] + +{------------------------------------------------------------------------------- + Roundtrip tests +-------------------------------------------------------------------------------} + +roundtrip :: forall a b. + (Eq a, Show a, Show b) + => (a -> b) -> (b -> Except String a) -> Awkward a -> Property +roundtrip there back (Awkward a) = + counterexample (show b) $ + runExcept (back b) === Right a + where + b :: b + b = there a + +{------------------------------------------------------------------------------- + Arbitrary instances +-------------------------------------------------------------------------------} + +-- | Newtype wrapper for \"awkward\" 'Arbitrary' instances +-- +-- In most property tests we don't explore edge cases, preferring for example +-- to use only simple header names, rather than check encoding issues. But in +-- these serialization tests the edge cases are of course important. +newtype Awkward a = Awkward { getAwkward :: a } + deriving (Show) + +awkward :: Arbitrary (Awkward a) => Gen a +awkward = getAwkward <$> arbitrary + +instance Arbitrary (Awkward CustomMetadata) where + arbitrary = Awkward <$> do + name <- awkward + oneof [ + BinaryHeader name <$> awkward + , AsciiHeader name <$> awkward + ] + +instance Arbitrary (Awkward HeaderName) where + arbitrary = Awkward <$> + suchThatMap (Strict.pack <$> arbitrary) safeHeaderName + +instance Arbitrary (Awkward BinaryValue) where + arbitrary = Awkward <$> + BinaryValue . Strict.pack <$> arbitrary + +instance Arbitrary (Awkward AsciiValue) where + arbitrary = Awkward <$> + suchThatMap (Strict.pack <$> arbitrary) safeAsciiValue + diff --git a/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs b/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs index 71e058ed..05380ef0 100644 --- a/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs +++ b/test-grapesy/Test/Sanity/StreamingType/NonStreaming.hs @@ -1,5 +1,6 @@ module Test.Sanity.StreamingType.NonStreaming (tests) where +import Control.Exception import Data.Default import Data.Proxy import Data.Word @@ -73,7 +74,7 @@ tests = testGroup "Test.Sanity.StreamingType.NonStreaming" [ type BinaryIncrement = BinaryRpc "binary" "increment" test_increment :: ClientServerConfig -> IO String -test_increment config = testClientServer $ \k -> k def { +test_increment config = testClientServer classifyHandlerException $ \k -> k def { config , client = \conn -> do Client.withRPC conn def (Proxy @BinaryIncrement) $ \call -> do @@ -86,4 +87,7 @@ test_increment config = testClientServer $ \k -> k def { return (succ n) ] } + where + classifyHandlerException :: SomeException -> Maybe UnexpectedHandlerException + classifyHandlerException = Just . UnexpectedHandlerException diff --git a/test-stress/Main.hs b/test-stress/Main.hs index bef1d64f..bb270f97 100644 --- a/test-stress/Main.hs +++ b/test-stress/Main.hs @@ -13,6 +13,7 @@ import Network.GRPC.Server.StreamType import Test.Stress.Cmdline import Test.Util.ClientServer +import Control.Tracer {------------------------------------------------------------------------------- Barebones stress test @@ -52,7 +53,7 @@ runClient test = runServer :: IO () runServer = -- Server.withServer params handlers $ Server.runServer config - runTestServer def [ + runTestServer def (contramap show stdoutTracer) [ streamingRpcHandler (Proxy @ManyShortLived) $ Binary.mkNonStreaming serverManyShortLived ]