diff --git a/benchmark/python/control_flow/foreach_rnn.py b/benchmark/python/control_flow/foreach_rnn.py deleted file mode 100644 index 4ce7a429ee9d..000000000000 --- a/benchmark/python/control_flow/foreach_rnn.py +++ /dev/null @@ -1,195 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import subprocess -import mxnet as mx -from mxnet import gluon -import time -import copy - -def get_gpus(): - """ - return a list of GPUs - """ - try: - re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) - except OSError: - return [] - return range(len([i for i in re.split('\n') if 'GPU' in i])) - -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) - self.cell = cell - - def hybrid_forward(self, F, inputs, states): - out, states = F.contrib.foreach(self.cell, inputs, states) - return out - -def benchmark_rnn(cell, rnn_data, states): - ctx = rnn_data.context - num_batches = 20 - - # Imperative - cell0 = copy.deepcopy(cell) - layer0 = TestRNNLayer(cell0) - layer0.initialize(ctx=ctx) - - # Hybridize - cell1 = copy.deepcopy(cell) - cell1.hybridize() - layer1 = TestRNNLayer(cell1) - layer1.initialize(ctx=ctx) - - # Hybridize - cell2 = copy.deepcopy(cell) - layer2 = TestRNNLayer(cell2) - layer2.initialize(ctx=ctx) - layer2.hybridize() - layer2(rnn_data, states) - - # Hybridize - cell3 = copy.deepcopy(cell) - cell3.hybridize(static_alloc=True) - layer3 = TestRNNLayer(cell3) - layer3.initialize(ctx=ctx) - - tic = time.time() - for i in range(num_batches): - res0 = layer0(rnn_data, states) - mx.nd.waitall() - print("Imperative inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res1 = layer1(rnn_data, states) - mx.nd.waitall() - print("Hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res3 = layer3(rnn_data, states) - mx.nd.waitall() - print("Static-hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res2 = layer2(rnn_data, states) - mx.nd.waitall() - print("Hybrid inference takes " + str(time.time() - tic)) - - layer2.export("foreach_rnn") - symnet = mx.symbol.load('foreach_rnn-symbol.json') - args1 = {} - params = layer2.collect_params() - for key in params.keys(): - args1[key] = params[key].data() - args1['data0'] = rnn_data - for i in range(len(states)): - args1['data' + str(i + 1)] = states[i] - exe = symnet.bind(ctx=ctx, args=args1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=False) - mx.nd.waitall() - print("Symbol inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res0 = layer0(rnn_data, states) - res0.backward() - mx.nd.waitall() - print("Imperative training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res1 = layer1(rnn_data, states) - res1.backward() - mx.nd.waitall() - print("Hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res3 = layer3(rnn_data, states) - res3.backward() - mx.nd.waitall() - print("Static-hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res2 = layer2(rnn_data, states) - res2.backward() - mx.nd.waitall() - print("Hybrid training takes " + str(time.time() - tic)) - - # gradients for the backward of the foreach symbol - args_grad1 = {} - for key in args1.keys(): - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) - exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=True) - exe.backward(res2) - mx.nd.waitall() - print("Symbol training takes " + str(time.time() - tic)) - print("") - -if __name__ == '__main__': - ndim = 512 - seq_len = 100 - batch_sizes = [1, 32] - cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), - gluon.rnn.GRUCell(ndim, prefix='rnn_'), - gluon.rnn.LSTMCell(ndim, prefix='rnn_')] - ctxs = [mx.cpu(0), mx.gpu(0)] - for cell in cells: - for ctx in ctxs: - for batch_size in batch_sizes: - if len(get_gpus()) == 0 and ctx == mx.gpu(0): - continue - if isinstance(cell, gluon.rnn.RNNCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - elif isinstance(cell, gluon.rnn.GRUCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - elif isinstance(cell, gluon.rnn.LSTMCell): - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, ndim), - ctx=mx.cpu(0)) - states = [] - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - states.append(mx.nd.normal(loc=0, scale=1, shape=(batch_size, ndim), - ctx=mx.cpu(0))) - if ctx == mx.gpu(0): - dev = "GPU" - else: - dev = "CPU" - print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, - batch_size)) - benchmark_rnn(cell, rnn_data, states) diff --git a/benchmark/python/control_flow/while_loop_rnn.py b/benchmark/python/control_flow/while_loop_rnn.py deleted file mode 100644 index 42aaee5840dd..000000000000 --- a/benchmark/python/control_flow/while_loop_rnn.py +++ /dev/null @@ -1,213 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# Code borrowed from ./benchmark/python/control_flow/foreach_rnn.py - -import subprocess -import mxnet as mx -from mxnet import gluon -import time -import copy - -def get_gpus(): - """ - return a list of GPUs - """ - try: - re = subprocess.check_output(["nvidia-smi", "-L"], universal_newlines=True) - except OSError: - return [] - return range(len([i for i in re.split('\n') if 'GPU' in i])) - -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell, length, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) - self.length = length - self.cell = cell - - def hybrid_forward(self, F, inputs, states): - def _func(*states): - i = states[0] - s = states[1: ] - data = inputs.take(i).squeeze(axis=0) - out, new_s = self.cell(data, s) - new_s = [i + 1] + new_s - return out, new_s - out, states = F.contrib.while_loop( - cond=lambda i, *_: i < self.length, - func=_func, - loop_vars=states, - max_iterations=self.length, - ) - return out + states - -def benchmark_rnn(cell, rnn_data, states, length): - ctx = rnn_data.context - num_batches = 20 - - # Imperative - cell0 = copy.deepcopy(cell) - layer0 = TestRNNLayer(cell0, length) - layer0.initialize(ctx=ctx) - - # Hybrid-cell - cell1 = copy.deepcopy(cell) - cell1.hybridize() - layer1 = TestRNNLayer(cell1, length) - layer1.initialize(ctx=ctx) - - # Hybrid - cell2 = copy.deepcopy(cell) - layer2 = TestRNNLayer(cell2, length) - layer2.initialize(ctx=ctx) - layer2.hybridize() - layer2(rnn_data, states) - - # Static-hybrid-cell - cell3 = copy.deepcopy(cell) - cell3.hybridize(static_alloc=True) - layer3 = TestRNNLayer(cell3, length) - layer3.initialize(ctx=ctx) - - tic = time.time() - for i in range(num_batches): - res0 = layer0(rnn_data, states) - mx.nd.waitall() - print("Imperative inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res1 = layer1(rnn_data, states) - mx.nd.waitall() - print("Hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res3 = layer3(rnn_data, states) - mx.nd.waitall() - print("Static-hybrid-cell inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - res2 = layer2(rnn_data, states) - mx.nd.waitall() - print("Hybrid inference takes " + str(time.time() - tic)) - - layer2.export("while_loop_rnn") - symnet = mx.symbol.load('while_loop_rnn-symbol.json') - args1 = {} - params = layer2.collect_params() - for key in params.keys(): - args1[key] = params[key].data() - args1['data0'] = rnn_data - for i in range(len(states)): - args1['data' + str(i + 1)] = states[i] - exe = symnet.bind(ctx=ctx, args=args1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=False) - mx.nd.waitall() - print("Symbol inference takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res0 = layer0(rnn_data, states) - res0[0].backward() - mx.nd.waitall() - print("Imperative training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res1 = layer1(rnn_data, states) - res1[0].backward() - mx.nd.waitall() - print("Hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res3 = layer3(rnn_data, states) - res3[0].backward() - mx.nd.waitall() - print("Static-hybrid-cell training takes " + str(time.time() - tic)) - - tic = time.time() - for i in range(num_batches): - with mx.autograd.record(): - res2 = layer2(rnn_data, states) - res2[0].backward() - mx.nd.waitall() - print("Hybrid training takes " + str(time.time() - tic)) - - # gradients for the backward of the while_loop symbol - args_grad1 = {} - for key in args1.keys(): - if key != "data1": - args_grad1[key] = mx.nd.empty(args1[key].shape, ctx=ctx) - exe = symnet.bind(ctx=ctx, args=args1, args_grad=args_grad1) - tic = time.time() - for i in range(num_batches): - exe.forward(is_train=True) - exe.backward(res2) - mx.nd.waitall() - print("Symbol training takes " + str(time.time() - tic)) - print("") - -if __name__ == '__main__': - def _zeros(shape): - return mx.nd.zeros(shape=shape, ctx=mx.cpu(0)) - def _array(shape): - return mx.nd.normal(loc=0.0, scale=1.0, shape=shape, ctx=mx.cpu(0)) - ndim = 512 - seq_len = 100 - batch_sizes = [1, 32] - cells = [gluon.rnn.RNNCell(ndim, prefix='rnn_'), - gluon.rnn.GRUCell(ndim, prefix='rnn_'), - gluon.rnn.LSTMCell(ndim, prefix='rnn_')] - ctxs = [mx.cpu(0), mx.gpu(0)] - for cell in cells: - for ctx in ctxs: - for batch_size in batch_sizes: - if len(get_gpus()) == 0 and ctx == mx.gpu(0): - continue - if isinstance(cell, gluon.rnn.RNNCell): - rnn_data = _array((seq_len, batch_size, ndim)) - states = [ - _zeros((1, )), - _array((batch_size, ndim)), - ] - if isinstance(cell, gluon.rnn.GRUCell): - rnn_data = _array((seq_len, batch_size, ndim)) - states = [ - _zeros((1, )), - _array((batch_size, ndim)), - ] - elif isinstance(cell, gluon.rnn.LSTMCell): - rnn_data = _array((seq_len, batch_size, ndim)) - states = [ - _zeros((1, )), - _array((batch_size, ndim)), - _array((batch_size, ndim)), - ] - if ctx == mx.gpu(0): - dev = "GPU" - else: - dev = "CPU" - print("Benchmark {} in {} (batch size: {})".format(cell._alias(), dev, batch_size)) - benchmark_rnn(cell, rnn_data, states, seq_len) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 0cf8724de301..80d8ef23b459 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` quantize foreach while_loop + ifelse ``` ## API Reference diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index ba43f2d6633c..96ce7987d800 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -54,6 +54,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` quantize foreach while_loop + ifelse ``` ## API Reference diff --git a/python/mxnet/ndarray/contrib.py b/python/mxnet/ndarray/contrib.py index b67cf5a55daf..b7b63c4e10e6 100644 --- a/python/mxnet/ndarray/contrib.py +++ b/python/mxnet/ndarray/contrib.py @@ -28,7 +28,7 @@ except ImportError: pass -__all__ = ["rand_zipfian", "foreach", "while_loop"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] # pylint: disable=line-too-long def rand_zipfian(true_classes, num_sampled, range_max, ctx=None): @@ -192,7 +192,6 @@ def check_input(inputs, in_type, msg): outputs = outputs[0] return (outputs, states) - def while_loop(cond, func, loop_vars, max_iterations=None): """Run a while loop with user-defined computation and loop condition. @@ -363,3 +362,97 @@ def _func_wrapper(loop_vars): [" Step %d, shape is %s" % (i, str(x.shape)) for i, x in enumerate(items)] )) return stacked_outputs, list(loop_vars) + +def ifelse(cond, then_func, else_func, inputs): + """Run a if-then-else using user-defined condition and computation + + This operator simulates a if-like branch which chooses to do one of + the two customized computations according to the specified condition. + + `inputs` is a list of NDArrays on which the condition and computations reply on. + + `cond` is a user-defined function, used as the if condition. + It consumes `inputs`, and produces a scalar MXNet NDArray, + indicating which branch of computation should be used. + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => NDArray`. + + `then_func` is a user-defined function, used as computation of the then branch. + It consumes `inputs`, and produces `outputs`. + The `then_func` is variadic, and its signature should be + `then_func(*loop_vars) => List[NDArray]`. + + `else_func` is a user-defined function, used as computation of the else branch. + It also consumes `inputs`, and produces `outputs`. + The `else_func` is variadic, and its signature should be + `else_func(*loop_vars) => List[NDArray]`. + + The `outputs` produces by `then_func` and `else_func` should have the same number + of elements, all of which should be in the same shape, of the same dtype and stype. + + This function returns a list of NDArrays, representing the computation result. + + Parameters + ---------- + cond: a Python function. + The branch condition. + then_func: a Python function. + The computation to be executed if `cond` is true. + else_func: a Python function. + The computation to be executed if `cond` is false. + inputs: list of NDArrays. + The variables fed to `cond`, `then_func` and `else_func`. + + Returns + ------- + outputs: a list of NDArrays, representing the result of computation. + + Examples + -------- + >>> cond = lambda a, b: a * b < 5 + >>> then_func = lambda a, b: (a + 5) * (b + 5) + >>> else_func = lambda a, b: (a - 5) * (b - 5) + >>> inputs = (mx.nd.array([1]), mx.nd.array([2])) + >>> outputs = mx.nd.contrib.ifelse(cond, then_func, else_func, inputs) + >>> outputs[0] + [42.] + + """ + def _to_python_scalar(inputs, type_, name): + """Converts "inputs", possibly typed mxnet NDArray, a numpy ndarray, other python types, + to the given type + """ + if hasattr(inputs, "asscalar"): + inputs = inputs.asscalar() + try: + inputs = type_(inputs) + except: + raise ValueError("Cannot convert %s to python %s" % (name, type_.__name__)) + return inputs + + def _to_ndarray_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet NDArray, a list of mxnet NDArray, + a tuple of mxnet NDArray, into a tuple of NDArray + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, ndarray.NDArray): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + for item in inputs: + if not isinstance(item, ndarray.NDArray): + raise ValueError("%s must be an NDArray, or a tuple or list of NDArrays" % (name, )) + return inputs + + inputs = _to_ndarray_tuple(inputs, "inputs") + if len(inputs) == 0: + raise ValueError("inputs should contain at least one element") + branch = _to_python_scalar(cond(*inputs), bool, "Return value of cond") + if branch: + outputs = then_func(*inputs) + outputs = _to_ndarray_tuple(outputs, "outputs of then_func") + else: + outputs = else_func(*inputs) + outputs = _to_ndarray_tuple(outputs, "outputs of else_func") + return list(outputs) diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 2c11921383c8..33932ba5ad94 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -34,7 +34,7 @@ from ..base import SymbolHandle, _as_list from ..attribute import AttrScope -__all__ = ["rand_zipfian", "foreach", "while_loop"] +__all__ = ["rand_zipfian", "foreach", "while_loop", "ifelse"] def rand_zipfian(true_classes, num_sampled, range_max): """Draw random samples from an approximately log-uniform or Zipfian distribution. @@ -556,3 +556,154 @@ def _union_inputs(*graphs): outputs = [result[i] for i in range(num_out_data)] final_loop_vars = [result[i] for i in range(num_out_data, num_outputs)] return outputs, final_loop_vars + +def ifelse(cond, then_func, else_func, inputs, name="ifelse"): + """Run a if-then-else using user-defined condition and computation + + This operator simulates a if-like branch which chooses to do one of + the two customized computations according to the specified condition. + + `inputs` is a list of Symbols on which the condition and computations reply on. + + `cond` is a user-defined function, used as the if condition. + It consumes `inputs`, and produces a scalar MXNet symbol, + indicating which branch of computation should be used. + The `cond` is variadic, and its signature should be + `cond(*loop_vars) => Symbol`. + + `then_func` is a user-defined function, used as computation of the then branch. + It consumes `inputs`, and produces `outputs`. + The `then_func` is variadic, and its signature should be + `then_func(*loop_vars) => List[Symbol]`. + + `else_func` is a user-defined function, used as computation of the else branch. + It also consumes `inputs`, and produces `outputs`. + The `else_func` is variadic, and its signature should be + `else_func(*loop_vars) => List[Symbol]`. + + The `outputs` produces by `then_func` and `else_func` should have the same number + of elements, all of which should be in the same shape, of the same dtype and stype. + + This function returns a list of symbols, representing the computation result. + + Parameters + ---------- + cond: a Python function. + The branch condition. + then_func: a Python function. + The computation to be executed if `cond` is true. + else_func: a Python function. + The computation to be executed if `cond` is false. + inputs: list of Symbols. + The variables fed to `cond`, `then_func` and `else_func`. + + Returns + ------- + outputs: a list of Symbols, representing the result of computation. + + Examples + -------- + >>> cond = lambda a, b: a * b < 5 + >>> then_func = lambda a, b: (a + 5) * (b + 5) + >>> else_func = lambda a, b: (a - 5) * (b - 5) + >>> inputs = (mx.sym.var('a'), mx.sym.var('b')) + >>> outputs = mx.sym.contrib.ifelse(cond, then_func, else_func, inputs) + """ + def _to_symbol_tuple(inputs, name): + """Converts "inputs", possibly a single mxnet Symbol, a list of mxnet Symbol, + a tuple of mxnet Symbol, into a tuple of Symbol + """ + if isinstance(inputs, list): + inputs = tuple(inputs) + if isinstance(inputs, Symbol): + inputs = (inputs, ) + if not isinstance(inputs, tuple): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + for item in inputs: + if not isinstance(item, Symbol): + raise ValueError("%s must be a Symbol, or a tuple or list of Symbol" % (name, )) + return inputs + + def _create_subgraph(graph_vars, graph_func, subgraph_name): + with AttrScope(__subgraph_name__=subgraph_name): + # create new variables with the same name, + # them feed them to the given func + new_graph_vars = [symbol.var(sym.name) for sym in graph_vars] + outputs = graph_func(*new_graph_vars) + outputs = _to_symbol_tuple(outputs, "outputs") + num_outputs = len(outputs) + # nnvm cut-graph does not allow inputs and outputs overlap + # so we calculate the name of inputs, and copy outputs once it overlaps with inputs + all_input_names = symbol.Group(outputs).list_inputs() + make_identity = lambda x: symbol.op.identity(x) if x.name in all_input_names else x + # group all outputs of graph_func + graph = symbol.Group(list(map(make_identity, outputs))) + return graph, num_outputs + + def _union_inputs(*graphs): + # Given a list of graphs, each whose inputs are either from input_vars or other variables. + # 1) calculate a list `inputs`, the union of their inputs. + # 2) for each graph, determine in which indices their inputs reside in `inputs` + # 3) for each variable in the input of `graph`, find which index it is + inputs = [] # List[Symbol], result of 1) + locs = [] # List[Tuple(List[Int], List[Int])], a list of tuples, + # where tuples are results of 2) and 3) + input_id_to_loc = {} # Dict[int, int], given id(sym), input_id_to_loc maps it + # to a `loc`, where inputs[loc] = sym + for graph in graphs: + # input_syms: all inputs to the `graph` + name_to_input_syms = {sym.name: sym for sym in _get_graph_inputs(graph)} + # some input_vars are inputs to `graph`, some are not + name_to_input_vars = {sym.name: sym for sym in inputs} + # other inputs to `graph` created by cut_graph + name_to_cut_g_syms = {sym.list_outputs()[0]: sym for sym in _cut_subgraph(graph)} + # collect arguments for each subgraph + input_locs = [] # results from the second step + for name in graph.list_inputs(): + assert name in name_to_input_syms # it should obviously hold + # name -> sym + if name in name_to_input_vars: + sym = name_to_input_vars[name] + elif name in name_to_cut_g_syms: + sym = name_to_cut_g_syms[name] + else: + sym = copy.deepcopy(name_to_input_syms[name]) + # do 2), and 1) is implicitly done + if id(sym) in input_id_to_loc: + loc = input_id_to_loc[id(sym)] + else: + loc = len(input_id_to_loc) + inputs.append(sym) + input_id_to_loc[id(sym)] = loc + input_locs.append(loc) + locs.append(input_locs) + return inputs, locs + inputs = _to_symbol_tuple(inputs, "inputs") + if len(inputs) == 0: + raise ValueError("loop_vars should contain at least one element") + # create graph for `cond' + cond_g, num_outputs = _create_subgraph(inputs, cond, name + "_cond") + if num_outputs != 1: + raise ValueError("cond should always produce a single output") + # create graph for `then` + then_g, then_num_outputs = _create_subgraph(inputs, then_func, name + "_then") + # create graph for `else` + else_g, else_num_outputs = _create_subgraph(inputs, else_func, name + "_else") + if then_num_outputs != else_num_outputs: + raise ValueError("Number of outputs differs between then-branch and else-branch") + # find symbols used in either cond_g or func_g + input_syms, (cond_input_locs, then_input_locs, else_input_locs) = \ + _union_inputs(cond_g, then_g, else_g) + result = symbol._internal._ifelse( + # [cond, then_g, else_g, *input_syms] + cond_g, + then_g, + else_g, + *input_syms, + cond_input_locs=cond_input_locs, + then_input_locs=then_input_locs, + else_input_locs=else_input_locs, + num_outputs=then_num_outputs + ) + result = _to_symbol_tuple(result, "result") + return list(result) diff --git a/src/operator/control_flow.cc b/src/operator/control_flow.cc index b00ed9b19d8c..261bd5070f7d 100644 --- a/src/operator/control_flow.cc +++ b/src/operator/control_flow.cc @@ -508,6 +508,18 @@ struct WhileLoopParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(func_var_locs) .describe("The locations of loop_vars among func's inputs."); } + template + bool sync_in_out(std::vector *in, + std::vector *out, + std::function is_empty) const { + for (int i = this->num_out_data; i < this->num_outputs; ++i) { + // each out->at(i) is a params, loop_var + T &x = in->at(this->func_input_locs[this->func_var_locs[i - this->num_out_data]]); + T &y = out->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; + } }; // struct WhileLoopParam DMLC_REGISTER_PARAMETER(WhileLoopParam); @@ -540,84 +552,8 @@ class WhileLoopState: public LoopState { } } } - template - static void extract_by_loc(const std::vector &array, - const nnvm::Tuple input_locs, - std::vector *out) { - out->clear(); - out->reserve(input_locs.ndim()); - for (dim_t i : input_locs) { - out->push_back(array[i]); - } - } - static bool is_shape_udf(const TShape &x) { - return x.ndim() == 0 || x.Size() == 0; - } - static bool is_stype_udf(const int &x) { - return x == exec::kBadStorageID; - } - static bool is_type_udf(const int &x) { - return x == -1; - } - template - static bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { - if (*x == *y || (x_empty && y_empty)) { - return true; - } - if (!x_empty && !y_empty) { - return false; - } - if (x_empty) { - *x = *y; - } - if (y_empty) { - *y = *x; - } - return true; - } - template - static bool sync_in_in(const nnvm::Tuple &input_locs, - std::vector *in, - std::vector *subg_in, - std::function is_empty) { - for (size_t i = 0; i < input_locs.ndim(); ++i) { - T &x = in->at(input_locs[i]); - T &y = subg_in->at(i); - fill_value(&x, &y, is_empty(x), is_empty(y)); - } - return true; - } - template - static bool sync_in_out(const WhileLoopParam& params, - std::vector *in, - std::vector *out, - std::function is_empty) { - for (int i = params.num_out_data; i < params.num_outputs; ++i) { - // each out->at(i) is a params, loop_var - T &x = in->at(params.func_input_locs[params.func_var_locs[i - params.num_out_data]]); - T &y = out->at(i); - fill_value(&x, &y, is_empty(x), is_empty(y)); - } - return true; - } }; -template -T _asscalar(const NDArray &a) { - CHECK_EQ(a.shape().Size(), 1U); - T data; - a.SyncCopyToCPU(&data, 1U); - return data; -} - -bool as_bool_scalar(const NDArray &a) { - MSHADOW_TYPE_SWITCH(a.dtype(), DType, { - return static_cast(_asscalar(a)); - }); - LOG(FATAL) << "Unknown dtype"; - return false; -} - static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, const OpContext& ctx, const std::vector& inputs, @@ -648,13 +584,13 @@ static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr, CHECK_EQ(params.max_iterations, outputs[i].shape()[0]); // construct inputs and outputs for cond std::vector cond_inputs, cond_outputs = {NDArray()}; - WhileLoopState::extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); std::vector cond_input_ptr, cond_output_ptr; to_ptr_vec(cond_inputs, &cond_input_ptr); to_ptr_vec(cond_outputs, &cond_output_ptr); // construct inputs and outputs for func std::vector func_inputs, func_outputs(outputs.size()); - WhileLoopState::extract_by_loc(inputs, params.func_input_locs, &func_inputs); + extract_by_loc(inputs, params.func_input_locs, &func_inputs); for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) { state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); if (!as_bool_scalar(*cond_output_ptr[0])) { @@ -716,8 +652,8 @@ static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr, } std::vector outputs; std::vector req; - WhileLoopState::extract_by_loc(_outputs, params.func_input_locs, &outputs); - WhileLoopState::extract_by_loc(_req, params.func_input_locs, &req); + extract_by_loc(_outputs, params.func_input_locs, &outputs); + extract_by_loc(_req, params.func_input_locs, &req); if (state.n_iterations == 0) { for (int i = params.num_out_data; i < params.num_outputs; ++i) { int j = params.func_var_locs[i - params.num_out_data]; @@ -796,7 +732,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, std::vector *out_shape) { using nnvm::ShapeVector; const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = WhileLoopState::is_shape_udf; + static const std::function is_udf = is_shape_udf; // sanity checks CHECK_EQ(in_shape->size() + 2U, (size_t) params.num_args); CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); @@ -811,7 +747,7 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, // create subg_in ShapeVector subg_in; ShapeVector &subg_out = *_subg_out; - WhileLoopState::extract_by_loc(*in_shape, input_locs, &subg_in); + extract_by_loc(*in_shape, input_locs, &subg_in); // create an indexed graph nnvm::Graph g; g.outputs = subg->outputs; @@ -884,35 +820,35 @@ static bool WhileLoopShape(const nnvm::NodeAttrs& attrs, }; ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] ShapeVector func_out_shape(params.num_outputs); - CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, params.cond_input_locs, 0, false); - CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); bool succ_1 = infer_subg(attrs.subgraphs[1], &func_out_shape, \ params.func_input_locs, params.num_out_data, true); - CHECK(WhileLoopState::sync_in_out(params, in_shape, out_shape, is_udf)); + CHECK(params.sync_in_out(in_shape, out_shape, is_udf)); return succ_0 && succ_1; } static bool WhileLoopType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = WhileLoopState::is_type_udf; + static const std::function is_udf = is_type_udf; CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args); CHECK_EQ(out_type->size(), (size_t) params.num_outputs); CHECK_EQ(attrs.subgraphs.size(), 2U); CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); std::vector cond_in_type; std::vector func_in_type; - WhileLoopState::extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); - WhileLoopState::extract_by_loc(*in_type, params.func_input_locs, &func_in_type); + extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + extract_by_loc(*in_type, params.func_input_locs, &func_in_type); std::vector cond_out_type = {0}; - CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); + CHECK(params.sync_in_out(in_type, out_type, is_udf)); bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); - CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + CHECK(params.sync_in_out(in_type, out_type, is_udf)); + CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type); - CHECK(WhileLoopState::sync_in_out(params, in_type, out_type, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); + CHECK(params.sync_in_out(in_type, out_type, is_udf)); + CHECK(sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf)); return succ_0 && succ_1; } @@ -922,28 +858,28 @@ static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { const WhileLoopParam& params = nnvm::get(attrs.parsed); - static const std::function is_udf = WhileLoopState::is_stype_udf; + static const std::function is_udf = is_stype_udf; CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args); CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); CHECK_EQ(attrs.subgraphs.size(), 2U); CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); std::vector cond_in_attrs; std::vector func_in_attrs; - WhileLoopState::extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); - WhileLoopState::extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); + extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs); std::vector cond_out_attrs = {kDefaultStorage}; DispatchMode cond_mode = DispatchMode::kUndefined; DispatchMode func_mode = DispatchMode::kUndefined; *dispatch_mode = DispatchMode::kFComputeEx; - CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); + CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf)); bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ &cond_mode, &cond_in_attrs, &cond_out_attrs); - CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf)); + CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ &func_mode, &func_in_attrs, out_attrs); - CHECK(WhileLoopState::sync_in_out(params, in_attrs, out_attrs, is_udf)); - CHECK(WhileLoopState::sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); + CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf)); + CHECK(sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf)); return succ_0 && succ_1; } @@ -977,6 +913,342 @@ WhileLoopGradient(const nnvm::NodePtr& n, const std::vector& og return entries; } +struct IfelseParam : public dmlc::Parameter { + int num_args; + int num_outputs; + nnvm::Tuple cond_input_locs; + nnvm::Tuple then_input_locs; + nnvm::Tuple else_input_locs; + DMLC_DECLARE_PARAMETER(IfelseParam) { + DMLC_DECLARE_FIELD(num_args).set_lower_bound(3) + .describe("Number of input arguments, including cond, then and else as three symbol inputs."); + DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1) + .describe("The number of outputs of the subgraph."); + DMLC_DECLARE_FIELD(cond_input_locs) + .describe("The locations of cond's inputs in the given inputs."); + DMLC_DECLARE_FIELD(then_input_locs) + .describe("The locations of then's inputs in the given inputs."); + DMLC_DECLARE_FIELD(else_input_locs) + .describe("The locations of else's inputs in the given inputs."); + } +}; // struct IfelseParam + +DMLC_REGISTER_PARAMETER(IfelseParam); + +class IfelseState { + public: + IfelseParam params; + CachedOpPtr cond_op; + LoopState then_branch; + LoopState else_branch; + int branch_selection; // 1 if then branch; 0 if else branch; -1 if undefined + + IfelseState(const IfelseParam ¶ms, + const Symbol &cond, + const Symbol &then_sym, + const Symbol &else_sym): + params(params), + cond_op(LoopState::MakeSharedOp(cond)), + then_branch(then_sym), + else_branch(else_sym), + branch_selection(-1) { + } +}; + +static void IfelseComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // The argument `inputs' are loop_vars and other inputs + // loop_vars are stored in stored in `loop_vars_locs' + // The argument `outputs' are output and new_loop_vars + // [0: num_out_data) are outputs at each step. + // [num_out_data: ) are new_loop_vars + IfelseState &state = state_ptr.get_state(); + const IfelseParam& params = state.params; + // a helper function, converting std::vector to std::vector + const auto to_ptr_vec = [](std::vector &in, std::vector *out) { + out->clear(); + out->reserve(in.size()); + std::transform(std::begin(in), + std::end(in), + std::back_inserter(*out), + [](NDArray &a) {return &a;}); + }; + // sanity checks + CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), (size_t) params.num_outputs); + CHECK_EQ(outputs.size(), req.size()); + // construct inputs and outputs for cond + std::vector cond_inputs; + std::vector cond_outputs = {NDArray()}; + std::vector cond_input_ptr; + std::vector cond_output_ptr; + extract_by_loc(inputs, params.cond_input_locs, &cond_inputs); + to_ptr_vec(cond_inputs, &cond_input_ptr); + to_ptr_vec(cond_outputs, &cond_output_ptr); + int &branch_selection = state.branch_selection; + // run cond + state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr); + branch_selection = as_bool_scalar(*cond_output_ptr[0]); + // select the right branch + const nnvm::Tuple &func_input_locs = branch_selection + ? params.then_input_locs + : params.else_input_locs; + LoopState &loop_state = branch_selection + ? state.then_branch + : state.else_branch; + // extract inputs for the branch + std::vector func_inputs; + extract_by_loc(inputs, func_input_locs, &func_inputs); + loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad); +} + +static void IfelseGradComputeExCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& _req, + const std::vector& outputs) { + IfelseState &state = state_ptr.get_state(); + const IfelseParam& params = state.params; + // sanity checks + CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args); + CHECK_EQ(outputs.size(), _req.size()); + // select the right branch + int branch_selection = state.branch_selection; + CHECK_NE(branch_selection, -1); + const nnvm::Tuple &func_input_locs = branch_selection + ? params.then_input_locs + : params.else_input_locs; + LoopState &loop_state = branch_selection + ? state.then_branch + : state.else_branch; + // construct parameters + std::vector ograds(inputs.begin(), inputs.begin() + params.num_outputs); + std::vector req; + extract_by_loc(_req, func_input_locs, &req); + std::vector igrads; + extract_by_loc(outputs, func_input_locs, &igrads); + loop_state.Backward(0, ograds, req, igrads); + loop_state.Cleanup(); +} + +static bool IfelseShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using nnvm::ShapeVector; + const IfelseParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = is_shape_udf; + // sanity checks + CHECK_EQ(in_shape->size() + 3U, (size_t) params.num_args); + CHECK_EQ(out_shape->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 3U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); + // infer shape for cond, then and else + auto infer_subg = [¶ms, in_shape, out_shape](std::shared_ptr subg, + ShapeVector *_subg_out, + const nnvm::Tuple &input_locs, + bool fill_out_shape) { + // create subg_in + ShapeVector subg_in; + ShapeVector &subg_out = *_subg_out; + extract_by_loc(*in_shape, input_locs, &subg_in); + // create an indexed graph + nnvm::Graph g; + g.outputs = subg->outputs; + const auto& idx = g.indexed_graph(); + // get input nodes + const auto &input_nids = idx.input_nodes(); + // sanity checks + CHECK_EQ(input_nids.size(), subg_in.size()); + CHECK_EQ(g.outputs.size(), subg_out.size()); + CHECK_EQ(idx.input_nodes().size(), subg_in.size()); + CHECK_EQ(idx.outputs().size(), subg_out.size()); + // create empty shapes for inference + ShapeVector shapes(idx.num_node_entries()); + // copy subg_in into shapes + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + shapes[eid] = subg_in[i]; + } + // copy subg_out into shapes + for (size_t i = 0; i < subg_out.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + shapes[eid] = subg_out[i]; + } + // copy done, call InferShape + g.attrs["shape"] = std::make_shared(std::move(shapes)); + g = exec::InferShape(std::move(g)); + // now `shapes' won't be used anymore, use new_shapes instead + const auto& new_shapes = g.GetAttr("shape"); + // copy subg_in back to in_shape + for (size_t i = 0; i < subg_in.size(); ++i) { + auto eid = idx.entry_id(input_nids[i], 0); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*in_shape, input_locs[i], g_out_shape); + } + if (!fill_out_shape) { + return true; + } + // copy subg_out back to out_shape + for (size_t i = 0; i < g.outputs.size(); ++i) { + auto eid = idx.entry_id(g.outputs[i]); + auto g_out_shape = new_shapes[eid]; + if (g_out_shape.ndim() == 0 || g_out_shape.Size() == 0) { + // when the shape is not fully inferred + continue; + } + SHAPE_ASSIGN_CHECK(*out_shape, i, g_out_shape); + } + return g.GetAttr("shape_num_unknown_nodes") == 0; + }; + ShapeVector cond_out_shape{TShape(1U)}; // this means: [(1, )] + ShapeVector then_out_shape(params.num_outputs); + ShapeVector else_out_shape(params.num_outputs); + bool succ_0 = infer_subg(attrs.subgraphs[0], &cond_out_shape, \ + params.cond_input_locs, false); + bool succ_1 = infer_subg(attrs.subgraphs[1], &then_out_shape, \ + params.then_input_locs, true); + bool succ_2 = infer_subg(attrs.subgraphs[2], &else_out_shape, \ + params.else_input_locs, true); + return succ_0 && succ_1 && succ_2; +} + +static bool IfelseType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + const IfelseParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = is_type_udf; + CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args); + CHECK_EQ(out_type->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 3U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); + std::vector cond_in_type; + std::vector then_in_type; + std::vector else_in_type; + extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type); + extract_by_loc(*in_type, params.then_input_locs, &then_in_type); + extract_by_loc(*in_type, params.else_input_locs, &else_in_type); + std::vector cond_out_type = {0}; + bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type); + CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf)); + bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, out_type); + CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf)); + bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, out_type); + CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf)); + return succ_0 && succ_1 && succ_2; +} + +static bool IfelseStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + static const std::function is_udf = is_stype_udf; + CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args); + CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs); + CHECK_EQ(attrs.subgraphs.size(), 3U); + CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U); + CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size()); + std::vector cond_in_attrs; + std::vector then_in_attrs; + std::vector else_in_attrs; + extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs); + extract_by_loc(*in_attrs, params.then_input_locs, &then_in_attrs); + extract_by_loc(*in_attrs, params.else_input_locs, &else_in_attrs); + std::vector cond_out_attrs = {kDefaultStorage}; + DispatchMode cond_mode = DispatchMode::kUndefined; + DispatchMode then_mode = DispatchMode::kUndefined; + DispatchMode else_mode = DispatchMode::kUndefined; + *dispatch_mode = DispatchMode::kFComputeEx; + bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \ + &cond_mode, &cond_in_attrs, &cond_out_attrs); + CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf)); + bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \ + &then_mode, &then_in_attrs, out_attrs); + CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf)); + bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \ + &else_mode, &else_in_attrs, out_attrs); + CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf)); + return succ_0 && succ_1 && succ_2; +} + +static bool BackwardIfelseStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args); + CHECK_EQ(attrs.subgraphs.size(), 3U); + static const std::function is_udf = is_stype_udf; + auto sub_pass = [&](const std::shared_ptr &subg, const nnvm::Tuple &input_locs) { + // A. first construct subg_in_attrs + // need subg_in_attrs as subg_bwd_out (copy), subg_fwd_in (extract), subg_fwd_out (copy) + std::vector subg_in_attrs; + size_t num_elts = params.num_outputs * 2 + input_locs.ndim(); + subg_in_attrs.reserve(num_elts); + // part 1. subg_bwd_out (copy) + subg_in_attrs.insert(subg_in_attrs.end(), + in_attrs->begin(), + in_attrs->begin() + params.num_outputs); + // part 2. subg_fwd_in (extract) + std::vector fwd_in(in_attrs->begin() + params.num_outputs, + in_attrs->begin() + params.num_outputs + params.num_args - 3); + std::vector subg_fwd_in; + extract_by_loc(fwd_in, input_locs, &subg_fwd_in); + subg_in_attrs.insert(subg_in_attrs.end(), + subg_fwd_in.begin(), + subg_fwd_in.end()); + // part 3. subg_fwd_out (copy) + subg_in_attrs.insert(subg_in_attrs.end(), + in_attrs->begin() + params.num_outputs + params.num_args - 3, + in_attrs->end()); + // check correctness of the number of elements + CHECK_EQ(subg_in_attrs.size(), num_elts); + // B. then we construct subg_out_attrs by extracting from out_attrs + std::vector subg_out_attrs; + extract_by_loc(*out_attrs, input_locs, &subg_out_attrs); + // then we construct the subgraph and do inference + CachedOp op(*subg, {}); + bool ret = op.BackwardStorageType(attrs, dev_mask, dispatch_mode, \ + &subg_in_attrs, &subg_out_attrs); + CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf)); + return ret; + }; + bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs); + bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs); + return succ_0 && succ_1; +} + +static OpStatePtr CreateIfelseState(const NodeAttrs& attrs, + Context ctx, + const std::vector& ishape, + const std::vector& itype) { + const IfelseParam& params = nnvm::get(attrs.parsed); + return OpStatePtr::Create( + params, + *attrs.subgraphs[0], + *attrs.subgraphs[1], + *attrs.subgraphs[2]); +} + +static std::vector +IfelseGradient(const nnvm::NodePtr& n, const std::vector& ograds) { + ElemwiseGradUseInOut fgrad{"_backward_ifelse"}; + std::vector entries = fgrad(n, ograds); + entries[0].node->attrs.subgraphs = n->attrs.subgraphs; + return entries; +} + NNVM_REGISTER_OP(_foreach) .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation") .set_attr_parser(ParamParser) @@ -1100,5 +1372,68 @@ NNVM_REGISTER_OP(_backward_while_loop) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU) .set_attr("FStatefulComputeEx", WhileLoopGradComputeExCPU); +NNVM_REGISTER_OP(_ifelse) +.MXNET_DESCRIBE("Run a if-then-else using user-defined condition and computation") +.set_attr_parser(ParamParser) +.set_attr("FInferStorageType", IfelseStorageType) +.set_num_inputs([](const NodeAttrs& attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs([](const NodeAttrs& attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_outputs; +}) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const IfelseParam& params = nnvm::get(attrs.parsed); + std::vector names; + names.reserve(params.num_args); + names.push_back("cond"); + names.push_back("then_branch"); + names.push_back("else_branch"); + for (int i = 3; i < params.num_args; ++i) + names.push_back("data" + std::to_string(i - 3)); + return names; +}) +.set_attr("FInputGraph", + [](const NodeAttrs& attrs) { + return std::vector{0, 1, 2}; +}) +.set_attr("FGradient", IfelseGradient) +.set_attr("FCreateOpState", CreateIfelseState) +.set_attr("FInferShape", IfelseShape) +.set_attr("FInferType", IfelseType) +.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FStatefulComputeEx", IfelseComputeExCPU) +.set_attr("key_var_num_args", "num_args") +.add_argument("cond", "Symbol", "Input graph for the condition.") +.add_argument("then_branch", "Symbol", "Input graph for the then branch.") +.add_argument("else_branch", "Symbol", "Input graph for the else branch.") +.add_argument("data", "NDArray-or-Symbol[]", + "The input arrays that include data arrays and states.") +.add_arguments(IfelseParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_ifelse) +.set_num_inputs([](const NodeAttrs& attrs){ + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_outputs * 2 + params.num_args - 3; +}) +.set_num_outputs([](const NodeAttrs& attrs){ + const IfelseParam& params = nnvm::get(attrs.parsed); + return params.num_args - 3; +}) +.set_attr("FExecType", [](const NodeAttrs& attrs) { + return ExecType::kSubgraphExec; +}) +.set_attr("FInferStorageType", BackwardIfelseStorageType) +.set_attr_parser(ParamParser) +.set_attr("TIsLayerOpBackward", true) +.set_attr("TIsBackward", true) +.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU) +.set_attr("FStatefulComputeEx", IfelseGradComputeExCPU); } // namespace op } // namespace mxnet diff --git a/src/operator/subgraph_op_common.cc b/src/operator/subgraph_op_common.cc index d845aa907d33..7a99aedb8602 100644 --- a/src/operator/subgraph_op_common.cc +++ b/src/operator/subgraph_op_common.cc @@ -161,6 +161,34 @@ bool InferSubgraphShape(const nnvm::Symbol &subgraph, return g.GetAttr("shape_num_unknown_nodes") == 0; } +template +T _asscalar(const NDArray &a) { + CHECK_EQ(a.shape().Size(), 1U); + T data; + a.SyncCopyToCPU(&data, 1U); + return data; +} + +bool as_bool_scalar(const NDArray &a) { + MSHADOW_TYPE_SWITCH(a.dtype(), DType, { + return static_cast(_asscalar(a)); + }); + LOG(FATAL) << "Unknown dtype"; + return false; +} + +bool is_shape_udf(const TShape &x) { + return x.ndim() == 0 || x.Size() == 0; +} + +bool is_stype_udf(const int &x) { + return x == exec::kBadStorageID; +} + +bool is_type_udf(const int &x) { + return x == -1; +} + LoopState::LoopState(const Symbol &g) { this->subgraph_sym = g; this->subgraph.outputs = g.outputs; diff --git a/src/operator/subgraph_op_common.h b/src/operator/subgraph_op_common.h index f73f09cd5c85..24983ae34632 100644 --- a/src/operator/subgraph_op_common.h +++ b/src/operator/subgraph_op_common.h @@ -57,6 +57,55 @@ bool InferSubgraphStorage(const nnvm::Symbol &subgraph, std::vector *in_attrs, std::vector *out_attrs); +bool as_bool_scalar(const NDArray &a); + +bool is_shape_udf(const TShape &x); + +bool is_stype_udf(const int &x); + +bool is_type_udf(const int &x); + +template +void extract_by_loc(const std::vector &array, + const nnvm::Tuple input_locs, + std::vector *out) { + out->clear(); + out->reserve(input_locs.ndim()); + for (dim_t i : input_locs) { + out->push_back(array[i]); + } +} + +template +bool fill_value(T *x, T *y, bool x_empty, bool y_empty) { + if (*x == *y || (x_empty && y_empty)) { + return true; + } + if (!x_empty && !y_empty) { + return false; + } + if (x_empty) { + *x = *y; + } + if (y_empty) { + *y = *x; + } + return true; +} + +template +bool sync_in_in(const nnvm::Tuple &input_locs, + std::vector *in, + std::vector *subg_in, + std::function is_empty) { + for (size_t i = 0; i < input_locs.ndim(); ++i) { + T &x = in->at(input_locs[i]); + T &y = subg_in->at(i); + fill_value(&x, &y, is_empty(x), is_empty(y)); + } + return true; +} + /* * This contains the states for running a loop and provides methods * of running the subgraph computation for an iteration. diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 9dd5c4397bee..12694572bb7c 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -15,17 +15,16 @@ # specific language governing permissions and limitations # under the License. +import numpy as np import mxnet as mx from mxnet import gluon -import numpy as np -import copy -from numpy.testing import assert_allclose -import unittest -from mxnet.test_utils import almost_equal, default_context -from numpy.testing import assert_allclose as assert_almost_equal # This is more restrictive +from numpy.testing import assert_allclose, assert_array_equal +from mxnet.test_utils import * from mxnet.base import _as_list +from common import with_seed +@with_seed() def test_while_loop_simple_forward(): class _TestBlock(gluon.HybridBlock): @@ -244,13 +243,14 @@ def _zeros_like_dict(name_list): assert_almost_equal(imp_grad, sym_grad, rtol=1e-4, atol=1e-4) +@with_seed() def test_while_loop_for_foreach(): def make_true_cond(): - return lambda loop_vars, _: (loop_vars[0] < 1e200).prod() + return lambda loop_vars, _: (loop_vars[0] < 1e35).prod() def make_false_cond(): - return lambda loop_vars, _: (loop_vars[0] > 1e200).prod() + return lambda loop_vars, _: (loop_vars[0] > 1e35).prod() def make_for_cond(length): return lambda loop_vars, _: loop_vars[0] < length @@ -613,8 +613,8 @@ def step(loop, free): (1, ), # a (1, ), # b ], - max_iterations=23, - n_steps=23, + max_iterations=5, + n_steps=5, ) # Case 1.2.* case_1( @@ -626,8 +626,8 @@ def step(loop, free): (2, 3, 4), # a (2, 3, 4), # b ], - max_iterations=31, - n_steps=31, + max_iterations=3, + n_steps=3, ) # Case 1.3.* case_1( @@ -644,7 +644,7 @@ def step(loop, free): ) # Case 2.1.* case_2( - cond=make_for_cond(length=31), + cond=make_for_cond(length=5), loop_var_shapes=[ (1, ), # i (2, ), # s @@ -654,11 +654,11 @@ def step(loop, free): (2, ), # f_1 (3, 4, 5, 6), # f_2, unused ], - n_steps=31, + n_steps=5, ) # Case 2.2.* case_2( - cond=make_for_cond(length=25), + cond=make_for_cond(length=3), loop_var_shapes=[ (1, ), # i (2, ), # s @@ -668,12 +668,12 @@ def step(loop, free): (2, ), # f_1 (3, 4, 5, 6), # f_2, unused ], - n_steps=25, + n_steps=3, ) # Case 3.* case_3( - length=11, - cond=make_for_cond(length=11), + length=5, + cond=make_for_cond(length=5), loop_var_shapes=[ (1, ), # i (2, ), # s_0 @@ -685,7 +685,7 @@ def step(loop, free): (2, ), # f_0 (3, 4, 5, 6), # f_1, unused ], - n_steps=11, + n_steps=5, ) # Case 4.1.* case_4( @@ -784,6 +784,7 @@ def step(loop, free): ) +@with_seed() def test_while_loop_nested(): def _to_np_list(arrays): @@ -891,6 +892,7 @@ def _get_sym_result(is_train, args, args_grad, out_grad): assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) +@with_seed() def test_while_loop_rnn(): def _array(shape): return mx.nd.random.uniform(-1.0, 1.0, shape=shape) @@ -972,6 +974,162 @@ def _func(*states): y = y.asnumpy() assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) +def _verify_ifelse(cond, then_func, else_func, input_var_shapes, free_var_shapes, is_train): + + def _create_symbol(prefix, i): + return mx.sym.var(prefix + str(i)) + + def _create_array(shape): + return mx.nd.random.uniform(-1.0, 1.0, shape=shape) + + def _to_numpy_list(arrays): + return [x.asnumpy() if x is not None else x for x in arrays] + + def _merge_dict(*dicts): + result = {} + for item in dicts: + result.update(item) + return result + + _input_syms = [_create_symbol("InputVar", i) for i, _ in enumerate(input_var_shapes)] + _free_syms = [_create_symbol("FreeVar", i) for i, _ in enumerate(free_var_shapes)] + _input_vars = [_create_array(x) for x in input_var_shapes] + _free_vars = [_create_array(x) for x in free_var_shapes] + _args_dict = _merge_dict( + {"InputVar" + str(i): x for i, x in enumerate(_input_vars)}, + {"FreeVar" + str(i): x for i, x in enumerate(_free_vars)}, + ) + + def _get_imperative_result(): + free_vars = [x.copy() for x in _free_vars] + input_vars = [x.copy() for x in _input_vars] + out_grads = [] + if is_train: + for var in free_vars + input_vars: + var.attach_grad() + with mx.autograd.record(train_mode=is_train): + outputs = mx.nd.contrib.ifelse( + cond=lambda *__input_vars: cond(__input_vars, free_vars), + then_func=lambda *__input_vars: then_func(__input_vars, free_vars), + else_func=lambda *__input_vars: else_func(__input_vars, free_vars), + inputs=input_vars, + ) + outputs = [x * 2 for x in outputs] + grads = [] + if is_train: + out_grads = [_create_array(x.shape) for x in outputs] + cat_out = mx.nd.concat(*[x.reshape(-1) for x in outputs], dim=0) + cat_out.backward(out_grad=mx.nd.concat(*[x.reshape(-1) for x in out_grads], dim=0)) + grads = [free_vars[i].grad for i, _ in enumerate(free_var_shapes)] \ + + [input_vars[i].grad for i, _ in enumerate(input_var_shapes)] + return _to_numpy_list(outputs), _to_numpy_list(grads), out_grads + + def _get_symbolic_result(out_grads): + outputs_sym = mx.sym.contrib.ifelse( + cond=lambda *__loop_vars: cond(__loop_vars, _free_syms), + then_func=lambda *__loop_vars: then_func(__loop_vars, _free_syms), + else_func=lambda *__loop_vars: else_func(__loop_vars, _free_syms), + inputs=_input_syms, + ) + outputs_sym = [x * 2 for x in outputs_sym] + outputs_sym = mx.sym.Group(outputs_sym) + executor = outputs_sym.bind( + ctx=default_context(), + args={name: _args_dict[name].copy() for name in outputs_sym.list_inputs()}, + args_grad=None if not is_train else _merge_dict( + {"InputVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(input_var_shapes)}, + {"FreeVar" + str(i): mx.nd.zeros(s) for i, s in enumerate(free_var_shapes)}, + ), + ) + outputs = executor.forward(is_train=is_train) + grads = [] + if is_train: + executor.backward(out_grads=out_grads) + grads = [executor.grad_dict.get("FreeVar" + str(i), None) for i, _ in enumerate(free_var_shapes)] \ + + [executor.grad_dict.get("InputVar" + str(i), None) for i, _ in enumerate(input_var_shapes)] + return _to_numpy_list(outputs), _to_numpy_list(grads) + + imp_outs, imp_grads, out_grads = _get_imperative_result() + sym_outs, sym_grads = _get_symbolic_result(out_grads) + for imp_out, sym_out in zip(imp_outs, sym_outs): + if imp_out is None or sym_out is None: + continue + assert_almost_equal(imp_out, sym_out, rtol=1e-5, atol=1e-5) + for imp_grad, sym_grad in zip(imp_grads, sym_grads): + if imp_grad is None or sym_grad is None: + continue + assert_almost_equal(imp_grad, sym_grad, rtol=1e-5, atol=1e-5) + + +@with_seed() +def test_ifelse(): + # whether there are free variables in three graphs + # whether these three graphs contain input_vars + # whether to use all input_vars + # which branch to choose + def run_case(cond_func, then_func, else_func, **params): + def make_cond(is_inverse): + def cond(inputs, free): + x = cond_func(inputs, free) + if is_inverse: + if isinstance(x, mx.sym.Symbol): + return mx.sym.logical_not(x) + else: + return mx.nd.logical_not(x) + return x + return cond + for is_train in [True, False]: + for is_inverse in [False, True]: + _verify_ifelse( + cond=make_cond(is_inverse), + then_func=then_func, + else_func=else_func, + is_train=is_train, + **params + ) + # Each function can + # 1. use_free_vars or not: T/F + # 2. use_input_vars or not: T/F + # 3. use_all_input_vars or not: T/F + # (a, b, c) are inputs, (d, e, f) are free_vars + cond_funcs = [ + lambda a, b, c, d, e, f: (a * b).sum() < 0.5, # F, T, F + lambda a, b, c, d, e, f: (a + b + c).sum() < 0.5, # F, T, T + lambda a, b, c, d, e, f: (d + e).sum() < 0.5, # T, F, F + lambda a, b, c, d, e, f: (d + e * a).sum() < 0.5, # T, T, F + lambda a, b, c, d, e, f: (d + e * a + b * c).sum() < 0.5, # T, T, T + ] + body_funcs = [ + lambda a, b, c, d, e, f: a * b, # F, T, F + lambda a, b, c, d, e, f: a * b * c, # F, T, T + lambda a, b, c, d, e, f: d * e, # T, F, F + lambda a, b, c, d, e, f: d * e * a, # T, T, F + lambda a, b, c, d, e, f: d * e * a * b * c, # T, T, T + # some extra tests + lambda a, b, c, d, e, f: b * c, + lambda a, b, c, d, e, f: a * c, + lambda a, b, c, d, e, f: (a + b) * c, + lambda a, b, c, d, e, f: c * (b - a), + ] + # enumerate all kinds of possible combinations + for cond_func in cond_funcs: + for then_func in body_funcs: + for else_func in body_funcs: + run_case( + cond_func=lambda x, y: cond_func(x[0], x[1], x[2], y[0], y[1], y[2]), + then_func=lambda x, y: then_func(x[0], x[1], x[2], y[0], y[1], y[2]), + else_func=lambda x, y: else_func(x[0], x[1], x[2], y[0], y[1], y[2]), + input_var_shapes=[ + (2, 3), + (2, 3), + (2, 3), + ], + free_var_shapes=[ + (2, 3), + (2, 3), + (2, 3), + ] + ) if __name__ == '__main__': import nose