Skip to content

Commit

Permalink
[flang] Handle special case for SHIFTA intrinsic
Browse files Browse the repository at this point in the history
This patch update the lowering of the shifta intrinsic to match
the behvior of gfortran. When the SHIFT value is equal to the
integer bitwidth then we handle it differently.
This is due to the operation used in lowering (`mlir::arith::ShRSIOp`)
that lowers to `ashr`.

Before this patch we have the following results:

```
SHIFTA(  -1, 8) =  0
SHIFTA(  -2, 8) =  0
SHIFTA( -30, 8) =  0
SHIFTA( -31, 8) =  0
SHIFTA( -32, 8) =  0
SHIFTA( -33, 8) =  0
SHIFTA(-126, 8) =  0
SHIFTA(-127, 8) =  0
SHIFTA(-128, 8) =  0
```

While gfortran is giving this:

```
SHIFTA(  -1, 8) = -1
SHIFTA(  -2, 8) = -1
SHIFTA( -30, 8) = -1
SHIFTA( -31, 8) = -1
SHIFTA( -32, 8) = -1
SHIFTA( -33, 8) = -1
SHIFTA(-126, 8) = -1
SHIFTA(-127, 8) = -1
SHIFTA(-128, 8) = -1
```

With this patch flang and gfortran have the same behavior.

Reviewed By: jeanPerier

Differential Revision: https://reviews.llvm.org/D133104
  • Loading branch information
clementval committed Sep 1, 2022
1 parent 4f046bc commit 83ba00b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 32 deletions.
30 changes: 28 additions & 2 deletions flang/lib/Lower/IntrinsicCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,7 @@ struct IntrinsicLibrary {
llvm::ArrayRef<mlir::Value> args);
template <typename Shift>
mlir::Value genShift(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
mlir::Value genShiftA(mlir::Type resultType, llvm::ArrayRef<mlir::Value>);
mlir::Value genSign(mlir::Type, llvm::ArrayRef<mlir::Value>);
fir::ExtendedValue genSize(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
mlir::Value genSpacing(mlir::Type resultType,
Expand Down Expand Up @@ -958,7 +959,7 @@ static constexpr IntrinsicHandler handlers[]{
{"radix", asAddr, handleDynamicOptional}}},
/*isElemental=*/false},
{"set_exponent", &I::genSetExponent},
{"shifta", &I::genShift<mlir::arith::ShRSIOp>},
{"shifta", &I::genShiftA},
{"shiftl", &I::genShift<mlir::arith::ShLIOp>},
{"shiftr", &I::genShift<mlir::arith::ShRUIOp>},
{"sign", &I::genSign},
Expand Down Expand Up @@ -4015,7 +4016,7 @@ mlir::Value IntrinsicLibrary::genSetExponent(mlir::Type resultType,
fir::getBase(args[1])));
}

// SHIFTA, SHIFTL, SHIFTR
// SHIFTL, SHIFTR
template <typename Shift>
mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
Expand All @@ -4041,6 +4042,31 @@ mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType,
return builder.create<mlir::arith::SelectOp>(loc, outOfBounds, zero, shifted);
}

// SHIFTA
mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
unsigned bits = resultType.getIntOrFloatBitWidth();
mlir::Value bitSize = builder.createIntegerConstant(loc, resultType, bits);
mlir::Value shift = builder.createConvert(loc, resultType, args[1]);
mlir::Value shiftEqBitSize = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::eq, shift, bitSize);

// Lowering of mlir::arith::ShRSIOp is using `ashr`. `ashr` is undefined when
// the shift amount is equal to the element size.
// So if SHIFT is equal to the bit width then it is handled as a special case.
mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0);
mlir::Value minusOne = builder.createIntegerConstant(loc, resultType, -1);
mlir::Value valueIsNeg = builder.create<mlir::arith::CmpIOp>(
loc, mlir::arith::CmpIPredicate::slt, args[0], zero);
mlir::Value specialRes =
builder.create<mlir::arith::SelectOp>(loc, valueIsNeg, minusOne, zero);

mlir::Value shifted =
builder.create<mlir::arith::ShRSIOp>(loc, args[0], shift);
return builder.create<mlir::arith::SelectOp>(loc, shiftEqBitSize, specialRes,
shifted);
}

// SIGN
mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType,
llvm::ArrayRef<mlir::Value> args) {
Expand Down
65 changes: 35 additions & 30 deletions flang/test/Lower/Intrinsics/shifta.f90
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ subroutine shifta1_test(a, b, c)
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
c = shifta(a, b)
! CHECK: %[[C_BITS:.*]] = arith.constant 8 : i8
! CHECK: %[[C_0:.*]] = arith.constant 0 : i8
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i8
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i8
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i8
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i8
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i8
! CHECK: %[[C0:.*]] = arith.constant 0 : i8
! CHECK: %[[CM1:.*]] = arith.constant -1 : i8
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i8
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i8
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i8
end subroutine shifta1_test

! CHECK-LABEL: shifta2_test
Expand All @@ -32,13 +33,14 @@ subroutine shifta2_test(a, b, c)
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
c = shifta(a, b)
! CHECK: %[[C_BITS:.*]] = arith.constant 16 : i16
! CHECK: %[[C_0:.*]] = arith.constant 0 : i16
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i16
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i16
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i16
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i16
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i16
! CHECK: %[[C0:.*]] = arith.constant 0 : i16
! CHECK: %[[CM1:.*]] = arith.constant -1 : i16
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i16
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i16
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i16
end subroutine shifta2_test

! CHECK-LABEL: shifta4_test
Expand All @@ -52,12 +54,13 @@ subroutine shifta4_test(a, b, c)
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
c = shifta(a, b)
! CHECK: %[[C_BITS:.*]] = arith.constant 32 : i32
! CHECK: %[[C_0:.*]] = arith.constant 0 : i32
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_VAL]], %[[C_0]] : i32
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_VAL]], %[[C_BITS]] : i32
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i32
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_VAL]], %[[C_BITS]] : i32
! CHECK: %[[C0:.*]] = arith.constant 0 : i32
! CHECK: %[[CM1:.*]] = arith.constant -1 : i32
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i32
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i32
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i32
end subroutine shifta4_test

! CHECK-LABEL: shifta8_test
Expand All @@ -71,13 +74,14 @@ subroutine shifta8_test(a, b, c)
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
c = shifta(a, b)
! CHECK: %[[C_BITS:.*]] = arith.constant 64 : i64
! CHECK: %[[C_0:.*]] = arith.constant 0 : i64
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i64
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i64
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i64
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i64
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i64
! CHECK: %[[C0:.*]] = arith.constant 0 : i64
! CHECK: %[[CM1:.*]] = arith.constant -1 : i64
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i64
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i64
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i64
end subroutine shifta8_test

! CHECK-LABEL: shifta16_test
Expand All @@ -91,11 +95,12 @@ subroutine shifta16_test(a, b, c)
! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref<i32>
c = shifta(a, b)
! CHECK: %[[C_BITS:.*]] = arith.constant 128 : i128
! CHECK: %[[C_0:.*]] = arith.constant 0 : i128
! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i128
! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i128
! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i128
! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1
! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i128
! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i128
! CHECK: %[[C0:.*]] = arith.constant 0 : i128
! CHECK: %[[CM1:.*]] = arith.constant {{.*}} : i128
! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i128
! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i128
! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128
! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i128
end subroutine shifta16_test

0 comments on commit 83ba00b

Please sign in to comment.