Skip to content

Commit

Permalink
refactor(server): setup TLS listeners manually to remove ServeTLS use
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Aug 30, 2024
1 parent ae83f2e commit 39ae088
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
32 changes: 21 additions & 11 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ type Server struct {
queryResolver resolver.ChainedResolver
cfg *config.Config
httpMux *chi.Mux
tlsCfg *tls.Config
}

func logger() *logrus.Entry {
Expand Down Expand Up @@ -117,8 +116,6 @@ func newTLSConfig(cfg *config.Config) (*tls.Config, error) {
}

// NewServer creates new server instance with passed config
//
//nolint:funlen
func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err error) {
var tlsCfg *tls.Config

Expand All @@ -134,7 +131,7 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
return nil, fmt.Errorf("server creation failed: %w", err)
}

httpListeners, httpsListeners, err := createHTTPListeners(cfg)
httpListeners, httpsListeners, err := createHTTPListeners(cfg, tlsCfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -165,7 +162,6 @@ func NewServer(ctx context.Context, cfg *config.Config) (server *Server, err err
cfg: cfg,
httpListeners: httpListeners,
httpsListeners: httpsListeners,
tlsCfg: tlsCfg,
}

server.printConfiguration()
Expand Down Expand Up @@ -211,21 +207,23 @@ func createServers(cfg *config.Config, tlsCfg *tls.Config) ([]*dns.Server, error
return dnsServers, err.ErrorOrNil()
}

func createHTTPListeners(cfg *config.Config) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newListeners("http", cfg.Ports.HTTP)
func createHTTPListeners(
cfg *config.Config, tlsCfg *tls.Config,
) (httpListeners, httpsListeners []net.Listener, err error) {
httpListeners, err = newTCPListeners("http", cfg.Ports.HTTP)
if err != nil {
return nil, nil, err
}

httpsListeners, err = newListeners("https", cfg.Ports.HTTPS)
httpsListeners, err = newTLSListeners("https", cfg.Ports.HTTPS, tlsCfg)
if err != nil {
return nil, nil, err
}

return httpListeners, httpsListeners, nil
}

func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
func newTCPListeners(proto string, addresses config.ListenConfig) ([]net.Listener, error) {
listeners := make([]net.Listener, 0, len(addresses))

for _, address := range addresses {
Expand All @@ -240,6 +238,19 @@ func newListeners(proto string, addresses config.ListenConfig) ([]net.Listener,
return listeners, nil
}

func newTLSListeners(proto string, addresses config.ListenConfig, tlsCfg *tls.Config) ([]net.Listener, error) {
listeners, err := newTCPListeners(proto, addresses)
if err != nil {
return nil, err
}

for i, inner := range listeners {
listeners[i] = tls.NewListener(inner, tlsCfg)
}

return listeners, nil
}

func createTLSServer(address string, tlsCfg *tls.Config) (*dns.Server, error) {
return &dns.Server{
Addr: address,
Expand Down Expand Up @@ -521,10 +532,9 @@ func (s *Server) Start(ctx context.Context, errCh chan<- error) {
ReadTimeout: readTimeout,
ReadHeaderTimeout: readHeaderTimeout,
WriteTimeout: writeTimeout,
TLSConfig: s.tlsCfg,
}

if err := server.ServeTLS(listener, "", ""); err != nil {
if err := server.Serve(listener); err != nil {
errCh <- fmt.Errorf("start https listener failed: %w", err)
}
}()
Expand Down
8 changes: 4 additions & 4 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -741,11 +740,12 @@ var _ = Describe("Running DNS server", func() {
cfg.KeyFile = ""
cfg.CertFile = ""
cfg.Ports = config.Ports{
HTTPS: []string{fmt.Sprintf(":%d", GetIntPort(httpsBasePort)+100)},
HTTPS: []string{":0"},
}
sut, err := NewServer(ctx, &cfg)

sut, err := newTLSConfig(&cfg)
Expect(err).Should(Succeed())
Expect(sut.tlsCfg.Certificates).ShouldNot(BeEmpty())
Expect(sut.Certificates).ShouldNot(BeEmpty())
})
})
})
Expand Down

0 comments on commit 39ae088

Please sign in to comment.