Skip to content

Commit

Permalink
init hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
kakkun61 committed Mar 15, 2024
1 parent 1155945 commit 0a876bc
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 38 deletions.
3 changes: 2 additions & 1 deletion hedis.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ library
Database.Redis.Commands,
Database.Redis.ManualCommands,
Database.Redis.URL,
Database.Redis.ConnectionContext
Database.Redis.ConnectionContext,
Database.Redis.Hooks
other-extensions: StrictData

benchmark hedis-benchmark
Expand Down
3 changes: 3 additions & 0 deletions src/Database/Redis.hs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ module Database.Redis (
-- * Pub\/Sub
module Database.Redis.PubSub,

-- * Hooks
Hooks(..), SendRequestHook, SendPubSubHook, CallbackHook, SendHook, ReceiveHook, defaultHooks,

-- * Low-Level Command API
sendRequest,
Reply(..), Status(..), RedisArg(..), RedisResult(..), ConnectionLostException(..),
Expand Down
29 changes: 17 additions & 12 deletions src/Database/Redis/Cluster.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module Database.Redis.Cluster
, disconnect
, requestPipelined
, nodes
, hooks
) where

import qualified Data.ByteString as B
Expand All @@ -36,6 +37,7 @@ import System.IO.Unsafe(unsafeInterleaveIO)

import Database.Redis.Protocol(Reply(Error), renderRequest, reply)
import qualified Database.Redis.Cluster.Command as CMD
import Database.Redis.Hooks (Hooks)

-- This module implements a clustered connection whilst maintaining
-- compatibility with the original Hedis codebase. In particular it still
Expand All @@ -48,7 +50,7 @@ import qualified Database.Redis.Cluster.Command as CMD

-- | A connection to a redis cluster, it is compoesed of a map from Node IDs to
-- | 'NodeConnection's, a 'Pipeline', and a 'ShardMap'
data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap
data Connection = Connection (HM.HashMap NodeID NodeConnection) (MVar Pipeline) (MVar ShardMap) CMD.InfoMap Hooks

-- | A connection to a single node in the cluster, similar to 'ProtocolPipelining.Connection'
data NodeConnection = NodeConnection CC.ConnectionContext (IOR.IORef (Maybe B.ByteString)) NodeID
Expand Down Expand Up @@ -100,13 +102,13 @@ instance Exception UnsupportedClusterCommandException
newtype CrossSlotException = CrossSlotException [[B.ByteString]] deriving (Show, Typeable)
instance Exception CrossSlotException

connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection
connect commandInfos shardMapVar timeoutOpt = do
connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> Hooks -> IO Connection
connect commandInfos shardMapVar timeoutOpt hooks' = do
shardMap <- readMVar shardMapVar
stateVar <- newMVar $ Pending []
pipelineVar <- newMVar $ Pipeline stateVar
nodeConns <- nodeConnections shardMap
return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) where
return $ Connection nodeConns pipelineVar shardMapVar (CMD.newInfoMap commandInfos) hooks' where
nodeConnections :: ShardMap -> IO (HM.HashMap NodeID NodeConnection)
nodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap)
connectNode :: Node -> IO (NodeID, NodeConnection)
Expand All @@ -116,14 +118,14 @@ connect commandInfos shardMapVar timeoutOpt = do
return (n, NodeConnection ctx ref n)

disconnect :: Connection -> IO ()
disconnect (Connection nodeConnMap _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where
disconnect (Connection nodeConnMap _ _ _ _) = mapM_ disconnectNode (HM.elems nodeConnMap) where
disconnectNode (NodeConnection nodeCtx _ _) = CC.disconnect nodeCtx

-- Add a request to the current pipeline for this connection. The pipeline will
-- be executed implicitly as soon as any result returned from this function is
-- evaluated.
requestPipelined :: IO ShardMap -> Connection -> [B.ByteString] -> IO Reply
requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
requestPipelined refreshAction conn@(Connection _ pipelineVar shardMapVar _ _) nextRequest = modifyMVar pipelineVar $ \(Pipeline stateVar) -> do
(newStateVar, repliesIndex) <- hasLocked $ modifyMVar stateVar $ \case
Pending requests | isMulti nextRequest -> do
replies <- evaluatePipeline shardMapVar refreshAction conn requests
Expand Down Expand Up @@ -228,7 +230,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
-- there is one.
case last replies of
(Error errString) | B.isPrefixOf "MOVED" errString -> do
let (Connection _ _ _ infoMap) = conn
let (Connection _ _ _ infoMap _) = conn
keys <- mconcat <$> mapM (requestKeys infoMap) requests
hashSlot <- hashSlotForKeys (CrossSlotException requests) keys
nodeConn <- nodeConnForHashSlot shardMapVar conn (MissingNodeException (head requests)) hashSlot
Expand All @@ -250,7 +252,7 @@ retryBatch shardMapVar refreshShardmapAction conn retryCount requests replies =
evaluateTransactionPipeline :: MVar ShardMap -> IO ShardMap -> Connection -> [[B.ByteString]] -> IO [Reply]
evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = do
let requests = reverse requests'
let (Connection _ _ _ infoMap) = conn
let (Connection _ _ _ infoMap _) = conn
keys <- mconcat <$> mapM (requestKeys infoMap) requests
-- In cluster mode Redis expects commands in transactions to all work on the
-- same hashslot. We find that hashslot here.
Expand Down Expand Up @@ -296,7 +298,7 @@ evaluateTransactionPipeline shardMapVar refreshShardmapAction conn requests' = d

nodeConnForHashSlot :: Exception e => MVar ShardMap -> Connection -> e -> HashSlot -> IO NodeConnection
nodeConnForHashSlot shardMapVar conn exception hashSlot = do
let (Connection nodeConns _ _ _) = conn
let (Connection nodeConns _ _ _ _) = conn
(ShardMap shardMap) <- hasLocked $ readMVar shardMapVar
node <-
case IntMap.lookup (fromEnum hashSlot) shardMap of
Expand Down Expand Up @@ -339,12 +341,12 @@ moved _ = False


nodeConnWithHostAndPort :: ShardMap -> Connection -> Host -> Port -> Maybe NodeConnection
nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _) host port = do
nodeConnWithHostAndPort shardMap (Connection nodeConns _ _ _ _) host port = do
node <- nodeWithHostAndPort shardMap host port
HM.lookup (nodeId node) nodeConns

nodeConnectionForCommand :: Connection -> ShardMap -> [B.ByteString] -> IO [NodeConnection]
nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap) (ShardMap shardMap) request =
nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap _) (ShardMap shardMap) request =
case request of
("FLUSHALL" : _) -> allNodes
("FLUSHDB" : _) -> allNodes
Expand All @@ -364,7 +366,7 @@ nodeConnectionForCommand conn@(Connection nodeConns _ _ infoMap) (ShardMap shard
Just allNodes' -> return allNodes'

allMasterNodes :: Connection -> ShardMap -> Maybe [NodeConnection]
allMasterNodes (Connection nodeConns _ _ _) (ShardMap shardMap) =
allMasterNodes (Connection nodeConns _ _ _ _) (ShardMap shardMap) =
mapM (flip HM.lookup nodeConns . nodeId) masterNodes
where
masterNodes = (\(Shard master _) -> master) <$> nub (IntMap.elems shardMap)
Expand Down Expand Up @@ -410,3 +412,6 @@ hasLocked action =
action `catches`
[ Handler $ \exc@BlockedIndefinitelyOnMVar -> throwIO exc
]

hooks :: Connection -> Hooks
hooks (Connection _ _ _ _ h) = h
13 changes: 8 additions & 5 deletions src/Database/Redis/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import qualified Network.Socket as NS
import qualified Data.HashMap.Strict as HM

import qualified Database.Redis.ProtocolPipelining as PP
import Database.Redis.Core(Redis, runRedisInternal, runRedisClusteredInternal)
import Database.Redis.Core(Redis, Hooks, runRedisInternal, runRedisClusteredInternal, defaultHooks)
import Database.Redis.Protocol(Reply(..))
import Database.Redis.Cluster(ShardMap(..), Node, Shard(..))
import qualified Database.Redis.Cluster as Cluster
Expand Down Expand Up @@ -97,6 +97,7 @@ data ConnectInfo = ConnInfo
-- get connected in this interval of time.
, connectTLSParams :: Maybe ClientParams
-- ^ Optional TLS parameters. TLS will be enabled if this is provided.
, connectHooks :: Hooks
} deriving Show

data ConnectError = ConnectAuthError Reply
Expand All @@ -117,6 +118,7 @@ instance Exception ConnectError
-- connectMaxIdleTime = 30 -- Keep open for 30 seconds
-- connectTimeout = Nothing -- Don't add timeout logic
-- connectTLSParams = Nothing -- Do not use TLS
-- connectHooks = defaultHooks -- Do nothing
-- @
--
defaultConnectInfo :: ConnectInfo
Expand All @@ -130,13 +132,14 @@ defaultConnectInfo = ConnInfo
, connectMaxIdleTime = 30
, connectTimeout = Nothing
, connectTLSParams = Nothing
, connectHooks = defaultHooks
}

createConnection :: ConnectInfo -> IO PP.Connection
createConnection ConnInfo{..} = do
let timeoutOptUs =
round . (1000000 *) <$> connectTimeout
conn <- PP.connect connectHost connectPort timeoutOptUs
conn <- PP.connect connectHost connectPort timeoutOptUs connectHooks
conn' <- case connectTLSParams of
Nothing -> return conn
Just tlsParams -> PP.enableTLS tlsParams conn
Expand Down Expand Up @@ -231,9 +234,9 @@ connectCluster bootstrapConnInfo = do
Left e -> throwIO $ ClusterConnectError e
Right infos -> do
#if MIN_VERSION_resource_pool(0,3,0)
pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo))
pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing $ connectHooks bootstrapConnInfo) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo))
#else
pool <- createPool (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
pool <- createPool (Cluster.connect infos shardMapVar Nothing $ connectHooks bootstrapConnInfo) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)
#endif
return $ ClusteredConnection shardMapVar pool

Expand All @@ -255,7 +258,7 @@ shardMapFromClusterSlotsResponse ClusterSlotsResponse{..} = ShardMap <$> foldr m
Cluster.Node clusterSlotsNodeID role hostname (toEnum clusterSlotsNodePort)

refreshShardMap :: Cluster.Connection -> IO ShardMap
refreshShardMap (Cluster.Connection nodeConns _ _ _) = do
refreshShardMap (Cluster.Connection nodeConns _ _ _ _) = do
let (Cluster.NodeConnection ctx _ _) = head $ HM.elems nodeConns
pipelineConn <- PP.fromCtx ctx
_ <- PP.beginReceiving pipelineConn
Expand Down
7 changes: 5 additions & 2 deletions src/Database/Redis/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
module Database.Redis.Core (
Redis(), unRedis, reRedis,
RedisCtx(..), MonadRedis(..),
Hooks(..), SendRequestHook, SendPubSubHook, CallbackHook, SendHook, ReceiveHook,
send, recv, sendRequest,
runRedisInternal,
runRedisClusteredInternal,
defaultHooks,
RedisEnv(..),
) where

Expand All @@ -24,6 +26,7 @@ import qualified Database.Redis.ProtocolPipelining as PP
import Database.Redis.Types
import Database.Redis.Cluster(ShardMap)
import qualified Database.Redis.Cluster as Cluster
import Database.Redis.Hooks

--------------------------------------------------------------------------------
-- The Redis Monad
Expand Down Expand Up @@ -118,8 +121,8 @@ sendRequest req = do
env <- ask
case env of
NonClusteredEnv{..} -> do
r <- liftIO $ PP.request envConn (renderRequest req)
r <- liftIO $ sendRequestHook (PP.hooks envConn) (PP.request envConn . renderRequest) req
setLastReply r
return r
ClusteredEnv{..} -> liftIO $ Cluster.requestPipelined refreshAction connection req
ClusteredEnv{..} -> liftIO $ sendRequestHook (Cluster.hooks connection) (Cluster.requestPipelined refreshAction connection) req
returnDecode r'
44 changes: 44 additions & 0 deletions src/Database/Redis/Hooks.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
module Database.Redis.Hooks where

import Data.ByteString (ByteString)
import Database.Redis.Protocol (Reply)
import {-# SOURCE #-} Database.Redis.PubSub (Message, PubSub)

data Hooks =
Hooks
{ sendRequestHook :: SendRequestHook
, sendPubSubHook :: SendPubSubHook
, callbackHook :: CallbackHook
, sendHook :: SendHook
, recieveHook :: ReceiveHook
}

-- | A hook for sending commands to the server and receiving replys from the server.
type SendRequestHook = ([ByteString] -> IO Reply) -> [ByteString] -> IO Reply

-- | A hook for sending pub/sub messages to the server.
type SendPubSubHook = ([ByteString] -> IO ()) -> [ByteString] -> IO ()

-- | A hook for invoking callbacks with pub/sub messages.
type CallbackHook = (Message -> IO PubSub) -> Message -> IO PubSub

-- | A hook for just sending raw data to the server.
type SendHook = (ByteString -> IO ()) -> ByteString -> IO ()

-- | A hook for receiving raw data from the server.
type ReceiveHook = IO Reply -> IO Reply

-- | The default hooks.
-- Every hook is the identity function.
defaultHooks :: Hooks
defaultHooks =
Hooks
{ sendRequestHook = id
, sendPubSubHook = id
, callbackHook = id
, sendHook = id
, recieveHook = id
}

instance Show Hooks where
show _ = "Hooks {sendRequestHook = _, sendPubSubHook = _, callbackHook = _, sendHook = _, recieveHook = _}"
21 changes: 12 additions & 9 deletions src/Database/Redis/ProtocolPipelining.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
--
module Database.Redis.ProtocolPipelining (
Connection,
connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush, fromCtx
connect, enableTLS, beginReceiving, disconnect, request, send, recv, flush, fromCtx, hooks
) where

import Prelude
Expand All @@ -31,6 +31,7 @@ import System.IO.Unsafe

import Database.Redis.Protocol
import qualified Database.Redis.ConnectionContext as CC
import Database.Redis.Hooks

data Connection = Conn
{ connCtx :: CC.ConnectionContext -- ^ Connection socket-handle.
Expand All @@ -42,14 +43,15 @@ data Connection = Conn
-- ^ Number of pending replies and thus the difference length between
-- 'connReplies' and 'connPending'.
-- length connPending - pendingCount = length connReplies
, hooks :: Hooks
}


fromCtx :: CC.ConnectionContext -> IO Connection
fromCtx ctx = Conn ctx <$> newIORef [] <*> newIORef [] <*> newIORef 0
fromCtx ctx = Conn ctx <$> newIORef [] <*> newIORef [] <*> newIORef 0 <*> pure defaultHooks

connect :: NS.HostName -> CC.PortID -> Maybe Int -> IO Connection
connect hostName portId timeoutOpt = do
connect :: NS.HostName -> CC.PortID -> Maybe Int -> Hooks -> IO Connection
connect hostName portId timeoutOpt hooks = do
connCtx <- CC.connect hostName portId timeoutOpt
connReplies <- newIORef []
connPending <- newIORef []
Expand All @@ -74,7 +76,7 @@ disconnect Conn{..} = CC.disconnect connCtx
-- The 'Handle' is 'hFlush'ed when reading replies from the 'connCtx'.
send :: Connection -> S.ByteString -> IO ()
send Conn{..} s = do
CC.send connCtx s
sendHook hooks (CC.send connCtx) s

-- Signal that we expect one more reply from Redis.
n <- atomicModifyIORef' connPendingCnt $ \n -> let n' = n+1 in (n', n')
Expand All @@ -88,10 +90,11 @@ send Conn{..} s = do

-- |Take a reply-thunk from the list of future replies.
recv :: Connection -> IO Reply
recv Conn{..} = do
(r:rs) <- readIORef connReplies
writeIORef connReplies rs
return r
recv Conn{..} =
recieveHook hooks $ do
(r:rs) <- readIORef connReplies
writeIORef connReplies rs
return r

-- | Flush the socket. Normally, the socket is flushed in 'recv' (actually 'conGetReplies'), but
-- for the multithreaded pub/sub code, the sending thread needs to explicitly flush the subscription
Expand Down
Loading

0 comments on commit 0a876bc

Please sign in to comment.