diff --git a/python/tests/api/writer/test_whylabs_integration.py b/python/tests/api/writer/test_whylabs_integration.py index 4a390cc6f..dc6b559ce 100644 --- a/python/tests/api/writer/test_whylabs_integration.py +++ b/python/tests/api/writer/test_whylabs_integration.py @@ -71,9 +71,8 @@ def _get_org() -> str: @httpretty.activate(allow_net_connect=False, verbose=True) def test_whylabs_writer_throttle_retry(): ENDPOINT = os.environ["WHYLABS_API_ENDPOINT"] - ORG_ID = _get_org() MODEL_ID = "XXX" - uri = f"{ENDPOINT}/v0/organizations/{ORG_ID}/log/async/{MODEL_ID}" + uri = f"{ENDPOINT}/v1/log/async" httpretty.register_uri(httpretty.POST, uri, status=429) # Fake WhyLabs that throttles why.init(reinit=True, force_local=True) data = {"col1": 1, "col2": "foo"} @@ -271,9 +270,8 @@ def test_put_column_schema_retry(): @httpretty.activate(allow_net_connect=False, verbose=True) def test_log_async_retry(): ENDPOINT = os.environ["WHYLABS_API_ENDPOINT"] - ORG_ID = _get_org() MODEL_ID = "xxx" - uri = f"{ENDPOINT}/v0/organizations/{ORG_ID}/log/async/{MODEL_ID}" + uri = f"{ENDPOINT}/v1/log/async" httpretty.register_uri(httpretty.POST, uri, status=429) # Fake WhyLabs that throttles writer = WhyLabsWriter(dataset_id=MODEL_ID) @@ -651,8 +649,9 @@ def test_transaction_aborted(): transaction_id = writer.start_transaction() status, id = writer.write(result) assert status - writer._whylabs_client.abort_transaction(transaction_id) + writer.abort_transaction() status, id = writer.write(result) + assert transaction_id == writer._whylabs_client._transaction_id with pytest.raises(TransactionAbortedException) as e: writer.commit_transaction() assert str(e) == "Transaction has been aborted" diff --git a/python/whylogs/api/writer/whylabs_transaction_writer.py b/python/whylogs/api/writer/whylabs_transaction_writer.py index 8fdca2c38..1da665d5d 100644 --- a/python/whylogs/api/writer/whylabs_transaction_writer.py +++ b/python/whylogs/api/writer/whylabs_transaction_writer.py @@ -67,7 +67,7 @@ def __init__( transaction_id: Optional[str] = None, ): super().__init__(org_id, api_key, dataset_id, api_client, ssl_ca_cert, _timeout_seconds, whylabs_client) - transaction_id = transaction_id or self._whylabs_client.get_transaction_id() # type: ignore + transaction_id = transaction_id or self._whylabs_client._transaction_id or self._whylabs_client.get_transaction_id() # type: ignore self._whylabs_client._transaction_id = transaction_id # type: ignore self._aborted: bool = False