diff --git a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py index 4fa17498d3..6cf3bf5d71 100644 --- a/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py +++ b/src/snowflake/snowpark/_internal/temp_table_auto_cleaner.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import logging +import threading import weakref from collections import defaultdict from typing import TYPE_CHECKING, Dict @@ -31,9 +32,12 @@ def __init__(self, session: "Session") -> None: # to its reference count for later temp table management # this dict will still be maintained even if the cleaner is stopped (`stop()` is called) self.ref_count_map: Dict[str, int] = defaultdict(int) + # Lock to protect the ref_count_map + self.lock = threading.RLock() def add(self, table: SnowflakeTable) -> None: - self.ref_count_map[table.name] += 1 + with self.lock: + self.ref_count_map[table.name] += 1 # the finalizer will be triggered when it gets garbage collected # and this table will be dropped finally _ = weakref.finalize(table, self._delete_ref_count, table.name) @@ -43,13 +47,15 @@ def _delete_ref_count(self, name: str) -> None: # pragma: no cover Decrements the reference count of a temporary table, and if the count reaches zero, puts this table in the queue for cleanup. """ - self.ref_count_map[name] -= 1 - if self.ref_count_map[name] == 0: + with self.lock: + self.ref_count_map[name] -= 1 + current_ref_count = self.ref_count_map[name] + if current_ref_count == 0: if self.session.auto_clean_up_temp_table_enabled: self.drop_table(name) - elif self.ref_count_map[name] < 0: + elif current_ref_count < 0: logging.debug( - f"Unexpected reference count {self.ref_count_map[name]} for table {name}" + f"Unexpected reference count {current_ref_count} for table {name}" ) def drop_table(self, name: str) -> None: # pragma: no cover @@ -89,9 +95,11 @@ def stop(self) -> None: @property def num_temp_tables_created(self) -> int: - return len(self.ref_count_map) + with self.lock: + return len(self.ref_count_map) @property def num_temp_tables_cleaned(self) -> int: # TODO SNOW-1662536: we may need a separate counter for the number of tables cleaned when parameter is enabled - return sum(v == 0 for v in self.ref_count_map.values()) + with self.lock: + return sum(v == 0 for v in self.ref_count_map.values()) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 3a4eacf25a..8e049c56a0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -831,10 +831,11 @@ def auto_clean_up_temp_table_enabled(self, value: bool) -> None: ) if value in [True, False]: - self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( - self._session_id, value - ) - self._auto_clean_up_temp_table_enabled = value + with self._lock: + self._conn._telemetry_client.send_auto_clean_up_temp_table_telemetry( + self._session_id, value + ) + self._auto_clean_up_temp_table_enabled = value else: raise ValueError( "value for auto_clean_up_temp_table_enabled must be True or False!" diff --git a/tests/integ/test_multithreading.py b/tests/integ/test_multithreading.py index 14b9c14578..a6d13584bc 100644 --- a/tests/integ/test_multithreading.py +++ b/tests/integ/test_multithreading.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import gc import hashlib import logging import os @@ -14,6 +15,7 @@ from snowflake.snowpark.session import Session from snowflake.snowpark.types import IntegerType +from tests.integ.test_temp_table_cleanup import wait_for_drop_table_sql_done try: import dateutil @@ -505,6 +507,45 @@ def finish(self): executor.submit(register_and_test_udaf, session, i) +@pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in local testing mode", + run=False, +) +def test_auto_temp_table_cleaner(session, caplog): + session._temp_table_auto_cleaner.ref_count_map.clear() + original_auto_clean_up_temp_table_enabled = session.auto_clean_up_temp_table_enabled + session.auto_clean_up_temp_table_enabled = True + + def create_temp_table(session_, thread_id): + df = session.sql(f"select {thread_id} as A").cache_result() + table_name = df.table_name + del df + return table_name + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + table_names = [] + for i in range(10): + futures.append(executor.submit(create_temp_table, session, i)) + + for future in as_completed(futures): + table_names.append(future.result()) + + gc.collect() + wait_for_drop_table_sql_done(session, caplog, expect_drop=True) + + try: + for table_name in table_names: + assert session._temp_table_auto_cleaner.ref_count_map[table_name] == 0 + assert session._temp_table_auto_cleaner.num_temp_tables_created == 10 + assert session._temp_table_auto_cleaner.num_temp_tables_cleaned == 10 + finally: + session.auto_clean_up_temp_table_enabled = ( + original_auto_clean_up_temp_table_enabled + ) + + @pytest.mark.skipif( IS_LINUX or IS_WINDOWS, reason="Linux and Windows test show multiple active threads when no threadpool is enabled", diff --git a/tests/integ/test_temp_table_cleanup.py b/tests/integ/test_temp_table_cleanup.py index 59ad319e7d..e3d97e29ff 100644 --- a/tests/integ/test_temp_table_cleanup.py +++ b/tests/integ/test_temp_table_cleanup.py @@ -37,12 +37,14 @@ def setup(session): def wait_for_drop_table_sql_done(session: Session, caplog, expect_drop: bool) -> None: # Loop through captured logs and search for the pattern pattern = r"Dropping .* with query id ([0-9a-f\-]+)" + matches = [] for record in caplog.records: match = re.search(pattern, record.message) if match: query_id = match.group(1) - break - else: + matches.append(query_id) + + if len(matches) == 0: if expect_drop: pytest.fail("No drop table sql found in logs") else: @@ -50,9 +52,10 @@ def wait_for_drop_table_sql_done(session: Session, caplog, expect_drop: bool) -> return caplog.clear() - async_job = session.create_async_job(query_id) - # Wait for the async job to finish - _ = async_job.result() + for query_id in matches: + async_job = session.create_async_job(query_id) + # Wait for the async job to finish + _ = async_job.result() def test_basic(session, caplog):