Skip to content

Commit

Permalink
Fixed matmul/mul/reduce bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sorenlassen committed Jul 11, 2022
1 parent 02eadd2 commit 6dc9f50
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/Transform/ONNX/DecomposeEinsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,7 @@ class Decomposer {
}
}

void reduce(Output& output) {
auto keep = otherSubscripts({&output});
void reduce(Output& output, const std::unordered_set<char> &keep) {
Axes axes;
for (size_t a = 0; a < output.size(); ++a) {
if (keep.count(output.subscripts[a]) == 0)
Expand Down Expand Up @@ -245,6 +244,7 @@ class Decomposer {
}

void transpose(Output& output, const Subscripts &transposedSubscripts) {
assert(output.subscripts.size() == transposedSubscripts.size());
if (output.subscripts == transposedSubscripts)
return;

Expand Down Expand Up @@ -297,8 +297,10 @@ class Decomposer {
output1.value = builder.create<ONNXMulOp>(
loc, output1.type(elementType), output1.value, output2.value)
.getResult();
if (reduceAtEnd)
reduce(output1);
if (reduceAtEnd) {
auto keep = otherSubscripts({&output1, &output2});
reduce(output1, keep);
}
remove(output2);
}

Expand Down Expand Up @@ -364,7 +366,8 @@ class Decomposer {

for (auto& output : outputs) {
diagonalize(output);
reduce(output);
auto keep = otherSubscripts({&output});
reduce(output, keep);
}

while (outputs.size() > 1) {
Expand Down

0 comments on commit 6dc9f50

Please sign in to comment.