From 0db445c84b235d1ceb7f494fca2a60c1a36e2d98 Mon Sep 17 00:00:00 2001 From: Dan Johnson Date: Fri, 19 Apr 2024 11:21:52 -0500 Subject: [PATCH] nsqd: support for multiple Auth HTTP Methods Adds simple config option and flag to allow for auth to occur via POST request in addition to GET. Rationale: Errors from net/http requests are bubbled to nsqd when there is an error during authentication, such as if the nsq authentication server is unavailable. These errors include the full path, including any GET parameter, thus causing the authentication secret to be logged. This does not occur by default for the POST body thus helping protect secrets in transit between nsqd and the authentication server. --- apps/nsqd/options.go | 1 + internal/auth/authorizations.go | 23 ++++++++++++------- internal/clusterinfo/data.go | 4 ++-- internal/http_api/api_request.go | 21 ++++++++++++++++-- nsqd/client_v2.go | 4 +++- nsqd/nsqd.go | 4 ++++ nsqd/nsqd_test.go | 6 ++--- nsqd/options.go | 2 ++ nsqd/protocol_v2_test.go | 38 ++++++++++++++++++++++++-------- nsqlookupd/nsqlookupd_test.go | 6 ++--- 10 files changed, 81 insertions(+), 28 deletions(-) diff --git a/apps/nsqd/options.go b/apps/nsqd/options.go index 7ec65ab06..f93597013 100644 --- a/apps/nsqd/options.go +++ b/apps/nsqd/options.go @@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet { authHTTPAddresses := app.StringArray{} flagSet.Var(&authHTTPAddresses, "auth-http-address", ": or a full url to query auth server (may be given multiple times)") + flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests") flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)") flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)") flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)") diff --git a/internal/auth/authorizations.go b/internal/auth/authorizations.go index 0105d41f8..c9936cd3e 100644 --- a/internal/auth/authorizations.go +++ b/internal/auth/authorizations.go @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool { } func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string, - clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { + clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) { var retErr error start := rand.Int() n := len(authd) for i := 0; i < n; i++ { a := authd[(i+start)%n] - authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout) + authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod) if err != nil { es := fmt.Sprintf("failed to auth against %s - %s", a, err) if retErr != nil { @@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName } func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string, - clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { + clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) { + var authState State v := url.Values{} v.Set("remote_ip", remoteIP) if tlsEnabled { @@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin var endpoint string if strings.Contains(authd, "://") { - endpoint = fmt.Sprintf("%s?%s", authd, v.Encode()) + endpoint = authd } else { - endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode()) + endpoint = fmt.Sprintf("http://%s/auth", authd) } - var authState State client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout) - if err := client.GETV1(endpoint, &authState); err != nil { - return nil, err + if httpRequestMethod == "post" { + if err := client.POSTV1(endpoint, v, &authState); err != nil { + return nil, err + } + } else { + endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode()) + if err := client.GETV1(endpoint, &authState); err != nil { + return nil, err + } } // validation on response diff --git a/internal/clusterinfo/data.go b/internal/clusterinfo/data.go index 6de3e61be..d4aec07ea 100644 --- a/internal/clusterinfo/data.go +++ b/internal/clusterinfo/data.go @@ -878,7 +878,7 @@ func (c *ClusterInfo) nsqlookupdPOST(addrs []string, uri string, qs string) erro for _, addr := range addrs { endpoint := fmt.Sprintf("http://%s/%s?%s", addr, uri, qs) c.logf("CI: querying nsqlookupd %s", endpoint) - err := c.client.POSTV1(endpoint) + err := c.client.POSTV1(endpoint, nil, nil) if err != nil { errs = append(errs, err) } @@ -894,7 +894,7 @@ func (c *ClusterInfo) producersPOST(pl Producers, uri string, qs string) error { for _, p := range pl { endpoint := fmt.Sprintf("http://%s/%s?%s", p.HTTPAddress(), uri, qs) c.logf("CI: querying nsqd %s", endpoint) - err := c.client.POSTV1(endpoint) + err := c.client.POSTV1(endpoint, nil, nil) if err != nil { errs = append(errs, err) } diff --git a/internal/http_api/api_request.go b/internal/http_api/api_request.go index 4db3ad945..36042bd44 100644 --- a/internal/http_api/api_request.go +++ b/internal/http_api/api_request.go @@ -1,6 +1,7 @@ package http_api import ( + "bytes" "crypto/tls" "encoding/json" "fmt" @@ -86,14 +87,26 @@ retry: // PostV1 is a helper function to perform a V1 HTTP request // and parse our NSQ daemon's expected response format, with deadlines. -func (c *Client) POSTV1(endpoint string) error { +func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error { retry: - req, err := http.NewRequest("POST", endpoint, nil) + var reqBody io.Reader + if data != nil { + js, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint) + } + reqBody = bytes.NewBuffer(js) + } + + req, err := http.NewRequest("POST", endpoint, reqBody) if err != nil { return err } req.Header.Add("Accept", "application/vnd.nsq; version=1.0") + if reqBody != nil { + req.Header.Add("Content-Type", "application/json") + } resp, err := c.c.Do(req) if err != nil { @@ -116,6 +129,10 @@ retry: return fmt.Errorf("got response %s %q", resp.Status, body) } + if v != nil { + return json.Unmarshal(body, &v) + } + return nil } diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 32250e72c..8a40ea8d5 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error { remoteIP, tlsEnabled, commonName, c.AuthSecret, c.nsqd.clientTLSConfig, c.nsqd.getOpts().HTTPClientConnectTimeout, - c.nsqd.getOpts().HTTPClientRequestTimeout) + c.nsqd.getOpts().HTTPClientRequestTimeout, + c.nsqd.getOpts().AuthHTTPRequestMethod, + ) if err != nil { return err } diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index 04404c5cc..523d1665b 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) { } n.clientTLSConfig = clientTLSConfig + if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" { + return nil, errors.New("--auth-http-request-method must be post or get") + } + for _, v := range opts.E2EProcessingLatencyPercentiles { if v <= 0 || v > 1 { return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v) diff --git a/nsqd/nsqd_test.go b/nsqd/nsqd_test.go index 87ae7a9c8..2351854c8 100644 --- a/nsqd/nsqd_test.go +++ b/nsqd/nsqd_test.go @@ -336,11 +336,11 @@ func TestCluster(t *testing.T) { test.Nil(t, err) url := fmt.Sprintf("http://%s/topic/create?topic=%s", nsqd.RealHTTPAddr(), topicName) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil) test.Nil(t, err) url = fmt.Sprintf("http://%s/channel/create?topic=%s&channel=ch", nsqd.RealHTTPAddr(), topicName) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil) test.Nil(t, err) // allow some time for nsqd to push info to nsqlookupd @@ -394,7 +394,7 @@ func TestCluster(t *testing.T) { test.Equal(t, "ch", lr.Channels[0]) url = fmt.Sprintf("http://%s/topic/delete?topic=%s", nsqd.RealHTTPAddr(), topicName) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil) test.Nil(t, err) // allow some time for nsqd to push info to nsqlookupd diff --git a/nsqd/options.go b/nsqd/options.go index 30fe5ea1e..8997c7265 100644 --- a/nsqd/options.go +++ b/nsqd/options.go @@ -27,6 +27,7 @@ type Options struct { BroadcastHTTPPort int `flag:"broadcast-http-port"` NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"` AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"` + AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"` HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"` HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"` @@ -110,6 +111,7 @@ func NewOptions() *Options { NSQLookupdTCPAddresses: make([]string, 0), AuthHTTPAddresses: make([]string, 0), + AuthHTTPRequestMethod: "get", HTTPClientConnectTimeout: 2 * time.Second, HTTPClientRequestTimeout: 5 * time.Second, diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index 5ef34e93e..19015d1e9 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -18,6 +18,7 @@ import ( "os" "runtime" "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -1476,7 +1477,8 @@ func TestClientAuth(t *testing.T) { authSuccess := "" tlsEnabled := false commonName := "" - runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName) + httpAuthRequestMethod := "get" + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) // now one that will succeed authResponse = `{"ttl":10, "authorizations": @@ -1484,16 +1486,21 @@ func TestClientAuth(t *testing.T) { }` authError = "" authSuccess = `{"identity":"","identity_url":"","permission_count":1}` - runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName) + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) // one with TLS enabled tlsEnabled = true commonName = "test.local" - runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName) + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) + + // test POST based authentication + httpAuthRequestMethod = "post" + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) + } func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string, - authSuccess string, tlsEnabled bool, commonName string) { + authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) { var err error var expectedRemoteIP string expectedTLS := "false" @@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf("in test auth handler %s", r.RequestURI) - r.ParseForm() - test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip")) - test.Equal(t, expectedTLS, r.Form.Get("tls")) - test.Equal(t, commonName, r.Form.Get("common_name")) - test.Equal(t, authSecret, r.Form.Get("secret")) + test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method)) + + var values url.Values + + if r.Method == "POST" { + err = json.NewDecoder(r.Body).Decode(&values) + if err != nil { + t.Error(err) + } + } else { + r.ParseForm() + values = r.Form + } + test.Equal(t, expectedRemoteIP, values.Get("remote_ip")) + test.Equal(t, expectedTLS, values.Get("tls")) + test.Equal(t, commonName, values.Get("common_name")) + test.Equal(t, authSecret, values.Get("secret")) fmt.Fprint(w, authResponse) })) defer authd.Close() @@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError opts.Logger = test.NewTestLogger(t) opts.LogLevel = LOG_DEBUG opts.AuthHTTPAddresses = []string{addr.Host} + opts.AuthHTTPRequestMethod = httpAuthRequestMethod if tlsEnabled { opts.TLSCert = "./test/certs/server.pem" opts.TLSKey = "./test/certs/server.key" diff --git a/nsqlookupd/nsqlookupd_test.go b/nsqlookupd/nsqlookupd_test.go index 492d09b30..6afb18906 100644 --- a/nsqlookupd/nsqlookupd_test.go +++ b/nsqlookupd/nsqlookupd_test.go @@ -220,7 +220,7 @@ func TestTombstoneRecover(t *testing.T) { endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d", httpAddr, topicName, HostAddr, HTTPPort) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil) test.Nil(t, err) pr := ProducersDoc{} @@ -263,7 +263,7 @@ func TestTombstoneUnregister(t *testing.T) { endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d", httpAddr, topicName, HostAddr, HTTPPort) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil) test.Nil(t, err) pr := ProducersDoc{} @@ -348,7 +348,7 @@ func TestTombstonedNodes(t *testing.T) { endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d", httpAddr, topicName, HostAddr, HTTPPort) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil) test.Nil(t, err) producers, _ = ci.GetLookupdProducers(lookupdHTTPAddrs)