diff --git a/api/service.go b/api/service.go index 1b2b575f4..ee86c08bc 100644 --- a/api/service.go +++ b/api/service.go @@ -4,12 +4,11 @@ import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/service" "github.com/0xERR0R/blocky/util" - "github.com/go-chi/chi/v5" ) // Service implements service.HTTPService. type Service struct { - service.HTTPInfo + service.SimpleHTTP } func NewService(cfg config.APIService, server StrictServerInterface) *Service { @@ -19,21 +18,10 @@ func NewService(cfg config.APIService, server StrictServerInterface) *Service { ) s := &Service{ - HTTPInfo: service.HTTPInfo{ - Info: service.Info{ - Name: "API", - Endpoints: endpoints, - }, - - Mux: chi.NewMux(), - }, + SimpleHTTP: service.NewSimpleHTTP("API", endpoints), } registerOpenAPIEndpoints(s.Mux, server) return s } - -func (s *Service) Merge(other service.Service) (service.Merger, error) { - return service.MergeHTTP(s, other) -} diff --git a/metrics/service.go b/metrics/service.go index 2bac7199a..32b49e6cf 100644 --- a/metrics/service.go +++ b/metrics/service.go @@ -4,13 +4,12 @@ import ( "github.com/0xERR0R/blocky/config" "github.com/0xERR0R/blocky/service" "github.com/0xERR0R/blocky/util" - "github.com/go-chi/chi/v5" "github.com/prometheus/client_golang/prometheus/promhttp" ) // Service implements service.HTTPService. type Service struct { - service.HTTPInfo + service.SimpleHTTP } func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service { @@ -25,14 +24,7 @@ func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service { } s := &Service{ - HTTPInfo: service.HTTPInfo{ - Info: service.Info{ - Name: "Metrics", - Endpoints: endpoints, - }, - - Mux: chi.NewMux(), - }, + SimpleHTTP: service.NewSimpleHTTP("Metrics", endpoints), } s.Mux.Handle( @@ -42,7 +34,3 @@ func NewService(cfg config.MetricsService, metricsCfg config.Metrics) *Service { return s } - -func (s *Service) Merge(other service.Service) (service.Merger, error) { - return service.MergeHTTP(s, other) -} diff --git a/server/doh.go b/server/doh.go index 50c7a42c6..64354c610 100644 --- a/server/doh.go +++ b/server/doh.go @@ -13,7 +13,7 @@ import ( ) type dohService struct { - service.HTTPInfo + service.SimpleHTTP handler dnsHandler } @@ -25,14 +25,7 @@ func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService { ) s := &dohService{ - HTTPInfo: service.HTTPInfo{ - Info: service.Info{ - Name: "DoH", - Endpoints: endpoints, - }, - - Mux: chi.NewMux(), - }, + SimpleHTTP: service.NewSimpleHTTP("DoH", endpoints), handler: handler, } @@ -50,10 +43,6 @@ func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService { return s } -func (s *dohService) Merge(other service.Service) (service.Merger, error) { - return service.MergeHTTP(s, other) -} - func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) { dnsParam, ok := req.URL.Query()["dns"] if !ok || len(dnsParam[0]) < 1 { diff --git a/server/http.go b/server/http.go index 0ee157800..02d11688e 100644 --- a/server/http.go +++ b/server/http.go @@ -19,7 +19,7 @@ import ( // that expose everything. The goal is to split it up // and remove it. type httpMiscService struct { - service.HTTPInfo + service.SimpleHTTP } func newHTTPMiscService(cfg *config.Config) *httpMiscService { @@ -28,20 +28,13 @@ func newHTTPMiscService(cfg *config.Config) *httpMiscService { service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS), ) - return &httpMiscService{ - HTTPInfo: service.HTTPInfo{ - Info: service.Info{ - Name: "HTTP", - Endpoints: endpoints, - }, - - Mux: createHTTPRouter(cfg), - }, + s := &httpMiscService{ + SimpleHTTP: service.NewSimpleHTTP("HTTP", endpoints), } -} -func (s *httpMiscService) Merge(other service.Service) (service.Merger, error) { - return service.MergeHTTP(s, other) + configureHTTPRouter(s.Router(), cfg) + + return s } // httpServer implements subServer for HTTP. diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 39ade0112..fca8f4b69 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -60,9 +60,7 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPRouter(cfg *config.Config) *chi.Mux { - router := chi.NewRouter() - +func configureHTTPRouter(router chi.Router, cfg *config.Config) { configureDebugHandler(router) configureDocsHandler(router) @@ -70,11 +68,9 @@ func createHTTPRouter(cfg *config.Config) *chi.Mux { configureStaticAssetsHandler(router) configureRootHandler(cfg, router) - - return router } -func configureDocsHandler(router *chi.Mux) { +func configureDocsHandler(router chi.Router) { router.Get("/docs/openapi.yaml", func(writer http.ResponseWriter, request *http.Request) { writer.Header().Set(contentTypeHeader, yamlContentType) _, err := writer.Write([]byte(docs.OpenAPI)) @@ -82,7 +78,7 @@ func configureDocsHandler(router *chi.Mux) { }) } -func configureStaticAssetsHandler(router *chi.Mux) { +func configureStaticAssetsHandler(router chi.Router) { assets, err := web.Assets() util.FatalOnError("unable to load static asset files", err) @@ -90,7 +86,7 @@ func configureStaticAssetsHandler(router *chi.Mux) { router.Handle("/static/*", http.StripPrefix("/static/", fs)) } -func configureRootHandler(cfg *config.Config, router *chi.Mux) { +func configureRootHandler(cfg *config.Config, router chi.Router) { router.Get("/", func(writer http.ResponseWriter, request *http.Request) { writer.Header().Set(contentTypeHeader, htmlContentType) @@ -149,6 +145,6 @@ func logAndResponseWithError(err error, message string, writer http.ResponseWrit } } -func configureDebugHandler(router *chi.Mux) { +func configureDebugHandler(router chi.Router) { router.Mount("/debug", middleware.Profiler()) } diff --git a/service/http.go b/service/http.go index 5eed4342d..a15d823df 100644 --- a/service/http.go +++ b/service/http.go @@ -30,8 +30,29 @@ type HTTPInfo struct { Mux *chi.Mux } +func NewHTTPInfo(name string, endpoints []Endpoint) HTTPInfo { + return HTTPInfo{ + Info: NewInfo(name, endpoints), + + Mux: chi.NewMux(), + } +} + func (i *HTTPInfo) Router() chi.Router { return i.Mux } +var _ HTTPService = (*SimpleHTTP)(nil) + +// SimpleHTTP implements HTTPService usinig the default HTTP merger. +type SimpleHTTP struct{ HTTPInfo } + +func NewSimpleHTTP(name string, endpoints []Endpoint) SimpleHTTP { + return SimpleHTTP{HTTPInfo: NewHTTPInfo(name, endpoints)} +} + +func (s *SimpleHTTP) Merge(other Service) (Merger, error) { + return MergeHTTP(s, other) +} + // MergeHTTP merges two compatible HTTPServices. // // The second parameter is of type `Service` to make it easy to call diff --git a/service/http_test.go b/service/http_test.go index c7f405ea6..5abc50261 100644 --- a/service/http_test.go +++ b/service/http_test.go @@ -14,7 +14,7 @@ var _ = Describe("Service HTTP", func() { Describe("HTTPInfo", func() { It("returns the expected router", func() { endpoints := EndpointsFromAddrs("proto", []string{":1", "localhost:2"}) - sut := HTTPInfo{Info{"name", endpoints}, chi.NewMux()} + sut := NewHTTPInfo("name", endpoints) Expect(sut.ServiceName()).Should(Equal("name")) Expect(sut.ExposeOn()).Should(Equal(endpoints)) @@ -70,17 +70,17 @@ var _ = Describe("Service HTTP", func() { }) It("doesn't modify what HTTP routes match", func() { - apiSvc := newFakeHTTPService("API", ":443") + apiSvc := NewSimpleHTTP("API", EndpointsFromAddrs("https", []string{":443"})) apiSvc.Router().Post("/api", nil) apiSvc.Router().Get("/api/get/", nil) - dohSvc := newFakeHTTPService("DoH", ":443") + dohSvc := NewSimpleHTTP("DoH", EndpointsFromAddrs("https", []string{":443"})) dohSvc.Router().Route("/dns-query", func(mux chi.Router) { mux.Get("/", nil) mux.Post("/", nil) }) - merged, err := apiSvc.Merge(dohSvc) + merged, err := apiSvc.Merge(&dohSvc) Expect(err).Should(Succeed()) casted, ok := merged.(HTTPService) @@ -105,20 +105,9 @@ var _ = Describe("Service HTTP", func() { }) }) -type fakeHTTPService struct { - HTTPInfo -} - -func newFakeHTTPService(name string, addrs ...string) *fakeHTTPService { - mux := chi.NewMux() - mux.Get("/"+name, nil) - - return &fakeHTTPService{HTTPInfo{ - Info: Info{Name: name, Endpoints: EndpointsFromAddrs("http", addrs)}, - Mux: mux, - }} -} +func newFakeHTTPService(name string, addrs ...string) HTTPService { + svc := NewSimpleHTTP(name, EndpointsFromAddrs("http", addrs)) + svc.Router().Get("/"+name, nil) -func (s *fakeHTTPService) Merge(other Service) (Merger, error) { - return MergeHTTP(s, other) + return &svc } diff --git a/service/service.go b/service/service.go index e0e34db3e..d452c1796 100644 --- a/service/service.go +++ b/service/service.go @@ -49,6 +49,13 @@ type Info struct { Endpoints []Endpoint } +func NewInfo(name string, endpoints []Endpoint) Info { + return Info{ + Name: name, + Endpoints: endpoints, + } +} + func (i *Info) ServiceName() string { return i.Name } func (i *Info) ExposeOn() []Endpoint { return i.Endpoints } func (i *Info) String() string { return svcString(i) } diff --git a/service/service_test.go b/service/service_test.go index a77a52a83..208353eb6 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,7 +10,7 @@ var _ = Describe("Service", func() { Describe("Info", func() { endpoints := EndpointsFromAddrs("proto", []string{":1", "localhost:2"}) - sut := Info{"name", endpoints} + sut := NewInfo("name", endpoints) It("implements Service", func() { var svc Service = &sut