Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev find files #236

Merged
merged 4 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions grama/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"copy_meta",
"custom_formatwarning",
"lookup",
"find_files",
"hide_traceback",
"param_dist",
"pipe",
Expand All @@ -16,7 +17,9 @@
from functools import wraps
from inspect import signature
from numbers import Integral
from glob import glob
from numpy import empty
from os.path import join as pathjoin
from pandas import DataFrame, concat
from pandas.core.dtypes.common import is_object_dtype
from pandas._libs import (
Expand Down Expand Up @@ -433,3 +436,29 @@ def _hide_traceback(exc_tuple=None, filename=None, tb_offset=None,
return ipython._showtraceback(etype, value, ipython.InteractiveTB.get_exception_only(etype, value))

ipython.showtraceback = _hide_traceback

# Find files
def find_files(dir, ext, recursive=False):
r"""Find files in a given directory

Args:
dir (str or list[str]): Path as string or list of directories (to be joined)
ext (str): File extension

Kwargs:
recursive (bool): Search recursively?
"""
## Parse arguments
# Handle dir provided as list
if isinstance(dir, list):
dir = pathjoin(*dir)
# Ensure leading dot
if ext[0] != ".":
ext = "." + ext
## Find the files
if recursive:
files = glob(pathjoin(dir, "**", "*" + ext), recursive=True)
else:
files = glob(pathjoin(dir, "*" + ext), recursive=False)

return files
36 changes: 31 additions & 5 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import numpy as np
import pandas as pd
from scipy.stats import norm
import unittest

from os.path import join
from context import grama as gr
from context import models, data
from context import models

## Core function tests
##################################################
Expand All @@ -15,3 +12,32 @@ def setUp(self):
def test_pipe(self):
## Chain
res = self.md >> gr.ev_hybrid(df_det="nom") >> gr.tf_sobol()

class TestFindFiles(unittest.TestCase):
def test_find_files(self):
# Non-recursive
res = gr.find_files(".", ".py", recursive=False)
self.assertTrue(
join(".", "context.py") in res
)
self.assertTrue(
not join("longrun", "sp_convergence.ipynb") in res
)
# Handles missing ext dot
res = gr.find_files(".", "py", recursive=False)
self.assertTrue(
join(".", "context.py") in res
)
# No return
res = gr.find_files(".", ".exe", recursive=False)
self.assertTrue(len(res) == 0)
# Recursive
res = gr.find_files(".", ".ipynb", recursive=True)
self.assertTrue(
join(".", "longrun", "sp_convergence.ipynb") in res
)
# Accepts paths to join
res = gr.find_files([".", "longrun"], ".ipynb", recursive=False)
self.assertTrue(
join(".", "longrun", "sp_convergence.ipynb") in res
)
Loading