diff --git a/pyathena/async_cursor.py b/pyathena/async_cursor.py index 81624f2c..8c9b2b21 100644 --- a/pyathena/async_cursor.py +++ b/pyathena/async_cursor.py @@ -28,6 +28,7 @@ def __init__( retry_config, max_workers=(cpu_count() or 1) * 5, arraysize=CursorIterator.DEFAULT_FETCH_SIZE, + kill_on_interrupt=True, ): super(AsyncCursor, self).__init__( connection=connection, @@ -40,6 +41,7 @@ def __init__( converter=converter, formatter=formatter, retry_config=retry_config, + kill_on_interrupt=kill_on_interrupt, ) self._executor = ThreadPoolExecutor(max_workers=max_workers) self._arraysize = arraysize diff --git a/pyathena/async_pandas_cursor.py b/pyathena/async_pandas_cursor.py index 21e318c0..ed6400e0 100644 --- a/pyathena/async_pandas_cursor.py +++ b/pyathena/async_pandas_cursor.py @@ -26,6 +26,7 @@ def __init__( retry_config, max_workers=(cpu_count() or 1) * 5, arraysize=CursorIterator.DEFAULT_FETCH_SIZE, + kill_on_interrupt=True, ): super(AsyncPandasCursor, self).__init__( connection=connection, @@ -40,6 +41,7 @@ def __init__( retry_config=retry_config, max_workers=max_workers, arraysize=arraysize, + kill_on_interrupt=kill_on_interrupt, ) def _collect_result_set(self, query_id): diff --git a/pyathena/common.py b/pyathena/common.py index f1bf99d2..e5997afa 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -84,6 +84,7 @@ def __init__( converter, formatter, retry_config, + kill_on_interrupt, **kwargs ): super(BaseCursor, self).__init__(**kwargs) @@ -97,6 +98,7 @@ def __init__( self._converter = converter self._formatter = formatter self._retry_config = retry_config + self._kill_on_interrupt = kill_on_interrupt @property def connection(self): @@ -117,7 +119,7 @@ def _get_query_execution(self, query_id): else: return AthenaQueryExecution(response) - def _poll(self, query_id): + def __poll(self, query_id): while True: query_execution = self._get_query_execution(query_id) if query_execution.state in [ @@ -129,6 +131,18 @@ def _poll(self, query_id): else: time.sleep(self._poll_interval) + def _poll(self, query_id): + try: + query_execution = self.__poll(query_id) + except KeyboardInterrupt as e: + if self._kill_on_interrupt: + _logger.warning("Query canceled by user.") + self._cancel(query_id) + query_execution = self.__poll(query_id) + else: + raise e + return query_execution + def _build_start_query_execution_request( self, query, work_group=None, s3_staging_dir=None ): diff --git a/pyathena/connection.py b/pyathena/connection.py index 8ba49695..5afd01a7 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -59,6 +59,7 @@ def __init__( formatter=None, retry_config=None, cursor_class=Cursor, + kill_on_interrupt=True, **kwargs ): self._kwargs = kwargs @@ -106,6 +107,7 @@ def __init__( self._formatter = formatter if formatter else DefaultParameterFormatter() self._retry_config = retry_config if retry_config else RetryConfig() self.cursor_class = cursor_class + self.kill_on_interrupt = kill_on_interrupt def _assume_role( self, profile_name, region_name, role_arn, role_session_name, duration_seconds @@ -171,6 +173,7 @@ def cursor(self, cursor=None, **kwargs): converter=converter, formatter=kwargs.pop("formatter", self._formatter), retry_config=kwargs.pop("retry_config", self._retry_config), + kill_on_interrupt=kwargs.pop("kill_on_interrupt", self.kill_on_interrupt), **kwargs ) diff --git a/pyathena/cursor.py b/pyathena/cursor.py index 6eafa61f..bbe643bd 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -25,6 +25,7 @@ def __init__( converter, formatter, retry_config, + kill_on_interrupt=True, **kwargs ): super(Cursor, self).__init__( @@ -38,6 +39,7 @@ def __init__( converter=converter, formatter=formatter, retry_config=retry_config, + kill_on_interrupt=kill_on_interrupt, **kwargs ) diff --git a/pyathena/pandas_cursor.py b/pyathena/pandas_cursor.py index 4ae881ba..acf5a0b7 100644 --- a/pyathena/pandas_cursor.py +++ b/pyathena/pandas_cursor.py @@ -27,6 +27,7 @@ def __init__( converter, formatter, retry_config, + kill_on_interrupt=True, **kwargs ): super(PandasCursor, self).__init__( @@ -40,6 +41,7 @@ def __init__( converter=converter, formatter=formatter, retry_config=retry_config, + kill_on_interrupt=kill_on_interrupt, **kwargs )