From 6a289fefd72f038c8cc14826f63624950f5de1f8 Mon Sep 17 00:00:00 2001 From: Menshikh Ivan Date: Wed, 13 Jul 2016 21:22:38 +0500 Subject: [PATCH] Distributed LDA in different network segments without broadcast (#782) --- gensim/models/lda_dispatcher.py | 75 +++++++++++++++++++-------------- gensim/models/lda_worker.py | 54 ++++++++++++++---------- gensim/models/ldamodel.py | 17 ++++---- gensim/utils.py | 23 +++------- 4 files changed, 91 insertions(+), 78 deletions(-) diff --git a/gensim/models/lda_dispatcher.py b/gensim/models/lda_dispatcher.py index 91da236121..484e497b5f 100755 --- a/gensim/models/lda_dispatcher.py +++ b/gensim/models/lda_dispatcher.py @@ -15,14 +15,21 @@ from __future__ import with_statement -import os, sys, logging, threading, time +import argparse +import os +import sys +import logging +import threading +import time from six import iteritems, itervalues + try: from Queue import Queue except ImportError: from queue import Queue import Pyro4 from gensim import utils +from gensim.models.lda_worker import LDA_WORKER_PREFIX logger = logging.getLogger("gensim.models.lda_dispatcher") @@ -39,6 +46,7 @@ # so this is really just a hack, see http://bugs.python.org/issue1360 HUGE_TIMEOUT = 365 * 24 * 60 * 60 # one year +LDA_DISPATCHER_PREFIX = 'gensim.lda_dispatcher' class Dispatcher(object): @@ -48,15 +56,16 @@ class Dispatcher(object): There should never be more than one dispatcher running at any one time. """ - def __init__(self, maxsize=MAX_JOBS_QUEUE): + def __init__(self, maxsize=MAX_JOBS_QUEUE, ns_conf={}): """ Note that the constructor does not fully initialize the dispatcher; use the `initialize()` function to populate it with workers etc. """ self.maxsize = maxsize self.callback = None # a pyro proxy to this object (unknown at init time, but will be set later) + self.ns_conf = ns_conf - + @Pyro4.expose def initialize(self, **model_params): """ `model_params` are parameters used to initialize individual workers (gets @@ -69,10 +78,9 @@ def initialize(self, **model_params): # locate all available workers and store their proxies, for subsequent RMI calls self.workers = {} - import Pyro4 - with utils.getNS() as ns: - self.callback = Pyro4.Proxy('PYRONAME:gensim.lda_dispatcher') # = self - for name, uri in iteritems(ns.list(prefix='gensim.lda_worker')): + with utils.getNS(**self.ns_conf) as ns: + self.callback = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX]) + for name, uri in iteritems(ns.list(prefix=LDA_WORKER_PREFIX)): try: worker = Pyro4.Proxy(uri) workerid = len(self.workers) @@ -87,27 +95,27 @@ def initialize(self, **model_params): if not self.workers: raise RuntimeError('no workers found; run some lda_worker scripts on your machines first!') - + @Pyro4.expose def getworkers(self): """ Return pyro URIs of all registered workers. """ return [worker._pyroUri for worker in itervalues(self.workers)] - + @Pyro4.expose def getjob(self, worker_id): logger.info("worker #%i requesting a new job" % worker_id) job = self.jobs.get(block=True, timeout=1) logger.info("worker #%i got a new job (%i left)" % (worker_id, self.jobs.qsize())) return job - + @Pyro4.expose def putjob(self, job): self._jobsreceived += 1 self.jobs.put(job, block=True, timeout=HUGE_TIMEOUT) logger.info("added a new job (len(queue)=%i items)" % self.jobs.qsize()) - + @Pyro4.expose def getstate(self): """ Merge states from across all workers and return the result. @@ -126,7 +134,7 @@ def getstate(self): logger.info("sending out merged state") return result - + @Pyro4.expose def reset(self, state): """ Initialize all workers for a new EM iterations. @@ -138,7 +146,7 @@ def reset(self, state): self._jobsdone = 0 self._jobsreceived = 0 - + @Pyro4.expose @Pyro4.oneway @utils.synchronous('lock_update') def jobdone(self, workerid): @@ -172,25 +180,30 @@ def exit(self): #endclass Dispatcher - def main(): - logging.basicConfig(format = '%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) - logger.info("running %s" % " ".join(sys.argv)) - - program = os.path.basename(sys.argv[0]) - # make sure we have enough cmd line parameters - if len(sys.argv) < 1: - print(globals()["__doc__"] % locals()) - sys.exit(1) - - if len(sys.argv) < 2: - maxsize = MAX_JOBS_QUEUE - else: - maxsize = int(sys.argv[1]) - utils.pyro_daemon('gensim.lda_dispatcher', Dispatcher(maxsize=maxsize)) - - logger.info("finished running %s" % program) - + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--maxsize", help="How many jobs (=chunks of N documents) " + "to keep 'pre-fetched' in a queue (default: %(default)s)", + type=int, default=MAX_JOBS_QUEUE) + parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None) + parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int) + parser.add_argument("--no-broadcast", help="Disable broadcast (default: %(default)s)", + action='store_const', default=True, const=False) + parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None) + parser.add_argument('-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel", + const=logging.INFO, default=logging.WARNING) + args = parser.parse_args() + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel) + logger.info("running %s", " ".join(sys.argv)) + + ns_conf = {"broadcast": args.no_broadcast, + "host": args.host, + "port": args.port, + "hmac_key": args.hmac} + utils.pyro_daemon(LDA_DISPATCHER_PREFIX, Dispatcher(maxsize=args.maxsize, ns_conf=ns_conf), ns_conf=ns_conf) + + logger.info("finished running %s", " ".join(sys.argv)) if __name__ == '__main__': diff --git a/gensim/models/lda_worker.py b/gensim/models/lda_worker.py index 8b3f7b84bf..fbae4c0fff 100755 --- a/gensim/models/lda_worker.py +++ b/gensim/models/lda_worker.py @@ -5,9 +5,7 @@ # Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html """ -USAGE: %(program)s - - Worker ("slave") process used in computing distributed LDA. Run this script \ +Worker ("slave") process used in computing distributed LDA. Run this script \ on every node in your cluster. If you wish, you may even run it multiple times \ on a single machine, to make better use of multiple cores (just beware that \ memory footprint increases accordingly). @@ -17,9 +15,13 @@ from __future__ import with_statement -import os, sys, logging +import os +import sys +import logging import threading import tempfile +import argparse + try: import Queue except ImportError: @@ -34,13 +36,14 @@ # periodically save intermediate models after every SAVE_DEBUG updates (0 for never) SAVE_DEBUG = 0 +LDA_WORKER_PREFIX = 'gensim.lda_worker' class Worker(object): def __init__(self): self.model = None - + @Pyro4.expose def initialize(self, myid, dispatcher, **model_params): self.lock_update = threading.Lock() self.jobsdone = 0 # how many jobs has this worker completed? @@ -50,7 +53,7 @@ def initialize(self, myid, dispatcher, **model_params): logger.info("initializing worker #%s" % myid) self.model = ldamodel.LdaModel(**model_params) - + @Pyro4.expose @Pyro4.oneway def requestjob(self): """ @@ -84,7 +87,7 @@ def processjob(self, job): self.model.save(fname) logger.info("finished processing job #%i" % (self.jobsdone - 1)) - + @Pyro4.expose @utils.synchronous('lock_update') def getstate(self): logger.info("worker #%i returning its state after %s jobs" % @@ -95,7 +98,7 @@ def getstate(self): self.finished = True return result - + @Pyro4.expose @utils.synchronous('lock_update') def reset(self, state): assert state is not None @@ -113,21 +116,28 @@ def exit(self): #endclass Worker - def main(): - logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) - logger.info("running %s" % " ".join(sys.argv)) - - program = os.path.basename(sys.argv[0]) - # make sure we have enough cmd line parameters - if len(sys.argv) < 1: - print(globals()["__doc__"] % locals()) - sys.exit(1) - - utils.pyro_daemon('gensim.lda_worker', Worker(), random_suffix=True) - - logger.info("finished running %s" % program) - + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None) + parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int) + parser.add_argument("--no-broadcast", help="Disable broadcast (default: %(default)s)", + action='store_const', default=True, const=False) + parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None) + parser.add_argument('-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel", + const=logging.INFO, default=logging.WARNING) + args = parser.parse_args() + + logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel) + logger.info("running %s", " ".join(sys.argv)) + + ns_conf = {"broadcast": args.no_broadcast, + "host": args.host, + "port": args.port, + "hmac_key": args.hmac} + + utils.pyro_daemon(LDA_WORKER_PREFIX, Worker(), random_suffix=True, ns_conf=ns_conf) + + logger.info("finished running %s", " ".join(sys.argv)) if __name__ == '__main__': diff --git a/gensim/models/ldamodel.py b/gensim/models/ldamodel.py index 8456e94a02..ee5263f13b 100755 --- a/gensim/models/ldamodel.py +++ b/gensim/models/ldamodel.py @@ -216,7 +216,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, distributed=False, chunksize=2000, passes=1, update_every=1, alpha='symmetric', eta=None, decay=0.5, offset=1.0, eval_every=10, iterations=50, gamma_threshold=0.001, - minimum_probability=0.01, random_state=None): + minimum_probability=0.01, random_state=None, ns_conf={}): """ If given, start training from the iterable `corpus` straight away. If not given, the model is left untrained (presumably because you want to call `update()` manually). @@ -325,13 +325,14 @@ def __init__(self, corpus=None, num_topics=100, id2word=None, # set up distributed version try: import Pyro4 - dispatcher = Pyro4.Proxy('PYRONAME:gensim.lda_dispatcher') - logger.debug("looking for dispatcher at %s" % str(dispatcher._pyroUri)) - dispatcher.initialize(id2word=self.id2word, num_topics=self.num_topics, - chunksize=chunksize, alpha=alpha, eta=eta, distributed=False) - self.dispatcher = dispatcher - self.numworkers = len(dispatcher.getworkers()) - logger.info("using distributed version with %i workers" % self.numworkers) + with utils.getNS(**ns_conf) as ns: + from gensim.models.lda_dispatcher import LDA_DISPATCHER_PREFIX + self.dispatcher = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX]) + logger.debug("looking for dispatcher at %s" % str(self.dispatcher._pyroUri)) + self.dispatcher.initialize(id2word=self.id2word, num_topics=self.num_topics, + chunksize=chunksize, alpha=alpha, eta=eta, distributed=False) + self.numworkers = len(self.dispatcher.getworkers()) + logger.info("using distributed version with %i workers" % self.numworkers) except Exception as err: logger.error("failed to initialize distributed LDA (%s)", err) raise RuntimeError("failed to initialize distributed LDA (%s)" % err) diff --git a/gensim/utils.py b/gensim/utils.py index 4699cad1f2..2867247e3e 100644 --- a/gensim/utils.py +++ b/gensim/utils.py @@ -971,29 +971,18 @@ def upload_chunked(server, docs, chunksize=1000, preprocess=None): start = end -def getNS(): +def getNS(host=None, port=None, broadcast=True, hmac_key=None): """ - Return a Pyro name server proxy. If there is no name server running, - start one on 0.0.0.0 (all interfaces), as a background process. - + Return a Pyro name server proxy. """ import Pyro4 try: - return Pyro4.locateNS() + return Pyro4.locateNS(host, port, broadcast, hmac_key) except Pyro4.errors.NamingError: - logger.info("Pyro name server not found; starting a new one") - os.system("python -m Pyro4.naming -n 0.0.0.0 &") - # TODO: spawn a proper daemon ala http://code.activestate.com/recipes/278731/ ? - # like this, if there's an error somewhere, we'll never know... (and the loop - # below will block). And it probably doesn't work on windows, either. - while True: - try: - return Pyro4.locateNS() - except: - pass + raise RuntimeError("Pyro name server not found") -def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None): +def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None, ns_conf={}): """ Register object with name server (starting the name server if not running yet) and block until the daemon is terminated. The object is registered under @@ -1003,7 +992,7 @@ def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None): if random_suffix: name += '.' + hex(random.randint(0, 0xffffff))[2:] import Pyro4 - with getNS() as ns: + with getNS(**ns_conf) as ns: with Pyro4.Daemon(ip or get_my_ip(), port or 0) as daemon: # register server for remote access uri = daemon.register(obj, name)