diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..4cab1f4 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Set the default behavior, in case people don't have core.autocrlf set. +* text=auto diff --git a/.gitignore b/.gitignore index cdfa95f..c02f7fb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ dist* *.egg-info *.vscode *.dvc* +*.idea* diff --git a/docs/make.bat b/docs/make.bat index 6247f7e..9534b01 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,35 +1,35 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/mltrace/__init__.py b/mltrace/__init__.py index 7ae7238..66a867f 100644 --- a/mltrace/__init__.py +++ b/mltrace/__init__.py @@ -4,9 +4,7 @@ register, backtrace, get_history, - get_components_with_owner, tag_component, - get_components_with_tag, log_component_run, create_random_ids, get_component_information, @@ -23,6 +21,9 @@ flag_output_id, unflag_output_id, review_flagged_outputs, + get_all_tags, + get_components, + unflag_all, ) __all__ = [ @@ -31,9 +32,7 @@ "register", "backtrace", "get_history", - "get_components_with_owner", "tag_component", - "get_components_with_tag", "log_component_run", "create_random_ids", "get_component_information", @@ -50,4 +49,7 @@ "flag_output_id", "unflag_output_id", "review_flagged_outputs", + "get_all_tags", + "get_components", + "unflag_all", ] diff --git a/mltrace/cli/cli.py b/mltrace/cli/cli.py index 3d2e35b..74692e7 100644 --- a/mltrace/cli/cli.py +++ b/mltrace/cli/cli.py @@ -2,6 +2,7 @@ import click + from mltrace import ( set_address, get_recent_run_ids, @@ -12,6 +13,9 @@ flag_output_id, unflag_output_id, review_flagged_outputs, + get_all_tags, + get_components, + unflag_all, ) import textwrap @@ -204,6 +208,7 @@ def show_res(res, indent, count, pos, need_stick): @click.group() def mltrace(): + # Pass pass @@ -223,6 +228,20 @@ def recent(limit: int, address: str = ""): show_info_card(id) +@mltrace.command("inspect") +@click.option("--address", help="Database server address") +@click.argument("component_run_id") +def inspect(component_run_id, address: str = ""): + """ + CLI to inspect a specific component run id. + """ + # Set address + if address and len(address) > 0: + set_address(address) + + show_info_card(component_run_id) + + @mltrace.command("history") @click.argument("component_name") @click.option("--limit", default=5, help="Limit of recent objects.") @@ -274,16 +293,31 @@ def flag(output_id: str, address: str = ""): @mltrace.command("unflag") -@click.argument("output_id") +@click.option("--output_id", help="Output ID to unflag") +@click.option("--all", is_flag=True, help="Add flag to unflag all") @click.option("--address", help="Database server address") -def unflag(output_id: str, address: str = ""): +def unflag(output_id: str = "", all: bool = False, address: str = ""): """ - Command to set the flag property of an output_id to false. + Command to set flag property of an output_id or all output_ids to false. """ + # Check if set --all and --output_id + if all and output_id: + raise click.ClickException("Can set either --all=True or specify an " + "--output_id. Cannot set both.") + + if not all and not output_id: + raise click.ClickException("Need to either set --all=True or specify " + "an --output_id to unflag.") + # Set address if address and len(address) > 0: set_address(address) - unflag_output_id(output_id) + + if all: + unflag_all() + + elif not all and output_id: + unflag_output_id(output_id) @mltrace.command("review") @@ -307,3 +341,44 @@ def review(limit: int = 5, address: str = ""): # Print component runs for component, count in component_counts[:limit]: show_info_card(component.id, count, len(outputs)) + + +@mltrace.command("components") +@click.option("--owner", help="Owner of components") +@click.option("--tag", help="Tag of components") +@click.option("--address", help="Database server address") +def components(owner: str = "", tag: str = "", address: str = ""): + """ + Command to list the components with options to filter by tag or owner. + """ + if address and len(address) > 0: + set_address(address) + + # Make return result + try: + result = get_components(tag, owner) + except RuntimeError: + raise click.ClickException("No components could be found with the " + "flags passed.") + + # Display components, one per line + for comp in result: + click.echo(f"Name: {comp.name}") + click.echo() + + +@mltrace.command("tags") +@click.option("--address", help="Database server address") +def tags(address: str = ""): + """ + Command to list all the tags currently used. + """ + # Set address + if address and len(address) > 0: + set_address(address) + + # Get all tags, automatically unique + all_tags = get_all_tags() + + click.echo(all_tags) + click.echo() diff --git a/mltrace/client.py b/mltrace/client.py index 5a7d6c3..372ef4d 100644 --- a/mltrace/client.py +++ b/mltrace/client.py @@ -30,6 +30,7 @@ def _set_address_helper(old_uri: str, address: str): Otherwise, DB_URI is set to {_db_uri}." ) + # --------------------- Database management functions ------------------- # @@ -57,7 +58,7 @@ def clean_db(): def create_component( - name: str, description: str, owner: str, tags: typing.List[str] = [] + name: str, description: str, owner: str, tags: typing.List[str] = [] ): """Creates a component entity in the database.""" store = Store(_db_uri) @@ -71,9 +72,9 @@ def tag_component(component_name: str, tags: typing.List[str]): def log_component_run( - component_run: ComponentRun, - set_dependencies_from_inputs=True, - staleness_threshold: int = (60 * 60 * 24 * 30), + component_run: ComponentRun, + set_dependencies_from_inputs=True, + staleness_threshold: int = (60 * 60 * 24 * 30), ): """Takes client-facing ComponentRun object and logs it to the DB.""" store = Store(_db_uri) @@ -143,13 +144,13 @@ def create_random_ids(num_outputs=1) -> typing.List[str]: def register( - component_name: str, - inputs: typing.List[str] = [], - outputs: typing.List[str] = [], - input_vars: typing.List[str] = [], - output_vars: typing.List[str] = [], - endpoint: bool = False, - staleness_threshold: int = (60 * 60 * 24 * 30), + component_name: str, + inputs: typing.List[str] = [], + outputs: typing.List[str] = [], + input_vars: typing.List[str] = [], + output_vars: typing.List[str] = [], + endpoint: bool = False, + staleness_threshold: int = (60 * 60 * 24 * 30), ): def actual_decorator(func): @functools.wraps(func) @@ -322,14 +323,19 @@ def unflag_output_id(output_id: str) -> bool: return store.set_io_pointer_flag(output_id, False) +def unflag_all(): + store = Store(_db_uri) + store.unflag_all() + + # ----------------- Basic retrieval functions ------------------- # def get_history( - component_name: str, - limit: int = 10, - date_lower: typing.Union[datetime, str] = datetime.min, - date_upper: typing.Union[datetime, str] = datetime.max, + component_name: str, + limit: int = 10, + date_lower: typing.Union[datetime, str] = datetime.min, + date_upper: typing.Union[datetime, str] = datetime.max, ) -> typing.List[ComponentRun]: """Returns a list of ComponentRuns that are part of the component's history.""" @@ -368,23 +374,6 @@ def get_history( return component_runs -def get_components_with_owner(owner: str) -> typing.List[Component]: - """Returns a list of all the components associated with the specified - order.""" - store = Store(_db_uri) - res = store.get_components_with_owner(owner) - - # Convert to client-facing Components - components = [] - for c in res: - tags = [tag.name for tag in c.tags] - d = copy.deepcopy(c.__dict__) - d.update({"tags": tags}) - components.append(Component.from_dictionary(d)) - - return components - - def get_component_information(component_name: str) -> Component: """Returns a Component with the name, info, owner, and tags.""" store = Store(_db_uri) @@ -421,10 +410,11 @@ def get_component_run_information(component_run_id: str) -> ComponentRun: return ComponentRun.from_dictionary(d) -def get_components_with_tag(tag: str) -> typing.List[Component]: - """Returns a list of components with the specified tag.""" +def get_components(tag="", owner="") -> typing.List[Component]: + """Returns all components with the specified owner and/or tag. + Else, returns all components.""" store = Store(_db_uri) - res = store.get_components_with_tag(tag) + res = store.get_components(tag=tag, owner=owner) # Convert to client-facing Components components = [] @@ -450,6 +440,13 @@ def get_io_pointer(io_pointer_id: str, create=True): return IOPointer.from_dictionary(iop.__dict__) +def get_all_tags() -> typing.List[str]: + store = Store(_db_uri) + res = store.get_all_tags() + tags = [t.name for t in res] + return tags + + # --------------- Complex retrieval functions ------------------ # diff --git a/mltrace/db/migrations/versions/a2cdf9aa818c_create_stale.py b/mltrace/db/migrations/versions/a2cdf9aa818c_create_stale.py index 675f779..45aef75 100644 --- a/mltrace/db/migrations/versions/a2cdf9aa818c_create_stale.py +++ b/mltrace/db/migrations/versions/a2cdf9aa818c_create_stale.py @@ -1,7 +1,7 @@ """create stale Revision ID: a2cdf9aa818c -Revises: 0a8485e5ba50 +Revises: None Create Date: 2021-05-18 14:05:25.540236 """ diff --git a/mltrace/db/store.py b/mltrace/db/store.py index 698b6dd..956dcb2 100644 --- a/mltrace/db/store.py +++ b/mltrace/db/store.py @@ -60,11 +60,11 @@ def __del__(self): self.session.close() def create_component( - self, - name: str, - description: str, - owner: str, - tags: typing.List[str] = [], + self, + name: str, + description: str, + owner: str, + tags: typing.List[str] = [], ): """Creates a component entity in the database if it does not already exist.""" @@ -90,9 +90,9 @@ def get_component(self, name: str) -> Component: """Retrieves component if exists.""" component = ( self.session.query(Component) - .outerjoin(Tag, Component.tags) - .filter(Component.name == name) - .first() + .outerjoin(Tag, Component.tags) + .filter(Component.name == name) + .first() ) return component @@ -101,17 +101,14 @@ def get_component_run(self, id: str) -> ComponentRun: """Retrieves component run if exists.""" component_run = ( self.session.query(ComponentRun) - .filter(ComponentRun.id == id) - .first() + .filter(ComponentRun.id == id) + .first() ) - # print(self.session.query(ComponentRun).subquery()) - print(component_run.outputs) - return component_run def add_tags_to_component( - self, component_name: str, tags: typing.List[str] + self, component_name: str, tags: typing.List[str] ): """Retreives existing component and adds tags.""" component = self.get_component(component_name) @@ -125,8 +122,21 @@ def add_tags_to_component( component.add_tags(tag_objects) self.session.commit() + def unflag_all(self): + """Unflags all IO Pointers and commits.""" + flagged_iop = ( + self.session.query(IOPointer) + .filter(IOPointer.flag.is_(True)) + .all() + ) + + for iop in flagged_iop: + iop.clear_flag() + + self.session.commit() + def initialize_empty_component_run( - self, component_name: str + self, component_name: str ) -> ComponentRun: """Initializes an empty run for the specified component. Does not commit to the database.""" @@ -149,15 +159,15 @@ def get_tag(self, name=str) -> Tag: return res[0] def get_io_pointers( - self, names: typing.List[str], pointer_type: PointerTypeEnum = None + self, names: typing.List[str], pointer_type: PointerTypeEnum = None ) -> typing.List[IOPointer]: """Creates io pointers around the specified path names. Retrieves existing io pointer if exists in DB, otherwise creates a new one with inferred pointer type.""" res = ( self.session.query(IOPointer) - .filter(IOPointer.name.in_(names)) - .all() + .filter(IOPointer.name.in_(names)) + .all() ) res_names = set([r.name for r in res]) need_to_add = set(names) - res_names @@ -177,7 +187,7 @@ def get_io_pointers( return res def get_io_pointer( - self, name: str, pointer_type: PointerTypeEnum = None, create=True + self, name: str, pointer_type: PointerTypeEnum = None, create=True ) -> IOPointer: """Creates an io pointer around the specified path. Retrieves existing io pointer if exists in DB, @@ -226,9 +236,9 @@ def delete_io_pointer(self, io_pointer: IOPointer): ) def commit_component_run( - self, - component_run: ComponentRun, - staleness_threshold: int = (60 * 60 * 24 * 30), + self, + component_run: ComponentRun, + staleness_threshold: int = (60 * 60 * 24 * 30), ): """Commits a fully initialized component run to the DB.""" status_dict = component_run.check_completeness() @@ -242,7 +252,7 @@ def commit_component_run( for dep in component_run.dependencies: # First case: there is over a month between component runs time_diff = ( - component_run.start_timestamp - dep.start_timestamp + component_run.start_timestamp - dep.start_timestamp ).total_seconds() if time_diff > staleness_threshold: days_diff = int(time_diff // (60 * 60 * 24)) @@ -317,10 +327,10 @@ def set_dependencies_from_inputs(self, component_run: ComponentRun): component_run.set_upstream(matches) def _traverse( - self, - node: ComponentRun, - depth: int, - node_list: typing.List[ComponentRun], + self, + node: ComponentRun, + depth: int, + node_list: typing.List[ComponentRun], ): # Add node to node_list as the step node_list.append((depth, node)) @@ -363,7 +373,7 @@ def _web_trace_helper(self, component_run_object: ComponentRun): res["childNodes"].append(out_dict) for dep in sorted( - component_run_object.dependencies, key=lambda x: x.id + component_run_object.dependencies, key=lambda x: x.id ): child_res = self._web_trace_helper(dep) res["childNodes"].append(child_res) @@ -374,10 +384,10 @@ def web_trace(self, output_id: str): """Prints list of ComponentRuns to display in the UI.""" component_run_objects = ( self.session.query(ComponentRun) - .outerjoin(IOPointer, ComponentRun.outputs) - .order_by(ComponentRun.start_timestamp.desc()) - .filter(IOPointer.name == output_id) - .all() + .outerjoin(IOPointer, ComponentRun.outputs) + .order_by(ComponentRun.start_timestamp.desc()) + .filter(IOPointer.name == output_id) + .all() ) if len(component_run_objects) == 0: @@ -395,10 +405,10 @@ def trace(self, output_id: str): component_run_object = ( self.session.query(ComponentRun) - .outerjoin(IOPointer, ComponentRun.outputs) - .order_by(ComponentRun.start_timestamp.desc()) - .filter(IOPointer.name == output_id) - .first() + .outerjoin(IOPointer, ComponentRun.outputs) + .order_by(ComponentRun.start_timestamp.desc()) + .filter(IOPointer.name == output_id) + .first() ) if component_run_object is None: @@ -412,11 +422,11 @@ def trace_batch(self, output_ids: typing.List[str]): pass def get_history( - self, - component_name: str, - limit: int = 10, - date_lower: typing.Union[datetime, str] = datetime.min, - date_upper: typing.Union[datetime, str] = datetime.max, + self, + component_name: str, + limit: int = 10, + date_lower: typing.Union[datetime, str] = datetime.min, + date_upper: typing.Union[datetime, str] = datetime.max, ) -> typing.List[ComponentRun]: """Gets lineage for the component, or a history of all its runs.""" history = ( @@ -435,37 +445,47 @@ def get_history( return history - def get_components_with_owner(self, owner: str) -> typing.List[Component]: + def get_components(self, tag: str = "", owner: str = ""): """Returns a list of all the components associated with the specified - order.""" - components = ( - self.session.query(Component) - .filter(Component.owner == owner) - .options(joinedload("tags")) - .all() - ) - - if len(components) == 0: - raise RuntimeError(f"Owner {owner} has no components.") - - return components - - def get_components_with_tag(self, tag: str) -> typing.List[Component]: - """Returns a list of all the components associated with that tag.""" - components = ( - self.session.query(Component) - .join(Tag, Component.tags) - .filter(Tag.name == tag) - .all() - ) + owner and/or tags.""" + if tag and owner: + components = ( + self.session.query(Component) + .join(Tag, Component.tags) + .filter( + and_( + Tag.name == tag, + Component.owner == owner, + ) + ) + .all() + ) + elif tag: + components = ( + self.session.query(Component) + .join(Tag, Component.tags) + .filter(Tag.name == tag) + .all() + ) + elif owner: + components = ( + self.session.query(Component) + .filter(Component.owner == owner) + .options(joinedload("tags")) + .all() + ) + else: + components = ( + self.session.query(Component).all() + ) if len(components) == 0: - raise RuntimeError(f"Tag {tag} has no components associated.") + raise RuntimeError(f"Search yielded no components.") return components def get_recent_run_ids( - self, limit: int = 50, last_run_id=None + self, limit: int = 50, last_run_id=None ) -> typing.List[str]: """Returns a list of recent component run IDs.""" @@ -473,8 +493,8 @@ def get_recent_run_ids( # Get start timestamp of last run id ts = ( self.session.query(ComponentRun) - .filter(ComponentRun.id == last_run_id) - .first() + .filter(ComponentRun.id == last_run_id) + .first() ).start_timestamp if not ts: raise RuntimeError( @@ -486,8 +506,8 @@ def get_recent_run_ids( map( lambda x: int(x[0]), self.session.query(ComponentRun.id) - .order_by(ComponentRun.start_timestamp.desc()) - .filter( + .order_by(ComponentRun.start_timestamp.desc()) + .filter( and_( ComponentRun.start_timestamp <= ts, ComponentRun.id != last_run_id, @@ -504,16 +524,16 @@ def get_recent_run_ids( map( lambda x: int(x[0]), self.session.query(ComponentRun.id) - .order_by(ComponentRun.start_timestamp.desc()) - .limit(limit) - .all(), + .order_by(ComponentRun.start_timestamp.desc()) + .limit(limit) + .all(), ) ) return runs def add_notes_to_component_run( - self, component_run_id: str, notes: str + self, component_run_id: str, notes: str ) -> str: """Retreives existing component and adds tags.""" component_run = self.get_component_run(component_run_id) @@ -548,7 +568,7 @@ def set_io_pointer_flag(self, output_id: str, value: bool): ) def review_flagged_outputs( - self, + self, ) -> typing.Tuple[ typing.List[str], typing.List[typing.Tuple[ComponentRun, int]] ]: @@ -556,8 +576,8 @@ def review_flagged_outputs( # Collate flagged outputs flagged_iops = ( self.session.query(IOPointer) - .filter(IOPointer.flag.is_(True)) - .all() + .filter(IOPointer.flag.is_(True)) + .all() ) flagged_output_ids = [iop.name for iop in flagged_iops] @@ -580,3 +600,6 @@ def review_flagged_outputs( # Return a list of the ComponentRuns in the order return flagged_output_ids, trace_nodes_counts + + def get_all_tags(self) -> typing.List[Tag]: + return self.session.query(Tag).all() diff --git a/tests/test_store.py b/tests/test_store.py index ea07c9b..8a55726 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -16,7 +16,7 @@ def testComponent(self): self.assertEqual(component.name, "test_component") # Retrieve components with owner - components = self.store.get_components_with_owner("shreya") + components = self.store.get_components(owner="shreya") self.assertEqual(1, len(components)) def testCompleteComponentRun(self):