diff --git a/apps/nsqd/options.go b/apps/nsqd/options.go index 7ec65ab06..f93597013 100644 --- a/apps/nsqd/options.go +++ b/apps/nsqd/options.go @@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet { authHTTPAddresses := app.StringArray{} flagSet.Var(&authHTTPAddresses, "auth-http-address", ": or a full url to query auth server (may be given multiple times)") + flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests") flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)") flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)") flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)") diff --git a/internal/auth/authorizations.go b/internal/auth/authorizations.go index 0105d41f8..c9936cd3e 100644 --- a/internal/auth/authorizations.go +++ b/internal/auth/authorizations.go @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool { } func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string, - clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { + clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) { var retErr error start := rand.Int() n := len(authd) for i := 0; i < n; i++ { a := authd[(i+start)%n] - authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout) + authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod) if err != nil { es := fmt.Sprintf("failed to auth against %s - %s", a, err) if retErr != nil { @@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName } func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string, - clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { + clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) { + var authState State v := url.Values{} v.Set("remote_ip", remoteIP) if tlsEnabled { @@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin var endpoint string if strings.Contains(authd, "://") { - endpoint = fmt.Sprintf("%s?%s", authd, v.Encode()) + endpoint = authd } else { - endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode()) + endpoint = fmt.Sprintf("http://%s/auth", authd) } - var authState State client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout) - if err := client.GETV1(endpoint, &authState); err != nil { - return nil, err + if httpRequestMethod == "post" { + if err := client.POSTV1(endpoint, v, &authState); err != nil { + return nil, err + } + } else { + endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode()) + if err := client.GETV1(endpoint, &authState); err != nil { + return nil, err + } } // validation on response diff --git a/internal/clusterinfo/data.go b/internal/clusterinfo/data.go index 6de3e61be..d4aec07ea 100644 --- a/internal/clusterinfo/data.go +++ b/internal/clusterinfo/data.go @@ -878,7 +878,7 @@ func (c *ClusterInfo) nsqlookupdPOST(addrs []string, uri string, qs string) erro for _, addr := range addrs { endpoint := fmt.Sprintf("http://%s/%s?%s", addr, uri, qs) c.logf("CI: querying nsqlookupd %s", endpoint) - err := c.client.POSTV1(endpoint) + err := c.client.POSTV1(endpoint, nil, nil) if err != nil { errs = append(errs, err) } @@ -894,7 +894,7 @@ func (c *ClusterInfo) producersPOST(pl Producers, uri string, qs string) error { for _, p := range pl { endpoint := fmt.Sprintf("http://%s/%s?%s", p.HTTPAddress(), uri, qs) c.logf("CI: querying nsqd %s", endpoint) - err := c.client.POSTV1(endpoint) + err := c.client.POSTV1(endpoint, nil, nil) if err != nil { errs = append(errs, err) } diff --git a/internal/http_api/api_request.go b/internal/http_api/api_request.go index 4db3ad945..36042bd44 100644 --- a/internal/http_api/api_request.go +++ b/internal/http_api/api_request.go @@ -1,6 +1,7 @@ package http_api import ( + "bytes" "crypto/tls" "encoding/json" "fmt" @@ -86,14 +87,26 @@ retry: // PostV1 is a helper function to perform a V1 HTTP request // and parse our NSQ daemon's expected response format, with deadlines. -func (c *Client) POSTV1(endpoint string) error { +func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error { retry: - req, err := http.NewRequest("POST", endpoint, nil) + var reqBody io.Reader + if data != nil { + js, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint) + } + reqBody = bytes.NewBuffer(js) + } + + req, err := http.NewRequest("POST", endpoint, reqBody) if err != nil { return err } req.Header.Add("Accept", "application/vnd.nsq; version=1.0") + if reqBody != nil { + req.Header.Add("Content-Type", "application/json") + } resp, err := c.c.Do(req) if err != nil { @@ -116,6 +129,10 @@ retry: return fmt.Errorf("got response %s %q", resp.Status, body) } + if v != nil { + return json.Unmarshal(body, &v) + } + return nil } diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 32250e72c..8a40ea8d5 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error { remoteIP, tlsEnabled, commonName, c.AuthSecret, c.nsqd.clientTLSConfig, c.nsqd.getOpts().HTTPClientConnectTimeout, - c.nsqd.getOpts().HTTPClientRequestTimeout) + c.nsqd.getOpts().HTTPClientRequestTimeout, + c.nsqd.getOpts().AuthHTTPRequestMethod, + ) if err != nil { return err } diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index 04404c5cc..523d1665b 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) { } n.clientTLSConfig = clientTLSConfig + if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" { + return nil, errors.New("--auth-http-request-method must be post or get") + } + for _, v := range opts.E2EProcessingLatencyPercentiles { if v <= 0 || v > 1 { return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v) diff --git a/nsqd/nsqd_test.go b/nsqd/nsqd_test.go index 87ae7a9c8..2351854c8 100644 --- a/nsqd/nsqd_test.go +++ b/nsqd/nsqd_test.go @@ -336,11 +336,11 @@ func TestCluster(t *testing.T) { test.Nil(t, err) url := fmt.Sprintf("http://%s/topic/create?topic=%s", nsqd.RealHTTPAddr(), topicName) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil) test.Nil(t, err) url = fmt.Sprintf("http://%s/channel/create?topic=%s&channel=ch", nsqd.RealHTTPAddr(), topicName) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil) test.Nil(t, err) // allow some time for nsqd to push info to nsqlookupd @@ -394,7 +394,7 @@ func TestCluster(t *testing.T) { test.Equal(t, "ch", lr.Channels[0]) url = fmt.Sprintf("http://%s/topic/delete?topic=%s", nsqd.RealHTTPAddr(), topicName) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil) test.Nil(t, err) // allow some time for nsqd to push info to nsqlookupd diff --git a/nsqd/options.go b/nsqd/options.go index 30fe5ea1e..8997c7265 100644 --- a/nsqd/options.go +++ b/nsqd/options.go @@ -27,6 +27,7 @@ type Options struct { BroadcastHTTPPort int `flag:"broadcast-http-port"` NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"` AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"` + AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"` HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"` HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"` @@ -110,6 +111,7 @@ func NewOptions() *Options { NSQLookupdTCPAddresses: make([]string, 0), AuthHTTPAddresses: make([]string, 0), + AuthHTTPRequestMethod: "get", HTTPClientConnectTimeout: 2 * time.Second, HTTPClientRequestTimeout: 5 * time.Second, diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index 5ef34e93e..19015d1e9 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -18,6 +18,7 @@ import ( "os" "runtime" "strconv" + "strings" "sync" "sync/atomic" "testing" @@ -1476,7 +1477,8 @@ func TestClientAuth(t *testing.T) { authSuccess := "" tlsEnabled := false commonName := "" - runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName) + httpAuthRequestMethod := "get" + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) // now one that will succeed authResponse = `{"ttl":10, "authorizations": @@ -1484,16 +1486,21 @@ func TestClientAuth(t *testing.T) { }` authError = "" authSuccess = `{"identity":"","identity_url":"","permission_count":1}` - runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName) + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) // one with TLS enabled tlsEnabled = true commonName = "test.local" - runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName) + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) + + // test POST based authentication + httpAuthRequestMethod = "post" + runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod) + } func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string, - authSuccess string, tlsEnabled bool, commonName string) { + authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) { var err error var expectedRemoteIP string expectedTLS := "false" @@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Logf("in test auth handler %s", r.RequestURI) - r.ParseForm() - test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip")) - test.Equal(t, expectedTLS, r.Form.Get("tls")) - test.Equal(t, commonName, r.Form.Get("common_name")) - test.Equal(t, authSecret, r.Form.Get("secret")) + test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method)) + + var values url.Values + + if r.Method == "POST" { + err = json.NewDecoder(r.Body).Decode(&values) + if err != nil { + t.Error(err) + } + } else { + r.ParseForm() + values = r.Form + } + test.Equal(t, expectedRemoteIP, values.Get("remote_ip")) + test.Equal(t, expectedTLS, values.Get("tls")) + test.Equal(t, commonName, values.Get("common_name")) + test.Equal(t, authSecret, values.Get("secret")) fmt.Fprint(w, authResponse) })) defer authd.Close() @@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError opts.Logger = test.NewTestLogger(t) opts.LogLevel = LOG_DEBUG opts.AuthHTTPAddresses = []string{addr.Host} + opts.AuthHTTPRequestMethod = httpAuthRequestMethod if tlsEnabled { opts.TLSCert = "./test/certs/server.pem" opts.TLSKey = "./test/certs/server.key" diff --git a/nsqlookupd/nsqlookupd_test.go b/nsqlookupd/nsqlookupd_test.go index 492d09b30..6afb18906 100644 --- a/nsqlookupd/nsqlookupd_test.go +++ b/nsqlookupd/nsqlookupd_test.go @@ -220,7 +220,7 @@ func TestTombstoneRecover(t *testing.T) { endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d", httpAddr, topicName, HostAddr, HTTPPort) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil) test.Nil(t, err) pr := ProducersDoc{} @@ -263,7 +263,7 @@ func TestTombstoneUnregister(t *testing.T) { endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d", httpAddr, topicName, HostAddr, HTTPPort) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil) test.Nil(t, err) pr := ProducersDoc{} @@ -348,7 +348,7 @@ func TestTombstonedNodes(t *testing.T) { endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d", httpAddr, topicName, HostAddr, HTTPPort) - err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint) + err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil) test.Nil(t, err) producers, _ = ci.GetLookupdProducers(lookupdHTTPAddrs)