Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OM] Support integer binary arithmetic in the Evaluator. #6711

Merged
merged 4 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions include/circt/Dialect/OM/Evaluator/Evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,17 @@ struct AttributeValue : EvaluatorValue {
AttributeValue(Attribute attr)
: AttributeValue(attr, mlir::UnknownLoc::get(attr.getContext())) {}
AttributeValue(Attribute attr, Location loc)
: EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr) {
: EvaluatorValue(attr.getContext(), Kind::Attr, loc), attr(attr),
type(cast<TypedAttr>(attr).getType()) {
markFullyEvaluated();
}

// Constructors for partially evaluated AttributeValue.
AttributeValue(Type type)
: AttributeValue(mlir::UnknownLoc::get(type.getContext())) {}
AttributeValue(Type type, Location loc)
: EvaluatorValue(type.getContext(), Kind::Attr, loc), type(type) {}

Attribute getAttr() const { return attr; }
template <typename AttrTy>
AttrTy getAs() const {
Expand All @@ -145,13 +153,17 @@ struct AttributeValue : EvaluatorValue {
return e->getKind() == Kind::Attr;
}

// Set Attribute for partially evaluated case.
LogicalResult setAttr(Attribute attr);

// Finalize the value.
LogicalResult finalizeImpl() { return success(); }
LogicalResult finalizeImpl();

Type getType() const { return attr.cast<TypedAttr>().getType(); }
Type getType() const { return type; }

private:
Attribute attr = {};
Type type;
};

// This perform finalization to `value`.
Expand Down Expand Up @@ -452,6 +464,11 @@ struct Evaluator {

FailureOr<EvaluatorValuePtr>
evaluateConstant(ConstantOp op, ActualParameters actualParams, Location loc);

FailureOr<EvaluatorValuePtr>
evaluateIntegerBinaryArithmetic(IntegerBinaryArithmeticOp op,
ActualParameters actualParams, Location loc);

/// Instantiate an Object with its class name and actual parameters.
FailureOr<EvaluatorValuePtr>
evaluateObjectInstance(StringAttr className, ActualParameters actualParams,
Expand Down
108 changes: 105 additions & 3 deletions lib/Dialect/OM/Evaluator/Evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ LogicalResult circt::om::evaluator::EvaluatorValue::finalize() {

Type circt::om::evaluator::EvaluatorValue::getType() const {
return llvm::TypeSwitch<const EvaluatorValue *, Type>(this)
.Case<AttributeValue>([](auto *attr) -> Type {
return cast<TypedAttr>(attr->getAttr()).getType();
})
.Case<AttributeValue>([](auto *attr) -> Type { return attr->getType(); })
.Case<ObjectValue>([](auto *object) { return object->getObjectType(); })
.Case<ListValue>([](auto *list) { return list->getListType(); })
.Case<MapValue>([](auto *map) { return map->getMapType(); })
Expand Down Expand Up @@ -130,6 +128,14 @@ FailureOr<evaluator::EvaluatorValuePtr> circt::om::Evaluator::getOrCreateValue(
.Case([&](ConstantOp op) {
return evaluateConstant(op, actualParams, loc);
})
.Case([&](IntegerBinaryArithmeticOp op) {
// Create a partially evaluated AttributeValue of
// om::IntegerType in case we need to delay evaluation.
evaluator::EvaluatorValuePtr result =
std::make_shared<evaluator::AttributeValue>(
op.getResult().getType(), loc);
return success(result);
})
.Case<ObjectFieldOp>([&](auto op) {
// Create a reference value since the value pointed by object
// field op is not created yet.
Expand Down Expand Up @@ -339,6 +345,9 @@ circt::om::Evaluator::evaluateValue(Value value, ActualParameters actualParams,
.Case([&](ConstantOp op) {
return evaluateConstant(op, actualParams, loc);
})
.Case([&](IntegerBinaryArithmeticOp op) {
return evaluateIntegerBinaryArithmetic(op, actualParams, loc);
})
.Case([&](ObjectOp op) {
return evaluateObjectInstance(op, actualParams);
})
Expand Down Expand Up @@ -394,6 +403,75 @@ circt::om::Evaluator::evaluateConstant(ConstantOp op,
op.getValue(), loc));
}

// Evaluator dispatch function for integer binary arithmetic.
FailureOr<EvaluatorValuePtr>
circt::om::Evaluator::evaluateIntegerBinaryArithmetic(
IntegerBinaryArithmeticOp op, ActualParameters actualParams, Location loc) {
// Get the op's EvaluatorValue handle, in case it hasn't been evaluated yet.
auto handle = getOrCreateValue(op.getResult(), actualParams, loc);

// If it's fully evaluated, we can return it.
if (handle.value()->isFullyEvaluated())
return handle;

// Evaluate operands if necessary, and return the partially evaluated value if
// they aren't ready.
auto lhsResult = evaluateValue(op.getLhs(), actualParams, loc);
if (failed(lhsResult))
return lhsResult;
if (!lhsResult.value()->isFullyEvaluated())
return handle;

auto rhsResult = evaluateValue(op.getRhs(), actualParams, loc);
if (failed(rhsResult))
return rhsResult;
if (!rhsResult.value()->isFullyEvaluated())
return handle;

// Extract the integer attributes.
auto extractAttr = [](evaluator::EvaluatorValue *value) {
return std::move(
llvm::TypeSwitch<evaluator::EvaluatorValue *, om::IntegerAttr>(value)
.Case([](evaluator::AttributeValue *val) {
return val->getAs<om::IntegerAttr>();
})
.Case([](evaluator::ReferenceValue *val) {
return cast<evaluator::AttributeValue>(
val->getStrippedValue()->get())
->getAs<om::IntegerAttr>();
}));
};

om::IntegerAttr lhs = extractAttr(lhsResult.value().get());
om::IntegerAttr rhs = extractAttr(rhsResult.value().get());
assert(lhs && rhs &&
"expected om::IntegerAttr for IntegerBinaryArithmeticOp operands");

// Perform arbitrary precision signed integer binary arithmetic.
FailureOr<APSInt> result = op.evaluateIntegerOperation(
lhs.getValue().getAPSInt(), rhs.getValue().getAPSInt());

if (failed(result))
return op->emitError("failed to evaluate integer operation");

// Package the result as a new om::IntegerAttr.
MLIRContext *ctx = op->getContext();
auto resultAttr =
om::IntegerAttr::get(ctx, mlir::IntegerAttr::get(ctx, result.value()));

// Finalize the op result value.
auto *handleValue = cast<evaluator::AttributeValue>(handle.value().get());
auto resultStatus = handleValue->setAttr(resultAttr);
if (failed(resultStatus))
return resultStatus;

auto finalizeStatus = handleValue->finalize();
if (failed(finalizeStatus))
return finalizeStatus;

return handle;
}

/// Evaluator dispatch function for Object instances.
FailureOr<circt::om::Evaluator::ActualParameters>
circt::om::Evaluator::createParametersFromOperands(
Expand Down Expand Up @@ -778,3 +856,27 @@ void evaluator::PathValue::setBasepath(const BasePathValue &basepath) {
path = PathAttr::get(path.getContext(), newPath);
markFullyEvaluated();
}

//===----------------------------------------------------------------------===//
// AttributeValue
//===----------------------------------------------------------------------===//

LogicalResult circt::om::evaluator::AttributeValue::setAttr(Attribute attr) {
if (cast<TypedAttr>(attr).getType() != this->type)
return mlir::emitError(getLoc(), "cannot set AttributeValue of type ")
<< this->type << " to Attribute " << attr;
if (isFullyEvaluated())
return mlir::emitError(
getLoc(),
"cannot set AttributeValue that has already been fully evaluated");
this->attr = attr;
markFullyEvaluated();
return success();
}

LogicalResult circt::om::evaluator::AttributeValue::finalizeImpl() {
if (!isFullyEvaluated())
return mlir::emitError(
getLoc(), "cannot finalize AttributeValue that is not fully evaluated");
return success();
}
136 changes: 136 additions & 0 deletions unittests/Dialect/OM/Evaluator/EvaluatorTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,4 +613,140 @@ TEST(EvaluatorTests, InstantiateCycle) {
ASSERT_TRUE(failed(result));
}

TEST(EvaluatorTests, IntegerBinaryArithmeticAdd) {
StringRef mod = "om.class @IntegerBinaryArithmeticAdd() {"
" %0 = om.constant #om.integer<1 : si3> : !om.integer"
" %1 = om.constant #om.integer<2 : si3> : !om.integer"
" %2 = om.integer.add %0, %1 : !om.integer"
" om.class.field @result, %2 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticAdd"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

ASSERT_EQ(3, llvm::cast<evaluator::AttributeValue>(fieldValue.get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

TEST(EvaluatorTests, IntegerBinaryArithmeticObjects) {
StringRef mod = "om.class @Class1() {"
" %0 = om.constant #om.integer<1 : si3> : !om.integer"
" om.class.field @value, %0 : !om.integer"
"}"
""
"om.class @Class2() {"
" %0 = om.constant #om.integer<2 : si3> : !om.integer"
" om.class.field @value, %0 : !om.integer"
"}"
""
"om.class @IntegerBinaryArithmeticObjects() {"
" %0 = om.object @Class1() : () -> !om.class.type<@Class1>"
" %1 = om.object.field %0, [@value] : "
"(!om.class.type<@Class1>) -> !om.integer"
""
" %2 = om.object @Class2() : () -> !om.class.type<@Class2>"
" %3 = om.object.field %2, [@value] : "
"(!om.class.type<@Class2>) -> !om.integer"
""
" %5 = om.integer.add %1, %3 : !om.integer"
" om.class.field @result, %5 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticObjects"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

ASSERT_EQ(3, llvm::cast<evaluator::AttributeValue>(fieldValue.get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

TEST(EvaluatorTests, IntegerBinaryArithmeticObjectsDelayed) {
StringRef mod =
"om.class @Class1(%input: !om.integer) {"
" %0 = om.constant #om.integer<1 : si3> : !om.integer"
" om.class.field @value, %0 : !om.integer"
" om.class.field @input, %input : !om.integer"
"}"
""
"om.class @Class2() {"
" %0 = om.constant #om.integer<2 : si3> : !om.integer"
" om.class.field @value, %0 : !om.integer"
"}"
""
"om.class @IntegerBinaryArithmeticObjectsDelayed() {"
" %0 = om.object @Class1(%5) : (!om.integer) -> !om.class.type<@Class1>"
" %1 = om.object.field %0, [@value] : "
"(!om.class.type<@Class1>) -> !om.integer"
""
" %2 = om.object @Class2() : () -> !om.class.type<@Class2>"
" %3 = om.object.field %2, [@value] : "
"(!om.class.type<@Class2>) -> !om.integer"
""
" %5 = om.integer.add %1, %3 : !om.integer"
" om.class.field @result, %5 : !om.integer"
"}";

DialectRegistry registry;
registry.insert<OMDialect>();

MLIRContext context(registry);
context.getOrLoadDialect<OMDialect>();

OwningOpRef<ModuleOp> owning =
parseSourceString<ModuleOp>(mod, ParserConfig(&context));

Evaluator evaluator(owning.release());

auto result = evaluator.instantiate(
StringAttr::get(&context, "IntegerBinaryArithmeticObjectsDelayed"), {});

ASSERT_TRUE(succeeded(result));

auto fieldValue = llvm::cast<evaluator::ObjectValue>(result.value().get())
->getField("result")
.value();

ASSERT_EQ(3, llvm::cast<evaluator::AttributeValue>(fieldValue.get())
->getAs<circt::om::IntegerAttr>()
.getValue()
.getValue());
}

} // namespace
Loading