Unofficial implementation of Using Trusted Data to Train Deep Networks on Labels Corrupted by Severe Noise (NIPS 18) in PyTorch.
(See example.ipynb for a walkthrough on MNIST)
from datasets import GoldCorrectionDataset
from glc import CorrectionGenerator, GoldCorrectionLossFunction
c_gen = CorrectionGenerator(simulate=True, dataset=trn_ds, randomization_strength=1.0)
# Fetch both corrupted and clean datasets if in simuate mode
trusted_dataset, untrusted_dataset = c_gen.fetch_datasets()
"""
Train the model on untrusted_dataset
"""
# Generate correction matrix
label_correction_matrix = c_gen.generate_correction_matrix(trainer.model, 32)
# Wrap trusted and untrusted dataset together using GoldCorrectionDataset class
gold_ds = GoldCorrectionDataset(trusted_dataset, untrusted_dataset)
gold_dl = DataLoader(gold_ds, batch_size=32, shuffle=True)
# Modified loss function
gold_loss = GoldCorrectionLossFunction(label_correction_matrix)
"""
Train using gold_ds and gold_loss the model, until convergence
"""