diff --git a/bin/build_docker_images.sh b/bin/build_docker_images.sh
index aec8cc2d2..aeae1ada5 100755
--- a/bin/build_docker_images.sh
+++ b/bin/build_docker_images.sh
@@ -258,6 +258,7 @@ build_images () {
create_image pyspark-container PySparkContainerDockerfile $public
create_image tf_cifar_container TensorFlowCifarDockerfile $public
create_image tf-container TensorFlowDockerfile $public
+ create_image pytorch-container PyTorchContainerDockerfile $public
}
diff --git a/clipper_admin/clipper_admin/deployers/cloudpickle.py b/clipper_admin/clipper_admin/deployers/cloudpickle.py
deleted file mode 100644
index 7870076ff..000000000
--- a/clipper_admin/clipper_admin/deployers/cloudpickle.py
+++ /dev/null
@@ -1,878 +0,0 @@
-"""
-This class is defined to override standard pickle functionality
-The goals of it follow:
--Serialize lambdas and nested functions to compiled byte code
--Deal with main module correctly
--Deal with other non-serializable objects
-It does not include an unpickler, as standard python unpickling suffices.
-This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
-`_.
-Copyright (c) 2012, Regents of the University of California.
-Copyright (c) 2009 `PiCloud, Inc. `_.
-All rights reserved.
-Redistribution and use in source and binary forms, with or without
-modification, are permitted provided that the following conditions
-are met:
- * Redistributions of source code must retain the above copyright
- notice, this list of conditions and the following disclaimer.
- * Redistributions in binary form must reproduce the above copyright
- notice, this list of conditions and the following disclaimer in the
- documentation and/or other materials provided with the distribution.
- * Neither the name of the University of California, Berkeley nor the
- names of its contributors may be used to endorse or promote
- products derived from this software without specific prior written
- permission.
-THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
-TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
-PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
-LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
-NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
-SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-"""
-from __future__ import print_function
-
-import dis
-from functools import partial
-import imp
-import io
-import itertools
-import opcode
-import operator
-import pickle
-import struct
-import sys
-import traceback
-import types
-import weakref
-
-if sys.version < '3':
- from pickle import Pickler
- try:
- from cStringIO import StringIO
- except ImportError:
- from StringIO import StringIO
- PY3 = False
-else:
- types.ClassType = type
- from pickle import _Pickler as Pickler
- from io import BytesIO as StringIO
- PY3 = True
-
-#relevant opcodes
-STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
-DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
-LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
-GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
-HAVE_ARGUMENT = dis.HAVE_ARGUMENT
-EXTENDED_ARG = dis.EXTENDED_ARG
-
-
-def islambda(func):
- return getattr(func, '__name__') == ''
-
-
-_BUILTIN_TYPE_NAMES = {}
-for k, v in types.__dict__.items():
- if type(v) is type:
- _BUILTIN_TYPE_NAMES[v] = k
-
-
-def _builtin_type(name):
- return getattr(types, name)
-
-
-if sys.version_info < (3, 4):
-
- def _walk_global_ops(code):
- """
- Yield (opcode, argument number) tuples for all
- global-referencing instructions in *code*.
- """
- code = getattr(code, 'co_code', b'')
- if not PY3:
- code = map(ord, code)
-
- n = len(code)
- i = 0
- extended_arg = 0
- while i < n:
- op = code[i]
- i += 1
- if op >= HAVE_ARGUMENT:
- oparg = code[i] + code[i + 1] * 256 + extended_arg
- extended_arg = 0
- i += 2
- if op == EXTENDED_ARG:
- extended_arg = oparg * 65536
- if op in GLOBAL_OPS:
- yield op, oparg
-
-else:
-
- def _walk_global_ops(code):
- """
- Yield (opcode, argument number) tuples for all
- global-referencing instructions in *code*.
- """
- for instr in dis.get_instructions(code):
- op = instr.opcode
- if op in GLOBAL_OPS:
- yield op, instr.arg
-
-
-class CloudPickler(Pickler):
-
- dispatch = Pickler.dispatch.copy()
-
- def __init__(self, file, protocol=None):
- Pickler.__init__(self, file, protocol)
- # set of modules to unpickle
- self.modules = set()
- # map ids to dictionary. used to ensure that functions can share global env
- self.globals_ref = {}
-
- def dump(self, obj):
- self.inject_addons()
- try:
- return Pickler.dump(self, obj)
- except RuntimeError as e:
- if 'recursion' in e.args[0]:
- msg = """Could not pickle object as excessively deep recursion required."""
- raise pickle.PicklingError(msg)
-
- def save_memoryview(self, obj):
- """Fallback to save_string"""
- Pickler.save_string(self, str(obj))
-
- def save_buffer(self, obj):
- """Fallback to save_string"""
- Pickler.save_string(self, str(obj))
-
- if PY3:
- dispatch[memoryview] = save_memoryview
- else:
- dispatch[buffer] = save_buffer
-
- def save_unsupported(self, obj):
- raise pickle.PicklingError(
- "Cannot pickle objects of type %s" % type(obj))
-
- dispatch[types.GeneratorType] = save_unsupported
-
- # itertools objects do not pickle!
- for v in itertools.__dict__.values():
- if type(v) is type:
- dispatch[v] = save_unsupported
-
- def save_module(self, obj):
- """
- Save a module as an import
- """
- mod_name = obj.__name__
- # If module is successfully found then it is not a dynamically created module
- try:
- _find_module(mod_name)
- is_dynamic = False
- except ImportError:
- is_dynamic = True
-
- self.modules.add(obj)
- if is_dynamic:
- self.save_reduce(
- dynamic_subimport, (obj.__name__, vars(obj)), obj=obj)
- else:
- self.save_reduce(subimport, (obj.__name__, ), obj=obj)
-
- dispatch[types.ModuleType] = save_module
-
- def save_codeobject(self, obj):
- """
- Save a code object
- """
- if PY3:
- args = (obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
- obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
- obj.co_names, obj.co_varnames, obj.co_filename,
- obj.co_name, obj.co_firstlineno, obj.co_lnotab,
- obj.co_freevars, obj.co_cellvars)
- else:
- args = (obj.co_argcount, obj.co_nlocals, obj.co_stacksize,
- obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
- obj.co_varnames, obj.co_filename, obj.co_name,
- obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
- obj.co_cellvars)
- self.save_reduce(types.CodeType, args, obj=obj)
-
- dispatch[types.CodeType] = save_codeobject
-
- def save_function(self, obj, name=None):
- """ Registered with the dispatch to handle all function types.
- Determines what kind of function obj is (e.g. lambda, defined at
- interactive prompt, etc) and handles the pickling appropriately.
- """
- write = self.write
-
- if name is None:
- name = obj.__name__
- modname = pickle.whichmodule(obj, name)
- # print('which gives %s %s %s' % (modname, obj, name))
- try:
- themodule = sys.modules[modname]
- except KeyError:
- # eval'd items such as namedtuple give invalid items for their function __module__
- modname = '__main__'
-
- if modname == '__main__':
- themodule = None
-
- if themodule:
- self.modules.add(themodule)
- if getattr(themodule, name, None) is obj:
- return self.save_global(obj, name)
-
- # a builtin_function_or_method which comes in as an attribute of some
- # object (e.g., object.__new__, itertools.chain.from_iterable) will end
- # up with modname "__main__" and so end up here. But these functions
- # have no __code__ attribute in CPython, so the handling for
- # user-defined functions below will fail.
- # So we pickle them here using save_reduce; have to do it differently
- # for different python versions.
- if not hasattr(obj, '__code__'):
- if PY3:
- if sys.version_info < (3, 4):
- raise pickle.PicklingError("Can't pickle %r" % obj)
- else:
- rv = obj.__reduce_ex__(self.proto)
- else:
- if hasattr(obj, '__self__'):
- rv = (getattr, (obj.__self__, name))
- else:
- raise pickle.PicklingError("Can't pickle %r" % obj)
- return Pickler.save_reduce(self, obj=obj, *rv)
-
- # if func is lambda, def'ed at prompt, is in main, or is nested, then
- # we'll pickle the actual function object rather than simply saving a
- # reference (as is done in default pickler), via save_function_tuple.
- if (islambda(obj) or
- getattr(obj.__code__, 'co_filename', None) == '' or
- themodule is None):
- self.save_function_tuple(obj)
- return
- else:
- # func is nested
- klass = getattr(themodule, name, None)
- if klass is None or klass is not obj:
- self.save_function_tuple(obj)
- return
-
- if obj.__dict__:
- # essentially save_reduce, but workaround needed to avoid recursion
- self.save(_restore_attr)
- write(pickle.MARK + pickle.GLOBAL + modname + '\n' + name + '\n')
- self.memoize(obj)
- self.save(obj.__dict__)
- write(pickle.TUPLE + pickle.REDUCE)
- else:
- write(pickle.GLOBAL + modname + '\n' + name + '\n')
- self.memoize(obj)
-
- dispatch[types.FunctionType] = save_function
-
- def save_function_tuple(self, func):
- """ Pickles an actual func object.
- A func comprises: code, globals, defaults, closure, and dict. We
- extract and save these, injecting reducing functions at certain points
- to recreate the func object. Keep in mind that some of these pieces
- can contain a ref to the func itself. Thus, a naive save on these
- pieces could trigger an infinite loop of save's. To get around that,
- we first create a skeleton func object using just the code (this is
- safe, since this won't contain a ref to the func), and memoize it as
- soon as it's created. The other stuff can then be filled in later.
- """
- if is_tornado_coroutine(func):
- self.save_reduce(
- _rebuild_tornado_coroutine, (func.__wrapped__, ), obj=func)
- return
-
- save = self.save
- write = self.write
-
- code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(
- func)
-
- save(_fill_function) # skeleton function updater
- write(pickle.MARK) # beginning of tuple that _fill_function expects
-
- # create a skeleton function object and memoize it
- save(_make_skel_func)
- save((code, closure, base_globals))
- write(pickle.REDUCE)
- self.memoize(func)
-
- # save the rest of the func data needed by _fill_function
- save(f_globals)
- save(defaults)
- save(dct)
- write(pickle.TUPLE)
- write(pickle.REDUCE) # applies _fill_function on the tuple
-
- _extract_code_globals_cache = (weakref.WeakKeyDictionary()
- if sys.version_info >= (2, 7) and
- not hasattr(sys, "pypy_version_info") else
- {})
-
- @classmethod
- def extract_code_globals(cls, co):
- """
- Find all globals names read or written to by codeblock co
- """
- out_names = cls._extract_code_globals_cache.get(co)
- if out_names is None:
- try:
- names = co.co_names
- except AttributeError:
- # PyPy "builtin-code" object
- out_names = set()
- else:
- out_names = set(names[oparg]
- for op, oparg in _walk_global_ops(co))
-
- # see if nested function have any global refs
- if co.co_consts:
- for const in co.co_consts:
- if type(const) is types.CodeType:
- out_names |= cls.extract_code_globals(const)
-
- cls._extract_code_globals_cache[co] = out_names
-
- return out_names
-
- def extract_func_data(self, func):
- """
- Turn the function into a tuple of data necessary to recreate it:
- code, globals, defaults, closure, dict
- """
- code = func.__code__
-
- # extract all global ref's
- func_global_refs = self.extract_code_globals(code)
-
- # process all variables referenced by global environment
- f_globals = {}
- for var in func_global_refs:
- if var in func.__globals__:
- f_globals[var] = func.__globals__[var]
-
- # defaults requires no processing
- defaults = func.__defaults__
-
- # process closure
- closure = [c.cell_contents
- for c in func.__closure__] if func.__closure__ else []
-
- # save the dict
- dct = func.__dict__
-
- base_globals = self.globals_ref.get(id(func.__globals__), {})
- self.globals_ref[id(func.__globals__)] = base_globals
-
- return (code, f_globals, defaults, closure, dct, base_globals)
-
- def save_builtin_function(self, obj):
- if obj.__module__ == "__builtin__":
- return self.save_global(obj)
- return self.save_function(obj)
-
- dispatch[types.BuiltinFunctionType] = save_builtin_function
-
- def save_global(self, obj, name=None, pack=struct.pack):
- if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
- if obj in _BUILTIN_TYPE_NAMES:
- return self.save_reduce(
- _builtin_type, (_BUILTIN_TYPE_NAMES[obj], ), obj=obj)
-
- if name is None:
- name = obj.__name__
-
- modname = getattr(obj, "__module__", None)
- if modname is None:
- modname = pickle.whichmodule(obj, name)
-
- if modname == '__main__':
- themodule = None
- else:
- __import__(modname)
- themodule = sys.modules[modname]
- self.modules.add(themodule)
-
- if hasattr(themodule, name) and getattr(themodule, name) is obj:
- return Pickler.save_global(self, obj, name)
-
- typ = type(obj)
- if typ is not obj and isinstance(obj, (type, types.ClassType)):
- d = dict(obj.__dict__) # copy dict proxy to a dict
- if not isinstance(d.get('__dict__', None), property):
- # don't extract dict that are properties
- d.pop('__dict__', None)
- d.pop('__weakref__', None)
-
- # hack as __new__ is stored differently in the __dict__
- new_override = d.get('__new__', None)
- if new_override:
- d['__new__'] = obj.__new__
-
- self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
- else:
- raise pickle.PicklingError("Can't pickle %r" % obj)
-
- dispatch[type] = save_global
- dispatch[types.ClassType] = save_global
-
- def save_instancemethod(self, obj):
- # Memoization rarely is ever useful due to python bounding
- if obj.__self__ is None:
- self.save_reduce(getattr, (obj.im_class, obj.__name__))
- else:
- if PY3:
- self.save_reduce(
- types.MethodType, (obj.__func__, obj.__self__), obj=obj)
- else:
- self.save_reduce(
- types.MethodType, (obj.__func__, obj.__self__,
- obj.__self__.__class__),
- obj=obj)
-
- dispatch[types.MethodType] = save_instancemethod
-
- def save_inst(self, obj):
- """Inner logic to save instance. Based off pickle.save_inst
- Supports __transient__"""
- cls = obj.__class__
-
- memo = self.memo
- write = self.write
- save = self.save
-
- if hasattr(obj, '__getinitargs__'):
- args = obj.__getinitargs__()
- len(args) # XXX Assert it's a sequence
- pickle._keep_alive(args, memo)
- else:
- args = ()
-
- write(pickle.MARK)
-
- if self.bin:
- save(cls)
- for arg in args:
- save(arg)
- write(pickle.OBJ)
- else:
- for arg in args:
- save(arg)
- write(pickle.INST + cls.__module__ + '\n' + cls.__name__ + '\n')
-
- self.memoize(obj)
-
- try:
- getstate = obj.__getstate__
- except AttributeError:
- stuff = obj.__dict__
- #remove items if transient
- if hasattr(obj, '__transient__'):
- transient = obj.__transient__
- stuff = stuff.copy()
- for k in list(stuff.keys()):
- if k in transient:
- del stuff[k]
- else:
- stuff = getstate()
- pickle._keep_alive(stuff, memo)
- save(stuff)
- write(pickle.BUILD)
-
- if not PY3:
- dispatch[types.InstanceType] = save_inst
-
- def save_property(self, obj):
- # properties not correctly saved in python
- self.save_reduce(
- property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj)
-
- dispatch[property] = save_property
-
- def save_classmethod(self, obj):
- try:
- orig_func = obj.__func__
- except AttributeError: # Python 2.6
- orig_func = obj.__get__(None, object)
- if isinstance(obj, classmethod):
- orig_func = orig_func.__func__ # Unbind
- self.save_reduce(type(obj), (orig_func, ), obj=obj)
-
- dispatch[classmethod] = save_classmethod
- dispatch[staticmethod] = save_classmethod
-
- def save_itemgetter(self, obj):
- """itemgetter serializer (needed for namedtuple support)"""
-
- class Dummy:
- def __getitem__(self, item):
- return item
-
- items = obj(Dummy())
- if not isinstance(items, tuple):
- items = (items, )
- return self.save_reduce(operator.itemgetter, items)
-
- if type(operator.itemgetter) is type:
- dispatch[operator.itemgetter] = save_itemgetter
-
- def save_attrgetter(self, obj):
- """attrgetter serializer"""
-
- class Dummy(object):
- def __init__(self, attrs, index=None):
- self.attrs = attrs
- self.index = index
-
- def __getattribute__(self, item):
- attrs = object.__getattribute__(self, "attrs")
- index = object.__getattribute__(self, "index")
- if index is None:
- index = len(attrs)
- attrs.append(item)
- else:
- attrs[index] = ".".join([attrs[index], item])
- return type(self)(attrs, index)
-
- attrs = []
- obj(Dummy(attrs))
- return self.save_reduce(operator.attrgetter, tuple(attrs))
-
- if type(operator.attrgetter) is type:
- dispatch[operator.attrgetter] = save_attrgetter
-
- def save_reduce(self,
- func,
- args,
- state=None,
- listitems=None,
- dictitems=None,
- obj=None):
- """Modified to support __transient__ on new objects
- Change only affects protocol level 2 (which is always used by PiCloud"""
- # Assert that args is a tuple or None
- if not isinstance(args, tuple):
- raise pickle.PicklingError("args from reduce() should be a tuple")
-
- # Assert that func is callable
- if not hasattr(func, '__call__'):
- raise pickle.PicklingError("func from reduce should be callable")
-
- save = self.save
- write = self.write
-
- # Protocol 2 special case: if func's name is __newobj__, use NEWOBJ
- if self.proto >= 2 and getattr(func, "__name__", "") == "__newobj__":
- #Added fix to allow transient
- cls = args[0]
- if not hasattr(cls, "__new__"):
- raise pickle.PicklingError(
- "args[0] from __newobj__ args has no __new__")
- if obj is not None and cls is not obj.__class__:
- raise pickle.PicklingError(
- "args[0] from __newobj__ args has the wrong class")
- args = args[1:]
- save(cls)
-
- #Don't pickle transient entries
- if hasattr(obj, '__transient__'):
- transient = obj.__transient__
- state = state.copy()
-
- for k in list(state.keys()):
- if k in transient:
- del state[k]
-
- save(args)
- write(pickle.NEWOBJ)
- else:
- save(func)
- save(args)
- write(pickle.REDUCE)
-
- if obj is not None:
- self.memoize(obj)
-
- # More new special cases (that work with older protocols as
- # well): when __reduce__ returns a tuple with 4 or 5 items,
- # the 4th and 5th item should be iterators that provide list
- # items and dict items (as (key, value) tuples), or None.
-
- if listitems is not None:
- self._batch_appends(listitems)
-
- if dictitems is not None:
- self._batch_setitems(dictitems)
-
- if state is not None:
- save(state)
- write(pickle.BUILD)
-
- def save_partial(self, obj):
- """Partial objects do not serialize correctly in python2.x -- this fixes the bugs"""
- self.save_reduce(_genpartial, (obj.func, obj.args, obj.keywords))
-
- if sys.version_info < (2, 7): # 2.7 supports partial pickling
- dispatch[partial] = save_partial
-
- def save_file(self, obj):
- """Save a file"""
- try:
- import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute
- except ImportError:
- import io as pystringIO
-
- if not hasattr(obj, 'name') or not hasattr(obj, 'mode'):
- raise pickle.PicklingError(
- "Cannot pickle files that do not map to an actual file")
- if obj is sys.stdout:
- return self.save_reduce(getattr, (sys, 'stdout'), obj=obj)
- if obj is sys.stderr:
- return self.save_reduce(getattr, (sys, 'stderr'), obj=obj)
- if obj is sys.stdin:
- raise pickle.PicklingError("Cannot pickle standard input")
- if obj.closed:
- raise pickle.PicklingError("Cannot pickle closed files")
- if hasattr(obj, 'isatty') and obj.isatty():
- raise pickle.PicklingError(
- "Cannot pickle files that map to tty objects")
- if 'r' not in obj.mode and '+' not in obj.mode:
- raise pickle.PicklingError(
- "Cannot pickle files that are not opened for reading: %s" %
- obj.mode)
-
- name = obj.name
-
- retval = pystringIO.StringIO()
-
- try:
- # Read the whole file
- curloc = obj.tell()
- obj.seek(0)
- contents = obj.read()
- obj.seek(curloc)
- except IOError:
- raise pickle.PicklingError(
- "Cannot pickle file %s as it cannot be read" % name)
- retval.write(contents)
- retval.seek(curloc)
-
- retval.name = name
- self.save(retval)
- self.memoize(obj)
-
- def save_ellipsis(self, obj):
- self.save_reduce(_gen_ellipsis, ())
-
- def save_not_implemented(self, obj):
- self.save_reduce(_gen_not_implemented, ())
-
- if PY3:
- dispatch[io.TextIOWrapper] = save_file
- else:
- dispatch[file] = save_file
-
- dispatch[type(Ellipsis)] = save_ellipsis
- dispatch[type(NotImplemented)] = save_not_implemented
- """Special functions for Add-on libraries"""
-
- def inject_addons(self):
- """Plug in system. Register additional pickling functions if modules already loaded"""
- pass
-
-
-# Tornado support
-
-
-def is_tornado_coroutine(func):
- """
- Return whether *func* is a Tornado coroutine function.
- Running coroutines are not supported.
- """
- if 'tornado.gen' not in sys.modules:
- return False
- gen = sys.modules['tornado.gen']
- if not hasattr(gen, "is_coroutine_function"):
- # Tornado version is too old
- return False
- return gen.is_coroutine_function(func)
-
-
-def _rebuild_tornado_coroutine(func):
- from tornado import gen
- return gen.coroutine(func)
-
-
-# Shorthands for legacy support
-
-
-def dump(obj, file, protocol=2):
- CloudPickler(file, protocol).dump(obj)
-
-
-def dumps(obj, protocol=2):
- file = StringIO()
-
- cp = CloudPickler(file, protocol)
- cp.dump(obj)
-
- return file.getvalue()
-
-
-# including pickles unloading functions in this namespace
-load = pickle.load
-loads = pickle.loads
-
-
-#hack for __import__ not working as desired
-def subimport(name):
- __import__(name)
- return sys.modules[name]
-
-
-def dynamic_subimport(name, vars):
- mod = imp.new_module(name)
- mod.__dict__.update(vars)
- sys.modules[name] = mod
- return mod
-
-
-# restores function attributes
-def _restore_attr(obj, attr):
- for key, val in attr.items():
- setattr(obj, key, val)
- return obj
-
-
-def _get_module_builtins():
- return pickle.__builtins__
-
-
-def print_exec(stream):
- ei = sys.exc_info()
- traceback.print_exception(ei[0], ei[1], ei[2], None, stream)
-
-
-def _modules_to_main(modList):
- """Force every module in modList to be placed into main"""
- if not modList:
- return
-
- main = sys.modules['__main__']
- for modname in modList:
- if type(modname) is str:
- try:
- mod = __import__(modname)
- except Exception as e:
- sys.stderr.write(
- 'warning: could not import %s\n. '
- 'Your function may unexpectedly error due to this import failing;'
- 'A version mismatch is likely. Specific error was:\n' %
- modname)
- print_exec(sys.stderr)
- else:
- setattr(main, mod.__name__, mod)
-
-
-#object generators:
-def _genpartial(func, args, kwds):
- if not args:
- args = ()
- if not kwds:
- kwds = {}
- return partial(func, *args, **kwds)
-
-
-def _gen_ellipsis():
- return Ellipsis
-
-
-def _gen_not_implemented():
- return NotImplemented
-
-
-def _fill_function(func, globals, defaults, dict):
- """ Fills in the rest of function data into the skeleton function object
- that were created via _make_skel_func().
- """
- func.__globals__.update(globals)
- func.__defaults__ = defaults
- func.__dict__ = dict
-
- return func
-
-
-def _make_cell(value):
- return (lambda: value).__closure__[0]
-
-
-def _reconstruct_closure(values):
- return tuple([_make_cell(v) for v in values])
-
-
-def _make_skel_func(code, closures, base_globals=None):
- """ Creates a skeleton function object that contains just the provided
- code and the correct number of cells in func_closure. All other
- func attributes (e.g. func_globals) are empty.
- """
- closure = _reconstruct_closure(closures) if closures else None
-
- if base_globals is None:
- base_globals = {}
- base_globals['__builtins__'] = __builtins__
-
- return types.FunctionType(code, base_globals, None, None, closure)
-
-
-def _find_module(mod_name):
- """
- Iterate over each part instead of calling imp.find_module directly.
- This function is able to find submodules (e.g. sickit.tree)
- """
- path = None
- for part in mod_name.split('.'):
- if path is not None:
- path = [path]
- file, path, description = imp.find_module(part, path)
- return file, path, description
-
-
-"""Constructors for 3rd party libraries
-Note: These can never be renamed due to client compatibility issues"""
-
-
-def _getobject(modname, attribute):
- mod = __import__(modname, fromlist=[attribute])
- return mod.__dict__[attribute]
-
-
-""" Use copy_reg to extend global pickle definitions """
-
-if sys.version_info < (3, 4):
- method_descriptor = type(str.upper)
-
- def _reduce_method_descriptor(obj):
- return (getattr, (obj.__objclass__, obj.__name__))
-
- try:
- import copy_reg as copyreg
- except ImportError:
- import copyreg
- copyreg.pickle(method_descriptor, _reduce_method_descriptor)
diff --git a/clipper_admin/clipper_admin/deployers/deployer_utils.py b/clipper_admin/clipper_admin/deployers/deployer_utils.py
index 54c950186..f2186ec78 100644
--- a/clipper_admin/clipper_admin/deployers/deployer_utils.py
+++ b/clipper_admin/clipper_admin/deployers/deployer_utils.py
@@ -1,7 +1,7 @@
from __future__ import print_function, with_statement, absolute_import
import logging
-from .cloudpickle import CloudPickler
+from cloudpickle import CloudPickler
from .module_dependency import ModuleDependencyAnalyzer
from ..clipper_admin import CLIPPER_TEMP_DIR
import six
@@ -22,6 +22,13 @@
logger = logging.getLogger(__name__)
+def serialize_object(obj):
+ s = six.StringIO()
+ c = CloudPickler(s, 2)
+ c.dump(obj)
+ return s.getvalue()
+
+
def save_python_function(name, func):
predict_fname = "func.pkl"
environment_fname = "environment.yml"
diff --git a/clipper_admin/clipper_admin/deployers/pytorch.py b/clipper_admin/clipper_admin/deployers/pytorch.py
new file mode 100644
index 000000000..dd3e63844
--- /dev/null
+++ b/clipper_admin/clipper_admin/deployers/pytorch.py
@@ -0,0 +1,191 @@
+from __future__ import print_function, with_statement, absolute_import
+import shutil
+import torch
+import logging
+import re
+import os
+import json
+
+from ..version import __version__
+from ..clipper_admin import ClipperException
+from .deployer_utils import save_python_function, serialize_object
+
+logger = logging.getLogger(__name__)
+
+PYTORCH_WEIGHTS_RELATIVE_PATH = "pytorch_weights.pkl"
+PYTORCH_MODEL_RELATIVE_PATH = "pytorch_model.pkl"
+
+
+def create_endpoint(
+ clipper_conn,
+ name,
+ input_type,
+ func,
+ pytorch_model,
+ default_output="None",
+ version=1,
+ slo_micros=3000000,
+ labels=None,
+ registry=None,
+ base_image="clipper/pytorch-container:{}".format(__version__),
+ num_replicas=1):
+ """Registers an app and deploys the provided predict function with PyTorch model as
+ a Clipper model.
+ Parameters
+ ----------
+ clipper_conn : :py:meth:`clipper_admin.ClipperConnection`
+ A ``ClipperConnection`` object connected to a running Clipper cluster.
+ name : str
+ The name to be assigned to both the registered application and deployed model.
+ input_type : str
+ The input_type to be associated with the registered app and deployed model.
+ One of "integers", "floats", "doubles", "bytes", or "strings".
+ func : function
+ The prediction function. Any state associated with the function will be
+ captured via closure capture and pickled with Cloudpickle.
+ pytorch_model : pytorch model object
+ The PyTorch model to save.
+ default_output : str, optional
+ The default output for the application. The default output will be returned whenever
+ an application is unable to receive a response from a model within the specified
+ query latency SLO (service level objective). The reason the default output was returned
+ is always provided as part of the prediction response object. Defaults to "None".
+ version : str, optional
+ The version to assign this model. Versions must be unique on a per-model
+ basis, but may be re-used across different models.
+ slo_micros : int, optional
+ The query latency objective for the application in microseconds.
+ This is the processing latency between Clipper receiving a request
+ and sending a response. It does not account for network latencies
+ before a request is received or after a response is sent.
+ If Clipper cannot process a query within the latency objective,
+ the default output is returned. Therefore, it is recommended that
+ the SLO not be set aggressively low unless absolutely necessary.
+ 100000 (100ms) is a good starting value, but the optimal latency objective
+ will vary depending on the application.
+ labels : list(str), optional
+ A list of strings annotating the model. These are ignored by Clipper
+ and used purely for user annotations.
+ registry : str, optional
+ The Docker container registry to push the freshly built model to. Note
+ that if you are running Clipper on Kubernetes, this registry must be accesible
+ to the Kubernetes cluster in order to fetch the container from the registry.
+ base_image : str, optional
+ The base Docker image to build the new model image from. This
+ image should contain all code necessary to run a Clipper model
+ container RPC client.
+ num_replicas : int, optional
+ The number of replicas of the model to create. The number of replicas
+ for a model can be changed at any time with
+ :py:meth:`clipper.ClipperConnection.set_num_replicas`.
+ """
+
+ clipper_conn.register_application(name, input_type, default_output,
+ slo_micros)
+ deploy_pytorch_model(clipper_conn, name, version, input_type, func,
+ pytorch_model, base_image, labels, registry,
+ num_replicas)
+
+ clipper_conn.link_model_to_app(name, name)
+
+
+def deploy_pytorch_model(
+ clipper_conn,
+ name,
+ version,
+ input_type,
+ func,
+ pytorch_model,
+ base_image="clipper/pytorch-container:{}".format(__version__),
+ labels=None,
+ registry=None,
+ num_replicas=1):
+ """Deploy a Python function with a PyTorch model.
+ Parameters
+ ----------
+ clipper_conn : :py:meth:`clipper_admin.ClipperConnection`
+ A ``ClipperConnection`` object connected to a running Clipper cluster.
+ name : str
+ The name to be assigned to both the registered application and deployed model.
+ version : str
+ The version to assign this model. Versions must be unique on a per-model
+ basis, but may be re-used across different models.
+ input_type : str
+ The input_type to be associated with the registered app and deployed model.
+ One of "integers", "floats", "doubles", "bytes", or "strings".
+ func : function
+ The prediction function. Any state associated with the function will be
+ captured via closure capture and pickled with Cloudpickle.
+ pytorch_model : pytorch model object
+ The Pytorch model to save.
+ base_image : str, optional
+ The base Docker image to build the new model image from. This
+ image should contain all code necessary to run a Clipper model
+ container RPC client.
+ labels : list(str), optional
+ A list of strings annotating the model. These are ignored by Clipper
+ and used purely for user annotations.
+ registry : str, optional
+ The Docker container registry to push the freshly built model to. Note
+ that if you are running Clipper on Kubernetes, this registry must be accesible
+ to the Kubernetes cluster in order to fetch the container from the registry.
+ num_replicas : int, optional
+ The number of replicas of the model to create. The number of replicas
+ for a model can be changed at any time with
+ :py:meth:`clipper.ClipperConnection.set_num_replicas`.
+ Example
+ -------
+
+ from clipper_admin import ClipperConnection, DockerContainerManager
+ from clipper_admin.deployers.pytorch import deploy_pytorch_model
+ from torch import nn
+
+ clipper_conn = ClipperConnection(DockerContainerManager())
+
+ # Connect to an already-running Clipper cluster
+ clipper_conn.connect()
+
+ model = nn.Linear(1,1)
+
+ #define a shift function to normalize prediction inputs
+ def predict(model, inputs):
+ pred = model(shift(inputs))
+ pred = pred.data.numpy()
+ return [str(x) for x in pred]
+
+ deploy_pytorch_model(
+ clipper_conn,
+ name="example",
+ version = 1,
+ input_type="doubles",
+ func=predict,
+ pytorch_model=model)
+ """
+
+ serialization_dir = save_python_function(name, func)
+
+ # save Torch model
+ torch_weights_save_loc = os.path.join(serialization_dir,
+ PYTORCH_WEIGHTS_RELATIVE_PATH)
+
+ torch_model_save_loc = os.path.join(serialization_dir,
+ PYTORCH_MODEL_RELATIVE_PATH)
+
+ try:
+ torch.save(pytorch_model.state_dict(), torch_weights_save_loc)
+ serialized_model = serialize_object(pytorch_model)
+ with open(torch_model_save_loc, "w") as serialized_model_file:
+ serialized_model_file.write(serialized_model)
+
+ except Exception as e:
+ logger.warn("Error saving torch model: %s" % e)
+
+ logger.info("Torch model saved")
+
+ # Deploy model
+ clipper_conn.build_and_deploy_model(name, version, input_type,
+ serialization_dir, base_image, labels,
+ registry, num_replicas)
+
+ # Remove temp files
+ shutil.rmtree(serialization_dir)
diff --git a/clipper_admin/setup.py b/clipper_admin/setup.py
index f6e540cdc..428bfde7e 100644
--- a/clipper_admin/setup.py
+++ b/clipper_admin/setup.py
@@ -26,12 +26,8 @@
package_data={'clipper_admin': ['*.txt', '*/*.yaml']},
keywords=['clipper', 'prediction', 'model', 'management'],
install_requires=[
- 'requests',
- 'subprocess32',
- 'pyyaml',
- 'docker',
- 'kubernetes',
- 'six',
+ 'requests', 'subprocess32', 'pyyaml', 'docker', 'kubernetes', 'six',
+ 'cloudpickle>=0.5.2'
],
extras_require={
'PySpark': ['pyspark'],
diff --git a/containers/python/pyspark_container.py b/containers/python/pyspark_container.py
index 939b9533c..c8834259b 100644
--- a/containers/python/pyspark_container.py
+++ b/containers/python/pyspark_container.py
@@ -5,9 +5,9 @@
import json
import numpy as np
+import cloudpickle
# sys.path.append(os.path.abspath("/lib/"))
-from clipper_admin.deployers import cloudpickle
import pyspark
from pyspark import SparkConf, SparkContext
diff --git a/containers/python/python_closure_container.py b/containers/python/python_closure_container.py
index fc6693bc2..9dbd61460 100644
--- a/containers/python/python_closure_container.py
+++ b/containers/python/python_closure_container.py
@@ -3,8 +3,7 @@
import os
import sys
import numpy as np
-
-from clipper_admin.deployers import cloudpickle
+import cloudpickle
IMPORT_ERROR_RETURN_CODE = 3
@@ -46,7 +45,7 @@ def predict_strings(self, inputs):
if __name__ == "__main__":
- print("Starting PythonContainer container")
+ print("Starting Python Closure container")
try:
model_name = os.environ["CLIPPER_MODEL_NAME"]
except KeyError:
diff --git a/containers/python/pytorch_container.py b/containers/python/pytorch_container.py
new file mode 100644
index 000000000..8ec8b1c95
--- /dev/null
+++ b/containers/python/pytorch_container.py
@@ -0,0 +1,117 @@
+from __future__ import print_function
+import rpc
+import os
+import sys
+import json
+
+import numpy as np
+import cloudpickle
+import torch
+import importlib
+from torch import nn
+from torch.autograd import Variable
+import torch.nn.functional as F
+
+IMPORT_ERROR_RETURN_CODE = 3
+
+PYTORCH_WEIGHTS_RELATIVE_PATH = "pytorch_weights.pkl"
+PYTORCH_MODEL_RELATIVE_PATH = "pytorch_model.pkl"
+
+
+def load_predict_func(file_path):
+ with open(file_path, 'r') as serialized_func_file:
+ return cloudpickle.load(serialized_func_file)
+
+
+def load_pytorch_model(model_path, weights_path):
+ with open(model_path, 'r') as serialized_model_file:
+ model = cloudpickle.load(serialized_model_file)
+
+ model.load_state_dict(torch.load(weights_path))
+ return model
+
+
+class PyTorchContainer(rpc.ModelContainerBase):
+ def __init__(self, path, input_type):
+ self.input_type = rpc.string_to_input_type(input_type)
+ modules_folder_path = "{dir}/modules/".format(dir=path)
+ sys.path.append(os.path.abspath(modules_folder_path))
+ predict_fname = "func.pkl"
+ predict_path = "{dir}/{predict_fname}".format(
+ dir=path, predict_fname=predict_fname)
+ self.predict_func = load_predict_func(predict_path)
+
+ torch_model_path = os.path.join(path, PYTORCH_MODEL_RELATIVE_PATH)
+ torch_weights_path = os.path.join(path, PYTORCH_WEIGHTS_RELATIVE_PATH)
+ self.model = load_pytorch_model(torch_model_path, torch_weights_path)
+
+ def predict_ints(self, inputs):
+ preds = self.predict_func(self.model, inputs)
+ return [str(p) for p in preds]
+
+ def predict_floats(self, inputs):
+ preds = self.predict_func(self.model, inputs)
+ return [str(p) for p in preds]
+
+ def predict_doubles(self, inputs):
+ preds = self.predict_func(self.model, inputs)
+ return [str(p) for p in preds]
+
+ def predict_bytes(self, inputs):
+ preds = self.predict_func(self.model, inputs)
+ return [str(p) for p in preds]
+
+ def predict_strings(self, inputs):
+ preds = self.predict_func(self.model, inputs)
+ return [str(p) for p in preds]
+
+
+if __name__ == "__main__":
+ print("Starting PyTorchContainer container")
+ try:
+ model_name = os.environ["CLIPPER_MODEL_NAME"]
+ except KeyError:
+ print(
+ "ERROR: CLIPPER_MODEL_NAME environment variable must be set",
+ file=sys.stdout)
+ sys.exit(1)
+ try:
+ model_version = os.environ["CLIPPER_MODEL_VERSION"]
+ except KeyError:
+ print(
+ "ERROR: CLIPPER_MODEL_VERSION environment variable must be set",
+ file=sys.stdout)
+ sys.exit(1)
+
+ ip = "127.0.0.1"
+ if "CLIPPER_IP" in os.environ:
+ ip = os.environ["CLIPPER_IP"]
+ else:
+ print("Connecting to Clipper on localhost")
+
+ port = 7000
+ if "CLIPPER_PORT" in os.environ:
+ port = int(os.environ["CLIPPER_PORT"])
+ else:
+ print("Connecting to Clipper with default port: {port}".format(
+ port=port))
+
+ input_type = "doubles"
+ if "CLIPPER_INPUT_TYPE" in os.environ:
+ input_type = os.environ["CLIPPER_INPUT_TYPE"]
+ else:
+ print("Using default input type: doubles")
+
+ model_path = os.environ["CLIPPER_MODEL_PATH"]
+
+ print("Initializing Pytorch function container")
+ sys.stdout.flush()
+ sys.stderr.flush()
+
+ try:
+ model = PyTorchContainer(model_path, input_type)
+ rpc_service = rpc.RPCService()
+ rpc_service.start(model, ip, port, model_name, model_version,
+ input_type)
+ except ImportError:
+ sys.exit(IMPORT_ERROR_RETURN_CODE)
diff --git a/containers/python/pytorch_container_entry.sh b/containers/python/pytorch_container_entry.sh
new file mode 100755
index 000000000..8f174514b
--- /dev/null
+++ b/containers/python/pytorch_container_entry.sh
@@ -0,0 +1,15 @@
+#!/usr/bin/env sh
+
+IMPORT_ERROR_RETURN_CODE=3
+
+echo "Attempting to run PyTorch container without installing dependencies"
+echo "Contents of /model"
+ls /model/
+/bin/bash -c "exec python /container/pytorch_container.py"
+if [ $? -eq $IMPORT_ERROR_RETURN_CODE ]; then
+ echo "Running PyTorch container without installing dependencies fails"
+ echo "Will install dependencies and try again"
+ conda install -y --file /model/conda_dependencies.txt
+ pip install -r /model/pip_dependencies.txt
+ /bin/bash -c "exec python /container/pytorch_container.py"
+fi
diff --git a/containers/python/tf_container.py b/containers/python/tf_container.py
index 4670a9c4d..4305f5784 100644
--- a/containers/python/tf_container.py
+++ b/containers/python/tf_container.py
@@ -3,10 +3,9 @@
import os
import sys
import tensorflow as tf
+import cloudpickle
import glob
-from clipper_admin.deployers import cloudpickle
-
def load_predict_func(file_path):
with open(file_path, 'r') as serialized_func_file:
diff --git a/dockerfiles/ClipperTestsDockerfile b/dockerfiles/ClipperTestsDockerfile
index 90c8f6934..885835c5f 100644
--- a/dockerfiles/ClipperTestsDockerfile
+++ b/dockerfiles/ClipperTestsDockerfile
@@ -17,7 +17,7 @@ RUN echo 'export PATH=/opt/conda/bin:$PATH' > /etc/profile.d/conda.sh \
ENV PATH "/opt/conda/bin:$PATH"
RUN conda install -y libgcc pyzmq
-RUN pip install requests subprocess32 scikit-learn numpy pyyaml docker kubernetes pyspark tensorflow
+RUN pip install requests subprocess32 scikit-learn numpy pyyaml docker kubernetes pyspark tensorflow cloudpickle==0.5.2
# Install maven
ARG MAVEN_VERSION=3.5.0
diff --git a/dockerfiles/PyClosureContainerDockerfile b/dockerfiles/PyClosureContainerDockerfile
index c5bbfe90a..050ebc194 100644
--- a/dockerfiles/PyClosureContainerDockerfile
+++ b/dockerfiles/PyClosureContainerDockerfile
@@ -3,10 +3,9 @@ FROM clipper/py-rpc:${CODE_VERSION}
COPY clipper_admin/clipper_admin/python_container_conda_deps.txt /lib/
RUN conda install -y --file /lib/python_container_conda_deps.txt
+RUN conda install -c anaconda cloudpickle=0.5.2
COPY containers/python/python_closure_container.py containers/python/python_closure_container_entry.sh /container/
-COPY clipper_admin/ /lib/clipper_admin
-RUN pip install /lib/clipper_admin
CMD ["/container/python_closure_container_entry.sh"]
diff --git a/dockerfiles/PySparkContainerDockerfile b/dockerfiles/PySparkContainerDockerfile
index a5a48d95b..ce11f837f 100644
--- a/dockerfiles/PySparkContainerDockerfile
+++ b/dockerfiles/PySparkContainerDockerfile
@@ -7,12 +7,10 @@ RUN echo deb http://ftp.de.debian.org/debian jessie-backports main >> /etc/apt/s
&& apt-get update --fix-missing \
&& apt-get install -yqq -t jessie-backports openjdk-8-jdk \
&& conda install -y --file /lib/python_container_conda_deps.txt \
- && pip install pyspark
+ && pip install pyspark \
+ && conda install -c anaconda cloudpickle=0.5.2
COPY containers/python/pyspark_container.py containers/python/pyspark_container_entry.sh /container/
-COPY VERSION.txt /lib/
-COPY clipper_admin/ /lib/clipper_admin
-RUN pip install /lib/clipper_admin
CMD ["/container/pyspark_container_entry.sh"]
diff --git a/dockerfiles/PyTorchContainerDockerfile b/dockerfiles/PyTorchContainerDockerfile
new file mode 100644
index 000000000..d7662ce5a
--- /dev/null
+++ b/dockerfiles/PyTorchContainerDockerfile
@@ -0,0 +1,18 @@
+ARG CODE_VERSION
+FROM clipper/py-rpc:${CODE_VERSION}
+
+COPY clipper_admin/clipper_admin/python_container_conda_deps.txt /lib/
+
+RUN echo deb http://ftp.de.debian.org/debian jessie-backports main >> /etc/apt/sources.list \
+ && apt-get update --fix-missing \
+ && apt-get install -yqq -t jessie-backports openjdk-8-jdk \
+ && conda install -y --file /lib/python_container_conda_deps.txt \
+ && conda install pytorch torchvision -c pytorch \
+ && conda install -c anaconda cloudpickle=0.5.2
+
+
+COPY containers/python/pytorch_container.py containers/python/pytorch_container_entry.sh /container/
+
+CMD ["/container/pytorch_container_entry.sh"]
+
+# vim: set filetype=dockerfile:
diff --git a/dockerfiles/TensorFlowDockerfile b/dockerfiles/TensorFlowDockerfile
index c464f63b0..c873e6320 100644
--- a/dockerfiles/TensorFlowDockerfile
+++ b/dockerfiles/TensorFlowDockerfile
@@ -4,11 +4,10 @@ FROM clipper/py-rpc:${CODE_VERSION}
COPY clipper_admin/clipper_admin/python_container_conda_deps.txt /lib/
RUN conda install -y --file /lib/python_container_conda_deps.txt
-RUN conda install tensorflow
+RUN conda install tensorflow \
+ && conda install -c anaconda cloudpickle=0.5.2
COPY containers/python/tf_container.py containers/python/tf_container_entry.sh /container/
-COPY clipper_admin/ /lib/clipper_admin
-RUN pip install /lib/clipper_admin
CMD ["/container/tf_container_entry.sh"]
diff --git a/integration-tests/deploy_pytorch_models.py b/integration-tests/deploy_pytorch_models.py
new file mode 100644
index 000000000..03491438a
--- /dev/null
+++ b/integration-tests/deploy_pytorch_models.py
@@ -0,0 +1,219 @@
+from __future__ import absolute_import, print_function
+import os
+import sys
+import requests
+import json
+import numpy as np
+import time
+import logging
+
+cur_dir = os.path.dirname(os.path.abspath(__file__))
+
+sys.path.insert(0, os.path.abspath('%s/util_direct_import/' % cur_dir))
+
+from util_package import mock_module_in_package as mmip
+import mock_module as mm
+
+import torch
+from torch.utils.data import DataLoader
+import torch.utils.data as data
+from PIL import Image
+from torch import nn, optim
+from torch.autograd import Variable
+from torchvision import transforms
+import torch.nn.functional as F
+
+from test_utils import (create_docker_connection, BenchmarkException, headers,
+ log_clipper_state)
+cur_dir = os.path.dirname(os.path.abspath(__file__))
+sys.path.insert(0, os.path.abspath("%s/../clipper_admin" % cur_dir))
+
+from clipper_admin.deployers.pytorch import deploy_pytorch_model, create_endpoint
+
+logging.basicConfig(
+ format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
+ datefmt='%y-%m-%d:%H:%M:%S',
+ level=logging.INFO)
+
+logger = logging.getLogger(__name__)
+
+app_name = "pytorch-test"
+model_name = "pytorch-model"
+
+
+def normalize(x):
+ return x.astype(np.double) / 255.0
+
+
+def objective(y, pos_label):
+ # prediction objective
+ if y == pos_label:
+ return 1
+ else:
+ return 0
+
+
+def parsedata(train_path, pos_label):
+ trainData = np.genfromtxt(train_path, delimiter=',', dtype=int)
+ records = trainData[:, 1:]
+ labels = trainData[:, :1]
+ transformedlabels = [objective(ele, pos_label) for ele in labels]
+ return (records, transformedlabels)
+
+
+def predict(model, xs):
+ preds = model(xs)
+ preds = [preds.data.numpy().tolist()[0]]
+ return [str(p) for p in preds]
+
+
+def deploy_and_test_model(clipper_conn,
+ model,
+ version,
+ link_model=False,
+ predict_fn=predict):
+ deploy_pytorch_model(clipper_conn, model_name, version, "integers",
+ predict_fn, model)
+
+ time.sleep(5)
+
+ if link_model:
+ clipper_conn.link_model_to_app(app_name, model_name)
+ time.sleep(5)
+
+ test_model(clipper_conn, app_name, version)
+
+
+def test_model(clipper_conn, app, version):
+ time.sleep(25)
+ num_preds = 25
+ num_defaults = 0
+ addr = clipper_conn.get_query_addr()
+ for i in range(num_preds):
+ response = requests.post(
+ "http://%s/%s/predict" % (addr, app),
+ headers=headers,
+ data=json.dumps({
+ 'input': get_test_point()
+ }))
+ result = response.json()
+ if response.status_code == requests.codes.ok and result["default"]:
+ num_defaults += 1
+ elif response.status_code != requests.codes.ok:
+ print(result)
+ raise BenchmarkException(response.text)
+
+ if num_defaults > 0:
+ print("Error: %d/%d predictions were default" % (num_defaults,
+ num_preds))
+ if num_defaults > num_preds / 2:
+ raise BenchmarkException("Error querying APP %s, MODEL %s:%d" %
+ (app, model_name, version))
+
+
+# Define a simple NN model
+class BasicNN(nn.Module):
+ def __init__(self):
+ super(BasicNN, self).__init__()
+ self.net = nn.Linear(28 * 28, 2)
+
+ def forward(self, x):
+ if type(x) == np.ndarray:
+ x = torch.from_numpy(x)
+ x = x.float()
+ x = Variable(x)
+ x = x.view(1, 1, 28, 28)
+ x = x / 255.0
+ batch_size = x.size(0)
+ x = x.view(batch_size, -1)
+ output = self.net(x.float())
+ return F.softmax(output)
+
+
+def train(model):
+ model.train()
+ optimizer = optim.SGD(model.parameters(), lr=0.001)
+ for epoch in range(10):
+ for i, data in enumerate(train_loader, 1):
+ image, j = data
+ optimizer.zero_grad()
+ output = model(image)
+ loss = F.cross_entropy(
+ output, Variable(torch.LongTensor([train_y[i - 1]])))
+ loss.backward()
+ optimizer.step()
+ return model
+
+
+def get_test_point():
+ return [np.random.randint(255) for _ in range(784)]
+
+
+#Define a dataloader to read data
+class TrainingDataset(data.Dataset):
+ def __init__(self, data, label):
+ self.imgs = data
+ self.classes = label
+
+ def __getitem__(self, index):
+ img = self.imgs[index]
+ label = self.classes[index]
+ img = torch.Tensor(img)
+ return img, torch.Tensor(label)
+
+
+if __name__ == "__main__":
+ pos_label = 3
+ try:
+ clipper_conn = create_docker_connection(
+ cleanup=True, start_clipper=True)
+
+ train_path = os.path.join(cur_dir, "data/train.data")
+ train_x, train_y = parsedata(train_path, pos_label)
+ train_x = normalize(train_x)
+ train_loader = TrainingDataset(train_x, train_y)
+
+ try:
+ clipper_conn.register_application(app_name, "integers",
+ "default_pred", 100000)
+ time.sleep(1)
+
+ addr = clipper_conn.get_query_addr()
+ response = requests.post(
+ "http://%s/%s/predict" % (addr, app_name),
+ headers=headers,
+ data=json.dumps({
+ 'input': get_test_point()
+ }))
+ result = response.json()
+ if response.status_code != requests.codes.ok:
+ print("Error: %s" % response.text)
+ raise BenchmarkException("Error creating app %s" % app_name)
+
+ version = 1
+
+ model = BasicNN()
+ nn_model = train(model)
+
+ deploy_and_test_model(
+ clipper_conn, nn_model, version, link_model=True)
+
+ app_and_model_name = "easy-register-app-model"
+ create_endpoint(clipper_conn, app_and_model_name, "integers",
+ predict, nn_model)
+ test_model(clipper_conn, app_and_model_name, 1)
+
+ except BenchmarkException as e:
+ log_clipper_state(clipper_conn)
+ logger.exception("BenchmarkException")
+ clipper_conn = create_docker_connection(
+ cleanup=True, start_clipper=False)
+ sys.exit(1)
+ else:
+ clipper_conn = create_docker_connection(
+ cleanup=True, start_clipper=False)
+ except Exception as e:
+ logger.exception("Exception")
+ clipper_conn = create_docker_connection(
+ cleanup=True, start_clipper=False)
+ sys.exit(1)