Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Jan 13, 2023
1 parent df5d1b9 commit 46388ad
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions nvflare/app_opt/tracking/mlflow/mlflow_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@

class MLFlowReceiver(AnalyticsReceiver):
def __init__(
self,
tracking_uri: Optional[str] = None,
kwargs: Optional[dict] = None,
artifact_location: Optional[str] = None,
events=None,
buffer_flush_time_in_sec=1,
self,
tracking_uri: Optional[str] = None,
kwargs: Optional[dict] = None,
artifact_location: Optional[str] = None,
events=None,
buffer_flush_time_in_sec=1,
):
"""
MLFlowReceiver receives log events from client and deliver to the MLFLow.
Expand Down Expand Up @@ -105,9 +105,7 @@ def mlflow_setup(self, art_full_path, experiment_name, experiment_tags, run_name
mlflow_client, experiment_name, art_full_path, experiment_tags
)
tags = self.get_run_tags(self.kwargs)
run = mlflow_client.create_run(
experiment_id=self.experiment_id, run_name=run_name, tags=tags
)
run = mlflow_client.create_run(experiment_id=self.experiment_id, run_name=run_name, tags=tags)
self.run_ids[site.name] = run.info.run_id

def _init_buffer(self, sites):
Expand Down Expand Up @@ -151,11 +149,11 @@ def get_artifact_location(self, relative_path: str):
return root_log_dir

def _create_experiment(
self,
mlflow_client: MlflowClient,
experiment_name: str,
artifact_location: str,
experiment_tags: Optional[dict] = None,
self,
mlflow_client: MlflowClient,
experiment_name: str,
artifact_location: str,
experiment_tags: Optional[dict] = None,
) -> Optional[str]:
experiment_id = None
if experiment_name:
Expand All @@ -164,6 +162,7 @@ def _create_experiment(
self.logger.info(f"Experiment with name '{experiment_name}' does not exist. Creating a new experiment.")
try:
import pathlib

artifact_location_uri = pathlib.Path(artifact_location).as_uri()
experiment_id = mlflow_client.create_experiment(
name=experiment_name, artifact_location=artifact_location_uri, tags=experiment_tags
Expand Down

0 comments on commit 46388ad

Please sign in to comment.