Skip to content

Commit

Permalink
refactor: a unified and decoupled model udf expr optimizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
lovelynewlife committed Sep 23, 2023
1 parent 977fcc2 commit bf8e8a1
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 39 deletions.
41 changes: 6 additions & 35 deletions onnxoptimizer/query/pandas/core/computation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from onnxoptimizer.query.pandas.core.computation.visitor import (
PARSERS,
)
from onnxoptimizer.query.pandas.optimization.optimizer import MultiModelExprOptimizer
from onnxoptimizer.query.pandas.optimization.optimizer import ModelExprOptimizer


def _check_engine(engine: str | None) -> str:
Expand Down Expand Up @@ -242,7 +242,7 @@ def pandas_eval(
resolvers=resolvers,
target=target,
)
# TODO: Handle modified target and resolvers.
# TODO: May need to handle modified target and resolvers.
# cannot eval two consequent expr:
# the second expr cannot ref the first assigner
expr_to_eval = []
Expand All @@ -258,48 +258,19 @@ def pandas_eval(
# Optimization Phase
#################

optimizer = MultiModelExprOptimizer()
optimizer = ModelExprOptimizer(env, engine, level)

expr_remain = []
if enable_opt:
expr_to_opt = []
assigners = []

for e2e in expr_to_eval:
if isinstance(e2e.terms, ONNXFuncNode) or isinstance(e2e.terms, ONNXPredicate):
expr_to_opt.append(e2e)
assigners.append(e2e.assigner)
else:
expr_remain.append(e2e)

if len(expr_to_opt) < 2:
if len(expr_to_opt) > 0:
onnx_compiler = ONNXPredicateCompiler(env)
compiled = onnx_compiler.compile(expr_to_opt[0].terms)

model_obj = ModelObject(compiled.model_partial)
compiled_term = ModelContext(model_obj)
if compiled.external_input is not None:
compiled_term.set_infer_input(**compiled.external_input)
compiled_expr = ComposedExpr(engine, env, level, compiled_term)
expr_remain.append(compiled_expr)
else:
expr_remain.extend(expr_to_opt)
else:
# do optimize phase
fused_term = optimizer.optimize(expr_to_opt)
composed_expr = ComposedExpr(engine, env, level, fused_term, assigners)
expr_remain.append(composed_expr)

expr_final_eval = optimizer.optimize(expr_to_eval)
else:
expr_remain = expr_to_eval
expr_final_eval = expr_to_eval

#################
# Evaluation Phase
#################

# evaluate un-optimized expr
for e2e in expr_remain:
for e2e in expr_final_eval:
# get our (possibly passed-in) scope
env = ensure_scope(
level + 1,
Expand Down
67 changes: 63 additions & 4 deletions onnxoptimizer/query/pandas/optimization/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,52 @@
import onnx

import onnxoptimizer
from onnxoptimizer.query.onnx.compile_expr import ONNXPredicateCompiler
from onnxoptimizer.query.onnx.joint import merge_project_models_wrap
from onnxoptimizer.query.onnx.context import ModelContext
from onnxoptimizer.query.onnx.model import ModelObject
from onnxoptimizer.query.pandas.core.computation.expr import ComposedExpr
from onnxoptimizer.query.pandas.core.computation.ops import ONNXFuncNode, ONNXPredicate


class MultiModelExprOptimizer:
def __init__(self):
self.rules = []
class ModelExprOptimizer:
def __init__(self, env, engine, level):
self.model_optimizer = onnxoptimizer
self.env = env
self.engine = engine
self.level = level
self.predicate_compiler = ONNXPredicateCompiler(env)

def optimize(self, expr_list):
expr_opted = []

expr_to_opt = []
assigners = []

predicate_to_opt = []

for e2e in expr_list:
if isinstance(e2e.terms, ONNXFuncNode) and e2e.assigner is not None:
expr_to_opt.append(e2e)
assigners.append(e2e.assigner)
if isinstance(e2e.terms, ONNXPredicate):
predicate_to_opt.append(e2e)
else:
expr_opted.append(e2e)

if len(expr_to_opt) < 2:
expr_opted.extend(expr_to_opt)
else:
fused_expr = self._optimize_multi_expr(expr_to_opt, assigners)
expr_opted.extend(fused_expr)

opted_predicates = self._optimize_predicate(predicate_to_opt)

expr_opted.extend(opted_predicates)

return expr_opted

def _optimize_multi_expr(self, expr_list, assigners):
models = []
all_inputs = {}
for expr in expr_list:
Expand All @@ -38,4 +73,28 @@ def optimize(self, expr_list):

model_context.set_infer_input(**all_inputs)

return model_context
composed_expr = ComposedExpr(self.engine, self.env, self.level,
model_context, assigners)

return [composed_expr]

def _optimize_predicate(self, expr_list):
opted_expr_list = []

for expr in expr_list:
compiled = self.predicate_compiler.compile(expr.terms)

model_obj = ModelObject(compiled.model_partial)
compiled_term = ModelContext(model_obj)
if compiled.external_input is not None:
compiled_term.set_infer_input(**compiled.external_input)
else:
compiled_term.set_infer_input(**{})
compiled_expr = ComposedExpr(self.engine, self.env, self.level, compiled_term)

opted_expr_list.append(compiled_expr)

return opted_expr_list



0 comments on commit bf8e8a1

Please sign in to comment.