From e4d81861717e44424c326c487a8dc65a233a53e9 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Mon, 5 Feb 2024 14:49:09 -0800 Subject: [PATCH 1/4] [OM] Support integer binary arithmetic in the Evaluator. This adds support for the Evaluator to evaluate integer binary arithmetic. This is defined in terms of the IntegerBinaryArithmeticOp interface, so the Evaluator can generically handle all ops that implement this interface in terms of the interface methods. At a high level, getOrCreateValue will create a partially evaluated value for the result of this op. Once the operands are fully evaluated, they are passed to the evaluateIntegerOperation interface method, and if successful, the result is used for the evaluator value. Because these are the first kind of operation using AttributeValues that doesn't necessarily make the Attribute immediately available, the AttributeValue has been extended to support construction in a partially evaluated state to be filled in later. The tests show example of straightforward evaluation where the operands are immediately available from constants, where the operands are references to object fields, and where the operands are references to object fields that are initially not evaluated. This handles many cases, but there is currently no detection for dataflow cycles through an integer binary arithmetic operation. The evaluator would need significant changes in the current setup to detect this on the fly, so currently assumes cycles are prevented earlier in the pipeline. --- .../circt/Dialect/OM/Evaluator/Evaluator.h | 23 ++- lib/Dialect/OM/Evaluator/Evaluator.cpp | 107 +++++++++++++- .../Dialect/OM/Evaluator/EvaluatorTests.cpp | 136 ++++++++++++++++++ 3 files changed, 260 insertions(+), 6 deletions(-) diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index 303bca0971b1..3834e626c108 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -133,9 +133,17 @@ struct AttributeValue : EvaluatorValue { AttributeValue(Attribute attr) : AttributeValue(attr, mlir::UnknownLoc::get(attr.getContext())) {} AttributeValue(Attribute attr, Location loc) - : EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr) { + : EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr), + type(cast(attr).getType()) { markFullyEvaluated(); } + + // Constructors for partially evaluated AttributeValue. + AttributeValue(Type type) + : AttributeValue(mlir::UnknownLoc::get(type.getContext())) {} + AttributeValue(Type type, Location loc) + : EvaluatorValue(type.getContext(), Kind::Attr, loc), type(type) {} + Attribute getAttr() const { return attr; } template AttrTy getAs() const { @@ -145,13 +153,17 @@ struct AttributeValue : EvaluatorValue { return e->getKind() == Kind::Attr; } + // Set Attribute for partially evaluated case. + LogicalResult setAttr(Attribute attr); + // Finalize the value. - LogicalResult finalizeImpl() { return success(); } + LogicalResult finalizeImpl(); - Type getType() const { return attr.cast().getType(); } + Type getType() const { return type; } private: Attribute attr = {}; + Type type; }; // This perform finalization to `value`. @@ -452,6 +464,11 @@ struct Evaluator { FailureOr evaluateConstant(ConstantOp op, ActualParameters actualParams, Location loc); + + FailureOr + evaluateIntegerBinaryArithmetic(IntegerBinaryArithmeticOp op, + ActualParameters actualParams, Location loc); + /// Instantiate an Object with its class name and actual parameters. FailureOr evaluateObjectInstance(StringAttr className, ActualParameters actualParams, diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index c8d8b7bbd1eb..15dd76c49900 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -56,9 +56,7 @@ LogicalResult circt::om::evaluator::EvaluatorValue::finalize() { Type circt::om::evaluator::EvaluatorValue::getType() const { return llvm::TypeSwitch(this) - .Case([](auto *attr) -> Type { - return cast(attr->getAttr()).getType(); - }) + .Case([](auto *attr) -> Type { return attr->getType(); }) .Case([](auto *object) { return object->getObjectType(); }) .Case([](auto *list) { return list->getListType(); }) .Case([](auto *map) { return map->getMapType(); }) @@ -130,6 +128,14 @@ FailureOr circt::om::Evaluator::getOrCreateValue( .Case([&](ConstantOp op) { return evaluateConstant(op, actualParams, loc); }) + .Case([&](IntegerBinaryArithmeticOp op) { + // Create a partially evaluated AttributeValue of + // om::IntegerType in case we need to delay evaluation. + evaluator::EvaluatorValuePtr result = + std::make_shared( + op.getResult().getType(), loc); + return success(result); + }) .Case([&](auto op) { // Create a reference value since the value pointed by object // field op is not created yet. @@ -339,6 +345,9 @@ circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams, .Case([&](ConstantOp op) { return evaluateConstant(op, actualParams, loc); }) + .Case([&](IntegerBinaryArithmeticOp op) { + return evaluateIntegerBinaryArithmetic(op, actualParams, loc); + }) .Case([&](ObjectOp op) { return evaluateObjectInstance(op, actualParams); }) @@ -394,6 +403,74 @@ circt::om::Evaluator::evaluateConstant(ConstantOp op, op.getValue(), loc)); } +// Evaluator dispatch function for integer binary arithmetic. +FailureOr +circt::om::Evaluator::evaluateIntegerBinaryArithmetic( + IntegerBinaryArithmeticOp op, ActualParameters actualParams, Location loc) { + // Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet. + auto handle = getOrCreateValue(op.getResult(), actualParams, loc); + + // If it's fully evaluated, we can return it. + if (handle.value()->isFullyEvaluated()) + return handle; + + // Evaluate operands if necessary, and return the partially evaluated value if + // they aren't ready. + auto lhsResult = evaluateValue(op.getLhs(), actualParams, loc); + if (failed(lhsResult)) + return lhsResult; + if (!lhsResult.value()->isFullyEvaluated()) + return handle; + + auto rhsResult = evaluateValue(op.getRhs(), actualParams, loc); + if (failed(rhsResult)) + return rhsResult; + if (!rhsResult.value()->isFullyEvaluated()) + return handle; + + // Extract the integer attributes. + auto extractAttr = [](evaluator::EvaluatorValue *value) { + return std::move( + llvm::TypeSwitch(value) + .Case([](evaluator::AttributeValue *val) { + return val->getAs(); + }) + .Case([](evaluator::ReferenceValue *val) { + return cast(val->getValue().get()) + ->getAs(); + })); + }; + + om::IntegerAttr lhs = extractAttr(lhsResult.value().get()); + om::IntegerAttr rhs = extractAttr(rhsResult.value().get()); + assert(lhs && rhs && + "expected om::IntegerAttr for IntegerBinaryArithmeticOp operands"); + + // Perform arbitrary precision signed integer binary arithmetic. + FailureOr result = op.evaluateIntegerOperation( + lhs.getValue().getAPSInt(), rhs.getValue().getAPSInt()); + + if (failed(result)) + return op->emitError("failed to evaluate integer operation"); + + // Package the result as a new om::IntegerAttr. + MLIRContext *ctx = op->getContext(); + auto resultAttr = + om::IntegerAttr::get(ctx, mlir::IntegerAttr::get(ctx, result.value())); + + // Finalize the op result value. + auto *handleValue = cast(handle.value().get()); + auto resultStatus = handleValue->setAttr(resultAttr); + if (failed(resultStatus)) + return resultStatus; + + auto finalizeStatus = handleValue->finalize(); + if (failed(finalizeStatus)) + return finalizeStatus; + + return handle; +} + /// Evaluator dispatch function for Object instances. FailureOr circt::om::Evaluator::createParametersFromOperands( @@ -778,3 +855,27 @@ void evaluator::PathValue::setBasepath(const BasePathValue &basepath) { path = PathAttr::get(path.getContext(), newPath); markFullyEvaluated(); } + +//===----------------------------------------------------------------------===// +// AttributeValue +//===----------------------------------------------------------------------===// + +LogicalResult circt::om::evaluator::AttributeValue::setAttr(Attribute attr) { + if (cast(attr).getType() != this->type) + return mlir::emitError(getLoc(), "cannot set AttributeValue of type ") + << this->type << " to Attribute " << attr; + if (isFullyEvaluated()) + return mlir::emitError( + getLoc(), + "cannot set AttributeValue that has already been fully evaluated"); + this->attr = attr; + markFullyEvaluated(); + return success(); +} + +LogicalResult circt::om::evaluator::AttributeValue::finalizeImpl() { + if (!isFullyEvaluated()) + return mlir::emitError( + getLoc(), "cannot finalize AttributeValue that is not fully evaluated"); + return success(); +} diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index 4559e8b471bc..640095dde765 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -613,4 +613,140 @@ TEST(EvaluatorTests, InstantiateCycle) { ASSERT_TRUE(failed(result)); } +TEST(EvaluatorTests, IntegerBinaryArithmeticAdd) { + StringRef mod = "om.class @IntegerBinaryArithmeticAdd() {" + " %0 = om.constant #om.integer<1 : si3> : !om.integer" + " %1 = om.constant #om.integer<2 : si3> : !om.integer" + " %2 = om.integer.add %0, %1 : !om.integer" + " om.class.field @result, %2 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticAdd"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(3, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + +TEST(EvaluatorTests, IntegerBinaryArithmeticObjects) { + StringRef mod = "om.class @Class1() {" + " %0 = om.constant #om.integer<1 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + "}" + "" + "om.class @Class2() {" + " %0 = om.constant #om.integer<2 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + "}" + "" + "om.class @IntegerBinaryArithmeticObjects() {" + " %0 = om.object @Class1() : () -> !om.class.type<@Class1>" + " %1 = om.object.field %0, [@value] : " + "(!om.class.type<@Class1>) -> !om.integer" + "" + " %2 = om.object @Class2() : () -> !om.class.type<@Class2>" + " %3 = om.object.field %2, [@value] : " + "(!om.class.type<@Class2>) -> !om.integer" + "" + " %5 = om.integer.add %1, %3 : !om.integer" + " om.class.field @result, %5 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticObjects"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(3, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + +TEST(EvaluatorTests, IntegerBinaryArithmeticObjectsDelayed) { + StringRef mod = + "om.class @Class1(%input: !om.integer) {" + " %0 = om.constant #om.integer<1 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + " om.class.field @input, %input : !om.integer" + "}" + "" + "om.class @Class2() {" + " %0 = om.constant #om.integer<2 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + "}" + "" + "om.class @IntegerBinaryArithmeticObjectsDelayed() {" + " %0 = om.object @Class1(%5) : (!om.integer) -> !om.class.type<@Class1>" + " %1 = om.object.field %0, [@value] : " + "(!om.class.type<@Class1>) -> !om.integer" + "" + " %2 = om.object @Class2() : () -> !om.class.type<@Class2>" + " %3 = om.object.field %2, [@value] : " + "(!om.class.type<@Class2>) -> !om.integer" + "" + " %5 = om.integer.add %1, %3 : !om.integer" + " om.class.field @result, %5 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticObjectsDelayed"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(3, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + } // namespace From 34bf6055ae45e3080a7f38c93fbe41bd7ccd6661 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 20 Feb 2024 10:29:22 -0700 Subject: [PATCH 2/4] Use getStrippedValue for fully evaluated references. Co-authored-by: Hideto Ueno --- lib/Dialect/OM/Evaluator/Evaluator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 15dd76c49900..eb8af9e02ba9 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -436,7 +436,7 @@ circt::om::Evaluator::evaluateIntegerBinaryArithmetic( return val->getAs(); }) .Case([](evaluator::ReferenceValue *val) { - return cast(val->getValue().get()) + return cast(val->getStrippedValue().get()) ->getAs(); })); }; From d3c2f571c38f27be7780c9d92fba3192357c70dc Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 20 Feb 2024 09:39:26 -0800 Subject: [PATCH 3/4] Format. --- lib/Dialect/OM/Evaluator/Evaluator.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index eb8af9e02ba9..8aa3dc73396f 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -436,7 +436,8 @@ circt::om::Evaluator::evaluateIntegerBinaryArithmetic( return val->getAs(); }) .Case([](evaluator::ReferenceValue *val) { - return cast(val->getStrippedValue().get()) + return cast( + val->getStrippedValue().get()) ->getAs(); })); }; From e761f69b879920d9142b42bdf82bae9129743e39 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 20 Feb 2024 09:40:05 -0800 Subject: [PATCH 4/4] getStrippedValue is a pointer. --- lib/Dialect/OM/Evaluator/Evaluator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index 8aa3dc73396f..61573c16c360 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -437,7 +437,7 @@ circt::om::Evaluator::evaluateIntegerBinaryArithmetic( }) .Case([](evaluator::ReferenceValue *val) { return cast( - val->getStrippedValue().get()) + val->getStrippedValue()->get()) ->getAs(); })); };