Skip to content

Commit

Permalink
auth: add the caching sha2 algorithm for authentication (#1232)
Browse files Browse the repository at this point in the history
This allows validating passwords against the `authentication_string`
data that MySQL stores for caching_sha2 passwords.

Related:
- pingcap/tidb#9411
  • Loading branch information
dveeden authored Jun 2, 2021
1 parent 1686fda commit 10b704a
Show file tree
Hide file tree
Showing 9 changed files with 428 additions and 93 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ bin/
y.go
*.output
.idea/
.vscode/
coverage.txt
66 changes: 0 additions & 66 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,9 @@
package auth

import (
"bytes"
"crypto/sha1"
"encoding/hex"
"fmt"

"github.com/pingcap/errors"
"github.com/pingcap/parser/format"
"github.com/pingcap/parser/terror"
)

// UserIdentity represents username and hostname.
Expand Down Expand Up @@ -79,64 +74,3 @@ func (role *RoleIdentity) String() string {
// TODO: Escape username and hostname.
return fmt.Sprintf("`%s`@`%s`", role.Username, role.Hostname)
}

// CheckScrambledPassword check scrambled password received from client.
// The new authentication is performed in following manner:
// SERVER: public_seed=create_random_string()
// send(public_seed)
// CLIENT: recv(public_seed)
// hash_stage1=sha1("password")
// hash_stage2=sha1(hash_stage1)
// reply=xor(hash_stage1, sha1(public_seed,hash_stage2)
// // this three steps are done in scramble()
// send(reply)
// SERVER: recv(reply)
// hash_stage1=xor(reply, sha1(public_seed,hash_stage2))
// candidate_hash2=sha1(hash_stage1)
// check(candidate_hash2==hash_stage2)
// // this three steps are done in check_scramble()
func CheckScrambledPassword(salt, hpwd, auth []byte) bool {
crypt := sha1.New()
_, err := crypt.Write(salt)
terror.Log(errors.Trace(err))
_, err = crypt.Write(hpwd)
terror.Log(errors.Trace(err))
hash := crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
if len(auth) != len(hash) {
return false
}
for i := range hash {
hash[i] ^= auth[i]
}

return bytes.Equal(hpwd, Sha1Hash(hash))
}

// Sha1Hash is an util function to calculate sha1 hash.
func Sha1Hash(bs []byte) []byte {
crypt := sha1.New()
_, err := crypt.Write(bs)
terror.Log(errors.Trace(err))
return crypt.Sum(nil)
}

// EncodePassword converts plaintext password to hashed hex string.
func EncodePassword(pwd string) string {
if len(pwd) == 0 {
return ""
}
hash1 := Sha1Hash([]byte(pwd))
hash2 := Sha1Hash(hash1)

return fmt.Sprintf("*%X", hash2)
}

// DecodePassword converts hex string password without prefix '*' to byte array.
func DecodePassword(pwd string) ([]byte, error) {
x, err := hex.DecodeString(pwd[1:])
if err != nil {
return nil, errors.Trace(err)
}
return x, nil
}
29 changes: 4 additions & 25 deletions auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package auth

import (
"testing"

. "github.com/pingcap/check"
)

Expand All @@ -22,29 +24,6 @@ var _ = Suite(&testAuthSuite{})
type testAuthSuite struct {
}

func (s *testAuthSuite) TestEncodePassword(c *C) {
pwd := "123"
c.Assert(EncodePassword(pwd), Equals, "*23AE809DDACAF96AF0FD78ED04B6A265E05AA257")
}

func (s *testAuthSuite) TestDecodePassword(c *C) {
x, err := DecodePassword(EncodePassword("123"))
c.Assert(err, IsNil)
c.Assert(x, DeepEquals, Sha1Hash(Sha1Hash([]byte("123"))))
}

func (s *testAuthSuite) TestCheckScramble(c *C) {
pwd := "abc"
salt := []byte{85, 92, 45, 22, 58, 79, 107, 6, 122, 125, 58, 80, 12, 90, 103, 32, 90, 10, 74, 82}
auth := []byte{24, 180, 183, 225, 166, 6, 81, 102, 70, 248, 199, 143, 91, 204, 169, 9, 161, 171, 203, 33}
encodepwd := EncodePassword(pwd)
hpwd, err := DecodePassword(encodepwd)
c.Assert(err, IsNil)

res := CheckScrambledPassword(salt, hpwd, auth)
c.Assert(res, IsTrue)

// Do not panic for invalid input.
res = CheckScrambledPassword(salt, hpwd, []byte("xxyyzz"))
c.Assert(res, IsFalse)
func TestT(t *testing.T) {
TestingT(t)
}
210 changes: 210 additions & 0 deletions auth/caching_sha2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package auth

// Resources:
// - https://dev.mysql.com/doc/refman/8.0/en/caching-sha2-pluggable-authentication.html
// - https://dev.mysql.com/doc/dev/mysql-server/latest/page_caching_sha2_authentication_exchanges.html
// - https://dev.mysql.com/doc/dev/mysql-server/latest/namespacesha2__password.html
// - https://www.akkadia.org/drepper/SHA-crypt.txt
// - https://dev.mysql.com/worklog/task/?id=9591
//
// CREATE USER 'foo'@'%' IDENTIFIED BY 'foobar';
// SELECT HEX(authentication_string) FROM mysql.user WHERE user='foo';
// 24412430303524031A69251C34295C4B35167C7F1E5A7B63091349503974624D34504B5A424679354856336868686F52485A736E4A733368786E427575516C73446469496537
//
// Format:
// Split on '$':
// - digest type ("A")
// - iterations (divided by ITERATION_MULTIPLIER)
// - salt+hash
//

import (
"bytes"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"strconv"
)

const (
MIXCHARS = 32
SALT_LENGTH = 20
ITERATION_MULTIPLIER = 1000
)

func b64From24bit(b []byte, n int) []byte {
b64t := []byte("./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz")

w := (int64(b[0]) << 16) | (int64(b[1]) << 8) | int64(b[2])
ret := make([]byte, 0, n)
for n > 0 {
n--
ret = append(ret, b64t[w&0x3f])
w >>= 6
}

return ret
}

func sha256crypt(plaintext string, salt []byte, iterations int) string {
// Numbers in the comments refer to the description of the algorithm on https://www.akkadia.org/drepper/SHA-crypt.txt

// 1, 2, 3
tmpA := sha256.New()
tmpA.Write([]byte(plaintext))
tmpA.Write(salt)

// 4, 5, 6, 7, 8
tmpB := sha256.New()
tmpB.Write([]byte(plaintext))
tmpB.Write(salt)
tmpB.Write([]byte(plaintext))
sumB := tmpB.Sum(nil)

// 9, 10
var i int
for i = len(plaintext); i > MIXCHARS; i -= MIXCHARS {
tmpA.Write(sumB[:MIXCHARS])
}
tmpA.Write(sumB[:i])

// 11
for i = len(plaintext); i > 0; i >>= 1 {
if i%2 == 0 {
tmpA.Write([]byte(plaintext))
} else {
tmpA.Write(sumB)
}
}

// 12
sumA := tmpA.Sum(nil)

// 13, 14, 15
tmpDP := sha256.New()
for range []byte(plaintext) {
tmpDP.Write([]byte(plaintext))
}
sumDP := tmpDP.Sum(nil)

// 16
p := make([]byte, 0, sha256.Size)
for i = len(plaintext); i > 0; i -= MIXCHARS {
if i > MIXCHARS {
p = append(p, sumDP...)
} else {
p = append(p, sumDP[0:i]...)
}
}

// 17, 18, 19
tmpDS := sha256.New()
for i = 0; i < 16+int(sumA[0]); i++ {
tmpDS.Write(salt)
}
sumDS := tmpDS.Sum(nil)

// 20
s := []byte{}
for i = len(salt); i > 0; i -= MIXCHARS {
if i > MIXCHARS {
s = append(s, sumDS...)
} else {
s = append(s, sumDS[0:i]...)
}
}

// 21
tmpC := sha256.New()
var sumC []byte
for i = 0; i < iterations; i++ {
tmpC.Reset()

if i&1 != 0 {
tmpC.Write(p)
} else {
tmpC.Write(sumA)
}
if i%3 != 0 {
tmpC.Write(s)
}
if i%7 != 0 {
tmpC.Write(p)
}
if i&1 != 0 {
tmpC.Write(sumA)
} else {
tmpC.Write(p)
}
sumC = tmpC.Sum(nil)
copy(sumA, tmpC.Sum(nil))
}

// 22
buf := bytes.Buffer{}
buf.Grow(100) // FIXME
buf.Write([]byte{'$', 'A', '$'})
rounds := fmt.Sprintf("%03d", iterations/ITERATION_MULTIPLIER)
buf.Write([]byte(rounds))
buf.Write([]byte{'$'})
buf.Write(salt)

buf.Write(b64From24bit([]byte{sumC[0], sumC[10], sumC[20]}, 4))
buf.Write(b64From24bit([]byte{sumC[21], sumC[1], sumC[11]}, 4))
buf.Write(b64From24bit([]byte{sumC[12], sumC[22], sumC[2]}, 4))
buf.Write(b64From24bit([]byte{sumC[3], sumC[13], sumC[23]}, 4))
buf.Write(b64From24bit([]byte{sumC[24], sumC[4], sumC[14]}, 4))
buf.Write(b64From24bit([]byte{sumC[15], sumC[25], sumC[5]}, 4))
buf.Write(b64From24bit([]byte{sumC[6], sumC[16], sumC[26]}, 4))
buf.Write(b64From24bit([]byte{sumC[27], sumC[7], sumC[17]}, 4))
buf.Write(b64From24bit([]byte{sumC[18], sumC[28], sumC[8]}, 4))
buf.Write(b64From24bit([]byte{sumC[9], sumC[19], sumC[29]}, 4))
buf.Write(b64From24bit([]byte{0, sumC[31], sumC[30]}, 3))

return buf.String()
}

// Checks if a MySQL style caching_sha2 authentication string matches a password
func CheckShaPassword(pwhash []byte, password string) (bool, error) {
pwhash_parts := bytes.Split(pwhash, []byte("$"))
if len(pwhash_parts) != 4 {
return false, errors.New("failed to decode hash parts")
}

hash_type := string(pwhash_parts[1])
if hash_type != "A" {
return false, errors.New("digest type is incompatible")
}

iterations, err := strconv.Atoi(string(pwhash_parts[2]))
if err != nil {
return false, errors.New("failed to decode iterations")
}
iterations = iterations * ITERATION_MULTIPLIER
salt := pwhash_parts[3][:SALT_LENGTH]

newHash := sha256crypt(password, salt, iterations)

return bytes.Equal(pwhash, []byte(newHash)), nil
}

func NewSha2Password(pwd string) string {
salt := make([]byte, SALT_LENGTH)
rand.Read(salt)

return sha256crypt(pwd, salt, 5*ITERATION_MULTIPLIER)
}
Loading

0 comments on commit 10b704a

Please sign in to comment.