Skip to content

Commit

Permalink
Fix indexing of PCA to use safer types (rapidsai#4255)
Browse files Browse the repository at this point in the history
Related to rapidsai#4105.

This PR fixes the types used for indexing and sizes in the PCA/TSVD C++ code.

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#4255
  • Loading branch information
lowener authored Nov 5, 2021
1 parent 5cf3cdd commit 95835b0
Show file tree
Hide file tree
Showing 13 changed files with 236 additions and 222 deletions.
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ set(CUML_BRANCH_VERSION_raft "${CUML_VERSION_MAJOR}.${CUML_VERSION_MINOR}")
find_and_configure_raft(VERSION ${CUML_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${CUML_BRANCH_VERSION_raft}
)
)
14 changes: 7 additions & 7 deletions cpp/include/cuml/decomposition/params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,24 @@ enum class solver : int {

class params {
public:
int n_rows;
int n_cols;
std::size_t n_rows;
std::size_t n_cols;
int gpu_id = 0;
};

class paramsSolver : public params {
public:
// math_t tol = 0.0;
float tol = 0.0;
int n_iterations = 15;
int verbose = 0;
float tol = 0.0;
std::uint32_t n_iterations = 15;
int verbose = 0;
};

template <typename enum_solver = solver>
class paramsTSVDTemplate : public paramsSolver {
public:
int n_components = 1;
enum_solver algorithm = enum_solver::COV_EIG_DQ;
std::size_t n_components = 1;
enum_solver algorithm = enum_solver::COV_EIG_DQ;
};

/**
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/cuml/decomposition/pca_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void fit(raft::handle_t& handle,
*/
void fit_transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::floatData_t** input,
MLCommon::Matrix::floatData_t** trans_input,
float* components,
Expand All @@ -100,7 +100,7 @@ void fit_transform(raft::handle_t& handle,

void fit_transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::doubleData_t** input,
MLCommon::Matrix::doubleData_t** trans_input,
double* components,
Expand All @@ -127,7 +127,7 @@ void fit_transform(raft::handle_t& handle,
*/
void transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<float>** input,
float* components,
MLCommon::Matrix::Data<float>** trans_input,
Expand All @@ -138,7 +138,7 @@ void transform(raft::handle_t& handle,

void transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<double>** input,
double* components,
MLCommon::Matrix::Data<double>** trans_input,
Expand All @@ -162,7 +162,7 @@ void transform(raft::handle_t& handle,
*/
void inverse_transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<float>** trans_input,
float* components,
MLCommon::Matrix::Data<float>** input,
Expand All @@ -173,7 +173,7 @@ void inverse_transform(raft::handle_t& handle,

void inverse_transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<double>** trans_input,
double* components,
MLCommon::Matrix::Data<double>** input,
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/cuml/decomposition/sign_flip_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ void sign_flip(raft::handle_t& handle,
std::vector<MLCommon::Matrix::Data<float>*>& input_data,
MLCommon::Matrix::PartDescriptor& input_desc,
float* components,
int n_components,
std::size_t n_components,
cudaStream_t* streams,
int n_stream);
std::uint32_t n_stream);

void sign_flip(raft::handle_t& handle,
std::vector<MLCommon::Matrix::Data<double>*>& input_data,
MLCommon::Matrix::PartDescriptor& input_desc,
double* components,
int n_components,
std::size_t n_components,
cudaStream_t* streams,
int n_stream);
std::uint32_t n_stream);

}; // end namespace opg
}; // end namespace PCA
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/cuml/decomposition/tsvd_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace opg {
*/
void fit(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::floatData_t** input,
float* components,
float* singular_vals,
Expand All @@ -46,7 +46,7 @@ void fit(raft::handle_t& handle,

void fit(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::doubleData_t** input,
double* components,
double* singular_vals,
Expand Down Expand Up @@ -104,7 +104,7 @@ void fit_transform(raft::handle_t& handle,
*/
void transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<float>** input,
float* components,
MLCommon::Matrix::Data<float>** trans_input,
Expand All @@ -113,7 +113,7 @@ void transform(raft::handle_t& handle,

void transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<double>** input,
double* components,
MLCommon::Matrix::Data<double>** trans_input,
Expand All @@ -133,7 +133,7 @@ void transform(raft::handle_t& handle,
*/
void inverse_transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<float>** trans_input,
float* components,
MLCommon::Matrix::Data<float>** input,
Expand All @@ -142,7 +142,7 @@ void inverse_transform(raft::handle_t& handle,

void inverse_transform(raft::handle_t& handle,
MLCommon::Matrix::RankSizePair** rank_sizes,
size_t n_parts,
std::uint32_t n_parts,
MLCommon::Matrix::Data<double>** trans_input,
double* components,
MLCommon::Matrix::Data<double>** input,
Expand Down
28 changes: 18 additions & 10 deletions cpp/src/pca/pca.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ void truncCompExpVars(const raft::handle_t& handle,
math_t* components,
math_t* explained_var,
math_t* explained_var_ratio,
const paramsTSVDTemplate<enum_solver> prms,
const paramsTSVDTemplate<enum_solver>& prms,
cudaStream_t stream)
{
size_t len = prms.n_cols * prms.n_cols;
auto len = prms.n_cols * prms.n_cols;
rmm::device_uvector<math_t> components_all(len, stream);
rmm::device_uvector<math_t> explained_var_all(prms.n_cols, stream);
rmm::device_uvector<math_t> explained_var_ratio_all(prms.n_cols, stream);
Expand All @@ -55,10 +55,18 @@ void truncCompExpVars(const raft::handle_t& handle,
components_all.data(), prms.n_cols, components, prms.n_components, prms.n_cols, stream);
raft::matrix::ratio(
handle, explained_var_all.data(), explained_var_ratio_all.data(), prms.n_cols, stream);
raft::matrix::truncZeroOrigin(
explained_var_all.data(), prms.n_cols, explained_var, prms.n_components, 1, stream);
raft::matrix::truncZeroOrigin(
explained_var_ratio_all.data(), prms.n_cols, explained_var_ratio, prms.n_components, 1, stream);
raft::matrix::truncZeroOrigin(explained_var_all.data(),
prms.n_cols,
explained_var,
prms.n_components,
std::size_t(1),
stream);
raft::matrix::truncZeroOrigin(explained_var_ratio_all.data(),
prms.n_cols,
explained_var_ratio,
prms.n_components,
std::size_t(1),
stream);
}

/**
Expand Down Expand Up @@ -97,12 +105,12 @@ void pcaFit(const raft::handle_t& handle,
ASSERT(prms.n_components > 0,
"Parameter n_components: number of components cannot be less than one");

int n_components = prms.n_components;
auto n_components = prms.n_components;
if (n_components > prms.n_cols) n_components = prms.n_cols;

raft::stats::mean(mu, input, prms.n_cols, prms.n_rows, true, false, stream);

size_t len = prms.n_cols * prms.n_cols;
auto len = prms.n_cols * prms.n_cols;
rmm::device_uvector<math_t> cov(len, stream);

Stats::cov(handle, cov.data(), input, mu, prms.n_cols, prms.n_rows, true, false, true, stream);
Expand Down Expand Up @@ -202,7 +210,7 @@ void pcaInverseTransform(const raft::handle_t& handle,
ASSERT(prms.n_components > 0,
"Parameter n_components: number of components cannot be less than one");

std::size_t components_len = static_cast<std::size_t>(prms.n_cols * prms.n_components);
auto components_len = prms.n_cols * prms.n_components;
rmm::device_uvector<math_t> components_copy{components_len, stream};
raft::copy(components_copy.data(), components, prms.n_cols * prms.n_components, stream);

Expand Down Expand Up @@ -262,7 +270,7 @@ void pcaTransform(const raft::handle_t& handle,
ASSERT(prms.n_components > 0,
"Parameter n_components: number of components cannot be less than one");

std::size_t components_len = static_cast<std::size_t>(prms.n_cols * prms.n_components);
auto components_len = prms.n_cols * prms.n_components;
rmm::device_uvector<math_t> components_copy{components_len, stream};
raft::copy(components_copy.data(), components, prms.n_cols * prms.n_components, stream);

Expand Down
Loading

0 comments on commit 95835b0

Please sign in to comment.