How to correctly use cute::copy to transfer 4bit data ? #1867
-
I found that in cute, the underlying storage of #include <cute/tensor.hpp>
using namespace cute;
template <typename Config>
__global__ void test_int4b_kernel(int4b_t *a_ptr, int m, int n) {
using SmemLayoutA = typename Config::SmemLayoutA;
using G2SCopyA = typename Config::G2SCopyA;
constexpr int kTileM = Config::kTileM;
constexpr int kTileN = Config::kTileN;
extern __shared__ int8_t shm_data[];
int4b_t *Ashm = (int4b_t *)shm_data;
auto A = make_tensor(make_gmem_ptr(a_ptr), make_shape(m, n),
make_stride(n, Int<1>{}));
int idx = threadIdx.x;
int ix = blockIdx.x;
auto gA =
local_tile(A, make_tile(Int<kTileM>{}, Int<kTileN>{}), make_coord(ix, _));
auto sA = make_tensor(make_smem_ptr(Ashm), SmemLayoutA{});
G2SCopyA g2s_tiled_copy_a;
auto g2s_thr_copy_a = g2s_tiled_copy_a.get_slice(idx);
auto tAgA_copy = g2s_thr_copy_a.partition_S(gA); // (CPY, CPY_M, CPY_K, k)
auto tAsA_copy = g2s_thr_copy_a.partition_D(sA); // (CPY, CPY_M, CPY_K)
clear(tAsA_copy);
cute::copy(g2s_tiled_copy_a, tAgA_copy(_, _, _, 0), tAsA_copy);
cp_async_fence();
cp_async_wait<0>();
__syncthreads();
if (thread0()) {
print("gA: ");
print_tensor(gA);
print("\nsA: ");
print_tensor(sA);
}
}
template <int kTileM_ = 128, int kTileN_ = 128>
struct TestConfig {
// tile configuration
static constexpr int kThreadNum = 32;
static constexpr int kTileM = kTileM_;
static constexpr int kTileN = kTileN_;
using SmemLayoutAtom =
decltype(make_layout(make_shape(Int<16>{}, Int<kTileN>{}),
make_stride(Int<kTileN>{}, Int<1>{})));
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtom{}, make_shape(Int<kTileM>{}, Int<kTileN>{})));
using g2s_copy_op = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
using g2s_copy_traits = Copy_Traits<g2s_copy_op>;
using g2s_copy_atom = Copy_Atom<g2s_copy_traits, int4b_t>;
using G2SCopyA =
decltype(make_tiled_copy(g2s_copy_atom{},
make_layout(make_shape(Int<16>{}, Int<2>{}),
make_stride(Int<2>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<32>{}))));
static constexpr int kShmSize = cute::cosize(SmemLayoutA{}) / 2;
};
void test_int4b() {
using namespace cute;
constexpr int M = 128;
constexpr int N = 64;
TestConfig<M, N> test_config;
std::vector<int8_t> v(M * M / 2);
for (int i = 0; i < M * M / 2; ++i) {
int t = i % 8;
v[i] = (t & 0x0f) | ((t & 0x0f) << 4);
}
int4b_t *v_d;
cudaMalloc(&v_d, M * M / 2);
cudaMemcpy(v_d, v.data(), M * M / 2, cudaMemcpyHostToDevice);
dim3 block = test_config.kThreadNum;
dim3 grid((M + test_config.kTileM - 1) / test_config.kTileM);
int shm_size = test_config.kShmSize;
auto partition_kernel = test_int4b_kernel<decltype(test_config)>;
cudaFuncSetAttribute(partition_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize, shm_size);
partition_kernel<<<grid, block, shm_size>>>(v_d, M, M);
cudaFree(v_d);
} However, after I modified
Why are the print results of sA and gA different? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
This is a C/C++ thing, not really a CuTe thing. In C/C++, You can create a packed CuTe tensor by specifying the logical data type you're working with: Tensor mA = make_tensor(make_gmem_ptr<uint4b_t>(my_ptr), my_layout_of_4b); which creates a "packed" pointer from Tensor sA = make_tensor(make_smem_ptr<uint4b_t>(my_s_ptr), my_layout_of_4b);
Tensor rA = make_tensor<uint4b_t>(my_layout_of_4b); |
Beta Was this translation helpful? Give feedback.
This is a C/C++ thing, not really a CuTe thing. In C/C++,
int4b_t*
is a pointer toint4b_t
s with underlying storage ofint8_t
. CuTe cannot assume that this pointer means "packed" (2xint4b_t
within eachint8_t
) safely. This is the reason whyarray_subbyte.data()
has been removed from CuTe'sarray_subbyte
container (But apparently not CUTLASS's) -- it is dangerous and error-prone to use these naked pointers that don't mean what you think they mean.You can create a packed CuTe tensor by specifying the logical data type you're working with:
which creates a "packed" pointer from
my_ptr
. Similarly with rmem or smem: