Skip to content

Commit

Permalink
Resolving multiple issues (#296)
Browse files Browse the repository at this point in the history
* chore: avoid trying to add Statoil certificate if Equinor Root CA is already found
* chore: avoid default caching web ids in PiHandlerWeb
* chore: allow numpy integers and floats as timedelta arguments and add timeout to search method
* fix: allow None-timedelta when reading Snapshots or Raw data

Refs. #294, #288, #276, #275
  • Loading branch information
mortendaehli authored Dec 5, 2023
1 parent 9469c22 commit c0ac503
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 54 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ print(c.search("tag*"))
df = c.read_tags(["tag1", "tag2"], "18.06.2020 08:00:00", "18.06.2020 09:00:00", 60)
```

Note, you can add a timeout argument to the search method in order to avoid long-running search queries.

### Jupyter Notebook Quickstart
Jupyter Notebook examples can be found in /examples. In order to run these examples, you need to install the
optional dependencies.
Expand Down
42 changes: 35 additions & 7 deletions tagreader/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.error import HTTPError

import numpy as np
import pandas as pd
import pytz

Expand Down Expand Up @@ -199,6 +200,7 @@ def get_handler(
options: Dict[str, Union[int, float, str]],
verifySSL: Optional[bool],
auth: Optional[Any],
cache: Optional[Union[SmartCache, BucketCache]] = None,
):
if imstype is None:
try:
Expand All @@ -224,6 +226,7 @@ def get_handler(
options=options,
verify_ssl=verifySSL,
auth=auth,
cache=cache,
)

if imstype == IMSType.ASPENONE:
Expand Down Expand Up @@ -278,29 +281,36 @@ def __init__(
f"timezone argument 'tz' needs to be either a valid timezone string or a tzinfo-object. Given type was {type(tz)}"
)

self.cache = cache
self.handler = get_handler(
imstype=imstype,
datasource=datasource,
url=url,
options=handler_options,
verifySSL=verifySSL,
auth=auth,
cache=self.cache,
)
self.cache = cache

def connect(self) -> None:
self.handler.connect()

def search_tag(
self, tag: Optional[str] = None, desc: Optional[str] = None
self,
tag: Optional[str] = None,
desc: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[Tuple[str, str]]:
logger.warning("This function is deprecated. Please call 'search()' instead")
return self.search(tag=tag, desc=desc)
return self.search(tag=tag, desc=desc, timeout=timeout)

def search(
self, tag: Optional[str] = None, desc: Optional[str] = None
self,
tag: Optional[str] = None,
desc: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[Tuple[str, str]]:
return self.handler.search(tag=tag, desc=desc)
return self.handler.search(tag=tag, desc=desc, timeout=timeout)

def _get_metadata(self, tag: str):
return self.handler._get_tag_metadata(
Expand Down Expand Up @@ -487,7 +497,7 @@ def read(
tags: Union[str, List[str]],
start_time: Optional[Union[datetime, pd.Timestamp, str]] = None,
end_time: Optional[Union[datetime, pd.Timestamp, str]] = None,
ts: Union[timedelta, pd.Timedelta, int] = timedelta(seconds=60),
ts: Optional[Union[timedelta, pd.Timedelta, int]] = timedelta(seconds=60),
read_type: ReaderType = ReaderType.INT,
get_status: bool = False,
) -> pd.DataFrame:
Expand Down Expand Up @@ -540,8 +550,26 @@ def read(

if isinstance(ts, pd.Timedelta):
ts = ts.to_pytimedelta()
elif isinstance(ts, (int, float)):
elif isinstance(
ts,
(
int,
float,
np.int32,
np.int64,
np.float32,
np.float64,
np.number,
np.integer,
),
):
ts = timedelta(seconds=int(ts))
elif not ts and read_type not in [ReaderType.SNAPSHOT, ReaderType.RAW]:
raise ValueError(
"ts needs to be a timedelta or an integer (number of seconds)"
" unless you are reading raw or snapshot data."
f" Given type: {type(ts)}"
)
elif not isinstance(ts, timedelta):
raise ValueError(
"ts needs to be either a None, timedelta or and integer (number of seconds)."
Expand Down
2 changes: 1 addition & 1 deletion tagreader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ReaderType(enum.IntEnum):


def add_statoil_root_certificate() -> bool:
return add_equinor_root_certificate(True) and add_equinor_root_certificate(False)
return add_equinor_root_certificate(True) or add_equinor_root_certificate(False)


def add_equinor_root_certificate(get_equinor: bool = True) -> bool:
Expand Down
90 changes: 56 additions & 34 deletions tagreader/web_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import urllib3
from requests_kerberos import OPTIONAL, HTTPKerberosAuth

from tagreader.cache import BaseCache
from tagreader.cache import BaseCache, BucketCache, SmartCache
from tagreader.logger import logger
from tagreader.utils import ReaderType, is_mac, is_windows, urljoin

Expand Down Expand Up @@ -115,8 +115,20 @@ def __init__(
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
self.session.verify = verify_ssl if verify_ssl is not None else get_verify_ssl()

def fetch(self, url, params: Optional[Union[str, Dict[str, str]]] = None) -> Dict:
res = self.session.get(url, params=params)
def fetch(
self,
url,
params: Optional[Union[str, Dict[str, str]]] = None,
timeout: Optional[int] = None,
) -> Dict:
res = self.session.get(
url,
params=params,
timeout=(
None,
timeout,
),
) # Noqa. Read timeout, No connect timeout.
res.raise_for_status()

if len(res.text) == 0:
Expand Down Expand Up @@ -366,7 +378,9 @@ def _get_default_mapname(self, tagname: str):
if v:
return k

def search(self, tag: Optional[str], desc: Optional[str]) -> List[Tuple[str, str]]:
def search(
self, tag: Optional[str], desc: Optional[str], timeout: Optional[int] = None
) -> List[Tuple[str, str]]:
if tag is None:
raise ValueError("Tag is a required argument")

Expand All @@ -386,7 +400,7 @@ def search(self, tag: Optional[str], desc: Optional[str]) -> List[Tuple[str, str
)
url = urljoin(self.base_url, "Browse?")
url += encoded_params
data = self.fetch(url)
data = self.fetch(url, timeout=timeout)

if "tags" not in data["data"]:
return []
Expand Down Expand Up @@ -602,6 +616,7 @@ def __init__(
auth: Optional[Any],
verify_ssl: bool,
options: Dict[str, Union[int, float, str]],
cache: Optional[Union[SmartCache, BucketCache]],
):
self._max_rows = options.get("max_rows", 10000)
if url is None:
Expand All @@ -615,7 +630,7 @@ def __init__(
verify_ssl=verify_ssl,
)
self._max_rows = options.get("max_rows", 10000)
self.web_id_cache = BaseCache(directory=Path(".") / ".cache" / datasource)
self.web_id_cache = cache

@staticmethod
def _time_to_UTC_string(time: datetime) -> str:
Expand Down Expand Up @@ -772,7 +787,10 @@ def verify_connection(self, datasource: str) -> bool:
return False

def search(
self, tag: Optional[str] = None, desc: Optional[str] = None
self,
tag: Optional[str] = None,
desc: Optional[str] = None,
timeout: Optional[int] = None,
) -> List[Tuple]:
params = self.generate_search_query(
tag=tag, desc=desc, datasource=self.datasource
Expand All @@ -781,7 +799,7 @@ def search(
done = False
ret = []
while not done:
data = self.fetch(url, params=params)
data = self.fetch(url, params=params, timeout=timeout)

for item in data["Items"]:
description = item["Description"] if "Description" in item else ""
Expand Down Expand Up @@ -823,33 +841,37 @@ def tag_to_web_id(self, tag: str) -> Optional[str]:
:return: WebId
:rtype: str
"""
if tag not in self.web_id_cache:
params = self.generate_search_query(
tag=tag, datasource=self.datasource, desc=None
)
params["fields"] = "name;webid"
url = urljoin(self.base_url, "search", "query")
data = self.fetch(url, params=params)

if len(data["Errors"]) > 0:
msg = f"Received error from server when searching for WebId for {tag}: {data['Errors']}"
raise ValueError(msg)

if len(data["Items"]) > 1:
# Compare elements and if same, return the first
first = data["Items"][0]
for item in data["Items"][1:]:
if item != first:
raise AssertionError(
f"Received {len(data['Items'])} results when trying to find unique WebId for {tag}."
)
elif len(data["Items"]) == 0:
logger.warning(f"Tag {tag} not found")
return None

web_id = data["Items"][0]["WebId"]
if self.web_id_cache and tag in self.web_id_cache:
return self.web_id_cache[tag]

params = self.generate_search_query(
tag=tag, datasource=self.datasource, desc=None
)
params["fields"] = "name;webid"
url = urljoin(self.base_url, "search", "query")
data = self.fetch(url, params=params)

if len(data["Errors"]) > 0:
msg = f"Received error from server when searching for WebId for {tag}: {data['Errors']}"
raise ValueError(msg)

if len(data["Items"]) > 1:
# Compare elements and if same, return the first
first = data["Items"][0]
for item in data["Items"][1:]:
if item != first:
raise AssertionError(
f"Received {len(data['Items'])} results when trying to find unique WebId for {tag}."
)
elif len(data["Items"]) == 0:
logger.warning(f"Tag {tag} not found")
return None

web_id = data["Items"][0]["WebId"]

if self.web_id_cache:
self.web_id_cache[tag] = web_id
return self.web_id_cache[tag]
return web_id

@staticmethod
def _is_summary(read_type: ReaderType) -> bool:
Expand Down
Empty file removed tests/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from pathlib import Path
from typing import Generator

import pytest

from tagreader.cache import SmartCache


@pytest.fixture # type: ignore[misc]
def cache(tmp_path: Path) -> Generator[SmartCache, None, None]:
cache = SmartCache(directory=tmp_path, size_limit=int(4e9))
yield cache
20 changes: 16 additions & 4 deletions tests/test_PIHandlerREST.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from datetime import timedelta
from typing import Generator
from typing import Generator, cast

import pytest

from tagreader.cache import SmartCache
from tagreader.utils import ReaderType, ensure_datetime_with_tz
from tagreader.web_handlers import PIHandlerWeb

Expand All @@ -12,11 +13,18 @@


@pytest.fixture # type: ignore[misc]
def pi_handler() -> Generator[PIHandlerWeb, None, None]:
def pi_handler(cache: SmartCache) -> Generator[PIHandlerWeb, None, None]:
h = PIHandlerWeb(
datasource="sourcename", auth=None, options={}, url=None, verify_ssl=True
datasource="sourcename",
auth=None,
options={},
url=None,
verify_ssl=True,
cache=cache,
)
h.web_id_cache["alreadyknowntag"] = "knownwebid"
if not isinstance(h.web_id_cache, SmartCache):
raise ValueError("Expected SmartCache in the web client.")
h.web_id_cache.add(key="alreadyknowntag", value="knownwebid")
yield h


Expand Down Expand Up @@ -89,6 +97,8 @@ def test_is_summary(pi_handler: PIHandlerWeb) -> None:
],
)
def test_generate_read_query(pi_handler: PIHandlerWeb, read_type: str) -> None:
if not isinstance(pi_handler.web_id_cache, SmartCache):
raise ValueError("Expected SmartCache in the fixture.")
start = ensure_datetime_with_tz(START_TIME)
stop = ensure_datetime_with_tz(STOP_TIME)
ts = timedelta(seconds=SAMPLE_TIME)
Expand Down Expand Up @@ -161,6 +171,8 @@ def test_generate_read_query(pi_handler: PIHandlerWeb, read_type: str) -> None:
def test_generate_read_query_with_status(
pi_handler: PIHandlerWeb, read_type: str
) -> None:
if not isinstance(pi_handler.web_id_cache, SmartCache):
raise ValueError("Expected SmartCache in the fixture.")
start = ensure_datetime_with_tz(START_TIME)
stop = ensure_datetime_with_tz(STOP_TIME)
ts = timedelta(seconds=SAMPLE_TIME)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_PIHandlerREST_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from tagreader.cache import SmartCache
from tagreader.clients import IMSClient, list_sources
from tagreader.utils import ReaderType, ensure_datetime_with_tz
from tagreader.web_handlers import PIHandlerWeb, get_verify_ssl, list_piwebapi_sources
Expand Down Expand Up @@ -46,10 +47,17 @@ def client() -> Generator[IMSClient, None, None]:


@pytest.fixture # type: ignore[misc]
def pi_handler() -> Generator[PIHandlerWeb, None, None]:
def pi_handler(cache: SmartCache) -> Generator[PIHandlerWeb, None, None]:
h = PIHandlerWeb(
datasource=SOURCE, verify_ssl=bool(verifySSL), auth=None, options={}, url=None
datasource=SOURCE,
verify_ssl=bool(verifySSL),
auth=None,
options={},
url=None,
cache=cache,
)
if not isinstance(h.web_id_cache, SmartCache):
raise ValueError("Expected SmartCache in the web client.")
h.web_id_cache["alreadyknowntag"] = "knownwebid"
yield h

Expand Down
6 changes: 0 additions & 6 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ def data() -> Generator[pd.DataFrame, None, None]:
yield df_total


@pytest.fixture # type: ignore[misc]
def cache(tmp_path: Path) -> Generator[SmartCache, None, None]:
cache = SmartCache(directory=tmp_path, size_limit=int(4e9))
yield cache


def test_base_cache(tmp_path: Path) -> None:
webidcache = BaseCache(directory=tmp_path)

Expand Down

0 comments on commit c0ac503

Please sign in to comment.