diff --git a/server/server.go b/server/server.go index d3f4d4c88..dffba376e 100644 --- a/server/server.go +++ b/server/server.go @@ -117,19 +117,11 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err return nil, fmt.Errorf("server creation failed: %w", err) } - httpRouter := createHTTPRouter(cfg) - httpsRouter := createHTTPSRouter(cfg) - httpListeners, httpsListeners, err := createHTTPListeners(cfg) if err != nil { return nil, err } - if len(httpListeners) != 0 || len(httpsListeners) != 0 { - metrics.Start(httpRouter, cfg.Prometheus) - metrics.Start(httpsRouter, cfg.Prometheus) - } - metrics.RegisterEventListeners() bootstrap, err := resolver.NewBootstrap(ctx, cfg) @@ -156,24 +148,26 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err cfg: cfg, httpListeners: httpListeners, httpsListeners: httpsListeners, - httpMux: httpRouter, - httpsMux: httpsRouter, cert: cert, } server.printConfiguration() - server.registerDNSHandlers(ctx) - err = server.registerAPIEndpoints(httpRouter) - + openAPIImpl, err := server.createOpenAPIInterfaceImpl() if err != nil { return nil, err } - err = server.registerAPIEndpoints(httpsRouter) + server.registerDNSHandlers(ctx) - if err != nil { - return nil, err + if len(cfg.Ports.HTTP) != 0 { + server.httpMux = createHTTPRouter(cfg, openAPIImpl) + server.registerDoHEndpoints(server.httpMux) + } + + if len(cfg.Ports.HTTPS) != 0 { + server.httpsMux = createHTTPSRouter(cfg, openAPIImpl) + server.registerDoHEndpoints(server.httpsMux) } return server, err diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 54f43a125..5661184c1 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -10,6 +10,7 @@ import ( "net/http" "time" + "github.com/0xERR0R/blocky/metrics" "github.com/0xERR0R/blocky/resolver" "github.com/0xERR0R/blocky/api" @@ -64,24 +65,15 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func (s *Server) registerAPIEndpoints(router *chi.Mux) error { +func (s *Server) registerDoHEndpoints(router *chi.Mux) { const pathDohQuery = "/dns-query" - openAPIImpl, err := s.createOpenAPIInterfaceImpl() - if err != nil { - return err - } - - api.RegisterOpenAPIEndpoints(router, openAPIImpl) - router.Get(pathDohQuery, s.dohGetRequestHandler) router.Get(pathDohQuery+"/", s.dohGetRequestHandler) router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) router.Post(pathDohQuery, s.dohPostRequestHandler) router.Post(pathDohQuery+"/", s.dohPostRequestHandler) router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) - - return nil } func (s *Server) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { @@ -177,27 +169,29 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPSRouter(cfg *config.Config) *chi.Mux { +func createHTTPSRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() configureSecureHeaderHandler(router) - registerHandlers(cfg, router) + registerHandlers(cfg, router, openAPIImpl) return router } -func createHTTPRouter(cfg *config.Config) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() - registerHandlers(cfg, router) + registerHandlers(cfg, router, openAPIImpl) return router } -func registerHandlers(cfg *config.Config, router *chi.Mux) { +func registerHandlers(cfg *config.Config, router *chi.Mux, openAPIImpl api.StrictServerInterface) { configureCorsHandler(router) + api.RegisterOpenAPIEndpoints(router, openAPIImpl) + configureDebugHandler(router) configureDocsHandler(router) @@ -205,6 +199,8 @@ func registerHandlers(cfg *config.Config, router *chi.Mux) { configureStaticAssetsHandler(router) configureRootHandler(cfg, router) + + metrics.Start(router, cfg.Prometheus) } func configureDocsHandler(router *chi.Mux) {