Skip to content

Commit

Permalink
feat: simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
tedbee authored and karootplx committed Dec 9, 2024
1 parent ebb72f5 commit ed110e3
Show file tree
Hide file tree
Showing 8 changed files with 482 additions and 13 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/docker_build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
- dev
- staging
- main
- simulator

jobs:
docker_publish:
Expand All @@ -28,7 +29,7 @@ jobs:
echo "BRANCH_NAME=$SANITIZED_BRANCH_NAME" >> $GITHUB_ENV
- name: Build and Push Docker Image with Branch Tag
if: github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/staging' || github.ref == 'refs/heads/main'
if: github.ref == 'refs/heads/dev' || github.ref == 'refs/heads/staging' || github.ref == 'refs/heads/main' || github.ref == 'refs/heads/simulator'
uses: macbre/push-to-ghcr@master
with:
image_name: ${{ github.repository }}
Expand Down
24 changes: 16 additions & 8 deletions commons/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,26 @@ class ObjectManager:

@classmethod
def get_miner(cls):
from neurons.miner import Miner

if cls._miner is None:
cls._miner = Miner()
if get_config().simulation:
from simulator.miner import MinerSim
if cls._miner is None:
cls._miner = MinerSim()
else:
from neurons.miner import Miner
if cls._miner is None:
cls._miner = Miner()
return cls._miner

@classmethod
def get_validator(cls):
from neurons.validator import Validator

if cls._validator is None:
cls._validator = Validator()
if get_config().simulation:
from simulator.validator import ValidatorSim
if cls._validator is None:
cls._validator = ValidatorSim()
else:
from neurons.validator import Validator
if cls._validator is None:
cls._validator = Validator()
return cls._validator

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion dojo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def get_dojo_api_base_url() -> str:
if base_url is None:
raise ValueError("DOJO_API_BASE_URL is not set in the environment.")

return base_url
return base_url
12 changes: 12 additions & 0 deletions dojo/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,18 @@ def add_args(parser):
help="Whether to run in fast mode, for developers to test locally.",
)

parser.add_argument(
"--simulation",
action="store_true",
help="Whether to run the validator in simulation mode",
)

parser.add_argument(
"--simulation_bad_miner",
action="store_true",
help="Set miner simluation to a bad one",
)

epoch_length = 100
known_args, _ = parser.parse_known_args()
if known_args := vars(known_args):
Expand Down
27 changes: 24 additions & 3 deletions entrypoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ if [ "$1" = 'miner' ]; then
echo "SUBTENSOR_ENDPOINT: ${SUBTENSOR_ENDPOINT}"
echo "NETUID: ${NETUID}"

EXTRA_ARGS=""
if [ "${SIMULATION}" = "true" ]; then
EXTRA_ARGS="${EXTRA_ARGS} --simulation"
fi
if [ "${FAST_MODE}" = "true" ]; then
EXTRA_ARGS="${EXTRA_ARGS} --fast_mode"
fi
if [ "${SIMULATION_BAD_MINER}" = "true" ]; then
EXTRA_ARGS="${EXTRA_ARGS} --simulation_bad_miner"
fi

python main_miner.py \
--netuid ${NETUID} \
--subtensor.network ${SUBTENSOR_NETWORK} \
Expand All @@ -29,7 +40,8 @@ if [ "$1" = 'miner' ]; then
--wallet.name ${WALLET_COLDKEY} \
--wallet.hotkey ${WALLET_HOTKEY} \
--axon.port ${AXON_PORT} \
--neuron.type miner
--neuron.type miner \
${EXTRA_ARGS}
fi

# If the first argument is 'validator', run the validator script
Expand All @@ -43,6 +55,14 @@ if [ "$1" = 'validator' ]; then
echo "NETUID: ${NETUID}"
echo "WANDB_PROJECT_NAME: ${WANDB_PROJECT_NAME}"

EXTRA_ARGS=""
if [ "${SIMULATION}" = "true" ]; then
EXTRA_ARGS="${EXTRA_ARGS} --simulation"
fi
if [ "${FAST_MODE}" = "true" ]; then
EXTRA_ARGS="${EXTRA_ARGS} --fast_mode"
fi

python main_validator.py \
--netuid ${NETUID} \
--subtensor.network ${SUBTENSOR_NETWORK} \
Expand All @@ -51,5 +71,6 @@ if [ "$1" = 'validator' ]; then
--wallet.name ${WALLET_COLDKEY} \
--wallet.hotkey ${WALLET_HOTKEY} \
--neuron.type validator \
--wandb.project_name ${WANDB_PROJECT_NAME}
fi
--wandb.project_name ${WANDB_PROJECT_NAME} \
${EXTRA_ARGS}
fi
Empty file added simulator/__init__.py
Empty file.
173 changes: 173 additions & 0 deletions simulator/miner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import redis
import traceback
import asyncio
import random
import json
from datetime import datetime, timezone
from dojo.utils.config import get_config
from bittensor.btlogging import logging as logger
from neurons.miner import Miner
from dojo.protocol import (
FeedbackRequest,
TaskResultRequest,
TaskResult,
Result
)
from commons.utils import get_new_uuid


class MinerSim(Miner):
def __init__(self):
super().__init__()
try:
# Initialize Redis connection
host = os.getenv("REDIS_HOST", "localhost")
port = int(os.getenv("REDIS_PORT", 6379))
self.redis_client = redis.Redis(
host=host,
port=port,
db=0,
decode_responses=True
)
logger.info("Redis connection established")

self._configure_simulation()

self.is_bad_miner = get_config().simulation_bad_miner
logger.info(f"Miner role set to: {'bad' if self.is_bad_miner else 'good'}")

logger.info("Starting Miner Simulator")
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
raise

def _configure_simulation(self):
"""Configure simulation parameters with environment variables or defaults."""
self.response_behaviors = {
'normal': float(os.getenv("SIM_NORMAL_RESP_PROB", 0.8)),
'no_response': float(os.getenv("SIM_NO_RESP_PROB", 0.1)),
'timeout': float(os.getenv("SIM_TIMEOUT_PROB", 0.1))
}

async def forward_feedback_request(self, synapse: FeedbackRequest) -> FeedbackRequest:
try:
# Validate that synapse, dendrite, dendrite.hotkey, and response are not None
if not synapse or not synapse.dendrite or not synapse.dendrite.hotkey:
logger.error("Invalid synapse: dendrite or dendrite.hotkey is None.")
return synapse

if not synapse.completion_responses:
logger.error("Invalid synapse: response field is None.")
return synapse

# Empty out completion response since not needed in simulator
new_synapse = synapse.model_copy(deep=True)
new_synapse.completion_responses = []

synapse.dojo_task_id = synapse.request_id
self.hotkey_to_request[synapse.dendrite.hotkey] = synapse

redis_key = f"feedback:{synapse.request_id}"
self.redis_client.set(
redis_key,
new_synapse.model_dump_json(),
ex=86400 # expire after 24 hours
)
logger.info(f"Stored feedback request {synapse.request_id}")

synapse.ground_truth = {}
return synapse

except Exception as e:
logger.error(f"Error handling FeedbackRequest: {e}")
traceback.print_exc()
return synapse

async def forward_task_result_request(self, synapse: TaskResultRequest) -> TaskResultRequest | None:
try:
logger.info(f"Received TaskResultRequest for task id: {synapse.task_id}")
if not synapse or not synapse.task_id:
logger.error("Invalid TaskResultRequest: missing task_id")
return None

# Simulate different response behaviors
behavior = self._get_response_behavior()

if behavior in ['no_response', 'timeout']:
logger.debug(f"Simulating {behavior} for task {synapse.task_id}")
if behavior == 'timeout':
await asyncio.sleep(30)
return None

redis_key = f"feedback:{synapse.task_id}"
request_data = self.redis_client.get(redis_key)

request_dict = json.loads(request_data) if request_data else None
feedback_request = FeedbackRequest(**request_dict) if request_dict else None

if not feedback_request:
logger.debug(f"No task result found for task id: {synapse.task_id}")
return None

current_time = datetime.now(timezone.utc).isoformat()

task_results = []
for criteria_type in feedback_request.criteria_types:
result = Result(
type=criteria_type.type,
value=self._generate_scores(feedback_request.ground_truth)
)

task_result = TaskResult(
id=get_new_uuid(),
status='COMPLETED',
created_at=current_time,
updated_at=current_time,
result_data=[result],
worker_id=get_new_uuid(),
task_id=synapse.task_id
)
task_results.append(task_result)

synapse.task_results = task_results
logger.info(f"TaskResultRequest: {synapse}")

self.redis_client.delete(redis_key)
logger.debug(f"Processed task result for task {synapse.task_id}")

return synapse

except Exception as e:
traceback.print_exc()
logger.error(f"Error handling TaskResultRequest: {e}")
return None

def _get_response_behavior(self) -> str:
"""Determine the response behavior based on configured probabilities."""
return random.choices(
list(self.response_behaviors.keys()),
weights=list(self.response_behaviors.values())
)[0]

def _generate_scores(self, ground_truth: dict) -> dict:
scores = {}

for k, v in ground_truth.items():
if self.is_bad_miner:
deviation = random.randint(-5, 5)
else:
deviation = random.randint(-2, 2)
random_score = max(1, min(10, v + deviation))
score = int((random_score / (10 - 1)) * (100 - 1) + 1)
scores[k] = score

return scores

# def __del__(self):
# """Cleanup Redis connection on object destruction"""
# try:
# self.redis_client.close()
# logger.info("Redis connection closed")
# except:
# pass
Loading

0 comments on commit ed110e3

Please sign in to comment.