Skip to content

Commit

Permalink
new dataset pipeline, test_engine_train_new_dataset_pipeline
Browse files Browse the repository at this point in the history
Seems to work. At least for this simple case.
rwth-i6#292
  • Loading branch information
albertz authored and Spotlight0xff committed Sep 5, 2020
1 parent b740abf commit 2cf9d0c
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tests/test_TFEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,36 @@ def test_engine_train():
engine.finalize()



def test_engine_train_new_dataset_pipeline():
from GeneratingDataset import DummyDataset
seq_len = 5
n_data_dim = 2
n_classes_dim = 3
train_data = DummyDataset(input_dim=n_data_dim, output_dim=n_classes_dim, num_seqs=5, seq_len=seq_len)
train_data.init_seq_order(epoch=1)
cv_data = DummyDataset(input_dim=n_data_dim, output_dim=n_classes_dim, num_seqs=3, seq_len=seq_len)
cv_data.init_seq_order(epoch=1)

config = Config()
config.update({
"model": "%s/model" % _get_tmp_dir(),
"num_outputs": n_classes_dim,
"num_inputs": n_data_dim,
"network": {"output": {"class": "softmax", "loss": "ce"}},
"start_epoch": 1,
"num_epochs": 2,
"max_seqs": 2,
"dataset_pipeline": True
})
_cleanup_old_models(config)
engine = Engine(config=config)
engine.init_train_from_config(config=config, train_data=train_data, dev_data=cv_data, eval_data=None)
assert engine.dataset_provider
engine.train()
engine.finalize()


def test_engine_train_uneven_batches():
rnd = numpy.random.RandomState(42)
from GeneratingDataset import StaticDataset
Expand Down

0 comments on commit 2cf9d0c

Please sign in to comment.