Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12867: Cleanup 9 Unary Backward ops #12920

Merged
merged 4 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_celu.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,24 @@ def test_bw_celu(input_shapes, device):
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_celu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)

tt_output_tensor_on_device = ttnn.celu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.celu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)

assert comp_pass
20 changes: 20 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_elu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,23 @@ def test_bw_elu(input_shapes, alpha, device):
golden_tensor = golden_function(grad_data, in_data, alpha)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_elu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device, True)

tt_output_tensor_on_device = ttnn.elu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.elu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,24 @@ def test_bw_hardshrink(input_shapes, lambd, device):

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_hardshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.hardshrink_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.hardshrink_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,23 @@ def test_bw_leaky_relu(input_shapes, negative_slope, device):
golden_tensor = golden_function(grad_data, in_data, negative_slope)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_leaky_relu_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, -1, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, -1, device, True)

tt_output_tensor_on_device = ttnn.leaky_relu_bw(grad_tensor, input_tensor)

golden_function = ttnn.get_golden_function(ttnn.leaky_relu_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,21 @@ def test_bw_logiteps(input_shapes, eps, device):
golden_tensor = golden_function(grad_data, in_data, eps)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_logiteps_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -2, 2, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)
tt_output_tensor_on_device = ttnn.logiteps_bw(grad_tensor, input_tensor)
golden_function = ttnn.get_golden_function(ttnn.logiteps_bw)
golden_tensor = golden_function(grad_data, in_data)
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,28 @@ def test_bw_softshrink(input_shapes, lambd, device):

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_softshrink_default(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -20, 20, device)
in_data.retain_grad()

pyt_y = torch.nn.functional.softshrink(in_data)

tt_output_tensor_on_device = ttnn.softshrink_bw(grad_tensor, input_tensor)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]

comp_pass = compare_results(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ std::vector<Tensor> ExecuteUnaryBackwardSoftplus::invoke(
return grad_tensor;
}

std::vector<Tensor> _rdiv_bw(
std::vector<Tensor> ExecuteUnaryBackwardRdiv::invoke(
const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
TT_FATAL((round_mode == "None" || round_mode == "trunc" || round_mode == "floor"), "Incorrect rounding mode (expected 'None', 'trunc', or 'floor')");
Expand Down Expand Up @@ -591,7 +591,7 @@ std::vector<Tensor> _square_bw(const Tensor& grad, const Tensor& input, const st
return grad_tensor;
}

std::vector<Tensor> _hardshrink_bw(
std::vector<Tensor> ExecuteUnaryBackwardHardshrink::invoke(
const Tensor& grad, const Tensor& input_tensor, float lambd, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor hardshrink_result = ttnn::hardshrink(input_tensor, lambd, output_mem_config);
Expand All @@ -603,7 +603,7 @@ std::vector<Tensor> _hardshrink_bw(

// softshrink
// result: torch.where(self < -lambd, grad, torch.where(self > lambd, grad, torch.tensor(0.0)))
std::vector<Tensor> _softshrink_bw(
std::vector<Tensor> ExecuteUnaryBackwardSoftshrink::invoke(
const Tensor& grad, const Tensor& input_tensor, float lambd, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor result = ttnn::where(
Expand All @@ -622,7 +622,7 @@ std::vector<Tensor> _softshrink_bw(

// Leaky_Relu
// result: torch.where(self > 0, grad_output, grad_output * negative_slope)
std::vector<Tensor> _leaky_relu_bw(
std::vector<Tensor> ExecuteUnaryBackwardLeakyRelu::invoke(
const Tensor& grad, const Tensor& input, float negative_slope, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = where(
Expand All @@ -634,7 +634,7 @@ std::vector<Tensor> _leaky_relu_bw(

// ELU
// result : grad * (torch.where(input >= 0, 1, alpha * torch.exp(input)))
std::vector<Tensor> _elu_bw(
std::vector<Tensor> ExecuteUnaryBackwardElu::invoke(
const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_result = where(
Expand All @@ -649,7 +649,7 @@ std::vector<Tensor> _elu_bw(

// Celu
// result: torch.where((input > 0), grad, grad * torch.exp(input / alpha))
std::vector<Tensor> _celu_bw(
std::vector<Tensor> ExecuteUnaryBackwardCelu::invoke(
const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor div_result = ttnn::multiply(
Expand Down Expand Up @@ -1074,7 +1074,7 @@ std::vector<Tensor> _cosh_bw(const Tensor& grad, const Tensor& input, const std:
// # grad_output / (self * (1.0 - self)),
// # self.new_full((), float("nan")),
// # )
std::vector<Tensor> _logiteps_bw(
std::vector<Tensor> ExecuteUnaryBackwardLogiteps::invoke(
const Tensor& grad, const Tensor& input, float eps, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
float low, high;
Expand Down Expand Up @@ -1414,7 +1414,7 @@ std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config, input_grad);
}

std::vector<Tensor> _repeat_bw(
std::vector<Tensor> ExecuteUnaryBackwardRepeat::invoke(
const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ namespace ttnn::operations::unary_backward {

enum class UnaryBackwardOpType {
DIV_BW,
RDIV_BW,
MULTIGAMMALN_BW,
ADD_BW,
EQ_BW,
Expand All @@ -36,11 +35,6 @@ enum class UnaryBackwardOpType {
SIGMOID_BW,
RELU_BW,
LOGIT_BW,
HARDSHRINK_BW,
SOFTSHRINK_BW,
LEAKY_RELU_BW,
ELU_BW,
CELU_BW,
RPOW_BW,
FLOOR_BW,
ROUND_BW,
Expand All @@ -61,7 +55,6 @@ enum class UnaryBackwardOpType {
CEIL_BW,
SOFTSIGN_BW,
COSH_BW,
LOGITEPS_BW,
LOG2_BW,
SIGN_BW,
DIV_NO_NAN_BW,
Expand All @@ -73,7 +66,6 @@ enum class UnaryBackwardOpType {
ERF_BW,
DEG2RAD_BW,
POLYGAMMA_BW,
REPEAT_BW,
};

std::vector<Tensor> _acos_bw( const Tensor& grad, const Tensor& input, const std::optional<MemoryConfig>& output_mem_config);
Expand Down Expand Up @@ -132,16 +124,6 @@ std::vector<Tensor> _log_bw( const Tensor& grad, const Tensor& input, const std:
std::vector<Tensor> _add_bw( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _eq_bw( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _hardshrink_bw( const Tensor& grad, const Tensor& input, float lambd = 0.5, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _softshrink_bw( const Tensor& grad, const Tensor& input, float lambd = 0.5, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _leaky_relu_bw( const Tensor& grad, const Tensor& input, float negative_slope = 0.01, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _elu_bw( const Tensor& grad, const Tensor& input, float alpha = 1.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _celu_bw( const Tensor& grad, const Tensor& input, float aplha = 1.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
std::vector<Tensor> _logiteps_bw( const Tensor& grad, const Tensor& input, float eps = 0.0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config);
Tensor change_layout_to_tile(const Tensor& temp, const MemoryConfig& output_mem_config);

// OpHandler struct template
Expand Down Expand Up @@ -414,48 +396,6 @@ struct OpHandler<UnaryBackwardOpType::POLYGAMMA_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::HARDSHRINK_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float lambd, const std::optional<MemoryConfig>& output_mem_config ) {
return _hardshrink_bw(grad, input, lambd, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::SOFTSHRINK_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float lambd, const std::optional<MemoryConfig>& output_mem_config ) {
return _softshrink_bw(grad, input, lambd, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::LEAKY_RELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float negative_slope, const std::optional<MemoryConfig>& output_mem_config ) {
return _leaky_relu_bw(grad, input, negative_slope, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::ELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config ) {
return _elu_bw(grad, input, alpha, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::CELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float alpha, const std::optional<MemoryConfig>& output_mem_config ) {
return _celu_bw(grad, input, alpha, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::LOGITEPS_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float eps, const std::optional<MemoryConfig>& output_mem_config ) {
return _logiteps_bw(grad, input, eps, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::GT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down Expand Up @@ -540,20 +480,6 @@ struct OpHandler<UnaryBackwardOpType::SUB_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::RDIV_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float scalar, string round_mode, const std::optional<MemoryConfig>& output_mem_config ) {
return _rdiv_bw(grad, input, scalar, round_mode, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::REPEAT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const tt::tt_metal::LegacyShape& shape, const std::optional<MemoryConfig>& output_mem_config ) {
return _repeat_bw(grad, input, shape, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::EQ_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, float other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
Loading
Loading