From 83ba00bd2eb7ac13fe18edecb306c6c9e14d4d67 Mon Sep 17 00:00:00 2001 From: Valentin Clement Date: Thu, 1 Sep 2022 16:27:51 +0200 Subject: [PATCH] [flang] Handle special case for SHIFTA intrinsic This patch update the lowering of the shifta intrinsic to match the behvior of gfortran. When the SHIFT value is equal to the integer bitwidth then we handle it differently. This is due to the operation used in lowering (`mlir::arith::ShRSIOp`) that lowers to `ashr`. Before this patch we have the following results: ``` SHIFTA( -1, 8) = 0 SHIFTA( -2, 8) = 0 SHIFTA( -30, 8) = 0 SHIFTA( -31, 8) = 0 SHIFTA( -32, 8) = 0 SHIFTA( -33, 8) = 0 SHIFTA(-126, 8) = 0 SHIFTA(-127, 8) = 0 SHIFTA(-128, 8) = 0 ``` While gfortran is giving this: ``` SHIFTA( -1, 8) = -1 SHIFTA( -2, 8) = -1 SHIFTA( -30, 8) = -1 SHIFTA( -31, 8) = -1 SHIFTA( -32, 8) = -1 SHIFTA( -33, 8) = -1 SHIFTA(-126, 8) = -1 SHIFTA(-127, 8) = -1 SHIFTA(-128, 8) = -1 ``` With this patch flang and gfortran have the same behavior. Reviewed By: jeanPerier Differential Revision: https://reviews.llvm.org/D133104 --- flang/lib/Lower/IntrinsicCall.cpp | 30 +++++++++++- flang/test/Lower/Intrinsics/shifta.f90 | 65 ++++++++++++++------------ 2 files changed, 63 insertions(+), 32 deletions(-) diff --git a/flang/lib/Lower/IntrinsicCall.cpp b/flang/lib/Lower/IntrinsicCall.cpp index c5a11219852574..545c2d96a28f2d 100644 --- a/flang/lib/Lower/IntrinsicCall.cpp +++ b/flang/lib/Lower/IntrinsicCall.cpp @@ -557,6 +557,7 @@ struct IntrinsicLibrary { llvm::ArrayRef args); template mlir::Value genShift(mlir::Type resultType, llvm::ArrayRef); + mlir::Value genShiftA(mlir::Type resultType, llvm::ArrayRef); mlir::Value genSign(mlir::Type, llvm::ArrayRef); fir::ExtendedValue genSize(mlir::Type, llvm::ArrayRef); mlir::Value genSpacing(mlir::Type resultType, @@ -958,7 +959,7 @@ static constexpr IntrinsicHandler handlers[]{ {"radix", asAddr, handleDynamicOptional}}}, /*isElemental=*/false}, {"set_exponent", &I::genSetExponent}, - {"shifta", &I::genShift}, + {"shifta", &I::genShiftA}, {"shiftl", &I::genShift}, {"shiftr", &I::genShift}, {"sign", &I::genSign}, @@ -4015,7 +4016,7 @@ mlir::Value IntrinsicLibrary::genSetExponent(mlir::Type resultType, fir::getBase(args[1]))); } -// SHIFTA, SHIFTL, SHIFTR +// SHIFTL, SHIFTR template mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType, llvm::ArrayRef args) { @@ -4041,6 +4042,31 @@ mlir::Value IntrinsicLibrary::genShift(mlir::Type resultType, return builder.create(loc, outOfBounds, zero, shifted); } +// SHIFTA +mlir::Value IntrinsicLibrary::genShiftA(mlir::Type resultType, + llvm::ArrayRef args) { + unsigned bits = resultType.getIntOrFloatBitWidth(); + mlir::Value bitSize = builder.createIntegerConstant(loc, resultType, bits); + mlir::Value shift = builder.createConvert(loc, resultType, args[1]); + mlir::Value shiftEqBitSize = builder.create( + loc, mlir::arith::CmpIPredicate::eq, shift, bitSize); + + // Lowering of mlir::arith::ShRSIOp is using `ashr`. `ashr` is undefined when + // the shift amount is equal to the element size. + // So if SHIFT is equal to the bit width then it is handled as a special case. + mlir::Value zero = builder.createIntegerConstant(loc, resultType, 0); + mlir::Value minusOne = builder.createIntegerConstant(loc, resultType, -1); + mlir::Value valueIsNeg = builder.create( + loc, mlir::arith::CmpIPredicate::slt, args[0], zero); + mlir::Value specialRes = + builder.create(loc, valueIsNeg, minusOne, zero); + + mlir::Value shifted = + builder.create(loc, args[0], shift); + return builder.create(loc, shiftEqBitSize, specialRes, + shifted); +} + // SIGN mlir::Value IntrinsicLibrary::genSign(mlir::Type resultType, llvm::ArrayRef args) { diff --git a/flang/test/Lower/Intrinsics/shifta.f90 b/flang/test/Lower/Intrinsics/shifta.f90 index 311206b213d22c..92fda39e84bf87 100644 --- a/flang/test/Lower/Intrinsics/shifta.f90 +++ b/flang/test/Lower/Intrinsics/shifta.f90 @@ -12,13 +12,14 @@ subroutine shifta1_test(a, b, c) ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 8 : i8 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i8 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i8 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i8 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i8 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i8 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i8 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i8 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i8 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i8 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i8 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i8 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i8 end subroutine shifta1_test ! CHECK-LABEL: shifta2_test @@ -32,13 +33,14 @@ subroutine shifta2_test(a, b, c) ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 16 : i16 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i16 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i16 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i16 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i16 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i16 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i16 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i16 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i16 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i16 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i16 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i16 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i16 end subroutine shifta2_test ! CHECK-LABEL: shifta4_test @@ -52,12 +54,13 @@ subroutine shifta4_test(a, b, c) ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 32 : i32 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i32 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_VAL]], %[[C_0]] : i32 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_VAL]], %[[C_BITS]] : i32 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i32 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_VAL]], %[[C_BITS]] : i32 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i32 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i32 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i32 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i32 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_VAL]] : i32 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i32 end subroutine shifta4_test ! CHECK-LABEL: shifta8_test @@ -71,13 +74,14 @@ subroutine shifta8_test(a, b, c) ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 64 : i64 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i64 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i64 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i64 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i64 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i64 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i64 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i64 + ! CHECK: %[[CM1:.*]] = arith.constant -1 : i64 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i64 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i64 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i64 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i64 end subroutine shifta8_test ! CHECK-LABEL: shifta16_test @@ -91,11 +95,12 @@ subroutine shifta16_test(a, b, c) ! CHECK: %[[B_VAL:.*]] = fir.load %[[B]] : !fir.ref c = shifta(a, b) ! CHECK: %[[C_BITS:.*]] = arith.constant 128 : i128 - ! CHECK: %[[C_0:.*]] = arith.constant 0 : i128 ! CHECK: %[[B_CONV:.*]] = fir.convert %[[B_VAL]] : (i32) -> i128 - ! CHECK: %[[UNDER:.*]] = arith.cmpi slt, %[[B_CONV]], %[[C_0]] : i128 - ! CHECK: %[[OVER:.*]] = arith.cmpi sge, %[[B_CONV]], %[[C_BITS]] : i128 - ! CHECK: %[[INVALID:.*]] = arith.ori %[[UNDER]], %[[OVER]] : i1 - ! CHECK: %[[SHIFT:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128 - ! CHECK: %[[RES:.*]] = arith.select %[[INVALID]], %[[C_0]], %[[SHIFT]] : i128 + ! CHECK: %[[SHIFT_IS_BITWIDTH:.*]] = arith.cmpi eq, %[[B_CONV]], %[[C_BITS]] : i128 + ! CHECK: %[[C0:.*]] = arith.constant 0 : i128 + ! CHECK: %[[CM1:.*]] = arith.constant {{.*}} : i128 + ! CHECK: %[[IS_NEG:.*]] = arith.cmpi slt, %[[A_VAL]], %[[C0]] : i128 + ! CHECK: %[[RES:.*]] = arith.select %[[IS_NEG]], %[[CM1]], %[[C0]] : i128 + ! CHECK: %[[SHIFTED:.*]] = arith.shrsi %[[A_VAL]], %[[B_CONV]] : i128 + ! CHECK: %{{.*}} = arith.select %[[SHIFT_IS_BITWIDTH]], %[[RES]], %[[SHIFTED]] : i128 end subroutine shifta16_test