Skip to content

Commit

Permalink
Merge pull request #141 from smlx/setkey
Browse files Browse the repository at this point in the history
fix: synchronise access to security key hardware
  • Loading branch information
smlx authored Nov 4, 2022
2 parents 4a270a5 + be80dd6 commit eb7534d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 19 deletions.
21 changes: 14 additions & 7 deletions internal/keyservice/piv/ecdhkey.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"regexp"
"sync"

pivgo "github.com/go-piv/piv-go/piv"
"github.com/smlx/piv-agent/internal/assuan"
Expand All @@ -17,12 +18,16 @@ var ciphertextECDH = regexp.MustCompile(

// ECDHKey implements ECDH using an underlying ECDSA key.
type ECDHKey struct {
ecdsa *pivgo.ECDSAPrivateKey
mu *sync.Mutex
*pivgo.ECDSAPrivateKey
}

// Decrypt performs ECDH as per gpg-agent.
// Decrypt performs ECDH as per gpg-agent, and implements the crypto.Decrypter
// interface.
func (k *ECDHKey) Decrypt(_ io.Reader, sexp []byte,
_ crypto.DecrypterOpts) ([]byte, error) {
k.mu.Lock()
defer k.mu.Unlock()
// parse out the ephemeral public key
matches := ciphertextECDH.FindAllSubmatch(sexp, -1)
ciphertext := matches[0][2]
Expand All @@ -40,7 +45,7 @@ func (k *ECDHKey) Decrypt(_ io.Reader, sexp []byte,
Y: ephPubY,
}
// marshal, encode, and return the result
shared, err := k.ecdsa.SharedKey(&ephPub)
shared, err := k.SharedKey(&ephPub)
if err != nil {
return nil, fmt.Errorf("couldn't generate shared secret: %v", err)
}
Expand All @@ -49,8 +54,10 @@ func (k *ECDHKey) Decrypt(_ io.Reader, sexp []byte,
return []byte(fmt.Sprintf("D (5:value%d:%s)\nOK\n", sharedLen, shared)), nil
}

// Public implements the other required method of the crypto.Decrypter and
// crypto.Signer interfaces.
func (k *ECDHKey) Public() crypto.PublicKey {
return k.ecdsa.Public()
// Sign wraps the underlying private key Sign operation in a mutex.
func (k *ECDHKey) Sign(rand io.Reader, digest []byte,
opts crypto.SignerOpts) ([]byte, error) {
k.mu.Lock()
defer k.mu.Unlock()
return k.ECDSAPrivateKey.Sign(rand, digest, opts)
}
30 changes: 22 additions & 8 deletions internal/keyservice/piv/keyservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ func (*KeyService) Name() string {
// Keygrips returns a single slice of concatenated keygrip byteslices - one for
// each cryptographic key available on the keyservice.
func (p *KeyService) Keygrips() ([][]byte, error) {
p.mu.Lock()
defer p.mu.Unlock()
var grips [][]byte
securityKeys, err := p.SecurityKeys()
securityKeys, err := p.getSecurityKeys()
if err != nil {
return nil, fmt.Errorf("couldn't get security keys: %w", err)
}
Expand All @@ -64,7 +66,9 @@ func (p *KeyService) Keygrips() ([][]byte, error) {
// HaveKey takes a list of keygrips, and returns a boolean indicating if any of
// the given keygrips were found, the found keygrip, and an error, if any.
func (p *KeyService) HaveKey(keygrips [][]byte) (bool, []byte, error) {
securityKeys, err := p.SecurityKeys()
p.mu.Lock()
defer p.mu.Unlock()
securityKeys, err := p.getSecurityKeys()
if err != nil {
return false, nil, fmt.Errorf("couldn't get security keys: %w", err)
}
Expand All @@ -90,7 +94,7 @@ func (p *KeyService) HaveKey(keygrips [][]byte) (bool, []byte, error) {
}

func (p *KeyService) getPrivateKey(keygrip []byte) (crypto.PrivateKey, error) {
securityKeys, err := p.SecurityKeys()
securityKeys, err := p.getSecurityKeys()
if err != nil {
return nil, fmt.Errorf("couldn't get security keys: %w", err)
}
Expand All @@ -110,7 +114,11 @@ func (p *KeyService) getPrivateKey(keygrip []byte) (crypto.PrivateKey, error) {
if err != nil {
return nil, fmt.Errorf("couldn't get private key from slot")
}
return privKey, nil
pivGoPrivKey, ok := privKey.(*pivgo.ECDSAPrivateKey)
if !ok {
return nil, fmt.Errorf("unexpected private key type: %T", privKey)
}
return &ECDHKey{mu: &p.mu, ECDSAPrivateKey: pivGoPrivKey}, nil
}
}
}
Expand All @@ -119,33 +127,39 @@ func (p *KeyService) getPrivateKey(keygrip []byte) (crypto.PrivateKey, error) {

// GetSigner returns a crypto.Signer associated with the given keygrip.
func (p *KeyService) GetSigner(keygrip []byte) (crypto.Signer, error) {
p.mu.Lock()
defer p.mu.Unlock()
privKey, err := p.getPrivateKey(keygrip)
if err != nil {
return nil, fmt.Errorf("couldn't get private key: %v", err)
}
signingPrivKey, ok := privKey.(crypto.Signer)
if !ok {
return nil, fmt.Errorf("private key is invalid type")
return nil, fmt.Errorf("private key is not a signer")
}
return signingPrivKey, nil
}

// GetDecrypter returns a crypto.Decrypter associated with the given keygrip.
func (p *KeyService) GetDecrypter(keygrip []byte) (crypto.Decrypter, error) {
p.mu.Lock()
defer p.mu.Unlock()
privKey, err := p.getPrivateKey(keygrip)
if err != nil {
return nil, fmt.Errorf("couldn't get private key: %v", err)
}
ecdsaPrivKey, ok := privKey.(*pivgo.ECDSAPrivateKey)
decryptingPrivKey, ok := privKey.(crypto.Decrypter)
if !ok {
return nil, fmt.Errorf("private key is invalid type")
return nil, fmt.Errorf("private key is not a decrypter")
}
return &ECDHKey{ecdsa: ecdsaPrivKey}, nil
return decryptingPrivKey, nil
}

// CloseAll closes all security keys without checking for errors.
// This should be called to clean up connections to `pcscd`.
func (p *KeyService) CloseAll() {
p.mu.Lock()
defer p.mu.Unlock()
p.log.Debug("closing security keys", zap.Int("count", len(p.securityKeys)))
for _, k := range p.securityKeys {
if err := k.Close(); err != nil {
Expand Down
13 changes: 9 additions & 4 deletions internal/keyservice/piv/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ func (p *KeyService) reloadSecurityKeys() error {
return nil
}

// SecurityKeys returns a slice containing all available security keys.
func (p *KeyService) SecurityKeys() ([]SecurityKey, error) {
p.mu.Lock()
defer p.mu.Unlock()
func (p *KeyService) getSecurityKeys() ([]SecurityKey, error) {
var err error
// check if any securityKeys are cached, and if not then cache them
if len(p.securityKeys) == 0 {
Expand All @@ -68,6 +65,7 @@ func (p *KeyService) SecurityKeys() ([]SecurityKey, error) {
// check they are healthy, and reload if not
for _, k := range p.securityKeys {
if _, err = k.AttestationCertificate(); err != nil {
p.log.Debug("PIV KeyService: couldn't get AttestationCertificate()", zap.Error(err))
if err = p.reloadSecurityKeys(); err != nil {
return nil, fmt.Errorf("couldn't reload security keys: %v", err)
}
Expand All @@ -76,3 +74,10 @@ func (p *KeyService) SecurityKeys() ([]SecurityKey, error) {
}
return p.securityKeys, nil
}

// SecurityKeys returns a slice containing all available security keys.
func (p *KeyService) SecurityKeys() ([]SecurityKey, error) {
p.mu.Lock()
defer p.mu.Unlock()
return p.getSecurityKeys()
}

0 comments on commit eb7534d

Please sign in to comment.