Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removing create_configuration and using generated code instead #15908

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 6 additions & 97 deletions sdk/tables/azure-data-tables/azure/data/tables/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,52 +17,27 @@
)
import logging
from uuid import uuid4



import six
try:
from urllib.parse import parse_qs, quote
except ImportError:
from urlparse import parse_qs # type: ignore
from urllib2 import quote # type: ignore

import six
from azure.core.configuration import Configuration
from azure.core.exceptions import (
ClientAuthenticationError,
ResourceNotFoundError
)
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import (
RequestsTransport,
HttpTransport,
HttpRequest,
)
from azure.core.pipeline.policies import (
RedirectPolicy,
ContentDecodePolicy,
BearerTokenCredentialPolicy,
ProxyPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
UserAgentPolicy
)

from ._shared_access_signature import QueryStringConstants
from ._constants import STORAGE_OAUTH_SCOPE, SERVICE_HOST_BASE, CONNECTION_TIMEOUT, READ_TIMEOUT
from ._constants import SERVICE_HOST_BASE
from ._models import LocationMode, BatchTransactionResult
from ._authentication import SharedKeyCredentialPolicy
from ._policies import (
StorageHeadersPolicy,
StorageContentValidation,
StorageRequestHook,
StorageResponseHook,
StorageLoggingPolicy,
StorageHosts,
TablesRetryPolicy,
)
from ._policies import StorageHeadersPolicy
from ._models import BatchErrorException
from ._sdk_moniker import SDK_MONIKER


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -116,7 +91,8 @@ def __init__(
self.require_encryption = kwargs.get("require_encryption", False)
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs)

self._config = kwargs.pop("_config", None)

def __enter__(self):
self._client.__enter__()
Expand Down Expand Up @@ -221,42 +197,6 @@ def _format_query_string(self, sas_token, credential, snapshot=None, share_snaps
credential = None
return query_str.rstrip("?&"), credential

def _create_pipeline(self, credential, **kwargs):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
self._credential_policy = None
if hasattr(credential, "get_token"):
self._credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
elif isinstance(credential, SharedKeyCredentialPolicy):
self._credential_policy = credential
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))

config = kwargs.get("_configuration") or create_configuration(**kwargs)
if kwargs.get("_pipeline"):
return config, kwargs["_pipeline"]
config.transport = kwargs.get("transport") # type: ignore
kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
kwargs.setdefault("read_timeout", READ_TIMEOUT)
if not config.transport:
config.transport = RequestsTransport(**kwargs)
policies = [
config.headers_policy,
config.proxy_policy,
config.user_agent_policy,
StorageContentValidation(),
StorageRequestHook(**kwargs),
self._credential_policy,
ContentDecodePolicy(response_encoding="utf-8"),
RedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
config.retry_policy,
config.logging_policy,
StorageResponseHook(**kwargs),
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs)
]
return config, Pipeline(config.transport, policies=policies)

def _batch_send( # pylint: disable=inconsistent-return-statements
self, entities, # type: List[TableEntity]
*reqs, # type: List[HttpRequest]
Expand Down Expand Up @@ -290,7 +230,7 @@ def _batch_send( # pylint: disable=inconsistent-return-statements
boundary="batch_{}".format(uuid4())
)

pipeline_response = self._pipeline.run(
pipeline_response = self._client._client._pipeline.run( # pylint: disable=protected-access
request, **kwargs
)
response = pipeline_response.http_response
Expand Down Expand Up @@ -415,37 +355,6 @@ def parse_connection_str(conn_str, credential, service, keyword_args):
return primary, credential


def create_configuration(**kwargs):
# type: (**Any) -> Configuration
config = Configuration(**kwargs)
config.headers_policy = StorageHeadersPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs)
# sdk_moniker="storage-{}/{}".format(kwargs.pop('storage_sdk'), VERSION), **kwargs)
config.retry_policy = kwargs.get("retry_policy") or TablesRetryPolicy(**kwargs)
config.logging_policy = StorageLoggingPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)

# Storage settings
config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024)
config.copy_polling_interval = 15

# Block blob uploads
config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024)
config.min_large_block_upload_threshold = kwargs.get("min_large_block_upload_threshold", 4 * 1024 * 1024 + 1)
config.use_byte_buffer = kwargs.get("use_byte_buffer", False)

# Page blob uploads
config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024)

# Blob downloads
config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024)
config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024)

# File uploads
config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024)
return config


def parse_query(query_str):
sas_values = QueryStringConstants.to_list()
parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
Expand Down
4 changes: 1 addition & 3 deletions sdk/tables/azure-data-tables/azure/data/tables/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ def _get_next_cb(self, continuation_token, **kwargs):
return self._command(
query_options=query_options,
next_table_name=continuation_token or None,
cls=kwargs.pop('cls', None) or _return_context_and_deserialized,
use_location=self.location_mode
cls=kwargs.pop('cls', None) or _return_context_and_deserialized
)
except HttpResponseError as error:
_process_table_error(error)
Expand Down Expand Up @@ -352,7 +351,6 @@ def _get_next_cb(self, continuation_token, **kwargs):
next_partition_key=next_partition_key,
table=self.table,
cls=kwargs.pop('cls', None) or _return_context_and_deserialized,
use_location=self.location_mode
)
except HttpResponseError as error:
_process_table_error(error)
Expand Down
24 changes: 19 additions & 5 deletions sdk/tables/azure-data-tables/azure/data/tables/_table_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,30 @@
from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
from azure.core.paging import ItemPaged
from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline.policies import UserAgentPolicy, ProxyPolicy

from ._deserialize import _convert_to_entity, _trim_service_metadata
from ._entity import TableEntity
from ._error import _process_table_error
from ._generated import AzureTable
from ._generated.models import (
# AccessPolicy,
SignedIdentifier,
TableProperties,
)
from ._serialize import _get_match_headers, _add_entity_properties
from ._base_client import parse_connection_str
from ._constants import CONNECTION_TIMEOUT
from ._table_client_base import TableClientBase
from ._serialize import serialize_iso
from ._deserialize import _return_headers_and_deserialized

from ._table_batch import TableBatchOperations
from ._models import TableEntityPropertiesPaged, UpdateMode, AccessPolicy


from ._sdk_moniker import SDK_MONIKER
from ._policies import (
StorageHeadersPolicy,
StorageLoggingPolicy,
TablesRetryPolicy,
)
class TableClient(TableClientBase):
""" :ivar str account_name: Name of the storage account (Cosmos or Azure)"""

Expand Down Expand Up @@ -63,7 +67,17 @@ def __init__(
:returns: None
"""
super(TableClient, self).__init__(account_url, table_name, credential=credential, **kwargs)
self._client = AzureTable(self.url, pipeline=self._pipeline)
self._client = AzureTable(
self.url,
sdk_moniker=SDK_MONIKER,
headers_policy=StorageHeadersPolicy(**kwargs),
user_agent_policy=UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs),
retry_policy=kwargs.pop("retry_policy", None) or TablesRetryPolicy(**kwargs),
logging_policy=StorageLoggingPolicy(**kwargs),
proxy_policy=ProxyPolicy(**kwargs),
connection_timeout=kwargs.pop("connection_timeout", None) or CONNECTION_TIMEOUT,
**kwargs
)


@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,27 @@
from azure.core.exceptions import HttpResponseError, ResourceExistsError
from azure.core.paging import ItemPaged
from azure.core.tracing.decorator import distributed_trace
from azure.core.pipeline import Pipeline
from ._models import TableItem
from azure.core.pipeline.policies import UserAgentPolicy, ProxyPolicy

from ._generated import AzureTable
from ._generated.models import TableProperties, TableServiceProperties
from ._models import TablePropertiesPaged, service_stats_deserialize, service_properties_deserialize
from ._base_client import parse_connection_str, TransportWrapper
from ._models import LocationMode
from ._models import (
TablePropertiesPaged,
service_stats_deserialize,
service_properties_deserialize,
TableItem
)
from ._base_client import parse_connection_str
from ._constants import CONNECTION_TIMEOUT
from ._error import _process_table_error
from ._sdk_moniker import SDK_MONIKER
from ._table_client import TableClient
from ._table_service_client_base import TableServiceClientBase

from ._policies import (
StorageHeadersPolicy,
StorageLoggingPolicy,
TablesRetryPolicy,
)

class TableServiceClient(TableServiceClientBase):
""" :ivar str account_name: Name of the storage account (Cosmos or Azure)"""
Expand Down Expand Up @@ -61,7 +70,17 @@ def __init__(
"""

super(TableServiceClient, self).__init__(account_url, service='table', credential=credential, **kwargs)
self._client = AzureTable(self.url, pipeline=self._pipeline)
self._client = AzureTable(
self.url,
sdk_moniker=SDK_MONIKER,
headers_policy=StorageHeadersPolicy(**kwargs),
user_agent_policy=UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs),
retry_policy=kwargs.pop("retry_policy", None) or TablesRetryPolicy(**kwargs),
logging_policy=StorageLoggingPolicy(**kwargs),
proxy_policy=ProxyPolicy(**kwargs),
connection_timeout=kwargs.pop("connection_timeout", None) or CONNECTION_TIMEOUT,
**kwargs
)

@classmethod
def from_connection_string(
Expand Down Expand Up @@ -103,7 +122,7 @@ def get_service_stats(self, **kwargs):
try:
timeout = kwargs.pop('timeout', None)
stats = self._client.service.get_statistics( # type: ignore
timeout=timeout, use_location=LocationMode.SECONDARY, **kwargs)
timeout=timeout, **kwargs) #use_location=LocationMode.SECONDARY, **kwargs)
return service_stats_deserialize(stats)
except HttpResponseError as error:
_process_table_error(error)
Expand Down Expand Up @@ -344,13 +363,8 @@ def get_table_client(self, table_name, **kwargs):

"""

_pipeline = Pipeline(
transport=TransportWrapper(self._pipeline._transport), # pylint: disable = protected-access
policies=self._pipeline._impl_policies # pylint: disable = protected-access
)

return TableClient(
self.url, table_name=table_name, credential=self.credential,
key_resolver_function=self.key_resolver_function, require_encryption=self.require_encryption,
key_encryption_key=self.key_encryption_key, api_version=self.api_version, _pipeline=_pipeline,
key_encryption_key=self.key_encryption_key, api_version=self.api_version,
_configuration=self._config, _location_mode=self._location_mode, _hosts=self._hosts, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,13 @@
import logging
from uuid import uuid4

from azure.core.pipeline import AsyncPipeline
from azure.core.exceptions import (
ResourceNotFoundError,
ClientAuthenticationError
)
from azure.core.pipeline.policies import (
ContentDecodePolicy,
AsyncBearerTokenCredentialPolicy,
AsyncRedirectPolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
)
from azure.core.pipeline.transport import AsyncHttpTransport, HttpRequest

from .._constants import STORAGE_OAUTH_SCOPE, CONNECTION_TIMEOUT, READ_TIMEOUT
from .._authentication import SharedKeyCredentialPolicy
from .._base_client import create_configuration
from .._policies import (
StorageContentValidation,
StorageRequestHook,
StorageHosts,
StorageHeadersPolicy
)
from ._policies_async import AsyncStorageResponseHook
from .._policies import StorageHeadersPolicy
from .._models import BatchErrorException, BatchTransactionResult

if TYPE_CHECKING:
Expand Down Expand Up @@ -64,45 +47,6 @@ async def close(self):
"""
await self._client.close()

def _create_pipeline(self, credential, **kwargs):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
self._credential_policy = None
if hasattr(credential, 'get_token'):
self._credential_policy = AsyncBearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
elif isinstance(credential, SharedKeyCredentialPolicy):
self._credential_policy = credential
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))
config = kwargs.get('_configuration') or create_configuration(**kwargs)
if kwargs.get('_pipeline'):
return config, kwargs['_pipeline']
config.transport = kwargs.get('transport') # type: ignore
kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
kwargs.setdefault("read_timeout", READ_TIMEOUT)
if not config.transport:
try:
from azure.core.pipeline.transport import AioHttpTransport
except ImportError:
raise ImportError("Unable to create async transport. Please check aiohttp is installed.")
config.transport = AioHttpTransport(**kwargs)
policies = [
config.headers_policy,
config.proxy_policy,
config.user_agent_policy,
StorageContentValidation(),
StorageRequestHook(**kwargs),
self._credential_policy,
ContentDecodePolicy(response_encoding="utf-8"),
AsyncRedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs), # type: ignore
config.retry_policy,
config.logging_policy,
AsyncStorageResponseHook(**kwargs),
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]
return config, AsyncPipeline(config.transport, policies=policies)

async def _batch_send(
self, entities, # type: List[TableEntity]
*reqs: 'HttpRequest',
Expand Down Expand Up @@ -135,7 +79,7 @@ async def _batch_send(
boundary="batch_{}".format(uuid4())
)

pipeline_response = await self._pipeline.run(
pipeline_response = await self._client._client._pipeline.run( # pylint: disable=protected-access
request, **kwargs
)
response = pipeline_response.http_response
Expand Down
Loading