Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify parallel part for data processing and fix abnormal exit. #58

Merged
merged 1 commit into from
Dec 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions data_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,7 @@ def reader():

reader, cleanup_callback = xmap_readers_mp(
lambda instance: self.process_utterance(instance["audio_filepath"], instance["text"]),
reader,
self._num_threads,
4096,
order=True)
reader, self._num_threads, 4096)

# register callback to main process
atexit.register(cleanup_callback)
Expand Down
68 changes: 41 additions & 27 deletions data_utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from Queue import Queue
from threading import Thread
from multiprocessing import Process, Manager
from multiprocessing import Process, Manager, Value
from paddle.v2.dataset.common import md5file


Expand Down Expand Up @@ -101,40 +101,35 @@ def xmap_readers_mp(mapper, reader, process_num, buffer_size, order=False):
:type process_num: int
:param buffer_size: Maximal buffer size.
:type buffer_size: int
:param order: Reserve the order of samples from the given reader.
:type order: bool
:return: The wrappered reader
:rtype: callable
:return: The wrappered reader and cleanup callback
:rtype: tuple
"""
end_flag = XmapEndSignal()

# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
for sample in reader():
in_queue.put(sample)
in_queue.put(end_flag)
read_workers = []
handle_workers = []
flush_workers = []

read_exit_flag = Value('i', 0)
handle_exit_flag = Value('i', 0)
flush_exit_flag = Value('i', 0)

# define a worker to read samples from reader to in_queue with order flag
def order_read_worker(reader, in_queue):
for order_id, sample in enumerate(reader()):
if read_exit_flag.value == 1: break
in_queue.put((order_id, sample))
in_queue.put(end_flag)

# define a worker to handle samples from in_queue by mapper and put results
# to out_queue
def handle_worker(in_queue, out_queue, mapper):
sample = in_queue.get()
while not isinstance(sample, XmapEndSignal):
out_queue.put(mapper(sample))
sample = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# the reading worker should not exit until all handling work exited
while handle_exit_flag.value == 0 or read_exit_flag.value == 0:
time.sleep(0.001)

# define a worker to handle samples from in_queue by mapper and put results
# to out_queue with order
def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
while not isinstance(ins, XmapEndSignal):
if handle_exit_flag.value == 1: break
order_id, sample = ins
result = mapper(sample)
while order_id != out_order[0]:
Expand All @@ -144,22 +139,39 @@ def order_handle_worker(in_queue, out_queue, mapper, out_order):
ins = in_queue.get()
in_queue.put(end_flag)
out_queue.put(end_flag)
# wait for exit of flushing worker
while flush_exit_flag.value == 0 or handle_exit_flag.value == 0:
time.sleep(0.001)
read_exit_flag.value = 1
handle_exit_flag.value = 1

# define a thread worker to flush samples from Manager.Queue to Queue
# for acceleration
def flush_worker(in_queue, out_queue):
finish = 0
while finish < process_num:
while finish < process_num and flush_exit_flag.value == 0:
sample = in_queue.get()
if isinstance(sample, XmapEndSignal):
finish += 1
else:
out_queue.put(sample)
out_queue.put(end_flag)
handle_exit_flag.value = 1
flush_exit_flag.value = 1

def cleanup():
# kill all sub process and threads
os._exit(0)
# first exit flushing workers
flush_exit_flag.value = 1
for w in flush_workers:
w.join()
# next exit handling workers
handle_exit_flag.value = 1
for w in handle_workers:
w.join()
# last exit reading workers
read_exit_flag.value = 1
for w in read_workers:
w.join()

def xreader():
# prepare shared memory
Expand All @@ -169,27 +181,29 @@ def xreader():
out_order = manager.list([0])

# start a read worker in a process
target = order_read_worker if order else read_worker
target = order_read_worker
p = Process(target=target, args=(reader, in_queue))
p.daemon = True
p.start()
read_workers.append(p)

# start handle_workers with multiple processes
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
target = order_handle_worker
args = (in_queue, out_queue, mapper, out_order)
workers = [
Process(target=target, args=args) for _ in xrange(process_num)
]
for w in workers:
w.daemon = True
w.start()
handle_workers.append(w)

# start a thread to read data from slow Manager.Queue
flush_queue = Queue(buffer_size)
t = Thread(target=flush_worker, args=(out_queue, flush_queue))
t.daemon = True
t.start()
flush_workers.append(t)

# get results
sample = flush_queue.get()
Expand Down