From 976ea0b49515a5a92ebd40c191852100c235ed51 Mon Sep 17 00:00:00 2001 From: Derek Gerstmann Date: Mon, 27 Nov 2023 16:55:41 -0800 Subject: [PATCH] [serialization] Serialize stub definitions of external parameters. (#7926) * Serialize stub definitions of external parameters. Add deserialize_parameter methods to allow the user to only deserialize the mapping of external parameters (and remap them to their own user parameters) prior to deserializing the full pipeline definition. * Clang tidy/format pass --------- Co-authored-by: Derek Gerstmann --- src/Deserialization.cpp | 153 ++++++++++++++++++++++++--- src/Deserialization.h | 34 ++++-- src/Serialization.cpp | 20 +++- src/halide_ir.fbs | 8 ++ tutorial/lesson_23_serialization.cpp | 8 +- 5 files changed, 194 insertions(+), 29 deletions(-) diff --git a/src/Deserialization.cpp b/src/Deserialization.cpp index bb19cf82c9aa..9923e9d1c89c 100644 --- a/src/Deserialization.cpp +++ b/src/Deserialization.cpp @@ -23,8 +23,8 @@ class Deserializer { public: Deserializer() = default; - explicit Deserializer(const std::map &external_params) - : external_params(external_params) { + explicit Deserializer(const std::map &user_params) + : user_params(user_params) { } // Deserialize a pipeline from the given filename @@ -36,6 +36,16 @@ class Deserializer { // Deserialize a pipeline from the given buffer of bytes Pipeline deserialize(const std::vector &data); + // Deserialize just the unbound external parameters that need to be defined for the pipeline from the given filename + // (so they can be remapped and overridden with user parameters prior to deserializing the pipeline) + std::map deserialize_parameters(const std::string &filename); + + // Deserialize just the unbound external parameters that need to be defined for the pipeline from the given input stream + std::map deserialize_parameters(std::istream &in); + + // Deserialize just the unbound external parameters that need to be defined for the pipeline from the given buffer of bytes + std::map deserialize_parameters(const std::vector &data); + private: // Helper function to deserialize a homogenous vector from a flatbuffer vector, // does not apply to union types like Stmt and Expr or enum types like MemoryType @@ -63,6 +73,9 @@ class Deserializer { std::map> buffers_in_pipeline; // External parameters that are not deserialized but will be used in the pipeline + std::map user_params; + + // Default external parameters that were created during deserialization std::map external_params; MemoryType deserialize_memory_type(Serialize::MemoryType memory_type); @@ -139,6 +152,8 @@ class Deserializer { Parameter deserialize_parameter(const Serialize::Parameter *parameter); + Parameter deserialize_external_parameter(const Serialize::ExternalParameter *external_parameter); + ExternFuncArgument deserialize_extern_func_argument(const Serialize::ExternFuncArgument *extern_func_argument); std::map deserialize_wrapper_refs(const flatbuffers::Vector> *wrappers); @@ -457,12 +472,15 @@ void Deserializer::deserialize_function(const Serialize::Func *function, Functio deserialize_vector(function->updates(), &Deserializer::deserialize_definition); const std::string debug_file = deserialize_string(function->debug_file()); + std::vector output_buffers; output_buffers.reserve(function->output_buffers_names()->size()); for (const auto &output_buffer_name_serialized : *function->output_buffers_names()) { auto output_buffer_name = deserialize_string(output_buffer_name_serialized); Parameter output_buffer; - if (auto it = external_params.find(output_buffer_name); it != external_params.end()) { + if (auto it = user_params.find(output_buffer_name); it != user_params.end()) { + output_buffer = it->second; + } else if (auto it = external_params.find(output_buffer_name); it != external_params.end()) { output_buffer = it->second; } else if (auto it = parameters_in_pipeline.find(output_buffer_name); it != parameters_in_pipeline.end()) { output_buffer = it->second; @@ -534,7 +552,9 @@ Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt) const auto index = deserialize_expr(store_stmt->index_type(), store_stmt->index()); const auto param_name = deserialize_string(store_stmt->param_name()); Parameter param; - if (auto it = external_params.find(param_name); it != external_params.end()) { + if (auto it = user_params.find(param_name); it != user_params.end()) { + param = it->second; + } else if (auto it = external_params.find(param_name); it != external_params.end()) { param = it->second; } else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) { param = it->second; @@ -799,7 +819,9 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr) } const auto param_name = deserialize_string(load_expr->param_name()); Parameter param; - if (auto it = external_params.find(param_name); it != external_params.end()) { + if (auto it = user_params.find(param_name); it != user_params.end()) { + param = it->second; + } else if (auto it = external_params.find(param_name); it != external_params.end()) { param = it->second; } else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) { param = it->second; @@ -850,7 +872,9 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr) } const auto param_name = deserialize_string(call_expr->param_name()); Parameter param; - if (auto it = external_params.find(param_name); it != external_params.end()) { + if (auto it = user_params.find(param_name); it != user_params.end()) { + param = it->second; + } else if (auto it = external_params.find(param_name); it != external_params.end()) { param = it->second; } else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) { param = it->second; @@ -866,7 +890,9 @@ Expr Deserializer::deserialize_expr(Serialize::Expr type_code, const void *expr) const auto type = deserialize_type(variable_expr->type()); const auto param_name = deserialize_string(variable_expr->param_name()); Parameter param; - if (auto it = external_params.find(param_name); it != external_params.end()) { + if (auto it = user_params.find(param_name); it != user_params.end()) { + param = it->second; + } else if (auto it = external_params.find(param_name); it != external_params.end()) { param = it->second; } else if (auto it = parameters_in_pipeline.find(param_name); it != parameters_in_pipeline.end()) { param = it->second; @@ -1224,6 +1250,15 @@ Parameter Deserializer::deserialize_parameter(const Serialize::Parameter *parame } } +Parameter Deserializer::deserialize_external_parameter(const Serialize::ExternalParameter *external_parameter) { + user_assert(external_parameter != nullptr); + const bool is_buffer = external_parameter->is_buffer(); + const auto type = deserialize_type(external_parameter->type()); + const int dimensions = external_parameter->dimensions(); + const std::string name = deserialize_string(external_parameter->name()); + return Parameter(type, is_buffer, dimensions, name); +} + ExternFuncArgument Deserializer::deserialize_extern_func_argument(const Serialize::ExternFuncArgument *extern_func_argument) { user_assert(extern_func_argument != nullptr); const auto arg_type = deserialize_extern_func_argument_type(extern_func_argument->arg_type()); @@ -1249,7 +1284,9 @@ ExternFuncArgument Deserializer::deserialize_extern_func_argument(const Serializ } else { const auto image_param_name = deserialize_string(extern_func_argument->image_param_name()); Parameter image_param; - if (auto it = external_params.find(image_param_name); it != external_params.end()) { + if (auto it = user_params.find(image_param_name); it != user_params.end()) { + image_param = it->second; + } else if (auto it = external_params.find(image_param_name); it != external_params.end()) { image_param = it->second; } else if (auto it = parameters_in_pipeline.find(image_param_name); it != parameters_in_pipeline.end()) { image_param = it->second; @@ -1397,6 +1434,13 @@ Pipeline Deserializer::deserialize(const std::vector &data) { parameters_in_pipeline[param.name()] = param; } + const std::vector parameters_external = + deserialize_vector(pipeline_obj->external_parameters(), + &Deserializer::deserialize_external_parameter); + for (const auto ¶m : parameters_external) { + external_params[param.name()] = param; + } + std::vector funcs; for (size_t i = 0; i < pipeline_obj->funcs()->size(); ++i) { deserialize_function(pipeline_obj->funcs()->Get(i), functions[i]); @@ -1427,44 +1471,119 @@ Pipeline Deserializer::deserialize(const std::vector &data) { return Pipeline(output_funcs, requirements); } +std::map Deserializer::deserialize_parameters(const std::string &filename) { + std::map empty; + std::ifstream in(filename, std::ios::binary | std::ios::in); + if (!in) { + user_error << "failed to open file " << filename << "\n"; + return empty; + } + std::map params = deserialize_parameters(in); + if (!in.good()) { + user_error << "failed to deserialize from file " << filename << " properly\n"; + return empty; + } + in.close(); + return params; +} + +std::map Deserializer::deserialize_parameters(std::istream &in) { + std::map empty; + if (!in) { + user_error << "failed to open input stream\n"; + return empty; + } + in.seekg(0, std::ios::end); + int size = in.tellg(); + in.seekg(0, std::ios::beg); + std::vector data(size); + in.read((char *)data.data(), size); + return deserialize_parameters(data); +} + +std::map Deserializer::deserialize_parameters(const std::vector &data) { + std::map external_parameters_by_name; + const auto *pipeline_obj = Serialize::GetPipeline(data.data()); + if (pipeline_obj == nullptr) { + user_warning << "deserialized pipeline is empty\n"; + return external_parameters_by_name; + } + + const std::vector external_parameters = + deserialize_vector(pipeline_obj->external_parameters(), + &Deserializer::deserialize_external_parameter); + + for (const auto ¶m : external_parameters) { + external_parameters_by_name[param.name()] = param; + } + return external_parameters_by_name; +} + } // namespace Internal -Pipeline deserialize_pipeline(const std::string &filename, const std::map &external_params) { - Internal::Deserializer deserializer(external_params); +Pipeline deserialize_pipeline(const std::string &filename, const std::map &user_params) { + Internal::Deserializer deserializer(user_params); return deserializer.deserialize(filename); } -Pipeline deserialize_pipeline(std::istream &in, const std::map &external_params) { - Internal::Deserializer deserializer(external_params); +Pipeline deserialize_pipeline(std::istream &in, const std::map &user_params) { + Internal::Deserializer deserializer(user_params); return deserializer.deserialize(in); } -Pipeline deserialize_pipeline(const std::vector &buffer, const std::map &external_params) { - Internal::Deserializer deserializer(external_params); +Pipeline deserialize_pipeline(const std::vector &buffer, const std::map &user_params) { + Internal::Deserializer deserializer(user_params); return deserializer.deserialize(buffer); } +std::map deserialize_parameters(const std::string &filename) { + Internal::Deserializer deserializer; + return deserializer.deserialize_parameters(filename); +} + +std::map deserialize_parameters(std::istream &in) { + Internal::Deserializer deserializer; + return deserializer.deserialize_parameters(in); +} + +std::map deserialize_parameters(const std::vector &buffer) { + Internal::Deserializer deserializer; + return deserializer.deserialize_parameters(buffer); +} + } // namespace Halide #else // WITH_SERIALIZATION namespace Halide { -Pipeline deserialize_pipeline(const std::string &filename, const std::map &external_params) { +Pipeline deserialize_pipeline(const std::string &filename, const std::map &user_params) { user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON."; return Pipeline(); } -Pipeline deserialize_pipeline(std::istream &in, const std::map &external_params) { +Pipeline deserialize_pipeline(std::istream &in, const std::map &user_params) { user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON."; return Pipeline(); } -Pipeline deserialize_pipeline(const std::vector &buffer, const std::map &external_params) { +Pipeline deserialize_pipeline(const std::vector &buffer, const std::map &user_params) { user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON."; return Pipeline(); } +std::map deserialize_parameters(const std::string &filename) { + user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON."; +} + +std::map deserialize_parameters(std::istream &in) { + user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON."; +} + +std::map deserialize_parameters(const std::vector &buffer) { + user_error << "Deserialization is not supported in this build of Halide; try rebuilding with WITH_SERIALIZATION=ON."; +} + } // namespace Halide #endif // WITH_SERIALIZATION diff --git a/src/Deserialization.h b/src/Deserialization.h index 82f7c8e7217b..b4b3844303c0 100644 --- a/src/Deserialization.h +++ b/src/Deserialization.h @@ -9,21 +9,43 @@ namespace Halide { /// @brief Deserialize a Halide pipeline from a file. /// @param filename The location of the file to deserialize. Must use .hlpipe extension. -/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead). +/// @param user_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead). /// @return Returns a newly constructed deserialized Pipeline object/ -Pipeline deserialize_pipeline(const std::string &filename, const std::map &external_params); +Pipeline deserialize_pipeline(const std::string &filename, const std::map &user_params); /// @brief Deserialize a Halide pipeline from an input stream. /// @param in The input stream to read from containing a serialized Halide pipeline -/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead). +/// @param user_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead). /// @return Returns a newly constructed deserialized Pipeline object/ -Pipeline deserialize_pipeline(std::istream &in, const std::map &external_params); +Pipeline deserialize_pipeline(std::istream &in, const std::map &user_params); /// @brief Deserialize a Halide pipeline from a byte buffer containing a serizalized pipeline in binary format /// @param data The data buffer containing a serialized Halide pipeline -/// @param external_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead). +/// @param user_params Map of named input/output parameters to bind with the resulting pipeline (used to avoid deserializing specific objects and enable the use of externally defined ones instead). /// @return Returns a newly constructed deserialized Pipeline object/ -Pipeline deserialize_pipeline(const std::vector &data, const std::map &external_params); +Pipeline deserialize_pipeline(const std::vector &data, const std::map &user_params); + +/// @brief Deserialize the extenal parameters for the Halide pipeline from a file. +/// This method allows a minimal deserialization of just the external pipeline parameters, so they can be +/// remapped and overridden with user parameters prior to deserializing the pipeline definition. +/// @param filename The location of the file to deserialize. Must use .hlpipe extension. +/// @return Returns a map containing the names and description of external parameters referenced in the pipeline +std::map deserialize_parameters(const std::string &filename); + +/// @brief Deserialize the extenal parameters for the Halide pipeline from input stream. +/// This method allows a minimal deserialization of just the external pipeline parameters, so they can be +/// remapped and overridden with user parameters prior to deserializing the pipeline definition. +/// @param in The input stream to read from containing a serialized Halide pipeline +/// @return Returns a map containing the names and description of external parameters referenced in the pipeline +std::map deserialize_parameters(std::istream &in); + +/// @brief Deserialize the extenal parameters for the Halide pipeline from a byte buffer containing a serialized +/// pipeline in binary format. This method allows a minimal deserialization of just the external pipeline +/// parameters, so they can be remapped and overridden with user parameters prior to deserializing the +/// pipeline definition. +/// @param data The data buffer containing a serialized Halide pipeline +/// @return Returns a map containing the names and description of external parameters referenced in the pipeline +std::map deserialize_parameters(const std::vector &data); } // namespace Halide diff --git a/src/Serialization.cpp b/src/Serialization.cpp index c85eaa15e1aa..e29c7e053179 100644 --- a/src/Serialization.cpp +++ b/src/Serialization.cpp @@ -127,6 +127,8 @@ class Serializer { Offset serialize_parameter(FlatBufferBuilder &builder, const Parameter ¶meter); + Offset serialize_external_parameter(FlatBufferBuilder &builder, const Parameter ¶meter); + Offset serialize_extern_func_argument(FlatBufferBuilder &builder, const ExternFuncArgument &extern_func_argument); Offset serialize_buffer(FlatBufferBuilder &builder, Buffer<> buffer); @@ -1351,6 +1353,14 @@ Offset Serializer::serialize_parameter(FlatBufferBuilder & } } +Offset Serializer::serialize_external_parameter(FlatBufferBuilder &builder, const Parameter ¶meter) { + const auto type_serialized = serialize_type(builder, parameter.type()); + const int dimensions = parameter.dimensions(); + const auto name_serialized = serialize_string(builder, parameter.name()); + const bool is_buffer = parameter.is_buffer(); + return Serialize::CreateExternalParameter(builder, is_buffer, type_serialized, dimensions, name_serialized); +} + Offset Serializer::serialize_extern_func_argument(FlatBufferBuilder &builder, const ExternFuncArgument &extern_func_argument) { const auto arg_type_serialized = serialize_extern_func_argument_type(extern_func_argument.arg_type); if (extern_func_argument.arg_type == ExternFuncArgument::ArgType::UndefinedArg) { @@ -1472,12 +1482,19 @@ void Serializer::serialize(const Pipeline &pipeline, std::vector &resul std::vector> parameters_serialized; parameters_serialized.reserve(parameters_in_pipeline.size()); for (const auto ¶m : parameters_in_pipeline) { - // we only serialize internal parameters with the pipeline + // we only serialize the definitions of internal parameters with the pipeline if (external_parameters.find(param.first) == external_parameters.end()) { parameters_serialized.push_back(serialize_parameter(builder, param.second)); } } + // Serialize only the metadata describing external parameters (to allow the to be created with defaults upon deserialization) + std::vector> external_parameters_serialized; + external_parameters_serialized.reserve(external_parameters.size()); + for (const auto ¶m : external_parameters) { + external_parameters_serialized.push_back(serialize_external_parameter(builder, param.second)); + } + std::vector> buffers_serialized; buffers_serialized.reserve(buffers_in_pipeline.size()); for (auto &buffer : buffers_in_pipeline) { @@ -1491,6 +1508,7 @@ void Serializer::serialize(const Pipeline &pipeline, std::vector &resul builder.CreateVector(requirements_serialized), builder.CreateVector(func_names_in_order_serialized), builder.CreateVector(parameters_serialized), + builder.CreateVector(external_parameters_serialized), builder.CreateVector(buffers_serialized)); builder.Finish(pipeline_obj); diff --git a/src/halide_ir.fbs b/src/halide_ir.fbs index f3d27e83a62a..479e488b6739 100644 --- a/src/halide_ir.fbs +++ b/src/halide_ir.fbs @@ -640,6 +640,13 @@ table Parameter { scalar_estimate: Expr; } +table ExternalParameter { + is_buffer: bool; + type: Type; + dimensions: int32; + name: string; +} + enum ExternFuncArgumentType: ubyte { UndefinedArg, FuncArg, @@ -701,6 +708,7 @@ table Pipeline { requirements: [Stmt]; func_names_in_order: [string]; parameters: [Parameter]; + external_parameters: [ExternalParameter]; buffers: [Buffer]; } diff --git a/tutorial/lesson_23_serialization.cpp b/tutorial/lesson_23_serialization.cpp index a01de5f916fd..f383debbcb7f 100644 --- a/tutorial/lesson_23_serialization.cpp +++ b/tutorial/lesson_23_serialization.cpp @@ -108,11 +108,9 @@ int main(int argc, char **argv) { { // Lets do the same thing again ... construct a new pipeline from scratch by deserializing the file we wrote to disk - // FIXME: We shouldn't have to populate the params ... but passing an empty map triggers an error in deserialize - // for a missing input param - std::map params; - ImageParam input(UInt(8), 3, "input"); - params.insert({"input", input.parameter()}); + // First we can deserialize the external parameters (useful in the event we want to remap them + // and replace the definitions with our own user parameter definitions) + std::map params = deserialize_parameters("blur.hlpipe"); // Now deserialize the pipeline from file Pipeline blur_pipeline = deserialize_pipeline("blur.hlpipe", params);