Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] SmemCopyAtom and MMA_Atom for fp32? #1842

Open
vickyandpiggy opened this issue Sep 26, 2024 · 7 comments
Open

[QST] SmemCopyAtom and MMA_Atom for fp32? #1842

vickyandpiggy opened this issue Sep 26, 2024 · 7 comments

Comments

@vickyandpiggy
Copy link

vickyandpiggy commented Sep 26, 2024

What is your question?
hello, I am developing a full precision attention backward kernel using cutlass, and get stuck in the use of ldmatrix and mma instructions for fp32.

My Gemm calculation is based on fp32 matrix, i.e. the datatype of D/A/B/C are all fp32. But the structs providied in mma_sm80.hpp take half-precision/mixed precision inputs so I am pretty confused about how to do things right in full precision. Here is my current setting for MMA, smem and gmem. Is there a way to use SM75_U32x4_LDSM_N and one of the mma instructions in my case?

  // MMA
  using TiledMma = TiledMMA<MMA_Atom<UniversalFMA<float, float, float>>, Layout<Shape<Int<16>, Int<8>, _1>>>;

  // Smem
  using SmemLayoutAtom = decltype(
    composition(Swizzle<3,3,3>{},
                Layout<Shape < _16,_32>,
                       Stride<_32, _1>>{}));
  using SmemCopyAtom = Copy_Atom<DefaultCopy, float>;

  // Gmem
  using GmemTiledCopy = decltype(
    make_tiled_copy(Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<cute::uint128_t>, float>{},
                    Layout<Shape <_16,_8>,
                           Stride< _8,_1>>{},
                    Layout<Shape < _1,_4>>{}));

@vickyandpiggy
Copy link
Author

Could anyone help? Many thanks!

@vickyandpiggy
Copy link
Author

It seems to me that mma instructions does not support fp32 for Multiplicand A/B from https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-data-types. So can i use ldmatrix alone to accelerate the copying from smem to register? Or is there better practice for full precision?

@cloudhan

This comment has been minimized.

@vickyandpiggy
Copy link
Author

vickyandpiggy commented Oct 10, 2024

The NVIDIA tensor core does not natively support A/B with fp32 inputs. So it is not possible.

alternatives are:

  1. use .tf32 version with reduced precision.
  2. if reduced precision is not bearable, try https://arxiv.org/pdf/2203.03341
  3. try AMD cards

ldmatrix instruction is also tightly coupled with tensor core fragment layout, maybe you can use .tf32 version if the layout match. But they may not speedup your matrix loading, you need to take care of L2 friendly prefecting and eliminating bank conflicts manually. It is all about how to correctly and preformantly use cute.

Thank you so much for the suggestions! I got it :)

@thakkarV
Copy link
Collaborator

thakkarV commented Oct 10, 2024

ldmatrix instruction is also tightly coupled with tensor core fragment layout

This is not necessarily true. You can in principle copy to arbitrary layouts using LDSM provided the partitioning is valid.

The NVIDIA tensor core does not natively support A/B with fp32 inputs

This is also irrelevant. @vickyandpiggy is trying to use SIMT cores for the matmul itself. In this case, you can still totally use LDSM provided the smem layout is legal to partition with LDSM.

@thakkarV
Copy link
Collaborator

@vickyandpiggy Please do not be discouraged.

Is there a way to use SM75_U32x4_LDSM_N and one of the mma instructions in my case?

What have you tried? what does the kernel look like so far? btw, for SIMT tensor cores, the throughput is low enough that it should not matter whether you use ld.shared or ld.matrix. You should still be able to achieve peak throughput

Copy link

github-actions bot commented Nov 9, 2024

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants