Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Van Niekerk committed May 9, 2022
1 parent a4674a0 commit edce489
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 96 deletions.
24 changes: 4 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ units = kmeans.predict(x.squeeze().cpu().numpy())

**Step 1**: Download and extract the [LibriSpeech](https://www.openslr.org/12) corpus.

**Step 2**: Encode LibriSpeech using the HuBERT-Discrete model and `encode.py` script (setting `--layer=7`):
**Step 2**: Encode LibriSpeech using the HuBERT-Discrete model and `encode.py` script:

```
usage: encode.py [-h] [--extension EXTENSION] [--model {hubert_soft,hubert_discrete}] [--layer LAYER] in-dir out-dir
usage: encode.py [-h] [--extension EXTENSION] [--model {hubert_soft,hubert_discrete}] in-dir out-dir
Encode an audio dataset.
Expand All @@ -73,31 +73,15 @@ optional arguments:
extension of the audio files.
--model {hubert_soft,hubert_discrete}
available models
--layer LAYER the selected transformer layer (defaults to the last layer)
```

for example:

```
python encode.py path/to/LibriSpeech path/to/LibriSpeech/
python encode.py path/to/LibriSpeech/wavs path/to/LibriSpeech/units --model hubert_discrete
```

**Step 3**: Discretize the extracted features using the k-means checkpoint and `discretize.py` script:

```
usage: discretize.py [-h] in-dir out-dir
Discretize HuBERT features.
positional arguments:
in-dir path to the dataset directory.
out-dir path to the output directory.
optional arguments:
-h, --help show this help message and exit
```

**Step 5**: Train the HuBERT-Soft model using the `train.py` script:
**Step 3**: Train the HuBERT-Soft model using the `train.py` script:

```
usage: train.py [-h] [--resume RESUME] [--warmstart] [--mask] [--alpha ALPHA] dataset-dir checkpoint-dir
Expand Down
42 changes: 0 additions & 42 deletions discretize.py

This file was deleted.

15 changes: 2 additions & 13 deletions encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from tqdm import tqdm

import torch
import torch.nn.functional as F
import torchaudio
from torchaudio.functional import resample

Expand All @@ -22,17 +21,13 @@ def encode_dataset(args):
wav, sr = torchaudio.load(in_path)
wav = resample(wav, sr, 16000)
wav = wav.unsqueeze(0).cuda()
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))

# Extract hubert features from the args.layer transformer layer
with torch.inference_mode():
x, _ = hubert.encode(wav, layer=args.layer)
if args.layer is None:
x = hubert.proj(x)
units = hubert.units(wav)

out_path = args.out_dir / in_path.relative_to(args.in_dir)
out_path.parent.mkdir(parents=True, exist_ok=True)
np.save(out_path.with_suffix(".npy"), x.squeeze(0).cpu().numpy())
np.save(out_path.with_suffix(".npy"), units.squeeze().cpu().numpy())


if __name__ == "__main__":
Expand Down Expand Up @@ -61,11 +56,5 @@ def encode_dataset(args):
choices=["hubert_soft", "hubert_discrete"],
default="hubert_soft",
)
parser.add_argument(
"--layer",
help="the selected transformer layer (defaults to the last layer)",
default=None,
type=int,
)
args = parser.parse_args()
encode_dataset(args)
9 changes: 8 additions & 1 deletion hubert/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
from .model import Hubert, hubert_discrete, hubert_soft, kmeans100
from .model import (
Hubert,
HubertDiscrete,
HubertSoft,
hubert_discrete,
hubert_soft,
kmeans100,
)
64 changes: 44 additions & 20 deletions hubert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class Hubert(nn.Module):
def __init__(self, num_label_embeddings: int = 100, mask=True):
def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
super().__init__()
self._mask = mask
self.feature_extractor = FeatureExtractor()
Expand Down Expand Up @@ -69,6 +69,28 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return logits, mask


class HubertSoft(Hubert):
def __init__(self):
super().__init__()

def units(self, wav: torch.Tensor) -> torch.Tensor:
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
x, _ = self.encode(wav)
return self.proj(x)


class HubertDiscrete(Hubert):
def __init__(self, kmeans):
super().__init__()
self.kmeans = kmeans

def units(self, wav: torch.Tensor) -> torch.LongTensor:
wav = F.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
x, _ = self.encode(wav, layer=7)
x = self.kmeans.predict(x.squeeze().cpu().numpy())
return torch.tensor(x, dtype=torch.long, device=wav.device)


class FeatureExtractor(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -204,43 +226,45 @@ def _compute_mask(
return mask


def _hubert(
name: str,
num_label_embeddings: int,
pretrained: bool = True,
progress: bool = True,
) -> Hubert:
hubert = Hubert(num_label_embeddings)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
hubert.load_state_dict(checkpoint)
hubert.eval()
return hubert


def hubert_discrete(
pretrained: bool = True,
progress: bool = True,
) -> Hubert:
) -> HubertDiscrete:
r"""HuBERT-Discrete from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
Args:
pretrained (bool): load pretrained weights into the model
progress (bool): show progress bar when downloading model
"""
return _hubert("hubert-discrete", 504, pretrained, progress)
kmeans = kmeans100(pretrained=pretrained, progress=progress)
hubert = HubertDiscrete(kmeans)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS["hubert-discrete"], progress=progress
)
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
hubert.load_state_dict(checkpoint)
hubert.eval()
return hubert


def hubert_soft(
pretrained: bool = True,
progress: bool = True,
) -> Hubert:
) -> HubertSoft:
r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
Args:
pretrained (bool): load pretrained weights into the model
progress (bool): show progress bar when downloading model
"""
return _hubert("hubert-soft", 100, pretrained, progress)
hubert = HubertSoft()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
URLS["hubert-soft"], progress=progress
)
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
hubert.load_state_dict(checkpoint)
hubert.eval()
return hubert


def _kmeans(
Expand Down

0 comments on commit edce489

Please sign in to comment.