From 4d09902436418564e94ac9ffcdd5bb4a681abcbc Mon Sep 17 00:00:00 2001 From: mcw-anasuya Date: Wed, 11 Sep 2024 11:14:12 +0000 Subject: [PATCH] #12375: Add qid and optional tensor output to ttnn.gelu_bw --- .../operations/backward/test_backward_gelu.py | 64 +++++++++++++++++++ .../device/binary_backward_op.cpp | 21 +++++- .../device/unary_backward_op.cpp | 33 +++++++--- .../device/unary_backward_op.hpp | 10 --- .../eltwise/unary_backward/unary_backward.hpp | 21 +++++- .../unary_backward/unary_backward_pybind.hpp | 60 ++++++++++++++++- 6 files changed, 186 insertions(+), 23 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py index 39c62a42e13e..a496eb27df2f 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py @@ -55,3 +55,67 @@ def test_bw_gelu_default(input_shapes, device): comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +@pytest.mark.parametrize( + "approximate", + ( + "none", + "tanh", + ), +) +def test_bw_gelu_opt_output(input_shapes, approximate, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + input_grad = ttnn.from_torch( + input_grad, ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + cq_id = 0 + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.gelu_bw(grad_tensor, input_tensor, approximate=approximate, input_grad=input_grad, queue_id=cq_id) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + + tt_output_tensor_on_device = [input_grad] + + golden_function = ttnn.get_golden_function(ttnn.gelu_bw) + golden_tensor = golden_function(grad_data, in_data) + + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_gelu_default_opt_output(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) + _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) + + cq_id = 0 + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.gelu_bw(grad_tensor, input_tensor, input_grad=input_grad, queue_id=cq_id) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + + tt_output_tensor_on_device = [input_grad] + + golden_function = ttnn.get_golden_function(ttnn.gelu_bw) + golden_tensor = golden_function(grad_data, in_data) + + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp index c0c2535e1582..af683c5fb395 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/device/binary_backward_op.cpp @@ -438,7 +438,15 @@ std::vector ExecuteBackwardBiasGelu::invoke( TT_FATAL((approximate == "none" || approximate == "tanh") && "Incorrect approximation type (expected 'none', 'tanh')"); std::vector grad_tensor; Tensor input = ttnn::add(input_a, input_b); - grad_tensor = ttnn::gelu_bw(grad, input, approximate = approximate, output_mem_config); + // grad_tensor = ttnn::gelu_bw(grad, input, approximate = approximate, output_mem_config); + std::vector> gelu_result = ttnn::gelu_bw(grad, input, approximate, output_mem_config); + for (const auto& opt_tensor : gelu_result) { + if (opt_tensor.has_value()) { + grad_tensor.push_back(opt_tensor.value()); + } else { + throw std::runtime_error("Gelu backward returned an empty tensor."); + } + } grad_tensor.emplace_back(grad_tensor[0]); return grad_tensor; } @@ -448,7 +456,16 @@ std::vector ExecuteBackwardBiasGelu::invoke( std::vector grad_tensor; TT_FATAL((approximate == "none" || approximate == "tanh") && "Incorrect rounding mode (expected 'none' or 'tanh')"); Tensor input = ttnn::add(input_tensor, bias); - grad_tensor = ttnn::gelu_bw(grad, input, approximate = approximate); + // grad_tensor = ttnn::gelu_bw(grad, input, approximate = approximate); + std::vector> gelu_result = ttnn::gelu_bw(grad, input, approximate, output_mem_config); + for (const auto& opt_tensor : gelu_result) { + if (opt_tensor.has_value()) { + grad_tensor.push_back(opt_tensor.value()); + } else { + throw std::runtime_error("Gelu backward returned an empty tensor."); + } + } + return grad_tensor; } diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index 7ad0d9905b8b..736c6951a9db 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -1323,9 +1323,18 @@ std::vector _deg2rad_bw(const Tensor& grad, const Tensor& input, const s } -std::vector _gelu_bw( - const Tensor& grad, const Tensor& input, string approximate, const std::optional& output_mem_config) { - std::vector grad_tensor; +std::vector> ExecuteUnaryBackwardGelu::invoke( + uint8_t queue_id, + const Tensor& grad, + const Tensor& input, + string approximate, + const std::optional& output_mem_config, + std::optional input_grad) { + std::vector> result; + if(!input_grad.has_value()){ + input_grad = ttnn::zeros_like(grad); + } + auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed TT_FATAL((approximate == "none" || approximate == "tanh") && "Incorrect approximate mode (expected 'None', 'tanh')"); @@ -1354,19 +1363,27 @@ std::vector _gelu_bw( std::nullopt, output_memory_config); - Tensor grad_a = ttnn::multiply(grad, (ttnn::add(left_derivative, right_derivative)), std::nullopt, output_memory_config); - grad_tensor.emplace_back(grad_a); + ttnn::multiply(queue_id, grad, (ttnn::add(left_derivative, right_derivative)), std::nullopt, output_memory_config, input_grad); + result.push_back(input_grad); } else { float kAlpha = M_SQRT1_2; float kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5; Tensor cdf = ttnn::multiply((ttnn::add(ttnn::erf(ttnn::multiply(input, kAlpha, std::nullopt, output_memory_config)), 1, std::nullopt, output_memory_config)), 0.5); Tensor pdf = ttnn::multiply(ttnn::exp(ttnn::multiply(ttnn::multiply(input, input), -0.5), false, output_memory_config), kBeta, std::nullopt, output_memory_config); - Tensor grad_a = ttnn::multiply(grad, (ttnn::add(cdf, ttnn::multiply(input, pdf)))); - grad_tensor.emplace_back(grad_a); + ttnn::multiply(queue_id, grad, ttnn::add(cdf, ttnn::multiply(input, pdf)), std::nullopt, output_memory_config, input_grad); + result.push_back(input_grad); } - return grad_tensor; + return result; +} + +std::vector> ExecuteUnaryBackwardGelu::invoke( + const Tensor& grad, + const Tensor& input, + string approximate, + const std::optional& output_mem_config) { + return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config); } std::vector _repeat_bw( diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index b642543358c2..2db73545e487 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -81,7 +81,6 @@ enum class UnaryBackwardOpType { ERF_BW, DEG2RAD_BW, POLYGAMMA_BW, - GELU_BW, REPEAT_BW, PROD_BW, }; @@ -162,8 +161,6 @@ std::vector _clamp_bw( const Tensor& grad, const Tensor& input, std::opt std::vector _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional& output_mem_config = std::nullopt); -std::vector _gelu_bw( const Tensor& grad, const Tensor& input, string approximate = "none", const std::optional& output_mem_config = std::nullopt); - std::vector _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::Shape& shape, const std::optional& output_mem_config); std::vector _prod_bw( const Tensor& grad, const Tensor& input, bool all_dimensions = true, int64_t dim = 0, const std::optional& output_mem_config = std::nullopt); @@ -628,13 +625,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, string approximate, const std::optional& output_mem_config ) { - return _gelu_bw(grad, input, approximate, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, const tt::tt_metal::Shape& shape, const std::optional& output_mem_config ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 66dfad36f735..e8d24d7591d4 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -251,6 +251,24 @@ struct ExecuteUnaryBackwardAbs { }; +struct ExecuteUnaryBackwardGelu{ + static std::vector> invoke( + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + string parameter_a, + const std::optional &memory_config = std::nullopt); + + static std::vector> invoke( + uint8_t queue_id, + const Tensor &grad_tensor_arg, + const Tensor &input_tensor_arg, + string parameter_a, + const std::optional &memory_config = std::nullopt, + std::optional input_grad = std::nullopt); + +}; + + } // operations::unary constexpr auto threshold_bw = ttnn::register_operation< @@ -538,8 +556,7 @@ constexpr auto rdiv_bw = ttnn::register_operation< constexpr auto gelu_bw = ttnn::register_operation< "ttnn::gelu_bw", - operations::unary_backward::ExecuteUnaryBackwardStringDefault< - operations::unary_backward::UnaryBackwardOpType::GELU_BW>>(); + operations::unary_backward::ExecuteUnaryBackwardGelu>(); constexpr auto repeat_bw = ttnn::register_operation< "ttnn::repeat_bw", diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp index 4843f24b5f75..375c8684c393 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward_pybind.hpp @@ -857,6 +857,64 @@ Keyword args: py::arg("memory_config") = std::nullopt}); } +template +void bind_unary_backward_gelu( + py::module& module, + const unary_backward_operation_t& operation, + const std::string& parameter_name_a, + const std::string& parameter_a_doc, + string parameter_a_value, + std::string_view description) { + auto doc = fmt::format( + R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor: ttnn.Tensor, {2}: string, *, memory_config: ttnn.MemoryConfig) -> std::vector + + {5} + + Args: + * :attr:`grad_tensor` + * :attr:`input_tensor` + + Keyword args: + * :attr:`{2}` (string): {3} , Default value = {4} + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id + + Example: + + >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) + >>> output = {1}(grad_tensor, input, {2} = {4}) + )doc", + operation.base_name(), + operation.python_fully_qualified_name(), + parameter_name_a, + parameter_a_doc, + parameter_a_value, + description); + bind_registered_operation( + module, + operation, + doc, + ttnn::pybind_overload_t{ + [](const unary_backward_operation_t& self, + const ttnn::Tensor& grad_tensor, + const ttnn::Tensor& input_tensor, + string parameter_a, + const std::optional& memory_config, + const std::optional& input_grad, + const uint8_t& queue_id) -> std::vector> { + return self(queue_id, grad_tensor, input_tensor, parameter_a, memory_config, input_grad); + }, + py::arg("grad_tensor"), + py::arg("input_tensor"), + py::kw_only(), + py::arg(parameter_name_a.c_str()) = parameter_a_value, + py::arg("memory_config") = std::nullopt, + py::arg("input_grad") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}); +} + } // namespace detail void py_module(py::module& module) { @@ -977,7 +1035,7 @@ void py_module(py::module& module) { "Shape", R"doc(Performs backward operations for repeat on :attr:`input_tensor_a` or :attr:`input_tensor`, with given :attr:`grad_tensor` using given :attr:`shape`.)doc"); - detail::bind_unary_backward_string_default( + detail::bind_unary_backward_gelu( module, ttnn::gelu_bw, "approximate",