Skip to content

Commit

Permalink
feat: impl predicate->onnx ir compilation whole process.
Browse files Browse the repository at this point in the history
  • Loading branch information
lovelynewlife committed Sep 21, 2023
1 parent 9598e70 commit 39874ba
Show file tree
Hide file tree
Showing 14 changed files with 778 additions and 255 deletions.
108 changes: 60 additions & 48 deletions examples/pandas/prediction_query.ipynb

Large diffs are not rendered by default.

194 changes: 176 additions & 18 deletions onnxoptimizer/query/onnx/compile.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from onnx.compose import merge_graphs, merge_models

from onnxoptimizer.query.pandas.core.computation.ops import ONNXPredicate, ONNXFuncNode, BinOp, Term, UnaryOp
import numpy as np
import onnx.numpy_helper
from onnx import ValueInfoProto
from onnx.compose import merge_models

from onnxoptimizer.query.onnx.joint import merge_project_models_wrap
from onnxoptimizer.query.onnx.joint.fragment import ModelFragment, OpModelFragment, TermModelFragment
from onnxoptimizer.query.onnx.joint.merge_expr import merge_models_binary_wrap
from onnxoptimizer.query.pandas.core.computation.ops import (ONNXPredicate, ONNXFuncNode,
BinOp, Term, UnaryOp, Constant)
from onnxoptimizer.query.types.mapper import numpy_onnx_tensor_type_map

ONNX_CMP_OPS_SYMS = (">", "<", ">=", "<=", "==", "!=")
_onnx_cmp_ops_nodes = (
Expand Down Expand Up @@ -51,27 +59,56 @@
_onnx_unary_ops_map = dict(zip(ONNX_UNARY_OPS_SYMS, _onnx_unary_ops_nodes))


def merge_fragment_with_unary_op(
m: ModelFragment,
op: OpModelFragment,
):
model = m.model
op_model = op.model

model = merge_models(model, op_model, io_map=[
(m.get_default_output().name, op.get_default_input().name)
])
return OpModelFragment(model, op.return_type, op.op)


class ONNXPredicateCompiler:
helper = onnx.helper
numpy_helper = onnx.numpy_helper

IDENTITY = "Identity"

def __init__(self, env):
self.env = env
self.temps = 0

def compile_save(self, node, path=None):
model = self.compile(node)
onnx.save(model.model, path)

def compile(self, node):
def compile(self, node) -> ModelFragment:
compile_method = f"compile_{type(node).__name__}"
compile_func = getattr(self, compile_method)

return compile_func(node)

def compile_ONNXPredicate(self, node: ONNXPredicate):
lhs = node.lhs
rhs = node.rhs
op = node.op
def get_temp_name_by_value(self, value):
name = f"{type(value).__name__}_{self.temps}"
self.temps += 1

lhs_ir = self.compile(lhs)
rhs_ir = self.compile(rhs)
print("compile predicate:", op)
return name

def compile_ONNXFuncNode(self, node: ONNXFuncNode):
print("compile func node:", node.name)
def get_temp_name(self):
name = f"temp_{self.temps}"
self.temps += 1

return name

def compile_ONNXPredicate(self, node: ONNXPredicate):
return self.compile_BinOp(node)

def compile_ONNXFuncNode(self, node: ONNXFuncNode) -> ModelFragment:
return ModelFragment(node.model, node.return_type)

def compile_BinOp(self, node: BinOp):
lhs = node.lhs
Expand All @@ -80,15 +117,136 @@ def compile_BinOp(self, node: BinOp):

lhs_ir = self.compile(lhs)
rhs_ir = self.compile(rhs)
print("compile normal bin op:", op)

lhs_prefix = self.get_temp_name()
rhs_prefix = self.get_temp_name()

lhs_tensor = lhs_ir.get_default_output()
lhs_name = lhs_prefix + lhs_tensor.name

rhs_tensor = rhs_ir.get_default_output()
rhs_name = rhs_prefix + rhs_tensor.name

lhs_input_tensor = ValueInfoProto()
lhs_input_tensor.CopyFrom(lhs_tensor)
lhs_input_tensor.name = lhs_name

rhs_input_tensor = ValueInfoProto()
rhs_input_tensor.CopyFrom(rhs_tensor)
rhs_input_tensor.name = rhs_name

output_type = numpy_onnx_tensor_type_map[node.return_type.type]

output_name = self.get_temp_name()

# Try delegate shape inference job to onnx
output_tensor = self.helper.make_tensor_value_info(
name=output_name,
elem_type=output_type,
shape=[None]
)

op_node = self.helper.make_node(
op_type=_onnx_binary_ops_dict[op],
inputs=[lhs_name, rhs_name],
outputs=[output_name]
)

partial_graph = self.helper.make_graph(
nodes=[op_node],
name=self.get_temp_name(),
inputs=[lhs_input_tensor, rhs_input_tensor],
outputs=[output_tensor]
)

partial_model = self.helper.make_model(partial_graph)

partial_model = merge_models_binary_wrap(
lhs_ir.model,
rhs_ir.model,
partial_model,
prefix1=lhs_prefix,
prefix2=rhs_prefix
)

partial_frag = OpModelFragment(partial_model, node.return_type, op)

return partial_frag

def compile_UnaryOp(self, node: UnaryOp):
op = node.op
operand = self.compile(node.operand)
print("compile normal unary op:", op)

output_name = self.get_temp_name()

input_tensor = operand.get_default_output()
input_name = input_tensor.name

output_type = numpy_onnx_tensor_type_map[node.return_type.type]

# Should I do shape inference here?
output_shape = input_tensor.type.tensor_type.shape

output_tensor = self.helper.make_tensor_value_info(
name=output_name,
elem_type=output_type,
shape=output_shape
)

op_node = self.helper.make_node(
op_type=_onnx_unary_ops_map[op],
inputs=[input_name],
outputs=[output_name]
)

partial_graph = self.helper.make_graph(
nodes=[op_node],
name=self.get_temp_name(),
inputs=[input_tensor],
outputs=[output_tensor]
)

partial_model = self.helper.make_model(partial_graph)

onnx.checker.check_model(partial_model)

temp_frag = OpModelFragment(partial_model, node.return_type, op)

merge_frag = merge_fragment_with_unary_op(operand, temp_frag)

return merge_frag

def compile_as_initializer(self, node):
output_name = self.get_temp_name()
value_name = self.get_temp_name_by_value(node.value)

np_value = np.array(node.value)
constant_value = self.numpy_helper.from_array(np_value, name=value_name)

identity_node = self.helper.make_node(self.IDENTITY,
inputs=[value_name],
outputs=[output_name])

term_output = self.helper.make_tensor_value_info(name=output_name,
elem_type=constant_value.data_type,
shape=np_value.shape)

partial_graph = self.helper.make_graph(
nodes=[identity_node],
name=self.get_temp_name(),
inputs=[],
outputs=[term_output],
initializer=[constant_value]
)

partial_model = self.helper.make_model(partial_graph)

onnx.checker.check_model(partial_model)

return TermModelFragment(partial_model, np_value.dtype)

def compile_Term(self, node: Term):
print("compile term node:", node.name)
return self.compile_as_initializer(node)

def compile_Constant(self, node: Term):
print("compile constant node:", node.name)
def compile_Constant(self, node: Constant):
return self.compile_as_initializer(node)
21 changes: 18 additions & 3 deletions onnxoptimizer/query/onnx/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,25 @@ def __init__(self, model_obj: ModelObject):

model_data = self.model_obj.model.SerializeToString()
self.infer_session = ort.InferenceSession(model_data)

self.labels_map = {
elem.name: elem for elem in self.infer_session.get_outputs()
if elem.name.endswith("label") or elem.name.endswith("variable")
}
self.probabilities_map = {
elem.name: elem for elem in self.infer_session.get_outputs()
if elem.name.endswith("probability")
}

self.infer_input = {}

def return_type(self, which=None):
if which is None:
assert len(self.labels_map) == 1
return next(iter(self.labels_map.values())).type
else:
return self.labels_map[which].type

def set_infer_input(self, **kwargs):
self.infer_input = kwargs

Expand All @@ -27,10 +44,8 @@ def __call__(self):
}
session = self.infer_session

labels = [elem.name for elem in session.get_outputs() if elem.name.endswith("label") or elem.name.endswith("variable")]
probabilities = [elem.name for elem in session.get_outputs() if elem.name.endswith("probability")]

label_out = []
labels = list(self.labels_map.keys())
for elem in labels:
label_out.append(elem.replace("output_label", "").replace("variable", ""))

Expand Down
6 changes: 3 additions & 3 deletions onnxoptimizer/query/onnx/joint/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from onnxoptimizer.query.onnx.joint.expr_compose import merge_project_models
from onnxoptimizer.query.onnx.joint.expr_compose import merge_project_graphs
from onnxoptimizer.query.onnx.joint.merge_expr import merge_project_models_wrap
from onnxoptimizer.query.onnx.joint.merge_expr import merge_project_graphs

__all__ = [
"merge_project_models",
"merge_project_models_wrap",
"merge_project_graphs"
]
Loading

0 comments on commit 39874ba

Please sign in to comment.