-
Notifications
You must be signed in to change notification settings - Fork 448
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* First commit with darts * Support darts in Katib * Fix problems * Modify darts example * Change num nodes to 4
- Loading branch information
1 parent
2d35d55
commit c23f9c6
Showing
14 changed files
with
1,066 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
FROM python:3.6 | ||
|
||
RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \ | ||
apt-get -y update && \ | ||
apt-get -y install gfortran libopenblas-dev liblapack-dev && \ | ||
pip install cython; \ | ||
fi | ||
|
||
RUN GRPC_HEALTH_PROBE_VERSION=v0.3.1 && \ | ||
if [ "$(uname -m)" = "ppc64le" ]; then \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \ | ||
elif [ "$(uname -m)" = "aarch64" ]; then \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \ | ||
else \ | ||
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \ | ||
fi && \ | ||
chmod +x /bin/grpc_health_probe | ||
|
||
ADD . /usr/src/app/github.com/kubeflow/katib | ||
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/nas/darts/v1alpha3 | ||
RUN pip install --no-cache-dir -r requirements.txt | ||
|
||
ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/v1alpha3/python:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/health/python | ||
|
||
ENTRYPOINT ["python", "main.py"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import grpc | ||
from concurrent import futures | ||
import time | ||
from pkg.apis.manager.v1alpha3.python import api_pb2_grpc | ||
from pkg.apis.manager.health.python import health_pb2_grpc | ||
from pkg.suggestion.v1alpha3.nas.darts.service import DartsService | ||
|
||
|
||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 | ||
DEFAULT_PORT = "0.0.0.0:6789" | ||
|
||
|
||
def serve(): | ||
print("Darts Suggestion Service") | ||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | ||
service = DartsService() | ||
api_pb2_grpc.add_SuggestionServicer_to_server(service, server) | ||
health_pb2_grpc.add_HealthServicer_to_server(service, server) | ||
server.add_insecure_port(DEFAULT_PORT) | ||
print("Listening...") | ||
server.start() | ||
try: | ||
while True: | ||
time.sleep(_ONE_DAY_IN_SECONDS) | ||
except KeyboardInterrupt: | ||
server.stop(0) | ||
|
||
|
||
if __name__ == "__main__": | ||
serve() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
grpcio==1.23.0 | ||
protobuf==3.9.1 | ||
googleapis-common-protos==1.6.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
ARG cuda_version=10.0 | ||
ARG cudnn_version=7 | ||
FROM pytorch/pytorch:1.0-cuda${cuda_version}-cudnn${cudnn_version}-runtime | ||
|
||
|
||
ADD . /usr/src/app/github.com/kubeflow/katib | ||
WORKDIR /usr/src/app/github.com/kubeflow/katib/examples/v1alpha3/nas/darts-cnn-cifar10 | ||
|
||
ENTRYPOINT ["python3", "-u", "run_trial.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import torch | ||
import copy | ||
|
||
|
||
class Architect(): | ||
"""" Architect controls architecture of cell by computing gradients of alphas | ||
""" | ||
|
||
def __init__(self, model, w_momentum, w_weight_decay): | ||
self.model = model | ||
self.v_model = copy.deepcopy(model) | ||
self.w_momentum = w_momentum | ||
self.w_weight_decay = w_weight_decay | ||
|
||
def virtual_step(self, train_x, train_y, xi, w_optim): | ||
""" | ||
Compute unrolled weight w' (virtual step) | ||
Step process: | ||
1) forward | ||
2) calculate loss | ||
3) compute gradient (by backprop) | ||
4) update gradient | ||
Args: | ||
xi: learning rate for virtual gradient step (same as weights lr) | ||
w_optim: weights optimizer | ||
""" | ||
|
||
# Forward and calculate loss | ||
# Loss for train with w. L_train(w) | ||
loss = self.model.loss(train_x, train_y) | ||
# Compute gradient | ||
gradients = torch.autograd.grad(loss, self.model.getWeights()) | ||
|
||
# Do virtual step (Update gradient) | ||
# Bellow opeartions do not need gradient tracking | ||
with torch.no_grad(): | ||
# dict key is not the value, but the pointer. So original network weight have to | ||
# be iterated also. | ||
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients): | ||
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum | ||
vw.copy_(w - xi * (m + g + self.w_weight_decay * w)) | ||
|
||
# Sync alphas | ||
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()): | ||
va.copy_(a) | ||
|
||
def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim): | ||
""" Compute unrolled loss and backward its gradients | ||
Args: | ||
xi: learning rate for virtual gradient step (same as model lr) | ||
w_optim: weights optimizer - for virtual step | ||
""" | ||
# Do virtual step (calc w') | ||
self.virtual_step(train_x, train_y, xi, w_optim) | ||
|
||
# Calculate unrolled loss | ||
# Loss for validation with w'. L_valid(w') | ||
loss = self.v_model.loss(valid_x, valid_y) | ||
|
||
# Calculate gradient | ||
v_alphas = tuple(self.v_model.getAlphas()) | ||
v_weights = tuple(self.v_model.getWeights()) | ||
v_grads = torch.autograd.grad(loss, v_alphas + v_weights) | ||
|
||
dalpha = v_grads[:len(v_alphas)] | ||
dws = v_grads[len(v_alphas):] | ||
|
||
hessian = self.compute_hessian(dws, train_x, train_y) | ||
|
||
# Update final gradient = dalpha - xi * hessian | ||
with torch.no_grad(): | ||
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian): | ||
alpha.grad = da - xi * h | ||
|
||
def compute_hessian(self, dws, train_x, train_y): | ||
""" | ||
dw = dw' { L_valid(w', alpha) } | ||
w+ = w + eps * dw | ||
w- = w - eps * dw | ||
hessian = (dalpha{ L_train(w+, alpha) } - dalpha{ L_train(w-, alpha) }) / (2*eps) | ||
eps = 0.01 / ||dw|| | ||
""" | ||
|
||
norm = torch.cat([dw.view(-1) for dw in dws]).norm() | ||
eps = 0.01 / norm | ||
|
||
# w+ = w + eps * dw | ||
with torch.no_grad(): | ||
for p, dw in zip(self.model.getWeights(), dws): | ||
p += eps * dw | ||
|
||
loss = self.model.loss(train_x, train_y) | ||
# dalpha { L_train(w+, alpha) } | ||
dalpha_positive = torch.autograd.grad(loss, self.model.getAlphas()) | ||
|
||
# w- = w - eps * dw | ||
with torch.no_grad(): | ||
for p, dw in zip(self.model.getWeights(), dws): | ||
# TODO (andreyvelich): Do we need this * 2.0 ? | ||
p -= 2. * eps * dw | ||
|
||
loss = self.model.loss(train_x, train_y) | ||
# dalpha { L_train(w-, alpha) } | ||
dalpha_negative = torch.autograd.grad(loss, self.model.getAlphas()) | ||
|
||
# recover w | ||
with torch.no_grad(): | ||
for p, dw in zip(self.model.getWeights(), dws): | ||
p += eps * dw | ||
|
||
hessian = [(p-n) / (2. * eps) for p, n in zip(dalpha_positive, dalpha_negative)] | ||
return hessian |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from operations import FactorizedReduce, StdConv, MixedOp | ||
|
||
|
||
class Cell(nn.Module): | ||
""" Cell for search | ||
Each edge is mixed and continuous relaxed. | ||
""" | ||
|
||
def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space): | ||
""" | ||
Args: | ||
num_nodes: Number of intermediate cell nodes | ||
c_prev_prev: channels_out[k-2] | ||
c_prev : Channels_out[k-1] | ||
c_cur : Channels_in[k] (current) | ||
reduction_prev: flag for whether the previous cell is reduction cell or not | ||
reduction_cur: flag for whether the current cell is reduction cell or not | ||
""" | ||
|
||
super(Cell, self).__init__() | ||
self.reduction_cur = reduction_cur | ||
self.num_nodes = num_nodes | ||
|
||
# If previous cell is reduction cell, current input size does not match with | ||
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing | ||
if reduction_prev: | ||
self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur) | ||
else: | ||
self.preprocess0 = StdConv(c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0) | ||
self.preprocess1 = StdConv(c_prev, c_cur, kernel_size=1, stride=1, padding=0) | ||
|
||
# Generate dag from mixed operations | ||
self.dag_ops = nn.ModuleList() | ||
|
||
for i in range(self.num_nodes): | ||
self.dag_ops.append(nn.ModuleList()) | ||
# Include 2 input nodes | ||
for j in range(2+i): | ||
# Reduction with stride = 2 must be only for the input node | ||
stride = 2 if reduction_cur and j < 2 else 1 | ||
op = MixedOp(c_cur, stride, search_space) | ||
self.dag_ops[i].append(op) | ||
|
||
def forward(self, s0, s1, w_dag): | ||
s0 = self.preprocess0(s0) | ||
s1 = self.preprocess1(s1) | ||
|
||
states = [s0, s1] | ||
for edges, w_list in zip(self.dag_ops, w_dag): | ||
state_cur = sum(edges[i](s, w) for i, (s, w) in enumerate((zip(states, w_list)))) | ||
states.append(state_cur) | ||
|
||
state_out = torch.cat(states[2:], dim=1) | ||
return state_out | ||
|
||
|
||
class NetworkCNN(nn.Module): | ||
|
||
def __init__(self, init_channels, input_channels, num_classes, num_layers, criterion, search_space): | ||
super(NetworkCNN, self).__init__() | ||
|
||
self.init_channels = init_channels | ||
self.num_classes = num_classes | ||
self.num_layers = num_layers | ||
self.criterion = criterion | ||
|
||
# TODO: Algorithm settings? | ||
self.num_nodes = 4 | ||
self.stem_multiplier = 3 | ||
|
||
c_cur = self.stem_multiplier*self.init_channels | ||
|
||
self.stem = nn.Sequential( | ||
nn.Conv2d(input_channels, c_cur, 3, padding=1, bias=False), | ||
nn.BatchNorm2d(c_cur) | ||
) | ||
|
||
# In first Cell stem is used for s0 and s1 | ||
# c_prev_prev and c_prev - output channels size | ||
# c_cur - init channels size | ||
c_prev_prev, c_prev, c_cur = c_cur, c_cur, self.init_channels | ||
|
||
self.cells = nn.ModuleList() | ||
|
||
reduction_prev = False | ||
for i in range(self.num_layers): | ||
# For [1/3, 2/3] Layers - Reduction cell with double channels | ||
# Others - Normal cell | ||
if i in [self.num_layers//3, 2*self.num_layers//3]: | ||
c_cur *= 2 | ||
reduction_cur = True | ||
else: | ||
reduction_cur = False | ||
|
||
cell = Cell(self.num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space) | ||
reduction_prev = reduction_cur | ||
self.cells.append(cell) | ||
|
||
c_cur_out = c_cur * self.num_nodes | ||
c_prev_prev, c_prev = c_prev, c_cur_out | ||
|
||
self.global_pooling = nn.AdaptiveAvgPool2d(1) | ||
self.classifier = nn.Linear(c_prev, self.num_classes) | ||
|
||
# Initialize alphas parameters | ||
num_ops = len(search_space.primitives) | ||
|
||
self.alpha_normal = nn.ParameterList() | ||
self.alpha_reduce = nn.ParameterList() | ||
|
||
for i in range(self.num_nodes): | ||
self.alpha_normal.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops))) | ||
self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops))) | ||
|
||
# Setup alphas list | ||
self.alphas = [] | ||
for name, parameter in self.named_parameters(): | ||
if "alpha" in name: | ||
self.alphas.append((name, parameter)) | ||
|
||
def forward(self, x): | ||
|
||
weights_normal = [F.softmax(alpha, dim=-1) for alpha in self.alpha_normal] | ||
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in self.alpha_reduce] | ||
|
||
s0 = s1 = self.stem(x) | ||
|
||
for cell in self.cells: | ||
weights = weights_reduce if cell.reduction_cur else weights_normal | ||
s0, s1 = s1, cell(s0, s1, weights) | ||
|
||
out = self.global_pooling(s1) | ||
|
||
# Make out flatten | ||
out = out.view(out.size(0), -1) | ||
|
||
logits = self.classifier(out) | ||
return logits | ||
|
||
def print_alphas(self): | ||
|
||
print("\n>>> Alphas Normal <<<") | ||
for alpha in self.alpha_normal: | ||
print(F.softmax(alpha, dim=-1)) | ||
|
||
print("\n>>> Alpha Reduce <<<") | ||
for alpha in self.alpha_reduce: | ||
print(F.softmax(alpha, dim=-1)) | ||
print("\n") | ||
|
||
def getWeights(self): | ||
return self.parameters() | ||
|
||
def getAlphas(self): | ||
for _, parameter in self.alphas: | ||
yield parameter | ||
|
||
def loss(self, x, y): | ||
logits = self.forward(x) | ||
return self.criterion(logits, y) | ||
|
||
def genotype(self, search_space): | ||
gene_normal = search_space.parse(self.alpha_normal, k=2) | ||
gene_reduce = search_space.parse(self.alpha_reduce, k=2) | ||
# concat all intermediate nodes | ||
concat = range(2, 2 + self.num_nodes) | ||
|
||
return search_space.genotype(normal=gene_normal, normal_concat=concat, | ||
reduce=gene_reduce, reduce_concat=concat) |
Oops, something went wrong.