Skip to content

Commit

Permalink
[Storage] Fix batch APIs for Azurite (#36862)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincenttran-msft authored Aug 20, 2024
1 parent 8a6980c commit 76328e9
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1412,6 +1412,8 @@ def delete_blobs( # pylint: disable=delete-operation-wrong-return-type
"""
if len(blobs) == 0:
return iter([])
if self._is_localhost:
kwargs['url_prepend'] = self.account_name

reqs, options = _generate_delete_blobs_options(
self._query_str,
Expand Down Expand Up @@ -1494,6 +1496,8 @@ def set_standard_blob_tier_blobs(
:return: An iterator of responses, one for each blob in order
:rtype: Iterator[~azure.core.pipeline.transport.HttpResponse]
"""
if self._is_localhost:
kwargs['url_prepend'] = self.account_name
reqs, options = _generate_set_tiers_options(
self._query_str,
self.container_name,
Expand Down Expand Up @@ -1553,6 +1557,8 @@ def set_premium_page_blob_tier_blobs(
:return: An iterator of responses, one for each blob in order
:rtype: Iterator[~azure.core.pipeline.transport.HttpResponse]
"""
if self._is_localhost:
kwargs['url_prepend'] = self.account_name
reqs, options = _generate_set_tiers_options(
self._query_str,
self.container_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def _generate_delete_blobs_options(
if_modified_since = kwargs.pop('if_modified_since', None)
if_unmodified_since = kwargs.pop('if_unmodified_since', None)
if_tags_match_condition = kwargs.pop('if_tags_match_condition', None)
url_prepend = kwargs.pop('url_prepend', None)
kwargs.update({'raise_on_any_failure': raise_on_any_failure,
'sas': query_str.replace('?', '&'),
'timeout': '&timeout=' + str(timeout) if timeout else "",
Expand Down Expand Up @@ -157,9 +158,11 @@ def _generate_delete_blobs_options(

req = HttpRequest(
"DELETE",
f"/{quote(container_name)}/{quote(str(blob_name), safe='/~')}{query_str}",
(f"{'/' + quote(url_prepend) if url_prepend else ''}/"
f"{quote(container_name)}/{quote(str(blob_name), safe='/~')}{query_str}"),
headers=header_parameters
)

req.format_parameters(query_parameters)
reqs.append(req)

Expand Down Expand Up @@ -223,6 +226,7 @@ def _generate_set_tiers_options(
raise_on_any_failure = kwargs.pop('raise_on_any_failure', True)
rehydrate_priority = kwargs.pop('rehydrate_priority', None)
if_tags = kwargs.pop('if_tags_match_condition', None)
url_prepend = kwargs.pop('url_prepend', None)
kwargs.update({'raise_on_any_failure': raise_on_any_failure,
'sas': query_str.replace('?', '&'),
'timeout': '&timeout=' + str(timeout) if timeout else "",
Expand Down Expand Up @@ -252,7 +256,8 @@ def _generate_set_tiers_options(

req = HttpRequest(
"PUT",
f"/{quote(container_name)}/{quote(str(blob_name), safe='/~')}{query_str}",
(f"{'/' + quote(url_prepend) if url_prepend else ''}/"
f"{quote(container_name)}/{quote(str(blob_name), safe='/~')}{query_str}"),
headers=header_parameters
)
req.format_parameters(query_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
self._hosts = kwargs.get("_hosts")
self.scheme = parsed_url.scheme
self._is_localhost = False

if service not in ["blob", "queue", "file-share", "dfs"]:
raise ValueError(f"Invalid service: {service}")
Expand All @@ -85,6 +86,7 @@ def __init__(
self.account_name = account[0] if len(account) > 1 else None
if not self.account_name and parsed_url.netloc.startswith("localhost") \
or parsed_url.netloc.startswith("127.0.0.1"):
self._is_localhost = True
self.account_name = parsed_url.path.strip("/")

self.credential = _format_shared_key_credential(self.account_name, credential)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,8 @@ async def delete_blobs(
"""
if len(blobs) == 0:
return AsyncList([])
if self._is_localhost:
kwargs['url_prepend'] = self.account_name

reqs, options = _generate_delete_blobs_options(
self._query_str,
Expand Down Expand Up @@ -1485,6 +1487,8 @@ async def set_standard_blob_tier_blobs(
:return: An async iterator of responses, one for each blob in order
:rtype: asynciterator[~azure.core.pipeline.transport.AsyncHttpResponse]
"""
if self._is_localhost:
kwargs['url_prepend'] = self.account_name
reqs, options = _generate_set_tiers_options(
self._query_str,
self.container_name,
Expand Down Expand Up @@ -1544,6 +1548,8 @@ async def set_premium_page_blob_tier_blobs(
:return: An async iterator of responses, one for each blob in order
:rtype: asynciterator[~azure.core.pipeline.transport.AsyncHttpResponse]
"""
if self._is_localhost:
kwargs['url_prepend'] = self.account_name
reqs, options = _generate_set_tiers_options(
self._query_str,
self.container_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
self._hosts = kwargs.get("_hosts")
self.scheme = parsed_url.scheme
self._is_localhost = False

if service not in ["blob", "queue", "file-share", "dfs"]:
raise ValueError(f"Invalid service: {service}")
Expand All @@ -85,6 +86,7 @@ def __init__(
self.account_name = account[0] if len(account) > 1 else None
if not self.account_name and parsed_url.netloc.startswith("localhost") \
or parsed_url.netloc.startswith("127.0.0.1"):
self._is_localhost = True
self.account_name = parsed_url.path.strip("/")

self.credential = _format_shared_key_credential(self.account_name, credential)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
self._hosts = kwargs.get("_hosts")
self.scheme = parsed_url.scheme
self._is_localhost = False

if service not in ["blob", "queue", "file-share", "dfs"]:
raise ValueError(f"Invalid service: {service}")
Expand All @@ -85,6 +86,7 @@ def __init__(
self.account_name = account[0] if len(account) > 1 else None
if not self.account_name and parsed_url.netloc.startswith("localhost") \
or parsed_url.netloc.startswith("127.0.0.1"):
self._is_localhost = True
self.account_name = parsed_url.path.strip("/")

self.credential = _format_shared_key_credential(self.account_name, credential)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
self._location_mode = kwargs.get("_location_mode", LocationMode.PRIMARY)
self._hosts = kwargs.get("_hosts")
self.scheme = parsed_url.scheme
self._is_localhost = False

if service not in ["blob", "queue", "file-share", "dfs"]:
raise ValueError(f"Invalid service: {service}")
Expand All @@ -85,6 +86,7 @@ def __init__(
self.account_name = account[0] if len(account) > 1 else None
if not self.account_name and parsed_url.netloc.startswith("localhost") \
or parsed_url.netloc.startswith("127.0.0.1"):
self._is_localhost = True
self.account_name = parsed_url.path.strip("/")

self.credential = _format_shared_key_credential(self.account_name, credential)
Expand Down

0 comments on commit 76328e9

Please sign in to comment.