Skip to content

Commit

Permalink
code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Spijkervet committed Mar 19, 2021
1 parent 0e0ab5d commit 957ad02
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 30 deletions.
5 changes: 1 addition & 4 deletions clmr/datasets/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ class AUDIO(Dataset):
_ext_audio = ".wav"

def __init__(
self,
root: str,
src_ext_audio: str = ".wav",
n_classes: int = 1,
self, root: str, src_ext_audio: str = ".wav", n_classes: int = 1,
) -> None:
super(AUDIO, self).__init__(root)

Expand Down
19 changes: 17 additions & 2 deletions clmr/modules/linear_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,23 @@ def configure_optimizers(self):
lr=self.hparams.finetuner_learning_rate,
weight_decay=self.hparams.weight_decay,
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.1,
patience=5,
threshold=0.0001,
threshold_mode="rel",
cooldown=0,
min_lr=0,
eps=1e-08,
verbose=False,
)
if scheduler:
return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Valid/loss"}
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "Valid/loss",
}
else:
return {"optimizer": optimizer}
17 changes: 4 additions & 13 deletions linear_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,7 @@
)

contrastive_test_dataset = ContrastiveDataset(
test_dataset,
input_shape=(1, args.audio_length),
transform=None,
test_dataset, input_shape=(1, args.audio_length), transform=None,
)

train_loader = DataLoader(
Expand Down Expand Up @@ -102,22 +100,15 @@
cl.freeze()

module = LinearEvaluation(
args,
cl.encoder,
hidden_dim=n_features,
output_dim=train_dataset.n_classes,
args, cl.encoder, hidden_dim=n_features, output_dim=train_dataset.n_classes,
)


if args.finetuner_checkpoint_path:
state_dict = load_finetuner_checkpoint(args.finetuner_checkpoint_path)
module.model.load_state_dict(state_dict)
else:
early_stop_callback = EarlyStopping(
monitor='Valid/loss',
patience=10,
verbose=False,
mode='min'
monitor="Valid/loss", patience=10, verbose=False, mode="min"
)

trainer = Trainer.from_argparse_args(
Expand All @@ -126,7 +117,7 @@
"runs", name="CLMRv2-eval-{}".format(args.dataset)
),
max_epochs=args.finetuner_max_epochs,
callbacks=[early_stop_callback]
callbacks=[early_stop_callback],
)
trainer.fit(module, train_loader, valid_loader)

Expand Down
7 changes: 2 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
RandomApply(
[
PitchShift(
n_samples=args.audio_length,
sample_rate=args.sample_rate,
n_samples=args.audio_length, sample_rate=args.sample_rate,
)
],
p=args.transforms_pitch,
Expand Down Expand Up @@ -156,9 +155,7 @@
test_dataset = get_dataset(args.dataset, args.dataset_dir, subset="test")

contrastive_test_dataset = ContrastiveDataset(
test_dataset,
input_shape=(1, args.audio_length),
transform=None,
test_dataset, input_shape=(1, args.audio_length), transform=None,
)

device = "cuda:0" if args.gpus else "cpu"
Expand Down
4 changes: 1 addition & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,5 @@ def run(self):
"Programming Language :: Python :: Implementation :: PyPy",
],
# $ setup.py publish support.
cmdclass={
"upload": UploadCommand,
},
cmdclass={"upload": UploadCommand,},
)
4 changes: 1 addition & 3 deletions tests/test_spectogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def test_audioset(self):

spec_transform = nn.Sequential(
torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
n_mels=n_mels,
sample_rate=sample_rate, n_fft=n_fft, n_mels=n_mels,
),
torchaudio.transforms.AmplitudeToDB(stype=stype, top_db=top_db),
)
Expand Down

0 comments on commit 957ad02

Please sign in to comment.