diff --git a/sys/net/application_layer/dns/cache.c b/sys/net/application_layer/dns/cache.c index deb3b82e85d1..d6881e68664a 100644 --- a/sys/net/application_layer/dns/cache.c +++ b/sys/net/application_layer/dns/cache.c @@ -17,11 +17,12 @@ #include "bitfield.h" #include "checksum/fletcher32.h" -#include "time_units.h" +#include "mutex.h" #include "net/af.h" #include "net/dns/cache.h" #include "net/ipv4/addr.h" #include "net/ipv6/addr.h" +#include "time_units.h" #include "ztimer.h" #define ENABLE_DEBUG 0 @@ -39,6 +40,7 @@ static struct dns_cache_entry { #endif } addr; } cache[CONFIG_DNS_CACHE_SIZE]; +static mutex_t cache_mutex = MUTEX_INIT; #if IS_ACTIVE(CONFIG_DNS_CACHE_A) && IS_ACTIVE(CONFIG_DNS_CACHE_AAAA) BITFIELD(cache_is_v6, CONFIG_DNS_CACHE_SIZE); @@ -118,10 +120,12 @@ static uint32_t _hash(const void *data, size_t len) int dns_cache_query(const char *domain_name, void *addr_out, int family) { + int res = 0; uint32_t now = ztimer_now(ZTIMER_MSEC) / MS_PER_SEC; uint32_t hash = _hash(domain_name, strlen(domain_name)); uint8_t addr_len = _addr_len(family); + mutex_lock(&cache_mutex); for (unsigned i = 0; i < CONFIG_DNS_CACHE_SIZE; ++i) { /* empty slot */ if (_is_empty(i)) { @@ -137,12 +141,15 @@ int dns_cache_query(const char *domain_name, void *addr_out, int family) if (cache[i].hash == hash && (!addr_len || addr_len == _get_len(i))) { DEBUG("dns_cache[%u] hit\n", i); memcpy(addr_out, &cache[i].addr, _get_len(i)); - return _get_len(i); + res = _get_len(i); + break; } } - DEBUG("dns_cache miss\n"); - - return 0; + if (res == 0) { + DEBUG("dns_cache miss\n"); + } + mutex_unlock(&cache_mutex); + return res; } static void _add_entry(uint8_t i, uint32_t hash, const void *addr_out, @@ -166,15 +173,16 @@ void dns_cache_add(const char *domain_name, const void *addr_out, assert(addr_len == 4 || addr_len == 16); DEBUG("dns_cache: lifetime of %s is %"PRIu32" s\n", domain_name, ttl); + mutex_lock(&cache_mutex); for (unsigned i = 0; i < CONFIG_DNS_CACHE_SIZE; ++i) { if (now > cache[i].expires || _is_empty(i)) { _add_entry(i, hash, addr_out, addr_len, now + ttl); - return; + goto exit; } if (cache[i].hash == hash && _get_len(i) == addr_len) { DEBUG("dns_cache[%u] update ttl\n", i); cache[i].expires = now + ttl; - return; + goto exit; } uint32_t _ttl = cache[i].expires - now; if (_ttl < oldest) { @@ -187,4 +195,6 @@ void dns_cache_add(const char *domain_name, const void *addr_out, DEBUG("dns_cache: evict first entry to expire\n"); _add_entry(idx, hash, addr_out, addr_len, now + ttl); } +exit: + mutex_unlock(&cache_mutex); }