Skip to content

Commit

Permalink
add max triton v3 kernel tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian Beck committed Sep 26, 2024
1 parent d38ced8 commit 6e09dbf
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
# args: [ --autofix, --no-sort-keys ]
# - id: name-tests-test # TODO this odes not allow for other files in tests/ folder
# args: [ --pytest-test-first ]
- id: no-commit-to-branch
# - id: no-commit-to-branch
- id: trailing-whitespace
- repo: https://github.com/asottile/setup-cfg-fmt
rev: v2.5.0
Expand Down
26 changes: 23 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def mlstm_interface(
torch.Tensor: matH outputs (no n and m values, no last states)
tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: matH, (matC_last, vecN_last, scaM_last)
"""

pass

```
Expand All @@ -38,14 +38,34 @@ The mLSTM repo contains the following kernel variants:
- `parallel`: parallel kernels like flash attention (quadratic)
- `recurrent`: recurrent kernels (mostly for inference) (linear)

Not all variants support all features of the interface. Only the chunkwise and recurrent support passing the initial states and returning the last states.
Not all variants support all features of the interface. Only the chunkwise and recurrent support passing the initial states and returning the last states.

### Kernel naming

#### External names of kernel functions in chunkwise, parallel and recurrent modules:
- Python function: `mlstm_[recurrent|parallel|chunkwise]_[specifier]_[triton|torch]_[[autograd|ownbw]]`
- Registry name (within module): `[specifier]_[triton|torch]_[[autograd|ownbw]]`


## Running the unit tests

The unit tests cross-check the different kernel implementations on numerical deviations for different dtypes.
You can run all of them with the following command:
```bash
pytest -s tests/test_mlstm/
```

The `-s` disables the log capturing so you see the results directly on the command line.
Each test will log the outputs to a new folder with the timestamp as name in the `test_outputs/` directory.

Example:
Each test starts with the line
`Test chunkwise-triton target=max_triton_v3 vs. baseline=parallel_stable_ag with S=4096, B=1, NH=1, DHQK=16, DHHV=16, DTYPE=torch.float32`.

This test tests the chunkwise triton kernel `max_triton_v3` against the `parallel_stable_ag` baseline and runs the `max_triton_v3` in dtype float32. It will compare the errors against the baseline in the same dtype (i.e. float32 here) and in float64.



---
---
---
Expand All @@ -64,5 +84,5 @@ Not all variants support all features of the interface. Only the chunkwise and r

## Questions about Nsight Systems & Nsight Compute

- How can I organize the workflow efficiently in project.
- How can I organize the workflow efficiently in project.
- How can I compare to baselines efficiently.
61 changes: 61 additions & 0 deletions tests/test_mlstm/test_chunkwise/test_chunkwise_triton_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from mlstm_kernels.mlstm.chunkwise import mlstm_chunkwise_max_triton
from mlstm_kernels.mlstm.chunkwise import mlstm_chunkwise_max_triton_v3

from ..test_params import final_combinations

Expand Down Expand Up @@ -54,6 +55,36 @@ def test_chunkwise_triton_vs_stable_torch(
add_fp64_baseline=True,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.")
@pytest.mark.xfail(reason="Fails due to numerical instability")
@pytest.mark.parametrize(
["S", "B", "NH", "DHQK", "DHHV", "target_dtype"], final_combinations
)
def test_chunkwise_triton_max_v3_vs_stable_torch(
self, test_session_folder, S, B, NH, DHQK, DHHV, target_dtype
):
print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}")
template_test_parallel_interface(
baseline_fn=mlstm_parallel_stable_torch_autograd,
target_fn=mlstm_chunkwise_max_triton_v3,
baseline_name="parallel_stable_ag",
target_name="max_triton_v3",
S=S,
B=B,
NH=NH,
DHQK=DHQK,
DHHV=DHHV,
dtype=getattr(torch, target_dtype),
atol_fw=1.0, # 3.0
rtol_fw=1.0,
atol_fwbw=1.5, # 3.5
rtol_fwbw=1.0,
vmax=1.0,
test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX,
save_dir=str(test_session_folder),
add_fp64_baseline=True,
)


class TestChunkwiseTritonAgVsUnstableTorchLong:
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.")
Expand Down Expand Up @@ -85,3 +116,33 @@ def test_chunkwise_triton_vs_unstable_torch(
save_dir=str(test_session_folder),
add_fp64_baseline=True,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="No GPU available.")
@pytest.mark.xfail(reason="Fails due to numerical instability")
@pytest.mark.parametrize(
["S", "B", "NH", "DHQK", "DHHV", "target_dtype"], final_combinations
)
def test_chunkwise_triton_max_v3_vs_unstable_torch(
self, test_session_folder, S, B, NH, DHQK, DHHV, target_dtype
):
print(f"S{S}B{B}NH{NH}DHQK{DHQK}DHHV{DHHV}")
template_test_parallel_interface(
baseline_fn=mlstm_parallel_torch_autograd,
target_fn=mlstm_chunkwise_max_triton_v3,
baseline_name="parallel_unstable_ag",
target_name="max_triton_v3",
S=S,
B=B,
NH=NH,
DHQK=DHQK,
DHHV=DHHV,
dtype=getattr(torch, target_dtype),
atol_fw=1.0, # 3.0
rtol_fw=1.0,
atol_fwbw=1.5, # 3.5
rtol_fwbw=1.0,
vmax=1.0,
test_folder_name_prefix=TEST_FOLDER_NAME_PREFIX,
save_dir=str(test_session_folder),
add_fp64_baseline=True,
)

0 comments on commit 6e09dbf

Please sign in to comment.