Skip to content

Commit

Permalink
[serialization] Serialize stub definitions of external parameters. (h…
Browse files Browse the repository at this point in the history
…alide#7926)

* Serialize stub definitions of external parameters.
Add deserialize_parameter methods to allow the user to only deserialize
the mapping of external parameters (and remap them to their own user
parameters) prior to deserializing the full pipeline definition.

* Clang tidy/format pass

---------

Co-authored-by: Derek Gerstmann <dgerstmann@adobe.com>
  • Loading branch information
2 people authored and ardier committed Mar 3, 2024
1 parent 42305e7 commit 53f0816
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 29 deletions.
153 changes: 136 additions & 17 deletions src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class Deserializer {
public:
Deserializer() = default;

explicit Deserializer(const std::map<std::string, Parameter> &external_params)
: external_params(external_params) {
explicit Deserializer(const std::map<std::string, Parameter> &user_params)
: user_params(user_params) {
}

// Deserialize a pipeline from the given filename
Expand All @@ -36,6 +36,16 @@ class Deserializer {
// Deserialize a pipeline from the given buffer of bytes
Pipeline deserialize(const std::vector<uint8_t> &data);

// Deserialize just the unbound external parameters that need to be defined for the pipeline from the given filename
// (so they can be remapped and overridden with user parameters prior to deserializing the pipeline)
std::map<std::string, Parameter> deserialize_parameters(const std::string &filename);

// Deserialize just the unbound external parameters that need to be defined for the pipeline from the given input stream
std::map<std::string, Parameter> deserialize_parameters(std::istream &in);

// Deserialize just the unbound external parameters that need to be defined for the pipeline from the given buffer of bytes
std::map<std::string, Parameter> deserialize_parameters(const std::vector<uint8_t> &data);

private:
// Helper function to deserialize a homogenous vector from a flatbuffer vector,
// does not apply to union types like Stmt and Expr or enum types like MemoryType
Expand Down Expand Up @@ -63,6 +73,9 @@ class Deserializer {
std::map<std::string, Buffer<>> buffers_in_pipeline;

// External parameters that are not deserialized but will be used in the pipeline
std::map<std::string, Parameter> user_params;

// Default external parameters that were created during deserialization
std::map<std::string, Parameter> external_params;

MemoryType deserialize_memory_type(Serialize::MemoryType memory_type);
Expand Down Expand Up @@ -139,6 +152,8 @@ class Deserializer {

Parameter deserialize_parameter(const Serialize::Parameter *parameter);

Parameter deserialize_external_parameter(const Serialize::ExternalParameter *external_parameter);

ExternFuncArgument deserialize_extern_func_argument(const Serialize::ExternFuncArgument *extern_func_argument);

std::map<std::string, FunctionPtr> deserialize_wrapper_refs(const flatbuffers::Vector<flatbuffers::Offset<Serialize::WrapperRef>> *wrappers);
Expand Down Expand Up @@ -457,12 +472,15 @@ void Deserializer::deserialize_function(const Serialize::Func *function, Functio
deserialize_vector<Serialize::Definition, Definition>(function->updates(),
&Deserializer::deserialize_definition);
const std::string debug_file = deserialize_string(function->debug_file());

std::vector<Parameter> output_buffers;
output_buffers.reserve(function->output_buffers_names()->size());
for (const auto &output_buffer_name_serialized : *function->output_buffers_names()) {
auto output_buffer_name = deserialize_string(output_buffer_name_serialized);
Parameter output_buffer;
if (auto it = external_params.find(output_buffer_name); it != external_params.end()) {
if (auto it = user_params.find(output_buffer_name); it != user_params.end()) {
output_buffer = it->second;
} else if (auto it = external_params.find(output_buffer_name); it != external_params.end()) {
output_buffer = it->second;
} else if (auto it = parameters_in_pipeline.find(output_buffer_name); it != parameters_in_pipeline.end()) {
output_buffer = it->second;
Expand Down Expand Up @@ -534,7 +552,9 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt)
const auto index = deserialize_expr(store_stmt->index_type(), store_stmt->index());
const auto param_name = deserialize_string(store_stmt->param_name());
Parameter param;
if (auto it = external_params.find(param_name); it != external_params.end()) {
if (auto it = user_params.find(param_name); it != user_params.end()) {
param = it->second;
} else if (auto it = external_params.find(param_name); it != external_params.end()) {
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
Expand Down Expand Up @@ -799,7 +819,9 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
}
const auto param_name = deserialize_string(load_expr->param_name());
Parameter param;
if (auto it = external_params.find(param_name); it != external_params.end()) {
if (auto it = user_params.find(param_name); it != user_params.end()) {
param = it->second;
} else if (auto it = external_params.find(param_name); it != external_params.end()) {
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
Expand Down Expand Up @@ -850,7 +872,9 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
}
const auto param_name = deserialize_string(call_expr->param_name());
Parameter param;
if (auto it = external_params.find(param_name); it != external_params.end()) {
if (auto it = user_params.find(param_name); it != user_params.end()) {
param = it->second;
} else if (auto it = external_params.find(param_name); it != external_params.end()) {
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
Expand All @@ -866,7 +890,9 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr)
const auto type = deserialize_type(variable_expr->type());
const auto param_name = deserialize_string(variable_expr->param_name());
Parameter param;
if (auto it = external_params.find(param_name); it != external_params.end()) {
if (auto it = user_params.find(param_name); it != user_params.end()) {
param = it->second;
} else if (auto it = external_params.find(param_name); it != external_params.end()) {
param = it->second;
} else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) {
param = it->second;
Expand Down Expand Up @@ -1224,6 +1250,15 @@ Parameter Deserializer::deserialize_parameter(const Serialize::Parameter *parame
}
}

Parameter Deserializer::deserialize_external_parameter(const Serialize::ExternalParameter *external_parameter) {
user_assert(external_parameter != nullptr);
const bool is_buffer = external_parameter->is_buffer();
const auto type = deserialize_type(external_parameter->type());
const int dimensions = external_parameter->dimensions();
const std::string name = deserialize_string(external_parameter->name());
return Parameter(type, is_buffer, dimensions, name);
}

ExternFuncArgument Deserializer::deserialize_extern_func_argument(const Serialize::ExternFuncArgument *extern_func_argument) {
user_assert(extern_func_argument != nullptr);
const auto arg_type = deserialize_extern_func_argument_type(extern_func_argument->arg_type());
Expand All @@ -1249,7 +1284,9 @@ ExternFuncArgument Deserializer::deserialize_extern_func_argument(const Serializ
} else {
const auto image_param_name = deserialize_string(extern_func_argument->image_param_name());
Parameter image_param;
if (auto it = external_params.find(image_param_name); it != external_params.end()) {
if (auto it = user_params.find(image_param_name); it != user_params.end()) {
image_param = it->second;
} else if (auto it = external_params.find(image_param_name); it != external_params.end()) {
image_param = it->second;
} else if (auto it = parameters_in_pipeline.find(image_param_name); it != parameters_in_pipeline.end()) {
image_param = it->second;
Expand Down Expand Up @@ -1397,6 +1434,13 @@ Pipeline Deserializer::deserialize(const std::vector<uint8_t> &data) {
parameters_in_pipeline[param.name()] = param;
}

const std::vector<Parameter> parameters_external =
deserialize_vector<Serialize::ExternalParameter, Parameter>(pipeline_obj->external_parameters(),
&Deserializer::deserialize_external_parameter);
for (const auto &param : parameters_external) {
external_params[param.name()] = param;
}

std::vector<Func> funcs;
for (size_t i = 0; i < pipeline_obj->funcs()->size(); ++i) {
deserialize_function(pipeline_obj->funcs()->Get(i), functions[i]);
Expand Down Expand Up @@ -1427,44 +1471,119 @@ Pipeline Deserializer::deserialize(const std::vector<uint8_t> &data) {
return Pipeline(output_funcs, requirements);
}

std::map<std::string, Parameter> Deserializer::deserialize_parameters(const std::string &filename) {
std::map<std::string, Parameter> empty;
std::ifstream in(filename, std::ios::binary | std::ios::in);
if (!in) {
user_error << "failed to open file " << filename << "\n";
return empty;
}
std::map<std::string, Parameter> params = deserialize_parameters(in);
if (!in.good()) {
user_error << "failed to deserialize from file " << filename << " properly\n";
return empty;
}
in.close();
return params;
}

std::map<std::string, Parameter> Deserializer::deserialize_parameters(std::istream &in) {
std::map<std::string, Parameter> empty;
if (!in) {
user_error << "failed to open input stream\n";
return empty;
}
in.seekg(0, std::ios::end);
int size = in.tellg();
in.seekg(0, std::ios::beg);
std::vector<uint8_t> data(size);
in.read((char *)data.data(), size);
return deserialize_parameters(data);
}

std::map<std::string, Parameter> Deserializer::deserialize_parameters(const std::vector<uint8_t> &data) {
std::map<std::string, Parameter> external_parameters_by_name;
const auto *pipeline_obj = Serialize::GetPipeline(data.data());
if (pipeline_obj == nullptr) {
user_warning << "deserialized pipeline is empty\n";
return external_parameters_by_name;
}

const std::vector<Parameter> external_parameters =
deserialize_vector<Serialize::ExternalParameter, Parameter>(pipeline_obj->external_parameters(),
&Deserializer::deserialize_external_parameter);

for (const auto &param : external_parameters) {
external_parameters_by_name[param.name()] = param;
}
return external_parameters_by_name;
}

} // namespace Internal

Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &user_params) {
Internal::Deserializer deserializer(user_params);
return deserializer.deserialize(filename);
}

Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &user_params) {
Internal::Deserializer deserializer(user_params);
return deserializer.deserialize(in);
}

Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &external_params) {
Internal::Deserializer deserializer(external_params);
Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &user_params) {
Internal::Deserializer deserializer(user_params);
return deserializer.deserialize(buffer);
}

std::map<std::string, Parameter> deserialize_parameters(const std::string &filename) {
Internal::Deserializer deserializer;
return deserializer.deserialize_parameters(filename);
}

std::map<std::string, Parameter> deserialize_parameters(std::istream &in) {
Internal::Deserializer deserializer;
return deserializer.deserialize_parameters(in);
}

std::map<std::string, Parameter> deserialize_parameters(const std::vector<uint8_t> &buffer) {
Internal::Deserializer deserializer;
return deserializer.deserialize_parameters(buffer);
}

} // namespace Halide

#else // WITH_SERIALIZATION

namespace Halide {

Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params) {
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &user_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}

Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params) {
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &user_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}

Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &external_params) {
Pipeline deserialize_pipeline(const std::vector<uint8_t> &buffer, const std::map<std::string, Parameter> &user_params) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
return Pipeline();
}

std::map<std::string, Parameter> deserialize_parameters(const std::string &filename) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

std::map<std::string, Parameter> deserialize_parameters(std::istream &in) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

std::map<std::string, Parameter> deserialize_parameters(const std::vector<uint8_t> &buffer) {
user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON.";
}

} // namespace Halide

#endif // WITH_SERIALIZATION
34 changes: 28 additions & 6 deletions src/Deserialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,43 @@ namespace Halide {

/// @brief Deserialize a Halide pipeline from a file.
/// @param filename The location of the file to deserialize. Must use .hlpipe extension.
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @param user_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &external_params);
Pipeline deserialize_pipeline(const std::string &filename, const std::map<std::string, Parameter> &user_params);

/// @brief Deserialize a Halide pipeline from an input stream.
/// @param in The input stream to read from containing a serialized Halide pipeline
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @param user_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &external_params);
Pipeline deserialize_pipeline(std::istream &in, const std::map<std::string, Parameter> &user_params);

/// @brief Deserialize a Halide pipeline from a byte buffer containing a serizalized pipeline in binary format
/// @param data The data buffer containing a serialized Halide pipeline
/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @param user_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead).
/// @return Returns a newly constructed deserialized Pipeline object/
Pipeline deserialize_pipeline(const std::vector<uint8_t> &data, const std::map<std::string, Parameter> &external_params);
Pipeline deserialize_pipeline(const std::vector<uint8_t> &data, const std::map<std::string, Parameter> &user_params);

/// @brief Deserialize the extenal parameters for the Halide pipeline from a file.
/// This method allows a minimal deserialization of just the external pipeline parameters, so they can be
/// remapped and overridden with user parameters prior to deserializing the pipeline definition.
/// @param filename The location of the file to deserialize. Must use .hlpipe extension.
/// @return Returns a map containing the names and description of external parameters referenced in the pipeline
std::map<std::string, Parameter> deserialize_parameters(const std::string &filename);

/// @brief Deserialize the extenal parameters for the Halide pipeline from input stream.
/// This method allows a minimal deserialization of just the external pipeline parameters, so they can be
/// remapped and overridden with user parameters prior to deserializing the pipeline definition.
/// @param in The input stream to read from containing a serialized Halide pipeline
/// @return Returns a map containing the names and description of external parameters referenced in the pipeline
std::map<std::string, Parameter> deserialize_parameters(std::istream &in);

/// @brief Deserialize the extenal parameters for the Halide pipeline from a byte buffer containing a serialized
/// pipeline in binary format. This method allows a minimal deserialization of just the external pipeline
/// parameters, so they can be remapped and overridden with user parameters prior to deserializing the
/// pipeline definition.
/// @param data The data buffer containing a serialized Halide pipeline
/// @return Returns a map containing the names and description of external parameters referenced in the pipeline
std::map<std::string, Parameter> deserialize_parameters(const std::vector<uint8_t> &data);

} // namespace Halide

Expand Down
Loading

0 comments on commit 53f0816

Please sign in to comment.