Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support for 3d MatMul #44

Merged
merged 2 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ distributing this software causes anyone any harm.
Authors:
Kalle Raiskila
Robin van Emden
Youngsun Kong
Youngsun Kong
AUTOMATIC1111
112 changes: 87 additions & 25 deletions src/nodes/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,35 @@ class MatMul : public Node {
op_name = "MatMul";
}

std::string vecstr( const std::vector<int>& vec ) const
{
std::stringstream result;
result << "{ ";
std::copy( vec.begin(), vec.end(), std::ostream_iterator<int>( result, ", " ) );
result << "}";
return result.str();
}

virtual void print(std::ostream &dst) const override
{
const Tensor *A = get_input_tensor(0);
const Tensor *B = get_input_tensor(1);
std::string type = A->data_type_str();

if( A->data_dim.size() != 2 )
ERROR("Unimplemented: higher than 2D MatMul");
bool A_is_correct_size = A->data_dim.size() == 2 || A->data_dim.size() == 3;
bool B_is_correct_size = B->data_dim.size() == 2 || B->data_dim.size() == 3;
if ( !A_is_correct_size || !B_is_correct_size )
{
ERROR( std::string( "Unimplemented: MatMul with dimensions: A: " ) + vecstr( A->data_dim ) + ", B: " + vecstr( B->data_dim ) );
}

std::vector<int> A_dim( A->data_dim.begin() + A->data_dim.size() - 2, A->data_dim.end() );
std::vector<int> B_dim( B->data_dim.begin() + B->data_dim.size() - 2, B->data_dim.end() );

int32_t rows = A->data_dim[0];
int32_t cols = B->data_dim[1];
int32_t inner = A->data_dim[1];
int32_t inner2 = B->data_dim[0];
int32_t rows = A_dim[0];
int32_t cols = B_dim[1];
int32_t inner = A_dim[1];
int32_t inner2 = B_dim[0];
if( inner == 0 ) inner=1;

// TODO: handle the case of [N] * [Nx1] multiplication,
Expand All @@ -28,14 +44,37 @@ class MatMul : public Node {
if( inner != inner2 )
ERROR("MatMul input's inner dimensions don't match");

INDT_1 << "/* MatMul */" << std::endl;
INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_2 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;
INDT_3 << "Y[r][c] = 0;" << std::endl;
INDT_3 << "for( uint32_t i=0; i<" << inner << "; i++ )" << std::endl;
INDT_4 << "Y[r][c] += A[r][i] * B[i][c];" << std::endl;
INDT_2 << "}" << std::endl;
bool A_is_2d = A->data_dim.size() == 2;
bool B_is_2d = B->data_dim.size() == 2;

if ( A_is_2d && B_is_2d )
{
INDT_1 << "/* MatMul */" << std::endl;
INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_2 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;
INDT_3 << "Y[r][c] = 0;" << std::endl;
INDT_3 << "for( uint32_t i=0; i<" << inner << "; i++ )" << std::endl;
INDT_4 << "Y[r][c] += A[r][i] * B[i][c];" << std::endl;
INDT_2 << "}" << std::endl;
}
else
{
std::string A_txt = A_is_2d ? "A" : "A[n]";
std::string B_txt = B_is_2d ? "B" : "B[n]";

INDT_1 << "/* MatMul */" << std::endl;

INDT_1 << "for( uint32_t n=0; n<" << A->data_dim[0] << "; n++ ) {" << std::endl;

INDT_2 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_3 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;
INDT_4 << "Y[n][r][c] = 0;" << std::endl;
INDT_4 << "for( uint32_t i=0; i<" << inner << "; i++ )" << std::endl;
INDT_5 << "Y[n][r][c] += " << A_txt << "[r][i] * " << B_txt << "[i][c];" << std::endl;
INDT_3 << "}" << std::endl;

INDT_1 << "}" << std::endl;
}
}
virtual void resolve(void) override
{
Expand All @@ -52,6 +91,23 @@ class MatMul : public Node {
result_dim(rows, cols);

Tensor *rv = new Tensor;

if ( A->data_dim.size() == 3 && B->data_dim.size() == 3 )
{
if ( A->data_dim[0] != B->data_dim[0] )
ERROR( std::string("MatMul input's dimensions don't match: A: ") + vecstr( A->data_dim ) + ", B: " + vecstr( B->data_dim ) );

rv->data_dim.push_back( A->data_dim[0] );
}
else if ( A->data_dim.size() == 3 && B->data_dim.size() == 2 )
{
rv->data_dim.push_back( A->data_dim[0] );
}
else if ( A->data_dim.size() == 2 && B->data_dim.size() == 3 )
{
rv->data_dim.push_back( B->data_dim[0] );
}

rv->data_dim.push_back(rows);
rv->data_dim.push_back(cols);
rv->data_type = A->data_type;
Expand All @@ -60,31 +116,37 @@ class MatMul : public Node {

void result_dim( int32_t &rows, int32_t &cols) const
{
const Tensor* A = get_input_tensor( 0 );
const Tensor* B = get_input_tensor( 1 );

std::vector<int> A_dim( A->data_dim.begin() + A->data_dim.size() - 2, A->data_dim.end() );
std::vector<int> B_dim( B->data_dim.begin() + B->data_dim.size() - 2, B->data_dim.end() );

// TODO: this is the check for vectors. Check equivalent for N-dimensons: N>2
if( get_input_tensor(0)->data_dim[1] != 0 && get_input_tensor(1)->data_dim[1] != 0 )
if ( A_dim[1] != 0 && B_dim[1] != 0 )
{
rows = get_input_tensor(0)->data_dim[0];
cols = get_input_tensor(1)->data_dim[1];
rows = A_dim[0];
cols = B_dim[1];
}
else if( get_input_tensor(0)->data_dim[1] == 0 && get_input_tensor(1)->data_dim[1] == 0 )
else if ( A_dim[1] == 0 && B_dim[1] == 0 )
{
ERROR("Bad input/unhandled: 2 vectors to MatMul");
ERROR( "Bad input/unhandled: 2 vectors to MatMul" );
}
else if( get_input_tensor(0)->data_dim[1] == 0 )
else if ( A_dim[1] == 0 )
{
cols = get_input_tensor(1)->data_dim[1];
if( get_input_tensor(0)->data_dim[0] == get_input_tensor(1)->data_dim[0] )
cols = B_dim[1];
if ( A_dim[0] == B_dim[0] )
rows = 1;
else
rows = get_input_tensor(0)->data_dim[0];
rows = A_dim[0];
}
else
{
rows = get_input_tensor(0)->data_dim[0];
if( get_input_tensor(0)->data_dim[1] == get_input_tensor(1)->data_dim[0] )
rows = A_dim[0];
if ( A_dim[1] == B_dim[0] )
cols = 1;
else
cols = get_input_tensor(1)->data_dim[0];
cols = B_dim[0];
}
}
};
Expand Down