Skip to content

Commit

Permalink
Linker: Better type comparison for OpTypeArray and OpTypeForwardPoint…
Browse files Browse the repository at this point in the history
…er (#2580)

* Types: Avoid comparing IDs for in Type::IsSameImpl

When linking, we end up with duplicate types for imported and exported
types, that needs to be removed. The current code would reject valid
import/export pairs of symbols due to IDs mismatch, even if the types or
constants behind those ID were the same.

Enabled remaining type_match_test

Fixes #2442
  • Loading branch information
pierremoreau authored and s-perron committed May 29, 2019
1 parent 0125b28 commit e7866de
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 111 deletions.
19 changes: 11 additions & 8 deletions source/link/linker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,23 @@
#include "source/opt/ir_loader.h"
#include "source/opt/pass_manager.h"
#include "source/opt/remove_duplicates_pass.h"
#include "source/opt/type_manager.h"
#include "source/spirv_target_env.h"
#include "source/util/make_unique.h"
#include "spirv-tools/libspirv.hpp"

namespace spvtools {
namespace {

using opt::IRContext;
using opt::Instruction;
using opt::IRContext;
using opt::Module;
using opt::Operand;
using opt::PassManager;
using opt::RemoveDuplicatesPass;
using opt::analysis::DecorationManager;
using opt::analysis::DefUseManager;
using opt::analysis::Type;
using opt::analysis::TypeManager;

// Stores various information about an imported or exported symbol.
struct LinkageSymbolInfo {
Expand Down Expand Up @@ -472,14 +474,15 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
opt::IRContext* context) {
spv_position_t position = {};

// Ensure th import and export types are the same.
const DefUseManager& def_use_manager = *context->get_def_use_mgr();
// Ensure the import and export types are the same.
const DecorationManager& decoration_manager = *context->get_decoration_mgr();
const TypeManager& type_manager = *context->get_type_mgr();
for (const auto& linking_entry : linkings_to_do) {
if (!RemoveDuplicatesPass::AreTypesEqual(
*def_use_manager.GetDef(linking_entry.imported_symbol.type_id),
*def_use_manager.GetDef(linking_entry.exported_symbol.type_id),
context))
Type* imported_symbol_type =
type_manager.GetType(linking_entry.imported_symbol.type_id);
Type* exported_symbol_type =
type_manager.GetType(linking_entry.exported_symbol.type_id);
if (!(*imported_symbol_type == *exported_symbol_type))
return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
<< "Type mismatch on symbol \""
<< linking_entry.imported_symbol.name
Expand Down
89 changes: 53 additions & 36 deletions source/opt/remove_duplicates_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,35 +96,67 @@ bool RemoveDuplicatesPass::RemoveDuplicateTypes() const {
return modified;
}

analysis::TypeManager type_manager(context()->consumer(), context());

std::vector<Instruction*> visited_types;
std::vector<analysis::ForwardPointer> visited_forward_pointers;
std::vector<Instruction*> to_delete;
for (auto* i = &*context()->types_values_begin(); i; i = i->NextNode()) {
const bool is_i_forward_pointer = i->opcode() == SpvOpTypeForwardPointer;

// We only care about types.
if (!spvOpcodeGeneratesType((i->opcode())) &&
i->opcode() != SpvOpTypeForwardPointer) {
if (!spvOpcodeGeneratesType(i->opcode()) && !is_i_forward_pointer) {
continue;
}

// Is the current type equal to one of the types we have aready visited?
SpvId id_to_keep = 0u;
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
for (auto j : visited_types) {
if (AreTypesEqual(*i, *j, context())) {
id_to_keep = j->result_id();
break;
if (!is_i_forward_pointer) {
// Is the current type equal to one of the types we have already visited?
SpvId id_to_keep = 0u;
analysis::Type* i_type = type_manager.GetType(i->result_id());
assert(i_type);
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
for (auto j : visited_types) {
analysis::Type* j_type = type_manager.GetType(j->result_id());
assert(j_type);
if (*i_type == *j_type) {
id_to_keep = j->result_id();
break;
}
}
}

if (id_to_keep == 0u) {
// This is a never seen before type, keep it around.
visited_types.emplace_back(i);
if (id_to_keep == 0u) {
// This is a never seen before type, keep it around.
visited_types.emplace_back(i);
} else {
// The same type has already been seen before, remove this one.
context()->KillNamesAndDecorates(i->result_id());
context()->ReplaceAllUsesWith(i->result_id(), id_to_keep);
modified = true;
to_delete.emplace_back(i);
}
} else {
// The same type has already been seen before, remove this one.
context()->KillNamesAndDecorates(i->result_id());
context()->ReplaceAllUsesWith(i->result_id(), id_to_keep);
modified = true;
to_delete.emplace_back(i);
analysis::ForwardPointer i_type(
i->GetSingleWordInOperand(0u),
(SpvStorageClass)i->GetSingleWordInOperand(1u));
i_type.SetTargetPointer(
type_manager.GetType(i_type.target_id())->AsPointer());

// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
const bool found_a_match =
std::find(std::begin(visited_forward_pointers),
std::end(visited_forward_pointers),
i_type) != std::end(visited_forward_pointers);

if (!found_a_match) {
// This is a never seen before type, keep it around.
visited_forward_pointers.emplace_back(i_type);
} else {
// The same type has already been seen before, remove this one.
modified = true;
to_delete.emplace_back(i);
}
}
}

Expand All @@ -151,8 +183,8 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const {

analysis::DecorationManager decoration_manager(context()->module());
for (auto* i = &*context()->annotation_begin(); i;) {
// Is the current decoration equal to one of the decorations we have aready
// visited?
// Is the current decoration equal to one of the decorations we have
// already visited?
bool already_visited = false;
// TODO(dneto0): Use a trie to avoid quadratic behaviour? Extract the
// ResultIdTrie from unify_const_pass.cpp for this.
Expand All @@ -177,20 +209,5 @@ bool RemoveDuplicatesPass::RemoveDuplicateDecorations() const {
return modified;
}

bool RemoveDuplicatesPass::AreTypesEqual(const Instruction& inst1,
const Instruction& inst2,
IRContext* context) {
if (inst1.opcode() != inst2.opcode()) return false;
if (!IsTypeInst(inst1.opcode())) return false;

const analysis::Type* type1 =
context->get_type_mgr()->GetType(inst1.result_id());
const analysis::Type* type2 =
context->get_type_mgr()->GetType(inst2.result_id());
if (type1 && type2 && *type1 == *type2) return true;

return false;
}

} // namespace opt
} // namespace spvtools
6 changes: 0 additions & 6 deletions source/opt/remove_duplicates_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ class RemoveDuplicatesPass : public Pass {
const char* name() const override { return "remove-duplicates"; }
Status Process() override;

// TODO(pierremoreau): Move this function somewhere else (e.g. pass.h or
// within the type manager)
// Returns whether two types are equal, and have the same decorations.
static bool AreTypesEqual(const Instruction& inst1, const Instruction& inst2,
IRContext* context);

private:
// Remove duplicate capabilities from the module
//
Expand Down
53 changes: 45 additions & 8 deletions source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,13 @@ uint32_t TypeManager::GetId(const Type* type) const {
}

void TypeManager::AnalyzeTypes(const Module& module) {
// First pass through the types. Any types that reference a forward pointer
// First pass through the constants, as some will be needed when traversing
// the types in the next pass.
for (const auto* inst : module.GetConstants()) {
id_to_constant_inst_[inst->result_id()] = inst;
}

// Then pass through the types. Any types that reference a forward pointer
// (directly or indirectly) are incomplete, and are added to incomplete types.
for (const auto* inst : module.GetTypes()) {
RecordIfTypeDefinition(*inst);
Expand Down Expand Up @@ -154,7 +160,7 @@ void TypeManager::AnalyzeTypes(const Module& module) {

#ifndef NDEBUG
// Check if the type pool contains two types that are the same. This
// is an indication that the hashing and comparision are wrong. It
// is an indication that the hashing and comparison are wrong. It
// will cause a problem if the type pool gets resized and everything
// is rehashed.
for (auto& i : type_pool_) {
Expand Down Expand Up @@ -505,8 +511,15 @@ Type* TypeManager::RebuildType(const Type& type) {
case Type::kArray: {
const Array* array_ty = type.AsArray();
const Type* ele_ty = array_ty->element_type();
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId());
if (array_ty->length_spec_id() != 0u)
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId(),
array_ty->length_spec_id());
else
rebuilt_ty =
MakeUnique<Array>(RebuildType(*ele_ty), array_ty->LengthId(),
array_ty->length_constant_type(),
array_ty->length_constant_words());
break;
}
case Type::kRuntimeArray: {
Expand Down Expand Up @@ -636,15 +649,39 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
case SpvOpTypeSampledImage:
type = new SampledImage(GetType(inst.GetSingleWordInOperand(0)));
break;
case SpvOpTypeArray:
type = new Array(GetType(inst.GetSingleWordInOperand(0)),
inst.GetSingleWordInOperand(1));
case SpvOpTypeArray: {
const uint32_t length_id = inst.GetSingleWordInOperand(1);
const Instruction* length_constant_inst = id_to_constant_inst_[length_id];
assert(length_constant_inst);

// If it is a specialised constants, retrieve its SpecId.
uint32_t spec_id = 0u;
Type* length_type = nullptr;
Operand::OperandData length_words;
if (spvOpcodeIsSpecConstant(length_constant_inst->opcode())) {
context()->get_decoration_mgr()->ForEachDecoration(
length_id, SpvDecorationSpecId,
[&spec_id](const Instruction& decoration) {
assert(decoration.opcode() == SpvOpDecorate);
spec_id = decoration.GetSingleWordOperand(2u);
});
} else {
length_type = GetType(length_constant_inst->type_id());
length_words = length_constant_inst->GetOperand(2u).words;
}

if (spec_id != 0u)
type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_id,
spec_id);
else
type = new Array(GetType(inst.GetSingleWordInOperand(0)), length_id,
length_type, length_words);
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
} break;
case SpvOpTypeRuntimeArray:
type = new RuntimeArray(GetType(inst.GetSingleWordInOperand(0)));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
Expand Down
2 changes: 2 additions & 0 deletions source/opt/type_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ class TypeManager {

IdToTypeMap id_to_incomplete_type_; // Maps ids to their type representations
// for incomplete types.

std::unordered_map<uint32_t, const Instruction*> id_to_constant_inst_;
};

} // namespace analysis
Expand Down
50 changes: 43 additions & 7 deletions source/opt/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,17 +383,46 @@ void SampledImage::GetExtraHashWords(
image_type_->GetHashWords(words, seen);
}

Array::Array(Type* type, uint32_t length_id)
: Type(kArray), element_type_(type), length_id_(length_id) {
Array::Array(Type* type, uint32_t length_id, uint32_t spec_id)
: Type(kArray),
element_type_(type),
length_id_(length_id),
length_spec_id_(spec_id),
length_constant_type_(nullptr),
length_constant_words_() {
assert(!type->AsVoid());
assert(spec_id != 0u);
}

Array::Array(Type* type, uint32_t length_id, const Type* constant_type,
Operand::OperandData constant_words)
: Type(kArray),
element_type_(type),
length_id_(length_id),
length_spec_id_(0u),
length_constant_type_(constant_type),
length_constant_words_(constant_words) {
assert(!type->AsVoid());
assert(constant_type && constant_type->AsInteger());
}

bool Array::IsSameImpl(const Type* that, IsSameCache* seen) const {
const Array* at = that->AsArray();
if (!at) return false;
return length_id_ == at->length_id_ &&
element_type_->IsSameImpl(at->element_type_, seen) &&
HasSameDecorations(that);
bool is_same = element_type_->IsSameImpl(at->element_type_, seen) &&
HasSameDecorations(that);
// If it is a specialized constant
if (length_spec_id_ != 0u) {
// ensure they have the same SpecId
is_same = is_same && length_spec_id_ == at->length_spec_id_;
} else {
// else, ensure they have the same length literal number.
is_same =
is_same &&
length_constant_type_->IsSameImpl(at->length_constant_type_, seen) &&
length_constant_words_ == at->length_constant_words_;
}
return is_same;
}

std::string Array::str() const {
Expand All @@ -405,7 +434,13 @@ std::string Array::str() const {
void Array::GetExtraHashWords(std::vector<uint32_t>* words,
std::unordered_set<const Type*>* seen) const {
element_type_->GetHashWords(words, seen);
words->push_back(length_id_);
if (length_spec_id_ != 0u) {
words->push_back(length_spec_id_);
} else {
length_constant_type_->GetHashWords(words, seen);
words->insert(words->end(), length_constant_words_.begin(),
length_constant_words_.end());
}
}

void Array::ReplaceElementType(const Type* type) { element_type_ = type; }
Expand Down Expand Up @@ -609,7 +644,8 @@ void Pipe::GetExtraHashWords(std::vector<uint32_t>* words,
bool ForwardPointer::IsSameImpl(const Type* that, IsSameCache*) const {
const ForwardPointer* fpt = that->AsForwardPointer();
if (!fpt) return false;
return target_id_ == fpt->target_id_ &&
return (pointer_ && fpt->pointer_ ? *pointer_ == *fpt->pointer_
: target_id_ == fpt->target_id_) &&
storage_class_ == fpt->storage_class_ && HasSameDecorations(that);
}

Expand Down
13 changes: 12 additions & 1 deletion source/opt/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <vector>

#include "source/latest_version_spirv_header.h"
#include "source/opt/instruction.h"
#include "spirv-tools/libspirv.h"

namespace spvtools {
Expand Down Expand Up @@ -356,12 +357,19 @@ class SampledImage : public Type {

class Array : public Type {
public:
Array(Type* element_type, uint32_t length_id);
Array(Type* element_type, uint32_t length_id, uint32_t spec_id);
Array(Type* element_type, uint32_t length_id, const Type* constant_type,
Operand::OperandData constant_words);
Array(const Array&) = default;

std::string str() const override;
const Type* element_type() const { return element_type_; }
uint32_t LengthId() const { return length_id_; }
uint32_t length_spec_id() const { return length_spec_id_; }
const Type* length_constant_type() const { return length_constant_type_; }
Operand::OperandData length_constant_words() const {
return length_constant_words_;
}

Array* AsArray() override { return this; }
const Array* AsArray() const override { return this; }
Expand All @@ -376,6 +384,9 @@ class Array : public Type {

const Type* element_type_;
uint32_t length_id_;
uint32_t length_spec_id_;
const Type* length_constant_type_;
Operand::OperandData length_constant_words_;
};

class RuntimeArray : public Type {
Expand Down
Loading

0 comments on commit e7866de

Please sign in to comment.