Skip to content

Commit

Permalink
clang formatted everything
Browse files Browse the repository at this point in the history
  • Loading branch information
sorenlassen committed Jul 12, 2022
1 parent 2710e1d commit 04f3684
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 220 deletions.
62 changes: 31 additions & 31 deletions src/Dialect/ONNX/ONNXEinsumOpHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
#include "src/Dialect/ONNX/ONNXEinsumOpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"

#include <stddef.h>
#include <stdint.h>
#include <map>
#include <regex>
#include <stddef.h>
#include <stdint.h>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -42,7 +42,7 @@ namespace {
// The arrow and output are optional.
#define EQUATION INPUTS "(?:" ARROW OUTPUT ")?"

const char* const equation_regex = EQUATION;
const char *const equation_regex = EQUATION;

} // namespace

Expand All @@ -52,12 +52,12 @@ LogicalResult verifyEquation(
std::cmatch match;
if (!std::regex_match(equation.begin(), equation.end(), match, pattern)) {
return emitErrorFn() << "invalid equation syntax";
}
}
assert(match.size() == 3);
size_t numEquationInputs = equation.count(',') + 1;
if (numEquationInputs != numInputs) {
return emitErrorFn() << "number of equation inputs " << numEquationInputs
<< " != number of actual inputs " << numInputs;
<< " != number of actual inputs " << numInputs;
}
// equation matched the regex - next check that any output satisfies the
// constaints that its subscripts cannot be repeated and each must occur in
Expand All @@ -72,14 +72,14 @@ LogicalResult verifyEquation(
StringRef output(outputGroup.first, outputGroup.length());
for (size_t p = 0; p < output.size(); ++p) {
char x = output[p]; // x must be in [ .A-Za-z] given regex OUTPUT
if (x >= 'A') { // tests whether x is a letter given x is in [ .A-Za-z]
if (x >= 'A') { // tests whether x is a letter given x is in [ .A-Za-z]
if (StringRef::npos != output.find(x, p + 1)) {
return emitErrorFn()
<< "subscript " << x << " appears multiple times in the output";
<< "subscript " << x << " appears multiple times in the output";
}
if (StringRef::npos == inputs.find(x)) {
return emitErrorFn()
<< "output subscript " << x << " doesn't appear in inputs";
<< "output subscript " << x << " doesn't appear in inputs";
}
}
}
Expand Down Expand Up @@ -111,16 +111,16 @@ namespace {
std::string inferEquationOutput(StringRef commaSeparatedInputs) {
std::map<char, int> counts;
for (char x : commaSeparatedInputs) {
if (x >= 'A') { // tests whether x is a letter, given x is in [ ,.A-Za-z]
if (x >= 'A') { // tests whether x is a letter, given x is in [ ,.A-Za-z]
counts[x] += 1; // counts[x] initializes to 0 if not yet mapped
}
}
std::string equationOutput = "...";
// iterate through sorted map in order, i.e. alphabetically and
// all upper case letters before lower case
for (const auto& entry : counts) { // entry == pair (x, count)
if (entry.second == 1) // one occurrence
equationOutput.push_back(entry.first);
for (const auto &entry : counts) { // entry == pair (x, count)
if (entry.second == 1) // one occurrence
equationOutput.push_back(entry.first);
}
return equationOutput;
}
Expand Down Expand Up @@ -189,9 +189,11 @@ FailureOr<Signature> inferSignature(
ONNXEinsumOpAdaptor operandAdaptor, ErrorFn emitErrorFn) {
auto equation = operandAdaptor.equation();
auto inputs = operandAdaptor.Inputs();
assert(succeeded(verifyEquation(equation, inputs.size(), emitErrorFn))); // precondition, TODO: remove this excessive check
assert(succeeded(verifyEquation(equation, inputs.size(),
emitErrorFn))); // precondition, TODO: remove this excessive check
assert(llvm::all_of(inputs, [](Value i) {
return i.getType().cast<ShapedType>().hasRank(); })); // precondition
return i.getType().cast<ShapedType>().hasRank();
})); // precondition
StringRef equationOutput, commaSeparatedInputs;
std::tie(commaSeparatedInputs, equationOutput) = equation.split('-');
std::string inferredOutput;
Expand Down Expand Up @@ -224,28 +226,25 @@ FailureOr<Signature> inferSignature(
auto letters = countLetters(equationInput);
if (!hasEllipsis(equationInput)) {
if (rank != letters) {
return emitErrorFn()
<< "number of equation input parameter subscripts " << letters
<< " != input type rank " << rank;
return emitErrorFn() << "number of equation input parameter subscripts "
<< letters << " != input type rank " << rank;
}
} else {
if (rank < letters) {
return emitErrorFn()
<< "number of equation input parameter subscripts " << letters
<< " exceeds input type rank " << rank;
return emitErrorFn() << "number of equation input parameter subscripts "
<< letters << " exceeds input type rank " << rank;
}
int64_t thisEllipsisRank = rank - letters;
if (thisEllipsisRank > kMaxEllipsisRank) {
return emitErrorFn()
<< "ellipsis rank exceeds maximum of " << kMaxEllipsisRank;
<< "ellipsis rank exceeds maximum of " << kMaxEllipsisRank;
}
if (ellipsisRank == -1) {
ellipsisRank = thisEllipsisRank;
} else {
if (ellipsisRank != thisEllipsisRank) {
return emitErrorFn()
<< "inputs disagree on ellipsis rank, "
<< ellipsisRank << " vs " << thisEllipsisRank;
return emitErrorFn() << "inputs disagree on ellipsis rank, "
<< ellipsisRank << " vs " << thisEllipsisRank;
}
}
}
Expand All @@ -265,23 +264,24 @@ FailureOr<Signature> inferSignature(
auto insertion = subscriptsToDims.emplace(s, d);
int64_t d0 = insertion.first->second; // == subscriptsToDims[s]
if (d0 != d) {
return emitErrorFn() << "subscript '" << s
<< "' has axes with different dim sizes " << d0 << ", " << d
<< " in the same input";
return emitErrorFn()
<< "subscript '" << s << "' has axes with different dim sizes "
<< d0 << ", " << d << " in the same input";
}
}
}
// Merge the subscripts with static dim sizes into the
// broadcast map shared by all inputs.
for (const auto& entry : subscriptsToDims) {
for (const auto &entry : subscriptsToDims) {
char s = entry.first;
int64_t d = entry.second;
if (d != 1) {
auto insertion = broadcast.emplace(s, d);
int64_t d0 = insertion.first->second; // == broadcast[s]
if (d0 != d) {
return emitErrorFn() << "subscript '" << s
<< "' has conflicting dim sizes " << d0 << ", " << d;
return emitErrorFn()
<< "subscript '" << s << "' has conflicting dim sizes " << d0
<< ", " << d;
}
}
}
Expand All @@ -293,7 +293,7 @@ FailureOr<Signature> inferSignature(
bool outputHasEllipsis = hasEllipsis(equationOutput);
if (ellipsisRank != 0 && !outputHasEllipsis) {
return emitErrorFn() << "output needs ellipsis because inputs have "
<< "non-empty ellipsis with rank " << ellipsisRank;
<< "non-empty ellipsis with rank " << ellipsisRank;
}
int64_t outputRank =
countLetters(equationOutput) + (outputHasEllipsis ? ellipsisRank : 0);
Expand Down
6 changes: 3 additions & 3 deletions src/Dialect/ONNX/ONNXEinsumOpHelper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@

#pragma once

#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "mlir/Support/LogicalResult.h"

#include <stdint.h>

namespace mlir {
class InFlightDiagnostic;
class ONNXEinsumOpAdaptor;
}
} // namespace mlir

namespace onnx_mlir {

Expand Down
8 changes: 5 additions & 3 deletions src/Dialect/ONNX/ONNXOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"

#include "src/Dialect/ONNX/ONNXEinsumOpHelper.hpp"
#include "src/Dialect/ONNX/ONNXOps.hpp"
#include "src/Dialect/ONNX/ONNXOpsHelper.hpp"
#include "src/Dialect/ONNX/ONNXEinsumOpHelper.hpp"
#include "src/Dialect/ONNX/ShapeInference/ONNXShapeHelper.hpp"
#include "src/Support/Diagnostic.hpp"

Expand Down Expand Up @@ -4028,7 +4028,8 @@ LogicalResult ONNXEinsumOp::verify() {
return failure();
}

Type firstElementType = inputs[0].getType().cast<ShapedType>().getElementType();
Type firstElementType =
inputs[0].getType().cast<ShapedType>().getElementType();
for (Value input : inputs) {
ShapedType type = input.getType().cast<ShapedType>();
if (type.getElementType() != firstElementType) {
Expand All @@ -4051,7 +4052,8 @@ LogicalResult ONNXEinsumOp::inferShapes(
};
auto shape = einsum::inferOutputShape(operandAdaptor, errorFn);
assert(succeeded(shape) && "any failure should be caught in verify()");
auto elementType = getOperand(0).getType().cast<ShapedType>().getElementType();
auto elementType =
getOperand(0).getType().cast<ShapedType>().getElementType();
getResult().setType(RankedTensorType::get(*shape, elementType));
return success();
}
Expand Down
Loading

0 comments on commit 04f3684

Please sign in to comment.