Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

authproxy: MDM device identity authenticated HTTP requests #80

Merged
merged 13 commits into from
Aug 28, 2023
42 changes: 30 additions & 12 deletions cmd/nanomdm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/micromdm/nanomdm/cli"
mdmhttp "github.com/micromdm/nanomdm/http"
httpapi "github.com/micromdm/nanomdm/http/api"
"github.com/micromdm/nanomdm/http/authproxy"
httpmdm "github.com/micromdm/nanomdm/http/mdm"
"github.com/micromdm/nanomdm/log/stdlogfmt"
"github.com/micromdm/nanomdm/push/buford"
Expand All @@ -34,6 +35,8 @@ const (
endpointMDM = "/mdm"
endpointCheckin = "/checkin"

endpointAuthProxy = "/authproxy/"

endpointAPIPushCert = "/v1/pushcert"
endpointAPIPush = "/v1/push/"
endpointAPIEnqueue = "/v1/enqueue/"
Expand Down Expand Up @@ -62,6 +65,7 @@ func main() {
flMigration = flag.Bool("migration", false, "HTTP endpoint for enrollment migrations")
flRetro = flag.Bool("retro", false, "Allow retroactive certificate-authorization association")
flDMURLPfx = flag.String("dm", "", "URL to send Declarative Management requests to")
flAuthProxy = flag.String("auth-proxy-url", "", "Reverse proxy URL target for MDM-authenticated HTTP requests")
)
flag.Parse()

Expand Down Expand Up @@ -129,6 +133,17 @@ func main() {
mdmService = dump.New(mdmService, os.Stdout)
}

// helper for authorizing MDM clients requests
certAuthMiddleware := func(h http.Handler) http.Handler {
h = httpmdm.CertVerifyMiddleware(h, verifier, logger.With("handler", "cert-verify"))
if *flCertHeader != "" {
h = httpmdm.CertExtractPEMHeaderMiddleware(h, *flCertHeader, logger.With("handler", "cert-extract"))
} else {
h = httpmdm.CertExtractMdmSignatureMiddleware(h, logger.With("handler", "cert-extract"))
}
return h
}

// register 'core' MDM HTTP handler
var mdmHandler http.Handler
if *flCheckin {
Expand All @@ -138,26 +153,29 @@ func main() {
// if we don't use a check-in handler then do both
mdmHandler = httpmdm.CheckinAndCommandHandler(mdmService, logger.With("handler", "checkin-command"))
}
mdmHandler = httpmdm.CertVerifyMiddleware(mdmHandler, verifier, logger.With("handler", "cert-verify"))
if *flCertHeader != "" {
mdmHandler = httpmdm.CertExtractPEMHeaderMiddleware(mdmHandler, *flCertHeader, logger.With("handler", "cert-extract"))
} else {
mdmHandler = httpmdm.CertExtractMdmSignatureMiddleware(mdmHandler, logger.With("handler", "cert-extract"))
}
mdmHandler = certAuthMiddleware(mdmHandler)
mux.Handle(endpointMDM, mdmHandler)

if *flCheckin {
// if we specified a separate check-in handler, set it up
var checkinHandler http.Handler
checkinHandler = httpmdm.CheckinHandler(mdmService, logger.With("handler", "checkin"))
checkinHandler = httpmdm.CertVerifyMiddleware(checkinHandler, verifier, logger.With("handler", "cert-verify"))
if *flCertHeader != "" {
checkinHandler = httpmdm.CertExtractPEMHeaderMiddleware(checkinHandler, *flCertHeader, logger.With("handler", "cert-extract"))
} else {
checkinHandler = httpmdm.CertExtractMdmSignatureMiddleware(checkinHandler, logger.With("handler", "cert-extract"))
}
checkinHandler = certAuthMiddleware(checkinHandler)
mux.Handle(endpointCheckin, checkinHandler)
}

if *flAuthProxy != "" {
var authProxyHandler http.Handler
authProxyHandler, err = authproxy.New(*flAuthProxy, logger.With("handler", "authproxy"))
if err != nil {
stdlog.Fatal(err)
}
logger.Debug("msg", "authproxy setup", "url", *flAuthProxy)
authProxyHandler = http.StripPrefix(endpointAuthProxy, authProxyHandler)
authProxyHandler = httpmdm.CertWithEnrollmentIDMiddleware(authProxyHandler, certauth.HashCert, mdmStorage, true, logger.With("handler", "with-enrollment-id"))
authProxyHandler = certAuthMiddleware(authProxyHandler)
mux.Handle(endpointAuthProxy, authProxyHandler)
}
}

if *flAPIKey != "" {
Expand Down
51 changes: 51 additions & 0 deletions http/authproxy/authproxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Package authproxy is a simple reverse proxy for Apple MDM clients.
package authproxy

import (
"net/http"
"net/http/httputil"
"net/url"

mdmhttp "github.com/micromdm/nanomdm/http"
httpmdm "github.com/micromdm/nanomdm/http/mdm"
"github.com/micromdm/nanomdm/log"
"github.com/micromdm/nanomdm/log/ctxlog"
)

const (
EnrollmentIDHeader = "X-Enrollment-ID"
TraceIDHeader = "X-Trace-ID"
)

// New creates a new NanoMDM enrollment authenticating reverse proxy.
// This reverse proxy is mostly the standard httputil proxy. It depends
// on middleware HTTP handlers to enforce authentication and set the
// context value for the enrollment ID.
func New(dest string, logger log.Logger) (*httputil.ReverseProxy, error) {
target, err := url.Parse(dest)
if err != nil {
return nil, err
}
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
ctxlog.Logger(r.Context(), logger).Info("err", err)
// use the same error as the standrad reverse proxy
w.WriteHeader(http.StatusBadGateway)
}
dir := proxy.Director
proxy.Director = func(req *http.Request) {
dir(req)
req.Host = target.Host
// save the effort of forwarding this huge header
req.Header.Del("Mdm-Signature")
jessepeterson marked this conversation as resolved.
Show resolved Hide resolved
if id := httpmdm.GetEnrollmentID(req.Context()); id != "" {
req.Header.Set(EnrollmentIDHeader, id)
}
// TODO: this couples us to our specific idea of trace logging
// Perhaps have an optional config for header specificaiton?
if id := mdmhttp.GetTraceID(req.Context()); id != "" {
req.Header.Set(TraceIDHeader, id)
jessepeterson marked this conversation as resolved.
Show resolved Hide resolved
}
}
return proxy, nil
}
6 changes: 6 additions & 0 deletions http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func VersionHandler(version string) http.HandlerFunc {

type ctxKeyTraceID struct{}

// GetTraceID returns the trace ID from ctx.
func GetTraceID(ctx context.Context) string {
id, _ := ctx.Value(ctxKeyTraceID{}).(string)
return id
}

// TraceLoggingMiddleware sets up a trace ID in the request context and
// logs HTTP requests.
func TraceLoggingMiddleware(next http.Handler, logger log.Logger, traceID func(*http.Request) string) http.HandlerFunc {
Expand Down
68 changes: 68 additions & 0 deletions http/mdm/mdm_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import (
mdmhttp "github.com/micromdm/nanomdm/http"
"github.com/micromdm/nanomdm/log"
"github.com/micromdm/nanomdm/log/ctxlog"
"github.com/micromdm/nanomdm/storage"
)

type contextKeyCert struct{}

var contextEnrollmentID struct{}

// CertExtractPEMHeaderMiddleware extracts the MDM enrollment identity
// certificate from the request into the HTTP request context. It looks
// at the request header which should be a URL-encoded PEM certificate.
Expand Down Expand Up @@ -128,3 +131,68 @@ func CertVerifyMiddleware(next http.Handler, verifier CertVerifier, logger log.L
next.ServeHTTP(w, r)
}
}

// GetEnrollmentID retrieves the MDM enrollment ID from ctx.
func GetEnrollmentID(ctx context.Context) string {
id, _ := ctx.Value(contextEnrollmentID).(string)
return id
}

type HashFn func(*x509.Certificate) string

// CertWithEnrollmentIDMiddleware tries to associate the enrollment ID to the request context.
// It does this by looking up the certificate on the context, hashing it with
// hasher, looking up the hash in storage, and setting the ID on the context.
//
// The next handler will be called even if cert or ID is not found unless
// enforce is true. This way next is able to use the existence of the ID on
// the context to make its own decisions.
func CertWithEnrollmentIDMiddleware(next http.Handler, hasher HashFn, store storage.CertAuthRetriever, enforce bool, logger log.Logger) http.HandlerFunc {
jessepeterson marked this conversation as resolved.
Show resolved Hide resolved
if store == nil || hasher == nil {
panic("store and hasher must not be nil")
}
return func(w http.ResponseWriter, r *http.Request) {
cert := GetCert(r.Context())
if cert == nil {
if enforce {
ctxlog.Logger(r.Context(), logger).Info(
"err", "missing certificate",
)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest)
return
} else {
ctxlog.Logger(r.Context(), logger).Debug(
"msg", "missing certificate",
)
next.ServeHTTP(w, r)
return
}
}
mr, err := store.EnrollmentFromHash(r.Context(), hasher(cert))
if err != nil {
ctxlog.Logger(r.Context(), logger).Info(
"msg", "retreiving enrollment from hash",
"err", err,
)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if mr == nil || mr.ID == "" {
if enforce {
ctxlog.Logger(r.Context(), logger).Info(
"err", "missing enrollment id",
)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest)
return
} else {
ctxlog.Logger(r.Context(), logger).Debug(
"msg", "missing enrollment id",
)
next.ServeHTTP(w, r)
return
}
}
ctx := context.WithValue(r.Context(), contextEnrollmentID, mr.ID)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
7 changes: 4 additions & 3 deletions service/certauth/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func New(next service.CheckinAndCommandService, storage storage.CertAuthStore, o
return certAuth
}

func hashCert(cert *x509.Certificate) string {
// HashCert returns the string representation
func HashCert(cert *x509.Certificate) string {
hashed := sha256.Sum256(cert.Raw)
b := make([]byte, len(hashed))
copy(b, hashed[:])
Expand All @@ -112,7 +113,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error {
return err
}
logger := ctxlog.Logger(r.Context, s.logger)
hash := hashCert(r.Certificate)
hash := HashCert(r.Certificate)
if hasHash, err := s.storage.HasCertHash(r, hash); err != nil {
return err
} else if hasHash {
Expand Down Expand Up @@ -157,7 +158,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error {
return err
}
logger := ctxlog.Logger(r.Context, s.logger)
hash := hashCert(r.Certificate)
hash := HashCert(r.Certificate)
if isAssoc, err := s.storage.IsCertHashAssociated(r, hash); err != nil {
return err
} else if isAssoc {
Expand Down
1 change: 1 addition & 0 deletions storage/all.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type AllStorage interface {
PushCertStore
CommandEnqueuer
CertAuthStore
CertAuthRetriever
StoreMigrator
TokenUpdateTallyStore
}
9 changes: 9 additions & 0 deletions storage/allmulti/certauth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package allmulti

import (
"context"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)
Expand Down Expand Up @@ -32,3 +34,10 @@ func (ms *MultiAllStorage) AssociateCertHash(r *mdm.Request, hash string) error
})
return err
}

func (ms *MultiAllStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) {
val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) {
return s.EnrollmentFromHash(ctx, hash)
})
return val.(*mdm.Request), err
}
21 changes: 21 additions & 0 deletions storage/file/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package file

import (
"bufio"
"context"
"errors"
"os"
"path"
Expand Down Expand Up @@ -68,3 +69,23 @@ func (s *FileStorage) AssociateCertHash(r *mdm.Request, hash string) error {
e := s.newEnrollment(r.ID)
return e.writeFile(CertAuthFilename, []byte(hash))
}

func (s *FileStorage) EnrollmentFromHash(_ context.Context, hash string) (*mdm.Request, error) {
f, err := os.Open(path.Join(s.path, CertAuthAssociationsFilename))
if err != nil {
return nil, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
text := scanner.Text()
if strings.Contains(text, hash) {
split := strings.Split(text, ",")
if len(split) < 2 {
return nil, errors.New("hash and enrollment id not present on line")
}
return &mdm.Request{EnrollID: &mdm.EnrollID{ID: split[0]}}, nil
}
}
return nil, nil
}
15 changes: 15 additions & 0 deletions storage/mysql/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mysql

import (
"context"
"database/sql"
"errors"
"strings"

"github.com/micromdm/nanomdm/mdm"
Expand Down Expand Up @@ -49,3 +51,16 @@ UPDATE sha256 = new.sha256;`,
)
return err
}

func (s *MySQLStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) {
var id string
err := s.db.QueryRowContext(
ctx,
`SELECT id FROM cert_auth_associations WHERE sha256 = ? LIMIT 1;`,
hash,
).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return &mdm.Request{EnrollID: &mdm.EnrollID{ID: id}}, err
}
15 changes: 15 additions & 0 deletions storage/pgsql/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package pgsql

import (
"context"
"database/sql"
"errors"
"strings"

"github.com/micromdm/nanomdm/mdm"
Expand Down Expand Up @@ -50,3 +52,16 @@ ON CONFLICT ON CONSTRAINT cert_auth_associations_pkey DO UPDATE SET updated_at=n
)
return err
}

func (s *PgSQLStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) {
var id string
err := s.db.QueryRowContext(
ctx,
`SELECT id FROM cert_auth_associations WHERE sha256 = $1 LIMIT 1;`,
hash,
).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return &mdm.Request{EnrollID: &mdm.EnrollID{ID: id}}, err
}
7 changes: 7 additions & 0 deletions storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ type CertAuthStore interface {
AssociateCertHash(r *mdm.Request, hash string) error
}

type CertAuthRetriever interface {
// EnrollmentFromHash retrieves an MDM request from a cert hash.
// Implementations should return a nil pointer if no result is found.
// The ID member ought to be populated when non-nil.
EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error)
jessepeterson marked this conversation as resolved.
Show resolved Hide resolved
}

// StoreMigrator retrieves MDM check-ins
type StoreMigrator interface {
// RetrieveMigrationCheckins sends the (decoded) forms of
Expand Down