#include "config.h" #include #include #include #include #include #include #include #include #include "logerr.h" #include "random.h" #include "hash.h" // FIXME: verify that this is the correct type of mutex for the use case static pthread_mutex_t SEED_INIT_MUTEX = PTHREAD_MUTEX_INITIALIZER; #define PRIV_SEED_LENGTH 16 struct Seed { bool initialized; uint8_t bytes[PRIV_SEED_LENGTH]; }; static struct Seed SEED = { .initialized = false, .bytes = { 0 }, }; int hash_init(void) { int rc = 0; int ret = 0; struct Seed seed = { 0 }; if (SEED.initialized) { goto out; } if (urandom_bytes(sizeof(seed.bytes), &seed.bytes)) { logerr("urandom_bytes(sizeof(seed.bytes), &seed.bytes);\n"); rc = -1; goto out; } ret = pthread_mutex_lock(&SEED_INIT_MUTEX); if (ret) { logerr( "pthread_mutex_lock(&SEED_INIT_MUTEX): %s;\n", strerror(ret) ); rc = -1; goto out; } /* In case where the seed was initialized by another thread while the current was enterinig hash_init() and locking the mutex. The previous check is only a best-effort one to avoid all threads that start at the same time to serially acquire the mutex and do nothing with. In order to sometimes avoid the extra overhead, the previous check exists, but only the following check that has any correctness guarantee, and instead of being best-effort it blocks others so it can make the correct decision. */ if (SEED.initialized) { goto out; } // FIXME: memcpy? const uint8_t length = sizeof(SEED.bytes) / sizeof(SEED.bytes[0]); for (uint8_t i = 0; i < length; i++) { SEED.bytes[i] = seed.bytes[i]; } SEED.initialized = true; out: // FIXME: verify that unlocking an unlocked mutex is fine ret = pthread_mutex_unlock(&SEED_INIT_MUTEX); if (ret) { logerr( "pthread_mutex_unlock(&SEED_INIT_MUTEX): %s\n", strerror(ret) ); rc = -1; } return rc; } /** * @tags infallible */ void hash(const size_t len, const void *const in, uint8_t out[OUTPUT_LENGTH]) { assert(SEED.initialized && "Missing call to hash_init() before using hash()."); siphash(in, len, SEED.bytes, out, 16); return; } #ifdef TEST #include "testing.h" static void * hash_init_thread(void *data) { int *ret = (int *)data; *ret = hash_init(); pthread_exit(NULL); return NULL; } static int test_hash_init(void) { int rc = 0; test_start("hash_init()"); const uint8_t length = sizeof(SEED.bytes) / sizeof(SEED.bytes[0]); assert(length == 16); for (uint8_t i = 0; i < length; i++) { assert(SEED.bytes[i] == 0); } pthread_mutex_t test_mutex = PTHREAD_MUTEX_INITIALIZER; int ret; { testing("when is initialized we do nothing"); SEED.initialized = true; assert(hash_init() == 0); for (uint8_t i = 0; i < length; i++) { assert(SEED.bytes[i] == 0); } test_ok(); } { testing("do nothing when initialized on the second check"); /* 0. start with an unitialized seed 1. grab the lock 2. start the other thread 3. set SEED_INITIALIZED 4. release the lock 5. join the other thread 6. assert HASH_SEED stilll empty */ SEED.initialized = false; // 0 ret = pthread_mutex_lock(&test_mutex); // 1 if (ret) { logerr( "pthread_mutex_lock(&test_mutex): %s;\n", strerror(ret) ); rc = -1; goto out; } pthread_t t; int t_rc = -1; ret = pthread_create(&t, NULL, hash_init_thread, &t_rc); // 2 if (ret) { logerr("pthread_create(%s, %s, %s, %s): %s\n", "&t", "NULL", "hash_init", "&t_rc", strerror(ret)); rc = -1; goto out; } SEED.initialized = true; // 3 ret = pthread_mutex_unlock(&test_mutex); // 4 if (ret) { logerr( "pthread_mutex_unlock(&test_mutex): %s\n", strerror(ret) ); rc = -1; goto out; } ret = pthread_join(t, NULL); // 5 if (ret) { logerr("pthread_join(t, &rc): %s\n", strerror(ret)); rc = -1; goto out; } if (t_rc) { logerr("hash_init() in pthread_t\n"); rc = -1; goto out; } for (size_t i = 0; i < length; i++) { assert(SEED.bytes[i] == 0); // 6 } test_ok(); } { testing("competing threads can init the seed besides us"); /* 0. start with an unitialized seed 1. grab the lock 2. start the other 2 threads 3. release the lock 4. join the other threads 5. assert(SEED_INITIALIZED == true); */ SEED.initialized = false; // 0 ret = pthread_mutex_lock(&test_mutex); // 1 if (ret) { logerr( "pthread_mutex_lock(&SEED_INIT_MUTEX): %s;\n", strerror(ret) ); rc = -1; goto out; } pthread_t t1; int t1_rc = -1; ret = pthread_create(&t1, NULL, hash_init_thread, &t1_rc); // 2 if (ret) { logerr("pthread_create(%s, %s, %s, %s): %s\n", "&t1", "NULL", "hash_init", "&t1_rc", strerror(ret)); rc = -1; goto out; } pthread_t t2; int t2_rc = -1; ret = pthread_create(&t2, NULL, hash_init_thread, &t2_rc); // 2 if (ret) { logerr("pthread_create(%s, %s, %s, %s): %s\n", "&t2", "NULL", "hash_init", "&t2_rc", strerror(ret)); rc = -1; goto out; } ret = pthread_mutex_unlock(&test_mutex); // 3 if (ret) { logerr( "pthread_mutex_unlock(&test_mutex): %s\n", strerror(ret) ); rc = -1; goto out; } ret = pthread_join(t1, NULL); // 4 if (ret) { logerr("pthread_join(t1, &rc): %s\n", strerror(ret)); rc = -1; goto out; } if (t1_rc) { logerr("hash_init() in pthread_t t1\n"); rc = -1; goto out; } ret = pthread_join(t2, NULL); // 4 if (ret) { logerr("pthread_join(t2, &rc): %s\n", strerror(ret)); rc = -1; goto out; } if (t2_rc) { logerr("hash_init() in pthread_t t2\n"); rc = -1; goto out; } assert(SEED.initialized == true); // 5 test_ok(); } out: ret = pthread_mutex_unlock(&test_mutex); if (ret) { logerr( "pthread_mutex_unlock(&test_mutex): %s\n", strerror(ret) ); rc = -1; } return rc; } static void test_hash(void) { test_start("hash()"); { testing("we get the same hash given the same input"); uint8_t out1[OUTPUT_LENGTH] = { 0 }; uint8_t out2[OUTPUT_LENGTH] = { 0 }; const char *const input_str = "the input"; const uint8_t *const input = (uint8_t *)input_str; const size_t input_length = strlen(input_str); hash(input_length, input, out1); hash(input_length, input, out2); for (int i = 0; i < OUTPUT_LENGTH; i++) { assert(out1[i] == out2[i]); } test_ok(); } { testing("hashing twice overwrites whatever was in `out`"); uint8_t out1[OUTPUT_LENGTH]; uint8_t out2[OUTPUT_LENGTH]; const char *const input_str = "another input"; const uint8_t *const input = (uint8_t *)input_str; const size_t input_length = strlen(input_str); hash(6, (uint8_t *)"before", out1); hash(input_length, input, out1); hash(input_length, input, out2); for (int i = 0; i < OUTPUT_LENGTH; i++) { assert(out1[i] == out2[i]); } test_ok(); } return; } int main(void) { int rc = 0; if (test_hash_init()) { logerr("test_hash_init();\n"); rc = -1; goto out; } test_hash(); out: return !!rc; } #endif