diff --git a/client/rpc_client.go b/client/rpc_client.go index d643ea7c9..1882a310e 100644 --- a/client/rpc_client.go +++ b/client/rpc_client.go @@ -2,7 +2,7 @@ package client import ( "bufio" - "fmt" + "errors" "log" "net" "sync" @@ -20,7 +20,9 @@ const ( ) var ( - clientClosed = fmt.Errorf("client closed") + errClientClosed = errors.New("client closed") + errStreamClosed = errors.New("stream closed") + errRequestTimeout = errors.New("request timeout") ) type seqCallback struct { @@ -50,6 +52,11 @@ type Config struct { // If provided, overrides the DefaultTimeout used for // IO deadlines Timeout time.Duration + + // Logger is a custom logger which you provide. If Logger is set, it will use + // this for the internal logger. If Logger is not set, it will fall back to the + // default logger from the log package. + Logger *log.Logger } // RPCClient is used to make requests to the Agent using an RPC mechanism. @@ -65,6 +72,7 @@ type RPCClient struct { dec *codec.Decoder enc *codec.Encoder writeLock sync.Mutex + logger *log.Logger dispatch map[uint64]seqHandler dispatchLock sync.Mutex @@ -81,7 +89,7 @@ func (c *RPCClient) send(header *requestHeader, obj interface{}) error { defer c.writeLock.Unlock() if c.shutdown { - return clientClosed + return errClientClosed } // Setup an IO deadline, this way we won't wait indefinitely @@ -140,6 +148,10 @@ func ClientFromConfig(c *Config) (*RPCClient, error) { writer: bufio.NewWriter(conn), dispatch: make(map[uint64]seqHandler), shutdownCh: make(chan struct{}), + logger: c.Logger, + } + if client.logger == nil { + client.logger = log.Default() } client.dec = codec.NewDecoder(client.reader, &codec.MsgpackHandle{RawToString: true, WriteExt: true}) @@ -403,47 +415,72 @@ func (c *RPCClient) GetCoordinate(node string) (*coordinate.Coordinate, error) { } type monitorHandler struct { + // These fields are constant client *RPCClient - closed bool - init bool + seq uint64 + + // These fields relate to the initial response. Once the initial response has been received, init + // is atomically set and the initial response is put into initCh. + init uint32 // atomic initCh chan<- error + + // These fields relate to whether or not the stream handler is still open and the log channel. + // The two following fields are protected by the mutex. + mtx sync.Mutex + closed bool logCh chan<- string - seq uint64 } func (mh *monitorHandler) Handle(resp *responseHeader) { // Initialize on the first response - if !mh.init { - mh.init = true + if atomic.CompareAndSwapUint32(&mh.init, 0, 1) { mh.initCh <- strToError(resp.Error) return } - // Decode logs for all other responses + // Decode the log var rec logRecord if err := mh.client.dec.Decode(&rec); err != nil { - log.Printf("[ERR] Failed to decode log: %v", err) + mh.client.logger.Printf("[ERR] Failed to decode log: %v", err) mh.client.deregisterHandler(mh.seq) return } + + // Take the mutex for the remainder of this function to ensure safe access to member variables + mh.mtx.Lock() + defer mh.mtx.Unlock() + + // If we're closed, dump the response + if mh.closed { + mh.client.logger.Printf("[WARN] Dropping monitor response, handler closed") + return + } + + // Not closed, so feed the response to the log channel select { case mh.logCh <- rec.Log: default: - log.Printf("[ERR] Dropping log! Monitor channel full") + mh.client.logger.Printf("[ERR] Dropping log! Monitor channel full") } } func (mh *monitorHandler) Cleanup() { - if !mh.closed { - if !mh.init { - mh.init = true - mh.initCh <- fmt.Errorf("Stream closed") - } - if mh.logCh != nil { - close(mh.logCh) - } - mh.closed = true + if atomic.CompareAndSwapUint32(&mh.init, 0, 1) { + mh.initCh <- errStreamClosed + } + + mh.mtx.Lock() + defer mh.mtx.Unlock() + + if mh.closed { + return } + + if mh.logCh != nil { + close(mh.logCh) + } + + mh.closed = true } // Monitor is used to subscribe to the logs of the agent @@ -460,6 +497,7 @@ func (c *RPCClient) Monitor(level logutils.LogLevel, ch chan<- string) (StreamHa // Create a monitor handler initCh := make(chan error, 1) + defer close(initCh) handler := &monitorHandler{ client: c, initCh: initCh, @@ -480,52 +518,80 @@ func (c *RPCClient) Monitor(level logutils.LogLevel, ch chan<- string) (StreamHa return StreamHandle(seq), err case <-c.shutdownCh: c.deregisterHandler(seq) - return 0, clientClosed + return 0, errClientClosed + case <-time.After(c.timeout): + c.deregisterHandler(seq) + return 0, errRequestTimeout } } type streamHandler struct { - client *RPCClient + // These fields are constant + client *RPCClient + seq uint64 + + // These fields relate to the initial response. Once the initial response has been received, init + // is atomically set and the initial response is put into initCh. + init uint32 // atomic + initCh chan<- error + + // These fields relate to whether or not the stream handler is still open and the event channel. + // The two following fields are protected by the mutex. + mtx sync.Mutex closed bool - init bool - initCh chan<- error eventCh chan<- map[string]interface{} - seq uint64 } func (sh *streamHandler) Handle(resp *responseHeader) { // Initialize on the first response - if !sh.init { - sh.init = true + if atomic.CompareAndSwapUint32(&sh.init, 0, 1) { sh.initCh <- strToError(resp.Error) return } - // Decode logs for all other responses + // Decode the event var rec map[string]interface{} if err := sh.client.dec.Decode(&rec); err != nil { - log.Printf("[ERR] Failed to decode stream record: %v", err) + sh.client.logger.Printf("[ERR] Failed to decode stream record: %v", err) sh.client.deregisterHandler(sh.seq) return } + + // Take the mutex for the remainder of this function to ensure safe access to member variables + sh.mtx.Lock() + defer sh.mtx.Unlock() + + // If we're closed, dump the response + if sh.closed { + sh.client.logger.Printf("[WARN] Dropping stream response, handler closed") + return + } + + // Not closed, so feed the response to the event channel select { case sh.eventCh <- rec: default: - log.Printf("[ERR] Dropping event! Stream channel full") + sh.client.logger.Printf("[ERR] Dropping event! Stream channel full") } } func (sh *streamHandler) Cleanup() { - if !sh.closed { - if !sh.init { - sh.init = true - sh.initCh <- fmt.Errorf("Stream closed") - } - if sh.eventCh != nil { - close(sh.eventCh) - } - sh.closed = true + if atomic.CompareAndSwapUint32(&sh.init, 0, 1) { + sh.initCh <- errStreamClosed + } + + sh.mtx.Lock() + defer sh.mtx.Unlock() + + if sh.closed { + return + } + + if sh.eventCh != nil { + close(sh.eventCh) } + + sh.closed = true } // Stream is used to subscribe to events @@ -542,6 +608,7 @@ func (c *RPCClient) Stream(filter string, ch chan<- map[string]interface{}) (Str // Create a monitor handler initCh := make(chan error, 1) + defer close(initCh) handler := &streamHandler{ client: c, initCh: initCh, @@ -562,24 +629,34 @@ func (c *RPCClient) Stream(filter string, ch chan<- map[string]interface{}) (Str return StreamHandle(seq), err case <-c.shutdownCh: c.deregisterHandler(seq) - return 0, clientClosed + return 0, errClientClosed + case <-time.After(c.timeout): + c.deregisterHandler(seq) + return 0, errRequestTimeout } } type queryHandler struct { + // These fields are constant client *RPCClient - closed bool - init bool + seq uint64 + + // These fields relate to the initial response. Once the initial response has been received, init + // is atomically set and the initial response is put into initCh. + init uint32 // atomic initCh chan<- error + + // These fields relate to whether or not the query handler is still open and the ACK and response + // channels. The three following fields are protected by the mutex. + mtx sync.Mutex + closed bool ackCh chan<- string respCh chan<- NodeResponse - seq uint64 } func (qh *queryHandler) Handle(resp *responseHeader) { // Initialize on the first response - if !qh.init { - qh.init = true + if atomic.CompareAndSwapUint32(&qh.init, 0, 1) { qh.initCh <- strToError(resp.Error) return } @@ -587,49 +664,86 @@ func (qh *queryHandler) Handle(resp *responseHeader) { // Decode the query response var rec queryRecord if err := qh.client.dec.Decode(&rec); err != nil { - log.Printf("[ERR] Failed to decode query response: %v", err) + qh.client.logger.Printf("[ERR] Failed to decode query response: %v", err) qh.client.deregisterHandler(qh.seq) return } + // We want to "defer qh.mtx.Unlock()" after locking, but we need to unlock before calling + // deregisterHandler below; so this variable and these helper functions allow us to "unlock" + // multiple times -- one in a defer, and one manually before deregistering the handler. + locked := false + lockSafely := func() { + if !locked { + qh.mtx.Lock() + locked = true + } + } + unlockSafely := func() { + if locked { + qh.mtx.Unlock() + locked = false + } + } + + // Lock the mutex for the remainder of this function to ensure safe access to member variables + lockSafely() + defer unlockSafely() + + // If we're closed, dump the response + if qh.closed { + qh.client.logger.Printf("[WARN] Dropping query response, handler closed") + return + } + + // Not closed, so feed the response to the appropriate channel switch rec.Type { case queryRecordAck: select { case qh.ackCh <- rec.From: default: - log.Printf("[ERR] Dropping query ack, channel full") + qh.client.logger.Printf("[ERR] Dropping query ack, channel full") } case queryRecordResponse: select { case qh.respCh <- NodeResponse{rec.From, rec.Payload}: default: - log.Printf("[ERR] Dropping query response, channel full") + qh.client.logger.Printf("[ERR] Dropping query response, channel full") } case queryRecordDone: // No further records coming + // XXX: We need to unlock the mutex before calling deregisterHandler, as it will call Cleanup, + // which wants to lock the mutex! + unlockSafely() qh.client.deregisterHandler(qh.seq) default: - log.Printf("[ERR] Unrecognized query record type: %s", rec.Type) + qh.client.logger.Printf("[ERR] Unrecognized query record type: %s", rec.Type) } } func (qh *queryHandler) Cleanup() { - if !qh.closed { - if !qh.init { - qh.init = true - qh.initCh <- fmt.Errorf("Stream closed") - } - if qh.ackCh != nil { - close(qh.ackCh) - } - if qh.respCh != nil { - close(qh.respCh) - } - qh.closed = true + if atomic.CompareAndSwapUint32(&qh.init, 0, 1) { + qh.initCh <- errStreamClosed + } + + qh.mtx.Lock() + defer qh.mtx.Unlock() + + if qh.closed { + return + } + + if qh.ackCh != nil { + close(qh.ackCh) } + if qh.respCh != nil { + close(qh.respCh) + } + + qh.closed = true } // QueryParam is provided to query set various settings. @@ -668,6 +782,7 @@ func (c *RPCClient) Query(params *QueryParam) error { // Create a query handler initCh := make(chan error, 1) + defer close(initCh) handler := &queryHandler{ client: c, initCh: initCh, @@ -683,13 +798,22 @@ func (c *RPCClient) Query(params *QueryParam) error { return err } + // Use the lower of either the channel timeout of the query params timeout (if provided) + timeout := c.timeout + if params.Timeout != 0 && params.Timeout < timeout { + timeout = params.Timeout + } + // Wait for a response select { case err := <-initCh: return err case <-c.shutdownCh: c.deregisterHandler(seq) - return clientClosed + return errClientClosed + case <-time.After(timeout): + c.deregisterHandler(seq) + return errRequestTimeout } } @@ -765,14 +889,14 @@ func (c *RPCClient) genericRPC(header *requestHeader, req interface{}, resp inte case err := <-errCh: return err case <-c.shutdownCh: - return clientClosed + return errClientClosed } } // strToError converts a string to an error if not blank func strToError(s string) error { if s != "" { - return fmt.Errorf(s) + return errors.New(s) } return nil } @@ -785,12 +909,13 @@ func (c *RPCClient) getSeq() uint64 { // deregisterAll is used to deregister all handlers func (c *RPCClient) deregisterAll() { c.dispatchLock.Lock() - defer c.dispatchLock.Unlock() + dispatch := c.dispatch + c.dispatch = make(map[uint64]seqHandler) + c.dispatchLock.Unlock() - for _, seqH := range c.dispatch { + for _, seqH := range dispatch { seqH.Cleanup() } - c.dispatch = make(map[uint64]seqHandler) } // deregisterHandler is used to deregister a handler @@ -833,7 +958,7 @@ func (c *RPCClient) listen() { for { if err := c.dec.Decode(&respHeader); err != nil { if !c.shutdown { - log.Printf("[ERR] agent.client: Failed to decode response header: %v", err) + c.logger.Printf("[ERR] agent.client: Failed to decode response header: %v", err) } break }