Skip to content

Commit

Permalink
pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
hhsecond committed May 28, 2020
1 parent 9ffe297 commit 168c4a5
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 94 deletions.
252 changes: 161 additions & 91 deletions redisai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings

from redis import StrictRedis
from redis.client import Pipeline as RedisPipeline
import numpy as np

from . import command_builder as builder
Expand All @@ -12,78 +13,6 @@
processor = Processor()


def enable_debug(f):
@wraps(f)
def wrapper(*args):
print(*args)
return f(*args)
return wrapper


class Dag:
def __init__(self, load, persist, executor, readonly=False):
self.result_processors = []
if readonly:
if persist:
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
"have PERSISTing values")
self.commands = ['AI.DAGRUN_RO']
else:
self.commands = ['AI.DAGRUN']
if load:
if not isinstance(load, (list, tuple)):
self.commands += ["LOAD", 1, load]
else:
self.commands += ["LOAD", len(load), *load]
if persist:
if not isinstance(persist, (list, tuple)):
self.commands += ["PERSIST", 1, persist, '|>']
else:
self.commands += ["PERSIST", len(persist), *persist, '|>']
elif load:
self.commands.append('|>')
self.executor = executor

def tensorset(self,
key: AnyStr,
tensor: Union[np.ndarray, list, tuple],
shape: Sequence[int] = None,
dtype: str = None) -> Any:
args = builder.tensorset(key, tensor, shape, dtype)
self.commands.extend(args)
self.commands.append("|>")
self.result_processors.append(bytes.decode)
return self

def tensorget(self,
key: AnyStr, as_numpy: bool = True,
meta_only: bool = False) -> Any:
args = builder.tensorget(key, as_numpy, meta_only)
self.commands.extend(args)
self.commands.append("|>")
self.result_processors.append(partial(processor.tensorget,
as_numpy=as_numpy,
meta_only=meta_only))
return self

def modelrun(self,
key: AnyStr,
inputs: Union[AnyStr, List[AnyStr]],
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
args = builder.modelrun(key, inputs, outputs)
self.commands.extend(args)
self.commands.append("|>")
self.result_processors.append(bytes.decode)
return self

def run(self):
results = self.executor(*self.commands)
out = []
for res, fn in zip(results, self.result_processors):
out.append(fn(res))
return out


class Client(StrictRedis):
"""
Redis client build specifically for the RedisAI module. It takes all the necessary
Expand All @@ -96,20 +25,47 @@ class Client(StrictRedis):
debug : bool
If debug mode is ON, then each command that is sent to the server is
printed to the terminal
enable_postprocess : bool
Flag to enable post processing. If enabled, all the bytestring-ed returns
are converted to python strings recursively and key value pairs will be converted
to dictionaries. Also note that, this flag doesn't work with pipeline() function
since pipeline function could have native redis commands (along with RedisAI
commands)
Example
-------
>>> from redisai import Client
>>> con = Client(host='localhost', port=6379)
"""
def __init__(self, debug=False, *args, **kwargs):
def __init__(self, debug=False, enable_postprocess=True, *args, **kwargs):
super().__init__(*args, **kwargs)
if debug:
self.execute_command = enable_debug(super().execute_command)
self.enable_postprocess = enable_postprocess

def pipeline(self, transaction: bool = True, shard_hint: bool = None) -> 'Pipeline':
"""
It follows the same pipeline implementation of native redis client but enables it
to access redisai operation as well. This function is experimental in the
current release.
Example
-------
>>> pipe = con.pipeline(transaction=False)
>>> pipe = pipe.set('nativeKey', 1)
>>> pipe = pipe.tensorset('redisaiKey', np.array([1, 2]))
>>> pipe.execute()
[True, b'OK']
"""
return Pipeline(self.enable_postprocess,
self.connection_pool,
self.response_callbacks,
transaction=True, shard_hint=None)

def dag(self, load: Sequence = None, persist: Sequence = None,
readonly: bool = False) -> Dag:
""" It returns a DAG object on which other DAG-allowed operations can be called. For
readonly: bool = False) -> 'Dag':
"""
It returns a DAG object on which other DAG-allowed operations can be called. For
more details about DAG in RedisAI, refer to the RedisAI documentation.
Parameters
Expand Down Expand Up @@ -141,7 +97,7 @@ def dag(self, load: Sequence = None, persist: Sequence = None,
>>> # You can even chain the operations
>>> result = dag.tensorset(**akwargs).modelrun(**bkwargs).tensorget(**ckwargs).run()
"""
return Dag(load, persist, self.execute_command, readonly)
return Dag(load, persist, self.execute_command, readonly, self.enable_postprocess)

def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
"""
Expand All @@ -168,7 +124,7 @@ def loadbackend(self, identifier: AnyStr, path: AnyStr) -> str:
"""
args = builder.loadbackend(identifier, path)
res = self.execute_command(*args)
return processor.loadbackend(res)
return res if not self.enable_postprocess else processor.loadbackend(res)

def modelset(self,
key: AnyStr,
Expand Down Expand Up @@ -227,7 +183,7 @@ def modelset(self,
args = builder.modelset(key, backend, device, data,
batch, minbatch, tag, inputs, outputs)
res = self.execute_command(*args)
return processor.modelset(res)
return res if not self.enable_postprocess else processor.modelset(res)

def modelget(self, key: AnyStr, meta_only=False) -> dict:
"""
Expand All @@ -253,7 +209,7 @@ def modelget(self, key: AnyStr, meta_only=False) -> dict:
"""
args = builder.modelget(key, meta_only)
res = self.execute_command(*args)
return processor.modelget(res)
return res if not self.enable_postprocess else processor.modelget(res)

def modeldel(self, key: AnyStr) -> str:
"""
Expand All @@ -276,7 +232,7 @@ def modeldel(self, key: AnyStr) -> str:
"""
args = builder.modeldel(key)
res = self.execute_command(*args)
return processor.modeldel(res)
return res if not self.enable_postprocess else processor.modeldel(res)

def modelrun(self,
key: AnyStr,
Expand Down Expand Up @@ -318,7 +274,7 @@ def modelrun(self,
"""
args = builder.modelrun(key, inputs, outputs)
res = self.execute_command(*args)
return processor.modelrun(res)
return res if not self.enable_postprocess else processor.modelrun(res)

def modelscan(self) -> List[List[AnyStr]]:
"""
Expand All @@ -340,7 +296,7 @@ def modelscan(self) -> List[List[AnyStr]]:
"in the future without any notice", UserWarning)
args = builder.modelscan()
res = self.execute_command(*args)
return processor.modelscan(res)
return res if not self.enable_postprocess else processor.modelscan(res)

def tensorset(self,
key: AnyStr,
Expand Down Expand Up @@ -376,7 +332,7 @@ def tensorset(self,
"""
args = builder.tensorset(key, tensor, shape, dtype)
res = self.execute_command(*args)
return processor.tensorset(res)
return res if not self.enable_postprocess else processor.tensorset(res)

def tensorget(self,
key: AnyStr, as_numpy: bool = True,
Expand Down Expand Up @@ -412,7 +368,8 @@ def tensorget(self,
"""
args = builder.tensorget(key, as_numpy, meta_only)
res = self.execute_command(*args)
return processor.tensorget(res, as_numpy, meta_only)
return res if not self.enable_postprocess else processor.tensorget(res,
as_numpy, meta_only)

def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -> str:
"""
Expand Down Expand Up @@ -456,7 +413,7 @@ def scriptset(self, key: AnyStr, device: str, script: str, tag: AnyStr = None) -
"""
args = builder.scriptset(key, device, script, tag)
res = self.execute_command(*args)
return processor.scriptset(res)
return res if not self.enable_postprocess else processor.scriptset(res)

def scriptget(self, key: AnyStr, meta_only=False) -> dict:
"""
Expand All @@ -481,7 +438,7 @@ def scriptget(self, key: AnyStr, meta_only=False) -> dict:
"""
args = builder.scriptget(key, meta_only)
res = self.execute_command(*args)
return processor.scriptget(res)
return res if not self.enable_postprocess else processor.scriptget(res)

def scriptdel(self, key: AnyStr) -> str:
"""
Expand All @@ -504,7 +461,7 @@ def scriptdel(self, key: AnyStr) -> str:
"""
args = builder.scriptdel(key)
res = self.execute_command(*args)
return processor.scriptdel(res)
return res if not self.enable_postprocess else processor.scriptdel(res)

def scriptrun(self,
key: AnyStr,
Expand Down Expand Up @@ -540,7 +497,7 @@ def scriptrun(self,
"""
args = builder.scriptrun(key, function, inputs, outputs)
res = self.execute_command(*args)
return processor.scriptrun(res)
return res if not self.enable_postprocess else processor.scriptrun(res)

def scriptscan(self) -> List[List[AnyStr]]:
"""
Expand All @@ -561,7 +518,7 @@ def scriptscan(self) -> List[List[AnyStr]]:
"in the future without any notice", UserWarning)
args = builder.scriptscan()
res = self.execute_command(*args)
return processor.scriptscan(res)
return res if not self.enable_postprocess else processor.scriptscan(res)

def infoget(self, key: AnyStr) -> dict:
"""
Expand Down Expand Up @@ -590,7 +547,7 @@ def infoget(self, key: AnyStr) -> dict:
"""
args = builder.infoget(key)
res = self.execute_command(*args)
return processor.infoget(res)
return res if not self.enable_postprocess else processor.infoget(res)

def inforeset(self, key: AnyStr) -> str:
"""
Expand All @@ -613,4 +570,117 @@ def inforeset(self, key: AnyStr) -> str:
"""
args = builder.inforeset(key)
res = self.execute_command(*args)
return processor.inforeset(res)
return res if not self.enable_postprocess else processor.inforeset(res)


class Pipeline(RedisPipeline, Client):
def __init__(self, enable_postprocess, *args, **kwargs):
warnings.warn("Pipeling AI commands through this client is experimental.",
UserWarning)
self.enable_postprocess = False
if enable_postprocess:
warnings.warn("Postprocessing is enabled but not allowed in pipelines."
"Disable postprocessing to remove this warning.", UserWarning)
self.tensorget_processors = []
super().__init__(*args, **kwargs)

def dag(self, *args, **kwargs):
raise RuntimeError("Pipeline object doesn't allow DAG creation currently")

def tensorget(self, key, as_numpy=True, meta_only=False):
self.tensorget_processors.append(partial(processor.tensorget,
as_numpy=as_numpy,
meta_only=meta_only))
return super().tensorget(key, as_numpy, meta_only)

def _execute_transaction(self, *args, **kwargs):
res = super()._execute_transaction(*args, **kwargs)
for i in range(len(res)):
# tensorget will have minimum 4 values if meta_only = True
if isinstance(res[i], list) and len(res[i]) >= 4:
res[i] = self.tensorget_processors.pop(0)(res[i])
return res

def _execute_pipeline(self, *args, **kwargs):
res = super()._execute_pipeline(*args, **kwargs)
for i in range(len(res)):
# tensorget will have minimum 4 values if meta_only = True
if isinstance(res[i], list) and len(res[i]) >= 4:
res[i] = self.tensorget_processors.pop(0)(res[i])
return res


class Dag:
def __init__(self, load, persist, executor, readonly=False, postprocess=True):
self.result_processors = []
self.enable_postprocess = True
if readonly:
if persist:
raise RuntimeError("READONLY requests cannot write (duh!) and should not "
"have PERSISTing values")
self.commands = ['AI.DAGRUN_RO']
else:
self.commands = ['AI.DAGRUN']
if load:
if not isinstance(load, (list, tuple)):
self.commands += ["LOAD", 1, load]
else:
self.commands += ["LOAD", len(load), *load]
if persist:
if not isinstance(persist, (list, tuple)):
self.commands += ["PERSIST", 1, persist, '|>']
else:
self.commands += ["PERSIST", len(persist), *persist, '|>']
elif load:
self.commands.append('|>')
self.executor = executor

def tensorset(self,
key: AnyStr,
tensor: Union[np.ndarray, list, tuple],
shape: Sequence[int] = None,
dtype: str = None) -> Any:
args = builder.tensorset(key, tensor, shape, dtype)
self.commands.extend(args)
self.commands.append("|>")
self.result_processors.append(bytes.decode)
return self

def tensorget(self,
key: AnyStr, as_numpy: bool = True,
meta_only: bool = False) -> Any:
args = builder.tensorget(key, as_numpy, meta_only)
self.commands.extend(args)
self.commands.append("|>")
self.result_processors.append(partial(processor.tensorget,
as_numpy=as_numpy,
meta_only=meta_only))
return self

def modelrun(self,
key: AnyStr,
inputs: Union[AnyStr, List[AnyStr]],
outputs: Union[AnyStr, List[AnyStr]]) -> Any:
args = builder.modelrun(key, inputs, outputs)
self.commands.extend(args)
self.commands.append("|>")
self.result_processors.append(bytes.decode)
return self

def run(self):
results = self.executor(*self.commands)
if self.enable_postprocess:
out = []
for res, fn in zip(results, self.result_processors):
out.append(fn(res))
else:
out = results
return out


def enable_debug(f):
@wraps(f)
def wrapper(*args):
print(*args)
return f(*args)
return wrapper
Loading

0 comments on commit 168c4a5

Please sign in to comment.