Skip to content

Commit

Permalink
DARTS Suggestion (#1175)
Browse files Browse the repository at this point in the history
* First commit with darts

* Support darts in Katib

* Fix problems

* Modify darts example

* Change num nodes to 4
  • Loading branch information
andreyvelich authored May 6, 2020
1 parent 2d35d55 commit c23f9c6
Show file tree
Hide file tree
Showing 14 changed files with 1,066 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
.gitignore
docs
examples
!examples/v1alpha3/nas/enas-cnn-cifar10
!examples/v1alpha3/nas
manifests
pkg/ui/*/frontend/node_modules
pkg/ui/*/frontend/build
26 changes: 26 additions & 0 deletions cmd/suggestion/nas/darts/v1alpha3/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
FROM python:3.6

RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \
apt-get -y update && \
apt-get -y install gfortran libopenblas-dev liblapack-dev && \
pip install cython; \
fi

RUN GRPC_HEALTH_PROBE_VERSION=v0.3.1 && \
if [ "$(uname -m)" = "ppc64le" ]; then \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \
elif [ "$(uname -m)" = "aarch64" ]; then \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \
else \
wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \
fi && \
chmod +x /bin/grpc_health_probe

ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/nas/darts/v1alpha3
RUN pip install --no-cache-dir -r requirements.txt

ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/v1alpha3/python:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/health/python

ENTRYPOINT ["python", "main.py"]

30 changes: 30 additions & 0 deletions cmd/suggestion/nas/darts/v1alpha3/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import grpc
from concurrent import futures
import time
from pkg.apis.manager.v1alpha3.python import api_pb2_grpc
from pkg.apis.manager.health.python import health_pb2_grpc
from pkg.suggestion.v1alpha3.nas.darts.service import DartsService


_ONE_DAY_IN_SECONDS = 60 * 60 * 24
DEFAULT_PORT = "0.0.0.0:6789"


def serve():
print("Darts Suggestion Service")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
service = DartsService()
api_pb2_grpc.add_SuggestionServicer_to_server(service, server)
health_pb2_grpc.add_HealthServicer_to_server(service, server)
server.add_insecure_port(DEFAULT_PORT)
print("Listening...")
server.start()
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)


if __name__ == "__main__":
serve()
3 changes: 3 additions & 0 deletions cmd/suggestion/nas/darts/v1alpha3/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
grpcio==1.23.0
protobuf==3.9.1
googleapis-common-protos==1.6.0
9 changes: 9 additions & 0 deletions examples/v1alpha3/nas/darts-cnn-cifar10/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
ARG cuda_version=10.0
ARG cudnn_version=7
FROM pytorch/pytorch:1.0-cuda${cuda_version}-cudnn${cudnn_version}-runtime


ADD . /usr/src/app/github.com/kubeflow/katib
WORKDIR /usr/src/app/github.com/kubeflow/katib/examples/v1alpha3/nas/darts-cnn-cifar10

ENTRYPOINT ["python3", "-u", "run_trial.py"]
113 changes: 113 additions & 0 deletions examples/v1alpha3/nas/darts-cnn-cifar10/architect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import copy


class Architect():
"""" Architect controls architecture of cell by computing gradients of alphas
"""

def __init__(self, model, w_momentum, w_weight_decay):
self.model = model
self.v_model = copy.deepcopy(model)
self.w_momentum = w_momentum
self.w_weight_decay = w_weight_decay

def virtual_step(self, train_x, train_y, xi, w_optim):
"""
Compute unrolled weight w' (virtual step)
Step process:
1) forward
2) calculate loss
3) compute gradient (by backprop)
4) update gradient
Args:
xi: learning rate for virtual gradient step (same as weights lr)
w_optim: weights optimizer
"""

# Forward and calculate loss
# Loss for train with w. L_train(w)
loss = self.model.loss(train_x, train_y)
# Compute gradient
gradients = torch.autograd.grad(loss, self.model.getWeights())

# Do virtual step (Update gradient)
# Bellow opeartions do not need gradient tracking
with torch.no_grad():
# dict key is not the value, but the pointer. So original network weight have to
# be iterated also.
for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients):
m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum
vw.copy_(w - xi * (m + g + self.w_weight_decay * w))

# Sync alphas
for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()):
va.copy_(a)

def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim):
""" Compute unrolled loss and backward its gradients
Args:
xi: learning rate for virtual gradient step (same as model lr)
w_optim: weights optimizer - for virtual step
"""
# Do virtual step (calc w')
self.virtual_step(train_x, train_y, xi, w_optim)

# Calculate unrolled loss
# Loss for validation with w'. L_valid(w')
loss = self.v_model.loss(valid_x, valid_y)

# Calculate gradient
v_alphas = tuple(self.v_model.getAlphas())
v_weights = tuple(self.v_model.getWeights())
v_grads = torch.autograd.grad(loss, v_alphas + v_weights)

dalpha = v_grads[:len(v_alphas)]
dws = v_grads[len(v_alphas):]

hessian = self.compute_hessian(dws, train_x, train_y)

# Update final gradient = dalpha - xi * hessian
with torch.no_grad():
for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian):
alpha.grad = da - xi * h

def compute_hessian(self, dws, train_x, train_y):
"""
dw = dw' { L_valid(w', alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha{ L_train(w+, alpha) } - dalpha{ L_train(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""

norm = torch.cat([dw.view(-1) for dw in dws]).norm()
eps = 0.01 / norm

# w+ = w + eps * dw
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
p += eps * dw

loss = self.model.loss(train_x, train_y)
# dalpha { L_train(w+, alpha) }
dalpha_positive = torch.autograd.grad(loss, self.model.getAlphas())

# w- = w - eps * dw
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
# TODO (andreyvelich): Do we need this * 2.0 ?
p -= 2. * eps * dw

loss = self.model.loss(train_x, train_y)
# dalpha { L_train(w-, alpha) }
dalpha_negative = torch.autograd.grad(loss, self.model.getAlphas())

# recover w
with torch.no_grad():
for p, dw in zip(self.model.getWeights(), dws):
p += eps * dw

hessian = [(p-n) / (2. * eps) for p, n in zip(dalpha_positive, dalpha_negative)]
return hessian
172 changes: 172 additions & 0 deletions examples/v1alpha3/nas/darts-cnn-cifar10/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from operations import FactorizedReduce, StdConv, MixedOp


class Cell(nn.Module):
""" Cell for search
Each edge is mixed and continuous relaxed.
"""

def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space):
"""
Args:
num_nodes: Number of intermediate cell nodes
c_prev_prev: channels_out[k-2]
c_prev : Channels_out[k-1]
c_cur : Channels_in[k] (current)
reduction_prev: flag for whether the previous cell is reduction cell or not
reduction_cur: flag for whether the current cell is reduction cell or not
"""

super(Cell, self).__init__()
self.reduction_cur = reduction_cur
self.num_nodes = num_nodes

# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing
if reduction_prev:
self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur)
else:
self.preprocess0 = StdConv(c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0)
self.preprocess1 = StdConv(c_prev, c_cur, kernel_size=1, stride=1, padding=0)

# Generate dag from mixed operations
self.dag_ops = nn.ModuleList()

for i in range(self.num_nodes):
self.dag_ops.append(nn.ModuleList())
# Include 2 input nodes
for j in range(2+i):
# Reduction with stride = 2 must be only for the input node
stride = 2 if reduction_cur and j < 2 else 1
op = MixedOp(c_cur, stride, search_space)
self.dag_ops[i].append(op)

def forward(self, s0, s1, w_dag):
s0 = self.preprocess0(s0)
s1 = self.preprocess1(s1)

states = [s0, s1]
for edges, w_list in zip(self.dag_ops, w_dag):
state_cur = sum(edges[i](s, w) for i, (s, w) in enumerate((zip(states, w_list))))
states.append(state_cur)

state_out = torch.cat(states[2:], dim=1)
return state_out


class NetworkCNN(nn.Module):

def __init__(self, init_channels, input_channels, num_classes, num_layers, criterion, search_space):
super(NetworkCNN, self).__init__()

self.init_channels = init_channels
self.num_classes = num_classes
self.num_layers = num_layers
self.criterion = criterion

# TODO: Algorithm settings?
self.num_nodes = 4
self.stem_multiplier = 3

c_cur = self.stem_multiplier*self.init_channels

self.stem = nn.Sequential(
nn.Conv2d(input_channels, c_cur, 3, padding=1, bias=False),
nn.BatchNorm2d(c_cur)
)

# In first Cell stem is used for s0 and s1
# c_prev_prev and c_prev - output channels size
# c_cur - init channels size
c_prev_prev, c_prev, c_cur = c_cur, c_cur, self.init_channels

self.cells = nn.ModuleList()

reduction_prev = False
for i in range(self.num_layers):
# For [1/3, 2/3] Layers - Reduction cell with double channels
# Others - Normal cell
if i in [self.num_layers//3, 2*self.num_layers//3]:
c_cur *= 2
reduction_cur = True
else:
reduction_cur = False

cell = Cell(self.num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space)
reduction_prev = reduction_cur
self.cells.append(cell)

c_cur_out = c_cur * self.num_nodes
c_prev_prev, c_prev = c_prev, c_cur_out

self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(c_prev, self.num_classes)

# Initialize alphas parameters
num_ops = len(search_space.primitives)

self.alpha_normal = nn.ParameterList()
self.alpha_reduce = nn.ParameterList()

for i in range(self.num_nodes):
self.alpha_normal.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops)))
self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops)))

# Setup alphas list
self.alphas = []
for name, parameter in self.named_parameters():
if "alpha" in name:
self.alphas.append((name, parameter))

def forward(self, x):

weights_normal = [F.softmax(alpha, dim=-1) for alpha in self.alpha_normal]
weights_reduce = [F.softmax(alpha, dim=-1) for alpha in self.alpha_reduce]

s0 = s1 = self.stem(x)

for cell in self.cells:
weights = weights_reduce if cell.reduction_cur else weights_normal
s0, s1 = s1, cell(s0, s1, weights)

out = self.global_pooling(s1)

# Make out flatten
out = out.view(out.size(0), -1)

logits = self.classifier(out)
return logits

def print_alphas(self):

print("\n>>> Alphas Normal <<<")
for alpha in self.alpha_normal:
print(F.softmax(alpha, dim=-1))

print("\n>>> Alpha Reduce <<<")
for alpha in self.alpha_reduce:
print(F.softmax(alpha, dim=-1))
print("\n")

def getWeights(self):
return self.parameters()

def getAlphas(self):
for _, parameter in self.alphas:
yield parameter

def loss(self, x, y):
logits = self.forward(x)
return self.criterion(logits, y)

def genotype(self, search_space):
gene_normal = search_space.parse(self.alpha_normal, k=2)
gene_reduce = search_space.parse(self.alpha_reduce, k=2)
# concat all intermediate nodes
concat = range(2, 2 + self.num_nodes)

return search_space.genotype(normal=gene_normal, normal_concat=concat,
reduce=gene_reduce, reduce_concat=concat)
Loading

0 comments on commit c23f9c6

Please sign in to comment.