diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 519fde3ed3..1036ad1fe7 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -790,20 +790,24 @@ def upload_file( filename=to_upload.name, ) + extra_headers = self.get_extra_headers_for_protocol(upload_location.native_url) encoded_md5 = b64encode(md5_bytes) with open(str(to_upload), "+rb") as local_file: content = local_file.read() content_length = len(content) + headers = {"Content-Length": str(content_length), "Content-MD5": encoded_md5} + headers.update(extra_headers) rsp = requests.put( upload_location.signed_url, data=content, - headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5}, + headers=headers, verify=False if self._config.platform.insecure_skip_verify is True else self._config.platform.ca_cert_file_path, ) - if rsp.status_code != requests.codes["OK"]: + # Check both HTTP 201 and 200, because some storage backends (e.g. Azure) return 201 instead of 200. + if rsp.status_code not in (requests.codes["OK"], requests.codes["created"]): raise FlyteValueException( rsp.status_code, f"Request to send data {upload_location.signed_url} failed.", @@ -1925,3 +1929,9 @@ def launch_backfill( return remote_wf return self.execute(remote_wf, inputs={}, project=project, domain=domain, execution_name=execution_name) + + @staticmethod + def get_extra_headers_for_protocol(native_url): + if native_url.startswith("abfs://"): + return {"x-ms-blob-type": "BlockBlob"} + return {} diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 94b03b044a..5a9a7c959e 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -203,6 +203,18 @@ def test_more_stuff(mock_client): assert computed_v2 != computed_v3 +def test_get_extra_headers_azure_blob_storage(): + native_url = "abfs://flyte@storageaccount/container/path/to/file" + headers = FlyteRemote.get_extra_headers_for_protocol(native_url) + assert headers["x-ms-blob-type"] == "BlockBlob" + + +def test_get_extra_headers_s3(): + native_url = "s3://flyte@storageaccount/container/path/to/file" + headers = FlyteRemote.get_extra_headers_for_protocol(native_url) + assert headers == {} + + @patch("flytekit.remote.remote.SynchronousFlyteClient") def test_generate_console_http_domain_sandbox_rewrite(mock_client): _, temp_filename = tempfile.mkstemp(suffix=".yaml")