diff --git a/neurons/validator.py b/neurons/validator.py index 3477ddd2..0827ee09 100644 --- a/neurons/validator.py +++ b/neurons/validator.py @@ -675,27 +675,51 @@ async def set_weights(self): # based on uids, adjusted based on metagraph during `resync_metagraph` uids = torch.tensor(list(range(len(safe_normalized_weights)))) - min_allowed_weights = self.subtensor.min_allowed_weights(self.config.netuid) - max_weight_limit = self.subtensor.max_weight_limit(self.config.netuid) - logger.debug(f"min_allowed_weights: {min_allowed_weights}") - logger.debug(f"max_weight_limit: {max_weight_limit}") + ( + final_uids, + final_weights, + ) = bt.utils.weight_utils.process_weights_for_netuid( # type: ignore + uids=uids, + weights=safe_normalized_weights, + netuid=self.config.netuid, + subtensor=self.subtensor, + metagraph=self.metagraph, + ) + logger.debug(f"weights:\n{safe_normalized_weights}") + logger.debug(f"uids:\n{uids}") - logger.debug(f"weights: {safe_normalized_weights}") - logger.debug(f"uids: {uids}") + _terminal_plot( + f"pre-processed weights, block: {self.block}", + safe_normalized_weights.numpy(), + ) - # dependent on underlying `set_weights` call - result, message = await self.set_weights_in_thread( - uids, safe_normalized_weights + logger.debug(f"final weights:\n{final_weights}") + logger.debug(f"final uids:\n{final_uids}") + + _terminal_plot( + f"final weights, block: {self.block}", + final_weights.numpy(), ) - if not result: - logger.error(f"Failed to set weights: {message}") + + # dependent on underlying `set_weights` call + try: + result, message = await asyncio.wait_for( + self._set_weights(final_uids, final_weights), timeout=90 + ) + if not result: + logger.error(f"Failed to set weights: {message}") + return + + logger.success(f"Set weights successfully: {message}") + except asyncio.TimeoutError: + logger.error("Setting weights timed out after 90 seconds") return - logger.success(f"Set weights successfully: {message}") return - async def set_weights_in_thread(self, uids: torch.Tensor, weights: torch.Tensor): - """Wrapper function to set weights in a separate thread + async def _set_weights(self, uids: torch.Tensor, weights: torch.Tensor): + """Wrapper function to set weights so we can ensure set weights happens + within a timeout. Args: uids (torch.Tensor): uids to set weights for @@ -704,76 +728,52 @@ async def set_weights_in_thread(self, uids: torch.Tensor, weights: torch.Tensor) Returns: tuple[bool, str]: Returns the result of _set_weights function """ - logger.trace("Attempting to set weights in another thread") - - async def _set_weights() -> tuple[bool, str]: - """LOCAL FUNCTION to set weights, we pass in a lock because of how - we are calling this function from the main thread, sending it - to a separate thread to avoid blocking the main thread, so the lock - MUST be acquired by the separate thread. - - Returns: - tuple[bool, str]: Returns a tuple of a boolean and a string - - boolean: True if weights were set successfully, False otherwise - - string: Message indicating the result of set weights - """ - max_attempts = 5 - attempt = 0 - result = False - while attempt < max_attempts and not result: - try: - logger.debug( - f"Set weights attempt {attempt+1}/{max_attempts} at block: {self.block},time: {time.time()}" - ) - try: - await asyncio.wait_for( - self._ensure_subtensor_ws_connected(), timeout=10 - ) - except asyncio.TimeoutError: - pass - - result, message = self.subtensor.set_weights( - wallet=self.wallet, - netuid=self.config.netuid, # type: ignore - uids=uids.tolist(), - weights=weights.tolist(), - wait_for_finalization=True, - wait_for_inclusion=False, - version_key=self.spec_version, - max_retries=1, - ) - if result: - logger.success(f"Set weights successfully: {message}") - return result, message - - raise SetWeightsFailed( - f"Failed to set weights with message:{message}" - ) - except Exception: - logger.warning( - f"Failed to set weights with attempt {attempt+1}/{max_attempts} due to: {message}" + max_attempts = 5 + attempt = 0 + result = False + while attempt < max_attempts and not result: + try: + logger.debug( + f"Set weights attempt {attempt+1}/{max_attempts} at block: {self.block},time: {time.time()}" + ) + try: + await asyncio.wait_for( + self._ensure_subtensor_ws_connected(), timeout=10 ) + except asyncio.TimeoutError: + pass - if attempt == max_attempts: - logger.error("Max attempts reached. Could not set weights.") - return False, "Max attempts reached" + result, message = self.subtensor.set_weights( + wallet=self.wallet, + netuid=self.config.netuid, # type: ignore + uids=uids.tolist(), + weights=weights.tolist(), + wait_for_finalization=True, + wait_for_inclusion=False, + version_key=self.spec_version, + max_retries=1, + ) + if result: + logger.success(f"Set weights successfully: {message}") + return result, message - await asyncio.sleep(12) - finally: - attempt += 1 + raise SetWeightsFailed(f"Failed to set weights with message:{message}") - return False, "Max attempts reached" + except Exception: + logger.warning( + f"Failed to set weights with attempt {attempt+1}/{max_attempts} due to: {message}" + ) - logger.trace("Submitting callable func to executor") + if attempt == max_attempts: + logger.error("Max attempts reached. Could not set weights.") + return False, "Max attempts reached" - try: - result, message = await asyncio.wait_for(_set_weights(), timeout=90) - except asyncio.TimeoutError: - logger.error("Setting weights timed out after 90 seconds") - return False, "Failed to set weights within time limit" + await asyncio.sleep(12) + finally: + attempt += 1 - return result, message + return False, "Max attempts reached" async def resync_metagraph(self): """Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph."""