Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed lda options #782

Merged
merged 5 commits into from
Jul 13, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions gensim/models/lda_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,21 @@


from __future__ import with_statement
import os, sys, logging, threading, time
import argparse
Copy link
Owner

@piskvorky piskvorky Jul 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is py2.7 only. @tmylk I don't think we can drop support for py2.6 yet... is this import safe?

If it's triggered only on importing lda_dispatcher.py, it's probably fine... but we don't want py2.7+ imports in "core" gensim (at import gensim).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked, this triggered only on importing lda_dispatcher.py or lda_worker.py.
Backport for argparse in setup.py for python < 2.7 (proof)

import os
import sys
import logging
import threading
import time
from six import iteritems, itervalues

try:
from Queue import Queue
except ImportError:
from queue import Queue
import Pyro4
from gensim import utils
from gensim.models.lda_worker import LDA_WORKER_PREFIX


logger = logging.getLogger("gensim.models.lda_dispatcher")
Expand All @@ -39,6 +46,7 @@
# so this is really just a hack, see http://bugs.python.org/issue1360
HUGE_TIMEOUT = 365 * 24 * 60 * 60 # one year

LDA_DISPATCHER_PREFIX = 'gensim.lda_dispatcher'


class Dispatcher(object):
Expand All @@ -48,15 +56,16 @@ class Dispatcher(object):
There should never be more than one dispatcher running at any one time.
"""

def __init__(self, maxsize=MAX_JOBS_QUEUE):
def __init__(self, maxsize=MAX_JOBS_QUEUE, ns_conf={}):
"""
Note that the constructor does not fully initialize the dispatcher;
use the `initialize()` function to populate it with workers etc.
"""
self.maxsize = maxsize
self.callback = None # a pyro proxy to this object (unknown at init time, but will be set later)
self.ns_conf = ns_conf


@Pyro4.expose
def initialize(self, **model_params):
"""
`model_params` are parameters used to initialize individual workers (gets
Expand All @@ -69,10 +78,9 @@ def initialize(self, **model_params):

# locate all available workers and store their proxies, for subsequent RMI calls
self.workers = {}
import Pyro4
with utils.getNS() as ns:
self.callback = Pyro4.Proxy('PYRONAME:gensim.lda_dispatcher') # = self
for name, uri in iteritems(ns.list(prefix='gensim.lda_worker')):
with utils.getNS(**self.ns_conf) as ns:
self.callback = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX])
for name, uri in iteritems(ns.list(prefix=LDA_WORKER_PREFIX)):
try:
worker = Pyro4.Proxy(uri)
workerid = len(self.workers)
Expand All @@ -87,27 +95,27 @@ def initialize(self, **model_params):
if not self.workers:
raise RuntimeError('no workers found; run some lda_worker scripts on your machines first!')


@Pyro4.expose
def getworkers(self):
"""
Return pyro URIs of all registered workers.
"""
return [worker._pyroUri for worker in itervalues(self.workers)]


@Pyro4.expose
def getjob(self, worker_id):
logger.info("worker #%i requesting a new job" % worker_id)
job = self.jobs.get(block=True, timeout=1)
logger.info("worker #%i got a new job (%i left)" % (worker_id, self.jobs.qsize()))
return job


@Pyro4.expose
def putjob(self, job):
self._jobsreceived += 1
self.jobs.put(job, block=True, timeout=HUGE_TIMEOUT)
logger.info("added a new job (len(queue)=%i items)" % self.jobs.qsize())


@Pyro4.expose
def getstate(self):
"""
Merge states from across all workers and return the result.
Expand All @@ -126,7 +134,7 @@ def getstate(self):
logger.info("sending out merged state")
return result


@Pyro4.expose
def reset(self, state):
"""
Initialize all workers for a new EM iterations.
Expand All @@ -138,7 +146,7 @@ def reset(self, state):
self._jobsdone = 0
self._jobsreceived = 0


@Pyro4.expose
@Pyro4.oneway
@utils.synchronous('lock_update')
def jobdone(self, workerid):
Expand Down Expand Up @@ -172,25 +180,30 @@ def exit(self):
#endclass Dispatcher



def main():
logging.basicConfig(format = '%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
logger.info("running %s" % " ".join(sys.argv))

program = os.path.basename(sys.argv[0])
# make sure we have enough cmd line parameters
if len(sys.argv) < 1:
print(globals()["__doc__"] % locals())
sys.exit(1)

if len(sys.argv) < 2:
maxsize = MAX_JOBS_QUEUE
else:
maxsize = int(sys.argv[1])
utils.pyro_daemon('gensim.lda_dispatcher', Dispatcher(maxsize=maxsize))

logger.info("finished running %s" % program)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--maxsize", help="How many jobs (=chunks of N documents) "
"to keep 'pre-fetched' in a queue (default: %(default)s)",
type=int, default=MAX_JOBS_QUEUE)
parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None)
parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int)
parser.add_argument("--no-broadcast", help="Disable broadcast (default: %(default)s)",
action='store_const', default=True, const=False)
parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None)
parser.add_argument('-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel",
const=logging.INFO, default=logging.WARNING)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel)
logger.info("running %s", " ".join(sys.argv))

ns_conf = {"broadcast": args.no_broadcast,
"host": args.host,
"port": args.port,
"hmac_key": args.hmac}
utils.pyro_daemon(LDA_DISPATCHER_PREFIX, Dispatcher(maxsize=args.maxsize, ns_conf=ns_conf), ns_conf=ns_conf)

logger.info("finished running %s", " ".join(sys.argv))


if __name__ == '__main__':
Expand Down
54 changes: 32 additions & 22 deletions gensim/models/lda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
USAGE: %(program)s

Worker ("slave") process used in computing distributed LDA. Run this script \
Worker ("slave") process used in computing distributed LDA. Run this script \
on every node in your cluster. If you wish, you may even run it multiple times \
on a single machine, to make better use of multiple cores (just beware that \
memory footprint increases accordingly).
Expand All @@ -17,9 +15,13 @@


from __future__ import with_statement
import os, sys, logging
import os
import sys
import logging
import threading
import tempfile
import argparse

try:
import Queue
except ImportError:
Expand All @@ -34,13 +36,14 @@
# periodically save intermediate models after every SAVE_DEBUG updates (0 for never)
SAVE_DEBUG = 0

LDA_WORKER_PREFIX = 'gensim.lda_worker'


class Worker(object):
def __init__(self):
self.model = None


@Pyro4.expose
def initialize(self, myid, dispatcher, **model_params):
self.lock_update = threading.Lock()
self.jobsdone = 0 # how many jobs has this worker completed?
Expand All @@ -50,7 +53,7 @@ def initialize(self, myid, dispatcher, **model_params):
logger.info("initializing worker #%s" % myid)
self.model = ldamodel.LdaModel(**model_params)


@Pyro4.expose
@Pyro4.oneway
def requestjob(self):
"""
Expand Down Expand Up @@ -84,7 +87,7 @@ def processjob(self, job):
self.model.save(fname)
logger.info("finished processing job #%i" % (self.jobsdone - 1))


@Pyro4.expose
@utils.synchronous('lock_update')
def getstate(self):
logger.info("worker #%i returning its state after %s jobs" %
Expand All @@ -95,7 +98,7 @@ def getstate(self):
self.finished = True
return result


@Pyro4.expose
@utils.synchronous('lock_update')
def reset(self, state):
assert state is not None
Expand All @@ -113,21 +116,28 @@ def exit(self):
#endclass Worker



def main():
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
logger.info("running %s" % " ".join(sys.argv))

program = os.path.basename(sys.argv[0])
# make sure we have enough cmd line parameters
if len(sys.argv) < 1:
print(globals()["__doc__"] % locals())
sys.exit(1)

utils.pyro_daemon('gensim.lda_worker', Worker(), random_suffix=True)

logger.info("finished running %s" % program)

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--host", help="Nameserver hostname (default: %(default)s)", default=None)
parser.add_argument("--port", help="Nameserver port (default: %(default)s)", default=None, type=int)
parser.add_argument("--no-broadcast", help="Disable broadcast (default: %(default)s)",
action='store_const', default=True, const=False)
parser.add_argument("--hmac", help="Nameserver hmac key (default: %(default)s)", default=None)
parser.add_argument('-v', '--verbose', help='Verbose flag', action='store_const', dest="loglevel",
const=logging.INFO, default=logging.WARNING)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=args.loglevel)
logger.info("running %s", " ".join(sys.argv))

ns_conf = {"broadcast": args.no_broadcast,
"host": args.host,
"port": args.port,
"hmac_key": args.hmac}

utils.pyro_daemon(LDA_WORKER_PREFIX, Worker(), random_suffix=True, ns_conf=ns_conf)

logger.info("finished running %s", " ".join(sys.argv))


if __name__ == '__main__':
Expand Down
17 changes: 9 additions & 8 deletions gensim/models/ldamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
distributed=False, chunksize=2000, passes=1, update_every=1,
alpha='symmetric', eta=None, decay=0.5, offset=1.0,
eval_every=10, iterations=50, gamma_threshold=0.001,
minimum_probability=0.01, random_state=None):
minimum_probability=0.01, random_state=None, ns_conf={}):
"""
If given, start training from the iterable `corpus` straight away. If not given,
the model is left untrained (presumably because you want to call `update()` manually).
Expand Down Expand Up @@ -325,13 +325,14 @@ def __init__(self, corpus=None, num_topics=100, id2word=None,
# set up distributed version
try:
import Pyro4
dispatcher = Pyro4.Proxy('PYRONAME:gensim.lda_dispatcher')
logger.debug("looking for dispatcher at %s" % str(dispatcher._pyroUri))
dispatcher.initialize(id2word=self.id2word, num_topics=self.num_topics,
chunksize=chunksize, alpha=alpha, eta=eta, distributed=False)
self.dispatcher = dispatcher
self.numworkers = len(dispatcher.getworkers())
logger.info("using distributed version with %i workers" % self.numworkers)
with utils.getNS(**ns_conf) as ns:
from gensim.models.lda_dispatcher import LDA_DISPATCHER_PREFIX
self.dispatcher = Pyro4.Proxy(ns.list(prefix=LDA_DISPATCHER_PREFIX)[LDA_DISPATCHER_PREFIX])
logger.debug("looking for dispatcher at %s" % str(self.dispatcher._pyroUri))
self.dispatcher.initialize(id2word=self.id2word, num_topics=self.num_topics,
chunksize=chunksize, alpha=alpha, eta=eta, distributed=False)
self.numworkers = len(self.dispatcher.getworkers())
logger.info("using distributed version with %i workers" % self.numworkers)
except Exception as err:
logger.error("failed to initialize distributed LDA (%s)", err)
raise RuntimeError("failed to initialize distributed LDA (%s)" % err)
Expand Down
23 changes: 6 additions & 17 deletions gensim/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,29 +971,18 @@ def upload_chunked(server, docs, chunksize=1000, preprocess=None):
start = end


def getNS():
def getNS(host=None, port=None, broadcast=True, hmac_key=None):
"""
Return a Pyro name server proxy. If there is no name server running,
start one on 0.0.0.0 (all interfaces), as a background process.

Return a Pyro name server proxy.
"""
import Pyro4
try:
return Pyro4.locateNS()
return Pyro4.locateNS(host, port, broadcast, hmac_key)
except Pyro4.errors.NamingError:
logger.info("Pyro name server not found; starting a new one")
os.system("python -m Pyro4.naming -n 0.0.0.0 &")
# TODO: spawn a proper daemon ala http://code.activestate.com/recipes/278731/ ?
# like this, if there's an error somewhere, we'll never know... (and the loop
# below will block). And it probably doesn't work on windows, either.
while True:
try:
return Pyro4.locateNS()
except:
pass
raise RuntimeError("Pyro name server not found")


def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None):
def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None, ns_conf={}):
"""
Register object with name server (starting the name server if not running
yet) and block until the daemon is terminated. The object is registered under
Expand All @@ -1003,7 +992,7 @@ def pyro_daemon(name, obj, random_suffix=False, ip=None, port=None):
if random_suffix:
name += '.' + hex(random.randint(0, 0xffffff))[2:]
import Pyro4
with getNS() as ns:
with getNS(**ns_conf) as ns:
with Pyro4.Daemon(ip or get_my_ip(), port or 0) as daemon:
# register server for remote access
uri = daemon.register(obj, name)
Expand Down