diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py index 1d40c60a1c8..b42841ad356 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py @@ -29,3 +29,24 @@ def test_bw_celu(input_shapes, device): comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_celu_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True) + + tt_output_tensor_on_device = ttnn.celu_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.celu_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py index 59783b57e84..8114d3022c1 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py @@ -36,3 +36,23 @@ def test_bw_elu(input_shapes, alpha, device): golden_tensor = golden_function(grad_data, in_data, alpha) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_elu_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device, True) + + tt_output_tensor_on_device = ttnn.elu_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.elu_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py b/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py index 3c386939a0a..5700e708382 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py @@ -28,3 +28,24 @@ def test_bw_hardshrink(input_shapes, lambd, device): comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_hardshrink_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) + + tt_output_tensor_on_device = ttnn.hardshrink_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.hardshrink_bw) + golden_tensor = golden_function(grad_data, in_data) + + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py b/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py index 5941f89709d..5c33b4f1664 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py @@ -27,3 +27,23 @@ def test_bw_leaky_relu(input_shapes, negative_slope, device): golden_tensor = golden_function(grad_data, in_data, negative_slope) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_leaky_relu_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True) + + tt_output_tensor_on_device = ttnn.leaky_relu_bw(grad_tensor, input_tensor) + + golden_function = ttnn.get_golden_function(ttnn.leaky_relu_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py b/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py index 4ee3ccf1aec..5546090dcf7 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py @@ -32,3 +32,21 @@ def test_bw_logiteps(input_shapes, eps, device): golden_tensor = golden_function(grad_data, in_data, eps) comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_logiteps_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -2, 2, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device) + tt_output_tensor_on_device = ttnn.logiteps_bw(grad_tensor, input_tensor) + golden_function = ttnn.get_golden_function(ttnn.logiteps_bw) + golden_tensor = golden_function(grad_data, in_data) + comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py b/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py index 815246a959c..bf3651a8fac 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py +++ b/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py @@ -35,3 +35,28 @@ def test_bw_softshrink(input_shapes, lambd, device): comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) assert comp_pass + + +@pytest.mark.parametrize( + "input_shapes", + ( + (torch.Size([1, 1, 32, 32])), + (torch.Size([1, 1, 320, 384])), + (torch.Size([1, 3, 320, 384])), + ), +) +def test_bw_softshrink_default(input_shapes, device): + in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) + grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device) + in_data.retain_grad() + + pyt_y = torch.nn.functional.softshrink(in_data) + + tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor) + + pyt_y.backward(gradient=grad_data) + + golden_tensor = [in_data.grad] + + comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor) + assert comp_pass diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp index b302c227f17..986fe5fa1c5 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.cpp @@ -101,7 +101,7 @@ std::vector ExecuteUnaryBackwardSoftplus::invoke( return grad_tensor; } -std::vector _rdiv_bw( +std::vector ExecuteUnaryBackwardRdiv::invoke( const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional& output_mem_config) { std::vector grad_tensor; TT_FATAL((round_mode == "None" || round_mode == "trunc" || round_mode == "floor"), "Incorrect rounding mode (expected 'None', 'trunc', or 'floor')"); @@ -591,7 +591,7 @@ std::vector _square_bw(const Tensor& grad, const Tensor& input, const st return grad_tensor; } -std::vector _hardshrink_bw( +std::vector ExecuteUnaryBackwardHardshrink::invoke( const Tensor& grad, const Tensor& input_tensor, float lambd, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor hardshrink_result = ttnn::hardshrink(input_tensor, lambd, output_mem_config); @@ -603,7 +603,7 @@ std::vector _hardshrink_bw( // softshrink // result: torch.where(self < -lambd, grad, torch.where(self > lambd, grad, torch.tensor(0.0))) -std::vector _softshrink_bw( +std::vector ExecuteUnaryBackwardSoftshrink::invoke( const Tensor& grad, const Tensor& input_tensor, float lambd, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor result = ttnn::where( @@ -622,7 +622,7 @@ std::vector _softshrink_bw( // Leaky_Relu // result: torch.where(self > 0, grad_output, grad_output * negative_slope) -std::vector _leaky_relu_bw( +std::vector ExecuteUnaryBackwardLeakyRelu::invoke( const Tensor& grad, const Tensor& input, float negative_slope, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor grad_result = where( @@ -634,7 +634,7 @@ std::vector _leaky_relu_bw( // ELU // result : grad * (torch.where(input >= 0, 1, alpha * torch.exp(input))) -std::vector _elu_bw( +std::vector ExecuteUnaryBackwardElu::invoke( const Tensor& grad, const Tensor& input, float alpha, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor grad_result = where( @@ -649,7 +649,7 @@ std::vector _elu_bw( // Celu // result: torch.where((input > 0), grad, grad * torch.exp(input / alpha)) -std::vector _celu_bw( +std::vector ExecuteUnaryBackwardCelu::invoke( const Tensor& grad, const Tensor& input, float alpha, const std::optional& output_mem_config) { std::vector grad_tensor; Tensor div_result = ttnn::multiply( @@ -1074,7 +1074,7 @@ std::vector _cosh_bw(const Tensor& grad, const Tensor& input, const std: // # grad_output / (self * (1.0 - self)), // # self.new_full((), float("nan")), // # ) -std::vector _logiteps_bw( +std::vector ExecuteUnaryBackwardLogiteps::invoke( const Tensor& grad, const Tensor& input, float eps, const std::optional& output_mem_config) { std::vector grad_tensor; float low, high; @@ -1414,7 +1414,7 @@ std::vector> ExecuteUnaryBackwardGelu::invoke( return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config, input_grad); } -std::vector _repeat_bw( +std::vector ExecuteUnaryBackwardRepeat::invoke( const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional& output_mem_config) { std::vector grad_tensor; auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp index eb2bb4927ea..3f35b6f67eb 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/device/unary_backward_op.hpp @@ -14,7 +14,6 @@ namespace ttnn::operations::unary_backward { enum class UnaryBackwardOpType { DIV_BW, - RDIV_BW, MULTIGAMMALN_BW, ADD_BW, EQ_BW, @@ -36,11 +35,6 @@ enum class UnaryBackwardOpType { SIGMOID_BW, RELU_BW, LOGIT_BW, - HARDSHRINK_BW, - SOFTSHRINK_BW, - LEAKY_RELU_BW, - ELU_BW, - CELU_BW, RPOW_BW, FLOOR_BW, ROUND_BW, @@ -61,7 +55,6 @@ enum class UnaryBackwardOpType { CEIL_BW, SOFTSIGN_BW, COSH_BW, - LOGITEPS_BW, LOG2_BW, SIGN_BW, DIV_NO_NAN_BW, @@ -73,7 +66,6 @@ enum class UnaryBackwardOpType { ERF_BW, DEG2RAD_BW, POLYGAMMA_BW, - REPEAT_BW, }; std::vector _acos_bw( const Tensor& grad, const Tensor& input, const std::optional& output_mem_config); @@ -132,16 +124,6 @@ std::vector _log_bw( const Tensor& grad, const Tensor& input, const std: std::vector _add_bw( const Tensor& grad, const Tensor& input, float alpha, const std::optional& output_mem_config = std::nullopt); std::vector _eq_bw( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config = std::nullopt); -std::vector _hardshrink_bw( const Tensor& grad, const Tensor& input, float lambd = 0.5, const std::optional& output_mem_config = std::nullopt); -std::vector _softshrink_bw( const Tensor& grad, const Tensor& input, float lambd = 0.5, const std::optional& output_mem_config = std::nullopt); -std::vector _leaky_relu_bw( const Tensor& grad, const Tensor& input, float negative_slope = 0.01, const std::optional& output_mem_config = std::nullopt); -std::vector _elu_bw( const Tensor& grad, const Tensor& input, float alpha = 1.0, const std::optional& output_mem_config = std::nullopt); -std::vector _celu_bw( const Tensor& grad, const Tensor& input, float aplha = 1.0, const std::optional& output_mem_config = std::nullopt); -std::vector _logiteps_bw( const Tensor& grad, const Tensor& input, float eps = 0.0, const std::optional& output_mem_config = std::nullopt); - -std::vector _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional& output_mem_config = std::nullopt); - -std::vector _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional& output_mem_config); Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config); // OpHandler struct template @@ -414,48 +396,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float lambd, const std::optional& output_mem_config ) { - return _hardshrink_bw(grad, input, lambd, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float lambd, const std::optional& output_mem_config ) { - return _softshrink_bw(grad, input, lambd, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float negative_slope, const std::optional& output_mem_config ) { - return _leaky_relu_bw(grad, input, negative_slope, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float alpha, const std::optional& output_mem_config ) { - return _elu_bw(grad, input, alpha, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float alpha, const std::optional& output_mem_config ) { - return _celu_bw(grad, input, alpha, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float eps, const std::optional& output_mem_config ) { - return _logiteps_bw(grad, input, eps, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config ) { @@ -540,20 +480,6 @@ struct OpHandler { } }; -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional& output_mem_config ) { - return _rdiv_bw(grad, input, scalar, round_mode, output_mem_config); - } -}; - -template <> -struct OpHandler { - static std::vector handle( const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional& output_mem_config ) { - return _repeat_bw(grad, input, shape, output_mem_config); - } -}; - template <> struct OpHandler { static std::vector handle( const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config ) { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp index 9eb523f312c..7579e60a811 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp @@ -82,16 +82,13 @@ struct ExecuteUnaryBackward##op_name { \ const std::optional &memory_config = std::nullopt); \ }; -template -struct ExecuteUnaryBackwardFloatWithDefault { - static std::vector invoke( - const Tensor &grad_tensor_arg, - const Tensor &input_tensor_arg, - float parameter_a, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, parameter_a, output_memory_config); - } +#define DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(op_name) \ +struct ExecuteUnaryBackward##op_name { \ + static std::vector invoke( \ + const Tensor &grad_tensor_arg, \ + const Tensor &input_tensor_arg, \ + float parameter_a, \ + const std::optional &memory_config = std::nullopt); \ }; template @@ -129,17 +126,13 @@ struct ExecuteUnaryBackwardClamp { const std::optional &memory_config = std::nullopt); }; -template -struct ExecuteUnaryBackwardFloatStringDefault { +struct ExecuteUnaryBackwardRdiv { static std::vector invoke( const Tensor &grad_tensor_arg, const Tensor &input_tensor_arg, float parameter_a, string parameter_b, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, parameter_a, parameter_b, output_memory_config); - } + const std::optional &memory_config = std::nullopt); }; template @@ -154,16 +147,12 @@ struct ExecuteUnaryBackwardStringDefault { } }; -template -struct ExecuteUnaryBackwardShape { +struct ExecuteUnaryBackwardRepeat { static std::vector invoke( const Tensor &grad_tensor_arg, const Tensor &input_tensor_arg, const tt::tt_metal::LegacyShape ¶meter_a, - const std::optional &memory_config = std::nullopt) { - auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config()); - return OpHandler::handle(grad_tensor_arg, input_tensor_arg, parameter_a, output_memory_config); - } + const std::optional &memory_config = std::nullopt); }; struct ExecuteUnaryBackwardPow { @@ -315,6 +304,13 @@ struct ExecuteUnaryBackwardGelu{ DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Softplus) DEFINE_UNARY_BACKWARD_OPERATION_WITH_2_DEFAULT_FLOATS(Hardtanh) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(Hardshrink) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(Softshrink) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(LeakyRelu) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(Elu) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(Celu) +DEFINE_UNARY_BACKWARD_OPERATION_WITH_1_DEFAULT_FLOAT(Logiteps) + } // operations::unary constexpr auto threshold_bw = ttnn::register_operation< @@ -544,35 +540,13 @@ constexpr auto erfc_bw = ttnn::register_operation< operations::unary_backward::ExecuteUnaryBackwardOp< operations::unary_backward::UnaryBackwardOpType::ERFC_BW>>(); -constexpr auto hardshrink_bw = ttnn::register_operation< - "ttnn::hardshrink_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::HARDSHRINK_BW>>(); - -constexpr auto softshrink_bw = ttnn::register_operation< - "ttnn::softshrink_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::SOFTSHRINK_BW>>(); - -constexpr auto leaky_relu_bw = ttnn::register_operation< - "ttnn::leaky_relu_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::LEAKY_RELU_BW>>(); - -constexpr auto elu_bw = ttnn::register_operation< - "ttnn::elu_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::ELU_BW>>(); - -constexpr auto celu_bw = ttnn::register_operation< - "ttnn::celu_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::CELU_BW>>(); - -constexpr auto logiteps_bw = ttnn::register_operation< - "ttnn::logiteps_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatWithDefault< - operations::unary_backward::UnaryBackwardOpType::LOGITEPS_BW>>(); +// Tensor + Float(Default) +constexpr auto logiteps_bw = ttnn::register_operation<"ttnn::logiteps_bw", operations::unary_backward::ExecuteUnaryBackwardLogiteps>(); +constexpr auto celu_bw = ttnn::register_operation<"ttnn::celu_bw", operations::unary_backward::ExecuteUnaryBackwardCelu>(); +constexpr auto elu_bw = ttnn::register_operation<"ttnn::elu_bw", operations::unary_backward::ExecuteUnaryBackwardElu>(); +constexpr auto leaky_relu_bw = ttnn::register_operation<"ttnn::leaky_relu_bw", operations::unary_backward::ExecuteUnaryBackwardLeakyRelu>(); +constexpr auto softshrink_bw = ttnn::register_operation<"ttnn::softshrink_bw", operations::unary_backward::ExecuteUnaryBackwardSoftshrink>(); +constexpr auto hardshrink_bw = ttnn::register_operation<"ttnn::hardshrink_bw", operations::unary_backward::ExecuteUnaryBackwardHardshrink>(); constexpr auto clamp_bw = ttnn::register_operation< "ttnn::clamp_bw", @@ -582,10 +556,10 @@ constexpr auto clamp_bw = ttnn::register_operation< constexpr auto hardtanh_bw = ttnn::register_operation<"ttnn::hardtanh_bw", operations::unary_backward::ExecuteUnaryBackwardHardtanh>(); constexpr auto softplus_bw = ttnn::register_operation<"ttnn::softplus_bw", operations::unary_backward::ExecuteUnaryBackwardSoftplus>(); + constexpr auto rdiv_bw = ttnn::register_operation< "ttnn::rdiv_bw", - operations::unary_backward::ExecuteUnaryBackwardFloatStringDefault< - operations::unary_backward::UnaryBackwardOpType::RDIV_BW>>(); + operations::unary_backward::ExecuteUnaryBackwardRdiv>(); constexpr auto gelu_bw = ttnn::register_operation< "ttnn::gelu_bw", @@ -593,8 +567,7 @@ constexpr auto gelu_bw = ttnn::register_operation< constexpr auto repeat_bw = ttnn::register_operation< "ttnn::repeat_bw", - operations::unary_backward::ExecuteUnaryBackwardShape< - operations::unary_backward::UnaryBackwardOpType::REPEAT_BW>>(); + operations::unary_backward::ExecuteUnaryBackwardRepeat>(); constexpr auto pow_bw = ttnn::register_operation< "ttnn::pow_bw", diff --git a/ttnn/ttnn/operations/unary_backward.py b/ttnn/ttnn/operations/unary_backward.py index dccd88a5362..2a58f12757a 100644 --- a/ttnn/ttnn/operations/unary_backward.py +++ b/ttnn/ttnn/operations/unary_backward.py @@ -24,22 +24,39 @@ def _golden_function_unary_backward(torch_op, grad_tensor, input_tensor, *args, return golden_tensor -def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha, *args, **kwargs): +def _golden_function_div_no_nan(torch_op, grad_tensor, input_tensor, alpha, *args, **kwargs): + pyt_y = torch.where(torch.tensor(alpha) == 0, torch.zeros_like(input_tensor), torch.div(input_tensor, alpha)) + input_tensor.retain_grad() + pyt_y.backward(gradient=grad_tensor) + golden_tensor = [input_tensor.grad] + golden_tensor[0] = torch.where(torch.isnan(golden_tensor[0]), torch.zeros_like(input_tensor), golden_tensor[0]) + return golden_tensor + + +def _golden_function_unary_backward_with_float(torch_op, grad_tensor, input_tensor, alpha=None, *args, **kwargs): if torch_op == "leaky_relu": - pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha, inplace=False) + if alpha != None: + pyt_y = torch.nn.functional.leaky_relu(input_tensor, negative_slope=alpha) + else: + pyt_y = torch.nn.functional.leaky_relu(input_tensor) elif torch_op == "elu": - pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha) + if alpha != None: + pyt_y = torch.nn.functional.elu(input_tensor, alpha=alpha) + else: + pyt_y = torch.nn.functional.elu(input_tensor) elif torch_op == "celu": - pyt_y = torch.nn.functional.celu(input_tensor, alpha) - elif torch_op == "div_no_nan": - pyt_y = torch.where(torch.tensor(alpha) == 0, torch.zeros_like(input_tensor), torch.div(input_tensor, alpha)) + if alpha != None: + pyt_y = torch.nn.functional.celu(input_tensor, alpha) + else: + pyt_y = torch.nn.functional.celu(input_tensor) else: - pyt_y = torch_op(input_tensor, alpha) + if alpha != None: + pyt_y = torch_op(input_tensor, alpha) + else: + pyt_y = torch_op(input_tensor) input_tensor.retain_grad() pyt_y.backward(gradient=grad_tensor) golden_tensor = [input_tensor.grad] - if torch_op == "div_no_nan": - golden_tensor[0] = torch.where(torch.isnan(golden_tensor[0]), torch.zeros_like(input_tensor), golden_tensor[0]) return golden_tensor @@ -146,35 +163,35 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.hardshrink_bw, - golden_function=lambda grad, input, *args, **kwargs: _golden_function_unary_backward( - torch.hardshrink, grad, input, *args, **kwargs + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.hardshrink, grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.softshrink_bw, - golden_function=lambda grad, input, *args, **kwargs: _golden_function_unary_backward( - torch.softshrink, grad, input, *args, **kwargs + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( + torch.softshrink, grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.leaky_relu_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( "leaky_relu", grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.elu_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( "elu", grad, input, alpha, *args, **kwargs ), ) ttnn.attach_golden_function( ttnn.celu_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( "celu", grad, input, alpha, *args, **kwargs ), ) @@ -188,7 +205,7 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.logiteps_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha=None, *args, **kwargs: _golden_function_unary_backward_with_float( torch.logit, grad, input, alpha, *args, **kwargs ), ) @@ -216,7 +233,7 @@ def _golden_function_backward_with_reverse_string( ttnn.attach_golden_function( ttnn.div_no_nan_bw, - golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_unary_backward_with_float( + golden_function=lambda grad, input, alpha, *args, **kwargs: _golden_function_div_no_nan( "div_no_nan", grad, input, alpha, *args, **kwargs ), )