Skip to content

Latest commit

 

History

History
254 lines (184 loc) · 8.35 KB

01_layout.md

File metadata and controls

254 lines (184 loc) · 8.35 KB

CuTe Layouts

Layout

This document describes Layout, CuTe's core abstraction. A Layout maps from (a) logical coordinate space(s) to a physical index space.

Layouts present a common interface to multidimensional array access that abstracts away the details of how the array's elements are organized in memory. This lets users write algorithms that access multidimensional arrays generically, so that layouts can change, without users' code needing to change.

CuTe also provides an "algebra of Layouts." Layouts can be combined and manipulated to construct more complicated layouts and to partition them across other layouts. This can help users do things like partition layouts of data over layouts of threads.

Layouts and Tensors

Any of the Layouts discussed in this section can be composed with data -- a pointer or an array -- to create a Tensor. The responsibility of the Layout is to define valid coordinate space(s) and, therefore, the logical shape of the data and map those into an index space. The index space is precisely the offset that would be used to index into the array of data.

For details on Tensor, please refer to the Tensor section of the tutorial.

Shapes and Strides

A Layout is a pair of Shape and Stride. Both Shape and Stride are IntTuple types.

IntTuple

An IntTuple is an integer or a tuple of IntTuples. This means that IntTuples can be arbitrarily nested. Operations defined on IntTuples include the following.

  • get<I>(IntTuple): The Ith element of the IntTuple. Note that get<0> is defined for integer IntTuples.

  • rank(IntTuple): The number of elements in an IntTuple. An int has rank 1, a tuple has rank tuple_size.

  • depth(IntTuple): The number of hierarchical IntTuples. An int has depth 0, a tuple has depth 1, a tuple that contains a tuple has depth 2, etc.

  • size(IntTuple): The product of all elements of the IntTuple.

We write IntTuples with parenthesis to denote the hierarchy. E.g. 6, (2), (4,3), (3,(6,2),8) are all IntTuples.

Layout

A Layout is then a pair of IntTuples. The first defines the abstract shape of the layout and the second defines the strides, which map from coordinates within the shape to the index space.

As a pair of IntTuples, we can define many similar operations on Layouts including

  • get<I>(Layout): The Ith sub-layout of the Layout.

  • rank(Layout): The number of modes in a Layout.

  • depth(Layout): The number of hierarchical Layouts. An int has depth 0, a tuple has depth 1, a tuple that contains a tuple has depth 2, etc.

  • shape(Layout): The shape of the Layout.

  • stride(Layout): The stride of the Layout.

  • size(Layout): The logical extent of the Layout. Equivalent to size(shape(Layout)).

Hierarchical access functions

IntTuples and thus Layouts can be arbitrarily nested. For convenience, we define versions of some of the above functions that take a sequence of integers, instead of just one integer. This makes it possible to access elements inside of nested IntTuple or Layout. For example, we permit get<I...>(x), where I... here and throughout this section is a "C++ parameter pack" that denotes zero or more (integer) template arguments. That is, get<I0,I1,...,IN>(x) is equivalent to get<IN>( $\dots$ (get<I1>(get<I0>(x))) $\dots$ )), where the ellipses are pseudocode and not actual C++ syntax. These hierarchical access functions include the following.

  • rank<I...>(x) := rank(get<I...>(x)). The rank of the I...th element of x.

  • depth<I...>(x) := depth(get<I...>(x)). The depth of the I...th element of x.

  • size<I...>(x) := size(get<I...>(x)). The size of the I...th element of x.

Vector examples

Then, we can define a vector as any Shape and Stride pair with rank == 1. For example, the Layout

Shape:  (8)
Stride: (1)

defines a contiguous 8-element vector. Similarly, with a stride of (2), the interpretation is that the eight elements are stored at positions 0, 2, 4, $\dots$.

By the above definition, we also interpret

Shape:  ((4,2))
Stride: ((1,4))

as a vector, since its shape is rank 1. The inner shape describes a 4x2 layout of data in column-major order, but the extra pair of parenthesis suggest we can interpret those two modes as a single 1-D 8-element vector instead. Due to the strides, the elements are also contiguous.

Matrix examples

Generalizing, we define a matrix as any Shape and Stride pair with rank 2. For example,

Shape:  (4,2)
Stride: (1,4)
  0   4
  1   5
  2   6
  3   7

is a 4x2 column-major matrix, and

Shape:  (4,2)
Stride: (2,1)
  0   1
  2   3
  4   5
  6   7

is a 4x2 row-major matrix.

Each of the modes of the matrix can also be split into multi-indices like the vector example. This lets us express more layouts beyond just row major and column major. For example,

Shape:  ((2,2),2)
Stride: ((4,1),2)
  0   2
  4   6
  1   3
  5   7

is also logically 4x2, with a stride of 2 across the rows but a multi-stride down the columns. Since this layout is logically 4x2, like the column-major and row-major examples above, we can still use 2-D coordinates to index into it.

Constructing a Layout

A Layout can be constructed in many different ways. It can include any combination of compile-time (static) integers or run-time (dynamic) integers.

auto layout_8s = make_layout(Int<8>{});
auto layout_8d = make_layout(8);

auto layout_2sx4s = make_layout(make_shape(Int<2>{},Int<4>{}));
auto layout_2sx4d = make_layout(make_shape(Int<2>{},4));

auto layout_2x4 = make_layout(make_shape (2, make_shape (2,2)),
                              make_stride(4, make_stride(1,2)));

Using a Layout

The fundamental use of a Layout is to map between logical coordinate space(s) and index space. For example, to print an arbitrary rank-2 layout, we can write the function

template <class Shape, class Stride>
void print2D(Layout<Shape,Stride> const& layout)
{
  for (int m = 0; m < size<0>(layout); ++m) {
    for (int n = 0; n < size<1>(layout); ++n) {
      printf("%3d  ", layout(m,n));
    }
    printf("\n");
  }
}

which produces the following output for the above examples.

> print2D(layout_2sx4s)
  0   2   4   6
  1   3   5   7
> print2D(layout_2sx4d)
  0   2   4   6
  1   3   5   7
> print2D(layout_2x4)
  0   2   1   3
  4   6   5   7

The multi-indices within the layout_4x4 example are handled as expected and interpreted as a rank-2 layout.

Note that for layout_1x4, we're using a 1-D coordinate for a 2-D multi-index in the second mode. In fact, we can generalize this and treat all of the above layouts as 1-D layouts. For instance, the following print1D function

template <class Shape, class Stride>
void print1D(Layout<Shape,Stride> const& layout)
{
  for (int i = 0; i < size(layout); ++i) {
    printf("%3d  ", layout(i));
  }
}

produces the following output for the above examples.

> print1D(layout_8s)
  0   1   2   3   4   5   6   7
> print1D(layout_8d)
  0   1   2   3   4   5   6   7
> print1D(layout_2sx4s)
  0   1   2   3   4   5   6   7
> print1D(layout_2sx4d)
  0   1   2   3   4   5   6   7
> print1D(layout_2x4)
  0   4   2   6   1   5   3   7

This shows explicitly that all of the layouts are simply folded views of an 8-element array.

Summary

  • The Shape of a Layout defines its coordinate space(s).

    • Every Layout has a 1-D coordinate space. This can be used to iterate in a "generalized-column-major" order.

    • Every Layout has a R-D coordinate space, where R is the rank of the layout. These spaces are ordered colexicographically (reading right to left, instead of "lexicographically," which reads left to right). The enumeration of that order corresponds to the 1-D coordinates above.

    • Every Layout has an h-D coordinate space where h is "hierarchical." These are ordered colexicographically and the enumeration of that order corresponds to the 1-D coordinates above. An h-D coordinate is congruent to the Shape so that each element of the coordinate has a corresponding element of the Shape.

  • The Stride of a Layout maps coordinates to indices.

    • In general, this could be any function from 1-D coordinates (integers) to indices (integers).

    • In CuTe we use an inner product of the h-D coordinates with the Stride elements.