From 9377dea65b2d90b0c88347129b650e180537dc2f Mon Sep 17 00:00:00 2001 From: JericHunter Date: Mon, 25 Apr 2022 07:56:59 -0700 Subject: [PATCH 1/8] Create test_crud_async.py creating async tests for the crud tests --- .../azure-cosmos/test/test_crud_async.py | 2663 +++++++++++++++++ 1 file changed, 2663 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/test/test_crud_async.py diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py new file mode 100644 index 000000000000..ac0d654ead79 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -0,0 +1,2663 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) 2014 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""End to end test. +""" +import asyncio +import json +import logging +import os.path +import unittest +import time +from typing import Mapping + +import urllib.parse as urllib +import uuid +import pytest +from azure.core import MatchConditions +from azure.core.exceptions import AzureError, ServiceResponseError +from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse +import azure.cosmos.documents as documents +import azure.cosmos.exceptions as exceptions +from azure.cosmos.http_constants import HttpHeaders, StatusCodes +import test_config +import azure.cosmos._base as base +import azure.cosmos.cosmos_client as cosmos_client +from azure.cosmos.diagnostics import RecordDiagnostics +from azure.cosmos.partition_key import PartitionKey +from azure.cosmos import _retry_utility +import requests +from urllib3.util.retry import Retry + +pytestmark = pytest.mark.cosmosEmulator + +# IMPORTANT NOTES: +# Most test cases in this file create collections in your Azure Cosmos account. +# Collections are billing entities. By running these test cases, you may incur monetary costs on your account. + +# To Run the test, replace the two member fields (masterKey and host) with values +# associated with your Azure Cosmos account. + + +class TimeoutTransport(RequestsTransport): + + def __init__(self, response): + self._response = response + super(TimeoutTransport, self).__init__() + + async def send(self, *args, **kwargs): + if kwargs.pop("passthrough", False): + return super(TimeoutTransport, self).send(*args, **kwargs) + + time.sleep(5) + if isinstance(self._response, Exception): + raise self._response + output = requests.Response() + output.status_code = self._response + response = RequestsTransportResponse(None, output) + return response + + +@pytest.mark.usefixtures("teardown") +class CRUDTests(unittest.TestCase): + """Python CRUD Tests. + """ + + configs = test_config._test_config + host = configs.host + masterKey = configs.masterKey + connectionPolicy = configs.connectionPolicy + last_headers = [] + + async def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): + """Assert HTTP failure with status. + + :Parameters: + - `status_code`: int + - `func`: function + """ + try: + func(*args, **kwargs) + self.assertFalse(True, 'function should fail.') + except exceptions.CosmosHttpResponseError as inst: + self.assertEqual(inst.status_code, status_code) + + @classmethod + async def setUpClass(cls): + if (cls.masterKey == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, connection_policy=cls.connectionPolicy) + cls.databaseForTest = cls.configs.create_database_if_not_exist(cls.client) + + async def setUp(self): + self.client = cosmos_client.CosmosClient(self.host, self.masterKey, "Session", + connection_policy=self.connectionPolicy) + async def test_database_crud(self): + # read databases. + databases = list(self.client.list_databases()) + # create a database. + before_create_databases_count = len(databases) + database_id = str(uuid.uuid4()) + created_db = self.client.create_database(database_id) + self.assertEqual(created_db.id, database_id) + # Read databases after creation. + databases = list(self.client.list_databases()) + self.assertEqual(len(databases), + before_create_databases_count + 1, + 'create should increase the number of databases') + # query databases. + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': database_id} + ] + })) + self.assertTrue(databases, 'number of results for the query should be > 0') + + # read database. + self.client.get_database_client(created_db.id) + + # delete database. + self.client.delete_database(created_db.id) + # read database after deletion + read_db = self.client.get_database_client(created_db.id) + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + read_db.read) + + database_proxy = self.client.create_database_if_not_exists(id=database_id, offer_throughput=10000) + self.assertEqual(database_id, database_proxy.id) + self.assertEqual(10000, database_proxy.read_offer().offer_throughput) + + database_proxy = self.client.create_database_if_not_exists(id=database_id, offer_throughput=9000) + self.assertEqual(database_id, database_proxy.id) + self.assertEqual(10000, database_proxy.read_offer().offer_throughput) + + self.client.delete_database(database_id) + + @pytest.mark.skip("skipping as the TestResources subscription doesn't support this offer") + async def test_database_level_offer_throughput(self): + # Create a database with throughput + offer_throughput = 1000 + database_id = str(uuid.uuid4()) + created_db = self.client.create_database( + id=database_id, + offer_throughput=offer_throughput + ) + self.assertEqual(created_db.id, database_id) + + # Verify offer throughput for database + offer = created_db.read_offer() + self.assertEqual(offer.offer_throughput, offer_throughput) + + # Update database offer throughput + new_offer_throughput = 2000 + offer = created_db.replace_throughput(new_offer_throughput) + self.assertEqual(offer.offer_throughput, new_offer_throughput) + self.client.delete_database(created_db.id) + + async def test_sql_query_crud(self): + # create two databases. + db1 = self.client.create_database('database 1' + str(uuid.uuid4())) + db2 = self.client.create_database('database 2' + str(uuid.uuid4())) + + # query with parameters. + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': db1.id} + ] + })) + self.assertEqual(1, len(databases), 'Unexpected number of query results.') + + # query without parameters. + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id="database non-existing"' + })) + self.assertEqual(0, len(databases), 'Unexpected number of query results.') + + # query with a string. + databases = list(self.client.query_databases('SELECT * FROM root r WHERE r.id="' + db2.id + '"')) # nosec + self.assertEqual(1, len(databases), 'Unexpected number of query results.') + self.client.delete_database(db1.id) + self.client.delete_database(db2.id) + + async def test_collection_crud(self): + created_db = self.databaseForTest + collections = list(created_db.list_containers()) + # create a collection + before_create_collections_count = len(collections) + collection_id = 'test_collection_crud ' + str(uuid.uuid4()) + collection_indexing_policy = {'indexingMode': 'consistent'} + created_recorder = RecordDiagnostics() + created_collection = created_db.create_container(id=collection_id, + indexing_policy=collection_indexing_policy, + partition_key=PartitionKey(path="/pk", kind="Hash"), + response_hook=created_recorder) + self.assertEqual(collection_id, created_collection.id) + assert isinstance(created_recorder.headers, Mapping) + assert 'Content-Type' in created_recorder.headers + assert isinstance(created_recorder.body, Mapping) + assert 'id' in created_recorder.body + + created_properties = created_collection.read() + self.assertEqual('consistent', created_properties['indexingPolicy']['indexingMode']) + + # read collections after creation + collections = list(created_db.list_containers()) + self.assertEqual(len(collections), + before_create_collections_count + 1, + 'create should increase the number of collections') + # query collections + collections = list(created_db.query_containers( + { + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': collection_id} + ] + })) + + self.assertTrue(collections) + # delete collection + created_db.delete_container(created_collection.id) + # read collection after deletion + created_container = created_db.get_container_client(created_collection.id) + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + created_container.read) + + container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, partition_key=PartitionKey(path='/id', kind='Hash')) + self.assertEqual(created_collection.id, container_proxy.id) + self.assertDictEqual(PartitionKey(path='/id', kind='Hash'), container_proxy._properties['partitionKey']) + + container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, partition_key=created_properties['partitionKey']) + self.assertEqual(created_container.id, container_proxy.id) + self.assertDictEqual(PartitionKey(path='/id', kind='Hash'), container_proxy._properties['partitionKey']) + + created_db.delete_container(created_collection.id) + + async def test_partitioned_collection(self): + created_db = self.databaseForTest + + collection_definition = {'id': 'test_partitioned_collection ' + str(uuid.uuid4()), + 'partitionKey': + { + 'paths': ['/id'], + 'kind': documents.PartitionKind.Hash + } + } + + offer_throughput = 10100 + created_collection = created_db.create_container(id=collection_definition['id'], + partition_key=collection_definition['partitionKey'], + offer_throughput=offer_throughput) + + self.assertEqual(collection_definition.get('id'), created_collection.id) + + created_collection_properties = created_collection.read() + self.assertEqual(collection_definition.get('partitionKey').get('paths')[0], + created_collection_properties['partitionKey']['paths'][0]) + self.assertEqual(collection_definition.get('partitionKey').get('kind'), + created_collection_properties['partitionKey']['kind']) + + expected_offer = created_collection.read_offer() + + self.assertIsNotNone(expected_offer) + + self.assertEqual(expected_offer.offer_throughput, offer_throughput) + + created_db.delete_container(created_collection.id) + + async def test_partitioned_collection_quota(self): + created_db = self.databaseForTest + + created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + retrieved_collection = created_db.get_container_client( + container=created_collection.id + ) + + retrieved_collection_properties = retrieved_collection.read( + populate_partition_key_range_statistics=True, + populate_quota_info=True) + self.assertIsNotNone(retrieved_collection_properties.get("statistics")) + self.assertIsNotNone(created_db.client_connection.last_response_headers.get("x-ms-resource-usage")) + + async def test_partitioned_collection_partition_key_extraction(self): + created_db = self.databaseForTest + + collection_id = 'test_partitioned_collection_partition_key_extraction ' + str(uuid.uuid4()) + created_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/address/state', kind=documents.PartitionKind.Hash) + ) + + document_definition = {'id': 'document1', + 'address': {'street': '1 Microsoft Way', + 'city': 'Redmond', + 'state': 'WA', + 'zip code': 98052 + } + } + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # create document without partition key being specified + created_document = created_collection.create_item(body=document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[1], '["WA"]') + del self.last_headers[:] + + self.assertEqual(created_document.get('id'), document_definition.get('id')) + self.assertEqual(created_document.get('address').get('state'), document_definition.get('address').get('state')) + + collection_id = 'test_partitioned_collection_partition_key_extraction1 ' + str(uuid.uuid4()) + created_collection1 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/address', kind=documents.PartitionKind.Hash) + ) + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # Create document with partitionkey not present as a leaf level property but a dict + created_document = created_collection1.create_item(document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[1], [{}]) + del self.last_headers[:] + + #self.assertEqual(options['partitionKey'], documents.Undefined) + + collection_id = 'test_partitioned_collection_partition_key_extraction2 ' + str(uuid.uuid4()) + created_collection2 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/address/state/city', kind=documents.PartitionKind.Hash) + ) + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # Create document with partitionkey not present in the document + created_document = created_collection2.create_item(document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[1], [{}]) + del self.last_headers[:] + + #self.assertEqual(options['partitionKey'], documents.Undefined) + + created_db.delete_container(created_collection.id) + created_db.delete_container(created_collection1.id) + created_db.delete_container(created_collection2.id) + + async def test_partitioned_collection_partition_key_extraction_special_chars(self): + created_db = self.databaseForTest + + collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars1 ' + str(uuid.uuid4()) + + created_collection1 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/\"level\' 1*()\"/\"le/vel2\"', kind=documents.PartitionKind.Hash) + ) + document_definition = {'id': 'document1', + "level' 1*()": {"le/vel2": 'val1'} + } + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + created_document = created_collection1.create_item(body=document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[1], '["val1"]') + del self.last_headers[:] + + collection_definition2 = { + 'id': 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()), + 'partitionKey': + { + 'paths': ['/\'level\" 1*()\'/\'le/vel2\''], + 'kind': documents.PartitionKind.Hash + } + } + + collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()) + + created_collection2 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/\'level\" 1*()\'/\'le/vel2\'', kind=documents.PartitionKind.Hash) + ) + + document_definition = {'id': 'document2', + 'level\" 1*()': {'le/vel2': 'val2'} + } + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # create document without partition key being specified + created_document = created_collection2.create_item(body=document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[1], '["val2"]') + del self.last_headers[:] + + created_db.delete_container(created_collection1.id) + created_db.delete_container(created_collection2.id) + + async def test_partitioned_collection_path_parser(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(test_dir, "BaselineTest.PathParser.json")) as json_file: + entries = json.loads(json_file.read()) + for entry in entries: + parts = base.ParsePaths([entry['path']]) + self.assertEqual(parts, entry['parts']) + + paths = ["/\"Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1\"/*"] + parts = ["Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1", "*"] + self.assertEqual(parts, base.ParsePaths(paths)) + + paths = ["/'Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1'/*"] + parts = ["Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1", "*"] + self.assertEqual(parts, base.ParsePaths(paths)) + + async def test_partitioned_collection_document_crud_and_query(self): + created_db = self.databaseForTest + + created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + document_definition = {'id': 'document', + 'key': 'value'} + + created_document = created_collection.create_item( + body=document_definition + ) + + self.assertEqual(created_document.get('id'), document_definition.get('id')) + self.assertEqual(created_document.get('key'), document_definition.get('key')) + + # read document + read_document = created_collection.read_item( + item=created_document.get('id'), + partition_key=created_document.get('id') + ) + + self.assertEqual(read_document.get('id'), created_document.get('id')) + self.assertEqual(read_document.get('key'), created_document.get('key')) + + # Read document feed doesn't require partitionKey as it's always a cross partition query + documentlist = list(created_collection.read_all_items()) + self.assertEqual(1, len(documentlist)) + + # replace document + document_definition['key'] = 'new value' + + replaced_document = created_collection.replace_item( + item=read_document, + body=document_definition + ) + + self.assertEqual(replaced_document.get('key'), document_definition.get('key')) + + # upsert document(create scenario) + document_definition['id'] = 'document2' + document_definition['key'] = 'value2' + + upserted_document = created_collection.upsert_item(body=document_definition) + + self.assertEqual(upserted_document.get('id'), document_definition.get('id')) + self.assertEqual(upserted_document.get('key'), document_definition.get('key')) + + documentlist = list(created_collection.read_all_items()) + self.assertEqual(2, len(documentlist)) + + # delete document + created_collection.delete_item(item=upserted_document, partition_key=upserted_document.get('id')) + + # query document on the partition key specified in the predicate will pass even without setting enableCrossPartitionQuery or passing in the partitionKey value + documentlist = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.id=\'' + replaced_document.get('id') + '\'' # nosec + })) + self.assertEqual(1, len(documentlist)) + + # query document on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value + try: + list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'' # nosec + })) + except Exception: + pass + + # cross partition query + documentlist = list(created_collection.query_items( + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec + enable_cross_partition_query=True + )) + + self.assertEqual(1, len(documentlist)) + + # query document by providing the partitionKey value + documentlist = list(created_collection.query_items( + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec + partition_key=replaced_document.get('id') + )) + + self.assertEqual(1, len(documentlist)) + + async def test_partitioned_collection_permissions(self): + created_db = self.databaseForTest + + collection_id = 'test_partitioned_collection_permissions all collection' + str(uuid.uuid4()) + + all_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/key', kind=documents.PartitionKind.Hash) + ) + + collection_id = 'test_partitioned_collection_permissions read collection' + str(uuid.uuid4()) + + read_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/key', kind=documents.PartitionKind.Hash) + ) + + user = created_db.create_user(body={'id': 'user' + str(uuid.uuid4())}) + + permission_definition = { + 'id': 'all permission', + 'permissionMode': documents.PermissionMode.All, + 'resource': all_collection.container_link, + 'resourcePartitionKey': [1] + } + + all_permission = user.create_permission(body=permission_definition) + + permission_definition = { + 'id': 'read permission', + 'permissionMode': documents.PermissionMode.Read, + 'resource': read_collection.container_link, + 'resourcePartitionKey': [1] + } + + read_permission = user.create_permission(body=permission_definition) + + resource_tokens = {} + # storing the resource tokens based on Resource IDs + resource_tokens["dbs/" + created_db.id + "/colls/" + all_collection.id] = (all_permission.properties['_token']) + resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = (read_permission.properties['_token']) + + restricted_client = cosmos_client.CosmosClient( + CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) + + document_definition = {'id': 'document1', + 'key': 1 + } + + all_collection.client_connection = restricted_client.client_connection + read_collection.client_connection = restricted_client.client_connection + + # Create document in all_collection should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object and it has all permissions + created_document = all_collection.create_item(body=document_definition) + + # Create document in read_collection should fail since it has only read permissions for this collection + await self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + read_collection.create_item, + document_definition) + + document_definition['key'] = 2 + # Create document should fail since the partitionKey is 2 which is different that what is specified as resourcePartitionKey in permission object + await self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + all_collection.create_item, + document_definition) + + document_definition['key'] = 1 + # Delete document should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object + created_document = all_collection.delete_item(item=created_document['id'], partition_key=document_definition['key']) + + # Delete document in read_collection should fail since it has only read permissions for this collection + await self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + read_collection.delete_item, + document_definition['id'], + document_definition['id'] + ) + + created_db.delete_container(all_collection) + created_db.delete_container(read_collection) + + async def test_partitioned_collection_execute_stored_procedure(self): + created_db = self.databaseForTest + + created_collection = self.configs.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + + sproc = { + 'id': 'storedProcedure' + str(uuid.uuid4()), + 'body': ( + 'function () {' + + ' var client = getContext().getCollection();' + + ' client.createDocument(client.getSelfLink(), { id: \'testDoc\', pk : 2}, {}, function(err, docCreated, options) { ' + + ' if(err) throw new Error(\'Error while creating document: \' + err.message);' + + ' else {' + + ' getContext().getResponse().setBody(1);' + + ' }' + + ' });}') + } + + created_sproc = created_collection.scripts.create_stored_procedure(body=sproc) + + # Partiton Key value same as what is specified in the stored procedure body + result = created_collection.scripts.execute_stored_procedure(sproc=created_sproc['id'], partition_key=2) + self.assertEqual(result, 1) + + # Partiton Key value different than what is specified in the stored procedure body will cause a bad request(400) error + await self.__AssertHTTPFailureWithStatus( + StatusCodes.BAD_REQUEST, + created_collection.scripts.execute_stored_procedure, + created_sproc['id'], + 3) + + async def test_partitioned_collection_partition_key_value_types(self): + created_db = self.databaseForTest + + created_collection = created_db.create_container( + id='test_partitioned_collection_partition_key_value_types ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/pk', kind='Hash') + ) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': None, + 'spam': 'eggs'} + + # create document with partitionKey set as None here + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'spam': 'eggs'} + + # create document with partitionKey set as Undefined here + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': True, + 'spam': 'eggs'} + + # create document with bool partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': 'value', + 'spam': 'eggs'} + + # create document with string partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': 100, + 'spam': 'eggs'} + + # create document with int partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': 10.50, + 'spam': 'eggs'} + + # create document with float partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'name': 'sample document', + 'spam': 'eggs', + 'pk': 'value'} + + # Should throw an error because automatic id generation is disabled always. + await self.__AssertHTTPFailureWithStatus( + StatusCodes.BAD_REQUEST, + created_collection.create_item, + document_definition + ) + + created_db.delete_container(created_collection) + + async def test_partitioned_collection_conflict_crud_and_query(self): + created_db = self.databaseForTest + + created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + conflict_definition = {'id': 'new conflict', + 'resourceId': 'doc1', + 'operationType': 'create', + 'resourceType': 'document' + } + + # read conflict here will return resource not found(404) since there is no conflict here + await self.__AssertHTTPFailureWithStatus( + StatusCodes.NOT_FOUND, + created_collection.get_conflict, + conflict_definition['id'], + conflict_definition['id'] + ) + + # Read conflict feed doesn't requires partitionKey to be specified as it's a cross partition thing + conflictlist = list(created_collection.list_conflicts()) + self.assertEqual(0, len(conflictlist)) + + # delete conflict here will return resource not found(404) since there is no conflict here + await self.__AssertHTTPFailureWithStatus( + StatusCodes.NOT_FOUND, + created_collection.delete_conflict, + conflict_definition['id'], + conflict_definition['id'] + ) + + # query conflicts on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value + try: + list(created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get( # nosec + 'resourceType') + '\'' + )) + except Exception: + pass + + conflictlist = list(created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', # nosec + enable_cross_partition_query=True + )) + + self.assertEqual(0, len(conflictlist)) + + # query conflicts by providing the partitionKey value + options = {'partitionKey': conflict_definition.get('id')} + conflictlist = list(created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', # nosec + partition_key=conflict_definition['id'] + )) + + self.assertEqual(0, len(conflictlist)) + + async def test_document_crud(self): + # create database + created_db = self.databaseForTest + # create collection + created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + # read documents + documents = list(created_collection.read_all_items()) + # create a document + before_create_documents_count = len(documents) + + # create a document with auto ID generation + document_definition = {'name': 'sample document', + 'spam': 'eggs', + 'key': 'value'} + + created_document = created_collection.create_item(body=document_definition, enable_automatic_id_generation=True) + self.assertEqual(created_document.get('name'), + document_definition['name']) + + document_definition = {'name': 'sample document', + 'spam': 'eggs', + 'key': 'value', + 'id': str(uuid.uuid4())} + + created_document = created_collection.create_item(body=document_definition) + self.assertEqual(created_document.get('name'), + document_definition['name']) + self.assertEqual(created_document.get('id'), + document_definition['id']) + + # duplicated documents are not allowed when 'id' is provided. + duplicated_definition_with_id = document_definition.copy() + await self.__AssertHTTPFailureWithStatus(StatusCodes.CONFLICT, + created_collection.create_item, + duplicated_definition_with_id) + # read documents after creation + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 2, + 'create should increase the number of documents') + # query documents + documents = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.name=@name', + 'parameters': [ + {'name': '@name', 'value': document_definition['name']} + ] + }, enable_cross_partition_query=True + )) + self.assertTrue(documents) + documents = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.name=@name', + 'parameters': [ + {'name': '@name', 'value': document_definition['name']} + ], + }, enable_cross_partition_query=True, + enable_scan_in_query=True + )) + self.assertTrue(documents) + # replace document. + created_document['name'] = 'replaced document' + created_document['spam'] = 'not eggs' + old_etag = created_document['_etag'] + replaced_document = created_collection.replace_item( + item=created_document['id'], + body=created_document + ) + self.assertEqual(replaced_document['name'], + 'replaced document', + 'document id property should change') + self.assertEqual(replaced_document['spam'], + 'not eggs', + 'property should have changed') + self.assertEqual(created_document['id'], + replaced_document['id'], + 'document id should stay the same') + + # replace document based on condition + replaced_document['name'] = 'replaced document based on condition' + replaced_document['spam'] = 'new spam field' + + # should fail for stale etag + await self.__AssertHTTPFailureWithStatus( + StatusCodes.PRECONDITION_FAILED, + created_collection.replace_item, + replaced_document['id'], + replaced_document, + if_match=old_etag, + ) + + # should fail if only etag specified + with self.assertRaises(ValueError): + created_collection.replace_item( + etag=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) + + # should fail if only match condition specified + with self.assertRaises(ValueError): + created_collection.replace_item( + match_condition=MatchConditions.IfNotModified, + item=replaced_document['id'], + body=replaced_document + ) + with self.assertRaises(ValueError): + created_collection.replace_item( + match_condition=MatchConditions.IfModified, + item=replaced_document['id'], + body=replaced_document + ) + + # should fail if invalid match condition specified + with self.assertRaises(TypeError): + created_collection.replace_item( + match_condition=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) + + # should pass for most recent etag + replaced_document_conditional = created_collection.replace_item( + match_condition=MatchConditions.IfNotModified, + etag=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) + self.assertEqual(replaced_document_conditional['name'], + 'replaced document based on condition', + 'document id property should change') + self.assertEqual(replaced_document_conditional['spam'], + 'new spam field', + 'property should have changed') + self.assertEqual(replaced_document_conditional['id'], + replaced_document['id'], + 'document id should stay the same') + # read document + one_document_from_read = created_collection.read_item( + item=replaced_document['id'], + partition_key=replaced_document['id'] + ) + self.assertEqual(replaced_document['id'], + one_document_from_read['id']) + # delete document + created_collection.delete_item( + item=replaced_document, + partition_key=replaced_document['id'] + ) + # read documents after deletion + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + created_collection.read_item, + replaced_document['id'], + replaced_document['id']) + + async def test_document_upsert(self): + # create database + created_db = self.databaseForTest + + # create collection + created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + # read documents and check count + documents = list(created_collection.read_all_items()) + before_create_documents_count = len(documents) + + # create document definition + document_definition = {'id': 'doc', + 'name': 'sample document', + 'spam': 'eggs', + 'key': 'value'} + + # create document using Upsert API + created_document = created_collection.upsert_item(body=document_definition) + + # verify id property + self.assertEqual(created_document['id'], + document_definition['id']) + + # test error for non-string id + with pytest.raises(TypeError): + document_definition['id'] = 7 + created_collection.upsert_item(body=document_definition) + + # read documents after creation and verify updated count + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 1, + 'create should increase the number of documents') + + # update document + created_document['name'] = 'replaced document' + created_document['spam'] = 'not eggs' + + # should replace document since it already exists + upserted_document = created_collection.upsert_item(body=created_document) + + # verify the changed properties + self.assertEqual(upserted_document['name'], + created_document['name'], + 'document name property should change') + self.assertEqual(upserted_document['spam'], + created_document['spam'], + 'property should have changed') + + # verify id property + self.assertEqual(upserted_document['id'], + created_document['id'], + 'document id should stay the same') + + # read documents after upsert and verify count doesn't increases again + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 1, + 'number of documents should remain same') + + created_document['id'] = 'new id' + + # Upsert should create new document since the id is different + new_document = created_collection.upsert_item(body=created_document) + + # Test modified access conditions + created_document['spam'] = 'more eggs' + created_collection.upsert_item(body=created_document) + with pytest.raises(exceptions.CosmosHttpResponseError): + created_collection.upsert_item( + body=created_document, + match_condition=MatchConditions.IfNotModified, + etag=new_document['_etag']) + + # verify id property + self.assertEqual(created_document['id'], + new_document['id'], + 'document id should be same') + + # read documents after upsert and verify count increases + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 2, + 'upsert should increase the number of documents') + + # delete documents + created_collection.delete_item(item=upserted_document, partition_key=upserted_document['id']) + created_collection.delete_item(item=new_document, partition_key=new_document['id']) + + # read documents after delete and verify count is same as original + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count, + 'number of documents should remain same') + + + async def _test_spatial_index(self): + db = self.databaseForTest + # partial policy specified + collection = db.create_container( + id='collection with spatial index ' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/"Location"/?', + 'indexes': [ + { + 'kind': 'Spatial', + 'dataType': 'Point' + } + ] + }, + { + 'path': '/' + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection.create_item( + body={ + 'id': 'loc1', + 'Location': { + 'type': 'Point', + 'coordinates': [20.0, 20.0] + } + } + ) + collection.create_item( + body={ + 'id': 'loc2', + 'Location': { + 'type': 'Point', + 'coordinates': [100.0, 100.0] + } + } + ) + results = list(collection.query_items( + query="SELECT * FROM root WHERE (ST_DISTANCE(root.Location, {type: 'Point', coordinates: [20.1, 20]}) < 20000)", + enable_cross_partition_query=True + )) + self.assertEqual(1, len(results)) + self.assertEqual('loc1', results[0]['id']) + + db.delete_container(container=collection) + + # CRUD test for User resource + async def test_user_crud(self): + # Should do User CRUD operations successfully. + # create database + db = self.databaseForTest + # list users + users = list(db.list_users()) + before_create_count = len(users) + # create user + user_id = 'new user' + str(uuid.uuid4()) + user = db.create_user(body={'id': user_id}) + self.assertEqual(user.id, user_id, 'user id error') + # list users after creation + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 1) + # query users + results = list(db.query_users( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': user_id} + ] + )) + self.assertTrue(results) + + # replace user + replaced_user_id = 'replaced user' + str(uuid.uuid4()) + user_properties = user.read() + user_properties['id'] = replaced_user_id + replaced_user = db.replace_user(user_id, user_properties) + self.assertEqual(replaced_user.id, + replaced_user_id, + 'user id should change') + self.assertEqual(user_properties['id'], + replaced_user.id, + 'user id should stay the same') + # read user + user = db.get_user_client(replaced_user.id) + self.assertEqual(replaced_user.id, user.id) + # delete user + db.delete_user(user.id) + # read user after deletion + deleted_user = db.get_user_client(user.id) + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + deleted_user.read) + + async def test_user_upsert(self): + # create database + db = self.databaseForTest + + # read users and check count + users = list(db.list_users()) + before_create_count = len(users) + + # create user using Upsert API + user_id = 'user' + str(uuid.uuid4()) + user = db.upsert_user(body={'id': user_id}) + + # verify id property + self.assertEqual(user.id, user_id, 'user id error') + + # read users after creation and verify updated count + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 1) + + # Should replace the user since it already exists, there is no public property to change here + user_properties = user.read() + upserted_user = db.upsert_user(user_properties) + + # verify id property + self.assertEqual(upserted_user.id, + user.id, + 'user id should remain same') + + # read users after upsert and verify count doesn't increases again + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 1) + + user_properties = user.read() + user_properties['id'] = 'new user' + str(uuid.uuid4()) + user.id = user_properties['id'] + + # Upsert should create new user since id is different + new_user = db.upsert_user(user_properties) + + # verify id property + self.assertEqual(new_user.id, user.id, 'user id error') + + # read users after upsert and verify count increases + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 2) + + # delete users + db.delete_user(upserted_user.id) + db.delete_user(new_user.id) + + # read users after delete and verify count remains the same + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count) + + async def test_permission_crud(self): + # Should do Permission CRUD operations successfully + # create database + db = self.databaseForTest + # create user + user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) + # list permissions + permissions = list(user.list_permissions()) + before_create_count = len(permissions) + permission = { + 'id': 'new permission', + 'permissionMode': documents.PermissionMode.Read, + 'resource': 'dbs/AQAAAA==/colls/AQAAAJ0fgTc=' # A random one. + } + # create permission + permission = user.create_permission(permission) + self.assertEqual(permission.id, + 'new permission', + 'permission id error') + # list permissions after creation + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 1) + # query permissions + results = list(user.query_permissions( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': permission.id} + ] + )) + self.assertTrue(results) + + # replace permission + change_permission = permission.properties.copy() + permission.properties['id'] = 'replaced permission' + permission.id = permission.properties['id'] + replaced_permission = user.replace_permission(change_permission['id'], permission.properties) + self.assertEqual(replaced_permission.id, + 'replaced permission', + 'permission id should change') + self.assertEqual(permission.id, + replaced_permission.id, + 'permission id should stay the same') + # read permission + permission = user.get_permission(replaced_permission.id) + self.assertEqual(replaced_permission.id, permission.id) + # delete permission + user.delete_permission(replaced_permission.id) + # read permission after deletion + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + user.get_permission, + permission.id) + + async def test_permission_upsert(self): + # create database + db = self.databaseForTest + + # create user + user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) + + # read permissions and check count + permissions = list(user.list_permissions()) + before_create_count = len(permissions) + + permission_definition = { + 'id': 'permission', + 'permissionMode': documents.PermissionMode.Read, + 'resource': 'dbs/AQAAAA==/colls/AQAAAJ0fgTc=' # A random one. + } + + # create permission using Upsert API + created_permission = user.upsert_permission(permission_definition) + + # verify id property + self.assertEqual(created_permission.id, + permission_definition['id'], + 'permission id error') + + # read permissions after creation and verify updated count + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 1) + + # update permission mode + permission_definition['permissionMode'] = documents.PermissionMode.All + + # should repace the permission since it already exists + upserted_permission = user.upsert_permission(permission_definition) + # verify id property + self.assertEqual(upserted_permission.id, + created_permission.id, + 'permission id should remain same') + + # verify changed property + self.assertEqual(upserted_permission.permission_mode, + permission_definition['permissionMode'], + 'permissionMode should change') + + # read permissions and verify count doesn't increases again + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 1) + + # update permission id + created_permission.properties['id'] = 'new permission' + created_permission.id = created_permission.properties['id'] + # resource needs to be changed along with the id in order to create a new permission + created_permission.properties['resource'] = 'dbs/N9EdAA==/colls/N9EdAIugXgA=' + created_permission.resource_link = created_permission.properties['resource'] + + # should create new permission since id has changed + new_permission = user.upsert_permission(created_permission.properties) + + # verify id and resource property + self.assertEqual(new_permission.id, + created_permission.id, + 'permission id should be same') + + self.assertEqual(new_permission.resource_link, + created_permission.resource_link, + 'permission resource should be same') + + # read permissions and verify count increases + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 2) + + # delete permissions + user.delete_permission(upserted_permission.id) + user.delete_permission(new_permission.id) + + # read permissions and verify count remains the same + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count) + + async def test_authorization(self): + async def __SetupEntities(client): + """ + Sets up entities for this test. + + :Parameters: + - `client`: cosmos_client_connection.CosmosClientConnection + + :Returns: + dict + + """ + # create database + db = self.databaseForTest + # create collection + collection = db.create_container( + id='test_authorization' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/id', kind='Hash') + ) + # create document1 + document = collection.create_item( + body={'id': 'doc1', + 'spam': 'eggs', + 'key': 'value'}, + ) + + # create user + user = db.create_user(body={'id': 'user' + str(uuid.uuid4())}) + + # create permission for collection + permission = { + 'id': 'permission On Coll', + 'permissionMode': documents.PermissionMode.Read, + 'resource': "dbs/" + db.id + "/colls/" + collection.id + } + permission_on_coll = user.create_permission(body=permission) + self.assertIsNotNone(permission_on_coll.properties['_token'], + 'permission token is invalid') + + # create permission for document + permission = { + 'id': 'permission On Doc', + 'permissionMode': documents.PermissionMode.All, + 'resource': "dbs/" + db.id + "/colls/" + collection.id + "/docs/" + document["id"] + } + permission_on_doc = user.create_permission(body=permission) + self.assertIsNotNone(permission_on_doc.properties['_token'], + 'permission token is invalid') + + entities = { + 'db': db, + 'coll': collection, + 'doc': document, + 'user': user, + 'permissionOnColl': permission_on_coll, + 'permissionOnDoc': permission_on_doc, + } + return entities + + # Client without any authorization will fail. + client = cosmos_client.CosmosClient(CRUDTests.host, {}, "Session", connection_policy=CRUDTests.connectionPolicy) + await self.__AssertHTTPFailureWithStatus(StatusCodes.UNAUTHORIZED, + list, + client.list_databases()) + # Client with master key. + client = cosmos_client.CosmosClient(CRUDTests.host, + CRUDTests.masterKey, + "Session", + connection_policy=CRUDTests.connectionPolicy) + # setup entities + entities = __SetupEntities(client) + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: + entities['permissionOnColl'].properties['_token']} + col_client = cosmos_client.CosmosClient( + CRUDTests.host, resource_tokens,"Session", connection_policy=CRUDTests.connectionPolicy) + db = entities['db'] + + old_client_connection = db.client_connection + db.client_connection = col_client.client_connection + # 1. Success-- Use Col Permission to Read + success_coll = db.get_container_client(container=entities['coll']) + # 2. Failure-- Use Col Permission to delete + await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, + db.delete_container, + success_coll) + # 3. Success-- Use Col Permission to Read All Docs + success_documents = list(success_coll.read_all_items()) + self.assertTrue(success_documents != None, + 'error reading documents') + self.assertEqual(len(success_documents), + 1, + 'Expected 1 Document to be succesfully read') + # 4. Success-- Use Col Permission to Read Doc + + docId = entities['doc']['id'] + success_doc = success_coll.read_item( + item=docId, + partition_key=docId + ) + self.assertTrue(success_doc != None, 'error reading document') + self.assertEqual( + success_doc['id'], + entities['doc']['id'], + 'Expected to read children using parent permissions') + + #5. Failure-- Use Col Permission to Delete Doc + await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, + success_coll.delete_item, + docId, docId) + + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId : + entities['permissionOnDoc'].properties['_token']} + + doc_client = cosmos_client.CosmosClient( + CRUDTests.host, resource_tokens,"Session", connection_policy=CRUDTests.connectionPolicy) + + #6. Success-- Use Doc permission to read doc + read_doc = doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, docId) + self.assertEqual(read_doc["id"], docId) + + #6. Success-- Use Doc permission to delete doc + doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) + self.assertEqual(read_doc["id"], docId) + + db.client_connection = old_client_connection + db.delete_container(entities['coll']) + + async def test_trigger_crud(self): + # create database + db = self.databaseForTest + # create collection + collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + # read triggers + triggers = list(collection.scripts.list_triggers()) + # create a trigger + before_create_triggers_count = len(triggers) + trigger_definition = { + 'id': 'sample trigger', + 'serverScript': 'function() {var x = 10;}', + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + } + trigger = collection.scripts.create_trigger(body=trigger_definition) + for property in trigger_definition: + if property != "serverScript": + self.assertEqual( + trigger[property], + trigger_definition[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(trigger['body'], + 'function() {var x = 10;}') + + # read triggers after creation + triggers = list(collection.scripts.list_triggers()) + self.assertEqual(len(triggers), + before_create_triggers_count + 1, + 'create should increase the number of triggers') + # query triggers + triggers = list(collection.scripts.query_triggers( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': trigger_definition['id']} + ] + )) + self.assertTrue(triggers) + + # replace trigger + change_trigger = trigger.copy() + trigger['body'] = 'function() {var x = 20;}' + replaced_trigger = collection.scripts.replace_trigger(change_trigger['id'], trigger) + for property in trigger_definition: + if property != "serverScript": + self.assertEqual( + replaced_trigger[property], + trigger[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(replaced_trigger['body'], + 'function() {var x = 20;}') + + # read trigger + trigger = collection.scripts.get_trigger(replaced_trigger['id']) + self.assertEqual(replaced_trigger['id'], trigger['id']) + # delete trigger + collection.scripts.delete_trigger(replaced_trigger['id']) + # read triggers after deletion + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + collection.scripts.delete_trigger, + replaced_trigger['id']) + + async def test_udf_crud(self): + # create database + db = self.databaseForTest + # create collection + collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + # read udfs + udfs = list(collection.scripts.list_user_defined_functions()) + # create a udf + before_create_udfs_count = len(udfs) + udf_definition = { + 'id': 'sample udf', + 'body': 'function() {var x = 10;}' + } + udf = collection.scripts.create_user_defined_function(body=udf_definition) + for property in udf_definition: + self.assertEqual( + udf[property], + udf_definition[property], + 'property {property} should match'.format(property=property)) + + # read udfs after creation + udfs = list(collection.scripts.list_user_defined_functions()) + self.assertEqual(len(udfs), + before_create_udfs_count + 1, + 'create should increase the number of udfs') + # query udfs + results = list(collection.scripts.query_user_defined_functions( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': udf_definition['id']} + ] + )) + self.assertTrue(results) + # replace udf + change_udf = udf.copy() + udf['body'] = 'function() {var x = 20;}' + replaced_udf = collection.scripts.replace_user_defined_function(udf=udf['id'], body=udf) + for property in udf_definition: + self.assertEqual( + replaced_udf[property], + udf[property], + 'property {property} should match'.format(property=property)) + # read udf + udf = collection.scripts.get_user_defined_function(replaced_udf['id']) + self.assertEqual(replaced_udf['id'], udf['id']) + # delete udf + collection.scripts.delete_user_defined_function(replaced_udf['id']) + # read udfs after deletion + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + collection.scripts.get_user_defined_function, + replaced_udf['id']) + + async def test_sproc_crud(self): + # create database + db = self.databaseForTest + # create collection + collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + # read sprocs + sprocs = list(collection.scripts.list_stored_procedures()) + # create a sproc + before_create_sprocs_count = len(sprocs) + sproc_definition = { + 'id': 'sample sproc', + 'serverScript': 'function() {var x = 10;}' + } + sproc = collection.scripts.create_stored_procedure(body=sproc_definition) + for property in sproc_definition: + if property != "serverScript": + self.assertEqual( + sproc[property], + sproc_definition[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(sproc['body'], 'function() {var x = 10;}') + + # read sprocs after creation + sprocs = list(collection.scripts.list_stored_procedures()) + self.assertEqual(len(sprocs), + before_create_sprocs_count + 1, + 'create should increase the number of sprocs') + # query sprocs + sprocs = list(collection.scripts.query_stored_procedures( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': sproc_definition['id']} + ] + )) + self.assertIsNotNone(sprocs) + # replace sproc + change_sproc = sproc.copy() + sproc['body'] = 'function() {var x = 20;}' + replaced_sproc = collection.scripts.replace_stored_procedure(sproc=change_sproc['id'], body=sproc) + for property in sproc_definition: + if property != 'serverScript': + self.assertEqual( + replaced_sproc[property], + sproc[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(replaced_sproc['body'], + "function() {var x = 20;}") + # read sproc + sproc = collection.scripts.get_stored_procedure(replaced_sproc['id']) + self.assertEqual(replaced_sproc['id'], sproc['id']) + # delete sproc + collection.scripts.delete_stored_procedure(replaced_sproc['id']) + # read sprocs after deletion + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + collection.scripts.get_stored_procedure, + replaced_sproc['id']) + + async def test_script_logging_execute_stored_procedure(self): + created_db = self.databaseForTest + + created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + sproc = { + 'id': 'storedProcedure' + str(uuid.uuid4()), + 'body': ( + 'function () {' + + ' var mytext = \'x\';' + + ' var myval = 1;' + + ' try {' + + ' console.log(\'The value of %s is %s.\', mytext, myval);' + + ' getContext().getResponse().setBody(\'Success!\');' + + ' }' + + ' catch (err) {' + + ' getContext().getResponse().setBody(\'inline err: [\' + err.number + \'] \' + err);' + + ' }' + '}') + } + + created_sproc = created_collection.scripts.create_stored_procedure(body=sproc) + + result = created_collection.scripts.execute_stored_procedure( + sproc=created_sproc['id'], + partition_key=1 + ) + + self.assertEqual(result, 'Success!') + self.assertFalse(HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) + + result = created_collection.scripts.execute_stored_procedure( + sproc=created_sproc['id'], + enable_script_logging=True, + partition_key=1 + ) + + self.assertEqual(result, 'Success!') + self.assertEqual(urllib.quote('The value of x is 1.'), + created_collection.scripts.client_connection.last_response_headers.get(HttpHeaders.ScriptLogResults)) + + result = created_collection.scripts.execute_stored_procedure( + sproc=created_sproc['id'], + enable_script_logging=False, + partition_key=1 + ) + + self.assertEqual(result, 'Success!') + self.assertFalse(HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) + + async def test_collection_indexing_policy(self): + # create database + db = self.databaseForTest + # create collection + collection = db.create_container( + id='test_collection_indexing_policy default policy' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + collection_properties = collection.read() + self.assertEqual(collection_properties['indexingPolicy']['indexingMode'], + documents.IndexingMode.Consistent, + 'default indexing mode should be consistent') + + db.delete_container(container=collection) + + consistent_collection = db.create_container( + id='test_collection_indexing_policy consistent collection ' + str(uuid.uuid4()), + indexing_policy={ + 'indexingMode': documents.IndexingMode.Consistent + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + consistent_collection_properties = consistent_collection.read() + self.assertEqual(consistent_collection_properties['indexingPolicy']['indexingMode'], + documents.IndexingMode.Consistent, + 'indexing mode should be consistent') + + db.delete_container(container=consistent_collection) + + collection_with_indexing_policy = db.create_container( + id='CollectionWithIndexingPolicy ' + str(uuid.uuid4()), + indexing_policy={ + 'automatic': True, + 'indexingMode': documents.IndexingMode.Consistent, + 'includedPaths': [ + { + 'path': '/', + 'indexes': [ + { + 'kind': documents.IndexKind.Hash, + 'dataType': documents.DataType.Number, + 'precision': 2 + } + ] + } + ], + 'excludedPaths': [ + { + 'path': '/"systemMetadata"/*' + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + collection_with_indexing_policy_properties = collection_with_indexing_policy.read() + self.assertEqual(1, + len(collection_with_indexing_policy_properties['indexingPolicy']['includedPaths']), + 'Unexpected includedPaths length') + self.assertEqual(2, + len(collection_with_indexing_policy_properties['indexingPolicy']['excludedPaths']), + 'Unexpected excluded path count') + db.delete_container(container=collection_with_indexing_policy) + + async def test_create_default_indexing_policy(self): + # create database + db = self.databaseForTest + + # no indexing policy specified + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # partial policy specified + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), + indexing_policy={ + 'indexingMode': documents.IndexingMode.Consistent, 'automatic': True + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # default policy + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy03' + str(uuid.uuid4()), + indexing_policy={}, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # missing indexes + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy04' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/*' + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # missing precision + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy05' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/*', + 'indexes': [ + { + 'kind': documents.IndexKind.Hash, + 'dataType': documents.DataType.String + }, + { + 'kind': documents.IndexKind.Range, + 'dataType': documents.DataType.Number + } + ] + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): + # create database + db = self.databaseForTest + + indexing_policy = { + "spatialIndexes": [ + { + "path": "/path0/*", + "types": [ + "Point", + "LineString", + "Polygon", + "MultiPolygon" + ] + }, + { + "path": "/path1/*", + "types": [ + "Point", + "LineString", + "Polygon", + "MultiPolygon" + ] + } + ], + "compositeIndexes": [ + [ + { + "path": "/path1", + "order": "ascending" + }, + { + "path": "/path2", + "order": "descending" + }, + { + "path": "/path3", + "order": "ascending" + } + ], + [ + { + "path": "/path4", + "order": "ascending" + }, + { + "path": "/path5", + "order": "descending" + }, + { + "path": "/path6", + "order": "ascending" + } + ] + ] + } + + custom_logger = logging.getLogger("CustomLogger") + created_container = db.create_container( + id='composite_index_spatial_index' + str(uuid.uuid4()), + indexing_policy=indexing_policy, + partition_key=PartitionKey(path='/id', kind='Hash'), + headers={"Foo":"bar"}, + user_agent="blah", + user_agent_overwrite=True, + logging_enable=True, + logger=custom_logger, + ) + created_properties = created_container.read(logger=custom_logger) + read_indexing_policy = created_properties['indexingPolicy'] + + if 'localhost' in self.host or '127.0.0.1' in self.host: # TODO: Differing result between live and emulator + self.assertListEqual(indexing_policy['spatialIndexes'], read_indexing_policy['spatialIndexes']) + else: + # All types are returned for spatial Indexes + self.assertListEqual(indexing_policy['spatialIndexes'], read_indexing_policy['spatialIndexes']) + + self.assertListEqual(indexing_policy['compositeIndexes'], read_indexing_policy['compositeIndexes']) + db.delete_container(container=created_container) + + async def _check_default_indexing_policy_paths(self, indexing_policy): + async def __get_first(array): + if array: + return array[0] + else: + return None + + # '/_etag' is present in excluded paths by default + self.assertEqual(1, len(indexing_policy['excludedPaths'])) + # included paths should be 1: '/'. + self.assertEqual(1, len(indexing_policy['includedPaths'])) + + root_included_path = __get_first([included_path for included_path in indexing_policy['includedPaths'] + if included_path['path'] == '/*']) + self.assertFalse(root_included_path.get('indexes')) + + async def test_client_request_timeout(self): + # Test is flaky on Emulator + if not('localhost' in self.host or '127.0.0.1' in self.host): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.RequestTimeout = 0.000000000001 + + with self.assertRaises(Exception): + # client does a getDatabaseAccount on initialization, which will time out + cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", connection_policy=connection_policy) + + async def test_client_request_timeout_when_connection_retry_configuration_specified(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.RequestTimeout = 0.000000000001 + connection_policy.ConnectionRetryConfiguration = Retry( + total=3, + read=3, + connect=3, + backoff_factor=0.3, + status_forcelist=(500, 502, 504) + ) + with self.assertRaises(AzureError): + # client does a getDatabaseAccount on initialization, which will time out + cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", connection_policy=connection_policy) + + async def test_client_connection_retry_configuration(self): + total_time_for_two_retries = self.initialize_client_with_connection_urllib_retry_config(2) + total_time_for_three_retries = self.initialize_client_with_connection_urllib_retry_config(3) + self.assertGreater(total_time_for_three_retries, total_time_for_two_retries) + + total_time_for_two_retries = self.initialize_client_with_connection_core_retry_config(2) + total_time_for_three_retries = self.initialize_client_with_connection_core_retry_config(3) + self.assertGreater(total_time_for_three_retries, total_time_for_two_retries) + + async def initialize_client_with_connection_urllib_retry_config(self, retries): + retry_policy = Retry( + total=retries, + read=retries, + connect=retries, + backoff_factor=0.3, + status_forcelist=(500, 502, 504) + ) + start_time = time.time() + try: + cosmos_client.CosmosClient( + "https://localhost:9999", + CRUDTests.masterKey, + "Session", + connection_retry_policy=retry_policy) + self.fail() + except AzureError as e: + end_time = time.time() + return end_time - start_time + + async def initialize_client_with_connection_core_retry_config(self, retries): + start_time = time.time() + try: + cosmos_client.CosmosClient( + "https://localhost:9999", + CRUDTests.masterKey, + "Session", + retry_total=retries, + retry_read=retries, + retry_connect=retries, + retry_status=retries) + self.fail() + except AzureError as e: + end_time = time.time() + return end_time - start_time + + async def test_absolute_client_timeout(self): + with self.assertRaises(exceptions.CosmosClientTimeoutError): + cosmos_client.CosmosClient( + "https://localhost:9999", + CRUDTests.masterKey, + "Session", + retry_total=3, + timeout=1) + + error_response = ServiceResponseError("Read timeout") + timeout_transport = TimeoutTransport(error_response) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + + with self.assertRaises(exceptions.CosmosClientTimeoutError): + client.create_database_if_not_exists("test", timeout=2) + + status_response = 500 # Users connection level retry + timeout_transport = TimeoutTransport(status_response) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + client.create_database("test", timeout=2) + + databases = client.list_databases(timeout=2) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + list(databases) + + status_response = 429 # Uses Cosmos custom retry + timeout_transport = TimeoutTransport(status_response) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + client.create_database_if_not_exists("test", timeout=2) + + databases = client.list_databases(timeout=2) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + list(databases) + + async def test_query_iterable_functionality(self): + def __create_resources(client): + """Creates resources for this test. + + :Parameters: + - `client`: cosmos_client_connection.CosmosClientConnection + + :Returns: + dict + + """ + collection = self.configs.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + doc1 = collection.create_item(body={'id': 'doc1', 'prop1': 'value1'}) + doc2 = collection.create_item(body={'id': 'doc2', 'prop1': 'value2'}) + doc3 = collection.create_item(body={'id': 'doc3', 'prop1': 'value3'}) + resources = { + 'coll': collection, + 'doc1': doc1, + 'doc2': doc2, + 'doc3': doc3 + } + return resources + + # Validate QueryIterable by converting it to a list. + resources = __create_resources(self.client) + results = resources['coll'].read_all_items(max_item_count=2) + docs = list(iter(results)) + self.assertEqual(3, + len(docs), + 'QueryIterable should return all documents' + + ' using continuation') + self.assertEqual(resources['doc1']['id'], docs[0]['id']) + self.assertEqual(resources['doc2']['id'], docs[1]['id']) + self.assertEqual(resources['doc3']['id'], docs[2]['id']) + + # Validate QueryIterable iterator with 'for'. + results = resources['coll'].read_all_items(max_item_count=2) + counter = 0 + # test QueryIterable with 'for'. + for doc in iter(results): + counter += 1 + if counter == 1: + self.assertEqual(resources['doc1']['id'], + doc['id'], + 'first document should be doc1') + elif counter == 2: + self.assertEqual(resources['doc2']['id'], + doc['id'], + 'second document should be doc2') + elif counter == 3: + self.assertEqual(resources['doc3']['id'], + doc['id'], + 'third document should be doc3') + self.assertEqual(counter, 3) + + # Get query results page by page. + results = resources['coll'].read_all_items(max_item_count=2) + + page_iter = results.by_page() + first_block = list(next(page_iter)) + self.assertEqual(2, len(first_block), 'First block should have 2 entries.') + self.assertEqual(resources['doc1']['id'], first_block[0]['id']) + self.assertEqual(resources['doc2']['id'], first_block[1]['id']) + self.assertEqual(1, len(list(next(page_iter))), 'Second block should have 1 entry.') + with self.assertRaises(StopIteration): + next(page_iter) + + async def test_trigger_functionality(self): + triggers_in_collection1 = [ + { + 'id': 't1', + 'body': ( + 'function() {' + + ' var item = getContext().getRequest().getBody();' + + ' item.id = item.id.toUpperCase() + \'t1\';' + + ' getContext().getRequest().setBody(item);' + + '}'), + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }, + { + 'id': 'response1', + 'body': ( + 'function() {' + + ' var prebody = getContext().getRequest().getBody();' + + ' if (prebody.id != \'TESTING POST TRIGGERt1\')' + ' throw \'id mismatch\';' + + ' var postbody = getContext().getResponse().getBody();' + + ' if (postbody.id != \'TESTING POST TRIGGERt1\')' + ' throw \'id mismatch\';' + '}'), + 'triggerType': documents.TriggerType.Post, + 'triggerOperation': documents.TriggerOperation.All + }, + { + 'id': 'response2', + # can't be used because setValue is currently disabled + 'body': ( + 'function() {' + + ' var predoc = getContext().getRequest().getBody();' + + ' var postdoc = getContext().getResponse().getBody();' + + ' getContext().getResponse().setValue(' + + ' \'predocname\', predoc.id + \'response2\');' + + ' getContext().getResponse().setValue(' + + ' \'postdocname\', postdoc.id + \'response2\');' + + '}'), + 'triggerType': documents.TriggerType.Post, + 'triggerOperation': documents.TriggerOperation.All, + }] + triggers_in_collection2 = [ + { + 'id': "t2", + 'body': "function() { }", # trigger already stringified + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }, + { + 'id': "t3", + 'body': ( + 'function() {' + + ' var item = getContext().getRequest().getBody();' + + ' item.id = item.id.toLowerCase() + \'t3\';' + + ' getContext().getRequest().setBody(item);' + + '}'), + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }] + triggers_in_collection3 = [ + { + 'id': 'triggerOpType', + 'body': 'function() { }', + 'triggerType': documents.TriggerType.Post, + 'triggerOperation': documents.TriggerOperation.Delete, + }] + + async def __CreateTriggers(collection, triggers): + """Creates triggers. + + :Parameters: + - `client`: cosmos_client_connection.CosmosClientConnection + - `collection`: dict + + """ + for trigger_i in triggers: + trigger = collection.scripts.create_trigger(body=trigger_i) + for property in trigger_i: + self.assertEqual( + trigger[property], + trigger_i[property], + 'property {property} should match'.format(property=property)) + + # create database + db = self.databaseForTest + # create collections + pkd = PartitionKey(path='/id', kind='Hash') + collection1 = db.create_container(id='test_trigger_functionality 1 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) + collection2 = db.create_container(id='test_trigger_functionality 2 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) + collection3 = db.create_container(id='test_trigger_functionality 3 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) + # create triggers + await __CreateTriggers(collection1, triggers_in_collection1) + await __CreateTriggers(collection2, triggers_in_collection2) + await __CreateTriggers(collection3, triggers_in_collection3) + # create document + triggers_1 = list(collection1.scripts.list_triggers()) + self.assertEqual(len(triggers_1), 3) + document_1_1 = collection1.create_item( + body={'id': 'doc1', + 'key': 'value'}, + pre_trigger_include='t1' + ) + self.assertEqual(document_1_1['id'], + 'DOC1t1', + 'id should be capitalized') + + document_1_2 = collection1.create_item( + body={'id': 'testing post trigger', 'key': 'value'}, + pre_trigger_include='t1', + post_trigger_include='response1', + ) + self.assertEqual(document_1_2['id'], 'TESTING POST TRIGGERt1') + + document_1_3 = collection1.create_item( + body={'id': 'responseheaders', 'key': 'value'}, + pre_trigger_include='t1' + ) + self.assertEqual(document_1_3['id'], "RESPONSEHEADERSt1") + + triggers_2 = list(collection2.scripts.list_triggers()) + self.assertEqual(len(triggers_2), 2) + document_2_1 = collection2.create_item( + body={'id': 'doc2', + 'key': 'value2'}, + pre_trigger_include='t2' + ) + self.assertEqual(document_2_1['id'], + 'doc2', + 'id shouldn\'t change') + document_2_2 = collection2.create_item( + body={'id': 'Doc3', + 'prop': 'empty', + 'key': 'value2'}, + pre_trigger_include='t3') + self.assertEqual(document_2_2['id'], 'doc3t3') + + triggers_3 = list(collection3.scripts.list_triggers()) + self.assertEqual(len(triggers_3), 1) + with self.assertRaises(Exception): + collection3.create_item( + body={'id': 'Docoptype', 'key': 'value2'}, + post_trigger_include='triggerOpType' + ) + + db.delete_container(collection1) + db.delete_container(collection2) + db.delete_container(collection3) + + async def test_stored_procedure_functionality(self): + # create database + db = self.databaseForTest + # create collection + collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + sproc1 = { + 'id': 'storedProcedure1' + str(uuid.uuid4()), + 'body': ( + 'function () {' + + ' for (var i = 0; i < 1000; i++) {' + + ' var item = getContext().getResponse().getBody();' + + ' if (i > 0 && item != i - 1) throw \'body mismatch\';' + + ' getContext().getResponse().setBody(i);' + + ' }' + + '}') + } + + retrieved_sproc = collection.scripts.create_stored_procedure(body=sproc1) + result = collection.scripts.execute_stored_procedure( + sproc=retrieved_sproc['id'], + partition_key=1 + ) + self.assertEqual(result, 999) + sproc2 = { + 'id': 'storedProcedure2' + str(uuid.uuid4()), + 'body': ( + 'function () {' + + ' for (var i = 0; i < 10; i++) {' + + ' getContext().getResponse().appendValue(\'Body\', i);' + + ' }' + + '}') + } + retrieved_sproc2 = collection.scripts.create_stored_procedure(body=sproc2) + result = collection.scripts.execute_stored_procedure( + sproc=retrieved_sproc2['id'], + partition_key=1 + ) + self.assertEqual(int(result), 123456789) + sproc3 = { + 'id': 'storedProcedure3' + str(uuid.uuid4()), + 'body': ( + 'function (input) {' + + ' getContext().getResponse().setBody(' + + ' \'a\' + input.temp);' + + '}') + } + retrieved_sproc3 = collection.scripts.create_stored_procedure(body=sproc3) + result = collection.scripts.execute_stored_procedure( + sproc=retrieved_sproc3['id'], + params={'temp': 'so'}, + partition_key=1 + ) + self.assertEqual(result, 'aso') + + async def __ValidateOfferResponseBody(self, offer, expected_coll_link, expected_offer_type): + # type: (Offer, str, Any) -> None + self.assertIsNotNone(offer.properties['id'], 'Id cannot be null.') + self.assertIsNotNone(offer.properties.get('_rid'), 'Resource Id (Rid) cannot be null.') + self.assertIsNotNone(offer.properties.get('_self'), 'Self Link cannot be null.') + self.assertIsNotNone(offer.properties.get('resource'), 'Resource Link cannot be null.') + self.assertTrue(offer.properties['_self'].find(offer.properties['id']) != -1, + 'Offer id not contained in offer self link.') + self.assertEqual(expected_coll_link.strip('/'), offer.properties['resource'].strip('/')) + if (expected_offer_type): + self.assertEqual(expected_offer_type, offer.properties.get('offerType')) + + async def test_offer_read_and_query(self): + # Create database. + db = self.databaseForTest + + # Create collection. + collection = db.create_container( + id='test_offer_read_and_query ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/id', kind='Hash') + ) + # Read the offer. + expected_offer = collection.read_offer() + collection_properties = collection.read() + await self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) + + # Now delete the collection. + db.delete_container(container=collection) + # Reading fails. + await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, collection.read_offer) + + async def test_offer_replace(self): + # Create database. + db = self.databaseForTest + # Create collection. + collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + # Read Offer + expected_offer = collection.read_offer() + collection_properties = collection.read() + await self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) + # Replace the offer. + replaced_offer = collection.replace_throughput(expected_offer.offer_throughput + 100) + collection_properties = collection.read() + await self.__ValidateOfferResponseBody(replaced_offer, collection_properties.get('_self'), None) + # Check if the replaced offer is what we expect. + self.assertEqual(expected_offer.properties.get('content').get('offerThroughput') + 100, + replaced_offer.properties.get('content').get('offerThroughput')) + self.assertEqual(expected_offer.offer_throughput + 100, + replaced_offer.offer_throughput) + + async def test_database_account_functionality(self): + # Validate database account functionality. + database_account = self.client.get_database_account() + self.assertEqual(database_account.DatabasesLink, '/dbs/') + self.assertEqual(database_account.MediaLink, '/media/') + if (HttpHeaders.MaxMediaStorageUsageInMB in + self.client.client_connection.last_response_headers): + self.assertEqual( + database_account.MaxMediaStorageUsageInMB, + self.client.client_connection.last_response_headers[ + HttpHeaders.MaxMediaStorageUsageInMB]) + if (HttpHeaders.CurrentMediaStorageUsageInMB in + self.client.client_connection.last_response_headers): + self.assertEqual( + database_account.CurrentMediaStorageUsageInMB, + self.client.client_connection.last_response_headers[ + HttpHeaders.CurrentMediaStorageUsageInMB]) + self.assertIsNotNone(database_account.ConsistencyPolicy['defaultConsistencyLevel']) + + async def test_index_progress_headers(self): + created_db = self.databaseForTest + consistent_coll = created_db.create_container( + id='test_index_progress_headers consistent_coll ' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/id", kind='Hash'), + ) + created_container = created_db.get_container_client(container=consistent_coll) + created_container.read(populate_quota_info=True) + self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) + self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) + + none_coll = created_db.create_container( + id='test_index_progress_headers none_coll ' + str(uuid.uuid4()), + indexing_policy={ + 'indexingMode': documents.IndexingMode.NoIndex, + 'automatic': False + }, + partition_key=PartitionKey(path="/id", kind='Hash') + ) + created_container = created_db.get_container_client(container=none_coll) + created_container.read(populate_quota_info=True) + self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) + self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) + + created_db.delete_container(consistent_coll) + created_db.delete_container(none_coll) + + async def test_id_validation(self): + # Id shouldn't end with space. + try: + self.client.create_database(id='id_with_space ') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id ends with a space.', e.args[0]) + # Id shouldn't contain '/'. + + try: + self.client.create_database(id='id_with_illegal/_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + # Id shouldn't contain '\\'. + + try: + self.client.create_database(id='id_with_illegal\\_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + # Id shouldn't contain '?'. + + try: + self.client.create_database(id='id_with_illegal?_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + # Id shouldn't contain '#'. + + try: + self.client.create_database(id='id_with_illegal#_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + + # Id can begin with space + db = self.client.create_database(id=' id_begin_space') + self.assertTrue(True) + + self.client.delete_database(database=db) + + async def test_id_case_validation(self): + # create database + created_db = self.databaseForTest + + uuid_string = str(uuid.uuid4()) + collection_id1 = 'sampleCollection ' + uuid_string + collection_id2 = 'SampleCollection ' + uuid_string + + # Verify that no collections exist + collections = list(created_db.list_containers()) + number_of_existing_collections = len(collections) + + # create 2 collections with different casing of IDs + # pascalCase + created_collection1 = created_db.create_container( + id=collection_id1, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + # CamelCase + created_collection2 = created_db.create_container( + id=collection_id2, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + collections = list(created_db.list_containers()) + + # verify if a total of 2 collections got created + self.assertEqual(len(collections), number_of_existing_collections + 2) + + # verify that collections are created with specified IDs + self.assertEqual(collection_id1, created_collection1.id) + self.assertEqual(collection_id2, created_collection2.id) + + created_db.delete_container(created_collection1) + created_db.delete_container(created_collection2) + + #TODO: fix test + @pytest.mark.skip + async def test_id_unicode_validation(self): + # create database + created_db = self.databaseForTest + + # unicode chars in Hindi for Id which translates to: "Hindi is the national language of India" + collection_id1 = u'हिन्दी भारत की राष्ट्रीय भाषा है' + + # Special chars for Id + collection_id2 = "!@$%^&*()-~`'_[]{}|;:,.<>" + + # verify that collections are created with specified IDs + created_collection1 = created_db.create_container( + id=collection_id1, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + created_collection2 = created_db.create_container( + id=collection_id2, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + self.assertEqual(collection_id1, created_collection1.id) + self.assertEqual(collection_id2, created_collection2.id) + + created_collection1_properties = created_collection1.read() + created_collection2_properties = created_collection2.read() + + created_db.client_connection.DeleteContainer(created_collection1_properties['_self']) + created_db.client_connection.DeleteContainer(created_collection2_properties['_self']) + + async def test_get_resource_with_dictionary_and_object(self): + created_db = self.databaseForTest + + # read database with id + read_db = self.client.get_database_client(created_db.id) + self.assertEqual(read_db.id, created_db.id) + + # read database with instance + read_db = self.client.get_database_client(created_db) + self.assertEqual(read_db.id, created_db.id) + + # read database with properties + read_db = self.client.get_database_client(created_db.read()) + self.assertEqual(read_db.id, created_db.id) + + created_container = self.configs.create_multi_partition_collection_if_not_exist(self.client) + + # read container with id + read_container = created_db.get_container_client(created_container.id) + self.assertEqual(read_container.id, created_container.id) + + # read container with instance + read_container = created_db.get_container_client(created_container) + self.assertEqual(read_container.id, created_container.id) + + # read container with properties + created_properties = created_container.read() + read_container = created_db.get_container_client(created_properties) + self.assertEqual(read_container.id, created_container.id) + + created_item = created_container.create_item({'id':'1' + str(uuid.uuid4())}) + + # read item with id + read_item = created_container.read_item(item=created_item['id'], partition_key=created_item['id']) + self.assertEqual(read_item['id'], created_item['id']) + + # read item with properties + read_item = created_container.read_item(item=created_item, partition_key=created_item['id']) + self.assertEqual(read_item['id'], created_item['id']) + + created_sproc = created_container.scripts.create_stored_procedure({ + 'id': 'storedProcedure' + str(uuid.uuid4()), + 'body': 'function () { }' + }) + + # read sproc with id + read_sproc = created_container.scripts.get_stored_procedure(created_sproc['id']) + self.assertEqual(read_sproc['id'], created_sproc['id']) + + # read sproc with properties + read_sproc = created_container.scripts.get_stored_procedure(created_sproc) + self.assertEqual(read_sproc['id'], created_sproc['id']) + + created_trigger = created_container.scripts.create_trigger({ + 'id': 'sample trigger' + str(uuid.uuid4()), + 'serverScript': 'function() {var x = 10;}', + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }) + + # read trigger with id + read_trigger = created_container.scripts.get_trigger(created_trigger['id']) + self.assertEqual(read_trigger['id'], created_trigger['id']) + + # read trigger with properties + read_trigger = created_container.scripts.get_trigger(created_trigger) + self.assertEqual(read_trigger['id'], created_trigger['id']) + + created_udf = created_container.scripts.create_user_defined_function({ + 'id': 'sample udf' + str(uuid.uuid4()), + 'body': 'function() {var x = 10;}' + }) + + # read udf with id + read_udf = created_container.scripts.get_user_defined_function(created_udf['id']) + self.assertEqual(created_udf['id'], read_udf['id']) + + # read udf with properties + read_udf = created_container.scripts.get_user_defined_function(created_udf) + self.assertEqual(created_udf['id'], read_udf['id']) + + created_user = created_db.create_user({ + 'id': 'user' + str(uuid.uuid4()) + }) + + # read user with id + read_user = created_db.get_user_client(created_user.id) + self.assertEqual(read_user.id, created_user.id) + + # read user with instance + read_user = created_db.get_user_client(created_user) + self.assertEqual(read_user.id, created_user.id) + + # read user with properties + created_user_properties = created_user.read() + read_user = created_db.get_user_client(created_user_properties) + self.assertEqual(read_user.id, created_user.id) + + created_permission = created_user.create_permission({ + 'id': 'all permission' + str(uuid.uuid4()), + 'permissionMode': documents.PermissionMode.All, + 'resource': created_container.container_link, + 'resourcePartitionKey': [1] + }) + + # read permission with id + read_permission = created_user.get_permission(created_permission.id) + self.assertEqual(read_permission.id, created_permission.id) + + # read permission with instance + read_permission = created_user.get_permission(created_permission) + self.assertEqual(read_permission.id, created_permission.id) + + # read permission with properties + read_permission = created_user.get_permission(created_permission.properties) + self.assertEqual(read_permission.id, created_permission.id) + + # Temporarily commenting analytical storage tests until emulator support comes. + # def test_create_container_with_analytical_store_off(self): + # # don't run test, for the time being, if running against the emulator + # if 'localhost' in self.host or '127.0.0.1' in self.host: + # return + + # created_db = self.databaseForTest + # collection_id = 'test_create_container_with_analytical_store_off_' + str(uuid.uuid4()) + # collection_indexing_policy = {'indexingMode': 'consistent'} + # created_recorder = RecordDiagnostics() + # created_collection = created_db.create_container(id=collection_id, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key not in properties or properties[ttl_key] == None) + + # def test_create_container_with_analytical_store_on(self): + # # don't run test, for the time being, if running against the emulator + # if 'localhost' in self.host or '127.0.0.1' in self.host: + # return + + # created_db = self.databaseForTest + # collection_id = 'test_create_container_with_analytical_store_on_' + str(uuid.uuid4()) + # collection_indexing_policy = {'indexingMode': 'consistent'} + # created_recorder = RecordDiagnostics() + # created_collection = created_db.create_container(id=collection_id, + # analytical_storage_ttl=-1, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) + + # def test_create_container_if_not_exists_with_analytical_store_on(self): + # # don't run test, for the time being, if running against the emulator + # if 'localhost' in self.host or '127.0.0.1' in self.host: + # return + + # # first, try when we know the container doesn't exist. + # created_db = self.databaseForTest + # collection_id = 'test_create_container_if_not_exists_with_analytical_store_on_' + str(uuid.uuid4()) + # collection_indexing_policy = {'indexingMode': 'consistent'} + # created_recorder = RecordDiagnostics() + # created_collection = created_db.create_container_if_not_exists(id=collection_id, + # analytical_storage_ttl=-1, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) + + # # next, try when we know the container DOES exist. This way both code paths are tested. + # created_collection = created_db.create_container_if_not_exists(id=collection_id, + # analytical_storage_ttl=-1, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) + + async def _MockExecuteFunction(self, function, *args, **kwargs): + self.last_headers.append(args[4].headers[HttpHeaders.PartitionKey] + if HttpHeaders.PartitionKey in args[4].headers else '') + return self.OriginalExecuteFunction(function, *args, **kwargs) + +if __name__ == '__main__': + try: + loop = asyncio.get_event_loop() + loop.run_until_complete(unittest.main()) + except SystemExit as inst: + if inst.args[0] is True: # raised by sys.exit(True) when tests failed + raise From 03388140a5f96568c169631ff200d1d5870c700f Mon Sep 17 00:00:00 2001 From: JericHunter Date: Mon, 25 Apr 2022 09:01:32 -0700 Subject: [PATCH 2/8] Update test_crud_async.py --- .../azure-cosmos/test/test_crud_async.py | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index ac0d654ead79..a709dbfd67f2 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -39,9 +39,8 @@ import azure.cosmos.documents as documents import azure.cosmos.exceptions as exceptions from azure.cosmos.http_constants import HttpHeaders, StatusCodes -import test_config import azure.cosmos._base as base -import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.aio.cosmos_client as cosmos_client from azure.cosmos.diagnostics import RecordDiagnostics from azure.cosmos.partition_key import PartitionKey from azure.cosmos import _retry_utility @@ -82,10 +81,6 @@ class CRUDTests(unittest.TestCase): """Python CRUD Tests. """ - configs = test_config._test_config - host = configs.host - masterKey = configs.masterKey - connectionPolicy = configs.connectionPolicy last_headers = [] async def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): @@ -141,7 +136,7 @@ async def test_database_crud(self): self.client.get_database_client(created_db.id) # delete database. - self.client.delete_database(created_db.id) + await self.client.delete_database(created_db.id) # read database after deletion read_db = self.client.get_database_client(created_db.id) await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, @@ -155,7 +150,7 @@ async def test_database_crud(self): self.assertEqual(database_id, database_proxy.id) self.assertEqual(10000, database_proxy.read_offer().offer_throughput) - self.client.delete_database(database_id) + await self.client.delete_database(database_id) @pytest.mark.skip("skipping as the TestResources subscription doesn't support this offer") async def test_database_level_offer_throughput(self): @@ -176,7 +171,7 @@ async def test_database_level_offer_throughput(self): new_offer_throughput = 2000 offer = created_db.replace_throughput(new_offer_throughput) self.assertEqual(offer.offer_throughput, new_offer_throughput) - self.client.delete_database(created_db.id) + await self.client.delete_database(created_db.id) async def test_sql_query_crud(self): # create two databases. @@ -201,8 +196,8 @@ async def test_sql_query_crud(self): # query with a string. databases = list(self.client.query_databases('SELECT * FROM root r WHERE r.id="' + db2.id + '"')) # nosec self.assertEqual(1, len(databases), 'Unexpected number of query results.') - self.client.delete_database(db1.id) - self.client.delete_database(db2.id) + await self.client.delete_database(db1.id) + await self.client.delete_database(db2.id) async def test_collection_crud(self): created_db = self.databaseForTest @@ -1410,7 +1405,7 @@ async def __SetupEntities(client): self.assertEqual(read_doc["id"], docId) #6. Success-- Use Doc permission to delete doc - doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) + await doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) self.assertEqual(read_doc["id"], docId) db.client_connection = old_client_connection @@ -1966,14 +1961,14 @@ async def test_absolute_client_timeout(self): self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) with self.assertRaises(exceptions.CosmosClientTimeoutError): - client.create_database_if_not_exists("test", timeout=2) + await client.create_database_if_not_exists("test", timeout=2) status_response = 500 # Users connection level retry timeout_transport = TimeoutTransport(status_response) client = cosmos_client.CosmosClient( self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) with self.assertRaises(exceptions.CosmosClientTimeoutError): - client.create_database("test", timeout=2) + await client.create_database("test", timeout=2) databases = client.list_databases(timeout=2) with self.assertRaises(exceptions.CosmosClientTimeoutError): @@ -1984,7 +1979,7 @@ async def test_absolute_client_timeout(self): client = cosmos_client.CosmosClient( self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) with self.assertRaises(exceptions.CosmosClientTimeoutError): - client.create_database_if_not_exists("test", timeout=2) + await client.create_database_if_not_exists("test", timeout=2) databases = client.list_databases(timeout=2) with self.assertRaises(exceptions.CosmosClientTimeoutError): @@ -2360,35 +2355,35 @@ async def test_index_progress_headers(self): async def test_id_validation(self): # Id shouldn't end with space. try: - self.client.create_database(id='id_with_space ') + await self.client.create_database(id='id_with_space ') self.assertFalse(True) except ValueError as e: self.assertEqual('Id ends with a space.', e.args[0]) # Id shouldn't contain '/'. try: - self.client.create_database(id='id_with_illegal/_char') + await self.client.create_database(id='id_with_illegal/_char') self.assertFalse(True) except ValueError as e: self.assertEqual('Id contains illegal chars.', e.args[0]) # Id shouldn't contain '\\'. try: - self.client.create_database(id='id_with_illegal\\_char') + await self.client.create_database(id='id_with_illegal\\_char') self.assertFalse(True) except ValueError as e: self.assertEqual('Id contains illegal chars.', e.args[0]) # Id shouldn't contain '?'. try: - self.client.create_database(id='id_with_illegal?_char') + await self.client.create_database(id='id_with_illegal?_char') self.assertFalse(True) except ValueError as e: self.assertEqual('Id contains illegal chars.', e.args[0]) # Id shouldn't contain '#'. try: - self.client.create_database(id='id_with_illegal#_char') + await self.client.create_database(id='id_with_illegal#_char') self.assertFalse(True) except ValueError as e: self.assertEqual('Id contains illegal chars.', e.args[0]) @@ -2397,7 +2392,7 @@ async def test_id_validation(self): db = self.client.create_database(id=' id_begin_space') self.assertTrue(True) - self.client.delete_database(database=db) + await self.client.delete_database(database=db) async def test_id_case_validation(self): # create database From 8ea1a8822397be4267016bc5aee45cfd3465d8f7 Mon Sep 17 00:00:00 2001 From: JericHunter Date: Wed, 27 Apr 2022 14:10:52 -0700 Subject: [PATCH 3/8] Update test_crud_async.py --- .../azure-cosmos/test/test_crud_async.py | 630 +++++++++--------- 1 file changed, 321 insertions(+), 309 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index a709dbfd67f2..d0abc299a98a 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -29,7 +29,7 @@ import unittest import time from typing import Mapping - +import test_config import urllib.parse as urllib import uuid import pytest @@ -49,6 +49,7 @@ pytestmark = pytest.mark.cosmosEmulator + # IMPORTANT NOTES: # Most test cases in this file create collections in your Azure Cosmos account. # Collections are billing entities. By running these test cases, you may incur monetary costs on your account. @@ -91,46 +92,42 @@ async def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs - `func`: function """ try: - func(*args, **kwargs) + await func(*args, **kwargs) self.assertFalse(True, 'function should fail.') except exceptions.CosmosHttpResponseError as inst: self.assertEqual(inst.status_code, status_code) @classmethod async def setUpClass(cls): - if (cls.masterKey == '[YOUR_KEY_HERE]' or - cls.host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, connection_policy=cls.connectionPolicy) - cls.databaseForTest = cls.configs.create_database_if_not_exist(cls.client) + cls.database = await cls.client.create_database_if_nflot_exists(test_config._test_config.TEST_DATABASE_ID) async def setUp(self): self.client = cosmos_client.CosmosClient(self.host, self.masterKey, "Session", connection_policy=self.connectionPolicy) + async def test_database_crud(self): # read databases. - databases = list(self.client.list_databases()) + databases = [database async for database in self.client.list_databases()] # create a database. before_create_databases_count = len(databases) database_id = str(uuid.uuid4()) - created_db = self.client.create_database(database_id) + created_db = await self.client.create_database(database_id) self.assertEqual(created_db.id, database_id) # Read databases after creation. - databases = list(self.client.list_databases()) + databases = [database async for database in self.client.list_databases()] self.assertEqual(len(databases), before_create_databases_count + 1, 'create should increase the number of databases') # query databases. - databases = list(self.client.query_databases({ - 'query': 'SELECT * FROM root r WHERE r.id=@id', - 'parameters': [ + databases = [database async for database in self.client.query_databases( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ {'name': '@id', 'value': database_id} ] - })) - self.assertTrue(databases, 'number of results for the query should be > 0') + )] + + self.assertTrue(len(databases) > 0) # read database. self.client.get_database_client(created_db.id) @@ -142,13 +139,13 @@ async def test_database_crud(self): await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, read_db.read) - database_proxy = self.client.create_database_if_not_exists(id=database_id, offer_throughput=10000) + database_proxy = await self.client.create_database_if_not_exists(id=database_id, offer_throughput=10000) self.assertEqual(database_id, database_proxy.id) - self.assertEqual(10000, database_proxy.read_offer().offer_throughput) + self.assertEqual(10000, await database_proxy.read_offer().offer_throughput) - database_proxy = self.client.create_database_if_not_exists(id=database_id, offer_throughput=9000) + database_proxy = await self.client.create_database_if_not_exists(id=database_id, offer_throughput=9000) self.assertEqual(database_id, database_proxy.id) - self.assertEqual(10000, database_proxy.read_offer().offer_throughput) + self.assertEqual(10000, await database_proxy.read_offer().offer_throughput) await self.client.delete_database(database_id) @@ -157,7 +154,7 @@ async def test_database_level_offer_throughput(self): # Create a database with throughput offer_throughput = 1000 database_id = str(uuid.uuid4()) - created_db = self.client.create_database( + created_db = await self.client.create_database( id=database_id, offer_throughput=offer_throughput ) @@ -175,33 +172,34 @@ async def test_database_level_offer_throughput(self): async def test_sql_query_crud(self): # create two databases. - db1 = self.client.create_database('database 1' + str(uuid.uuid4())) - db2 = self.client.create_database('database 2' + str(uuid.uuid4())) + db1 = await self.client.create_database('database 1' + str(uuid.uuid4())) + db2 = await self.client.create_database('database 2' + str(uuid.uuid4())) # query with parameters. - databases = list(self.client.query_databases({ - 'query': 'SELECT * FROM root r WHERE r.id=@id', - 'parameters': [ + databases = [database async for database in self.client.query_databases( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ {'name': '@id', 'value': db1.id} ] - })) + )] self.assertEqual(1, len(databases), 'Unexpected number of query results.') # query without parameters. - databases = list(self.client.query_databases({ - 'query': 'SELECT * FROM root r WHERE r.id="database non-existing"' - })) + databases = [database async for database in self.client.query_databases( + query='SELECT * FROM root r WHERE r.id="database non-existing"' + )] self.assertEqual(0, len(databases), 'Unexpected number of query results.') # query with a string. - databases = list(self.client.query_databases('SELECT * FROM root r WHERE r.id="' + db2.id + '"')) # nosec + databases = [database async for database in + self.client.query_databases('SELECT * FROM root r WHERE r.id="' + db2.id + '"')] # nosec self.assertEqual(1, len(databases), 'Unexpected number of query results.') await self.client.delete_database(db1.id) await self.client.delete_database(db2.id) async def test_collection_crud(self): created_db = self.databaseForTest - collections = list(created_db.list_containers()) + collections = [collection async for collection in created_db.list_containers()] # create a collection before_create_collections_count = len(collections) collection_id = 'test_collection_crud ' + str(uuid.uuid4()) @@ -209,7 +207,7 @@ async def test_collection_crud(self): created_recorder = RecordDiagnostics() created_collection = created_db.create_container(id=collection_id, indexing_policy=collection_indexing_policy, - partition_key=PartitionKey(path="/pk", kind="Hash"), + partition_key=PartitionKey(path="/pk", kind="Hash"), response_hook=created_recorder) self.assertEqual(collection_id, created_collection.id) assert isinstance(created_recorder.headers, Mapping) @@ -221,18 +219,18 @@ async def test_collection_crud(self): self.assertEqual('consistent', created_properties['indexingPolicy']['indexingMode']) # read collections after creation - collections = list(created_db.list_containers()) + collections = [collection async for collection in created_db.list_containers()] self.assertEqual(len(collections), before_create_collections_count + 1, 'create should increase the number of collections') # query collections - collections = list(created_db.query_containers( - { - 'query': 'SELECT * FROM root r WHERE r.id=@id', - 'parameters': [ - {'name': '@id', 'value': collection_id} - ] - })) + collections = [collection async for collection in created_db.query_containers( + + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': collection_id} + ] + )] self.assertTrue(collections) # delete collection @@ -242,11 +240,13 @@ async def test_collection_crud(self): await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, created_container.read) - container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, partition_key=PartitionKey(path='/id', kind='Hash')) + container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, + partition_key=PartitionKey(path='/id', kind='Hash')) self.assertEqual(created_collection.id, container_proxy.id) self.assertDictEqual(PartitionKey(path='/id', kind='Hash'), container_proxy._properties['partitionKey']) - container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, partition_key=created_properties['partitionKey']) + container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, + partition_key=created_properties['partitionKey']) self.assertEqual(created_container.id, container_proxy.id) self.assertDictEqual(PartitionKey(path='/id', kind='Hash'), container_proxy._properties['partitionKey']) @@ -265,8 +265,8 @@ async def test_partitioned_collection(self): offer_throughput = 10100 created_collection = created_db.create_container(id=collection_definition['id'], - partition_key=collection_definition['partitionKey'], - offer_throughput=offer_throughput) + partition_key=collection_definition['partitionKey'], + offer_throughput=offer_throughput) self.assertEqual(collection_definition.get('id'), created_collection.id) @@ -341,7 +341,7 @@ async def test_partitioned_collection_partition_key_extraction(self): self.assertEqual(self.last_headers[1], [{}]) del self.last_headers[:] - #self.assertEqual(options['partitionKey'], documents.Undefined) + # self.assertEqual(options['partitionKey'], documents.Undefined) collection_id = 'test_partitioned_collection_partition_key_extraction2 ' + str(uuid.uuid4()) created_collection2 = created_db.create_container( @@ -357,7 +357,7 @@ async def test_partitioned_collection_partition_key_extraction(self): self.assertEqual(self.last_headers[1], [{}]) del self.last_headers[:] - #self.assertEqual(options['partitionKey'], documents.Undefined) + # self.assertEqual(options['partitionKey'], documents.Undefined) created_db.delete_container(created_collection.id) created_db.delete_container(created_collection1.id) @@ -390,7 +390,7 @@ async def test_partitioned_collection_partition_key_extraction_special_chars(sel 'paths': ['/\'level\" 1*()\'/\'le/vel2\''], 'kind': documents.PartitionKind.Hash } - } + } collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()) @@ -455,7 +455,7 @@ async def test_partitioned_collection_document_crud_and_query(self): self.assertEqual(read_document.get('key'), created_document.get('key')) # Read document feed doesn't require partitionKey as it's always a cross partition query - documentlist = list(created_collection.read_all_items()) + documentlist = [document async for document in created_collection.read_all_items()] self.assertEqual(1, len(documentlist)) # replace document @@ -484,34 +484,34 @@ async def test_partitioned_collection_document_crud_and_query(self): created_collection.delete_item(item=upserted_document, partition_key=upserted_document.get('id')) # query document on the partition key specified in the predicate will pass even without setting enableCrossPartitionQuery or passing in the partitionKey value - documentlist = list(created_collection.query_items( - { - 'query': 'SELECT * FROM root r WHERE r.id=\'' + replaced_document.get('id') + '\'' # nosec - })) + documentlist = [document async for document in created_collection.query_items( + + query='SELECT * FROM root r WHERE r.id=\'' + replaced_document.get('id') + '\'' # nosec + )] self.assertEqual(1, len(documentlist)) # query document on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value try: - list(created_collection.query_items( - { - 'query': 'SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'' # nosec - })) + [document async for document in created_collection.query_items( + + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'' # nosec + )] except Exception: pass # cross partition query - documentlist = list(created_collection.query_items( + documentlist = [document async for document in created_collection.query_items( query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec enable_cross_partition_query=True - )) + )] self.assertEqual(1, len(documentlist)) # query document by providing the partitionKey value - documentlist = list(created_collection.query_items( + documentlist = [document async for document in created_collection.query_items( query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec partition_key=replaced_document.get('id') - )) + )] self.assertEqual(1, len(documentlist)) @@ -555,7 +555,8 @@ async def test_partitioned_collection_permissions(self): resource_tokens = {} # storing the resource tokens based on Resource IDs resource_tokens["dbs/" + created_db.id + "/colls/" + all_collection.id] = (all_permission.properties['_token']) - resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = (read_permission.properties['_token']) + resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = ( + read_permission.properties['_token']) restricted_client = cosmos_client.CosmosClient( CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) @@ -585,7 +586,8 @@ async def test_partitioned_collection_permissions(self): document_definition['key'] = 1 # Delete document should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object - created_document = all_collection.delete_item(item=created_document['id'], partition_key=document_definition['key']) + created_document = all_collection.delete_item(item=created_document['id'], + partition_key=document_definition['key']) # Delete document in read_collection should fail since it has only read permissions for this collection await self.__AssertHTTPFailureWithStatus( @@ -606,14 +608,14 @@ async def test_partitioned_collection_execute_stored_procedure(self): sproc = { 'id': 'storedProcedure' + str(uuid.uuid4()), 'body': ( - 'function () {' + - ' var client = getContext().getCollection();' + - ' client.createDocument(client.getSelfLink(), { id: \'testDoc\', pk : 2}, {}, function(err, docCreated, options) { ' + - ' if(err) throw new Error(\'Error while creating document: \' + err.message);' + - ' else {' + - ' getContext().getResponse().setBody(1);' + - ' }' + - ' });}') + 'function () {' + + ' var client = getContext().getCollection();' + + ' client.createDocument(client.getSelfLink(), { id: \'testDoc\', pk : 2}, {}, function(err, docCreated, options) { ' + + ' if(err) throw new Error(\'Error while creating document: \' + err.message);' + + ' else {' + + ' getContext().getResponse().setBody(1);' + + ' }' + + ' });}') } created_sproc = created_collection.scripts.create_stored_procedure(body=sproc) @@ -711,7 +713,7 @@ async def test_partitioned_collection_conflict_crud_and_query(self): ) # Read conflict feed doesn't requires partitionKey to be specified as it's a cross partition thing - conflictlist = list(created_collection.list_conflicts()) + conflictlist = [conflict async for conflict in created_collection.list_conflicts()] self.assertEqual(0, len(conflictlist)) # delete conflict here will return resource not found(404) since there is no conflict here @@ -724,26 +726,28 @@ async def test_partitioned_collection_conflict_crud_and_query(self): # query conflicts on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value try: - list(created_collection.query_conflicts( - query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get( # nosec - 'resourceType') + '\'' - )) + [conflict async for conflict in created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get( # nosec + 'resourceType') + '\'' + )] except Exception: pass - conflictlist = list(created_collection.query_conflicts( - query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', # nosec - enable_cross_partition_query=True - )) + conflictlist = [conflict async for conflict in created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', + # nosec + enable_cross_partition_query=True + )] self.assertEqual(0, len(conflictlist)) # query conflicts by providing the partitionKey value options = {'partitionKey': conflict_definition.get('id')} - conflictlist = list(created_collection.query_conflicts( - query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', # nosec + conflictlist = [conflict async for conflict in created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', + # nosec partition_key=conflict_definition['id'] - )) + )] self.assertEqual(0, len(conflictlist)) @@ -753,7 +757,7 @@ async def test_document_crud(self): # create collection created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) # read documents - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] # create a document before_create_documents_count = len(documents) @@ -783,30 +787,30 @@ async def test_document_crud(self): created_collection.create_item, duplicated_definition_with_id) # read documents after creation - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] self.assertEqual( len(documents), before_create_documents_count + 2, 'create should increase the number of documents') # query documents - documents = list(created_collection.query_items( - { - 'query': 'SELECT * FROM root r WHERE r.name=@name', - 'parameters': [ - {'name': '@name', 'value': document_definition['name']} - ] - }, enable_cross_partition_query=True - )) + documents = [document async for document in created_collection.query_items( + + query='SELECT * FROM root r WHERE r.name=@name', + parameter=[ + {'name': '@name', 'value': document_definition['name']} + ] + , enable_cross_partition_query=True + )] self.assertTrue(documents) - documents = list(created_collection.query_items( - { - 'query': 'SELECT * FROM root r WHERE r.name=@name', - 'parameters': [ - {'name': '@name', 'value': document_definition['name']} - ], - }, enable_cross_partition_query=True, + documents = [document async for document in created_collection.query_items( + + query='SELECT * FROM root r WHERE r.name=@name', + parameter=[ + {'name': '@name', 'value': document_definition['name']} + ] + , enable_cross_partition_query=True, enable_scan_in_query=True - )) + )] self.assertTrue(documents) # replace document. created_document['name'] = 'replaced document' @@ -871,11 +875,11 @@ async def test_document_crud(self): # should pass for most recent etag replaced_document_conditional = created_collection.replace_item( - match_condition=MatchConditions.IfNotModified, - etag=replaced_document['_etag'], - item=replaced_document['id'], - body=replaced_document - ) + match_condition=MatchConditions.IfNotModified, + etag=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) self.assertEqual(replaced_document_conditional['name'], 'replaced document based on condition', 'document id property should change') @@ -911,7 +915,7 @@ async def test_document_upsert(self): created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) # read documents and check count - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] before_create_documents_count = len(documents) # create document definition @@ -933,7 +937,7 @@ async def test_document_upsert(self): created_collection.upsert_item(body=document_definition) # read documents after creation and verify updated count - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] self.assertEqual( len(documents), before_create_documents_count + 1, @@ -960,7 +964,7 @@ async def test_document_upsert(self): 'document id should stay the same') # read documents after upsert and verify count doesn't increases again - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] self.assertEqual( len(documents), before_create_documents_count + 1, @@ -986,7 +990,7 @@ async def test_document_upsert(self): 'document id should be same') # read documents after upsert and verify count increases - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] self.assertEqual( len(documents), before_create_documents_count + 2, @@ -997,43 +1001,42 @@ async def test_document_upsert(self): created_collection.delete_item(item=new_document, partition_key=new_document['id']) # read documents after delete and verify count is same as original - documents = list(created_collection.read_all_items()) + documents = [document async for document in created_collection.read_all_items()] self.assertEqual( len(documents), before_create_documents_count, 'number of documents should remain same') - async def _test_spatial_index(self): db = self.databaseForTest # partial policy specified collection = db.create_container( id='collection with spatial index ' + str(uuid.uuid4()), indexing_policy={ - 'includedPaths': [ - { - 'path': '/"Location"/?', - 'indexes': [ - { - 'kind': 'Spatial', - 'dataType': 'Point' - } - ] - }, - { - 'path': '/' - } - ] - }, + 'includedPaths': [ + { + 'path': '/"Location"/?', + 'indexes': [ + { + 'kind': 'Spatial', + 'dataType': 'Point' + } + ] + }, + { + 'path': '/' + } + ] + }, partition_key=PartitionKey(path='/id', kind='Hash') - ) + ) collection.create_item( body={ - 'id': 'loc1', - 'Location': { + 'id': 'loc1', + 'Location': { 'type': 'Point', 'coordinates': [20.0, 20.0] - } + } } ) collection.create_item( @@ -1045,10 +1048,10 @@ async def _test_spatial_index(self): } } ) - results = list(collection.query_items( + results = [result async for result in collection.query_items( query="SELECT * FROM root WHERE (ST_DISTANCE(root.Location, {type: 'Point', coordinates: [20.1, 20]}) < 20000)", enable_cross_partition_query=True - )) + )] self.assertEqual(1, len(results)) self.assertEqual('loc1', results[0]['id']) @@ -1060,22 +1063,22 @@ async def test_user_crud(self): # create database db = self.databaseForTest # list users - users = list(db.list_users()) + users = [user async for user in db.list_users()] before_create_count = len(users) # create user user_id = 'new user' + str(uuid.uuid4()) user = db.create_user(body={'id': user_id}) self.assertEqual(user.id, user_id, 'user id error') # list users after creation - users = list(db.list_users()) + users = [user async for user in db.list_users()] self.assertEqual(len(users), before_create_count + 1) # query users - results = list(db.query_users( - query='SELECT * FROM root r WHERE r.id=@id', - parameters=[ - {'name': '@id', 'value': user_id} - ] - )) + results = [user async for user in db.query_users( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': user_id} + ] + )] self.assertTrue(results) # replace user @@ -1104,7 +1107,7 @@ async def test_user_upsert(self): db = self.databaseForTest # read users and check count - users = list(db.list_users()) + users = [user async for user in db.list_users()] before_create_count = len(users) # create user using Upsert API @@ -1115,7 +1118,7 @@ async def test_user_upsert(self): self.assertEqual(user.id, user_id, 'user id error') # read users after creation and verify updated count - users = list(db.list_users()) + users = [user async for user in db.list_users()] self.assertEqual(len(users), before_create_count + 1) # Should replace the user since it already exists, there is no public property to change here @@ -1128,7 +1131,7 @@ async def test_user_upsert(self): 'user id should remain same') # read users after upsert and verify count doesn't increases again - users = list(db.list_users()) + users = [user async for user in db.list_users()] self.assertEqual(len(users), before_create_count + 1) user_properties = user.read() @@ -1142,7 +1145,7 @@ async def test_user_upsert(self): self.assertEqual(new_user.id, user.id, 'user id error') # read users after upsert and verify count increases - users = list(db.list_users()) + users = [user async for user in db.list_users()] self.assertEqual(len(users), before_create_count + 2) # delete users @@ -1150,7 +1153,7 @@ async def test_user_upsert(self): db.delete_user(new_user.id) # read users after delete and verify count remains the same - users = list(db.list_users()) + users = [user async for user in db.list_users()] self.assertEqual(len(users), before_create_count) async def test_permission_crud(self): @@ -1160,7 +1163,7 @@ async def test_permission_crud(self): # create user user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) # list permissions - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] before_create_count = len(permissions) permission = { 'id': 'new permission', @@ -1173,15 +1176,15 @@ async def test_permission_crud(self): 'new permission', 'permission id error') # list permissions after creation - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] self.assertEqual(len(permissions), before_create_count + 1) # query permissions - results = list(user.query_permissions( - query='SELECT * FROM root r WHERE r.id=@id', - parameters=[ - {'name': '@id', 'value': permission.id} - ] - )) + results = [permission async for permission in user.query_permissions( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': permission.id} + ] + )] self.assertTrue(results) # replace permission @@ -1213,7 +1216,7 @@ async def test_permission_upsert(self): user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) # read permissions and check count - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] before_create_count = len(permissions) permission_definition = { @@ -1231,7 +1234,7 @@ async def test_permission_upsert(self): 'permission id error') # read permissions after creation and verify updated count - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] self.assertEqual(len(permissions), before_create_count + 1) # update permission mode @@ -1250,7 +1253,7 @@ async def test_permission_upsert(self): 'permissionMode should change') # read permissions and verify count doesn't increases again - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] self.assertEqual(len(permissions), before_create_count + 1) # update permission id @@ -1273,7 +1276,7 @@ async def test_permission_upsert(self): 'permission resource should be same') # read permissions and verify count increases - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] self.assertEqual(len(permissions), before_create_count + 2) # delete permissions @@ -1281,7 +1284,7 @@ async def test_permission_upsert(self): user.delete_permission(new_permission.id) # read permissions and verify count remains the same - permissions = list(user.list_permissions()) + permissions = [permission async for permission in user.list_permissions()] self.assertEqual(len(permissions), before_create_count) async def test_authorization(self): @@ -1305,9 +1308,9 @@ async def __SetupEntities(client): ) # create document1 document = collection.create_item( - body={'id': 'doc1', - 'spam': 'eggs', - 'key': 'value'}, + body={'id': 'doc1', + 'spam': 'eggs', + 'key': 'value'}, ) # create user @@ -1321,7 +1324,7 @@ async def __SetupEntities(client): } permission_on_coll = user.create_permission(body=permission) self.assertIsNotNone(permission_on_coll.properties['_token'], - 'permission token is invalid') + 'permission token is invalid') # create permission for document permission = { @@ -1331,7 +1334,7 @@ async def __SetupEntities(client): } permission_on_doc = user.create_permission(body=permission) self.assertIsNotNone(permission_on_doc.properties['_token'], - 'permission token is invalid') + 'permission token is invalid') entities = { 'db': db, @@ -1356,9 +1359,9 @@ async def __SetupEntities(client): # setup entities entities = __SetupEntities(client) resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: - entities['permissionOnColl'].properties['_token']} + entities['permissionOnColl'].properties['_token']} col_client = cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens,"Session", connection_policy=CRUDTests.connectionPolicy) + CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) db = entities['db'] old_client_connection = db.client_connection @@ -1370,7 +1373,7 @@ async def __SetupEntities(client): db.delete_container, success_coll) # 3. Success-- Use Col Permission to Read All Docs - success_documents = list(success_coll.read_all_items()) + success_documents = [document async for document in success_coll.read_all_items()] self.assertTrue(success_documents != None, 'error reading documents') self.assertEqual(len(success_documents), @@ -1389,22 +1392,22 @@ async def __SetupEntities(client): entities['doc']['id'], 'Expected to read children using parent permissions') - #5. Failure-- Use Col Permission to Delete Doc + # 5. Failure-- Use Col Permission to Delete Doc await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, success_coll.delete_item, docId, docId) - resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId : - entities['permissionOnDoc'].properties['_token']} + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId: + entities['permissionOnDoc'].properties['_token']} doc_client = cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens,"Session", connection_policy=CRUDTests.connectionPolicy) + CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) - #6. Success-- Use Doc permission to read doc - read_doc = doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, docId) + # 6. Success-- Use Doc permission to read doc + read_doc = await doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, docId) self.assertEqual(read_doc["id"], docId) - #6. Success-- Use Doc permission to delete doc + # 6. Success-- Use Doc permission to delete doc await doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) self.assertEqual(read_doc["id"], docId) @@ -1417,7 +1420,7 @@ async def test_trigger_crud(self): # create collection collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) # read triggers - triggers = list(collection.scripts.list_triggers()) + triggers = [trigger async for trigger in collection.scripts.list_triggers()] # create a trigger before_create_triggers_count = len(triggers) trigger_definition = { @@ -1438,17 +1441,17 @@ async def test_trigger_crud(self): 'function() {var x = 10;}') # read triggers after creation - triggers = list(collection.scripts.list_triggers()) + triggers = [trigger async for trigger in collection.scripts.list_triggers()] self.assertEqual(len(triggers), before_create_triggers_count + 1, 'create should increase the number of triggers') # query triggers - triggers = list(collection.scripts.query_triggers( - query='SELECT * FROM root r WHERE r.id=@id', - parameters=[ - {'name': '@id', 'value': trigger_definition['id']} - ] - )) + triggers = [trigger async for trigger in collection.scripts.query_triggers( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': trigger_definition['id']} + ] + )] self.assertTrue(triggers) # replace trigger @@ -1481,7 +1484,7 @@ async def test_udf_crud(self): # create collection collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) # read udfs - udfs = list(collection.scripts.list_user_defined_functions()) + udfs = [udf async for udf in collection.scripts.list_user_defined_functions()] # create a udf before_create_udfs_count = len(udfs) udf_definition = { @@ -1496,17 +1499,17 @@ async def test_udf_crud(self): 'property {property} should match'.format(property=property)) # read udfs after creation - udfs = list(collection.scripts.list_user_defined_functions()) + udfs = [udf async for udf in collection.scripts.list_user_defined_functions()] self.assertEqual(len(udfs), before_create_udfs_count + 1, 'create should increase the number of udfs') # query udfs - results = list(collection.scripts.query_user_defined_functions( - query='SELECT * FROM root r WHERE r.id=@id', - parameters=[ - {'name': '@id', 'value': udf_definition['id']} - ] - )) + results = [udf async for udf in collection.scripts.query_user_defined_functions( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': udf_definition['id']} + ] + )] self.assertTrue(results) # replace udf change_udf = udf.copy() @@ -1533,7 +1536,7 @@ async def test_sproc_crud(self): # create collection collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) # read sprocs - sprocs = list(collection.scripts.list_stored_procedures()) + sprocs = [sproc async for sproc in collection.scripts.list_stored_procedures()] # create a sproc before_create_sprocs_count = len(sprocs) sproc_definition = { @@ -1551,17 +1554,17 @@ async def test_sproc_crud(self): self.assertEqual(sproc['body'], 'function() {var x = 10;}') # read sprocs after creation - sprocs = list(collection.scripts.list_stored_procedures()) + sprocs = [sproc async for sproc in collection.scripts.list_stored_procedures()] self.assertEqual(len(sprocs), before_create_sprocs_count + 1, 'create should increase the number of sprocs') # query sprocs - sprocs = list(collection.scripts.query_stored_procedures( - query='SELECT * FROM root r WHERE r.id=@id', - parameters=[ - {'name': '@id', 'value': sproc_definition['id']} - ] - )) + sprocs = [sproc async for sproc in collection.scripts.query_stored_procedures( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': sproc_definition['id']} + ] + )] self.assertIsNotNone(sprocs) # replace sproc change_sproc = sproc.copy() @@ -1594,17 +1597,17 @@ async def test_script_logging_execute_stored_procedure(self): sproc = { 'id': 'storedProcedure' + str(uuid.uuid4()), 'body': ( - 'function () {' + - ' var mytext = \'x\';' + - ' var myval = 1;' + - ' try {' + - ' console.log(\'The value of %s is %s.\', mytext, myval);' + - ' getContext().getResponse().setBody(\'Success!\');' + - ' }' + - ' catch (err) {' + - ' getContext().getResponse().setBody(\'inline err: [\' + err.number + \'] \' + err);' + - ' }' - '}') + 'function () {' + + ' var mytext = \'x\';' + + ' var myval = 1;' + + ' try {' + + ' console.log(\'The value of %s is %s.\', mytext, myval);' + + ' getContext().getResponse().setBody(\'Success!\');' + + ' }' + + ' catch (err) {' + + ' getContext().getResponse().setBody(\'inline err: [\' + err.number + \'] \' + err);' + + ' }' + '}') } created_sproc = created_collection.scripts.create_stored_procedure(body=sproc) @@ -1615,7 +1618,8 @@ async def test_script_logging_execute_stored_procedure(self): ) self.assertEqual(result, 'Success!') - self.assertFalse(HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) + self.assertFalse( + HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) result = created_collection.scripts.execute_stored_procedure( sproc=created_sproc['id'], @@ -1625,7 +1629,8 @@ async def test_script_logging_execute_stored_procedure(self): self.assertEqual(result, 'Success!') self.assertEqual(urllib.quote('The value of x is 1.'), - created_collection.scripts.client_connection.last_response_headers.get(HttpHeaders.ScriptLogResults)) + created_collection.scripts.client_connection.last_response_headers.get( + HttpHeaders.ScriptLogResults)) result = created_collection.scripts.execute_stored_procedure( sproc=created_sproc['id'], @@ -1634,7 +1639,8 @@ async def test_script_logging_execute_stored_procedure(self): ) self.assertEqual(result, 'Success!') - self.assertFalse(HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) + self.assertFalse( + HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) async def test_collection_indexing_policy(self): # create database @@ -1719,8 +1725,8 @@ async def test_create_default_indexing_policy(self): collection = db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), indexing_policy={ - 'indexingMode': documents.IndexingMode.Consistent, 'automatic': True - }, + 'indexingMode': documents.IndexingMode.Consistent, 'automatic': True + }, partition_key=PartitionKey(path='/id', kind='Hash') ) collection_properties = collection.read() @@ -1741,12 +1747,12 @@ async def test_create_default_indexing_policy(self): collection = db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy04' + str(uuid.uuid4()), indexing_policy={ - 'includedPaths': [ - { - 'path': '/*' - } - ] - }, + 'includedPaths': [ + { + 'path': '/*' + } + ] + }, partition_key=PartitionKey(path='/id', kind='Hash') ) collection_properties = collection.read() @@ -1757,22 +1763,22 @@ async def test_create_default_indexing_policy(self): collection = db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy05' + str(uuid.uuid4()), indexing_policy={ - 'includedPaths': [ - { - 'path': '/*', - 'indexes': [ - { - 'kind': documents.IndexKind.Hash, - 'dataType': documents.DataType.String - }, - { - 'kind': documents.IndexKind.Range, - 'dataType': documents.DataType.Number - } - ] - } - ] - }, + 'includedPaths': [ + { + 'path': '/*', + 'indexes': [ + { + 'kind': documents.IndexKind.Hash, + 'dataType': documents.DataType.String + }, + { + 'kind': documents.IndexKind.Range, + 'dataType': documents.DataType.Number + } + ] + } + ] + }, partition_key=PartitionKey(path='/id', kind='Hash') ) collection_properties = collection.read() @@ -1841,7 +1847,7 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): id='composite_index_spatial_index' + str(uuid.uuid4()), indexing_policy=indexing_policy, partition_key=PartitionKey(path='/id', kind='Hash'), - headers={"Foo":"bar"}, + headers={"Foo": "bar"}, user_agent="blah", user_agent_overwrite=True, logging_enable=True, @@ -1877,29 +1883,31 @@ async def __get_first(array): async def test_client_request_timeout(self): # Test is flaky on Emulator - if not('localhost' in self.host or '127.0.0.1' in self.host): + if not ('localhost' in self.host or '127.0.0.1' in self.host): connection_policy = documents.ConnectionPolicy() # making timeout 0 ms to make sure it will throw - connection_policy.RequestTimeout = 0.000000000001 + connection_policy.RequestTimeout = 0.000000000001 with self.assertRaises(Exception): # client does a getDatabaseAccount on initialization, which will time out - cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", connection_policy=connection_policy) + cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", + connection_policy=connection_policy) async def test_client_request_timeout_when_connection_retry_configuration_specified(self): connection_policy = documents.ConnectionPolicy() # making timeout 0 ms to make sure it will throw - connection_policy.RequestTimeout = 0.000000000001 + connection_policy.RequestTimeout = 0.000000000001 connection_policy.ConnectionRetryConfiguration = Retry( - total=3, - read=3, - connect=3, - backoff_factor=0.3, - status_forcelist=(500, 502, 504) - ) + total=3, + read=3, + connect=3, + backoff_factor=0.3, + status_forcelist=(500, 502, 504) + ) with self.assertRaises(AzureError): # client does a getDatabaseAccount on initialization, which will time out - cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", connection_policy=connection_policy) + cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", + connection_policy=connection_policy) async def test_client_connection_retry_configuration(self): total_time_for_two_retries = self.initialize_client_with_connection_urllib_retry_config(2) @@ -1972,7 +1980,7 @@ async def test_absolute_client_timeout(self): databases = client.list_databases(timeout=2) with self.assertRaises(exceptions.CosmosClientTimeoutError): - list(databases) + databases = [database async for database in databases] status_response = 429 # Uses Cosmos custom retry timeout_transport = TimeoutTransport(status_response) @@ -1983,7 +1991,7 @@ async def test_absolute_client_timeout(self): databases = client.list_databases(timeout=2) with self.assertRaises(exceptions.CosmosClientTimeoutError): - list(databases) + databases = [database async for database in databases] async def test_query_iterable_functionality(self): def __create_resources(client): @@ -2011,7 +2019,7 @@ def __create_resources(client): # Validate QueryIterable by converting it to a list. resources = __create_resources(self.client) results = resources['coll'].read_all_items(max_item_count=2) - docs = list(iter(results)) + docs = [doc async for doc in results] self.assertEqual(3, len(docs), 'QueryIterable should return all documents' + @@ -2044,11 +2052,11 @@ def __create_resources(client): results = resources['coll'].read_all_items(max_item_count=2) page_iter = results.by_page() - first_block = list(next(page_iter)) + first_block = [page async for page in next(page_iter)] self.assertEqual(2, len(first_block), 'First block should have 2 entries.') self.assertEqual(resources['doc1']['id'], first_block[0]['id']) self.assertEqual(resources['doc2']['id'], first_block[1]['id']) - self.assertEqual(1, len(list(next(page_iter))), 'Second block should have 1 entry.') + self.assertEqual(1, len([page async for page in next(page_iter)]), 'Second block should have 1 entry.') with self.assertRaises(StopIteration): next(page_iter) @@ -2057,25 +2065,25 @@ async def test_trigger_functionality(self): { 'id': 't1', 'body': ( - 'function() {' + - ' var item = getContext().getRequest().getBody();' + - ' item.id = item.id.toUpperCase() + \'t1\';' + - ' getContext().getRequest().setBody(item);' + - '}'), + 'function() {' + + ' var item = getContext().getRequest().getBody();' + + ' item.id = item.id.toUpperCase() + \'t1\';' + + ' getContext().getRequest().setBody(item);' + + '}'), 'triggerType': documents.TriggerType.Pre, 'triggerOperation': documents.TriggerOperation.All }, { 'id': 'response1', 'body': ( - 'function() {' + - ' var prebody = getContext().getRequest().getBody();' + - ' if (prebody.id != \'TESTING POST TRIGGERt1\')' - ' throw \'id mismatch\';' + - ' var postbody = getContext().getResponse().getBody();' + - ' if (postbody.id != \'TESTING POST TRIGGERt1\')' - ' throw \'id mismatch\';' - '}'), + 'function() {' + + ' var prebody = getContext().getRequest().getBody();' + + ' if (prebody.id != \'TESTING POST TRIGGERt1\')' + ' throw \'id mismatch\';' + + ' var postbody = getContext().getResponse().getBody();' + + ' if (postbody.id != \'TESTING POST TRIGGERt1\')' + ' throw \'id mismatch\';' + '}'), 'triggerType': documents.TriggerType.Post, 'triggerOperation': documents.TriggerOperation.All }, @@ -2083,14 +2091,14 @@ async def test_trigger_functionality(self): 'id': 'response2', # can't be used because setValue is currently disabled 'body': ( - 'function() {' + - ' var predoc = getContext().getRequest().getBody();' + - ' var postdoc = getContext().getResponse().getBody();' + - ' getContext().getResponse().setValue(' + - ' \'predocname\', predoc.id + \'response2\');' + - ' getContext().getResponse().setValue(' + - ' \'postdocname\', postdoc.id + \'response2\');' + - '}'), + 'function() {' + + ' var predoc = getContext().getRequest().getBody();' + + ' var postdoc = getContext().getResponse().getBody();' + + ' getContext().getResponse().setValue(' + + ' \'predocname\', predoc.id + \'response2\');' + + ' getContext().getResponse().setValue(' + + ' \'postdocname\', postdoc.id + \'response2\');' + + '}'), 'triggerType': documents.TriggerType.Post, 'triggerOperation': documents.TriggerOperation.All, }] @@ -2104,11 +2112,11 @@ async def test_trigger_functionality(self): { 'id': "t3", 'body': ( - 'function() {' + - ' var item = getContext().getRequest().getBody();' + - ' item.id = item.id.toLowerCase() + \'t3\';' + - ' getContext().getRequest().setBody(item);' + - '}'), + 'function() {' + + ' var item = getContext().getRequest().getBody();' + + ' item.id = item.id.toLowerCase() + \'t3\';' + + ' getContext().getRequest().setBody(item);' + + '}'), 'triggerType': documents.TriggerType.Pre, 'triggerOperation': documents.TriggerOperation.All }] @@ -2140,15 +2148,18 @@ async def __CreateTriggers(collection, triggers): db = self.databaseForTest # create collections pkd = PartitionKey(path='/id', kind='Hash') - collection1 = db.create_container(id='test_trigger_functionality 1 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) - collection2 = db.create_container(id='test_trigger_functionality 2 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) - collection3 = db.create_container(id='test_trigger_functionality 3 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) + collection1 = db.create_container(id='test_trigger_functionality 1 ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/key', kind='Hash')) + collection2 = db.create_container(id='test_trigger_functionality 2 ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/key', kind='Hash')) + collection3 = db.create_container(id='test_trigger_functionality 3 ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/key', kind='Hash')) # create triggers await __CreateTriggers(collection1, triggers_in_collection1) await __CreateTriggers(collection2, triggers_in_collection2) await __CreateTriggers(collection3, triggers_in_collection3) # create document - triggers_1 = list(collection1.scripts.list_triggers()) + triggers_1 = [trigger async for trigger in collection1.scripts.list_triggers()] self.assertEqual(len(triggers_1), 3) document_1_1 = collection1.create_item( body={'id': 'doc1', @@ -2172,7 +2183,7 @@ async def __CreateTriggers(collection, triggers): ) self.assertEqual(document_1_3['id'], "RESPONSEHEADERSt1") - triggers_2 = list(collection2.scripts.list_triggers()) + triggers_2 = [trigger async for trigger in collection2.scripts.list_triggers()] self.assertEqual(len(triggers_2), 2) document_2_1 = collection2.create_item( body={'id': 'doc2', @@ -2189,7 +2200,7 @@ async def __CreateTriggers(collection, triggers): pre_trigger_include='t3') self.assertEqual(document_2_2['id'], 'doc3t3') - triggers_3 = list(collection3.scripts.list_triggers()) + triggers_3 = [trigger async for trigger in collection3.scripts.list_triggers()] self.assertEqual(len(triggers_3), 1) with self.assertRaises(Exception): collection3.create_item( @@ -2210,13 +2221,13 @@ async def test_stored_procedure_functionality(self): sproc1 = { 'id': 'storedProcedure1' + str(uuid.uuid4()), 'body': ( - 'function () {' + - ' for (var i = 0; i < 1000; i++) {' + - ' var item = getContext().getResponse().getBody();' + - ' if (i > 0 && item != i - 1) throw \'body mismatch\';' + - ' getContext().getResponse().setBody(i);' + - ' }' + - '}') + 'function () {' + + ' for (var i = 0; i < 1000; i++) {' + + ' var item = getContext().getResponse().getBody();' + + ' if (i > 0 && item != i - 1) throw \'body mismatch\';' + + ' getContext().getResponse().setBody(i);' + + ' }' + + '}') } retrieved_sproc = collection.scripts.create_stored_procedure(body=sproc1) @@ -2228,11 +2239,11 @@ async def test_stored_procedure_functionality(self): sproc2 = { 'id': 'storedProcedure2' + str(uuid.uuid4()), 'body': ( - 'function () {' + - ' for (var i = 0; i < 10; i++) {' + - ' getContext().getResponse().appendValue(\'Body\', i);' + - ' }' + - '}') + 'function () {' + + ' for (var i = 0; i < 10; i++) {' + + ' getContext().getResponse().appendValue(\'Body\', i);' + + ' }' + + '}') } retrieved_sproc2 = collection.scripts.create_stored_procedure(body=sproc2) result = collection.scripts.execute_stored_procedure( @@ -2243,10 +2254,10 @@ async def test_stored_procedure_functionality(self): sproc3 = { 'id': 'storedProcedure3' + str(uuid.uuid4()), 'body': ( - 'function (input) {' + - ' getContext().getResponse().setBody(' + - ' \'a\' + input.temp);' + - '}') + 'function (input) {' + + ' getContext().getResponse().setBody(' + + ' \'a\' + input.temp);' + + '}') } retrieved_sproc3 = collection.scripts.create_stored_procedure(body=sproc3) result = collection.scripts.execute_stored_procedure( @@ -2403,7 +2414,7 @@ async def test_id_case_validation(self): collection_id2 = 'SampleCollection ' + uuid_string # Verify that no collections exist - collections = list(created_db.list_containers()) + collections = [collection async for collection in created_db.list_containers()] number_of_existing_collections = len(collections) # create 2 collections with different casing of IDs @@ -2419,7 +2430,7 @@ async def test_id_case_validation(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - collections = list(created_db.list_containers()) + collections = [collection async for collection in created_db.list_containers()] # verify if a total of 2 collections got created self.assertEqual(len(collections), number_of_existing_collections + 2) @@ -2431,7 +2442,7 @@ async def test_id_case_validation(self): created_db.delete_container(created_collection1) created_db.delete_container(created_collection2) - #TODO: fix test + # TODO: fix test @pytest.mark.skip async def test_id_unicode_validation(self): # create database @@ -2455,7 +2466,7 @@ async def test_id_unicode_validation(self): self.assertEqual(collection_id1, created_collection1.id) self.assertEqual(collection_id2, created_collection2.id) - + created_collection1_properties = created_collection1.read() created_collection2_properties = created_collection2.read() @@ -2492,7 +2503,7 @@ async def test_get_resource_with_dictionary_and_object(self): read_container = created_db.get_container_client(created_properties) self.assertEqual(read_container.id, created_container.id) - created_item = created_container.create_item({'id':'1' + str(uuid.uuid4())}) + created_item = created_container.create_item({'id': '1' + str(uuid.uuid4())}) # read item with id read_item = created_container.read_item(item=created_item['id'], partition_key=created_item['id']) @@ -2644,11 +2655,12 @@ async def test_get_resource_with_dictionary_and_object(self): # ttl_key = "analyticalStorageTtl" # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) - async def _MockExecuteFunction(self, function, *args, **kwargs): + def _MockExecuteFunction(self, function, *args, **kwargs): self.last_headers.append(args[4].headers[HttpHeaders.PartitionKey] - if HttpHeaders.PartitionKey in args[4].headers else '') + if HttpHeaders.PartitionKey in args[4].headers else '') return self.OriginalExecuteFunction(function, *args, **kwargs) + if __name__ == '__main__': try: loop = asyncio.get_event_loop() From 011bc02b0d290eca287b1bb58ea8a982f79c5f7c Mon Sep 17 00:00:00 2001 From: JericHunter Date: Wed, 27 Apr 2022 14:14:23 -0700 Subject: [PATCH 4/8] Update test_crud_async.py --- sdk/cosmos/azure-cosmos/test/test_crud_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index d0abc299a98a..d6fcaf49b65d 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -477,7 +477,7 @@ async def test_partitioned_collection_document_crud_and_query(self): self.assertEqual(upserted_document.get('id'), document_definition.get('id')) self.assertEqual(upserted_document.get('key'), document_definition.get('key')) - documentlist = list(created_collection.read_all_items()) + documentlist = [document async for document in created_collection.read_all_items()] self.assertEqual(2, len(documentlist)) # delete document From c28bc24548a4dcdcf6cc8f713b1fe7d5bb07383c Mon Sep 17 00:00:00 2001 From: JericHunter Date: Fri, 29 Apr 2022 08:48:58 -0700 Subject: [PATCH 5/8] Update test_crud_async.py --- .../azure-cosmos/test/test_crud_async.py | 481 +++++++++--------- 1 file changed, 241 insertions(+), 240 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index d6fcaf49b65d..fa7f801b1971 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -100,7 +100,7 @@ async def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs @classmethod async def setUpClass(cls): cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, connection_policy=cls.connectionPolicy) - cls.database = await cls.client.create_database_if_nflot_exists(test_config._test_config.TEST_DATABASE_ID) + cls.databaseForTest = cls.client.create_database_if_not_exists(test_config._test_config.TEST_DATABASE_ID) async def setUp(self): self.client = cosmos_client.CosmosClient(self.host, self.masterKey, "Session", @@ -161,12 +161,12 @@ async def test_database_level_offer_throughput(self): self.assertEqual(created_db.id, database_id) # Verify offer throughput for database - offer = created_db.read_offer() + offer = await created_db.read_offer() self.assertEqual(offer.offer_throughput, offer_throughput) # Update database offer throughput new_offer_throughput = 2000 - offer = created_db.replace_throughput(new_offer_throughput) + offer = await created_db.replace_throughput(new_offer_throughput) self.assertEqual(offer.offer_throughput, new_offer_throughput) await self.client.delete_database(created_db.id) @@ -205,17 +205,17 @@ async def test_collection_crud(self): collection_id = 'test_collection_crud ' + str(uuid.uuid4()) collection_indexing_policy = {'indexingMode': 'consistent'} created_recorder = RecordDiagnostics() - created_collection = created_db.create_container(id=collection_id, - indexing_policy=collection_indexing_policy, - partition_key=PartitionKey(path="/pk", kind="Hash"), - response_hook=created_recorder) + created_collection = await created_db.create_container(id=collection_id, + indexing_policy=collection_indexing_policy, + partition_key=PartitionKey(path="/pk", kind="Hash"), + response_hook=created_recorder) self.assertEqual(collection_id, created_collection.id) assert isinstance(created_recorder.headers, Mapping) assert 'Content-Type' in created_recorder.headers assert isinstance(created_recorder.body, Mapping) assert 'id' in created_recorder.body - created_properties = created_collection.read() + created_properties = await created_collection.read() self.assertEqual('consistent', created_properties['indexingPolicy']['indexingMode']) # read collections after creation @@ -234,23 +234,23 @@ async def test_collection_crud(self): self.assertTrue(collections) # delete collection - created_db.delete_container(created_collection.id) + await created_db.delete_container(created_collection.id) # read collection after deletion - created_container = created_db.get_container_client(created_collection.id) + created_container = await created_db.get_container_client(created_collection.id) await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, created_container.read) - container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, + container_proxy = await created_db.create_container_if_not_exists(id=created_collection.id, partition_key=PartitionKey(path='/id', kind='Hash')) self.assertEqual(created_collection.id, container_proxy.id) self.assertDictEqual(PartitionKey(path='/id', kind='Hash'), container_proxy._properties['partitionKey']) - container_proxy = created_db.create_container_if_not_exists(id=created_collection.id, + container_proxy = await created_db.create_container_if_not_exists(id=created_collection.id, partition_key=created_properties['partitionKey']) self.assertEqual(created_container.id, container_proxy.id) self.assertDictEqual(PartitionKey(path='/id', kind='Hash'), container_proxy._properties['partitionKey']) - created_db.delete_container(created_collection.id) + await created_db.delete_container(created_collection.id) async def test_partitioned_collection(self): created_db = self.databaseForTest @@ -264,36 +264,36 @@ async def test_partitioned_collection(self): } offer_throughput = 10100 - created_collection = created_db.create_container(id=collection_definition['id'], + created_collection = await created_db.create_container(id=collection_definition['id'], partition_key=collection_definition['partitionKey'], offer_throughput=offer_throughput) self.assertEqual(collection_definition.get('id'), created_collection.id) - created_collection_properties = created_collection.read() + created_collection_properties = await created_collection.read() self.assertEqual(collection_definition.get('partitionKey').get('paths')[0], created_collection_properties['partitionKey']['paths'][0]) self.assertEqual(collection_definition.get('partitionKey').get('kind'), created_collection_properties['partitionKey']['kind']) - expected_offer = created_collection.read_offer() + expected_offer = await created_collection.read_offer() self.assertIsNotNone(expected_offer) self.assertEqual(expected_offer.offer_throughput, offer_throughput) - created_db.delete_container(created_collection.id) + await created_db.delete_container(created_collection.id) async def test_partitioned_collection_quota(self): created_db = self.databaseForTest created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) - retrieved_collection = created_db.get_container_client( + retrieved_collection = await created_db.get_container_client( container=created_collection.id ) - retrieved_collection_properties = retrieved_collection.read( + retrieved_collection_properties = await retrieved_collection.read( populate_partition_key_range_statistics=True, populate_quota_info=True) self.assertIsNotNone(retrieved_collection_properties.get("statistics")) @@ -303,7 +303,7 @@ async def test_partitioned_collection_partition_key_extraction(self): created_db = self.databaseForTest collection_id = 'test_partitioned_collection_partition_key_extraction ' + str(uuid.uuid4()) - created_collection = created_db.create_container( + created_collection = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/address/state', kind=documents.PartitionKind.Hash) ) @@ -319,7 +319,7 @@ async def test_partitioned_collection_partition_key_extraction(self): self.OriginalExecuteFunction = _retry_utility.ExecuteFunction _retry_utility.ExecuteFunction = self._MockExecuteFunction # create document without partition key being specified - created_document = created_collection.create_item(body=document_definition) + created_document = await created_collection.create_item(body=document_definition) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], '["WA"]') del self.last_headers[:] @@ -328,7 +328,7 @@ async def test_partitioned_collection_partition_key_extraction(self): self.assertEqual(created_document.get('address').get('state'), document_definition.get('address').get('state')) collection_id = 'test_partitioned_collection_partition_key_extraction1 ' + str(uuid.uuid4()) - created_collection1 = created_db.create_container( + created_collection1 = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/address', kind=documents.PartitionKind.Hash) ) @@ -344,7 +344,7 @@ async def test_partitioned_collection_partition_key_extraction(self): # self.assertEqual(options['partitionKey'], documents.Undefined) collection_id = 'test_partitioned_collection_partition_key_extraction2 ' + str(uuid.uuid4()) - created_collection2 = created_db.create_container( + created_collection2 = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/address/state/city', kind=documents.PartitionKind.Hash) ) @@ -352,23 +352,23 @@ async def test_partitioned_collection_partition_key_extraction(self): self.OriginalExecuteFunction = _retry_utility.ExecuteFunction _retry_utility.ExecuteFunction = self._MockExecuteFunction # Create document with partitionkey not present in the document - created_document = created_collection2.create_item(document_definition) + created_document = await created_collection2.create_item(document_definition) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], [{}]) del self.last_headers[:] # self.assertEqual(options['partitionKey'], documents.Undefined) - created_db.delete_container(created_collection.id) - created_db.delete_container(created_collection1.id) - created_db.delete_container(created_collection2.id) + await created_db.delete_container(created_collection.id) + await created_db.delete_container(created_collection1.id) + await created_db.delete_container(created_collection2.id) async def test_partitioned_collection_partition_key_extraction_special_chars(self): created_db = self.databaseForTest collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars1 ' + str(uuid.uuid4()) - created_collection1 = created_db.create_container( + created_collection1 = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/\"level\' 1*()\"/\"le/vel2\"', kind=documents.PartitionKind.Hash) ) @@ -394,7 +394,7 @@ async def test_partitioned_collection_partition_key_extraction_special_chars(sel collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()) - created_collection2 = created_db.create_container( + created_collection2 = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/\'level\" 1*()\'/\'le/vel2\'', kind=documents.PartitionKind.Hash) ) @@ -406,13 +406,13 @@ async def test_partitioned_collection_partition_key_extraction_special_chars(sel self.OriginalExecuteFunction = _retry_utility.ExecuteFunction _retry_utility.ExecuteFunction = self._MockExecuteFunction # create document without partition key being specified - created_document = created_collection2.create_item(body=document_definition) + created_document = await created_collection2.create_item(body=document_definition) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], '["val2"]') del self.last_headers[:] - created_db.delete_container(created_collection1.id) - created_db.delete_container(created_collection2.id) + await created_db.delete_container(created_collection1.id) + await created_db.delete_container(created_collection2.id) async def test_partitioned_collection_path_parser(self): test_dir = os.path.dirname(os.path.abspath(__file__)) @@ -438,7 +438,7 @@ async def test_partitioned_collection_document_crud_and_query(self): document_definition = {'id': 'document', 'key': 'value'} - created_document = created_collection.create_item( + created_document = await created_collection.create_item( body=document_definition ) @@ -446,7 +446,7 @@ async def test_partitioned_collection_document_crud_and_query(self): self.assertEqual(created_document.get('key'), document_definition.get('key')) # read document - read_document = created_collection.read_item( + read_document = await created_collection.read_item( item=created_document.get('id'), partition_key=created_document.get('id') ) @@ -461,7 +461,7 @@ async def test_partitioned_collection_document_crud_and_query(self): # replace document document_definition['key'] = 'new value' - replaced_document = created_collection.replace_item( + replaced_document = await created_collection.replace_item( item=read_document, body=document_definition ) @@ -472,7 +472,7 @@ async def test_partitioned_collection_document_crud_and_query(self): document_definition['id'] = 'document2' document_definition['key'] = 'value2' - upserted_document = created_collection.upsert_item(body=document_definition) + upserted_document = await created_collection.upsert_item(body=document_definition) self.assertEqual(upserted_document.get('id'), document_definition.get('id')) self.assertEqual(upserted_document.get('key'), document_definition.get('key')) @@ -520,19 +520,19 @@ async def test_partitioned_collection_permissions(self): collection_id = 'test_partitioned_collection_permissions all collection' + str(uuid.uuid4()) - all_collection = created_db.create_container( + all_collection = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/key', kind=documents.PartitionKind.Hash) ) collection_id = 'test_partitioned_collection_permissions read collection' + str(uuid.uuid4()) - read_collection = created_db.create_container( + read_collection = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/key', kind=documents.PartitionKind.Hash) ) - user = created_db.create_user(body={'id': 'user' + str(uuid.uuid4())}) + user = await created_db.create_user(body={'id': 'user' + str(uuid.uuid4())}) permission_definition = { 'id': 'all permission', @@ -541,7 +541,7 @@ async def test_partitioned_collection_permissions(self): 'resourcePartitionKey': [1] } - all_permission = user.create_permission(body=permission_definition) + all_permission = await user.create_permission(body=permission_definition) permission_definition = { 'id': 'read permission', @@ -550,7 +550,7 @@ async def test_partitioned_collection_permissions(self): 'resourcePartitionKey': [1] } - read_permission = user.create_permission(body=permission_definition) + read_permission = await user.create_permission(body=permission_definition) resource_tokens = {} # storing the resource tokens based on Resource IDs @@ -565,11 +565,11 @@ async def test_partitioned_collection_permissions(self): 'key': 1 } - all_collection.client_connection = restricted_client.client_connection - read_collection.client_connection = restricted_client.client_connection + all_collection.client_connection = await restricted_client.client_connection + read_collection.client_connection = await restricted_client.client_connection # Create document in all_collection should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object and it has all permissions - created_document = all_collection.create_item(body=document_definition) + created_document = await all_collection.create_item(body=document_definition) # Create document in read_collection should fail since it has only read permissions for this collection await self.__AssertHTTPFailureWithStatus( @@ -597,8 +597,8 @@ async def test_partitioned_collection_permissions(self): document_definition['id'] ) - created_db.delete_container(all_collection) - created_db.delete_container(read_collection) + await created_db.delete_container(all_collection) + await created_db.delete_container(read_collection) async def test_partitioned_collection_execute_stored_procedure(self): created_db = self.databaseForTest @@ -618,10 +618,10 @@ async def test_partitioned_collection_execute_stored_procedure(self): ' });}') } - created_sproc = created_collection.scripts.create_stored_procedure(body=sproc) + created_sproc = await created_collection.scripts.create_stored_procedure(body=sproc) # Partiton Key value same as what is specified in the stored procedure body - result = created_collection.scripts.execute_stored_procedure(sproc=created_sproc['id'], partition_key=2) + result = await created_collection.scripts.execute_stored_procedure(sproc=created_sproc['id'], partition_key=2) self.assertEqual(result, 1) # Partiton Key value different than what is specified in the stored procedure body will cause a bad request(400) error @@ -634,7 +634,7 @@ async def test_partitioned_collection_execute_stored_procedure(self): async def test_partitioned_collection_partition_key_value_types(self): created_db = self.databaseForTest - created_collection = created_db.create_container( + created_collection = await created_db.create_container( id='test_partitioned_collection_partition_key_value_types ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/pk', kind='Hash') ) @@ -644,41 +644,41 @@ async def test_partitioned_collection_partition_key_value_types(self): 'spam': 'eggs'} # create document with partitionKey set as None here - created_collection.create_item(body=document_definition) + await created_collection.create_item(body=document_definition) document_definition = {'id': 'document1' + str(uuid.uuid4()), 'spam': 'eggs'} # create document with partitionKey set as Undefined here - created_collection.create_item(body=document_definition) + await created_collection.create_item(body=document_definition) document_definition = {'id': 'document1' + str(uuid.uuid4()), 'pk': True, 'spam': 'eggs'} # create document with bool partitionKey - created_collection.create_item(body=document_definition) + await created_collection.create_item(body=document_definition) document_definition = {'id': 'document1' + str(uuid.uuid4()), 'pk': 'value', 'spam': 'eggs'} # create document with string partitionKey - created_collection.create_item(body=document_definition) + await created_collection.create_item(body=document_definition) document_definition = {'id': 'document1' + str(uuid.uuid4()), 'pk': 100, 'spam': 'eggs'} # create document with int partitionKey - created_collection.create_item(body=document_definition) + await created_collection.create_item(body=document_definition) document_definition = {'id': 'document1' + str(uuid.uuid4()), 'pk': 10.50, 'spam': 'eggs'} # create document with float partitionKey - created_collection.create_item(body=document_definition) + await created_collection.create_item(body=document_definition) document_definition = {'name': 'sample document', 'spam': 'eggs', @@ -691,7 +691,7 @@ async def test_partitioned_collection_partition_key_value_types(self): document_definition ) - created_db.delete_container(created_collection) + await created_db.delete_container(created_collection) async def test_partitioned_collection_conflict_crud_and_query(self): created_db = self.databaseForTest @@ -766,7 +766,7 @@ async def test_document_crud(self): 'spam': 'eggs', 'key': 'value'} - created_document = created_collection.create_item(body=document_definition, enable_automatic_id_generation=True) + created_document = await created_collection.create_item(body=document_definition, enable_automatic_id_generation=True) self.assertEqual(created_document.get('name'), document_definition['name']) @@ -775,7 +775,7 @@ async def test_document_crud(self): 'key': 'value', 'id': str(uuid.uuid4())} - created_document = created_collection.create_item(body=document_definition) + created_document = await created_collection.create_item(body=document_definition) self.assertEqual(created_document.get('name'), document_definition['name']) self.assertEqual(created_document.get('id'), @@ -816,7 +816,7 @@ async def test_document_crud(self): created_document['name'] = 'replaced document' created_document['spam'] = 'not eggs' old_etag = created_document['_etag'] - replaced_document = created_collection.replace_item( + replaced_document = await created_collection.replace_item( item=created_document['id'], body=created_document ) @@ -874,7 +874,7 @@ async def test_document_crud(self): ) # should pass for most recent etag - replaced_document_conditional = created_collection.replace_item( + replaced_document_conditional = await created_collection.replace_item( match_condition=MatchConditions.IfNotModified, etag=replaced_document['_etag'], item=replaced_document['id'], @@ -890,14 +890,14 @@ async def test_document_crud(self): replaced_document['id'], 'document id should stay the same') # read document - one_document_from_read = created_collection.read_item( + one_document_from_read = await created_collection.read_item( item=replaced_document['id'], partition_key=replaced_document['id'] ) self.assertEqual(replaced_document['id'], one_document_from_read['id']) # delete document - created_collection.delete_item( + await created_collection.delete_item( item=replaced_document, partition_key=replaced_document['id'] ) @@ -925,7 +925,7 @@ async def test_document_upsert(self): 'key': 'value'} # create document using Upsert API - created_document = created_collection.upsert_item(body=document_definition) + created_document = await created_collection.upsert_item(body=document_definition) # verify id property self.assertEqual(created_document['id'], @@ -934,7 +934,7 @@ async def test_document_upsert(self): # test error for non-string id with pytest.raises(TypeError): document_definition['id'] = 7 - created_collection.upsert_item(body=document_definition) + await created_collection.upsert_item(body=document_definition) # read documents after creation and verify updated count documents = [document async for document in created_collection.read_all_items()] @@ -948,7 +948,7 @@ async def test_document_upsert(self): created_document['spam'] = 'not eggs' # should replace document since it already exists - upserted_document = created_collection.upsert_item(body=created_document) + upserted_document = await created_collection.upsert_item(body=created_document) # verify the changed properties self.assertEqual(upserted_document['name'], @@ -973,13 +973,13 @@ async def test_document_upsert(self): created_document['id'] = 'new id' # Upsert should create new document since the id is different - new_document = created_collection.upsert_item(body=created_document) + new_document = await created_collection.upsert_item(body=created_document) # Test modified access conditions created_document['spam'] = 'more eggs' - created_collection.upsert_item(body=created_document) + await created_collection.upsert_item(body=created_document) with pytest.raises(exceptions.CosmosHttpResponseError): - created_collection.upsert_item( + await created_collection.upsert_item( body=created_document, match_condition=MatchConditions.IfNotModified, etag=new_document['_etag']) @@ -997,8 +997,8 @@ async def test_document_upsert(self): 'upsert should increase the number of documents') # delete documents - created_collection.delete_item(item=upserted_document, partition_key=upserted_document['id']) - created_collection.delete_item(item=new_document, partition_key=new_document['id']) + await created_collection.delete_item(item=upserted_document, partition_key=upserted_document['id']) + await created_collection.delete_item(item=new_document, partition_key=new_document['id']) # read documents after delete and verify count is same as original documents = [document async for document in created_collection.read_all_items()] @@ -1010,7 +1010,7 @@ async def test_document_upsert(self): async def _test_spatial_index(self): db = self.databaseForTest # partial policy specified - collection = db.create_container( + collection = await db.create_container( id='collection with spatial index ' + str(uuid.uuid4()), indexing_policy={ 'includedPaths': [ @@ -1030,7 +1030,7 @@ async def _test_spatial_index(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection.create_item( + await collection.create_item( body={ 'id': 'loc1', 'Location': { @@ -1039,7 +1039,7 @@ async def _test_spatial_index(self): } } ) - collection.create_item( + await collection.create_item( body={ 'id': 'loc2', 'Location': { @@ -1055,7 +1055,7 @@ async def _test_spatial_index(self): self.assertEqual(1, len(results)) self.assertEqual('loc1', results[0]['id']) - db.delete_container(container=collection) + await db.delete_container(container=collection) # CRUD test for User resource async def test_user_crud(self): @@ -1067,7 +1067,7 @@ async def test_user_crud(self): before_create_count = len(users) # create user user_id = 'new user' + str(uuid.uuid4()) - user = db.create_user(body={'id': user_id}) + user = await db.create_user(body={'id': user_id}) self.assertEqual(user.id, user_id, 'user id error') # list users after creation users = [user async for user in db.list_users()] @@ -1083,9 +1083,9 @@ async def test_user_crud(self): # replace user replaced_user_id = 'replaced user' + str(uuid.uuid4()) - user_properties = user.read() + user_properties = await user.read() user_properties['id'] = replaced_user_id - replaced_user = db.replace_user(user_id, user_properties) + replaced_user = await db.replace_user(user_id, user_properties) self.assertEqual(replaced_user.id, replaced_user_id, 'user id should change') @@ -1093,12 +1093,12 @@ async def test_user_crud(self): replaced_user.id, 'user id should stay the same') # read user - user = db.get_user_client(replaced_user.id) + user = await db.get_user_client(replaced_user.id) self.assertEqual(replaced_user.id, user.id) # delete user - db.delete_user(user.id) + await db.delete_user(user.id) # read user after deletion - deleted_user = db.get_user_client(user.id) + deleted_user = await db.get_user_client(user.id) await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, deleted_user.read) @@ -1112,7 +1112,7 @@ async def test_user_upsert(self): # create user using Upsert API user_id = 'user' + str(uuid.uuid4()) - user = db.upsert_user(body={'id': user_id}) + user = await db.upsert_user(body={'id': user_id}) # verify id property self.assertEqual(user.id, user_id, 'user id error') @@ -1122,8 +1122,8 @@ async def test_user_upsert(self): self.assertEqual(len(users), before_create_count + 1) # Should replace the user since it already exists, there is no public property to change here - user_properties = user.read() - upserted_user = db.upsert_user(user_properties) + user_properties = await user.read() + upserted_user = await db.upsert_user(user_properties) # verify id property self.assertEqual(upserted_user.id, @@ -1134,12 +1134,12 @@ async def test_user_upsert(self): users = [user async for user in db.list_users()] self.assertEqual(len(users), before_create_count + 1) - user_properties = user.read() + user_properties = await user.read() user_properties['id'] = 'new user' + str(uuid.uuid4()) user.id = user_properties['id'] # Upsert should create new user since id is different - new_user = db.upsert_user(user_properties) + new_user = await db.upsert_user(user_properties) # verify id property self.assertEqual(new_user.id, user.id, 'user id error') @@ -1149,8 +1149,8 @@ async def test_user_upsert(self): self.assertEqual(len(users), before_create_count + 2) # delete users - db.delete_user(upserted_user.id) - db.delete_user(new_user.id) + await db.delete_user(upserted_user.id) + await db.delete_user(new_user.id) # read users after delete and verify count remains the same users = [user async for user in db.list_users()] @@ -1161,7 +1161,7 @@ async def test_permission_crud(self): # create database db = self.databaseForTest # create user - user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) + user = await db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) # list permissions permissions = [permission async for permission in user.list_permissions()] before_create_count = len(permissions) @@ -1171,7 +1171,7 @@ async def test_permission_crud(self): 'resource': 'dbs/AQAAAA==/colls/AQAAAJ0fgTc=' # A random one. } # create permission - permission = user.create_permission(permission) + permission = await user.create_permission(permission) self.assertEqual(permission.id, 'new permission', 'permission id error') @@ -1188,7 +1188,7 @@ async def test_permission_crud(self): self.assertTrue(results) # replace permission - change_permission = permission.properties.copy() + change_permission = await permission.properties.copy() permission.properties['id'] = 'replaced permission' permission.id = permission.properties['id'] replaced_permission = user.replace_permission(change_permission['id'], permission.properties) @@ -1199,10 +1199,10 @@ async def test_permission_crud(self): replaced_permission.id, 'permission id should stay the same') # read permission - permission = user.get_permission(replaced_permission.id) + permission = await user.get_permission(replaced_permission.id) self.assertEqual(replaced_permission.id, permission.id) # delete permission - user.delete_permission(replaced_permission.id) + await user.delete_permission(replaced_permission.id) # read permission after deletion await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, user.get_permission, @@ -1213,7 +1213,7 @@ async def test_permission_upsert(self): db = self.databaseForTest # create user - user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) + user = await db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) # read permissions and check count permissions = [permission async for permission in user.list_permissions()] @@ -1226,7 +1226,7 @@ async def test_permission_upsert(self): } # create permission using Upsert API - created_permission = user.upsert_permission(permission_definition) + created_permission = await user.upsert_permission(permission_definition) # verify id property self.assertEqual(created_permission.id, @@ -1241,7 +1241,7 @@ async def test_permission_upsert(self): permission_definition['permissionMode'] = documents.PermissionMode.All # should repace the permission since it already exists - upserted_permission = user.upsert_permission(permission_definition) + upserted_permission = await user.upsert_permission(permission_definition) # verify id property self.assertEqual(upserted_permission.id, created_permission.id, @@ -1264,7 +1264,7 @@ async def test_permission_upsert(self): created_permission.resource_link = created_permission.properties['resource'] # should create new permission since id has changed - new_permission = user.upsert_permission(created_permission.properties) + new_permission = await user.upsert_permission(created_permission.properties) # verify id and resource property self.assertEqual(new_permission.id, @@ -1280,8 +1280,8 @@ async def test_permission_upsert(self): self.assertEqual(len(permissions), before_create_count + 2) # delete permissions - user.delete_permission(upserted_permission.id) - user.delete_permission(new_permission.id) + await user.delete_permission(upserted_permission.id) + await user.delete_permission(new_permission.id) # read permissions and verify count remains the same permissions = [permission async for permission in user.list_permissions()] @@ -1302,19 +1302,19 @@ async def __SetupEntities(client): # create database db = self.databaseForTest # create collection - collection = db.create_container( + collection = await db.create_container( id='test_authorization' + str(uuid.uuid4()), partition_key=PartitionKey(path='/id', kind='Hash') ) # create document1 - document = collection.create_item( + document = await collection.create_item( body={'id': 'doc1', 'spam': 'eggs', 'key': 'value'}, ) # create user - user = db.create_user(body={'id': 'user' + str(uuid.uuid4())}) + user = await db.create_user(body={'id': 'user' + str(uuid.uuid4())}) # create permission for collection permission = { @@ -1322,7 +1322,7 @@ async def __SetupEntities(client): 'permissionMode': documents.PermissionMode.Read, 'resource': "dbs/" + db.id + "/colls/" + collection.id } - permission_on_coll = user.create_permission(body=permission) + permission_on_coll = await user.create_permission(body=permission) self.assertIsNotNone(permission_on_coll.properties['_token'], 'permission token is invalid') @@ -1332,7 +1332,7 @@ async def __SetupEntities(client): 'permissionMode': documents.PermissionMode.All, 'resource': "dbs/" + db.id + "/colls/" + collection.id + "/docs/" + document["id"] } - permission_on_doc = user.create_permission(body=permission) + permission_on_doc = await user.create_permission(body=permission) self.assertIsNotNone(permission_on_doc.properties['_token'], 'permission token is invalid') @@ -1357,17 +1357,17 @@ async def __SetupEntities(client): "Session", connection_policy=CRUDTests.connectionPolicy) # setup entities - entities = __SetupEntities(client) + entities = await __SetupEntities(client) resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: entities['permissionOnColl'].properties['_token']} col_client = cosmos_client.CosmosClient( CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) db = entities['db'] - old_client_connection = db.client_connection - db.client_connection = col_client.client_connection + old_client_connection = await db.client_connection + db.client_connection = await col_client.client_connection # 1. Success-- Use Col Permission to Read - success_coll = db.get_container_client(container=entities['coll']) + success_coll = await db.get_container_client(container=entities['coll']) # 2. Failure-- Use Col Permission to delete await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, db.delete_container, @@ -1382,7 +1382,7 @@ async def __SetupEntities(client): # 4. Success-- Use Col Permission to Read Doc docId = entities['doc']['id'] - success_doc = success_coll.read_item( + success_doc = await success_coll.read_item( item=docId, partition_key=docId ) @@ -1404,14 +1404,15 @@ async def __SetupEntities(client): CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) # 6. Success-- Use Doc permission to read doc - read_doc = await doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, docId) + read_doc = await doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, + docId) self.assertEqual(read_doc["id"], docId) # 6. Success-- Use Doc permission to delete doc await doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) self.assertEqual(read_doc["id"], docId) - db.client_connection = old_client_connection + db.client_connection = await old_client_connection db.delete_container(entities['coll']) async def test_trigger_crud(self): @@ -1429,7 +1430,7 @@ async def test_trigger_crud(self): 'triggerType': documents.TriggerType.Pre, 'triggerOperation': documents.TriggerOperation.All } - trigger = collection.scripts.create_trigger(body=trigger_definition) + trigger = await collection.scripts.create_trigger(body=trigger_definition) for property in trigger_definition: if property != "serverScript": self.assertEqual( @@ -1455,9 +1456,9 @@ async def test_trigger_crud(self): self.assertTrue(triggers) # replace trigger - change_trigger = trigger.copy() + change_trigger = await trigger.copy() trigger['body'] = 'function() {var x = 20;}' - replaced_trigger = collection.scripts.replace_trigger(change_trigger['id'], trigger) + replaced_trigger = await collection.scripts.replace_trigger(change_trigger['id'], trigger) for property in trigger_definition: if property != "serverScript": self.assertEqual( @@ -1469,10 +1470,10 @@ async def test_trigger_crud(self): 'function() {var x = 20;}') # read trigger - trigger = collection.scripts.get_trigger(replaced_trigger['id']) + trigger = await collection.scripts.get_trigger(replaced_trigger['id']) self.assertEqual(replaced_trigger['id'], trigger['id']) # delete trigger - collection.scripts.delete_trigger(replaced_trigger['id']) + await collection.scripts.delete_trigger(replaced_trigger['id']) # read triggers after deletion await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, collection.scripts.delete_trigger, @@ -1491,7 +1492,7 @@ async def test_udf_crud(self): 'id': 'sample udf', 'body': 'function() {var x = 10;}' } - udf = collection.scripts.create_user_defined_function(body=udf_definition) + udf = await collection.scripts.create_user_defined_function(body=udf_definition) for property in udf_definition: self.assertEqual( udf[property], @@ -1512,19 +1513,19 @@ async def test_udf_crud(self): )] self.assertTrue(results) # replace udf - change_udf = udf.copy() + change_udf = await udf.copy() udf['body'] = 'function() {var x = 20;}' - replaced_udf = collection.scripts.replace_user_defined_function(udf=udf['id'], body=udf) + replaced_udf = await collection.scripts.replace_user_defined_function(udf=udf['id'], body=udf) for property in udf_definition: self.assertEqual( replaced_udf[property], udf[property], 'property {property} should match'.format(property=property)) # read udf - udf = collection.scripts.get_user_defined_function(replaced_udf['id']) + udf = await collection.scripts.get_user_defined_function(replaced_udf['id']) self.assertEqual(replaced_udf['id'], udf['id']) # delete udf - collection.scripts.delete_user_defined_function(replaced_udf['id']) + await collection.scripts.delete_user_defined_function(replaced_udf['id']) # read udfs after deletion await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, collection.scripts.get_user_defined_function, @@ -1543,7 +1544,7 @@ async def test_sproc_crud(self): 'id': 'sample sproc', 'serverScript': 'function() {var x = 10;}' } - sproc = collection.scripts.create_stored_procedure(body=sproc_definition) + sproc = await collection.scripts.create_stored_procedure(body=sproc_definition) for property in sproc_definition: if property != "serverScript": self.assertEqual( @@ -1567,9 +1568,9 @@ async def test_sproc_crud(self): )] self.assertIsNotNone(sprocs) # replace sproc - change_sproc = sproc.copy() + change_sproc = await sproc.copy() sproc['body'] = 'function() {var x = 20;}' - replaced_sproc = collection.scripts.replace_stored_procedure(sproc=change_sproc['id'], body=sproc) + replaced_sproc = await collection.scripts.replace_stored_procedure(sproc=change_sproc['id'], body=sproc) for property in sproc_definition: if property != 'serverScript': self.assertEqual( @@ -1580,10 +1581,10 @@ async def test_sproc_crud(self): self.assertEqual(replaced_sproc['body'], "function() {var x = 20;}") # read sproc - sproc = collection.scripts.get_stored_procedure(replaced_sproc['id']) + sproc = await collection.scripts.get_stored_procedure(replaced_sproc['id']) self.assertEqual(replaced_sproc['id'], sproc['id']) # delete sproc - collection.scripts.delete_stored_procedure(replaced_sproc['id']) + await collection.scripts.delete_stored_procedure(replaced_sproc['id']) # read sprocs after deletion await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, collection.scripts.get_stored_procedure, @@ -1610,9 +1611,9 @@ async def test_script_logging_execute_stored_procedure(self): '}') } - created_sproc = created_collection.scripts.create_stored_procedure(body=sproc) + created_sproc = await created_collection.scripts.create_stored_procedure(body=sproc) - result = created_collection.scripts.execute_stored_procedure( + result = await created_collection.scripts.execute_stored_procedure( sproc=created_sproc['id'], partition_key=1 ) @@ -1621,7 +1622,7 @@ async def test_script_logging_execute_stored_procedure(self): self.assertFalse( HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) - result = created_collection.scripts.execute_stored_procedure( + result = await created_collection.scripts.execute_stored_procedure( sproc=created_sproc['id'], enable_script_logging=True, partition_key=1 @@ -1632,7 +1633,7 @@ async def test_script_logging_execute_stored_procedure(self): created_collection.scripts.client_connection.last_response_headers.get( HttpHeaders.ScriptLogResults)) - result = created_collection.scripts.execute_stored_procedure( + result = await created_collection.scripts.execute_stored_procedure( sproc=created_sproc['id'], enable_script_logging=False, partition_key=1 @@ -1646,19 +1647,19 @@ async def test_collection_indexing_policy(self): # create database db = self.databaseForTest # create collection - collection = db.create_container( + collection = await db.create_container( id='test_collection_indexing_policy default policy' + str(uuid.uuid4()), partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() self.assertEqual(collection_properties['indexingPolicy']['indexingMode'], documents.IndexingMode.Consistent, 'default indexing mode should be consistent') - db.delete_container(container=collection) + await db.delete_container(container=collection) - consistent_collection = db.create_container( + consistent_collection = await db.create_container( id='test_collection_indexing_policy consistent collection ' + str(uuid.uuid4()), indexing_policy={ 'indexingMode': documents.IndexingMode.Consistent @@ -1666,14 +1667,14 @@ async def test_collection_indexing_policy(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - consistent_collection_properties = consistent_collection.read() + consistent_collection_properties = await consistent_collection.read() self.assertEqual(consistent_collection_properties['indexingPolicy']['indexingMode'], documents.IndexingMode.Consistent, 'indexing mode should be consistent') - db.delete_container(container=consistent_collection) + await db.delete_container(container=consistent_collection) - collection_with_indexing_policy = db.create_container( + collection_with_indexing_policy = await db.create_container( id='CollectionWithIndexingPolicy ' + str(uuid.uuid4()), indexing_policy={ 'automatic': True, @@ -1699,52 +1700,52 @@ async def test_collection_indexing_policy(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_with_indexing_policy_properties = collection_with_indexing_policy.read() + collection_with_indexing_policy_properties = await collection_with_indexing_policy.read() self.assertEqual(1, len(collection_with_indexing_policy_properties['indexingPolicy']['includedPaths']), 'Unexpected includedPaths length') self.assertEqual(2, len(collection_with_indexing_policy_properties['indexingPolicy']['excludedPaths']), 'Unexpected excluded path count') - db.delete_container(container=collection_with_indexing_policy) + await db.delete_container(container=collection_with_indexing_policy) async def test_create_default_indexing_policy(self): # create database db = self.databaseForTest # no indexing policy specified - collection = db.create_container( + collection = await db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) - db.delete_container(container=collection) + await db.delete_container(container=collection) # partial policy specified - collection = db.create_container( + collection = await db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), indexing_policy={ 'indexingMode': documents.IndexingMode.Consistent, 'automatic': True }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) - db.delete_container(container=collection) + await db.delete_container(container=collection) # default policy - collection = db.create_container( + collection = await db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy03' + str(uuid.uuid4()), indexing_policy={}, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) - db.delete_container(container=collection) + await db.delete_container(container=collection) # missing indexes - collection = db.create_container( + collection = await db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy04' + str(uuid.uuid4()), indexing_policy={ 'includedPaths': [ @@ -1755,12 +1756,12 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) - db.delete_container(container=collection) + await db.delete_container(container=collection) # missing precision - collection = db.create_container( + collection = await db.create_container( id='test_create_default_indexing_policy TestCreateDefaultPolicy05' + str(uuid.uuid4()), indexing_policy={ 'includedPaths': [ @@ -1781,9 +1782,9 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) - db.delete_container(container=collection) + await db.delete_container(container=collection) async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): # create database @@ -1843,7 +1844,7 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): } custom_logger = logging.getLogger("CustomLogger") - created_container = db.create_container( + created_container = await db.create_container( id='composite_index_spatial_index' + str(uuid.uuid4()), indexing_policy=indexing_policy, partition_key=PartitionKey(path='/id', kind='Hash'), @@ -1853,8 +1854,8 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): logging_enable=True, logger=custom_logger, ) - created_properties = created_container.read(logger=custom_logger) - read_indexing_policy = created_properties['indexingPolicy'] + created_properties = await created_container.read(logger=custom_logger) + read_indexing_policy = await created_properties['indexingPolicy'] if 'localhost' in self.host or '127.0.0.1' in self.host: # TODO: Differing result between live and emulator self.assertListEqual(indexing_policy['spatialIndexes'], read_indexing_policy['spatialIndexes']) @@ -1863,7 +1864,7 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): self.assertListEqual(indexing_policy['spatialIndexes'], read_indexing_policy['spatialIndexes']) self.assertListEqual(indexing_policy['compositeIndexes'], read_indexing_policy['compositeIndexes']) - db.delete_container(container=created_container) + await db.delete_container(container=created_container) async def _check_default_indexing_policy_paths(self, indexing_policy): async def __get_first(array): @@ -1877,7 +1878,7 @@ async def __get_first(array): # included paths should be 1: '/'. self.assertEqual(1, len(indexing_policy['includedPaths'])) - root_included_path = __get_first([included_path for included_path in indexing_policy['includedPaths'] + root_included_path = await __get_first([included_path for included_path in indexing_policy['includedPaths'] if included_path['path'] == '/*']) self.assertFalse(root_included_path.get('indexes')) @@ -1910,12 +1911,12 @@ async def test_client_request_timeout_when_connection_retry_configuration_specif connection_policy=connection_policy) async def test_client_connection_retry_configuration(self): - total_time_for_two_retries = self.initialize_client_with_connection_urllib_retry_config(2) - total_time_for_three_retries = self.initialize_client_with_connection_urllib_retry_config(3) + total_time_for_two_retries = await self.initialize_client_with_connection_urllib_retry_config(2) + total_time_for_three_retries = await self.initialize_client_with_connection_urllib_retry_config(3) self.assertGreater(total_time_for_three_retries, total_time_for_two_retries) - total_time_for_two_retries = self.initialize_client_with_connection_core_retry_config(2) - total_time_for_three_retries = self.initialize_client_with_connection_core_retry_config(3) + total_time_for_two_retries = await self.initialize_client_with_connection_core_retry_config(2) + total_time_for_three_retries = await self.initialize_client_with_connection_core_retry_config(3) self.assertGreater(total_time_for_three_retries, total_time_for_two_retries) async def initialize_client_with_connection_urllib_retry_config(self, retries): @@ -1994,7 +1995,7 @@ async def test_absolute_client_timeout(self): databases = [database async for database in databases] async def test_query_iterable_functionality(self): - def __create_resources(client): + async def __create_resources(client): """Creates resources for this test. :Parameters: @@ -2005,9 +2006,9 @@ def __create_resources(client): """ collection = self.configs.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) - doc1 = collection.create_item(body={'id': 'doc1', 'prop1': 'value1'}) - doc2 = collection.create_item(body={'id': 'doc2', 'prop1': 'value2'}) - doc3 = collection.create_item(body={'id': 'doc3', 'prop1': 'value3'}) + doc1 = await collection.create_item(body={'id': 'doc1', 'prop1': 'value1'}) + doc2 = await collection.create_item(body={'id': 'doc2', 'prop1': 'value2'}) + doc3 = await collection.create_item(body={'id': 'doc3', 'prop1': 'value3'}) resources = { 'coll': collection, 'doc1': doc1, @@ -2017,8 +2018,8 @@ def __create_resources(client): return resources # Validate QueryIterable by converting it to a list. - resources = __create_resources(self.client) - results = resources['coll'].read_all_items(max_item_count=2) + resources = await __create_resources(self.client) + results = await resources['coll'].read_all_items(max_item_count=2) docs = [doc async for doc in results] self.assertEqual(3, len(docs), @@ -2029,7 +2030,7 @@ def __create_resources(client): self.assertEqual(resources['doc3']['id'], docs[2]['id']) # Validate QueryIterable iterator with 'for'. - results = resources['coll'].read_all_items(max_item_count=2) + results = await resources['coll'].read_all_items(max_item_count=2) counter = 0 # test QueryIterable with 'for'. for doc in iter(results): @@ -2049,9 +2050,9 @@ def __create_resources(client): self.assertEqual(counter, 3) # Get query results page by page. - results = resources['coll'].read_all_items(max_item_count=2) + results = await resources['coll'].read_all_items(max_item_count=2) - page_iter = results.by_page() + page_iter = await results.by_page() first_block = [page async for page in next(page_iter)] self.assertEqual(2, len(first_block), 'First block should have 2 entries.') self.assertEqual(resources['doc1']['id'], first_block[0]['id']) @@ -2137,7 +2138,7 @@ async def __CreateTriggers(collection, triggers): """ for trigger_i in triggers: - trigger = collection.scripts.create_trigger(body=trigger_i) + trigger = await collection.scripts.create_trigger(body=trigger_i) for property in trigger_i: self.assertEqual( trigger[property], @@ -2148,11 +2149,11 @@ async def __CreateTriggers(collection, triggers): db = self.databaseForTest # create collections pkd = PartitionKey(path='/id', kind='Hash') - collection1 = db.create_container(id='test_trigger_functionality 1 ' + str(uuid.uuid4()), + collection1 = await db.create_container(id='test_trigger_functionality 1 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) - collection2 = db.create_container(id='test_trigger_functionality 2 ' + str(uuid.uuid4()), + collection2 = await db.create_container(id='test_trigger_functionality 2 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) - collection3 = db.create_container(id='test_trigger_functionality 3 ' + str(uuid.uuid4()), + collection3 = await db.create_container(id='test_trigger_functionality 3 ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/key', kind='Hash')) # create triggers await __CreateTriggers(collection1, triggers_in_collection1) @@ -2170,14 +2171,14 @@ async def __CreateTriggers(collection, triggers): 'DOC1t1', 'id should be capitalized') - document_1_2 = collection1.create_item( + document_1_2 = await collection1.create_item( body={'id': 'testing post trigger', 'key': 'value'}, pre_trigger_include='t1', post_trigger_include='response1', ) self.assertEqual(document_1_2['id'], 'TESTING POST TRIGGERt1') - document_1_3 = collection1.create_item( + document_1_3 = await collection1.create_item( body={'id': 'responseheaders', 'key': 'value'}, pre_trigger_include='t1' ) @@ -2185,7 +2186,7 @@ async def __CreateTriggers(collection, triggers): triggers_2 = [trigger async for trigger in collection2.scripts.list_triggers()] self.assertEqual(len(triggers_2), 2) - document_2_1 = collection2.create_item( + document_2_1 = await collection2.create_item( body={'id': 'doc2', 'key': 'value2'}, pre_trigger_include='t2' @@ -2193,7 +2194,7 @@ async def __CreateTriggers(collection, triggers): self.assertEqual(document_2_1['id'], 'doc2', 'id shouldn\'t change') - document_2_2 = collection2.create_item( + document_2_2 = await collection2.create_item( body={'id': 'Doc3', 'prop': 'empty', 'key': 'value2'}, @@ -2208,9 +2209,9 @@ async def __CreateTriggers(collection, triggers): post_trigger_include='triggerOpType' ) - db.delete_container(collection1) - db.delete_container(collection2) - db.delete_container(collection3) + await db.delete_container(collection1) + await db.delete_container(collection2) + await db.delete_container(collection3) async def test_stored_procedure_functionality(self): # create database @@ -2230,8 +2231,8 @@ async def test_stored_procedure_functionality(self): '}') } - retrieved_sproc = collection.scripts.create_stored_procedure(body=sproc1) - result = collection.scripts.execute_stored_procedure( + retrieved_sproc = await collection.scripts.create_stored_procedure(body=sproc1) + result = await collection.scripts.execute_stored_procedure( sproc=retrieved_sproc['id'], partition_key=1 ) @@ -2245,8 +2246,8 @@ async def test_stored_procedure_functionality(self): ' }' + '}') } - retrieved_sproc2 = collection.scripts.create_stored_procedure(body=sproc2) - result = collection.scripts.execute_stored_procedure( + retrieved_sproc2 = await collection.scripts.create_stored_procedure(body=sproc2) + result = await collection.scripts.execute_stored_procedure( sproc=retrieved_sproc2['id'], partition_key=1 ) @@ -2259,8 +2260,8 @@ async def test_stored_procedure_functionality(self): ' \'a\' + input.temp);' + '}') } - retrieved_sproc3 = collection.scripts.create_stored_procedure(body=sproc3) - result = collection.scripts.execute_stored_procedure( + retrieved_sproc3 = await collection.scripts.create_stored_procedure(body=sproc3) + result = await collection.scripts.execute_stored_procedure( sproc=retrieved_sproc3['id'], params={'temp': 'so'}, partition_key=1 @@ -2284,17 +2285,17 @@ async def test_offer_read_and_query(self): db = self.databaseForTest # Create collection. - collection = db.create_container( + collection = await db.create_container( id='test_offer_read_and_query ' + str(uuid.uuid4()), partition_key=PartitionKey(path='/id', kind='Hash') ) # Read the offer. - expected_offer = collection.read_offer() - collection_properties = collection.read() + expected_offer = await collection.read_offer() + collection_properties = await collection.read() await self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) # Now delete the collection. - db.delete_container(container=collection) + await db.delete_container(container=collection) # Reading fails. await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, collection.read_offer) @@ -2304,12 +2305,12 @@ async def test_offer_replace(self): # Create collection. collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) # Read Offer - expected_offer = collection.read_offer() - collection_properties = collection.read() + expected_offer = await collection.read_offer() + collection_properties = await collection.read() await self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) # Replace the offer. - replaced_offer = collection.replace_throughput(expected_offer.offer_throughput + 100) - collection_properties = collection.read() + replaced_offer = await collection.replace_throughput(expected_offer.offer_throughput + 100) + collection_properties = await collection.read() await self.__ValidateOfferResponseBody(replaced_offer, collection_properties.get('_self'), None) # Check if the replaced offer is what we expect. self.assertEqual(expected_offer.properties.get('content').get('offerThroughput') + 100, @@ -2319,35 +2320,35 @@ async def test_offer_replace(self): async def test_database_account_functionality(self): # Validate database account functionality. - database_account = self.client.get_database_account() + database_account = await self.client.get_database_account() self.assertEqual(database_account.DatabasesLink, '/dbs/') self.assertEqual(database_account.MediaLink, '/media/') if (HttpHeaders.MaxMediaStorageUsageInMB in - self.client.client_connection.last_response_headers): + await self.client.client_connection.last_response_headers): self.assertEqual( database_account.MaxMediaStorageUsageInMB, self.client.client_connection.last_response_headers[ HttpHeaders.MaxMediaStorageUsageInMB]) if (HttpHeaders.CurrentMediaStorageUsageInMB in - self.client.client_connection.last_response_headers): + await self.client.client_connection.last_response_headers): self.assertEqual( database_account.CurrentMediaStorageUsageInMB, - self.client.client_connection.last_response_headers[ + await self.client.client_connection.last_response_headers[ HttpHeaders.CurrentMediaStorageUsageInMB]) self.assertIsNotNone(database_account.ConsistencyPolicy['defaultConsistencyLevel']) async def test_index_progress_headers(self): created_db = self.databaseForTest - consistent_coll = created_db.create_container( + consistent_coll = await created_db.create_container( id='test_index_progress_headers consistent_coll ' + str(uuid.uuid4()), partition_key=PartitionKey(path="/id", kind='Hash'), ) - created_container = created_db.get_container_client(container=consistent_coll) - created_container.read(populate_quota_info=True) + created_container = await created_db.get_container_client(container=consistent_coll) + await created_container.read(populate_quota_info=True) self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) - none_coll = created_db.create_container( + none_coll = await created_db.create_container( id='test_index_progress_headers none_coll ' + str(uuid.uuid4()), indexing_policy={ 'indexingMode': documents.IndexingMode.NoIndex, @@ -2355,13 +2356,13 @@ async def test_index_progress_headers(self): }, partition_key=PartitionKey(path="/id", kind='Hash') ) - created_container = created_db.get_container_client(container=none_coll) - created_container.read(populate_quota_info=True) + created_container = await created_db.get_container_client(container=none_coll) + await created_container.read(populate_quota_info=True) self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) - created_db.delete_container(consistent_coll) - created_db.delete_container(none_coll) + await created_db.delete_container(consistent_coll) + await created_db.delete_container(none_coll) async def test_id_validation(self): # Id shouldn't end with space. @@ -2400,7 +2401,7 @@ async def test_id_validation(self): self.assertEqual('Id contains illegal chars.', e.args[0]) # Id can begin with space - db = self.client.create_database(id=' id_begin_space') + db = await self.client.create_database(id=' id_begin_space') self.assertTrue(True) await self.client.delete_database(database=db) @@ -2419,13 +2420,13 @@ async def test_id_case_validation(self): # create 2 collections with different casing of IDs # pascalCase - created_collection1 = created_db.create_container( + created_collection1 = await created_db.create_container( id=collection_id1, partition_key=PartitionKey(path='/id', kind='Hash') ) # CamelCase - created_collection2 = created_db.create_container( + created_collection2 = await created_db.create_container( id=collection_id2, partition_key=PartitionKey(path='/id', kind='Hash') ) @@ -2439,8 +2440,8 @@ async def test_id_case_validation(self): self.assertEqual(collection_id1, created_collection1.id) self.assertEqual(collection_id2, created_collection2.id) - created_db.delete_container(created_collection1) - created_db.delete_container(created_collection2) + await created_db.delete_container(created_collection1) + await created_db.delete_container(created_collection2) # TODO: fix test @pytest.mark.skip @@ -2455,11 +2456,11 @@ async def test_id_unicode_validation(self): collection_id2 = "!@$%^&*()-~`'_[]{}|;:,.<>" # verify that collections are created with specified IDs - created_collection1 = created_db.create_container( + created_collection1 = await created_db.create_container( id=collection_id1, partition_key=PartitionKey(path='/id', kind='Hash') ) - created_collection2 = created_db.create_container( + created_collection2 = await created_db.create_container( id=collection_id2, partition_key=PartitionKey(path='/id', kind='Hash') ) @@ -2470,8 +2471,8 @@ async def test_id_unicode_validation(self): created_collection1_properties = created_collection1.read() created_collection2_properties = created_collection2.read() - created_db.client_connection.DeleteContainer(created_collection1_properties['_self']) - created_db.client_connection.DeleteContainer(created_collection2_properties['_self']) + await created_db.client_connection.DeleteContainer(created_collection1_properties['_self']) + await created_db.client_connection.DeleteContainer(created_collection2_properties['_self']) async def test_get_resource_with_dictionary_and_object(self): created_db = self.databaseForTest @@ -2491,42 +2492,42 @@ async def test_get_resource_with_dictionary_and_object(self): created_container = self.configs.create_multi_partition_collection_if_not_exist(self.client) # read container with id - read_container = created_db.get_container_client(created_container.id) + read_container = await created_db.get_container_client(created_container.id) self.assertEqual(read_container.id, created_container.id) # read container with instance - read_container = created_db.get_container_client(created_container) + read_container = await created_db.get_container_client(created_container) self.assertEqual(read_container.id, created_container.id) # read container with properties - created_properties = created_container.read() - read_container = created_db.get_container_client(created_properties) + created_properties = await created_container.read() + read_container = await created_db.get_container_client(created_properties) self.assertEqual(read_container.id, created_container.id) - created_item = created_container.create_item({'id': '1' + str(uuid.uuid4())}) + created_item = await created_container.create_item({'id': '1' + str(uuid.uuid4())}) # read item with id - read_item = created_container.read_item(item=created_item['id'], partition_key=created_item['id']) + read_item = await created_container.read_item(item=created_item['id'], partition_key=created_item['id']) self.assertEqual(read_item['id'], created_item['id']) # read item with properties - read_item = created_container.read_item(item=created_item, partition_key=created_item['id']) + read_item = await created_container.read_item(item=created_item, partition_key=created_item['id']) self.assertEqual(read_item['id'], created_item['id']) - created_sproc = created_container.scripts.create_stored_procedure({ + created_sproc = await created_container.scripts.create_stored_procedure({ 'id': 'storedProcedure' + str(uuid.uuid4()), 'body': 'function () { }' }) # read sproc with id - read_sproc = created_container.scripts.get_stored_procedure(created_sproc['id']) + read_sproc = await created_container.scripts.get_stored_procedure(created_sproc['id']) self.assertEqual(read_sproc['id'], created_sproc['id']) # read sproc with properties - read_sproc = created_container.scripts.get_stored_procedure(created_sproc) + read_sproc = await created_container.scripts.get_stored_procedure(created_sproc) self.assertEqual(read_sproc['id'], created_sproc['id']) - created_trigger = created_container.scripts.create_trigger({ + created_trigger = await created_container.scripts.create_trigger({ 'id': 'sample trigger' + str(uuid.uuid4()), 'serverScript': 'function() {var x = 10;}', 'triggerType': documents.TriggerType.Pre, @@ -2541,37 +2542,37 @@ async def test_get_resource_with_dictionary_and_object(self): read_trigger = created_container.scripts.get_trigger(created_trigger) self.assertEqual(read_trigger['id'], created_trigger['id']) - created_udf = created_container.scripts.create_user_defined_function({ + created_udf = await created_container.scripts.create_user_defined_function({ 'id': 'sample udf' + str(uuid.uuid4()), 'body': 'function() {var x = 10;}' }) # read udf with id - read_udf = created_container.scripts.get_user_defined_function(created_udf['id']) + read_udf = await created_container.scripts.get_user_defined_function(created_udf['id']) self.assertEqual(created_udf['id'], read_udf['id']) # read udf with properties - read_udf = created_container.scripts.get_user_defined_function(created_udf) + read_udf = await created_container.scripts.get_user_defined_function(created_udf) self.assertEqual(created_udf['id'], read_udf['id']) - created_user = created_db.create_user({ + created_user = await created_db.create_user({ 'id': 'user' + str(uuid.uuid4()) }) # read user with id - read_user = created_db.get_user_client(created_user.id) + read_user = await created_db.get_user_client(created_user.id) self.assertEqual(read_user.id, created_user.id) # read user with instance - read_user = created_db.get_user_client(created_user) + read_user = await created_db.get_user_client(created_user) self.assertEqual(read_user.id, created_user.id) # read user with properties - created_user_properties = created_user.read() - read_user = created_db.get_user_client(created_user_properties) + created_user_properties = await created_user.read() + read_user = await created_db.get_user_client(created_user_properties) self.assertEqual(read_user.id, created_user.id) - created_permission = created_user.create_permission({ + created_permission = await created_user.create_permission({ 'id': 'all permission' + str(uuid.uuid4()), 'permissionMode': documents.PermissionMode.All, 'resource': created_container.container_link, @@ -2579,15 +2580,15 @@ async def test_get_resource_with_dictionary_and_object(self): }) # read permission with id - read_permission = created_user.get_permission(created_permission.id) + read_permission = await created_user.get_permission(created_permission.id) self.assertEqual(read_permission.id, created_permission.id) # read permission with instance - read_permission = created_user.get_permission(created_permission) + read_permission = await created_user.get_permission(created_permission) self.assertEqual(read_permission.id, created_permission.id) # read permission with properties - read_permission = created_user.get_permission(created_permission.properties) + read_permission = await created_user.get_permission(created_permission.properties) self.assertEqual(read_permission.id, created_permission.id) # Temporarily commenting analytical storage tests until emulator support comes. From d8187e1c3788ec3c6e7f64c1bd547747cd811be0 Mon Sep 17 00:00:00 2001 From: JericHunter Date: Tue, 3 May 2022 09:39:40 -0700 Subject: [PATCH 6/8] Update test_crud_async.py --- .../azure-cosmos/test/test_crud_async.py | 176 +++++++++--------- 1 file changed, 89 insertions(+), 87 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index fa7f801b1971..249b2914d703 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -100,7 +100,7 @@ async def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs @classmethod async def setUpClass(cls): cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, connection_policy=cls.connectionPolicy) - cls.databaseForTest = cls.client.create_database_if_not_exists(test_config._test_config.TEST_DATABASE_ID) + cls.databaseForTest = await cls.client.create_database_if_not_exists(test_config._test_config.TEST_DATABASE_ID) async def setUp(self): self.client = cosmos_client.CosmosClient(self.host, self.masterKey, "Session", @@ -215,7 +215,7 @@ async def test_collection_crud(self): assert isinstance(created_recorder.body, Mapping) assert 'id' in created_recorder.body - created_properties = await created_collection.read() + created_properties = created_collection.read() self.assertEqual('consistent', created_properties['indexingPolicy']['indexingMode']) # read collections after creation @@ -236,7 +236,7 @@ async def test_collection_crud(self): # delete collection await created_db.delete_container(created_collection.id) # read collection after deletion - created_container = await created_db.get_container_client(created_collection.id) + created_container = created_db.get_container_client(created_collection.id) await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, created_container.read) @@ -287,9 +287,9 @@ async def test_partitioned_collection(self): async def test_partitioned_collection_quota(self): created_db = self.databaseForTest - created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) - retrieved_collection = await created_db.get_container_client( + retrieved_collection = created_db.get_container_client( container=created_collection.id ) @@ -433,7 +433,7 @@ async def test_partitioned_collection_path_parser(self): async def test_partitioned_collection_document_crud_and_query(self): created_db = self.databaseForTest - created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) document_definition = {'id': 'document', 'key': 'value'} @@ -481,7 +481,7 @@ async def test_partitioned_collection_document_crud_and_query(self): self.assertEqual(2, len(documentlist)) # delete document - created_collection.delete_item(item=upserted_document, partition_key=upserted_document.get('id')) + await created_collection.delete_item(item=upserted_document, partition_key=upserted_document.get('id')) # query document on the partition key specified in the predicate will pass even without setting enableCrossPartitionQuery or passing in the partitionKey value documentlist = [document async for document in created_collection.query_items( @@ -558,15 +558,15 @@ async def test_partitioned_collection_permissions(self): resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = ( read_permission.properties['_token']) - restricted_client = cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) + async with cosmos_client.CosmosClient( + CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) as restricted_client: print('Async Initialization') document_definition = {'id': 'document1', 'key': 1 } - all_collection.client_connection = await restricted_client.client_connection - read_collection.client_connection = await restricted_client.client_connection + all_collection.client_connection = restricted_client.client_connection + read_collection.client_connection = restricted_client.client_connection # Create document in all_collection should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object and it has all permissions created_document = await all_collection.create_item(body=document_definition) @@ -603,7 +603,7 @@ async def test_partitioned_collection_permissions(self): async def test_partitioned_collection_execute_stored_procedure(self): created_db = self.databaseForTest - created_collection = self.configs.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION_WITH_CUSTOM_PK_PARTITION_KEY, PartitionKey(path="/pk")) sproc = { 'id': 'storedProcedure' + str(uuid.uuid4()), @@ -696,7 +696,7 @@ async def test_partitioned_collection_partition_key_value_types(self): async def test_partitioned_collection_conflict_crud_and_query(self): created_db = self.databaseForTest - created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) conflict_definition = {'id': 'new conflict', 'resourceId': 'doc1', @@ -755,7 +755,7 @@ async def test_document_crud(self): # create database created_db = self.databaseForTest # create collection - created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # read documents documents = [document async for document in created_collection.read_all_items()] # create a document @@ -845,7 +845,7 @@ async def test_document_crud(self): # should fail if only etag specified with self.assertRaises(ValueError): - created_collection.replace_item( + await created_collection.replace_item( etag=replaced_document['_etag'], item=replaced_document['id'], body=replaced_document @@ -853,13 +853,13 @@ async def test_document_crud(self): # should fail if only match condition specified with self.assertRaises(ValueError): - created_collection.replace_item( + await created_collection.replace_item( match_condition=MatchConditions.IfNotModified, item=replaced_document['id'], body=replaced_document ) with self.assertRaises(ValueError): - created_collection.replace_item( + await created_collection.replace_item( match_condition=MatchConditions.IfModified, item=replaced_document['id'], body=replaced_document @@ -867,7 +867,7 @@ async def test_document_crud(self): # should fail if invalid match condition specified with self.assertRaises(TypeError): - created_collection.replace_item( + await created_collection.replace_item( match_condition=replaced_document['_etag'], item=replaced_document['id'], body=replaced_document @@ -912,7 +912,7 @@ async def test_document_upsert(self): created_db = self.databaseForTest # create collection - created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # read documents and check count documents = [document async for document in created_collection.read_all_items()] @@ -1093,12 +1093,12 @@ async def test_user_crud(self): replaced_user.id, 'user id should stay the same') # read user - user = await db.get_user_client(replaced_user.id) + user = db.get_user_client(replaced_user.id) self.assertEqual(replaced_user.id, user.id) # delete user await db.delete_user(user.id) # read user after deletion - deleted_user = await db.get_user_client(user.id) + deleted_user = db.get_user_client(user.id) await self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, deleted_user.read) @@ -1188,7 +1188,7 @@ async def test_permission_crud(self): self.assertTrue(results) # replace permission - change_permission = await permission.properties.copy() + change_permission = permission.properties.copy() permission.properties['id'] = 'replaced permission' permission.id = permission.properties['id'] replaced_permission = user.replace_permission(change_permission['id'], permission.properties) @@ -1347,27 +1347,27 @@ async def __SetupEntities(client): return entities # Client without any authorization will fail. - client = cosmos_client.CosmosClient(CRUDTests.host, {}, "Session", connection_policy=CRUDTests.connectionPolicy) + async with cosmos_client.CosmosClient(CRUDTests.host, {}, "Session", connection_policy=CRUDTests.connectionPolicy) as client: print('Async Initialization') await self.__AssertHTTPFailureWithStatus(StatusCodes.UNAUTHORIZED, list, client.list_databases()) # Client with master key. - client = cosmos_client.CosmosClient(CRUDTests.host, + async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", - connection_policy=CRUDTests.connectionPolicy) + connection_policy=CRUDTests.connectionPolicy) as client: print('Async Initialization') # setup entities entities = await __SetupEntities(client) resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: entities['permissionOnColl'].properties['_token']} - col_client = cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) + async with cosmos_client.CosmosClient( + CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) as col_client: print('Async Initialization') db = entities['db'] - old_client_connection = await db.client_connection - db.client_connection = await col_client.client_connection + old_client_connection = db.client_connection + db.client_connection = col_client.client_connection # 1. Success-- Use Col Permission to Read - success_coll = await db.get_container_client(container=entities['coll']) + success_coll = db.get_container_client(container=entities['coll']) # 2. Failure-- Use Col Permission to delete await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, db.delete_container, @@ -1400,8 +1400,8 @@ async def __SetupEntities(client): resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId: entities['permissionOnDoc'].properties['_token']} - doc_client = cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) + async with cosmos_client.CosmosClient( + CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) as doc_client: print('Async Initialization') # 6. Success-- Use Doc permission to read doc read_doc = await doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, @@ -1412,14 +1412,14 @@ async def __SetupEntities(client): await doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) self.assertEqual(read_doc["id"], docId) - db.client_connection = await old_client_connection - db.delete_container(entities['coll']) + db.client_connection = old_client_connection + await db.delete_container(entities['coll']) async def test_trigger_crud(self): # create database db = self.databaseForTest # create collection - collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # read triggers triggers = [trigger async for trigger in collection.scripts.list_triggers()] # create a trigger @@ -1456,7 +1456,7 @@ async def test_trigger_crud(self): self.assertTrue(triggers) # replace trigger - change_trigger = await trigger.copy() + change_trigger = trigger.copy() trigger['body'] = 'function() {var x = 20;}' replaced_trigger = await collection.scripts.replace_trigger(change_trigger['id'], trigger) for property in trigger_definition: @@ -1483,7 +1483,7 @@ async def test_udf_crud(self): # create database db = self.databaseForTest # create collection - collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # read udfs udfs = [udf async for udf in collection.scripts.list_user_defined_functions()] # create a udf @@ -1513,7 +1513,7 @@ async def test_udf_crud(self): )] self.assertTrue(results) # replace udf - change_udf = await udf.copy() + change_udf = udf.copy() udf['body'] = 'function() {var x = 20;}' replaced_udf = await collection.scripts.replace_user_defined_function(udf=udf['id'], body=udf) for property in udf_definition: @@ -1535,7 +1535,7 @@ async def test_sproc_crud(self): # create database db = self.databaseForTest # create collection - collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # read sprocs sprocs = [sproc async for sproc in collection.scripts.list_stored_procedures()] # create a sproc @@ -1568,7 +1568,7 @@ async def test_sproc_crud(self): )] self.assertIsNotNone(sprocs) # replace sproc - change_sproc = await sproc.copy() + change_sproc = sproc.copy() sproc['body'] = 'function() {var x = 20;}' replaced_sproc = await collection.scripts.replace_stored_procedure(sproc=change_sproc['id'], body=sproc) for property in sproc_definition: @@ -1593,7 +1593,7 @@ async def test_sproc_crud(self): async def test_script_logging_execute_stored_procedure(self): created_db = self.databaseForTest - created_collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) sproc = { 'id': 'storedProcedure' + str(uuid.uuid4()), @@ -1652,7 +1652,7 @@ async def test_collection_indexing_policy(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = await collection.read() + collection_properties = collection.read() self.assertEqual(collection_properties['indexingPolicy']['indexingMode'], documents.IndexingMode.Consistent, 'default indexing mode should be consistent') @@ -1667,7 +1667,7 @@ async def test_collection_indexing_policy(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - consistent_collection_properties = await consistent_collection.read() + consistent_collection_properties = consistent_collection.read() self.assertEqual(consistent_collection_properties['indexingPolicy']['indexingMode'], documents.IndexingMode.Consistent, 'indexing mode should be consistent') @@ -1718,7 +1718,7 @@ async def test_create_default_indexing_policy(self): id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = await collection.read() + collection_properties = collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1730,7 +1730,7 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = await collection.read() + collection_properties = collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1740,7 +1740,7 @@ async def test_create_default_indexing_policy(self): indexing_policy={}, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = await collection.read() + collection_properties = collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1756,7 +1756,7 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = await collection.read() + collection_properties = collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1782,7 +1782,7 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = await collection.read() + collection_properties = collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1854,8 +1854,8 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): logging_enable=True, logger=custom_logger, ) - created_properties = await created_container.read(logger=custom_logger) - read_indexing_policy = await created_properties['indexingPolicy'] + created_properties = created_container.read(logger=custom_logger) + read_indexing_policy = created_properties['indexingPolicy'] if 'localhost' in self.host or '127.0.0.1' in self.host: # TODO: Differing result between live and emulator self.assertListEqual(indexing_policy['spatialIndexes'], read_indexing_policy['spatialIndexes']) @@ -1878,7 +1878,7 @@ async def __get_first(array): # included paths should be 1: '/'. self.assertEqual(1, len(indexing_policy['includedPaths'])) - root_included_path = await __get_first([included_path for included_path in indexing_policy['includedPaths'] + root_included_path = __get_first([included_path for included_path in indexing_policy['includedPaths'] if included_path['path'] == '/*']) self.assertFalse(root_included_path.get('indexes')) @@ -1891,8 +1891,8 @@ async def test_client_request_timeout(self): with self.assertRaises(Exception): # client does a getDatabaseAccount on initialization, which will time out - cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", - connection_policy=connection_policy) + async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", + connection_policy=connection_policy) as client: print('Async initialization') async def test_client_request_timeout_when_connection_retry_configuration_specified(self): connection_policy = documents.ConnectionPolicy() @@ -1907,8 +1907,8 @@ async def test_client_request_timeout_when_connection_retry_configuration_specif ) with self.assertRaises(AzureError): # client does a getDatabaseAccount on initialization, which will time out - cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", - connection_policy=connection_policy) + async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", + connection_policy=connection_policy) as client: print('Async Initialization') async def test_client_connection_retry_configuration(self): total_time_for_two_retries = await self.initialize_client_with_connection_urllib_retry_config(2) @@ -1929,11 +1929,11 @@ async def initialize_client_with_connection_urllib_retry_config(self, retries): ) start_time = time.time() try: - cosmos_client.CosmosClient( + async with cosmos_client.CosmosClient( "https://localhost:9999", CRUDTests.masterKey, "Session", - connection_retry_policy=retry_policy) + connection_retry_policy=retry_policy) as client: print('Async initialization') self.fail() except AzureError as e: end_time = time.time() @@ -1942,14 +1942,14 @@ async def initialize_client_with_connection_urllib_retry_config(self, retries): async def initialize_client_with_connection_core_retry_config(self, retries): start_time = time.time() try: - cosmos_client.CosmosClient( + async with cosmos_client.CosmosClient( "https://localhost:9999", CRUDTests.masterKey, "Session", retry_total=retries, retry_read=retries, retry_connect=retries, - retry_status=retries) + retry_status=retries) as client: print('Async initialization') self.fail() except AzureError as e: end_time = time.time() @@ -1957,25 +1957,26 @@ async def initialize_client_with_connection_core_retry_config(self, retries): async def test_absolute_client_timeout(self): with self.assertRaises(exceptions.CosmosClientTimeoutError): - cosmos_client.CosmosClient( + async with cosmos_client.CosmosClient( "https://localhost:9999", CRUDTests.masterKey, "Session", retry_total=3, - timeout=1) + timeout=1)as client: print('Async initialization') error_response = ServiceResponseError("Read timeout") timeout_transport = TimeoutTransport(error_response) - client = cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + async with cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) as client: print('Async initialization') with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database_if_not_exists("test", timeout=2) status_response = 500 # Users connection level retry timeout_transport = TimeoutTransport(status_response) - client = cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + async with cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) as client: print( + 'Async initialization') with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database("test", timeout=2) @@ -1985,8 +1986,9 @@ async def test_absolute_client_timeout(self): status_response = 429 # Uses Cosmos custom retry timeout_transport = TimeoutTransport(status_response) - client = cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + async with cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) as client: print( + 'Async initialization') with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database_if_not_exists("test", timeout=2) @@ -2005,7 +2007,7 @@ async def __create_resources(client): dict """ - collection = self.configs.create_multi_partition_collection_with_custom_pk_if_not_exist(self.client) + collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION_WITH_CUSTOM_PK_PARTITION_KEY, PartitionKey(path="/pk")) doc1 = await collection.create_item(body={'id': 'doc1', 'prop1': 'value1'}) doc2 = await collection.create_item(body={'id': 'doc2', 'prop1': 'value2'}) doc3 = await collection.create_item(body={'id': 'doc3', 'prop1': 'value3'}) @@ -2019,7 +2021,7 @@ async def __create_resources(client): # Validate QueryIterable by converting it to a list. resources = await __create_resources(self.client) - results = await resources['coll'].read_all_items(max_item_count=2) + results = resources['coll'].read_all_items(max_item_count=2) docs = [doc async for doc in results] self.assertEqual(3, len(docs), @@ -2030,10 +2032,10 @@ async def __create_resources(client): self.assertEqual(resources['doc3']['id'], docs[2]['id']) # Validate QueryIterable iterator with 'for'. - results = await resources['coll'].read_all_items(max_item_count=2) + results = resources['coll'].read_all_items(max_item_count=2) counter = 0 # test QueryIterable with 'for'. - for doc in iter(results): + async for doc in iter(results): counter += 1 if counter == 1: self.assertEqual(resources['doc1']['id'], @@ -2050,9 +2052,9 @@ async def __create_resources(client): self.assertEqual(counter, 3) # Get query results page by page. - results = await resources['coll'].read_all_items(max_item_count=2) + results = resources['coll'].read_all_items(max_item_count=2) - page_iter = await results.by_page() + page_iter = results.by_page() first_block = [page async for page in next(page_iter)] self.assertEqual(2, len(first_block), 'First block should have 2 entries.') self.assertEqual(resources['doc1']['id'], first_block[0]['id']) @@ -2137,9 +2139,9 @@ async def __CreateTriggers(collection, triggers): - `collection`: dict """ - for trigger_i in triggers: + async for trigger_i in triggers: trigger = await collection.scripts.create_trigger(body=trigger_i) - for property in trigger_i: + async for property in trigger_i: self.assertEqual( trigger[property], trigger_i[property], @@ -2204,7 +2206,7 @@ async def __CreateTriggers(collection, triggers): triggers_3 = [trigger async for trigger in collection3.scripts.list_triggers()] self.assertEqual(len(triggers_3), 1) with self.assertRaises(Exception): - collection3.create_item( + await collection3.create_item( body={'id': 'Docoptype', 'key': 'value2'}, post_trigger_include='triggerOpType' ) @@ -2217,7 +2219,7 @@ async def test_stored_procedure_functionality(self): # create database db = self.databaseForTest # create collection - collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) sproc1 = { 'id': 'storedProcedure1' + str(uuid.uuid4()), @@ -2303,7 +2305,7 @@ async def test_offer_replace(self): # Create database. db = self.databaseForTest # Create collection. - collection = self.configs.create_multi_partition_collection_if_not_exist(self.client) + collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # Read Offer expected_offer = await collection.read_offer() collection_properties = await collection.read() @@ -2343,7 +2345,7 @@ async def test_index_progress_headers(self): id='test_index_progress_headers consistent_coll ' + str(uuid.uuid4()), partition_key=PartitionKey(path="/id", kind='Hash'), ) - created_container = await created_db.get_container_client(container=consistent_coll) + created_container = created_db.get_container_client(container=consistent_coll) await created_container.read(populate_quota_info=True) self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) @@ -2356,7 +2358,7 @@ async def test_index_progress_headers(self): }, partition_key=PartitionKey(path="/id", kind='Hash') ) - created_container = await created_db.get_container_client(container=none_coll) + created_container = created_db.get_container_client(container=none_coll) await created_container.read(populate_quota_info=True) self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) @@ -2489,19 +2491,19 @@ async def test_get_resource_with_dictionary_and_object(self): read_db = self.client.get_database_client(created_db.read()) self.assertEqual(read_db.id, created_db.id) - created_container = self.configs.create_multi_partition_collection_if_not_exist(self.client) + created_container = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # read container with id - read_container = await created_db.get_container_client(created_container.id) + read_container = created_db.get_container_client(created_container.id) self.assertEqual(read_container.id, created_container.id) # read container with instance - read_container = await created_db.get_container_client(created_container) + read_container = created_db.get_container_client(created_container) self.assertEqual(read_container.id, created_container.id) # read container with properties created_properties = await created_container.read() - read_container = await created_db.get_container_client(created_properties) + read_container = created_db.get_container_client(created_properties) self.assertEqual(read_container.id, created_container.id) created_item = await created_container.create_item({'id': '1' + str(uuid.uuid4())}) @@ -2560,16 +2562,16 @@ async def test_get_resource_with_dictionary_and_object(self): }) # read user with id - read_user = await created_db.get_user_client(created_user.id) + read_user = created_db.get_user_client(created_user.id) self.assertEqual(read_user.id, created_user.id) # read user with instance - read_user = await created_db.get_user_client(created_user) + read_user = created_db.get_user_client(created_user) self.assertEqual(read_user.id, created_user.id) # read user with properties created_user_properties = await created_user.read() - read_user = await created_db.get_user_client(created_user_properties) + read_user = created_db.get_user_client(created_user_properties) self.assertEqual(read_user.id, created_user.id) created_permission = await created_user.create_permission({ From a5301d2d186f73a13c3ddedc442136517e3d9071 Mon Sep 17 00:00:00 2001 From: simorenoh Date: Tue, 3 May 2022 18:19:12 -0400 Subject: [PATCH 7/8] Update test_crud_async.py --- .../azure-cosmos/test/test_crud_async.py | 338 ++++++++---------- 1 file changed, 156 insertions(+), 182 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index 249b2914d703..985111a250b9 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # The MIT License (MIT) -# Copyright (c) 2014 Microsoft Corporation +# Copyright (c) 2022 Microsoft Corporation # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,7 +22,6 @@ """End to end test. """ -import asyncio import json import logging import os.path @@ -43,7 +42,7 @@ import azure.cosmos.aio.cosmos_client as cosmos_client from azure.cosmos.diagnostics import RecordDiagnostics from azure.cosmos.partition_key import PartitionKey -from azure.cosmos import _retry_utility +from azure.cosmos.aio import _retry_utility_async import requests from urllib3.util.retry import Retry @@ -99,13 +98,9 @@ async def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs @classmethod async def setUpClass(cls): - cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, connection_policy=cls.connectionPolicy) + cls.client = cosmos_client.CosmosClient(cls.host, cls.masterKey, consistency_level="Session", connection_policy=cls.connectionPolicy) cls.databaseForTest = await cls.client.create_database_if_not_exists(test_config._test_config.TEST_DATABASE_ID) - async def setUp(self): - self.client = cosmos_client.CosmosClient(self.host, self.masterKey, "Session", - connection_policy=self.connectionPolicy) - async def test_database_crud(self): # read databases. databases = [database async for database in self.client.list_databases()] @@ -215,7 +210,7 @@ async def test_collection_crud(self): assert isinstance(created_recorder.body, Mapping) assert 'id' in created_recorder.body - created_properties = created_collection.read() + created_properties = await created_collection.read() self.assertEqual('consistent', created_properties['indexingPolicy']['indexingMode']) # read collections after creation @@ -316,11 +311,11 @@ async def test_partitioned_collection_partition_key_extraction(self): } } - self.OriginalExecuteFunction = _retry_utility.ExecuteFunction - _retry_utility.ExecuteFunction = self._MockExecuteFunction + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunction # create document without partition key being specified created_document = await created_collection.create_item(body=document_definition) - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], '["WA"]') del self.last_headers[:] @@ -333,27 +328,25 @@ async def test_partitioned_collection_partition_key_extraction(self): partition_key=PartitionKey(path='/address', kind=documents.PartitionKind.Hash) ) - self.OriginalExecuteFunction = _retry_utility.ExecuteFunction - _retry_utility.ExecuteFunction = self._MockExecuteFunction + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunction # Create document with partitionkey not present as a leaf level property but a dict - created_document = created_collection1.create_item(document_definition) - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + created_document = await created_collection1.create_item(document_definition) + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], [{}]) del self.last_headers[:] - # self.assertEqual(options['partitionKey'], documents.Undefined) - collection_id = 'test_partitioned_collection_partition_key_extraction2 ' + str(uuid.uuid4()) created_collection2 = await created_db.create_container( id=collection_id, partition_key=PartitionKey(path='/address/state/city', kind=documents.PartitionKind.Hash) ) - self.OriginalExecuteFunction = _retry_utility.ExecuteFunction - _retry_utility.ExecuteFunction = self._MockExecuteFunction + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunction # Create document with partitionkey not present in the document created_document = await created_collection2.create_item(document_definition) - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], [{}]) del self.last_headers[:] @@ -376,22 +369,13 @@ async def test_partitioned_collection_partition_key_extraction_special_chars(sel "level' 1*()": {"le/vel2": 'val1'} } - self.OriginalExecuteFunction = _retry_utility.ExecuteFunction - _retry_utility.ExecuteFunction = self._MockExecuteFunction - created_document = created_collection1.create_item(body=document_definition) - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunction + created_document = await created_collection1.create_item(body=document_definition) + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], '["val1"]') del self.last_headers[:] - collection_definition2 = { - 'id': 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()), - 'partitionKey': - { - 'paths': ['/\'level\" 1*()\'/\'le/vel2\''], - 'kind': documents.PartitionKind.Hash - } - } - collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()) created_collection2 = await created_db.create_container( @@ -403,11 +387,11 @@ async def test_partitioned_collection_partition_key_extraction_special_chars(sel 'level\" 1*()': {'le/vel2': 'val2'} } - self.OriginalExecuteFunction = _retry_utility.ExecuteFunction - _retry_utility.ExecuteFunction = self._MockExecuteFunction + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunction # create document without partition key being specified created_document = await created_collection2.create_item(body=document_definition) - _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction self.assertEqual(self.last_headers[1], '["val2"]') del self.last_headers[:] @@ -433,14 +417,12 @@ async def test_partitioned_collection_path_parser(self): async def test_partitioned_collection_document_crud_and_query(self): created_db = self.databaseForTest - created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) + created_collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION, PartitionKey(path="/id")) document_definition = {'id': 'document', 'key': 'value'} - created_document = await created_collection.create_item( - body=document_definition - ) + created_document = await created_collection.create_item(body=document_definition) self.assertEqual(created_document.get('id'), document_definition.get('id')) self.assertEqual(created_document.get('key'), document_definition.get('key')) @@ -485,7 +467,6 @@ async def test_partitioned_collection_document_crud_and_query(self): # query document on the partition key specified in the predicate will pass even without setting enableCrossPartitionQuery or passing in the partitionKey value documentlist = [document async for document in created_collection.query_items( - query='SELECT * FROM root r WHERE r.id=\'' + replaced_document.get('id') + '\'' # nosec )] self.assertEqual(1, len(documentlist)) @@ -493,7 +474,6 @@ async def test_partitioned_collection_document_crud_and_query(self): # query document on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value try: [document async for document in created_collection.query_items( - query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'' # nosec )] except Exception: @@ -555,50 +535,50 @@ async def test_partitioned_collection_permissions(self): resource_tokens = {} # storing the resource tokens based on Resource IDs resource_tokens["dbs/" + created_db.id + "/colls/" + all_collection.id] = (all_permission.properties['_token']) - resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = ( - read_permission.properties['_token']) + resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = (read_permission.properties['_token']) async with cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) as restricted_client: print('Async Initialization') - - document_definition = {'id': 'document1', - 'key': 1 - } - - all_collection.client_connection = restricted_client.client_connection - read_collection.client_connection = restricted_client.client_connection - - # Create document in all_collection should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object and it has all permissions - created_document = await all_collection.create_item(body=document_definition) - - # Create document in read_collection should fail since it has only read permissions for this collection - await self.__AssertHTTPFailureWithStatus( - StatusCodes.FORBIDDEN, - read_collection.create_item, - document_definition) - - document_definition['key'] = 2 - # Create document should fail since the partitionKey is 2 which is different that what is specified as resourcePartitionKey in permission object - await self.__AssertHTTPFailureWithStatus( - StatusCodes.FORBIDDEN, - all_collection.create_item, - document_definition) - - document_definition['key'] = 1 - # Delete document should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object - created_document = all_collection.delete_item(item=created_document['id'], - partition_key=document_definition['key']) - - # Delete document in read_collection should fail since it has only read permissions for this collection - await self.__AssertHTTPFailureWithStatus( - StatusCodes.FORBIDDEN, - read_collection.delete_item, - document_definition['id'], - document_definition['id'] - ) + CRUDTests.host, resource_tokens, consistency_level="Session", connection_policy=CRUDTests.connectionPolicy) as restricted_client: + print('Async Initialization') + + document_definition = {'id': 'document1', + 'key': 1 + } + + all_collection.client_connection = restricted_client.client_connection + read_collection.client_connection = restricted_client.client_connection + + # Create document in all_collection should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object and it has all permissions + created_document = await all_collection.create_item(body=document_definition) + + # Create document in read_collection should fail since it has only read permissions for this collection + await self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + read_collection.create_item, + document_definition) + + document_definition['key'] = 2 + # Create document should fail since the partitionKey is 2 which is different that what is specified as resourcePartitionKey in permission object + await self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + all_collection.create_item, + document_definition) + + document_definition['key'] = 1 + # Delete document should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object + created_document = await all_collection.delete_item(item=created_document['id'], + partition_key=document_definition['key']) + + # Delete document in read_collection should fail since it has only read permissions for this collection + await self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + read_collection.delete_item, + document_definition['id'], + document_definition['id'] + ) - await created_db.delete_container(all_collection) - await created_db.delete_container(read_collection) + await created_db.delete_container(all_collection) + await created_db.delete_container(read_collection) async def test_partitioned_collection_execute_stored_procedure(self): created_db = self.databaseForTest @@ -707,7 +687,7 @@ async def test_partitioned_collection_conflict_crud_and_query(self): # read conflict here will return resource not found(404) since there is no conflict here await self.__AssertHTTPFailureWithStatus( StatusCodes.NOT_FOUND, - created_collection.get_conflict, + created_collection.read_conflict, conflict_definition['id'], conflict_definition['id'] ) @@ -1347,73 +1327,75 @@ async def __SetupEntities(client): return entities # Client without any authorization will fail. - async with cosmos_client.CosmosClient(CRUDTests.host, {}, "Session", connection_policy=CRUDTests.connectionPolicy) as client: print('Async Initialization') - await self.__AssertHTTPFailureWithStatus(StatusCodes.UNAUTHORIZED, - list, - client.list_databases()) + async with cosmos_client.CosmosClient(CRUDTests.host, {}, consistency_level="Session", connection_policy=CRUDTests.connectionPolicy) as client: + try: + db_list = [db async for db in client.list_databases()] + except exceptions.CosmosHttpResponseError as e: + assert e.status_code == 401 + # Client with master key. async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, - "Session", - connection_policy=CRUDTests.connectionPolicy) as client: print('Async Initialization') - # setup entities - entities = await __SetupEntities(client) - resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: - entities['permissionOnColl'].properties['_token']} + consistency_level="Session", + connection_policy=CRUDTests.connectionPolicy) as client: + # setup entities + entities = await __SetupEntities(client) + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: + entities['permissionOnColl'].properties['_token']} + async with cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) as col_client: print('Async Initialization') - db = entities['db'] - - old_client_connection = db.client_connection - db.client_connection = col_client.client_connection - # 1. Success-- Use Col Permission to Read - success_coll = db.get_container_client(container=entities['coll']) - # 2. Failure-- Use Col Permission to delete - await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, - db.delete_container, - success_coll) - # 3. Success-- Use Col Permission to Read All Docs - success_documents = [document async for document in success_coll.read_all_items()] - self.assertTrue(success_documents != None, - 'error reading documents') - self.assertEqual(len(success_documents), - 1, - 'Expected 1 Document to be succesfully read') - # 4. Success-- Use Col Permission to Read Doc - - docId = entities['doc']['id'] - success_doc = await success_coll.read_item( - item=docId, - partition_key=docId - ) - self.assertTrue(success_doc != None, 'error reading document') - self.assertEqual( - success_doc['id'], - entities['doc']['id'], - 'Expected to read children using parent permissions') + CRUDTests.host, resource_tokens, consistency_level="Session", connection_policy=CRUDTests.connectionPolicy) as col_client: + db = entities['db'] + + old_client_connection = db.client_connection + db.client_connection = col_client.client_connection + # 1. Success-- Use Col Permission to Read + success_coll = db.get_container_client(container=entities['coll']) + # 2. Failure-- Use Col Permission to delete + await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, + db.delete_container, + success_coll) + # 3. Success-- Use Col Permission to Read All Docs + success_documents = [document async for document in success_coll.read_all_items()] + self.assertTrue(success_documents != None, + 'error reading documents') + self.assertEqual(len(success_documents), + 1, + 'Expected 1 Document to be succesfully read') + # 4. Success-- Use Col Permission to Read Doc + + docId = entities['doc']['id'] + success_doc = await success_coll.read_item( + item=docId, + partition_key=docId + ) + self.assertTrue(success_doc != None, 'error reading document') + self.assertEqual( + success_doc['id'], + entities['doc']['id'], + 'Expected to read children using parent permissions') - # 5. Failure-- Use Col Permission to Delete Doc - await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, - success_coll.delete_item, - docId, docId) + # 5. Failure-- Use Col Permission to Delete Doc + await self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, + success_coll.delete_item, + docId, docId) - resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId: - entities['permissionOnDoc'].properties['_token']} + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId: + entities['permissionOnDoc'].properties['_token']} async with cosmos_client.CosmosClient( - CRUDTests.host, resource_tokens, "Session", connection_policy=CRUDTests.connectionPolicy) as doc_client: print('Async Initialization') + CRUDTests.host, resource_tokens, consistency_level="Session", connection_policy=CRUDTests.connectionPolicy) as doc_client: - # 6. Success-- Use Doc permission to read doc - read_doc = await doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, - docId) - self.assertEqual(read_doc["id"], docId) + # 6. Success-- Use Doc permission to read doc + read_doc = await doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, docId) + self.assertEqual(read_doc["id"], docId) - # 6. Success-- Use Doc permission to delete doc - await doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) - self.assertEqual(read_doc["id"], docId) + # 6. Success-- Use Doc permission to delete doc + await doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) + self.assertEqual(read_doc["id"], docId) - db.client_connection = old_client_connection - await db.delete_container(entities['coll']) + db.client_connection = old_client_connection + await db.delete_container(entities['coll']) async def test_trigger_crud(self): # create database @@ -1652,7 +1634,7 @@ async def test_collection_indexing_policy(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() self.assertEqual(collection_properties['indexingPolicy']['indexingMode'], documents.IndexingMode.Consistent, 'default indexing mode should be consistent') @@ -1667,7 +1649,7 @@ async def test_collection_indexing_policy(self): partition_key=PartitionKey(path='/id', kind='Hash') ) - consistent_collection_properties = consistent_collection.read() + consistent_collection_properties = await consistent_collection.read() self.assertEqual(consistent_collection_properties['indexingPolicy']['indexingMode'], documents.IndexingMode.Consistent, 'indexing mode should be consistent') @@ -1718,7 +1700,7 @@ async def test_create_default_indexing_policy(self): id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1730,7 +1712,7 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1740,7 +1722,7 @@ async def test_create_default_indexing_policy(self): indexing_policy={}, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1756,7 +1738,7 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1782,7 +1764,7 @@ async def test_create_default_indexing_policy(self): }, partition_key=PartitionKey(path='/id', kind='Hash') ) - collection_properties = collection.read() + collection_properties = await collection.read() await self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) await db.delete_container(container=collection) @@ -1854,7 +1836,7 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): logging_enable=True, logger=custom_logger, ) - created_properties = created_container.read(logger=custom_logger) + created_properties = await created_container.read(logger=custom_logger) read_indexing_policy = created_properties['indexingPolicy'] if 'localhost' in self.host or '127.0.0.1' in self.host: # TODO: Differing result between live and emulator @@ -1867,7 +1849,7 @@ async def test_create_indexing_policy_with_composite_and_spatial_indexes(self): await db.delete_container(container=created_container) async def _check_default_indexing_policy_paths(self, indexing_policy): - async def __get_first(array): + def __get_first(array): if array: return array[0] else: @@ -1891,8 +1873,9 @@ async def test_client_request_timeout(self): with self.assertRaises(Exception): # client does a getDatabaseAccount on initialization, which will time out - async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", - connection_policy=connection_policy) as client: print('Async initialization') + async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, consistency_level="Session", + connection_policy=connection_policy) as client: + print('Async initialization') async def test_client_request_timeout_when_connection_retry_configuration_specified(self): connection_policy = documents.ConnectionPolicy() @@ -1907,8 +1890,9 @@ async def test_client_request_timeout_when_connection_retry_configuration_specif ) with self.assertRaises(AzureError): # client does a getDatabaseAccount on initialization, which will time out - async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, "Session", - connection_policy=connection_policy) as client: print('Async Initialization') + async with cosmos_client.CosmosClient(CRUDTests.host, CRUDTests.masterKey, consistency_level="Session", + connection_policy=connection_policy) as client: + print('Async Initialization') async def test_client_connection_retry_configuration(self): total_time_for_two_retries = await self.initialize_client_with_connection_urllib_retry_config(2) @@ -1932,7 +1916,7 @@ async def initialize_client_with_connection_urllib_retry_config(self, retries): async with cosmos_client.CosmosClient( "https://localhost:9999", CRUDTests.masterKey, - "Session", + consistency_level="Session", connection_retry_policy=retry_policy) as client: print('Async initialization') self.fail() except AzureError as e: @@ -1945,7 +1929,7 @@ async def initialize_client_with_connection_core_retry_config(self, retries): async with cosmos_client.CosmosClient( "https://localhost:9999", CRUDTests.masterKey, - "Session", + consistency_level="Session", retry_total=retries, retry_read=retries, retry_connect=retries, @@ -1960,14 +1944,14 @@ async def test_absolute_client_timeout(self): async with cosmos_client.CosmosClient( "https://localhost:9999", CRUDTests.masterKey, - "Session", + consistency_level="Session", retry_total=3, timeout=1)as client: print('Async initialization') error_response = ServiceResponseError("Read timeout") timeout_transport = TimeoutTransport(error_response) async with cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) as client: print('Async initialization') + self.host, self.masterKey, consistency_level="Session", transport=timeout_transport, passthrough=True) as client: print('Async initialization') with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database_if_not_exists("test", timeout=2) @@ -1975,7 +1959,7 @@ async def test_absolute_client_timeout(self): status_response = 500 # Users connection level retry timeout_transport = TimeoutTransport(status_response) async with cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) as client: print( + self.host, self.masterKey, consistency_level="Session", transport=timeout_transport, passthrough=True) as client: print( 'Async initialization') with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database("test", timeout=2) @@ -1987,7 +1971,7 @@ async def test_absolute_client_timeout(self): status_response = 429 # Uses Cosmos custom retry timeout_transport = TimeoutTransport(status_response) async with cosmos_client.CosmosClient( - self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) as client: print( + self.host, self.masterKey, consistency_level="Session", transport=timeout_transport, passthrough=True) as client: print( 'Async initialization') with self.assertRaises(exceptions.CosmosClientTimeoutError): await client.create_database_if_not_exists("test", timeout=2) @@ -2035,7 +2019,7 @@ async def __create_resources(client): results = resources['coll'].read_all_items(max_item_count=2) counter = 0 # test QueryIterable with 'for'. - async for doc in iter(results): + async for doc in results: counter += 1 if counter == 1: self.assertEqual(resources['doc1']['id'], @@ -2322,7 +2306,7 @@ async def test_offer_replace(self): async def test_database_account_functionality(self): # Validate database account functionality. - database_account = await self.client.get_database_account() + database_account = await self.client._get_database_account() self.assertEqual(database_account.DatabasesLink, '/dbs/') self.assertEqual(database_account.MediaLink, '/media/') if (HttpHeaders.MaxMediaStorageUsageInMB in @@ -2445,8 +2429,6 @@ async def test_id_case_validation(self): await created_db.delete_container(created_collection1) await created_db.delete_container(created_collection2) - # TODO: fix test - @pytest.mark.skip async def test_id_unicode_validation(self): # create database created_db = self.databaseForTest @@ -2470,11 +2452,11 @@ async def test_id_unicode_validation(self): self.assertEqual(collection_id1, created_collection1.id) self.assertEqual(collection_id2, created_collection2.id) - created_collection1_properties = created_collection1.read() - created_collection2_properties = created_collection2.read() + created_collection1_properties = await created_collection1.read() + created_collection2_properties = await created_collection2.read() - await created_db.client_connection.DeleteContainer(created_collection1_properties['_self']) - await created_db.client_connection.DeleteContainer(created_collection2_properties['_self']) + await created_db.delete_container(created_collection1_properties) + await created_db.delete_container(created_collection2_properties) async def test_get_resource_with_dictionary_and_object(self): created_db = self.databaseForTest @@ -2488,10 +2470,10 @@ async def test_get_resource_with_dictionary_and_object(self): self.assertEqual(read_db.id, created_db.id) # read database with properties - read_db = self.client.get_database_client(created_db.read()) + read_db = self.client.get_database_client(await created_db.read()) self.assertEqual(read_db.id, created_db.id) - created_container = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) + created_container = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION, PartitionKey(path="/id")) # read container with id read_container = created_db.get_container_client(created_container.id) @@ -2537,11 +2519,11 @@ async def test_get_resource_with_dictionary_and_object(self): }) # read trigger with id - read_trigger = created_container.scripts.get_trigger(created_trigger['id']) + read_trigger = await created_container.scripts.get_trigger(created_trigger['id']) self.assertEqual(read_trigger['id'], created_trigger['id']) # read trigger with properties - read_trigger = created_container.scripts.get_trigger(created_trigger) + read_trigger = await created_container.scripts.get_trigger(created_trigger) self.assertEqual(read_trigger['id'], created_trigger['id']) created_udf = await created_container.scripts.create_user_defined_function({ @@ -2658,16 +2640,8 @@ async def test_get_resource_with_dictionary_and_object(self): # ttl_key = "analyticalStorageTtl" # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) - def _MockExecuteFunction(self, function, *args, **kwargs): + async def _MockExecuteFunction(self, function, *args, **kwargs): self.last_headers.append(args[4].headers[HttpHeaders.PartitionKey] if HttpHeaders.PartitionKey in args[4].headers else '') - return self.OriginalExecuteFunction(function, *args, **kwargs) - + return await self.OriginalExecuteFunction(function, *args, **kwargs) -if __name__ == '__main__': - try: - loop = asyncio.get_event_loop() - loop.run_until_complete(unittest.main()) - except SystemExit as inst: - if inst.args[0] is True: # raised by sys.exit(True) when tests failed - raise From 831a838507359f4e8821408bbe8227a0554d9b70 Mon Sep 17 00:00:00 2001 From: simorenoh Date: Fri, 13 May 2022 12:00:26 -0400 Subject: [PATCH 8/8] Update test_crud_async.py --- sdk/cosmos/azure-cosmos/test/test_crud_async.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index 985111a250b9..cb1cb586f26d 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -136,11 +136,11 @@ async def test_database_crud(self): database_proxy = await self.client.create_database_if_not_exists(id=database_id, offer_throughput=10000) self.assertEqual(database_id, database_proxy.id) - self.assertEqual(10000, await database_proxy.read_offer().offer_throughput) + self.assertEqual(10000, await database_proxy.get_throughput().offer_throughput) database_proxy = await self.client.create_database_if_not_exists(id=database_id, offer_throughput=9000) self.assertEqual(database_id, database_proxy.id) - self.assertEqual(10000, await database_proxy.read_offer().offer_throughput) + self.assertEqual(10000, await database_proxy.get_throughput().offer_throughput) await self.client.delete_database(database_id) @@ -156,7 +156,7 @@ async def test_database_level_offer_throughput(self): self.assertEqual(created_db.id, database_id) # Verify offer throughput for database - offer = await created_db.read_offer() + offer = await created_db.get_throughput() self.assertEqual(offer.offer_throughput, offer_throughput) # Update database offer throughput @@ -271,7 +271,7 @@ async def test_partitioned_collection(self): self.assertEqual(collection_definition.get('partitionKey').get('kind'), created_collection_properties['partitionKey']['kind']) - expected_offer = await created_collection.read_offer() + expected_offer = await created_collection.get_throughput() self.assertIsNotNone(expected_offer) @@ -2276,7 +2276,7 @@ async def test_offer_read_and_query(self): partition_key=PartitionKey(path='/id', kind='Hash') ) # Read the offer. - expected_offer = await collection.read_offer() + expected_offer = await collection.get_throughput() collection_properties = await collection.read() await self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) @@ -2291,7 +2291,7 @@ async def test_offer_replace(self): # Create collection. collection = await self.databaseForTest.create_container(test_config._test_config.TEST_COLLECTION_MULTI_PARTITION) # Read Offer - expected_offer = await collection.read_offer() + expected_offer = await collection.get_throughput() collection_properties = await collection.read() await self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) # Replace the offer.