From abcd81f787e82264ae13bc1a1c6927f1be01e1cf Mon Sep 17 00:00:00 2001 From: zdelrosario Date: Thu, 1 Aug 2024 08:57:43 -0400 Subject: [PATCH 1/3] create find_files() tool --- grama/tools.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/grama/tools.py b/grama/tools.py index c9f196f..025dd31 100644 --- a/grama/tools.py +++ b/grama/tools.py @@ -3,6 +3,7 @@ "copy_meta", "custom_formatwarning", "lookup", + "find_files", "hide_traceback", "param_dist", "pipe", @@ -16,7 +17,9 @@ from functools import wraps from inspect import signature from numbers import Integral +from glob import glob from numpy import empty +from os.path import join as pathjoin from pandas import DataFrame, concat from pandas.core.dtypes.common import is_object_dtype from pandas._libs import ( @@ -433,3 +436,29 @@ def _hide_traceback(exc_tuple=None, filename=None, tb_offset=None, return ipython._showtraceback(etype, value, ipython.InteractiveTB.get_exception_only(etype, value)) ipython.showtraceback = _hide_traceback + +# Find files +def find_files(dir, ext, recursive=False): + r"""Find files in a given directory + + Args: + dir (str or list[str]): Path as string or list of directories (to be joined) + ext (str): File extension + + Kwargs: + recursive (bool): Search recursively? + """ + ## Parse arguments + # Handle dir provided as list + if isinstance(dir, list): + dir = pathjoin(**dir) + # Ensure leading dot + if ext[0] != ".": + ext = "." + ext + ## Find the files + if recursive: + files = glob(pathjoin(dir, "**", "*" + ext), recursive=True) + else: + files = glob(pathjoin(dir, "*" + ext), recursive=False) + + return files From 0180c566bf240b84d1f49fc8368f7446b76e4d94 Mon Sep 17 00:00:00 2001 From: zdelrosario Date: Thu, 1 Aug 2024 08:57:53 -0400 Subject: [PATCH 2/3] cover find_files with tests --- tests/test_tools.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index d462bba..5ce50d5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,10 +1,7 @@ -import numpy as np -import pandas as pd -from scipy.stats import norm import unittest - +from os.path import join from context import grama as gr -from context import models, data +from context import models ## Core function tests ################################################## @@ -15,3 +12,22 @@ def setUp(self): def test_pipe(self): ## Chain res = self.md >> gr.ev_hybrid(df_det="nom") >> gr.tf_sobol() + +class TestFindFiles(unittest.TestCase): + def test_find_files(self): + # Non-recursive + res = gr.find_files(".", ".py", recursive=False) + self.assertTrue( + join(".", "context.py") in res + ) + self.assertTrue( + not join("longrun", "sp_convergence.ipynb") in res + ) + # No return + res = gr.find_files(".", ".exe", recursive=False) + self.assertTrue(len(res) == 0) + # Recursive + res = gr.find_files(".", ".ipynb", recursive=True) + self.assertTrue( + join(".", "longrun", "sp_convergence.ipynb") in res + ) From edfa48cd3923d78f2521cd343d9923f29c66e4a2 Mon Sep 17 00:00:00 2001 From: zdelrosario Date: Thu, 1 Aug 2024 09:01:01 -0400 Subject: [PATCH 3/3] cover ext fix and path joining; fix splat --- grama/tools.py | 2 +- tests/test_tools.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/grama/tools.py b/grama/tools.py index 025dd31..b83c865 100644 --- a/grama/tools.py +++ b/grama/tools.py @@ -451,7 +451,7 @@ def find_files(dir, ext, recursive=False): ## Parse arguments # Handle dir provided as list if isinstance(dir, list): - dir = pathjoin(**dir) + dir = pathjoin(*dir) # Ensure leading dot if ext[0] != ".": ext = "." + ext diff --git a/tests/test_tools.py b/tests/test_tools.py index 5ce50d5..c18574b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -23,6 +23,11 @@ def test_find_files(self): self.assertTrue( not join("longrun", "sp_convergence.ipynb") in res ) + # Handles missing ext dot + res = gr.find_files(".", "py", recursive=False) + self.assertTrue( + join(".", "context.py") in res + ) # No return res = gr.find_files(".", ".exe", recursive=False) self.assertTrue(len(res) == 0) @@ -31,3 +36,8 @@ def test_find_files(self): self.assertTrue( join(".", "longrun", "sp_convergence.ipynb") in res ) + # Accepts paths to join + res = gr.find_files([".", "longrun"], ".ipynb", recursive=False) + self.assertTrue( + join(".", "longrun", "sp_convergence.ipynb") in res + )