Skip to content

Commit

Permalink
blas: Refactor and simplify gemm call further
Browse files Browse the repository at this point in the history
Further clarify transpose logic by putting it into BlasOrder methods.
  • Loading branch information
bluss committed Aug 9, 2024
1 parent 7226d39 commit 453eae3
Showing 1 changed file with 57 additions and 67 deletions.
124 changes: 57 additions & 67 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use libc::c_int;
#[cfg(feature = "blas")]
use cblas_sys as blas_sys;
#[cfg(feature = "blas")]
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT};
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};

/// len of vector before we use blas
#[cfg(feature = "blas")]
Expand Down Expand Up @@ -400,40 +400,33 @@ fn mat_mul_impl<A>(
// Compute A B -> C
// We require for BLAS compatibility that:
// A, B, C are contiguous (stride=1) in their fastest dimension,
// but it can be either first or second axis (either rowmajor/"c" or colmajor/"f").
// but they can be either row major/"c" or col major/"f".
//
// The "normal case" is CblasRowMajor for cblas.
// Select CblasRowMajor, CblasColMajor to fit C's memory order.
// Select CblasRowMajor / CblasColMajor to fit C's memory order.
//
// Apply transpose to A, B as needed if they differ from the normal case.
// Apply transpose to A, B as needed if they differ from the row major case.
// If C is CblasColMajor then transpose both A, B (again!)

let (a_layout, a_axis, b_layout, b_axis, c_layout) =
match (get_blas_compatible_layout(a),
get_blas_compatible_layout(b),
get_blas_compatible_layout(c))
let (a_layout, b_layout, c_layout) =
if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
(get_blas_compatible_layout(a),
get_blas_compatible_layout(b),
get_blas_compatible_layout(c))
{
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => {
(a_layout, a_layout.lead_axis(),
b_layout, b_layout.lead_axis(), c_layout)
},
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => {
// CblasColMajor is the "other case"
// Mark a, b as having layouts opposite of what they were detected as, which
// ends up with the correct transpose setting w.r.t col major
(a_layout.opposite(), a_layout.lead_axis(),
b_layout.opposite(), b_layout.lead_axis(), c_layout)
},
_ => break 'blas_block,
(a_layout, b_layout, c_layout)
} else {
break 'blas_block;
};

let a_trans = a_layout.to_cblas_transpose();
let lda = blas_stride(&a, a_axis);
let cblas_layout = c_layout.to_cblas_layout();
let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
let lda = blas_stride(&a, a_layout);

let b_trans = b_layout.to_cblas_transpose();
let ldb = blas_stride(&b, b_axis);
let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
let ldb = blas_stride(&b, b_layout);

let ldc = blas_stride(&c, c_layout.lead_axis());
let ldc = blas_stride(&c, c_layout);

macro_rules! gemm_scalar_cast {
(f32, $var:ident) => {
Expand All @@ -457,7 +450,7 @@ fn mat_mul_impl<A>(
// Where Op is notrans/trans/conjtrans
unsafe {
blas_sys::$gemm(
c_layout.to_cblas_layout(),
cblas_layout,
a_trans,
b_trans,
m as blas_index, // m, rows of Op(a)
Expand Down Expand Up @@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
// may be arbitrary.
let a_trans = CblasNoTrans;

let (a_stride, cblas_layout) = match layout {
MemoryOrder::C => {
(a.strides()[0].max(k as isize) as blas_index,
CBLAS_LAYOUT::CblasRowMajor)
}
MemoryOrder::F => {
(a.strides()[1].max(m as isize) as blas_index,
CBLAS_LAYOUT::CblasColMajor)
}
};
let a_stride = blas_stride(&a, layout);
let cblas_layout = layout.to_cblas_layout();

// Low addr in memory pointers required for x, y
let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
Expand Down Expand Up @@ -835,61 +820,66 @@ where
#[cfg(feature = "blas")]
#[derive(Copy, Clone)]
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
enum MemoryOrder
enum BlasOrder
{
C,
F,
}

#[cfg(feature = "blas")]
impl MemoryOrder
impl BlasOrder
{
#[inline]
/// Axis of leading stride (opposite of contiguous axis)
fn lead_axis(self) -> usize
fn transpose(self) -> Self
{
match self {
MemoryOrder::C => 0,
MemoryOrder::F => 1,
Self::C => Self::F,
Self::F => Self::C,
}
}

/// Get opposite memory order
#[inline]
fn opposite(self) -> Self
/// Axis of leading stride (opposite of contiguous axis)
fn get_blas_lead_axis(self) -> usize
{
match self {
MemoryOrder::C => MemoryOrder::F,
MemoryOrder::F => MemoryOrder::C,
Self::C => 0,
Self::F => 1,
}
}

fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE
fn to_cblas_layout(self) -> CBLAS_LAYOUT
{
match self {
MemoryOrder::C => CblasNoTrans,
MemoryOrder::F => CblasTrans,
Self::C => CBLAS_LAYOUT::CblasRowMajor,
Self::F => CBLAS_LAYOUT::CblasColMajor,
}
}

fn to_cblas_layout(self) -> CBLAS_LAYOUT
/// When using cblas_sgemm (etc) with C matrix using `for_layout`,
/// how should this `self` matrix be transposed
fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
{
match self {
MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor,
MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor,
let effective_order = match for_layout {
CBLAS_LAYOUT::CblasRowMajor => self,
CBLAS_LAYOUT::CblasColMajor => self.transpose(),
};

match effective_order {
Self::C => CblasNoTrans,
Self::F => CblasTrans,
}
}
}

#[cfg(feature = "blas")]
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
{
let (m, n) = dim.into_pattern();
let s0 = stride[0] as isize;
let s1 = stride[1] as isize;
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
MemoryOrder::C => (s1, s0, m, n),
MemoryOrder::F => (s0, s1, n, m),
BlasOrder::C => (s1, s0, m, n),
BlasOrder::F => (s0, s1, n, m),
};

if !(inner_stride == 1 || outer_dim == 1) {
Expand Down Expand Up @@ -920,13 +910,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool

/// Get BLAS compatible layout if any (C or F, preferring the former)
#[cfg(feature = "blas")]
fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<MemoryOrder>
fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<BlasOrder>
where S: Data
{
if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) {
Some(MemoryOrder::C)
} else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) {
Some(MemoryOrder::F)
if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) {
Some(BlasOrder::C)
} else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) {
Some(BlasOrder::F)
} else {
None
}
Expand All @@ -937,10 +927,10 @@ where S: Data
///
/// Return leading stride (lda, ldb, ldc) of array
#[cfg(feature = "blas")]
fn blas_stride<S>(a: &ArrayBase<S, Ix2>, axis: usize) -> blas_index
fn blas_stride<S>(a: &ArrayBase<S, Ix2>, order: BlasOrder) -> blas_index
where S: Data
{
debug_assert!(axis <= 1);
let axis = order.get_blas_lead_axis();
let other_axis = 1 - axis;
let len_this = a.shape()[axis];
let len_other = a.shape()[other_axis];
Expand Down Expand Up @@ -968,7 +958,7 @@ where
if !same_type::<A, S::Elem>() {
return false;
}
is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
is_blas_2d(&a.dim, &a.strides, BlasOrder::C)
}

#[cfg(test)]
Expand All @@ -982,7 +972,7 @@ where
if !same_type::<A, S::Elem>() {
return false;
}
is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
is_blas_2d(&a.dim, &a.strides, BlasOrder::F)
}

#[cfg(test)]
Expand Down Expand Up @@ -1096,7 +1086,7 @@ mod blas_tests
if stride < N {
assert_eq!(get_blas_compatible_layout(&m), None);
} else {
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
}
}
}
Expand Down

0 comments on commit 453eae3

Please sign in to comment.