Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add credential schema policy classes to improve code flow #974

Merged
merged 5 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 33 additions & 51 deletions autorest/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,11 @@
from .models.operation_group import OperationGroup
from .models.parameter import Parameter
from .models.parameter_list import ParameterList
from .models.credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .models.credential_schema_policy import get_credential_schema_policy, CredentialSchemaPolicy
from .serializers import JinjaSerializer


def _get_credential_default_policy_type_has_async_version(credential_default_policy_type: str) -> bool:
mapping = {
"BearerTokenCredentialPolicy": True,
"AzureKeyCredentialPolicy": False
}
return mapping[credential_default_policy_type]

_LOGGER = logging.getLogger(__name__)
class CodeGenerator(Plugin):
@staticmethod
Expand Down Expand Up @@ -62,6 +57,8 @@ def _build_exceptions_set(yaml_data: List[Dict[str, Any]]) -> Set[int]:
def _create_code_model(self, yaml_data: Dict[str, Any], options: Dict[str, Union[str, bool]]) -> CodeModel:
# Create a code model
code_model = CodeModel(options)
if code_model.options['credential']:
self._handle_default_authentication_policy(code_model)
code_model.module_name = yaml_data["info"]["python_title"]
code_model.class_name = yaml_data["info"]["pascal_case_title"]
code_model.description = (
Expand Down Expand Up @@ -138,70 +135,68 @@ def _get_credential_scopes(self, credential):
)
return credential_scopes

def _get_credential_param(self, azure_arm, credential, credential_default_policy_type):
credential_scopes = self._get_credential_scopes(credential)
def _initialize_credential_schema_policy(
self, code_model: CodeModel, credential_schema_policy: CredentialSchemaPolicy
):
credential_scopes = self._get_credential_scopes(code_model.options['credential'])
credential_key_header_name = self._autorestapi.get_value('credential-key-header-name')
azure_arm = code_model.options['azure_arm']
credential = code_model.options['credential']

if credential_default_policy_type == "BearerTokenCredentialPolicy":
if hasattr(credential_schema_policy, "credential_scopes"):
if not credential_scopes:
if azure_arm:
credential_scopes = ["https://management.azure.com/.default"]
elif credential:
# If add-credential is specified, we still want to add a credential_scopes variable.
# Will make it an empty list so we can differentiate between this case and None
_LOGGER.warning(
"You have default credential policy BearerTokenCredentialPolicy"
"You have default credential policy %s "
"but not the --credential-scopes flag set while generating non-management plane code. "
"This is not recommend because it forces the customer to pass credential scopes "
"through kwargs if they want to authenticate."
"through kwargs if they want to authenticate.",
credential_schema_policy.name
)
credential_scopes = []

if credential_key_header_name:
raise ValueError(
"You have passed in a credential key header name with default credential policy type "
"BearerTokenCredentialPolicy. This is not allowed, since credential key header name is tied with "
"AzureKeyCredentialPolicy. Instead, with this policy it is recommend you pass in "
"--credential-scopes."
f"{credential_schema_policy.name}. This is not allowed, since credential key header "
"name is tied with AzureKeyCredentialPolicy. Instead, with this policy it is recommend you "
"pass in --credential-scopes."
)
credential_schema_policy.initialize(
credential=TokenCredentialSchema(async_mode=False),
credential_scopes=credential_scopes,
)
else:
# currently the only other credential policy is AzureKeyCredentialPolicy
if credential_scopes:
raise ValueError(
"You have passed in credential scopes with default credential policy type "
"AzureKeyCredentialPolicy. This is not allowed, since credential scopes is tied with "
"BearerTokenCredentialPolicy. Instead, with this policy you must pass in "
f"{code_model.default_authentication_policy.name}. Instead, with this policy you must pass in "
"--credential-key-header-name."
)
if not credential_key_header_name:
credential_key_header_name = "api-key"
_LOGGER.info(
"Defaulting the AzureKeyCredentialPolicy header's name to 'api-key'"
)
return credential_scopes, credential_key_header_name

def _handle_default_authentication_policy(self, azure_arm, credential):

passed_in_credential_default_policy_type = (
self._autorestapi.get_value("credential-default-policy-type") or "BearerTokenCredentialPolicy"
)

# right now, we only allow BearerTokenCredentialPolicy and AzureKeyCredentialPolicy
allowed_policies = ["BearerTokenCredentialPolicy", "AzureKeyCredentialPolicy"]
try:
credential_default_policy_type = [
cp for cp in allowed_policies if cp.lower() == passed_in_credential_default_policy_type.lower()
][0]
except IndexError:
raise ValueError(
"The credential you pass in with --credential-default-policy-type must be either "
"BearerTokenCredentialPolicy or AzureKeyCredentialPolicy"
credential_schema_policy.initialize(
credential=AzureKeyCredentialSchema(),
credential_key_header_name=credential_key_header_name,
)

credential_scopes, credential_key_header_name = self._get_credential_param(
azure_arm, credential, credential_default_policy_type
def _handle_default_authentication_policy(self, code_model: CodeModel):
credential_schema_policy_name = (
self._autorestapi.get_value("credential-default-policy-type") or
code_model.default_authentication_policy.name
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
)

return credential_default_policy_type, credential_scopes, credential_key_header_name
credential_schema_policy = get_credential_schema_policy(credential_schema_policy_name)
self._initialize_credential_schema_policy(code_model, credential_schema_policy)
code_model.credential_schema_policy = credential_schema_policy


def _build_code_model_options(self) -> Dict[str, Any]:
Expand All @@ -213,13 +208,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
self._autorestapi.get_boolean_value("add-credential", False)
)

credential_default_policy_type, credential_scopes, credential_key_header_name = (
self._handle_default_authentication_policy(
azure_arm, credential
)
)


license_header = self._autorestapi.get_value("header-text")
if license_header:
license_header = license_header.replace("\n", "\n# ")
Expand All @@ -231,8 +219,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
options: Dict[str, Any] = {
"azure_arm": azure_arm,
"credential": credential,
"credential_scopes": credential_scopes,
"credential_key_header_name": credential_key_header_name,
"head_as_boolean": self._autorestapi.get_boolean_value("head-as-boolean", False),
"license_header": license_header,
"keep_version_file": self._autorestapi.get_boolean_value("keep-version-file", False),
Expand All @@ -244,10 +230,6 @@ def _build_code_model_options(self) -> Dict[str, Any]:
"client_side_validation": self._autorestapi.get_boolean_value("client-side-validation", False),
"tracing": self._autorestapi.get_boolean_value("trace", False),
"multiapi": self._autorestapi.get_boolean_value("multiapi", False),
"credential_default_policy_type": credential_default_policy_type,
"credential_default_policy_type_has_async_version": (
_get_credential_default_policy_type_has_async_version(credential_default_policy_type)
)
}

if options["basic_setup_py"] and not options["package_version"]:
Expand Down
28 changes: 20 additions & 8 deletions autorest/codegen/models/code_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# --------------------------------------------------------------------------
from itertools import chain
import logging
from typing import cast, List, Dict, Optional, Any, Set, Union
from typing import cast, List, Dict, Optional, Any, Set

from .base_schema import BaseSchema
from .credential_schema import AzureKeyCredentialSchema, TokenCredentialSchema
from .credential_schema_policy import (
BearerTokenCredentialPolicy, CredentialSchemaPolicy
)
from .enum_schema import EnumSchema
from .object_schema import ObjectSchema
from .operation_group import OperationGroup
Expand Down Expand Up @@ -68,6 +70,7 @@ def __init__(self, options: Dict[str, Any]) -> None:
self.custom_base_url: Optional[str] = None
self.base_url: Optional[str] = None
self.service_client: Client = Client()
self._credential_schema_policy: Optional[CredentialSchemaPolicy] = None

def lookup_schema(self, schema_id: int) -> BaseSchema:
"""Looks to see if the schema has already been created.
Expand Down Expand Up @@ -124,14 +127,9 @@ def add_credential_global_parameter(self) -> None:
:return: None
:rtype: None
"""
credential_schema: Union[AzureKeyCredentialSchema, TokenCredentialSchema]
if self.options["credential_default_policy_type"] == "BearerTokenCredentialPolicy":
credential_schema = TokenCredentialSchema(async_mode=False)
else:
credential_schema = AzureKeyCredentialSchema()
credential_parameter = Parameter(
yaml_data={},
schema=credential_schema,
schema=self.credential_schema_policy.credential,
serialized_name="credential",
rest_api_name="credential",
implementation="Client",
Expand Down Expand Up @@ -203,6 +201,20 @@ def _lookup_operation(yaml_id: int) -> Operation:
operation for operation in operation_group.operations if operation not in next_operations
]

@property
def default_authentication_policy(self) -> CredentialSchemaPolicy:
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
return BearerTokenCredentialPolicy()

@property
def credential_schema_policy(self) -> CredentialSchemaPolicy:
if not self._credential_schema_policy:
raise ValueError("You want to find the Credential Schema Policy, but have not given a value")
return self._credential_schema_policy

@credential_schema_policy.setter
def credential_schema_policy(self, val: CredentialSchemaPolicy) -> None:
self._credential_schema_policy = val

@staticmethod
def _add_properties_from_inheritance_helper(schema, properties) -> List[Property]:
if not schema.base_models:
Expand Down
73 changes: 73 additions & 0 deletions autorest/codegen/models/credential_schema_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from abc import abstractmethod
from typing import Optional
from .credential_schema import CredentialSchema

class CredentialSchemaPolicy:
name: Optional[str] = None
lmazuel marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:
self._credential = None

@property
def credential(self) -> CredentialSchema:
if not self._credential:
raise ValueError(
"You have not initialized this policy with its credential yet"
)
return self._credential

def initialize(self, credential, **kwargs): # pylint: disable=unused-argument
"""Initialize your schema policy"""
self._credential = credential

@abstractmethod
def call(self, async_mode: bool) -> str:
...


class BearerTokenCredentialPolicy(CredentialSchemaPolicy):
name = "BearerTokenCredentialPolicy"

def __init__(self) -> None:
super().__init__()
self.credential_scopes = None

def initialize(self, credential, **kwargs):
super().initialize(credential)
self.credential_scopes = kwargs.pop("credential_scopes")

def call(self, async_mode: bool) -> str:
policy_name = f"Async{self.name}" if async_mode else self.name
return f"policies.{policy_name}(self.credential, *self.credential_scopes, **kwargs)"


class AzureKeyCredentialPolicy(CredentialSchemaPolicy):
name = "AzureKeyCredentialPolicy"

def __init__(self) -> None:
super().__init__()
self.credential_key_header_name = None

def initialize(self, credential, **kwargs):
lmazuel marked this conversation as resolved.
Show resolved Hide resolved
super().initialize(credential)
self.credential_key_header_name = kwargs.pop("credential_key_header_name")

def call(self, async_mode: bool) -> str:
return f'policies.AzureKeyCredentialPolicy(self.credential, "{self.credential_key_header_name}", **kwargs)'

def get_credential_schema_policy(name):
policies = [BearerTokenCredentialPolicy(), AzureKeyCredentialPolicy()]
try:
return next(p for p in policies if p.name.lower() == name.lower())
except StopIteration:
raise ValueError(
"The credential policy you pass in with --credential-default-policy-type must be either "
"{}".format(
" or ".join([p.name for p in policies])
)
)
4 changes: 2 additions & 2 deletions autorest/codegen/serializers/general_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _service_client_imports() -> FileImport:

if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
self._correct_credential_parameter()

Expand All @@ -75,7 +75,7 @@ def serialize_config_file(self) -> str:

if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
self._correct_credential_parameter()

Expand Down
2 changes: 1 addition & 1 deletion autorest/codegen/serializers/metadata_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _is_paging(operation):
async_global_parameters = self.code_model.global_parameters
if (
self.code_model.options['credential'] and
self.code_model.options['credential_default_policy_type'] == "BearerTokenCredentialPolicy"
isinstance(self.code_model.credential_schema_policy.credential, TokenCredentialSchema)
):
# this ensures that the TokenCredentialSchema showing up in the list of code model's global parameters
# is sync. This way we only have to make a copy for an async_credential
Expand Down
10 changes: 4 additions & 6 deletions autorest/codegen/templates/config.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class {{ code_model.class_name }}Configuration(Configuration):
self.{{ constant_parameter.serialized_name }} = {{ constant_parameter.constant_declaration }}
{% endfor %}
{% endif %}
{% if code_model.options['credential_scopes'] is not none %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ code_model.options['credential_scopes'] }})
{% if code_model.options['credential'] and code_model.credential_schema_policy.credential_scopes is defined %}
self.credential_scopes = kwargs.pop('credential_scopes', {{ code_model.credential_schema_policy.credential_scopes }})
{% endif %}
kwargs.setdefault('sdk_moniker', '{{ sdk_moniker }}/{}'.format(VERSION))
self._configure(**kwargs)
Expand All @@ -74,12 +74,10 @@ class {{ code_model.class_name }}Configuration(Configuration):
self.authentication_policy = kwargs.get('authentication_policy')
{% if code_model.options['credential'] %}
{# only adding this if credential_scopes is not passed during code generation #}
{% if code_model.options["credential_scopes"] is not none and code_model.options["credential_scopes"]|length == 0 %}
{% if code_model.credential_schema_policy.credential_scopes is defined and code_model.credential_schema_policy.credential_scopes|length == 0 %}
if not self.credential_scopes and not self.authentication_policy:
raise ValueError("You must provide either credential_scopes or authentication_policy as kwargs")
{% endif %}
if self.credential and not self.authentication_policy:
{% set credential_default_policy_type = ("Async" if (async_mode and code_model.options['credential_default_policy_type_has_async_version']) else "") + code_model.options['credential_default_policy_type'] %}
{% set credential_param_type = ("'" + code_model.options['credential_key_header_name'] + "', ") if code_model.options['credential_key_header_name'] else ("*self.credential_scopes, " if "BearerTokenCredentialPolicy" in credential_default_policy_type else "") %}
self.authentication_policy = policies.{{ credential_default_policy_type }}(self.credential, {{ credential_param_type if credential_param_type }}**kwargs)
self.authentication_policy = {{ code_model.credential_schema_policy.call(async_mode) }}
{% endif %}
7 changes: 3 additions & 4 deletions autorest/codegen/templates/metadata.json.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@
},
"config": {
"credential": {{ code_model.options['credential'] | tojson }},
"credential_scopes": {{ code_model.options['credential_scopes'] | tojson }},
"credential_default_policy_type": {{ code_model.options['credential_default_policy_type'] | tojson }},
"credential_default_policy_type_has_async_version": {{ code_model.options['credential_default_policy_type_has_async_version'] | tojson }},
"credential_key_header_name": {{ code_model.options['credential_key_header_name'] | tojson }},
"credential_scopes": {{ (code_model.credential_schema_policy.credential_scopes if code_model.options['credential'] and code_model.credential_schema_policy.credential_scopes is defined else None)| tojson}},
"credential_call_sync": {{ (code_model.credential_schema_policy.call(async_mode=False) if code_model.options['credential'] else None) | tojson }},
"credential_call_async": {{ (code_model.credential_schema_policy.call(async_mode=True) if code_model.options['credential'] else None) | tojson }},
"sync_imports": {{ sync_config_imports | tojson }},
"async_imports": {{ async_config_imports | tojson }}
},
Expand Down
10 changes: 5 additions & 5 deletions autorest/multiapi/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class Config:
def __init__(self, default_version_metadata: Dict[str, Any]):
self.credential = default_version_metadata["config"]["credential"]
self.credential_scopes = default_version_metadata["config"]["credential_scopes"]
self.credential_default_policy_type = default_version_metadata["config"]["credential_default_policy_type"]
self.credential_default_policy_type_has_async_version = (
default_version_metadata["config"]["credential_default_policy_type_has_async_version"]
)
self.credential_key_header_name = default_version_metadata["config"]["credential_key_header_name"]
self.default_version_metadata = default_version_metadata

def imports(self, async_mode: bool) -> FileImport:
imports_to_load = "async_imports" if async_mode else "sync_imports"
return FileImport(json.loads(self.default_version_metadata['config'][imports_to_load]))

def credential_call(self, async_mode: bool) -> str:
if async_mode:
return self.default_version_metadata["config"]["credential_call_async"]
return self.default_version_metadata["config"]["credential_call_sync"]
Loading