Skip to content

Commit

Permalink
Moving DecodeSegement to Parser (#278)
Browse files Browse the repository at this point in the history
* Moving `DecodeSegement` to `Parser`

This would allow us to remove some global variables and move them to parser options as well as potentially introduce interfaces for json and b64 encoding/decoding to replace the std lib, if someone wanted to do that for performance reasons.

We keep the functions exported because of explicit user demand.

* Sign/Verify does take the decoded form now
  • Loading branch information
oxisto authored Mar 24, 2023
1 parent c6ec5a2 commit b357385
Show file tree
Hide file tree
Showing 19 changed files with 212 additions and 196 deletions.
22 changes: 7 additions & 15 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,7 @@ func (m *SigningMethodECDSA) Alg() string {

// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ecdsa.PublicKey struct
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error

// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}

func (m *SigningMethodECDSA) Verify(signingString string, sig []byte, key interface{}) error {
// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
Expand Down Expand Up @@ -97,19 +89,19 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa

// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ecdsa.PrivateKey struct
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte, error) {
// Get the key
var ecdsaKey *ecdsa.PrivateKey
switch k := key.(type) {
case *ecdsa.PrivateKey:
ecdsaKey = k
default:
return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}

// Create the hasher
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := m.Hash.New()
Expand All @@ -120,7 +112,7 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
curveBits := ecdsaKey.Curve.Params().BitSize

if m.CurveBits != curveBits {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

keyBytes := curveBits / 8
Expand All @@ -135,8 +127,8 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) (string
r.FillBytes(out[0:keyBytes]) // r is assigned to the first half of output.
s.FillBytes(out[keyBytes:]) // s is assigned to the second half of output.

return EncodeSegment(out), nil
return out, nil
} else {
return "", err
return nil, err
}
}
26 changes: 21 additions & 5 deletions ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt_test
import (
"crypto/ecdsa"
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -65,7 +66,7 @@ func TestECDSAVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ecdsaKey)
err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ecdsaKey)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand All @@ -90,12 +91,13 @@ func TestECDSASign(t *testing.T) {
toSign := strings.Join(parts[0:2], ".")
method := jwt.GetSigningMethod(data.alg)
sig, err := method.Sign(toSign, ecdsaKey)

if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)

ssig := encodeSegment(sig)
if ssig == parts[2] {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
}

err = method.Verify(toSign, sig, ecdsaKey.Public())
Expand Down Expand Up @@ -155,10 +157,24 @@ func BenchmarkECDSASigning(b *testing.B) {
if err != nil {
b.Fatalf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] {
if reflect.DeepEqual(sig, decodeSegment(b, parts[2])) {
b.Fatalf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)
}
}
})
}
}

func decodeSegment(t interface{ Fatalf(string, ...any) }, signature string) (sig []byte) {
var err error
sig, err = jwt.NewParser().DecodeSegment(signature)
if err != nil {
t.Fatalf("could not decode segment: %v", err)
}

return
}

func encodeSegment(sig []byte) string {
return (&jwt.Token{}).EncodeSegment(sig)
}
25 changes: 10 additions & 15 deletions ed25519.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ func (m *SigningMethodEd25519) Alg() string {

// Verify implements token verification for the SigningMethod.
// For this verify method, key must be an ed25519.PublicKey
func (m *SigningMethodEd25519) Verify(signingString, signature string, key interface{}) error {
var err error
func (m *SigningMethodEd25519) Verify(signingString string, sig []byte, key interface{}) error {
var ed25519Key ed25519.PublicKey
var ok bool

Expand All @@ -47,12 +46,6 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter
return ErrInvalidKey
}

// Decode the signature
var sig []byte
if sig, err = DecodeSegment(signature); err != nil {
return err
}

// Verify the signature
if !ed25519.Verify(ed25519Key, []byte(signingString), sig) {
return ErrEd25519Verification
Expand All @@ -63,23 +56,25 @@ func (m *SigningMethodEd25519) Verify(signingString, signature string, key inter

// Sign implements token signing for the SigningMethod.
// For this signing method, key must be an ed25519.PrivateKey
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]byte, error) {
var ed25519Key crypto.Signer
var ok bool

if ed25519Key, ok = key.(crypto.Signer); !ok {
return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}

if _, ok := ed25519Key.Public().(ed25519.PublicKey); !ok {
return "", ErrInvalidKey
return nil, ErrInvalidKey
}

// Sign the string and return the encoded result
// ed25519 performs a two-pass hash as part of its algorithm. Therefore, we need to pass a non-prehashed message into the Sign function, as indicated by crypto.Hash(0)
// Sign the string and return the result. ed25519 performs a two-pass hash
// as part of its algorithm. Therefore, we need to pass a non-prehashed
// message into the Sign function, as indicated by crypto.Hash(0)
sig, err := ed25519Key.Sign(rand.Reader, []byte(signingString), crypto.Hash(0))
if err != nil {
return "", err
return nil, err
}
return EncodeSegment(sig), nil

return sig, nil
}
8 changes: 5 additions & 3 deletions ed25519_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestEd25519Verify(t *testing.T) {

method := jwt.GetSigningMethod(data.alg)

err = method.Verify(strings.Join(parts[0:2], "."), parts[2], ed25519Key)
err = method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), ed25519Key)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand Down Expand Up @@ -77,8 +77,10 @@ func TestEd25519Sign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig == parts[2] && !data.valid {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], sig)

ssig := encodeSegment(sig)
if ssig == parts[2] && !data.valid {
t.Errorf("[%v] Identical signatures\nbefore:\n%v\nafter:\n%v", data.name, parts[2], ssig)
}
}
}
16 changes: 5 additions & 11 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,13 @@ func (m *SigningMethodHMAC) Alg() string {
}

// Verify implements token verification for the SigningMethod. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
func (m *SigningMethodHMAC) Verify(signingString string, sig []byte, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}

// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
return err
}

// Can we use the specified hashing method?
if !m.Hash.Available() {
return ErrHashUnavailable
Expand All @@ -79,17 +73,17 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac

// Sign implements token signing for the SigningMethod.
// Key must be []byte
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) (string, error) {
func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, error) {
if keyBytes, ok := key.([]byte); ok {
if !m.Hash.Available() {
return "", ErrHashUnavailable
return nil, ErrHashUnavailable
}

hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))

return EncodeSegment(hasher.Sum(nil)), nil
return hasher.Sum(nil), nil
}

return "", ErrInvalidKeyType
return nil, ErrInvalidKeyType
}
5 changes: 3 additions & 2 deletions hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt_test

import (
"os"
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -53,7 +54,7 @@ func TestHMACVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], hmacTestKey)
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), hmacTestKey)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand All @@ -72,7 +73,7 @@ func TestHMACSign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
Expand Down
11 changes: 6 additions & 5 deletions none.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ func (m *signingMethodNone) Alg() string {
}

// Only allow 'none' alg type if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Verify(signingString, signature string, key interface{}) (err error) {
func (m *signingMethodNone) Verify(signingString string, sig []byte, key interface{}) (err error) {
// Key must be UnsafeAllowNoneSignatureType to prevent accidentally
// accepting 'none' signing method
if _, ok := key.(unsafeNoneMagicConstant); !ok {
return NoneSignatureTypeDisallowedError
}
// If signing method is none, signature must be an empty string
if signature != "" {
if string(sig) != "" {
return newError("'none' signing method with non-empty signature", ErrTokenUnverifiable)
}

Expand All @@ -41,9 +41,10 @@ func (m *signingMethodNone) Verify(signingString, signature string, key interfac
}

// Only allow 'none' signing if UnsafeAllowNoneSignatureType is specified as the key
func (m *signingMethodNone) Sign(signingString string, key interface{}) (string, error) {
func (m *signingMethodNone) Sign(signingString string, key interface{}) ([]byte, error) {
if _, ok := key.(unsafeNoneMagicConstant); ok {
return "", nil
return []byte{}, nil
}
return "", NoneSignatureTypeDisallowedError

return nil, NoneSignatureTypeDisallowedError
}
5 changes: 3 additions & 2 deletions none_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"reflect"
"strings"
"testing"

Expand Down Expand Up @@ -46,7 +47,7 @@ func TestNoneVerify(t *testing.T) {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], data.key)
err := method.Verify(strings.Join(parts[0:2], "."), decodeSegment(t, parts[2]), data.key)
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
Expand All @@ -65,7 +66,7 @@ func TestNoneSign(t *testing.T) {
if err != nil {
t.Errorf("[%v] Error signing token: %v", data.name, err)
}
if sig != parts[2] {
if !reflect.DeepEqual(sig, decodeSegment(t, parts[2])) {
t.Errorf("[%v] Incorrect signature.\nwas:\n%v\nexpecting:\n%v", data.name, sig, parts[2])
}
}
Expand Down
Loading

0 comments on commit b357385

Please sign in to comment.