Skip to content

Commit

Permalink
more flexible disabling of autocast for fsq
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 6, 2024
1 parent be92f79 commit 16ec0a8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.15.0"
version = "1.15.1"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
31 changes: 18 additions & 13 deletions vector_quantize_pytorch/finite_scalar_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
"""

from __future__ import annotations
from functools import wraps
from functools import wraps, partial
from contextlib import nullcontext
from typing import List, Tuple

import torch
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
channel_first: bool = False,
projection_has_bias: bool = True,
return_indices = True,
force_quantization_f32 = True
):
super().__init__()
_levels = torch.tensor(levels, dtype=int32)
Expand Down Expand Up @@ -99,6 +101,7 @@ def __init__(
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)

self.allowed_dtypes = allowed_dtypes
self.force_quantization_f32 = force_quantization_f32

def bound(self, z, eps: float = 1e-3):
""" Bound `z`, an array of shape (..., d). """
Expand Down Expand Up @@ -166,7 +169,6 @@ def forward(self, z):
c - number of codebook dim
"""

orig_dtype = z.dtype
is_img_or_video = z.ndim >= 4
need_move_channel_last = is_img_or_video or self.channel_first

Expand All @@ -182,25 +184,28 @@ def forward(self, z):

z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)

# make sure allowed dtype before quantizing
# whether to force quantization step to be full precision or not

if z.dtype not in self.allowed_dtypes:
z = z.float()
force_f32 = self.force_quantization_f32
quantization_context = partial(autocast, enabled = False) if force_f32 else nullcontext

codes = self.quantize(z)
with quantization_context():
orig_dtype = z.dtype

# returning indices could be optional
if force_f32 and orig_dtype not in self.allowed_dtypes:
z = z.float()

indices = None
codes = self.quantize(z)

if self.return_indices:
indices = self.codes_to_indices(codes)
# returning indices could be optional

codes = rearrange(codes, 'b n c d -> b n (c d)')
indices = None

# cast codes back to original dtype
if self.return_indices:
indices = self.codes_to_indices(codes)

codes = rearrange(codes, 'b n c d -> b n (c d)')

if codes.dtype != orig_dtype:
codes = codes.type(orig_dtype)

# project out
Expand Down

0 comments on commit 16ec0a8

Please sign in to comment.