diff --git a/cmd/nanomdm/main.go b/cmd/nanomdm/main.go index ffe4416..7586282 100644 --- a/cmd/nanomdm/main.go +++ b/cmd/nanomdm/main.go @@ -44,6 +44,11 @@ const ( endpointAPIVersion = "/version" ) +const ( + EnrollmentIDHeader = "X-Enrollment-ID" + TraceIDHeader = "X-Trace-ID" +) + func main() { cliStorage := cli.NewStorage() flag.Var(&cliStorage.Storage, "storage", "name of storage backend") @@ -166,7 +171,11 @@ func main() { if *flAuthProxy != "" { var authProxyHandler http.Handler - authProxyHandler, err = authproxy.New(*flAuthProxy, logger.With("handler", "authproxy")) + authProxyHandler, err = authproxy.New(*flAuthProxy, + authproxy.WithLogger(logger.With("handler", "authproxy")), + authproxy.WithHeaderFunc(EnrollmentIDHeader, httpmdm.GetEnrollmentID), + authproxy.WithHeaderFunc(TraceIDHeader, mdmhttp.GetTraceID), + ) if err != nil { stdlog.Fatal(err) } diff --git a/http/authproxy/authproxy.go b/http/authproxy/authproxy.go index 4bb8c5b..fd04287 100644 --- a/http/authproxy/authproxy.go +++ b/http/authproxy/authproxy.go @@ -2,33 +2,68 @@ package authproxy import ( + "context" "net/http" "net/http/httputil" "net/url" - mdmhttp "github.com/micromdm/nanomdm/http" - httpmdm "github.com/micromdm/nanomdm/http/mdm" "github.com/micromdm/nanomdm/log" "github.com/micromdm/nanomdm/log/ctxlog" ) -const ( - EnrollmentIDHeader = "X-Enrollment-ID" - TraceIDHeader = "X-Trace-ID" -) +// HeaderFunc takes an HTTP request and returns a string value. +// Ostensibly to be set in a header on the proxy target. +type HeaderFunc func(context.Context) string + +type config struct { + logger log.Logger + fwdSig bool + headerFuncs map[string]HeaderFunc +} + +type Option func(*config) + +// WithLogger sets a logger for error reporting. +func WithLogger(logger log.Logger) Option { + return func(c *config) { + c.logger = logger + } +} + +// WithHeaderFunc configures fn to be called and added as an HTTP header to the proxy target request. +func WithHeaderFunc(header string, fn HeaderFunc) Option { + return func(c *config) { + c.headerFuncs[header] = fn + } +} + +// WithForwardMDMSignature forwards the MDM-Signature header onto the proxy destination. +// This option is off by default because the header adds about two kilobytes to the request. +func WithForwardMDMSignature() Option { + return func(c *config) { + c.fwdSig = true + } +} // New creates a new NanoMDM enrollment authenticating reverse proxy. // This reverse proxy is mostly the standard httputil proxy. It depends // on middleware HTTP handlers to enforce authentication and set the // context value for the enrollment ID. -func New(dest string, logger log.Logger) (*httputil.ReverseProxy, error) { +func New(dest string, opts ...Option) (*httputil.ReverseProxy, error) { + config := &config{ + logger: log.NopLogger, + headerFuncs: make(map[string]HeaderFunc), + } + for _, opt := range opts { + opt(config) + } target, err := url.Parse(dest) if err != nil { return nil, err } proxy := httputil.NewSingleHostReverseProxy(target) proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { - ctxlog.Logger(r.Context(), logger).Info("err", err) + ctxlog.Logger(r.Context(), config.logger).Info("err", err) // use the same error as the standrad reverse proxy w.WriteHeader(http.StatusBadGateway) } @@ -36,15 +71,18 @@ func New(dest string, logger log.Logger) (*httputil.ReverseProxy, error) { proxy.Director = func(req *http.Request) { dir(req) req.Host = target.Host - // save the effort of forwarding this huge header - req.Header.Del("Mdm-Signature") - if id := httpmdm.GetEnrollmentID(req.Context()); id != "" { - req.Header.Set(EnrollmentIDHeader, id) + if !config.fwdSig { + // save the effort of forwarding this huge header + req.Header.Del("Mdm-Signature") } - // TODO: this couples us to our specific idea of trace logging - // Perhaps have an optional config for header specificaiton? - if id := mdmhttp.GetTraceID(req.Context()); id != "" { - req.Header.Set(TraceIDHeader, id) + // set any headers we want to forward. + for k, fn := range config.headerFuncs { + if k == "" || fn == nil { + continue + } + if v := fn(req.Context()); v != "" { + req.Header.Set(k, v) + } } } return proxy, nil