diff --git a/common/configdb.h b/common/configdb.h index f8b77ee69..012370748 100644 --- a/common/configdb.h +++ b/common/configdb.h @@ -57,6 +57,7 @@ class ConfigDBConnector_Native : public SonicV2Connector_Native ## Note: callback is difficult to implement by SWIG C++, so keep in python self.handlers = {} + self.fire_init_data = {} @property def KEY_SEPARATOR(self): @@ -71,10 +72,28 @@ class ConfigDBConnector_Native : public SonicV2Connector_Native return self.getDbName() ## Note: callback is difficult to implement by SWIG C++, so keep in python - def listen(self): - ## Start listen Redis keyspace events and will trigger corresponding handlers when content of a table changes. + def listen(self, init_data_handler=None): + ## Start listen Redis keyspace event. Pass a callback function to `init` to handle initial table data. self.pubsub = self.get_redis_client(self.db_name).pubsub() self.pubsub.psubscribe("__keyspace@{}__:*".format(self.get_dbid(self.db_name))) + + # Build a cache of data for all subscribed tables that will recieve the initial table data so we dont send duplicate event notifications + init_data = {tbl: self.get_table(tbl) for tbl in self.handlers if init_data_handler or self.fire_init_data[tbl]} + + # Function to send initial data as series of updates through individual table callback handlers + def load_data(tbl, data): + if self.fire_init_data[tbl]: + for row, x in data.items(): + self.__fire(tbl, row, x) + return False + return True + + init_callback_data = {tbl: data for tbl, data in init_data.items() if load_data(tbl, data)} + + # Pass all initial data that we DID NOT send as updates to handlers through the init callback if provided by caller + if init_data_handler: + init_data_handler(init_callback_data) + while True: item = self.pubsub.listen_message() if item['type'] == 'pmessage': @@ -84,6 +103,12 @@ class ConfigDBConnector_Native : public SonicV2Connector_Native if table in self.handlers: client = self.get_redis_client(self.db_name) data = self.raw_to_typed(client.hgetall(key)) + if table in init_data and row in init_data[table]: + cache_hit = init_data[table][row] == data + del init_data[table][row] + if not init_data[table]: + del init_data[table] + if cache_hit: continue self.__fire(table, row, data) except ValueError: pass #Ignore non table-formated redis entries @@ -153,8 +178,9 @@ class ConfigDBConnector_Native : public SonicV2Connector_Native handler = self.handlers[table] handler(table, key, data) - def subscribe(self, table, handler): + def subscribe(self, table, handler, fire_init_data=False): self.handlers[table] = handler + self.fire_init_data[table] = fire_init_data def unsubscribe(self, table): if table in self.handlers: diff --git a/tests/test_redis_ut.py b/tests/test_redis_ut.py index 1bebd2e04..f7634c867 100644 --- a/tests/test_redis_ut.py +++ b/tests/test_redis_ut.py @@ -1,6 +1,7 @@ import os import time import pytest +import multiprocessing from threading import Thread from pympler.tracker import SummaryTracker from swsscommon import swsscommon @@ -552,6 +553,50 @@ def thread_coming_entry(): config_db.unsubscribe(table_name) assert table_name not in config_db.handlers +def test_ConfigDBInit(): + table_name_1 = 'TEST_TABLE_1' + table_name_2 = 'TEST_TABLE_2' + test_key = 'key1' + test_data = {'field1': 'value1'} + test_data_update = {'field1': 'value2'} + + manager = multiprocessing.Manager() + ret_data = manager.dict() + + def test_handler(table, key, data, ret): + ret[table] = {key: data} + + def test_init_handler(data, ret): + ret.update(data) + + def thread_listen(ret): + config_db = ConfigDBConnector() + config_db.connect(wait_for_init=False) + + config_db.subscribe(table_name_1, lambda table, key, data: test_handler(table, key, data, ret), + fire_init_data=False) + config_db.subscribe(table_name_2, lambda table, key, data: test_handler(table, key, data, ret), + fire_init_data=True) + + config_db.listen(init_data_handler=lambda data: test_init_handler(data, ret)) + + config_db = ConfigDBConnector() + config_db.connect(wait_for_init=False) + client = config_db.get_redis_client(config_db.CONFIG_DB) + client.flushdb() + + # Init table data + config_db.set_entry(table_name_1, test_key, test_data) + config_db.set_entry(table_name_2, test_key, test_data) + + thread = multiprocessing.Process(target=thread_listen, args=(ret_data,)) + thread.start() + time.sleep(5) + thread.terminate() + + assert ret_data[table_name_1] == {test_key: test_data} + assert ret_data[table_name_2] == {test_key: test_data} + def test_DBConnectFailure(): """ Verify that a DB connection failure will not cause a process abort