-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[API change] Based on user feedback, we have removed distinction betw…
…een the graph and plan objects. With the new API, plan remains embedded in the graph and all operations are performed on the graph object. Previously, ``` REQUIRE(graph.validate().is_good()); REQUIRE(graph.build_operation_graph(handle).is_good()); auto plans = graph.get_execution_plan_list({fe::HeurMode_t::A}); REQUIRE(plans.check_support(handle).is_good()); REQUIRE(graph.set_execution_plans(plans).is_good()); ``` Now, ``` REQUIRE(graph.validate().is_good()); REQUIRE(graph.build_operation_graph(handle).is_good()); REQUIRE(graph.create_execution_plans({fe::HeurMode_t::A}).is_good()); REQUIRE(graph.check_support(handle).is_good()); REQUIRE(graph.build_plans(handle).is_good()); ``` Also, with this change the following new API have been introduced on the graph class. ``` error_t build_plans(cudnnHandle_t const &handle, BuildPlanPolicy_t const policy = BuildPlanPolicy_t::HEURISTICS_CHOICE, bool const do_multithreaded_builds = false); Graph & deselect_workspace_greater_than(int64_t const workspace); Graph & deselect_behavior_notes(std::vector<BehaviorNote_t> const ¬es); Graph & deselect_numeric_notes(std::vector<NumericalNote_t> const ¬es); int64_t get_workspace_size() const int64_t get_autotune_workspace_size() const; error_t autotune(cudnnHandle_t handle, std::unordered_map<std::shared_ptr<Tensor_attributes>, void *> variants, void *workspace, void *user_impl = nullptr); ``` [API change] Removes the implicit `validate` call made in `build_operation_graph`. Now, the expectation is that the user explicitly calls `validate` on the graph before calling `build_operation_graph`. This helps the user distinguish errors between malformed graphs and error occuring due to lowering into cudnn. [API change] Return error codes from the graph API have now been marked `nodiscard`. [New API] Have added a new `graph::key() -> int64_t` as an API that returns a hash on the graph object. This can be used as key for graph caching. Eg. of this usage is shown in the samples. [New API] Have added new python API `create_handle`, `destroy_handle`, `set_stream`, `get_stream` to allow custom handle and stream management on the graph object. [New functionality] sdpa backward can now compute dbias if the fprop had a bias operation. This functionality was added in cudnn 8.9.6. [Enhancement] There is a extension in behavior of `CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT`. This is documented in `docs/operation/Attention.md` [Enhancement] Have added better error checks to make sure all the tensors of the node have been created. This prevents unexpected segmentation faults seen earlier. [Bug Fix] Fix issues in instancenorm, which had caused invalid memory access earlier. [Enhancement] Have moved the v0.9 API samples to `samples/legacy_samples` folder for better organization.
- Loading branch information
Showing
103 changed files
with
5,521 additions
and
5,791 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#pragma once | ||
|
||
#include "../cudnn_frontend_utils.h" | ||
|
||
namespace cudnn_frontend::detail { | ||
|
||
class Context { | ||
DataType_t compute_data_type = DataType_t::NOT_SET; | ||
DataType_t intermediate_data_type = DataType_t::NOT_SET; | ||
DataType_t io_data_type = DataType_t::NOT_SET; | ||
|
||
public: | ||
Context& | ||
set_intermediate_data_type(DataType_t const type) { | ||
intermediate_data_type = type; | ||
return *this; | ||
} | ||
|
||
Context& | ||
set_io_data_type(DataType_t const type) { | ||
io_data_type = type; | ||
return *this; | ||
} | ||
|
||
Context& | ||
set_compute_data_type(DataType_t const type) { | ||
compute_data_type = type; | ||
return *this; | ||
} | ||
|
||
DataType_t | ||
get_io_data_type() const { | ||
return io_data_type; | ||
} | ||
|
||
DataType_t | ||
get_intermediate_data_type() const { | ||
return intermediate_data_type; | ||
} | ||
|
||
DataType_t | ||
get_compute_data_type() const { | ||
return compute_data_type; | ||
} | ||
|
||
Context& | ||
fill_missing_properties(Context const& global_context) { | ||
if (get_compute_data_type() == DataType_t::NOT_SET) { | ||
set_compute_data_type(global_context.get_compute_data_type()); | ||
} | ||
if (get_intermediate_data_type() == DataType_t::NOT_SET) { | ||
set_intermediate_data_type(global_context.get_intermediate_data_type()); | ||
} | ||
if (get_io_data_type() == DataType_t::NOT_SET) { | ||
set_io_data_type(global_context.get_io_data_type()); | ||
} | ||
return *this; | ||
} | ||
}; | ||
|
||
} // namespace cudnn_frontend::detail |
Oops, something went wrong.