Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for sample_weights in LinearRegression #4428

Merged
merged 18 commits into from
Mar 10, 2022

Conversation

lowener
Copy link
Contributor

@lowener lowener commented Dec 6, 2021

Closes #4031.
Scikit-learn is rescaling the data (here) to take into account the sample_weight parameter.

@github-actions github-actions bot added CUDA/C++ Cython / Python Cython or Python issue labels Dec 6, 2021
Copy link
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One super small docstring note, otherwise it looks great

cpp/include/cuml/linear_model/glm.hpp Outdated Show resolved Hide resolved
@github-actions
Copy link

This PR has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this PR if it is no longer required. Otherwise, please respond with a comment indicating any updates. This PR will be labeled inactive-90d if there is no activity in the next 60 days.

@lowener lowener marked this pull request as ready for review January 28, 2022 13:51
@lowener lowener requested review from a team as code owners January 28, 2022 13:51
@lowener lowener added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Jan 28, 2022
@lowener lowener changed the base branch from branch-22.02 to branch-22.04 January 31, 2022 10:58
@lowener
Copy link
Contributor Author

lowener commented Jan 31, 2022

CI Failure on cuml/gpu/cuda/11.0/driver-450/python/3.8/centos7:

cuml/test/test_api.py::test_fit_function[IncrementalPCA] [e34af1b13a84:1329 :0:1329] Caught signal 11 (Segmentation fault: Sent by the kernel at address (nil))
14:12:13 ==== backtrace (tid:   1329) ====
14:12:13  0  /opt/conda/envs/rapids/lib/python3.8/site-packages/ucp/_libs/../../../../libucs.so.0(ucs_handle_error+0x155) [0x7fb21f2493f5]
14:12:13  1  /opt/conda/envs/rapids/lib/python3.8/site-packages/ucp/_libs/../../../../libucs.so.0(+0x2d791) [0x7fb21f249791]
14:12:13  2  /opt/conda/envs/rapids/lib/python3.8/site-packages/ucp/_libs/../../../../libucs.so.0(+0x2d962) [0x7fb21f249962]
14:12:13  3  /usr/lib64/libc.so.6(+0x36400) [0x7fb2c658e400]
14:12:13  4  /opt/conda/envs/rapids/lib/python3.8/site-packages/cupy_backends/cuda/libs/../../../../../libcusolver.so.10(+0x557d83) [0x7fb078691d83]
14:12:13  5  /opt/conda/envs/rapids/lib/python3.8/site-packages/cupy_backends/cuda/libs/../../../../../libcusolver.so.10(cusolverDnDestroy+0x22) [0x7fb078487292]
14:12:13  6  /workspace/python/cuml/raft/common/handle.cpython-38-x86_64-linux-gnu.so(_ZN4raft8handle_tD1Ev+0x321) [0x7fb2c058f921]
14:12:13  7  /workspace/python/cuml/raft/common/handle.cpython-38-x86_64-linux-gnu.so(+0x28ab1) [0x7fb2c058aab1]
14:12:13  8  /opt/conda/envs/rapids/bin/python(+0xfab48) [0x55e9a8d45b48]
14:12:13  9  /opt/conda/envs/rapids/bin/python(+0xfc7bd) [0x55e9a8d477bd]
...
14:12:13 =================================
14:12:13 Fatal Python error: Segmentation fault

@lowener
Copy link
Contributor Author

lowener commented Jan 31, 2022

rerun tests

@lowener lowener added the 3 - Ready for Review Ready for review by team label Jan 31, 2022
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes look great. Just very minor things.

@@ -62,7 +67,8 @@ void olsFit(const raft::handle_t& handle,
bool fit_intercept,
bool normalize,
cudaStream_t stream,
int algo = 0)
int algo = 0,
math_t* sample_weight = nullptr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something you need to change in this PR, but it would be nice to start adopting std::optional for arguments like these. Becomes self-documenting as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree so I started to make the change but I saw that std::optional support on Cython is very recent and will be added to version 0.30. For the moment we're using version 0.29 so the adoption of std::optional on Python-facing functions should maybe wait a bit more.

cpp/src_prims/stats/weighted_mean.cuh Outdated Show resolved Hide resolved
python/cuml/test/test_linear_model.py Outdated Show resolved Hide resolved
@lowener lowener requested a review from a team as a code owner February 18, 2022 00:10
@github-actions github-actions bot added the CMake label Feb 18, 2022
rapids-bot bot pushed a commit to rapidsai/raft that referenced this pull request Mar 2, 2022
@github-actions github-actions bot removed the CMake label Mar 2, 2022
@lowener lowener requested a review from cjnolet March 10, 2022 15:01
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dantegd
Copy link
Member

dantegd commented Mar 10, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit fc94e5f into rapidsai:branch-22.04 Mar 10, 2022
@lowener lowener deleted the 22.02-linear-weight branch March 10, 2022 17:10
vimarsh6739 pushed a commit to vimarsh6739/cuml that referenced this pull request Oct 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team CUDA/C++ Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEA] Sample weights for Linear Regression
3 participants