-
Notifications
You must be signed in to change notification settings - Fork 527
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support sparse input for SVC and SVR #5273
Support sparse input for SVC and SVR #5273
Conversation
…ute, allow ExpandedL2 distance compute when applicible
…th single instance of indices
…evice_csr_matrix_view
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Malte for addressing the issues. I have only a few small comments left.
@tfeher , thanks for reviewing, I have pushed an update where I addressed your review suggestions. |
Added "breaking" label because of the changes on the C++ API. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's looking much better. I did a somewhat brief skim over the changes so I could provide feedback more quickly.. I'll do a little more thorough review next week but so far I see only minor things.
*/ | ||
template <typename math_t> | ||
void svcPredictSparse(const raft::handle_t& handle, | ||
int* indptr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we accept such a limited set of types for this, we could probably eventually use the raft::csr_matrix_view
but raw pointers from the Python->C++ hand-off is fine too since we really have not start porting over any of our other C++ APIs to accept mdspan directly yet.
MLCommon::Matrix::Matrix<math_t>* x_ws_matrix = nullptr; | ||
|
||
// matrix l2 norm for RBF kernels | ||
rmm::device_uvector<math_t> matrix_l2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point, we'll be replacing all occurrences w/ the mdarray but for now we can keep these using RMM directly.
model_d.support_matrix.indices = <int*><uintptr_t>self.support_vectors_.indices.ptr | ||
model_d.support_matrix.data = <double*><uintptr_t>self.support_vectors_.data.ptr | ||
else: | ||
model_d.support_matrix.data = <double*><uintptr_t>self.support_vectors_.ptr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a copy of the block above- can we consolidate these, maybe into their own function? Maybe something like configure_support_matrix()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above - we need to distinguish in between C++ data types
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's looking much better. I did a somewhat brief skim over the changes so I could provide feedback more quickly.. I'll do a little more thorough review next week but so far I see only minor things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Malte for the update! I have missed two issues in my previous review, please fix these. Otherwise the PR looks good to me.
Thanks @tfeher for the review. I have applied your suggestions. |
Thanks @cjnolet for the early feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Malte for fixing the issues. LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look great! Thanks again for these changes, @mfoerste4!
/merge |
This PR adds support for sparse input to SVR and SVC. 'fit' as well as 'predict' can be called with sparse data compatible/convertible to SparseCumlArray. Support vectors in the model might also be stored as sparse data and can be retrieved as such.
This PR requires rapidsai/raft#1296 to provide sparse kernel computations.
Corresponding issue: #2197