Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Question about tiledcopy with swizzle layout #1947

Open
ssiu opened this issue Nov 16, 2024 · 1 comment
Open

[QST] Question about tiledcopy with swizzle layout #1947

ssiu opened this issue Nov 16, 2024 · 1 comment

Comments

@ssiu
Copy link

ssiu commented Nov 16, 2024

Hi, I am trying to learn about tiledcopy involving swizzle layouts and am currently running into some confusion. My code is here https://github.com/ssiu/cuda/blob/master/cutlass/tiled_copy_swizzle.cu

Basically I am trying to copy an 8 x 8 tensor (g_in, initialized from 0-63) to another 8 x 8 tensor (g_out). We use a single warp to copy these 64 elements. The tiledcopy is

    TiledCopy tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, T>{},
                                     Layout<Shape<_4,_8>, Stride<_1,_4>>{},
                                     Layout<Shape< _2,_1>>{});

which looks like this
Image

If we define the layout of g_in to be row major and the layout of g_out to be column major:

    auto in_layout = make_layout(make_shape (Int<8>{}, Int<8>{}),
                        make_stride(Int<8>{}, Int<1>{}));

    auto out_layout = make_layout(make_shape (Int<8>{}, Int<8>{}),
                        make_stride(Int<1>{}, Int<8>{}));

then the tiledcopy is just a transpose operation which is expected

Image

However, if we define the layout of g_out to be a swizzle layout

    auto out_layout = composition(Swizzle<1, 1, 1>{},
                                 make_layout(make_shape (Int<8>{}, Int<8>{}),
                                 make_stride(Int<1>{}, Int<8>{})));

which looks like this
Image

then we get

Image

I was expecting g_out to look exactly the same as the swizzle layout as shown? Am I doing something wrong?

Thanks!

@ccecka
Copy link

ccecka commented Nov 16, 2024

Thank you for the excellent reproducer and explanation!

Are you concerned about the extra 0s in the output or that the first two columns look like they are column-major?

When I execute your program with the Swizzled Layout, I do not get those extra 0s. I've included my full build command:

$ make scratch_bug
/usr/local/cuda/bin/nvcc -ccbin=g++ -O3 -std=c++17 -ftemplate-backtrace-limit=0 -arch sm_70 -uumn --expt-extended-lambda --expt-relaxed-constexpr --use_fast_math -Xptxas -v --compiler-options "-O3 -std=c++17 -ftemplate-backtrace-limit=0 -Wall -Wno-unused-local-typedefs -Wno-strict-aliasing -Wno-unused-function -Wno-format-security -Wno-unknown-pragmas -Wno-psabi" -I. -I/usr/local/cuda/include -I/home/ccecka/Desktop/mnt/ccecka_nvresearch/kernel_store/cutlass/include -I/home/ccecka/Desktop/mnt/ccecka_nvresearch/kernel_store/cutlass/tools/util/include -I/home/ccecka/Desktop/mnt/ccecka_nvresearch/kernel_store/cutlass/test -I/home/ccecka/Desktop/mnt/ccecka_nvresearch/kernel_store/cutlass/examples -o scratch_bug scratch_bug.cu -L/usr/local/cuda/lib64 
ptxas info    : 113 bytes gmem, 312 bytes cmem[4]
ptxas info    : Compiling entry function '_ZN3cub17CUB_200302_700_NS11EmptyKernelIvEEvv' for 'sm_70'
ptxas info    : Function properties for _ZN3cub17CUB_200302_700_NS11EmptyKernelIvEEvv
    0 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 4 registers, 352 bytes cmem[0]
ptxas info    : Compiling entry function '_Z9mm_kernelIfN4cute6LayoutINS0_5tupleIJNS0_1CILi8EEES4_EEENS2_IJS4_NS3_ILi1EEEEEEEENS0_14ComposedLayoutINS0_7SwizzleILi1ELi1ELi1EEENS3_ILi0EEENS1_IS5_NS2_IJS6_S4_EEEEEEENS0_9TiledCopyINS0_9Copy_AtomIJNS0_39AutoVectorizingCopyWithAssumedAlignmentILi8EEEfEEENS1_INS2_IJNS3_ILi32EEENS3_ILi2EEEEEENS2_IJSM_S6_EEEEES5_EEEvPT_T0_SS_T1_T2_' for 'sm_70'
ptxas info    : Function properties for _Z9mm_kernelIfN4cute6LayoutINS0_5tupleIJNS0_1CILi8EEES4_EEENS2_IJS4_NS3_ILi1EEEEEEEENS0_14ComposedLayoutINS0_7SwizzleILi1ELi1ELi1EEENS3_ILi0EEENS1_IS5_NS2_IJS6_S4_EEEEEEENS0_9TiledCopyINS0_9Copy_AtomIJNS0_39AutoVectorizingCopyWithAssumedAlignmentILi8EEEfEEENS1_INS2_IJNS3_ILi32EEENS3_ILi2EEEEEENS2_IJSM_S6_EEEEES5_EEEvPT_T0_SS_T1_T2_
    24 bytes stack frame, 0 bytes spill stores, 0 bytes spill loads
ptxas info    : Used 24 registers, 24 bytes cumulative stack size, 378 bytes cmem[0]
$ ./scratch_bug
Using device 0: NVIDIA TITAN V  (SM70, 80 SMs)
Sw<1,1,1> o _0 o (_8,_8):(_1,_8)
       0    1    2    3    4    5    6    7 
    +----+----+----+----+----+----+----+----+
 0  |  0 |  8 | 16 | 24 | 32 | 40 | 48 | 56 |
    +----+----+----+----+----+----+----+----+
 1  |  1 |  9 | 17 | 25 | 33 | 41 | 49 | 57 |
    +----+----+----+----+----+----+----+----+
 2  |  2 | 10 | 18 | 26 | 34 | 42 | 50 | 58 |
    +----+----+----+----+----+----+----+----+
 3  |  3 | 11 | 19 | 27 | 35 | 43 | 51 | 59 |
    +----+----+----+----+----+----+----+----+
 4  |  6 | 14 | 22 | 30 | 38 | 46 | 54 | 62 |
    +----+----+----+----+----+----+----+----+
 5  |  7 | 15 | 23 | 31 | 39 | 47 | 55 | 63 |
    +----+----+----+----+----+----+----+----+
 6  |  4 | 12 | 20 | 28 | 36 | 44 | 52 | 60 |
    +----+----+----+----+----+----+----+----+
 7  |  5 | 13 | 21 | 29 | 37 | 45 | 53 | 61 |
    +----+----+----+----+----+----+----+----+
  g_in : gmem_ptr[32b](0x7d1149a00000) o (_8,_8):(_8,_1)
 g_out : gmem_ptr[32b](0x7d1149a00200) o Sw<1,1,1> o _0 o (_8,_8):(_1,_8)
 tg_in : gmem_ptr[32b](0x7d1149a00000) o ((_1,_2),_1,_1):((_0,_8),_0,_0)
ts_out : gmem_ptr[32b](0x7d1149a00200) o ((_1,_2),_1,_1):((_0,_1),_0,_0)
g_in  : 
 0  1  2  3  4  5  6  7 
 8  9 10 11 12 13 14 15 
16 17 18 19 20 21 22 23 
24 25 26 27 28 29 30 31 
32 33 34 35 36 37 38 39 
40 41 42 43 44 45 46 47 
48 49 50 51 52 53 54 55 
56 57 58 59 60 61 62 63 
==========
g_out : 
 0  8 16 24 48 56 32 40 
 1  9 17 25 49 57 33 41 
 2 10 18 26 50 58 34 42 
 3 11 19 27 51 59 35 43 
 4 12 20 28 52 60 36 44 
 5 13 21 29 53 61 37 45 
 6 14 22 30 54 62 38 46 
 7 15 23 31 55 63 39 47

As you can see, you do get what looks like initially like column-major output, but then you can see that the columns that appear to be permuted. That's the effect of the Swizzle. It looks like this because your printing function is printing the array like that with no consideration of the layout:

    printf("g_out : \n");
    for (int i = 0; i < 8; i++){
        for (int j=0;j<8;j++){
            printf("%2.0f ", h_out[i*8+j]);
        }
        printf("\n");
    }

If you replace those with cute::print_tensor(Tensor) then you'll see the logical view of your output rather than the physical view.

    print_tensor(make_tensor(h_in.data(), in_layout));
    print_tensor(make_tensor(h_out.data(), out_layout));
==================
ptr[32b](0x58ae29c31d50) o (_8,_8):(_8,_1):
  0.00e+00  1.00e+00  2.00e+00  3.00e+00  4.00e+00  5.00e+00  6.00e+00  7.00e+00
  8.00e+00  9.00e+00  1.00e+01  1.10e+01  1.20e+01  1.30e+01  1.40e+01  1.50e+01
  1.60e+01  1.70e+01  1.80e+01  1.90e+01  2.00e+01  2.10e+01  2.20e+01  2.30e+01
  2.40e+01  2.50e+01  2.60e+01  2.70e+01  2.80e+01  2.90e+01  3.00e+01  3.10e+01
  3.20e+01  3.30e+01  3.40e+01  3.50e+01  3.60e+01  3.70e+01  3.80e+01  3.90e+01
  4.00e+01  4.10e+01  4.20e+01  4.30e+01  4.40e+01  4.50e+01  4.60e+01  4.70e+01
  4.80e+01  4.90e+01  5.00e+01  5.10e+01  5.20e+01  5.30e+01  5.40e+01  5.50e+01
  5.60e+01  5.70e+01  5.80e+01  5.90e+01  6.00e+01  6.10e+01  6.20e+01  6.30e+01
ptr[32b](0x58ae29c32730) o Sw<1,1,1> o _0 o (_8,_8):(_1,_8):
  0.00e+00  1.00e+00  2.00e+00  3.00e+00  4.00e+00  5.00e+00  6.00e+00  7.00e+00
  8.00e+00  9.00e+00  1.00e+01  1.10e+01  1.20e+01  1.30e+01  1.40e+01  1.50e+01
  1.60e+01  1.70e+01  1.80e+01  1.90e+01  2.00e+01  2.10e+01  2.20e+01  2.30e+01
  2.40e+01  2.50e+01  2.60e+01  2.70e+01  2.80e+01  2.90e+01  3.00e+01  3.10e+01
  3.20e+01  3.30e+01  3.40e+01  3.50e+01  3.60e+01  3.70e+01  3.80e+01  3.90e+01
  4.00e+01  4.10e+01  4.20e+01  4.30e+01  4.40e+01  4.50e+01  4.60e+01  4.70e+01
  4.80e+01  4.90e+01  5.00e+01  5.10e+01  5.20e+01  5.30e+01  5.40e+01  5.50e+01
  5.60e+01  5.70e+01  5.80e+01  5.90e+01  6.00e+01  6.10e+01  6.20e+01  6.30e+01

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants