Skip to content

Commit

Permalink
[FIRRTL] Dedup: speed up handling of instances (#7815)
Browse files Browse the repository at this point in the history
Dedup tries to hash all modules in parallel.  To accomplish this, the
names of instantiated modules are not included as part of the structural
hash, but they are taken in to account when checking if two modules are
the same.  This process involves comparing the instantiated children
modules of two modules if their hashes match.  This was implemented by
using an array attribute, to make comparisons quicker.

When a module or class has many thousands of instances underneath it, it
becomes impractical to build a array attribute with every child module.
Interning such a large ArrayAttr is incredibly slow and will eat up that
memory for the rest of the process.

Instead, we don't bother interning the instance arrays, and just keep
them as plain old vectors, which comes with the benefit of not eagerly
interning gigantic arrays.
  • Loading branch information
youngar authored Nov 14, 2024
1 parent f71d0fb commit bf43ca2
Showing 1 changed file with 25 additions and 27 deletions.
52 changes: 25 additions & 27 deletions lib/Dialect/FIRRTL/Transforms/Dedup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseMapInfo.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -89,9 +90,14 @@ struct ModuleInfo {
// SHA256 hash.
std::array<uint8_t, 32> structuralHash;
// Module names referred by instance op in the module.
mlir::ArrayAttr referredModuleNames;
std::vector<StringAttr> referredModuleNames;
};

static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
return lhs.structuralHash == rhs.structuralHash &&
lhs.referredModuleNames == rhs.referredModuleNames;
}

/// Unique identifier for a value. All value sources are numbered by apperance,
/// and values are identified using this numbering (`index`) and an `offset`.
/// For BlockArgument's, this is the argument number.
Expand Down Expand Up @@ -146,11 +152,9 @@ struct StructuralHasher {
explicit StructuralHasher(const StructuralHasherSharedConstants &constants)
: constants(constants){};

std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>
getHashAndModuleNames(FModuleLike module) {
ModuleInfo getModuleInfo(FModuleLike module) {
update(&(*module));
auto hash = sha.final();
return {hash, referredModuleNames};
return {sha.final(), std::move(referredModuleNames)};
}

private:
Expand Down Expand Up @@ -359,7 +363,7 @@ struct StructuralHasher {
DenseMap<StringAttr, SymbolTarget> innerSymTargets;

// This keeps track of module names in the order of the appearance.
SmallVector<mlir::StringAttr> referredModuleNames;
std::vector<StringAttr> referredModuleNames;

// String constants.
const StructuralHasherSharedConstants &constants;
Expand Down Expand Up @@ -1595,13 +1599,13 @@ struct DenseMapInfo<ModuleInfo> {
static inline ModuleInfo getEmptyKey() {
std::array<uint8_t, 32> key;
std::fill(key.begin(), key.end(), ~0);
return {key, DenseMapInfo<mlir::ArrayAttr>::getEmptyKey()};
return {key, {}};
}

static inline ModuleInfo getTombstoneKey() {
std::array<uint8_t, 32> key;
std::fill(key.begin(), key.end(), ~0 - 1);
return {key, DenseMapInfo<mlir::ArrayAttr>::getTombstoneKey()};
return {key, {}};
}

static unsigned getHashValue(const ModuleInfo &val) {
Expand All @@ -1611,12 +1615,13 @@ struct DenseMapInfo<ModuleInfo> {
std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));

// Combine module names.
return llvm::hash_combine(hash, val.referredModuleNames);
return llvm::hash_combine(
hash, llvm::hash_combine_range(val.referredModuleNames.begin(),
val.referredModuleNames.end()));
}

static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
return lhs.structuralHash == rhs.structuralHash &&
lhs.referredModuleNames == rhs.referredModuleNames;
return lhs == rhs;
}
};
} // namespace llvm
Expand Down Expand Up @@ -1659,9 +1664,7 @@ class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
return cast<FModuleLike>(*node->getModule());
}));

SmallVector<std::optional<
std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>>>
hashesAndModuleNames(modules.size());
SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
StructuralHasherSharedConstants hasherConstants(&getContext());

// Attribute name used to store dedup_group for this pass.
Expand Down Expand Up @@ -1708,7 +1711,7 @@ class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {

StructuralHasher hasher(hasherConstants);
// Calculate the hash of the module and referred module names.
hashesAndModuleNames[idx] = hasher.getHashAndModuleNames(module);
moduleInfos[idx] = hasher.getModuleInfo(module);
return success();
});

Expand All @@ -1717,9 +1720,9 @@ class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {

for (auto [i, module] : llvm::enumerate(modules)) {
auto moduleName = module.getModuleNameAttr();
auto &hashAndModuleNamesOpt = hashesAndModuleNames[i];
auto &maybeModuleInfo = moduleInfos[i];
// If the hash was not calculated, we need to skip it.
if (!hashAndModuleNamesOpt) {
if (!maybeModuleInfo) {
// We record it in the dedup map to help detect errors when the user
// marks the module as both NoDedup and MustDedup. We do not record this
// module in the hasher to make sure no other module dedups "into" this
Expand All @@ -1728,16 +1731,11 @@ class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
continue;
}

// Replace module names referred in the module with new names.
SmallVector<mlir::Attribute> names;
for (auto oldModuleName : hashAndModuleNamesOpt->second) {
auto newModuleName = dedupMap[oldModuleName];
names.push_back(newModuleName);
}
auto &moduleInfo = maybeModuleInfo.value();

// Create a module info to use it as a key.
ModuleInfo moduleInfo{hashAndModuleNamesOpt->first,
mlir::ArrayAttr::get(module.getContext(), names)};
// Replace module names referred in the module with new names.
for (auto &referredModule : moduleInfo.referredModuleNames)
referredModule = dedupMap[referredModule];

// Check if there a module with the same hash.
auto it = moduleInfoToModule.find(moduleInfo);
Expand All @@ -1755,7 +1753,7 @@ class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
// Add the module to a new dedup group.
dedupMap[moduleName] = moduleName;
// Record the module info.
moduleInfoToModule[moduleInfo] = module;
moduleInfoToModule[std::move(moduleInfo)] = module;
}

// This part verifies that all modules marked by "MustDedup" have been
Expand Down

0 comments on commit bf43ca2

Please sign in to comment.