Skip to content

Commit

Permalink
Merge pull request #1487 from danrjohnson/support_post_auth
Browse files Browse the repository at this point in the history
nsqd: support POST auth
  • Loading branch information
mreiferson authored May 12, 2024
2 parents 62fa868 + 0db445c commit 4de1606
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 28 deletions.
1 change: 1 addition & 0 deletions apps/nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func nsqdFlagSet(opts *nsqd.Options) *flag.FlagSet {

authHTTPAddresses := app.StringArray{}
flagSet.Var(&authHTTPAddresses, "auth-http-address", "<addr>:<port> or a full url to query auth server (may be given multiple times)")
flagSet.String("auth-http-request-method", opts.AuthHTTPRequestMethod, "HTTP method to use for auth server requests")
flagSet.String("broadcast-address", opts.BroadcastAddress, "address that will be registered with lookupd (defaults to the OS hostname)")
flagSet.Int("broadcast-tcp-port", opts.BroadcastTCPPort, "TCP port that will be registered with lookupd (defaults to the TCP port that this nsqd is listening on)")
flagSet.Int("broadcast-http-port", opts.BroadcastHTTPPort, "HTTP port that will be registered with lookupd (defaults to the HTTP port that this nsqd is listening on)")
Expand Down
23 changes: 15 additions & 8 deletions internal/auth/authorizations.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func (a *State) IsExpired() bool {
}

func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var retErr error
start := rand.Int()
n := len(authd)
for i := 0; i < n; i++ {
a := authd[(i+start)%n]
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout)
authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout, httpRequestMethod)
if err != nil {
es := fmt.Sprintf("failed to auth against %s - %s", a, err)
if retErr != nil {
Expand All @@ -97,7 +97,8 @@ func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName
}

func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration, httpRequestMethod string) (*State, error) {
var authState State
v := url.Values{}
v.Set("remote_ip", remoteIP)
if tlsEnabled {
Expand All @@ -110,15 +111,21 @@ func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName strin

var endpoint string
if strings.Contains(authd, "://") {
endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
endpoint = authd
} else {
endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
endpoint = fmt.Sprintf("http://%s/auth", authd)
}

var authState State
client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
if httpRequestMethod == "post" {
if err := client.POSTV1(endpoint, v, &authState); err != nil {
return nil, err
}
} else {
endpoint = fmt.Sprintf("%s?%s", endpoint, v.Encode())
if err := client.GETV1(endpoint, &authState); err != nil {
return nil, err
}
}

// validation on response
Expand Down
4 changes: 2 additions & 2 deletions internal/clusterinfo/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ func (c *ClusterInfo) nsqlookupdPOST(addrs []string, uri string, qs string) erro
for _, addr := range addrs {
endpoint := fmt.Sprintf("http://%s/%s?%s", addr, uri, qs)
c.logf("CI: querying nsqlookupd %s", endpoint)
err := c.client.POSTV1(endpoint)
err := c.client.POSTV1(endpoint, nil, nil)
if err != nil {
errs = append(errs, err)
}
Expand All @@ -894,7 +894,7 @@ func (c *ClusterInfo) producersPOST(pl Producers, uri string, qs string) error {
for _, p := range pl {
endpoint := fmt.Sprintf("http://%s/%s?%s", p.HTTPAddress(), uri, qs)
c.logf("CI: querying nsqd %s", endpoint)
err := c.client.POSTV1(endpoint)
err := c.client.POSTV1(endpoint, nil, nil)
if err != nil {
errs = append(errs, err)
}
Expand Down
21 changes: 19 additions & 2 deletions internal/http_api/api_request.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http_api

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -86,14 +87,26 @@ retry:

// PostV1 is a helper function to perform a V1 HTTP request
// and parse our NSQ daemon's expected response format, with deadlines.
func (c *Client) POSTV1(endpoint string) error {
func (c *Client) POSTV1(endpoint string, data url.Values, v interface{}) error {
retry:
req, err := http.NewRequest("POST", endpoint, nil)
var reqBody io.Reader
if data != nil {
js, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal POST data to endpoint: %v", endpoint)
}
reqBody = bytes.NewBuffer(js)
}

req, err := http.NewRequest("POST", endpoint, reqBody)
if err != nil {
return err
}

req.Header.Add("Accept", "application/vnd.nsq; version=1.0")
if reqBody != nil {
req.Header.Add("Content-Type", "application/json")
}

resp, err := c.c.Do(req)
if err != nil {
Expand All @@ -116,6 +129,10 @@ retry:
return fmt.Errorf("got response %s %q", resp.Status, body)
}

if v != nil {
return json.Unmarshal(body, &v)
}

return nil
}

Expand Down
4 changes: 3 additions & 1 deletion nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,9 @@ func (c *clientV2) QueryAuthd() error {
remoteIP, tlsEnabled, commonName, c.AuthSecret,
c.nsqd.clientTLSConfig,
c.nsqd.getOpts().HTTPClientConnectTimeout,
c.nsqd.getOpts().HTTPClientRequestTimeout)
c.nsqd.getOpts().HTTPClientRequestTimeout,
c.nsqd.getOpts().AuthHTTPRequestMethod,
)
if err != nil {
return err
}
Expand Down
4 changes: 4 additions & 0 deletions nsqd/nsqd.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ func New(opts *Options) (*NSQD, error) {
}
n.clientTLSConfig = clientTLSConfig

if opts.AuthHTTPRequestMethod != "post" && opts.AuthHTTPRequestMethod != "get" {
return nil, errors.New("--auth-http-request-method must be post or get")
}

for _, v := range opts.E2EProcessingLatencyPercentiles {
if v <= 0 || v > 1 {
return nil, fmt.Errorf("invalid E2E processing latency percentile: %v", v)
Expand Down
6 changes: 3 additions & 3 deletions nsqd/nsqd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,11 @@ func TestCluster(t *testing.T) {
test.Nil(t, err)

url := fmt.Sprintf("http://%s/topic/create?topic=%s", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

url = fmt.Sprintf("http://%s/channel/create?topic=%s&channel=ch", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

// allow some time for nsqd to push info to nsqlookupd
Expand Down Expand Up @@ -394,7 +394,7 @@ func TestCluster(t *testing.T) {
test.Equal(t, "ch", lr.Channels[0])

url = fmt.Sprintf("http://%s/topic/delete?topic=%s", nsqd.RealHTTPAddr(), topicName)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(url, nil, nil)
test.Nil(t, err)

// allow some time for nsqd to push info to nsqlookupd
Expand Down
2 changes: 2 additions & 0 deletions nsqd/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type Options struct {
BroadcastHTTPPort int `flag:"broadcast-http-port"`
NSQLookupdTCPAddresses []string `flag:"lookupd-tcp-address" cfg:"nsqlookupd_tcp_addresses"`
AuthHTTPAddresses []string `flag:"auth-http-address" cfg:"auth_http_addresses"`
AuthHTTPRequestMethod string `flag:"auth-http-request-method" cfg:"auth_http_request_method"`
HTTPClientConnectTimeout time.Duration `flag:"http-client-connect-timeout" cfg:"http_client_connect_timeout"`
HTTPClientRequestTimeout time.Duration `flag:"http-client-request-timeout" cfg:"http_client_request_timeout"`

Expand Down Expand Up @@ -110,6 +111,7 @@ func NewOptions() *Options {

NSQLookupdTCPAddresses: make([]string, 0),
AuthHTTPAddresses: make([]string, 0),
AuthHTTPRequestMethod: "get",

HTTPClientConnectTimeout: 2 * time.Second,
HTTPClientRequestTimeout: 5 * time.Second,
Expand Down
38 changes: 29 additions & 9 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -1476,24 +1477,30 @@ func TestClientAuth(t *testing.T) {
authSuccess := ""
tlsEnabled := false
commonName := ""
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
httpAuthRequestMethod := "get"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// now one that will succeed
authResponse = `{"ttl":10, "authorizations":
[{"topic":"test", "channels":[".*"], "permissions":["subscribe","publish"]}]
}`
authError = ""
authSuccess = `{"identity":"","identity_url":"","permission_count":1}`
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// one with TLS enabled
tlsEnabled = true
commonName = "test.local"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName)
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

// test POST based authentication
httpAuthRequestMethod = "post"
runAuthTest(t, authResponse, authSecret, authError, authSuccess, tlsEnabled, commonName, httpAuthRequestMethod)

}

func runAuthTest(t *testing.T, authResponse string, authSecret string, authError string,
authSuccess string, tlsEnabled bool, commonName string) {
authSuccess string, tlsEnabled bool, commonName string, httpAuthRequestMethod string) {
var err error
var expectedRemoteIP string
expectedTLS := "false"
Expand All @@ -1503,11 +1510,23 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError

authd := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Logf("in test auth handler %s", r.RequestURI)
r.ParseForm()
test.Equal(t, expectedRemoteIP, r.Form.Get("remote_ip"))
test.Equal(t, expectedTLS, r.Form.Get("tls"))
test.Equal(t, commonName, r.Form.Get("common_name"))
test.Equal(t, authSecret, r.Form.Get("secret"))
test.Equal(t, httpAuthRequestMethod, strings.ToLower(r.Method))

var values url.Values

if r.Method == "POST" {
err = json.NewDecoder(r.Body).Decode(&values)
if err != nil {
t.Error(err)
}
} else {
r.ParseForm()
values = r.Form
}
test.Equal(t, expectedRemoteIP, values.Get("remote_ip"))
test.Equal(t, expectedTLS, values.Get("tls"))
test.Equal(t, commonName, values.Get("common_name"))
test.Equal(t, authSecret, values.Get("secret"))
fmt.Fprint(w, authResponse)
}))
defer authd.Close()
Expand All @@ -1519,6 +1538,7 @@ func runAuthTest(t *testing.T, authResponse string, authSecret string, authError
opts.Logger = test.NewTestLogger(t)
opts.LogLevel = LOG_DEBUG
opts.AuthHTTPAddresses = []string{addr.Host}
opts.AuthHTTPRequestMethod = httpAuthRequestMethod
if tlsEnabled {
opts.TLSCert = "./test/certs/server.pem"
opts.TLSKey = "./test/certs/server.key"
Expand Down
6 changes: 3 additions & 3 deletions nsqlookupd/nsqlookupd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ func TestTombstoneRecover(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

pr := ProducersDoc{}
Expand Down Expand Up @@ -263,7 +263,7 @@ func TestTombstoneUnregister(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

pr := ProducersDoc{}
Expand Down Expand Up @@ -348,7 +348,7 @@ func TestTombstonedNodes(t *testing.T) {

endpoint := fmt.Sprintf("http://%s/topic/tombstone?topic=%s&node=%s:%d",
httpAddr, topicName, HostAddr, HTTPPort)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint)
err = http_api.NewClient(nil, ConnectTimeout, RequestTimeout).POSTV1(endpoint, nil, nil)
test.Nil(t, err)

producers, _ = ci.GetLookupdProducers(lookupdHTTPAddrs)
Expand Down

0 comments on commit 4de1606

Please sign in to comment.