From cb5b89108e0217131a914315be698827a4d05c15 Mon Sep 17 00:00:00 2001 From: Muspi Merol Date: Fri, 6 Dec 2024 23:54:05 +0800 Subject: [PATCH] Refactor `NO_RELOAD` implementation (#10088) * Refactor `NO_RELOAD` implementation * add changeset --------- Co-authored-by: gradio-pr-bot --- .changeset/legal-geese-taste.md | 5 +++++ gradio/utils.py | 35 ++++++++++++++++++--------------- 2 files changed, 24 insertions(+), 16 deletions(-) create mode 100644 .changeset/legal-geese-taste.md diff --git a/.changeset/legal-geese-taste.md b/.changeset/legal-geese-taste.md new file mode 100644 index 0000000000000..000b4ebe448c7 --- /dev/null +++ b/.changeset/legal-geese-taste.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:Refactor `NO_RELOAD` implementation diff --git a/gradio/utils.py b/gradio/utils.py index e38c21a703d4f..4d371f41507b5 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -172,10 +172,21 @@ def swap_blocks(self, demo: Blocks): self.alert_change("reload") -NO_RELOAD = True +class DynamicBoolean(int): + def __init__(self, value: int): + self.value = bool(value) + def __bool__(self): + return self.value -def _remove_no_reload_codeblocks(file_path: str): + def set(self, value: int): + self.value = bool(value) + + +NO_RELOAD = DynamicBoolean(True) + + +def _remove_if_name_main_codeblock(file_path: str): """Parse the file, remove the gr.no_reload code blocks, and write the file back to disk. Parameters: @@ -187,16 +198,6 @@ def _remove_no_reload_codeblocks(file_path: str): tree = ast.parse(code) - def _is_gr_no_reload(expr: ast.AST) -> bool: - """Find with gr.no_reload context managers.""" - return ( - isinstance(expr, ast.If) - and isinstance(expr.test, ast.Attribute) - and isinstance(expr.test.value, ast.Name) - and expr.test.value.id == "gr" - and expr.test.attr == "NO_RELOAD" - ) - def _is_if_name_main(expr: ast.AST) -> bool: """Find the if __name__ == '__main__': block.""" return ( @@ -212,7 +213,7 @@ def _is_if_name_main(expr: ast.AST) -> bool: # Find the positions of the code blocks to load for node in ast.walk(tree): - if _is_gr_no_reload(node) or _is_if_name_main(node): + if _is_if_name_main(node): assert isinstance(node, ast.If) # noqa: S101 node.body = [ast.Pass(lineno=node.lineno, col_offset=node.col_offset)] @@ -236,6 +237,8 @@ def watchfn(reloader: SourceFileReloader): get_changes is taken from uvicorn's default file watcher. """ + NO_RELOAD.set(False) + # The thread running watchfn will be the thread reloading # the app. So we need to modify this thread_data attr here # so that subsequent calls to reload don't launch the app @@ -275,7 +278,7 @@ def iter_py_files() -> Iterator[Path]: # Need to import the module in this thread so that the # module is available in the namespace of this thread module = reloader.watch_module - no_reload_source_code = _remove_no_reload_codeblocks(str(reloader.demo_file)) + no_reload_source_code = _remove_if_name_main_codeblock(str(reloader.demo_file)) exec(no_reload_source_code, module.__dict__) sys.modules[reloader.watch_module_name] = module @@ -294,7 +297,7 @@ def iter_py_files() -> Iterator[Path]: # changes to be reflected in the main demo file. if changed.suffix == ".py": - changed_in_copy = _remove_no_reload_codeblocks(str(changed)) + changed_in_copy = _remove_if_name_main_codeblock(str(changed)) if changed != reloader.demo_file: changed_module = _find_module(changed) if changed_module: @@ -305,7 +308,7 @@ def iter_py_files() -> Iterator[Path]: if top_level_parent != changed_module: importlib.reload(top_level_parent) - changed_demo_file = _remove_no_reload_codeblocks( + changed_demo_file = _remove_if_name_main_codeblock( str(reloader.demo_file) ) exec(changed_demo_file, module.__dict__)