Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix spinner animation blocking user input in diagnostic tool #631

Merged
merged 1 commit into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions user_tools/src/spark_rapids_pytools/common/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,23 +379,25 @@ class ToolsSpinner:
A class to manage the spinner animation.
Reference: https://stackoverflow.com/a/66558182

:param in_debug_mode: Flag indicating if running in debug (verbose) mode. Defaults to False.
:param enabled: Flag indicating if the spinner is enabled. Defaults to True.
"""
in_debug_mode: bool = field(default=False, init=True)
pixel_spinner: PixelSpinner = field(default=PixelSpinner('Processing...'), init=False)
enabled: bool = field(default=True, init=True)
pixel_spinner: PixelSpinner = field(default=PixelSpinner('Processing...', hide_cursor=False), init=False)
end: str = field(default='Processing Completed!', init=False)
timeout: float = field(default=0.1, init=False)
completed: bool = field(default=False, init=False)
spinner_thread: threading.Thread = field(default=None, init=False)
pause_event: threading.Event = field(default=threading.Event(), init=False)

def _spinner_animation(self):
while not self.completed:
self.pixel_spinner.next()
time.sleep(self.timeout)
while self.pause_event.is_set():
self.pause_event.wait(self.timeout)

def start(self):
# Don't start if in debug mode
if not self.in_debug_mode:
if self.enabled:
self.spinner_thread = threading.Thread(target=self._spinner_animation, daemon=True)
self.spinner_thread.start()
return self
Expand All @@ -404,6 +406,16 @@ def stop(self):
self.completed = True
print(f'\r\n{self.end}', flush=True)

def pause(self, insert_newline=False):
if self.enabled:
if insert_newline:
# Print a newline for visual separation
print()
self.pause_event.set()

def resume(self):
self.pause_event.clear()

def __enter__(self):
return self.start()

Expand Down
10 changes: 7 additions & 3 deletions user_tools/src/spark_rapids_pytools/rapids/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@ def _process_custom_args(self):

self.thread_num = thread_num
self.logger.debug('Set thread number as: %d', self.thread_num)

self.logger.warning('This operation will collect sensitive information from your cluster, '
'such as OS & HW info, Yarn/Spark configurations and log files etc.')
log_message = ('This operation will collect sensitive information from your cluster, '
'such as OS & HW info, Yarn/Spark configurations and log files etc.')
yes = self.wrapper_options.get('yes', False)
if yes:
self.logger.warning(log_message)
self.logger.info('Confirmed by command line option.')
else:
# Pause the spinner for user prompt
self.spinner.pause(insert_newline=True)
print(log_message)
user_input = input('Do you want to continue (yes/no): ')
if user_input.lower() not in ['yes', 'y']:
raise RuntimeError('User canceled the operation.')
self.spinner.resume()

def requires_cluster_connection(self) -> bool:
return True
Expand Down
5 changes: 4 additions & 1 deletion user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class RapidsTool(object):
name: str = field(default=None, init=False)
ctxt: ToolContext = field(default=None, init=False)
logger: Logger = field(default=None, init=False)
spinner: ToolsSpinner = field(default=None, init=False)

def pretty_name(self):
return self.name.capitalize()
Expand Down Expand Up @@ -272,7 +273,9 @@ def _verify_exec_cluster(self):
self._handle_non_running_exec_cluster(msg)

def launch(self):
with ToolsSpinner(in_debug_mode=ToolLogging.is_debug_mode_enabled()):
# Spinner should not be enabled in debug mode
enable_spinner = not ToolLogging.is_debug_mode_enabled()
with ToolsSpinner(enabled=enable_spinner) as self.spinner:
self._init_tool()
self._connect_to_execution_cluster()
self._process_arguments()
Expand Down
Loading