Skip to content

Commit

Permalink
[Sparse] Add compact operator (dmlc#6352)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <ubuntu@ip-172-31-24-117.ap-northeast-1.compute.internal>
  • Loading branch information
2 people authored and Ubuntu committed Nov 27, 2023
1 parent d3cb636 commit 892ab4a
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 88 deletions.
3 changes: 1 addition & 2 deletions dgl_sparse/include/sparse/matrix_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#ifndef SPARSE_MATRIX_OPS_H_
#define SPARSE_MATRIX_OPS_H_

#include <sparse/sparse_format.h>
#include <sparse/sparse_matrix.h>

#include <tuple>
Expand Down Expand Up @@ -47,7 +46,7 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
*/
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
torch::Tensor leading_indices);
const torch::optional<torch::Tensor>& leading_indices);

} // namespace sparse
} // namespace dgl
Expand Down
65 changes: 0 additions & 65 deletions dgl_sparse/src/macro.h

This file was deleted.

123 changes: 115 additions & 8 deletions dgl_sparse/src/matrix_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
#include <sparse/matrix_ops.h>
#include <torch/script.h>

#include "./macro.h"
#include "./matrix_ops_impl.h"

namespace dgl {
namespace sparse {

Expand Down Expand Up @@ -58,12 +55,122 @@ std::tuple<std::shared_ptr<COO>, torch::Tensor, torch::Tensor> COOIntersection(
return {ret_coo, lhs_indices, rhs_indices};
}

/** @brief Return the reverted mapping of a permutation. */
static torch::Tensor RevertPermutation(const torch::Tensor& perm) {
auto rev_tensor = torch::empty_like(perm);
rev_tensor.index_put_(
{perm}, torch::arange(0, perm.numel(), rev_tensor.options()));
return rev_tensor;
}

/**
* @brief Compute the compact indices of row indices and leading indices. Return
* the compacted indices and the original row indices of compacted indices.
*
* @param row The row indices.
* @param leading_indices The leading indices.
*
* @return A tuple of compact indices, original indices.
*/
static std::tuple<torch::Tensor, torch::Tensor> CompactIndices(
const torch::Tensor& row,
const torch::optional<torch::Tensor>& leading_indices) {
torch::Tensor sorted, sort_indices, uniqued, unique_reverse_indices, counts;
// 1. Sort leading indices and row indices in ascending order.
int64_t n_leading_indices = 0;
if (leading_indices.has_value()) {
n_leading_indices = leading_indices.value().numel();
std::tie(sorted, sort_indices) =
torch::cat({leading_indices.value(), row}).sort();
} else {
std::tie(sorted, sort_indices) = row.sort();
}
// 2. Reverse sort indices.
auto sort_rev_indices = RevertPermutation(sort_indices);
// 3. Unique the sorted array.
std::tie(uniqued, unique_reverse_indices, counts) =
torch::unique_consecutive(sorted, true);
auto reverse_indices = unique_reverse_indices.index({sort_rev_indices});
auto n_uniqued = uniqued.numel();

// 4. Relabel the indices and map the inverse array to the original array.
auto split_indices = torch::full({n_uniqued}, -1, reverse_indices.options());

split_indices.index_put_(
{reverse_indices.slice(0, 0, n_leading_indices)},
torch::arange(0, n_leading_indices, split_indices.options()));

split_indices.index_put_(
{(split_indices == -1).nonzero().view(-1)},
torch::arange(n_leading_indices, n_uniqued, split_indices.options()));
// 5. Decode the indices to get the compact indices.
auto new_row = split_indices.index({reverse_indices.slice(
0, n_leading_indices, n_leading_indices + row.numel())});
return {new_row, uniqued.index({RevertPermutation(split_indices)})};
}

static std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactCOO(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
const torch::optional<torch::Tensor>& leading_indices) {
torch::Tensor row, col;
auto coo = mat->COOTensors();
if (dim == 0)
std::tie(row, col) = coo;
else
std::tie(col, row) = coo;

torch::Tensor new_row, uniqued;
std::tie(new_row, uniqued) = CompactIndices(row, leading_indices);

if (dim == 0) {
auto ret = SparseMatrix::FromCOO(
torch::stack({new_row, col}, 0), mat->value(),
std::vector<int64_t>{uniqued.numel(), mat->shape()[1]});
return {ret, uniqued};
} else {
auto ret = SparseMatrix::FromCOO(
torch::stack({col, new_row}, 0), mat->value(),
std::vector<int64_t>{mat->shape()[0], uniqued.numel()});
return {ret, uniqued};
}
}

static std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactCSR(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
const torch::optional<torch::Tensor>& leading_indices) {
std::shared_ptr<CSR> csr;
if (dim == 0)
csr = mat->CSCPtr();
else
csr = mat->CSRPtr();

torch::Tensor new_indices, uniqued;
std::tie(new_indices, uniqued) =
CompactIndices(csr->indices, leading_indices);

auto ret_value = mat->value();
if (csr->value_indices.has_value())
ret_value = mat->value().index_select(0, csr->value_indices.value());
if (dim == 0) {
auto ret = SparseMatrix::FromCSC(
csr->indptr, new_indices, ret_value,
std::vector<int64_t>{uniqued.numel(), mat->shape()[1]});
return {ret, uniqued};
} else {
auto ret = SparseMatrix::FromCSR(
csr->indptr, new_indices, ret_value,
std::vector<int64_t>{mat->shape()[0], uniqued.numel()});
return {ret, uniqued};
}
}

std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> Compact(
const c10::intrusive_ptr<SparseMatrix>& mat, uint64_t dim,
torch::Tensor leading_indices) {
DGL_SPARSE_COO_SWITCH(mat->COOPtr(), XPU, IdType, "Compact", {
return CompactImpl<XPU, IdType>(mat, dim, leading_indices);
});
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
const torch::optional<torch::Tensor>& leading_indices) {
if (mat->HasCOO()) {
return CompactCOO(mat, dim, leading_indices);
}
return CompactCSR(mat, dim, leading_indices);
}

} // namespace sparse
Expand Down
16 changes: 5 additions & 11 deletions dgl_sparse/src/matrix_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,15 @@
#define DGL_SPARSE_MATRIX_OPS_IMPL_H_

#include <sparse/sparse_format.h>
#include <sparse/sparse_matrix.h>

#include <tuple>
#include <vector>

namespace dgl {
namespace sparse {

template <c10::DeviceType XPU, typename IdType>
std::tuple<c10::intrusive_ptr<SparseMatrix>, torch::Tensor> CompactImpl(
const c10::intrusive_ptr<SparseMatrix>& mat, int64_t dim,
torch::Tensor leading_indices) {
// Place holder only.
return {mat, leading_indices};
}
#include "./utils.h"

} // namespace sparse
namespace dgl {
namespace sparse {} // namespace sparse
} // namespace dgl

#endif // DGL_SPARSE_MATRIX_OPS_IMPL_H_
4 changes: 3 additions & 1 deletion dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// clang-format on

#include <sparse/elementwise_op.h>
#include <sparse/matrix_ops.h>
#include <sparse/reduction.h>
#include <sparse/sddmm.h>
#include <sparse/softmax.h>
Expand Down Expand Up @@ -54,7 +55,8 @@ TORCH_LIBRARY(dgl_sparse, m) {
.def("spmm", &SpMM)
.def("sddmm", &SDDMM)
.def("softmax", &Softmax)
.def("spspmm", &SpSpMM);
.def("spspmm", &SpSpMM)
.def("compact", &Compact);
}

} // namespace sparse
Expand Down
5 changes: 4 additions & 1 deletion python/dgl/sparse/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,10 @@ def compact(
>>> print(original_rows)
torch.Tensor([1, 2, 0])
"""
raise NotImplementedError
mat, idx = torch.ops.dgl_sparse.compact(
self.c_sparse_matrix, dim, leading_indices
)
return SparseMatrix(mat), idx


def spmatrix(
Expand Down
47 changes: 47 additions & 0 deletions tests/python/pytorch/sparse/test_matrix_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import backend as F
import pytest
import torch

from .utils import (
rand_coo,
rand_csc,
rand_csr,
rand_diag,
sparse_matrix_to_dense,
)


@pytest.mark.parametrize(
"create_func", [rand_diag, rand_csr, rand_csc, rand_coo]
)
@pytest.mark.parametrize("dim", [0, 1])
@pytest.mark.parametrize("index", [None, (1, 3), (4, 0, 2)])
def test_compact(create_func, dim, index):
ctx = F.ctx()
shape = (5, 5)
ans_idx = []
if index is not None:
ans_idx = list(dict.fromkeys(index))
index = torch.tensor(index).to(ctx)

A = create_func(shape, 8, ctx)

A_compact, ret_id = A.compact(dim, index)
A_compact_dense = sparse_matrix_to_dense(A_compact)

A_dense = sparse_matrix_to_dense(A)

for i in range(shape[dim]):
if dim == 0:
row = list(A_dense[i, :].nonzero().reshape(-1))
else:
row = list(A_dense[:, i].nonzero().reshape(-1))
if (i not in list(ans_idx)) and len(row) > 0:
ans_idx.append(i)
if len(ans_idx):
ans_idx = torch.tensor(ans_idx).to(ctx)
A_dense_select = sparse_matrix_to_dense(A.index_select(dim, ans_idx))

assert A_compact_dense.shape == A_dense_select.shape
assert torch.allclose(A_compact_dense, A_dense_select)
assert torch.allclose(ans_idx, ret_id)

0 comments on commit 892ab4a

Please sign in to comment.