From c23f9c69bf5808da498d2970e90f241b7db2f323 Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Wed, 6 May 2020 04:13:09 +0100 Subject: [PATCH] DARTS Suggestion (#1175) * First commit with darts * Support darts in Katib * Fix problems * Modify darts example * Change num nodes to 4 --- .dockerignore | 2 +- cmd/suggestion/nas/darts/v1alpha3/Dockerfile | 26 ++ cmd/suggestion/nas/darts/v1alpha3/main.py | 30 +++ .../nas/darts/v1alpha3/requirements.txt | 3 + .../v1alpha3/nas/darts-cnn-cifar10/Dockerfile | 9 + .../nas/darts-cnn-cifar10/architect.py | 113 +++++++++ .../v1alpha3/nas/darts-cnn-cifar10/model.py | 172 +++++++++++++ .../nas/darts-cnn-cifar10/operations.py | 166 +++++++++++++ .../nas/darts-cnn-cifar10/run_trial.py | 231 ++++++++++++++++++ .../nas/darts-cnn-cifar10/search_space.py | 50 ++++ .../v1alpha3/nas/darts-cnn-cifar10/utils.py | 70 ++++++ examples/v1alpha3/nas/darts-example-gpu.yaml | 85 +++++++ pkg/suggestion/v1alpha3/nas/darts/__init__.py | 0 pkg/suggestion/v1alpha3/nas/darts/service.py | 110 +++++++++ 14 files changed, 1066 insertions(+), 1 deletion(-) create mode 100644 cmd/suggestion/nas/darts/v1alpha3/Dockerfile create mode 100644 cmd/suggestion/nas/darts/v1alpha3/main.py create mode 100644 cmd/suggestion/nas/darts/v1alpha3/requirements.txt create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/Dockerfile create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/architect.py create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/model.py create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/operations.py create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/run_trial.py create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/search_space.py create mode 100644 examples/v1alpha3/nas/darts-cnn-cifar10/utils.py create mode 100644 examples/v1alpha3/nas/darts-example-gpu.yaml create mode 100644 pkg/suggestion/v1alpha3/nas/darts/__init__.py create mode 100644 pkg/suggestion/v1alpha3/nas/darts/service.py diff --git a/.dockerignore b/.dockerignore index 6854764de2c..95feeba4c6f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,7 +2,7 @@ .gitignore docs examples -!examples/v1alpha3/nas/enas-cnn-cifar10 +!examples/v1alpha3/nas manifests pkg/ui/*/frontend/node_modules pkg/ui/*/frontend/build diff --git a/cmd/suggestion/nas/darts/v1alpha3/Dockerfile b/cmd/suggestion/nas/darts/v1alpha3/Dockerfile new file mode 100644 index 00000000000..6f640674fca --- /dev/null +++ b/cmd/suggestion/nas/darts/v1alpha3/Dockerfile @@ -0,0 +1,26 @@ +FROM python:3.6 + +RUN if [ "$(uname -m)" = "ppc64le" ] || [ "$(uname -m)" = "aarch64" ]; then \ + apt-get -y update && \ + apt-get -y install gfortran libopenblas-dev liblapack-dev && \ + pip install cython; \ + fi + +RUN GRPC_HEALTH_PROBE_VERSION=v0.3.1 && \ + if [ "$(uname -m)" = "ppc64le" ]; then \ + wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-ppc64le; \ + elif [ "$(uname -m)" = "aarch64" ]; then \ + wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-arm64; \ + else \ + wget -qO/bin/grpc_health_probe https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64; \ + fi && \ + chmod +x /bin/grpc_health_probe + +ADD . /usr/src/app/github.com/kubeflow/katib +WORKDIR /usr/src/app/github.com/kubeflow/katib/cmd/suggestion/nas/darts/v1alpha3 +RUN pip install --no-cache-dir -r requirements.txt + +ENV PYTHONPATH /usr/src/app/github.com/kubeflow/katib:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/v1alpha3/python:/usr/src/app/github.com/kubeflow/katib/pkg/apis/manager/health/python + +ENTRYPOINT ["python", "main.py"] + diff --git a/cmd/suggestion/nas/darts/v1alpha3/main.py b/cmd/suggestion/nas/darts/v1alpha3/main.py new file mode 100644 index 00000000000..b33b3116745 --- /dev/null +++ b/cmd/suggestion/nas/darts/v1alpha3/main.py @@ -0,0 +1,30 @@ +import grpc +from concurrent import futures +import time +from pkg.apis.manager.v1alpha3.python import api_pb2_grpc +from pkg.apis.manager.health.python import health_pb2_grpc +from pkg.suggestion.v1alpha3.nas.darts.service import DartsService + + +_ONE_DAY_IN_SECONDS = 60 * 60 * 24 +DEFAULT_PORT = "0.0.0.0:6789" + + +def serve(): + print("Darts Suggestion Service") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + service = DartsService() + api_pb2_grpc.add_SuggestionServicer_to_server(service, server) + health_pb2_grpc.add_HealthServicer_to_server(service, server) + server.add_insecure_port(DEFAULT_PORT) + print("Listening...") + server.start() + try: + while True: + time.sleep(_ONE_DAY_IN_SECONDS) + except KeyboardInterrupt: + server.stop(0) + + +if __name__ == "__main__": + serve() diff --git a/cmd/suggestion/nas/darts/v1alpha3/requirements.txt b/cmd/suggestion/nas/darts/v1alpha3/requirements.txt new file mode 100644 index 00000000000..92bd5706e11 --- /dev/null +++ b/cmd/suggestion/nas/darts/v1alpha3/requirements.txt @@ -0,0 +1,3 @@ +grpcio==1.23.0 +protobuf==3.9.1 +googleapis-common-protos==1.6.0 diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/Dockerfile b/examples/v1alpha3/nas/darts-cnn-cifar10/Dockerfile new file mode 100644 index 00000000000..c7ee6579510 --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/Dockerfile @@ -0,0 +1,9 @@ +ARG cuda_version=10.0 +ARG cudnn_version=7 +FROM pytorch/pytorch:1.0-cuda${cuda_version}-cudnn${cudnn_version}-runtime + + +ADD . /usr/src/app/github.com/kubeflow/katib +WORKDIR /usr/src/app/github.com/kubeflow/katib/examples/v1alpha3/nas/darts-cnn-cifar10 + +ENTRYPOINT ["python3", "-u", "run_trial.py"] diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/architect.py b/examples/v1alpha3/nas/darts-cnn-cifar10/architect.py new file mode 100644 index 00000000000..65217706af3 --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/architect.py @@ -0,0 +1,113 @@ +import torch +import copy + + +class Architect(): + """" Architect controls architecture of cell by computing gradients of alphas + """ + + def __init__(self, model, w_momentum, w_weight_decay): + self.model = model + self.v_model = copy.deepcopy(model) + self.w_momentum = w_momentum + self.w_weight_decay = w_weight_decay + + def virtual_step(self, train_x, train_y, xi, w_optim): + """ + Compute unrolled weight w' (virtual step) + Step process: + 1) forward + 2) calculate loss + 3) compute gradient (by backprop) + 4) update gradient + + Args: + xi: learning rate for virtual gradient step (same as weights lr) + w_optim: weights optimizer + """ + + # Forward and calculate loss + # Loss for train with w. L_train(w) + loss = self.model.loss(train_x, train_y) + # Compute gradient + gradients = torch.autograd.grad(loss, self.model.getWeights()) + + # Do virtual step (Update gradient) + # Bellow opeartions do not need gradient tracking + with torch.no_grad(): + # dict key is not the value, but the pointer. So original network weight have to + # be iterated also. + for w, vw, g in zip(self.model.getWeights(), self.v_model.getWeights(), gradients): + m = w_optim.state[w].get("momentum_buffer", 0.) * self.w_momentum + vw.copy_(w - xi * (m + g + self.w_weight_decay * w)) + + # Sync alphas + for a, va in zip(self.model.getAlphas(), self.v_model.getAlphas()): + va.copy_(a) + + def unrolled_backward(self, train_x, train_y, valid_x, valid_y, xi, w_optim): + """ Compute unrolled loss and backward its gradients + Args: + xi: learning rate for virtual gradient step (same as model lr) + w_optim: weights optimizer - for virtual step + """ + # Do virtual step (calc w') + self.virtual_step(train_x, train_y, xi, w_optim) + + # Calculate unrolled loss + # Loss for validation with w'. L_valid(w') + loss = self.v_model.loss(valid_x, valid_y) + + # Calculate gradient + v_alphas = tuple(self.v_model.getAlphas()) + v_weights = tuple(self.v_model.getWeights()) + v_grads = torch.autograd.grad(loss, v_alphas + v_weights) + + dalpha = v_grads[:len(v_alphas)] + dws = v_grads[len(v_alphas):] + + hessian = self.compute_hessian(dws, train_x, train_y) + + # Update final gradient = dalpha - xi * hessian + with torch.no_grad(): + for alpha, da, h in zip(self.model.getAlphas(), dalpha, hessian): + alpha.grad = da - xi * h + + def compute_hessian(self, dws, train_x, train_y): + """ + dw = dw' { L_valid(w', alpha) } + w+ = w + eps * dw + w- = w - eps * dw + hessian = (dalpha{ L_train(w+, alpha) } - dalpha{ L_train(w-, alpha) }) / (2*eps) + eps = 0.01 / ||dw|| + """ + + norm = torch.cat([dw.view(-1) for dw in dws]).norm() + eps = 0.01 / norm + + # w+ = w + eps * dw + with torch.no_grad(): + for p, dw in zip(self.model.getWeights(), dws): + p += eps * dw + + loss = self.model.loss(train_x, train_y) + # dalpha { L_train(w+, alpha) } + dalpha_positive = torch.autograd.grad(loss, self.model.getAlphas()) + + # w- = w - eps * dw + with torch.no_grad(): + for p, dw in zip(self.model.getWeights(), dws): + # TODO (andreyvelich): Do we need this * 2.0 ? + p -= 2. * eps * dw + + loss = self.model.loss(train_x, train_y) + # dalpha { L_train(w-, alpha) } + dalpha_negative = torch.autograd.grad(loss, self.model.getAlphas()) + + # recover w + with torch.no_grad(): + for p, dw in zip(self.model.getWeights(), dws): + p += eps * dw + + hessian = [(p-n) / (2. * eps) for p, n in zip(dalpha_positive, dalpha_negative)] + return hessian diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/model.py b/examples/v1alpha3/nas/darts-cnn-cifar10/model.py new file mode 100644 index 00000000000..d17894b14c4 --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/model.py @@ -0,0 +1,172 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from operations import FactorizedReduce, StdConv, MixedOp + + +class Cell(nn.Module): + """ Cell for search + Each edge is mixed and continuous relaxed. + """ + + def __init__(self, num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space): + """ + Args: + num_nodes: Number of intermediate cell nodes + c_prev_prev: channels_out[k-2] + c_prev : Channels_out[k-1] + c_cur : Channels_in[k] (current) + reduction_prev: flag for whether the previous cell is reduction cell or not + reduction_cur: flag for whether the current cell is reduction cell or not + """ + + super(Cell, self).__init__() + self.reduction_cur = reduction_cur + self.num_nodes = num_nodes + + # If previous cell is reduction cell, current input size does not match with + # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing + if reduction_prev: + self.preprocess0 = FactorizedReduce(c_prev_prev, c_cur) + else: + self.preprocess0 = StdConv(c_prev_prev, c_cur, kernel_size=1, stride=1, padding=0) + self.preprocess1 = StdConv(c_prev, c_cur, kernel_size=1, stride=1, padding=0) + + # Generate dag from mixed operations + self.dag_ops = nn.ModuleList() + + for i in range(self.num_nodes): + self.dag_ops.append(nn.ModuleList()) + # Include 2 input nodes + for j in range(2+i): + # Reduction with stride = 2 must be only for the input node + stride = 2 if reduction_cur and j < 2 else 1 + op = MixedOp(c_cur, stride, search_space) + self.dag_ops[i].append(op) + + def forward(self, s0, s1, w_dag): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + for edges, w_list in zip(self.dag_ops, w_dag): + state_cur = sum(edges[i](s, w) for i, (s, w) in enumerate((zip(states, w_list)))) + states.append(state_cur) + + state_out = torch.cat(states[2:], dim=1) + return state_out + + +class NetworkCNN(nn.Module): + + def __init__(self, init_channels, input_channels, num_classes, num_layers, criterion, search_space): + super(NetworkCNN, self).__init__() + + self.init_channels = init_channels + self.num_classes = num_classes + self.num_layers = num_layers + self.criterion = criterion + + # TODO: Algorithm settings? + self.num_nodes = 4 + self.stem_multiplier = 3 + + c_cur = self.stem_multiplier*self.init_channels + + self.stem = nn.Sequential( + nn.Conv2d(input_channels, c_cur, 3, padding=1, bias=False), + nn.BatchNorm2d(c_cur) + ) + + # In first Cell stem is used for s0 and s1 + # c_prev_prev and c_prev - output channels size + # c_cur - init channels size + c_prev_prev, c_prev, c_cur = c_cur, c_cur, self.init_channels + + self.cells = nn.ModuleList() + + reduction_prev = False + for i in range(self.num_layers): + # For [1/3, 2/3] Layers - Reduction cell with double channels + # Others - Normal cell + if i in [self.num_layers//3, 2*self.num_layers//3]: + c_cur *= 2 + reduction_cur = True + else: + reduction_cur = False + + cell = Cell(self.num_nodes, c_prev_prev, c_prev, c_cur, reduction_prev, reduction_cur, search_space) + reduction_prev = reduction_cur + self.cells.append(cell) + + c_cur_out = c_cur * self.num_nodes + c_prev_prev, c_prev = c_prev, c_cur_out + + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(c_prev, self.num_classes) + + # Initialize alphas parameters + num_ops = len(search_space.primitives) + + self.alpha_normal = nn.ParameterList() + self.alpha_reduce = nn.ParameterList() + + for i in range(self.num_nodes): + self.alpha_normal.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops))) + self.alpha_reduce.append(nn.Parameter(1e-3*torch.randn(i+2, num_ops))) + + # Setup alphas list + self.alphas = [] + for name, parameter in self.named_parameters(): + if "alpha" in name: + self.alphas.append((name, parameter)) + + def forward(self, x): + + weights_normal = [F.softmax(alpha, dim=-1) for alpha in self.alpha_normal] + weights_reduce = [F.softmax(alpha, dim=-1) for alpha in self.alpha_reduce] + + s0 = s1 = self.stem(x) + + for cell in self.cells: + weights = weights_reduce if cell.reduction_cur else weights_normal + s0, s1 = s1, cell(s0, s1, weights) + + out = self.global_pooling(s1) + + # Make out flatten + out = out.view(out.size(0), -1) + + logits = self.classifier(out) + return logits + + def print_alphas(self): + + print("\n>>> Alphas Normal <<<") + for alpha in self.alpha_normal: + print(F.softmax(alpha, dim=-1)) + + print("\n>>> Alpha Reduce <<<") + for alpha in self.alpha_reduce: + print(F.softmax(alpha, dim=-1)) + print("\n") + + def getWeights(self): + return self.parameters() + + def getAlphas(self): + for _, parameter in self.alphas: + yield parameter + + def loss(self, x, y): + logits = self.forward(x) + return self.criterion(logits, y) + + def genotype(self, search_space): + gene_normal = search_space.parse(self.alpha_normal, k=2) + gene_reduce = search_space.parse(self.alpha_reduce, k=2) + # concat all intermediate nodes + concat = range(2, 2 + self.num_nodes) + + return search_space.genotype(normal=gene_normal, normal_concat=concat, + reduce=gene_reduce, reduce_concat=concat) diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/operations.py b/examples/v1alpha3/nas/darts-cnn-cifar10/operations.py new file mode 100644 index 00000000000..e9efe8b9e63 --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/operations.py @@ -0,0 +1,166 @@ +import torch.nn as nn +import torch + +OPS = { + 'none': lambda channels, stride: Zero(stride), + 'avg_pooling_3x3': lambda channels, stride: PoolBN('avg', channels, kernel_size=3, stride=stride, padding=1), + 'max_pooling_3x3': lambda channels, stride: PoolBN('max', channels, kernel_size=3, stride=stride, padding=1), + 'skip_connection': lambda channels, stride: Identity() if stride == 1 else FactorizedReduce(channels, channels), + 'separable_convolution_3x3': lambda channels, stride: SepConv(channels, kernel_size=3, stride=stride, padding=1), + 'separable_convolution_5x5': lambda channels, stride: SepConv(channels, kernel_size=5, stride=stride, padding=2), + # 3x3 -> 5x5 + 'dilated_convolution_3x3': lambda channels, stride: DilConv(channels, + kernel_size=3, stride=stride, padding=2, dilation=2), + # 5x5 -> 9x9 + 'dilated_convolution_5x5': lambda channels, stride: DilConv(channels, + kernel_size=5, stride=stride, padding=4, dilation=2), +} + + +class Zero(nn.Module): + """ + Zero operation + """ + + def __init__(self, stride): + super(Zero, self).__init__() + self.stride = stride + + def forward(self, x): + if self.stride == 1: + return x * 0. + # Resize by stride + return x[:, :, ::self.stride, ::self.stride] * 0. + + +class PoolBN(nn.Module): + """ + Avg or Max pooling - BN + """ + + def __init__(self, pool_type, channels, kernel_size, stride, padding): + super(PoolBN, self).__init__() + if pool_type == "avg": + self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False) + elif pool_type == "max": + self.pool = nn.MaxPool2d(kernel_size, stride, padding) + + self.bn = nn.BatchNorm2d(channels, affine=False) + self.net = nn.Sequential( + self.pool, + self.bn + ) + + def forward(self, x): + # out = self.pool(x), + # print(out) + # out = self.bn(out) + # print(out) + return self.net(x) + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class FactorizedReduce(nn.Module): + """ + Reduce feature map size by factorized pointwise (stride=2) + ReLU - Conv1 - Conv2 - BN + """ + + def __init__(self, c_in, c_out): + super(FactorizedReduce, self).__init__() + self.relu = nn.ReLU() + self.conv1 = nn.Conv2d(c_in, c_out // 2, kernel_size=1, stride=2, padding=0, bias=False) + self.conv2 = nn.Conv2d(c_in, c_out // 2, kernel_size=1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(c_out, affine=False) + + def forward(self, x): + + x = self.relu(x) + out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + + return out + + +class StdConv(nn.Module): + """ Standard convolition + ReLU - Conv - BN + """ + + def __init__(self, c_in, c_out, kernel_size, stride, padding): + super(StdConv, self).__init__() + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(c_in, c_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), + nn.BatchNorm2d(c_out, affine=False) + ) + + def forward(self, x): + return self.net(x) + + +class DilConv(nn.Module): + """ (Dilated) depthwise separable conv + ReLU - (Dilated) depthwise separable - Pointwise - BN + + If dilation == 2, 3x3 conv => 5x5 receptive field + 5x5 conv => 9x9 receptive field + """ + + def __init__(self, channels, kernel_size, stride, padding, dilation): + super(DilConv, self).__init__() + + self.net = nn.Sequential( + nn.ReLU(), + nn.Conv2d(channels, channels, kernel_size, stride, padding, dilation=dilation, groups=channels, bias=False), + nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(channels, affine=False) + ) + + def forward(self, x): + return self.net(x) + + +class SepConv(nn.Module): + """ Depthwise separable conv + DilConv (dilation=1) * 2 + """ + + def __init__(self, channels, kernel_size, stride, padding): + super(SepConv, self).__init__() + self.net = nn.Sequential( + DilConv(channels, kernel_size, stride=stride, padding=padding, dilation=1), + DilConv(channels, kernel_size, stride=1, padding=padding, dilation=1) + ) + + def forward(self, x): + return self.net(x) + + +class MixedOp(nn.Module): + """ Mixed operation + """ + + def __init__(self, channels, stride, search_space): + super(MixedOp, self).__init__() + self.ops = nn.ModuleList() + + for primitive in search_space.primitives: + op = OPS[primitive](channels, stride) + self.ops.append(op) + + def forward(self, x, weights): + """ + Args: + x: input + weights: weight for each operation + """ + return sum(w * op(x) for w, op in zip(weights, self.ops)) diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/run_trial.py b/examples/v1alpha3/nas/darts-cnn-cifar10/run_trial.py new file mode 100644 index 00000000000..550cbc42cf3 --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/run_trial.py @@ -0,0 +1,231 @@ + +import torch.nn as nn + +import torch +import argparse +import json +import numpy as np + +from model import NetworkCNN +from architect import Architect +import utils +from search_space import SearchSpace + + +# TODO: Move to the algorithm settings +w_lr = 0.025 +w_lr_min = 0.001 +w_momentum = 0.9 +w_weight_decay = 3e-4 +w_grad_clip = 5. + +alpha_lr = 3e-4 +alpha_weight_decay = 1e-3 + +batch_size = 128 +num_workers = 4 + +init_channels = 16 + +print_step = 50 + + +def main(): + + parser = argparse.ArgumentParser(description='TrainingContainer') + parser.add_argument('--algorithm-settings', type=str, default="", help="algorithm settings") + parser.add_argument('--search-space', type=str, default="", help="search space for the neural architecture search") + parser.add_argument('--num-layers', type=str, default="", help="number of layers of the neural network") + + args = parser.parse_args() + + algorithm_settings = args.algorithm_settings.replace("\'", "\"") + algorithm_settings = json.loads(algorithm_settings) + print("Algorithm settings") + print("{}\n".format(algorithm_settings)) + num_epochs = int(algorithm_settings["num_epoch"]) + + search_space = args.search_space.replace("\'", "\"") + search_space = json.loads(search_space) + search_space = SearchSpace(search_space) + + num_layers = int(args.num_layers) + print("Number of layers {}\n".format(num_layers)) + + # Set GPU Device + # Currently use only first available GPU + # TODO: Add multi GPU support + # TODO: Add functionality to select GPU + all_gpus = list(range(torch.cuda.device_count())) + if len(all_gpus) > 0: + device = torch.device("cuda") + torch.cuda.set_device(all_gpus[0]) + np.random.seed(2) + torch.manual_seed(2) + torch.cuda.manual_seed_all(2) + torch.backends.cudnn.benchmark = True + print(">>> Use GPU for Training <<<") + print("Device ID: {}".format(torch.cuda.current_device())) + print("Device name: {}".format(torch.cuda.get_device_name(0))) + print("Device availability: {}\n".format(torch.cuda.is_available())) + else: + device = torch.device("cpu") + print(">>> Use CPU for Training <<<") + + # Get dataset with meta information + # TODO: Add support for more dataset + input_channels, num_classes, train_data = utils.get_dataset() + + criterion = nn.CrossEntropyLoss().to(device) + + model = NetworkCNN(init_channels, input_channels, num_classes, num_layers, criterion, search_space) + + model = model.to(device) + + # Weights optimizer + w_optim = torch.optim.SGD(model.getWeights(), w_lr, momentum=w_momentum, weight_decay=w_weight_decay) + + # Alphas optimizer + alpha_optim = torch.optim.Adam(model.getAlphas(), alpha_lr, betas=(0.5, 0.999), weight_decay=alpha_weight_decay) + + # Split data to train/validation + num_train = len(train_data) + split = num_train // 2 + indices = list(range(num_train)) + + train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split]) + valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:]) + + train_loader = torch.utils.data.DataLoader(train_data, + batch_size=batch_size, + sampler=train_sampler, + num_workers=num_workers, + pin_memory=True) + + valid_loader = torch.utils.data.DataLoader(train_data, + batch_size=batch_size, + sampler=valid_sampler, + num_workers=num_workers, + pin_memory=True) + + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + w_optim, + num_epochs, + eta_min=w_lr_min) + + architect = Architect(model, w_momentum, w_weight_decay) + + # Start training + best_top1 = 0. + + for epoch in range(num_epochs): + lr_scheduler.step() + lr = lr_scheduler.get_lr()[0] + + model.print_alphas() + + # Training + print(">>> Training") + train(train_loader, valid_loader, model, architect, w_optim, + alpha_optim, lr, epoch, num_epochs, device) + + # Validation + print("\n>>> Validation") + cur_step = (epoch + 1) * len(train_loader) + top1 = validate(valid_loader, model, epoch, cur_step, num_epochs, device) + + # Print genotype + genotype = model.genotype(search_space) + print("\nModel genotype = {}".format(genotype)) + + # Modify best top1 + if top1 > best_top1: + best_top1 = top1 + best_genotype = genotype + + print("Final best Prec@1 = {:.4%}".format(best_top1)) + print("\nBest-Genotype={}".format(str(best_genotype).replace(" ", ""))) + + +def train(train_loader, valid_loader, model, architect, w_optim, alpha_optim, lr, epoch, num_epochs, device): + top1 = utils.AverageMeter() + top5 = utils.AverageMeter() + losses = utils.AverageMeter() + cur_step = epoch * len(train_loader) + + model.train() + for step, ((train_x, train_y), (valid_x, valid_y)) in enumerate(zip(train_loader, valid_loader)): + + train_x, train_y = train_x.to(device, non_blocking=True), train_y.to(device, non_blocking=True) + valid_x, valid_y = valid_x.to(device, non_blocking=True), valid_y.to(device, non_blocking=True) + + train_size = train_x.size(0) + + # Phase 1. Architect step (Alpha) + alpha_optim.zero_grad() + architect.unrolled_backward(train_x, train_y, valid_x, valid_y, lr, w_optim) + alpha_optim.step() + + # Phase 2. Child network step (W) + w_optim.zero_grad() + logits = model(train_x) + loss = model.criterion(logits, train_y) + loss.backward() + + # Gradient clipping + nn.utils.clip_grad_norm_(model.getWeights(), w_grad_clip) + w_optim.step() + + prec1, prec5 = utils.accuracy(logits, train_y, topk=(1, 5)) + + losses.update(loss.item(), train_size) + top1.update(prec1.item(), train_size) + top5.update(prec5.item(), train_size) + + if step % print_step == 0 or step == len(train_loader) - 1: + print( + "Train: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch+1, num_epochs, step, len(train_loader)-1, losses=losses, + top1=top1, top5=top5)) + + cur_step += 1 + + print("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch+1, num_epochs, top1.avg)) + + +def validate(valid_loader, model, epoch, cur_step, num_epochs, device): + top1 = utils.AverageMeter() + top5 = utils.AverageMeter() + losses = utils.AverageMeter() + + model.eval() + + with torch.no_grad(): + for step, (valid_x, valid_y) in enumerate(valid_loader): + valid_x, valid_y = valid_x.to(device, non_blocking=True), valid_y.to(device, non_blocking=True) + + valid_size = valid_x.size(0) + + logits = model(valid_x) + loss = model.criterion(logits, valid_y) + + prec1, prec5 = utils.accuracy(logits, valid_y, topk=(1, 5)) + losses.update(loss.item(), valid_size) + top1.update(prec1.item(), valid_size) + top5.update(prec5.item(), valid_size) + + if step % print_step == 0 or step == len(valid_loader) - 1: + print( + "Validation: [{:2d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} " + "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( + epoch+1, num_epochs, step, len(valid_loader)-1, losses=losses, + top1=top1, top5=top5)) + + print("Valid: [{:2d}/{}] Final Prec@1 {:.4%}".format(epoch+1, num_epochs, top1.avg)) + + return top1.avg + + +if __name__ == "__main__": + main() diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/search_space.py b/examples/v1alpha3/nas/darts-cnn-cifar10/search_space.py new file mode 100644 index 00000000000..f1ca61eee01 --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/search_space.py @@ -0,0 +1,50 @@ +from collections import namedtuple +import torch + + +class SearchSpace(): + def __init__(self, search_space): + self.primitives = search_space + self.primitives.append("none") + + print(">>> All Primitives") + print("{}\n".format(self.primitives)) + self.genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') + + def parse(self, alpha, k): + """ + Parse continuous alpha to discrete gene. + alpha is ParameterList: + ParameterList [ + Parameter(n_edges1, n_ops), + Parameter(n_edges2, n_ops), + ... + ] + + gene is list: + [ + [('node1_ops_1', node_idx), ..., ('node1_ops_k', node_idx)], + [('node2_ops_1', node_idx), ..., ('node2_ops_k', node_idx)], + ... + ] + each node has two edges (k=2) in CNN. + """ + + gene = [] + assert self.primitives[-1] == 'none' # assume last PRIMITIVE is 'none' + + # 1) Convert the mixed op to discrete edge (single op) by choosing top-1 weight edge + # 2) Choose top-k edges per node by edge score (top-1 weight in edge) + for edges in alpha: + # edges: Tensor(n_edges, n_ops) + edge_max, primitive_indices = torch.topk(edges[:, :-1], 1) # ignore 'none' + topk_edge_values, topk_edge_indices = torch.topk(edge_max.view(-1), k) + node_gene = [] + for edge_idx in topk_edge_indices: + prim_idx = primitive_indices[edge_idx] + prim = self.primitives[prim_idx] + node_gene.append((prim, edge_idx.item())) + + gene.append(node_gene) + + return gene diff --git a/examples/v1alpha3/nas/darts-cnn-cifar10/utils.py b/examples/v1alpha3/nas/darts-cnn-cifar10/utils.py new file mode 100644 index 00000000000..b736ebbe5ea --- /dev/null +++ b/examples/v1alpha3/nas/darts-cnn-cifar10/utils.py @@ -0,0 +1,70 @@ +import torchvision.datasets as dset +import torchvision.transforms as transforms + + +class AverageMeter(): + """ Computes and stores the average and current value """ + + def __init__(self): + self.reset() + + def reset(self): + """ Reset all statistics """ + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + """ Update statistics """ + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """ Computes the precision@k for the specified values of k """ + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + + pred = pred.t() + # one-hot case + if target.ndimension() > 1: + target = target.max(1)[1] + + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(1.0 / batch_size)) + + return res + + +def get_dataset(): + dataset_cls = dset.CIFAR10 + num_classes = 10 + input_channels = 3 + + # Do preprocessing + MEAN = [0.49139968, 0.48215827, 0.44653124] + STD = [0.24703233, 0.24348505, 0.26158768] + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + + normalize = [ + transforms.ToTensor(), + transforms.Normalize(MEAN, STD) + ] + + train_transform = transforms.Compose(transf + normalize) + + train_data = dataset_cls(root="./data", train=True, download=True, transform=train_transform) + + return input_channels, num_classes, train_data diff --git a/examples/v1alpha3/nas/darts-example-gpu.yaml b/examples/v1alpha3/nas/darts-example-gpu.yaml new file mode 100644 index 00000000000..9ce615c344a --- /dev/null +++ b/examples/v1alpha3/nas/darts-example-gpu.yaml @@ -0,0 +1,85 @@ +apiVersion: "kubeflow.org/v1alpha3" +kind: Experiment +metadata: + namespace: kubeflow + name: darts-example-gpu +spec: + parallelTrialCount: 1 + maxTrialCount: 1 + maxFailedTrialCount: 1 + objective: + type: maximize + objectiveMetricName: Best-Genotype + metricsCollectorSpec: + collector: + kind: StdOut + source: + filter: + metricsFormat: + - "([\\w-]+)=(Genotype.*)" + algorithm: + algorithmName: darts + algorithmSettings: + - name: num_epoch + value: "3" + nasConfig: + graphConfig: + numLayers: 3 + operations: + - operationType: separable_convolution + parameters: + - name: filter_size + parameterType: categorical + feasibleSpace: + list: + - "3" + - operationType: dilated_convolution + parameters: + - name: filter_size + parameterType: categorical + feasibleSpace: + list: + - "3" + - "5" + - operationType: avg_pooling + parameters: + - name: filter_size + parameterType: categorical + feasibleSpace: + list: + - "3" + - operationType: max_pooling + parameters: + - name: filter_size + parameterType: categorical + feasibleSpace: + list: + - "3" + - operationType: skip_connection + trialTemplate: + goTemplate: + rawTemplate: |- + apiVersion: batch/v1 + kind: Job + metadata: + name: {{.Trial}} + namespace: {{.NameSpace}} + spec: + template: + spec: + containers: + - name: {{.Trial}} + image: docker.io/kubeflowkatib/darts-cnn-cifar10 + imagePullPolicy: Always + command: + - "python3" + - "run_trial.py" + {{- with .HyperParameters}} + {{- range .}} + - "--{{.Name}}=\"{{.Value}}\"" + {{- end}} + {{- end}} + resources: + limits: + nvidia.com/gpu: 1 + restartPolicy: Never diff --git a/pkg/suggestion/v1alpha3/nas/darts/__init__.py b/pkg/suggestion/v1alpha3/nas/darts/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pkg/suggestion/v1alpha3/nas/darts/service.py b/pkg/suggestion/v1alpha3/nas/darts/service.py new file mode 100644 index 00000000000..c5c9f247c19 --- /dev/null +++ b/pkg/suggestion/v1alpha3/nas/darts/service.py @@ -0,0 +1,110 @@ +import logging +from logging import getLogger, StreamHandler, INFO +import json + +from pkg.suggestion.v1alpha3.internal.base_health_service import HealthServicer +from pkg.apis.manager.v1alpha3.python import api_pb2 +from pkg.apis.manager.v1alpha3.python import api_pb2_grpc + + +class DartsService(api_pb2_grpc.SuggestionServicer, HealthServicer): + + def __init__(self): + super(DartsService, self).__init__() + self.is_first_run = True + + self.logger = getLogger(__name__) + FORMAT = '%(asctime)-15s Experiment %(experiment_name)s %(message)s' + logging.basicConfig(format=FORMAT) + handler = StreamHandler() + handler.setLevel(INFO) + self.logger.setLevel(INFO) + self.logger.addHandler(handler) + self.logger.propagate = False + + # TODO: Add validation + def ValidateAlgorithmSettings(self, request, context): + return api_pb2.ValidateAlgorithmSettingsReply() + + def GetSuggestions(self, request, context): + if self.is_first_run: + nas_config = request.experiment.spec.nas_config + num_layers = str(nas_config.graph_config.num_layers) + + search_space = get_search_space(nas_config.operations) + + settings_raw = request.experiment.spec.algorithm.algorithm_setting + algorithm_settings = get_algorithm_settings(settings_raw) + + search_space_json = json.dumps(search_space) + algorithm_settings_json = json.dumps(algorithm_settings) + + search_space_str = str(search_space_json).replace('\"', '\'') + algorithm_settings_str = str(algorithm_settings_json).replace('\"', '\'') + + self.is_first_run = False + + parameter_assignments = [] + for i in range(request.request_number): + + self.logger.info(">>> Generate new Darts Trial Job") + + self.logger.info(">>> Number of layers {}\n".format(num_layers)) + + self.logger.info(">>> Search Space") + self.logger.info("{}\n".format(search_space_str)) + + self.logger.info(">>> Algorithm Settings") + self.logger.info("{}\n\n".format(algorithm_settings_str)) + + parameter_assignments.append( + api_pb2.GetSuggestionsReply.ParameterAssignments( + assignments=[ + api_pb2.ParameterAssignment( + name="algorithm-settings", + value=algorithm_settings_str + ), + api_pb2.ParameterAssignment( + name="search-space", + value=search_space_str + ), + api_pb2.ParameterAssignment( + name="num-layers", + value=num_layers + ) + ] + ) + ) + + return api_pb2.GetSuggestionsReply(parameter_assignments=parameter_assignments) + + +def get_search_space(operations): + search_space = [] + + for operation in list(operations.operation): + opt_type = operation.operation_type + + if opt_type == "skip_connection": + search_space.append(opt_type) + else: + # Currently support only one Categorical parameter - filter size + opt_spec = list(operation.parameter_specs.parameters)[0] + for filter_size in list(opt_spec.feasible_space.list): + search_space.append(opt_type+"_{}x{}".format(filter_size, filter_size)) + return search_space + + +# TODO: Add more algorithm settings +def get_algorithm_settings(settings_raw): + + algorithm_settings_default = { + "num_epoch": 50 + } + + for setting in settings_raw: + s_name = setting.name + s_value = setting.value + algorithm_settings_default[s_name] = s_value + + return algorithm_settings_default