Skip to content

Commit

Permalink
Merge pull request #236 from zdelrosario/dev_find_files
Browse files Browse the repository at this point in the history
Dev find files
  • Loading branch information
zdelrosario authored Aug 1, 2024
2 parents 46b1abd + 601ee9f commit e1e5f89
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 5 deletions.
29 changes: 29 additions & 0 deletions grama/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"copy_meta",
"custom_formatwarning",
"lookup",
"find_files",
"hide_traceback",
"param_dist",
"pipe",
Expand All @@ -16,7 +17,9 @@
from functools import wraps
from inspect import signature
from numbers import Integral
from glob import glob
from numpy import empty
from os.path import join as pathjoin
from pandas import DataFrame, concat
from pandas.core.dtypes.common import is_object_dtype
from pandas._libs import (
Expand Down Expand Up @@ -433,3 +436,29 @@ def _hide_traceback(exc_tuple=None, filename=None, tb_offset=None,
return ipython._showtraceback(etype, value, ipython.InteractiveTB.get_exception_only(etype, value))

ipython.showtraceback = _hide_traceback

# Find files
def find_files(dir, ext, recursive=False):
r"""Find files in a given directory
Args:
dir (str or list[str]): Path as string or list of directories (to be joined)
ext (str): File extension
Kwargs:
recursive (bool): Search recursively?
"""
## Parse arguments
# Handle dir provided as list
if isinstance(dir, list):
dir = pathjoin(*dir)
# Ensure leading dot
if ext[0] != ".":
ext = "." + ext
## Find the files
if recursive:
files = glob(pathjoin(dir, "**", "*" + ext), recursive=True)
else:
files = glob(pathjoin(dir, "*" + ext), recursive=False)

return files
36 changes: 31 additions & 5 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import numpy as np
import pandas as pd
from scipy.stats import norm
import unittest

from os.path import join
from context import grama as gr
from context import models, data
from context import models

## Core function tests
##################################################
Expand All @@ -15,3 +12,32 @@ def setUp(self):
def test_pipe(self):
## Chain
res = self.md >> gr.ev_hybrid(df_det="nom") >> gr.tf_sobol()

class TestFindFiles(unittest.TestCase):
def test_find_files(self):
# Non-recursive
res = gr.find_files(".", ".py", recursive=False)
self.assertTrue(
join(".", "context.py") in res
)
self.assertTrue(
not join("longrun", "sp_convergence.ipynb") in res
)
# Handles missing ext dot
res = gr.find_files(".", "py", recursive=False)
self.assertTrue(
join(".", "context.py") in res
)
# No return
res = gr.find_files(".", ".exe", recursive=False)
self.assertTrue(len(res) == 0)
# Recursive
res = gr.find_files(".", ".ipynb", recursive=True)
self.assertTrue(
join(".", "longrun", "sp_convergence.ipynb") in res
)
# Accepts paths to join
res = gr.find_files([".", "longrun"], ".ipynb", recursive=False)
self.assertTrue(
join(".", "longrun", "sp_convergence.ipynb") in res
)

0 comments on commit e1e5f89

Please sign in to comment.