Skip to content

Commit

Permalink
Generic backward implem with cutlass (facebookresearch#371)
Browse files Browse the repository at this point in the history
* Old bw code

* P100: gradV working

* gk/gq working (at least for small values of M, and on P100/f16)

* Further restrict supported values for bw

* Fix storage into smem for Simt

* More tooling for pruint/debug

* Remove tests we dont need for now

* Tests pass on P100 :D

* 4 warps per block

* Restraint on q length

* Use tensorcores on V100 for f16

* Support dynamic smem for bw

* Handle alignment and different dtype/arch

* Fix NaNS by initializing shared memory

* bw.py

* Fix launch bounds

* Faster 'computeDi'

* minus_lse can operate on arrays

* Output number of regs used etc...

* Code cleanup

* Hackfix for alignment check during forward

* zFill to avoid nans in Sm80 + fix launch bounds

* COde cleanup1

* clang-format

* Fix tests

* Add benchmark for K=64

Co-authored-by: danthe3rd <danthe3rd@users.noreply.github.com>
Co-authored-by: danthe3rd <danthe3rd>
  • Loading branch information
danthe3rd and danthe3rd authored Aug 1, 2022
1 parent cd8e3d2 commit 1d4a0ce
Show file tree
Hide file tree
Showing 12 changed files with 1,920 additions and 229 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_extensions():
nvcc_flags = nvcc_flags.split(" ")
cuda_version = get_cuda_version(CUDA_HOME)
if cuda_version >= 1102:
nvcc_flags += ["--threads", "4"]
nvcc_flags += ["--threads", "4", "--ptxas-options=-v"]
extra_compile_args["nvcc"] = nvcc_flags
if (
cuda_version >= 1100
Expand Down
17 changes: 11 additions & 6 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@


def assert_allclose(
out: torch.Tensor, ref: torch.Tensor, msg: str = "failed", **kwargs
out: torch.Tensor,
ref: torch.Tensor,
msg: str = "failed",
atol: float = 1e-8,
rtol: float = 1e-5,
) -> None:
flatten_diff = (out - ref).abs().flatten()
flatten_diff = ((out - ref).abs() - atol - ref.abs() * rtol).flatten()
max_pos = flatten_diff.argmax()
assert torch.allclose(out, ref, **kwargs), (
assert torch.allclose(out, ref, rtol=rtol, atol=atol), (
f"{msg}: max_diff={flatten_diff.max()}: "
f"out={out.flatten()[max_pos]} and ref={ref.flatten()[max_pos]} "
f"/ atol={kwargs.get('atol', 1e-8)}"
f"/ atol={atol}, rtol={rtol}"
)


Expand Down Expand Up @@ -172,7 +176,7 @@ def test_logsumexp(
(query.float() / k_len**0.5) @ key.float().transpose(-2, -1)
).logsumexp(-1)

assert_allclose(lse, ref_lse, atol=2e-4)
assert_allclose(lse[:, : ref_lse.shape[1]], ref_lse, atol=2e-4)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -204,6 +208,7 @@ def test_memory_efficient_attention_backward(
dtype,
op: xformers.ops.MemoryEfficientAttentionOp,
):
torch.manual_seed(batch_size * q_len + kv_len * k_len)
scale = 3
query = torch.randn((batch_size, q_len, k_len), device=device, dtype=dtype) * scale
key = torch.randn((batch_size, kv_len, k_len), device=device, dtype=dtype) * scale
Expand Down Expand Up @@ -252,7 +257,7 @@ def test_memory_efficient_attention_backward(
atol = 2e-4 + 2e-6 * k_len * kv_len * math.sqrt(batch_size) * math.sqrt(q_len)
rtol = 1e-8
if dtype is torch.half:
atol = 3e-2
atol = 4e-2
rtol = 1e-2
if dtype is torch.bfloat16:
# I've seen (out=0.29 / ref=0.03)
Expand Down
4 changes: 3 additions & 1 deletion xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def ref_attention(q, k, v, attn_bias=None, p=0.0):
device = torch.device("cuda")

NUM_THREADS = [1] if device.type == "cuda" else [1, 40]
SHAPES = list(itertools.product([32, 256], [128, 512, 1024], [16, 32, 128]))
SHAPES = list(itertools.product([32, 256], [128, 512, 1024], [16, 32, 64, 128]))
SHAPES = list(set(SHAPES))
SHAPES.sort()

Expand Down Expand Up @@ -92,6 +92,7 @@ def benchmark_forward(shape, num_threads: int, attn_bias_type, dtype):
attn_bias_type=attn_bias_type,
has_dropout=False,
kv_len=M,
q_len=M,
)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
Expand Down Expand Up @@ -166,6 +167,7 @@ def benchmark_backward(shape, num_threads: int, attn_bias_type, dtype):
attn_bias_type=attn_bias_type,
has_dropout=False,
kv_len=M,
q_len=M,
)
try:
op = dispatch.op if FORCE_OP is None else FORCE_OP
Expand Down
2 changes: 2 additions & 0 deletions xformers/components/attention/csrc/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ TORCH_LIBRARY_FRAGMENT(xformers, m) {
"xformers::efficient_attention_forward_generic(Tensor query, Tensor key, Tensor value, bool compute_logsumexp, Tensor? attn_bias, float p) -> (Tensor, Tensor, int, int)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::efficient_attention_backward_generic(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor logsumexp, Tensor output, Tensor? attn_bias, float p, int rng_seed, int rng_offset) -> (Tensor, Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"xformers::_temp_dropout(Tensor out, float p) -> Tensor"));
}
Loading

0 comments on commit 1d4a0ce

Please sign in to comment.