Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add HyperBand #787

Merged
merged 20 commits into from
Sep 25, 2019
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions cmd/suggestion/hyperband/v1alpha3/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
FROM python:3

ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/hyperband/v1alpha3
RUN if [ "$(uname -m)" = "ppc64le" ]; then \
apt-get -y update && \
apt-get -y install gfortran libopenblas-dev liblapack-dev && \
pip install cython; \
fi
RUN GRPC_HEALTH_PROBE_VERSION=v0.3.0 && \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64 && \
chmod +x /bin/grpc_health_probe

ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/hyperband/v1alpha3
RUN pip install --no-cache-dir -r requirements.txt
ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/v1alpha3/python

ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/v1alpha3/python:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/health/python

ENTRYPOINT ["python", "main.py"]
8 changes: 7 additions & 1 deletion cmd/suggestion/hyperband/v1alpha3/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import grpc
import time
from pkg.apis.manager.v1alpha3.python import api_pb2_grpc
from pkg.apis.manager.health.python import health_pb2_grpc
from pkg.suggestion.v1alpha3.hyperband_service import HyperbandService
from concurrent import futures

_ONE_DAY_IN_SECONDS = 60 * 60 * 24
DEFAULT_PORT = "0.0.0.0:6789"


def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
api_pb2_grpc.add_SuggestionServicer_to_server(HyperbandService(), server)
service = HyperbandService()
api_pb2_grpc.add_SuggestionServicer_to_server(service, server)
health_pb2_grpc.add_HealthServicer_to_server(service, server)

server.add_insecure_port(DEFAULT_PORT)
print("Listening...")
server.start()
Expand All @@ -19,5 +24,6 @@ def serve():
except KeyboardInterrupt:
server.stop(0)


if __name__ == "__main__":
serve()
10 changes: 5 additions & 5 deletions cmd/suggestion/hyperband/v1alpha3/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
grpcio
duecredit
grpcio==1.23.0
duecredit===0.7.0
cloudpickle==0.5.6
numpy>=1.13.3
scikit-learn>=0.19.0
scipy>=0.19.1
forestci
protobuf
googleapis-common-protos
forestci==0.3
protobuf==3.9.1
googleapis-common-protos==1.6.0
2 changes: 1 addition & 1 deletion examples/v1alpha3/hyperband-example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ spec:
- name: "eta"
value: "3"
- name: "r_l"
value: "9"
value: "2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, or the validation will fail because the hyperband algorithm will validate parallelTrialCount according to the r_l

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How was it working in v1alpha2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In e2e test, we will set the parallel trial count to 2. Thus we will get the error parallel trial count is less than 9 bacause of the r_l. In v1alpha2 we do not have e2e test for hyperband. Thus we do not have the problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks for explanation.

maxFailedTrialCount: 9
parameters:
- name: --lr
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package suggestionclient

import (
common "github.com/kubeflow/katib/pkg/apis/controller/common/v1alpha3"
experimentsv1alpha3 "github.com/kubeflow/katib/pkg/apis/controller/experiments/v1alpha3"
suggestionsv1alpha3 "github.com/kubeflow/katib/pkg/apis/controller/suggestions/v1alpha3"
suggestionapi "github.com/kubeflow/katib/pkg/apis/manager/v1alpha3"
)

// appendAlgorithmSettingsFromSuggestion appends the algorithm settings
// in suggestion to Experiment.
// Algorithm settings in suggestion will overwrite the settings in experiment.
func appendAlgorithmSettingsFromSuggestion(experiment *experimentsv1alpha3.Experiment, algoSettingsInSuggestion *common.AlgorithmSpec) {
algoSettingsInExperiment := experiment.Spec.Algorithm
for _, setting := range algoSettingsInSuggestion.AlgorithmSettings {
gaocegege marked this conversation as resolved.
Show resolved Hide resolved
if index, found := contains(algoSettingsInExperiment, setting.Name); found {
// If the setting is found in Experiment, update it.
algoSettingsInExperiment.AlgorithmSettings[index].Value = setting.Value
} else {
// If not found, append it.
algoSettingsInExperiment.AlgorithmSettings = append(
algoSettingsInExperiment.AlgorithmSettings, setting)
}
}
}

func updateAlgorithmSettings(suggestion *suggestionsv1alpha3.Suggestion, algorithm *suggestionapi.AlgorithmSpec) {
algoSettingsInSuggestion := suggestion.Spec.AlgorithmSpec
for _, setting := range algorithm.AlgorithmSetting {
gaocegege marked this conversation as resolved.
Show resolved Hide resolved
if setting != nil {
if index, found := contains(algoSettingsInSuggestion, setting.Name); found {
// If the setting is found in Suggestion, update it.
algoSettingsInSuggestion.AlgorithmSettings[index].Value = setting.Value
} else {
// If not found, append it.
algoSettingsInSuggestion.AlgorithmSettings = append(algoSettingsInSuggestion.AlgorithmSettings, common.AlgorithmSetting{
Name: setting.Name,
Value: setting.Value,
})
}
}
}
}

func contains(algorithmSettings *common.AlgorithmSpec,
name string) (int, bool) {
for i, s := range algorithmSettings.AlgorithmSettings {
if s.Name == name {
return i, true
}
}
return -1, false
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,24 @@ var (
timeout = 60 * time.Second
)

// SuggestionClient is the interface to communicate with algorithm services.
type SuggestionClient interface {
SyncAssignments(instance *suggestionsv1alpha3.Suggestion, e *experimentsv1alpha3.Experiment,
ts []trialsv1alpha3.Trial) error

ValidateAlgorithmSettings(instance *suggestionsv1alpha3.Suggestion, e *experimentsv1alpha3.Experiment) error
}

// General is the implementation for SuggestionClient.
type General struct {
}

// New creates a new SuggestionClient.
func New() SuggestionClient {
return &General{}
}

// SyncAssignments syncs assignments from algorithm services.
func (g *General) SyncAssignments(
instance *suggestionsv1alpha3.Suggestion,
e *experimentsv1alpha3.Experiment,
Expand All @@ -60,8 +64,13 @@ func (g *General) SyncAssignments(
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

// Algorithm settings in suggestion will overwrite the settings in experiment.
filledE := e.DeepCopy()
appendAlgorithmSettingsFromSuggestion(filledE, instance.Spec.AlgorithmSpec)
experiment := g.ConvertExperiment(e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/g.ConvertExperiment(e)/g.ConvertExperiment(filledE)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good find. I will update it when CI is finished. Want to see the result of CI first

Copy link
Member

@hougangliu hougangliu Sep 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about not introducing filledE, and update g.ConvertExperiment(e) to g.ConvertExperiment(e, instance.Spec.AlgorithmSpec)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think code here is more readable. Or we will update the convertAlgortihmSettings. It is hard to read, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me


request := &suggestionapi.GetSuggestionsRequest{
Experiment: g.ConvertExperiment(e),
Experiment: experiment,
Trials: g.ConvertTrials(ts),
RequestNumber: int32(requestNum),
}
Expand All @@ -83,10 +92,11 @@ func (g *General) SyncAssignments(
})
}

// TODO(gaocegege): Set algorithm settings
updateAlgorithmSettings(instance, response.Algorithm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spec is not updated in Reconcile

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, gotcha

return nil
}

// ValidateAlgorithmSettings validates if the algorithm specific configurations are valid.
func (g *General) ValidateAlgorithmSettings(instance *suggestionsv1alpha3.Suggestion, e *experimentsv1alpha3.Experiment) error {
logger := log.WithValues("Suggestion", types.NamespacedName{Name: instance.GetName(), Namespace: instance.GetNamespace()})
endpoint := fmt.Sprintf("%s:%d", instance.Name, consts.DefaultSuggestionPort)
Expand Down Expand Up @@ -151,6 +161,12 @@ func (g *General) ConvertExperiment(e *experimentsv1alpha3.Experiment) *suggesti
if e.Spec.NasConfig != nil {
res.Spec.NasConfig = convertNasConfig(e.Spec.NasConfig)
}
if e.Spec.ParallelTrialCount != nil {
res.Spec.ParallelTrialCount = *e.Spec.ParallelTrialCount
}
if e.Spec.MaxTrialCount != nil {
res.Spec.MaxTrialCount = *e.Spec.MaxTrialCount
}
return res
}

Expand Down
Loading