package gobang import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/binary" "encoding/hex" "errors" "fmt" "hash" "io" "log/slog" "math/big" "math/bits" "os" "reflect" "runtime/debug" "strings" "sync" "syscall" "time" ) type logLevel int8 const ( LevelNone logLevel = 0 LevelError logLevel = 1 LevelWarning logLevel = 2 LevelInfo logLevel = 3 LevelDebug logLevel = 4 ) const ( uuidDashCount = 4 uuidByteCount = 16 uuidEncodedLength = (uuidByteCount * 2) + uuidDashCount ) type UUID struct { bytes [uuidByteCount]byte } type Gauge struct { Inc func(...any) Dec func(...any) } type CopyResult struct { Written int64 Err error Label string } const maxInt = int((^uint(0)) >> 1) const MinimumPasswordLength = 16 const ( scrypt_N = 1 << 15 scrypt_r = 8 scrypt_p = 1 scryptSaltMinLength = 32 scryptDesiredLength = 32 ) // Private variables // lastV7time is the last time we returned stored as: // // 52 bits of time in milliseconds since epoch // 12 bits of (fractional nanoseconds) >> 8 var lastV7Time int64 var timeMutex sync.Mutex // Local variables var ( level logLevel = LevelInfo emitMetric bool = true hostname string ) // Package pbkdf2 implements the key derivation function PBKDF2 as defined in // RFC 2898 / PKCS #5 v2.0. // // A key derivation function is useful when encrypting data based on a password // or any other not-fully-random data. It uses a pseudorandom function to derive // a secure encryption key based on the password. // // While v2.0 of the standard defines only one pseudorandom function to use, // HMAC-SHA1, the drafted v2.1 specification allows use of all five FIPS // Approved Hash Functions SHA-1, SHA-224, SHA-256, SHA-384 and SHA-512 for // HMAC. To choose, you can pass the `New` functions from the different SHA // packages to pbkdf2.Key. // // // Key derives a key from the password, salt and iteration count, returning a // []byte of length keylen that can be used as cryptographic key. The key is // derived based on the method described as PBKDF2 with the HMAC variant using // the supplied hash function. // // For example, to use a HMAC-SHA-1 based PBKDF2 key derivation function, you // can get a derived key for e.g. AES-256 (which needs a 32-byte key) by // doing: // // dk := pbkdf2.Key([]byte("some password"), salt, 4096, 32, sha1.New) // // Remember to get a good random salt. At least 8 bytes is recommended by the // RFC. // // Using a higher iteration count will increase the cost of an exhaustive // search but will also make derivation proportionally slower. func _PBKDF2Key( password []byte, salt []byte, iter int, keyLen int, h func() hash.Hash, ) []byte { prf := hmac.New(h, password) hashLen := prf.Size() numBlocks := (keyLen + hashLen - 1) / hashLen var buffer [4]byte dk := make([]byte, 0, numBlocks*hashLen) U := make([]byte, hashLen) for block := 1; block <= numBlocks; block++ { // N.B.: || means concatenation, ^ means XOR // for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter // U_1 = PRF(password, salt || uint(i)) prf.Reset() prf.Write(salt) buffer[0] = byte(block >> 24) buffer[1] = byte(block >> 16) buffer[2] = byte(block >> 8) buffer[3] = byte(block >> 0) prf.Write(buffer[:4]) dk = prf.Sum(dk) T := dk[len(dk) - hashLen:] copy(U, T) // U_n = PRF(password, U_(n - 1)) for n := 2; n <= iter; n++ { prf.Reset() prf.Write(U) U = U[:0] U = prf.Sum(U) for x := range U { T[x] ^= U[x] } } } return dk[:keyLen] } // blockCopy copies n numbers from src into dst. func blockCopy(dst []uint32, src []uint32, n int) { copy(dst, src[:n]) } // blockXOR XORs numbers from dst with n numbers from src. func blockXOR(dst []uint32, src []uint32, n int) { for i, v := range src[:n] { dst[i] ^= v } } // salsaXOR applies Salsa20/8 to the XOR of 16 numbers from tmp and in, // and puts the result into both tmp and out. func salsaXOR(tmp *[16]uint32, in []uint32, out []uint32) { w0 := tmp[0] ^ in[0] w1 := tmp[1] ^ in[1] w2 := tmp[2] ^ in[2] w3 := tmp[3] ^ in[3] w4 := tmp[4] ^ in[4] w5 := tmp[5] ^ in[5] w6 := tmp[6] ^ in[6] w7 := tmp[7] ^ in[7] w8 := tmp[8] ^ in[8] w9 := tmp[9] ^ in[9] w10 := tmp[10] ^ in[10] w11 := tmp[11] ^ in[11] w12 := tmp[12] ^ in[12] w13 := tmp[13] ^ in[13] w14 := tmp[14] ^ in[14] w15 := tmp[15] ^ in[15] x0 := w0 x1 := w1 x2 := w2 x3 := w3 x4 := w4 x5 := w5 x6 := w6 x7 := w7 x8 := w8 x9 := w9 x10 := w10 x11 := w11 x12 := w12 x13 := w13 x14 := w14 x15 := w15 for i := 0; i < 8; i += 2 { x4 ^= bits.RotateLeft32(x0 + x12, 7) x8 ^= bits.RotateLeft32(x4 + x0, 9) x12 ^= bits.RotateLeft32(x8 + x4, 13) x0 ^= bits.RotateLeft32(x12 + x8, 18) x9 ^= bits.RotateLeft32(x5 + x1, 7) x13 ^= bits.RotateLeft32(x9 + x5, 9) x1 ^= bits.RotateLeft32(x13 + x9, 13) x5 ^= bits.RotateLeft32(x1 + x13, 18) x14 ^= bits.RotateLeft32(x10 + x6, 7) x2 ^= bits.RotateLeft32(x14 + x10, 9) x6 ^= bits.RotateLeft32(x2 + x14, 13) x10 ^= bits.RotateLeft32(x6 + x2, 18) x3 ^= bits.RotateLeft32(x15 + x11, 7) x7 ^= bits.RotateLeft32(x3 + x15, 9) x11 ^= bits.RotateLeft32(x7 + x3, 13) x15 ^= bits.RotateLeft32(x11 + x7, 18) x1 ^= bits.RotateLeft32(x0 + x3, 7) x2 ^= bits.RotateLeft32(x1 + x0, 9) x3 ^= bits.RotateLeft32(x2 + x1, 13) x0 ^= bits.RotateLeft32(x3 + x2, 18) x6 ^= bits.RotateLeft32(x5 + x4, 7) x7 ^= bits.RotateLeft32(x6 + x5, 9) x4 ^= bits.RotateLeft32(x7 + x6, 13) x5 ^= bits.RotateLeft32(x4 + x7, 18) x11 ^= bits.RotateLeft32(x10 + x9, 7) x8 ^= bits.RotateLeft32(x11 + x10, 9) x9 ^= bits.RotateLeft32(x8 + x11, 13) x10 ^= bits.RotateLeft32(x9 + x8, 18) x12 ^= bits.RotateLeft32(x15 + x14, 7) x13 ^= bits.RotateLeft32(x12 + x15, 9) x14 ^= bits.RotateLeft32(x13 + x12, 13) x15 ^= bits.RotateLeft32(x14 + x13, 18) } x0 += w0 x1 += w1 x2 += w2 x3 += w3 x4 += w4 x5 += w5 x6 += w6 x7 += w7 x8 += w8 x9 += w9 x10 += w10 x11 += w11 x12 += w12 x13 += w13 x14 += w14 x15 += w15 out[0], tmp[0] = x0, x0 out[1], tmp[1] = x1, x1 out[2], tmp[2] = x2, x2 out[3], tmp[3] = x3, x3 out[4], tmp[4] = x4, x4 out[5], tmp[5] = x5, x5 out[6], tmp[6] = x6, x6 out[7], tmp[7] = x7, x7 out[8], tmp[8] = x8, x8 out[9], tmp[9] = x9, x9 out[10], tmp[10] = x10, x10 out[11], tmp[11] = x11, x11 out[12], tmp[12] = x12, x12 out[13], tmp[13] = x13, x13 out[14], tmp[14] = x14, x14 out[15], tmp[15] = x15, x15 } func blockMix(tmp *[16]uint32, in []uint32, out []uint32, r int) { blockCopy(tmp[:], in[(2 * r - 1) * 16:], 16) for i := 0; i < 2 * r; i += 2 { salsaXOR(tmp, in[i * 16:], out[i * 8:]) salsaXOR(tmp, in[i * 16 + 16:], out[i * 8 + r * 16:]) } } func integer(b []uint32, r int) uint64 { j := (2 * r - 1) * 16 return uint64(b[j]) | (uint64(b[j + 1]) << 32) } func smix(b []byte, r int, N int, v []uint32, xy []uint32) { var tmp [16]uint32 R := 32 * r x := xy y := xy[R:] j := 0 for i := 0; i < R; i++ { x[i] = binary.LittleEndian.Uint32(b[j:]) j += 4 } for i := 0; i < N; i += 2 { blockCopy(v[i * R:], x, R) blockMix(&tmp, x, y, r) blockCopy(v[(i + 1) * R:], y, R) blockMix(&tmp, y, x, r) } for i := 0; i < N; i += 2 { j := int(integer(x, r) & uint64(N - 1)) blockXOR(x, v[j * R:], R) blockMix(&tmp, x, y, r) j = int(integer(y, r) & uint64(N - 1)) blockXOR(y, v[j * R:], R) blockMix(&tmp, y, x, r) } j = 0 for _, v := range x[:R] { binary.LittleEndian.PutUint32(b[j:], v) j += 4 } } // Package scrypt implements the scrypt key derivation function as defined in // Colin Percival's paper "Stronger Key Derivation via Sequential Memory-Hard // Functions" (https://www.tarsnap.com/scrypt/scrypt.pdf). // // // Key derives a key from the password, salt, and cost parameters, returning // a byte slice of length keyLen that can be used as cryptographic key. // // N is a CPU/memory cost parameter, which must be a power of 2 greater than 1. // r and p must satisfy r * p < 2³⁰. If the parameters do not satisfy the // limits, the function returns a nil byte slice and an error. // // For example, you can get a derived key for e.g. AES-256 (which needs a // 32-byte key) by doing: // // dk, err := scrypt.Key([]byte("some password"), salt, 32768, 8, 1, 32) // // The recommended parameters for interactive logins as of 2017 are N=32768, r=8 // and p=1. The parameters N, r, and p should be increased as memory latency and // CPU parallelism increases; consider setting N to the highest power of 2 you // can derive within 100 milliseconds. Remember to get a good random salt. func scrypt( password []byte, salt []byte, N int, r int, p int, keyLen int, ) ([]byte, error) { if N <= 1 || N & (N - 1) != 0 { return nil, errors.New("scrypt: N must be > 1 and a power of 2") } if ((uint64(r) * uint64(p)) >= (1 << 30)) || r > maxInt / 128 / p || r > maxInt / 256 || N > maxInt / 128 / r { return nil, errors.New("scrypt: parameters are too large") } xy := make([]uint32, 64 * r) v := make([]uint32, 32 * N * r) b := _PBKDF2Key(password, salt, 1, p * 128 * r, sha256.New) for i := 0; i < p; i++ { smix(b[i * 128 * r:], r, N, v, xy) } return _PBKDF2Key(password, b, 1, keyLen, sha256.New), nil } func Random(length int) []byte { buffer := make([]byte, length) _, err := io.ReadFull(rand.Reader, buffer) FatalIf(err) return buffer } func Salt() []byte { return Random(scryptSaltMinLength) } func Hash(password []byte, salt []byte) []byte{ Assert(len(salt) >= scryptSaltMinLength, "salt is too small") hash, err := scrypt( password, salt, scrypt_N, scrypt_r, scrypt_p, scryptDesiredLength, ) FatalIf(err) return hash } // FIXME: finish rewriting // FIXME: add tests // // getV7Time returns the time in milliseconds and nanoseconds / 256. // The returned (milli << (12 + seq)) is guaranteed to be greater than // (milli << (12 + seq)) returned by any previous call to getV7Time. // `seq` Sequence number is between 0 and 3906 (nanoPerMilli >> 8) func getV7Time(nano int64) (int64, int64) { const nanoPerMilli = 1000 * 1000 milli := nano / nanoPerMilli seq := (nano - (milli * nanoPerMilli)) >> 8 now := milli << (12 + seq) timeMutex.Lock() defer timeMutex.Unlock() if now <= lastV7Time { now = lastV7Time + 1 milli = now >> 12 seq = now & 0xfff } lastV7Time = now return milli, seq } func newUUIDFrom(randomBuffer [uuidByteCount]byte, now int64) UUID { randomBuffer[6] = (randomBuffer[6] & 0x0f) | 0x40 // Version 4 randomBuffer[8] = (randomBuffer[8] & 0x3f) | 0x80 // Variant is 10 t, s := getV7Time(now) randomBuffer[0] = byte(t >> 40) randomBuffer[1] = byte(t >> 32) randomBuffer[2] = byte(t >> 24) randomBuffer[3] = byte(t >> 16) randomBuffer[4] = byte(t >> 8) randomBuffer[5] = byte(t >> 0) randomBuffer[6] = 0x70 | (0x0f & byte(s >> 8)) randomBuffer[7] = byte(s) return UUID { bytes: randomBuffer } } func NewUUID() UUID { var buffer [uuidByteCount]byte _, err := io.ReadFull(rand.Reader, buffer[7:]) FatalIf(err) now := time.Now().UnixNano() return newUUIDFrom(buffer, now) } func (uuid UUID) String() string { dst := [uuidEncodedLength]byte { 0, 0, 0, 0, 0, 0, 0, 0, '-', 0, 0, 0, 0, '-', 0, 0, 0, 0, '-', 0, 0, 0, 0, '-', 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, } hex.Encode(dst[ 0:8], uuid.bytes[0:4]) hex.Encode(dst[ 9:13], uuid.bytes[4:6]) hex.Encode(dst[14:18], uuid.bytes[6:8]) hex.Encode(dst[19:23], uuid.bytes[8:10]) hex.Encode(dst[24:36], uuid.bytes[10:]) return string(dst[:]) } var ( dashIndexes = []int { 8, 13, 18, 23 } emptyUUID = UUID {} badUUIDLengthError = errors.New("str isn't of the correct length") badUUIDDashCountError = errors.New("Bad count of dashes in string") badUUIDDashPositionError = errors.New("Bad char in string") ) func ParseUUID(str string) (UUID, error) { if len(str) != uuidEncodedLength { return emptyUUID, badUUIDLengthError } if strings.Count(str, "-") != uuidDashCount { return emptyUUID, badUUIDDashCountError } for _, idx := range dashIndexes { if str[idx] != '-' { return emptyUUID, badUUIDDashPositionError } } hexstr := strings.Join(strings.Split(str, "-"), "") data, err := hex.DecodeString(hexstr) if err != nil { return emptyUUID, err } return UUID { bytes: [uuidByteCount]byte(data) }, nil } func Debug(message string, type_ string, args ...any) { if level < LevelDebug { return } slog.Debug( message, append( []any { "id", NewUUID(), "kind", "log", "type", type_, }, args..., )..., ) } func Info(message string, type_ string, args ...any) { if level < LevelInfo { return } slog.Info( message, append( []any { "id", NewUUID(), "kind", "log", "type", type_, }, args..., )..., ) } func Warning(message string, type_ string, args ...any) { if level < LevelWarning { return } slog.Warn( message, append( []any { "id", NewUUID(), "kind", "log", "type", type_, }, args..., )..., ) } func Error(message string, type_ string, args ...any) { if level < LevelError { return } slog.Error( message, append( []any { "id", NewUUID(), "kind", "log", "type", type_, }, args..., )..., ) } func Metric(type_ string, label string, args ...any) { if !emitMetric { return } slog.Info( "_", append( []any { "id", NewUUID(), "kind", "metric", "type", type_, "label", label, }, args..., )..., ) } func MakeGauge(label string, staticArgs ...any) Gauge { var zero = big.NewInt(0) var one = big.NewInt(1) count := big.NewInt(0) emitGauge := func(dynamicArgs ...any) { if count.Cmp(zero) == -1 { Error( "Gauge went negative", "process-metric", append( []any { "value", count }, append( staticArgs, dynamicArgs..., )..., )..., ) return // avoid wrong metrics being emitted } Metric( "gauge", label, // TODO: we'll have slices.Concat on Go 1.22 append( []any { "value", count }, append( staticArgs, dynamicArgs..., )..., )..., ) } return Gauge { Inc: func(dynamicArgs ...any) { count.Add(count, one) emitGauge(dynamicArgs...) }, Dec: func(dynamicArgs ...any) { count.Sub(count, one) emitGauge(dynamicArgs...) }, } } func MakeCounter(label string) func(...any) { return func(args ...any) { Metric( "counter", label, append([]any { "value", 1 }, args...)..., ) } } func ErrorIf(err error) { if err != nil { fmt.Fprintf(os.Stderr, "Unexpected error: %#v\n", err) os.Exit(1) } } func ErrorNil(err error) { if err == nil { fmt.Fprintf(os.Stderr, "Expected error, got nil\n") os.Exit(1) } } func showColour() bool { return os.Getenv("NO_COLOUR") == "" } func TestStart(name string) { fmt.Fprintf(os.Stderr, "%s:\n", name) } func Testing(message string, body func()) { if showColour() { fmt.Fprintf( os.Stderr, "\033[0;33mtesting\033[0m: %s... ", message, ) body() fmt.Fprint(os.Stderr, "\033[0;32mOK\033[0m.\n") } else { fmt.Fprintf(os.Stderr, "testing: %s...", message) body() fmt.Fprint(os.Stderr, " OK.\n") } } func AssertEqual(given any, expected any) { if !reflect.DeepEqual(given, expected) { if showColour() { fmt.Fprintf(os.Stderr, "\033[0;31mERR\033[0m.\n") } else { fmt.Fprintf(os.Stderr, "ERR.\n") } fmt.Fprintf(os.Stderr, "given != expected\n") fmt.Fprintf(os.Stderr, "given: %#v\n", given) fmt.Fprintf(os.Stderr, "expected: %#v\n", expected) os.Exit(1) } } func AssertEqualI(i int, given any, expected any) { if !reflect.DeepEqual(given, expected) { if showColour() { fmt.Fprintf(os.Stderr, "\033[0;31mERR\033[0m.\n") } else { fmt.Fprintf(os.Stderr, "ERR.\n") } fmt.Fprintf(os.Stderr, "given != expected (i = %d)\n", i) fmt.Fprintf(os.Stderr, "given: %#v\n", given) fmt.Fprintf(os.Stderr, "expected: %#v\n", expected) os.Exit(1) } } func setLoggerOutput(w io.Writer) { slog.SetDefault(slog.New(slog.NewJSONHandler(w, &slog.HandlerOptions { AddSource: true, })).With( slog.Group( "info", "pid", os.Getpid(), "ppid", os.Getppid(), "puuid", NewUUID(), ), )) } func levelFromString(name string) (bool, logLevel) { label := strings.ToUpper(name) if label == "NONE" { return true, LevelNone } if label == "ERROR" { return true, LevelError } if label == "WARNING" { return true, LevelWarning } if label == "INFO" { return true, LevelInfo } if label == "DEBUG" { return true, LevelDebug } return false, level } func setLogLevel() { ok, envLevel := levelFromString(os.Getenv("LOG_LEVEL")) if ok { level = envLevel } } func setMetric() { if os.Getenv("NO_METRIC") != "" { emitMetric = false } } func setTraceback() { if os.Getenv("GOTRACEBACK") == "" { debug.SetTraceback("crash") } } func Fatal(err error) { Error( "Fatal error", "fatal-error", "error", err, "stack", string(debug.Stack()), ) syscall.Kill(os.Getpid(), syscall.SIGABRT) os.Exit(3) } func FatalIf(err error) { if err != nil { Fatal(err) } } func Assert(condition bool, message string) { if !condition { Fatal(errors.New(message)) } } func Unreachable() { Assert(false, "Unreachable code was reached") } func setHostname() { var err error hostname, err = os.Hostname() FatalIf(err) } func Init() { setLoggerOutput(os.Stdout) setLogLevel() setMetric() setTraceback() setHostname() }