Skip to content

Commit

Permalink
feat(batch-exports): Buffer batches while asynchronously flushing
Browse files Browse the repository at this point in the history
wip

feat(batch-exports): Buffer batches in queue while asynchronously flushing
  • Loading branch information
tomasfarias committed Oct 16, 2024
1 parent 405dd59 commit 20d7cb5
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 48 deletions.
101 changes: 101 additions & 0 deletions posthog/temporal/batch_exports/batch_exports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections.abc
import dataclasses
import datetime as dt
Expand Down Expand Up @@ -251,6 +252,106 @@ async def iter_records_from_model_view(
yield record_batch


class RecordBatchQueue(asyncio.Queue):
def __init__(self, max_size_bytes=0):
super().__init__(maxsize=max_size_bytes)
self._bytes_size = 0
self._schema_set = asyncio.Event()
self.record_batch_schema = None

def _get(self) -> pa.RecordBatch:
item = self._queue.popleft()

Check failure on line 263 in posthog/temporal/batch_exports/batch_exports.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

"RecordBatchQueue" has no attribute "_queue"
self._bytes_size -= item.get_total_buffer_size()
return item

def _put(self, item: pa.RecordBatch):
self._bytes_size += item.get_total_buffer_size()

if not self._schema_set.is_set():
self.set_schema(item)

self._queue.append(item)

Check failure on line 273 in posthog/temporal/batch_exports/batch_exports.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

"RecordBatchQueue" has no attribute "_queue"

def set_schema(self, record_batch: pa.RecordBatch):
self.record_batch_schema = record_batch.schema
self._schema_set.set()

def qsize(self):
"""Size in bytes of record batches in the queue."""
return self._bytes_size

async def get_schema(self) -> pa.Schema:
await self._schema_set.wait()
return self.record_batch_schema


def start_produce_batch_export_record_batches(
client: ClickHouseClient,
model_name: str,
is_backfill: bool,
team_id: int,
interval_start: str,
interval_end: str,
fields: list[BatchExportField] | None = None,
destination_default_fields: list[BatchExportField] | None = None,
**parameters,
):
if fields is None:
if destination_default_fields is None:
fields = default_fields()
else:
fields = destination_default_fields

if model_name == "persons":
view = SELECT_FROM_PERSONS_VIEW

else:
if parameters["exclude_events"]:
parameters["exclude_events"] = list(parameters["exclude_events"])
else:
parameters["exclude_events"] = []

if parameters["include_events"]:
parameters["include_events"] = list(parameters["include_events"])
else:
parameters["include_events"] = []

if str(team_id) in settings.UNCONSTRAINED_TIMESTAMP_TEAM_IDS:
query_template = SELECT_FROM_EVENTS_VIEW_UNBOUNDED
elif is_backfill:
query_template = SELECT_FROM_EVENTS_VIEW_BACKFILL
else:
query_template = SELECT_FROM_EVENTS_VIEW
lookback_days = settings.OVERRIDE_TIMESTAMP_TEAM_IDS.get(team_id, settings.DEFAULT_TIMESTAMP_LOOKBACK_DAYS)
parameters["lookback_days"] = lookback_days

if "_inserted_at" not in [field["alias"] for field in fields]:
control_fields = [BatchExportField(expression="_inserted_at", alias="_inserted_at")]
else:
control_fields = []

query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in fields + control_fields)

view = query_template.substitute(fields=query_fields)

parameters["team_id"] = team_id
parameters["interval_start"] = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S")
parameters["interval_end"] = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S")
extra_query_parameters = parameters.pop("extra_query_parameters") or {}
parameters = {**parameters, **extra_query_parameters}

queue = RecordBatchQueue()
query_id = uuid.uuid4()
done_event = asyncio.Event()
produce_task = asyncio.create_task(
client.aproduce_query_as_arrow_record_batches(
view, queue=queue, done_event=done_event, query_parameters=parameters, query_id=str(query_id)
)
)

return queue, done_event, produce_task


def iter_records(
client: ClickHouseClient,
team_id: int,
Expand Down
161 changes: 120 additions & 41 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import contextlib
import dataclasses
import datetime as dt
import functools
import json
import operator

import pyarrow as pa
import structlog
from django.conf import settings
from google.cloud import bigquery
from google.oauth2 import service_account
Expand All @@ -27,8 +30,8 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
iter_model_records,
start_batch_export_run,
start_produce_batch_export_record_batches,
)
from posthog.temporal.batch_exports.metrics import (
get_bytes_exported_metric,
Expand All @@ -42,18 +45,19 @@
)
from posthog.temporal.batch_exports.utils import (
JsonType,
apeek_first_and_rewind,
cast_record_batch_json_columns,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.logger import configure_temporal_worker_logger
from posthog.temporal.common.utils import (
BatchExportHeartbeatDetails,
should_resume_from_activity_heartbeat,
)

logger = structlog.get_logger()


def get_bigquery_fields_from_record_schema(
record_schema: pa.Schema, known_json_columns: list[str]
Expand All @@ -72,6 +76,9 @@ def get_bigquery_fields_from_record_schema(
bq_schema: list[bigquery.SchemaField] = []

for name in record_schema.names:
if name == "_inserted_at":
continue

pa_field = record_schema.field(name)

if pa.types.is_string(pa_field.type) or isinstance(pa_field.type, JsonType):
Expand Down Expand Up @@ -264,8 +271,13 @@ async def load_parquet_file(self, parquet_file, table, table_schema):
schema=table_schema,
)

load_job = self.load_table_from_file(parquet_file, table, job_config=job_config, rewind=True)
return await asyncio.to_thread(load_job.result)
await logger.adebug("Creating BigQuery load job for Parquet file '%s'", parquet_file)
load_job = await asyncio.to_thread(
self.load_table_from_file, parquet_file, table, job_config=job_config, rewind=True
)
await logger.adebug("Waiting for BigQuery load job for Parquet file '%s'", parquet_file)
result = await asyncio.to_thread(load_job.result)
return result

async def load_jsonl_file(self, jsonl_file, table, table_schema):
"""Execute a COPY FROM query with given connection to copy contents of jsonl_file."""
Expand All @@ -274,8 +286,14 @@ async def load_jsonl_file(self, jsonl_file, table, table_schema):
schema=table_schema,
)

load_job = self.load_table_from_file(jsonl_file, table, job_config=job_config, rewind=True)
return await asyncio.to_thread(load_job.result)
await logger.adebug("Creating BigQuery load job for JSONL file '%s'", jsonl_file)
load_job = await asyncio.to_thread(
self.load_table_from_file, jsonl_file, table, job_config=job_config, rewind=True
)

await logger.adebug("Waiting for BigQuery load job for JSONL file '%s'", jsonl_file)
result = await asyncio.to_thread(load_job.result)
return result


@contextlib.contextmanager
Expand Down Expand Up @@ -327,7 +345,9 @@ def bigquery_default_fields() -> list[BatchExportField]:
@activity.defn
async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> RecordsCompleted:
"""Activity streams data from ClickHouse to BigQuery."""
logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="BigQuery")
logger = await configure_temporal_worker_logger(
logger=structlog.get_logger(), team_id=inputs.team_id, destination="BigQuery"
)
await logger.ainfo(
"Batch exporting range %s - %s to BigQuery: %s.%s.%s",
inputs.data_interval_start,
Expand Down Expand Up @@ -357,24 +377,52 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model
if model is not None:
model_name = model.name
extra_query_parameters = model.schema["values"] if model.schema is not None else None
fields = model.schema["fields"] if model.schema is not None else None
else:
model_name = "events"
extra_query_parameters = None
fields = None
else:
model = inputs.batch_export_schema
schema = model = inputs.batch_export_schema
model_name = "custom"
extra_query_parameters = schema["values"] if schema is not None else {}
fields = schema["fields"] if schema is not None else None

records_iterator = iter_model_records(
queue, done_event, produce_task = start_produce_batch_export_record_batches(
client=client,
model=model,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
destination_default_fields=bigquery_default_fields(),
is_backfill=inputs.is_backfill,
extra_query_parameters=extra_query_parameters,
)

first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator)
if first_record_batch is None:
get_schema_task = asyncio.create_task(queue.get_schema())
wait_for_producer_done_task = asyncio.create_task(done_event.wait())

await asyncio.wait([get_schema_task, wait_for_producer_done_task], return_when=asyncio.FIRST_COMPLETED)

# Finishing producing happens sequentially after putting to queue and setting the schema.
# So, either we finished both tasks, or we finished without putting anything in the queue.
if get_schema_task.done():
# In the first case, we'll land here.
# The schema is available, and the queue is not empty, so we can start the batch export.
record_batch_schema = get_schema_task.result()
elif wait_for_producer_done_task.done():
# In the second case, we'll land here.
# The schema is not available as the queue is empty.
# Since we finished producing with an empty queue, there is nothing to batch export.
return 0
else:
raise Exception("Unreachable")

if inputs.use_json_type is True:
json_type = "JSON"
Expand All @@ -383,8 +431,6 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
json_type = "STRING"
json_columns = []

first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=json_columns)

if model is None or (isinstance(model, BatchExportModel) and model.name == "events"):
schema = [

Check failure on line 435 in posthog/temporal/batch_exports/bigquery_batch_export.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Incompatible types in assignment (expression has type "list[SchemaField]", variable has type "BatchExportSchema | None")
bigquery.SchemaField("uuid", "STRING"),
Expand All @@ -401,9 +447,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"),
]
else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns)
schema = get_bigquery_fields_from_record_schema(record_batch_schema, known_json_columns=json_columns)

Check failure on line 450 in posthog/temporal/batch_exports/bigquery_batch_export.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Incompatible types in assignment (expression has type "list[SchemaField]", variable has type "BatchExportSchema | None")

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()
Expand Down Expand Up @@ -446,41 +490,42 @@ async def flush_to_bigquery(
last: bool,
error: Exception | None,
):
table = bigquery_stage_table if requires_merge else bigquery_table
await logger.adebug(
"Loading %s records of size %s bytes",
"Loading %s records of size %s bytes to BigQuery table '%s'",
records_since_last_flush,
bytes_since_last_flush,
table,
)
table = bigquery_stage_table if requires_merge else bigquery_table

await bq_client.load_jsonl_file(local_results_file, table, schema)

await logger.adebug("Loading to BigQuery table '%s' finished", table)
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

heartbeater.details = (str(last_inserted_at),)

record_schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[
field.with_nullable(True)
for field in first_record_batch.select([field.name for field in schema]).schema
]
)
writer = JSONLBatchExportWriter(
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
flush_callable=flush_to_bigquery,
)
flush_tasks = []
while not queue.empty() or not done_event.is_set():
await logger.adebug("Starting record batch writer")
flush_start_event = asyncio.Event()
task = asyncio.create_task(
consume_batch_export_record_batches(
queue, done_event, flush_start_event, flush_to_bigquery, json_columns, logger
)
)

await flush_start_event.wait()

async with writer.open_temporary_file():
async for record_batch in records_iterator:
record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns)
flush_tasks.append(task)

await writer.write_record_batch(record_batch)
await logger.adebug(
"Finished producing and consuming all record batches, now waiting on any pending flush tasks"
)
await asyncio.wait(flush_tasks)

records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks))

if requires_merge:
merge_key = (
Expand All @@ -494,7 +539,41 @@ async def flush_to_bigquery(
update_fields=schema,
)

return writer.records_total
return records_total


async def consume_batch_export_record_batches(
queue, done_event, flush_start_event, flush_to_bigquery, json_columns, logger
):
writer = JSONLBatchExportWriter(
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
flush_callable=flush_to_bigquery,
)

async with writer.open_temporary_file():
await logger.adebug("Starting record batch writing loop")
while True:
try:
record_batch = queue.get_nowait()
except asyncio.QueueEmpty:
if done_event.is_set():
await logger.adebug("Empty queue with no more events being produced, closing writer loop")
flush_start_event.set()
break
else:
await asyncio.sleep(1)
continue

record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns)
await writer.write_record_batch(record_batch, flush=False)

if writer.should_flush():
await logger.adebug("Writer finished, ready to flush events")
flush_start_event.set()
break

await logger.adebug("Completed %s records", writer.records_total)
return writer.records_total


def get_batch_export_writer(
Expand Down
Loading

0 comments on commit 20d7cb5

Please sign in to comment.