diff --git a/src/nodes/lstm.cc b/src/nodes/lstm.cc index 6a89d9d..5105e9a 100644 --- a/src/nodes/lstm.cc +++ b/src/nodes/lstm.cc @@ -102,6 +102,9 @@ void LSTM::print_activation(std::ostream &dst, const std::string &activation, co * The code is almost identical for forward and backwards nodes */ void LSTM::print_lstm_kernel(std::ostream &dst, bool forward) const { + const Tensor* B = get_B(); + const Tensor* P = get_P(); + int dir; // direction index into tensors that separate forward and backward (W,B,Y,...) int f_act; // indexes for the activation functions in activations[] int g_act; @@ -223,7 +226,7 @@ void LSTM::print_lstm_kernel(std::ostream &dst, bool forward) const INDT_2<< "for( int h=0; his_used() ) { + if( get_Y()->is_used() ) { INDT_3<< Y_snbh << "= " << Yh_dbh <<";" << std::endl; } INDT_2<< "}" << std::endl << std::endl; @@ -231,6 +234,15 @@ void LSTM::print_lstm_kernel(std::ostream &dst, bool forward) const void LSTM::print(std::ostream &dst) const { + const Tensor* X = get_X(); + const Tensor* W = get_W(); + const Tensor* R = get_R(); + const Tensor* B = get_B(); + const Tensor* sequence_lens = get_sequence_lens(); + const Tensor* initial_h = get_initial_h(); + const Tensor* initial_c = get_initial_c(); + const Tensor* P = get_P(); + INDT_1<< "/* LSTM " << std::endl; INDT_1<< " * inputs: " << std::endl; INDT_1<< " * X = " << X->cname() << std::endl; @@ -242,9 +254,9 @@ void LSTM::print(std::ostream &dst) const INDT_1<< " * initial_c = " << (initial_c?initial_c->cname():"") << std::endl; INDT_1<< " * P = " << (P?P->cname():"") << std::endl; INDT_1<< " * outputs: " << std::endl; - INDT_1<< " * Y = " << Y->cname() << std::endl; - INDT_1<< " * Y_h = " << Y_h->cname() << std::endl; - INDT_1<< " * Y_c = " << Y_c->cname() << std::endl; + INDT_1<< " * Y = " << get_Y()->cname() << std::endl; + INDT_1<< " * Y_h = " << get_Y_h()->cname() << std::endl; + INDT_1<< " * Y_c = " << get_Y_c()->cname() << std::endl; INDT_1<< " * attributes:" << std::endl; INDT_1<< " * activations: "; for( auto a : activations ) @@ -319,6 +331,8 @@ void LSTM::print(std::ostream &dst) const // Helper function for resolve(void) void LSTM::calculate_data_dimensions() { + const Tensor* X = get_X(); + const Tensor* W = get_W(); if( layout == 0 ) { seq_length = X->data_dim[0]; batch_size = X->data_dim[1]; @@ -371,41 +385,35 @@ void LSTM::resolve(void) if( hidden_size < 0 ) ERROR("Must provide hidden_size attribute!"); - X = inputs[0]; - register_input(X, "X"); - W = inputs[1]; - register_input(W, "W"); - R = inputs[2]; - register_input(R, "R"); + register_input(get_X(), "X"); + register_input(get_W(), "W"); + register_input(get_R(), "R"); //optional inputs. Trailing unprovided inputs can just be left out //but non-trailing, unprovided inputs MUST have an empty string as name // (guess that means tensors MAY NOT have an empty string as name?) - if( inputs.size() > 3 && inputs[3]->name != "" ) { - B = inputs[3]; - register_input(B, "B"); + if( get_B() ) { + register_input(get_B(), "B"); } - if( inputs.size() > 4 && inputs[4]->name != "" ) { - sequence_lens = inputs[4]; - register_input(sequence_lens, "sequence_lens"); + if( get_sequence_lens() ) { + register_input(get_sequence_lens(), "sequence_lens"); } - if( inputs.size() > 5 && inputs[5]->name != "" ) { - initial_h = inputs[5]; - register_input(initial_h, "initial_h"); + if( get_initial_h()) { + register_input(get_initial_h(), "initial_h"); } - if( inputs.size() > 6 && inputs[6]->name != "" ) { - initial_c = inputs[6]; - register_input(initial_c, "initial_c"); + if( get_initial_c()) { + register_input(get_initial_c(), "initial_c"); } - if( inputs.size() > 7 && inputs[7]->name != "" ) { - P = inputs[7]; - register_input(P, "P"); + if( get_P() ) { + register_input(get_P(), "P"); } calculate_data_dimensions(); - if( sequence_lens ) { + if( get_sequence_lens() ) { + const Tensor* sequence_lens = get_sequence_lens(); + if( static_cast(sequence_lens->rank()) != 1 ) ERROR("If providing sequence lengths, it must be a 1D tensor"); if( static_cast(sequence_lens->data_dim[0]) != batch_size ) @@ -419,8 +427,8 @@ void LSTM::resolve(void) // Generate output tensors. - Y = new Tensor; - Y->data_type = X->data_type; + Tensor *Y = new Tensor; + Y->data_type = get_X()->data_type; std::vector y_size; if( layout == 0 ) y_size = std::vector({ seq_length, num_directions, batch_size, hidden_size }); @@ -436,8 +444,8 @@ void LSTM::resolve(void) else ych_size = std::vector({ batch_size, num_directions, hidden_size }); - Y_h = new Tensor; - Y_h->data_type = X->data_type; + Tensor *Y_h = new Tensor; + Y_h->data_type = get_X()->data_type; Y_h->data_dim = ych_size; Y_h->isRecursive=true; Y_h->data_buffer = calloc(Y_h->data_num_elem(), Y_h->data_elem_size()); @@ -445,8 +453,8 @@ void LSTM::resolve(void) ERROR("Memory allocation failed"); Y_h->initialize = true; - Y_c = new Tensor; - Y_c->data_type = X->data_type; + Tensor *Y_c = new Tensor; + Y_c->data_type = get_X()->data_type; Y_c->data_dim = ych_size; Y_c->isRecursive=true; Y_c->data_buffer = calloc(Y_c->data_num_elem(), Y_c->data_elem_size()); diff --git a/src/nodes/lstm.h b/src/nodes/lstm.h index ec28276..dfe0d6d 100644 --- a/src/nodes/lstm.h +++ b/src/nodes/lstm.h @@ -26,35 +26,9 @@ class LSTM : public Node { clip = -1.0; hidden_size = -1; input_forget = 0; - X=NULL; - W=NULL; - R=NULL; - B=NULL; - sequence_lens=NULL; - initial_h=NULL; - initial_c=NULL; - P=NULL; - Y=NULL; - Y_h=NULL; - Y_c=NULL; layout=0; } - // inputs - const Tensor *X; - const Tensor *W; - const Tensor *R; - // optional inputs - const Tensor *B; - const Tensor *sequence_lens; - const Tensor *initial_h; - const Tensor *initial_c; - const Tensor *P; - // optional outputs - Tensor *Y; - Tensor *Y_h; - Tensor *Y_c; - // Attributes std::vector activation_alpha; std::vector activation_beta; @@ -78,6 +52,29 @@ class LSTM : public Node { float get_activation_alpha( const std::string &a); float get_activation_beta( const std::string &a); + const Tensor* get_X(void) const { return inputs[0]; } + const Tensor* get_W(void) const { return inputs[1]; } + const Tensor* get_R(void) const { return inputs[2]; } + const Tensor* get_Y(void) const { return outputs[0]; } + const Tensor* get_Y_h(void) const { return outputs[1]; } + const Tensor* get_Y_c(void) const { return outputs[2]; } + + // ONNX allows omitting optional inputs by either: + // - not give them at all + // - named with the empty string + const Tensor* get_optional(unsigned N) const + { + if( inputs.size() <= N ) + return nullptr; + if( inputs[N]->name == "" ) + return nullptr; + return inputs[N]; + } + const Tensor* get_B(void) const { return get_optional(3); } + const Tensor* get_sequence_lens(void) const { return get_optional(4); } + const Tensor* get_initial_h(void) const {return get_optional(5); } + const Tensor* get_initial_c(void) const {return get_optional(6); } + const Tensor* get_P(void) const {return get_optional(7); } void print_activation(std::ostream &dst, const std::string &activation, const std::string &var) const; void print_lstm_kernel(std::ostream &dst, bool forward) const;