Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing in custom Session #170

Merged
merged 8 commits into from
May 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,9 @@ To lint, install requirements (included in the previous step) and run
```bash
make lint
```

## Acknowledgmnts

We would like to thank the following people for their contributions:

- [@aadamson](https://github.com/aadamson) for their contributions in supporting custom `requests.Session` objects [#170](https://github.com/sigopt/sigopt-python/pull/170)
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# For continuous integration and development
mock==1.0.1
mock>=3.0.5
pytest==2.8.7
twine==1.9.1
3 changes: 2 additions & 1 deletion sigopt/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class Connection(object):
Client-facing interface for creating Connections.
Shouldn't be changed without a major version change.
"""
def __init__(self, client_token=None, user_agent=None):
def __init__(self, client_token=None, user_agent=None, session=None):
client_token = client_token or os.environ.get('SIGOPT_API_TOKEN')
api_url = os.environ.get('SIGOPT_API_URL') or DEFAULT_API_URL
if not client_token:
Expand All @@ -289,6 +289,7 @@ def __init__(self, client_token=None, user_agent=None):
client_token,
'',
default_headers,
session=session,
)
self.impl = ConnectionImpl(requestor, api_url=api_url)

Expand Down
7 changes: 5 additions & 2 deletions sigopt/requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ def __init__(
user,
password,
headers,
verify_ssl_certs=True,
verify_ssl_certs=None,
proxies=None,
timeout=DEFAULT_HTTP_TIMEOUT,
client_ssl_certs=None,
session=None,
):
self._set_auth(user, password)
self.default_headers = headers or {}
self.verify_ssl_certs = verify_ssl_certs
self.proxies = proxies
self.timeout = timeout
self.client_ssl_certs = client_ssl_certs
self.session = session

def _set_auth(self, username, password):
if username is not None:
Expand All @@ -48,7 +50,8 @@ def delete(self, url, params=None, json=None, headers=None):
def request(self, method, url, params=None, json=None, headers=None):
headers = self._with_default_headers(headers)
try:
response = requests.request(
caller = (self.session or requests)
response = caller.request(
method=method,
url=url,
params=params,
Expand Down
16 changes: 15 additions & 1 deletion test/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,26 @@ class TestInterface(object):
def test_create(self):
conn = Connection(client_token='client_token')
assert conn.impl.api_url == 'https://api.sigopt.com'
assert conn.impl.requestor.verify_ssl_certs is True
assert conn.impl.requestor.verify_ssl_certs is None
assert conn.impl.requestor.session is None
assert conn.impl.requestor.proxies is None
assert conn.impl.requestor.timeout == DEFAULT_HTTP_TIMEOUT
assert isinstance(conn.clients, ApiResource)
assert isinstance(conn.experiments, ApiResource)

def test_create_uses_session_if_provided(self):
session = mock.Mock()
conn = Connection(client_token='client_token', session=session)
assert conn.impl.requestor.session is session

response = mock.Mock()
session.request.return_value = response
response.status_code = 200
response.text = '{}'
session.request.assert_not_called()
conn.experiments().fetch()
session.request.assert_called_once()

def test_environment_variable(self):
with mock.patch.dict(os.environ, {'SIGOPT_API_TOKEN': 'client_token'}):
Connection()
Expand Down