Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[form recognizer] Remove unnecessary code #14257

Merged
2 commits merged into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
TYPE_CHECKING
)
from azure.core.tracing.decorator import distributed_trace
from azure.core.polling import LROPoller
from azure.core.polling.base_polling import LROBasePolling

from ._response_handlers import (
prepare_receipt,
prepare_content_result,
prepare_form_result
)
from ._helpers import get_content_type, error_map
from ._helpers import get_content_type
from ._form_base_client import FormRecognizerClientBase
from ._polling import AnalyzePolling
if TYPE_CHECKING:
from azure.core.polling import LROPoller
from ._models import FormPage, RecognizedForm


Expand Down Expand Up @@ -104,14 +105,11 @@ def begin_recognize_receipts(self, receipt, **kwargs):
:caption: Recognize US sales receipt fields.
"""
locale = kwargs.pop("locale", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
include_field_elements = kwargs.pop("include_field_elements", False)
if content_type == "application/json":
raise TypeError("Call begin_recognize_receipts_from_url() to analyze a receipt from a URL.")
cls = kwargs.pop("cls", self._receipt_callback)
polling = LROBasePolling(timeout=polling_interval, **kwargs)
if content_type is None:
content_type = get_content_type(receipt)

Expand All @@ -123,9 +121,7 @@ def begin_recognize_receipts(self, receipt, **kwargs):
content_type=content_type,
include_text_details=include_field_elements,
cls=cls,
polling=polling,
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -161,20 +157,15 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs):
:caption: Recognize US sales receipt fields from a URL.
"""
locale = kwargs.pop("locale", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_field_elements = kwargs.pop("include_field_elements", False)
cls = kwargs.pop("cls", self._receipt_callback)
polling = LROBasePolling(timeout=polling_interval, **kwargs)
if self.api_version == "2.1-preview.1" and locale:
kwargs.update({"locale": locale})
return self._client.begin_analyze_receipt_async( # type: ignore
file_stream={"source": receipt_url},
include_text_details=include_field_elements,
cls=cls,
polling=polling,
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -212,9 +203,6 @@ def begin_recognize_content(self, form, **kwargs):
:dedent: 8
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
raise TypeError("Call begin_recognize_content_from_url() to analyze a document from a URL.")
Expand All @@ -226,9 +214,7 @@ def begin_recognize_content(self, form, **kwargs):
file_stream=form,
content_type=content_type,
cls=kwargs.pop("cls", self._content_callback),
polling=LROBasePolling(timeout=polling_interval, **kwargs),
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand All @@ -249,15 +235,10 @@ def begin_recognize_content_from_url(self, form_url, **kwargs):
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

return self._client.begin_analyze_layout_async( # type: ignore
file_stream={"source": form_url},
cls=kwargs.pop("cls", self._content_callback),
polling=LROBasePolling(timeout=polling_interval, **kwargs),
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -299,9 +280,8 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
raise TypeError("Call begin_recognize_custom_forms_from_url() to analyze a document from a URL.")
Expand All @@ -314,16 +294,13 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume
analyze_result = self._deserialize(self._generated_models.AnalyzeOperationResult, raw_response)
return prepare_form_result(analyze_result, model_id)

deserialization_callback = cls if cls else analyze_callback
return self._client.begin_analyze_with_custom_model( # type: ignore
file_stream=form,
model_id=model_id,
include_text_details=include_field_elements,
content_type=content_type,
cls=deserialization_callback,
cls=kwargs.pop("cls", analyze_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down Expand Up @@ -351,24 +328,20 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

include_field_elements = kwargs.pop("include_field_elements", False)

def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argument
analyze_result = self._deserialize(self._generated_models.AnalyzeOperationResult, raw_response)
return prepare_form_result(analyze_result, model_id)

deserialization_callback = cls if cls else analyze_callback
return self._client.begin_analyze_with_custom_model( # type: ignore
file_stream={"source": form_url},
model_id=model_id,
include_text_details=include_field_elements,
cls=deserialization_callback,
cls=kwargs.pop("cls", analyze_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
CopyRequest,
CopyAuthorizationResult
)
from ._helpers import (
error_map,
TransportWrapper
)
from ._helpers import TransportWrapper

from ._models import (
CustomFormModelInfo,
AccountProperties,
Expand Down Expand Up @@ -152,7 +150,6 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
)
),
cls=lambda pipeline_response, _, response_headers: pipeline_response,
error_map=error_map,
**kwargs
) # type: PipelineResponseType

Expand All @@ -176,7 +173,6 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
cls=deserialization_callback,
continuation_token=continuation_token,
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[TrainingPolling()], **kwargs),
error_map=error_map,
**kwargs
)

Expand Down Expand Up @@ -204,11 +200,7 @@ def delete_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

self._client.delete_custom_model(
model_id=model_id,
error_map=error_map,
**kwargs
)
self._client.delete_custom_model(model_id=model_id, **kwargs)

@distributed_trace
def list_custom_models(self, **kwargs):
Expand All @@ -231,7 +223,6 @@ def list_custom_models(self, **kwargs):
"""
return self._client.list_custom_models( # type: ignore
cls=kwargs.pop("cls", lambda objs: [CustomFormModelInfo._from_generated(x) for x in objs]),
error_map=error_map,
**kwargs
)

Expand All @@ -254,7 +245,7 @@ def get_account_properties(self, **kwargs):
:dedent: 8
:caption: Get properties for the form recognizer account.
"""
response = self._client.get_custom_models(error_map=error_map, **kwargs)
response = self._client.get_custom_models(**kwargs)
return AccountProperties._from_generated(response.summary)

@distributed_trace
Expand All @@ -281,7 +272,7 @@ def get_custom_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

response = self._client.get_custom_model(model_id=model_id, include_keys=True, error_map=error_map, **kwargs)
response = self._client.get_custom_model(model_id=model_id, include_keys=True, **kwargs)
return CustomFormModel._from_generated(response)

@distributed_trace
Expand Down Expand Up @@ -314,7 +305,6 @@ def get_copy_authorization(self, resource_id, resource_region, **kwargs):

response = self._client.generate_model_copy_authorization( # type: ignore
cls=lambda pipeline_response, deserialized, response_headers: pipeline_response,
error_map=error_map,
**kwargs
) # type: PipelineResponse
target = json.loads(response.http_response.text())
Expand Down Expand Up @@ -359,9 +349,7 @@ def begin_copy_model(

if not model_id:
raise ValueError("model_id cannot be None or empty.")

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
copy_result = self._deserialize(self._generated_models.CopyOperationResult, raw_response)
Expand All @@ -380,8 +368,6 @@ def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
),
cls=kwargs.pop("cls", _copy_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[CopyPolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,10 @@
from azure.core.credentials import AzureKeyCredential
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.exceptions import (
ResourceNotFoundError,
ResourceExistsError,
ClientAuthenticationError
)

POLLING_INTERVAL = 5
COGNITIVE_KEY_HEADER = "Ocp-Apim-Subscription-Key"


error_map = {
404: ResourceNotFoundError,
409: ResourceExistsError,
401: ClientAuthenticationError
}

def _get_deserialize():
from ._generated.v2_1_preview_1 import FormRecognizerClient
return FormRecognizerClient("dummy", "dummy")._deserialize # pylint: disable=protected-access
Expand Down
Loading