Skip to content

Commit

Permalink
refactor: cleanup TLS self-signed cert generation
Browse files Browse the repository at this point in the history
It's now actually a self-signed cert, instead of using a CA no one will
ever see.
  • Loading branch information
ThinkChaos committed Aug 30, 2024
1 parent 006afd1 commit 02944fc
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 89 deletions.
60 changes: 60 additions & 0 deletions helpertest/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package helpertest

import (
"crypto/tls"
"crypto/x509"
"sync"

"github.com/0xERR0R/blocky/util"
. "github.com/onsi/gomega"
)

const tlsTestServerName = "test.blocky.invalid"

type tlsData struct {
ServerCfg *tls.Config
ClientCfg *tls.Config
}

// Lazy init
//
//nolint:gochecknoglobals
var (
initTLSData sync.Once
tlsDataStorage tlsData
)

func getTLSData() tlsData {
initTLSData.Do(func() {
cert, err := util.TLSGenerateSelfSignedCert([]string{tlsTestServerName})
Expect(err).Should(Succeed())

tlsDataStorage.ServerCfg = &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}

certPool := x509.NewCertPool()
certPool.AddCert(cert.Leaf)

tlsDataStorage.ClientCfg = &tls.Config{
RootCAs: certPool,
ServerName: tlsTestServerName,
MinVersion: tls.VersionTLS13,
}
})

return tlsDataStorage
}

// TLSTestServerConfig returns a TLS Config for use by test servers.
func TLSTestServerConfig() *tls.Config {
return getTLSData().ServerCfg.Clone()
}

// TLSTestServerConfig returns a TLS Config for use by test clients.
//
// This is required to connect to a test TLS server, otherwise TLS verification fails.
func TLSTestClientConfig() *tls.Config {
return getTLSData().ClientCfg.Clone()
}
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type NewServerFunc func(address string) (*dns.Server, error)

func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
if cfg.CertFile == "" && cfg.KeyFile == "" {
cert, err = util.CreateSelfSignedCert()
cert, err = util.TLSGenerateSelfSignedCert([]string{"blocky.invalid", "*"})
if err != nil {
return tls.Certificate{}, fmt.Errorf("unable to generate self-signed certificate: %w", err)
}
Expand Down
9 changes: 3 additions & 6 deletions service/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"net"
"time"

"github.com/0xERR0R/blocky/util"
"github.com/0xERR0R/blocky/helpertest"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)
Expand Down Expand Up @@ -108,14 +108,11 @@ var _ = Describe("Service Listener", func() {
Entry("ListenTLS",
entryFuncs{
Listen: func(ctx context.Context, endpoint Endpoint) (Listener, error) {
cert, err := util.CreateSelfSignedCert()
Expect(err).Should(Succeed())

return ListenTLS(ctx, endpoint, &tls.Config{Certificates: []tls.Certificate{cert}})
return ListenTLS(ctx, endpoint, helpertest.TLSTestServerConfig())
},
Dial: func(ctx context.Context, addr string) (net.Conn, error) {
d := tls.Dialer{
Config: &tls.Config{InsecureSkipVerify: true},
Config: helpertest.TLSTestClientConfig(),
}

return d.DialContext(ctx, "tcp", addr)
Expand Down
113 changes: 31 additions & 82 deletions util/tls.go
Original file line number Diff line number Diff line change
@@ -1,118 +1,67 @@
package util

import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"math"
"crypto/x509/pkix"
"fmt"
"math/big"
mrand "math/rand"
"time"
)

const (
caExpiryYears = 10
certExpiryYears = 5
certSerialMaxBits = 128
certExpiryYears = 5
)

//nolint:funlen
func CreateSelfSignedCert() (tls.Certificate, error) {
// Create CA
ca := &x509.Certificate{
//nolint:gosec
SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))),
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(caExpiryYears, 0, 0),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}

caPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
// TLSGenerateSelfSignedCert returns a new self-signed cert for the given domains.
//
// Being self-signed, no client will trust this certificate.
func TLSGenerateSelfSignedCert(domains []string) (tls.Certificate, error) {
serialMax := new(big.Int).Lsh(big.NewInt(1), certSerialMaxBits)
serial, err := rand.Int(rand.Reader, serialMax)
if err != nil {
return tls.Certificate{}, err
}

caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey)
if err != nil {
return tls.Certificate{}, err
}
template := &x509.Certificate{
SerialNumber: serial,

caPEM := new(bytes.Buffer)
if err = pem.Encode(caPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}); err != nil {
return tls.Certificate{}, err
}
Subject: pkix.Name{Organization: []string{"Blocky"}},
DNSNames: domains,

caPrivKeyPEM := new(bytes.Buffer)

b, err := x509.MarshalECPrivateKey(caPrivKey)
if err != nil {
return tls.Certificate{}, err
}

if err = pem.Encode(caPrivKeyPEM, &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: b,
}); err != nil {
return tls.Certificate{}, err
}
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(certExpiryYears, 0, 0),

// Create certificate
cert := &x509.Certificate{
//nolint:gosec
SerialNumber: big.NewInt(int64(mrand.Intn(math.MaxInt))),
DNSNames: []string{"*"},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(certExpiryYears, 0, 0),
SubjectKeyId: []byte{1, 2, 3, 4, 6},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature,
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}

certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return tls.Certificate{}, err
return tls.Certificate{}, fmt.Errorf("unable to generate private key: %w", err)
}

certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey)
der, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
if err != nil {
return tls.Certificate{}, err
return tls.Certificate{}, fmt.Errorf("cert creation from template failed: %w", err)
}

certPEM := new(bytes.Buffer)
if err = pem.Encode(certPEM, &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}); err != nil {
return tls.Certificate{}, err
}

certPrivKeyPEM := new(bytes.Buffer)

b, err = x509.MarshalECPrivateKey(certPrivKey)
// Parse the generated DER back into a useable cert
// This avoids needing to do it for each TLS handshake (see tls.Certificate.Leaf comment)
cert, err := x509.ParseCertificate(der)
if err != nil {
return tls.Certificate{}, err
}

if err = pem.Encode(certPrivKeyPEM, &pem.Block{
Type: "EC PRIVATE KEY",
Bytes: b,
}); err != nil {
return tls.Certificate{}, err
return tls.Certificate{}, fmt.Errorf("generated cert DER could not be parsed: %w", err)
}

keyPair, err := tls.X509KeyPair(certPEM.Bytes(), certPrivKeyPEM.Bytes())
if err != nil {
return tls.Certificate{}, err
tlsCert := tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: privKey,
Leaf: cert,
}

return keyPair, nil
return tlsCert, nil
}

0 comments on commit 02944fc

Please sign in to comment.