diff options
Diffstat (limited to 'src/scrypt.go')
-rw-r--r-- | src/scrypt.go | 357 |
1 files changed, 55 insertions, 302 deletions
diff --git a/src/scrypt.go b/src/scrypt.go index fe6cb8d..a4f03d2 100644 --- a/src/scrypt.go +++ b/src/scrypt.go @@ -1,28 +1,30 @@ package scrypt import ( - "crypto/hmac" "crypto/rand" - "crypto/sha256" - "encoding/binary" "encoding/hex" "errors" "fmt" - "hash" "io" - "math/bits" "os" "slices" ) +/* +#define _XOPEN_SOURCE 700 +#include <stdlib.h> +#include <scrypt-kdf.h> +*/ +import "C" + + + const ( - saltMinLength = 32 - desiredLength = 32 - maxInt = int((^uint(0)) >> 1) MinimumPasswordLength = 16 - + _SALT_MIN_LENGTH = 32 + _DESIRED_LENGTH = 32 _N = 1 << 15 r = 8 p = 1 @@ -30,274 +32,12 @@ const ( var ( - ErrBadN = errors.New("scrypt: N must be > 1 and a power of 2") - ErrParamsTooLarge = errors.New("scrypt: parameters are too large") - ErrSaltTooSmall = errors.New("scrypt: salt is too small") + ErrSaltTooSmall = errors.New("scrypt: salt is too small") + ErrInternal = errors.New("scrypt: internal error") ) -// Package pbkdf2 implements the key derivation function PBKDF2 as defined in -// RFC 2898 / PKCS #5 v2.0. -// -// A key derivation function is useful when encrypting data based on a password -// or any other not-fully-random data. It uses a pseudorandom function to derive -// a secure encryption key based on the password. -// -// While v2.0 of the standard defines only one pseudorandom function to use, -// HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS -// Approved Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for -// HMAC. To choose, you can pass the `New` functions from the different SHA -// packages to pbkdf2.Key. -// -// -// Key derives a key from the password, salt and iteration count, returning a -// []byte of length keylen that can be used as cryptographic key. The key is -// derived based on the method described as PBKDF2 with the HMAC variant using -// the supplied hash function. -// -// For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you -// can get a derived key for e.g. AES-256 (which needs a 32-byte key) by -// doing: -// -// dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) -// -// Remember to get a good random salt. At least 8 bytes is recommended by the -// RFC. -// -// Using a higher iteration count will increase the cost of an exhaustive -// search but will also make derivation proportionally slower. -func _PBKDF2Key( - password []byte, - salt []byte, - iter int, - keyLen int, - h func() hash.Hash, -) []byte { - prf := hmac.New(h, password) - hashLen := prf.Size() - numBlocks := (keyLen + hashLen - 1) / hashLen - - var buffer [4]byte - dk := make([]byte, 0, numBlocks*hashLen) - U := make([]byte, hashLen) - for block := 1; block <= numBlocks; block++ { - // N.B.: || means concatenation, ^ means XOR - // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter - // U_1 = PRF(password, salt || uint(i)) - prf.Reset() - prf.Write(salt) - buffer[0] = byte(block >> 24) - buffer[1] = byte(block >> 16) - buffer[2] = byte(block >> 8) - buffer[3] = byte(block >> 0) - prf.Write(buffer[:4]) - dk = prf.Sum(dk) - T := dk[len(dk) - hashLen:] - copy(U, T) - - // U_n = PRF(password, U_(n - 1)) - for n := 2; n <= iter; n++ { - prf.Reset() - prf.Write(U) - U = U[:0] - U = prf.Sum(U) - for x := range U { - T[x] ^= U[x] - } - } - } - return dk[:keyLen] -} - -// blockCopy copies n numbers from src into dst. -func blockCopy(dst []uint32, src []uint32, n int) { - copy(dst, src[:n]) -} - -// blockXOR XORs numbers from dst with n numbers from src. -func blockXOR(dst []uint32, src []uint32, n int) { - for i, v := range src[:n] { - dst[i] ^= v - } -} - -// salsaXOR applies Salsa20/8 to the XOR of 16 numbers from tmp and in, -// and puts the result into both tmp and out. -func salsaXOR(tmp *[16]uint32, in []uint32, out []uint32) { - w0 := tmp[0] ^ in[0] - w1 := tmp[1] ^ in[1] - w2 := tmp[2] ^ in[2] - w3 := tmp[3] ^ in[3] - w4 := tmp[4] ^ in[4] - w5 := tmp[5] ^ in[5] - w6 := tmp[6] ^ in[6] - w7 := tmp[7] ^ in[7] - w8 := tmp[8] ^ in[8] - w9 := tmp[9] ^ in[9] - w10 := tmp[10] ^ in[10] - w11 := tmp[11] ^ in[11] - w12 := tmp[12] ^ in[12] - w13 := tmp[13] ^ in[13] - w14 := tmp[14] ^ in[14] - w15 := tmp[15] ^ in[15] - - x0 := w0 - x1 := w1 - x2 := w2 - x3 := w3 - x4 := w4 - x5 := w5 - x6 := w6 - x7 := w7 - x8 := w8 - x9 := w9 - x10 := w10 - x11 := w11 - x12 := w12 - x13 := w13 - x14 := w14 - x15 := w15 - - for i := 0; i < 8; i += 2 { - x4 ^= bits.RotateLeft32(x0 + x12, 7) - x8 ^= bits.RotateLeft32(x4 + x0, 9) - x12 ^= bits.RotateLeft32(x8 + x4, 13) - x0 ^= bits.RotateLeft32(x12 + x8, 18) - - x9 ^= bits.RotateLeft32(x5 + x1, 7) - x13 ^= bits.RotateLeft32(x9 + x5, 9) - x1 ^= bits.RotateLeft32(x13 + x9, 13) - x5 ^= bits.RotateLeft32(x1 + x13, 18) - - x14 ^= bits.RotateLeft32(x10 + x6, 7) - x2 ^= bits.RotateLeft32(x14 + x10, 9) - x6 ^= bits.RotateLeft32(x2 + x14, 13) - x10 ^= bits.RotateLeft32(x6 + x2, 18) - - x3 ^= bits.RotateLeft32(x15 + x11, 7) - x7 ^= bits.RotateLeft32(x3 + x15, 9) - x11 ^= bits.RotateLeft32(x7 + x3, 13) - x15 ^= bits.RotateLeft32(x11 + x7, 18) - - x1 ^= bits.RotateLeft32(x0 + x3, 7) - x2 ^= bits.RotateLeft32(x1 + x0, 9) - x3 ^= bits.RotateLeft32(x2 + x1, 13) - x0 ^= bits.RotateLeft32(x3 + x2, 18) - - x6 ^= bits.RotateLeft32(x5 + x4, 7) - x7 ^= bits.RotateLeft32(x6 + x5, 9) - x4 ^= bits.RotateLeft32(x7 + x6, 13) - x5 ^= bits.RotateLeft32(x4 + x7, 18) - - x11 ^= bits.RotateLeft32(x10 + x9, 7) - x8 ^= bits.RotateLeft32(x11 + x10, 9) - x9 ^= bits.RotateLeft32(x8 + x11, 13) - x10 ^= bits.RotateLeft32(x9 + x8, 18) - - x12 ^= bits.RotateLeft32(x15 + x14, 7) - x13 ^= bits.RotateLeft32(x12 + x15, 9) - x14 ^= bits.RotateLeft32(x13 + x12, 13) - x15 ^= bits.RotateLeft32(x14 + x13, 18) - } - - x0 += w0 - x1 += w1 - x2 += w2 - x3 += w3 - x4 += w4 - x5 += w5 - x6 += w6 - x7 += w7 - x8 += w8 - x9 += w9 - x10 += w10 - x11 += w11 - x12 += w12 - x13 += w13 - x14 += w14 - x15 += w15 - - out[0], tmp[0] = x0, x0 - out[1], tmp[1] = x1, x1 - out[2], tmp[2] = x2, x2 - out[3], tmp[3] = x3, x3 - out[4], tmp[4] = x4, x4 - out[5], tmp[5] = x5, x5 - out[6], tmp[6] = x6, x6 - out[7], tmp[7] = x7, x7 - out[8], tmp[8] = x8, x8 - out[9], tmp[9] = x9, x9 - out[10], tmp[10] = x10, x10 - out[11], tmp[11] = x11, x11 - out[12], tmp[12] = x12, x12 - out[13], tmp[13] = x13, x13 - out[14], tmp[14] = x14, x14 - out[15], tmp[15] = x15, x15 -} - -func blockMix(tmp *[16]uint32, in []uint32, out []uint32, r int) { - blockCopy(tmp[:], in[(2 * r - 1) * 16:], 16) - for i := 0; i < 2 * r; i += 2 { - salsaXOR(tmp, in[i * 16:], out[i * 8:]) - salsaXOR(tmp, in[i * 16 + 16:], out[i * 8 + r * 16:]) - } -} - -func integer(b []uint32, r int) uint64 { - j := (2 * r - 1) * 16 - return uint64(b[j]) | (uint64(b[j + 1]) << 32) -} - -func smix(b []byte, r int, N int, v []uint32, xy []uint32) { - var tmp [16]uint32 - R := 32 * r - x := xy - y := xy[R:] - - j := 0 - for i := 0; i < R; i++ { - x[i] = binary.LittleEndian.Uint32(b[j:]) - j += 4 - } - for i := 0; i < N; i += 2 { - blockCopy(v[i * R:], x, R) - blockMix(&tmp, x, y, r) - - blockCopy(v[(i + 1) * R:], y, R) - blockMix(&tmp, y, x, r) - } - for i := 0; i < N; i += 2 { - j := int(integer(x, r) & uint64(N - 1)) - blockXOR(x, v[j * R:], R) - blockMix(&tmp, x, y, r) - - j = int(integer(y, r) & uint64(N - 1)) - blockXOR(y, v[j * R:], R) - blockMix(&tmp, y, x, r) - } - j = 0 - for _, v := range x[:R] { - binary.LittleEndian.PutUint32(b[j:], v) - j += 4 - } -} - -func validateParams(N int, r int, p int) error { - if N <= 1 || N & (N - 1) != 0 { - return ErrBadN - } - - if ((uint64(r) * uint64(p)) >= (1 << 30)) || - r > maxInt / 128 / p || - r > maxInt / 256 || - N > maxInt / 128 / r { - return ErrParamsTooLarge - } - - return nil -} - // Package scrypt implements the scrypt key derivation function as defined in // Colin Percival's paper "Stronger Key Derivation via Sequential Memory-Hard // Functions" (https://www.tarsnap.com/scrypt/scrypt.pdf). @@ -325,39 +65,37 @@ func scrypt( N int, r int, p int, - keyLen int, + outlen int, ) ([]byte, error) { - err := validateParams(N, r, p) - if err != nil { - return nil, err - } - - xy := make([]uint32, 64 * r) - v := make([]uint32, 32 * r * N) - b := _PBKDF2Key(password, salt, 1, p * 128 * r, sha256.New) - - for i := 0; i < p; i++ { - smix(b[i * 128 * r:], r, N, v, xy) - } - - return _PBKDF2Key(password, b, 1, keyLen, sha256.New), nil -} - -func SaltFrom(r io.Reader) ([]byte, error) { - buffer := make([]byte, saltMinLength) - _, err := io.ReadFull(r, buffer) - if err != nil { - return nil, err + passwordbuf := C.CBytes(password) + saltbuf := C.CBytes(salt) + defer C.free(passwordbuf) + defer C.free(saltbuf) + + outbuf := C.malloc(C.size_t(outlen)) + defer C.free(outbuf) + + rv := C.scrypt_kdf( + (*C.uint8_t)(passwordbuf), + C.size_t(len(password)), + (*C.uint8_t)(saltbuf), + C.size_t(len(salt)), + C.uint64_t(N), + C.uint32_t(r), + C.uint32_t(p), + (*C.uint8_t)(outbuf), + C.size_t(outlen), + ) + if rv != 0 { + return nil, ErrInternal } - return buffer, nil -} -func Salt() ([]byte, error) { - return SaltFrom(rand.Reader) + out := C.GoBytes(outbuf, C.int(outlen)) + return out, nil } func Hash(password []byte, salt []byte) ([]byte, error) { - if len(salt) < saltMinLength { + if len(salt) < _SALT_MIN_LENGTH { return nil, ErrSaltTooSmall } @@ -367,7 +105,7 @@ func Hash(password []byte, salt []byte) ([]byte, error) { _N, r, p, - desiredLength, + _DESIRED_LENGTH, ) if err != nil { return nil, err @@ -376,6 +114,19 @@ func Hash(password []byte, salt []byte) ([]byte, error) { return hash, nil } +func SaltFrom(r io.Reader) ([]byte, error) { + buffer := make([]byte, _SALT_MIN_LENGTH) + _, err := io.ReadFull(r, buffer) + if err != nil { + return nil, err + } + return buffer, nil +} + +func Salt() ([]byte, error) { + return SaltFrom(rand.Reader) +} + func Check(password []byte, salt []byte, hash []byte) (bool, error) { candidate, err := Hash(password, salt) if err != nil { @@ -392,8 +143,10 @@ func Main() { fmt.Fprintf(os.Stderr, "Usage: scrypt PASSWORD SALT\n") os.Exit(2) } + password := os.Args[1] + salt := os.Args[2] - payload, err := Hash([]byte(os.Args[1]), []byte(os.Args[2])) + payload, err := Hash([]byte(password), []byte(salt)) if err != nil { if err == ErrSaltTooSmall { fmt.Fprintln(os.Stderr, err) |