Skip to content

Commit

Permalink
refactor: remove redundant code of value assign, much better query ev…
Browse files Browse the repository at this point in the history
…al phases.
  • Loading branch information
lovelynewlife committed Sep 15, 2023
1 parent 91eb689 commit 06eb26f
Show file tree
Hide file tree
Showing 9 changed files with 969 additions and 884 deletions.
2 changes: 1 addition & 1 deletion VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.17
dev-ever
15 changes: 10 additions & 5 deletions onnxoptimizer/query/onnx/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@ def __call__(self):
}
session = self.infer_session

labels = [elem.name for elem in session.get_outputs()
if elem.name.endswith("output_label") or elem.name.endswith("variable")]
labels = [elem.name for elem in session.get_outputs()]
probabilities = [elem.name for elem in session.get_outputs() if elem.name.endswith("output_probability")]

label_out = []
for elem in labels:
label_out.append(elem.replace("output_label", "").replace("variable", ""))

infer_res = session.run(labels, infer_batch)
res = {}
infer_result = {}

for i in range(len(labels)):
res[labels[i]] = infer_res[i]
return res[labels[0]]
infer_result[labels[i]] = infer_res[i]

return infer_result[label_out[0]]


class MultiModelContext:
Expand Down
4 changes: 2 additions & 2 deletions onnxoptimizer/query/pandas/api/patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import pandas
import pandas as pd
from pandas import DataFrame
from pandas.util._validators import validate_bool_kwarg

Expand All @@ -9,11 +10,10 @@

LEVEL_OFFSET_1 = 1


@callable_patch(pandas)
def predict_eval(expr: str,
parser: str = "pandas",
engine: str | None = 'python',
engine: str | None = "onnxruntime",
local_dict: Any = None,
global_dict: Any = None,
resolvers: Any = (),
Expand Down
7 changes: 7 additions & 0 deletions onnxoptimizer/query/pandas/core/computation/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,11 @@
else:
NUMEXPR_VERSION = None

ort = import_optional_dependency("onnxruntime", errors="warn")
ONNXRUNTIME_INSTALLED = ort is not None
if ONNXRUNTIME_INSTALLED:
ONNXRUNTIME_VERSION = ort.__version__
else:
ONNXRUNTIME_VERSION = None

__all__ = ["NUMEXPR_INSTALLED", "NUMEXPR_VERSION"]
15 changes: 15 additions & 0 deletions onnxoptimizer/query/pandas/core/computation/engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,22 @@ def _evaluate(self) -> None:
pass


class ONNXEngine(AbstractEngine):
"""
Evaluate an expression in ONNX context
"""

has_neg_frac = True

def evaluate(self):
return self.expr()

def _evaluate(self) -> None:
pass


ENGINES: dict[str, type[AbstractEngine]] = {
"numexpr": NumExprEngine,
"python": PythonEngine,
"onnxruntime": ONNXEngine
}
132 changes: 31 additions & 101 deletions onnxoptimizer/query/pandas/core/computation/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
from pandas.core.dtypes.common import is_extension_array_dtype

from onnxoptimizer.query.onnx.context import MultiModelContext
from onnxoptimizer.query.pandas.core.computation.check import ONNXRUNTIME_INSTALLED
from onnxoptimizer.query.pandas.core.computation.engines import ENGINES
from onnxoptimizer.query.pandas.core.computation.expr import (
from onnxoptimizer.query.pandas.core.computation.visitor import (
PARSERS,
Expr,
)
from onnxoptimizer.query.pandas.core.computation.expr import Expr, _assign_value, ComposedExpr
from onnxoptimizer.query.pandas.core.computation.parsing import tokenize_string
from onnxoptimizer.query.pandas.core.computation.scope import ensure_scope
from pandas.core.generic import NDFrame
Expand Down Expand Up @@ -65,8 +66,14 @@ def _check_engine(engine: str | None) -> str:
# Could potentially be done on engine instantiation
if engine == "numexpr" and not NUMEXPR_INSTALLED:
raise ImportError(
"'numexpr' is not installed or an unsupported version. Cannot use "
"engine='numexpr' for query/eval if 'numexpr' is not installed"
f"'{engine}' is not installed or an unsupported version. Cannot use "
f"engine='{engine}' for query/eval if '{engine}' is not installed"
)

if engine == "onnxruntime" and not ONNXRUNTIME_INSTALLED:
raise ImportError(
f"'{engine}' is not installed or an unsupported version. Cannot use "
f"engine='{engine}' for query/eval if '{engine}' is not installed"
)

return engine
Expand Down Expand Up @@ -240,6 +247,9 @@ def pandas_eval(
resolvers=resolvers,
target=target,
)
# TODO: Handle modified target and resolvers.
# cannot eval two consequent expr:
# the second expr cannot ref the first assigner

parsed_expr = Expr(expr, engine=engine, parser=parser, env=env)

Expand All @@ -250,86 +260,40 @@ def pandas_eval(
#################
# TODO: optimization phase
expr_remain = []
res = {}
if enable_opt:
expr_to_opt = []
assigners = []

for e2e in expr_to_eval:
if isinstance(e2e.terms, ONNXEvalNode):
expr_to_opt.append(e2e)
assigners.append(e2e.assigner)
else:
expr_remain.append(e2e)

if len(expr_to_opt) < 2:
expr_remain.extend(expr_to_opt)
else:
# do optimize eval
fused_expr = MultiModelContext(expr_to_opt)
res.update(fused_expr())
# do optimize phase
fused_term = MultiModelContext(expr_to_opt)
env = ensure_scope(
level + 1,
global_dict=global_dict,
local_dict=local_dict,
resolvers=resolvers,
target=target,
)
composed_expr = ComposedExpr(engine, env, level, fused_term, assigners)
expr_remain.append(composed_expr)

else:
expr_remain = expr_to_eval

#################
# Evaluation Phase
#################
env = ensure_scope(
level + 1,
global_dict=global_dict,
local_dict=local_dict,
resolvers=resolvers,
target=target,
)

for assigner, ret in res.items():
if env.target is not None and assigner is not None:
target_modified = True

# if returning a copy, copy only on the first assignment
if not inplace and first_expr:
try:
target = env.target
if isinstance(target, NDFrame):
target = target.copy(deep=None)
else:
target = target.copy()
except AttributeError as err:
raise ValueError("Cannot return a copy of the target") from err
else:
target = env.target

# TypeError is most commonly raised (e.g. int, list), but you
# get IndexError if you try to do this assignment on np.ndarray.
# we will ignore numpy warnings here; e.g. if trying
# to use a non-numeric indexer
try:
with warnings.catch_warnings(record=True):
# TODO: Filter the warnings we actually care about here.
if inplace and isinstance(target, NDFrame):
target.loc[:, assigner] = ret
else:
target[ # pyright: ignore[reportGeneralTypeIssues]
assigner
] = ret
except (TypeError, IndexError) as err:
raise ValueError("Cannot assign expression output to target") from err

if not resolvers:
resolvers = ({assigner: ret},)
else:
# existing resolver needs updated to handle
# case of mutating existing column in copy
for resolver in resolvers:
if assigner in resolver:
resolver[assigner] = ret
break
else:
resolvers += ({assigner: ret},)

ret = None
first_expr = False
env.target = target

# evaluate un-optimized exprs
# evaluate un-optimized expr
for e2e in expr_remain:
# get our (possibly passed-in) scope
env = ensure_scope(
Expand All @@ -346,15 +310,7 @@ def pandas_eval(
eng = ENGINES[engine]
eng_inst = eng(e2e)

# Temporary engine numexpr fallback check
# TODO: move it ahead, avoid eval many times
try:
ret = eng_inst.evaluate()
except ValueError:
engine = 'python'
eng = ENGINES[engine]
eng_inst = eng(e2e)
ret = eng_inst.evaluate()
ret = eng_inst.evaluate()

if e2e.assigner is None:
if multi_line:
Expand Down Expand Up @@ -384,33 +340,7 @@ def pandas_eval(
else:
target = env.target

# TypeError is most commonly raised (e.g. int, list), but you
# get IndexError if you try to do this assignment on np.ndarray.
# we will ignore numpy warnings here; e.g. if trying
# to use a non-numeric indexer
try:
with warnings.catch_warnings(record=True):
# TODO: Filter the warnings we actually care about here.
if inplace and isinstance(target, NDFrame):
target.loc[:, assigner] = ret
else:
target[ # pyright: ignore[reportGeneralTypeIssues]
assigner
] = ret
except (TypeError, IndexError) as err:
raise ValueError("Cannot assign expression output to target") from err

if not resolvers:
resolvers = ({assigner: ret},)
else:
# existing resolver needs updated to handle
# case of mutating existing column in copy
for resolver in resolvers:
if assigner in resolver:
resolver[assigner] = ret
break
else:
resolvers += ({assigner: ret},)
target, resolvers = e2e.assign_value(target, resolvers, inplace)

ret = None
first_expr = False
Expand Down
Loading

0 comments on commit 06eb26f

Please sign in to comment.