Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] Add an IrSha1 observation space. #267

Merged
merged 2 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler_gym/envs/llvm/service/Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ Benchmark::Benchmark(const std::string& name, const Bitcode& bitcode,
: context_(std::make_unique<llvm::LLVMContext>()),
module_(makeModuleOrDie(*context_, bitcode, name)),
baselineCosts_(baselineCosts),
hash_(getModuleHash(*module_)),
name_(name),
bitcodeSize_(bitcode.size()) {}

Expand All @@ -110,7 +109,6 @@ Benchmark::Benchmark(const std::string& name, std::unique_ptr<llvm::LLVMContext>
: context_(std::move(context)),
module_(std::move(module)),
baselineCosts_(baselineCosts),
hash_(getModuleHash(*module_)),
name_(name),
bitcodeSize_(bitcodeSize) {}

Expand All @@ -122,4 +120,6 @@ std::unique_ptr<Benchmark> Benchmark::clone(const fs::path& workingDirectory) co
return std::make_unique<Benchmark>(name(), bitcode, workingDirectory, baselineCosts());
}

BenchmarkHash Benchmark::module_hash() const { return getModuleHash(*module_); }

} // namespace compiler_gym::llvm_service
12 changes: 4 additions & 8 deletions compiler_gym/envs/llvm/service/Benchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@

namespace compiler_gym::llvm_service {

// We identify benchmarks using a hash of the LLVM module, which is a
// 160 bits SHA1.
//
// NOTE(cummins): In the future when we extend this to support optimizing for
// performance, we would need this
// A 160 bits SHA1 that identifies an LLVM module.
using BenchmarkHash = llvm::ModuleHash;

using Bitcode = llvm::SmallString<0>;
Expand All @@ -47,6 +43,9 @@ class Benchmark {
// Make a copy of the benchmark.
std::unique_ptr<Benchmark> clone(const boost::filesystem::path& workingDirectory) const;

// Compute and return a SHA1 hash of the module.
BenchmarkHash module_hash() const;

inline const std::string& name() const { return name_; }

inline const size_t bitcodeSize() const { return bitcodeSize_; }
Expand All @@ -66,8 +65,6 @@ class Benchmark {

inline const llvm::Module* module_ptr() const { return module_.get(); }

inline const BenchmarkHash hash() const { return hash_; }

// Replace the benchmark module with a new one. This is to enable
// out-of-process modification of the IR by serializing the benchmark to a
// file, modifying the file, then loading the modified file and updating the
Expand All @@ -81,7 +78,6 @@ class Benchmark {
std::unique_ptr<llvm::LLVMContext> context_;
std::unique_ptr<llvm::Module> module_;
const BaselineCosts baselineCosts_;
const BenchmarkHash hash_;
const std::string name_;
// The length of the bitcode string for this benchmark.
const size_t bitcodeSize_;
Expand Down
13 changes: 13 additions & 0 deletions compiler_gym/envs/llvm/service/LlvmSession.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <fmt/format.h>
#include <glog/logging.h>

#include <iomanip>
#include <optional>
#include <subprocess/subprocess.hpp>

Expand Down Expand Up @@ -270,6 +271,18 @@ Status LlvmSession::getObservation(LlvmObservationSpace space, Observation* repl
reply->set_string_value(ir);
break;
}
case LlvmObservationSpace::IR_SHA1: {
std::stringstream ss;
const BenchmarkHash hash = benchmark().module_hash();
// Hex encode, zero pad, and concatenate the unsigned integers that
// contain the hash.
for (uint32_t val : hash) {
ss << std::setfill('0') << std::setw(sizeof(BenchmarkHash::value_type) * 2) << std::hex
<< val;
}
reply->set_string_value(ss.str());
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
// Generate an output path with 16 bits of randomness.
const auto outpath = fs::unique_path(workingDirectory_ / "module-%%%%%%%%.bc");
Expand Down
8 changes: 8 additions & 0 deletions compiler_gym/envs/llvm/service/ObservationSpaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::IR_SHA1: {
ScalarRange sha1Size;
space.mutable_string_size_range()->mutable_min()->set_value(40);
space.mutable_string_size_range()->mutable_max()->set_value(40);
space.set_deterministic(true);
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
ScalarRange pathLength;
space.mutable_string_size_range()->mutable_min()->set_value(0);
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/envs/llvm/service/ObservationSpaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum class LlvmObservationSpace {
// The entire LLVM module as an IR string. This allows the user to do its own
// feature extraction.
IR,
// The 40-digit hex SHA1 checksum of the LLVM module.
IR_SHA1,
// Write the bitcode to a file. Returns a string, which is the path of the
// written file.
BITCODE_FILE,
Expand Down
19 changes: 19 additions & 0 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_observation_spaces(env: LlvmEnv):

assert set(env.observation.spaces.keys()) == {
"Ir",
"IrSha1",
"BitcodeFile",
"InstCount",
"InstCountDict",
Expand Down Expand Up @@ -79,6 +80,24 @@ def test_ir_observation_space(env: LlvmEnv):
assert not space.platform_dependent


def test_ir_sha1_observation_space(env: LlvmEnv):
env.reset("cbench-v1/crc32")
key = "IrSha1"
space = env.observation.spaces[key]
assert isinstance(space.space, Sequence)
assert space.space.dtype == str
assert space.space.size_range == (40, 40)

value: str = env.observation[key]
print(value) # For debugging in case of error.
assert isinstance(value, str)
assert len(value) == 40
assert space.space.contains(value)

assert space.deterministic
assert not space.platform_dependent


def test_bitcode_observation_space(env: LlvmEnv):
env.reset("cbench-v1/crc32")
key = "BitcodeFile"
Expand Down