From 36d443728dd19290991f91b21eba81774e66806f Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 2 Apr 2024 22:06:09 -0400 Subject: [PATCH] refactor(server): move middleware setup to `httpServer` --- server/http.go | 50 +++++++++++++++++++++++++++++++++++++- server/server_endpoints.go | 36 --------------------------- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/server/http.go b/server/http.go index 78f7fe0df..cac0e8102 100644 --- a/server/http.go +++ b/server/http.go @@ -5,6 +5,9 @@ import ( "net" "net/http" "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/cors" ) type httpServer struct { @@ -26,7 +29,7 @@ func newHTTPServer(name string, handler http.Handler) *httpServer { ReadHeaderTimeout: readHeaderTimeout, WriteTimeout: writeTimeout, - Handler: handler, + Handler: withCommonMiddleware(handler), }, name: name, @@ -46,3 +49,48 @@ func (s *httpServer) Serve(ctx context.Context, l net.Listener) error { return s.inner.Serve(l) } + +func withCommonMiddleware(inner http.Handler) *chi.Mux { + // Middleware must be defined before routes, so + // create a new router and mount the inner handler + mux := chi.NewMux() + + mux.Use( + secureHeadersMiddleware, + newCORSMiddleware(), + ) + + mux.Mount("/", inner) + + return mux +} + +type httpMiddleware = func(http.Handler) http.Handler + +func secureHeadersMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.TLS != nil { + w.Header().Set("strict-transport-security", "max-age=63072000") + w.Header().Set("x-frame-options", "DENY") + w.Header().Set("x-content-type-options", "nosniff") + w.Header().Set("x-xss-protection", "1; mode=block") + } + + next.ServeHTTP(w, r) + }) +} + +func newCORSMiddleware() httpMiddleware { + const corsMaxAge = 5 * time.Minute + + options := cors.Options{ + AllowCredentials: true, + AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, + AllowedMethods: []string{"GET", "POST"}, + AllowedOrigins: []string{"*"}, + ExposedHeaders: []string{"Link"}, + MaxAge: int(corsMaxAge.Seconds()), + } + + return cors.New(options).Handler +} diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 582a21127..af862a980 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "time" "github.com/0xERR0R/blocky/metrics" "github.com/0xERR0R/blocky/resolver" @@ -23,7 +22,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/go-chi/cors" "github.com/miekg/dns" ) @@ -33,22 +31,8 @@ const ( dnsContentType = "application/dns-message" htmlContentType = "text/html; charset=UTF-8" yamlContentType = "text/yaml" - corsMaxAge = 5 * time.Minute ) -func secureHeader(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.TLS != nil { - w.Header().Set("strict-transport-security", "max-age=63072000") - w.Header().Set("x-frame-options", "DENY") - w.Header().Set("x-content-type-options", "nosniff") - w.Header().Set("x-xss-protection", "1; mode=block") - } - - next.ServeHTTP(w, r) - }) -} - func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, err error) { bControl, err := resolver.GetFromChainWithType[api.BlockingControl](s.queryResolver) if err != nil { @@ -175,10 +159,6 @@ func (s *Server) Query( func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() - configureSecureHeaderHandler(router) - - configureCorsHandler(router) - api.RegisterOpenAPIEndpoints(router, openAPIImpl) configureDebugHandler(router) @@ -265,22 +245,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit } } -func configureSecureHeaderHandler(router *chi.Mux) { - router.Use(secureHeader) -} - func configureDebugHandler(router *chi.Mux) { router.Mount("/debug", middleware.Profiler()) } - -func configureCorsHandler(router *chi.Mux) { - crs := cors.New(cors.Options{ - AllowedOrigins: []string{"*"}, - AllowedMethods: []string{"GET", "POST"}, - AllowedHeaders: []string{"Accept", "Authorization", "Content-Type", "X-CSRF-Token"}, - ExposedHeaders: []string{"Link"}, - AllowCredentials: true, - MaxAge: int(corsMaxAge.Seconds()), - }) - router.Use(crs.Handler) -}