From b27d06b717a78f5d7509ac5681a12486a986d158 Mon Sep 17 00:00:00 2001 From: Danny Allen Date: Mon, 30 Mar 2020 14:32:17 -0700 Subject: [PATCH] [dvs] Add generic polling utility (#1233) Signed-off-by: Danny Allen --- tests/dvslib/dvs_common.py | 59 ++++++++++ tests/dvslib/dvs_database.py | 205 +++++++++++------------------------ 2 files changed, 121 insertions(+), 143 deletions(-) create mode 100644 tests/dvslib/dvs_common.py diff --git a/tests/dvslib/dvs_common.py b/tests/dvslib/dvs_common.py new file mode 100644 index 000000000000..fe2e77d2b724 --- /dev/null +++ b/tests/dvslib/dvs_common.py @@ -0,0 +1,59 @@ +""" + dvs_common contains common infrastructure for writing tests for the + virtual switch. +""" + +import collections +import time + +_PollingConfig = collections.namedtuple('PollingConfig', 'polling_interval timeout strict') + +class PollingConfig(_PollingConfig): + """ + PollingConfig provides parameters that are used to control the behavior + for polling functions. + + Params: + polling_interval (int): How often to poll, in seconds. + timeout (int): The maximum amount of time to wait, in seconds. + strict (bool): If the strict flag is set, reaching the timeout + will cause tests to fail (e.g. assert False) + """ + + pass + +def wait_for_result(polling_function, polling_config): + """ + wait_for_result will periodically run `polling_function` + using the parameters described in `polling_config` and return the + output of the polling function. + + Args: + polling_config (PollingConfig): The parameters to use to poll + the db. + polling_function (Callable[[], (bool, Any)]): The function being + polled. The function takes no arguments and must return a + status which indicates if the function was succesful or + not, as well as some return value. + + Returns: + Any: The output of the polling function, if it is succesful, + None otherwise. + """ + if polling_config.polling_interval == 0: + iterations = 1 + else: + iterations = int(polling_config.timeout // polling_config.polling_interval) + 1 + + for _ in range(iterations): + (status, result) = polling_function() + + if status: + return result + + time.sleep(polling_config.polling_interval) + + if polling_config.strict: + assert False + + return None diff --git a/tests/dvslib/dvs_database.py b/tests/dvslib/dvs_database.py index 5748c47f1159..ea667cd27c7f 100644 --- a/tests/dvslib/dvs_database.py +++ b/tests/dvslib/dvs_database.py @@ -2,21 +2,8 @@ dvs_database contains utilities for interacting with redis when writing tests for the virtual switch. """ -from __future__ import print_function - -import time -import collections - from swsscommon import swsscommon - - -# PollingConfig provides parameters that are used to control polling behavior -# when accessing redis: -# - polling_interval: how often to check for updates in redis -# - timeout: the max amount of time to wait for updates in redis -# - strict: if the strict flag is set, failure to receive updates will cause -# the polling method to cause tests to fail (e.g. assert False) -PollingConfig = collections.namedtuple('PollingConfig', 'polling_interval timeout strict') +from dvslib.dvs_common import wait_for_result, PollingConfig class DVSDatabase(object): @@ -56,27 +43,27 @@ def create_entry(self, table_name, key, entry): formatted_entry = swsscommon.FieldValuePairs(entry.items()) table.set(key, formatted_entry) - def wait_for_entry(self, table_name, key, - polling_config=DEFAULT_POLLING_CONFIG): + def get_entry(self, table_name, key): """ - Gets the entry stored at `key` in the specified table. This method - will wait for the entry to exist. + Gets the entry stored at `key` in the specified table. Args: table_name (str): The name of the table where the entry is stored. key (str): The key that maps to the entry being retrieved. - polling_config (PollingConfig): The parameters to use to poll - the db. Returns: Dict[str, str]: The entry stored at `key`. If no entry is found, then an empty Dict will be returned. - """ - access_function = self._get_entry_access_function(table_name, key, True) - return self._db_poll(polling_config, access_function) + table = swsscommon.Table(self.db_connection, table_name) + (status, fv_pairs) = table.get(key) + + if not status: + return {} + + return dict(fv_pairs) def delete_entry(self, table_name, key): """ @@ -91,6 +78,49 @@ def delete_entry(self, table_name, key): table = swsscommon.Table(self.db_connection, table_name) table._del(key) # pylint: disable=protected-access + def get_keys(self, table_name): + """ + Gets all of the keys stored in the specified table. + + Args: + table_name (str): The name of the table from which to fetch + the keys. + + Returns: + List[str]: The keys stored in the table. If no keys are found, + then an empty List will be returned. + """ + + table = swsscommon.Table(self.db_connection, table_name) + keys = table.getKeys() + + return keys if keys else [] + + def wait_for_entry(self, table_name, key, + polling_config=DEFAULT_POLLING_CONFIG): + """ + Gets the entry stored at `key` in the specified table. This method + will wait for the entry to exist. + + Args: + table_name (str): The name of the table where the entry is + stored. + key (str): The key that maps to the entry being retrieved. + polling_config (PollingConfig): The parameters to use to poll + the db. + + Returns: + Dict[str, str]: The entry stored at `key`. If no entry is found, + then an empty Dict will be returned. + + """ + + def _access_function(): + fv_pairs = self.get_entry(table_name, key) + return (bool(fv_pairs), fv_pairs) + + return wait_for_result(_access_function, polling_config) + def wait_for_empty_entry(self, table_name, key, @@ -109,8 +139,11 @@ def wait_for_empty_entry(self, bool: True if no entry exists at `key`, False otherwise. """ - access_function = self._get_entry_access_function(table_name, key, False) - return not self._db_poll(polling_config, access_function) + def _access_function(): + fv_pairs = self.get_entry(table_name, key) + return (not fv_pairs, fv_pairs) + + return wait_for_result(_access_function, polling_config) def wait_for_n_keys(self, table_name, @@ -133,122 +166,8 @@ def wait_for_n_keys(self, then an empty List will be returned. """ - access_function = self._get_keys_access_function(table_name, num_keys) - return self._db_poll(polling_config, access_function) - - def _get_keys_access_function(self, table_name, num_keys): - """ - Generates an access function to check for `num_keys` in the given - table and return the list of keys if successful. - - Args: - table_name (str): The name of the table from which to fetch - the keys. - num_keys (int): The number of keys to check for in the table. - If this is set to None, then this function will just return - whatever keys are in the table. - - Returns: - Callable([[], (bool, List[str])]): A function that can be - called to access the database. - - If `num_keys` keys are found in the given table, or left - unspecified, then the function will return True along with - the list of keys that were found. Otherwise, the function will - return False and some undefined list of keys. - """ - - table = swsscommon.Table(self.db_connection, table_name) - - def _accessor(): - keys = table.getKeys() - if not keys: - keys = [] - - if not num_keys and num_keys != 0: - status = True - else: - status = len(keys) == num_keys - - return (status, keys) - - return _accessor - - def _get_entry_access_function(self, table_name, key, expect_entry): - """ - Generates an access function to check for existence of an entry - at `key` and return it if successful. - - Args: - table_name (str): The name of the table from which to fetch - the entry. - key (str): The key that maps to the entry being retrieved. - expect_entry (bool): Whether or not we expect to see an entry - at `key`. - - Returns: - Callable([[], (bool, Dict[str, str])]): A function that can be - called to access the database. - - If `expect_entry` is set and an entry is found, then the - function will return True along with the entry that was found. - - If `expect_entry` is not set and no entry is found, then the - function will return True along with an empty Dict. - - In all other cases, the function will return False with some - undefined Dict. - """ - - table = swsscommon.Table(self.db_connection, table_name) - - def _accessor(): - (status, fv_pairs) = table.get(key) - - status = expect_entry == status - - if fv_pairs: - entry = dict(fv_pairs) - else: - entry = {} - - return (status, entry) - - return _accessor - - @staticmethod - def _db_poll(polling_config, access_function): - """ - _db_poll will periodically run `access_function` on the database - using the parameters described in `polling_config` and return the - output of the access function. - - Args: - polling_config (PollingConfig): The parameters to use to poll - the db. - access_function (Callable[[], (bool, Any)]): The function used - for polling the db. Note that the function must return a - status which indicates if the function was succesful or - not, as well as some return value. - - Returns: - Any: The output of the access function, if it is succesful, - None otherwise. - """ - if polling_config.polling_interval == 0: - iterations = 1 - else: - iterations = int(polling_config.timeout // polling_config.polling_interval) + 1 - - for _ in range(iterations): - (status, result) = access_function() - - if status: - return result - - time.sleep(polling_config.polling_interval) - - if polling_config.strict: - assert False + def _access_function(): + keys = self.get_keys(table_name) + return (len(keys) == num_keys, keys) - return None + return wait_for_result(_access_function, polling_config)