Skip to content

Commit

Permalink
[Misc.] Avoid calling IsPinned in the coo/csr constructor from ever…
Browse files Browse the repository at this point in the history
…y sampling process (dmlc#6568)
  • Loading branch information
chang-l authored and DominikaJedynak committed Mar 12, 2024
1 parent f52dfe3 commit b015963
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
16 changes: 10 additions & 6 deletions include/dgl/aten/coo.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ struct COOMatrix {
data(darr),
row_sorted(rsorted),
col_sorted(csorted) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity();
}

Expand Down Expand Up @@ -134,6 +129,15 @@ struct COOMatrix {
aten::IsNullArray(data);
}

// Check and update the internal flag is_pinned.
// This function will initialize a cuda context.
inline bool CheckIfPinnedInCUDA() {
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
(aten::IsNullArray(col) || col.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
return is_pinned;
}

/** @brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == row->ctx) return *this;
Expand All @@ -151,7 +155,7 @@ struct COOMatrix {
num_rows, num_cols, row.PinMemory(), col.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
col_sorted);
CHECK(new_coo.is_pinned)
CHECK(new_coo.CheckIfPinnedInCUDA())
<< "An internal DGL error has occured while trying to pin a COO "
"matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
Expand Down
16 changes: 10 additions & 6 deletions include/dgl/aten/csr.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ struct CSRMatrix {
indices(iarr),
data(darr),
sorted(sorted_flag) {
if (!IsEmpty()) {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
}
CheckValidity();
}

Expand Down Expand Up @@ -128,6 +123,15 @@ struct CSRMatrix {
aten::IsNullArray(data);
}

// Check and update the internal flag is_pinned.
// This function will initialize a cuda context.
inline bool CheckIfPinnedInCUDA() {
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
(aten::IsNullArray(indices) || indices.IsPinned()) &&
(aten::IsNullArray(data) || data.IsPinned());
return is_pinned;
}

/** @brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DGLContext& ctx) const {
if (ctx == indptr->ctx) return *this;
Expand All @@ -143,7 +147,7 @@ struct CSRMatrix {
auto new_csr = CSRMatrix(
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
CHECK(new_csr.is_pinned)
CHECK(new_csr.CheckIfPinnedInCUDA())
<< "An internal DGL error has occured while trying to pin a CSR "
"matrix. Please file a bug at "
"'https://github.com/dmlc/dgl/issues' "
Expand Down

0 comments on commit b015963

Please sign in to comment.