Skip to content

Commit

Permalink
Merge pull request #50 from tensorplex-labs/fix/reuse-weight-utils
Browse files Browse the repository at this point in the history
fix/reuse-weight-utils
  • Loading branch information
jarvis8x7b authored Oct 27, 2024
2 parents cb7be96 + 2e45a0a commit 8cd6ad2
Showing 1 changed file with 76 additions and 76 deletions.
152 changes: 76 additions & 76 deletions neurons/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,27 +675,51 @@ async def set_weights(self):
# based on uids, adjusted based on metagraph during `resync_metagraph`
uids = torch.tensor(list(range(len(safe_normalized_weights))))

min_allowed_weights = self.subtensor.min_allowed_weights(self.config.netuid)
max_weight_limit = self.subtensor.max_weight_limit(self.config.netuid)
logger.debug(f"min_allowed_weights: {min_allowed_weights}")
logger.debug(f"max_weight_limit: {max_weight_limit}")
(
final_uids,
final_weights,
) = bt.utils.weight_utils.process_weights_for_netuid( # type: ignore
uids=uids,
weights=safe_normalized_weights,
netuid=self.config.netuid,
subtensor=self.subtensor,
metagraph=self.metagraph,
)
logger.debug(f"weights:\n{safe_normalized_weights}")
logger.debug(f"uids:\n{uids}")

logger.debug(f"weights: {safe_normalized_weights}")
logger.debug(f"uids: {uids}")
_terminal_plot(
f"pre-processed weights, block: {self.block}",
safe_normalized_weights.numpy(),
)

# dependent on underlying `set_weights` call
result, message = await self.set_weights_in_thread(
uids, safe_normalized_weights
logger.debug(f"final weights:\n{final_weights}")
logger.debug(f"final uids:\n{final_uids}")

_terminal_plot(
f"final weights, block: {self.block}",
final_weights.numpy(),
)
if not result:
logger.error(f"Failed to set weights: {message}")

# dependent on underlying `set_weights` call
try:
result, message = await asyncio.wait_for(
self._set_weights(final_uids, final_weights), timeout=90
)
if not result:
logger.error(f"Failed to set weights: {message}")
return

logger.success(f"Set weights successfully: {message}")
except asyncio.TimeoutError:
logger.error("Setting weights timed out after 90 seconds")
return

logger.success(f"Set weights successfully: {message}")
return

async def set_weights_in_thread(self, uids: torch.Tensor, weights: torch.Tensor):
"""Wrapper function to set weights in a separate thread
async def _set_weights(self, uids: torch.Tensor, weights: torch.Tensor):
"""Wrapper function to set weights so we can ensure set weights happens
within a timeout.
Args:
uids (torch.Tensor): uids to set weights for
Expand All @@ -704,76 +728,52 @@ async def set_weights_in_thread(self, uids: torch.Tensor, weights: torch.Tensor)
Returns:
tuple[bool, str]: Returns the result of _set_weights function
"""
logger.trace("Attempting to set weights in another thread")

async def _set_weights() -> tuple[bool, str]:
"""LOCAL FUNCTION to set weights, we pass in a lock because of how
we are calling this function from the main thread, sending it
to a separate thread to avoid blocking the main thread, so the lock
MUST be acquired by the separate thread.
Returns:
tuple[bool, str]: Returns a tuple of a boolean and a string
- boolean: True if weights were set successfully, False otherwise
- string: Message indicating the result of set weights
"""
max_attempts = 5
attempt = 0
result = False
while attempt < max_attempts and not result:
try:
logger.debug(
f"Set weights attempt {attempt+1}/{max_attempts} at block: {self.block},time: {time.time()}"
)
try:
await asyncio.wait_for(
self._ensure_subtensor_ws_connected(), timeout=10
)
except asyncio.TimeoutError:
pass

result, message = self.subtensor.set_weights(
wallet=self.wallet,
netuid=self.config.netuid, # type: ignore
uids=uids.tolist(),
weights=weights.tolist(),
wait_for_finalization=True,
wait_for_inclusion=False,
version_key=self.spec_version,
max_retries=1,
)
if result:
logger.success(f"Set weights successfully: {message}")
return result, message

raise SetWeightsFailed(
f"Failed to set weights with message:{message}"
)

except Exception:
logger.warning(
f"Failed to set weights with attempt {attempt+1}/{max_attempts} due to: {message}"
max_attempts = 5
attempt = 0
result = False
while attempt < max_attempts and not result:
try:
logger.debug(
f"Set weights attempt {attempt+1}/{max_attempts} at block: {self.block},time: {time.time()}"
)
try:
await asyncio.wait_for(
self._ensure_subtensor_ws_connected(), timeout=10
)
except asyncio.TimeoutError:
pass

if attempt == max_attempts:
logger.error("Max attempts reached. Could not set weights.")
return False, "Max attempts reached"
result, message = self.subtensor.set_weights(
wallet=self.wallet,
netuid=self.config.netuid, # type: ignore
uids=uids.tolist(),
weights=weights.tolist(),
wait_for_finalization=True,
wait_for_inclusion=False,
version_key=self.spec_version,
max_retries=1,
)
if result:
logger.success(f"Set weights successfully: {message}")
return result, message

await asyncio.sleep(12)
finally:
attempt += 1
raise SetWeightsFailed(f"Failed to set weights with message:{message}")

return False, "Max attempts reached"
except Exception:
logger.warning(
f"Failed to set weights with attempt {attempt+1}/{max_attempts} due to: {message}"
)

logger.trace("Submitting callable func to executor")
if attempt == max_attempts:
logger.error("Max attempts reached. Could not set weights.")
return False, "Max attempts reached"

try:
result, message = await asyncio.wait_for(_set_weights(), timeout=90)
except asyncio.TimeoutError:
logger.error("Setting weights timed out after 90 seconds")
return False, "Failed to set weights within time limit"
await asyncio.sleep(12)
finally:
attempt += 1

return result, message
return False, "Max attempts reached"

async def resync_metagraph(self):
"""Resyncs the metagraph and updates the hotkeys and moving averages based on the new metagraph."""
Expand Down

0 comments on commit 8cd6ad2

Please sign in to comment.