Skip to content

Commit

Permalink
support for 3d MatMul (kraiskil#44)
Browse files Browse the repository at this point in the history
Fix by: AUTOMATIC1111
  • Loading branch information
AUTOMATIC1111 authored Apr 2, 2024
1 parent 3542d14 commit 316b67b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 26 deletions.
3 changes: 2 additions & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ distributing this software causes anyone any harm.
Authors:
Kalle Raiskila
Robin van Emden
Youngsun Kong
Youngsun Kong
AUTOMATIC1111
112 changes: 87 additions & 25 deletions src/nodes/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,35 @@ class MatMul : public Node {
op_name = "MatMul";
}

std::string vecstr( const std::vector<int>& vec ) const
{
std::stringstream result;
result << "{ ";
std::copy( vec.begin(), vec.end(), std::ostream_iterator<int>( result, ", " ) );
result << "}";
return result.str();
}

virtual void print(std::ostream &dst) const override
{
const Tensor *A = get_input_tensor(0);
const Tensor *B = get_input_tensor(1);
std::string type = A->data_type_str();

if( A->data_dim.size() != 2 )
ERROR("Unimplemented: higher than 2D MatMul");
bool A_is_correct_size = A->data_dim.size() == 2 || A->data_dim.size() == 3;
bool B_is_correct_size = B->data_dim.size() == 2 || B->data_dim.size() == 3;
if ( !A_is_correct_size || !B_is_correct_size )
{
ERROR( std::string( "Unimplemented: MatMul with dimensions: A: " ) + vecstr( A->data_dim ) + ", B: " + vecstr( B->data_dim ) );
}

std::vector<int> A_dim( A->data_dim.begin() + A->data_dim.size() - 2, A->data_dim.end() );
std::vector<int> B_dim( B->data_dim.begin() + B->data_dim.size() - 2, B->data_dim.end() );

int32_t rows = A->data_dim[0];
int32_t cols = B->data_dim[1];
int32_t inner = A->data_dim[1];
int32_t inner2 = B->data_dim[0];
int32_t rows = A_dim[0];
int32_t cols = B_dim[1];
int32_t inner = A_dim[1];
int32_t inner2 = B_dim[0];
if( inner == 0 ) inner=1;

// TODO: handle the case of [N] * [Nx1] multiplication,
Expand All @@ -28,14 +44,37 @@ class MatMul : public Node {
if( inner != inner2 )
ERROR("MatMul input's inner dimensions don't match");

INDT_1 << "/* MatMul */" << std::endl;
INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_2 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;
INDT_3 << "Y[r][c] = 0;" << std::endl;
INDT_3 << "for( uint32_t i=0; i<" << inner << "; i++ )" << std::endl;
INDT_4 << "Y[r][c] += A[r][i] * B[i][c];" << std::endl;
INDT_2 << "}" << std::endl;
bool A_is_2d = A->data_dim.size() == 2;
bool B_is_2d = B->data_dim.size() == 2;

if ( A_is_2d && B_is_2d )
{
INDT_1 << "/* MatMul */" << std::endl;
INDT_1 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_2 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;
INDT_3 << "Y[r][c] = 0;" << std::endl;
INDT_3 << "for( uint32_t i=0; i<" << inner << "; i++ )" << std::endl;
INDT_4 << "Y[r][c] += A[r][i] * B[i][c];" << std::endl;
INDT_2 << "}" << std::endl;
}
else
{
std::string A_txt = A_is_2d ? "A" : "A[n]";
std::string B_txt = B_is_2d ? "B" : "B[n]";

INDT_1 << "/* MatMul */" << std::endl;

INDT_1 << "for( uint32_t n=0; n<" << A->data_dim[0] << "; n++ ) {" << std::endl;

INDT_2 << "for( uint32_t r=0; r<" << rows << "; r++ )" << std::endl;
INDT_3 << "for( uint32_t c=0; c<" << cols << "; c++ ) {" << std::endl;
INDT_4 << "Y[n][r][c] = 0;" << std::endl;
INDT_4 << "for( uint32_t i=0; i<" << inner << "; i++ )" << std::endl;
INDT_5 << "Y[n][r][c] += " << A_txt << "[r][i] * " << B_txt << "[i][c];" << std::endl;
INDT_3 << "}" << std::endl;

INDT_1 << "}" << std::endl;
}
}
virtual void resolve(void) override
{
Expand All @@ -52,6 +91,23 @@ class MatMul : public Node {
result_dim(rows, cols);

Tensor *rv = new Tensor;

if ( A->data_dim.size() == 3 && B->data_dim.size() == 3 )
{
if ( A->data_dim[0] != B->data_dim[0] )
ERROR( std::string("MatMul input's dimensions don't match: A: ") + vecstr( A->data_dim ) + ", B: " + vecstr( B->data_dim ) );

rv->data_dim.push_back( A->data_dim[0] );
}
else if ( A->data_dim.size() == 3 && B->data_dim.size() == 2 )
{
rv->data_dim.push_back( A->data_dim[0] );
}
else if ( A->data_dim.size() == 2 && B->data_dim.size() == 3 )
{
rv->data_dim.push_back( B->data_dim[0] );
}

rv->data_dim.push_back(rows);
rv->data_dim.push_back(cols);
rv->data_type = A->data_type;
Expand All @@ -60,31 +116,37 @@ class MatMul : public Node {

void result_dim( int32_t &rows, int32_t &cols) const
{
const Tensor* A = get_input_tensor( 0 );
const Tensor* B = get_input_tensor( 1 );

std::vector<int> A_dim( A->data_dim.begin() + A->data_dim.size() - 2, A->data_dim.end() );
std::vector<int> B_dim( B->data_dim.begin() + B->data_dim.size() - 2, B->data_dim.end() );

// TODO: this is the check for vectors. Check equivalent for N-dimensons: N>2
if( get_input_tensor(0)->data_dim[1] != 0 && get_input_tensor(1)->data_dim[1] != 0 )
if ( A_dim[1] != 0 && B_dim[1] != 0 )
{
rows = get_input_tensor(0)->data_dim[0];
cols = get_input_tensor(1)->data_dim[1];
rows = A_dim[0];
cols = B_dim[1];
}
else if( get_input_tensor(0)->data_dim[1] == 0 && get_input_tensor(1)->data_dim[1] == 0 )
else if ( A_dim[1] == 0 && B_dim[1] == 0 )
{
ERROR("Bad input/unhandled: 2 vectors to MatMul");
ERROR( "Bad input/unhandled: 2 vectors to MatMul" );
}
else if( get_input_tensor(0)->data_dim[1] == 0 )
else if ( A_dim[1] == 0 )
{
cols = get_input_tensor(1)->data_dim[1];
if( get_input_tensor(0)->data_dim[0] == get_input_tensor(1)->data_dim[0] )
cols = B_dim[1];
if ( A_dim[0] == B_dim[0] )
rows = 1;
else
rows = get_input_tensor(0)->data_dim[0];
rows = A_dim[0];
}
else
{
rows = get_input_tensor(0)->data_dim[0];
if( get_input_tensor(0)->data_dim[1] == get_input_tensor(1)->data_dim[0] )
rows = A_dim[0];
if ( A_dim[1] == B_dim[0] )
cols = 1;
else
cols = get_input_tensor(1)->data_dim[0];
cols = B_dim[0];
}
}
};
Expand Down

0 comments on commit 316b67b

Please sign in to comment.