From 5571e192084260b33f55525833f10c3cf3a8d402 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Fri, 30 Aug 2024 18:23:54 -0400 Subject: [PATCH] refactor: switch DoH to Service pattern --- config/config.go | 25 +++++++ config/http_services.go | 18 +++++ server/doh.go | 137 +++++++++++++++++++++++++++++++++++++ server/http.go | 6 +- server/server.go | 9 ++- server/server_endpoints.go | 103 +--------------------------- service/http.go | 8 +++ service/http_test.go | 36 ++++++++++ 8 files changed, 235 insertions(+), 107 deletions(-) create mode 100644 config/http_services.go create mode 100644 server/doh.go diff --git a/config/config.go b/config/config.go index 8890b996e..b1c940f71 100644 --- a/config/config.go +++ b/config/config.go @@ -234,6 +234,7 @@ type Config struct { Redis Redis `yaml:"redis"` Log log.Config `yaml:"log"` Ports Ports `yaml:"ports"` + Services Services `yaml:"-"` // not user exposed yet MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"` CertFile string `yaml:"certFile"` KeyFile string `yaml:"keyFile"` @@ -263,6 +264,17 @@ type Config struct { } `yaml:",inline"` } +// Services holds network service related configuration. +// +// The actual config layout is not decided yet. +// See https://github.com/0xERR0R/blocky/issues/1206 +// +// The `yaml` struct tags are just for manual testing, +// and require replacing `yaml:"-"` in Config to work. +type Services struct { + DoH DoHService `yaml:"dns-over-https"` +} + type Ports struct { DNS ListenConfig `yaml:"dns" default:"53"` HTTP ListenConfig `yaml:"http"` @@ -601,6 +613,19 @@ func (cfg *Config) validate(logger *logrus.Entry) { cfg.Upstreams.validate(logger) } +// CopyPortsToServices sets Services values to match Ports. +// +// This should be replaced with a migration once everything from Ports is supported in Services. +// Done this way for now to avoid creating temporary generic services and updating all Ports related code at once. +func (cfg *Config) CopyPortsToServices() { + cfg.Services = Services{ + DoH: DoHService{Addrs: DoHAddrs{ + HTTPAddrs: HTTPAddrs{HTTP: cfg.Ports.HTTP}, + HTTPSAddrs: HTTPSAddrs{HTTPS: cfg.Ports.HTTPS}, + }}, + } +} + // ConvertPort converts string representation into a valid port (0 - 65535) func ConvertPort(in string) (uint16, error) { const ( diff --git a/config/http_services.go b/config/http_services.go new file mode 100644 index 000000000..6ac98ed3e --- /dev/null +++ b/config/http_services.go @@ -0,0 +1,18 @@ +package config + +type DoHService struct { + Addrs DoHAddrs `yaml:"addrs"` +} + +type DoHAddrs struct { + HTTPAddrs `yaml:",inline"` + HTTPSAddrs `yaml:",inline"` +} + +type HTTPAddrs struct { + HTTP ListenConfig `yaml:"http"` +} + +type HTTPSAddrs struct { + HTTPS ListenConfig `yaml:"https"` +} diff --git a/server/doh.go b/server/doh.go new file mode 100644 index 000000000..50c7a42c6 --- /dev/null +++ b/server/doh.go @@ -0,0 +1,137 @@ +package server + +import ( + "encoding/base64" + "io" + "net/http" + + "github.com/0xERR0R/blocky/config" + "github.com/0xERR0R/blocky/service" + "github.com/0xERR0R/blocky/util" + "github.com/go-chi/chi/v5" + "github.com/miekg/dns" +) + +type dohService struct { + service.HTTPInfo + + handler dnsHandler +} + +func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService { + endpoints := util.ConcatSlices( + service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Addrs.HTTP), + service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Addrs.HTTPS), + ) + + s := &dohService{ + HTTPInfo: service.HTTPInfo{ + Info: service.Info{ + Name: "DoH", + Endpoints: endpoints, + }, + + Mux: chi.NewMux(), + }, + + handler: handler, + } + + s.Mux.Route("/dns-query", func(mux chi.Router) { + // Handlers for / also handle /dns-query without trailing slash + + mux.Get("/", s.handleGET) + mux.Get("/{clientID}", s.handleGET) + + mux.Post("/", s.handlePOST) + mux.Post("/{clientID}", s.handlePOST) + }) + + return s +} + +func (s *dohService) Merge(other service.Service) (service.Merger, error) { + return service.MergeHTTP(s, other) +} + +func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) { + dnsParam, ok := req.URL.Query()["dns"] + if !ok || len(dnsParam[0]) < 1 { + http.Error(rw, "dns param is missing", http.StatusBadRequest) + + return + } + + rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0]) + if err != nil { + http.Error(rw, "wrong message format", http.StatusBadRequest) + + return + } + + if len(rawMsg) > dohMessageLimit { + http.Error(rw, "URI Too Long", http.StatusRequestURITooLong) + + return + } + + s.processDohMessage(rawMsg, rw, req) +} + +func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) { + contentType := req.Header.Get("Content-type") + if contentType != dnsContentType { + http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType) + + return + } + + rawMsg, err := io.ReadAll(req.Body) + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + + return + } + + if len(rawMsg) > dohMessageLimit { + http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge) + + return + } + + s.processDohMessage(rawMsg, rw, req) +} + +func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { + msg := new(dns.Msg) + if err := msg.Unpack(rawMsg); err != nil { + logger().Error("can't deserialize message: ", err) + http.Error(rw, err.Error(), http.StatusBadRequest) + + return + } + + ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg) + + s.handler(ctx, dnsReq, httpMsgWriter{rw}) +} + +type httpMsgWriter struct { + rw http.ResponseWriter +} + +func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error { + b, err := msg.Pack() + if err != nil { + return err + } + + r.rw.Header().Set("content-type", dnsContentType) + + // https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1 + r.rw.WriteHeader(http.StatusOK) + + _, err = r.rw.Write(b) + + return err +} diff --git a/server/http.go b/server/http.go index 7c4da3230..b13c3389a 100644 --- a/server/http.go +++ b/server/http.go @@ -23,9 +23,7 @@ type httpMiscService struct { service.HTTPInfo } -func newHTTPMiscService( - cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler, -) *httpMiscService { +func newHTTPMiscService(cfg *config.Config, openAPIImpl api.StrictServerInterface) *httpMiscService { endpoints := util.ConcatSlices( service.EndpointsFromAddrs(service.HTTPProtocol, cfg.Ports.HTTP), service.EndpointsFromAddrs(service.HTTPSProtocol, cfg.Ports.HTTPS), @@ -38,7 +36,7 @@ func newHTTPMiscService( Endpoints: endpoints, }, - Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler), + Mux: createHTTPRouter(cfg, openAPIImpl), }, } } diff --git a/server/server.go b/server/server.go index e5d4817b2..7c9a5e227 100644 --- a/server/server.go +++ b/server/server.go @@ -109,7 +109,11 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) { } // NewServer creates new server instance with passed config +// +//nolint:funlen func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) { + cfg.CopyPortsToServices() + var tlsCfg *tls.Config if len(cfg.Ports.HTTPS) > 0 || len(cfg.Ports.TLS) > 0 { @@ -179,7 +183,8 @@ func (s *Server) createServices() ([]service.Service, error) { } res := []service.Service{ - newHTTPMiscService(s.cfg, openAPIImpl, s.handleReq), + newHTTPMiscService(s.cfg, openAPIImpl), + newDoHService(s.cfg.Services.DoH, s.handleReq), } // Remove services the user has not enabled @@ -228,6 +233,8 @@ func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config err := errors.Join( newListeners(ctx, service.HTTPProtocol, cfg.Ports.HTTP, service.ListenTCP, res), newListeners(ctx, service.HTTPSProtocol, cfg.Ports.HTTPS, listenTLS, res), + newListeners(ctx, service.HTTPProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res), + newListeners(ctx, service.HTTPSProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res), ) if err != nil { return nil, err diff --git a/server/server_endpoints.go b/server/server_endpoints.go index cac287f43..7024170fe 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -2,10 +2,8 @@ package server import ( "context" - "encoding/base64" "fmt" "html/template" - "io" "net" "net/http" @@ -52,103 +50,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil } -func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) { - const pathDohQuery = "/dns-query" - - s := &dohServer{dnsHandler} - - router.Get(pathDohQuery, s.dohGetRequestHandler) - router.Get(pathDohQuery+"/", s.dohGetRequestHandler) - router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler) - router.Post(pathDohQuery, s.dohPostRequestHandler) - router.Post(pathDohQuery+"/", s.dohPostRequestHandler) - router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler) -} - -type dohServer struct{ handler dnsHandler } - -func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) { - dnsParam, ok := req.URL.Query()["dns"] - if !ok || len(dnsParam[0]) < 1 { - http.Error(rw, "dns param is missing", http.StatusBadRequest) - - return - } - - rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0]) - if err != nil { - http.Error(rw, "wrong message format", http.StatusBadRequest) - - return - } - - if len(rawMsg) > dohMessageLimit { - http.Error(rw, "URI Too Long", http.StatusRequestURITooLong) - - return - } - - s.processDohMessage(rawMsg, rw, req) -} - -func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) { - contentType := req.Header.Get("Content-type") - if contentType != dnsContentType { - http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType) - - return - } - - rawMsg, err := io.ReadAll(req.Body) - if err != nil { - http.Error(rw, err.Error(), http.StatusBadRequest) - - return - } - - if len(rawMsg) > dohMessageLimit { - http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge) - - return - } - - s.processDohMessage(rawMsg, rw, req) -} - -func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) { - msg := new(dns.Msg) - if err := msg.Unpack(rawMsg); err != nil { - logger().Error("can't deserialize message: ", err) - http.Error(rw, err.Error(), http.StatusBadRequest) - - return - } - - ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg) - - s.handler(ctx, dnsReq, httpMsgWriter{rw}) -} - -type httpMsgWriter struct { - rw http.ResponseWriter -} - -func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error { - b, err := msg.Pack() - if err != nil { - return err - } - - r.rw.Header().Set("content-type", dnsContentType) - - // https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1 - r.rw.WriteHeader(http.StatusOK) - - _, err = r.rw.Write(b) - - return err -} - func (s *Server) Query( ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, ) (*model.Response, error) { @@ -160,7 +61,7 @@ func (s *Server) Query( return s.resolve(ctx, req) } -func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux { +func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux { router := chi.NewRouter() api.RegisterOpenAPIEndpoints(router, openAPIImpl) @@ -173,8 +74,6 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, configureRootHandler(cfg, router) - registerDoHEndpoints(router, dnsHandler) - metrics.Start(router, cfg.Prometheus) return router diff --git a/service/http.go b/service/http.go index 6b634c821..5eed4342d 100644 --- a/service/http.go +++ b/service/http.go @@ -82,6 +82,14 @@ func (m *httpMerger) Merge(other Service) (Merger, error) { _ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error { m.router.With(middlewares...).Method(method, route, handler) + // Expose /example/ as /example too + // Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830 + // The main point of this is for DoH's `/dns-query` endpoint. + if strings.HasSuffix(route, "/") { + route := strings.TrimSuffix(route, "/") + m.router.With(middlewares...).Method(method, route, handler) + } + return nil }) diff --git a/service/http_test.go b/service/http_test.go index 20c1a76e7..c7f405ea6 100644 --- a/service/http_test.go +++ b/service/http_test.go @@ -1,6 +1,8 @@ package service import ( + "net/http" + "github.com/go-chi/chi/v5" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -66,6 +68,40 @@ var _ = Describe("Service HTTP", func() { _, err = sut.Merge(nonHTTPSvc) Expect(err).Should(MatchError(ContainSubstring("not an HTTPService"))) }) + + It("doesn't modify what HTTP routes match", func() { + apiSvc := newFakeHTTPService("API", ":443") + apiSvc.Router().Post("/api", nil) + apiSvc.Router().Get("/api/get/", nil) + + dohSvc := newFakeHTTPService("DoH", ":443") + dohSvc.Router().Route("/dns-query", func(mux chi.Router) { + mux.Get("/", nil) + mux.Post("/", nil) + }) + + merged, err := apiSvc.Merge(dohSvc) + Expect(err).Should(Succeed()) + + casted, ok := merged.(HTTPService) + Expect(ok).Should(BeTrue()) + + chiCtx := chi.NewRouteContext() + mux := casted.Router() + + Expect(mux.Match(chiCtx, http.MethodPost, "/api")).Should(BeTrue()) + Expect(mux.Match(chiCtx, http.MethodPost, "/api/")).Should(BeFalse()) + Expect(mux.Match(chiCtx, http.MethodGet, "/api")).Should(BeFalse()) + + Expect(mux.Match(chiCtx, http.MethodGet, "/api/get/")).Should(BeTrue()) + Expect(mux.Match(chiCtx, http.MethodGet, "/api/get")).Should(BeFalse()) + Expect(mux.Match(chiCtx, http.MethodPost, "/api/get/")).Should(BeFalse()) + + Expect(mux.Match(chiCtx, http.MethodGet, "/dns-query")).Should(BeTrue()) + Expect(mux.Match(chiCtx, http.MethodGet, "/dns-query/")).Should(BeTrue()) + Expect(mux.Match(chiCtx, http.MethodPost, "/dns-query")).Should(BeTrue()) + Expect(mux.Match(chiCtx, http.MethodPost, "/dns-query/")).Should(BeTrue()) + }) }) })