Skip to content

Commit

Permalink
Merge pull request #446 from QiJune/format_py_code_2nd
Browse files Browse the repository at this point in the history
format python code in python directory
  • Loading branch information
reyoung authored Nov 12, 2016
2 parents ef5e483 + a1ba3f4 commit 58e1b3b
Show file tree
Hide file tree
Showing 54 changed files with 3,498 additions and 2,926 deletions.
1 change: 0 additions & 1 deletion python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

35 changes: 19 additions & 16 deletions python/paddle/trainer/PyDataProvider2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@
import functools
import itertools

logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
logging.basicConfig(format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")


class SequenceType(object):
Expand Down Expand Up @@ -132,8 +131,10 @@ def __init__(self, generator, input_order):
def __call__(self, obj, filename):
for item in self.generator(obj, filename):
if isinstance(item, dict):
yield [item.get(input_name, None) for input_name in
self.input_order]
yield [
item.get(input_name, None)
for input_name in self.input_order
]
else:
yield item

Expand Down Expand Up @@ -162,8 +163,8 @@ def __call__(self, obj, filename):
yield items
except AssertionError as e:
self.logger.warning(
"Item (%s) is not fit the input type with error %s"
% (repr(item), repr(e)))
"Item (%s) is not fit the input type with error %s" %
(repr(item), repr(e)))

if self.check_fail_continue:
continue
Expand Down Expand Up @@ -202,13 +203,17 @@ def loop_check(callback, item):
callback(each)


def provider(input_types=None, should_shuffle=None, pool_size=-1,
def provider(input_types=None,
should_shuffle=None,
pool_size=-1,
min_pool_size=-1,
can_over_batch_size=True,
calc_batch_size=None,
cache=CacheType.NO_CACHE,
check=False, check_fail_continue=False,
init_hook=None, **kwargs):
check=False,
check_fail_continue=False,
init_hook=None,
**kwargs):
"""
Provider decorator. Use it to make a function into PyDataProvider2 object.
In this function, user only need to get each sample for some train/test
Expand Down Expand Up @@ -318,9 +323,9 @@ def __init__(self, file_list, **kwargs):
"Could not recognize should_shuffle (%s), "
"just use default value of should_shuffle."
" Please set should_shuffle to bool value or "
"something in %s" % (
repr(self.should_shuffle),
repr(true_table + false_table)))
"something in %s" %
(repr(self.should_shuffle),
repr(true_table + false_table)))
self.should_shuffle = None

self.pool_size = pool_size
Expand Down Expand Up @@ -351,8 +356,7 @@ def __init__(self, file_list, **kwargs):
self.generator = InputOrderWrapper(self.generator,
self.input_order)
if self.check:
self.generator = CheckWrapper(self.generator,
self.slots,
self.generator = CheckWrapper(self.generator, self.slots,
check_fail_continue,
self.logger)

Expand All @@ -368,4 +372,3 @@ def deserialize_args(args):
:return:
"""
return cPickle.loads(args)

35 changes: 22 additions & 13 deletions python/paddle/trainer/PyDataProviderWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module provide a wrapper(decorator) to wrap a data process method into a
PyDataProvider. Some examples are shown `here <data_provider/python_case.html>`_.
Expand Down Expand Up @@ -47,6 +46,7 @@

import io


class SlotType(object): # Just a hint for user.
pass

Expand Down Expand Up @@ -83,6 +83,7 @@ class SparseNonValueSlot(SlotType):
- **SubSeq**: [[[int, int, ...], [int, ....], ...] , \
[[int, int, ...], [int, ....], ...] , ...]
"""

def __init__(self, dim):
"""
:param dim: slot dimension
Expand Down Expand Up @@ -294,8 +295,9 @@ def reset(self):
fn = "%s_%d" % (self.profile_filename, self.profile_count)
sortby = "cumulative"
with open(fn, "w") as f:
pstats.Stats(self.profiler, stream=f).sort_stats(
sortby).print_stats()
pstats.Stats(
self.profiler,
stream=f).sort_stats(sortby).print_stats()
self.logger.info("saving profile to file %s" % fn)
self.profile_count += 1
self.logger.info("resetting profile")
Expand Down Expand Up @@ -453,9 +455,10 @@ def writeDataStream(dat, data_callback):
seq_stream.flush()
subseq_stream.flush()

return "".join([self.int_packer.pack(current_batch_size),
data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()])
return "".join([
self.int_packer.pack(current_batch_size), data_bytes.getvalue(),
seq_bytes.getvalue(), subseq_bytes.getvalue()
])

finally:
data_stream.close()
Expand Down Expand Up @@ -516,7 +519,7 @@ def __prepareData(self, batch_size, ret_list):
self.data_pool[idx])
idx -= 1

ret_list += self.data_pool[self.data_pool_idx: idx + 1]
ret_list += self.data_pool[self.data_pool_idx:idx + 1]

# for speed reason, just shift left index, not delete data actually.
self.data_pool_idx = idx + 1
Expand All @@ -537,8 +540,8 @@ def fillPool(self):
if self.max_pool_size == 0:
for i in xrange(min(self.file_count, len(self.generators))):
self.data_pool += list(self.generators[i])
self.generators = self.generators[
min(self.file_count, len(self.generators)):]
self.generators = self.generators[min(self.file_count,
len(self.generators)):]
self.max_pool_size = len(self.data_pool)
else:
while len(self.data_pool) < self.max_pool_size and len(
Expand All @@ -562,9 +565,15 @@ def default_init_hook(cls, *args, **kwargs):
del cls, args, kwargs


def provider(slots=None, use_seq=False, should_shuffle=True, pool_size=1,
can_over_batch_size=True, calc_batch_size=lambda data: 1,
debug=False, init_hook=default_init_hook, profile_filename=None):
def provider(slots=None,
use_seq=False,
should_shuffle=True,
pool_size=1,
can_over_batch_size=True,
calc_batch_size=lambda data: 1,
debug=False,
init_hook=default_init_hook,
profile_filename=None):
"""
The decorator for PyDataProvider. User should use this to create Provider class.
User should only concern how to read sample from file.
Expand Down Expand Up @@ -663,7 +672,7 @@ class Cls(GeneralPyDataProvider):
def __init__(self, *file_list, **kwargs):
logging.basicConfig(
format="[%(levelname)s %(asctime)s %(filename)s:%(lineno)s]"
" %(message)s")
" %(message)s")

self.logger = logging.getLogger("")
if debug:
Expand Down
1 change: 0 additions & 1 deletion python/paddle/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading

0 comments on commit 58e1b3b

Please sign in to comment.