diff --git a/cmd/metricscollector/v1alpha3/tfevent-metricscollector/main.py b/cmd/metricscollector/v1alpha3/tfevent-metricscollector/main.py index 73992d73f52..8d2d05ac7ef 100644 --- a/cmd/metricscollector/v1alpha3/tfevent-metricscollector/main.py +++ b/cmd/metricscollector/v1alpha3/tfevent-metricscollector/main.py @@ -39,7 +39,7 @@ def parse_options(): WaitOtherMainProcesses(completed_marked_dir=opt.dir_path) - mc = MetricsCollector(opt.metric_names.split(',')) + mc = MetricsCollector(opt.metric_names.split(';')) observation_log = mc.parse_file(opt.dir_path) channel = grpc.beta.implementations.insecure_channel( diff --git a/pkg/metricscollector/v1alpha3/tfevent-metricscollector/tfevent_loader.py b/pkg/metricscollector/v1alpha3/tfevent-metricscollector/tfevent_loader.py index 0531c580939..e0dd7b7dc46 100644 --- a/pkg/metricscollector/v1alpha3/tfevent-metricscollector/tfevent_loader.py +++ b/pkg/metricscollector/v1alpha3/tfevent-metricscollector/tfevent_loader.py @@ -12,6 +12,8 @@ # When the event file is under a directory(e.g. test dir), please specify "{{dirname}}/{{metrics name}}" # For example, in the TensorFlow official tutorial for mnist with summary (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py), # the "accuracy" metric is saved under "train" and "test" directories. So in Katib, please specify name of metrics as "train/accuracy" and "test/accuracy". + + class TFEventFileParser: def find_all_files(self, directory): for root, dirs, files in os.walk(directory): @@ -22,23 +24,24 @@ def find_all_files(self, directory): def parse_summary(self, tfefile, metrics): metric_logs = [] for summary in tf.train.summary_iterator(tfefile): - paths=tfefile.split("/") + paths = tfefile.split("/") for v in summary.summary.value: for m in metrics: tag = str(v.tag) if len(paths) >= 2 and len(m.split("/")) >= 2: - tag = str(paths[-2]+"/" + v.tag) + tag = str(paths[-2]+"/" + v.tag) if tag.startswith(m): ml = api_pb2.MetricLog( - time_stamp=rfc3339.rfc3339(datetime.fromtimestamp(summary.wall_time)), - metric=api_pb2.Metric( - name=m, - value=str(v.simple_value) - ) + time_stamp=rfc3339.rfc3339(datetime.fromtimestamp(summary.wall_time)), + metric=api_pb2.Metric( + name=m, + value=str(v.simple_value) ) + ) metric_logs.append(ml) return metric_logs + class MetricsCollector: def __init__(self, metric_names): self.logger = getLogger(__name__) @@ -59,6 +62,6 @@ def parse_file(self, directory): self.logger.info(f + " will be parsed.") mls.extend(self.parser.parse_summary(f, self.metrics)) except Exception as e: - self.logger.warning("Unexpected error: "+ str(e)) + self.logger.warning("Unexpected error: " + str(e)) continue return api_pb2.ObservationLog(metric_logs=mls)