Skip to content

Commit

Permalink
#12375: Add qid and optional tensor output to ttnn.gelu_bw
Browse files Browse the repository at this point in the history
  • Loading branch information
mcw-anasuya committed Sep 11, 2024
1 parent fc94626 commit 4dcb7d8
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 24 deletions.
64 changes: 64 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,67 @@ def test_bw_gelu_default(input_shapes, device):

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"approximate",
(
"none",
"tanh",
),
)
def test_bw_gelu_opt_output(input_shapes, approximate, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
input_grad = ttnn.from_torch(
input_grad, ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.gelu_bw(grad_tensor, input_tensor, approximate=approximate, input_grad=input_grad, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad]

golden_function = ttnn.get_golden_function(ttnn.gelu_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_gelu_default_opt_output(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -5, 5, device)
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.gelu_bw(grad_tensor, input_tensor, input_grad=input_grad, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad]

golden_function = ttnn.get_golden_function(ttnn.gelu_bw)
golden_tensor = golden_function(grad_data, in_data)

comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,10 @@ std::vector<Tensor> ExecuteBackwardBiasGelu::invoke(
TT_FATAL((approximate == "none" || approximate == "tanh") && "Incorrect approximation type (expected 'none', 'tanh')");
std::vector<Tensor> grad_tensor;
Tensor input = ttnn::add(input_a, input_b);
grad_tensor = ttnn::gelu_bw(grad, input, approximate = approximate, output_mem_config);
grad_tensor.emplace_back(grad_tensor[0]);
std::vector<std::optional<Tensor>> gelu_result = ttnn::gelu_bw(grad, input, approximate, output_mem_config);
if (gelu_result[0].has_value()) {
grad_tensor.push_back(gelu_result[0].value());
}
return grad_tensor;
}

Expand All @@ -448,7 +450,10 @@ std::vector<Tensor> ExecuteBackwardBiasGelu::invoke(
std::vector<Tensor> grad_tensor;
TT_FATAL((approximate == "none" || approximate == "tanh") && "Incorrect rounding mode (expected 'none' or 'tanh')");
Tensor input = ttnn::add(input_tensor, bias);
grad_tensor = ttnn::gelu_bw(grad, input, approximate = approximate);
std::vector<std::optional<Tensor>> gelu_result = ttnn::gelu_bw(grad, input, approximate = approximate, output_mem_config);
if (gelu_result[0].has_value()) {
grad_tensor.push_back(gelu_result[0].value());
}
return grad_tensor;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1323,9 +1323,18 @@ std::vector<Tensor> _deg2rad_bw(const Tensor& grad, const Tensor& input, const s
}


std::vector<Tensor> _gelu_bw(
const Tensor& grad, const Tensor& input, string approximate, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
uint8_t queue_id,
const Tensor& grad,
const Tensor& input,
string approximate,
const std::optional<MemoryConfig>& output_mem_config,
std::optional<Tensor> input_grad) {
std::vector<std::optional<Tensor>> result;
if(!input_grad.has_value()){
input_grad = ttnn::zeros_like(grad);
}

auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
TT_FATAL((approximate == "none" || approximate == "tanh") && "Incorrect approximate mode (expected 'None', 'tanh')");

Expand Down Expand Up @@ -1354,19 +1363,27 @@ std::vector<Tensor> _gelu_bw(
std::nullopt,
output_memory_config);

Tensor grad_a = ttnn::multiply(grad, (ttnn::add(left_derivative, right_derivative)), std::nullopt, output_memory_config);
grad_tensor.emplace_back(grad_a);
ttnn::multiply(queue_id, grad, (ttnn::add(left_derivative, right_derivative)), std::nullopt, output_memory_config, input_grad);
result.push_back(input_grad);
} else {
float kAlpha = M_SQRT1_2;
float kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5;
Tensor cdf =
ttnn::multiply((ttnn::add(ttnn::erf(ttnn::multiply(input, kAlpha, std::nullopt, output_memory_config)), 1, std::nullopt, output_memory_config)), 0.5);
Tensor pdf = ttnn::multiply(ttnn::exp(ttnn::multiply(ttnn::multiply(input, input), -0.5), false, output_memory_config), kBeta, std::nullopt, output_memory_config);
Tensor grad_a = ttnn::multiply(grad, (ttnn::add(cdf, ttnn::multiply(input, pdf))));
grad_tensor.emplace_back(grad_a);
ttnn::multiply(queue_id, grad, ttnn::add(cdf, ttnn::multiply(input, pdf)), std::nullopt, output_memory_config, input_grad);
result.push_back(input_grad);
}

return grad_tensor;
return result;
}

std::vector<std::optional<ttnn::Tensor>> ExecuteUnaryBackwardGelu::invoke(
const Tensor& grad,
const Tensor& input,
string approximate,
const std::optional<MemoryConfig>& output_mem_config) {
return ExecuteUnaryBackwardGelu::invoke(DefaultQueueId, grad, input, approximate, output_mem_config);
}

std::vector<Tensor> _repeat_bw(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ enum class UnaryBackwardOpType {
ERF_BW,
DEG2RAD_BW,
POLYGAMMA_BW,
GELU_BW,
REPEAT_BW,
PROD_BW,
};
Expand Down Expand Up @@ -162,8 +161,6 @@ std::vector<Tensor> _clamp_bw( const Tensor& grad, const Tensor& input, std::opt

std::vector<Tensor> _rdiv_bw( const Tensor& grad, const Tensor& input, float scalar, string round_mode = "None", const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _gelu_bw( const Tensor& grad, const Tensor& input, string approximate = "none", const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const tt::tt_metal::Shape& shape, const std::optional<MemoryConfig>& output_mem_config);

std::vector<Tensor> _prod_bw( const Tensor& grad, const Tensor& input, bool all_dimensions = true, int64_t dim = 0, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);
Expand Down Expand Up @@ -628,13 +625,6 @@ struct OpHandler<UnaryBackwardOpType::RDIV_BW> {
}
};

template <>
struct OpHandler<UnaryBackwardOpType::GELU_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, string approximate, const std::optional<MemoryConfig>& output_mem_config ) {
return _gelu_bw(grad, input, approximate, output_mem_config);
}
};

template <>
struct OpHandler<UnaryBackwardOpType::REPEAT_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const tt::tt_metal::Shape& shape, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down
21 changes: 19 additions & 2 deletions ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,24 @@ struct ExecuteUnaryBackwardAbs {
};


struct ExecuteUnaryBackwardGelu{
static std::vector<std::optional<ttnn::Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
string parameter_a,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_arg,
string parameter_a,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
std::optional<Tensor> input_grad = std::nullopt);

};


} // operations::unary

constexpr auto threshold_bw = ttnn::register_operation<
Expand Down Expand Up @@ -538,8 +556,7 @@ constexpr auto rdiv_bw = ttnn::register_operation<

constexpr auto gelu_bw = ttnn::register_operation<
"ttnn::gelu_bw",
operations::unary_backward::ExecuteUnaryBackwardStringDefault<
operations::unary_backward::UnaryBackwardOpType::GELU_BW>>();
operations::unary_backward::ExecuteUnaryBackwardGelu>();

constexpr auto repeat_bw = ttnn::register_operation<
"ttnn::repeat_bw",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -857,6 +857,64 @@ Keyword args:
py::arg("memory_config") = std::nullopt});
}

template <typename unary_backward_operation_t>
void bind_unary_backward_gelu(
py::module& module,
const unary_backward_operation_t& operation,
const std::string& parameter_name_a,
const std::string& parameter_a_doc,
string parameter_a_value,
std::string_view description) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor: ttnn.Tensor, {2}: string, *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{5}
Args:
* :attr:`grad_tensor`
* :attr:`input_tensor`
Keyword args:
* :attr:`{2}` (string): {3} , Default value = {4}
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> input = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, input, {2} = {4})
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
parameter_name_a,
parameter_a_doc,
parameter_a_value,
description);
bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const unary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
string parameter_a,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& input_grad,
const uint8_t& queue_id) -> std::vector<std::optional<ttnn::Tensor>> {
return self(queue_id, grad_tensor, input_tensor, parameter_a, memory_config, input_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::kw_only(),
py::arg(parameter_name_a.c_str()) = parameter_a_value,
py::arg("memory_config") = std::nullopt,
py::arg("input_grad") = std::nullopt,
py::arg("queue_id") = ttnn::DefaultQueueId});
}

} // namespace detail

void py_module(py::module& module) {
Expand Down Expand Up @@ -977,7 +1035,7 @@ void py_module(py::module& module) {
"Shape",
R"doc(Performs backward operations for repeat on :attr:`input_tensor_a` or :attr:`input_tensor`, with given :attr:`grad_tensor` using given :attr:`shape`.)doc");

detail::bind_unary_backward_string_default(
detail::bind_unary_backward_gelu(
module,
ttnn::gelu_bw,
"approximate",
Expand Down

0 comments on commit 4dcb7d8

Please sign in to comment.