diff --git a/internal/keyservice/piv/ecdhkey.go b/internal/keyservice/piv/ecdhkey.go index e459476..7acdafe 100644 --- a/internal/keyservice/piv/ecdhkey.go +++ b/internal/keyservice/piv/ecdhkey.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "regexp" + "sync" pivgo "github.com/go-piv/piv-go/piv" "github.com/smlx/piv-agent/internal/assuan" @@ -17,12 +18,16 @@ var ciphertextECDH = regexp.MustCompile( // ECDHKey implements ECDH using an underlying ECDSA key. type ECDHKey struct { - ecdsa *pivgo.ECDSAPrivateKey + mu *sync.Mutex + *pivgo.ECDSAPrivateKey } -// Decrypt performs ECDH as per gpg-agent. +// Decrypt performs ECDH as per gpg-agent, and implements the crypto.Decrypter +// interface. func (k *ECDHKey) Decrypt(_ io.Reader, sexp []byte, _ crypto.DecrypterOpts) ([]byte, error) { + k.mu.Lock() + defer k.mu.Unlock() // parse out the ephemeral public key matches := ciphertextECDH.FindAllSubmatch(sexp, -1) ciphertext := matches[0][2] @@ -40,7 +45,7 @@ func (k *ECDHKey) Decrypt(_ io.Reader, sexp []byte, Y: ephPubY, } // marshal, encode, and return the result - shared, err := k.ecdsa.SharedKey(&ephPub) + shared, err := k.SharedKey(&ephPub) if err != nil { return nil, fmt.Errorf("couldn't generate shared secret: %v", err) } @@ -49,8 +54,10 @@ func (k *ECDHKey) Decrypt(_ io.Reader, sexp []byte, return []byte(fmt.Sprintf("D (5:value%d:%s)\nOK\n", sharedLen, shared)), nil } -// Public implements the other required method of the crypto.Decrypter and -// crypto.Signer interfaces. -func (k *ECDHKey) Public() crypto.PublicKey { - return k.ecdsa.Public() +// Sign wraps the underlying private key Sign operation in a mutex. +func (k *ECDHKey) Sign(rand io.Reader, digest []byte, + opts crypto.SignerOpts) ([]byte, error) { + k.mu.Lock() + defer k.mu.Unlock() + return k.ECDSAPrivateKey.Sign(rand, digest, opts) } diff --git a/internal/keyservice/piv/keyservice.go b/internal/keyservice/piv/keyservice.go index 566bfdf..5b4b78e 100644 --- a/internal/keyservice/piv/keyservice.go +++ b/internal/keyservice/piv/keyservice.go @@ -39,8 +39,10 @@ func (*KeyService) Name() string { // Keygrips returns a single slice of concatenated keygrip byteslices - one for // each cryptographic key available on the keyservice. func (p *KeyService) Keygrips() ([][]byte, error) { + p.mu.Lock() + defer p.mu.Unlock() var grips [][]byte - securityKeys, err := p.SecurityKeys() + securityKeys, err := p.getSecurityKeys() if err != nil { return nil, fmt.Errorf("couldn't get security keys: %w", err) } @@ -64,7 +66,9 @@ func (p *KeyService) Keygrips() ([][]byte, error) { // HaveKey takes a list of keygrips, and returns a boolean indicating if any of // the given keygrips were found, the found keygrip, and an error, if any. func (p *KeyService) HaveKey(keygrips [][]byte) (bool, []byte, error) { - securityKeys, err := p.SecurityKeys() + p.mu.Lock() + defer p.mu.Unlock() + securityKeys, err := p.getSecurityKeys() if err != nil { return false, nil, fmt.Errorf("couldn't get security keys: %w", err) } @@ -90,7 +94,7 @@ func (p *KeyService) HaveKey(keygrips [][]byte) (bool, []byte, error) { } func (p *KeyService) getPrivateKey(keygrip []byte) (crypto.PrivateKey, error) { - securityKeys, err := p.SecurityKeys() + securityKeys, err := p.getSecurityKeys() if err != nil { return nil, fmt.Errorf("couldn't get security keys: %w", err) } @@ -110,7 +114,11 @@ func (p *KeyService) getPrivateKey(keygrip []byte) (crypto.PrivateKey, error) { if err != nil { return nil, fmt.Errorf("couldn't get private key from slot") } - return privKey, nil + pivGoPrivKey, ok := privKey.(*pivgo.ECDSAPrivateKey) + if !ok { + return nil, fmt.Errorf("unexpected private key type: %T", privKey) + } + return &ECDHKey{mu: &p.mu, ECDSAPrivateKey: pivGoPrivKey}, nil } } } @@ -119,33 +127,39 @@ func (p *KeyService) getPrivateKey(keygrip []byte) (crypto.PrivateKey, error) { // GetSigner returns a crypto.Signer associated with the given keygrip. func (p *KeyService) GetSigner(keygrip []byte) (crypto.Signer, error) { + p.mu.Lock() + defer p.mu.Unlock() privKey, err := p.getPrivateKey(keygrip) if err != nil { return nil, fmt.Errorf("couldn't get private key: %v", err) } signingPrivKey, ok := privKey.(crypto.Signer) if !ok { - return nil, fmt.Errorf("private key is invalid type") + return nil, fmt.Errorf("private key is not a signer") } return signingPrivKey, nil } // GetDecrypter returns a crypto.Decrypter associated with the given keygrip. func (p *KeyService) GetDecrypter(keygrip []byte) (crypto.Decrypter, error) { + p.mu.Lock() + defer p.mu.Unlock() privKey, err := p.getPrivateKey(keygrip) if err != nil { return nil, fmt.Errorf("couldn't get private key: %v", err) } - ecdsaPrivKey, ok := privKey.(*pivgo.ECDSAPrivateKey) + decryptingPrivKey, ok := privKey.(crypto.Decrypter) if !ok { - return nil, fmt.Errorf("private key is invalid type") + return nil, fmt.Errorf("private key is not a decrypter") } - return &ECDHKey{ecdsa: ecdsaPrivKey}, nil + return decryptingPrivKey, nil } // CloseAll closes all security keys without checking for errors. // This should be called to clean up connections to `pcscd`. func (p *KeyService) CloseAll() { + p.mu.Lock() + defer p.mu.Unlock() p.log.Debug("closing security keys", zap.Int("count", len(p.securityKeys))) for _, k := range p.securityKeys { if err := k.Close(); err != nil { diff --git a/internal/keyservice/piv/list.go b/internal/keyservice/piv/list.go index f4a410f..d745b52 100644 --- a/internal/keyservice/piv/list.go +++ b/internal/keyservice/piv/list.go @@ -54,10 +54,7 @@ func (p *KeyService) reloadSecurityKeys() error { return nil } -// SecurityKeys returns a slice containing all available security keys. -func (p *KeyService) SecurityKeys() ([]SecurityKey, error) { - p.mu.Lock() - defer p.mu.Unlock() +func (p *KeyService) getSecurityKeys() ([]SecurityKey, error) { var err error // check if any securityKeys are cached, and if not then cache them if len(p.securityKeys) == 0 { @@ -68,6 +65,7 @@ func (p *KeyService) SecurityKeys() ([]SecurityKey, error) { // check they are healthy, and reload if not for _, k := range p.securityKeys { if _, err = k.AttestationCertificate(); err != nil { + p.log.Debug("PIV KeyService: couldn't get AttestationCertificate()", zap.Error(err)) if err = p.reloadSecurityKeys(); err != nil { return nil, fmt.Errorf("couldn't reload security keys: %v", err) } @@ -76,3 +74,10 @@ func (p *KeyService) SecurityKeys() ([]SecurityKey, error) { } return p.securityKeys, nil } + +// SecurityKeys returns a slice containing all available security keys. +func (p *KeyService) SecurityKeys() ([]SecurityKey, error) { + p.mu.Lock() + defer p.mu.Unlock() + return p.getSecurityKeys() +}