Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dt use patch #36977

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"sdk/synapse/azure-synapse/**",
"sdk/synapse/azure-synapse-artifacts/**",
"sdk/translation/azure-ai-translation-document/samples/assets/**",
"sdk/translation/azure-ai-translation-document/doc/**",
"sdk/translation/azure-ai-translation-document/tests/glossaries-valid.csv",
"sdk/storage/azure-storage-blob/**",
"sdk/storage/azure-storage-extensions/**",
Expand Down
3 changes: 2 additions & 1 deletion sdk/translation/azure-ai-translation-document/MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ recursive-include tests *.py
recursive-include samples *.py *.md
include azure/__init__.py
include azure/ai/__init__.py
include azure/ai/translation/__init__.py
include azure/ai/translation/__init__.py
recursive-include doc *.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from copy import deepcopy
from typing import Any, TYPE_CHECKING, Union
from typing_extensions import Self

from azure.core import PipelineClient
from azure.core.credentials import AzureKeyCredential
Expand Down Expand Up @@ -97,7 +98,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "DocumentTranslationClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down Expand Up @@ -177,7 +178,7 @@ def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs:
def close(self) -> None:
self._client.close()

def __enter__(self) -> "SingleDocumentTranslationClient":
def __enter__(self) -> Self:
self._client.__enter__()
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def as_dict(self, *, exclude_readonly: bool = False) -> typing.Dict[str, typing.
"""

result = {}
readonly_props = []
if exclude_readonly:
readonly_props = [p._rest_name for p in self._attr_to_rest_field.values() if _is_readonly(p)]
for k, v in self.items():
Expand Down Expand Up @@ -883,5 +884,6 @@ def rest_discriminator(
*,
name: typing.Optional[str] = None,
type: typing.Optional[typing.Callable] = None, # pylint: disable=redefined-builtin
visibility: typing.Optional[typing.List[str]] = None,
) -> typing.Any:
return _RestField(name=name, type=type, is_discriminator=True)
return _RestField(name=name, type=type, is_discriminator=True, visibility=visibility)

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,41 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines,protected-access
"""Customize generated code here.

Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize
"""

import sys
from typing import Any, IO, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union, cast, overload
from azure.core.polling import NoPolling, PollingMethod
from azure.core.polling.base_polling import LROBasePolling
from typing import Any, IO, Callable, Dict, Iterator, List, Optional, Type, TypeVar, Union, cast, overload, Tuple
from azure.core.tracing.decorator import distributed_trace
from azure.core.utils import case_insensitive_dict
from azure.core.rest import HttpRequest, HttpResponse
from azure.core.rest import HttpRequest, HttpResponse, AsyncHttpResponse
from azure.core.pipeline import PipelineResponse
from azure.core.exceptions import (
ClientAuthenticationError,
HttpResponseError,
ResourceExistsError,
ResourceNotFoundError,
ResourceNotModifiedError,
ODataV4Format,
map_error,
)
from ..models import _models
from .. import _model_base
from azure.core.polling import LROPoller, NoPolling, PollingMethod
from azure.core.polling.base_polling import (
LROBasePolling,
OperationResourcePolling,
_is_empty,
_as_json,
BadResponse,
OperationFailed,
_raise_if_bad_http_status_and_method,
)
from .. import _model_base, models as _models
from ..models import (
TranslationStatus,
)
from .._model_base import _deserialize
from ._operations import (
DocumentTranslationClientOperationsMixin as GeneratedDocumentTranslationClientOperationsMixin,
Expand All @@ -34,7 +45,7 @@
ClsType,
build_single_document_translation_document_translate_request,
)
from .._patch import DocumentTranslationLROPoller

from .._vendor import prepare_multipart_form_data

if sys.version_info >= (3, 9):
Expand All @@ -46,10 +57,228 @@
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] # type: ignore


ResponseType = Union[HttpResponse, AsyncHttpResponse]
PipelineResponseType = PipelineResponse[HttpRequest, ResponseType]
PollingReturnType_co = TypeVar("PollingReturnType_co", covariant=True)

_FINISHED = frozenset(["succeeded", "cancelled", "cancelling", "failed"])
_FAILED = frozenset(["validationfailed"])


def convert_status(status, ll=False):
catalinaperalta marked this conversation as resolved.
Show resolved Hide resolved
if ll is False:
if status == "Cancelled":
return "Canceled"
if status == "Cancelling":
return "Canceling"
elif ll is True:
if status == "Canceled":
return "Cancelled"
if status == "Canceling":
return "Cancelling"
return status


class DocumentTranslationLROPoller(LROPoller[PollingReturnType_co]):
"""A custom poller implementation for Document Translation. Call `result()` on the poller to return
a pageable of :class:`~azure.ai.translation.document.DocumentStatus`."""

_polling_method: "DocumentTranslationLROPollingMethod"

@property
def id(self) -> str:
"""The ID for the translation operation

:return: The str ID for the translation operation.
:rtype: str
"""
if self._polling_method._current_body:
return self._polling_method._current_body.id
return self._polling_method._get_id_from_headers()

@property
def details(self) -> TranslationStatus:
"""The details for the translation operation

:return: The details for the translation operation.
:rtype: ~azure.ai.translation.document.TranslationStatus
"""
if self._polling_method._current_body:
return TranslationStatus(self._polling_method._current_body)
return TranslationStatus(id=self._polling_method._get_id_from_headers()) # type: ignore

@classmethod
def from_continuation_token( # pylint: disable=docstring-missing-return,docstring-missing-param,docstring-missing-rtype
cls, polling_method, continuation_token, **kwargs: Any
):
"""
:meta private:
"""
(
client,
initial_response,
deserialization_callback,
) = polling_method.from_continuation_token(continuation_token, **kwargs)

return cls(client, initial_response, deserialization_callback, polling_method)


class DocumentTranslationLROPollingMethod(LROBasePolling):
"""A custom polling method implementation for Document Translation."""

def __init__(self, *args, **kwargs):
self._cont_token_response = kwargs.pop("cont_token_response")
super().__init__(*args, **kwargs)

@property
def _current_body(self) -> TranslationStatus:
try:
return TranslationStatus(self._pipeline_response.http_response.json())
except Exception: # pylint: disable=broad-exception-caught
return TranslationStatus() # type: ignore[call-overload]

def _get_id_from_headers(self) -> str:
return (
self._initial_response.http_response.headers["Operation-Location"]
.split("/batches/")[1]
.split("?api-version")[0]
)

def finished(self) -> bool:
"""Is this polling finished?

:return: True/False for whether polling is complete.
:rtype: bool
"""
return self._finished(self.status())

@staticmethod
def _finished(status) -> bool:
if hasattr(status, "value"):
status = status.value
return str(status).lower() in _FINISHED

@staticmethod
def _failed(status) -> bool:
if hasattr(status, "value"):
status = status.value
return str(status).lower() in _FAILED

def get_continuation_token(self) -> str:
if self._current_body:
return self._current_body.id
return self._get_id_from_headers()

# pylint: disable=arguments-differ
def from_continuation_token(self, continuation_token: str, **kwargs: Any) -> Tuple: # type: ignore[override]
try:
client = kwargs["client"]
except KeyError as exc:
raise ValueError("Need kwarg 'client' to be recreated from continuation_token") from exc

try:
deserialization_callback = kwargs["deserialization_callback"]
except KeyError as exc:
raise ValueError("Need kwarg 'deserialization_callback' to be recreated from continuation_token") from exc

return client, self._cont_token_response, deserialization_callback

def _poll(self) -> None:
"""Poll status of operation so long as operation is incomplete and
we have an endpoint to query.

:raises: OperationFailed if operation status 'Failed' or 'Canceled'.
:raises: BadStatus if response status invalid.
:raises: BadResponse if response invalid.
"""

while not self.finished():
self.update_status()
while not self.finished():
self._delay()
self.update_status()

if self._failed(self.status()):
raise OperationFailed("Operation failed or canceled")

final_get_url = self._operation.get_final_get_url(self._pipeline_response)
if final_get_url:
self._pipeline_response = self.request_status(final_get_url)
_raise_if_bad_http_status_and_method(self._pipeline_response.http_response)


class TranslationPolling(OperationResourcePolling):
"""Implements a Location polling."""

def can_poll(self, pipeline_response: PipelineResponseType) -> bool:
"""Answer if this polling method could be used.

:param pipeline_response: The PipelineResponse type
:type pipeline_response: PipelineResponseType
:return: Whether polling should be performed.
:rtype: bool
"""
response = pipeline_response.http_response
can_poll = self._operation_location_header in response.headers
if can_poll:
return True

if not _is_empty(response):
body = _as_json(response)
status = body.get("status")
if status:
return True
return False

def _set_async_url_if_present(self, response: ResponseType) -> None:
location_header = response.headers.get(self._operation_location_header)
if location_header:
self._async_url = location_header
else:
self._async_url = response.request.url

def get_status(self, pipeline_response: PipelineResponseType) -> str:
"""Process the latest status update retrieved from a 'location' header.

:param azure.core.pipeline.PipelineResponse pipeline_response: latest REST call response.
:return: The current operation status
:rtype: str
:raises: BadResponse if response has no body and not status 202.
"""
response = pipeline_response.http_response
if not _is_empty(response):
body = _as_json(response)
status = body.get("status")
if status:
return self._map_nonstandard_statuses(status, body)
raise BadResponse("No status found in body")
raise BadResponse("The response from long running operation does not contain a body.")

def _map_nonstandard_statuses(self, status: str, body: Dict[str, Any]) -> str:
"""Map non-standard statuses.

:param str status: lro process status.
:param str body: pipeline response body.
:return: The current operation status.
:rtype: str
"""
if status == "ValidationFailed":
self.raise_error(body)
return status

def raise_error(self, body: Dict[str, Any]) -> None:
error = body["error"]
if body["error"].get("innerError", None):
error = body["error"]["innerError"]
http_response_error = HttpResponseError(message="({}): {}".format(error["code"], error["message"]))
http_response_error.error = ODataV4Format(error) # set error.code
raise http_response_error


class DocumentTranslationClientOperationsMixin(GeneratedDocumentTranslationClientOperationsMixin):

@distributed_trace
def begin_start_translation( # type: ignore[override]
def _begin_start_translation( # type: ignore[override]
catalinaperalta marked this conversation as resolved.
Show resolved Hide resolved
self, body: Union[_models.StartTranslationDetails, JSON, IO[bytes]], **kwargs: Any
) -> DocumentTranslationLROPoller[_models.TranslationStatus]:
_headers = case_insensitive_dict(kwargs.pop("headers", {}) or {})
Expand All @@ -61,7 +290,7 @@ def begin_start_translation( # type: ignore[override]
lro_delay = kwargs.pop("polling_interval", self._config.polling_interval)
cont_token: Optional[str] = kwargs.pop("continuation_token", None)
if cont_token is None:
raw_result = self._start_translation_initial( # type: ignore[func-returns-value]
raw_result = self.__begin_start_translation_initial( # type: ignore[func-returns-value]
body=body, content_type=content_type, cls=lambda x, y, z: x, headers=_headers, params=_params, **kwargs
)
kwargs.pop("error_map", None)
Expand Down
Loading