diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index 86c3c5dae99ef..16d4ccdacf0d0 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -336,12 +336,12 @@ def start_produce_batch_export_record_batches( view = SELECT_FROM_PERSONS_VIEW else: - if parameters["exclude_events"]: + if parameters.get("exclude_events", None): parameters["exclude_events"] = list(parameters["exclude_events"]) else: parameters["exclude_events"] = [] - if parameters["include_events"]: + if parameters.get("include_events", None): parameters["include_events"] = list(parameters["include_events"]) else: parameters["include_events"] = [] @@ -367,7 +367,7 @@ def start_produce_batch_export_record_batches( parameters["team_id"] = team_id parameters["interval_start"] = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S") parameters["interval_end"] = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S") - extra_query_parameters = parameters.pop("extra_query_parameters") or {} + extra_query_parameters = parameters.pop("extra_query_parameters", {}) or {} parameters = {**parameters, **extra_query_parameters} queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES) diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index dda307dda004a..b784a404d5ca3 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -2,15 +2,19 @@ import json import operator from random import randint +import asyncio import pytest from django.test import override_settings +import pyarrow as pa from posthog.batch_exports.service import BatchExportModel from posthog.temporal.batch_exports.batch_exports import ( get_data_interval, iter_model_records, iter_records, + start_produce_batch_export_record_batches, + RecordBatchQueue, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse @@ -404,3 +408,427 @@ def test_get_data_interval(interval, data_interval_end, expected): """Test get_data_interval returns the expected data interval tuple.""" result = get_data_interval(interval, data_interval_end) assert result == expected + + +async def get_record_batch_from_queue(queue, done_event): + while not queue.empty() or not done_event.is_set(): + try: + record_batch = queue.get_nowait() + except asyncio.QueueEmpty: + if done_event.is_set(): + break + else: + await asyncio.sleep(0.1) + continue + + return record_batch + return None + + +async def test_start_produce_batch_export_record_batches_uses_extra_query_parameters(clickhouse_client): + """Test start_produce_batch_export_record_batches uses a HogQL value.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X", "custom": 3}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + fields=[ + {"expression": "JSONExtractInt(properties, %(hogql_val_0)s)", "alias": "custom_prop"}, + ], + extra_query_parameters={"hogql_val_0": "custom"}, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + for expected, record in zip(events, records): + if expected["properties"] is None: + raise ValueError("Empty properties") + + assert record["custom_prop"] == expected["properties"]["custom"] + + +async def test_start_produce_batch_export_record_batches_can_flatten_properties(clickhouse_client): + """Test start_produce_batch_export_record_batches can flatten properties.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X", "custom-property": 3}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + fields=[ + {"expression": "event", "alias": "event"}, + {"expression": "JSONExtractString(properties, '$browser')", "alias": "browser"}, + {"expression": "JSONExtractString(properties, '$os')", "alias": "os"}, + {"expression": "JSONExtractInt(properties, 'custom-property')", "alias": "custom_prop"}, + ], + extra_query_parameters={"hogql_val_0": "custom"}, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + all_expected = sorted(events, key=operator.itemgetter("event")) + all_record = sorted(records, key=operator.itemgetter("event")) + + for expected, record in zip(all_expected, all_record): + if expected["properties"] is None: + raise ValueError("Empty properties") + + assert record["browser"] == expected["properties"]["$browser"] + assert record["os"] == expected["properties"]["$os"] + assert record["custom_prop"] == expected["properties"]["custom-property"] + + +@pytest.mark.parametrize( + "field", + [ + {"expression": "event", "alias": "event_name"}, + {"expression": "team_id", "alias": "team"}, + {"expression": "timestamp", "alias": "time_the_stamp"}, + {"expression": "created_at", "alias": "creation_time"}, + ], +) +async def test_start_produce_batch_export_record_batches_with_single_field_and_alias(clickhouse_client, field): + """Test start_produce_batch_export_record_batches can return a single aliased field.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + fields=[field], + extra_query_parameters={}, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + all_expected = sorted(events, key=operator.itemgetter(field["expression"])) + all_record = sorted(records, key=operator.itemgetter(field["alias"])) + + for expected, record in zip(all_expected, all_record): + assert len(record) == 2 + # Always set for progress tracking + assert record.get("_inserted_at", None) is not None + + result = record[field["alias"]] + expected_value = expected[field["expression"]] + + if isinstance(result, dt.datetime): + # Event generation function returns datetimes as strings. + expected_value = dt.datetime.fromisoformat(expected_value).replace(tzinfo=dt.UTC) + + assert result == expected_value + + +async def test_start_produce_batch_export_record_batches_ignores_timestamp_predicates(clickhouse_client): + """Test the rows returned ignore timestamp predicates when configured.""" + team_id = randint(1, 1000000) + + inserted_at = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + data_interval_end = inserted_at + dt.timedelta(hours=1) + + # Insert some data with timestamps a couple of years before inserted_at + timestamp_start = inserted_at - dt.timedelta(hours=24 * 365 * 2) + timestamp_end = inserted_at - dt.timedelta(hours=24 * 365) + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=timestamp_start, + end_time=timestamp_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + inserted_at=inserted_at, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=inserted_at.isoformat(), + interval_end=data_interval_end.isoformat(), + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert len(records) == 0 + + with override_settings(UNCONSTRAINED_TIMESTAMP_TEAM_IDS=[str(team_id)]): + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=inserted_at.isoformat(), + interval_end=data_interval_end.isoformat(), + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events) + + +async def test_start_produce_batch_export_record_batches_can_include_events(clickhouse_client): + """Test the rows returned can include events.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10000, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + # Include the latter half of events. + include_events = (event["event"] for event in events[5000:]) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + include_events=include_events, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events[5000:]) + + +async def test_start_produce_batch_export_record_batches_can_exclude_events(clickhouse_client): + """Test the rows returned can include events.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10000, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + # Exclude the latter half of events. + exclude_events = (event["event"] for event in events[5000:]) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events[:5000]) + + +async def test_start_produce_batch_export_record_batches_handles_duplicates(clickhouse_client): + """Test the rows returned are de-duplicated.""" + team_id = randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=100, + count_outside_range=0, + count_other_team=0, + duplicate=True, + person_properties={"$browser": "Chrome", "$os": "Mac OS X"}, + ) + + queue, done_event, _ = start_produce_batch_export_record_batches( + client=clickhouse_client, + team_id=team_id, + is_backfill=False, + model_name="events", + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + ) + + records = [] + while not queue.empty() or not done_event.is_set(): + record_batch = await get_record_batch_from_queue(queue, done_event) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + + assert_records_match_events(records, events) + + +async def test_record_batch_queue_tracks_bytes(): + """Test `RecordBatchQueue` tracks bytes from `RecordBatch`.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + + queue = RecordBatchQueue() + + await queue.put(record_batch) + assert record_batch.get_total_buffer_size() == queue.qsize() + + item = await queue.get() + + assert item == record_batch + assert queue.qsize() == 0 + + +async def test_record_batch_queue_raises_queue_full(): + """Test `QueueFull` is raised when we put too many bytes.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + record_batch_size = record_batch.get_total_buffer_size() + + queue = RecordBatchQueue(max_size_bytes=record_batch_size) + + await queue.put(record_batch) + assert record_batch.get_total_buffer_size() == queue.qsize() + + with pytest.raises(asyncio.QueueFull): + queue.put_nowait(record_batch) + + item = await queue.get() + + assert item == record_batch + assert queue.qsize() == 0 + + +async def test_record_batch_queue_sets_schema(): + """Test `RecordBatchQueue` sets a schema from first `RecordBatch`.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + + queue = RecordBatchQueue() + + await queue.put(record_batch) + + assert queue._schema_set.is_set() + + schema = await queue.get_schema() + assert schema == record_batch.schema