Skip to content

Commit

Permalink
cryto/tls: Implement kemtls with mutual auth #66
Browse files Browse the repository at this point in the history
  • Loading branch information
claucece committed Mar 30, 2021
1 parent 5ef1b90 commit c2495e4
Show file tree
Hide file tree
Showing 70 changed files with 5,902 additions and 3,430 deletions.
207 changes: 207 additions & 0 deletions src/crypto/kem/kem.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package kem

import (
"circl/dh/sidh"
"circl/kem/schemes"
"encoding/binary"
"errors"
"fmt"
"io"

"golang.org/x/crypto/curve25519"
)

// ID identifies each flavor of KEM.
type ID uint16

const (
// KEM25519 is X25519 as a KEM. Not quantum-safe.
KEM25519 ID = 0x01fb
// Kyber512 is a post-quantum KEM based on MLWE
Kyber512 ID = 0x01fc
// SIKEp434 is a post-quantum KEM
SIKEp434 ID = 0x01fd

// minimum
minKEM = KEM25519
// maximum
maxKEM = SIKEp434
)

// PrivateKey is a private key.
type PrivateKey struct {
KEMId ID
PrivateKey []byte
}

// PublicKey is a public key.
type PublicKey struct {
KEMId ID
PublicKey []byte
}

// MarshalBinary returns the byte representation of a public key.
func (pubKey *PublicKey) MarshalBinary() ([]byte, error) {
buf := make([]byte, 2+len(pubKey.PublicKey))
binary.LittleEndian.PutUint16(buf, uint16(pubKey.KEMId))
copy(buf[2:], pubKey.PublicKey)
return buf, nil
}

// UnmarshalBinary produces a PublicKey from a byte array.
func (pubKey *PublicKey) UnmarshalBinary(data []byte) error {
id := ID(binary.LittleEndian.Uint16(data[:2]))
if id < minKEM || id > maxKEM {
return errors.New("Invalid KEM type")
}

pubKey.KEMId = id
pubKey.PublicKey = data[2:]
return nil
}

// GenerateKey generates a keypair for a given KEM.
// It returns a public and private key.
func GenerateKey(rand io.Reader, kemID ID) (*PublicKey, *PrivateKey, error) {
switch kemID {
case Kyber512:
scheme := schemes.ByName("Kyber512")
seed := make([]byte, scheme.SeedSize())
if _, err := io.ReadFull(rand, seed); err != nil {
return nil, nil, err
}
publicKey, privateKey := scheme.DeriveKeyPair(seed)
pk, _ := publicKey.MarshalBinary()
sk, _ := privateKey.MarshalBinary()

return &PublicKey{KEMId: kemID, PublicKey: pk}, &PrivateKey{KEMId: kemID, PrivateKey: sk}, nil
case KEM25519:
privateKey := make([]byte, curve25519.ScalarSize)
if _, err := io.ReadFull(rand, privateKey); err != nil {
return nil, nil, err
}
publicKey, err := curve25519.X25519(privateKey, curve25519.Basepoint)
if err != nil {
return nil, nil, err
}
return &PublicKey{KEMId: kemID, PublicKey: publicKey}, &PrivateKey{KEMId: kemID, PrivateKey: privateKey}, nil
case SIKEp434:
privateKey := sidh.NewPrivateKey(sidh.Fp434, sidh.KeyVariantSike)
publicKey := sidh.NewPublicKey(sidh.Fp434, sidh.KeyVariantSike)
if err := privateKey.Generate(rand); err != nil {
return nil, nil, err
}
privateKey.GeneratePublicKey(publicKey)

pubBytes := make([]byte, publicKey.Size())
privBytes := make([]byte, privateKey.Size())
publicKey.Export(pubBytes)
privateKey.Export(privBytes)
return &PublicKey{KEMId: kemID, PublicKey: pubBytes}, &PrivateKey{KEMId: kemID, PrivateKey: privBytes}, nil
default:
return nil, nil, fmt.Errorf("crypto/kem: internal error: unsupported KEM %d", kemID)
}

}

// Encapsulate returns a shared secret and a ciphertext.
func Encapsulate(rand io.Reader, pk *PublicKey) ([]byte, []byte, error) {
switch pk.KEMId {
case Kyber512:
scheme := schemes.ByName("Kyber512")
pub, err := scheme.UnmarshalBinaryPublicKey(pk.PublicKey)
if err != nil {
return nil, nil, err
}

seed := make([]byte, scheme.EncapsulationSeedSize())
if _, err := io.ReadFull(rand, seed); err != nil {
return nil, nil, err
}

ct, ss, err := scheme.EncapsulateDeterministically(pub, seed)
if err != nil {
return nil, nil, err
}

return ss, ct, nil
case KEM25519:
privateKey := make([]byte, curve25519.ScalarSize)
if _, err := io.ReadFull(rand, privateKey); err != nil {
return nil, nil, err
}
ciphertext, err := curve25519.X25519(privateKey, curve25519.Basepoint)
if err != nil {
return nil, nil, err
}
sharedSecret, err := curve25519.X25519(privateKey, pk.PublicKey)
if err != nil {
return nil, nil, err
}
return sharedSecret, ciphertext, nil
case SIKEp434:
kem := sidh.NewSike434(rand)
sikepk := sidh.NewPublicKey(sidh.Fp434, sidh.KeyVariantSike)
err := sikepk.Import(pk.PublicKey)
if err != nil {
return nil, nil, err
}

ct := make([]byte, kem.CiphertextSize())
ss := make([]byte, kem.SharedSecretSize())
err = kem.Encapsulate(ct, ss, sikepk)
if err != nil {
return nil, nil, err
}

return ss, ct, nil
default:
return nil, nil, errors.New("crypto/kem: internal error: unsupported KEM in Encapsulate")
}
}

// Decapsulate generates the shared secret.
func Decapsulate(privateKey *PrivateKey, ciphertext []byte) ([]byte, error) {
switch privateKey.KEMId {
case Kyber512:
scheme := schemes.ByName("Kyber512")
sk, err := scheme.UnmarshalBinaryPrivateKey(privateKey.PrivateKey)
if err != nil {
return nil, err
}
if len(ciphertext) != scheme.CiphertextSize() {
return nil, fmt.Errorf("crypto/kem: ciphertext is of len %d, expected %d", len(ciphertext), scheme.CiphertextSize())
}
ss, err := scheme.Decapsulate(sk, ciphertext)
if err != nil {
return nil, err
}

return ss, nil
case KEM25519:
sharedSecret, err := curve25519.X25519(privateKey.PrivateKey, ciphertext)
if err != nil {
return nil, err
}
return sharedSecret, nil
case SIKEp434:
kem := sidh.NewSike434(nil)
sikesk := sidh.NewPrivateKey(sidh.Fp434, sidh.KeyVariantSike)
err := sikesk.Import(privateKey.PrivateKey)
if err != nil {
return nil, err
}

sikepk := sidh.NewPublicKey(sidh.Fp434, sidh.KeyVariantSike)
sikesk.GeneratePublicKey(sikepk)
ss := make([]byte, kem.SharedSecretSize())
err = kem.Decapsulate(ss, sikesk, sikepk, ciphertext)
if err != nil {
return nil, err
}

return ss, nil
default:
return nil, errors.New("crypto/kem: internal error: unsupported KEM in Decapsulate")
}
}
58 changes: 58 additions & 0 deletions src/crypto/kem/kem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package kem

import (
"bytes"
"crypto/rand"
"testing"
)

func TestKemAPI(t *testing.T) {
tests := []struct {
name string
kemID ID
}{
{"Kem25519", KEM25519},
{"SIKEp434", SIKEp434},
{"Kyber512", Kyber512},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
publicKey, privateKey, err := GenerateKey(rand.Reader, tt.kemID)
if err != nil {
t.Fatal(err)
}
ss, ct, err := Encapsulate(rand.Reader, publicKey)
if err != nil {
t.Fatal(err)
}

ss2, err := Decapsulate(privateKey, ct)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(ss, ss2) {
t.Fatal("Decapsulated differing shared secret")
}

data, _ := publicKey.MarshalBinary()
pk2 := new(PublicKey)
err = pk2.UnmarshalBinary(data)
if err != nil {
t.Fatal("error unmarshaling")
}
if pk2.KEMId != publicKey.KEMId {
t.Fatal("Difference in Id")
}
if !bytes.Equal(publicKey.PublicKey, publicKey.PublicKey) {
t.Fatal("Difference in data for public keys")
}
})
}

// check if nonexisting kem fails
invalidKemID := ID(0)
if _, _, err := GenerateKey(rand.Reader, invalidKemID); err == nil {
t.Fatal("This KEM should've been invalid and failed")
}

}
70 changes: 43 additions & 27 deletions src/crypto/tls/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/kem"
"crypto/rsa"
"errors"
"fmt"
Expand Down Expand Up @@ -118,6 +119,8 @@ func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType
sigType = signatureECDSA
case Ed25519:
sigType = signatureEd25519
case KEMTLSWithSIKEp434, KEMTLSWithKyber512:
sigType = authKEMTLS
default:
scheme := circlPki.SchemeByTLSID(uint(signatureAlgorithm))
if scheme == nil {
Expand All @@ -140,6 +143,8 @@ func typeAndHashFromSignatureScheme(signatureAlgorithm SignatureScheme) (sigType
hash = crypto.SHA512
case Ed25519:
hash = directSigning
case KEMTLSWithSIKEp434, KEMTLSWithKyber512:
hash = directSigning
default:
scheme := circlPki.SchemeByTLSID(uint(signatureAlgorithm))
if scheme == nil {
Expand Down Expand Up @@ -267,39 +272,50 @@ func signatureSchemesForCertificate(version uint16, cert *Certificate) []Signatu
// This function must be kept in sync with supportedSignatureAlgorithmsDC.
func signatureSchemeForDelegatedCredential(version uint16, dc *DelegatedCredential) []SignatureScheme {
pub := dc.cred.publicKey

var sigAlgs []SignatureScheme
switch pub.(type) {
case *ecdsa.PublicKey:
pk, ok := pub.(*ecdsa.PublicKey)
if !ok {

kemPub, ok := pub.(*kem.PublicKey)
if ok {
if kemPub.KEMId == kem.SIKEp434 {
sigAlgs = []SignatureScheme{KEMTLSWithSIKEp434}
} else if kemPub.KEMId == kem.Kyber512 {
sigAlgs = []SignatureScheme{KEMTLSWithKyber512}
} else {
return nil
}
switch pk.Curve {
case elliptic.P256():
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
} else {
switch pub.(type) {
case *ecdsa.PublicKey:
pk, ok := pub.(*ecdsa.PublicKey)
if !ok {
return nil
}
switch pk.Curve {
case elliptic.P256():
sigAlgs = []SignatureScheme{ECDSAWithP256AndSHA256}
case elliptic.P384():
sigAlgs = []SignatureScheme{ECDSAWithP384AndSHA384}
case elliptic.P521():
sigAlgs = []SignatureScheme{ECDSAWithP521AndSHA512}
default:
return nil
}
case ed25519.PublicKey:
sigAlgs = []SignatureScheme{Ed25519}
case circlSign.PublicKey:
pk, ok := pub.(circlSign.PublicKey)
if !ok {
return nil
}
scheme := pk.Scheme()
tlsScheme, ok := scheme.(circlPki.TLSScheme)
if !ok {
return nil
}
sigAlgs = []SignatureScheme{SignatureScheme(tlsScheme.TLSIdentifier())}
default:
return nil
}
case ed25519.PublicKey:
sigAlgs = []SignatureScheme{Ed25519}
case circlSign.PublicKey:
pk, ok := pub.(circlSign.PublicKey)
if !ok {
return nil
}
scheme := pk.Scheme()
tlsScheme, ok := scheme.(circlPki.TLSScheme)
if !ok {
return nil
}
sigAlgs = []SignatureScheme{SignatureScheme(tlsScheme.TLSIdentifier())}
default:
return nil
}

return sigAlgs
Expand Down
2 changes: 1 addition & 1 deletion src/crypto/tls/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func TestSupportedSignatureAlgorithms(t *testing.T) {
if sigType == 0 {
t.Errorf("%v: missing signature type", sigAlg)
}
if hash == 0 && sigAlg != Ed25519 && circlPki.SchemeByTLSID(uint(sigAlg)) == nil {
if hash == 0 && (sigAlg != Ed25519 && sigType != authKEMTLS) && circlPki.SchemeByTLSID(uint(sigAlg)) == nil {
t.Errorf("%v: missing hash", sigAlg)
}
}
Expand Down
Loading

0 comments on commit c2495e4

Please sign in to comment.