Skip to content

Commit

Permalink
[CombToArith] Fix coarsening of division by zero UB
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed May 4, 2024
1 parent 62fbb3a commit f443494
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 7 deletions.
49 changes: 49 additions & 0 deletions integration_test/arcilator/JIT/div-by-zero.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit
// CHECK: divu = 0{{$}}
// CHECK: divs = 0{{$}}
// CHECK: modu = 0{{$}}
// CHECK: mods = 0{{$}}

hw.module @Baz(in %a: i8, in %b: i8, out divu: i8, out divs: i8, out modu: i8, out mods: i8) {
%zero = hw.constant 0 : i8
%is_zero = comb.icmp eq %zero, %b : i8

%0 = comb.divu %a, %b : i8
%1 = comb.mux %is_zero, %zero, %0 : i8

%2 = comb.divs %a, %b : i8
%3 = comb.mux %is_zero, %zero, %2 : i8

%4 = comb.modu %a, %b : i8
%5 = comb.mux %is_zero, %zero, %4 : i8

%6 = comb.mods %a, %b : i8
%7 = comb.mux %is_zero, %zero, %6 : i8

hw.output %1, %3, %5, %7 : i8, i8, i8, i8
}

func.func @main() {
%eight = arith.constant 8 : i8
%zero = arith.constant 0 : i8

arc.sim.instantiate @Baz as %model {
arc.sim.set_input %model, "a" = %eight : i8, !arc.sim.instance<@Baz>
arc.sim.set_input %model, "b" = %zero : i8, !arc.sim.instance<@Baz>

arc.sim.step %model : !arc.sim.instance<@Baz>

%divu = arc.sim.get_port %model, "divu" : i8, !arc.sim.instance<@Baz>
%divs = arc.sim.get_port %model, "divs" : i8, !arc.sim.instance<@Baz>
%modu = arc.sim.get_port %model, "modu" : i8, !arc.sim.instance<@Baz>
%mods = arc.sim.get_port %model, "mods" : i8, !arc.sim.instance<@Baz>

arc.sim.emit "divu", %divu : i8
arc.sim.emit "divs", %divs : i8
arc.sim.emit "modu", %modu : i8
arc.sim.emit "mods", %mods : i8
}

return
}
30 changes: 27 additions & 3 deletions lib/Conversion/CombToArith/CombToArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,30 @@ struct BinaryOpConversion : OpConversionPattern<SourceOp> {
}
};

/// Lowering for division operations that need to special-case zero-value
/// divisors to not run coarser UB than CIRCT defines.
template <typename SourceOp, typename TargetOp>
struct DivOpConversion : OpConversionPattern<SourceOp> {
using OpConversionPattern<SourceOp>::OpConversionPattern;
using OpAdaptor = typename SourceOp::Adaptor;

LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value zero = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 0));
Value one = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 1));
Value isZero = rewriter.create<arith::CmpIOp>(loc, CmpIPredicate::eq,
adaptor.getRhs(), zero);
Value divisor =
rewriter.create<arith::SelectOp>(loc, isZero, one, adaptor.getRhs());
rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(), divisor);
return success();
}
};

/// Lower a comb::ReplicateOp operation to the LLVM dialect.
template <typename SourceOp, typename TargetOp>
struct VariadicOpConversion : OpConversionPattern<SourceOp> {
Expand Down Expand Up @@ -282,9 +306,9 @@ void circt::populateCombToArithConversionPatterns(
ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
LogicalShiftConversion<ShlOp, ShLIOp>,
LogicalShiftConversion<ShrUOp, ShRUIOp>,
BinaryOpConversion<SubOp, SubIOp>, BinaryOpConversion<DivSOp, DivSIOp>,
BinaryOpConversion<DivUOp, DivUIOp>, BinaryOpConversion<ModSOp, RemSIOp>,
BinaryOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
BinaryOpConversion<SubOp, SubIOp>, DivOpConversion<DivSOp, DivSIOp>,
DivOpConversion<DivUOp, DivUIOp>, DivOpConversion<ModSOp, RemSIOp>,
DivOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
VariadicOpConversion<XorOp, XOrIOp>>(converter, patterns.getContext());
Expand Down
16 changes: 12 additions & 4 deletions test/Conversion/CombToArith/comb-to-arith.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,21 @@ hw.module @test(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32, in %
// CHECK-NEXT: %c42_i32 = arith.constant 42 : i32
%c42_i32 = hw.constant 42 : i32

// CHECK-NEXT: arith.divsi %arg0, %arg1 : i32
// CHECK: [[IS_ZERO:%.+]] = arith.cmpi eq, %arg1, %c0_i32{{.*}}
// CHECK-NEXT: [[DIVISOR:%.+]] = arith.select [[IS_ZERO]], %c1_i32{{.*}}, %arg1
// CHECK-NEXT: arith.divsi %arg0, [[DIVISOR]] : i32
%0 = comb.divs %arg0, %arg1 : i32
// CHECK-NEXT: arith.divui %arg0, %arg1 : i32
// CHECK: [[IS_ZERO:%.+]] = arith.cmpi eq, %arg1, %c0_i32{{.*}}
// CHECK-NEXT: [[DIVISOR:%.+]] = arith.select [[IS_ZERO]], %c1_i32{{.*}}, %arg1
// CHECK-NEXT: arith.divui %arg0, [[DIVISOR]] : i32
%1 = comb.divu %arg0, %arg1 : i32
// CHECK-NEXT: arith.remsi %arg0, %arg1 : i32
// CHECK: [[IS_ZERO:%.+]] = arith.cmpi eq, %arg1, %c0_i32{{.*}}
// CHECK-NEXT: [[DIVISOR:%.+]] = arith.select [[IS_ZERO]], %c1_i32{{.*}}, %arg1
// CHECK-NEXT: arith.remsi %arg0, [[DIVISOR]] : i32
%2 = comb.mods %arg0, %arg1 : i32
// CHECK-NEXT: arith.remui %arg0, %arg1 : i32
// CHECK: [[IS_ZERO:%.+]] = arith.cmpi eq, %arg1, %c0_i32{{.*}}
// CHECK-NEXT: [[DIVISOR:%.+]] = arith.select [[IS_ZERO]], %c1_i32{{.*}}, %arg1
// CHECK-NEXT: arith.remui %arg0, [[DIVISOR]] : i32
%3 = comb.modu %arg0, %arg1 : i32

// CHECK-NEXT: arith.subi %arg0, %arg1 : i32
Expand Down

0 comments on commit f443494

Please sign in to comment.