diff --git a/pyalex/api.py b/pyalex/api.py index 053e110..ba3ea32 100644 --- a/pyalex/api.py +++ b/pyalex/api.py @@ -3,6 +3,7 @@ from urllib.parse import quote_plus import requests +from requests.auth import AuthBase from urllib3.util import Retry try: @@ -22,6 +23,7 @@ def __setattr__(self, key, value): config = AlexConfig( email=None, api_key=None, + user_agent="pyalex/" + __version__, openalex_url="https://api.openalex.org", max_retries=0, retry_backoff_factor=0.1, @@ -160,6 +162,32 @@ def __next__(self): return results +class OpenAlexAuth(AuthBase): + """OpenAlex auth class based on requests auth + + Includes the email, api_key and user-agent headers. + + arguments: + config: an AlexConfig object + + """ + + def __init__(self, config): + self.config = config + + def __call__(self, r): + if self.config.api_key: + r.headers["Authorization"] = f"Bearer {self.config.api_key}" + + if self.config.email: + r.headers["From"] = self.config.email + + if self.config.user_agent: + r.headers["User-Agent"] = self.config.user_agent + + return r + + class BaseOpenAlex: """Base class for OpenAlex objects.""" @@ -222,13 +250,7 @@ def count(self): return m["count"] def _get_from_url(self, url, return_meta=False): - params = {"api_key": config.api_key} if config.api_key else {} - - res = _get_requests_session().get( - url, - headers={"User-Agent": "pyalex/" + __version__, "email": config.email}, - params=params, - ) + res = _get_requests_session().get(url, auth=OpenAlexAuth(config)) # handle query errors if res.status_code == 403: @@ -334,11 +356,9 @@ def __getitem__(self, key): def ngrams(self, return_meta=False): openalex_id = self["id"].split("/")[-1] + n_gram_url = f"{config.openalex_url}/works/{openalex_id}/ngrams" - res = _get_requests_session().get( - f"{config.openalex_url}/works/{openalex_id}/ngrams", - headers={"User-Agent": "pyalex/" + __version__, "email": config.email}, - ) + res = _get_requests_session().get(n_gram_url, auth=OpenAlexAuth(config)) res.raise_for_status() results = res.json() diff --git a/tests/test_pyalex.py b/tests/test_pyalex.py index f3d5867..165a4c6 100644 --- a/tests/test_pyalex.py +++ b/tests/test_pyalex.py @@ -288,3 +288,16 @@ def test_sample_seed(): def test_subset(): url = "https://api.openalex.org/works?select=id,doi,display_name" assert url == Works().select(["id", "doi", "display_name"]).url + + +def test_auth(): + w_no_auth = Works().get() + pyalex.config.email = "pyalex_github_unittests@example.com" + pyalex.config.api_key = "my_api_key" + + w_auth = Works().get() + + pyalex.config.email = None + pyalex.config.api_key = None + + assert len(w_no_auth) == len(w_auth)