Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gpu kernel for new api : linalg.lstsq #38621

Merged
merged 11 commits into from
Jan 10, 2022
Merged

Add gpu kernel for new api : linalg.lstsq #38621

merged 11 commits into from
Jan 10, 2022

Conversation

haohongxiang
Copy link
Contributor

@haohongxiang haohongxiang commented Dec 30, 2021

PR types

Function optimization

PR changes

OPs

Describe

Add gpu kernel for new api : linalg.lstsq
You can see the implementation of cpu kernel in PR38585
The docs_cn of paddle.linalg.lstsq is merged by docs-PR4174

1、usage example:

import paddle
paddle.set_device("cpu")
x = paddle.to_tensor([[1, 3], [3, 2], [5, 6.]])
y = paddle.to_tensor([[3, 4, 6], [5, 3, 4], [1, 2, 1.]])
results = paddle.linalg.lstsq(x, y, driver="gelsd")

# solution
print(results[0]) 
# [[ 0.78350395, -0.22165027, -0.62371236],
# [-0.11340097,  0.78866047,  1.14948535]]

# residuals
print(results[1])
# [19.81443405, 10.43814468, 30.56185532])

# rank
print(results[2])
# 2

# singular values
print(results[3])
# [9.03455734, 1.54167950]

2、docs_cn:

image

3、docs_en:

image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

dingjiaweiww
dingjiaweiww previously approved these changes Jan 6, 2022
should be one of float32, float64.
y (Tensor): A tensor with shape ``(*, M, K)`` , the data type of the input Tensor ``y``
should be one of float32, float64.
rcond(float, optional): A float pointing number used to determine the effective rank of ``x``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要说明默认值。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加

.. code-block:: python

import paddle
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行冗余 可以删除

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

auto tmp_x = dito.Transpose(new_x);
auto tmp_y = dito.Transpose(new_y);
framework::TensorCopy(tmp_x, new_x.place(), &new_x);
framework::TensorCopy(tmp_y, new_y.place(), &new_y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need this copy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for comments. It's not necessary so I removed it.

auto slice_b = dito.Slice(trans_b, {-2}, {0}, {min_mn});
auto tmp_r = dito.TrilTriu(slice_r, 0, false);
framework::TensorCopy(tmp_r, new_x.place(), &new_x);
framework::TensorCopy(slice_b, new_y.place(), &new_y);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for comments. It's not necessary so I removed it.

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for API docs

Copy link
Member

@ForFishes ForFishes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ForFishes ForFishes merged commit 405103d into PaddlePaddle:develop Jan 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants