diff --git a/hedis.cabal b/hedis.cabal index b504a9df..e5802984 100644 --- a/hedis.cabal +++ b/hedis.cabal @@ -108,7 +108,8 @@ library Database.Redis.Commands, Database.Redis.ManualCommands, Database.Redis.URL, - Database.Redis.ConnectionContext + Database.Redis.ConnectionContext, + Database.Redis.Hooks other-extensions: StrictData benchmark hedis-benchmark diff --git a/src/Database/Redis.hs b/src/Database/Redis.hs index b8ec8849..1112816e 100644 --- a/src/Database/Redis.hs +++ b/src/Database/Redis.hs @@ -176,6 +176,9 @@ module Database.Redis ( -- * Pub\/Sub module Database.Redis.PubSub, + -- * Hooks + Hooks(..), SendRequestHook, SendPubSubHook, CallbackHook, SendHook, ReceiveHook, defaultHooks, + -- * Low-Level Command API sendRequest, Reply(..), Status(..), RedisArg(..), RedisResult(..), ConnectionLostException(..), diff --git a/src/Database/Redis/Cluster.hs b/src/Database/Redis/Cluster.hs index 2f577bc2..d9664987 100644 --- a/src/Database/Redis/Cluster.hs +++ b/src/Database/Redis/Cluster.hs @@ -15,6 +15,7 @@ module Database.Redis.Cluster , disconnect , requestPipelined , nodes + , hooks ) where import qualified Data.ByteString as B @@ -36,6 +37,7 @@ import System.IO.Unsafe(unsafeInterleaveIO) import Database.Redis.Protocol(Reply(Error), renderRequest, reply) import qualified Database.Redis.Cluster.Command as CMD +import Database.Redis.Hooks (Hooks) -- This module implements a clustered connection whilst maintaining -- compatibility with the original Hedis codebase. In particular it still @@ -48,7 +50,7 @@ import qualified Database.Redis.Cluster.Command as CMD -- | A connection to a redis cluster, it is compoesed of a map from Node IDs to -- | 'NodeConnection's, a 'Pipeline', and a 'ShardMap' -data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap +data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap Hooks -- | A connection to a single node in the cluster, similar to 'ProtocolPipelining.Connection' data NodeConnection = NodeConnection CC.ConnectionContext (IOR.IORef (Maybe B.ByteString)) NodeID @@ -100,13 +102,13 @@ instance Exception UnsupportedClusterCommandException newtype CrossSlotException = CrossSlotException [[B.ByteString]] deriving (Show, Typeable) instance Exception CrossSlotException -connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection -connect commandInfos shardMapVar timeoutOpt = do +connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> Hooks -> IO Connection +connect commandInfos shardMapVar timeoutOpt hooks' = do shardMap <- readMVar shardMapVar stateVar <- newMVar $ Pending [] pipelineVar <- newMVar $ Pipeline stateVar nodeConns <- nodeConnections shardMap - return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) where + return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) hooks' where nodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection) nodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap) connectNode :: Node -> IO (NodeID, NodeConnection) @@ -116,14 +118,14 @@ connect commandInfos shardMapVar timeoutOpt = do return (n, NodeConnection ctx ref n) disconnect :: Connection -> IO () -disconnect (Connection nodeConnMap _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where +disconnect (Connection nodeConnMap _ _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where disconnectNode (NodeConnection nodeCtx _ _) = CC.disconnect nodeCtx -- Add a request to the current pipeline for this connection. The pipeline will -- be executed implicitly as soon as any result returned from this function is -- evaluated. requestPipelined :: IO ShardMap -> Connection -> [B.ByteString] -> IO Reply -requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do +requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do (newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case Pending requests | isMulti nextRequest -> do replies <- evaluatePipeline shardMapVar refreshAction conn requests @@ -228,7 +230,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies = -- there is one. case last replies of (Error errString) | B.isPrefixOf "MOVED" errString -> do - let (Connection _ _ _ infoMap) = conn + let (Connection _ _ _ infoMap _) = conn keys <- mconcat <$> mapM (requestKeys infoMap) requests hashSlot <- hashSlotForKeys (CrossSlotException requests) keys nodeConn <- nodeConnForHashSlot shardMapVar conn (MissingNodeException (head requests)) hashSlot @@ -250,7 +252,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies = evaluateTransactionPipeline :: MVar ShardMap -> IO ShardMap -> Connection -> [[B.ByteString]] -> IO [Reply] evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = do let requests = reverse requests' - let (Connection _ _ _ infoMap) = conn + let (Connection _ _ _ infoMap _) = conn keys <- mconcat <$> mapM (requestKeys infoMap) requests -- In cluster mode Redis expects commands in transactions to all work on the -- same hashslot. We find that hashslot here. @@ -296,7 +298,7 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d nodeConnForHashSlot :: Exception e => MVar ShardMap -> Connection -> e -> HashSlot -> IO NodeConnection nodeConnForHashSlot shardMapVar conn exception hashSlot = do - let (Connection nodeConns _ _ _) = conn + let (Connection nodeConns _ _ _ _) = conn (ShardMap shardMap) <- hasLocked $ readMVar shardMapVar node <- case IntMap.lookup (fromEnum hashSlot) shardMap of @@ -339,12 +341,12 @@ moved _ = False nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> Maybe NodeConnection -nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _) host port = do +nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _ _) host port = do node <- nodeWithHostAndPort shardMap host port HM.lookup (nodeId node) nodeConns nodeConnectionForCommand :: Connection -> ShardMap -> [B.ByteString] -> IO [NodeConnection] -nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap) (ShardMap shardMap) request = +nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap shardMap) request = case request of ("FLUSHALL" : _) -> allNodes ("FLUSHDB" : _) -> allNodes @@ -364,7 +366,7 @@ nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap) (ShardMap shard Just allNodes' -> return allNodes' allMasterNodes :: Connection -> ShardMap -> Maybe [NodeConnection] -allMasterNodes (Connection nodeConns _ _ _) (ShardMap shardMap) = +allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) = mapM (flip HM.lookup nodeConns . nodeId) masterNodes where masterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap) @@ -410,3 +412,6 @@ hasLocked action = action `catches` [ Handler $ \exc@BlockedIndefinitelyOnMVar -> throwIO exc ] + +hooks :: Connection -> Hooks +hooks (Connection _ _ _ _ h) = h diff --git a/src/Database/Redis/Connection.hs b/src/Database/Redis/Connection.hs index 156662ec..9afae9de 100644 --- a/src/Database/Redis/Connection.hs +++ b/src/Database/Redis/Connection.hs @@ -31,7 +31,7 @@ import qualified Network.Socket as NS import qualified Data.HashMap.Strict as HM import qualified Database.Redis.ProtocolPipelining as PP -import Database.Redis.Core(Redis, runRedisInternal, runRedisClusteredInternal) +import Database.Redis.Core(Redis, Hooks, runRedisInternal, runRedisClusteredInternal, defaultHooks) import Database.Redis.Protocol(Reply(..)) import Database.Redis.Cluster(ShardMap(..), Node, Shard(..)) import qualified Database.Redis.Cluster as Cluster @@ -97,6 +97,7 @@ data ConnectInfo = ConnInfo -- get connected in this interval of time. , connectTLSParams :: Maybe ClientParams -- ^ Optional TLS parameters. TLS will be enabled if this is provided. + , connectHooks :: Hooks } deriving Show data ConnectError = ConnectAuthError Reply @@ -117,6 +118,7 @@ instance Exception ConnectError -- connectMaxIdleTime = 30 -- Keep open for 30 seconds -- connectTimeout = Nothing -- Don't add timeout logic -- connectTLSParams = Nothing -- Do not use TLS +-- connectHooks = defaultHooks -- Do nothing -- @ -- defaultConnectInfo :: ConnectInfo @@ -130,13 +132,14 @@ defaultConnectInfo = ConnInfo , connectMaxIdleTime = 30 , connectTimeout = Nothing , connectTLSParams = Nothing + , connectHooks = defaultHooks } createConnection :: ConnectInfo -> IO PP.Connection createConnection ConnInfo{..} = do let timeoutOptUs = round . (1000000 *) <$> connectTimeout - conn <- PP.connect connectHost connectPort timeoutOptUs + conn <- PP.connectWithHooks connectHost connectPort timeoutOptUs connectHooks conn' <- case connectTLSParams of Nothing -> return conn Just tlsParams -> PP.enableTLS tlsParams conn @@ -231,9 +234,9 @@ connectCluster bootstrapConnInfo = do Left e -> throwIO $ ClusterConnectError e Right infos -> do #if MIN_VERSION_resource_pool(0,3,0) - pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)) + pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing $ connectHooks bootstrapConnInfo) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)) #else - pool <- createPool (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) + pool <- createPool (Cluster.connect infos shardMapVar Nothing $ connectHooks bootstrapConnInfo) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) #endif return $ ClusteredConnection shardMapVar pool @@ -255,7 +258,7 @@ shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr m Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort) refreshShardMap :: Cluster.Connection -> IO ShardMap -refreshShardMap (Cluster.Connection nodeConns _ _ _) = do +refreshShardMap (Cluster.Connection nodeConns _ _ _ _) = do let (Cluster.NodeConnection ctx _ _) = head $ HM.elems nodeConns pipelineConn <- PP.fromCtx ctx _ <- PP.beginReceiving pipelineConn diff --git a/src/Database/Redis/Core.hs b/src/Database/Redis/Core.hs index 36cf1904..49d2b866 100644 --- a/src/Database/Redis/Core.hs +++ b/src/Database/Redis/Core.hs @@ -5,9 +5,11 @@ module Database.Redis.Core ( Redis(), unRedis, reRedis, RedisCtx(..), MonadRedis(..), + Hooks(..), SendRequestHook, SendPubSubHook, CallbackHook, SendHook, ReceiveHook, send, recv, sendRequest, runRedisInternal, runRedisClusteredInternal, + defaultHooks, RedisEnv(..), ) where @@ -24,6 +26,7 @@ import qualified Database.Redis.ProtocolPipelining as PP import Database.Redis.Types import Database.Redis.Cluster(ShardMap) import qualified Database.Redis.Cluster as Cluster +import Database.Redis.Hooks -------------------------------------------------------------------------------- -- The Redis Monad @@ -118,8 +121,8 @@ sendRequest req = do env <- ask case env of NonClusteredEnv{..} -> do - r <- liftIO $ PP.request envConn (renderRequest req) + r <- liftIO $ sendRequestHook (PP.hooks envConn) (PP.request envConn . renderRequest) req setLastReply r return r - ClusteredEnv{..} -> liftIO $ Cluster.requestPipelined refreshAction connection req + ClusteredEnv{..} -> liftIO $ sendRequestHook (Cluster.hooks connection) (Cluster.requestPipelined refreshAction connection) req returnDecode r' diff --git a/src/Database/Redis/Hooks.hs b/src/Database/Redis/Hooks.hs new file mode 100644 index 00000000..bd6ba362 --- /dev/null +++ b/src/Database/Redis/Hooks.hs @@ -0,0 +1,44 @@ +module Database.Redis.Hooks where + +import Data.ByteString (ByteString) +import Database.Redis.Protocol (Reply) +import {-# SOURCE #-} Database.Redis.PubSub (Message, PubSub) + +data Hooks = + Hooks + { sendRequestHook :: SendRequestHook + , sendPubSubHook :: SendPubSubHook + , callbackHook :: CallbackHook + , sendHook :: SendHook + , receiveHook :: ReceiveHook + } + +-- | A hook for sending commands to the server and receiving replys from the server. +type SendRequestHook = ([ByteString] -> IO Reply) -> [ByteString] -> IO Reply + +-- | A hook for sending pub/sub messages to the server. +type SendPubSubHook = ([ByteString] -> IO ()) -> [ByteString] -> IO () + +-- | A hook for invoking callbacks with pub/sub messages. +type CallbackHook = (Message -> IO PubSub) -> Message -> IO PubSub + +-- | A hook for just sending raw data to the server. +type SendHook = (ByteString -> IO ()) -> ByteString -> IO () + +-- | A hook for receiving raw data from the server. +type ReceiveHook = IO Reply -> IO Reply + +-- | The default hooks. +-- Every hook is the identity function. +defaultHooks :: Hooks +defaultHooks = + Hooks + { sendRequestHook = id + , sendPubSubHook = id + , callbackHook = id + , sendHook = id + , receiveHook = id + } + +instance Show Hooks where + show _ = "Hooks {sendRequestHook = _, sendPubSubHook = _, callbackHook = _, sendHook = _, receiveHook = _}" diff --git a/src/Database/Redis/ProtocolPipelining.hs b/src/Database/Redis/ProtocolPipelining.hs index dc0d22fe..1677dd0b 100644 --- a/src/Database/Redis/ProtocolPipelining.hs +++ b/src/Database/Redis/ProtocolPipelining.hs @@ -17,7 +17,7 @@ -- module Database.Redis.ProtocolPipelining ( Connection, - connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush, fromCtx + connect, connectWithHooks, enableTLS, beginReceiving, disconnect, request, send, recv, flush, fromCtx, hooks ) where import Prelude @@ -31,6 +31,7 @@ import System.IO.Unsafe import Database.Redis.Protocol import qualified Database.Redis.ConnectionContext as CC +import Database.Redis.Hooks data Connection = Conn { connCtx :: CC.ConnectionContext -- ^ Connection socket-handle. @@ -42,14 +43,18 @@ data Connection = Conn -- ^ Number of pending replies and thus the difference length between -- 'connReplies' and 'connPending'. -- length connPending - pendingCount = length connReplies + , hooks :: Hooks } fromCtx :: CC.ConnectionContext -> IO Connection -fromCtx ctx = Conn ctx <$> newIORef [] <*> newIORef [] <*> newIORef 0 +fromCtx ctx = Conn ctx <$> newIORef [] <*> newIORef [] <*> newIORef 0 <*> pure defaultHooks connect :: NS.HostName -> CC.PortID -> Maybe Int -> IO Connection -connect hostName portId timeoutOpt = do +connect hostName portId timeoutOpt = connectWithHooks hostName portId timeoutOpt defaultHooks + +connectWithHooks :: NS.HostName -> CC.PortID -> Maybe Int -> Hooks -> IO Connection +connectWithHooks hostName portId timeoutOpt hooks = do connCtx <- CC.connect hostName portId timeoutOpt connReplies <- newIORef [] connPending <- newIORef [] @@ -74,7 +79,7 @@ disconnect Conn{..} = CC.disconnect connCtx -- The 'Handle' is 'hFlush'ed when reading replies from the 'connCtx'. send :: Connection -> S.ByteString -> IO () send Conn{..} s = do - CC.send connCtx s + sendHook hooks (CC.send connCtx) s -- Signal that we expect one more reply from Redis. n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n') @@ -88,10 +93,11 @@ send Conn{..} s = do -- |Take a reply-thunk from the list of future replies. recv :: Connection -> IO Reply -recv Conn{..} = do - (r:rs) <- readIORef connReplies - writeIORef connReplies rs - return r +recv Conn{..} = + receiveHook hooks $ do + (r:rs) <- readIORef connReplies + writeIORef connReplies rs + return r -- | Flush the socket. Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but -- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription diff --git a/src/Database/Redis/PubSub.hs b/src/Database/Redis/PubSub.hs index 71022b2c..df449f68 100644 --- a/src/Database/Redis/PubSub.hs +++ b/src/Database/Redis/PubSub.hs @@ -28,6 +28,7 @@ import Control.Concurrent.Async (withAsync, waitEitherCatch, waitEitherCatchSTM) import Control.Concurrent.STM import Control.Exception (throwIO) import Control.Monad +import Control.Monad.Reader (asks) import Control.Monad.State import Data.ByteString.Char8 (ByteString) import Data.List (foldl') @@ -42,6 +43,8 @@ import qualified Database.Redis.Connection as Connection import qualified Database.Redis.ProtocolPipelining as PP import Database.Redis.Protocol (Reply(..), renderRequest) import Database.Redis.Types +import Control.Monad.IO.Unlift (MonadUnliftIO(withRunInIO)) +import Data.Functor (($>)) -- |While in PubSub mode, we keep track of the number of current subscriptions -- (as reported by Redis replies) and the number of messages we expect to @@ -111,8 +114,10 @@ class Command a where sendCmd :: (Command (Cmd a b)) => Cmd a b -> StateT PubSubState Core.Redis () sendCmd DoNothing = return () sendCmd cmd = do - lift $ Core.send (redisCmd cmd : changes cmd) - modifyPending (updatePending cmd) + conn <- lift $ Core.reRedis $ asks Core.envConn + let hook = Core.sendPubSubHook $ PP.hooks conn + lift $ withRunInIO $ \runInIO -> hook (runInIO . Core.send) (redisCmd cmd : changes cmd) + modifyPending (updatePending cmd) cmdCount :: Cmd a b -> Int cmdCount DoNothing = 0 @@ -124,7 +129,10 @@ totalPendingChanges (PubSub{..}) = rawSendCmd :: (Command (Cmd a b)) => PP.Connection -> Cmd a b -> IO () rawSendCmd _ DoNothing = return () -rawSendCmd conn cmd = PP.send conn $ renderRequest $ redisCmd cmd : changes cmd +rawSendCmd conn cmd = + let hook = Core.sendPubSubHook $ PP.hooks conn + msg = redisCmd cmd : changes cmd + in hook (PP.send conn . renderRequest) msg plusChangeCnt :: Cmd a b -> Int -> Int plusChangeCnt DoNothing = id @@ -246,9 +254,10 @@ pubSub initial callback recv :: StateT PubSubState Core.Redis () recv = do + hook <- lift $ Core.reRedis $ asks $ Core.callbackHook . PP.hooks . Core.envConn reply <- lift Core.recv case decodeMsg reply of - Msg msg -> liftIO (callback msg) >>= send + Msg msg -> liftIO (hook callback msg) >>= send Subscribed -> modifyPending (subtract 1) >> recv Unsubscribed n -> do putSubCnt n @@ -492,16 +501,16 @@ listenThread :: PubSubController -> PP.Connection -> IO () listenThread ctrl rawConn = forever $ do msg <- PP.recv rawConn case decodeMsg msg of - Msg (Message channel msgCt) -> do + Msg message@(Message channel _) -> do cm <- atomically $ readTVar (callbacks ctrl) case HM.lookup channel cm of Nothing -> return () - Just c -> mapM_ (\(_,x) -> x msgCt) c - Msg (PMessage pattern channel msgCt) -> do + Just c -> void $ Core.callbackHook (PP.hooks rawConn) (\m -> mapM_ (\(_,x) -> x $ msgMessage m) c $> mempty) message + Msg message@(PMessage pattern _ _) -> do pm <- atomically $ readTVar (pcallbacks ctrl) case HM.lookup pattern pm of Nothing -> return () - Just c -> mapM_ (\(_,x) -> x channel msgCt) c + Just c -> void $ Core.callbackHook (PP.hooks rawConn) (\m -> mapM_ (\(_,x) -> x (msgChannel m) (msgMessage m)) c $> mempty) message Subscribed -> atomically $ modifyTVar (pendingCnt ctrl) (\x -> x - 1) Unsubscribed _ -> atomically $ diff --git a/src/Database/Redis/PubSub.hs-boot b/src/Database/Redis/PubSub.hs-boot new file mode 100644 index 00000000..f596af55 --- /dev/null +++ b/src/Database/Redis/PubSub.hs-boot @@ -0,0 +1,8 @@ +module Database.Redis.PubSub ( + Message, + PubSub, +) where + +data PubSub + +data Message diff --git a/src/Database/Redis/Sentinel.hs b/src/Database/Redis/Sentinel.hs index 8c15b3c5..d3a4f0d8 100644 --- a/src/Database/Redis/Sentinel.hs +++ b/src/Database/Redis/Sentinel.hs @@ -46,7 +46,6 @@ import Control.Exception (Exception, IOException, evaluate, throwI import Control.Monad import Control.Monad.Catch (Handler (..), MonadCatch, catches, throwM) import Control.Monad.Except -import Control.Monad.IO.Class(MonadIO(liftIO)) import Data.ByteString (ByteString) import qualified Data.ByteString as BS import qualified Data.ByteString.Char8 as BS8