diff --git a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py index bcbdd12fe..094b289e6 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py +++ b/user_tools/src/spark_rapids_pytools/rapids/rapids_tool.py @@ -151,14 +151,18 @@ def _process_job_submission_args(self): @phase_banner('Process-Arguments') def _process_arguments(self): - # 0- process the output location - self._process_output_args() - # 1- process any arguments to be passed to the RAPIDS tool - self._process_rapids_args() - # 2- we need to process the arguments of the CLI - self._process_custom_args() - # 3- process submission arguments - self._process_job_submission_args() + try: + # 0- process the output location + self._process_output_args() + # 1- process any arguments to be passed to the RAPIDS tool + self._process_rapids_args() + # 2- we need to process the arguments of the CLI + self._process_custom_args() + # 3- process submission arguments + self._process_job_submission_args() + except Exception as ex: # pylint: disable=broad-except + self.logger.error('Failed in processing arguments') + raise ex @phase_banner('Initialization') def _init_tool(self): @@ -387,14 +391,21 @@ class RapidsJarTool(RapidsTool): """ def _process_jar_arg(self): + jar_path = '' tools_jar_url = self.wrapper_options.get('toolsJar') - if tools_jar_url is None: - tools_jar_url = self.ctxt.get_rapids_jar_url() - # download the jar - jar_path = self.ctxt.platform.storage.download_resource(tools_jar_url, - self.ctxt.get_local_work_dir(), - fail_ok=False, - create_dir=True) + try: + if tools_jar_url is None: + tools_jar_url = self.ctxt.get_rapids_jar_url() + # download the jar + self.logger.info('Downloading the tools jars %s', tools_jar_url) + jar_path = self.ctxt.platform.storage.download_resource(tools_jar_url, + self.ctxt.get_local_work_dir(), + fail_ok=False, + create_dir=True) + except Exception as e: # pylint: disable=broad-except + self.logger.exception('Exception occurred downloading jar %s', tools_jar_url) + raise e + self.logger.info('RAPIDS accelerator tools jar is downloaded to work_dir %s', jar_path) # get the jar file name jar_file_name = FSUtil.get_resource_name(jar_path) diff --git a/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py b/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py index be350b67e..e7ba7ee80 100644 --- a/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py +++ b/user_tools/src/spark_rapids_pytools/rapids/tool_ctxt.py @@ -169,6 +169,7 @@ def get_local_work_dir(self) -> str: return self.get_local('depFolder') def get_rapids_jar_url(self) -> str: + self.logger.info('Fetching the Rapids Jar URL') # get the version from the package, instead of the yaml file # jar_version = self.get_value('sparkRapids', 'version') if self.is_fatwheel_mode(): @@ -176,6 +177,7 @@ def get_rapids_jar_url(self) -> str: matching_files = glob(offline_path_regex) if not matching_files: raise FileNotFoundError('In Fat Mode. No matching JAR files found.') + self.logger.info('Using jar from wheel file %s', matching_files[0]) return matching_files[0] mvn_base_url = self.get_value('sparkRapids', 'mvnUrl') jar_version = Utilities.get_latest_mvn_jar_from_metadata(mvn_base_url)