diff --git a/plsc/entry.py b/plsc/entry.py index 748ef946944f5..b9b76fd17b51e 100644 --- a/plsc/entry.py +++ b/plsc/entry.py @@ -18,6 +18,7 @@ import errno import json import os +import math import shutil import subprocess import sys @@ -142,7 +143,7 @@ def __init__(self): self.log_period = 200 self.input_info = [{'name': 'image', - 'shape': [-1, 3, 224, 224], + 'shape': [-1, 3, 112, 112], 'dtype': 'float32'}, {'name': 'label', 'shape':[-1, 1], @@ -957,9 +958,8 @@ def train(self): self.load_checkpoint(executor=exe, main_program=origin_prog) if self.train_reader is None: - train_reader = paddle.batch(reader.arc_train( - self.dataset_dir, self.num_classes), - batch_size=self.train_batch_size) + train_reader = reader.arc_train( + self.dataset_dir, self.num_classes) else: train_reader = self.train_reader