Skip to content

Commit

Permalink
[FIRRTL] Canoncializations of not( cmp ) (#6957)
Browse files Browse the repository at this point in the history
Simple peephole optimization.  Triggers on real designs.
  • Loading branch information
darthscsi authored Apr 29, 2024
1 parent 81789e2 commit 8acaeb5
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 12 deletions.
82 changes: 82 additions & 0 deletions include/circt/Dialect/FIRRTL/FIRRTLCanonicalization.td
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,42 @@ def NotNot : Pat <
(MoveNameHint $old, (AsUIntPrimOp $x)),
[]>;

// not(eq(x,y)) -> neq(x,y)
def NotEq : Pat <
(NotPrimOp:$old (EQPrimOp $x, $y)),
(MoveNameHint $old, (NEQPrimOp $x, $y)),
[]>;

// not(neq(x,y)) -> eq(x,y)
def NotNeq : Pat <
(NotPrimOp:$old (NEQPrimOp $x, $y)),
(MoveNameHint $old, (EQPrimOp $x, $y)),
[]>;

// not(leq(x,y)) -> gt(x,y)
def NotLeq : Pat <
(NotPrimOp:$old (LEQPrimOp $x, $y)),
(MoveNameHint $old, (GTPrimOp $x, $y)),
[]>;

// not(lt(x,y)) -> geq(x,y)
def NotLt : Pat <
(NotPrimOp:$old (LTPrimOp $x, $y)),
(MoveNameHint $old, (GEQPrimOp $x, $y)),
[]>;

// not(geq(x,y)) -> lt(x,y)
def NotGeq : Pat <
(NotPrimOp:$old (GEQPrimOp $x, $y)),
(MoveNameHint $old, (LTPrimOp $x, $y)),
[]>;

// not(gt(x,y)) -> leq(x,y)
def NotGt : Pat <
(NotPrimOp:$old (GTPrimOp $x, $y)),
(MoveNameHint $old, (LEQPrimOp $x, $y)),
[]>;

// connect(a:t,b:t) -> strictconnect(a,b)
def ConnectSameType : Pat<
(ConnectOp $dst, $src),
Expand Down Expand Up @@ -314,6 +350,8 @@ def AndOfPad : Pat <
(MoveNameHint $old, (PadPrimOp (AndPrimOp (TailPrimOp $x, (TypeWidthAdjust32 $x, $y)), $y), $n)),
[(KnownWidth $x), (UIntType $x), (EqualTypes $x, $pad)]>;



def AndOfAsSIntL : Pat<
(AndPrimOp:$old (AsSIntPrimOp $x), $y),
(MoveNameHint $old, (AndPrimOp $x, (AsUIntPrimOp $y))),
Expand Down Expand Up @@ -349,6 +387,13 @@ def OrOfPad : Pat <
(OrPrimOp (TailPrimOp $x, (TypeWidthAdjust32 $x, $y)), $y))),
[(KnownWidth $x), (UIntType $x), (EqualTypes $x, $pad)]>;

/// or(a, orr(x)) -> orr(cat(a, x))
def OrOrr : Pat <
(OrPrimOp:$old $a, (OrRPrimOp $x)),
(MoveNameHint $old, (OrRPrimOp (CatPrimOp (AsUIntPrimOp $a), (AsUIntPrimOp $x)))),
// No type extension in the Or
[(IntTypeWidthLTX<2> $a)]>;

// xor(x, 0) -> x, fold can't handle all cases
def XorOfZero : Pat <
(XorPrimOp:$old $x, (ConstantOp:$zcst $_)),
Expand Down Expand Up @@ -398,6 +443,18 @@ def OrRPadU : Pat <
(MoveNameHint $old, (OrRPrimOp $x)),
[(KnownWidth $x), (UIntTypes $x), (IntTypeWidthGEQ32 $pad, $x)]>;

/// orr(cat(y, orr(z))) -> orr(cat(y,z))
def OrRCatOrR_left : Pat <
(OrRPrimOp:$old (CatPrimOp $y, (OrRPrimOp $z))),
(MoveNameHint $old, (OrRPrimOp (CatPrimOp $y, (AsUIntPrimOp $z)))),
[]>;

/// orr(cat(orr(z),y)) -> orr(cat(z,y))
def OrRCatOrR_right : Pat <
(OrRPrimOp:$old (CatPrimOp (OrRPrimOp $z), $y)),
(MoveNameHint $old, (OrRPrimOp (CatPrimOp (AsUIntPrimOp $z), $y))),
[]>;

/// xorr(asSInt(x)) -> xorr(x)
def XorRasSInt : Pat <
(XorRPrimOp:$old (AsSIntPrimOp $x)),
Expand Down Expand Up @@ -428,6 +485,19 @@ def XorRPadU : Pat <
(MoveNameHint $old, (XorRPrimOp $x)),
[(KnownWidth $x), (UIntTypes $x), (IntTypeWidthGEQ32 $pad, $x)]>;

/// xorr(cat(y, xorr(z))) -> xorr(cat(y,z))
def XorRCatXorR_left : Pat <
(XorRPrimOp:$old (CatPrimOp $y, (XorRPrimOp $z))),
(MoveNameHint $old, (XorRPrimOp (CatPrimOp $y, (AsUIntPrimOp $z)))),
[]>;

/// xorr(cat(xorr(z),y)) -> xorr(cat(z,y))
def XorRCatXorR_right : Pat <
(XorRPrimOp:$old (CatPrimOp (XorRPrimOp $z), $y)),
(MoveNameHint $old, (XorRPrimOp (CatPrimOp (AsUIntPrimOp $z), $y))),
[]>;


/// andr(asSInt(x)) -> andr(x)
def AndRasSInt : Pat <
(AndRPrimOp:$old (AsSIntPrimOp $x)),
Expand Down Expand Up @@ -476,6 +546,18 @@ def AndRPadS : Pat <
(MoveNameHint $old, (AndRPrimOp $x)),
[(KnownWidth $x), (SIntTypes $x)]>;

/// andr(cat(y, andr(z))) -> andr(cat(y,z))
def AndRCatAndR_left : Pat <
(AndRPrimOp:$old (CatPrimOp $y, (AndRPrimOp $z))),
(MoveNameHint $old, (AndRPrimOp (CatPrimOp $y, (AsUIntPrimOp $z)))),
[]>;

/// andr(cat(andr(z),y)) -> andr(cat(z,y))
def AndRCatAndR_right : Pat <
(AndRPrimOp:$old (CatPrimOp (AndRPrimOp $z), $y)),
(MoveNameHint $old, (AndRPrimOp (CatPrimOp (AsUIntPrimOp $z), $y))),
[]>;

// bits(bits(x, ...), ...) -> bits(x, ...)
def BitsOfBits : Pat<
(BitsPrimOp:$old (BitsPrimOp $x, I32Attr:$iHigh, I32Attr:$iLow), I32Attr:$oHigh, I32Attr:$oLow),
Expand Down
19 changes: 12 additions & 7 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,8 @@ OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad>(
context);
patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
patterns::OrOrr>(context);
}

OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
Expand Down Expand Up @@ -879,7 +879,6 @@ LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
.getResult();
}
}

return {};
});
}
Expand Down Expand Up @@ -1079,7 +1078,9 @@ OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {

void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::NotNot>(context);
results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
patterns::NotGt>(context);
}

OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
Expand All @@ -1106,7 +1107,8 @@ void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
results
.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
patterns::AndRCatZeroL, patterns::AndRCatZeroR>(context);
patterns::AndRCatZeroL, patterns::AndRCatZeroR,
patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
}

OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
Expand All @@ -1131,7 +1133,8 @@ OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
patterns::OrRCatZeroH, patterns::OrRCatZeroL>(context);
patterns::OrRCatZeroH, patterns::OrRCatZeroL,
patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
}

OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
Expand All @@ -1155,7 +1158,9 @@ OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
patterns::XorRCatZeroH, patterns::XorRCatZeroL>(context);
patterns::XorRCatZeroH, patterns::XorRCatZeroL,
patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
context);
}

//===----------------------------------------------------------------------===//
Expand Down
90 changes: 85 additions & 5 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ firrtl.module @Xor(in %in: !firrtl.uint<4>,
// CHECK-LABEL: firrtl.module @Not
firrtl.module @Not(in %in: !firrtl.uint<4>,
in %sin: !firrtl.sint<4>,
out %out1: !firrtl.uint<1>,
out %outu: !firrtl.uint<4>,
out %outs: !firrtl.uint<4>) {
%0 = firrtl.not %in : (!firrtl.uint<4>) -> !firrtl.uint<4>
Expand All @@ -330,6 +331,44 @@ firrtl.module @Not(in %in: !firrtl.uint<4>,
// CHECK: firrtl.strictconnect %outu, %in
// CHECK: %[[cast:.*]] = firrtl.asUInt %sin
// CHECK: firrtl.strictconnect %outs, %[[cast]]

%c5_ui4 = firrtl.constant 5 : !firrtl.uint<4>
%5 = firrtl.eq %in, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%6 = firrtl.not %5 : (!firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out1, %6 : !firrtl.uint<1>
// CHECK: firrtl.neq
// CHECK-NEXT: firrtl.strictconnect

%7 = firrtl.neq %in, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%8 = firrtl.not %7 : (!firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out1, %8 : !firrtl.uint<1>
// CHECK: firrtl.eq
// CHECK-NEXT: firrtl.strictconnect

%9 = firrtl.lt %in, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%10 = firrtl.not %9 : (!firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out1, %10 : !firrtl.uint<1>
// CHECK: firrtl.geq
// CHECK-NEXT: firrtl.strictconnect

%11 = firrtl.leq %in, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%12 = firrtl.not %11 : (!firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out1, %12 : !firrtl.uint<1>
// CHECK: firrtl.gt
// CHECK-NEXT: firrtl.strictconnect

%13 = firrtl.gt %in, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%14 = firrtl.not %13 : (!firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out1, %14 : !firrtl.uint<1>
// CHECK: firrtl.leq
// CHECK-NEXT: firrtl.strictconnect

%15 = firrtl.geq %in, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%16 = firrtl.not %15 : (!firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out1, %16 : !firrtl.uint<1>
// CHECK: firrtl.lt
// CHECK-NEXT: firrtl.strictconnect

}

// CHECK-LABEL: firrtl.module @EQ
Expand Down Expand Up @@ -365,6 +404,14 @@ firrtl.module @EQ(in %in1: !firrtl.uint<1>,
// CHECK: [[ORR:%.+]] = firrtl.orr %in4
// CHECK-NEXT: firrtl.not [[ORR]]
// CHECK-NEXT: firrtl.strictconnect

%c5_ui4 = firrtl.constant 5 : !firrtl.uint<4>
%5 = firrtl.neq %in4, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%6 = firrtl.eq %5, %c0_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out, %6 : !firrtl.uint<1>
// CHECK: firrtl.eq
// CHECK-NEXT: firrtl.strictconnect

}

// CHECK-LABEL: firrtl.module @NEQ
Expand Down Expand Up @@ -393,6 +440,13 @@ firrtl.module @NEQ(in %in1: !firrtl.uint<1>,
// CHECK: [[ANDR:%.+]] = firrtl.andr %in4
// CHECK-NEXT: firrtl.not [[ANDR]]
// CHECK-NEXT: firrtl.strictconnect

%c5_ui4 = firrtl.constant 5 : !firrtl.uint<4>
%5 = firrtl.eq %in4, %c5_ui4 : (!firrtl.uint<4>, !firrtl.uint<4>) -> !firrtl.uint<1>
%6 = firrtl.neq %5, %c1_ui1 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
firrtl.strictconnect %out, %6 : !firrtl.uint<1>
// CHECK: firrtl.neq
// CHECK-NEXT: firrtl.strictconnect
}

// CHECK-LABEL: firrtl.module @Cat
Expand Down Expand Up @@ -732,11 +786,12 @@ firrtl.module @Tail(in %in4u: !firrtl.uint<4>,

// CHECK-LABEL: firrtl.module @Andr
firrtl.module @Andr(in %in0 : !firrtl.uint<0>, in %in1 : !firrtl.sint<2>,
in %in2 : !firrtl.uint<2>,
out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>,
out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>,
out %e: !firrtl.uint<1>, out %f: !firrtl.uint<1>,
out %g: !firrtl.uint<1>, in %h : !firrtl.uint<64>,
out %i: !firrtl.uint<1>) {
out %i: !firrtl.uint<1>, out %j: !firrtl.uint<1>) {
%invalid_ui2 = firrtl.invalidvalue : !firrtl.uint<2>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
%c3_ui2 = firrtl.constant 3 : !firrtl.uint<2>
Expand Down Expand Up @@ -783,14 +838,22 @@ firrtl.module @Andr(in %in0 : !firrtl.uint<0>, in %in1 : !firrtl.sint<2>,
%11 = firrtl.pad %in1, 3 : (!firrtl.sint<2>) -> !firrtl.sint<3>
%12 = firrtl.andr %11 : (!firrtl.sint<3>) -> !firrtl.uint<1>
firrtl.strictconnect %i, %12 : !firrtl.uint<1>

// CHECK: %[[cat:.*]] = firrtl.cat %in2, %h
// CHECK-NEXT: firrtl.andr %[[cat]]
%13 = firrtl.andr %in2 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%14 = firrtl.cat %13, %h : (!firrtl.uint<1>, !firrtl.uint<64>) -> !firrtl.uint<65>
%15 = firrtl.andr %14 : (!firrtl.uint<65>) -> !firrtl.uint<1>
firrtl.strictconnect %j, %15 : !firrtl.uint<1>
}

// CHECK-LABEL: firrtl.module @Orr
firrtl.module @Orr(in %in0 : !firrtl.uint<0>,
firrtl.module @Orr(in %in0 : !firrtl.uint<0>, in %in2 : !firrtl.uint<2>,
out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>,
out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>,
out %e: !firrtl.uint<1>, out %f: !firrtl.uint<1>,
out %g: !firrtl.uint<1>, in %h : !firrtl.uint<64>) {
out %g: !firrtl.uint<1>, in %h : !firrtl.uint<64>,
out %j : !firrtl.uint<1>) {
%invalid_ui2 = firrtl.invalidvalue : !firrtl.uint<2>
%c0_ui2 = firrtl.constant 0 : !firrtl.uint<2>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
Expand Down Expand Up @@ -823,14 +886,23 @@ firrtl.module @Orr(in %in0 : !firrtl.uint<0>,
%9 = firrtl.cvt %8 : (!firrtl.uint<68>) -> !firrtl.sint<69>
%10 = firrtl.orr %9 : (!firrtl.sint<69>) -> !firrtl.uint<1>
firrtl.strictconnect %g, %10 : !firrtl.uint<1>

// CHECK: %[[cat:.*]] = firrtl.cat %in2, %h
// CHECK-NEXT: firrtl.orr %[[cat]]
%13 = firrtl.orr %in2 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%14 = firrtl.cat %13, %h : (!firrtl.uint<1>, !firrtl.uint<64>) -> !firrtl.uint<65>
%15 = firrtl.orr %14 : (!firrtl.uint<65>) -> !firrtl.uint<1>
firrtl.strictconnect %j, %15 : !firrtl.uint<1>

}

// CHECK-LABEL: firrtl.module @Xorr
firrtl.module @Xorr(in %in0 : !firrtl.uint<0>,
firrtl.module @Xorr(in %in0 : !firrtl.uint<0>, in %in2 : !firrtl.uint<2>,
out %a: !firrtl.uint<1>, out %b: !firrtl.uint<1>,
out %c: !firrtl.uint<1>, out %d: !firrtl.uint<1>,
out %e: !firrtl.uint<1>, out %f: !firrtl.uint<1>,
out %g: !firrtl.uint<1>, in %h : !firrtl.uint<64>) {
out %g: !firrtl.uint<1>, in %h : !firrtl.uint<64>,
out %j : !firrtl.uint<1>) {
%invalid_ui2 = firrtl.invalidvalue : !firrtl.uint<2>
%c3_ui2 = firrtl.constant 3 : !firrtl.uint<2>
%c2_ui2 = firrtl.constant 2 : !firrtl.uint<2>
Expand Down Expand Up @@ -864,6 +936,14 @@ firrtl.module @Xorr(in %in0 : !firrtl.uint<0>,
%9 = firrtl.cvt %8 : (!firrtl.uint<68>) -> !firrtl.sint<69>
%10 = firrtl.xorr %9 : (!firrtl.sint<69>) -> !firrtl.uint<1>
firrtl.strictconnect %g, %10 : !firrtl.uint<1>

// CHECK: %[[cat:.*]] = firrtl.cat %in2, %h
// CHECK-NEXT: firrtl.xorr %[[cat]]
%13 = firrtl.xorr %in2 : (!firrtl.uint<2>) -> !firrtl.uint<1>
%14 = firrtl.cat %13, %h : (!firrtl.uint<1>, !firrtl.uint<64>) -> !firrtl.uint<65>
%15 = firrtl.xorr %14 : (!firrtl.uint<65>) -> !firrtl.uint<1>
firrtl.strictconnect %j, %15 : !firrtl.uint<1>

}

// CHECK-LABEL: firrtl.module @Reduce
Expand Down

0 comments on commit 8acaeb5

Please sign in to comment.