diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index cc0b58df3f1b..6f9308b12cea 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -714,61 +714,32 @@ def create_op_by_mode(mode): return fused_op, stack_op, recurrent_block_prefix -def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss): +def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) - # ==== Single layer ==== - fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) - fused_layer.initialize() - - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) - with stack_layer.name_scope(): - stack_layer.add(stack_op(hidden_size, prefix='l0_')) - stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) - - # ==== Multiple layer ==== - fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) + fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix) fused_layer.initialize() stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): - stack_layer.add(stack_op(hidden_size, prefix='l0_')) - stack_layer.add(stack_op(hidden_size, prefix='l1_')) - stack_layer.add(stack_op(hidden_size, prefix='l2_')) + for n in range(num_layers): + stack_layer.add(stack_op(hidden_size, prefix="l{}_".format(n))) stack_layer.initialize() - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size) -def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): +def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss): fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode) - # ==== Single layer ==== - fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) - fused_layer.initialize() - stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) - with stack_layer.name_scope(): - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), - stack_op(hidden_size, prefix='r0_'))) - stack_layer.initialize() - - check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) - - # ==== Multiple layer ==== - fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) + fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix) fused_layer.initialize() stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix) with stack_layer.name_scope(): - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'), - stack_op(hidden_size, prefix='r0_'))) - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l1_'), - stack_op(hidden_size, prefix='r1_'))) - stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l2_'), - stack_op(hidden_size, prefix='r2_'))) - stack_layer.initialize() - + for n in range(num_layers): + stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix="l{}_".format(n)), + stack_op(hidden_size, prefix="r{}_".format(n)))) + stack_layer.initialize() check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True) @@ -777,10 +748,11 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss): def test_fused_lstm_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss) @with_seed() @@ -788,10 +760,11 @@ def test_fused_lstm_layer(): def test_fused_gru_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss) @with_seed() @@ -799,10 +772,11 @@ def test_fused_gru_layer(): def test_fused_rnnrelu_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss) @with_seed() @@ -810,10 +784,11 @@ def test_fused_rnnrelu_layer(): def test_fused_rnntanh_layer(): input_sizes = [8] hidden_sizes = [8, 16] - for input_size, hidden_size in product(input_sizes, hidden_sizes): + num_layers = [1, 2, 3, 4] + for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers): loss = mx.gluon.loss.L2Loss() - check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss) - check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss) + check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss) + check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss) def test_rnn_unroll_variant_length():