Skip to content

Commit

Permalink
Fix actions (#31)
Browse files Browse the repository at this point in the history
* Fix actions

* change to py39

* Fix black

* More fixes

* Fixes

* Fix torch version

* Fix cudnn
  • Loading branch information
pkufool committed Aug 24, 2023
1 parent 285ad4d commit 6a4b834
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 37 deletions.
10 changes: 10 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[flake8]
show-source=true
statistics=true
max-line-length=80

exclude =
.git,
.github,
setup.py,
build,
Empty file modified .github/scripts/install_cuda.sh
100644 → 100755
Empty file.
Empty file modified .github/scripts/install_cudnn.sh
100644 → 100755
Empty file.
5 changes: 3 additions & 2 deletions .github/scripts/install_torch.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

torch=$TORCH_VERSION
cuda=$CUDA_VERSION
echo "torch version: $torch"
echo "cuda version: $cuda"

case ${torch} in
1.5.*)
case ${cuda} in
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/run_tests_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
torch: ["1.13.1"]
torchaudio: ["0.13.1"]
python-version: ["3.11"]
torch: ["1.12.1"]
torchaudio: ["0.12.1"]
python-version: ["3.9"]
build_type: ["Release", "Debug"]

steps:
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
run: |
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }} -f https://download.pytorch.org/whl/cpu/torch_stable.html
python3 -c "import torch; print('torch version:', torch.__version__)"
python3 -m torch.utils.collect_env
Expand All @@ -92,7 +92,7 @@ jobs:
run: |
python3 -m pip install -qq --upgrade pip
python3 -m pip install -qq torch==${{ matrix.torch }}
python3 -m pip install -qq torch==${{ matrix.torchaudio }}
python3 -m pip install -qq torchaudio==${{ matrix.torchaudio }}
python3 -c "import torch; print('torch version:', torch.__version__)"
python3 -m torch.utils.collect_env
Expand Down
9 changes: 4 additions & 5 deletions .github/workflows/run_tests_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
cuda: ["11.7"]
torch: ["1.13.1"]
torchaudio: ["1.13.0"]
python-version: ["3.11"]
cuda: ["11.6"]
torch: ["1.12.1"]
python-version: ["3.9"]
build_type: ["Release", "Debug"]

steps:
Expand Down Expand Up @@ -103,7 +102,7 @@ jobs:
env:
cuda: ${{ matrix.cuda }}
run: |
./scripts/github_actions/install_cudnn.sh
./.github/scripts/install_cudnn.sh
- name: Configure CMake
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/style_check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ jobs:
shell: bash
working-directory: ${{github.workspace}}
run: |
black --check --diff .
black -l 80 --check --diff .
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ endif()

cmake_minimum_required(VERSION 3.8 FATAL_ERROR)

set(CMAKE_DISABLE_FIND_PACKAGE_MKL TRUE)
set(languages CXX)
set(_FT_WITH_CUDA ON)

Expand Down
8 changes: 6 additions & 2 deletions fast_rnnt/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ endif()
pybind11_add_module(_fast_rnnt ${fast_rnnt_srcs})
target_link_libraries(_fast_rnnt PRIVATE mutual_information_core)

if(UNIX AND NOT APPLE)
if(APPLE)
target_link_libraries(_fast_rnnt
PRIVATE
${TORCH_DIR}/lib/libtorch_python.dylib
)
elseif(UNIX)
target_link_libraries(_fast_rnnt
PRIVATE
${PYTHON_LIBRARY}
${TORCH_DIR}/lib/libtorch_python.so
)
endif()

2 changes: 0 additions & 2 deletions fast_rnnt/python/fast_rnnt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,3 @@
from .rnnt_loss import rnnt_loss_pruned
from .rnnt_loss import rnnt_loss_simple
from .rnnt_loss import rnnt_loss_smoothed


18 changes: 11 additions & 7 deletions fast_rnnt/python/fast_rnnt/mutual_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def forward(
if return_grad or px.requires_grad or py.requires_grad:
ans_grad = torch.ones(B, device=px.device, dtype=px.dtype)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px, py, boundary, p, ans_grad)
px, py, boundary, p, ans_grad
)
ctx.save_for_backward(px_grad, py_grad)
assert len(pxy_grads) == 2
pxy_grads[0] = px_grad
Expand Down Expand Up @@ -290,8 +291,9 @@ def mutual_information_recursion(
px, py = px.contiguous(), py.contiguous()

pxy_grads = [None, None]
scores = MutualInformationRecursionFunction.apply(px, py, pxy_grads,
boundary, return_grad)
scores = MutualInformationRecursionFunction.apply(
px, py, pxy_grads, boundary, return_grad
)
px_grad, py_grad = pxy_grads
return (scores, (px_grad, py_grad)) if return_grad else scores

Expand Down Expand Up @@ -388,16 +390,18 @@ def joint_mutual_information_recursion(
p = torch.empty(B, S + 1, T + 1, device=px_tot.device, dtype=px_tot.dtype)

# note, tot_probs is without grad.
tot_probs = _fast_rnnt.mutual_information_forward(px_tot, py_tot, boundary, p)
tot_probs = _fast_rnnt.mutual_information_forward(
px_tot, py_tot, boundary, p
)

# this is a kind of "fake gradient" that we use, in effect to compute
# occupation probabilities. The backprop will work regardless of the
# actual derivative w.r.t. the total probs.
ans_grad = torch.ones(B, device=px_tot.device, dtype=px_tot.dtype)

(px_grad,
py_grad) = _fast_rnnt.mutual_information_backward(px_tot, py_tot, boundary, p,
ans_grad)
(px_grad, py_grad) = _fast_rnnt.mutual_information_backward(
px_tot, py_tot, boundary, p, ans_grad
)

px_grad = px_grad.reshape(1, B, -1)
py_grad = py_grad.reshape(1, B, -1)
Expand Down
24 changes: 15 additions & 9 deletions fast_rnnt/python/fast_rnnt/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_rnnt_logprobs(
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down Expand Up @@ -291,7 +291,9 @@ def rnnt_loss_simple(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down Expand Up @@ -495,7 +497,9 @@ def rnnt_loss(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down Expand Up @@ -770,9 +774,7 @@ def do_rnnt_pruning(
lm_pruning = torch.gather(
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, C)
),
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, C)),
).reshape(B, T, s_range, C)
return am_pruning, lm_pruning

Expand Down Expand Up @@ -1057,7 +1059,9 @@ def rnnt_loss_pruned(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down Expand Up @@ -1248,7 +1252,7 @@ def get_rnnt_logprobs_smoothed(
am.transpose(1, 2), # (B, C, T)
dim=1,
index=symbols.unsqueeze(2).expand(B, S, T),
) # (B, S, T)
) # (B, S, T)

if rnnt_type == "regular":
px_am = torch.cat(
Expand Down Expand Up @@ -1413,7 +1417,9 @@ def rnnt_loss_smoothed(
T = T0 if rnnt_type != "regular" else T0 - 1
if boundary is None:
offset = torch.tensor(
(T - 1) / 2, dtype=px.dtype, device=px.device,
(T - 1) / 2,
dtype=px.dtype,
device=px.device,
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def build_extension(self, ext: setuptools.extension.Extension):
cmake_args = "-DCMAKE_BUILD_TYPE=Release -DFT_BUILD_TESTS=OFF"

if make_args == "" and system_make_args == "":
make_args = ' -j '
make_args = " -j "

if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
Expand Down Expand Up @@ -89,17 +89,17 @@ def get_package_version():
latest_version = latest_version.strip('"')
return latest_version


def get_requirements():
with open("requirements.txt", encoding="utf8") as f:
requirements = f.read().splitlines()

return requirements


package_name = "fast_rnnt"

with open(
"fast_rnnt/python/fast_rnnt/__init__.py", "a"
) as f:
with open("fast_rnnt/python/fast_rnnt/__init__.py", "a") as f:
f.write(f"__version__ = '{get_package_version()}'\n")

setuptools.setup(
Expand Down

0 comments on commit 6a4b834

Please sign in to comment.