-
Notifications
You must be signed in to change notification settings - Fork 103
/
deepdist.py
127 lines (106 loc) · 4.11 KB
/
deepdist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import copy
import cPickle as pickle
from multiprocessing import Process
from rwlock import RWLock
import socket
import sys
from threading import Thread
import urllib2
import urlparse
"""Lightning-Fast Deep Learning on Spark
"""
class DeepDist:
def __init__(self, model, master='127.0.0.1:5000', min_updates=0, max_updates=4096):
"""DeepDist - Distributed deep learning.
:param model: provide a model that can be trained in parallel on the workers
"""
self.model = model
self.lock = RWLock()
self.descent = lambda model, gradient: model
self.master = master
self.state = 'serving'
self.served = 0
self.received = 0
#self.server = None
self.pmodel = None
self.min_updates = min_updates
self.max_updates = max_updates
def __enter__(self):
Thread(target=self.start).start()
# self.server = Process(target=self.start)
# self.server.start()
return self
def __exit__(self, type, value, traceback):
# self.server.terminate()
pass # need to shut down server here
def start(self):
from flask import Flask, request
app = Flask(__name__)
@app.route('/')
def index():
return 'DeepDist'
@app.route('/model', methods=['GET', 'POST', 'PUT'])
def model_flask():
i = 0
while (self.state != 'serving' or self.served >= self.max_updates) and (i < 1000):
time.sleep(1)
i += 1
# pickle on first read
pmodel = None
self.lock.acquire_read()
if not self.pmodel:
self.lock.release()
self.lock.acquire_write()
if not self.pmodel:
self.pmodel = pickle.dumps(self.model, -1)
self.served += 1
pmodel = self.pmodel
self.lock.release()
else:
self.served += 1
pmodel = self.pmodel
self.lock.release()
return pmodel
@app.route('/update', methods=['GET', 'POST', 'PUT'])
def update_flask():
gradient = pickle.loads(request.data)
self.lock.acquire_write()
if self.min_updates <= self.served:
state = 'receiving'
self.received += 1
self.descent(self.model, gradient)
if self.received >= self.served and self.min_updates <= self.received:
self.received = 0
self.served = 0
self.state = 'serving'
self.pmodel = None
self.lock.release()
return 'OK'
print 'Listening to 0.0.0.0:5000...'
app.run(host='0.0.0.0', debug=True, threaded=True, use_reloader=False)
def train(self, rdd, gradient, descent):
master = self.master # will be pickled
if master == None:
master = rdd.ctx._conf.get('spark.master')
if master.startswith('local['):
master = 'localhost:5000'
else:
if master.startswith('spark://'):
master = '%s:5000' % urlparse.urlparse(master).netloc.split(':')[0]
else:
master = '%s:5000' % master.split(':')[0]
print '\n*** Master: %s\n' % master
self.descent = descent
def mapPartitions(data):
return [send_gradient(gradient(fetch_model(master=master), data), master=master)]
return rdd.mapPartitions(mapPartitions).collect()
def fetch_model(master='localhost:5000'):
request = urllib2.Request('http://%s/model' % master,
headers={'Content-Type': 'application/deepdist'})
return pickle.loads(urllib2.urlopen(request).read())
def send_gradient(gradient, master='localhost:5000'):
if not gradient:
return 'EMPTY'
request = urllib2.Request('http://%s/update' % master, pickle.dumps(gradient, -1),
headers={'Content-Type': 'application/deepdist'})
return urllib2.urlopen(request).read()