Skip to content

Commit

Permalink
Update generated (Azure#16049)
Browse files Browse the repository at this point in the history
Use more generated code for config and pipeline creation
  • Loading branch information
seankane-msft authored and rakshith91 committed Jan 15, 2021
1 parent a08f290 commit 8505636
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 223 deletions.
89 changes: 23 additions & 66 deletions sdk/tables/azure-data-tables/azure/data/tables/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from azure.core.exceptions import ClientAuthenticationError, ResourceNotFoundError
from azure.core.pipeline import Pipeline
from azure.core.pipeline.transport import (
RequestsTransport,
HttpTransport,
HttpRequest,
)
Expand Down Expand Up @@ -66,7 +65,6 @@
from ._models import BatchErrorException
from ._sdk_moniker import SDK_MONIKER


_LOGGER = logging.getLogger(__name__)
_SERVICE_PARAMS = {
"blob": {"primary": "BlobEndpoint", "secondary": "BlobSecondaryEndpoint"},
Expand Down Expand Up @@ -124,9 +122,27 @@ def __init__(
self.require_encryption = kwargs.get("require_encryption", False)
self.key_encryption_key = kwargs.get("key_encryption_key")
self.key_resolver_function = kwargs.get("key_resolver_function")
self._config, self._pipeline = self._create_pipeline(
self.credential, storage_sdk=service, **kwargs
)

self._configure_credential(self.credential)
kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
kwargs.setdefault("read_timeout", READ_TIMEOUT)

self._policies = [
StorageHeadersPolicy(**kwargs),
ProxyPolicy(**kwargs),
UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs),
StorageContentValidation(),
StorageRequestHook(**kwargs),
self._credential_policy,
ContentDecodePolicy(response_encoding="utf-8"),
RedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
kwargs.get("retry_policy") or TablesRetryPolicy(**kwargs),
StorageLoggingPolicy(**kwargs),
StorageResponseHook(**kwargs),
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]

def __enter__(self):
self._client.__enter__()
Expand Down Expand Up @@ -233,7 +249,7 @@ def _format_query_string(
credential = None
return query_str.rstrip("?&"), credential

def _create_pipeline(self, credential, **kwargs):
def _configure_credential(self, credential):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
self._credential_policy = None
if hasattr(credential, "get_token"):
Expand All @@ -245,32 +261,6 @@ def _create_pipeline(self, credential, **kwargs):
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))

config = kwargs.get("_configuration") or create_configuration(**kwargs)
if kwargs.get("_pipeline"):
return config, kwargs["_pipeline"]
config.transport = kwargs.get("transport") # type: ignore
kwargs.setdefault("connection_timeout", CONNECTION_TIMEOUT)
kwargs.setdefault("read_timeout", READ_TIMEOUT)
if not config.transport:
config.transport = RequestsTransport(**kwargs)
policies = [
config.headers_policy,
config.proxy_policy,
config.user_agent_policy,
StorageContentValidation(),
StorageRequestHook(**kwargs),
self._credential_policy,
ContentDecodePolicy(response_encoding="utf-8"),
RedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
config.retry_policy,
config.logging_policy,
StorageResponseHook(**kwargs),
DistributedTracingPolicy(**kwargs),
HttpLoggingPolicy(**kwargs),
]
return config, Pipeline(config.transport, policies=policies)

def _batch_send( # pylint: disable=inconsistent-return-statements
self,
entities, # type: List[TableEntity]
Expand Down Expand Up @@ -302,7 +292,7 @@ def _batch_send( # pylint: disable=inconsistent-return-statements
boundary="batch_{}".format(uuid4()),
)

pipeline_response = self._pipeline.run(request, **kwargs)
pipeline_response = self._client._client._pipeline.run(request, **kwargs) # pylint:disable=protected-access
response = pipeline_response.http_response

if response.status_code == 403:
Expand Down Expand Up @@ -461,39 +451,6 @@ def parse_connection_str(conn_str, credential, service, keyword_args):
return primary, credential


def create_configuration(**kwargs):
# type: (**Any) -> Configuration
config = Configuration(**kwargs)
config.headers_policy = StorageHeadersPolicy(**kwargs)
config.user_agent_policy = UserAgentPolicy(sdk_moniker=SDK_MONIKER, **kwargs)
# sdk_moniker="storage-{}/{}".format(kwargs.pop('storage_sdk'), VERSION), **kwargs)
config.retry_policy = kwargs.get("retry_policy") or TablesRetryPolicy(**kwargs)
config.logging_policy = StorageLoggingPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)

# Storage settings
config.max_single_put_size = kwargs.get("max_single_put_size", 64 * 1024 * 1024)
config.copy_polling_interval = 15

# Block blob uploads
config.max_block_size = kwargs.get("max_block_size", 4 * 1024 * 1024)
config.min_large_block_upload_threshold = kwargs.get(
"min_large_block_upload_threshold", 4 * 1024 * 1024 + 1
)
config.use_byte_buffer = kwargs.get("use_byte_buffer", False)

# Page blob uploads
config.max_page_size = kwargs.get("max_page_size", 4 * 1024 * 1024)

# Blob downloads
config.max_single_get_size = kwargs.get("max_single_get_size", 32 * 1024 * 1024)
config.max_chunk_get_size = kwargs.get("max_chunk_get_size", 4 * 1024 * 1024)

# File uploads
config.max_range_size = kwargs.get("max_range_size", 4 * 1024 * 1024)
return config


def parse_query(query_str):
sas_values = QueryStringConstants.to_list()
parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(

client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)}
self._serialize = Serializer(client_models)
self._serialize.client_side_validation = False
self._deserialize = Deserializer(client_models)

self.table = TableOperations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(

client_models = {k: v for k, v in models.__dict__.items() if isinstance(v, type)}
self._serialize = Serializer(client_models)
self._serialize.client_side_validation = False
self._deserialize = Deserializer(client_models)

self.table = TableOperations(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest

from ... import models
from ... import models as _models

T = TypeVar('T')
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]]
Expand All @@ -31,7 +31,7 @@ class ServiceOperations:
:param deserializer: An object model deserializer.
"""

models = models
models = _models

def __init__(self, client, config, serializer, deserializer) -> None:
self._client = client
Expand All @@ -41,7 +41,7 @@ def __init__(self, client, config, serializer, deserializer) -> None:

async def set_properties(
self,
table_service_properties: "models.TableServiceProperties",
table_service_properties: "_models.TableServiceProperties",
timeout: Optional[int] = None,
request_id_parameter: Optional[str] = None,
**kwargs
Expand Down Expand Up @@ -102,7 +102,7 @@ async def set_properties(

if response.status_code not in [202]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.TableServiceError, response)
error = self._deserialize(_models.TableServiceError, response)
raise HttpResponseError(response=response, model=error)

response_headers = {}
Expand All @@ -120,7 +120,7 @@ async def get_properties(
timeout: Optional[int] = None,
request_id_parameter: Optional[str] = None,
**kwargs
) -> "models.TableServiceProperties":
) -> "_models.TableServiceProperties":
"""Gets the properties of an account's Table service, including properties for Analytics and CORS
(Cross-Origin Resource Sharing) rules.
Expand All @@ -134,7 +134,7 @@ async def get_properties(
:rtype: ~azure.data.tables.models.TableServiceProperties
:raises: ~azure.core.exceptions.HttpResponseError
"""
cls = kwargs.pop('cls', None) # type: ClsType["models.TableServiceProperties"]
cls = kwargs.pop('cls', None) # type: ClsType["_models.TableServiceProperties"]
error_map = {
401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError
}
Expand Down Expand Up @@ -170,7 +170,7 @@ async def get_properties(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.TableServiceError, response)
error = self._deserialize(_models.TableServiceError, response)
raise HttpResponseError(response=response, model=error)

response_headers = {}
Expand All @@ -190,7 +190,7 @@ async def get_statistics(
timeout: Optional[int] = None,
request_id_parameter: Optional[str] = None,
**kwargs
) -> "models.TableServiceStats":
) -> "_models.TableServiceStats":
"""Retrieves statistics related to replication for the Table service. It is only available on the
secondary location endpoint when read-access geo-redundant replication is enabled for the
account.
Expand All @@ -205,7 +205,7 @@ async def get_statistics(
:rtype: ~azure.data.tables.models.TableServiceStats
:raises: ~azure.core.exceptions.HttpResponseError
"""
cls = kwargs.pop('cls', None) # type: ClsType["models.TableServiceStats"]
cls = kwargs.pop('cls', None) # type: ClsType["_models.TableServiceStats"]
error_map = {
401: ClientAuthenticationError, 404: ResourceNotFoundError, 409: ResourceExistsError
}
Expand Down Expand Up @@ -241,7 +241,7 @@ async def get_statistics(

if response.status_code not in [200]:
map_error(status_code=response.status_code, response=response, error_map=error_map)
error = self._deserialize(models.TableServiceError, response)
error = self._deserialize(_models.TableServiceError, response)
raise HttpResponseError(response=response, model=error)

response_headers = {}
Expand Down
Loading

0 comments on commit 8505636

Please sign in to comment.