Skip to content

Commit

Permalink
fix LSTM
Browse files Browse the repository at this point in the history
  • Loading branch information
kraiskil committed Jul 26, 2023
1 parent 57b8241 commit 168a87e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 58 deletions.
72 changes: 40 additions & 32 deletions src/nodes/lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ void LSTM::print_activation(std::ostream &dst, const std::string &activation, co
* The code is almost identical for forward and backwards nodes */
void LSTM::print_lstm_kernel(std::ostream &dst, bool forward) const
{
const Tensor* B = get_B();
const Tensor* P = get_P();

int dir; // direction index into tensors that separate forward and backward (W,B,Y,...)
int f_act; // indexes for the activation functions in activations[]
int g_act;
Expand Down Expand Up @@ -223,14 +226,23 @@ void LSTM::print_lstm_kernel(std::ostream &dst, bool forward) const
INDT_2<< "for( int h=0; h<hs; h++) {" << std::endl;
INDT_3<< Yh_dbh << " = ot[b][h] * ";
print_activation( dst, activations[h_act], Yc_dbh );
if( Y->is_used() ) {
if( get_Y()->is_used() ) {
INDT_3<< Y_snbh << "= " << Yh_dbh <<";" << std::endl;
}
INDT_2<< "}" << std::endl << std::endl;
}

void LSTM::print(std::ostream &dst) const
{
const Tensor* X = get_X();
const Tensor* W = get_W();
const Tensor* R = get_R();
const Tensor* B = get_B();
const Tensor* sequence_lens = get_sequence_lens();
const Tensor* initial_h = get_initial_h();
const Tensor* initial_c = get_initial_c();
const Tensor* P = get_P();

INDT_1<< "/* LSTM " << std::endl;
INDT_1<< " * inputs: " << std::endl;
INDT_1<< " * X = " << X->cname() << std::endl;
Expand All @@ -242,9 +254,9 @@ void LSTM::print(std::ostream &dst) const
INDT_1<< " * initial_c = " << (initial_c?initial_c->cname():"") << std::endl;
INDT_1<< " * P = " << (P?P->cname():"") << std::endl;
INDT_1<< " * outputs: " << std::endl;
INDT_1<< " * Y = " << Y->cname() << std::endl;
INDT_1<< " * Y_h = " << Y_h->cname() << std::endl;
INDT_1<< " * Y_c = " << Y_c->cname() << std::endl;
INDT_1<< " * Y = " << get_Y()->cname() << std::endl;
INDT_1<< " * Y_h = " << get_Y_h()->cname() << std::endl;
INDT_1<< " * Y_c = " << get_Y_c()->cname() << std::endl;
INDT_1<< " * attributes:" << std::endl;
INDT_1<< " * activations: ";
for( auto a : activations )
Expand Down Expand Up @@ -319,6 +331,8 @@ void LSTM::print(std::ostream &dst) const
// Helper function for resolve(void)
void LSTM::calculate_data_dimensions()
{
const Tensor* X = get_X();
const Tensor* W = get_W();
if( layout == 0 ) {
seq_length = X->data_dim[0];
batch_size = X->data_dim[1];
Expand Down Expand Up @@ -371,41 +385,35 @@ void LSTM::resolve(void)
if( hidden_size < 0 )
ERROR("Must provide hidden_size attribute!");

X = inputs[0];
register_input(X, "X");
W = inputs[1];
register_input(W, "W");
R = inputs[2];
register_input(R, "R");
register_input(get_X(), "X");
register_input(get_W(), "W");
register_input(get_R(), "R");

//optional inputs. Trailing unprovided inputs can just be left out
//but non-trailing, unprovided inputs MUST have an empty string as name
// (guess that means tensors MAY NOT have an empty string as name?)
if( inputs.size() > 3 && inputs[3]->name != "" ) {
B = inputs[3];
register_input(B, "B");
if( get_B() ) {
register_input(get_B(), "B");
}
if( inputs.size() > 4 && inputs[4]->name != "" ) {
sequence_lens = inputs[4];
register_input(sequence_lens, "sequence_lens");
if( get_sequence_lens() ) {
register_input(get_sequence_lens(), "sequence_lens");
}
if( inputs.size() > 5 && inputs[5]->name != "" ) {
initial_h = inputs[5];
register_input(initial_h, "initial_h");
if( get_initial_h()) {
register_input(get_initial_h(), "initial_h");
}
if( inputs.size() > 6 && inputs[6]->name != "" ) {
initial_c = inputs[6];
register_input(initial_c, "initial_c");
if( get_initial_c()) {
register_input(get_initial_c(), "initial_c");
}
if( inputs.size() > 7 && inputs[7]->name != "" ) {
P = inputs[7];
register_input(P, "P");
if( get_P() ) {
register_input(get_P(), "P");
}


calculate_data_dimensions();

if( sequence_lens ) {
if( get_sequence_lens() ) {
const Tensor* sequence_lens = get_sequence_lens();

if( static_cast<int>(sequence_lens->rank()) != 1 )
ERROR("If providing sequence lengths, it must be a 1D tensor");
if( static_cast<int>(sequence_lens->data_dim[0]) != batch_size )
Expand All @@ -419,8 +427,8 @@ void LSTM::resolve(void)

// Generate output tensors.

Y = new Tensor;
Y->data_type = X->data_type;
Tensor *Y = new Tensor;
Y->data_type = get_X()->data_type;
std::vector<int> y_size;
if( layout == 0 )
y_size = std::vector<int>({ seq_length, num_directions, batch_size, hidden_size });
Expand All @@ -436,17 +444,17 @@ void LSTM::resolve(void)
else
ych_size = std::vector<int>({ batch_size, num_directions, hidden_size });

Y_h = new Tensor;
Y_h->data_type = X->data_type;
Tensor *Y_h = new Tensor;
Y_h->data_type = get_X()->data_type;
Y_h->data_dim = ych_size;
Y_h->isRecursive=true;
Y_h->data_buffer = calloc(Y_h->data_num_elem(), Y_h->data_elem_size());
if( Y_h->data_buffer == NULL )
ERROR("Memory allocation failed");
Y_h->initialize = true;

Y_c = new Tensor;
Y_c->data_type = X->data_type;
Tensor *Y_c = new Tensor;
Y_c->data_type = get_X()->data_type;
Y_c->data_dim = ych_size;
Y_c->isRecursive=true;
Y_c->data_buffer = calloc(Y_c->data_num_elem(), Y_c->data_elem_size());
Expand Down
49 changes: 23 additions & 26 deletions src/nodes/lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,9 @@ class LSTM : public Node {
clip = -1.0;
hidden_size = -1;
input_forget = 0;
X=NULL;
W=NULL;
R=NULL;
B=NULL;
sequence_lens=NULL;
initial_h=NULL;
initial_c=NULL;
P=NULL;
Y=NULL;
Y_h=NULL;
Y_c=NULL;
layout=0;
}

// inputs
const Tensor *X;
const Tensor *W;
const Tensor *R;
// optional inputs
const Tensor *B;
const Tensor *sequence_lens;
const Tensor *initial_h;
const Tensor *initial_c;
const Tensor *P;
// optional outputs
Tensor *Y;
Tensor *Y_h;
Tensor *Y_c;

// Attributes
std::vector<float> activation_alpha;
std::vector<float> activation_beta;
Expand All @@ -78,6 +52,29 @@ class LSTM : public Node {

float get_activation_alpha( const std::string &a);
float get_activation_beta( const std::string &a);
const Tensor* get_X(void) const { return inputs[0]; }
const Tensor* get_W(void) const { return inputs[1]; }
const Tensor* get_R(void) const { return inputs[2]; }
const Tensor* get_Y(void) const { return outputs[0]; }
const Tensor* get_Y_h(void) const { return outputs[1]; }
const Tensor* get_Y_c(void) const { return outputs[2]; }

// ONNX allows omitting optional inputs by either:
// - not give them at all
// - named with the empty string
const Tensor* get_optional(unsigned N) const
{
if( inputs.size() <= N )
return nullptr;
if( inputs[N]->name == "" )
return nullptr;
return inputs[N];
}
const Tensor* get_B(void) const { return get_optional(3); }
const Tensor* get_sequence_lens(void) const { return get_optional(4); }
const Tensor* get_initial_h(void) const {return get_optional(5); }
const Tensor* get_initial_c(void) const {return get_optional(6); }
const Tensor* get_P(void) const {return get_optional(7); }

void print_activation(std::ostream &dst, const std::string &activation, const std::string &var) const;
void print_lstm_kernel(std::ostream &dst, bool forward) const;
Expand Down

0 comments on commit 168a87e

Please sign in to comment.