Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor and simplify BLAS gemm call further #1421

Merged
merged 6 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ keywords = ["array", "data-structure", "multidimensional", "matrix", "blas"]
categories = ["data-structures", "science"]

exclude = ["docgen/images/*"]
resolver = "2"

[lib]
name = "ndarray"
Expand Down
4 changes: 2 additions & 2 deletions crates/blas-mock-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ doc = false
doctest = false

[dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }
cblas-sys = { workspace = true }

[dev-dependencies]
ndarray = { workspace = true, features = ["approx", "blas"] }
ndarray-gen = { workspace = true }
itertools = { workspace = true }
16 changes: 10 additions & 6 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use ndarray::linalg::general_mat_vec_mul;
use ndarray::Order;
use ndarray::{Data, Ix, LinalgScalar};
use ndarray_gen::array_builder::ArrayBuilder;
use ndarray_gen::array_builder::ElementGenerator;

use approx::assert_relative_eq;
use defmac::defmac;
Expand Down Expand Up @@ -230,7 +231,6 @@ fn gen_mat_mul()
let sizes = vec![
(4, 4, 4),
(8, 8, 8),
(10, 10, 10),
(8, 8, 1),
(1, 10, 10),
(10, 1, 10),
Expand All @@ -241,19 +241,23 @@ fn gen_mat_mul()
(4, 17, 3),
(17, 3, 22),
(19, 18, 2),
(16, 17, 15),
(15, 16, 17),
(67, 63, 62),
(67, 50, 62),
];
let strides = &[1, 2, -1, -2];
let cf_order = [Order::C, Order::F];
let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard];

// test different strides and memory orders
for (&s1, &s2) in iproduct!(strides, strides) {
for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) {
for &(m, k, n) in &sizes {
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5;
println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3);
let a = ArrayBuilder::new((m, k))
.memory_order(ord1)
.generator(gen)
.build()
* 0.5;
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();

Expand Down
17 changes: 8 additions & 9 deletions crates/ndarray-gen/src/array_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pub struct ArrayBuilder<D: Dimension>
pub enum ElementGenerator
{
Sequential,
Checkerboard,
Zero,
}

Expand Down Expand Up @@ -64,16 +65,14 @@ where D: Dimension
pub fn build<T>(self) -> Array<T, D>
where T: Num + Clone
{
let mut current = T::zero();
let zero = T::zero();
let size = self.dim.size();
let use_zeros = self.generator == ElementGenerator::Zero;
Array::from_iter((0..size).map(|_| {
let ret = current.clone();
if !use_zeros {
current = ret.clone() + T::one();
}
ret
}))
(match self.generator {
ElementGenerator::Sequential =>
Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)),
ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()),
ElementGenerator::Zero => Array::zeros(size),
})
.into_shape_with_order((self.dim, self.memory_order))
.unwrap()
}
Expand Down
1 change: 1 addition & 0 deletions scripts/cross-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ QC_FEAT=--features=ndarray-rand/quickcheck

cross build -v --features="$FEATURES" $QC_FEAT --target=$TARGET
cross test -v --no-fail-fast --features="$FEATURES" $QC_FEAT --target=$TARGET
cross test -v -p blas-mock-tests
2 changes: 1 addition & 1 deletion scripts/makechangelog.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# Will produce some duplicates for PRs integrated using rebase,
# but those will not occur with current merge queue.

git log --first-parent --pretty="format:%H" "$@" | while read commit_sha
git log --first-parent --pretty="tformat:%H" "$@" | while IFS= read -r commit_sha
do
gh api "/repos/:owner/:repo/commits/${commit_sha}/pulls" \
-q ".[] | \"- \(.title) by [@\(.user.login)](\(.user.html_url)) [#\(.number)](\(.html_url))\""
Expand Down
Loading