Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CombToSMT] Make result of div-by-zero undefined #7025

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions integration_test/circt-lec/comb.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) {
// TODO

// comb.divs
// TODO
// RUN: circt-lec %s -c1=divs_unsafe -c2=divs_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS_UNSAFE
// RUN: circt-lec %s -c1=divs -c2=divs --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVS
// COMB_DIVS_UNSAFE: c1 != c2
// COMB_DIVS: c1 == c2

hw.module @divs_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.divs %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @divs(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.divs %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.divu
// TODO
// RUN: circt-lec %s -c1=divu_unsafe -c2=divu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU_UNSAFE
// RUN: circt-lec %s -c1=divu -c2=divu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_DIVU
// COMB_DIVU_UNSAFE: c1 != c2
// COMB_DIVU: c1 == c2

hw.module @divu_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.divu %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @divu(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.divu %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.extract
// TODO
Expand All @@ -71,10 +103,42 @@ hw.module @decomposedAnd(in %in1: i1, in %in2: i1, out out: i1) {
// TODO

// comb.mods
// TODO
// RUN: circt-lec %s -c1=mods_unsafe -c2=mods_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS_UNSAFE
// RUN: circt-lec %s -c1=mods -c2=mods --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODS
// COMB_MODS_UNSAFE: c1 != c2
// COMB_MODS: c1 == c2

hw.module @mods_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.mods %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @mods(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.mods %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.modu
// TODO
// RUN: circt-lec %s -c1=modu_unsafe -c2=modu_unsafe --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU_UNSAFE
// RUN: circt-lec %s -c1=modu -c2=modu --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MODU
// COMB_MODU_UNSAFE: c1 != c2
// COMB_MODU: c1 == c2

hw.module @modu_unsafe(in %in1: i32, in %in2: i32, out out: i32) {
%0 = comb.modu %in1, %in2 : i32
hw.output %0 : i32
}

hw.module @modu(in %in1: i32, in %in2: i32, out out: i32) {
%0 = hw.constant 0 : i32
%1 = comb.icmp eq %in2, %0 : i32
%2 = comb.modu %in1, %in2 : i32
%3 = comb.mux %1, %0, %2 : i32
hw.output %3 : i32
}

// comb.mul
// RUN: circt-lec %s -c1=mulBy2 -c2=addTwice --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_MUL
Expand Down
36 changes: 32 additions & 4 deletions lib/Conversion/CombToSMT/CombToSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,34 @@ struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
}
};

/// Lower the SourceOp to the TargetOp special-casing if the second operand is
/// zero to return a new symbolic value.
template <typename SourceOp, typename TargetOp>
struct DivisionOpConversion : OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename SourceOp::Adaptor;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
if (!type)
return failure();

auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
op.getResult().getType());
Value zero =
rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
Value division =
rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
return success();
}
};

/// Converts an operation with a variadic number of operands to a chain of
/// binary operations assuming left-associativity of the operation.
template <typename SourceOp, typename TargetOp>
Expand Down Expand Up @@ -236,10 +264,10 @@ void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
OneToOneOpConversion<ShlOp, smt::BVShlOp>,
OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
OneToOneOpConversion<DivSOp, smt::BVSDivOp>,
OneToOneOpConversion<DivUOp, smt::BVUDivOp>,
OneToOneOpConversion<ModSOp, smt::BVSRemOp>,
OneToOneOpConversion<ModUOp, smt::BVURemOp>,
DivisionOpConversion<DivSOp, smt::BVSDivOp>,
DivisionOpConversion<DivUOp, smt::BVUDivOp>,
DivisionOpConversion<ModSOp, smt::BVSRemOp>,
DivisionOpConversion<ModUOp, smt::BVURemOp>,
VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
Expand Down
49 changes: 22 additions & 27 deletions lib/Tools/circt-lec/ConstructLEC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,37 +110,32 @@ void ConstructLECPass::runOnOperation() {

builder.createBlock(&entryFunc.getBody());

Value areEquivalent;
if (moduleA == moduleB) {
// Trivially equivalent
areEquivalent =
builder.create<LLVM::ConstantOp>(loc, builder.getI1Type(), 1);
moduleA->erase();
} else {
auto lecOp = builder.create<verif::LogicEquivalenceCheckingOp>(loc);
areEquivalent = lecOp.getAreEquivalent();
auto *outputOpA = moduleA.getBodyBlock()->getTerminator();
auto *outputOpB = moduleB.getBodyBlock()->getTerminator();
lecOp.getFirstCircuit().takeBody(moduleA.getBody());
lecOp.getSecondCircuit().takeBody(moduleB.getBody());

moduleA->erase();
auto lecOp = builder.create<verif::LogicEquivalenceCheckingOp>(loc);
Value areEquivalent = lecOp.getAreEquivalent();
builder.cloneRegionBefore(moduleA.getBody(), lecOp.getFirstCircuit(),
lecOp.getFirstCircuit().end());
builder.cloneRegionBefore(moduleB.getBody(), lecOp.getSecondCircuit(),
lecOp.getSecondCircuit().end());

moduleA->erase();
if (moduleA != moduleB)
moduleB->erase();

{
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(outputOpA);
builder.create<verif::YieldOp>(loc, outputOpA->getOperands());
outputOpA->erase();
builder.setInsertionPoint(outputOpB);
builder.create<verif::YieldOp>(loc, outputOpB->getOperands());
outputOpB->erase();
}

sortTopologically(&lecOp.getFirstCircuit().front());
sortTopologically(&lecOp.getSecondCircuit().front());
{
auto *term = lecOp.getFirstCircuit().front().getTerminator();
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPoint(term);
builder.create<verif::YieldOp>(loc, term->getOperands());
term->erase();
term = lecOp.getSecondCircuit().front().getTerminator();
builder.setInsertionPoint(term);
builder.create<verif::YieldOp>(loc, term->getOperands());
term->erase();
}

sortTopologically(&lecOp.getFirstCircuit().front());
sortTopologically(&lecOp.getSecondCircuit().front());

Comment on lines +113 to +138
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change just to ensure a module isn't trivially found to be equivalent to itself if its output is undefined?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly.

// TODO: we should find a more elegant way of reporting the result than
// already inserting some LLVM here
Value eqFormatString =
Expand Down
24 changes: 20 additions & 4 deletions test/Conversion/CombToSMT/comb-to-smt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,29 @@ func.func @test(%a0: !smt.bv<32>, %a1: !smt.bv<32>, %a2: !smt.bv<32>, %a3: !smt.
%arg4 = builtin.unrealized_conversion_cast %a4 : !smt.bv<1> to i1
%arg5 = builtin.unrealized_conversion_cast %a5 : !smt.bv<4> to i4

// CHECK: smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.sdiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%0 = comb.divs %arg0, %arg1 : i32
// CHECK-NEXT: smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.udiv [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%1 = comb.divu %arg0, %arg1 : i32
// CHECK-NEXT: smt.bv.srem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.srem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%2 = comb.mods %arg0, %arg1 : i32
// CHECK-NEXT: smt.bv.urem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: [[ZERO:%.+]] = smt.bv.constant #smt.bv<0> : !smt.bv<32>
// CHECK-NEXT: [[IS_ZERO:%.+]] = smt.eq [[A1]], [[ZERO]] : !smt.bv<32>
// CHECK-NEXT: [[UNDEF:%.+]] = smt.declare_fun : !smt.bv<32>
// CHECK-NEXT: [[DIV:%.+]] = smt.bv.urem [[A0]], [[A1]] : !smt.bv<32>
// CHECK-NEXT: smt.ite [[IS_ZERO]], [[UNDEF]], [[DIV]] : !smt.bv<32>
%3 = comb.modu %arg0, %arg1 : i32

// CHECK-NEXT: [[NEG:%.+]] = smt.bv.neg [[A1]] : !smt.bv<32>
Expand Down
Loading