Skip to content

Commit

Permalink
Support string metrics values in Controller (#1176)
Browse files Browse the repository at this point in the history
* Support strings metric values

* Fix tests
  • Loading branch information
andreyvelich authored May 5, 2020
1 parent 58e0764 commit 2d35d55
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 30 deletions.
4 changes: 2 additions & 2 deletions pkg/apis/controller/common/v1alpha3/common_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ type ParameterAssignment struct {
}

type Metric struct {
Name string `json:"name,omitempty"`
Value float64 `json:"value,omitempty"`
Name string `json:"name,omitempty"`
Value string `json:"value,omitempty"`
}

// +k8s:deepcopy-gen=true
Expand Down
27 changes: 18 additions & 9 deletions pkg/controller.v1alpha3/experiment/util/status_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
package util

import (
"strconv"

logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"

commonv1alpha3 "github.com/kubeflow/katib/pkg/apis/controller/common/v1alpha3"
Expand Down Expand Up @@ -77,28 +79,35 @@ func updateTrialsSummary(instance *experimentsv1alpha3.Experiment, trials *trial
sts.PendingTrialList = append(sts.PendingTrialList, trial.Name)
}

objectiveMetricValue := getObjectiveMetricValue(trial, objectiveMetricName)
if objectiveMetricValue == nil {
objectiveMetricValueStr := getObjectiveMetricValue(trial, objectiveMetricName)
if objectiveMetricValueStr == nil {
continue
}
objectiveMetricValue, err := strconv.ParseFloat(*objectiveMetricValueStr, 64)

// For string metrics values best trial is the latest
if err != nil {
bestTrialIndex = index
continue
}

//intialize vars to objective metric value of the first trial
//initialize vars to objective metric value of the first trial
if bestTrialIndex == -1 {
bestTrialValue = *objectiveMetricValue
bestTrialValue = objectiveMetricValue
bestTrialIndex = index
}

if objectiveType == commonv1alpha3.ObjectiveTypeMinimize {
if *objectiveMetricValue < bestTrialValue {
bestTrialValue = *objectiveMetricValue
if objectiveMetricValue < bestTrialValue {
bestTrialValue = objectiveMetricValue
bestTrialIndex = index
}
if instance.Spec.Objective.Goal != nil && bestTrialValue <= objectiveValueGoal {
isObjectiveGoalReached = true
}
} else if objectiveType == commonv1alpha3.ObjectiveTypeMaximize {
if *objectiveMetricValue > bestTrialValue {
bestTrialValue = *objectiveMetricValue
if objectiveMetricValue > bestTrialValue {
bestTrialValue = objectiveMetricValue
bestTrialIndex = index
}
if instance.Spec.Objective.Goal != nil && bestTrialValue >= objectiveValueGoal {
Expand Down Expand Up @@ -131,7 +140,7 @@ func updateTrialsSummary(instance *experimentsv1alpha3.Experiment, trials *trial
return isObjectiveGoalReached
}

func getObjectiveMetricValue(trial trialsv1alpha3.Trial, objectiveMetricName string) *float64 {
func getObjectiveMetricValue(trial trialsv1alpha3.Trial, objectiveMetricName string) *string {
if trial.Status.Observation == nil {
return nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func convertTrialObservation(observation *commonapiv1alpha3.Observation) *sugges
for _, m := range observation.Metrics {
resObservation.Metrics = append(resObservation.Metrics, &suggestionapi.Metric{
Name: m.Name,
Value: fmt.Sprintf("%f", m.Value),
Value: m.Value,
})
}
}
Expand Down
17 changes: 13 additions & 4 deletions pkg/controller.v1alpha3/trial/trial_controller_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,36 @@ func isJobSucceeded(jobCondition *commonv1.JobCondition) bool {
return false
}

func getBestObjectiveMetricValue(metricLogs []*api_pb.MetricLog, objectiveType commonv1alpha3.ObjectiveType) *float64 {
func getBestObjectiveMetricValue(metricLogs []*api_pb.MetricLog, objectiveType commonv1alpha3.ObjectiveType) *string {
metricLogSize := len(metricLogs)
if metricLogSize == 0 {
return nil
}

bestObjectiveValue, _ := strconv.ParseFloat(metricLogs[0].Metric.Value, 64)
for _, metricLog := range metricLogs[1:] {
bestObjectiveValue, err := strconv.ParseFloat(metricLogs[0].Metric.Value, 64)
bestIndex := 0

if err != nil {
// If metrics are string values return the latest value
return &metricLogs[len(metricLogs)-1].Metric.Value
}

for idx, metricLog := range metricLogs[1:] {
objectiveMetricValue, _ := strconv.ParseFloat(metricLog.Metric.Value, 64)
if objectiveType == commonv1alpha3.ObjectiveTypeMinimize {
if objectiveMetricValue < bestObjectiveValue {
bestObjectiveValue = objectiveMetricValue
bestIndex = idx + 1
}
} else if objectiveType == commonv1alpha3.ObjectiveTypeMaximize {
if objectiveMetricValue > bestObjectiveValue {
bestObjectiveValue = objectiveMetricValue
bestIndex = idx + 1
}
}

}
return &bestObjectiveValue
return &metricLogs[bestIndex].Metric.Value
}

func needUpdateFinalizers(trial *trialsv1alpha3.Trial) (bool, []string) {
Expand Down
19 changes: 12 additions & 7 deletions test/e2e/v1alpha3/resume-e2e-experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"context"
"fmt"
"io/ioutil"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"log"
"os"
"strconv"
"time"

"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"

appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
k8syaml "k8s.io/apimachinery/pkg/util/yaml"
Expand All @@ -27,7 +29,7 @@ const (
timeout = 30 * time.Minute
)

func verifyResult(exp *experimentsv1alpha3.Experiment) (*float64, error) {
func verifyResult(exp *experimentsv1alpha3.Experiment) (*string, error) {
if len(exp.Status.CurrentOptimalTrial.ParameterAssignments) == 0 {
return nil, fmt.Errorf("Best parameter assignments not updated in status")
}
Expand Down Expand Up @@ -126,11 +128,11 @@ func main() {
log.Fatal("Experiment run timed out")
}

metricVal, err := verifyResult(exp)
metricValStr, err := verifyResult(exp)
if err != nil {
log.Fatal(err)
}
if metricVal == nil {
if metricValStr == nil {
log.Fatal("Metric value in CurrentOptimalTrial not populated")
}

Expand All @@ -139,8 +141,11 @@ func main() {
if exp.Spec.Objective.Goal != nil {
goal = *exp.Spec.Objective.Goal
}
if (exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMinimize && *metricVal < goal) ||
(exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMaximize && *metricVal > goal) {

metricVal, err := strconv.ParseFloat(*metricValStr, 64)
if err != nil &&
((exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMinimize && metricVal < goal) ||
(exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMaximize && metricVal > goal)) {
log.Print("Objective Goal reached")
} else {

Expand Down
19 changes: 12 additions & 7 deletions test/e2e/v1alpha3/run-e2e-experiment.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"context"
"fmt"
"io/ioutil"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"
"log"
"os"
"strconv"
"time"

"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/types"

appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
k8syaml "k8s.io/apimachinery/pkg/util/yaml"
Expand All @@ -27,7 +29,7 @@ const (
timeout = 30 * time.Minute
)

func verifyResult(exp *experimentsv1alpha3.Experiment) (*float64, error) {
func verifyResult(exp *experimentsv1alpha3.Experiment) (*string, error) {
if len(exp.Status.CurrentOptimalTrial.ParameterAssignments) == 0 {
return nil, fmt.Errorf("Best parameter assignments not updated in status")
}
Expand Down Expand Up @@ -109,11 +111,11 @@ func main() {
log.Fatal("Experiment run timed out")
}

metricVal, err := verifyResult(exp)
metricValStr, err := verifyResult(exp)
if err != nil {
log.Fatal(err)
}
if metricVal == nil {
if metricValStr == nil {
log.Fatal("Metric value in CurrentOptimalTrial not populated")
}

Expand All @@ -122,8 +124,11 @@ func main() {
if exp.Spec.Objective.Goal != nil {
goal = *exp.Spec.Objective.Goal
}
if (exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMinimize && *metricVal < goal) ||
(exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMaximize && *metricVal > goal) {

metricVal, err := strconv.ParseFloat(*metricValStr, 64)
if err != nil &&
((exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMinimize && metricVal < goal) ||
(exp.Spec.Objective.Goal != nil && objectiveType == commonv1alpha3.ObjectiveTypeMaximize && metricVal > goal)) {
log.Print("Objective Goal reached")
} else {

Expand Down

0 comments on commit 2d35d55

Please sign in to comment.