Skip to content

Commit

Permalink
[Lang] Support sparse matrix datatype and storage format configuration (
Browse files Browse the repository at this point in the history
#4673)

* Add sparse matrix datatype configuration

* create sparse matrix with datatype in Python

* sparse solver takes as sparse matrix with datatype parameters

* operator overloading with bug

* fix operator overloading bugs

* Add more operator overloading functions

* EigenSparseMatrix operator overloading

* improve

* Clang-tidy

* add more datatype EigenSparseMatrix

* get/set element bug fix

* Bugfix:sparse matrix shape configuration

* improve sparse matrix test cases

* Update tests/python/test_sparse_matrix.py

Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>

* improve

* Update taichi/program/sparse_matrix.h

Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>

Co-authored-by: Yi Xu <xy_xuyi@foxmail.com>
Co-authored-by: taichiCourse01 <tgc01@taichi.graphics>
  • Loading branch information
3 people authored May 9, 2022
1 parent b485bde commit 736ebd5
Show file tree
Hide file tree
Showing 9 changed files with 496 additions and 214 deletions.
60 changes: 36 additions & 24 deletions misc/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@

n = 8

K = ti.linalg.SparseMatrixBuilder(n, n, max_num_triplets=100)
f = ti.linalg.SparseMatrixBuilder(n, 1, max_num_triplets=100)
K = ti.linalg.SparseMatrixBuilder(n,
n,
max_num_triplets=100,
dtype=ti.f32,
storage_format='col_major')
f = ti.linalg.SparseMatrixBuilder(n,
1,
max_num_triplets=100,
dtype=ti.f32,
storage_format='col_major')


@ti.kernel
def fill(A: ti.types.sparse_matrix_builder(),
b: ti.types.sparse_matrix_builder(), interval: ti.i32):
for i in range(n):
if i > 0:
A[i - 1, i] += -1.0
A[i, i] += 1
A[i - 1, i] += -2.0
A[i, i] += 1.0
if i < n - 1:
A[i + 1, i] += -1.0
A[i, i] += 1.0
Expand All @@ -33,32 +41,36 @@ def fill(A: ti.types.sparse_matrix_builder(),
print(">>>> A = K.build()")
print(A)

print(">>>> Summation: C = A + A")
C = A + A
print(">>>> Summation: B = A + A")
B = A + A
print(B)

print(">>>> Summation: B += A")
B += A
print(B)

print(">>>> Subtraction: C = B - A")
C = B - A
print(C)

print(">>>> Subtraction: C -= A")
C -= A
print(C)

print(">>>> Subtraction: D = A - A")
D = A - A
print(">>>> Multiplication with a scalar on the right: D = A * 3.0")
D = A * 3.0
print(D)

print(">>>> Multiplication with a scalar on the right: E = A * 3.0")
E = A * 3.0
print(E)
print(">>>> Multiplication with a scalar on the left: D = 3.0 * A")
D = 3.0 * A
print(D)

print(">>>> Multiplication with a scalar on the left: E = 3.0 * A")
E = 3.0 * A
print(">>>> Transpose: E = D.transpose()")
E = D.transpose()
print(E)

print(">>>> Transpose: F = A.transpose()")
F = A.transpose()
print(">>>> Matrix multiplication: F= E @ A")
F = E @ A
print(F)

print(">>>> Matrix multiplication: G = E @ A")
G = E @ A
print(G)

print(">>>> Element-wise multiplication: H = E * A")
H = E * A
print(H)

print(f">>>> Element Access: A[0,0] = {A[0,0]}")
print(f">>>> Element Access: F[0,0] = {F[0,0]}")
41 changes: 35 additions & 6 deletions python/taichi/linalg/sparse_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,32 @@ class SparseMatrix:
m (int): the second dimension of a sparse matrix.
sm (SparseMatrix): another sparse matrix that will be built from.
"""
def __init__(self, n=None, m=None, sm=None, dtype=f32):
def __init__(self,
n=None,
m=None,
sm=None,
dtype=f32,
storage_format="col_major"):
if sm is None:
self.n = n
self.m = m if m else n
self.matrix = get_runtime().prog.create_sparse_matrix(n, m)
self.matrix = get_runtime().prog.create_sparse_matrix(
n, m, dtype, storage_format)
else:
self.n = sm.num_rows()
self.m = sm.num_cols()
self.matrix = sm

def __iadd__(self, other):
"""Addition operation for sparse matrix.
Returns:
The result sparse matrix of the addition.
"""
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
self.matrix += other.matrix
return self

def __add__(self, other):
"""Addition operation for sparse matrix.
Expand All @@ -35,6 +51,16 @@ def __add__(self, other):
sm = self.matrix + other.matrix
return SparseMatrix(sm=sm)

def __isub__(self, other):
"""Subtraction operation for sparse matrix.
Returns:
The result sparse matrix of the subtraction.
"""
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
self.matrix -= other.matrix
return self

def __sub__(self, other):
"""Subtraction operation for sparse matrix.
Expand All @@ -54,7 +80,7 @@ def __mul__(self, other):
The result of multiplication.
"""
if isinstance(other, float):
sm = self.matrix * other
sm = other * self.matrix
return SparseMatrix(sm=sm)
if isinstance(other, SparseMatrix):
assert self.n == other.n and self.m == other.m, f"Dimension mismatch between sparse matrices ({self.n}, {self.m}) and ({other.n}, {other.m})"
Expand All @@ -72,7 +98,7 @@ def __rmul__(self, other):
The result of multiplication.
"""
if isinstance(other, float):
sm = other * self.matrix
sm = self.matrix * other
return SparseMatrix(sm=sm)

return None
Expand Down Expand Up @@ -135,18 +161,21 @@ class SparseMatrixBuilder:
num_rows (int): the first dimension of a sparse matrix.
num_cols (int): the second dimension of a sparse matrix.
max_num_triplets (int): the maximum number of triplets.
dtype (ti.dtype): the data type of the sparse matrix.
storage_format (str): the storage format of the sparse matrix.
"""
def __init__(self,
num_rows=None,
num_cols=None,
max_num_triplets=0,
dtype=f32):
dtype=f32,
storage_format="col_major"):
self.num_rows = num_rows
self.num_cols = num_cols if num_cols else num_rows
self.dtype = dtype
if num_rows is not None:
self.ptr = get_runtime().prog.create_sparse_matrix_builder(
num_rows, num_cols, max_num_triplets, dtype)
num_rows, num_cols, max_num_triplets, dtype, storage_format)

def _get_addr(self):
"""Get the address of the sparse matrix"""
Expand Down
152 changes: 76 additions & 76 deletions taichi/program/sparse_matrix.cpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,53 @@
#include "taichi/program/sparse_matrix.h"

#include <sstream>
#include <string>
#include <unordered_map>
#include <utility>

#include "Eigen/Dense"
#include "Eigen/SparseLU"

#define BUILD(TYPE) \
{ \
using T = Eigen::Triplet<float##TYPE>; \
std::vector<T> *triplets = static_cast<std::vector<T> *>(triplets_adr); \
matrix_.setFromTriplets(triplets->begin(), triplets->end()); \
}

#define MAKE_MATRIX(TYPE, STORAGE) \
{ \
Pair("f" #TYPE, #STORAGE), \
[](int rows, int cols, DataType dt) -> std::unique_ptr<SparseMatrix> { \
using FC = Eigen::SparseMatrix<float##TYPE, Eigen::STORAGE>; \
return std::make_unique<EigenSparseMatrix<FC>>(rows, cols, dt); \
} \
}

namespace {
using Pair = std::pair<std::string, std::string>;
struct key_hash {
std::size_t operator()(const Pair &k) const {
auto h1 = std::hash<std::string>{}(k.first);
auto h2 = std::hash<std::string>{}(k.second);
return h1 ^ h2;
}
};
} // namespace

namespace taichi {
namespace lang {

SparseMatrixBuilder::SparseMatrixBuilder(int rows,
int cols,
int max_num_triplets,
DataType dtype)
DataType dtype,
const std::string &storage_format)
: rows_(rows),
cols_(cols),
max_num_triplets_(max_num_triplets),
dtype_(dtype) {
dtype_(dtype),
storage_format_(storage_format) {
auto element_size = data_type_size(dtype);
TI_ASSERT((element_size == 4 || element_size == 8));
data_base_ptr_ =
Expand Down Expand Up @@ -50,116 +82,84 @@ void SparseMatrixBuilder::print_triplets() {
}

template <typename T, typename G>
SparseMatrix SparseMatrixBuilder::build_template() {
void SparseMatrixBuilder::build_template(std::unique_ptr<SparseMatrix> &m) {
using V = Eigen::Triplet<T>;
std::vector<V> triplets;
T *data = reinterpret_cast<T *>(data_base_ptr_.get());
for (int i = 0; i < num_triplets_; i++) {
triplets.push_back(V(((G *)data)[i * 3], ((G *)data)[i * 3 + 1],
taichi_union_cast<T>(data[i * 3 + 2])));
}
SparseMatrix sm(rows_, cols_);
sm.get_matrix().setFromTriplets(triplets.begin(), triplets.end());
m->build_triplets(static_cast<void *>(&triplets));
clear();
return sm;
}

SparseMatrix SparseMatrixBuilder::build() {
std::unique_ptr<SparseMatrix> SparseMatrixBuilder::build() {
TI_ASSERT(built_ == false);
built_ = true;
auto sm = make_sparse_matrix(rows_, cols_, dtype_, storage_format_);
auto element_size = data_type_size(dtype_);
switch (element_size) {
case 4:
return build_template<float32, int32>();
build_template<float32, int32>(sm);
break;
case 8:
return build_template<float64, int64>();
build_template<float64, int64>(sm);
break;
default:
TI_ERROR("Unsupported sparse matrix data type!");
break;
}
return sm;
}

void SparseMatrixBuilder::clear() {
built_ = false;
num_triplets_ = 0;
}

SparseMatrix::SparseMatrix(Eigen::SparseMatrix<float32> &matrix) {
this->matrix_ = matrix;
}

SparseMatrix::SparseMatrix(int rows, int cols) : matrix_(rows, cols) {
}

const std::string SparseMatrix::to_string() const {
template <class EigenMatrix>
const std::string EigenSparseMatrix<EigenMatrix>::to_string() const {
Eigen::IOFormat clean_fmt(4, 0, ", ", "\n", "[", "]");
// Note that the code below first converts the sparse matrix into a dense one.
// https://stackoverflow.com/questions/38553335/how-can-i-print-in-console-a-formatted-sparse-matrix-with-eigen
std::ostringstream ostr;
ostr << Eigen::MatrixXf(matrix_).format(clean_fmt);
ostr << Eigen::MatrixXf(matrix_.template cast<float>()).format(clean_fmt);
return ostr.str();
}

const int SparseMatrix::num_rows() const {
return matrix_.rows();
}
const int SparseMatrix::num_cols() const {
return matrix_.cols();
}

Eigen::SparseMatrix<float32> &SparseMatrix::get_matrix() {
return matrix_;
}

const Eigen::SparseMatrix<float32> &SparseMatrix::get_matrix() const {
return matrix_;
}

SparseMatrix operator+(const SparseMatrix &sm1, const SparseMatrix &sm2) {
Eigen::SparseMatrix<float32> res(sm1.matrix_ + sm2.matrix_);
return SparseMatrix(res);
}

SparseMatrix operator-(const SparseMatrix &sm1, const SparseMatrix &sm2) {
Eigen::SparseMatrix<float32> res(sm1.matrix_ - sm2.matrix_);
return SparseMatrix(res);
}

SparseMatrix operator*(float scale, const SparseMatrix &sm) {
Eigen::SparseMatrix<float32> res(scale * sm.matrix_);
return SparseMatrix(res);
}

SparseMatrix operator*(const SparseMatrix &sm, float scale) {
return scale * sm;
}

SparseMatrix operator*(const SparseMatrix &sm1, const SparseMatrix &sm2) {
Eigen::SparseMatrix<float32> res(sm1.matrix_.cwiseProduct(sm2.matrix_));
return SparseMatrix(res);
}

SparseMatrix SparseMatrix::matmul(const SparseMatrix &sm) {
Eigen::SparseMatrix<float32> res(matrix_ * sm.matrix_);
return SparseMatrix(res);
}

Eigen::VectorXf SparseMatrix::mat_vec_mul(
const Eigen::Ref<const Eigen::VectorXf> &b) {
return matrix_ * b;
}

SparseMatrix SparseMatrix::transpose() {
Eigen::SparseMatrix<float32> res(matrix_.transpose());
return SparseMatrix(res);
}

float32 SparseMatrix::get_element(int row, int col) {
return matrix_.coeff(row, col);
template <class EigenMatrix>
void EigenSparseMatrix<EigenMatrix>::build_triplets(void *triplets_adr) {
std::string sdtype = taichi::lang::data_type_name(dtype_);
if (sdtype == "f32") {
BUILD(32)
} else if (sdtype == "f64") {
BUILD(64)
} else {
TI_ERROR("Unsupported sparse matrix data type {}!", sdtype);
}
}

void SparseMatrix::set_element(int row, int col, float32 value) {
matrix_.coeffRef(row, col) = value;
std::unique_ptr<SparseMatrix> make_sparse_matrix(
int rows,
int cols,
DataType dt,
const std::string &storage_format = "col_major") {
using func_type = std::unique_ptr<SparseMatrix> (*)(int, int, DataType);
static const std::unordered_map<Pair, func_type, key_hash> map = {
MAKE_MATRIX(32, ColMajor), MAKE_MATRIX(32, RowMajor),
MAKE_MATRIX(64, ColMajor), MAKE_MATRIX(64, RowMajor)};
std::unordered_map<std::string, std::string> format_map = {
{"col_major", "ColMajor"}, {"row_major", "RowMajor"}};
std::string tdt = taichi::lang::data_type_name(dt);
Pair key = std::make_pair(tdt, format_map.at(storage_format));
auto it = map.find(key);
if (it != map.end()) {
auto func = map.at(key);
return func(rows, cols, dt);
} else
TI_ERROR("Unsupported sparse matrix data type: {}, storage format: {}", tdt,
storage_format);
}

} // namespace lang
Expand Down
Loading

0 comments on commit 736ebd5

Please sign in to comment.