diff --git a/README.rst b/README.rst index 88bcd26..9008817 100644 --- a/README.rst +++ b/README.rst @@ -102,10 +102,23 @@ A ``Collection`` can also be instantiated directly: .. code:: python - from taxii2client import Collection + from taxii2client.v20 import Collection, as_pages + collection = Collection('https://example.com/api1/collections/91a7b528-80eb-42ed-a74d-c6fbd5a26116') collection.get_object('indicator--252c7c11-daf2-42bd-843b-be65edca9f61') + # For normal (no pagination) requests + collection.get_objects() + collection.get_manifest() + + # For pagination requests. + for bundle in as_pages(collection.get_objects, per_request=50): + print(bundle) + + for manifest_resource in as_pages(collection.get_manifest, per_request=50): + print(manifest_resource) + + In addition to the object-specific properties and methods, all classes have a ``refresh()`` method that reloads the URL corresponding to that resource, to ensure properties have the most up-to-date values. diff --git a/taxii2client/common.py b/taxii2client/common.py index 75f544e..a6ae0cc 100644 --- a/taxii2client/common.py +++ b/taxii2client/common.py @@ -1,4 +1,5 @@ import datetime +import re import pytz import requests @@ -119,6 +120,17 @@ def _to_json(resp): ), e) +def _grab_total_items(resp): + """Extracts the Total elements available on the Endpoint making the request""" + try: + results = re.match(r"^items (\d+)-(\d+)/(\d+)$", resp.headers["Content-Range"]) + return int(results.group(2)) - int(results.group(1)) + 1, int(results.group(3)) + except ValueError as e: + six.raise_from(InvalidJSONError( + "Invalid Content-Range was received from " + resp.request.url + ), e) + + class _TAXIIEndpoint(object): """Contains some data and functionality common to all TAXII endpoint classes: a URL, connection, and ability to close the connection. It also @@ -127,7 +139,7 @@ class _TAXIIEndpoint(object): """ def __init__(self, url, conn=None, user=None, password=None, verify=True, - proxies=None, version="2.0"): + proxies=None): """Create a TAXII endpoint. Args: @@ -251,7 +263,10 @@ def get(self, url, headers=None, params=None): msg = "Unexpected Response. Got Content-Type: '{}' for Accept: '{}'" raise TAXIIServiceException(msg.format(content_type, accept)) - return _to_json(resp) + if "Range" in merged_headers: + return resp + else: + return _to_json(resp) def post(self, url, headers=None, params=None, **kwargs): """Send a JSON POST request with the given request headers, additional diff --git a/taxii2client/v20/__init__.py b/taxii2client/v20/__init__.py index a07b758..b9e4c74 100644 --- a/taxii2client/v20/__init__.py +++ b/taxii2client/v20/__init__.py @@ -3,16 +3,50 @@ from __future__ import unicode_literals import json +import logging import time import six import six.moves.urllib.parse as urlparse from .. import MEDIA_TYPE_STIX_V20, MEDIA_TYPE_TAXII_V20 -from ..common import _filter_kwargs_to_query_params, _TAXIIEndpoint +from ..common import _filter_kwargs_to_query_params, _grab_total_items, _to_json, _TAXIIEndpoint from ..exceptions import AccessError, ValidationError +# Module-level logger +log = logging.getLogger(__name__) +log.propagate = False + +formatter = logging.Formatter("[%(name)s] [%(levelname)s] [%(asctime)s] %(message)s") + +# Console Handler for taxii2client messages +ch = logging.StreamHandler() +ch.setFormatter(formatter) +log.addHandler(ch) + + +def as_pages(func, start=0, per_request=0, *args, **kwargs): + """Creates a generator for TAXII 2.0 endpoints that support pagination.""" + resp = func(start=start, per_request=per_request, *args, **kwargs) + yield _to_json(resp) + total_obtained, total_available = _grab_total_items(resp) + + if total_obtained != per_request: + log.warning("TAXII Server response with different amount of objects! Setting per_request=%s", total_obtained) + per_request = total_obtained + + start += per_request + while start < total_available: + + resp = func(start=start, per_request=per_request, *args, **kwargs) + yield _to_json(resp) + + total_in_request, total_available = _grab_total_items(resp) + total_obtained += total_in_request + start += per_request + + class Status(_TAXIIEndpoint): """TAXII Status Resource. @@ -264,6 +298,10 @@ def custom_properties(self): def objects_url(self): return self.url + "objects/" + @property + def manifest_url(self): + return self.url + "manifest/" + @property def _raw(self): """Get the "raw" collection information response (parsed JSON).""" @@ -329,12 +367,16 @@ def refresh(self, accept=MEDIA_TYPE_TAXII_V20): self._populate_fields(**response) self._loaded = True - def get_objects(self, accept=MEDIA_TYPE_STIX_V20, **filter_kwargs): - """Implement the ``Get Objects`` endpoint (section 5.3)""" + def get_objects(self, accept=MEDIA_TYPE_STIX_V20, start=0, per_request=0, **filter_kwargs): + """Implement the ``Get Objects`` endpoint (section 5.3). For pagination requests use ``as_pages`` method.""" self._verify_can_read() query_params = _filter_kwargs_to_query_params(filter_kwargs) - return self._conn.get(self.objects_url, headers={"Accept": accept}, - params=query_params) + headers = {"Accept": accept} + + if per_request > 0: + headers["Range"] = "items {}-{}".format(start, (start + per_request) - 1) + + return self._conn.get(self.objects_url, headers=headers, params=query_params) def get_object(self, obj_id, version=None, accept=MEDIA_TYPE_STIX_V20): """Implement the ``Get an Object`` endpoint (section 5.5)""" @@ -343,8 +385,7 @@ def get_object(self, obj_id, version=None, accept=MEDIA_TYPE_STIX_V20): query_params = None if version: query_params = _filter_kwargs_to_query_params({"version": version}) - return self._conn.get(url, headers={"Accept": accept}, - params=query_params) + return self._conn.get(url, headers={"Accept": accept}, params=query_params) def add_objects(self, bundle, wait_for_completion=True, poll_interval=1, timeout=60, accept=MEDIA_TYPE_TAXII_V20, @@ -423,13 +464,16 @@ def add_objects(self, bundle, wait_for_completion=True, poll_interval=1, return status - def get_manifest(self, accept=MEDIA_TYPE_TAXII_V20, **filter_kwargs): - """Implement the ``Get Object Manifests`` endpoint (section 5.6).""" + def get_manifest(self, accept=MEDIA_TYPE_TAXII_V20, start=0, per_request=0, **filter_kwargs): + """Implement the ``Get Object Manifests`` endpoint (section 5.6). For pagination requests use ``as_pages`` method.""" self._verify_can_read() query_params = _filter_kwargs_to_query_params(filter_kwargs) - return self._conn.get(self.url + "manifest/", - headers={"Accept": accept}, - params=query_params) + headers = {"Accept": accept} + + if per_request > 0: + headers["Range"] = "items {}-{}".format(start, (start + per_request) - 1) + + return self._conn.get(self.manifest_url, headers=headers, params=query_params) class ApiRoot(_TAXIIEndpoint): @@ -693,4 +737,4 @@ def refresh(self): self._loaded = True -__all__ = ["ApiRoot", "Collection", "Server", "Status"] +__all__ = ["ApiRoot", "Collection", "Server", "Status", "as_pages"]