Skip to content

Commit

Permalink
Bump torch to <2.4.1 (#145)
Browse files Browse the repository at this point in the history
* bump torch to <2.5 (#142)

* bump torch to <2.5 (#143)

* bump torch to <2.4.1 (#144)

* bump torch (#146)

* install from git, not pypi

* Update setup.py

Co-authored-by: Saaketh Narayan <narayan.saaketh@gmail.com>

* no type checking in `kernel.py` (#147)

---------

Co-authored-by: Saaketh Narayan <narayan.saaketh@gmail.com>
  • Loading branch information
eitanturok and snarayan21 authored Aug 30, 2024
1 parent 5b2650a commit 35abddf
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 110 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ jobs:
fail-fast: false
matrix:
include:
- name: "python3.11-pytorch2.3.1-gpus1"
- name: "python3.11-pytorch2.4.0-gpus1"
gpu_num: 1
python_version: 3.11
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
- name: "python3.11-pytorch2.3.1-gpus2"
container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
- name: "python3.11-pytorch2.4.0-gpus2"
gpu_num: 2
python_version: 3.11
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
steps:
- name: Run PR GPU tests
uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2
Expand Down
144 changes: 43 additions & 101 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional

import torch
import triton
import triton.language as tl


def assert_is_tensor(x: torch.Tensor, ndim: int):
def assert_is_tensor(x, ndim):
if x.ndim != ndim:
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')


def assert_is_matrix(x: torch.Tensor):
def assert_is_matrix(x):
assert_is_tensor(x, 2)


def assert_is_vector(x: torch.Tensor):
def assert_is_vector(x):
if x.ndim != 1:
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')


def assert_equal(a: Any, b: Any):
def assert_equal(a, b):
if a != b:
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)

Expand All @@ -44,13 +43,13 @@ def assert_equal(a: Any, b: Any):
)
@triton.jit
def _padded_copy(
a: torch.Tensor,
b: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Any,
bins: torch.Tensor,
padded_bins: torch.Tensor,
a,
b,
indices,
bin_ids,
weights,
bins,
padded_bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -105,15 +104,7 @@ def _padded_copy(
offsets += BLOCK_X


def padded_gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -129,7 +120,7 @@ def padded_gather(

# NOTE: Because of the padding, the output size is dynamic.
# We load the final padded bin bound to get the output rows.
output_rows = int(padded_bins[-1].cpu().item())
output_rows = padded_bins[-1].cpu().item()
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
_padded_copy[(indices.shape[0],)](
x,
Expand All @@ -147,14 +138,7 @@ def padded_gather(
return out


def gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
def gather(x, indices, bin_ids, weights, bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -186,15 +170,7 @@ def gather(
return out


def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -227,14 +203,7 @@ def padded_scatter(
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])


def scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
def scatter(x, indices, bin_ids, weights, bins, top_k):
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)


Expand All @@ -257,13 +226,13 @@ def scatter(
)
@triton.jit
def _padded_copy_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
x,
grad,
wgrad,
indices,
bin_ids,
bins,
padded_bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -307,15 +276,7 @@ def _padded_copy_wgrad(
tl.store(wgrad, out)


def padded_scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_matrix(grad)
Expand All @@ -342,14 +303,7 @@ def padded_scatter_wgrad(
return out


def scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
top_k: int,
):
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)


Expand All @@ -370,13 +324,13 @@ def scatter_wgrad(
)
@triton.jit
def _binned_copy(
a: torch.Tensor,
b: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
weights, #: Optional[torch.Tensor],
bins: torch.Tensor,
a,
b,
num_experts,
expert_capacity,
indices,
weights,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -435,14 +389,7 @@ def _binned_copy(
offsets += BLOCK_X


def binned_gather(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
expert_capacity: int,
top_k: int,
):
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -454,6 +401,7 @@ def binned_gather(

num_experts = bins.shape[0]
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)

_binned_copy[(num_experts, expert_capacity)](
x,
out,
Expand All @@ -470,13 +418,7 @@ def binned_gather(
return out


def binned_scatter(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
def binned_scatter(x, indices, weights, bins, top_k):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_vector(indices)
Expand Down Expand Up @@ -524,13 +466,13 @@ def binned_scatter(
)
@triton.jit
def _binned_copy_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
bins: torch.Tensor,
x,
grad,
wgrad,
num_experts,
expert_capacity,
indices,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -576,7 +518,7 @@ def _binned_copy_wgrad(
tl.store(wgrad, out)


def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int):
def binned_scatter_wgrad(x, grad, indices, bins, top_k):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_matrix(grad)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# build requirements
[build-system]
requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"]
requires = ["setuptools < 70.0.0", "torch >= 2.4.0, < 2.4.1"]
build-backend = "setuptools.build_meta"

# Pytest
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@
install_requires = [
'numpy>=1.21.5,<2.1.0',
'packaging>=21.3.0,<24.2',
'torch>=2.3.0,<2.4',
'torch>=2.4.0,<2.4.1',
'triton>=2.1.0',
'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301',
'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@v0.7.1',
]

extra_deps = {}

extra_deps['gg'] = [
'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb',
'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@v0.1.6',
]

extra_deps['dev'] = [
Expand All @@ -83,7 +83,7 @@
]

extra_deps['testing'] = [
'mosaicml>=0.22.0',
'mosaicml>=0.24.1',
]

extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}})
Expand Down

0 comments on commit 35abddf

Please sign in to comment.