Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12164: Add queue_id and optional output tensors to backward ops #12255

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,52 @@ def test_bw_rsub(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_rsub_opt(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -5, 5, device, True)

grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 5, device)

input_grad = None
other_grad = None
tt_output_tensor_on_device = None
mouliraj-mcw marked this conversation as resolved.
Show resolved Hide resolved

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.rsub_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
queue_id=cq_id,
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.rsub_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
75 changes: 75 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,78 @@ def test_bw_unary_sub(input_shapes, scalar, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_sub_opt(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

input_grad = None
other_grad = None
tt_output_tensor_on_device = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.sub_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
queue_id=cq_id,
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.sub_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("scalar", [0.05, 1.0, 0.5, 0.12, 0.0, -0.05, -1.0, -0.5, -0.12])
def test_bw_sub_scalar_opt_output(input_shapes, scalar, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)

_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.sub_bw(grad_tensor, input_tensor, scalar, input_grad=input_grad, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad]
golden_function = ttnn.get_golden_function(ttnn.sub_bw)
golden_tensor = golden_function(grad_data, in_data, scalar)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,51 @@ def test_bw_subalpha_default(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_subalpha_opt_output(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

input_grad = None
other_grad = None
tt_output_tensor_on_device = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.subalpha_bw(
grad_tensor,
input_tensor,
other_tensor,
are_required_outputs=are_required_outputs,
input_grad=input_grad,
other_grad=other_grad,
queue_id=cq_id,
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.subalpha_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,36 @@ struct ExecuteBackwardAdd {
};

struct ExecuteBackwardSub {
static std::vector<Tensor> invoke(
static std::vector<std::optional<Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float scalar,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);

static std::vector<std::optional<Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt,
std::optional<Tensor> other_grad = std::nullopt);

static std::vector<std::optional<Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
float scalar,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

static std::vector<Tensor> invoke(

static std::vector<std::optional<Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
const std::optional<MemoryConfig> &memory_config = std::nullopt);

static std::vector<ComplexTensor> invoke(
Expand Down Expand Up @@ -264,10 +284,51 @@ struct ExecuteAddalphaBW {
std::optional<Tensor> input_b_grad = std::nullopt);
};

struct ExecuteBackwardSubAlpha {
static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float alpha,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt,
std::optional<Tensor> other_grad = std::nullopt);

static std::vector<std::optional<ttnn::Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float alpha,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
const std::optional<MemoryConfig> &memory_config = std::nullopt);

};

struct ExecuteBackwardRsub {
static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt,
std::optional<Tensor> other_grad = std::nullopt);

static std::vector<std::optional<ttnn::Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
const std::optional<MemoryConfig> &memory_config = std::nullopt);

};

} // operations::binary

constexpr auto atan2_bw = ttnn::register_operation<"ttnn::atan2_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::ATAN2_BW>>();
constexpr auto rsub_bw = ttnn::register_operation<"ttnn::rsub_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::RSUB_BW>>();
constexpr auto xlogy_bw = ttnn::register_operation<"ttnn::xlogy_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::XLOGY_BW>>();
constexpr auto hypot_bw = ttnn::register_operation<"ttnn::hypot_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::HYPOT_BW>>();
constexpr auto ldexp_bw = ttnn::register_operation<"ttnn::ldexp_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::LDEXP_BW>>();
Expand All @@ -278,7 +339,13 @@ constexpr auto min_bw = ttnn::register_operation<"ttnn::min_bw", operations::bin
constexpr auto max_bw = ttnn::register_operation<"ttnn::max_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::MAX_BW>>();


constexpr auto subalpha_bw = ttnn::register_operation<"ttnn::subalpha_bw", operations::binary_backward::ExecuteBinaryBackwardFloatDefault<operations::binary_backward::BinaryBackwardOpType::SUBALPHA_BW>>();
constexpr auto subalpha_bw = ttnn::register_operation<
"ttnn::subalpha_bw",
operations::binary_backward::ExecuteBackwardSubAlpha>();

constexpr auto rsub_bw = ttnn::register_operation<
"ttnn::rsub_bw",
operations::binary_backward::ExecuteBackwardRsub>();

constexpr auto concat_bw = ttnn::register_operation<"ttnn::concat_bw", operations::binary_backward::ExecuteBinaryBackwardIntDefault<operations::binary_backward::BinaryBackwardOpType::CONCAT_BW>>();

Expand Down
Loading
Loading