Skip to content

Commit

Permalink
feat: allow configuring DoH separately from other HTTP endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Apr 2, 2024
1 parent 80e0cad commit 13798e0
Show file tree
Hide file tree
Showing 8 changed files with 192 additions and 105 deletions.
12 changes: 12 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ type Config struct {
Redis Redis `yaml:"redis"`
Log log.Config `yaml:"log"`
Ports Ports `yaml:"ports"`
Services *Services `yaml:"services"`
MinTLSServeVer TLSVersion `yaml:"minTlsServeVersion" default:"1.2"`
CertFile string `yaml:"certFile"`
KeyFile string `yaml:"keyFile"`
Expand Down Expand Up @@ -262,6 +263,10 @@ type Config struct {
} `yaml:",inline"`
}

type Services struct {
DoH DoHService `yaml:"dns-over-https"`
}

type Ports struct {
DNS ListenConfig `yaml:"dns" default:"53"`
HTTP ListenConfig `yaml:"http"`
Expand Down Expand Up @@ -590,6 +595,13 @@ func (cfg *Config) migrate(logger *logrus.Entry) bool {
}),
})

// Ports forward-compat, should be switched to Services back-compat once everything is there
if cfg.Services == nil {
cfg.Services = new(Services)
cfg.Services.DoH.Addrs.HTTP = cfg.Ports.HTTP
cfg.Services.DoH.Addrs.HTTPS = cfg.Ports.HTTPS
}

usesDepredOpts = cfg.Blocking.migrate(logger) || usesDepredOpts
usesDepredOpts = cfg.HostsFile.migrate(logger) || usesDepredOpts

Expand Down
18 changes: 18 additions & 0 deletions config/doh_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package config

type DoHService struct {
Addrs DoHAddrs `yaml:"addrs"`
}

type DoHAddrs struct {
HTTPAddrs `yaml:",inline"`
HTTPSAddrs `yaml:",inline"`
}

type HTTPAddrs struct {
HTTP ListenConfig `yaml:"http"`
}

type HTTPSAddrs struct {
HTTPS ListenConfig `yaml:"https"`
}
137 changes: 137 additions & 0 deletions server/doh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
package server

import (
"encoding/base64"
"io"
"net/http"

"github.com/0xERR0R/blocky/config"
"github.com/0xERR0R/blocky/service"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/miekg/dns"
)

type dohService struct {
service.HTTPInfo

handler dnsHandler
}

func newDoHService(cfg config.DoHService, handler dnsHandler) *dohService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(httpProtocol, cfg.Addrs.HTTP),
service.EndpointsFromAddrs(httpsProtocol, cfg.Addrs.HTTPS),
)

s := &dohService{
HTTPInfo: service.HTTPInfo{
Info: service.Info{
Name: "DoH",
Endpoints: endpoints,
},

Mux: chi.NewMux(),
},

handler: handler,
}

s.Mux.Route("/dns-query", func(mux chi.Router) {
// Handlers for / also handle /dns-query without trailing slash

mux.Get("/", s.handleGET)
mux.Get("/{clientID}", s.handleGET)

mux.Post("/", s.handlePOST)
mux.Post("/{clientID}", s.handlePOST)
})

return s
}

func (s *dohService) Merge(other service.Service) (service.Merger, error) {
return service.MergeHTTP(s, other)
}

func (s *dohService) handleGET(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)

return
}

rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
if err != nil {
http.Error(rw, "wrong message format", http.StatusBadRequest)

return
}

if len(rawMsg) > dohMessageLimit {
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)

return
}

s.processDohMessage(rawMsg, rw, req)
}

func (s *dohService) handlePOST(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)

return
}

rawMsg, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)

return
}

if len(rawMsg) > dohMessageLimit {
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)

return
}

s.processDohMessage(rawMsg, rw, req)
}

func (s *dohService) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)

return
}

ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)

s.handler(ctx, dnsReq, httpMsgWriter{rw})
}

type httpMsgWriter struct {
rw http.ResponseWriter
}

func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}

r.rw.Header().Set("content-type", dnsContentType)

// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)

_, err = r.rw.Write(b)

return err
}
4 changes: 2 additions & 2 deletions server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type httpService struct {
service.HTTPInfo
}

func newHTTPService(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *httpService {
func newHTTPService(cfg *config.Config, openAPIImpl api.StrictServerInterface) *httpService {
endpoints := util.ConcatSlices(
service.EndpointsFromAddrs(httpProtocol, cfg.Ports.HTTP),
service.EndpointsFromAddrs(httpsProtocol, cfg.Ports.HTTPS),
Expand All @@ -35,7 +35,7 @@ func newHTTPService(cfg *config.Config, openAPIImpl api.StrictServerInterface, d
Endpoints: endpoints,
},

Mux: createHTTPRouter(cfg, openAPIImpl, dnsHandler),
Mux: createHTTPRouter(cfg, openAPIImpl),
},
}
}
Expand Down
5 changes: 4 additions & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ func (s *Server) createServices() ([]service.Service, error) {
}

res := []service.Service{
newHTTPService(s.cfg, openAPIImpl, s.handleReq),
newHTTPService(s.cfg, openAPIImpl),
newDoHService(s.cfg.Services.DoH, s.handleReq),
}

// Remove services the user has not enabled
Expand Down Expand Up @@ -237,6 +238,8 @@ func createListeners(ctx context.Context, cfg *config.Config, tlsCfg *tls.Config
err := errors.Join(
newListeners(ctx, httpProtocol, cfg.Ports.HTTP, service.ListenTCP, res),
newListeners(ctx, httpsProtocol, cfg.Ports.HTTPS, listenTLS, res),
newListeners(ctx, httpProtocol, cfg.Services.DoH.Addrs.HTTP, service.ListenTCP, res),
newListeners(ctx, httpsProtocol, cfg.Services.DoH.Addrs.HTTPS, listenTLS, res),
)
if err != nil {
return nil, err
Expand Down
103 changes: 1 addition & 102 deletions server/server_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ package server

import (
"context"
"encoding/base64"
"fmt"
"html/template"
"io"
"net"
"net/http"
"time"
Expand Down Expand Up @@ -68,103 +66,6 @@ func (s *Server) createOpenAPIInterfaceImpl() (impl api.StrictServerInterface, e
return api.NewOpenAPIInterfaceImpl(bControl, s, refresher, cacheControl), nil
}

func registerDoHEndpoints(router *chi.Mux, dnsHandler dnsHandler) {
const pathDohQuery = "/dns-query"

s := &dohServer{dnsHandler}

router.Get(pathDohQuery, s.dohGetRequestHandler)
router.Get(pathDohQuery+"/", s.dohGetRequestHandler)
router.Get(pathDohQuery+"/{clientID}", s.dohGetRequestHandler)
router.Post(pathDohQuery, s.dohPostRequestHandler)
router.Post(pathDohQuery+"/", s.dohPostRequestHandler)
router.Post(pathDohQuery+"/{clientID}", s.dohPostRequestHandler)
}

type dohServer struct{ handler dnsHandler }

func (s *dohServer) dohGetRequestHandler(rw http.ResponseWriter, req *http.Request) {
dnsParam, ok := req.URL.Query()["dns"]
if !ok || len(dnsParam[0]) < 1 {
http.Error(rw, "dns param is missing", http.StatusBadRequest)

return
}

rawMsg, err := base64.RawURLEncoding.DecodeString(dnsParam[0])
if err != nil {
http.Error(rw, "wrong message format", http.StatusBadRequest)

return
}

if len(rawMsg) > dohMessageLimit {
http.Error(rw, "URI Too Long", http.StatusRequestURITooLong)

return
}

s.processDohMessage(rawMsg, rw, req)
}

func (s *dohServer) dohPostRequestHandler(rw http.ResponseWriter, req *http.Request) {
contentType := req.Header.Get("Content-type")
if contentType != dnsContentType {
http.Error(rw, "unsupported content type", http.StatusUnsupportedMediaType)

return
}

rawMsg, err := io.ReadAll(req.Body)
if err != nil {
http.Error(rw, err.Error(), http.StatusBadRequest)

return
}

if len(rawMsg) > dohMessageLimit {
http.Error(rw, "Payload Too Large", http.StatusRequestEntityTooLarge)

return
}

s.processDohMessage(rawMsg, rw, req)
}

func (s *dohServer) processDohMessage(rawMsg []byte, rw http.ResponseWriter, httpReq *http.Request) {
msg := new(dns.Msg)
if err := msg.Unpack(rawMsg); err != nil {
logger().Error("can't deserialize message: ", err)
http.Error(rw, err.Error(), http.StatusBadRequest)

return
}

ctx, dnsReq := newRequestFromHTTP(httpReq.Context(), httpReq, msg)

s.handler(ctx, dnsReq, httpMsgWriter{rw})
}

type httpMsgWriter struct {
rw http.ResponseWriter
}

func (r httpMsgWriter) WriteMsg(msg *dns.Msg) error {
b, err := msg.Pack()
if err != nil {
return err
}

r.rw.Header().Set("content-type", dnsContentType)

// https://www.rfc-editor.org/rfc/rfc8484#section-4.2.1
r.rw.WriteHeader(http.StatusOK)

_, err = r.rw.Write(b)

return err
}

func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
Expand All @@ -176,7 +77,7 @@ func (s *Server) Query(
return s.resolve(ctx, req)
}

func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface, dnsHandler dnsHandler) *chi.Mux {
func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface) *chi.Mux {
router := chi.NewRouter()

configureSecureHeaderHandler(router)
Expand All @@ -193,8 +94,6 @@ func createHTTPRouter(cfg *config.Config, openAPIImpl api.StrictServerInterface,

configureRootHandler(cfg, router)

registerDoHEndpoints(router, dnsHandler)

metrics.Start(router, cfg.Prometheus)

return router
Expand Down
9 changes: 9 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ var _ = BeforeSuite(func() {
},
}

cfg.Services = &config.Services{
DoH: config.DoHService{
Addrs: config.DoHAddrs{
HTTPAddrs: config.HTTPAddrs{HTTP: cfg.Ports.HTTP},
HTTPSAddrs: config.HTTPSAddrs{HTTPS: cfg.Ports.HTTPS},
},
},
}

// create server
sut, err = NewServer(ctx, cfg)
Expect(err).Should(Succeed())
Expand Down
9 changes: 9 additions & 0 deletions service/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ func (m *httpMerger) Merge(other Service) (Merger, error) {
_ = chi.Walk(httpSvc.Router(), func(method, route string, handler http.Handler, middlewares ...middleware) error {
m.router.With(middlewares...).Method(method, route, handler)

// Expose /example/ as /example too
// Workaround for chi.Walk missing the second form https://github.com/go-chi/chi/issues/830
// This means we expose the route without the slash even if it wasn't oringinally registered as such!
// The main point of this is for DoH (/dns-query).
if strings.HasSuffix(route, "/") {
route := strings.TrimSuffix(route, "/")
m.router.With(middlewares...).Method(method, route, handler)
}

return nil
})

Expand Down

0 comments on commit 13798e0

Please sign in to comment.