Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor integration test #1145

Merged
merged 3 commits into from
Nov 27, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 106 additions & 114 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import stripe
import pytest
from queue import Queue
from collections import defaultdict
from typing import List, Dict, Tuple, Optional

if platform.python_implementation() == "PyPy":
pytest.skip("skip integration tests with PyPy", allow_module_level=True)
Expand All @@ -17,6 +20,57 @@
from http.server import BaseHTTPRequestHandler, HTTPServer


class TestHandler(BaseHTTPRequestHandler):
num_requests = 0

requests = defaultdict(Queue)

@classmethod
def _add_request(cls, req):
richardm-stripe marked this conversation as resolved.
Show resolved Hide resolved
q = cls.requests[id(cls)]
q.put(req)

@classmethod
def get_requests(cls, n) -> List[BaseHTTPRequestHandler]:
reqs = []
for _ in range(n):
reqs.append(cls.requests[id(cls)].get(False))

assert cls.requests[id(cls)].empty()
return reqs

def do_GET(self):
return self._do_request()

def do_POST(self):
return self._do_request()

def _do_request(self):
n = self.__class__.num_requests
self.__class__.num_requests += 1
self._add_request(self)
richardm-stripe marked this conversation as resolved.
Show resolved Hide resolved

provided_status, provided_headers, provided_body = self.do_request(n)
status = provided_status or self.default_status
headers = provided_headers or self.default_headers
body = provided_body or self.default_body
self.send_response(status)
for header_name, header_value in headers.items():
self.send_header(header_name, header_value)
self.end_headers()
self.wfile.write(body)
return

default_status = 200
default_headers = {"Content-Type": "application/json; charset=utf-8"}
default_body = json.dumps({}).encode("utf-8")

def do_request(
self, n: int
) -> Tuple[Optional[int], Optional[Dict[str, str]], Optional[bytes]]:
return (self.default_status, self.default_headers, self.default_body)


class TestIntegration(object):
@pytest.fixture(autouse=True)
def close_mock_server(self):
Expand Down Expand Up @@ -63,40 +117,19 @@ def setup_mock_server(self, handler):
self.mock_server_thread.start()

def test_hits_api_base(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return
class MockServerRequestHandler(TestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)

stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.Balance.retrieve()
assert MockServerRequestHandler.num_requests == 1
reqs = MockServerRequestHandler.get_requests(1)
assert reqs[0].path == "/v1/balance"

def test_hits_proxy_through_default_http_client(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return
class MockServerRequestHandler(TestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)

Expand All @@ -116,19 +149,8 @@ def do_GET(self):
assert MockServerRequestHandler.num_requests == 2

def test_hits_proxy_through_custom_client(self):
richardm-stripe marked this conversation as resolved.
Show resolved Hide resolved
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0

def do_GET(self):
self.__class__.num_requests += 1

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return
class MockServerRequestHandler(TestHandler):
pass

self.setup_mock_server(MockServerRequestHandler)

Expand All @@ -141,101 +163,71 @@ def do_GET(self):
assert MockServerRequestHandler.num_requests == 1

def test_passes_client_telemetry_when_enabled(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
class MockServerRequestHandler(TestHandler):
num_requests = 0

def do_GET(self):
try:
self.__class__.num_requests += 1
req_num = self.__class__.num_requests
if req_num == 1:
time.sleep(31 / 1000) # 31 ms
assert not self.headers.get(
"X-Stripe-Client-Telemetry"
)
elif req_num == 2:
assert self.headers.get("X-Stripe-Client-Telemetry")
telemetry = json.loads(
self.headers.get("x-stripe-client-telemetry")
)
assert "last_request_metrics" in telemetry
req_id = telemetry["last_request_metrics"][
"request_id"
]
duration_ms = telemetry["last_request_metrics"][
"request_duration_ms"
]
assert req_id == "req_1"
# The first request took 31 ms, so the client perceived
# latency shouldn't be outside this range.
assert 30 < duration_ms < 300
else:
assert False, (
"Should not have reached request %d" % req_num
)

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.send_header("Request-Id", "req_%d" % req_num)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
except AssertionError as ex:
# Throwing assertions on the server side causes a
# connection error to be logged instead of an assertion
# failure. Instead, we return the assertion failure as
# json so it can be logged as a StripeError.
self.send_response(400)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.end_headers()
self.wfile.write(
json.dumps(
{
"error": {
"type": "invalid_request_error",
"message": str(ex),
}
}
).encode("utf-8")
)
def do_request(self, req_num):
if req_num == 0:
time.sleep(31 / 1000) # 31 ms

return [
200,
{
"Content-Type": "application/json; charset=utf-8",
"Request-Id": "req_1",
},
None,
]

self.setup_mock_server(MockServerRequestHandler)
stripe.api_base = "http://localhost:%s" % self.mock_server_port
stripe.enable_telemetry = True

stripe.Balance.retrieve()
stripe.Balance.retrieve()

reqs = MockServerRequestHandler.get_requests(2)
assert MockServerRequestHandler.num_requests == 2
# req 1
assert not reqs[0].headers.get("X-Stripe-Client-Telemetry")
richardm-stripe marked this conversation as resolved.
Show resolved Hide resolved
# req 2
telemetry_raw = reqs[1].headers.get("x-stripe-client-telemetry")
richardm-stripe marked this conversation as resolved.
Show resolved Hide resolved

assert telemetry_raw is not None
telemetry = json.loads(telemetry_raw)
assert "last_request_metrics" in telemetry

duration_ms = telemetry["last_request_metrics"]["request_duration_ms"]
# The first request took 31 ms, so the client perceived
# latency shouldn't be outside this range.
assert 30 < duration_ms < 300

def test_uses_thread_local_client_telemetry(self):
class MockServerRequestHandler(BaseHTTPRequestHandler):
num_requests = 0
class MockServerRequestHandler(TestHandler):
local_num_requests = 0
seen_metrics = set()
stats_lock = Lock()

def do_GET(self):
def do_request(self, _n):
with self.__class__.stats_lock:
self.__class__.num_requests += 1
req_num = self.__class__.num_requests
self.__class__.local_num_requests += 1
req_num = self.__class__.local_num_requests

if self.headers.get("X-Stripe-Client-Telemetry"):
telemetry = json.loads(
self.headers.get("X-Stripe-Client-Telemetry")
)
raw_telemetry = self.headers.get("X-Stripe-Client-Telemetry")
if raw_telemetry:
telemetry = json.loads(raw_telemetry)
req_id = telemetry["last_request_metrics"]["request_id"]
with self.__class__.stats_lock:
self.__class__.seen_metrics.add(req_id)

self.send_response(200)
self.send_header(
"Content-Type", "application/json; charset=utf-8"
)
self.send_header("Request-Id", "req_%d" % req_num)
self.end_headers()
self.wfile.write(json.dumps({}).encode("utf-8"))
return [
200,
{
"Content-Type": "application/json; charset=utf-8",
"Request-Id": "req_%s" % (req_num),
},
None,
]

self.setup_mock_server(MockServerRequestHandler)
stripe.api_base = "http://localhost:%s" % self.mock_server_port
Expand Down