Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AppConfig] Add lock to SyncTokenPolicy #19395

Merged
merged 5 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ def __init__(self, base_url, credential, **kwargs):
pipeline = kwargs.get("pipeline")

if pipeline is None:
self._sync_token_policy = SyncTokenPolicy()
aad_mode = not isinstance(credential, AppConfigConnectionStringCredential)
pipeline = self._create_appconfig_pipeline(
credential=credential, aad_mode=aad_mode, base_url=base_url, **kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#
# --------------------------------------------------------------------------
from typing import Any, Dict
from threading import Lock
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy

Expand Down Expand Up @@ -63,18 +64,20 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument
# type: (**Any) -> None
self._sync_token_header = "Sync-Token"
self._sync_tokens = {} # type: Dict[str, Any]
self._lock = Lock()

def on_request(self, request): # type: ignore # pylint: disable=arguments-differ
# type: (PipelineRequest) -> None
"""This is executed before sending the request to the next policy.
:param request: The PipelineRequest object.
:type request: ~azure.core.pipeline.PipelineRequest
"""
sync_token_header = ",".join(str(x) for x in self._sync_tokens.values())
if sync_token_header:
request.http_request.headers.update(
{self._sync_token_header: sync_token_header}
)
with self._lock:
sync_token_header = ",".join(str(x) for x in self._sync_tokens.values())
if sync_token_header:
request.http_request.headers.update(
{self._sync_token_header: sync_token_header}
)

def on_response(self, request, response): # type: ignore # pylint: disable=arguments-differ
# type: (PipelineRequest, PipelineResponse) -> None
Expand Down Expand Up @@ -105,9 +108,10 @@ def _update_sync_token(self, sync_token):
# type: (SyncToken) -> None
if not sync_token:
return
existing_token = self._sync_tokens.get(sync_token.token_id, None)
if not existing_token:
self._sync_tokens[sync_token.token_id] = sync_token
return
if existing_token.sequence_number < sync_token.sequence_number:
self._sync_tokens[sync_token.token_id] = sync_token
with self._lock:
existing_token = self._sync_tokens.get(sync_token.token_id, None)
if not existing_token:
self._sync_tokens[sync_token.token_id] = sync_token
return
if existing_token.sequence_number < sync_token.sequence_number:
self._sync_tokens[sync_token.token_id] = sync_token
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is "add_token"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_token is update_sync_token

Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
from .._azure_appconfiguration_credential import AppConfigConnectionStringCredential
from .._generated.models import KeyValue
from .._models import ConfigurationSetting
from .._sync_token import SyncTokenPolicy
from .._user_agent import USER_AGENT
from ._sync_token_async import AsyncSyncTokenPolicy

try:
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -87,10 +87,9 @@ def __init__(self, base_url, credential, **kwargs):
)

pipeline = kwargs.get("pipeline")
self._sync_token_policy = SyncTokenPolicy()
self._sync_token_policy = AsyncSyncTokenPolicy()

if pipeline is None:
self._sync_token_policy = SyncTokenPolicy()
aad_mode = not isinstance(credential, AppConfigConnectionStringCredential)
pipeline = self._create_appconfig_pipeline(
credential=credential, aad_mode=aad_mode, base_url=base_url, **kwargs
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from typing import Any, Dict
from asyncio import Lock
from azure.core.pipeline import PipelineRequest, PipelineResponse
from azure.core.pipeline.policies import SansIOHTTPPolicy

from .._sync_token import SyncToken


class AsyncSyncTokenPolicy(SansIOHTTPPolicy):
"""A simple policy that enable the given callback
with the response.
:keyword callback raw_response_hook: Callback function. Will be invoked on response.
"""

def __init__(self, **kwargs): # pylint: disable=unused-argument
# type: (**Any) -> None
self._sync_token_header = "Sync-Token"
self._sync_tokens = {} # type: Dict[str, Any]
self._lock = Lock()

async def on_request(self, request): # type: ignore # pylint: disable=arguments-differ, invalid-overridden-method
# type: (PipelineRequest) -> None
"""This is executed before sending the request to the next policy.
:param request: The PipelineRequest object.
:type request: ~azure.core.pipeline.PipelineRequest
"""
async with self._lock:
sync_token_header = ",".join(str(x) for x in self._sync_tokens.values())
if sync_token_header:
request.http_request.headers.update(
{self._sync_token_header: sync_token_header}
)

async def on_response(self, request, response): # type: ignore # pylint: disable=arguments-differ, invalid-overridden-method
# type: (PipelineRequest, PipelineResponse) -> None
"""This is executed after the request comes back from the policy.
:param request: The PipelineRequest object.
:type request: ~azure.core.pipeline.PipelineRequest
:param response: The PipelineResponse object.
:type response: ~azure.core.pipeline.PipelineResponse
"""
sync_token_header = response.http_response.headers.get(self._sync_token_header)
if not sync_token_header:
return
sync_token_strings = sync_token_header.split(",")
if not sync_token_strings:
return
for sync_token_string in sync_token_strings:
sync_token = SyncToken.from_sync_token_string(sync_token_string)
await self._update_sync_token(sync_token)

async def add_token(self, full_raw_tokens):
# type: (str) -> None
raw_tokens = full_raw_tokens.split(",")
for raw_token in raw_tokens:
sync_token = SyncToken.from_sync_token_string(raw_token)
await self._update_sync_token(sync_token)

async def _update_sync_token(self, sync_token):
# type: (SyncToken) -> None
if not sync_token:
return
async with self._lock:
existing_token = self._sync_tokens.get(sync_token.token_id, None)
if not existing_token:
self._sync_tokens[sync_token.token_id] = sync_token
return
if existing_token.sequence_number < sync_token.sequence_number:
self._sync_tokens[sync_token.token_id] = sync_token