Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Azure-specific headers when uploading to blob storage #1784

Merged
merged 2 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,20 +790,24 @@ def upload_file(
filename=to_upload.name,
)

extra_headers = self.get_extra_headers_for_protocol(upload_location.native_url)
encoded_md5 = b64encode(md5_bytes)
with open(str(to_upload), "+rb") as local_file:
content = local_file.read()
content_length = len(content)
headers = {"Content-Length": str(content_length), "Content-MD5": encoded_md5}
headers.update(extra_headers)
rsp = requests.put(
upload_location.signed_url,
data=content,
headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5},
headers=headers,
verify=False
if self._config.platform.insecure_skip_verify is True
else self._config.platform.ca_cert_file_path,
)

if rsp.status_code != requests.codes["OK"]:
# Check both HTTP 201 and 200, because some storage backends (e.g. Azure) return 201 instead of 200.
if rsp.status_code not in (requests.codes["OK"], requests.codes["created"]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we have to check this value as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry, forgot to specify in my PR description.
Azure returns an HTTP 201. So this was throwing an exception even though the request was successful

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add this as a comment please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

raise FlyteValueException(
rsp.status_code,
f"Request to send data {upload_location.signed_url} failed.",
Expand Down Expand Up @@ -1925,3 +1929,9 @@ def launch_backfill(
return remote_wf

return self.execute(remote_wf, inputs={}, project=project, domain=domain, execution_name=execution_name)

@staticmethod
def get_extra_headers_for_protocol(native_url):
if native_url.startswith("abfs://"):
return {"x-ms-blob-type": "BlockBlob"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of blockblob type, the return code will become 200, won't it?
https://stackoverflow.com/questions/67015459/why-does-azure-blob-storage-return-201-but-not-upload

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That matches what happened when I tested it

return {}
12 changes: 12 additions & 0 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ def test_more_stuff(mock_client):
assert computed_v2 != computed_v3


def test_get_extra_headers_azure_blob_storage():
native_url = "abfs://flyte@storageaccount/container/path/to/file"
headers = FlyteRemote.get_extra_headers_for_protocol(native_url)
assert headers["x-ms-blob-type"] == "BlockBlob"


def test_get_extra_headers_s3():
native_url = "s3://flyte@storageaccount/container/path/to/file"
headers = FlyteRemote.get_extra_headers_for_protocol(native_url)
assert headers == {}


@patch("flytekit.remote.remote.SynchronousFlyteClient")
def test_generate_console_http_domain_sandbox_rewrite(mock_client):
_, temp_filename = tempfile.mkstemp(suffix=".yaml")
Expand Down
Loading