Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[serialization] Add support to serialize to memory, and a basic serialization tutorial #7760

Merged
merged 11 commits into from
Sep 28, 2023
36 changes: 34 additions & 2 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,15 @@ class Deserializer {
: external_params(external_params) {
}

// Deserialize a pipeline from the given filename
Pipeline deserialize(const std::string &filename);

// Deserialize a pipeline from the given input stream
Pipeline deserialize(std::istream &in);

// Deserialize a pipeline from the given buffer of bytes
Pipeline deserialize(const std::vector<uint8_t> &data);

private:
// Helper function to deserialize a homogenous vector from a flatbuffer vector,
// does not apply to union types like Stmt and Expr or enum types like MemoryType
Expand Down Expand Up @@ -445,6 +450,8 @@ void Deserializer::deserialize_function(const Serialize::Func *function, Functio
output_buffer = it->second;
} else if (auto it = parameters_in_pipeline.find(output_buffer_name); it != parameters_in_pipeline.end()) {
output_buffer = it->second;
} else if (!output_buffer_name.empty()) {
user_error << "unknown output buffer used in pipeline '" << output_buffer_name << "'\n";
}
output_buffers.push_back(output_buffer);
}
Expand Down Expand Up @@ -514,6 +521,8 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
const auto alignment = deserialize_modulus_remainder(store_stmt->alignment());
return Store::make(name, value, index, param, predicate, alignment);
Expand Down Expand Up @@ -771,6 +780,8 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
const auto alignment = deserialize_modulus_remainder(load_expr->alignment());
const auto type = deserialize_type(load_expr->type());
Expand Down Expand Up @@ -820,6 +831,8 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
const auto type = deserialize_type(call_expr->type());
return Call::make(type, name, args, call_type, func_ptr, value_index, image, param);
Expand All @@ -834,6 +847,8 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
auto image_name = deserialize_string(variable_expr->image_name());
Buffer<> image;
Expand Down Expand Up @@ -1031,6 +1046,8 @@ PrefetchDirective Deserializer::deserialize_prefetch_directive(const Serialize::
Parameter param;
if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
} else if (!param_name.empty()) {
user_error << "unknown parameter used in pipeline '" << param_name << "'\n";
}
auto hl_prefetch_directive = PrefetchDirective();
hl_prefetch_directive.name = name;
Expand Down Expand Up @@ -1199,6 +1216,8 @@ ExternFuncArgument Deserializer::deserialize_extern_func_argument(const Serializ
image_param = it->second;
} else if (auto it = parameters_in_pipeline.find(image_param_name); it != parameters_in_pipeline.end()) {
image_param = it->second;
} else if (!image_param_name.empty()) {
user_error << "unknown image parameter used in pipeline '" << image_param_name << "'\n";
}
return ExternFuncArgument(image_param);
}
Expand Down Expand Up @@ -1294,9 +1313,12 @@ Pipeline Deserializer::deserialize(std::istream &in) {
in.seekg(0, std::ios::end);
int size = in.tellg();
in.seekg(0, std::ios::beg);
std::vector<char> data(size);
in.read(data.data(), size);
std::vector<uint8_t> data(size);
in.read((char *)data.data(), size);
return deserialize(data);
}

Pipeline Deserializer::deserialize(const std::vector<uint8_t> &data) {
const auto *pipeline_obj = Serialize::GetPipeline(data.data());
if (pipeline_obj == nullptr) {
user_warning << "deserialized pipeline is empty\n";
Expand Down Expand Up @@ -1375,6 +1397,11 @@ Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Para
return deserializer.deserialize(in);
}

Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
return deserializer.deserialize(buffer);
}

} // namespace Halide

#else // WITH_SERIALIZATION
Expand All @@ -1391,6 +1418,11 @@ Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Para
return Pipeline();
}

Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &external_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}

} // namespace Halide

#endif // WITH_SERIALIZATION
20 changes: 14 additions & 6 deletions src/Deserialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@

namespace Halide {

/**
* Deserialize a Halide pipeline from a file.
* filename should always end in .hlpipe suffix.
* external_params is an optional map, all parameters in the map
* will be treated as external parameters so won't be deserialized.
*/
/// @brief Deserialize a Halide pipeline from a file.
/// @param filename The location of the file to deserialize. Must use .hlpipe extension.
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params);

/// @brief Deserialize a Halide pipeline from an input stream.
/// @param in The input stream to read from containing a serialized Halide pipeline
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params);

/// @brief Deserialize a Halide pipeline from a byte buffer containing a serizalized pipeline in binary format
/// @param data The data buffer containing a serialized Halide pipeline
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(const std::vector<uint8_t> &data, const std::map<std::string, Parameter> &external_params);

} // namespace Halide

#endif
49 changes: 47 additions & 2 deletions src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ class Serializer {
public:
Serializer() = default;

// Serialize the given pipeline into the given filename
void serialize(const Pipeline &pipeline, const std::string &filename);

// Serialize the given pipeline into given the data buffer
derek-gerstmann marked this conversation as resolved.
Show resolved Hide resolved
void serialize(const Pipeline &pipeline, std::vector<uint8_t> &data);

const std::map<std::string, Parameter> &get_external_parameters() const {
return external_parameters;
}
Expand Down Expand Up @@ -1388,7 +1392,7 @@ void Serializer::build_function_mappings(const std::map<std::string, Function> &
}
}

void Serializer::serialize(const Pipeline &pipeline, const std::string &filename) {
void Serializer::serialize(const Pipeline &pipeline, std::vector<uint8_t> &result) {
FlatBufferBuilder builder(1024);

// extract the DAG, unwrap function from Funcs
Expand Down Expand Up @@ -1453,17 +1457,46 @@ void Serializer::serialize(const Pipeline &pipeline, const std::string &filename

uint8_t *buf = builder.GetBufferPointer();
int size = builder.GetSize();

if (buf != nullptr && size > 0) {
derek-gerstmann marked this conversation as resolved.
Show resolved Hide resolved
result.clear();
result.reserve(size);
result.insert(result.begin(), buf, buf + size);
} else {
user_error << "failed to serialize pipeline!\n";
}
}

void Serializer::serialize(const Pipeline &pipeline, const std::string &filename) {
std::vector<uint8_t> data;
derek-gerstmann marked this conversation as resolved.
Show resolved Hide resolved
serialize(pipeline, data);
std::ofstream out(filename, std::ios::out | std::ios::binary);
if (!out) {
user_error << "failed to open file " << filename << "\n";
exit(1);
}
out.write((char *)(buf), size);
out.write((char *)(data.data()), data.size());
out.close();
}

} // namespace Internal

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data) {
Internal::Serializer serializer;
serializer.serialize(pipeline, data);
}

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data, std::map<std::string, Parameter> &params) {
Internal::Serializer serializer;
serializer.serialize(pipeline, data);
params = serializer.get_external_parameters();
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename) {
Internal::Serializer serializer;
serializer.serialize(pipeline, filename);
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, std::map<std::string, Parameter> &params) {
Internal::Serializer serializer;
serializer.serialize(pipeline, filename);
Expand All @@ -1476,6 +1509,18 @@ void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, s

namespace Halide {

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data, std::map<std::string, Parameter> &params) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, std::map<std::string, Parameter> &params) {
user_error << "Serialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}
Expand Down
21 changes: 21 additions & 0 deletions src/Serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,27 @@

namespace Halide {

/// @brief Serialize a Halide pipeline into the given data buffer.
/// @param pipeline The Halide pipeline to serialize.
/// @param data The data buffer to store the serialized Halide pipeline into. Any existing contents will be destroyed.
/// @param params Map of named parameters which will get populated during serialization (can be used to bind external parameters to objects in the pipeline by name).
void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data);

/// @brief Serialize a Halide pipeline into the given data buffer.
/// @param pipeline The Halide pipeline to serialize.
/// @param data The data buffer to store the serialized Halide pipeline into. Any existing contents will be destroyed.
/// @param params Map of named parameters which will get populated during serialization (can be used to bind external parameters to objects in the pipeline by name).
void serialize_pipeline(const Pipeline &pipeline, std::vector<uint8_t> &data, std::map<std::string, Parameter> &params);

/// @brief Serialize a Halide pipeline into the given filename.
/// @param pipeline The Halide pipeline to serialize.
/// @param filename The location of the file to write into to store the serialized pipeline. Any existing contents will be destroyed.
void serialize_pipeline(const Pipeline &pipeline, const std::string &filename);

/// @brief Serialize a Halide pipeline into the given filename.
/// @param pipeline The Halide pipeline to serialize.
/// @param filename The location of the file to write into to store the serialized pipeline. Any existing contents will be destroyed.
/// @param params Map of named parameters which will get populated during serialization (can be used to bind external parameters to objects in the pipeline by name).
void serialize_pipeline(const Pipeline &pipeline, const std::string &filename, std::map<std::string, Parameter> &params);

} // namespace Halide
Expand Down
3 changes: 2 additions & 1 deletion tutorial/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -210,5 +210,6 @@ if (TARGET Halide::Mullapudi2016)
set_tests_properties(tutorial_lesson_21_auto_scheduler_run PROPERTIES LABELS "tutorial;multithreaded")
endif ()

# Lesson 22
# Lessons 22-23
add_tutorial(lesson_22_jit_performance.cpp)
add_tutorial(lesson_23_serialization.cpp WITH_IMAGE_IO)
Loading