From 30b196e4061d4b943008aaec1e32547bae2db883 Mon Sep 17 00:00:00 2001 From: iscai-msft <43154838+iscai-msft@users.noreply.github.com> Date: Wed, 22 Sep 2021 12:32:29 -0400 Subject: [PATCH] [rest] add backcompat mixin to rest requests (#20599) --- .../azure-core/azure/core/_pipeline_client.py | 15 +- .../azure/core/_pipeline_client_async.py | 5 +- .../azure/core/pipeline/transport/_base.py | 121 +----- .../azure-core/azure/core/rest/_helpers.py | 279 ++++++++---- sdk/core/azure-core/azure/core/rest/_rest.py | 25 +- .../azure-core/azure/core/rest/_rest_py3.py | 28 +- .../utils/_pipeline_transport_rest_shared.py | 204 +++++++++ .../async_tests/test_authentication_async.py | 60 +-- .../async_tests/test_base_polling_async.py | 104 +++-- .../async_tests/test_basic_transport_async.py | 186 ++++---- .../test_http_logging_policy_async.py | 25 +- .../tests/async_tests/test_pipeline_async.py | 37 +- .../tests/async_tests/test_request_asyncio.py | 14 +- .../tests/async_tests/test_request_trio.py | 13 +- .../async_tests/test_retry_policy_async.py | 55 ++- .../test_stream_generator_async.py | 9 +- .../tests/async_tests/test_streaming_async.py | 41 +- .../async_tests/test_testserver_async.py | 8 +- .../test_tracing_decorator_async.py | 37 +- .../async_tests/test_universal_http_async.py | 28 +- .../azure-core/tests/test_authentication.py | 96 +++-- .../azure-core/tests/test_base_polling.py | 106 +++-- .../azure-core/tests/test_basic_transport.py | 230 +++++----- .../tests/test_custom_hook_policy.py | 31 +- sdk/core/azure-core/tests/test_error_map.py | 22 +- .../tests/test_http_logging_policy.py | 24 +- sdk/core/azure-core/tests/test_pipeline.py | 105 +++-- .../tests/test_request_id_policy.py | 15 +- .../tests/test_requests_universal.py | 10 +- .../azure-core/tests/test_rest_backcompat.py | 403 ++++++++++++++++++ .../tests/test_rest_http_request.py | 117 ++++- .../azure-core/tests/test_retry_policy.py | 57 +-- .../azure-core/tests/test_stream_generator.py | 11 +- sdk/core/azure-core/tests/test_streaming.py | 42 +- sdk/core/azure-core/tests/test_testserver.py | 9 +- .../tests/test_tracing_decorator.py | 43 +- .../azure-core/tests/test_tracing_policy.py | 29 +- .../tests/test_universal_pipeline.py | 27 +- .../tests/test_user_agent_policy.py | 13 +- sdk/core/azure-core/tests/utils.py | 43 ++ 40 files changed, 1873 insertions(+), 854 deletions(-) create mode 100644 sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py create mode 100644 sdk/core/azure-core/tests/test_rest_backcompat.py create mode 100644 sdk/core/azure-core/tests/utils.py diff --git a/sdk/core/azure-core/azure/core/_pipeline_client.py b/sdk/core/azure-core/azure/core/_pipeline_client.py index 44e5995c5f84..f29f391c4ed2 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client.py @@ -65,17 +65,6 @@ _LOGGER = logging.getLogger(__name__) -def _prepare_request(request): - # returns the request ready to run through pipelines - # and a bool telling whether we ended up converting it - rest_request = False - try: - request_to_run = request._to_pipeline_transport_request() # pylint: disable=protected-access - rest_request = True - except AttributeError: - request_to_run = request - return rest_request, request_to_run - class PipelineClient(PipelineClientBase): """Service client core methods. @@ -204,9 +193,9 @@ def send_request(self, request, **kwargs): :return: The response of your network call. Does not do error handling on your response. :rtype: ~azure.core.rest.HttpResponse # """ - rest_request, request_to_run = _prepare_request(request) + rest_request = hasattr(request, "content") return_pipeline_response = kwargs.pop("_return_pipeline_response", False) - pipeline_response = self._pipeline.run(request_to_run, **kwargs) # pylint: disable=protected-access + pipeline_response = self._pipeline.run(request, **kwargs) # pylint: disable=protected-access response = pipeline_response.http_response if rest_request: response = _to_rest_response(response) diff --git a/sdk/core/azure-core/azure/core/_pipeline_client_async.py b/sdk/core/azure-core/azure/core/_pipeline_client_async.py index 357b3d9b917d..ad8b758d07af 100644 --- a/sdk/core/azure-core/azure/core/_pipeline_client_async.py +++ b/sdk/core/azure-core/azure/core/_pipeline_client_async.py @@ -37,7 +37,6 @@ RequestIdPolicy, AsyncRetryPolicy, ) -from ._pipeline_client import _prepare_request from .pipeline._tools_async import to_rest_response as _to_rest_response try: @@ -175,10 +174,10 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use return AsyncPipeline(transport, policies) async def _make_pipeline_call(self, request, **kwargs): - rest_request, request_to_run = _prepare_request(request) + rest_request = hasattr(request, "content") return_pipeline_response = kwargs.pop("_return_pipeline_response", False) pipeline_response = await self._pipeline.run( - request_to_run, **kwargs # pylint: disable=protected-access + request, **kwargs # pylint: disable=protected-access ) response = pipeline_response.http_response if rest_request: diff --git a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py index 761c61caa9ae..7c420a392059 100644 --- a/sdk/core/azure-core/azure/core/pipeline/transport/_base.py +++ b/sdk/core/azure-core/azure/core/pipeline/transport/_base.py @@ -34,7 +34,6 @@ from io import BytesIO import json import logging -import os import time import copy @@ -50,7 +49,6 @@ TYPE_CHECKING, Generic, TypeVar, - cast, IO, List, Union, @@ -63,7 +61,7 @@ Type ) -from six.moves.http_client import HTTPConnection, HTTPResponse as _HTTPResponse +from six.moves.http_client import HTTPResponse as _HTTPResponse from azure.core.exceptions import HttpResponseError from azure.core.pipeline import ( @@ -75,6 +73,12 @@ ) from .._tools import await_result as _await_result from ...utils._utils import _case_insensitive_dict +from ...utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _prepare_multipart_body_helper, + _serialize_request, + _format_data_helper, +) if TYPE_CHECKING: @@ -127,36 +131,6 @@ def _urljoin(base_url, stub_url): parsed = parsed._replace(path=parsed.path.rstrip("/") + "/" + stub_url) return parsed.geturl() - -class _HTTPSerializer(HTTPConnection, object): - """Hacking the stdlib HTTPConnection to serialize HTTP request as strings. - """ - - def __init__(self, *args, **kwargs): - self.buffer = b"" - kwargs.setdefault("host", "fakehost") - super(_HTTPSerializer, self).__init__(*args, **kwargs) - - def putheader(self, header, *values): - if header in ["Host", "Accept-Encoding"]: - return - super(_HTTPSerializer, self).putheader(header, *values) - - def send(self, data): - self.buffer += data - - -def _serialize_request(http_request): - serializer = _HTTPSerializer() - serializer.request( - method=http_request.method, - url=http_request.url, - body=http_request.body, - headers=http_request.headers, - ) - return serializer.buffer - - class HttpTransport( AbstractContextManager, ABC, Generic[HTTPRequestType, HTTPResponseType] ): # type: ignore @@ -253,16 +227,7 @@ def _format_data(data): :param data: The request field data. :type data: str or file-like object. """ - if hasattr(data, "read"): - data = cast(IO, data) - data_name = None - try: - if data.name[0] != "<" and data.name[-1] != ">": - data_name = os.path.basename(data.name) - except (AttributeError, TypeError): - pass - return (data_name, data, "application/octet-stream") - return (None, cast(str, data)) + return _format_data_helper(data) def format_parameters(self, params): # type: (Dict[str, str]) -> None @@ -272,26 +237,7 @@ def format_parameters(self, params): :param dict params: A dictionary of parameters. """ - query = urlparse(self.url).query - if query: - self.url = self.url.partition("?")[0] - existing_params = { - p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] - } - params.update(existing_params) - query_params = [] - for k, v in params.items(): - if isinstance(v, list): - for w in v: - if w is None: - raise ValueError("Query parameter {} cannot be None".format(k)) - query_params.append("{}={}".format(k, w)) - else: - if v is None: - raise ValueError("Query parameter {} cannot be None".format(k)) - query_params.append("{}={}".format(k, v)) - query = "?" + "&".join(query_params) - self.url = self.url + query + return _format_parameters_helper(self, params) def set_streamed_data_body(self, data): """Set a streamable data body. @@ -416,54 +362,7 @@ def prepare_multipart_body(self, content_index=0): :returns: The updated index after all parts in this request have been added. :rtype: int """ - if not self.multipart_mixed_info: - return 0 - - requests = self.multipart_mixed_info[0] # type: List[HttpRequest] - boundary = self.multipart_mixed_info[2] # type: Optional[str] - - # Update the main request with the body - main_message = Message() - main_message.add_header("Content-Type", "multipart/mixed") - if boundary: - main_message.set_boundary(boundary) - - for req in requests: - part_message = Message() - if req.multipart_mixed_info: - content_index = req.prepare_multipart_body(content_index=content_index) - part_message.add_header("Content-Type", req.headers['Content-Type']) - payload = req.serialize() - # We need to remove the ~HTTP/1.1 prefix along with the added content-length - payload = payload[payload.index(b'--'):] - else: - part_message.add_header("Content-Type", "application/http") - part_message.add_header("Content-Transfer-Encoding", "binary") - part_message.add_header("Content-ID", str(content_index)) - payload = req.serialize() - content_index += 1 - part_message.set_payload(payload) - main_message.attach(part_message) - - try: - from email.policy import HTTP - - full_message = main_message.as_bytes(policy=HTTP) - eol = b"\r\n" - except ImportError: # Python 2.7 - # Right now we decide to not support Python 2.7 on serialization, since - # it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it) - raise NotImplementedError( - "Multipart request are not supported on Python 2.7" - ) - # full_message = main_message.as_string() - # eol = b'\n' - _, _, body = full_message.split(eol, 2) - self.set_bytes_body(body) - self.headers["Content-Type"] = ( - "multipart/mixed; boundary=" + main_message.get_boundary() - ) - return content_index + return _prepare_multipart_body_helper(self, content_index) def serialize(self): # type: () -> bytes diff --git a/sdk/core/azure-core/azure/core/rest/_helpers.py b/sdk/core/azure-core/azure/core/rest/_helpers.py index 072d9f70992d..3aba4c5a8fda 100644 --- a/sdk/core/azure-core/azure/core/rest/_helpers.py +++ b/sdk/core/azure-core/azure/core/rest/_helpers.py @@ -23,12 +23,13 @@ # IN THE SOFTWARE. # # -------------------------------------------------------------------------- -import os import codecs import cgi -from enum import Enum from json import dumps -import collections +try: + import collections.abc as collections +except ImportError: + import collections # type: ignore from typing import ( Optional, Union, @@ -40,17 +41,23 @@ Any, Dict, Iterable, - Iterator, - cast, - Callable, ) import xml.etree.ElementTree as ET import six try: + binary_type = str from urlparse import urlparse # type: ignore except ImportError: + binary_type = bytes # type: ignore from urllib.parse import urlparse from azure.core.serialization import AzureJSONEncoder +from ..utils._pipeline_transport_rest_shared import ( + _format_parameters_helper, + _pad_attr_name, + _prepare_multipart_body_helper, + _serialize_request, + _format_data_helper, +) ################################### TYPES SECTION ######################### @@ -73,19 +80,6 @@ ContentTypeBase = Union[str, bytes, Iterable[bytes]] -class HttpVerbs(str, Enum): - GET = "GET" - PUT = "PUT" - POST = "POST" - HEAD = "HEAD" - PATCH = "PATCH" - DELETE = "DELETE" - MERGE = "MERGE" - -########################### ERRORS SECTION ################################# - - - ########################### HELPER SECTION ################################# def _verify_data_object(name, value): @@ -102,25 +96,6 @@ def _verify_data_object(name, value): ) ) -def _format_data(data): - # type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]] - """Format field data according to whether it is a stream or - a string for a form-data request. - - :param data: The request field data. - :type data: str or file-like object. - """ - if hasattr(data, "read"): - data = cast(IO, data) - data_name = None - try: - if data.name[0] != "<" and data.name[-1] != ">": - data_name = os.path.basename(data.name) - except (AttributeError, TypeError): - pass - return (data_name, data, "application/octet-stream") - return (None, cast(str, data)) - def set_urlencoded_body(data, has_files): body = {} default_headers = {} @@ -141,7 +116,7 @@ def set_urlencoded_body(data, has_files): def set_multipart_body(files): formatted_files = { - f: _format_data(d) for f, d in files.items() if d is not None + f: _format_data_helper(d) for f, d in files.items() if d is not None } return {}, formatted_files @@ -189,35 +164,6 @@ def set_json_body(json): "Content-Length": str(len(body)) }, body -def format_parameters(url, params): - """Format parameters into a valid query string. - It's assumed all parameters have already been quoted as - valid URL strings. - - :param dict params: A dictionary of parameters. - """ - query = urlparse(url).query - if query: - url = url.partition("?")[0] - existing_params = { - p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] - } - params.update(existing_params) - query_params = [] - for k, v in params.items(): - if isinstance(v, list): - for w in v: - if w is None: - raise ValueError("Query parameter {} cannot be None".format(k)) - query_params.append("{}={}".format(k, w)) - else: - if v is None: - raise ValueError("Query parameter {} cannot be None".format(k)) - query_params.append("{}={}".format(k, v)) - query = "?" + "&".join(query_params) - url += query - return url - def lookup_encoding(encoding): # type: (str) -> bool # including check for whether encoding is known taken from httpx @@ -227,25 +173,6 @@ def lookup_encoding(encoding): except LookupError: return False -def to_pipeline_transport_request_helper(rest_request): - from ..pipeline.transport import HttpRequest as PipelineTransportHttpRequest - return PipelineTransportHttpRequest( - method=rest_request.method, - url=rest_request.url, - headers=rest_request.headers, - files=rest_request._files, # pylint: disable=protected-access - data=rest_request._data # pylint: disable=protected-access - ) - -def from_pipeline_transport_request_helper(request_class, pipeline_transport_request): - return request_class( - method=pipeline_transport_request.method, - url=pipeline_transport_request.url, - headers=pipeline_transport_request.headers, - files=pipeline_transport_request.files, - data=pipeline_transport_request.data - ) - def get_charset_encoding(response): # type: (...) -> Optional[str] content_type = response.headers.get("Content-Type") @@ -267,3 +194,181 @@ def decode_to_text(encoding, content): if encoding: return content.decode(encoding) return codecs.getincrementaldecoder("utf-8-sig")(errors="replace").decode(content) + +class HttpRequestBackcompatMixin(object): + + def __getattr__(self, attr): + backcompat_attrs = [ + "files", + "data", + "multipart_mixed_info", + "query", + "body", + "format_parameters", + "set_streamed_data_body", + "set_text_body", + "set_xml_body", + "set_json_body", + "set_formdata_body", + "set_bytes_body", + "set_multipart_mixed", + "prepare_multipart_body", + "serialize", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + return self.__getattribute__(attr) + + def __setattr__(self, attr, value): + backcompat_attrs = [ + "multipart_mixed_info", + "files", + "data", + "body", + ] + attr = _pad_attr_name(attr, backcompat_attrs) + super(HttpRequestBackcompatMixin, self).__setattr__(attr, value) + + @property + def _multipart_mixed_info(self): + """DEPRECATED: Information used to make multipart mixed requests. + This is deprecated and will be removed in a later release. + """ + try: + return self._multipart_mixed_info_val + except AttributeError: + return None + + @_multipart_mixed_info.setter + def _multipart_mixed_info(self, val): + """DEPRECATED: Set information to make multipart mixed requests. + This is deprecated and will be removed in a later release. + """ + self._multipart_mixed_info_val = val + + @property + def _query(self): + """DEPRECATED: Query parameters passed in by user + This is deprecated and will be removed in a later release. + """ + query = urlparse(self.url).query + if query: + return {p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")]} + return {} + + @property + def _body(self): + """DEPRECATED: Body of the request. You should use the `content` property instead + This is deprecated and will be removed in a later release. + """ + return self._data + + @_body.setter + def _body(self, val): + """DEPRECATED: Set the body of the request + This is deprecated and will be removed in a later release. + """ + self._data = val + + def _format_parameters(self, params): + """DEPRECATED: Format the query parameters + This is deprecated and will be removed in a later release. + You should pass the query parameters through the kwarg `params` + instead. + """ + return _format_parameters_helper(self, params) + + def _set_streamed_data_body(self, data): + """DEPRECATED: Set the streamed request body. + This is deprecated and will be removed in a later release. + You should pass your stream content through the `content` kwarg instead + """ + if not isinstance(data, binary_type) and not any( + hasattr(data, attr) for attr in ["read", "__iter__", "__aiter__"] + ): + raise TypeError( + "A streamable data source must be an open file-like object or iterable." + ) + headers = self._set_body(content=data) + self._files = None + self.headers.update(headers) + + def _set_text_body(self, data): + """DEPRECATED: Set the text body + This is deprecated and will be removed in a later release. + You should pass your text content through the `content` kwarg instead + """ + headers = self._set_body(content=data) + self.headers.update(headers) + self._files = None + + def _set_xml_body(self, data): + """DEPRECATED: Set the xml body. + This is deprecated and will be removed in a later release. + You should pass your xml content through the `content` kwarg instead + """ + headers = self._set_body(content=data) + self.headers.update(headers) + self._files = None + + def _set_json_body(self, data): + """DEPRECATED: Set the json request body. + This is deprecated and will be removed in a later release. + You should pass your json content through the `json` kwarg instead + """ + headers = self._set_body(json=data) + self.headers.update(headers) + self._files = None + + def _set_formdata_body(self, data=None): + """DEPRECATED: Set the formrequest body. + This is deprecated and will be removed in a later release. + You should pass your stream content through the `files` kwarg instead + """ + if data is None: + data = {} + content_type = self.headers.pop("Content-Type", None) if self.headers else None + + if content_type and content_type.lower() == "application/x-www-form-urlencoded": + headers = self._set_body(data=data) + self._files = None + else: # Assume "multipart/form-data" + headers = self._set_body(files=data) + self._data = None + self.headers.update(headers) + + def _set_bytes_body(self, data): + """DEPRECATED: Set the bytes request body. + This is deprecated and will be removed in a later release. + You should pass your bytes content through the `content` kwarg instead + """ + headers = self._set_body(content=data) + # we don't want default Content-Type + # in 2.7, byte strings are still strings, so they get set with text/plain content type + + headers.pop("Content-Type", None) + self.headers.update(headers) + self._files = None + + def _set_multipart_mixed(self, *requests, **kwargs): + """DEPRECATED: Set the multipart mixed info. + This is deprecated and will be removed in a later release. + """ + self.multipart_mixed_info = ( + requests, + kwargs.pop("policies", []), + kwargs.pop("boundary", None), + kwargs + ) + + def _prepare_multipart_body(self, content_index=0): + """DEPRECATED: Prepare your request body for multipart requests. + This is deprecated and will be removed in a later release. + """ + return _prepare_multipart_body_helper(self, content_index) + + def _serialize(self): + """DEPRECATED: Serialize this request using application/http spec. + This is deprecated and will be removed in a later release. + :rtype: bytes + """ + return _serialize_request(self) diff --git a/sdk/core/azure-core/azure/core/rest/_rest.py b/sdk/core/azure-core/azure/core/rest/_rest.py index 88c7835b1df6..d021c26b569e 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest.py +++ b/sdk/core/azure-core/azure/core/rest/_rest.py @@ -32,16 +32,14 @@ from ..utils._utils import _case_insensitive_dict from ._helpers import ( - FilesType, set_content_body, set_json_body, set_multipart_body, set_urlencoded_body, - format_parameters, - to_pipeline_transport_request_helper, - from_pipeline_transport_request_helper, + _format_parameters_helper, get_charset_encoding, decode_to_text, + HttpRequestBackcompatMixin, ) from ..exceptions import ResponseNotReadError if TYPE_CHECKING: @@ -63,7 +61,7 @@ ################################## CLASSES ###################################### -class HttpRequest(object): +class HttpRequest(HttpRequestBackcompatMixin): """Provisional object that represents an HTTP request. **This object is provisional**, meaning it may be changed in a future release. @@ -107,7 +105,7 @@ def __init__(self, method, url, **kwargs): params = kwargs.pop("params", None) if params: - self.url = format_parameters(self.url, params) + _format_parameters_helper(self, params) self._files = None self._data = None @@ -127,10 +125,14 @@ def __init__(self, method, url, **kwargs): ) ) - def _set_body(self, content, data, files, json): - # type: (Optional[ContentType], Optional[dict], Optional[FilesType], Any) -> HeadersType + def _set_body(self, **kwargs): + # type: (Any) -> HeadersType """Sets the body of the request, and returns the default headers """ + content = kwargs.pop("content", None) + data = kwargs.pop("data", None) + files = kwargs.pop("files", None) + json = kwargs.pop("json", None) default_headers = {} if data is not None and not isinstance(data, dict): # should we warn? @@ -183,13 +185,6 @@ def __deepcopy__(self, memo=None): except (ValueError, TypeError): return copy.copy(self) - def _to_pipeline_transport_request(self): - return to_pipeline_transport_request_helper(self) - - @classmethod - def _from_pipeline_transport_request(cls, pipeline_transport_request): - return from_pipeline_transport_request_helper(cls, pipeline_transport_request) - class _HttpResponseBase(object): # pylint: disable=too-many-instance-attributes def __init__(self, **kwargs): diff --git a/sdk/core/azure-core/azure/core/rest/_rest_py3.py b/sdk/core/azure-core/azure/core/rest/_rest_py3.py index a69199815f81..e66cad7d34ee 100644 --- a/sdk/core/azure-core/azure/core/rest/_rest_py3.py +++ b/sdk/core/azure-core/azure/core/rest/_rest_py3.py @@ -31,11 +31,10 @@ Any, AsyncIterable, AsyncIterator, - Dict, Iterable, Iterator, Optional, - Type, Union, + cast, ) @@ -47,15 +46,13 @@ ParamsType, FilesType, HeadersType, - cast, set_json_body, set_multipart_body, set_urlencoded_body, - format_parameters, - to_pipeline_transport_request_helper, - from_pipeline_transport_request_helper, + _format_parameters_helper, get_charset_encoding, decode_to_text, + HttpRequestBackcompatMixin, ) from ._helpers_py3 import set_content_body from ..exceptions import ResponseNotReadError @@ -84,7 +81,7 @@ async def close(self): ################################## CLASSES ###################################### -class HttpRequest: +class HttpRequest(HttpRequestBackcompatMixin): """**Provisional** object that represents an HTTP request. **This object is provisional**, meaning it may be changed in a future release. @@ -137,7 +134,7 @@ def __init__( self.method = method if params: - self.url = format_parameters(self.url, params) + _format_parameters_helper(self, params) self._files = None self._data = None # type: Any @@ -159,10 +156,10 @@ def __init__( def _set_body( self, - content: Optional[ContentType], - data: Optional[dict], - files: Optional[FilesType], - json: Any, + content: Optional[ContentType] = None, + data: Optional[dict] = None, + files: Optional[FilesType] = None, + json: Any = None, ) -> HeadersType: """Sets the body of the request, and returns the default headers """ @@ -209,13 +206,6 @@ def __deepcopy__(self, memo=None) -> "HttpRequest": except (ValueError, TypeError): return copy.copy(self) - def _to_pipeline_transport_request(self): - return to_pipeline_transport_request_helper(self) - - @classmethod - def _from_pipeline_transport_request(cls, pipeline_transport_request): - return from_pipeline_transport_request_helper(cls, pipeline_transport_request) - class _HttpResponseBase: # pylint: disable=too-many-instance-attributes def __init__( diff --git a/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py new file mode 100644 index 000000000000..8a66e238090e --- /dev/null +++ b/sdk/core/azure-core/azure/core/utils/_pipeline_transport_rest_shared.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +from __future__ import absolute_import +import os +from typing import TYPE_CHECKING, cast, IO + +from email.message import Message +from six.moves.http_client import HTTPConnection + +try: + binary_type = str + from urlparse import urlparse # type: ignore +except ImportError: + binary_type = bytes # type: ignore + from urllib.parse import urlparse + +if TYPE_CHECKING: + from typing import ( # pylint: disable=ungrouped-imports + Dict, + List, + Union, + Tuple, + Optional, + ) + # importing both the py3 RestHttpRequest and the fallback RestHttpRequest + from azure.core.rest._rest_py3 import HttpRequest as RestHttpRequestPy3 + from azure.core.rest._rest import HttpRequest as RestHttpRequestPy2 + from azure.core.pipeline.transport import ( + HttpRequest as PipelineTransportHttpRequest + ) + HTTPRequestType = Union[ + RestHttpRequestPy3, RestHttpRequestPy2, PipelineTransportHttpRequest + ] + +def _format_parameters_helper(http_request, params): + """Helper for format_parameters. + + Format parameters into a valid query string. + It's assumed all parameters have already been quoted as + valid URL strings. + + :param http_request: The http request whose parameters + we are trying to format + :param dict params: A dictionary of parameters. + """ + query = urlparse(http_request.url).query + if query: + http_request.url = http_request.url.partition("?")[0] + existing_params = { + p[0]: p[-1] for p in [p.partition("=") for p in query.split("&")] + } + params.update(existing_params) + query_params = [] + for k, v in params.items(): + if isinstance(v, list): + for w in v: + if w is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, w)) + else: + if v is None: + raise ValueError("Query parameter {} cannot be None".format(k)) + query_params.append("{}={}".format(k, v)) + query = "?" + "&".join(query_params) + http_request.url = http_request.url + query + +def _pad_attr_name(attr, backcompat_attrs): + # type: (str, List[str]) -> str + """Pad hidden attributes so users can access them. + + Currently, for our backcompat attributes, we define them + as private, so they're hidden from intellisense and sphinx, + but still allow users to access them as public attributes + for backcompat purposes. This function is called so if + users access publicly call a private backcompat attribute, + we can return them the private variable in getattr + """ + return "_{}".format(attr) if attr in backcompat_attrs else attr + +def _prepare_multipart_body_helper(http_request, content_index=0): + # type: (HTTPRequestType, int) -> int + """Helper for prepare_multipart_body. + + Will prepare the body of this request according to the multipart information. + + This call assumes the on_request policies have been applied already in their + correct context (sync/async) + + Does nothing if "set_multipart_mixed" was never called. + :param http_request: The http request whose multipart body we are trying + to prepare + :param int content_index: The current index of parts within the batch message. + :returns: The updated index after all parts in this request have been added. + :rtype: int + """ + if not http_request.multipart_mixed_info: + return 0 + + requests = http_request.multipart_mixed_info[0] # type: List[HTTPRequestType] + boundary = http_request.multipart_mixed_info[2] # type: Optional[str] + + # Update the main request with the body + main_message = Message() + main_message.add_header("Content-Type", "multipart/mixed") + if boundary: + main_message.set_boundary(boundary) + + for req in requests: + part_message = Message() + if req.multipart_mixed_info: + content_index = req.prepare_multipart_body(content_index=content_index) + part_message.add_header("Content-Type", req.headers['Content-Type']) + payload = req.serialize() + # We need to remove the ~HTTP/1.1 prefix along with the added content-length + payload = payload[payload.index(b'--'):] + else: + part_message.add_header("Content-Type", "application/http") + part_message.add_header("Content-Transfer-Encoding", "binary") + part_message.add_header("Content-ID", str(content_index)) + payload = req.serialize() + content_index += 1 + part_message.set_payload(payload) + main_message.attach(part_message) + + try: + from email.policy import HTTP + + full_message = main_message.as_bytes(policy=HTTP) + eol = b"\r\n" + except ImportError: # Python 2.7 + # Right now we decide to not support Python 2.7 on serialization, since + # it doesn't serialize a valid HTTP request (and our main scenario Storage refuses it) + raise NotImplementedError( + "Multipart request are not supported on Python 2.7" + ) + # full_message = main_message.as_string() + # eol = b'\n' + _, _, body = full_message.split(eol, 2) + http_request.set_bytes_body(body) + http_request.headers["Content-Type"] = ( + "multipart/mixed; boundary=" + main_message.get_boundary() + ) + return content_index + +class _HTTPSerializer(HTTPConnection, object): + """Hacking the stdlib HTTPConnection to serialize HTTP request as strings. + """ + + def __init__(self, *args, **kwargs): + self.buffer = b"" + kwargs.setdefault("host", "fakehost") + super(_HTTPSerializer, self).__init__(*args, **kwargs) + + def putheader(self, header, *values): + if header in ["Host", "Accept-Encoding"]: + return + super(_HTTPSerializer, self).putheader(header, *values) + + def send(self, data): + self.buffer += data + +def _serialize_request(http_request): + # type: (HTTPRequestType) -> bytes + """Helper for serialize. + + Serialize a request using the application/http spec/ + + :param http_request: The http request which we are trying + to serialize. + :rtype: bytes + """ + serializer = _HTTPSerializer() + serializer.request( + method=http_request.method, + url=http_request.url, + body=http_request.body, + headers=http_request.headers, + ) + return serializer.buffer + +def _format_data_helper(data): + # type: (Union[str, IO]) -> Union[Tuple[None, str], Tuple[Optional[str], IO, str]] + """Helper for _format_data. + + Format field data according to whether it is a stream or + a string for a form-data request. + + :param data: The request field data. + :type data: str or file-like object. + """ + if hasattr(data, "read"): + data = cast(IO, data) + data_name = None + try: + if data.name[0] != "<" and data.name[-1] != ">": + data_name = os.path.basename(data.name) + except (AttributeError, TypeError): + pass + return (data_name, data, "application/octet-stream") + return (None, cast(str, data)) diff --git a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py index 7230018aa37f..0c1362f93594 100644 --- a/sdk/core/azure-core/tests/async_tests/test_authentication_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_authentication_async.py @@ -4,6 +4,7 @@ # license information. # ------------------------------------------------------------------------- import asyncio +from email.policy import HTTP import time from unittest.mock import Mock @@ -15,9 +16,10 @@ import pytest pytestmark = pytest.mark.asyncio +from utils import HTTP_REQUESTS - -async def test_bearer_policy_adds_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_adds_header(http_request): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) expected_token = AccessToken("expected_token", 2524608000) @@ -37,17 +39,18 @@ async def get_token(_): policies = [AsyncBearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = AsyncPipeline(transport=Mock(), policies=policies) - await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None) + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) assert get_token_calls == 1 - await pipeline.run(HttpRequest("GET", "https://spam.eggs"), context=None) + await pipeline.run(http_request("GET", "https://spam.eggs"), context=None) # Didn't need a new token assert get_token_calls == 1 -async def test_bearer_policy_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_send(http_request): """The bearer token policy should invoke the next policy's send method and return the result""" - expected_request = HttpRequest("GET", "https://spam.eggs") + expected_request = http_request("GET", "https://spam.eggs") expected_response = Mock() async def verify_request(request): @@ -61,7 +64,8 @@ async def verify_request(request): assert response is expected_response -async def test_bearer_policy_token_caching(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_token_caching(http_request): good_for_one_hour = AccessToken("token", time.time() + 3600) expected_token = good_for_one_hour get_token_calls = 0 @@ -78,10 +82,10 @@ async def get_token(_): ] pipeline = AsyncPipeline(transport=Mock, policies=policies) - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 # policy has no token at first request -> it should call get_token - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 # token is good for an hour -> policy should return it from cache expired_token = AccessToken("token", time.time()) @@ -93,14 +97,15 @@ async def get_token(_): ] pipeline = AsyncPipeline(transport=Mock(), policies=policies) - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 1 - await pipeline.run(HttpRequest("GET", "https://spam.eggs")) + await pipeline.run(http_request("GET", "https://spam.eggs")) assert get_token_calls == 2 # token expired -> policy should call get_token -async def test_bearer_policy_optionally_enforces_https(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_optionally_enforces_https(http_request): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" async def assert_option_popped(request, **kwargs): @@ -114,20 +119,21 @@ async def assert_option_popped(request, **kwargs): # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): - await pipeline.run(HttpRequest("GET", "http://not.secure")) + await pipeline.run(http_request("GET", "http://not.secure")) with pytest.raises(ServiceRequestError): - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=True) # when enforce_https=False, an insecure request should pass - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) # https requests should always pass - await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) - await pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) - await pipeline.run(HttpRequest("GET", "https://secure")) + await pipeline.run(http_request("GET", "https://secure"), enforce_https=False) + await pipeline.run(http_request("GET", "https://secure"), enforce_https=True) + await pipeline.run(http_request("GET", "https://secure")) -async def test_bearer_policy_preserves_enforce_https_opt_out(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_preserves_enforce_https_opt_out(http_request): """The policy should use request context to preserve an opt out from https enforcement""" class ContextValidator(SansIOHTTPPolicy): @@ -140,10 +146,11 @@ def on_request(self, request): policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies) - await pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + await pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) -async def test_bearer_policy_context_unmodified_by_default(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_context_unmodified_by_default(http_request): """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" class ContextValidator(SansIOHTTPPolicy): @@ -156,10 +163,11 @@ def on_request(self, request): policies = [AsyncBearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = AsyncPipeline(transport=Mock(send=lambda *_, **__: get_completed_future(Mock())), policies=policies) - await pipeline.run(HttpRequest("GET", "https://secure")) + await pipeline.run(http_request("GET", "https://secure")) -async def test_bearer_policy_calls_sansio_methods(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_bearer_policy_calls_sansio_methods(http_request): """AsyncBearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOAsyncHTTPPolicyRunner""" class TestPolicy(AsyncBearerTokenCredentialPolicy): @@ -179,7 +187,7 @@ async def send(self, request): transport = Mock(send=Mock(return_value=get_completed_future(Mock(status_code=200)))) pipeline = AsyncPipeline(transport=transport, policies=[policy]) - await pipeline.run(HttpRequest("GET", "https://localhost")) + await pipeline.run(http_request("GET", "https://localhost")) policy.on_request.assert_called_once_with(policy.request) policy.on_response.assert_called_once_with(policy.request, policy.response) @@ -193,7 +201,7 @@ class TestException(Exception): policy = TestPolicy(credential, "scope") pipeline = AsyncPipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - await pipeline.run(HttpRequest("GET", "https://localhost")) + await pipeline.run(http_request("GET", "https://localhost")) policy.on_exception.assert_called_once_with(policy.request) # ...or the second @@ -209,7 +217,7 @@ async def fake_send(*args, **kwargs): transport = Mock(send=Mock(wraps=fake_send)) pipeline = AsyncPipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - await pipeline.run(HttpRequest("GET", "https://localhost")) + await pipeline.run(http_request("GET", "https://localhost")) assert transport.send.call_count == 2 policy.on_challenge.assert_called_once() policy.on_exception.assert_called_once_with(policy.request) diff --git a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py index f010abaaddf4..5165c523913d 100644 --- a/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_base_polling_async.py @@ -27,6 +27,7 @@ import json import pickle import re +from utils import HTTP_REQUESTS, is_rest import types import unittest try: @@ -83,8 +84,9 @@ class BadEndpointError(Exception): POLLING_STATUS = 200 CLIENT = AsyncPipelineClient("http://example.org") +CLIENT.http_request_type = None async def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(request.url) + return TestBasePolling.mock_update(client_self.http_request_type, request.url) CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -162,12 +164,14 @@ def test_base_polling_continuation_token(client, polling_response): @pytest.mark.asyncio -async def test_post(async_pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_post(async_pipeline_client_builder, deserialization_cb, http_request): # Test POST LRO with both Location and Operation-Location # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, 'POST', 202, { @@ -182,12 +186,14 @@ async def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'status': 'Succeeded'} @@ -213,12 +219,14 @@ async def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, 'GET', 200, body=None ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'status': 'Succeeded'} @@ -238,12 +246,14 @@ async def send(request, **kwargs): @pytest.mark.asyncio -async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_post_resource_location(async_pipeline_client_builder, deserialization_cb, http_request): # ResourceLocation # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, 'POST', 202, { @@ -257,12 +267,14 @@ async def send(request, **kwargs): if request.url == 'http://example.org/resource_location': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} @@ -285,7 +297,7 @@ class TestBasePolling(object): convert = re.compile('([a-z0-9])([A-Z])') @staticmethod - def mock_send(method, status, headers=None, body=RESPONSE_BODY): + def mock_send(http_request, method, status, headers=None, body=RESPONSE_BODY): if headers is None: headers = {} response = Response() @@ -302,15 +314,23 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" - request = CLIENT._request( - response.request.method, - response.request.url, - None, # params - response.request.headers, - body, - None, # form_content - None # stream_content - ) + if is_rest(http_request): + request = http_request( + response.request.method, + response.request.url, + headers=response.request.headers, + content=body, + ) + else: + request = CLIENT._request( + response.request.method, + response.request.url, + None, # params + response.request.headers, + body, + None, # form_content + None # stream_content + ) return PipelineResponse( request, @@ -322,7 +342,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): ) @staticmethod - def mock_update(url, headers=None): + def mock_update(http_request, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) @@ -354,14 +374,9 @@ def mock_update(url, headers=None): else: raise Exception('URL does not match') - request = CLIENT._request( + request = http_request( response.request.method, response.request.url, - None, # params - {}, # request has no headers - None, # Request has no body - None, # form_content - None # stream_content ) return PipelineResponse( @@ -404,11 +419,14 @@ def mock_deserialization_no_body(pipeline_response): return None @pytest.mark.asyncio -async def test_long_running_put(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_long_running_put(http_request): #TODO: Test custom header field - + CLIENT.http_request_type = http_request # Test throw on non LRO related status code - response = TestBasePolling.mock_send('PUT', 1000, {}) + response = TestBasePolling.mock_send( + http_request, 'PUT', 1000, {} + ) with pytest.raises(HttpResponseError): await async_poller(CLIENT, response, TestBasePolling.mock_outputs, @@ -420,6 +438,7 @@ async def test_long_running_put(): 'name': TEST_NAME } response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {}, response_body ) @@ -435,6 +454,7 @@ def no_update_allowed(url, headers=None): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'operation-location': ASYNC_URL}) polling_method = AsyncLROBasePolling(0) @@ -446,6 +466,7 @@ def no_update_allowed(url, headers=None): # Test polling location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'location': LOCATION_URL}) polling_method = AsyncLROBasePolling(0) @@ -458,6 +479,7 @@ def no_update_allowed(url, headers=None): # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'location': LOCATION_URL}, response_body) polling_method = AsyncLROBasePolling(0) @@ -469,6 +491,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -478,6 +501,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -486,10 +510,12 @@ def no_update_allowed(url, headers=None): AsyncLROBasePolling(0)) @pytest.mark.asyncio -async def test_long_running_patch(): - +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_long_running_patch(http_request): + CLIENT.http_request_type = http_request # Test polling from location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -502,6 +528,7 @@ async def test_long_running_patch(): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -514,6 +541,7 @@ async def test_long_running_patch(): # Test polling from location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 200, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -526,6 +554,7 @@ async def test_long_running_patch(): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 200, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -538,6 +567,7 @@ async def test_long_running_patch(): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -547,6 +577,7 @@ async def test_long_running_patch(): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -555,9 +586,12 @@ async def test_long_running_patch(): AsyncLROBasePolling(0)) @pytest.mark.asyncio -async def test_long_running_delete(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_long_running_delete(http_request): # Test polling from operation-location header + CLIENT.http_request_type = http_request response = TestBasePolling.mock_send( + http_request, 'DELETE', 202, {'operation-location': ASYNC_URL}, body="" @@ -570,10 +604,12 @@ async def test_long_running_delete(): assert polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None @pytest.mark.asyncio -async def test_long_running_post(): - +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_long_running_post(http_request): + CLIENT.http_request_type = http_request # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'POST', 201, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -585,6 +621,7 @@ async def test_long_running_post(): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -596,6 +633,7 @@ async def test_long_running_post(): # Test polling from location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -608,6 +646,7 @@ async def test_long_running_post(): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -617,6 +656,7 @@ async def test_long_running_post(): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -625,13 +665,15 @@ async def test_long_running_post(): AsyncLROBasePolling(0)) @pytest.mark.asyncio -async def test_long_running_negative(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_long_running_negative(http_request): global LOCATION_BODY global POLLING_STATUS - + CLIENT.http_request_type = http_request # Test LRO PUT throws for invalid json LOCATION_BODY = '{' response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller( @@ -645,6 +687,7 @@ async def test_long_running_negative(): LOCATION_BODY = '{\'"}' response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller(CLIENT, response, @@ -656,6 +699,7 @@ async def test_long_running_negative(): LOCATION_BODY = '{' POLLING_STATUS = 203 response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}) poll = async_poller(CLIENT, response, diff --git a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py index 1ae80db6f60c..5632a2097131 100644 --- a/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_basic_transport_async.py @@ -3,11 +3,11 @@ # Licensed under the MIT License. See LICENSE.txt in the project root for # license information. # ------------------------------------------------------------------------- -from azure.core.pipeline.transport import HttpRequest, AsyncHttpResponse, AsyncHttpTransport, AioHttpTransport +from azure.core.pipeline.transport import AsyncHttpResponse, AsyncHttpTransport, AioHttpTransport from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import AsyncPipeline from azure.core.exceptions import HttpResponseError - +from utils import HTTP_REQUESTS import pytest @@ -34,9 +34,10 @@ def body(self): @pytest.mark.asyncio -async def test_basic_options_aiohttp(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_options_aiohttp(port, http_request): - request = HttpRequest("OPTIONS", "http://localhost:{}/basic/string".format(port)) + request = http_request("OPTIONS", "http://localhost:{}/basic/string".format(port)) async with AsyncPipeline(AioHttpTransport(), policies=[]) as pipeline: response = await pipeline.run(request) @@ -45,7 +46,8 @@ async def test_basic_options_aiohttp(port): @pytest.mark.asyncio -async def test_multipart_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send(http_request): transport = MockAsyncHttpTransport() class RequestPolicy(object): @@ -53,10 +55,10 @@ async def on_request(self, request): # type: (PipelineRequest) -> None request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -91,7 +93,8 @@ async def on_request(self, request): @pytest.mark.asyncio -async def test_multipart_send_with_context(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_context(http_request): transport = MockAsyncHttpTransport() header_policy = HeadersPolicy() @@ -101,10 +104,10 @@ async def on_request(self, request): # type: (PipelineRequest) -> None request.http_request.headers['x-ms-date'] = 'Thu, 14 Jun 2018 16:46:54 GMT' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -142,19 +145,20 @@ async def on_request(self, request): @pytest.mark.asyncio -async def test_multipart_send_with_one_changeset(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_one_changeset(http_request): transport = MockAsyncHttpTransport() requests = [ - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ] - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( *requests, boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" @@ -190,22 +194,23 @@ async def test_multipart_send_with_one_changeset(): @pytest.mark.asyncio -async def test_multipart_send_with_multiple_changesets(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_multiple_changesets(http_request): transport = MockAsyncHttpTransport() - changeset1 = HttpRequest("", "") + changeset1 = http_request("", "") changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - changeset2 = HttpRequest("", "") + changeset2 = http_request("", "") changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3"), + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3"), boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset1, changeset2, @@ -263,19 +268,20 @@ async def test_multipart_send_with_multiple_changesets(): @pytest.mark.asyncio -async def test_multipart_send_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_combination_changeset_first(http_request): transport = MockAsyncHttpTransport() - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -317,17 +323,18 @@ async def test_multipart_send_with_combination_changeset_first(): @pytest.mark.asyncio -async def test_multipart_send_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_combination_changeset_last(http_request): transport = MockAsyncHttpTransport() - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -370,18 +377,19 @@ async def test_multipart_send_with_combination_changeset_last(): @pytest.mark.asyncio -async def test_multipart_send_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_send_with_combination_changeset_middle(http_request): transport = MockAsyncHttpTransport() - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container1/blob1"), boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -423,7 +431,8 @@ async def test_multipart_send_with_combination_changeset_middle(): @pytest.mark.asyncio -async def test_multipart_receive(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive(http_request): class ResponsePolicy(object): def on_response(self, request, response): @@ -435,10 +444,10 @@ async def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None response.http_response.headers['x-ms-async-fun'] = 'true' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -496,14 +505,15 @@ async def on_response(self, request, response): @pytest.mark.asyncio -async def test_multipart_receive_with_one_changeset(): - changeset = HttpRequest("", "") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive_with_one_changeset(http_request): + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -550,20 +560,21 @@ async def test_multipart_receive_with_one_changeset(): @pytest.mark.asyncio -async def test_multipart_receive_with_multiple_changesets(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive_with_multiple_changesets(http_request): - changeset1 = HttpRequest("", "") + changeset1 = http_request("", "") changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - changeset2 = HttpRequest("", "") + changeset2 = http_request("", "") changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3") + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset1, changeset2) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -636,16 +647,17 @@ async def test_multipart_receive_with_multiple_changesets(): @pytest.mark.asyncio -async def test_multipart_receive_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive_with_combination_changeset_first(http_request): - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(changeset, HttpRequest("DELETE", "/container2/blob2")) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(changeset, http_request("DELETE", "/container2/blob2")) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' @@ -712,16 +724,17 @@ def test_raise_for_status_good_response(): @pytest.mark.asyncio -async def test_multipart_receive_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive_with_combination_changeset_middle(http_request): - changeset = HttpRequest("", "") - changeset.set_multipart_mixed(HttpRequest("DELETE", "/container1/blob1")) + changeset = http_request("", "") + changeset.set_multipart_mixed(http_request("DELETE", "/container1/blob1")) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container2/blob2") ) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -778,16 +791,17 @@ async def test_multipart_receive_with_combination_changeset_middle(): @pytest.mark.asyncio -async def test_multipart_receive_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive_with_combination_changeset_last(http_request): - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(HttpRequest("DELETE", "/container0/blob0"), changeset) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(http_request("DELETE", "/container0/blob0"), changeset) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -844,11 +858,12 @@ async def test_multipart_receive_with_combination_changeset_last(): @pytest.mark.asyncio -async def test_multipart_receive_with_bom(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_multipart_receive_with_bom(http_request): - req0 = HttpRequest("DELETE", "/container0/blob0") + req0 = http_request("DELETE", "/container0/blob0") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) body_as_bytes = ( b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\n" @@ -883,12 +898,13 @@ async def test_multipart_receive_with_bom(): @pytest.mark.asyncio -async def test_recursive_multipart_receive(): - req0 = HttpRequest("DELETE", "/container0/blob0") - internal_req0 = HttpRequest("DELETE", "/container0/blob0") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_recursive_multipart_receive(http_request): + req0 = http_request("DELETE", "/container0/blob0") + internal_req0 = http_request("DELETE", "/container0/blob0") req0.set_multipart_mixed(internal_req0) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) internal_body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" diff --git a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py index 0c4e931bc4e2..640d085103aa 100644 --- a/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_http_logging_policy_async.py @@ -15,15 +15,15 @@ PipelineContext ) from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, ) from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) +from utils import HTTP_REQUESTS - -def test_http_logger(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -41,7 +41,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') http_response = HttpResponse(universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -136,8 +136,8 @@ def emit(self, record): mock_handler.reset() - -def test_http_logger_operation_level(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger_operation_level(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -156,7 +156,7 @@ def emit(self, record): policy = HttpLoggingPolicy() kwargs={'logger': logger} - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') http_response = HttpResponse(universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -207,8 +207,8 @@ def emit(self, record): mock_handler.reset() - -def test_http_logger_with_body(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger_with_body(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -226,7 +226,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') universal_request.body = "testbody" http_response = HttpResponse(universal_request, None) http_response.status_code = 202 @@ -248,8 +248,9 @@ def emit(self, record): mock_handler.reset() +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) @pytest.mark.skipif(sys.version_info < (3, 6), reason="types.AsyncGeneratorType does not exist in 3.5") -def test_http_logger_with_generator_body(): +def test_http_logger_with_generator_body(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -267,7 +268,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') mock = Mock() mock.__class__ = types.AsyncGeneratorType universal_request.body = mock diff --git a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py index a84d53093ca3..7117d5c50bf6 100644 --- a/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_pipeline_async.py @@ -38,7 +38,6 @@ ) from azure.core.pipeline.transport import ( AsyncHttpTransport, - HttpRequest, AsyncioRequestsTransport, TrioRequestsTransport, AioHttpTransport @@ -55,10 +54,12 @@ import trio import pytest +from utils import HTTP_REQUESTS @pytest.mark.asyncio -async def test_sans_io_exception(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_sans_io_exception(http_request): class BrokenSender(AsyncHttpTransport): async def send(self, request, **config): raise ValueError("Broken") @@ -75,7 +76,7 @@ async def __aexit__(self, exc_type, exc_value, traceback): pipeline = AsyncPipeline(BrokenSender(), [SansIOHTTPPolicy()]) - req = HttpRequest('GET', '/') + req = http_request('GET', '/') with pytest.raises(ValueError): await pipeline.run(req) @@ -89,9 +90,10 @@ def on_exception(self, requests, **kwargs): await pipeline.run(req) @pytest.mark.asyncio -async def test_basic_aiohttp(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_aiohttp(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -104,10 +106,11 @@ async def test_basic_aiohttp(port): assert isinstance(response.http_response.status_code, int) @pytest.mark.asyncio -async def test_basic_aiohttp_separate_session(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_aiohttp_separate_session(port, http_request): session = aiohttp.ClientSession() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -123,9 +126,10 @@ async def test_basic_aiohttp_separate_session(port): await transport.session.close() @pytest.mark.asyncio -async def test_basic_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -185,9 +189,10 @@ def test_pass_in_http_logging_policy(): assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) @pytest.mark.asyncio -async def test_conf_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_conf_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -197,10 +202,11 @@ async def test_conf_async_requests(port): assert isinstance(response.http_response.status_code, int) -def test_conf_async_trio_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_conf_async_trio_requests(port, http_request): async def do(): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), AsyncRedirectPolicy() @@ -212,7 +218,8 @@ async def do(): assert isinstance(response.http_response.status_code, int) @pytest.mark.asyncio -async def test_retry_without_http_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_without_http_response(http_request): class NaughtyPolicy(AsyncHTTPPolicy): def send(*args): raise AzureError('boo') @@ -220,7 +227,7 @@ def send(*args): policies = [AsyncRetryPolicy(), NaughtyPolicy()] pipeline = AsyncPipeline(policies=policies, transport=None) with pytest.raises(AzureError): - await pipeline.run(HttpRequest('GET', url='https://foo.bar')) + await pipeline.run(http_request('GET', url='https://foo.bar')) @pytest.mark.asyncio async def test_add_custom_policy(): diff --git a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py index 55856ae6ca94..8a8199e65429 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_asyncio.py @@ -5,13 +5,14 @@ # ------------------------------------------------------------------------- import json -from azure.core.pipeline.transport import AsyncioRequestsTransport, HttpRequest - +from azure.core.pipeline.transport import AsyncioRequestsTransport +from utils import HTTP_REQUESTS import pytest @pytest.mark.asyncio -async def test_async_gen_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_async_gen_data(port, http_request): class AsyncGen: def __init__(self): self._range = iter([b"azerty"]) @@ -26,13 +27,14 @@ async def __anext__(self): raise StopAsyncIteration async with AsyncioRequestsTransport() as transport: - req = HttpRequest('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) + req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" @pytest.mark.asyncio -async def test_send_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_send_data(port, http_request): async with AsyncioRequestsTransport() as transport: - req = HttpRequest('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") + req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" diff --git a/sdk/core/azure-core/tests/async_tests/test_request_trio.py b/sdk/core/azure-core/tests/async_tests/test_request_trio.py index 11a0058404f5..7fafa0d41c28 100644 --- a/sdk/core/azure-core/tests/async_tests/test_request_trio.py +++ b/sdk/core/azure-core/tests/async_tests/test_request_trio.py @@ -5,13 +5,15 @@ # ------------------------------------------------------------------------- import json -from azure.core.pipeline.transport import TrioRequestsTransport, HttpRequest +from azure.core.pipeline.transport import TrioRequestsTransport +from utils import HTTP_REQUESTS import pytest @pytest.mark.trio -async def test_async_gen_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_async_gen_data(port, http_request): class AsyncGen: def __init__(self): self._range = iter([b"azerty"]) @@ -26,14 +28,15 @@ async def __anext__(self): raise StopAsyncIteration async with TrioRequestsTransport() as transport: - req = HttpRequest('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) + req = http_request('GET', 'http://localhost:{}/basic/anything'.format(port), data=AsyncGen()) response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" @pytest.mark.trio -async def test_send_data(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_send_data(port, http_request): async with TrioRequestsTransport() as transport: - req = HttpRequest('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") + req = http_request('PUT', 'http://localhost:{}/basic/anything'.format(port), data=b"azerty") response = await transport.send(req) assert json.loads(response.text())['data'] == "azerty" \ No newline at end of file diff --git a/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py b/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py index faa0944e6399..270cc5c5684d 100644 --- a/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_retry_policy_async.py @@ -24,7 +24,6 @@ ) from azure.core.pipeline import AsyncPipeline, PipelineResponse from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, AsyncHttpTransport, ) @@ -32,6 +31,8 @@ import os import time import asyncio +from itertools import product +from utils import HTTP_REQUESTS def test_retry_code_class_variables(): retry_policy = AsyncRetryPolicy() @@ -59,10 +60,10 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) +def test_retry_after(retry_after_input, http_request): retry_policy = AsyncRetryPolicy() - request = HttpRequest("GET", "http://localhost") + request = http_request("GET", "http://localhost") response = HttpResponse(request, None) response.headers["retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) @@ -77,10 +78,10 @@ def test_retry_after(retry_after_input): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_x_ms_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) +def test_x_ms_retry_after(retry_after_input, http_request): retry_policy = AsyncRetryPolicy() - request = HttpRequest("GET", "http://localhost") + request = http_request("GET", "http://localhost") response = HttpResponse(request, None) response.headers["x-ms-retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) @@ -96,7 +97,8 @@ def test_x_ms_retry_after(retry_after_input): assert retry_after == float(retry_after_input) @pytest.mark.asyncio -async def test_retry_on_429(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_on_429(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 @@ -113,7 +115,7 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe response.status_code = 429 return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_retry = AsyncRetryPolicy(retry_total = 1) transport = MockTransport() pipeline = AsyncPipeline(transport, [http_retry]) @@ -121,7 +123,8 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe assert transport._count == 2 @pytest.mark.asyncio -async def test_no_retry_on_201(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_no_retry_on_201(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._count = 0 @@ -140,7 +143,7 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe response.headers = headers return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_retry = AsyncRetryPolicy(retry_total = 1) transport = MockTransport() pipeline = AsyncPipeline(transport, [http_retry]) @@ -148,7 +151,8 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe assert transport._count == 1 @pytest.mark.asyncio -async def test_retry_seekable_stream(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_seekable_stream(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._first = True @@ -171,14 +175,15 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe return response data = BytesIO(b"Lots of dataaaa") - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_request.set_streamed_data_body(data) http_retry = AsyncRetryPolicy(retry_total = 1) pipeline = AsyncPipeline(MockTransport(), [http_retry]) await pipeline.run(http_request) @pytest.mark.asyncio -async def test_retry_seekable_file(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_seekable_file(http_request): class MockTransport(AsyncHttpTransport): def __init__(self): self._first = True @@ -209,7 +214,7 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe file = tempfile.NamedTemporaryFile(delete=False) file.write(b'Lots of dataaaa') file.close() - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') headers = {'Content-Type': "multipart/form-data"} http_request.headers = headers with open(file.name, 'rb') as f: @@ -225,7 +230,8 @@ async def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> Pipe @pytest.mark.asyncio -async def test_retry_timeout(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_retry_timeout(http_request): timeout = 1 def send(request, **kwargs): @@ -241,11 +247,12 @@ def send(request, **kwargs): pipeline = AsyncPipeline(transport, [AsyncRetryPolicy(timeout=timeout)]) with pytest.raises(ServiceResponseTimeoutError): - await pipeline.run(HttpRequest("GET", "http://localhost/")) + await pipeline.run(http_request("GET", "http://localhost/")) @pytest.mark.asyncio -async def test_timeout_defaults(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_timeout_defaults(http_request): """When "timeout" is not set, the policy should not override the transport's timeout configuration""" async def send(request, **kwargs): @@ -262,18 +269,20 @@ async def send(request, **kwargs): ) pipeline = AsyncPipeline(transport, [AsyncRetryPolicy()]) - await pipeline.run(HttpRequest("GET", "http://localhost/")) + await pipeline.run(http_request("GET", "http://localhost/")) assert transport.send.call_count == 1, "policy should not retry: its first send succeeded" +combinations = [(ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)] @pytest.mark.asyncio @pytest.mark.parametrize( - "transport_error,expected_timeout_error", - ((ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)), + "combinations,http_request", + product(combinations, HTTP_REQUESTS), ) -async def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error): +async def test_does_not_sleep_after_timeout(combinations, http_request): # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s. # It should not sleep the second time when given timeout=1 + transport_error, expected_timeout_error = combinations timeout = 1 transport = Mock( @@ -284,6 +293,6 @@ async def test_does_not_sleep_after_timeout(transport_error, expected_timeout_er pipeline = AsyncPipeline(transport, [AsyncRetryPolicy(timeout=timeout)]) with pytest.raises(expected_timeout_error): - await pipeline.run(HttpRequest("GET", "http://localhost/")) + await pipeline.run(http_request("GET", "http://localhost/")) assert transport.sleep.call_count == 1 diff --git a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py index d90b5b15b4c9..f213b92dbbf7 100644 --- a/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_stream_generator_async.py @@ -4,7 +4,6 @@ # ------------------------------------ import requests from azure.core.pipeline.transport import ( - HttpRequest, AsyncHttpResponse, AsyncHttpTransport, AsyncioRequestsTransportResponse, @@ -14,9 +13,11 @@ from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator from unittest import mock import pytest +from utils import HTTP_REQUESTS @pytest.mark.asyncio -async def test_connection_error_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_connection_error_response(http_request): class MockSession(object): def __init__(self): self.auto_decompress = True @@ -38,7 +39,7 @@ async def open(self): pass async def send(self, request, **kwargs): - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') response = AsyncHttpResponse(request, None) response.status_code = 200 return response @@ -65,7 +66,7 @@ class AsyncMock(mock.MagicMock): async def __call__(self, *args, **kwargs): return super(AsyncMock, self).__call__(*args, **kwargs) - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') pipeline = AsyncPipeline(MockTransport()) http_response = AsyncHttpResponse(http_request, None) http_response.internal_response = MockInternalResponse() diff --git a/sdk/core/azure-core/tests/async_tests/test_streaming_async.py b/sdk/core/azure-core/tests/async_tests/test_streaming_async.py index 6b4c9ac3b912..8ca4413be936 100644 --- a/sdk/core/azure-core/tests/async_tests/test_streaming_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_streaming_async.py @@ -27,16 +27,18 @@ import pytest from azure.core import AsyncPipelineClient from azure.core.exceptions import DecodeError +from utils import HTTP_REQUESTS @pytest.mark.asyncio -async def test_decompress_plain_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -47,14 +49,15 @@ async def test_decompress_plain_no_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_compress_plain_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -65,14 +68,15 @@ async def test_compress_plain_no_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_decompress_compressed_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -86,14 +90,15 @@ async def test_decompress_compressed_no_header(): pass @pytest.mark.asyncio -async def test_compress_compressed_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -107,7 +112,8 @@ async def test_compress_compressed_no_header(): pass @pytest.mark.asyncio -async def test_decompress_plain_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_plain_header(http_request): # expect error import zlib account_name = "coretests" @@ -115,7 +121,7 @@ async def test_decompress_plain_header(): url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -128,14 +134,15 @@ async def test_decompress_plain_header(): pass @pytest.mark.asyncio -async def test_compress_plain_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_plain_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -146,14 +153,15 @@ async def test_compress_plain_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_decompress_compressed_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_decompress_compressed_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -164,14 +172,15 @@ async def test_decompress_compressed_header(): assert decoded == "test" @pytest.mark.asyncio -async def test_compress_compressed_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_compress_compressed_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = AsyncPipelineClient(account_url) async with client: - request = client.get(url) + request = http_request("GET", url) pipeline_response = await client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) diff --git a/sdk/core/azure-core/tests/async_tests/test_testserver_async.py b/sdk/core/azure-core/tests/async_tests/test_testserver_async.py index 4501e2dc3887..d6557b2b3e7c 100644 --- a/sdk/core/azure-core/tests/async_tests/test_testserver_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_testserver_async.py @@ -24,12 +24,14 @@ # # -------------------------------------------------------------------------- import pytest -from azure.core.pipeline.transport import HttpRequest, AioHttpTransport +from azure.core.pipeline.transport import AioHttpTransport +from utils import HTTP_REQUESTS """This file does a simple call to the testserver to make sure we can use the testserver""" @pytest.mark.asyncio -async def test_smoke(port): - request = HttpRequest(method="GET", url="http://localhost:{}/basic/string".format(port)) +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_smoke(port, http_request): + request = http_request(method="GET", url="http://localhost:{}/basic/string".format(port)) async with AioHttpTransport() as sender: response = await sender.send(request) response.raise_for_status() diff --git a/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py b/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py index 62b606218a0e..1b813ba53cde 100644 --- a/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_tracing_decorator_async.py @@ -15,11 +15,12 @@ import pytest from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.policies import HTTPPolicy -from azure.core.pipeline.transport import HttpTransport, HttpRequest +from azure.core.pipeline.transport import HttpTransport from azure.core.settings import settings from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async from tracing_common import FakeSpan +from utils import HTTP_REQUESTS @pytest.fixture(scope="module") @@ -29,9 +30,9 @@ def fake_span(): class MockClient: @distributed_trace - def __init__(self, policies=None, assert_current_span=False): + def __init__(self, http_request, policies=None, assert_current_span=False): time.sleep(0.001) - self.request = HttpRequest("GET", "http://localhost") + self.request = http_request("GET", "http://localhost") if policies is None: policies = [] policies.append(mock.Mock(spec=HTTPPolicy, send=self.verify_request)) @@ -88,9 +89,10 @@ async def raising_exception(self): class TestAsyncDecorator(object): @pytest.mark.asyncio - async def test_decorator_tracing_attr(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_decorator_tracing_attr(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.tracing_attr() assert len(parent.children) == 2 @@ -100,9 +102,10 @@ async def test_decorator_tracing_attr(self): @pytest.mark.asyncio - async def test_decorator_has_different_name(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_decorator_has_different_name(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.check_name_is_different() assert len(parent.children) == 2 assert parent.children[0].name == "MockClient.__init__" @@ -110,9 +113,10 @@ async def test_decorator_has_different_name(self): @pytest.mark.asyncio - async def test_used(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_used(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient(policies=[]) + client = MockClient(http_request, policies=[]) await client.get_foo(parent_span=parent) await client.get_foo() @@ -126,9 +130,10 @@ async def test_used(self): @pytest.mark.asyncio - async def test_span_merge_span(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_span_merge_span(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.merge_span_method() await client.no_merge_span_method() @@ -142,9 +147,10 @@ async def test_span_merge_span(self): @pytest.mark.asyncio - async def test_span_complicated(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_span_complicated(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) await client.make_request(2) with parent.span("child") as child: time.sleep(0.001) @@ -163,11 +169,12 @@ async def test_span_complicated(self): assert not parent.children[3].children @pytest.mark.asyncio - async def test_span_with_exception(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + async def test_span_with_exception(self, http_request): """Assert that if an exception is raised, the next sibling method is actually a sibling span. """ with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) try: await client.raising_exception() except: diff --git a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py index fa65fc3353c5..b9229c209fa2 100644 --- a/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py +++ b/sdk/core/azure-core/tests/async_tests/test_universal_http_async.py @@ -24,7 +24,6 @@ # #-------------------------------------------------------------------------- from azure.core.pipeline.transport import ( - HttpRequest, AioHttpTransport, AioHttpTransportResponse, AsyncioRequestsTransport, @@ -34,12 +33,13 @@ import trio import pytest - +from utils import HTTP_REQUESTS, create_http_request @pytest.mark.asyncio -async def test_basic_aiohttp(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_aiohttp(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with AioHttpTransport() as sender: response = await sender.send(request) assert response.body() is not None @@ -48,18 +48,20 @@ async def test_basic_aiohttp(port): assert isinstance(response.status_code, int) @pytest.mark.asyncio -async def test_aiohttp_auto_headers(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_aiohttp_auto_headers(port, http_request): - request = HttpRequest("POST", "http://localhost:{}/basic/string".format(port)) + request = http_request("POST", "http://localhost:{}/basic/string".format(port)) async with AioHttpTransport() as sender: response = await sender.send(request) auto_headers = response.internal_response.request_info.headers assert 'Content-Type' not in auto_headers @pytest.mark.asyncio -async def test_basic_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_basic_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with AsyncioRequestsTransport() as sender: response = await sender.send(request) assert response.body() is not None @@ -67,19 +69,21 @@ async def test_basic_async_requests(port): assert isinstance(response.status_code, int) @pytest.mark.asyncio -async def test_conf_async_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +async def test_conf_async_requests(port, http_request): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with AsyncioRequestsTransport() as sender: response = await sender.send(request) assert response.body() is not None assert isinstance(response.status_code, int) -def test_conf_async_trio_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_conf_async_trio_requests(port, http_request): async def do(): - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) async with TrioRequestsTransport() as sender: return await sender.send(request) assert response.body() is not None diff --git a/sdk/core/azure-core/tests/test_authentication.py b/sdk/core/azure-core/tests/test_authentication.py index de029e8ea352..1be8b8e12cf3 100644 --- a/sdk/core/azure-core/tests/test_authentication.py +++ b/sdk/core/azure-core/tests/test_authentication.py @@ -4,7 +4,7 @@ # license information. # ------------------------------------------------------------------------- import time - +from itertools import product import azure.core from azure.core.credentials import AccessToken, AzureKeyCredential, AzureSasCredential, AzureNamedKeyCredential from azure.core.exceptions import ServiceRequestError @@ -15,7 +15,7 @@ AzureKeyCredentialPolicy, AzureSasCredentialPolicy, ) -from azure.core.pipeline.transport import HttpRequest +from utils import HTTP_REQUESTS import pytest @@ -26,7 +26,8 @@ from mock import Mock -def test_bearer_policy_adds_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_adds_header(http_request): """The bearer token policy should add a header containing a token from its credential""" # 2524608000 == 01/01/2050 @ 12:00am (UTC) expected_token = AccessToken("expected_token", 2524608000) @@ -39,19 +40,20 @@ def verify_authorization_header(request): policies = [BearerTokenCredentialPolicy(fake_credential, "scope"), Mock(send=verify_authorization_header)] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert fake_credential.get_token.call_count == 1 - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) # Didn't need a new token assert fake_credential.get_token.call_count == 1 -def test_bearer_policy_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_send(http_request): """The bearer token policy should invoke the next policy's send method and return the result""" - expected_request = HttpRequest("GET", "https://spam.eggs") + expected_request = http_request("GET", "https://spam.eggs") expected_response = Mock() def verify_request(request): @@ -65,15 +67,16 @@ def verify_request(request): assert response is expected_response -def test_bearer_policy_token_caching(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_token_caching(http_request): good_for_one_hour = AccessToken("token", time.time() + 3600) credential = Mock(get_token=Mock(return_value=good_for_one_hour)) pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 # policy has no token at first request -> it should call get_token - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 # token is good for an hour -> policy should return it from cache expired_token = AccessToken("token", time.time()) @@ -81,14 +84,15 @@ def test_bearer_policy_token_caching(): credential.get_token.return_value = expired_token pipeline = Pipeline(transport=Mock(), policies=[BearerTokenCredentialPolicy(credential, "scope")]) - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 1 - pipeline.run(HttpRequest("GET", "https://spam.eggs")) + pipeline.run(http_request("GET", "https://spam.eggs")) assert credential.get_token.call_count == 2 # token expired -> policy should call get_token -def test_bearer_policy_optionally_enforces_https(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_optionally_enforces_https(http_request): """HTTPS enforcement should be controlled by a keyword argument, and enabled by default""" def assert_option_popped(request, **kwargs): @@ -102,20 +106,21 @@ def assert_option_popped(request, **kwargs): # by default and when enforce_https=True, the policy should raise when given an insecure request with pytest.raises(ServiceRequestError): - pipeline.run(HttpRequest("GET", "http://not.secure")) + pipeline.run(http_request("GET", "http://not.secure")) with pytest.raises(ServiceRequestError): - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=True) + pipeline.run(http_request("GET", "http://not.secure"), enforce_https=True) # when enforce_https=False, an insecure request should pass - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) # https requests should always pass - pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=False) - pipeline.run(HttpRequest("GET", "https://secure"), enforce_https=True) - pipeline.run(HttpRequest("GET", "https://secure")) + pipeline.run(http_request("GET", "https://secure"), enforce_https=False) + pipeline.run(http_request("GET", "https://secure"), enforce_https=True) + pipeline.run(http_request("GET", "https://secure")) -def test_bearer_policy_preserves_enforce_https_opt_out(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_preserves_enforce_https_opt_out(http_request): """The policy should use request context to preserve an opt out from https enforcement""" class ContextValidator(SansIOHTTPPolicy): @@ -127,10 +132,11 @@ def on_request(self, request): policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "http://not.secure"), enforce_https=False) + pipeline.run(http_request("GET", "http://not.secure"), enforce_https=False) -def test_bearer_policy_default_context(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_default_context(http_request): """The policy should call get_token with the scopes given at construction, and no keyword arguments, by default""" expected_scope = "scope" token = AccessToken("", 0) @@ -138,12 +144,13 @@ def test_bearer_policy_default_context(): policy = BearerTokenCredentialPolicy(credential, expected_scope) pipeline = Pipeline(transport=Mock(), policies=[policy]) - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) credential.get_token.assert_called_once_with(expected_scope) -def test_bearer_policy_context_unmodified_by_default(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_context_unmodified_by_default(http_request): """When no options for the policy accompany a request, the policy shouldn't add anything to the request context""" class ContextValidator(SansIOHTTPPolicy): @@ -154,10 +161,11 @@ def on_request(self, request): policies = [BearerTokenCredentialPolicy(credential, "scope"), ContextValidator()] pipeline = Pipeline(transport=Mock(), policies=policies) - pipeline.run(HttpRequest("GET", "https://secure")) + pipeline.run(http_request("GET", "https://secure")) -def test_bearer_policy_calls_on_challenge(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_calls_on_challenge(http_request): """BearerTokenCredentialPolicy should call its on_challenge method when it receives an authentication challenge""" class TestPolicy(BearerTokenCredentialPolicy): @@ -173,12 +181,13 @@ def on_challenge(self, request, challenge): transport = Mock(send=Mock(return_value=response)) pipeline = Pipeline(transport=transport, policies=policies) - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) assert TestPolicy.called -def test_bearer_policy_cannot_complete_challenge(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_cannot_complete_challenge(http_request): """BearerTokenCredentialPolicy should return the 401 response when it can't complete its challenge""" expected_scope = "scope" @@ -189,14 +198,15 @@ def test_bearer_policy_cannot_complete_challenge(): policies = [BearerTokenCredentialPolicy(credential, expected_scope)] pipeline = Pipeline(transport=transport, policies=policies) - response = pipeline.run(HttpRequest("GET", "https://localhost")) + response = pipeline.run(http_request("GET", "https://localhost")) assert response.http_response is expected_response assert transport.send.call_count == 1 credential.get_token.assert_called_once_with(expected_scope) -def test_bearer_policy_calls_sansio_methods(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_bearer_policy_calls_sansio_methods(http_request): """BearerTokenCredentialPolicy should call SansIOHttpPolicy methods as does _SansIOHTTPPolicyRunner""" class TestPolicy(BearerTokenCredentialPolicy): @@ -216,7 +226,7 @@ def send(self, request): transport = Mock(send=Mock(return_value=Mock(status_code=200))) pipeline = Pipeline(transport=transport, policies=[policy]) - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) policy.on_request.assert_called_once_with(policy.request) policy.on_response.assert_called_once_with(policy.request, policy.response) @@ -230,7 +240,7 @@ class TestException(Exception): policy = TestPolicy(credential, "scope") pipeline = Pipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) policy.on_exception.assert_called_once_with(policy.request) # ...or the second @@ -246,14 +256,15 @@ def raise_the_second_time(*args, **kwargs): transport = Mock(send=Mock(wraps=raise_the_second_time)) pipeline = Pipeline(transport=transport, policies=[policy]) with pytest.raises(TestException): - pipeline.run(HttpRequest("GET", "https://localhost")) + pipeline.run(http_request("GET", "https://localhost")) assert transport.send.call_count == 2 policy.on_challenge.assert_called_once() policy.on_exception.assert_called_once_with(policy.request) @pytest.mark.skipif(azure.core.__version__ >= "2", reason="this test applies only to azure-core 1.x") -def test_key_vault_regression(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_key_vault_regression(http_request): """Test for regression affecting azure-keyvault-* 4.0.0. This test must pass, unmodified, for all 1.x versions.""" from azure.core.pipeline.policies._authentication import _BearerTokenCredentialPolicyBase @@ -273,7 +284,8 @@ def test_key_vault_regression(): assert policy._token.token == token -def test_azure_key_credential_policy(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_azure_key_credential_policy(http_request): """Tests to see if we can create an AzureKeyCredentialPolicy""" key_header = "api_key" @@ -287,7 +299,7 @@ def verify_authorization_header(request): credential_policy = AzureKeyCredentialPolicy(credential=credential, name=key_header) pipeline = Pipeline(transport=transport, policies=[credential_policy]) - pipeline.run(HttpRequest("GET", "https://test_key_credential")) + pipeline.run(http_request("GET", "https://test_key_credential")) def test_azure_key_credential_policy_raises(): @@ -313,7 +325,7 @@ def test_azure_key_credential_updates(): credential.update(api_key) assert credential.key == api_key -@pytest.mark.parametrize("sas,url,expected_url", [ +combinations = [ ("sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), ("?sig=test_signature", "https://test_sas_credential", "https://test_sas_credential?sig=test_signature"), ("sig=test_signature", "https://test_sas_credential?sig=test_signature", "https://test_sas_credential?sig=test_signature"), @@ -322,10 +334,12 @@ def test_azure_key_credential_updates(): ("?sig=test_signature", "https://test_sas_credential?", "https://test_sas_credential?sig=test_signature"), ("sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), ("?sig=test_signature", "https://test_sas_credential?foo=bar", "https://test_sas_credential?foo=bar&sig=test_signature"), -]) -def test_azure_sas_credential_policy(sas, url, expected_url): - """Tests to see if we can create an AzureSasCredentialPolicy""" +] +@pytest.mark.parametrize("combinations,http_request", product(combinations, HTTP_REQUESTS)) +def test_azure_sas_credential_policy(combinations, http_request): + """Tests to see if we can create an AzureSasCredentialPolicy""" + sas, url, expected_url = combinations def verify_authorization(request): assert request.url == expected_url @@ -334,7 +348,7 @@ def verify_authorization(request): credential_policy = AzureSasCredentialPolicy(credential=credential) pipeline = Pipeline(transport=transport, policies=[credential_policy]) - pipeline.run(HttpRequest("GET", url)) + pipeline.run(http_request("GET", url)) def test_azure_sas_credential_updates(): """Tests AzureSasCredential updates""" diff --git a/sdk/core/azure-core/tests/test_base_polling.py b/sdk/core/azure-core/tests/test_base_polling.py index f5c1db343302..c75b401df01d 100644 --- a/sdk/core/azure-core/tests/test_base_polling.py +++ b/sdk/core/azure-core/tests/test_base_polling.py @@ -48,7 +48,7 @@ from azure.core.polling.base_polling import LROBasePolling from azure.core.pipeline.policies._utils import _FixedOffset - +from utils import HTTP_REQUESTS, is_rest class SimpleResource: """An implementation of Python 3 SimpleNamespace. @@ -82,8 +82,9 @@ class BadEndpointError(Exception): POLLING_STATUS = 200 CLIENT = PipelineClient("http://example.org") +CLIENT.http_request_type = None def mock_run(client_self, request, **kwargs): - return TestBasePolling.mock_update(request.url, request.headers) + return TestBasePolling.mock_update(client_self.http_request_type, request.url, request.headers) CLIENT._pipeline.run = types.MethodType(mock_run, CLIENT) @@ -183,13 +184,14 @@ def test_delay_extraction_httpdate(polling_response): assert polling._extract_delay() == 60*60 # one hour in seconds assert str(mock_datetime.now.call_args[0][0]) == "" - -def test_post(pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_post(pipeline_client_builder, deserialization_cb, http_request): # Test POST LRO with both Location and Operation-Location # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, 'POST', 202, { @@ -204,12 +206,14 @@ def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'status': 'Succeeded'} @@ -235,12 +239,14 @@ def send(request, **kwargs): if request.url == 'http://example.org/location': return TestBasePolling.mock_send( + http_request, 'GET', 200, body=None ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'status': 'Succeeded'} @@ -258,13 +264,14 @@ def send(request, **kwargs): result = poll.result() assert result is None - -def test_post_resource_location(pipeline_client_builder, deserialization_cb): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_post_resource_location(pipeline_client_builder, deserialization_cb, http_request): # ResourceLocation # The initial response contains both Location and Operation-Location, a 202 and no Body initial_response = TestBasePolling.mock_send( + http_request, 'POST', 202, { @@ -278,12 +285,14 @@ def send(request, **kwargs): if request.url == 'http://example.org/resource_location': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'location_result': True} ).http_response elif request.url == 'http://example.org/async_monitor': return TestBasePolling.mock_send( + http_request, 'GET', 200, body={'status': 'Succeeded', 'resourceLocation': 'http://example.org/resource_location'} @@ -307,7 +316,7 @@ class TestBasePolling(object): convert = re.compile('([a-z0-9])([A-Z])') @staticmethod - def mock_send(method, status, headers=None, body=RESPONSE_BODY): + def mock_send(http_request, method, status, headers=None, body=RESPONSE_BODY): if headers is None: headers = {} response = Response() @@ -324,15 +333,23 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): response.headers.update({"content-type": "application/json; charset=utf8"}) response.reason = "OK" - request = CLIENT._request( - response.request.method, - response.request.url, - None, # params - response.request.headers, - body, - None, # form_content - None # stream_content - ) + if is_rest(http_request): + request = http_request( + response.request.method, + response.request.url, + headers=response.request.headers, + content=body, + ) + else: + request = CLIENT._request( + response.request.method, + response.request.url, + None, # params + response.request.headers, + body, + None, # form_content + None # stream_content + ) return PipelineResponse( request, @@ -344,7 +361,7 @@ def mock_send(method, status, headers=None, body=RESPONSE_BODY): ) @staticmethod - def mock_update(url, headers=None): + def mock_update(http_request, url, headers=None): response = Response() response._content_consumed = True response.request = mock.create_autospec(Request) @@ -375,15 +392,9 @@ def mock_update(url, headers=None): else: raise Exception('URL does not match') - - request = CLIENT._request( + request = http_request( response.request.method, response.request.url, - None, # params - {}, # request has no headers - None, # Request has no body - None, # form_content - None # stream_content ) return PipelineResponse( @@ -425,11 +436,14 @@ def mock_deserialization_no_body(pipeline_response): """ return None - def test_long_running_put(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_long_running_put(self, http_request): #TODO: Test custom header field # Test throw on non LRO related status code - response = TestBasePolling.mock_send('PUT', 1000, {}) + response = TestBasePolling.mock_send( + http_request, 'PUT', 1000, {}) + CLIENT.http_request_type = http_request with pytest.raises(HttpResponseError): LROPoller(CLIENT, response, TestBasePolling.mock_outputs, @@ -441,6 +455,7 @@ def test_long_running_put(self): 'name': TEST_NAME } response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {}, response_body ) @@ -455,6 +470,7 @@ def no_update_allowed(url, headers=None): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'operation-location': ASYNC_URL}) poll = LROPoller(CLIENT, response, @@ -465,6 +481,7 @@ def no_update_allowed(url, headers=None): # Test polling location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -476,6 +493,7 @@ def no_update_allowed(url, headers=None): # Test polling initial payload invalid (SQLDb) response_body = {} # Empty will raise response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'location': LOCATION_URL}, response_body) poll = LROPoller(CLIENT, response, @@ -486,6 +504,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -495,6 +514,7 @@ def no_update_allowed(url, headers=None): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, 'PUT', 201, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -502,10 +522,12 @@ def no_update_allowed(url, headers=None): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - def test_long_running_patch(self): - + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_long_running_patch(self, http_request): + CLIENT.http_request_type = http_request # Test polling from location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -517,6 +539,7 @@ def test_long_running_patch(self): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -528,6 +551,7 @@ def test_long_running_patch(self): # Test polling from location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 200, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -539,6 +563,7 @@ def test_long_running_patch(self): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 200, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -550,6 +575,7 @@ def test_long_running_patch(self): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -559,6 +585,7 @@ def test_long_running_patch(self): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, 'PATCH', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -566,27 +593,33 @@ def test_long_running_patch(self): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - def test_long_running_delete(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_long_running_delete(self, http_request): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'DELETE', 202, {'operation-location': ASYNC_URL}, body="" ) + CLIENT.http_request_type = http_request poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) poll.wait() assert poll._polling_method._pipeline_response.http_response.internal_response.randomFieldFromPollAsyncOpHeader is None - def test_long_running_post_legacy(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_long_running_post_legacy(self, http_request): # Former oooooold tests to refactor one day to something more readble # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'POST', 201, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) + CLIENT.http_request_type = http_request poll = LROPoller(CLIENT, response, TestBasePolling.mock_deserialization_no_body, LROBasePolling(0)) @@ -595,6 +628,7 @@ def test_long_running_post_legacy(self): # Test polling from operation-location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'operation-location': ASYNC_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -606,6 +640,7 @@ def test_long_running_post_legacy(self): # Test polling from location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}, body={'properties':{'provisioningState': 'Succeeded'}}) @@ -617,6 +652,7 @@ def test_long_running_post_legacy(self): # Test fail to poll from operation-location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'operation-location': ERROR}) with pytest.raises(BadEndpointError): @@ -626,6 +662,7 @@ def test_long_running_post_legacy(self): # Test fail to poll from location header response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': ERROR}) with pytest.raises(BadEndpointError): @@ -633,13 +670,15 @@ def test_long_running_post_legacy(self): TestBasePolling.mock_outputs, LROBasePolling(0)).result() - def test_long_running_negative(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_long_running_negative(self, http_request): global LOCATION_BODY global POLLING_STATUS - + CLIENT.http_request_type = http_request # Test LRO PUT throws for invalid json LOCATION_BODY = '{' response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller( @@ -653,6 +692,7 @@ def test_long_running_negative(self): LOCATION_BODY = '{\'"}' response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -664,6 +704,7 @@ def test_long_running_negative(self): LOCATION_BODY = '{' POLLING_STATUS = 203 response = TestBasePolling.mock_send( + http_request, 'POST', 202, {'location': LOCATION_URL}) poll = LROPoller(CLIENT, response, @@ -675,4 +716,3 @@ def test_long_running_negative(self): LOCATION_BODY = json.dumps({ 'name': TEST_NAME }) POLLING_STATUS = 200 - diff --git a/sdk/core/azure-core/tests/test_basic_transport.py b/sdk/core/azure-core/tests/test_basic_transport.py index 7c57f53dfeee..4b4fd280dc0c 100644 --- a/sdk/core/azure-core/tests/test_basic_transport.py +++ b/sdk/core/azure-core/tests/test_basic_transport.py @@ -12,13 +12,14 @@ except ImportError: import mock -from azure.core.pipeline.transport import HttpRequest, HttpResponse, RequestsTransport +from azure.core.pipeline.transport import HttpResponse, RequestsTransport from azure.core.pipeline.transport._base import HttpClientTransportResponse, HttpTransport, _deserialize_response, _urljoin from azure.core.pipeline.policies import HeadersPolicy from azure.core.pipeline import Pipeline from azure.core.exceptions import HttpResponseError import logging import pytest +from utils import HTTP_REQUESTS class MockResponse(HttpResponse): @@ -31,9 +32,10 @@ def body(self): return self._body @pytest.mark.skipif(sys.version_info < (3, 6), reason="Multipart serialization not supported on 2.7 + dict order not deterministic on 3.5") -def test_http_request_serialization(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_request_serialization(http_request): # Method + Url - request = HttpRequest("DELETE", "/container0/blob0") + request = http_request("DELETE", "/container0/blob0") serialized = request.serialize() expected = ( @@ -44,7 +46,7 @@ def test_http_request_serialization(): assert serialized == expected # Method + Url + Headers - request = HttpRequest( + request = http_request( "DELETE", "/container0/blob0", # Use OrderedDict to get consistent test result on 3.5 where order is not guaranteed @@ -67,7 +69,7 @@ def test_http_request_serialization(): # Method + Url + Headers + Body - request = HttpRequest( + request = http_request( "DELETE", "/container0/blob0", headers={ @@ -87,16 +89,18 @@ def test_http_request_serialization(): assert serialized == expected -def test_url_join(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_url_join(http_request): assert _urljoin('devstoreaccount1', '') == 'devstoreaccount1/' assert _urljoin('devstoreaccount1', 'testdir/') == 'devstoreaccount1/testdir/' assert _urljoin('devstoreaccount1/', '') == 'devstoreaccount1/' assert _urljoin('devstoreaccount1/', 'testdir/') == 'devstoreaccount1/testdir/' -def test_http_client_response(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_client_response(port, http_request): # Create a core request - request = HttpRequest("GET", "http://localhost:{}".format(port)) + request = http_request("GET", "http://localhost:{}".format(port)) # Fake a transport based on http.client conn = HTTPConnection("localhost", port) @@ -115,10 +119,11 @@ def test_http_client_response(port): assert "Content-Type" in response.headers -def test_response_deserialization(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_deserialization(http_request): # Method + Url - request = HttpRequest("DELETE", "/container0/blob0") + request = http_request("DELETE", "/container0/blob0") body = ( b'HTTP/1.1 202 Accepted\r\n' b'x-ms-request-id: 778fdc83-801e-0000-62ff-0334671e284f\r\n' @@ -135,7 +140,7 @@ def test_response_deserialization(): } # Method + Url + Headers + Body - request = HttpRequest( + request = http_request( "DELETE", "/container0/blob0", headers={ @@ -161,9 +166,10 @@ def test_response_deserialization(): } assert response.text() == "I am groot" -def test_response_deserialization_utf8_bom(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_deserialization_utf8_bom(http_request): - request = HttpRequest("DELETE", "/container0/blob0") + request = http_request("DELETE", "/container0/blob0") body = ( b'HTTP/1.1 400 One of the request inputs is not valid.\r\n' b'x-ms-error-code: InvalidInput\r\n' @@ -181,7 +187,8 @@ def test_response_deserialization_utf8_bom(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -189,10 +196,10 @@ def test_multipart_send(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -227,17 +234,18 @@ def test_multipart_send(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_context(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_context(http_request): transport = mock.MagicMock(spec=HttpTransport) header_policy = HeadersPolicy({ 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -275,7 +283,8 @@ def test_multipart_send_with_context(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_one_changeset(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_one_changeset(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -284,18 +293,18 @@ def test_multipart_send_with_one_changeset(): }) requests = [ - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ] - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( *requests, policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525", @@ -333,7 +342,8 @@ def test_multipart_send_with_one_changeset(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_multiple_changesets(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_multiple_changesets(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -341,22 +351,22 @@ def test_multipart_send_with_multiple_changesets(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset1 = HttpRequest("", "") + changeset1 = http_request("", "") changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - changeset2 = HttpRequest("", "") + changeset2 = http_request("", "") changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3"), + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3"), policies=[header_policy], boundary="changeset_8b9e487e-a353-4dcb-a6f4-0688191e0314" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset1, changeset2, @@ -419,7 +429,8 @@ def test_multipart_send_with_multiple_changesets(): @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_combination_changeset_first(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -427,17 +438,17 @@ def test_multipart_send_with_combination_changeset_first(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), policies=[header_policy], boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -482,7 +493,8 @@ def test_multipart_send_with_combination_changeset_first(): ) @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_combination_changeset_last(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -490,16 +502,16 @@ def test_multipart_send_with_combination_changeset_last(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, policies=[header_policy], boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" @@ -545,7 +557,8 @@ def test_multipart_send_with_combination_changeset_last(): ) @pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") -def test_multipart_send_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_send_with_combination_changeset_middle(http_request): transport = mock.MagicMock(spec=HttpTransport) @@ -553,17 +566,17 @@ def test_multipart_send_with_combination_changeset_middle(): 'x-ms-date': 'Thu, 14 Jun 2018 16:46:54 GMT' }) - changeset = HttpRequest("", "") + changeset = http_request("", "") changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), + http_request("DELETE", "/container1/blob1"), policies=[header_policy], boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2"), + http_request("DELETE", "/container2/blob2"), policies=[header_policy], boundary="batch_357de4f7-6d0b-4e02-8cd2-6361411a9525" ) @@ -608,17 +621,18 @@ def test_multipart_send_with_combination_changeset_middle(): ) -def test_multipart_receive(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive(http_request): class ResponsePolicy(object): def on_response(self, request, response): # type: (PipelineRequest, PipelineResponse) -> None response.http_response.headers['x-ms-fun'] = 'true' - req0 = HttpRequest("DELETE", "/container0/blob0") - req1 = HttpRequest("DELETE", "/container1/blob1") + req0 = http_request("DELETE", "/container0/blob0") + req1 = http_request("DELETE", "/container1/blob1") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( req0, req1, @@ -670,27 +684,30 @@ def on_response(self, request, response): assert res1.status_code == 404 assert res1.headers['x-ms-fun'] == 'true' -def test_raise_for_status_bad_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_raise_for_status_bad_response(http_request): response = MockResponse(request=None, body=None, content_type=None) response.status_code = 400 with pytest.raises(HttpResponseError): response.raise_for_status() -def test_raise_for_status_good_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_raise_for_status_good_response(http_request): response = MockResponse(request=None, body=None, content_type=None) response.status_code = 200 response.raise_for_status() -def test_multipart_receive_with_one_changeset(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive_with_one_changeset(http_request): - changeset = HttpRequest(None, None) + changeset = http_request(None, None) changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset) body_as_bytes = ( @@ -737,20 +754,21 @@ def test_multipart_receive_with_one_changeset(): assert res0.status_code == 202 -def test_multipart_receive_with_multiple_changesets(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive_with_multiple_changesets(http_request): - changeset1 = HttpRequest(None, None) + changeset1 = http_request(None, None) changeset1.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - changeset2 = HttpRequest(None, None) + changeset2 = http_request(None, None) changeset2.set_multipart_mixed( - HttpRequest("DELETE", "/container2/blob2"), - HttpRequest("DELETE", "/container3/blob3") + http_request("DELETE", "/container2/blob2"), + http_request("DELETE", "/container3/blob3") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(changeset1, changeset2) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -822,16 +840,17 @@ def test_multipart_receive_with_multiple_changesets(): assert parts[3].status_code == 409 -def test_multipart_receive_with_combination_changeset_first(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive_with_combination_changeset_first(http_request): - changeset = HttpRequest(None, None) + changeset = http_request(None, None) changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), - HttpRequest("DELETE", "/container1/blob1") + http_request("DELETE", "/container0/blob0"), + http_request("DELETE", "/container1/blob1") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(changeset, HttpRequest("DELETE", "/container2/blob2")) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(changeset, http_request("DELETE", "/container2/blob2")) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' b'Content-Type: multipart/mixed; boundary="changeset_357de4f7-6d0b-4e02-8cd2-6361411a9525"\r\n' @@ -886,16 +905,17 @@ def test_multipart_receive_with_combination_changeset_first(): assert parts[2].status_code == 404 -def test_multipart_receive_with_combination_changeset_middle(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive_with_combination_changeset_middle(http_request): - changeset = HttpRequest(None, None) - changeset.set_multipart_mixed(HttpRequest("DELETE", "/container1/blob1")) + changeset = http_request(None, None) + changeset.set_multipart_mixed(http_request("DELETE", "/container1/blob1")) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed( - HttpRequest("DELETE", "/container0/blob0"), + http_request("DELETE", "/container0/blob0"), changeset, - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container2/blob2") ) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -951,16 +971,17 @@ def test_multipart_receive_with_combination_changeset_middle(): assert parts[2].status_code == 404 -def test_multipart_receive_with_combination_changeset_last(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive_with_combination_changeset_last(http_request): - changeset = HttpRequest(None, None) + changeset = http_request(None, None) changeset.set_multipart_mixed( - HttpRequest("DELETE", "/container1/blob1"), - HttpRequest("DELETE", "/container2/blob2") + http_request("DELETE", "/container1/blob1"), + http_request("DELETE", "/container2/blob2") ) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") - request.set_multipart_mixed(HttpRequest("DELETE", "/container0/blob0"), changeset) + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") + request.set_multipart_mixed(http_request("DELETE", "/container0/blob0"), changeset) body_as_bytes = ( b'--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n' @@ -1016,11 +1037,12 @@ def test_multipart_receive_with_combination_changeset_last(): assert parts[2].status_code == 404 -def test_multipart_receive_with_bom(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_multipart_receive_with_bom(http_request): - req0 = HttpRequest("DELETE", "/container0/blob0") + req0 = http_request("DELETE", "/container0/blob0") - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) body_as_bytes = ( b"--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\n" @@ -1052,12 +1074,13 @@ def test_multipart_receive_with_bom(): assert res0.body().startswith(b'\xef\xbb\xbf') -def test_recursive_multipart_receive(): - req0 = HttpRequest("DELETE", "/container0/blob0") - internal_req0 = HttpRequest("DELETE", "/container0/blob0") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_recursive_multipart_receive(http_request): + req0 = http_request("DELETE", "/container0/blob0") + internal_req0 = http_request("DELETE", "/container0/blob0") req0.set_multipart_mixed(internal_req0) - request = HttpRequest("POST", "http://account.blob.core.windows.net/?comp=batch") + request = http_request("POST", "http://account.blob.core.windows.net/?comp=batch") request.set_multipart_mixed(req0) internal_body_as_str = ( "--batchresponse_66925647-d0cb-4109-b6d3-28efe3e1e5ed\r\n" @@ -1107,10 +1130,11 @@ def test_close_unopened_transport(): transport.close() -def test_timeout(caplog, port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_timeout(caplog, port, http_request): transport = RequestsTransport() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) with caplog.at_level(logging.WARNING, logger="azure.core.pipeline.transport"): with Pipeline(transport) as pipeline: @@ -1119,10 +1143,11 @@ def test_timeout(caplog, port): assert "Tuple timeout setting is deprecated" not in caplog.text -def test_tuple_timeout(caplog, port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_tuple_timeout(caplog, port, http_request): transport = RequestsTransport() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) with caplog.at_level(logging.WARNING, logger="azure.core.pipeline.transport"): with Pipeline(transport) as pipeline: @@ -1131,10 +1156,11 @@ def test_tuple_timeout(caplog, port): assert "Tuple timeout setting is deprecated" in caplog.text -def test_conflict_timeout(caplog, port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_conflict_timeout(caplog, port, http_request): transport = RequestsTransport() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) with pytest.raises(ValueError): with Pipeline(transport) as pipeline: diff --git a/sdk/core/azure-core/tests/test_custom_hook_policy.py b/sdk/core/azure-core/tests/test_custom_hook_policy.py index e553ca5a9811..8ff38e0c203f 100644 --- a/sdk/core/azure-core/tests/test_custom_hook_policy.py +++ b/sdk/core/azure-core/tests/test_custom_hook_policy.py @@ -11,8 +11,10 @@ from azure.core.pipeline.policies import CustomHookPolicy, UserAgentPolicy from azure.core.pipeline.transport import HttpTransport import pytest +from utils import HTTP_REQUESTS -def test_response_hook_policy_in_init(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_hook_policy_in_init(http_request): def test_callback(response): raise ValueError() @@ -24,11 +26,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request) -def test_response_hook_policy_in_request(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_hook_policy_in_request(http_request): def test_callback(response): raise ValueError() @@ -40,11 +43,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request, raw_response_hook=test_callback) -def test_response_hook_policy_in_both(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_response_hook_policy_in_both(http_request): def test_callback(response): raise ValueError() @@ -59,11 +63,12 @@ def test_callback_request(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(TypeError): client._pipeline.run(request, raw_response_hook=test_callback_request) -def test_request_hook_policy_in_init(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_hook_policy_in_init(http_request): def test_callback(response): raise ValueError() @@ -75,11 +80,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request) -def test_request_hook_policy_in_request(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_hook_policy_in_request(http_request): def test_callback(response): raise ValueError() @@ -91,11 +97,12 @@ def test_callback(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(ValueError): client._pipeline.run(request, raw_request_hook=test_callback) -def test_request_hook_policy_in_both(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_hook_policy_in_both(http_request): def test_callback(response): raise ValueError() @@ -110,6 +117,6 @@ def test_callback_request(response): custom_hook_policy ] client = PipelineClient(base_url=url, policies=policies, transport=transport) - request = client.get(url) + request = http_request("GET", url) with pytest.raises(TypeError): client._pipeline.run(request, raw_request_hook=test_callback_request) diff --git a/sdk/core/azure-core/tests/test_error_map.py b/sdk/core/azure-core/tests/test_error_map.py index 6fa355b490ab..f3f4e4c848ae 100644 --- a/sdk/core/azure-core/tests/test_error_map.py +++ b/sdk/core/azure-core/tests/test_error_map.py @@ -31,12 +31,13 @@ ErrorMap, ) from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, ) +from utils import HTTP_REQUESTS -def test_error_map(): - request = HttpRequest("GET", "") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_error_map(http_request): + request = http_request("GET", "") response = HttpResponse(request, None) error_map = { 404: ResourceNotFoundError @@ -44,8 +45,9 @@ def test_error_map(): with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) -def test_error_map_no_default(): - request = HttpRequest("GET", "") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_error_map_no_default(http_request): + request = http_request("GET", "") response = HttpResponse(request, None) error_map = ErrorMap({ 404: ResourceNotFoundError @@ -53,8 +55,9 @@ def test_error_map_no_default(): with pytest.raises(ResourceNotFoundError): map_error(404, response, error_map) -def test_error_map_with_default(): - request = HttpRequest("GET", "") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_error_map_with_default(http_request): + request = http_request("GET", "") response = HttpResponse(request, None) error_map = ErrorMap({ 404: ResourceNotFoundError @@ -62,8 +65,9 @@ def test_error_map_with_default(): with pytest.raises(ResourceExistsError): map_error(401, response, error_map) -def test_only_default(): - request = HttpRequest("GET", "") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_only_default(http_request): + request = http_request("GET", "") response = HttpResponse(request, None) error_map = ErrorMap(default_error=ResourceExistsError) with pytest.raises(ResourceExistsError): diff --git a/sdk/core/azure-core/tests/test_http_logging_policy.py b/sdk/core/azure-core/tests/test_http_logging_policy.py index c69b8b5f0e7c..a745946d4c3e 100644 --- a/sdk/core/azure-core/tests/test_http_logging_policy.py +++ b/sdk/core/azure-core/tests/test_http_logging_policy.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ """Tests for the HttpLoggingPolicy.""" - +import pytest import logging import types try: @@ -16,15 +16,16 @@ PipelineContext ) from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, ) from azure.core.pipeline.policies import ( HttpLoggingPolicy, ) +from utils import HTTP_REQUESTS -def test_http_logger(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -42,7 +43,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') http_response = HttpResponse(universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None)) @@ -138,7 +139,8 @@ def emit(self, record): -def test_http_logger_operation_level(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger_operation_level(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -157,7 +159,7 @@ def emit(self, record): policy = HttpLoggingPolicy() kwargs={'logger': logger} - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') http_response = HttpResponse(universal_request, None) http_response.status_code = 202 request = PipelineRequest(universal_request, PipelineContext(None, **kwargs)) @@ -209,7 +211,8 @@ def emit(self, record): mock_handler.reset() -def test_http_logger_with_body(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger_with_body(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -227,7 +230,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') universal_request.body = "testbody" http_response = HttpResponse(universal_request, None) http_response.status_code = 202 @@ -249,7 +252,8 @@ def emit(self, record): mock_handler.reset() -def test_http_logger_with_generator_body(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_http_logger_with_generator_body(http_request): class MockHandler(logging.Handler): def __init__(self): @@ -267,7 +271,7 @@ def emit(self, record): policy = HttpLoggingPolicy(logger=logger) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') mock = Mock() mock.__class__ = types.GeneratorType universal_request.body = mock diff --git a/sdk/core/azure-core/tests/test_pipeline.py b/sdk/core/azure-core/tests/test_pipeline.py index 6d260c07e9fe..fb5a063e4f7b 100644 --- a/sdk/core/azure-core/tests/test_pipeline.py +++ b/sdk/core/azure-core/tests/test_pipeline.py @@ -50,21 +50,23 @@ ) from azure.core.pipeline.transport._base import PipelineClientBase from azure.core.pipeline.transport import ( - HttpRequest, HttpTransport, RequestsTransport, ) +from utils import HTTP_REQUESTS, is_rest from azure.core.exceptions import AzureError -def test_default_http_logging_policy(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_default_http_logging_policy(http_request): config = Configuration() pipeline_client = PipelineClient(base_url="test") pipeline = pipeline_client._build_pipeline(config) http_logging_policy = pipeline._impl_policies[-1]._policy assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST -def test_pass_in_http_logging_policy(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_pass_in_http_logging_policy(http_request): config = Configuration() http_logging_policy = HttpLoggingPolicy() http_logging_policy.allowed_header_names.update( @@ -78,7 +80,8 @@ def test_pass_in_http_logging_policy(): assert http_logging_policy.allowed_header_names == HttpLoggingPolicy.DEFAULT_HEADERS_WHITELIST.union({"x-ms-added-header"}) -def test_sans_io_exception(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_sans_io_exception(http_request): class BrokenSender(HttpTransport): def send(self, request, **config): raise ValueError("Broken") @@ -95,7 +98,7 @@ def __exit__(self, exc_type, exc_value, traceback): pipeline = Pipeline(BrokenSender(), [SansIOHTTPPolicy()]) - req = HttpRequest("GET", "/") + req = http_request("GET", "/") with pytest.raises(ValueError): pipeline.run(req) @@ -108,9 +111,10 @@ def on_exception(self, requests, **kwargs): with pytest.raises(NotImplementedError): pipeline.run(req) -def test_requests_socket_timeout(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_requests_socket_timeout(http_request): conf = Configuration() - request = HttpRequest("GET", "https://bing.com") + request = http_request("GET", "https://bing.com") policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -179,26 +183,29 @@ def test_format_incorrect_endpoint(): client.format_url("foo/bar") assert str(exp.value) == "The value provided for the url part Endpoint was incorrect, and resulted in an invalid url" -def test_request_json(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_json(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") data = "Lots of dataaaa" request.set_json_body(data) assert request.data == json.dumps(data) assert request.headers.get("Content-Length") == "17" -def test_request_data(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_data(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") data = "Lots of dataaaa" request.set_bytes_body(data) assert request.data == data assert request.headers.get("Content-Length") == "15" -def test_request_stream(): - request = HttpRequest("GET", "/") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_stream(http_request): + request = http_request("GET", "/") data = b"Lots of dataaaa" request.set_streamed_data_body(data) @@ -216,45 +223,51 @@ def data_gen(): assert request.data == data -def test_request_xml(): - request = HttpRequest("GET", "/") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_xml(http_request): + request = http_request("GET", "/") data = ET.Element("root") request.set_xml_body(data) assert request.data == b"\n" -def test_request_url_with_params(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" request.format_parameters({"g": "h"}) assert request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] -def test_request_url_with_params_as_list(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params_as_list(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" request.format_parameters({"g": ["h","i"]}) assert request.url in ["a/b/c?g=h&g=i&t=y", "a/b/c?t=y&g=h&g=i"] -def test_request_url_with_params_with_none_in_list(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params_with_none_in_list(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" with pytest.raises(ValueError): request.format_parameters({"g": ["h",None]}) -def test_request_url_with_params_with_none(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_url_with_params_with_none(http_request): - request = HttpRequest("GET", "/") + request = http_request("GET", "/") request.url = "a/b/c?t=y" with pytest.raises(ValueError): request.format_parameters({"g": None}) -def test_repr(): - request = HttpRequest("GET", "hello.com") +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_repr(http_request): + request = http_request("GET", "hello.com") assert repr(request) == "" def test_add_custom_policy(): @@ -355,10 +368,11 @@ def send(*args): with pytest.raises(ValueError): client = PipelineClient(base_url="test", policies=policies, per_retry_policies=[foo_policy]) -def test_basic_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_basic_requests(port, http_request): conf = Configuration() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -369,9 +383,10 @@ def test_basic_requests(port): assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) -def test_basic_options_requests(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_basic_options_requests(port, http_request): - request = HttpRequest("OPTIONS", "http://localhost:{}/basic/string".format(port)) + request = http_request("OPTIONS", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -382,10 +397,11 @@ def test_basic_options_requests(port): assert pipeline._transport.session is None assert isinstance(response.http_response.status_code, int) -def test_basic_requests_separate_session(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_basic_requests_separate_session(port, http_request): session = requests.Session() - request = HttpRequest("GET", "http://localhost:{}/basic/string".format(port)) + request = http_request("GET", "http://localhost:{}/basic/string".format(port)) policies = [ UserAgentPolicy("myusergant"), RedirectPolicy() @@ -400,21 +416,28 @@ def test_basic_requests_separate_session(port): assert transport.session transport.session.close() -def test_request_text(port): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_text(port, http_request): client = PipelineClientBase("http://localhost:{}".format(port)) - request = client.get( - "/", - content="foo" - ) + if is_rest(http_request): + request = http_request("GET", "/", json="foo") + else: + request = client.get( + "/", + content="foo" + ) # In absence of information, everything is JSON (double quote added) assert request.data == json.dumps("foo") - request = client.post( - "/", - headers={'content-type': 'text/whatever'}, - content="foo" - ) + if is_rest(http_request): + request = http_request("POST", "/", headers={'content-type': 'text/whatever'}, content="foo") + else: + request = client.post( + "/", + headers={'content-type': 'text/whatever'}, + content="foo" + ) # We want a direct string assert request.data == "foo" diff --git a/sdk/core/azure-core/tests/test_request_id_policy.py b/sdk/core/azure-core/tests/test_request_id_policy.py index 7da467b467c2..44c8647982cb 100644 --- a/sdk/core/azure-core/tests/test_request_id_policy.py +++ b/sdk/core/azure-core/tests/test_request_id_policy.py @@ -4,7 +4,6 @@ # ------------------------------------ """Tests for the request id policy.""" from azure.core.pipeline.policies import RequestIdPolicy -from azure.core.pipeline.transport import HttpRequest from azure.core.pipeline import PipelineRequest, PipelineContext try: from unittest import mock @@ -12,15 +11,16 @@ import mock from itertools import product import pytest +from utils import HTTP_REQUESTS auto_request_id_values = (True, False, None) request_id_init_values = ("foo", None, "_unset") request_id_set_values = ("bar", None, "_unset") request_id_req_values = ("baz", None, "_unset") -full_combination = list(product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values)) +full_combination = list(product(auto_request_id_values, request_id_init_values, request_id_set_values, request_id_req_values, HTTP_REQUESTS)) -@pytest.mark.parametrize("auto_request_id, request_id_init, request_id_set, request_id_req", full_combination) -def test_request_id_policy(auto_request_id, request_id_init, request_id_set, request_id_req): +@pytest.mark.parametrize("auto_request_id, request_id_init, request_id_set, request_id_req, http_request", full_combination) +def test_request_id_policy(auto_request_id, request_id_init, request_id_set, request_id_req, http_request): """Test policy with no other policy and happy path""" kwargs = {} if auto_request_id is not None: @@ -30,7 +30,7 @@ def test_request_id_policy(auto_request_id, request_id_init, request_id_set, req request_id_policy = RequestIdPolicy(**kwargs) if request_id_set != "_unset": request_id_policy.set_request_id(request_id_set) - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') pipeline_request = PipelineRequest(request, PipelineContext(None)) if request_id_req != "_unset": pipeline_request.context.options['request_id'] = request_id_req @@ -55,10 +55,11 @@ def test_request_id_policy(auto_request_id, request_id_init, request_id_set, req else: assert not "x-ms-client-request-id" in request.headers -def test_request_id_already_exists(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_id_already_exists(http_request): """Test policy with no other policy and happy path""" request_id_policy = RequestIdPolicy() - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') request.headers["x-ms-client-request-id"] = "VALUE" pipeline_request = PipelineRequest(request, PipelineContext(None)) request_id_policy.on_request(pipeline_request) diff --git a/sdk/core/azure-core/tests/test_requests_universal.py b/sdk/core/azure-core/tests/test_requests_universal.py index b965d4fd3d69..891a364cfd61 100644 --- a/sdk/core/azure-core/tests/test_requests_universal.py +++ b/sdk/core/azure-core/tests/test_requests_universal.py @@ -25,8 +25,9 @@ # -------------------------------------------------------------------------- import concurrent.futures import requests.utils - -from azure.core.pipeline.transport import HttpRequest, RequestsTransport, RequestsTransportResponse +import pytest +from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse +from utils import HTTP_REQUESTS def test_threading_basic_requests(): @@ -44,8 +45,9 @@ def thread_body(local_sender): future = executor.submit(thread_body, sender) assert future.result() -def test_requests_auto_headers(port): - request = HttpRequest("POST", "http://localhost:{}/basic/string".format(port)) +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_requests_auto_headers(port, http_request): + request = http_request("POST", "http://localhost:{}/basic/string".format(port)) with RequestsTransport() as sender: response = sender.send(request) auto_headers = response.internal_response.request.headers diff --git a/sdk/core/azure-core/tests/test_rest_backcompat.py b/sdk/core/azure-core/tests/test_rest_backcompat.py new file mode 100644 index 000000000000..56998caec3e5 --- /dev/null +++ b/sdk/core/azure-core/tests/test_rest_backcompat.py @@ -0,0 +1,403 @@ +# -*- coding: utf-8 -*- +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import sys +import pytest +import json +import xml.etree.ElementTree as ET +from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest +from azure.core.rest import HttpRequest as RestHttpRequest +try: + import collections.abc as collections +except ImportError: + import collections + +@pytest.fixture +def old_request(): + return PipelineTransportHttpRequest("GET", "/") + +@pytest.fixture +def new_request(): + return RestHttpRequest("GET", "/") + +def test_request_attr_parity(old_request, new_request): + for attr in dir(old_request): + if not attr[0] == "_": + # if not a private attr, we want parity + assert hasattr(new_request, attr) + +def test_request_set_attrs(old_request, new_request): + for attr in dir(old_request): + if attr[0] == "_": + continue + try: + # if we can set it on the old request, we want to + # be able to set it on the new + setattr(old_request, attr, "foo") + except: + pass + else: + setattr(new_request, attr, "foo") + assert getattr(old_request, attr) == getattr(new_request, attr) == "foo" + +def test_request_multipart_mixed_info(old_request, new_request): + old_request.multipart_mixed_info = "foo" + new_request.multipart_mixed_info = "foo" + assert old_request.multipart_mixed_info == new_request.multipart_mixed_info == "foo" + +def test_request_files_attr(old_request, new_request): + assert old_request.files == new_request.files == None + old_request.files = {"hello": "world"} + new_request.files = {"hello": "world"} + assert old_request.files == new_request.files == {"hello": "world"} + +def test_request_data_attr(old_request, new_request): + assert old_request.data == new_request.data == None + old_request.data = {"hello": "world"} + new_request.data = {"hello": "world"} + assert old_request.data == new_request.data == {"hello": "world"} + +def test_request_query(old_request, new_request): + assert old_request.query == new_request.query == {} + old_request.url = "http://localhost:5000?a=b&c=d" + new_request.url = "http://localhost:5000?a=b&c=d" + assert old_request.query == new_request.query == {'a': 'b', 'c': 'd'} + +def test_request_query_and_params_kwarg(old_request): + # should be same behavior if we pass in query params through the params kwarg in the new requests + old_request.url = "http://localhost:5000?a=b&c=d" + new_request = RestHttpRequest("GET", "http://localhost:5000", params={'a': 'b', 'c': 'd'}) + assert old_request.query == new_request.query == {'a': 'b', 'c': 'd'} + +def test_request_body(old_request, new_request): + assert old_request.body == new_request.body == None + old_request.data = {"hello": "world"} + new_request.data = {"hello": "world"} + assert ( + old_request.body == + new_request.body == + new_request.content == + {"hello": "world"} + ) + # files will not override data + old_request.files = {"foo": "bar"} + new_request.files = {"foo": "bar"} + assert ( + old_request.body == + new_request.body == + new_request.content == + {"hello": "world"} + ) + + # nullify data + old_request.data = None + new_request.data = None + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + None + ) + +def test_format_parameters(old_request, new_request): + old_request.url = "a/b/c?t=y" + new_request.url = "a/b/c?t=y" + assert old_request.url == new_request.url == "a/b/c?t=y" + old_request.format_parameters({"g": "h"}) + new_request.format_parameters({"g": "h"}) + + # ordering can vary, so not sticking on order + assert old_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + +def test_request_format_parameters_and_params_kwarg(old_request): + # calling format_parameters on an old request should be the same + # behavior as passing in params to new request + old_request.url = "a/b/c?t=y" + old_request.format_parameters({"g": "h"}) + new_request = RestHttpRequest( + "GET", "a/b/c?t=y", params={"g": "h"} + ) + assert old_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + + # additionally, calling format_parameters on a new request + # should be the same as passing the params to a new request + assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + assert new_request.url in ["a/b/c?g=h&t=y", "a/b/c?t=y&g=h"] + +def test_request_streamed_data_body(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.files = new_request.files = "foo" + # passing in iterable + def streaming_body(data): + yield data # pragma: nocover + old_request.set_streamed_data_body(streaming_body("i will be streamed")) + new_request.set_streamed_data_body(streaming_body("i will be streamed")) + + assert old_request.files == new_request.files == None + assert isinstance(old_request.data, collections.Iterable) + assert isinstance(new_request.data, collections.Iterable) + assert isinstance(old_request.body, collections.Iterable) + assert isinstance(new_request.body, collections.Iterable) + assert isinstance(new_request.content, collections.Iterable) + assert old_request.headers == new_request.headers == {} + +def test_request_streamed_data_body_non_iterable(old_request, new_request): + # should fail before nullifying the files property + old_request.files = new_request.files = "foo" + # passing in non iterable + with pytest.raises(TypeError) as ex: + old_request.set_streamed_data_body(1) + assert "A streamable data source must be an open file-like object or iterable" in str(ex.value) + assert old_request.data is None + assert old_request.files == "foo" + + with pytest.raises(TypeError) as ex: + new_request.set_streamed_data_body(1) + assert "A streamable data source must be an open file-like object or iterable" in str(ex.value) + assert old_request.data is None + assert old_request.files == "foo" + assert old_request.headers == new_request.headers == {} + +def test_request_streamed_data_body_and_content_kwarg(old_request): + # passing stream bodies to set_streamed_data_body + # and passing a stream body to the content kwarg of the new request should be the same + def streaming_body(data): + yield data # pragma: nocover + old_request.set_streamed_data_body(streaming_body("stream")) + new_request = RestHttpRequest("GET", "/", content=streaming_body("stream")) + assert old_request.files == new_request.files == None + assert isinstance(old_request.data, collections.Iterable) + assert isinstance(new_request.data, collections.Iterable) + assert isinstance(old_request.body, collections.Iterable) + assert isinstance(new_request.body, collections.Iterable) + assert isinstance(new_request.content, collections.Iterable) + assert old_request.headers == new_request.headers == {} + +def test_request_text_body(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.files = new_request.files = "foo" + old_request.set_text_body("i am text") + new_request.set_text_body("i am text") + + assert old_request.files == new_request.files == None + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + "i am text" + ) + assert old_request.headers['Content-Length'] == new_request.headers['Content-Length'] == '9' + assert not old_request.headers.get("Content-Type") + assert new_request.headers["Content-Type"] == "text/plain" + +def test_request_text_body_and_content_kwarg(old_request): + old_request.set_text_body("i am text") + new_request = RestHttpRequest("GET", "/", content="i am text") + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + "i am text" + ) + assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "9" + assert old_request.files == new_request.files == None + +def test_request_xml_body(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.files = new_request.files = "foo" + xml_input = ET.Element("root") + old_request.set_xml_body(xml_input) + new_request.set_xml_body(xml_input) + + assert old_request.files == new_request.files == None + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + b"\n" + ) + assert old_request.headers == new_request.headers == {'Content-Length': '47'} + +def test_request_xml_body_and_content_kwarg(old_request): + old_request.set_text_body("i am text") + new_request = RestHttpRequest("GET", "/", content="i am text") + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + "i am text" + ) + assert old_request.headers["Content-Length"] == new_request.headers["Content-Length"] == "9" + assert old_request.files == new_request.files == None + +def test_request_json_body(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.files = new_request.files = "foo" + json_input = {"hello": "world"} + old_request.set_json_body(json_input) + new_request.set_json_body(json_input) + + assert old_request.files == new_request.files == None + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + json.dumps(json_input) + ) + assert old_request.headers["Content-Length"] == new_request.headers['Content-Length'] == '18' + assert not old_request.headers.get("Content-Type") + assert new_request.headers["Content-Type"] == "application/json" + +def test_request_json_body_and_json_kwarg(old_request): + json_input = {"hello": "world"} + old_request.set_json_body(json_input) + new_request = RestHttpRequest("GET", "/", json=json_input) + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + json.dumps(json_input) + ) + assert old_request.headers["Content-Length"] == new_request.headers['Content-Length'] == '18' + assert not old_request.headers.get("Content-Type") + assert new_request.headers["Content-Type"] == "application/json" + assert old_request.files == new_request.files == None + +def test_request_formdata_body_files(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.data = new_request.data = "foo" + old_request.files = new_request.files = "bar" + + # without the urlencoded content type, set_formdata_body + # will set it as files + old_request.set_formdata_body({"fileName": "hello.jpg"}) + new_request.set_formdata_body({"fileName": "hello.jpg"}) + + assert old_request.data == new_request.data == None + assert ( + old_request.files == + new_request.files == + new_request.content == + {'fileName': (None, 'hello.jpg')} + ) + + # we don't set any multipart headers with boundaries + # we rely on the transport to boundary calculating + assert old_request.headers == new_request.headers == {} + +def test_request_formdata_body_data(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.data = new_request.data = "foo" + old_request.files = new_request.files = "bar" + + # with the urlencoded content type, set_formdata_body + # will set it as data + old_request.headers["Content-Type"] = "application/x-www-form-urlencoded" + new_request.headers["Content-Type"] = "application/x-www-form-urlencoded" + old_request.set_formdata_body({"fileName": "hello.jpg"}) + new_request.set_formdata_body({"fileName": "hello.jpg"}) + + assert old_request.files == new_request.files == None + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + {"fileName": "hello.jpg"} + ) + # old behavior would pop out the Content-Type header + # new behavior doesn't do that + assert old_request.headers == {} + assert new_request.headers == {'Content-Type': "application/x-www-form-urlencoded"} + +def test_request_formdata_body_and_files_kwarg(old_request): + files = {"fileName": "hello.jpg"} + old_request.set_formdata_body(files) + new_request = RestHttpRequest("GET", "/", files=files) + assert old_request.data == new_request.data == None + assert old_request.body == new_request.body == None + assert old_request.headers == new_request.headers == {} + assert old_request.files == new_request.files == {'fileName': (None, 'hello.jpg')} + +def test_request_formdata_body_and_data_kwarg(old_request): + data = {"fileName": "hello.jpg"} + # with the urlencoded content type, set_formdata_body + # will set it as data + old_request.headers["Content-Type"] = "application/x-www-form-urlencoded" + old_request.set_formdata_body(data) + new_request = RestHttpRequest("GET", "/", data=data) + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + {"fileName": "hello.jpg"} + ) + assert old_request.headers == {} + assert new_request.headers == {"Content-Type": "application/x-www-form-urlencoded"} + assert old_request.files == new_request.files == None + +def test_request_bytes_body(old_request, new_request): + assert old_request.files == new_request.files == None + assert old_request.data == new_request.data == None + old_request.files = new_request.files = "foo" + bytes_input = b"hello, world!" + old_request.set_bytes_body(bytes_input) + new_request.set_bytes_body(bytes_input) + + assert old_request.files == new_request.files == None + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + bytes_input + ) + assert old_request.headers == new_request.headers == {'Content-Length': '13'} + +def test_request_bytes_body_and_content_kwarg(old_request): + bytes_input = b"hello, world!" + old_request.set_bytes_body(bytes_input) + new_request = RestHttpRequest("GET", "/", content=bytes_input) + assert ( + old_request.data == + new_request.data == + old_request.body == + new_request.body == + new_request.content == + bytes_input + ) + if sys.version_info < (3, 0): + # in 2.7, b'' is a string, so we're setting content-type headers + assert old_request.headers["Content-Length"] == new_request.headers['Content-Length'] == '13' + assert new_request.headers["Content-Type"] == "text/plain" + else: + assert old_request.headers == new_request.headers == {'Content-Length': '13'} + assert old_request.files == new_request.files diff --git a/sdk/core/azure-core/tests/test_rest_http_request.py b/sdk/core/azure-core/tests/test_rest_http_request.py index 2c43f2b9b0bf..a773b99b61ba 100644 --- a/sdk/core/azure-core/tests/test_rest_http_request.py +++ b/sdk/core/azure-core/tests/test_rest_http_request.py @@ -11,7 +11,15 @@ import pytest import sys import collections + +from azure.core.configuration import Configuration from azure.core.rest import HttpRequest +from azure.core.pipeline.policies import ( + CustomHookPolicy, UserAgentPolicy, SansIOHTTPPolicy, RetryPolicy +) +from utils import is_rest +from rest_client import TestRestClient +from azure.core import PipelineClient @pytest.fixture def assert_iterator_body(): @@ -284,6 +292,113 @@ def test_use_custom_json_encoder(): request = HttpRequest("GET", "/headers", json=bytearray("mybytes", "utf-8")) assert request.content == '"bXlieXRlcw=="' +def test_request_policies_raw_request_hook(port): + # test that the request all the way through the pipeline is a new request + request = HttpRequest("GET", "/headers") + def callback(request): + assert is_rest(request.http_request) + raise ValueError("I entered the callback!") + custom_hook_policy = CustomHookPolicy(raw_request_hook=callback) + policies = [ + UserAgentPolicy("myuseragent"), + custom_hook_policy + ] + client = TestRestClient(port=port, policies=policies) + + with pytest.raises(ValueError) as ex: + client.send_request(request) + assert "I entered the callback!" in str(ex.value) + +@pytest.mark.skipif(sys.version_info < (3, 0), reason="Multipart serialization not supported on 2.7") +def test_request_policies_chain(port): + class OldPolicyModifyBody(SansIOHTTPPolicy): + def on_request(self, request): + assert is_rest(request.http_request) # first make sure this is a new request + # deals with request like an old request + request.http_request.set_json_body({"hello": "world"}) + + class NewPolicyModifyHeaders(SansIOHTTPPolicy): + def on_request(self, request): + assert is_rest(request.http_request) + assert request.http_request.content == '{"hello": "world"}' + + # modify header to know we entered this callback + request.http_request.headers = { + "x-ms-date": "Thu, 14 Jun 2018 16:46:54 GMT", + "Authorization": "SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=", + "Content-Length": "0", + } + + class OldPolicySerializeRequest(SansIOHTTPPolicy): + def on_request(self, request): + assert is_rest(request.http_request) + # don't want to deal with content in serialize, so let's first just remove it + request.http_request.data = None + expected = ( + b'DELETE http://localhost:5000/container0/blob0 HTTP/1.1\r\n' + b'x-ms-date: Thu, 14 Jun 2018 16:46:54 GMT\r\n' + b'Authorization: SharedKey account:G4jjBXA7LI/RnWKIOQ8i9xH4p76pAQ+4Fs4R1VxasaE=\r\n' + b'Content-Length: 0\r\n' + b'\r\n' + ) + assert request.http_request.serialize() == expected + raise ValueError("Passed through the policies!") + + policies = [ + OldPolicyModifyBody(), + NewPolicyModifyHeaders(), + OldPolicySerializeRequest(), + ] + request = HttpRequest("DELETE", "/container0/blob0") + client = TestRestClient(port="5000", policies=policies) + with pytest.raises(ValueError) as ex: + client.send_request( + request, + content="I should be overriden", + ) + assert "Passed through the policies!" in str(ex.value) + + +def test_per_call_policies_old_then_new(port): + config = Configuration() + retry_policy = RetryPolicy() + config.retry_policy = retry_policy + + class OldPolicy(SansIOHTTPPolicy): + """A policy that deals with a rest request thinking that it's an old request""" + + def on_request(self, pipeline_request): + request = pipeline_request.http_request + assert is_rest(request) + assert request.body == '{"hello": "world"}' # old request has property body + request.set_text_body("change to me!") + return pipeline_request + + class NewPolicy(SansIOHTTPPolicy): + + def on_request(self, pipeline_request): + request = pipeline_request.http_request + assert is_rest(request) + assert request.content == 'change to me!' # new request has property content + raise ValueError("I entered the policies!") + + pipeline_client = PipelineClient( + base_url="http://localhost:{}".format(port), + config=config, + per_call_policies=[OldPolicy(), NewPolicy()] + ) + client = TestRestClient(port=port) + client._client = pipeline_client + + with pytest.raises(ValueError) as ex: + client.send_request(HttpRequest("POST", "/basic/anything", json={"hello": "world"})) + + # since we don't have all policies set up, the call ends up failing + # but that's ok with us, we want to make sure that chaining the requests + # work + assert "I entered the policies!" in str(ex.value) + + # NOTE: For files, we don't allow list of tuples yet, just dict. Will uncomment when we add this capability # def test_multipart_multiple_files_single_input_content(): # files = [ @@ -309,4 +424,4 @@ def test_use_custom_json_encoder(): # b"\r\n", # b"--+++--\r\n", # ] -# ) \ No newline at end of file +# ) diff --git a/sdk/core/azure-core/tests/test_retry_policy.py b/sdk/core/azure-core/tests/test_retry_policy.py index ce1a590fba26..8e989aed487c 100644 --- a/sdk/core/azure-core/tests/test_retry_policy.py +++ b/sdk/core/azure-core/tests/test_retry_policy.py @@ -8,6 +8,7 @@ except ImportError: from cStringIO import StringIO as BytesIO import pytest +from itertools import product from azure.core.configuration import ConnectionConfiguration from azure.core.exceptions import ( AzureError, @@ -22,7 +23,6 @@ ) from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, HttpTransport, ) @@ -34,6 +34,7 @@ from unittest.mock import Mock except ImportError: from mock import Mock +from utils import HTTP_REQUESTS def test_retry_code_class_variables(): @@ -62,10 +63,10 @@ def test_retry_types(): backoff_time = retry_policy.get_backoff_time(settings) assert backoff_time == 4 -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) +def test_retry_after(retry_after_input, http_request): retry_policy = RetryPolicy() - request = HttpRequest("GET", "http://localhost") + request = http_request("GET", "http://localhost") response = HttpResponse(request, None) response.headers["retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) @@ -80,10 +81,10 @@ def test_retry_after(retry_after_input): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -@pytest.mark.parametrize("retry_after_input", [('0'), ('800'), ('1000'), ('1200')]) -def test_x_ms_retry_after(retry_after_input): +@pytest.mark.parametrize("retry_after_input,http_request", product(['0', '800', '1000', '1200'], HTTP_REQUESTS)) +def test_x_ms_retry_after(retry_after_input, http_request): retry_policy = RetryPolicy() - request = HttpRequest("GET", "http://localhost") + request = http_request("GET", "http://localhost") response = HttpResponse(request, None) response.headers["x-ms-retry-after-ms"] = retry_after_input pipeline_response = PipelineResponse(request, response, None) @@ -98,7 +99,8 @@ def test_x_ms_retry_after(retry_after_input): retry_after = retry_policy.get_retry_after(pipeline_response) assert retry_after == float(retry_after_input) -def test_retry_on_429(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_on_429(http_request): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -115,14 +117,15 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe response.status_code = 429 return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_retry = RetryPolicy(retry_total = 1) transport = MockTransport() pipeline = Pipeline(transport, [http_retry]) pipeline.run(http_request) assert transport._count == 2 -def test_no_retry_on_201(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_no_retry_on_201(http_request): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -141,14 +144,15 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe response.headers = headers return response - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_retry = RetryPolicy(retry_total = 1) transport = MockTransport() pipeline = Pipeline(transport, [http_retry]) pipeline.run(http_request) assert transport._count == 1 -def test_retry_seekable_stream(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_seekable_stream(http_request): class MockTransport(HttpTransport): def __init__(self): self._first = True @@ -171,13 +175,14 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe return response data = BytesIO(b"Lots of dataaaa") - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') http_request.set_streamed_data_body(data) http_retry = RetryPolicy(retry_total = 1) pipeline = Pipeline(MockTransport(), [http_retry]) pipeline.run(http_request) -def test_retry_seekable_file(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_seekable_file(http_request): class MockTransport(HttpTransport): def __init__(self): self._first = True @@ -208,7 +213,7 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe file = tempfile.NamedTemporaryFile(delete=False) file.write(b'Lots of dataaaa') file.close() - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') headers = {'Content-Type': "multipart/form-data"} http_request.headers = headers with open(file.name, 'rb') as f: @@ -222,8 +227,8 @@ def send(self, request, **kwargs): # type: (PipelineRequest, Any) -> PipelineRe pipeline.run(http_request) os.unlink(f.name) - -def test_retry_timeout(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_timeout(http_request): timeout = 1 def send(request, **kwargs): @@ -239,10 +244,10 @@ def send(request, **kwargs): pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)]) with pytest.raises(ServiceResponseTimeoutError): - response = pipeline.run(HttpRequest("GET", "http://localhost/")) - + response = pipeline.run(http_request("GET", "http://localhost/")) -def test_timeout_defaults(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_timeout_defaults(http_request): """When "timeout" is not set, the policy should not override the transport's timeout configuration""" def send(request, **kwargs): @@ -259,17 +264,19 @@ def send(request, **kwargs): ) pipeline = Pipeline(transport, [RetryPolicy()]) - pipeline.run(HttpRequest("GET", "http://localhost/")) + pipeline.run(http_request("GET", "http://localhost/")) assert transport.send.call_count == 1, "policy should not retry: its first send succeeded" +combinations = [(ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)] @pytest.mark.parametrize( - "transport_error,expected_timeout_error", - ((ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)), + "combinations,http_request", + product(combinations, HTTP_REQUESTS), ) -def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error): +def test_does_not_sleep_after_timeout(combinations, http_request): # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s. # It should not sleep the second time when given timeout=1 + transport_error,expected_timeout_error = combinations timeout = 1 transport = Mock( @@ -280,6 +287,6 @@ def test_does_not_sleep_after_timeout(transport_error, expected_timeout_error): pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)]) with pytest.raises(expected_timeout_error): - pipeline.run(HttpRequest("GET", "http://localhost/")) + pipeline.run(http_request("GET", "http://localhost/")) assert transport.sleep.call_count == 1 diff --git a/sdk/core/azure-core/tests/test_stream_generator.py b/sdk/core/azure-core/tests/test_stream_generator.py index 1d8b3c7172c4..4c9f5190e9c6 100644 --- a/sdk/core/azure-core/tests/test_stream_generator.py +++ b/sdk/core/azure-core/tests/test_stream_generator.py @@ -4,7 +4,6 @@ # ------------------------------------ import requests from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, HttpTransport, RequestsTransport, @@ -17,8 +16,10 @@ except ImportError: import mock import pytest +from utils import HTTP_REQUESTS -def test_connection_error_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_connection_error_response(http_request): class MockTransport(HttpTransport): def __init__(self): self._count = 0 @@ -31,7 +32,7 @@ def open(self): pass def send(self, request, **kwargs): - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') response = HttpResponse(request, None) response.status_code = 200 return response @@ -43,7 +44,7 @@ def __next__(self): if self._count == 0: self._count += 1 raise requests.exceptions.ConnectionError - + def stream(self, chunk_size, decode_content=False): if self._count == 0: self._count += 1 @@ -58,7 +59,7 @@ def __init__(self): def close(self): pass - http_request = HttpRequest('GET', 'http://localhost/') + http_request = http_request('GET', 'http://localhost/') pipeline = Pipeline(MockTransport()) http_response = HttpResponse(http_request, None) http_response.internal_response = MockInternalResponse() diff --git a/sdk/core/azure-core/tests/test_streaming.py b/sdk/core/azure-core/tests/test_streaming.py index 2a5e6a4d0bb8..87e51851722f 100644 --- a/sdk/core/azure-core/tests/test_streaming.py +++ b/sdk/core/azure-core/tests/test_streaming.py @@ -23,16 +23,19 @@ # THE SOFTWARE. # # -------------------------------------------------------------------------- +import pytest from azure.core import PipelineClient from azure.core.exceptions import DecodeError +from utils import HTTP_REQUESTS -def test_decompress_plain_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -40,13 +43,14 @@ def test_decompress_plain_no_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_compress_plain_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_plain_no_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -54,13 +58,14 @@ def test_compress_plain_no_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_decompress_compressed_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -71,13 +76,14 @@ def test_decompress_compressed_no_header(): except UnicodeDecodeError: pass -def test_compress_compressed_no_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_compressed_no_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -88,14 +94,15 @@ def test_compress_compressed_no_header(): except UnicodeDecodeError: pass -def test_decompress_plain_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_plain_header(http_request): # expect error import requests account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -105,13 +112,14 @@ def test_decompress_plain_header(): except (requests.exceptions.ContentDecodingError, DecodeError): pass -def test_compress_plain_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_plain_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.txt".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) @@ -119,13 +127,14 @@ def test_compress_plain_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_decompress_compressed_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_decompress_compressed_header(http_request): # expect plain text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=True) @@ -133,13 +142,14 @@ def test_decompress_compressed_header(): decoded = content.decode('utf-8') assert decoded == "test" -def test_compress_compressed_header(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_compress_compressed_header(http_request): # expect compressed text account_name = "coretests" account_url = "https://{}.blob.core.windows.net".format(account_name) url = "https://{}.blob.core.windows.net/tests/test_with_header.tar.gz".format(account_name) client = PipelineClient(account_url) - request = client.get(url) + request = http_request("GET", url) pipeline_response = client._pipeline.run(request, stream=True) response = pipeline_response.http_response data = response.stream_download(client._pipeline, decompress=False) diff --git a/sdk/core/azure-core/tests/test_testserver.py b/sdk/core/azure-core/tests/test_testserver.py index 0cf5e4ffc26e..544778e32a79 100644 --- a/sdk/core/azure-core/tests/test_testserver.py +++ b/sdk/core/azure-core/tests/test_testserver.py @@ -23,11 +23,14 @@ # THE SOFTWARE. # # -------------------------------------------------------------------------- -from azure.core.pipeline.transport import HttpRequest, RequestsTransport +from azure.core.pipeline.transport import RequestsTransport +from utils import HTTP_REQUESTS +import pytest """This file does a simple call to the testserver to make sure we can use the testserver""" -def test_smoke(port): - request = HttpRequest(method="GET", url="http://localhost:{}/basic/string".format(port)) +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_smoke(port, http_request): + request = http_request(method="GET", url="http://localhost:{}/basic/string".format(port)) with RequestsTransport() as sender: response = sender.send(request) response.raise_for_status() diff --git a/sdk/core/azure-core/tests/test_tracing_decorator.py b/sdk/core/azure-core/tests/test_tracing_decorator.py index be9a820747e3..4c4b91d81420 100644 --- a/sdk/core/azure-core/tests/test_tracing_decorator.py +++ b/sdk/core/azure-core/tests/test_tracing_decorator.py @@ -14,12 +14,12 @@ import pytest from azure.core.pipeline import Pipeline, PipelineResponse from azure.core.pipeline.policies import HTTPPolicy -from azure.core.pipeline.transport import HttpTransport, HttpRequest +from azure.core.pipeline.transport import HttpTransport from azure.core.settings import settings from azure.core.tracing import common from azure.core.tracing.decorator import distributed_trace from tracing_common import FakeSpan - +from utils import HTTP_REQUESTS @pytest.fixture(scope="module") def fake_span(): @@ -28,9 +28,9 @@ def fake_span(): class MockClient: @distributed_trace - def __init__(self, policies=None, assert_current_span=False): + def __init__(self, http_request, policies=None, assert_current_span=False): time.sleep(0.001) - self.request = HttpRequest("GET", "http://localhost") + self.request = http_request("GET", "http://localhost") if policies is None: policies = [] policies.append(mock.Mock(spec=HTTPPolicy, send=self.verify_request)) @@ -87,8 +87,9 @@ def random_function(): pass -def test_get_function_and_class_name(): - client = MockClient() +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_get_function_and_class_name(http_request): + client = MockClient(http_request) assert common.get_function_and_class_name(client.get_foo, client) == "MockClient.get_foo" assert common.get_function_and_class_name(random_function) == "random_function" @@ -96,9 +97,10 @@ def test_get_function_and_class_name(): @pytest.mark.usefixtures("fake_span") class TestDecorator(object): - def test_decorator_tracing_attr(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_decorator_tracing_attr(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.tracing_attr() assert len(parent.children) == 2 @@ -106,18 +108,20 @@ def test_decorator_tracing_attr(self): assert parent.children[1].name == "MockClient.tracing_attr" assert parent.children[1].attributes == {'foo': 'bar'} - def test_decorator_has_different_name(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_decorator_has_different_name(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.check_name_is_different() assert len(parent.children) == 2 assert parent.children[0].name == "MockClient.__init__" assert parent.children[1].name == "different name" - def test_used(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_used(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient(policies=[]) + client = MockClient(http_request, policies=[]) client.get_foo(parent_span=parent) client.get_foo() @@ -129,9 +133,10 @@ def test_used(self): assert parent.children[2].name == "MockClient.get_foo" assert not parent.children[2].children - def test_span_merge_span(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_span_merge_span(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.merge_span_method() client.no_merge_span_method() @@ -143,9 +148,10 @@ def test_span_merge_span(self): assert parent.children[2].name == "MockClient.no_merge_span_method" assert parent.children[2].children[0].name == "MockClient.get_foo" - def test_span_complicated(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_span_complicated(self, http_request): with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) client.make_request(2) with parent.span("child") as child: time.sleep(0.001) @@ -163,11 +169,12 @@ def test_span_complicated(self): assert parent.children[3].name == "MockClient.make_request" assert not parent.children[3].children - def test_span_with_exception(self): + @pytest.mark.parametrize("http_request", HTTP_REQUESTS) + def test_span_with_exception(self, http_request): """Assert that if an exception is raised, the next sibling method is actually a sibling span. """ with FakeSpan(name="parent") as parent: - client = MockClient() + client = MockClient(http_request) try: client.raising_exception() except: diff --git a/sdk/core/azure-core/tests/test_tracing_policy.py b/sdk/core/azure-core/tests/test_tracing_policy.py index 2a0fc03a78e7..ef46a53197e3 100644 --- a/sdk/core/azure-core/tests/test_tracing_policy.py +++ b/sdk/core/azure-core/tests/test_tracing_policy.py @@ -7,10 +7,12 @@ from azure.core.pipeline import PipelineResponse, PipelineRequest, PipelineContext from azure.core.pipeline.policies import DistributedTracingPolicy, UserAgentPolicy -from azure.core.pipeline.transport import HttpRequest, HttpResponse +from azure.core.pipeline.transport import HttpResponse from azure.core.settings import settings from tracing_common import FakeSpan import time +import pytest +from utils import HTTP_REQUESTS try: from unittest import mock @@ -18,13 +20,14 @@ import mock -def test_distributed_tracing_policy_solo(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_distributed_tracing_policy_solo(http_request): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() - request = HttpRequest("GET", "http://localhost/temp?query=query") + request = http_request("GET", "http://localhost/temp?query=query") request.headers["x-ms-client-request-id"] = "some client request id" pipeline_request = PipelineRequest(request, PipelineContext(None)) @@ -68,7 +71,8 @@ def test_distributed_tracing_policy_solo(): assert network_span.attributes.get("http.status_code") == 504 -def test_distributed_tracing_policy_attributes(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_distributed_tracing_policy_attributes(http_request): """Test policy with no other policy and happy path""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: @@ -76,7 +80,7 @@ def test_distributed_tracing_policy_attributes(): 'myattr': 'myvalue' }) - request = HttpRequest("GET", "http://localhost/temp?query=query") + request = http_request("GET", "http://localhost/temp?query=query") pipeline_request = PipelineRequest(request, PipelineContext(None)) policy.on_request(pipeline_request) @@ -92,13 +96,14 @@ def test_distributed_tracing_policy_attributes(): assert network_span.attributes.get("myattr") == "myvalue" -def test_distributed_tracing_policy_badurl(caplog): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_distributed_tracing_policy_badurl(caplog, http_request): """Test policy with a bad url that will throw, and be sure policy ignores it""" settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() - request = HttpRequest("GET", "http://[[[") + request = http_request("GET", "http://[[[") request.headers["x-ms-client-request-id"] = "some client request id" pipeline_request = PipelineRequest(request, PipelineContext(None)) @@ -125,14 +130,15 @@ def test_distributed_tracing_policy_badurl(caplog): assert len(root_span.children) == 0 -def test_distributed_tracing_policy_with_user_agent(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_distributed_tracing_policy_with_user_agent(http_request): """Test policy working with user agent.""" settings.tracing_implementation.set_value(FakeSpan) with mock.patch.dict('os.environ', {"AZURE_HTTP_USER_AGENT": "mytools"}): with FakeSpan(name="parent") as root_span: policy = DistributedTracingPolicy() - request = HttpRequest("GET", "http://localhost") + request = http_request("GET", "http://localhost") request.headers["x-ms-client-request-id"] = "some client request id" pipeline_request = PipelineRequest(request, PipelineContext(None)) @@ -183,11 +189,12 @@ def test_distributed_tracing_policy_with_user_agent(): assert network_span.status == 'Transport trouble' -def test_span_namer(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_span_namer(http_request): settings.tracing_implementation.set_value(FakeSpan) with FakeSpan(name="parent") as root_span: - request = HttpRequest("GET", "http://localhost/temp?query=query") + request = http_request("GET", "http://localhost/temp?query=query") pipeline_request = PipelineRequest(request, PipelineContext(None)) def fixed_namer(http_request): diff --git a/sdk/core/azure-core/tests/test_universal_pipeline.py b/sdk/core/azure-core/tests/test_universal_pipeline.py index ea5676374458..9d532608228b 100644 --- a/sdk/core/azure-core/tests/test_universal_pipeline.py +++ b/sdk/core/azure-core/tests/test_universal_pipeline.py @@ -42,7 +42,6 @@ PipelineContext ) from azure.core.pipeline.transport import ( - HttpRequest, HttpResponse, RequestsTransportResponse, ) @@ -54,6 +53,7 @@ RetryPolicy, HTTPPolicy, ) +from utils import HTTP_REQUESTS, create_http_request def test_pipeline_context(): kwargs={ @@ -86,26 +86,28 @@ def test_pipeline_context(): assert len(revived_context) == 1 -def test_request_history(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_history(http_request): class Non_deep_copiable(object): def __deepcopy__(self, memodict={}): raise ValueError() body = Non_deep_copiable() - request = HttpRequest('GET', 'http://localhost/', {'user-agent': 'test_request_history'}) + request = create_http_request(http_request, 'GET', 'http://localhost/', {'user-agent': 'test_request_history'}) request.body = body request_history = RequestHistory(request) assert request_history.http_request.headers == request.headers assert request_history.http_request.url == request.url assert request_history.http_request.method == request.method -def test_request_history_type_error(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_request_history_type_error(http_request): class Non_deep_copiable(object): def __deepcopy__(self, memodict={}): raise TypeError() body = Non_deep_copiable() - request = HttpRequest('GET', 'http://localhost/', {'user-agent': 'test_request_history'}) + request = create_http_request(http_request, 'GET', 'http://localhost/', {'user-agent': 'test_request_history'}) request.body = body request_history = RequestHistory(request) assert request_history.http_request.headers == request.headers @@ -113,8 +115,9 @@ def __deepcopy__(self, memodict={}): assert request_history.http_request.method == request.method @mock.patch('azure.core.pipeline.policies._universal._LOGGER') -def test_no_log(mock_http_logger): - universal_request = HttpRequest('GET', 'http://localhost/') +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_no_log(mock_http_logger, http_request): + universal_request = http_request('GET', 'http://localhost/') request = PipelineRequest(universal_request, PipelineContext(None)) http_logger = NetworkTraceLoggingPolicy() response = PipelineResponse(request, HttpResponse(universal_request, None), request.context) @@ -178,7 +181,8 @@ def test_no_log(mock_http_logger): second_count = mock_http_logger.debug.call_count assert second_count == first_count * 2 -def test_retry_without_http_response(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_retry_without_http_response(http_request): class NaughtyPolicy(HTTPPolicy): def send(*args): raise AzureError('boo') @@ -186,12 +190,13 @@ def send(*args): policies = [RetryPolicy(), NaughtyPolicy()] pipeline = Pipeline(policies=policies, transport=None) with pytest.raises(AzureError): - pipeline.run(HttpRequest('GET', url='https://foo.bar')) + pipeline.run(http_request('GET', url='https://foo.bar')) -def test_raw_deserializer(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_raw_deserializer(http_request): raw_deserializer = ContentDecodePolicy() context = PipelineContext(None, stream=False) - universal_request = HttpRequest('GET', 'http://localhost/') + universal_request = http_request('GET', 'http://localhost/') request = PipelineRequest(universal_request, context) def build_response(body, content_type=None): diff --git a/sdk/core/azure-core/tests/test_user_agent_policy.py b/sdk/core/azure-core/tests/test_user_agent_policy.py index 4f8b01c93b7e..afed968f149f 100644 --- a/sdk/core/azure-core/tests/test_user_agent_policy.py +++ b/sdk/core/azure-core/tests/test_user_agent_policy.py @@ -4,14 +4,16 @@ # ------------------------------------ """Tests for the user agent policy.""" from azure.core.pipeline.policies import UserAgentPolicy -from azure.core.pipeline.transport import HttpRequest from azure.core.pipeline import PipelineRequest, PipelineContext try: from unittest import mock except ImportError: import mock +import pytest +from utils import HTTP_REQUESTS -def test_user_agent_policy(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_user_agent_policy(http_request): user_agent = UserAgentPolicy(base_user_agent='foo') assert user_agent._user_agent == 'foo' @@ -21,7 +23,7 @@ def test_user_agent_policy(): user_agent = UserAgentPolicy(base_user_agent='foo', user_agent='bar', user_agent_use_env=False) assert user_agent._user_agent == 'bar foo' - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') pipeline_request = PipelineRequest(request, PipelineContext(None)) pipeline_request.context.options['user_agent'] = 'xyz' @@ -29,12 +31,13 @@ def test_user_agent_policy(): assert request.headers['User-Agent'] == 'xyz bar foo' -def test_user_agent_environ(): +@pytest.mark.parametrize("http_request", HTTP_REQUESTS) +def test_user_agent_environ(http_request): with mock.patch.dict('os.environ', {'AZURE_HTTP_USER_AGENT': "mytools"}): policy = UserAgentPolicy(None) assert policy.user_agent.endswith("mytools") - request = HttpRequest('GET', 'http://localhost/') + request = http_request('GET', 'http://localhost/') policy.on_request(PipelineRequest(request, PipelineContext(None))) assert request.headers["user-agent"].endswith("mytools") diff --git a/sdk/core/azure-core/tests/utils.py b/sdk/core/azure-core/tests/utils.py new file mode 100644 index 000000000000..ab3be4d12d1d --- /dev/null +++ b/sdk/core/azure-core/tests/utils.py @@ -0,0 +1,43 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import pytest +############################## LISTS USED TO PARAMETERIZE TESTS ############################## +from azure.core.rest import HttpRequest as RestHttpRequest +from azure.core.pipeline.transport import HttpRequest as PipelineTransportHttpRequest + +HTTP_REQUESTS = [PipelineTransportHttpRequest, RestHttpRequest] + + +############################## HELPER FUNCTIONS ############################## + +def is_rest(http_request): + return hasattr(http_request, "content") + +def create_http_request(http_request, *args, **kwargs): + if hasattr(http_request, "content"): + method = args[0] + url = args[1] + try: + headers = args[2] + except IndexError: + headers = None + try: + files = args[3] + except IndexError: + files = None + try: + data = args[4] + except IndexError: + data = None + return http_request( + method=method, + url=url, + headers=headers, + files=files, + data=data, + **kwargs + ) + return http_request(*args, **kwargs)