Skip to content

Commit

Permalink
Distributed LDA in different network segments without broadcast (#782)
Browse files Browse the repository at this point in the history
  • Loading branch information
menshikh-iv authored and tmylk committed Jul 13, 2016
1 parent 9c1d9b7 commit 6a289fe
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 78 deletions.
75 changes: 44 additions & 31 deletions gensim/models/lda_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@


from __future__ import with_statement
import os, sys, logging, threading, time
import argparse
import os
import sys
import logging
import threading
import time
from six import iteritems, itervalues

try:
from Queue import Queue
except ImportError:
from queue import Queue
import Pyro4
from gensim import utils
from gensim.models.lda_worker import LDA_WORKER_PREFIX


logger = logging.getLogger("gensim.models.lda_dispatcher")
Expand All @@ -39,6 +46,7 @@
# so this is really just a hack, see http://bugs.python.org/issue1360
HUGE_TIMEOUT = 365 * 24 * 60 * 60 # one year

LDA_DISPATCHER_PREFIX = 'gensim.lda_dispatcher'


class Dispatcher(object):
Expand All @@ -48,15 +56,16 @@ class Dispatcher(object):
There should never be more than one dispatcher running at any one time.
"""

def __init__(self, maxsize=MAX_JOBS_QUEUE):
def __init__(self, maxsize=MAX_JOBS_QUEUE, ns_conf={}):
"""
Note that the constructor does not fully initialize the dispatcher;
use the `initialize()` function to populate it with workers etc.
"""
self.maxsize = maxsize
self.callback = None # a pyro proxy to this object (unknown at init time, but will be set later)
self.ns_conf = ns_conf


@Pyro4.expose
def initialize(self, **model_params):
"""
`model_params` are parameters used to initialize individual workers (gets
Expand All @@ -69,10 +78,9 @@ def initialize(self, **model_params):

# locate all available workers and store their proxies, for subsequent RMI calls
self.workers = {}
import Pyro4
with utils.getNS() as ns:
self.callback = Pyro4.Proxy('PYRONAME:gensim.lda_dispatcher') # = self
for name, uri in iteritems(ns.list(prefix='gensim.lda_worker')):
with utils.getNS(**self.ns_conf) as ns:
self.callback = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX])
for name, uri in iteritems(ns.list(prefix=LDA_WORKER_PREFIX)):
try:
worker = Pyro4.Proxy(uri)
workerid = len(self.workers)
Expand All @@ -87,27 +95,27 @@ def initialize(self, **model_params):
if not self.workers:
raise RuntimeError('no workers found; run some lda_worker scripts on your machines first!')


@Pyro4.expose
def getworkers(self):
"""
Return pyro URIs of all registered workers.
"""
return [worker._pyroUri for worker in itervalues(self.workers)]


@Pyro4.expose
def getjob(self, worker_id):
logger.info("worker #%i requesting a new job" % worker_id)
job = self.jobs.get(block=True, timeout=1)
logger.info("worker #%i got a new job (%i left)" % (worker_id, self.jobs.qsize()))
return job


@Pyro4.expose
def putjob(self, job):
self._jobsreceived += 1
self.jobs.put(job, block=True, timeout=HUGE_TIMEOUT)
logger.info("added a new job (len(queue)=%i items)" % self.jobs.qsize())


@Pyro4.expose
def getstate(self):
"""
Merge states from across all workers and return the result.
Expand All @@ -126,7 +134,7 @@ def getstate(self):
logger.info("sending out merged state")
return result


@Pyro4.expose
def reset(self, state):
"""
Initialize all workers for a new EM iterations.
Expand All @@ -138,7 +146,7 @@ def reset(self, state):
self._jobsdone = 0
self._jobsreceived = 0


@Pyro4.expose
@Pyro4.oneway
@utils.synchronous('lock_update')
def jobdone(self, workerid):
Expand Down Expand Up @@ -172,25 +180,30 @@ def exit(self):
#endclass Dispatcher



def main():
logging.basicConfig(format = '%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
logger.info("running %s" % " ".join(sys.argv))

program = os.path.basename(sys.argv[0])
# make sure we have enough cmd line parameters
if len(sys.argv) < 1:
print(globals()["__doc__"] % locals())
sys.exit(1)

if len(sys.argv) < 2:
maxsize = MAX_JOBS_QUEUE
else:
maxsize = int(sys.argv[1])
utils.pyro_daemon('gensim.lda_dispatcher', Dispatcher(maxsize=maxsize))

logger.info("finished running %s" % program)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--maxsize", help="How many jobs (=chunks of N documents) "
"to keep 'pre-fetched' in a queue (default: %(default)s)",
type=int, default=MAX_JOBS_QUEUE)
parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None)
parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int)
parser.add_argument("--no-broadcast", help="Disable broadcast (default: %(default)s)",
action='store_const', default=True, const=False)
parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None)
parser.add_argument('-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel",
const=logging.INFO, default=logging.WARNING)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel)
logger.info("running %s", " ".join(sys.argv))

ns_conf = {"broadcast": args.no_broadcast,
"host": args.host,
"port": args.port,
"hmac_key": args.hmac}
utils.pyro_daemon(LDA_DISPATCHER_PREFIX, Dispatcher(maxsize=args.maxsize, ns_conf=ns_conf), ns_conf=ns_conf)

logger.info("finished running %s", " ".join(sys.argv))


if __name__ == '__main__':
Expand Down
54 changes: 32 additions & 22 deletions gensim/models/lda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
USAGE: %(program)s
Worker ("slave") process used in computing distributed LDA. Run this script \
Worker ("slave") process used in computing distributed LDA. Run this script \
on every node in your cluster. If you wish, you may even run it multiple times \
on a single machine, to make better use of multiple cores (just beware that \
memory footprint increases accordingly).
Expand All @@ -17,9 +15,13 @@


from __future__ import with_statement
import os, sys, logging
import os
import sys
import logging
import threading
import tempfile
import argparse

try:
import Queue
except ImportError:
Expand All @@ -34,13 +36,14 @@
# periodically save intermediate models after every SAVE_DEBUG updates (0 for never)
SAVE_DEBUG = 0

LDA_WORKER_PREFIX = 'gensim.lda_worker'


class Worker(object):
def __init__(self):
self.model = None


@Pyro4.expose
def initialize(self, myid, dispatcher, **model_params):
self.lock_update = threading.Lock()
self.jobsdone = 0 # how many jobs has this worker completed?
Expand All @@ -50,7 +53,7 @@ def initialize(self, myid, dispatcher, **model_params):
logger.info("initializing worker #%s" % myid)
self.model = ldamodel.LdaModel(**model_params)


@Pyro4.expose
@Pyro4.oneway
def requestjob(self):
"""
Expand Down Expand Up @@ -84,7 +87,7 @@ def processjob(self, job):
self.model.save(fname)
logger.info("finished processing job #%i" % (self.jobsdone - 1))


@Pyro4.expose
@utils.synchronous('lock_update')
def getstate(self):
logger.info("worker #%i returning its state after %s jobs" %
Expand All @@ -95,7 +98,7 @@ def getstate(self):
self.finished = True
return result


@Pyro4.expose
@utils.synchronous('lock_update')
def reset(self, state):
assert state is not None
Expand All @@ -113,21 +116,28 @@ def exit(self):
#endclass Worker



def main():
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
logger.info("running %s" % " ".join(sys.argv))

program = os.path.basename(sys.argv[0])
# make sure we have enough cmd line parameters
if len(sys.argv) < 1:
print(globals()["__doc__"] % locals())
sys.exit(1)

utils.pyro_daemon('gensim.lda_worker', Worker(), random_suffix=True)

logger.info("finished running %s" % program)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None)
parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int)
parser.add_argument("--no-broadcast", help="Disable broadcast (default: %(default)s)",
action='store_const', default=True, const=False)
parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None)
parser.add_argument('-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel",
const=logging.INFO, default=logging.WARNING)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel)
logger.info("running %s", " ".join(sys.argv))

ns_conf = {"broadcast": args.no_broadcast,
"host": args.host,
"port": args.port,
"hmac_key": args.hmac}

utils.pyro_daemon(LDA_WORKER_PREFIX, Worker(), random_suffix=True, ns_conf=ns_conf)

logger.info("finished running %s", " ".join(sys.argv))


if __name__ == '__main__':
Expand Down
17 changes: 9 additions & 8 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
distributed=False, chunksize=2000, passes=1, update_every=1,
alpha='symmetric', eta=None, decay=0.5, offset=1.0,
eval_every=10, iterations=50, gamma_threshold=0.001,
minimum_probability=0.01, random_state=None):
minimum_probability=0.01, random_state=None, ns_conf={}):
"""
If given, start training from the iterable `corpus` straight away. If not given,
the model is left untrained (presumably because you want to call `update()` manually).
Expand Down Expand Up @@ -325,13 +325,14 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
# set up distributed version
try:
import Pyro4
dispatcher = Pyro4.Proxy('PYRONAME:gensim.lda_dispatcher')
logger.debug("looking for dispatcher at %s" % str(dispatcher._pyroUri))
dispatcher.initialize(id2word=self.id2word, num_topics=self.num_topics,
chunksize=chunksize, alpha=alpha, eta=eta, distributed=False)
self.dispatcher = dispatcher
self.numworkers = len(dispatcher.getworkers())
logger.info("using distributed version with %i workers" % self.numworkers)
with utils.getNS(**ns_conf) as ns:
from gensim.models.lda_dispatcher import LDA_DISPATCHER_PREFIX
self.dispatcher = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX])
logger.debug("looking for dispatcher at %s" % str(self.dispatcher._pyroUri))
self.dispatcher.initialize(id2word=self.id2word, num_topics=self.num_topics,
chunksize=chunksize, alpha=alpha, eta=eta, distributed=False)
self.numworkers = len(self.dispatcher.getworkers())
logger.info("using distributed version with %i workers" % self.numworkers)
except Exception as err:
logger.error("failed to initialize distributed LDA (%s)", err)
raise RuntimeError("failed to initialize distributed LDA (%s)" % err)
Expand Down
23 changes: 6 additions & 17 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,29 +971,18 @@ def upload_chunked(server, docs, chunksize=1000, preprocess=None):
start = end


def getNS():
def getNS(host=None, port=None, broadcast=True, hmac_key=None):
"""
Return a Pyro name server proxy. If there is no name server running,
start one on 0.0.0.0 (all interfaces), as a background process.
Return a Pyro name server proxy.
"""
import Pyro4
try:
return Pyro4.locateNS()
return Pyro4.locateNS(host, port, broadcast, hmac_key)
except Pyro4.errors.NamingError:
logger.info("Pyro name server not found; starting a new one")
os.system("python -m Pyro4.naming -n 0.0.0.0 &")
# TODO: spawn a proper daemon ala http://code.activestate.com/recipes/278731/ ?
# like this, if there's an error somewhere, we'll never know... (and the loop
# below will block). And it probably doesn't work on windows, either.
while True:
try:
return Pyro4.locateNS()
except:
pass
raise RuntimeError("Pyro name server not found")


def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None):
def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None, ns_conf={}):
"""
Register object with name server (starting the name server if not running
yet) and block until the daemon is terminated. The object is registered under
Expand All @@ -1003,7 +992,7 @@ def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None):
if random_suffix:
name += '.' + hex(random.randint(0, 0xffffff))[2:]
import Pyro4
with getNS() as ns:
with getNS(**ns_conf) as ns:
with Pyro4.Daemon(ip or get_my_ip(), port or 0) as daemon:
# register server for remote access
uri = daemon.register(obj, name)
Expand Down

0 comments on commit 6a289fe

Please sign in to comment.