-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Population based training #1833
Merged
Merged
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
467bea2
docs: update new algorithm service details
a9p d001104
feat: trial augmentation strategy
a9p e4755c6
feat: pbt suggestion service
a9p 471ed19
feat: PbtTemplate and associated test image
a9p 93d5312
feat: introduce annotation field to trial specifications
a9p 45e6ada
feat: trial assignment changes to support annotations from suggestion
a9p e8fd056
feat: update pbt suggestion to use new Annotation api
a9p c00ef34
chore: gofmt and black run across PBT changes
a9p 6f46a64
feedback: remove tf summary export, change default print unit, reduce…
a9p dcd10b3
feedback: move PBT template to example.
a9p aa93c45
feedback: changes to inject_webhook and utils.
a9p c6eab46
feedback: change suggestion mutation mount variable name and add to c…
a9p be1834d
feedback: Add trial_names to GetSuggestionsReply and change suggestio…
a9p cf94d88
feedback: removed unnecessary checks and moved to async pbt implement…
a9p 7cc4326
feedback: update trial name override location and change annotations …
a9p 274130d
feedback: add pbt to github workflow
a9p 4bb0800
feedback: move labels to ParameterAssignments in GetSuggestionsReply …
a9p 6fbd3a8
feedback: remove operator changes
a9p 18650e5
feedback: GHA updates
a9p 3e0cfd1
feedback: new formatting changes
a9p 6962ad7
feedback: add suggestion-pbt to gh-actions build-load.sh.
a9p 5cf08b3
fix: missing pbt->simple-pbt name changes, add simple-pbt to update-i…
a9p d88187d
feedback: add pointer to website from main readme for pbt
a9p File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
name: E2E Test with simple-pbt | ||
on: | ||
- pull_request | ||
|
||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
|
||
jobs: | ||
e2e: | ||
runs-on: ubuntu-20.04 | ||
timeout-minutes: 120 | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v2 | ||
|
||
- name: Setup Test Env | ||
uses: ./.github/workflows/template-setup-e2e-test | ||
with: | ||
kubernetes-version: ${{ matrix.kubernetes-version }} | ||
|
||
- name: Run e2e test with ${{ matrix.experiments }} experiments | ||
uses: ./.github/workflows/template-e2e-test | ||
with: | ||
experiments: ${{ matrix.experiments }} | ||
# Comma Delimited | ||
trial-images: simple-pbt | ||
|
||
strategy: | ||
fail-fast: false | ||
matrix: | ||
# Detail: https://hub.docker.com/r/kindest/node | ||
# TODO (tenzen-y): We need to consider running tests on more kubernetes versions. | ||
# kubernetes-version: ["v1.20.15", "v1.21.12", "v1.22.9", "v1.23.6", "v1.24.1"] | ||
kubernetes-version: ["v1.21.12", "v1.22.9", "v1.23.6"] | ||
# Comma Delimited | ||
experiments: ["simple-pbt"] | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
FROM python:3.9-slim | ||
|
||
ENV TARGET_DIR /opt/katib | ||
ENV SUGGESTION_DIR cmd/suggestion/pbt/v1beta1 | ||
ENV GRPC_HEALTH_PROBE_VERSION v0.4.6 | ||
|
||
RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \ | ||
apt-get -y update && \ | ||
apt-get -y install gfortran libopenblas-dev liblapack-dev wget && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/*; \ | ||
else \ | ||
apt-get -y update && \ | ||
apt-get -y install wget && \ | ||
apt-get clean && \ | ||
rm -rf /var/lib/apt/lists/*; \ | ||
fi | ||
RUN if [ "$(uname -m)" = "ppc64le" ]; then \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \ | ||
elif [ "$(uname -m)" = "aarch64" ]; then \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \ | ||
else \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \ | ||
fi && \ | ||
chmod +x /bin/grpc_health_probe | ||
|
||
ADD ./pkg/ ${TARGET_DIR}/pkg/ | ||
ADD ./${SUGGESTION_DIR}/ ${TARGET_DIR}/${SUGGESTION_DIR}/ | ||
WORKDIR ${TARGET_DIR}/${SUGGESTION_DIR} | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
RUN chgrp -R 0 ${TARGET_DIR} \ | ||
&& chmod -R g+rwX ${TARGET_DIR} | ||
|
||
ENV PYTHONPATH ${TARGET_DIR}:${TARGET_DIR}/pkg/apis/manager/v1beta1/python:${TARGET_DIR}/pkg/apis/manager/health/python | ||
|
||
ENTRYPOINT ["python", "main.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright 2022 The Kubeflow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import grpc | ||
import time | ||
from pkg.apis.manager.v1beta1.python import api_pb2_grpc | ||
from pkg.apis.manager.health.python import health_pb2_grpc | ||
from pkg.suggestion.v1beta1.pbt.service import PbtService | ||
from concurrent import futures | ||
|
||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 | ||
DEFAULT_PORT = "0.0.0.0:6789" | ||
|
||
|
||
def serve(): | ||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
service = PbtService() | ||
api_pb2_grpc.add_SuggestionServicer_to_server(service, server) | ||
health_pb2_grpc.add_HealthServicer_to_server(service, server) | ||
|
||
server.add_insecure_port(DEFAULT_PORT) | ||
print("Listening...") | ||
server.start() | ||
try: | ||
while True: | ||
time.sleep(_ONE_DAY_IN_SECONDS) | ||
except KeyboardInterrupt: | ||
server.stop(0) | ||
|
||
|
||
if __name__ == "__main__": | ||
serve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
grpcio==1.41.1 | ||
protobuf==3.19.1 | ||
googleapis-common-protos==1.53.0 | ||
numpy==1.22.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
apiVersion: kubeflow.org/v1beta1 | ||
kind: Experiment | ||
metadata: | ||
namespace: kubeflow | ||
name: simple-pbt | ||
spec: | ||
maxTrialCount: 2 | ||
parallelTrialCount: 2 | ||
maxFailedTrialCount: 3 | ||
resumePolicy: FromVolume | ||
objective: | ||
type: maximize | ||
goal: 0.99 | ||
objectiveMetricName: Validation-accuracy | ||
algorithm: | ||
algorithmName: pbt | ||
algorithmSettings: | ||
- name: suggestion_trial_dir | ||
value: /var/log/katib/checkpoints/ | ||
- name: n_population | ||
value: '40' | ||
- name: truncation_threshold | ||
value: '0.2' | ||
parameters: | ||
- name: lr | ||
parameterType: double | ||
feasibleSpace: | ||
min: '0.0001' | ||
max: '0.02' | ||
step: '0.0001' | ||
trialTemplate: | ||
primaryContainerName: training-container | ||
trialParameters: | ||
- name: learningRate | ||
description: Learning rate for training the model | ||
reference: lr | ||
trialSpec: | ||
apiVersion: batch/v1 | ||
kind: Job | ||
spec: | ||
template: | ||
spec: | ||
containers: | ||
- name: training-container | ||
image: docker.io/kubeflowkatib/simple-pbt:latest | ||
command: | ||
- "python3" | ||
- "/opt/pbt/pbt_test.py" | ||
- "--epochs=20" | ||
- "--lr=${trialParameters.learningRate}" | ||
- "--checkpoint=/var/log/katib/checkpoints/" | ||
restartPolicy: Never |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
FROM python:3.9-slim | ||
|
||
ADD examples/v1beta1/trial-images/simple-pbt /opt/pbt | ||
WORKDIR /opt/pbt | ||
|
||
RUN python3 -m pip install -r requirements.txt | ||
|
||
RUN chgrp -R 0 /opt/pbt \ | ||
&& chmod -R g+rwX /opt/pbt | ||
|
||
ENTRYPOINT ["python3", "/opt/pbt/pbt_test.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
#!/usr/bin/env python | ||
a9p marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Implementation based on: | ||
# https://github.com/ray-project/ray/blob/master/python/ray/tune/examples/pbt_example.py | ||
|
||
import argparse | ||
import numpy as np | ||
import os | ||
import pickle | ||
import random | ||
import time | ||
|
||
# Ensure job runs for at least this long (secs) to allow metrics collector to | ||
# read PID correctly before cleanup | ||
_METRICS_COLLECTOR_SPAWN_LATENCY = 7 | ||
|
||
|
||
class PBTBenchmarkExample: | ||
"""Toy PBT problem for benchmarking adaptive learning rate. | ||
The goal is to optimize this trainable's accuracy. The accuracy increases | ||
fastest at the optimal lr, which is a function of the current accuracy. | ||
The optimal lr schedule for this problem is the triangle wave as follows. | ||
Note that many lr schedules for real models also follow this shape: | ||
best lr | ||
^ | ||
| /\ | ||
| / \ | ||
| / \ | ||
| / \ | ||
------------> accuracy | ||
In this problem, using PBT with a population of 2-4 is sufficient to | ||
roughly approximate this lr schedule. Higher population sizes will yield | ||
faster convergence. Training will not converge without PBT. | ||
""" | ||
|
||
def __init__(self, lr, checkpoint: str): | ||
self._lr = lr | ||
|
||
self._checkpoint_file = os.path.join(checkpoint, "training.ckpt") | ||
if os.path.exists(self._checkpoint_file): | ||
with open(self._checkpoint_file, "rb") as fin: | ||
checkpoint_data = pickle.load(fin) | ||
self._accuracy = checkpoint_data["accuracy"] | ||
self._step = checkpoint_data["step"] | ||
else: | ||
os.makedirs(checkpoint, exist_ok=True) | ||
self._step = 1 | ||
self._accuracy = 0.0 | ||
|
||
def save_checkpoint(self): | ||
with open(self._checkpoint_file, "wb") as fout: | ||
pickle.dump({"step": self._step, "accuracy": self._accuracy}, fout) | ||
|
||
def step(self): | ||
midpoint = 50 # lr starts decreasing after acc > midpoint | ||
q_tolerance = 3 # penalize exceeding lr by more than this multiple | ||
noise_level = 2 # add gaussian noise to the acc increase | ||
# triangle wave: | ||
# - start at 0.001 @ t=0, | ||
# - peak at 0.01 @ t=midpoint, | ||
# - end at 0.001 @ t=midpoint * 2, | ||
if self._accuracy < midpoint: | ||
optimal_lr = 0.01 * self._accuracy / midpoint | ||
else: | ||
optimal_lr = 0.01 - 0.01 * (self._accuracy - midpoint) / midpoint | ||
optimal_lr = min(0.01, max(0.001, optimal_lr)) | ||
|
||
# compute accuracy increase | ||
q_err = max(self._lr, optimal_lr) / ( | ||
min(self._lr, optimal_lr) + np.finfo(float).eps | ||
) | ||
if q_err < q_tolerance: | ||
self._accuracy += (1.0 / q_err) * random.random() | ||
elif self._lr > optimal_lr: | ||
self._accuracy -= (q_err - q_tolerance) * random.random() | ||
self._accuracy += noise_level * np.random.normal() | ||
self._accuracy = max(0, min(100, self._accuracy)) | ||
|
||
self._step += 1 | ||
|
||
def __repr__(self): | ||
return "epoch {}:\nlr={:0.4f}\nValidation-accuracy={:0.4f}".format( | ||
self._step, self._lr, self._accuracy / 100 | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
# Parse CLI arguments | ||
parser = argparse.ArgumentParser(description="PBT Basic Test") | ||
parser.add_argument( | ||
"--lr", type=float, default=0.0001, help="learning rate (default: 0.0001)" | ||
) | ||
parser.add_argument( | ||
"--epochs", type=int, default=20, help="number of epochs to train (default: 20)" | ||
) | ||
parser.add_argument( | ||
"--checkpoint", | ||
type=str, | ||
default="/var/log/katib/checkpoints/", | ||
help="checkpoint directory (resume and save)", | ||
) | ||
opt = parser.parse_args() | ||
|
||
benchmark = PBTBenchmarkExample(opt.lr, opt.checkpoint) | ||
|
||
start_time = time.time() | ||
for i in range(opt.epochs): | ||
benchmark.step() | ||
exec_time_thresh = time.time() - start_time - _METRICS_COLLECTOR_SPAWN_LATENCY | ||
if exec_time_thresh < 0: | ||
time.sleep(abs(exec_time_thresh)) | ||
benchmark.save_checkpoint() | ||
|
||
print(benchmark) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
numpy==1.22.2 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to sync
experiments
name with the manifest name.For example, If this experiments name is "simple-pbt", we need to change this manifest name from pbt.yaml to simple-pbt.yaml.