Skip to content

Commit

Permalink
Read stream per chunk
Browse files Browse the repository at this point in the history
  • Loading branch information
renaudhartert-db committed Nov 5, 2024
1 parent 8b14f51 commit ff5c54d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
15 changes: 13 additions & 2 deletions databricks/sdk/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ def __init__(self,
http_timeout_seconds: float = None,
extra_error_customizers: List[_ErrorCustomizer] = None,
debug_headers: bool = False,
clock: Clock = None):
clock: Clock = None,
streaming_buffer_size: int = 1024 * 1024): # 1MB
"""
:param debug_truncate_bytes:
:param retry_timeout_seconds:
Expand All @@ -68,6 +69,8 @@ def __init__(self,
:param extra_error_customizers:
:param debug_headers: Whether to include debug headers in the request log.
:param clock: Clock object to use for time-related operations.
:param streaming_buffer_size: The size of the buffer to use for streaming responses. If None, all the
response content is loaded into memory at once.
"""

self._debug_truncate_bytes = debug_truncate_bytes or 96
Expand All @@ -78,6 +81,7 @@ def __init__(self,
self._clock = clock or RealClock()
self._session = requests.Session()
self._session.auth = self._authenticate
self._streaming_buffer_size = streaming_buffer_size

# We don't use `max_retries` from HTTPAdapter to align with a more production-ready
# retry strategy established in the Databricks SDK for Go. See _is_retryable and
Expand Down Expand Up @@ -158,7 +162,9 @@ def do(self,
for header in response_headers if response_headers else []:
resp[header] = response.headers.get(Casing.to_header_case(header))
if raw:
resp["contents"] = _StreamingResponse(response)
streaming_response = _StreamingResponse(response)
streaming_response.set_chunk_size(self._streaming_buffer_size)
resp["contents"] = streaming_response
return resp
if not len(response.content):
return resp
Expand Down Expand Up @@ -283,6 +289,11 @@ def isatty(self) -> bool:
return False

def read(self, n: int = -1) -> bytes:
"""
Read up to n bytes from the response stream. If n is negative, read
until the end of the stream.
"""

self._open()
read_everything = n < 0
remaining_bytes = n
Expand Down
37 changes: 37 additions & 0 deletions tests/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest
import requests
from unittest.mock import Mock

from databricks.sdk import errors, useragent
from databricks.sdk._base_client import _BaseClient, _StreamingResponse
Expand Down Expand Up @@ -276,3 +277,39 @@ def inner(h: BaseHTTPRequestHandler):
assert 'foo' in res

assert len(requests) == 2


@pytest.mark.parametrize('chunk_size,expected_chunks,data_size', [
(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks
(10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks
(200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk
])
def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size):
test_data = b"0" * data_size
content_chunks = []

mock_response = Mock(spec=requests.Response)
def mock_iter_content(chunk_size):
# Simulate how requests would chunk the data
for i in range(0, len(test_data), chunk_size):
chunk = test_data[i:i + chunk_size]
content_chunks.append(chunk) # Track chunks for verification
yield chunk

mock_response.iter_content = mock_iter_content

# Create streaming response and set chunk size
stream = _StreamingResponse(mock_response)
stream.set_chunk_size(chunk_size)

# Read all data
received_data = b""
while True:
chunk = stream.read(1)
if not chunk:
break
received_data += chunk

assert received_data == test_data # All data was received correctly
assert len(content_chunks) == expected_chunks # Correct number of chunks
assert all(len(c) <= chunk_size for c in content_chunks) # Chunks don't exceed size

0 comments on commit ff5c54d

Please sign in to comment.