Skip to content

Commit

Permalink
add credential schema policy classes to improve code flow (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft authored Jul 15, 2021
1 parent d4b50a5 commit ce1c5c9
Show file tree
Hide file tree
Showing 40 changed files with 206 additions and 168 deletions.
112 changes: 47 additions & 65 deletions autorest/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# --------------------------------------------------------------------------
import logging
import sys
from typing import Dict, Any, Set, Union, List
from typing import Dict, Any, Set, Union, List, Type
import yaml

from .. import Plugin
Expand All @@ -16,14 +16,8 @@
from .models.parameter_list import GlobalParameterList
from .models.rest import Rest
from .serializers import JinjaSerializer


def _get_credential_default_policy_type_has_async_version(credential_default_policy_type: str) -> bool:
mapping = {
"BearerTokenCredentialPolicy": True,
"AzureKeyCredentialPolicy": False
}
return mapping[credential_default_policy_type]
from .models.credential_schema_policy import CredentialSchemaPolicy, get_credential_schema_policy_type
from .models.credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema

def _build_convenience_layer(yaml_data: Dict[str, Any], code_model: CodeModel) -> None:
# Create operations
Expand Down Expand Up @@ -111,6 +105,8 @@ def _create_code_model(self, yaml_data: Dict[str, Any], options: Dict[str, Union
only_path_and_body_params_positional=only_path_and_body_params_positional,
options=options,
)
if code_model.options['credential']:
self._handle_default_authentication_policy(code_model)
code_model.module_name = yaml_data["info"]["python_title"]
code_model.class_name = yaml_data["info"]["pascal_case_title"]
code_model.description = (
Expand Down Expand Up @@ -176,70 +172,69 @@ def _get_credential_scopes(self, credential):
)
return credential_scopes

def _get_credential_param(self, azure_arm, credential, credential_default_policy_type):
credential_scopes = self._get_credential_scopes(credential)
def _initialize_credential_schema_policy(
self, code_model: CodeModel, credential_schema_policy: Type[CredentialSchemaPolicy]
) -> CredentialSchemaPolicy:
credential_scopes = self._get_credential_scopes(code_model.options['credential'])
credential_key_header_name = self._autorestapi.get_value('credential-key-header-name')
azure_arm = code_model.options['azure_arm']
credential = code_model.options['credential']

if credential_default_policy_type == "BearerTokenCredentialPolicy":
if hasattr(credential_schema_policy, "credential_scopes"):
if not credential_scopes:
if azure_arm:
credential_scopes = ["https://management.azure.com/.default"]
elif credential:
# If add-credential is specified, we still want to add a credential_scopes variable.
# Will make it an empty list so we can differentiate between this case and None
_LOGGER.warning(
"You have default credential policy BearerTokenCredentialPolicy"
"You have default credential policy %s "
"but not the --credential-scopes flag set while generating non-management plane code. "
"This is not recommend because it forces the customer to pass credential scopes "
"through kwargs if they want to authenticate."
"through kwargs if they want to authenticate.",
credential_schema_policy.name()
)
credential_scopes = []

if credential_key_header_name:
raise ValueError(
"You have passed in a credential key header name with default credential policy type "
"BearerTokenCredentialPolicy. This is not allowed, since credential key header name is tied with "
"AzureKeyCredentialPolicy. Instead, with this policy it is recommend you pass in "
"--credential-scopes."
)
else:
# currently the only other credential policy is AzureKeyCredentialPolicy
if credential_scopes:
raise ValueError(
"You have passed in credential scopes with default credential policy type "
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
"BearerTokenCredentialPolicy. Instead, with this policy you must pass in "
"--credential-key-header-name."
)
if not credential_key_header_name:
credential_key_header_name = "api-key"
_LOGGER.info(
"Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
f"{credential_schema_policy.name()}. This is not allowed, since credential key header "
"name is tied with AzureKeyCredentialPolicy. Instead, with this policy it is recommend you "
"pass in --credential-scopes."
)
return credential_scopes, credential_key_header_name

def _handle_default_authentication_policy(self, azure_arm, credential):

passed_in_credential_default_policy_type = (
self._autorestapi.get_value("credential-default-policy-type") or "BearerTokenCredentialPolicy"
)

# right now, we only allow BearerTokenCredentialPolicy and AzureKeyCredentialPolicy
allowed_policies = ["BearerTokenCredentialPolicy", "AzureKeyCredentialPolicy"]
try:
credential_default_policy_type = [
cp for cp in allowed_policies if cp.lower() == passed_in_credential_default_policy_type.lower()
][0]
except IndexError:
return credential_schema_policy(
credential=TokenCredentialSchema(async_mode=False),
credential_scopes=credential_scopes,
)
# currently the only other credential policy is AzureKeyCredentialPolicy
if credential_scopes:
raise ValueError(
"The credential you pass in with --credential-default-policy-type must be either "
"BearerTokenCredentialPolicy or AzureKeyCredentialPolicy"
"You have passed in credential scopes with default credential policy type "
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
f"{code_model.default_authentication_policy.name()}. Instead, with this policy you must pass in "
"--credential-key-header-name."
)

credential_scopes, credential_key_header_name = self._get_credential_param(
azure_arm, credential, credential_default_policy_type
if not credential_key_header_name:
credential_key_header_name = "api-key"
_LOGGER.info(
"Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
)
return credential_schema_policy(
credential=AzureKeyCredentialSchema(),
credential_key_header_name=credential_key_header_name,
)

return credential_default_policy_type, credential_scopes, credential_key_header_name
def _handle_default_authentication_policy(self, code_model: CodeModel):
credential_schema_policy_name = (
self._autorestapi.get_value("credential-default-policy-type") or
code_model.default_authentication_policy.name()
)
credential_schema_policy_type = get_credential_schema_policy_type(credential_schema_policy_name)
credential_schema_policy = self._initialize_credential_schema_policy(
code_model, credential_schema_policy_type
)
code_model.credential_schema_policy = credential_schema_policy


def _build_code_model_options(self) -> Dict[str, Any]:
Expand All @@ -251,13 +246,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
self._autorestapi.get_boolean_value("add-credential", False)
)

credential_default_policy_type, credential_scopes, credential_key_header_name = (
self._handle_default_authentication_policy(
azure_arm, credential
)
)


license_header = self._autorestapi.get_value("header-text")
if license_header:
license_header = license_header.replace("\n", "\n# ")
Expand All @@ -269,8 +257,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
options: Dict[str, Any] = {
"azure_arm": azure_arm,
"credential": credential,
"credential_scopes": credential_scopes,
"credential_key_header_name": credential_key_header_name,
"head_as_boolean": self._autorestapi.get_boolean_value("head-as-boolean", False),
"license_header": license_header,
"keep_version_file": self._autorestapi.get_boolean_value("keep-version-file", False),
Expand All @@ -282,10 +268,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"client_side_validation": self._autorestapi.get_boolean_value("client-side-validation", False),
"tracing": self._autorestapi.get_boolean_value("trace", False),
"multiapi": self._autorestapi.get_boolean_value("multiapi", False),
"credential_default_policy_type": credential_default_policy_type,
"credential_default_policy_type_has_async_version": (
_get_credential_default_policy_type_has_async_version(credential_default_policy_type)
),
"polymorphic_examples": self._autorestapi.get_value("polymorphic-examples") or 5,
}

Expand Down
28 changes: 20 additions & 8 deletions autorest/codegen/models/code_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# --------------------------------------------------------------------------
from itertools import chain
import logging
from typing import cast, List, Dict, Optional, Any, Set, Union
from typing import cast, List, Dict, Optional, Any, Set, Type

from .base_schema import BaseSchema
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .credential_schema_policy import (
BearerTokenCredentialPolicy, CredentialSchemaPolicy
)
from .enum_schema import EnumSchema
from .object_schema import ObjectSchema
from .operation_group import OperationGroup
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
self.service_client: Client = Client(self, GlobalParameterList())
self._rest: Optional[Rest] = None
self.request_builder_ids: Dict[int, RequestBuilder] = {}
self._credential_schema_policy: Optional[CredentialSchemaPolicy] = None

@property
def global_parameters(self) -> GlobalParameterList:
Expand Down Expand Up @@ -158,14 +161,9 @@ def add_credential_global_parameter(self) -> None:
:return: None
:rtype: None
"""
credential_schema: Union[AzureKeyCredentialSchema, TokenCredentialSchema]
if self.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy":
credential_schema = TokenCredentialSchema(async_mode=False)
else:
credential_schema = AzureKeyCredentialSchema()
credential_parameter = Parameter(
yaml_data={},
schema=credential_schema,
schema=self.credential_schema_policy.credential,
serialized_name="credential",
rest_api_name="credential",
implementation="Client",
Expand Down Expand Up @@ -217,6 +215,20 @@ def _lookup_operation(yaml_id: int) -> Operation:
operation for operation in operation_group.operations if operation not in next_operations
]

@property
def default_authentication_policy(self) -> Type[CredentialSchemaPolicy]:
return BearerTokenCredentialPolicy

@property
def credential_schema_policy(self) -> CredentialSchemaPolicy:
if not self._credential_schema_policy:
raise ValueError("You want to find the Credential Schema Policy, but have not given a value")
return self._credential_schema_policy

@credential_schema_policy.setter
def credential_schema_policy(self, val: CredentialSchemaPolicy) -> None:
self._credential_schema_policy = val

@staticmethod
def _add_properties_from_inheritance_helper(schema, properties) -> List[Property]:
if not schema.base_models:
Expand Down
70 changes: 70 additions & 0 deletions autorest/codegen/models/credential_schema_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from abc import abstractmethod
from typing import List
from .credential_schema import CredentialSchema

class CredentialSchemaPolicy:

def __init__(self, credential: CredentialSchema, *args, **kwargs) -> None: # pylint: disable=unused-argument
self.credential = credential

@abstractmethod
def call(self, async_mode: bool) -> str:
...

@classmethod
def name(cls):
return cls.__name__


class BearerTokenCredentialPolicy(CredentialSchemaPolicy):

def __init__(
self,
credential: CredentialSchema,
credential_scopes: List[str]
) -> None:
super().__init__(credential)
self._credential_scopes = credential_scopes

@property
def credential_scopes(self):
return self._credential_scopes

def call(self, async_mode: bool) -> str:
policy_name = f"Async{self.name()}" if async_mode else self.name()
return f"policies.{policy_name}(self.credential, *self.credential_scopes, **kwargs)"


class AzureKeyCredentialPolicy(CredentialSchemaPolicy):

def __init__(
self,
credential: CredentialSchema,
credential_key_header_name: str
) -> None:
super().__init__(credential)
self._credential_key_header_name = credential_key_header_name

@property
def credential_key_header_name(self):
return self._credential_key_header_name

def call(self, async_mode: bool) -> str:
return f'policies.AzureKeyCredentialPolicy(self.credential, "{self.credential_key_header_name}", **kwargs)'

def get_credential_schema_policy_type(name):
policies = [BearerTokenCredentialPolicy, AzureKeyCredentialPolicy]
try:
return next(p for p in policies if p.name().lower() == name.lower())
except StopIteration:
raise ValueError(
"The credential policy you pass in with --credential-default-policy-type must be either "
"{}".format(
" or ".join([p.name() for p in policies])
)
)
4 changes: 2 additions & 2 deletions autorest/codegen/serializers/general_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def serialize_service_client_file(self) -> str:

if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
self._correct_credential_parameter()

Expand All @@ -71,7 +71,7 @@ def serialize_config_file(self) -> str:

if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
self._correct_credential_parameter()

Expand Down
4 changes: 2 additions & 2 deletions autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def _is_paging(operation):
# for typing purposes.
async_global_parameters = self.code_model.global_parameters
if (
self.code_model.options["credential"]
and self.code_model.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy"
self.code_model.options['credential'] and
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
# this ensures that the TokenCredentialSchema showing up in the list of code model's global parameters
# is sync. This way we only have to make a copy for an async_credential
Expand Down
10 changes: 4 additions & 6 deletions autorest/codegen/templates/config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class {{ code_model.class_name }}Configuration(Configuration):
self.{{ constant_parameter.serialized_name }} = {{ constant_parameter.constant_declaration }}
{% endfor %}
{% endif %}
{% if code_model.options['credential_scopes'] is not none %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ code_model.options['credential_scopes'] }})
{% if code_model.options['credential'] and code_model.credential_schema_policy.credential_scopes is defined %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ code_model.credential_schema_policy.credential_scopes }})
{% endif %}
kwargs.setdefault('sdk_moniker', '{{ sdk_moniker }}/{}'.format(VERSION))
self._configure(**kwargs)
Expand All @@ -73,12 +73,10 @@ class {{ code_model.class_name }}Configuration(Configuration):
self.authentication_policy = kwargs.get('authentication_policy')
{% if code_model.options['credential'] %}
{# only adding this if credential_scopes is not passed during code generation #}
{% if code_model.options["credential_scopes"] is not none and code_model.options["credential_scopes"]|length == 0 %}
{% if code_model.credential_schema_policy.credential_scopes is defined and code_model.credential_schema_policy.credential_scopes|length == 0 %}
if not self.credential_scopes and not self.authentication_policy:
raise ValueError("You must provide either credential_scopes or authentication_policy as kwargs")
{% endif %}
if self.credential and not self.authentication_policy:
{% set credential_default_policy_type = ("Async" if (async_mode and code_model.options['credential_default_policy_type_has_async_version']) else "") + code_model.options['credential_default_policy_type'] %}
{% set credential_param_type = ("'" + code_model.options['credential_key_header_name'] + "', ") if code_model.options['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
self.authentication_policy = {{ code_model.credential_schema_policy.call(async_mode) }}
{% endif %}
7 changes: 3 additions & 4 deletions autorest/codegen/templates/metadata.json.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@
},
"config": {
"credential": {{ code_model.options['credential'] | tojson }},
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }},
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }},
"credential_scopes": {{ (code_model.credential_schema_policy.credential_scopes if code_model.options['credential'] and code_model.credential_schema_policy.credential_scopes is defined else None)| tojson}},
"credential_call_sync": {{ (code_model.credential_schema_policy.call(async_mode=False) if code_model.options['credential'] else None) | tojson }},
"credential_call_async": {{ (code_model.credential_schema_policy.call(async_mode=True) if code_model.options['credential'] else None) | tojson }},
"sync_imports": {{ sync_config_imports | tojson }},
"async_imports": {{ async_config_imports | tojson }}
},
Expand Down
Loading

0 comments on commit ce1c5c9

Please sign in to comment.