-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2273 from locustio/add-RestUser
Add RestUser
- Loading branch information
Showing
9 changed files
with
305 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import time | ||
from typing import Any, Callable | ||
import grpc | ||
import grpc.experimental.gevent as grpc_gevent | ||
from grpc_interceptor import ClientInterceptor | ||
from locust import User | ||
from locust.exception import LocustError | ||
|
||
# patch grpc so that it uses gevent instead of asyncio | ||
grpc_gevent.init_gevent() | ||
|
||
|
||
class LocustInterceptor(ClientInterceptor): | ||
def __init__(self, environment, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.env = environment | ||
|
||
def intercept( | ||
self, | ||
method: Callable, | ||
request_or_iterator: Any, | ||
call_details: grpc.ClientCallDetails, | ||
): | ||
response = None | ||
exception = None | ||
start_perf_counter = time.perf_counter() | ||
response_length = 0 | ||
try: | ||
response = method(request_or_iterator, call_details) | ||
response_length = response.result().ByteSize() | ||
except grpc.RpcError as e: | ||
exception = e | ||
|
||
self.env.events.request.fire( | ||
request_type="grpc", | ||
name=call_details.method, | ||
response_time=(time.perf_counter() - start_perf_counter) * 1000, | ||
response_length=response_length, | ||
response=response, | ||
context=None, | ||
exception=exception, | ||
) | ||
return response | ||
|
||
|
||
class GrpcUser(User): | ||
abstract = True | ||
stub_class = None | ||
|
||
def __init__(self, environment): | ||
super().__init__(environment) | ||
for attr_value, attr_name in ((self.host, "host"), (self.stub_class, "stub_class")): | ||
if attr_value is None: | ||
raise LocustError(f"You must specify the {attr_name}.") | ||
|
||
self._channel = grpc.insecure_channel(self.host) | ||
interceptor = LocustInterceptor(environment=environment) | ||
self._channel = grpc.intercept_channel(self._channel, interceptor) | ||
|
||
self.stub = self.stub_class(self._channel) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,87 +1,21 @@ | ||
# make sure you use grpc version 1.39.0 or later, | ||
# because of https://github.com/grpc/grpc/issues/15880 that affected earlier versions | ||
from typing import Callable, Any | ||
import time | ||
|
||
import grpc | ||
import grpc.experimental.gevent as grpc_gevent | ||
import gevent | ||
from locust import events, User, task | ||
from locust.exception import LocustError | ||
from grpc_interceptor import ClientInterceptor | ||
|
||
import hello_pb2_grpc | ||
import grpc_user | ||
import hello_pb2 | ||
|
||
import hello_pb2_grpc | ||
from hello_server import start_server | ||
|
||
# patch grpc so that it uses gevent instead of asyncio | ||
grpc_gevent.init_gevent() | ||
from locust import events, task | ||
|
||
|
||
# Start the dummy server. This is not something you would do in a real test. | ||
@events.init.add_listener | ||
def run_grpc_server(environment, **_kwargs): | ||
# Start the dummy server. This is not something you would do in a real test. | ||
gevent.spawn(start_server) | ||
|
||
|
||
class LocustInterceptor(ClientInterceptor): | ||
def __init__(self, environment, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
self.env = environment | ||
|
||
def intercept( | ||
self, | ||
method: Callable, | ||
request_or_iterator: Any, | ||
call_details: grpc.ClientCallDetails, | ||
): | ||
response = None | ||
exception = None | ||
start_perf_counter = time.perf_counter() | ||
response_length = 0 | ||
try: | ||
response = method(request_or_iterator, call_details) | ||
response_length = response.result().ByteSize() | ||
except grpc.RpcError as e: | ||
exception = e | ||
|
||
self.env.events.request.fire( | ||
request_type="grpc", | ||
name=call_details.method, | ||
response_time=(time.perf_counter() - start_perf_counter) * 1000, | ||
response_length=response_length, | ||
response=response, | ||
context=None, | ||
exception=exception, | ||
) | ||
return response | ||
|
||
|
||
class GrpcUser(User): | ||
abstract = True | ||
|
||
stub_class = None | ||
|
||
def __init__(self, environment): | ||
super().__init__(environment) | ||
for attr_value, attr_name in ((self.host, "host"), (self.stub_class, "stub_class")): | ||
if attr_value is None: | ||
raise LocustError(f"You must specify the {attr_name}.") | ||
|
||
self._channel = grpc.insecure_channel(self.host) | ||
interceptor = LocustInterceptor(environment=environment) | ||
self._channel = grpc.intercept_channel(self._channel, interceptor) | ||
|
||
self.stub = self.stub_class(self._channel) | ||
|
||
|
||
class HelloGrpcUser(GrpcUser): | ||
class HelloGrpcUser(grpc_user.GrpcUser): | ||
host = "localhost:50051" | ||
stub_class = hello_pb2_grpc.HelloServiceStub | ||
|
||
@task | ||
def sayHello(self): | ||
self.stub.SayHello(hello_pb2.HelloRequest(name="Test")) | ||
time.sleep(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from contextlib import contextmanager | ||
from locust import task, run_single_user | ||
from locust.contrib.fasthttp import ResponseContextManager | ||
from locust.user.wait_time import constant | ||
from locust.contrib.rest import RestUser | ||
|
||
|
||
class MyUser(RestUser): | ||
host = "https://postman-echo.com" | ||
wait_time = constant(180) # be nice to postman-echo.com, and dont run this at scale. | ||
|
||
@task | ||
def t(self): | ||
# should work | ||
with self.rest("GET", "/get", json={"foo": 1}) as resp: | ||
if resp.js["args"]["foo"] != 1: | ||
resp.failure(f"Unexpected value of foo in response {resp.text}") | ||
|
||
# should work | ||
with self.rest("POST", "/post", json={"foo": 1}) as resp: | ||
if resp.js["data"]["foo"] != 1: | ||
resp.failure(f"Unexpected value of foo in response {resp.text}") | ||
# assertions are a nice short way of expressiont your expectations about the response. The AssertionError thrown will be caught | ||
# and fail the request, including the message and the payload in the failure content | ||
assert resp.js["data"]["foo"] == 1, "Unexpected value of foo in response" | ||
|
||
# assertions are a nice short way to validate the response. The AssertionError they raise | ||
# will be caught by rest() and mark the request as failed | ||
|
||
with self.rest("POST", "/post", json={"foo": 1}) as resp: | ||
# mark the request as failed with the message "Assertion failed" | ||
assert resp.js["data"]["foo"] == 2 | ||
|
||
with self.rest("POST", "/post", json={"foo": 1}) as resp: | ||
# custom failure message | ||
assert resp.js["data"]["foo"] == 2, "my custom error message" | ||
|
||
with self.rest("POST", "/post", json={"foo": 1}) as resp: | ||
# use a trailing comma to append the response text to the custom message | ||
assert resp.js["data"]["foo"] == 2, "my custom error message with response text," | ||
|
||
# this only works in python 3.8 and up, so it is commented out: | ||
# if sys.version_info >= (3, 8): | ||
# with self.rest("", "/post", json={"foo": 1}) as resp: | ||
# # assign and assert in one line | ||
# assert (foo := resp.js["foo"]) | ||
# print(f"the number {foo} is awesome") | ||
|
||
# rest() catches most exceptions, so any programming mistakes you make automatically marks the request as a failure | ||
# and stores the callstack in the failure message | ||
with self.rest("POST", "/post", json={"foo": 1}) as resp: | ||
1 / 0 # pylint: disable=pointless-statement | ||
|
||
# response isnt even json, but RestUser will already have been marked it as a failure, so we dont have to do it again | ||
with self.rest("GET", "/") as resp: | ||
pass | ||
|
||
with self.rest("GET", "/") as resp: | ||
# If resp.js is None (which it will be when there is a connection failure, a non-json responses etc), | ||
# reading from resp.js will raise a TypeError (instead of an AssertionError), so lets avoid that: | ||
if resp.js: | ||
assert resp.js["foo"] == 2 | ||
# or, as a mildly confusing oneliner: | ||
assert not resp.js or resp.js["foo"] == 2 | ||
|
||
# 404 | ||
with self.rest("GET", "http://example.com/") as resp: | ||
pass | ||
|
||
# connection closed | ||
with self.rest("POST", "http://example.com:42/", json={"foo": 1}) as resp: | ||
pass | ||
|
||
|
||
# An example of how you might write a common base class for an API that always requires | ||
# certain headers, or where you always want to check the response in a certain way | ||
class RestUserThatLooksAtErrors(RestUser): | ||
abstract = True | ||
|
||
@contextmanager | ||
def rest(self, method, url, **kwargs) -> ResponseContextManager: | ||
extra_headers = {"my_header": "my_value"} | ||
with super().rest(method, url, headers=extra_headers, **kwargs) as resp: | ||
resp: ResponseContextManager | ||
if resp.js and "error" in resp.js and resp.js["error"] is not None: | ||
resp.failure(resp.js["error"]) | ||
yield resp | ||
|
||
|
||
class MyOtherRestUser(RestUserThatLooksAtErrors): | ||
host = "https://postman-echo.com" | ||
wait_time = constant(180) # be nice to postman-echo.com, and dont run this at scale. | ||
|
||
@task | ||
def t(self): | ||
with self.rest("GET", "/") as _resp: | ||
pass | ||
|
||
|
||
if __name__ == "__main__": | ||
run_single_user(MyUser) |
Oops, something went wrong.