diff --git a/include/circt/Dialect/OM/Evaluator/Evaluator.h b/include/circt/Dialect/OM/Evaluator/Evaluator.h index 303bca0971b1..3834e626c108 100644 --- a/include/circt/Dialect/OM/Evaluator/Evaluator.h +++ b/include/circt/Dialect/OM/Evaluator/Evaluator.h @@ -133,9 +133,17 @@ struct AttributeValue : EvaluatorValue { AttributeValue(Attribute attr) : AttributeValue(attr, mlir::UnknownLoc::get(attr.getContext())) {} AttributeValue(Attribute attr, Location loc) - : EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr) { + : EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr), + type(cast(attr).getType()) { markFullyEvaluated(); } + + // Constructors for partially evaluated AttributeValue. + AttributeValue(Type type) + : AttributeValue(mlir::UnknownLoc::get(type.getContext())) {} + AttributeValue(Type type, Location loc) + : EvaluatorValue(type.getContext(), Kind::Attr, loc), type(type) {} + Attribute getAttr() const { return attr; } template AttrTy getAs() const { @@ -145,13 +153,17 @@ struct AttributeValue : EvaluatorValue { return e->getKind() == Kind::Attr; } + // Set Attribute for partially evaluated case. + LogicalResult setAttr(Attribute attr); + // Finalize the value. - LogicalResult finalizeImpl() { return success(); } + LogicalResult finalizeImpl(); - Type getType() const { return attr.cast().getType(); } + Type getType() const { return type; } private: Attribute attr = {}; + Type type; }; // This perform finalization to `value`. @@ -452,6 +464,11 @@ struct Evaluator { FailureOr evaluateConstant(ConstantOp op, ActualParameters actualParams, Location loc); + + FailureOr + evaluateIntegerBinaryArithmetic(IntegerBinaryArithmeticOp op, + ActualParameters actualParams, Location loc); + /// Instantiate an Object with its class name and actual parameters. FailureOr evaluateObjectInstance(StringAttr className, ActualParameters actualParams, diff --git a/lib/Dialect/OM/Evaluator/Evaluator.cpp b/lib/Dialect/OM/Evaluator/Evaluator.cpp index c8d8b7bbd1eb..61573c16c360 100644 --- a/lib/Dialect/OM/Evaluator/Evaluator.cpp +++ b/lib/Dialect/OM/Evaluator/Evaluator.cpp @@ -56,9 +56,7 @@ LogicalResult circt::om::evaluator::EvaluatorValue::finalize() { Type circt::om::evaluator::EvaluatorValue::getType() const { return llvm::TypeSwitch(this) - .Case([](auto *attr) -> Type { - return cast(attr->getAttr()).getType(); - }) + .Case([](auto *attr) -> Type { return attr->getType(); }) .Case([](auto *object) { return object->getObjectType(); }) .Case([](auto *list) { return list->getListType(); }) .Case([](auto *map) { return map->getMapType(); }) @@ -130,6 +128,14 @@ FailureOr circt::om::Evaluator::getOrCreateValue( .Case([&](ConstantOp op) { return evaluateConstant(op, actualParams, loc); }) + .Case([&](IntegerBinaryArithmeticOp op) { + // Create a partially evaluated AttributeValue of + // om::IntegerType in case we need to delay evaluation. + evaluator::EvaluatorValuePtr result = + std::make_shared( + op.getResult().getType(), loc); + return success(result); + }) .Case([&](auto op) { // Create a reference value since the value pointed by object // field op is not created yet. @@ -339,6 +345,9 @@ circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams, .Case([&](ConstantOp op) { return evaluateConstant(op, actualParams, loc); }) + .Case([&](IntegerBinaryArithmeticOp op) { + return evaluateIntegerBinaryArithmetic(op, actualParams, loc); + }) .Case([&](ObjectOp op) { return evaluateObjectInstance(op, actualParams); }) @@ -394,6 +403,75 @@ circt::om::Evaluator::evaluateConstant(ConstantOp op, op.getValue(), loc)); } +// Evaluator dispatch function for integer binary arithmetic. +FailureOr +circt::om::Evaluator::evaluateIntegerBinaryArithmetic( + IntegerBinaryArithmeticOp op, ActualParameters actualParams, Location loc) { + // Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet. + auto handle = getOrCreateValue(op.getResult(), actualParams, loc); + + // If it's fully evaluated, we can return it. + if (handle.value()->isFullyEvaluated()) + return handle; + + // Evaluate operands if necessary, and return the partially evaluated value if + // they aren't ready. + auto lhsResult = evaluateValue(op.getLhs(), actualParams, loc); + if (failed(lhsResult)) + return lhsResult; + if (!lhsResult.value()->isFullyEvaluated()) + return handle; + + auto rhsResult = evaluateValue(op.getRhs(), actualParams, loc); + if (failed(rhsResult)) + return rhsResult; + if (!rhsResult.value()->isFullyEvaluated()) + return handle; + + // Extract the integer attributes. + auto extractAttr = [](evaluator::EvaluatorValue *value) { + return std::move( + llvm::TypeSwitch(value) + .Case([](evaluator::AttributeValue *val) { + return val->getAs(); + }) + .Case([](evaluator::ReferenceValue *val) { + return cast( + val->getStrippedValue()->get()) + ->getAs(); + })); + }; + + om::IntegerAttr lhs = extractAttr(lhsResult.value().get()); + om::IntegerAttr rhs = extractAttr(rhsResult.value().get()); + assert(lhs && rhs && + "expected om::IntegerAttr for IntegerBinaryArithmeticOp operands"); + + // Perform arbitrary precision signed integer binary arithmetic. + FailureOr result = op.evaluateIntegerOperation( + lhs.getValue().getAPSInt(), rhs.getValue().getAPSInt()); + + if (failed(result)) + return op->emitError("failed to evaluate integer operation"); + + // Package the result as a new om::IntegerAttr. + MLIRContext *ctx = op->getContext(); + auto resultAttr = + om::IntegerAttr::get(ctx, mlir::IntegerAttr::get(ctx, result.value())); + + // Finalize the op result value. + auto *handleValue = cast(handle.value().get()); + auto resultStatus = handleValue->setAttr(resultAttr); + if (failed(resultStatus)) + return resultStatus; + + auto finalizeStatus = handleValue->finalize(); + if (failed(finalizeStatus)) + return finalizeStatus; + + return handle; +} + /// Evaluator dispatch function for Object instances. FailureOr circt::om::Evaluator::createParametersFromOperands( @@ -778,3 +856,27 @@ void evaluator::PathValue::setBasepath(const BasePathValue &basepath) { path = PathAttr::get(path.getContext(), newPath); markFullyEvaluated(); } + +//===----------------------------------------------------------------------===// +// AttributeValue +//===----------------------------------------------------------------------===// + +LogicalResult circt::om::evaluator::AttributeValue::setAttr(Attribute attr) { + if (cast(attr).getType() != this->type) + return mlir::emitError(getLoc(), "cannot set AttributeValue of type ") + << this->type << " to Attribute " << attr; + if (isFullyEvaluated()) + return mlir::emitError( + getLoc(), + "cannot set AttributeValue that has already been fully evaluated"); + this->attr = attr; + markFullyEvaluated(); + return success(); +} + +LogicalResult circt::om::evaluator::AttributeValue::finalizeImpl() { + if (!isFullyEvaluated()) + return mlir::emitError( + getLoc(), "cannot finalize AttributeValue that is not fully evaluated"); + return success(); +} diff --git a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp index 4559e8b471bc..640095dde765 100644 --- a/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp +++ b/unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp @@ -613,4 +613,140 @@ TEST(EvaluatorTests, InstantiateCycle) { ASSERT_TRUE(failed(result)); } +TEST(EvaluatorTests, IntegerBinaryArithmeticAdd) { + StringRef mod = "om.class @IntegerBinaryArithmeticAdd() {" + " %0 = om.constant #om.integer<1 : si3> : !om.integer" + " %1 = om.constant #om.integer<2 : si3> : !om.integer" + " %2 = om.integer.add %0, %1 : !om.integer" + " om.class.field @result, %2 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticAdd"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(3, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + +TEST(EvaluatorTests, IntegerBinaryArithmeticObjects) { + StringRef mod = "om.class @Class1() {" + " %0 = om.constant #om.integer<1 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + "}" + "" + "om.class @Class2() {" + " %0 = om.constant #om.integer<2 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + "}" + "" + "om.class @IntegerBinaryArithmeticObjects() {" + " %0 = om.object @Class1() : () -> !om.class.type<@Class1>" + " %1 = om.object.field %0, [@value] : " + "(!om.class.type<@Class1>) -> !om.integer" + "" + " %2 = om.object @Class2() : () -> !om.class.type<@Class2>" + " %3 = om.object.field %2, [@value] : " + "(!om.class.type<@Class2>) -> !om.integer" + "" + " %5 = om.integer.add %1, %3 : !om.integer" + " om.class.field @result, %5 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticObjects"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(3, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + +TEST(EvaluatorTests, IntegerBinaryArithmeticObjectsDelayed) { + StringRef mod = + "om.class @Class1(%input: !om.integer) {" + " %0 = om.constant #om.integer<1 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + " om.class.field @input, %input : !om.integer" + "}" + "" + "om.class @Class2() {" + " %0 = om.constant #om.integer<2 : si3> : !om.integer" + " om.class.field @value, %0 : !om.integer" + "}" + "" + "om.class @IntegerBinaryArithmeticObjectsDelayed() {" + " %0 = om.object @Class1(%5) : (!om.integer) -> !om.class.type<@Class1>" + " %1 = om.object.field %0, [@value] : " + "(!om.class.type<@Class1>) -> !om.integer" + "" + " %2 = om.object @Class2() : () -> !om.class.type<@Class2>" + " %3 = om.object.field %2, [@value] : " + "(!om.class.type<@Class2>) -> !om.integer" + "" + " %5 = om.integer.add %1, %3 : !om.integer" + " om.class.field @result, %5 : !om.integer" + "}"; + + DialectRegistry registry; + registry.insert(); + + MLIRContext context(registry); + context.getOrLoadDialect(); + + OwningOpRef owning = + parseSourceString(mod, ParserConfig(&context)); + + Evaluator evaluator(owning.release()); + + auto result = evaluator.instantiate( + StringAttr::get(&context, "IntegerBinaryArithmeticObjectsDelayed"), {}); + + ASSERT_TRUE(succeeded(result)); + + auto fieldValue = llvm::cast(result.value().get()) + ->getField("result") + .value(); + + ASSERT_EQ(3, llvm::cast(fieldValue.get()) + ->getAs() + .getValue() + .getValue()); +} + } // namespace