Skip to content

Commit

Permalink
feat: auto discovery of indexer-service in k8s
Browse files Browse the repository at this point in the history
Signed-off-by: Alexis Asseman <alexis@semiotic.ai>
  • Loading branch information
aasseman committed Apr 4, 2023
1 parent db30582 commit c88d9df
Show file tree
Hide file tree
Showing 9 changed files with 1,526 additions and 1,155 deletions.
21 changes: 19 additions & 2 deletions autoagora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,29 @@ def init_config(argv: Optional[Sequence[str]] = None):
#
# Query volume metrics
#
argparser.add_argument(

indexer_service_metrics_endpoint_group = argparser.add_argument_group(
"Indexer-service metrics endpoint. Exactly one argument required"
)
indexer_service_metrics_endpoint_exclusive_group = (
indexer_service_metrics_endpoint_group.add_mutually_exclusive_group(
required=True
)
)
indexer_service_metrics_endpoint_exclusive_group.add_argument(
"--indexer-service-metrics-endpoint",
env_var="INDEXER_SERVICE_METRICS_ENDPOINT",
required=True,
help="HTTP endpoint for the indexer-service metrics. Can be a comma-separated for multiple endpoints.",
)
indexer_service_metrics_endpoint_exclusive_group.add_argument(
"--indexer-service-metrics-k8s-service",
env_var="INDEXER_SERVICE_METRICS_K8S_SERVICE",
help="""
Kubernetes service name for the indexer-service and pod port serving its
metrics. Will watch the service's endpoint IPs continuously for changes.
Format: <scheme>://<service_name>:<pod_metrics_port>/<path>.
""",
)

#
# Price multiplier (Absolute price)
Expand Down
90 changes: 90 additions & 0 deletions autoagora/k8s_service_watcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2022-, Semiotic AI, Inc.
# SPDX-License-Identifier: Apache-2.0

import asyncio as aio
import logging

from kubernetes import client, config, watch
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException


class K8SServiceEndpointsWatcher:
def __init__(self, service_name: str) -> None:
"""Maintains an automatically, asynchronously updated list of endpoints backing
a kubernetes service in the current namespace.
This is supposed to be run from within a Kubernetes pod. The pod will need a
role that grants it:
```
rules:
- apiGroups: [""]
resources: ["endpoints"]
verbs: ["watch"]
```
Args:
service_name (str): Kubernetes service name.
Raises:
FileNotFoundError: couldn't find
`/var/run/secrets/kubernetes.io/serviceaccount/namespace`, which is
expected when running within a Kubernetes pod container.
"""
self.endpoint_ips = []
self._service_name = service_name

try:
with open(
"/var/run/secrets/kubernetes.io/serviceaccount/namespace", "r"
) as f:
self._namespace = f.read().strip()
except FileNotFoundError:
logging.exception("Probably not running in Kubernetes.")
raise

# Starts the async _loop immediately
self._future = aio.ensure_future(self._watch_loop())

async def _watch_loop(self) -> None:
"""Restarts the k8s watch on expiration."""
while True:
try:
await self._watch()
except ApiException as api_exc:
if api_exc.status == watch.watch.HTTP_STATUS_GONE:
logging.debug("k8s_service_watcher 410 timeout.")
else:
raise
logging.debug("k8s_service_watcher restarted")

async def _watch(self) -> None:
"""Watches for changes in k8s service endpoints."""
config.load_incluster_config()

api = ApiClient()
v1 = client.CoreV1Api(api)
w = watch.Watch()
event_stream = w.stream(
v1.list_namespaced_endpoints,
namespace=self._namespace,
field_selector=f"metadata.name={self._service_name}",
)

loop = aio.get_running_loop()

while event := await loop.run_in_executor(None, next, event_stream):
result = event["object"] # type: ignore

self.endpoint_ips = [
address.ip
for subset in result.subsets # type: ignore
for address in subset.addresses # type: ignore
]

logging.debug(
"Got endpoint IPs for service %s: %s",
self._service_name,
self.endpoint_ips,
)
17 changes: 16 additions & 1 deletion autoagora/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
from autoagora.indexer_utils import get_allocated_subgraphs, set_cost_model
from autoagora.model_builder import model_update_loop
from autoagora.price_multiplier import price_bandit_loop
from autoagora.query_metrics import (
K8SServiceWatcherMetricsEndpoints,
StaticMetricsEndpoints,
)

init_config()

Expand All @@ -36,6 +40,7 @@ async def allocated_subgraph_watcher():
(args.relative_query_costs_exclude_subgraphs or "").split(",")
)

# Initialize connection pool to PG database
try:
pgpool = await asyncpg.create_pool(
host=args.postgres_host,
Expand All @@ -53,6 +58,16 @@ async def allocated_subgraph_watcher():
)
raise

# Initialize indexer-service metrics endpoints
if args.indexer_service_metrics_endpoint: # static list
metrics_endpoints = StaticMetricsEndpoints(
args.indexer_service_metrics_endpoint
)
else: # auto from k8s
metrics_endpoints = K8SServiceWatcherMetricsEndpoints(
args.indexer_service_metrics_k8s_service
)

while True:
try:
allocated_subgraphs = (await get_allocated_subgraphs()) - excluded_subgraphs
Expand Down Expand Up @@ -85,7 +100,7 @@ async def allocated_subgraph_watcher():

# Launch the price multiplier update loop for the new subgraph
update_loops[new_subgraph].bandit = aio.ensure_future(
price_bandit_loop(new_subgraph, pgpool)
price_bandit_loop(new_subgraph, pgpool, metrics_endpoints)
)
logging.info(
"Added price multiplier update loop for subgraph %s", new_subgraph
Expand Down
9 changes: 6 additions & 3 deletions autoagora/price_multiplier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from autoagora.config import args
from autoagora.price_save_state_db import PriceSaveStateDB
from autoagora.query_metrics import MetricsEndpoints
from autoagora.subgraph_wrapper import SubgraphWrapper

reward_gauge = Gauge(
Expand All @@ -34,7 +35,9 @@
)


async def price_bandit_loop(subgraph: str, pgpool: asyncpg.Pool):
async def price_bandit_loop(
subgraph: str, pgpool: asyncpg.Pool, metrics_endpoints: MetricsEndpoints
):
try:
# Instantiate environment.
environment = SubgraphWrapper(subgraph)
Expand Down Expand Up @@ -85,7 +88,7 @@ async def price_bandit_loop(subgraph: str, pgpool: asyncpg.Pool):

# Update the save state
# NOTE: `bid_scale` is specific to "scaled_gaussian" agent action type
logging.debug("Price bandit %s - Saving state to DB.")
logging.debug("Price bandit %s - Saving state to DB.", subgraph)
await save_state_db.save_state(
subgraph=subgraph,
mean=bandit.bid_scale(bandit.mean().item()),
Expand All @@ -106,7 +109,7 @@ async def price_bandit_loop(subgraph: str, pgpool: asyncpg.Pool):
# 3. Get the reward.
# Get queries per second.
queries_per_second = await environment.queries_per_second(
args.qps_observation_duration
metrics_endpoints, args.qps_observation_duration
)
logging.debug(
"Price bandit %s - Queries per second: %s", subgraph, queries_per_second
Expand Down
104 changes: 93 additions & 11 deletions autoagora/query_metrics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,95 @@
# Copyright 2022-, Semiotic AI, Inc.
# SPDX-License-Identifier: Apache-2.0

import asyncio
import logging
import re
from abc import ABC, abstractmethod
from functools import reduce
from typing import List
from urllib.parse import urlparse

import aiohttp
import backoff

from autoagora.config import args
from autoagora.k8s_service_watcher import K8SServiceEndpointsWatcher


class MetricsEndpoints(ABC):
"""Defines an interface for an object that provides a list of metrics endpoints."""

def __init__(self) -> None:
super().__init__()

@abstractmethod
def __call__(self) -> List[str]:
"""Retrieves a list of metrics endpoints.
Returns:
List[str]: A list of strings representing the metrics endpoints.
"""
pass


class StaticMetricsEndpoints(MetricsEndpoints):
"""Defines an interface for an object that provides a list of static metrics endpoints."""

def __init__(self, comma_separated_endpoints: str) -> None:
"""Initializes a new instance of StaticMetricsEndpoints with the given
comma-separated string of metrics endpoints.
Args:
comma_separated_endpoints (str): A comma-separated string of metrics
endpoints.
"""
super().__init__()
self._endpoints = comma_separated_endpoints.split(",")

def __call__(self) -> List[str]:
"""Returns a list of metrics endpoints.
Returns:
List[str]: A list of metrics endpoints.
"""
return self._endpoints


class K8SServiceWatcherMetricsEndpoints(MetricsEndpoints):
"""Implementation of MetricsEndpoints that returns a continuously-updating a list of
metrics endpoints from a Kubernetes service URL."""

def __init__(self, url: str) -> None:
"""Initializes a new instance of K8SServiceWatcherMetricsEndpoints with the
given Kubernetes service URL.
Args:
url (str): A string representing the Kubernetes service URL in the format
<scheme>://<service_name>:<pod_metrics_port>/<path>.
"""
super().__init__()
self._parsed_url = urlparse(url)
# Assuming the "hostname" is actually the k8s service name, as indicated in the
# arguments documentation.

service_name = self._parsed_url.hostname
# Check that service_name is non-empty
assert service_name, "k8s service name is empty."
# Check that service_name is a valid RFC-1123 DNS label
assert re.fullmatch(
r"[a-z0-9]([-a-z0-9]*[a-z0-9])?", service_name
), "Invalid k8s service name."
self._k8s_service_watcher = K8SServiceEndpointsWatcher(service_name)

def __call__(self) -> List[str]:
"""Retrieves a list of metrics endpoints.
Returns:
List[str]: A list of strings representing the metrics endpoints.
"""
port = self._parsed_url.port
return [
self._parsed_url._replace(netloc=f"{endpoint_ip}:{port}").geturl()
for endpoint_ip in self._k8s_service_watcher.endpoint_ips
]


class HTTPError(Exception):
Expand All @@ -19,8 +99,10 @@ class HTTPError(Exception):
@backoff.on_exception(
backoff.expo, (aiohttp.ClientError, HTTPError), max_time=30, logger=logging.root
)
async def subgraph_query_count(subgraph: str) -> int:
endpoints = args.indexer_service_metrics_endpoint.split(",")
async def subgraph_query_count(
subgraph: str, metrics_endpoints: MetricsEndpoints
) -> int:
endpoints = metrics_endpoints()
results = []
async with aiohttp.ClientSession() as session:
for endpoint in endpoints:
Expand All @@ -37,6 +119,13 @@ async def subgraph_query_count(subgraph: str) -> int:
)
)

logging.debug(
"Number of queries for subgraph %s from %s: %s",
subgraph,
endpoint,
results[-1:], # Will return empty list if empty, instead of error
)

if len(results) == 0:
# The subgraph query count will not be in the metric if it hasn't received any
# queries.
Expand All @@ -45,10 +134,3 @@ async def subgraph_query_count(subgraph: str) -> int:
return int(results[0])
else:
return reduce(lambda x, y: int(x) + int(y), results)


if __name__ == "__main__":
res = asyncio.run(
subgraph_query_count("Qmaz1R8vcv9v3gUfksqiS9JUz7K9G8S5By3JYn8kTiiP5K")
)
print(res)
21 changes: 17 additions & 4 deletions autoagora/subgraph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from time import time
from typing import Optional

import aiohttp
import backoff

from autoagora.indexer_utils import get_cost_variables, set_cost_model
from autoagora.query_metrics import subgraph_query_count
from autoagora.query_metrics import HTTPError, MetricsEndpoints, subgraph_query_count


class SubgraphWrapper:
Expand All @@ -22,19 +25,29 @@ async def set_cost_multiplier(self, cost_multiplier: float):
await set_cost_model(self.subgraph, variables=cost_variables)
self.last_change_time = time()

async def queries_per_second(self, average_duration: float = 1):
# Timeout occurs when e.g. restarting the indexer-service. So we give it up to 10
# minutes to recover.
# Though subgraph_query_count has its own backoff, it is shorter. If the
# indexer-service outage lasts for longer, we prefer re-running the whole
# queries_per_second.
@backoff.on_exception(
backoff.expo, (aiohttp.ClientError, HTTPError), max_time=600, max_tries=10
)
async def queries_per_second(
self, metrics_endpoints: MetricsEndpoints, average_duration: float = 1
):
# Wait for the gateway to take our new costs into account
if self.last_change_time is not None:
time_since_last_change = time() - self.last_change_time
if time_since_last_change < SubgraphWrapper.GATEWAY_DELAY:
await sleep(SubgraphWrapper.GATEWAY_DELAY - time_since_last_change)

query_count_1 = await subgraph_query_count(self.subgraph)
query_count_1 = await subgraph_query_count(self.subgraph, metrics_endpoints)
timestamp_1 = time()

await sleep(average_duration)

query_count_2 = await subgraph_query_count(self.subgraph)
query_count_2 = await subgraph_query_count(self.subgraph, metrics_endpoints)
timestamp_2 = time()

queries_per_second = (query_count_2 - query_count_1) / (
Expand Down
Loading

0 comments on commit c88d9df

Please sign in to comment.