Skip to content

Commit

Permalink
Refactor integration test (#1145)
Browse files Browse the repository at this point in the history
* Refactor integration test

* Update tests/test_integration.py
  • Loading branch information
richardm-stripe authored Nov 27, 2023
1 parent fdad6c8 commit 219c603
Showing 1 changed file with 106 additions and 114 deletions.
220 changes: 106 additions & 114 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import stripe
import pytest
from queue import Queue
from collections import defaultdict
from typing import List, Dict, Tuple, Optional

if platform.python_implementation() == "PyPy":
pytest.skip("skip integration tests with PyPy", allow_module_level=True)
Expand All @@ -17,6 +20,57 @@
from http.server import BaseHTTPRequestHandler, HTTPServer


class TestHandler(BaseHTTPRequestHandler):
num_requests = 0

requests = defaultdict(Queue)

@classmethod
def _add_request(cls, req):
q = cls.requests[id(cls)]
q.put(req)

@classmethod
def get_requests(cls, n) -> List[BaseHTTPRequestHandler]:
reqs = []
for _ in range(n):
reqs.append(cls.requests[id(cls)].get(False))

assert cls.requests[id(cls)].empty()
return reqs

def do_GET(self):
return self._do_request()

def do_POST(self):
return self._do_request()

def _do_request(self):
n = self.__class__.num_requests
self.__class__.num_requests += 1
self._add_request(self)

provided_status, provided_headers, provided_body = self.do_request(n)
status = provided_status or self.default_status
headers = provided_headers or self.default_headers
body = provided_body or self.default_body
self.send_response(status)
for header_name, header_value in headers.items():
self.send_header(header_name, header_value)
self.end_headers()
self.wfile.write(body)
return

default_status = 200
default_headers = {"Content-Type": "application/json; charset=utf-8"}
default_body = json.dumps({}).encode("utf-8")

def do_request(
self, n: int
) -> Tuple[Optional[int], Optional[Dict[str, str]], Optional[bytes]]:
return (self.default_status, self.default_headers, self.default_body)


class TestIntegration(object):
@pytest.fixture(autouse=True)
def close_mock_server(self):
Expand Down Expand Up @@ -63,40 +117,19 @@ def setup_mock_server(self, handler):
self.mock_server_thread.start()

def test_hits_api_base(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return
class MockServerRequestHandler(TestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)

stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.Balance.retrieve()
assert MockServerRequestHandler.num_requests == 1
reqs = MockServerRequestHandler.get_requests(1)
assert reqs[0].path == "/v1/balance"

def test_hits_proxy_through_default_http_client(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return
class MockServerRequestHandler(TestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)

Expand All @@ -116,19 +149,8 @@ def do_GET(self):
assert MockServerRequestHandler.num_requests == 2

def test_hits_proxy_through_custom_client(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return
class MockServerRequestHandler(TestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)

Expand All @@ -141,101 +163,71 @@ def do_GET(self):
assert MockServerRequestHandler.num_requests == 1

def test_passes_client_telemetry_when_enabled(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
class MockServerRequestHandler(TestHandler):
num_requests = 0

def do_GET(self):
try:
self.__class__.num_requests += 1
req_num = self.__class__.num_requests
if req_num == 1:
time.sleep(31 / 1000) # 31 ms
assert not self.headers.get(
"X-Stripe-Client-Telemetry"
)
elif req_num == 2:
assert self.headers.get("X-Stripe-Client-Telemetry")
telemetry = json.loads(
self.headers.get("x-stripe-client-telemetry")
)
assert "last_request_metrics" in telemetry
req_id = telemetry["last_request_metrics"][
"request_id"
]
duration_ms = telemetry["last_request_metrics"][
"request_duration_ms"
]
assert req_id == "req_1"
# The first request took 31 ms, so the client perceived
# latency shouldn't be outside this range.
assert 30 < duration_ms < 300
else:
assert False, (
"Should not have reached request %d" % req_num
)

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.send_header("Request-Id", "req_%d" % req_num)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
except AssertionError as ex:
# Throwing assertions on the server side causes a
# connection error to be logged instead of an assertion
# failure. Instead, we return the assertion failure as
# json so it can be logged as a StripeError.
self.send_response(400)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(
json.dumps(
{
"error": {
"type": "invalid_request_error",
"message": str(ex),
}
}
).encode("utf-8")
)
def do_request(self, req_num):
if req_num == 0:
time.sleep(31 / 1000) # 31 ms

return [
200,
{
"Content-Type": "application/json; charset=utf-8",
"Request-Id": "req_1",
},
None,
]

self.setup_mock_server(MockServerRequestHandler)
stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.enable_telemetry = True

stripe.Balance.retrieve()
stripe.Balance.retrieve()

reqs = MockServerRequestHandler.get_requests(2)
assert MockServerRequestHandler.num_requests == 2
# req 1
assert not reqs[0].headers.get("x-stripe-client-telemetry")
# req 2
telemetry_raw = reqs[1].headers.get("x-stripe-client-telemetry")

assert telemetry_raw is not None
telemetry = json.loads(telemetry_raw)
assert "last_request_metrics" in telemetry

duration_ms = telemetry["last_request_metrics"]["request_duration_ms"]
# The first request took 31 ms, so the client perceived
# latency shouldn't be outside this range.
assert 30 < duration_ms < 300

def test_uses_thread_local_client_telemetry(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0
class MockServerRequestHandler(TestHandler):
local_num_requests = 0
seen_metrics = set()
stats_lock = Lock()

def do_GET(self):
def do_request(self, _n):
with self.__class__.stats_lock:
self.__class__.num_requests += 1
req_num = self.__class__.num_requests
self.__class__.local_num_requests += 1
req_num = self.__class__.local_num_requests

if self.headers.get("X-Stripe-Client-Telemetry"):
telemetry = json.loads(
self.headers.get("X-Stripe-Client-Telemetry")
)
raw_telemetry = self.headers.get("X-Stripe-Client-Telemetry")
if raw_telemetry:
telemetry = json.loads(raw_telemetry)
req_id = telemetry["last_request_metrics"]["request_id"]
with self.__class__.stats_lock:
self.__class__.seen_metrics.add(req_id)

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.send_header("Request-Id", "req_%d" % req_num)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return [
200,
{
"Content-Type": "application/json; charset=utf-8",
"Request-Id": "req_%s" % (req_num),
},
None,
]

self.setup_mock_server(MockServerRequestHandler)
stripe.api_base = "http://localhost:%s" % self.mock_server_port
Expand Down

0 comments on commit 219c603

Please sign in to comment.