Skip to content

Commit

Permalink
Handle review comments.
Browse files Browse the repository at this point in the history
Renamed `getTypeSizeAndAlignment` to `getTypeSizeAndAlignmentOrCrash`.
It makes it clear to user that function terminates the compilation if it
can't handle the given type.

The `getTypeSizeAndAlignment` now returns a pair wrapped into an
std::optional which can be used to detect error.
  • Loading branch information
abidh committed Jul 24, 2024
1 parent d3bbbc6 commit 9a3d9f9
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 37 deletions.
14 changes: 11 additions & 3 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,10 +487,18 @@ std::string getTypeAsString(mlir::Type ty, const KindMapping &kindMap,
/// target dependent type size inquiries in lowering. It would also not be
/// straightforward given the need for a kind map that would need to be
/// converted in terms of mlir::DataLayoutEntryKey.
std::pair<std::uint64_t, unsigned short> getTypeSizeAndAlignment(
mlir::Location loc, mlir::Type ty, const mlir::DataLayout &dl,
const fir::KindMapping &kindMap, bool *success = nullptr);

/// This variant terminates the compilation if an unsupported type is passed.
std::pair<std::uint64_t, unsigned short>
getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty,
const mlir::DataLayout &dl,
const fir::KindMapping &kindMap);

/// This variant returns std::nullopt if an unsupported type is passed.
std::optional<std::pair<uint64_t, unsigned short>>
getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
const mlir::DataLayout &dl,
const fir::KindMapping &kindMap);
} // namespace fir

#endif // FORTRAN_OPTIMIZER_DIALECT_FIRTYPE_H
10 changes: 5 additions & 5 deletions flang/lib/Optimizer/CodeGen/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
return byteOffset;
}
mlir::Type compType = component.second;
auto [compSize, compAlign] =
fir::getTypeSizeAndAlignment(loc, compType, getDataLayout(), kindMap);
auto [compSize, compAlign] = fir::getTypeSizeAndAlignmentOrCrash(
loc, compType, getDataLayout(), kindMap);
byteOffset = llvm::alignTo(byteOffset, compAlign);
ArgClass LoComp, HiComp;
classify(loc, compType, byteOffset, LoComp, HiComp);
Expand All @@ -452,8 +452,8 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
ArgClass &Hi) const {
mlir::Type eleTy = seqTy.getEleTy();
const std::uint64_t arraySize = seqTy.getConstantArraySize();
auto [eleSize, eleAlign] =
fir::getTypeSizeAndAlignment(loc, eleTy, getDataLayout(), kindMap);
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
loc, eleTy, getDataLayout(), kindMap);
std::uint64_t eleStorageSize = llvm::alignTo(eleSize, eleAlign);
for (std::uint64_t i = 0; i < arraySize; ++i) {
byteOffset = llvm::alignTo(byteOffset, eleAlign);
Expand Down Expand Up @@ -641,7 +641,7 @@ struct TargetX86_64 : public GenericTarget<TargetX86_64> {
mlir::Type ty) const {
CodeGenSpecifics::Marshalling marshal;
auto sizeAndAlign =
fir::getTypeSizeAndAlignment(loc, ty, getDataLayout(), kindMap);
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
// The stack is always 8 byte aligned (note 14 in 3.2.3).
unsigned short align =
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
Expand Down
63 changes: 40 additions & 23 deletions flang/lib/Optimizer/Dialect/FIRType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1393,61 +1393,78 @@ void FIROpsDialect::registerTypes() {
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
}

std::pair<std::uint64_t, unsigned short>
std::optional<std::pair<uint64_t, unsigned short>>
fir::getTypeSizeAndAlignment(mlir::Location loc, mlir::Type ty,
const mlir::DataLayout &dl,
const fir::KindMapping &kindMap, bool *success) {
const fir::KindMapping &kindMap) {
if (mlir::isa<mlir::IntegerType, mlir::FloatType, mlir::ComplexType>(ty)) {
llvm::TypeSize size = dl.getTypeSize(ty);
unsigned short alignment = dl.getTypeABIAlignment(ty);
return {size, alignment};
return std::pair{size, alignment};
}
if (auto firCmplx = mlir::dyn_cast<fir::ComplexType>(ty)) {
auto [floatSize, floatAlign] = getTypeSizeAndAlignment(
loc, firCmplx.getEleType(kindMap), dl, kindMap, success);
return {llvm::alignTo(floatSize, floatAlign) + floatSize, floatAlign};
auto result =
getTypeSizeAndAlignment(loc, firCmplx.getEleType(kindMap), dl, kindMap);
if (!result)
return result;
auto [floatSize, floatAlign] = *result;
return std::pair{llvm::alignTo(floatSize, floatAlign) + floatSize,
floatAlign};
}
if (auto real = mlir::dyn_cast<fir::RealType>(ty))
return getTypeSizeAndAlignment(loc, real.getFloatType(kindMap), dl, kindMap,
success);
return getTypeSizeAndAlignment(loc, real.getFloatType(kindMap), dl,
kindMap);

if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) {
auto [eleSize, eleAlign] =
getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap, success);
auto result = getTypeSizeAndAlignment(loc, seqTy.getEleTy(), dl, kindMap);
if (!result)
return result;
auto [eleSize, eleAlign] = *result;
std::uint64_t size =
llvm::alignTo(eleSize, eleAlign) * seqTy.getConstantArraySize();
return {size, eleAlign};
return std::pair{size, eleAlign};
}
if (auto recTy = mlir::dyn_cast<fir::RecordType>(ty)) {
std::uint64_t size = 0;
unsigned short align = 1;
for (auto component : recTy.getTypeList()) {
auto [compSize, compAlign] =
getTypeSizeAndAlignment(loc, component.second, dl, kindMap, success);
auto result = getTypeSizeAndAlignment(loc, component.second, dl, kindMap);
if (!result)
return result;
auto [compSize, compAlign] = *result;
size =
llvm::alignTo(size, compAlign) + llvm::alignTo(compSize, compAlign);
align = std::max(align, compAlign);
}
return {size, align};
return std::pair{size, align};
}
if (auto logical = mlir::dyn_cast<fir::LogicalType>(ty)) {
mlir::Type intTy = mlir::IntegerType::get(
logical.getContext(), kindMap.getLogicalBitsize(logical.getFKind()));
return getTypeSizeAndAlignment(loc, intTy, dl, kindMap, success);
return getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
}
if (auto character = mlir::dyn_cast<fir::CharacterType>(ty)) {
mlir::Type intTy = mlir::IntegerType::get(
character.getContext(),
kindMap.getCharacterBitsize(character.getFKind()));
auto [compSize, compAlign] =
getTypeSizeAndAlignment(loc, intTy, dl, kindMap, success);
auto result = getTypeSizeAndAlignment(loc, intTy, dl, kindMap);
if (!result)
return result;
auto [compSize, compAlign] = *result;
if (character.hasConstantLen())
compSize *= character.getLen();
return {compSize, compAlign};
}
if (success) {
*success = false;
return {0, 1};
return std::pair{compSize, compAlign};
}
TODO(loc, "computing size of a component");
return std::nullopt;
}

std::pair<std::uint64_t, unsigned short>
fir::getTypeSizeAndAlignmentOrCrash(mlir::Location loc, mlir::Type ty,
const mlir::DataLayout &dl,
const fir::KindMapping &kindMap) {
std::optional<std::pair<uint64_t, unsigned short>> result =
getTypeSizeAndAlignment(loc, ty, dl, kindMap);
if (result)
return *result;
TODO(loc, "computing size of a component");
}
8 changes: 3 additions & 5 deletions flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,10 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType(
llvm::SmallVector<mlir::LLVM::DINodeAttr> elements;
std::uint64_t offset = 0;
for (auto [fieldName, fieldTy] : Ty.getTypeList()) {
bool success = true;
auto [byteSize, byteAlign] =
fir::getTypeSizeAndAlignment(loc, fieldTy, *dl, kindMapping, &success);
if (!success)
auto result = fir::getTypeSizeAndAlignment(loc, fieldTy, *dl, kindMapping);
if (!result)
return genPlaceholderType(context);

auto [byteSize, byteAlign] = *result;
mlir::LLVM::DITypeAttr elemTy = convertType(fieldTy, fileAttr, scope, loc);
offset = llvm::alignTo(offset, byteAlign);
mlir::LLVM::DIDerivedTypeAttr tyAttr = mlir::LLVM::DIDerivedTypeAttr::get(
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Optimizer/Transforms/LoopVersioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ void LoopVersioningPass::runOnOperation() {
if (mlir::isa<mlir::FloatType>(elementType) ||
mlir::isa<mlir::IntegerType>(elementType) ||
mlir::isa<fir::ComplexType>(elementType)) {
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignment(
auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
arg.getLoc(), elementType, *dl, kindMap);
typeSize = llvm::alignTo(eleSize, eleAlign);
}
Expand Down

0 comments on commit 9a3d9f9

Please sign in to comment.