Skip to content

Commit

Permalink
Check for invalid data. (#6742)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Mar 4, 2021
1 parent a9b4a95 commit f20074e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
17 changes: 11 additions & 6 deletions src/data/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,11 +898,12 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
return max_columns;
}
std::vector<std::vector<uint64_t>> max_columns_vector(nthread);
dmlc::OMPException exc;
dmlc::OMPException exec;
std::atomic<bool> valid{true};
// First-pass over the batch counting valid elements
#pragma omp parallel num_threads(nthread)
{
exc.Run([&]() {
exec.Run([&]() {
int tid = omp_get_thread_num();
size_t begin = tid*thread_size;
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
Expand All @@ -912,7 +913,10 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
for (size_t i = begin; i < end; ++i) {
auto line = batch.GetLine(i);
for (auto j = 0ull; j < line.Size(); j++) {
auto element = line.GetElement(j);
data::COOTuple const& element = line.GetElement(j);
if (!std::isinf(missing) && std::isinf(element.value)) {
valid = false;
}
const size_t key = element.row_idx - base_rowid;
CHECK_GE(key, builder_base_row_offset);
max_columns_local =
Expand All @@ -927,7 +931,8 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
});
}
exc.Rethrow();
exec.Rethrow();
CHECK(valid) << "Input data contains `inf` or `nan`";
for (const auto & max : max_columns_vector) {
max_columns = std::max(max_columns, max[0]);
}
Expand All @@ -938,7 +943,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread

#pragma omp parallel num_threads(nthread)
{
exc.Run([&]() {
exec.Run([&]() {
int tid = omp_get_thread_num();
size_t begin = tid*thread_size;
size_t end = tid != (nthread-1) ? (tid+1)*thread_size : batch_size;
Expand All @@ -954,7 +959,7 @@ uint64_t SparsePage::Push(const AdapterBatchT& batch, float missing, int nthread
}
});
}
exc.Rethrow();
exec.Rethrow();
omp_set_num_threads(nthread_original);

return max_columns;
Expand Down
8 changes: 8 additions & 0 deletions tests/cpp/data/test_simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ TEST(SimpleDMatrix, MissingData) {
CHECK_EQ(dmat->Info().num_nonzero_, 2);
dmat.reset(new data::SimpleDMatrix(&adapter, 1.0, 1));
CHECK_EQ(dmat->Info().num_nonzero_, 1);

{
data[1] = std::numeric_limits<float>::infinity();
data::DenseAdapter adapter(data.data(), data.size(), 1);
EXPECT_THROW(data::SimpleDMatrix dmat(
&adapter, std::numeric_limits<float>::quiet_NaN(), -1),
dmlc::Error);
}
}

TEST(SimpleDMatrix, EmptyRow) {
Expand Down

0 comments on commit f20074e

Please sign in to comment.