Skip to content

Commit

Permalink
[llvm] Add an IrSha1 observation space.
Browse files Browse the repository at this point in the history
This adds a new `IrSha1` observation space that is a 40-digit SHA1
checksum of the current module state.
  • Loading branch information
ChrisCummins committed May 13, 2021
1 parent ea37391 commit 1099afe
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler_gym/envs/llvm/service/Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,6 @@ std::unique_ptr<Benchmark> Benchmark::clone(const fs::path& workingDirectory) co
return std::make_unique<Benchmark>(name(), bitcode, workingDirectory, baselineCosts());
}

BenchmarkHash Benchmark::module_hash() const { return getModuleHash(*module_); }

} // namespace compiler_gym::llvm_service
3 changes: 3 additions & 0 deletions compiler_gym/envs/llvm/service/Benchmark.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class Benchmark {
// Make a copy of the benchmark.
std::unique_ptr<Benchmark> clone(const boost::filesystem::path& workingDirectory) const;

// Compute and return a SHA1 hash of the module.
BenchmarkHash module_hash() const;

inline const std::string& name() const { return name_; }

inline const size_t bitcodeSize() const { return bitcodeSize_; }
Expand Down
13 changes: 13 additions & 0 deletions compiler_gym/envs/llvm/service/LlvmSession.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <fmt/format.h>
#include <glog/logging.h>

#include <iomanip>
#include <optional>
#include <subprocess/subprocess.hpp>

Expand Down Expand Up @@ -270,6 +271,18 @@ Status LlvmSession::getObservation(LlvmObservationSpace space, Observation* repl
reply->set_string_value(ir);
break;
}
case LlvmObservationSpace::IR_SHA1: {
std::stringstream ss;
const BenchmarkHash hash = benchmark().module_hash();
// Hex encode, zero pad, and concatenate the unsigned integers that
// contain the hash.
for (uint32_t val : hash) {
ss << std::setfill('0') << std::setw(sizeof(BenchmarkHash::value_type) * 2) << std::hex
<< val;
}
reply->set_string_value(ss.str());
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
// Generate an output path with 16 bits of randomness.
const auto outpath = fs::unique_path(workingDirectory_ / "module-%%%%%%%%.bc");
Expand Down
8 changes: 8 additions & 0 deletions compiler_gym/envs/llvm/service/ObservationSpaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ std::vector<ObservationSpace> getLlvmObservationSpaceList() {
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::IR_SHA1: {
ScalarRange sha1Size;
space.mutable_string_size_range()->mutable_min()->set_value(40);
space.mutable_string_size_range()->mutable_max()->set_value(40);
space.set_deterministic(true);
space.set_platform_dependent(false);
break;
}
case LlvmObservationSpace::BITCODE_FILE: {
ScalarRange pathLength;
space.mutable_string_size_range()->mutable_min()->set_value(0);
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/envs/llvm/service/ObservationSpaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ enum class LlvmObservationSpace {
// The entire LLVM module as an IR string. This allows the user to do its own
// feature extraction.
IR,
// The 40-digit hex SHA1 checksum of the LLVM module.
IR_SHA1,
// Write the bitcode to a file. Returns a string, which is the path of the
// written file.
BITCODE_FILE,
Expand Down
19 changes: 19 additions & 0 deletions tests/llvm/observation_spaces_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_observation_spaces(env: LlvmEnv):

assert set(env.observation.spaces.keys()) == {
"Ir",
"IrSha1",
"BitcodeFile",
"InstCount",
"InstCountDict",
Expand Down Expand Up @@ -79,6 +80,24 @@ def test_ir_observation_space(env: LlvmEnv):
assert not space.platform_dependent


def test_ir_sha1_observation_space(env: LlvmEnv):
env.reset("cbench-v1/crc32")
key = "IrSha1"
space = env.observation.spaces[key]
assert isinstance(space.space, Sequence)
assert space.space.dtype == str
assert space.space.size_range == (40, 40)

value: str = env.observation[key]
print(value) # For debugging in case of error.
assert isinstance(value, str)
assert len(value) == 40
assert space.space.contains(value)

assert space.deterministic
assert not space.platform_dependent


def test_bitcode_observation_space(env: LlvmEnv):
env.reset("cbench-v1/crc32")
key = "BitcodeFile"
Expand Down

0 comments on commit 1099afe

Please sign in to comment.