From 219c603396f518707398555c63bd655968b9b5a4 Mon Sep 17 00:00:00 2001 From: Richard Marmorstein <52928443+richardm-stripe@users.noreply.github.com> Date: Mon, 27 Nov 2023 08:37:45 -0800 Subject: [PATCH] Refactor integration test (#1145) * Refactor integration test * Update tests/test_integration.py --- tests/test_integration.py | 220 ++++++++++++++++++-------------------- 1 file changed, 106 insertions(+), 114 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index 6673f6175..3fdcf6b48 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -7,6 +7,9 @@ import stripe import pytest +from queue import Queue +from collections import defaultdict +from typing import List, Dict, Tuple, Optional if platform.python_implementation() == "PyPy": pytest.skip("skip integration tests with PyPy", allow_module_level=True) @@ -17,6 +20,57 @@ from http.server import BaseHTTPRequestHandler, HTTPServer +class TestHandler(BaseHTTPRequestHandler): + num_requests = 0 + + requests = defaultdict(Queue) + + @classmethod + def _add_request(cls, req): + q = cls.requests[id(cls)] + q.put(req) + + @classmethod + def get_requests(cls, n) -> List[BaseHTTPRequestHandler]: + reqs = [] + for _ in range(n): + reqs.append(cls.requests[id(cls)].get(False)) + + assert cls.requests[id(cls)].empty() + return reqs + + def do_GET(self): + return self._do_request() + + def do_POST(self): + return self._do_request() + + def _do_request(self): + n = self.__class__.num_requests + self.__class__.num_requests += 1 + self._add_request(self) + + provided_status, provided_headers, provided_body = self.do_request(n) + status = provided_status or self.default_status + headers = provided_headers or self.default_headers + body = provided_body or self.default_body + self.send_response(status) + for header_name, header_value in headers.items(): + self.send_header(header_name, header_value) + self.end_headers() + self.wfile.write(body) + return + + default_status = 200 + default_headers = {"Content-Type": "application/json; charset=utf-8"} + default_body = json.dumps({}).encode("utf-8") + + def do_request( + self, n: int + ) -> Tuple[Optional[int], Optional[Dict[str, str]], Optional[bytes]]: + return (self.default_status, self.default_headers, self.default_body) + + class TestIntegration(object): @pytest.fixture(autouse=True) def close_mock_server(self): @@ -63,40 +117,19 @@ def setup_mock_server(self, handler): self.mock_server_thread.start() def test_hits_api_base(self): - class MockServerRequestHandler(BaseHTTPRequestHandler): - num_requests = 0 - - def do_GET(self): - self.__class__.num_requests += 1 - - self.send_response(200) - self.send_header( - "Content-Type", "application/json; charset=utf-8" - ) - self.end_headers() - self.wfile.write(json.dumps({}).encode("utf-8")) - return + class MockServerRequestHandler(TestHandler): + pass self.setup_mock_server(MockServerRequestHandler) stripe.api_base = "http://localhost:%s" % self.mock_server_port stripe.Balance.retrieve() - assert MockServerRequestHandler.num_requests == 1 + reqs = MockServerRequestHandler.get_requests(1) + assert reqs[0].path == "/v1/balance" def test_hits_proxy_through_default_http_client(self): - class MockServerRequestHandler(BaseHTTPRequestHandler): - num_requests = 0 - - def do_GET(self): - self.__class__.num_requests += 1 - - self.send_response(200) - self.send_header( - "Content-Type", "application/json; charset=utf-8" - ) - self.end_headers() - self.wfile.write(json.dumps({}).encode("utf-8")) - return + class MockServerRequestHandler(TestHandler): + pass self.setup_mock_server(MockServerRequestHandler) @@ -116,19 +149,8 @@ def do_GET(self): assert MockServerRequestHandler.num_requests == 2 def test_hits_proxy_through_custom_client(self): - class MockServerRequestHandler(BaseHTTPRequestHandler): - num_requests = 0 - - def do_GET(self): - self.__class__.num_requests += 1 - - self.send_response(200) - self.send_header( - "Content-Type", "application/json; charset=utf-8" - ) - self.end_headers() - self.wfile.write(json.dumps({}).encode("utf-8")) - return + class MockServerRequestHandler(TestHandler): + pass self.setup_mock_server(MockServerRequestHandler) @@ -141,66 +163,21 @@ def do_GET(self): assert MockServerRequestHandler.num_requests == 1 def test_passes_client_telemetry_when_enabled(self): - class MockServerRequestHandler(BaseHTTPRequestHandler): + class MockServerRequestHandler(TestHandler): num_requests = 0 - def do_GET(self): - try: - self.__class__.num_requests += 1 - req_num = self.__class__.num_requests - if req_num == 1: - time.sleep(31 / 1000) # 31 ms - assert not self.headers.get( - "X-Stripe-Client-Telemetry" - ) - elif req_num == 2: - assert self.headers.get("X-Stripe-Client-Telemetry") - telemetry = json.loads( - self.headers.get("x-stripe-client-telemetry") - ) - assert "last_request_metrics" in telemetry - req_id = telemetry["last_request_metrics"][ - "request_id" - ] - duration_ms = telemetry["last_request_metrics"][ - "request_duration_ms" - ] - assert req_id == "req_1" - # The first request took 31 ms, so the client perceived - # latency shouldn't be outside this range. - assert 30 < duration_ms < 300 - else: - assert False, ( - "Should not have reached request %d" % req_num - ) - - self.send_response(200) - self.send_header( - "Content-Type", "application/json; charset=utf-8" - ) - self.send_header("Request-Id", "req_%d" % req_num) - self.end_headers() - self.wfile.write(json.dumps({}).encode("utf-8")) - except AssertionError as ex: - # Throwing assertions on the server side causes a - # connection error to be logged instead of an assertion - # failure. Instead, we return the assertion failure as - # json so it can be logged as a StripeError. - self.send_response(400) - self.send_header( - "Content-Type", "application/json; charset=utf-8" - ) - self.end_headers() - self.wfile.write( - json.dumps( - { - "error": { - "type": "invalid_request_error", - "message": str(ex), - } - } - ).encode("utf-8") - ) + def do_request(self, req_num): + if req_num == 0: + time.sleep(31 / 1000) # 31 ms + + return [ + 200, + { + "Content-Type": "application/json; charset=utf-8", + "Request-Id": "req_1", + }, + None, + ] self.setup_mock_server(MockServerRequestHandler) stripe.api_base = "http://localhost:%s" % self.mock_server_port @@ -208,34 +185,49 @@ def do_GET(self): stripe.Balance.retrieve() stripe.Balance.retrieve() + + reqs = MockServerRequestHandler.get_requests(2) assert MockServerRequestHandler.num_requests == 2 + # req 1 + assert not reqs[0].headers.get("x-stripe-client-telemetry") + # req 2 + telemetry_raw = reqs[1].headers.get("x-stripe-client-telemetry") + + assert telemetry_raw is not None + telemetry = json.loads(telemetry_raw) + assert "last_request_metrics" in telemetry + + duration_ms = telemetry["last_request_metrics"]["request_duration_ms"] + # The first request took 31 ms, so the client perceived + # latency shouldn't be outside this range. + assert 30 < duration_ms < 300 def test_uses_thread_local_client_telemetry(self): - class MockServerRequestHandler(BaseHTTPRequestHandler): - num_requests = 0 + class MockServerRequestHandler(TestHandler): + local_num_requests = 0 seen_metrics = set() stats_lock = Lock() - def do_GET(self): + def do_request(self, _n): with self.__class__.stats_lock: - self.__class__.num_requests += 1 - req_num = self.__class__.num_requests + self.__class__.local_num_requests += 1 + req_num = self.__class__.local_num_requests - if self.headers.get("X-Stripe-Client-Telemetry"): - telemetry = json.loads( - self.headers.get("X-Stripe-Client-Telemetry") - ) + raw_telemetry = self.headers.get("X-Stripe-Client-Telemetry") + if raw_telemetry: + telemetry = json.loads(raw_telemetry) req_id = telemetry["last_request_metrics"]["request_id"] with self.__class__.stats_lock: self.__class__.seen_metrics.add(req_id) - self.send_response(200) - self.send_header( - "Content-Type", "application/json; charset=utf-8" - ) - self.send_header("Request-Id", "req_%d" % req_num) - self.end_headers() - self.wfile.write(json.dumps({}).encode("utf-8")) + return [ + 200, + { + "Content-Type": "application/json; charset=utf-8", + "Request-Id": "req_%s" % (req_num), + }, + None, + ] self.setup_mock_server(MockServerRequestHandler) stripe.api_base = "http://localhost:%s" % self.mock_server_port