Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom policy #17340

Merged
merged 8 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdk/core/azure-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Release History

## 1.12.1 (Unreleased)
## 1.13.0 (Unreleased)

### Features

- Supported adding custom policies #16519


## 1.12.0 (2021-03-08)
Expand Down
43 changes: 34 additions & 9 deletions sdk/core/azure-core/azure/core/_pipeline_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@
# --------------------------------------------------------------------------

import logging
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
from .configuration import Configuration
from .pipeline import Pipeline
from .pipeline.transport._base import PipelineClientBase
from .pipeline.policies import (
ContentDecodePolicy, DistributedTracingPolicy, HttpLoggingPolicy, RequestIdPolicy
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
RequestIdPolicy,
)
from .pipeline.transport import RequestsTransport

Expand Down Expand Up @@ -64,6 +71,10 @@ class PipelineClient(PipelineClientBase):
:keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used.
:keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
:keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
:paramtype per_call_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]]
:keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
:paramtype per_retry_policies: Union[HTTPPolicy, SansIOHTTPPolicy, list[HTTPPolicy], list[SansIOHTTPPolicy]]
:keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport.
:return: A pipeline object.
:rtype: ~azure.core.pipeline.Pipeline
Expand Down Expand Up @@ -102,20 +113,34 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
policies = kwargs.get('policies')

if policies is None: # [] is a valid policy list
per_call_policies = kwargs.get('per_call_policies', [])
per_retry_policies = kwargs.get('per_retry_policies', [])
policies = [
RequestIdPolicy(**kwargs),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy,
config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)
ContentDecodePolicy(**kwargs)
]
if isinstance(per_call_policies, Iterable):
for policy in per_call_policies:
policies.append(policy)
Comment on lines +125 to +127
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to explicitly iterate over per_call_policies. list.extend(iterable) does the same thing.

else:
policies.append(per_call_policies)

policies = policies + [config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy]
Comment on lines +131 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a new list object to be created and assigned to policies. list.extend(iterable) is a better choice as it modifies policies directly.

if isinstance(per_retry_policies, Iterable):
for policy in per_retry_policies:
policies.append(policy)
else:
policies.append(per_retry_policies)

policies = policies + [config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)]

if not transport:
transport = RequestsTransport(**kwargs)
Expand Down
38 changes: 33 additions & 5 deletions sdk/core/azure-core/azure/core/_pipeline_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@
# --------------------------------------------------------------------------

import logging
from collections.abc import Iterable
from .configuration import Configuration
from .pipeline import AsyncPipeline
from .pipeline.transport._base import PipelineClientBase
from .pipeline.policies import (
ContentDecodePolicy, DistributedTracingPolicy, HttpLoggingPolicy, RequestIdPolicy
ContentDecodePolicy,
DistributedTracingPolicy,
HttpLoggingPolicy,
RequestIdPolicy,
)

try:
Expand Down Expand Up @@ -62,8 +66,14 @@ class AsyncPipelineClient(PipelineClientBase):
:param str base_url: URL for the request.
:keyword ~azure.core.configuration.Configuration config: If omitted, the standard configuration is used.
:keyword Pipeline pipeline: If omitted, a Pipeline object is created and returned.
:keyword list[HTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword HttpTransport transport: If omitted, RequestsTransport is used for synchronous transport.
:keyword list[AsyncHTTPPolicy] policies: If omitted, the standard policies of the configuration object is used.
:keyword per_call_policies: If specified, the policies will be added into the policy list before RetryPolicy
:paramtype per_call_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
:keyword per_retry_policies: If specified, the policies will be added into the policy list after RetryPolicy
:paramtype per_retry_policies: Union[AsyncHTTPPolicy, SansIOHTTPPolicy,
list[AsyncHTTPPolicy], list[SansIOHTTPPolicy]]
:keyword AsyncHttpTransport transport: If omitted, AioHttpTransport is used for synchronous transport.
:return: An async pipeline object.
:rtype: ~azure.core.pipeline.AsyncPipeline

Expand Down Expand Up @@ -101,16 +111,34 @@ def _build_pipeline(self, config, **kwargs): # pylint: disable=no-self-use
policies = kwargs.get('policies')

if policies is None: # [] is a valid policy list
per_call_policies = kwargs.get('per_call_policies', [])
per_retry_policies = kwargs.get('per_retry_policies', [])
policies = [
RequestIdPolicy(**kwargs),
config.headers_policy,
config.user_agent_policy,
config.proxy_policy,
ContentDecodePolicy(**kwargs),
ContentDecodePolicy(**kwargs)
]
if isinstance(per_call_policies, Iterable):
for policy in per_call_policies:
policies.append(policy)
else:
policies.append(per_call_policies)

policies = policies + [
config.redirect_policy,
config.retry_policy,
config.authentication_policy,
config.custom_hook_policy,
config.custom_hook_policy
]
if isinstance(per_retry_policies, Iterable):
for policy in per_retry_policies:
policies.append(policy)
else:
policies.append(per_retry_policies)

policies = policies + [
config.logging_policy,
DistributedTracingPolicy(**kwargs),
config.http_logging_policy or HttpLoggingPolicy(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion sdk/core/azure-core/azure/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# regenerated.
# --------------------------------------------------------------------------

VERSION = "1.12.1"
VERSION = "1.13.0"
40 changes: 39 additions & 1 deletion sdk/core/azure-core/tests/async_tests/test_pipeline_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,42 @@ def send(*args):
policies = [AsyncRetryPolicy(), NaughtyPolicy()]
pipeline = AsyncPipeline(policies=policies, transport=None)
with pytest.raises(AzureError):
await pipeline.run(HttpRequest('GET', url='https://foo.bar'))
await pipeline.run(HttpRequest('GET', url='https://foo.bar'))

@pytest.mark.asyncio
async def test_add_custom_policy():
class BooPolicy(AsyncHTTPPolicy):
def send(*args):
raise AzureError('boo')

class FooPolicy(AsyncHTTPPolicy):
def send(*args):
raise AzureError('boo')

boo_policy = BooPolicy()
foo_policy = FooPolicy()
client = AsyncPipelineClient(base_url="test", per_call_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = AsyncPipelineClient(base_url="test", per_call_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies
xiangyan99 marked this conversation as resolved.
Show resolved Hide resolved

client = AsyncPipelineClient(base_url="test", per_retry_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
client = AsyncPipelineClient(base_url="test", per_retry_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = AsyncPipelineClient(base_url="test", per_call_policies=boo_policy, per_retry_policies=foo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies

client = AsyncPipelineClient(base_url="test", per_call_policies=[boo_policy],
per_retry_policies=[foo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies
41 changes: 40 additions & 1 deletion sdk/core/azure-core/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
SansIOHTTPPolicy,
UserAgentPolicy,
RedirectPolicy,
HttpLoggingPolicy
HttpLoggingPolicy,
HTTPPolicy,
SansIOHTTPPolicy
)
from azure.core.pipeline.transport._base import PipelineClientBase
from azure.core.pipeline.transport import (
Expand Down Expand Up @@ -332,6 +334,43 @@ def test_repr(self):
request = HttpRequest("GET", "hello.com")
assert repr(request) == "<HttpRequest [GET], url: 'hello.com'>"

def test_add_custom_policy(self):
class BooPolicy(HTTPPolicy):
def send(*args):
raise AzureError('boo')

class FooPolicy(HTTPPolicy):
def send(*args):
raise AzureError('boo')

boo_policy = BooPolicy()
foo_policy = FooPolicy()
client = PipelineClient(base_url="test", per_call_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = PipelineClient(base_url="test", per_call_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = PipelineClient(base_url="test", per_retry_policies=boo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
client = PipelineClient(base_url="test", per_retry_policies=[boo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies

client = PipelineClient(base_url="test", per_call_policies=boo_policy, per_retry_policies=foo_policy)
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies

client = PipelineClient(base_url="test", per_call_policies=[boo_policy],
per_retry_policies=[foo_policy])
policies = client._pipeline._impl_policies
assert boo_policy in policies
assert foo_policy in policies


if __name__ == "__main__":
unittest.main()