Skip to content

Commit

Permalink
address #162
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 25, 2024
1 parent 492e666 commit b5cb143
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.17.4"
version = "1.17.5"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
5 changes: 4 additions & 1 deletion tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ def test_vq_mask():
@pytest.mark.parametrize('implicit_neural_codebook', (True, False))
@pytest.mark.parametrize('use_cosine_sim', (True, False))
@pytest.mark.parametrize('train', (True, False))
@pytest.mark.parametrize('shared_codebook', (True, False))
def test_residual_vq(
implicit_neural_codebook,
use_cosine_sim,
train
train,
shared_codebook
):
from vector_quantize_pytorch import ResidualVQ

Expand All @@ -78,6 +80,7 @@ def test_residual_vq(
codebook_size = 128,
implicit_neural_codebook = implicit_neural_codebook,
use_cosine_sim = use_cosine_sim,
shared_codebook = shared_codebook
)

x = torch.randn(1, 256, 32)
Expand Down
12 changes: 12 additions & 0 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def __init__(
ema_update = False
)

if shared_codebook:
vq_kwargs.update(
manual_ema_update = True
)

self.layers = ModuleList([VectorQuantize(dim = codebook_dim, codebook_dim = codebook_dim, accept_image_fmap = accept_image_fmap, **vq_kwargs) for _ in range(num_quantizers)])

assert all([not vq.has_projections for vq in self.layers])
Expand All @@ -157,6 +162,8 @@ def __init__(

# sharing codebook logic

self.shared_codebook = shared_codebook

if not shared_codebook:
return

Expand Down Expand Up @@ -349,6 +356,11 @@ def forward(
all_indices.append(embed_indices)
all_losses.append(loss)

# if shared codebook, update ema only at end

if self.shared_codebook:
first(self.layers)._codebook.update_ema()

# project out, if needed

quantized_out = self.project_out(quantized_out)
Expand Down
41 changes: 28 additions & 13 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def __init__(
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
ema_update = True,
manual_ema_update = False,
affine_param = False,
sync_affine_param = False,
affine_param_batch_decay = 0.99,
Expand All @@ -290,6 +291,7 @@ def __init__(

self.decay = decay
self.ema_update = ema_update
self.manual_ema_update = manual_ema_update

init_fn = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(num_codebooks, codebook_size, dim)
Expand Down Expand Up @@ -458,6 +460,12 @@ def expire_codes_(self, batch_samples):
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
self.replace(batch_samples, batch_mask = expired_codes)

def update_ema(self):
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)

embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
self.embed.data.copy_(embed_normalized)

@autocast('cuda', enabled = False)
def forward(
self,
Expand Down Expand Up @@ -551,11 +559,9 @@ def forward(

ema_inplace(self.embed_avg.data, embed_sum, self.decay)

cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)

embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
if not self.manual_ema_update:
self.update_ema()
self.expire_codes_(x)

if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
Expand All @@ -582,11 +588,14 @@ def __init__(
gumbel_sample = gumbel_sample,
sample_codebook_temp = 1.,
ema_update = True,
manual_ema_update = False
):
super().__init__()
self.transform_input = l2norm

self.ema_update = ema_update
self.manual_ema_update = manual_ema_update

self.decay = decay

if not kmeans_init:
Expand Down Expand Up @@ -671,6 +680,14 @@ def expire_codes_(self, batch_samples):
batch_samples = rearrange(batch_samples, 'h ... d -> h (...) d')
self.replace(batch_samples, batch_mask = expired_codes)

def update_ema(self):
cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)

embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
embed_normalized = l2norm(embed_normalized)

self.embed.data.copy_(embed_normalized)

@autocast('cuda', enabled = False)
def forward(
self,
Expand Down Expand Up @@ -746,13 +763,9 @@ def forward(

ema_inplace(self.embed_avg.data, embed_sum, self.decay)

cluster_size = laplace_smoothing(self.cluster_size, self.codebook_size, self.eps) * self.cluster_size.sum(dim = -1, keepdim = True)

embed_normalized = self.embed_avg / rearrange(cluster_size, '... -> ... 1')
embed_normalized = l2norm(embed_normalized)

self.embed.data.copy_(embed_normalized)
self.expire_codes_(x)
if not self.manual_ema_update:
self.update_ema()
self.expire_codes_(x)

if needs_codebook_dim:
quantize, embed_ind = map(lambda t: rearrange(t, '1 ... -> ...'), (quantize, embed_ind))
Expand Down Expand Up @@ -802,6 +815,7 @@ def __init__(
sync_codebook = None,
sync_affine_param = False,
ema_update = True,
manual_ema_update = False,
learnable_codebook = False,
in_place_codebook_optimizer: Callable[..., Optimizer] = None, # Optimizer used to update the codebook embedding if using learnable_codebook
affine_param = False,
Expand Down Expand Up @@ -881,7 +895,8 @@ def __init__(
learnable_codebook = has_codebook_orthogonal_loss or learnable_codebook,
sample_codebook_temp = sample_codebook_temp,
gumbel_sample = gumbel_sample_fn,
ema_update = ema_update
ema_update = ema_update,
manual_ema_update = manual_ema_update
)

if affine_param:
Expand Down

0 comments on commit b5cb143

Please sign in to comment.