Skip to content

Commit

Permalink
fix(adserver): adserver returns cloudevents compatible response (#5348)
Browse files Browse the repository at this point in the history
* modify tests to check that adserver returns CE-compatible responses

* refactor server post handler to return CE-compatible responses
  • Loading branch information
michaelcheah authored Feb 19, 2024
1 parent 33dc760 commit 1f32089
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 39 deletions.
8 changes: 4 additions & 4 deletions components/alibi-detect-server/adserver/base/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import logging
import tempfile
from distutils.util import strtobool
from typing import Optional

ARTIFACT_DOWNLOAD_LOCATION = os.environ.get("DRIFT_ARTIFACTS_DIR", "/tmp")

Expand All @@ -18,13 +18,13 @@


class Rclone:
def __init__(self, cfg_file: str = None):
def __init__(self, cfg_file: Optional[str] = None):
self.cfg_file = cfg_file

def copy(self, src: str, dest: str = None):
def copy(self, src: str, dest: Optional[str] = None):
if rclone is None:
raise RuntimeError(
"rclone binary not found - rclone-based storage funcionality disabled"
"rclone binary not found - rclone-based storage functionality disabled"
)

if dest is None:
Expand Down
4 changes: 2 additions & 2 deletions components/alibi-detect-server/adserver/cm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
SELDON_PREDICTOR_ID = DEFAULT_LABELS["predictor_name"]


def _load_class_module(module_path: str) -> str:
def _load_class_module(module_path: str):
components = module_path.split(".")
mod = __import__(".".join(components[:-1]))
for comp in components[1:]:
Expand All @@ -32,7 +32,7 @@ def _load_class_module(module_path: str) -> str:

class CustomMetricsModel(CEModel): # pylint:disable=c-extension-no-member
def __init__(
self, name: str, storage_uri: str, elasticsearch_uri: str = None, model=None
self, name: str, storage_uri: str, elasticsearch_uri: Optional[str] = None, model=None
):
"""
Custom Metrics Model
Expand Down
104 changes: 71 additions & 33 deletions components/alibi-detect-server/adserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
event_type: str,
event_source: str,
http_port: int = DEFAULT_HTTP_PORT,
reply_url: str = None,
reply_url: Optional[str] = None,
):
"""
CloudEvents server
Expand Down Expand Up @@ -146,29 +146,21 @@ def get_request_handler(protocol, request: Dict) -> RequestHandler:
raise Exception(f"Unknown protocol {protocol}")


def sendCloudEvent(event: v1.Event, url: str):
def forward_request(headers, data, url):
"""
Send CloudEvent
Forward request
Parameters
----------
event
CloudEvent to send
headers
Headers to forward
data
Data to forward
url
Url to send event
Url to forward to
"""
http_marshaller = marshaller.NewDefaultHTTPMarshaller()
binary_headers, binary_data = http_marshaller.ToRequest(
event, converters.TypeBinary, json.dumps
)

logging.info("binary CloudEvent")
for k, v in binary_headers.items():
logging.info("{0}: {1}\r\n".format(k, v))
logging.info(binary_data)

response = requests.post(url, headers=binary_headers, data=binary_data)
response = requests.post(url, headers=headers, data=data)
response.raise_for_status()


Expand Down Expand Up @@ -252,27 +244,73 @@ def post(self):
else:
logging.error("Metrics returned are invalid: " + str(runtime_metrics))

if response.data is not None:
revent = create_cloud_event(
response.data,
self.event_type,
self.event_source,
event_id=event.EventID(),
extensions=event.Extensions(),
)

if response.data is not None:
# Create event from response if reply_url is active
revent_headers, revent_data = http_marshaller.ToRequest(
revent, converters.TypeBinary, json.dumps
)

if not self.reply_url == "":
if event.EventID() is None or event.EventID() == "":
resp_event_id = uuid.uuid1().hex
else:
resp_event_id = event.EventID()
revent = (
v1.Event()
.SetContentType("application/json")
.SetData(response.data)
.SetEventID(resp_event_id)
.SetSource(self.event_source)
.SetEventType(self.event_type)
.SetExtensions(event.Extensions())
)
logging.debug(json.dumps(revent.Properties()))
sendCloudEvent(revent, self.reply_url)
self.write(json.dumps(response.data))
logging.info("binary CloudEvent")
for k, v in revent_headers.items():
logging.info("{0}: {1}\r\n".format(k, v))
logging.info(revent_data)
forward_request(revent_headers, revent_data, self.reply_url)

self.set_header("Content-Type", "application/json")
for headers in revent_headers:
self.set_header(headers, revent_headers[headers])
self.write(revent_data)


def create_cloud_event(
data: dict,
event_type: str,
event_source: str,
extensions: dict,
event_id: str = None,
) -> v1.Event:
"""
Create a CloudEvent
Parameters
----------
data
The data to send
event_type
The CE event type
event_source
The CE event source
extensions
Any extensions to add
event_id
The event id
Returns
-------
A CloudEvent
"""
if event_id is None or event_id == "":
event_id = uuid.uuid1().hex

event = (
v1.Event()
.SetData(data)
.SetEventID(event_id if event_id else str(uuid.uuid1().hex))
.SetSource(event_source)
.SetEventType(event_type)
.SetExtensions(extensions)
)
return event

class LivenessHandler(tornado.web.RequestHandler):
def get(self):
Expand Down
23 changes: 23 additions & 0 deletions components/alibi-detect-server/adserver/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import List, Dict, Optional, Union
import json
import requests_mock
from cloudevents.sdk import converters
from cloudevents.sdk import marshaller
from cloudevents.sdk.event import v1


class TestProtocol(AsyncHTTPTestCase):
Expand Down Expand Up @@ -74,11 +77,31 @@ def test_basic(self):
)
self.assertEqual(response.code, 200)
expectedResponse = DummyModel.getResponse().data
# assert that the expected response conforms to the CloudEvent spec
event = v1.Event()
http_marshaller = marshaller.NewDefaultHTTPMarshaller()
try:
event = http_marshaller.FromRequest(
event, response.headers, response.body, json.loads
)
except Exception as e:
assert False, f"Failed to unmarshall data with error: {type(e).__name__}('{e}')"

# assert cloud event properties have been set correctly in response
self.assertEqual(event.Data(), expectedResponse)
self.assertEqual(event.Source(), self.eventSource)
self.assertEqual(event.EventType(), self.eventType)
self.assertEqual(event.ContentType(), "application/json")
self.assertEqual(event.EventID(), "1234")
self.assertEqual(event.CloudEventVersion(), "1.0")
self.assertEqual(response.body.decode("utf-8"), json.dumps(expectedResponse))

# assert requests have been made with the correct headers and data
self.assertEqual(m.request_history[0].json(), expectedResponse)
headers: Dict = m.request_history[0]._request.headers
self.assertEqual(headers["ce-source"], self.eventSource)
self.assertEqual(headers["ce-type"], self.eventType)
self.assertNotIn("ce-datacontenttype", headers)


class TestKFservingV2HttpModel(AsyncHTTPTestCase):
Expand Down

0 comments on commit 1f32089

Please sign in to comment.