Skip to content

Commit

Permalink
ONNX support, bug fix, user friendly APIs (#10)
Browse files Browse the repository at this point in the history
* bumbed version through version.py file

* gitignore for built files

* Readme.md linked to example repo

* assets for onnx tests

* onnx support

* user friendly apis

* tests for onnx support and new user friendly APIs

* minor nit

* testing against redisai edge

* pandas to test requirments
  • Loading branch information
Sherin Thomas authored and lantiga committed Jun 23, 2019
1 parent ab579ae commit 6d4174c
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 129 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
build:
docker:
- image: circleci/python:3.7.1
- image: redisai/redisai:latest
- image: redisai/redisai:edge

working_directory: ~/repo

Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.project
.pydevproject
*.pyc
.venv/
.venv/
redisai.egg-info
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@
$ pip install redisai
```

[RedisAI example repo](https://github.com/RedisAI/redisai-examples) shows few examples made using redisai-py under `python_client` section.


21 changes: 21 additions & 0 deletions redisai/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,22 @@
from .version import __version__
from .client import (Client, Tensor, BlobTensor, DType, Device, Backend)


def save_model(*args, **kwargs):
"""
Importing inside to avoid loading the TF/PyTorch/ONNX
into the scope unnecessary. This function wraps the
internal save model utility to make it user friendly
"""
from .model import Model
Model.save(*args, **kwargs)


def load_model(*args, **kwargs):
"""
Importing inside to avoid loading the TF/PyTorch/ONNX
into the scope unnecessary. This function wraps the
internal load model utility to make it user friendly
"""
from .model import Model
return Model.load(*args, **kwargs)
41 changes: 20 additions & 21 deletions redisai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,31 @@


class Device(Enum):
cpu = 'cpu'
gpu = 'gpu'
cpu = 'CPU'
gpu = 'GPU'


class Backend(Enum):
tf = 'tf'
torch = 'torch'
onnx = 'ort'
tf = 'TF'
torch = 'TORCH'
onnx = 'ONNX'


class DType(Enum):
float = 'float'
double = 'double'
int8 = 'int8'
int16 = 'int16'
int32 = 'int32'
int64 = 'int64'
uint8 = 'uint8'
uint16 = 'uint16'
uint32 = 'uint32'
uint64 = 'uint64'
float = 'FLOAT'
double = 'DOUBLE'
int8 = 'INT8'
int16 = 'INT16'
int32 = 'INT32'
int64 = 'INT64'
uint8 = 'UINT8'
uint16 = 'UINT16'
uint32 = 'UINT32'
uint64 = 'UINT64'

# aliases
float32 = 'float'
float64 = 'double'
float32 = 'FLOAT'
float64 = 'DOUBLE'


def _str_or_strlist(v):
Expand All @@ -54,7 +54,7 @@ def _convert_to_num(dt, arr):
if isinstance(obj, list):
_convert_to_num(obj)
else:
if dt in (DType.float, DType.double):
if dt in (DType.float.value, DType.double.value):
arr[ix] = float(obj)
else:
arr[ix] = int(obj)
Expand Down Expand Up @@ -159,10 +159,9 @@ def to_numpy(self):

@staticmethod
def _to_numpy_type(t):
t = t.lower()
mm = {
'float': 'float32',
'double': 'float64'
'FLOAT': 'float32',
'DOUBLE': 'float64'
}
if t in mm:
return mm[t]
Expand Down
88 changes: 57 additions & 31 deletions redisai/model.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
import pickle
import os
import warnings

from .client import Device, Backend
import sys

try:
import tensorflow as tf
except (ModuleNotFoundError, ImportError):
pass # that's Okey if you don't have TF
pass

try:
import torch
except (ModuleNotFoundError, ImportError):
pass # it's Okey if you don't have PT either
pass

try:
import onnx
except (ModuleNotFoundError, ImportError):
pass

try:
import skl2onnx
import sklearn
except (ModuleNotFoundError, ImportError):
pass


class Model:

__slots__ = ['graph', 'backend', 'device', 'inputs', 'outputs']
def __init__(self, path, device=Device.cpu, inputs=None, outputs=None):

def __init__(self, path, device=None, inputs=None, outputs=None):
"""
Declare a model suitable for passing to modelset
:param path: Filepath from where the stored model can be read
Expand All @@ -37,9 +45,9 @@ def __init__(self, path, device=Device.cpu, inputs=None, outputs=None):
raise NotImplementedError('Instance creation is not impelemented yet')

@classmethod
def save(cls, obj, path: str, input=None, output=None, as_native=True):
def save(cls, obj, path: str, input=None, output=None, as_native=True, prototype=None):
"""
Infer the backend (TF/PyTorch) by inspecting the class hierarchy
Infer the backend (TF/PyTorch/ONNX) by inspecting the class hierarchy
and calls the appropriate serialization utility. It is essentially a
wrapper over serialization mechanism of each backend
:param path: Path to which the graph/model will be saved
Expand All @@ -54,15 +62,25 @@ def save(cls, obj, path: str, input=None, output=None, as_native=True):
mechanism if True. If False, custom saving utility will be called
which saves other informations required for modelset. Defaults to True
"""
if issubclass(type(obj), tf.Session):
if 'tensorflow' in sys.modules and issubclass(type(obj), tf.Session):
cls._save_tf_graph(obj, path, output, as_native)
elif issubclass(type(type(obj)), torch.jit.ScriptMeta):
elif 'torch' in sys.modules and issubclass(
type(type(obj)), torch.jit.ScriptMeta):
# TODO Is there a better way to check this
cls._save_pt_graph(obj, path, as_native)
cls._save_torch_graph(obj, path, as_native)
elif 'onnx' in sys.modules and issubclass(
type(obj), onnx.onnx_ONNX_RELEASE_ml_pb2.ModelProto):
cls._save_onnx_graph(obj, path, as_native)
elif 'skl2onnx' in sys.modules and issubclass(
type(obj), sklearn.base.BaseEstimator):
cls._save_sklearn_graph(obj, path, as_native, prototype)
else:
raise TypeError(('Invalid Object. '
'Need traced graph or scripted graph from PyTorch or '
'Session object from Tensorflow'))
message = ("Could not find the required dependancy to export the graph object. "
"`save_model` relies on serialization mechanism provided by the"
" supported backends such as Tensorflow, PyTorch, ONNX or skl2onnx. "
"Please install package required for serializing your graph. "
"For more information, checkout the redisia-py documentation")
raise RuntimeError(message)

@classmethod
def _save_tf_graph(cls, sess, path, output, as_native):
Expand All @@ -81,10 +99,10 @@ def _save_tf_graph(cls, sess, path, output, as_native):
raise NotImplementedError('Saving non-native graph is not supported yet')

@classmethod
def _save_pt_graph(cls, graph, path, as_native):
def _save_torch_graph(cls, graph, path, as_native):
# TODO how to handle the cpu/gpu
if as_native:
if graph.training == True:
if graph.training is True:
warnings.warn(
'Graph is in training mode. Converting to evaluation mode')
graph.eval()
Expand All @@ -93,25 +111,33 @@ def _save_pt_graph(cls, graph, path, as_native):
else:
raise NotImplementedError('Saving non-native graph is not supported yet')

@staticmethod
def _get_filled_dict(graph, backend, input=None, output=None):
return {
'graph': graph,
'backend': backend,
'input': input,
'output': output}
@classmethod
def _save_onnx_graph(cls, graph, path, as_native):
if as_native:
with open(path, 'wb') as f:
f.write(graph.SerializeToString())
else:
raise NotImplementedError('Saving non-native graph is not supported yet')

@staticmethod
def _write_custom_model(outdict, path):
with open(path, 'wb') as file:
pickle.dump(outdict, file)
@classmethod
def _save_sklearn_graph(cls, graph, path, as_native, prototype):
if not as_native:
raise NotImplementedError('Saving non-native graph is not supported yet')
if hasattr(prototype, 'shape') and hasattr(prototype, 'dtype'):
datatype = skl2onnx.common.data_types.guess_data_type(prototype)
serialized = skl2onnx.convert_sklearn(graph, initial_types=datatype)
cls._save_onnx_graph(serialized, path, as_native)
else:
raise TypeError(
"Serializing scikit learn model needs to know shape and dtype"
" of input data which will be inferred from `prototype` "
"parameter. It has to be a valid `numpy.ndarray` of shape of your input")

@classmethod
def load(cls, path:str):
def load(cls, path: str):
"""
Return the binary data if saved with `as_native` otherwise return the dict
that contains binary graph/model on `graph` key. Check `_get_filled_dict`
for more details.
that contains binary graph/model on `graph` key (Not implemented yet).
:param path: File path from where the native model or the rai models are saved
"""
with open(path, 'rb') as f:
Expand Down
5 changes: 5 additions & 0 deletions redisai/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Store the version here so:
# 1) we don't load dependencies by storing it in __init__.py
# 2) we can import it in setup.py for the same reason
# 3) we can import it into your module module
__version__ = '0.3.0'
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#!/usr/bin/env python
from setuptools import setup, find_packages

exec(open('redisai/version.py').read())

setup(
name='redisai',
version='0.2.0',
version=__version__, # comes from redisai/version.py

description='RedisAI Python Client',
url='http://github.com/RedisAI/redisai-py',
Expand Down
5 changes: 4 additions & 1 deletion test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
numpy
torch
tensorflow
tensorflow
onnx
skl2onnx
pandas
37 changes: 29 additions & 8 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import os.path
from redisai import Client, DType, Backend, Device, Tensor, BlobTensor
from redisai import load_model
from redis.exceptions import ResponseError


Expand Down Expand Up @@ -44,14 +45,11 @@ def test_numpy_tensor(self):
self.assertEqual([2, 3], values)

def test_run_tf_model(self):
model = os.path.join(MODEL_DIR, 'graph.pb')
bad_model = os.path.join(MODEL_DIR, 'pt-minimal.pt')
model_path = os.path.join(MODEL_DIR, 'graph.pb')
bad_model_path = os.path.join(MODEL_DIR, 'pt-minimal.pt')

with open(model, 'rb') as f:
model_pb = f.read()

with open(bad_model, 'rb') as f:
wrong_model_pb = f.read()
model_pb = load_model(model_path)
wrong_model_pb = load_model(bad_model_path)

con = self.get_client()
con.modelset('m', Backend.tf, Device.cpu, model_pb,
Expand Down Expand Up @@ -96,5 +94,28 @@ def bar(a, b):
tensor = con.tensorget('c')
self.assertEqual([4, 6], tensor.value)

def test_run_onnxml_model(self):
mlmodel_path = os.path.join(MODEL_DIR, 'boston.onnx')
onnxml_model = load_model(mlmodel_path)
con = self.get_client()
con.modelset("onnx_model", Backend.onnx, Device.cpu, onnxml_model)
tensor = BlobTensor.from_numpy(np.ones((1, 13), dtype=np.float32))
con.tensorset("input", tensor)
con.modelrun("onnx_model", ["input"], ["output"])
outtensor = con.tensorget("output")
self.assertEqual(int(outtensor.value[0]), 24)

def test_run_onnxdl_model(self):
# A PyTorch model that finds the square
dlmodel_path = os.path.join(MODEL_DIR, 'findsquare.onnx')
onnxdl_model = load_model(dlmodel_path)
con = self.get_client()
con.modelset("onnx_model", Backend.onnx, Device.cpu, onnxdl_model)
tensor = BlobTensor.from_numpy(np.array((2, 3), dtype=np.float32))
con.tensorset("input", tensor)
con.modelrun("onnx_model", ["input"], ["output"])
outtensor = con.tensorget("output")
self.assertEqual(outtensor.value, [4.0, 9.0])


# TODO: image/blob tests; more numpy tests..
# TODO: image/blob tests; more numpy tests..
Loading

0 comments on commit 6d4174c

Please sign in to comment.