diff --git a/examples/pandas/prediction_query.ipynb b/examples/pandas/prediction_query.ipynb index 02da37b..982e81a 100644 --- a/examples/pandas/prediction_query.ipynb +++ b/examples/pandas/prediction_query.ipynb @@ -17,12 +17,14 @@ "from sklearn.preprocessing import StandardScaler\n", "import matplotlib.pyplot as plt\n", "\n", + "DATA_DIR = \"/home/uw1/MLquery/reference/snippets/py_onnx/expedia\"\n", + "\n", "sklearn.set_config(display='diagram')\n", "\n", "# 表路径\n", - "path1 = \"/home/uw1/snippets/py_onnx/expedia/data/S_listings.csv\"\n", - "path2 = \"/home/uw1/snippets/py_onnx/expedia/data/R1_hotels.csv\"\n", - "path3 = \"/home/uw1/snippets/py_onnx/expedia/data/R2_searches.csv\"\n", + "path1 = f\"{DATA_DIR}/data/S_listings.csv\"\n", + "path2 = f\"{DATA_DIR}/data/R1_hotels.csv\"\n", + "path3 = f\"{DATA_DIR}/data/R2_searches.csv\"\n", "# 读取csv表\n", "S_listings = pd.read_csv(path1)\n", "R1_hotels = pd.read_csv(path2)\n", @@ -61,8 +63,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:20:56.293096349Z", - "start_time": "2023-09-14T14:20:52.905516721Z" + "end_time": "2023-09-20T08:29:56.470642927Z", + "start_time": "2023-09-20T08:29:54.276610261Z" } }, "id": "initial_id" @@ -74,8 +76,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:20:56.329993615Z", - "start_time": "2023-09-14T14:20:56.298925022Z" + "end_time": "2023-09-20T08:29:56.501436572Z", + "start_time": "2023-09-20T08:29:56.475026791Z" } }, "outputs": [ @@ -123,8 +125,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:39.069314907Z", - "start_time": "2023-09-14T14:20:56.336077806Z" + "end_time": "2023-09-20T08:31:25.788872355Z", + "start_time": "2023-09-20T08:29:56.501779222Z" } }, "outputs": [ @@ -134,8 +136,8 @@ "text": [ "[Pipeline] ............ (step 1 of 1) Processing scaler, total= 0.1s\n", "[ColumnTransformer] ..... (1 of 2) Processing numerical, total= 0.1s\n", - "[Pipeline] ............ (step 1 of 1) Processing onehot, total= 1.5s\n", - "[ColumnTransformer] ... (2 of 2) Processing categorical, total= 1.5s\n", + "[Pipeline] ............ (step 1 of 1) Processing onehot, total= 1.2s\n", + "[ColumnTransformer] ... (2 of 2) Processing categorical, total= 1.2s\n", "Training done.\n" ] } @@ -153,8 +155,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:39.071695911Z", - "start_time": "2023-09-14T14:22:39.069023811Z" + "end_time": "2023-09-20T08:31:25.797897803Z", + "start_time": "2023-09-20T08:31:25.789729348Z" } }, "outputs": [ @@ -182,8 +184,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:45.287430337Z", - "start_time": "2023-09-14T14:22:39.069700163Z" + "end_time": "2023-09-20T08:31:31.204254992Z", + "start_time": "2023-09-20T08:31:25.790294144Z" } }, "outputs": [ @@ -193,8 +195,8 @@ "text": [ "[Pipeline] ............ (step 1 of 1) Processing scaler, total= 0.1s\n", "[ColumnTransformer] ..... (1 of 2) Processing numerical, total= 0.1s\n", - "[Pipeline] ............ (step 1 of 1) Processing onehot, total= 1.5s\n", - "[ColumnTransformer] ... (2 of 2) Processing categorical, total= 1.5s\n", + "[Pipeline] ............ (step 1 of 1) Processing onehot, total= 1.2s\n", + "[ColumnTransformer] ... (2 of 2) Processing categorical, total= 1.2s\n", "Training done.\n" ] } @@ -207,22 +209,22 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 27, "id": "57098bcb814bc88c", "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:45.386794834Z", - "start_time": "2023-09-14T14:22:45.286981108Z" + "end_time": "2023-09-20T08:47:49.288583215Z", + "start_time": "2023-09-20T08:47:49.261947235Z" } }, "outputs": [ { "data": { - "text/plain": " prop_location_score1 prop_location_score2 prop_log_historical_price \\\n0 2.20 0.0472 0.00 \n1 2.56 0.0221 0.00 \n2 2.56 0.0863 0.00 \n3 2.20 0.0104 0.00 \n4 0.69 0.0138 0.00 \n... ... ... ... \n119995 3.09 0.2478 0.00 \n119996 3.04 0.0000 0.00 \n119997 0.00 0.0872 0.00 \n119998 2.48 0.0817 4.11 \n119999 2.83 0.0026 4.08 \n\n price_usd orig_destination_distance prop_review_score \\\n0 80.0 86.59 3.5 \n1 75.0 82.51 4.5 \n2 89.0 91.15 4.0 \n3 40.0 87.17 2.5 \n4 65.0 85.21 4.0 \n... ... ... ... \n119995 76.0 0.00 4.5 \n119996 93.0 0.00 0.0 \n119997 143.0 0.00 4.5 \n119998 42.0 0.00 3.5 \n119999 30.0 0.00 2.0 \n\n avg_bookings_usd stdev_bookings_usd position prop_country_id ... \\\n0 172.605500 86.693494 '1' '219' ... \n1 277.470000 77.690000 '1' '219' ... \n2 124.185714 62.019917 '1' '219' ... \n3 103.603333 83.719086 '1' '219' ... \n4 113.760000 30.520000 '1' '219' ... \n... ... ... ... ... ... \n119995 246.620000 0.000000 '0' '219' ... \n119996 252.000000 84.000000 '0' '219' ... \n119997 247.537561 142.940988 '1' '219' ... \n119998 142.761429 103.587524 '0' '219' ... \n119999 270.600000 0.000000 '1' '219' ... \n\n site_id visitor_location_country_id srch_destination_id \\\n0 '5' '219' '13233' \n1 '5' '219' '13233' \n2 '5' '219' '13233' \n3 '5' '219' '13233' \n4 '5' '219' '13233' \n... ... ... ... \n119995 '5' '219' '16823' \n119996 '5' '219' '16823' \n119997 '5' '219' '16823' \n119998 '5' '219' '16823' \n119999 '5' '219' '16823' \n\n srch_length_of_stay srch_booking_window srch_adults_count \\\n0 2 0 2 \n1 2 0 2 \n2 2 0 2 \n3 2 0 2 \n4 2 0 2 \n... ... ... ... \n119995 1 21 2 \n119996 1 21 2 \n119997 1 21 2 \n119998 2 30 2 \n119999 2 30 2 \n\n srch_children_count srch_room_count srch_saturday_night_bool \\\n0 0 1 1 \n1 0 1 1 \n2 0 1 1 \n3 0 1 1 \n4 0 1 1 \n... ... ... ... \n119995 0 1 1 \n119996 0 1 1 \n119997 0 1 1 \n119998 0 1 1 \n119999 0 1 1 \n\n random_bool \n0 0 \n1 0 \n2 0 \n3 0 \n4 0 \n... ... \n119995 0 \n119996 0 \n119997 0 \n119998 1 \n119999 1 \n\n[120000 rows x 28 columns]", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
prop_location_score1prop_location_score2prop_log_historical_priceprice_usdorig_destination_distanceprop_review_scoreavg_bookings_usdstdev_bookings_usdpositionprop_country_id...site_idvisitor_location_country_idsrch_destination_idsrch_length_of_staysrch_booking_windowsrch_adults_countsrch_children_countsrch_room_countsrch_saturday_night_boolrandom_bool
02.200.04720.0080.086.593.5172.60550086.693494'1''219'...'5''219''13233'2020110
12.560.02210.0075.082.514.5277.47000077.690000'1''219'...'5''219''13233'2020110
22.560.08630.0089.091.154.0124.18571462.019917'1''219'...'5''219''13233'2020110
32.200.01040.0040.087.172.5103.60333383.719086'1''219'...'5''219''13233'2020110
40.690.01380.0065.085.214.0113.76000030.520000'1''219'...'5''219''13233'2020110
..................................................................
1199953.090.24780.0076.00.004.5246.6200000.000000'0''219'...'5''219''16823'12120110
1199963.040.00000.0093.00.000.0252.00000084.000000'0''219'...'5''219''16823'12120110
1199970.000.08720.00143.00.004.5247.537561142.940988'1''219'...'5''219''16823'12120110
1199982.480.08174.1142.00.003.5142.761429103.587524'0''219'...'5''219''16823'23020111
1199992.830.00264.0830.00.002.0270.6000000.000000'1''219'...'5''219''16823'23020111
\n

120000 rows × 28 columns

\n
" + "text/plain": " prop_location_score1 prop_location_score2 prop_log_historical_price \\\n0 2.20 0.0472 0.00 \n1 2.56 0.0221 0.00 \n2 2.56 0.0863 0.00 \n3 2.20 0.0104 0.00 \n4 0.69 0.0138 0.00 \n... ... ... ... \n49995 3.22 0.1841 4.11 \n49996 3.26 0.1802 4.24 \n49997 2.77 0.1969 4.91 \n49998 3.18 0.1785 3.84 \n49999 2.83 0.1616 4.07 \n\n price_usd orig_destination_distance prop_review_score \\\n0 80.0 86.59 3.5 \n1 75.0 82.51 4.5 \n2 89.0 91.15 4.0 \n3 40.0 87.17 2.5 \n4 65.0 85.21 4.0 \n... ... ... ... \n49995 65.0 208.62 3.0 \n49996 70.0 208.55 3.5 \n49997 104.0 208.24 4.0 \n49998 40.0 208.64 2.5 \n49999 65.0 208.39 4.0 \n\n avg_bookings_usd stdev_bookings_usd position prop_country_id ... \\\n0 172.605500 86.693494 '1' '219' ... \n1 277.470000 77.690000 '1' '219' ... \n2 124.185714 62.019917 '1' '219' ... \n3 103.603333 83.719086 '1' '219' ... \n4 113.760000 30.520000 '1' '219' ... \n... ... ... ... ... ... \n49995 116.355500 92.122719 '0' '219' ... \n49996 75.948333 27.501123 '0' '219' ... \n49997 171.820000 71.153756 '0' '219' ... \n49998 96.950000 105.050405 '0' '219' ... \n49999 82.274737 46.771863 '1' '219' ... \n\n site_id visitor_location_country_id srch_destination_id \\\n0 '5' '219' '13233' \n1 '5' '219' '13233' \n2 '5' '219' '13233' \n3 '5' '219' '13233' \n4 '5' '219' '13233' \n... ... ... ... \n49995 '5' '219' '21382' \n49996 '5' '219' '21382' \n49997 '5' '219' '21382' \n49998 '5' '219' '21382' \n49999 '5' '219' '21382' \n\n srch_length_of_stay srch_booking_window srch_adults_count \\\n0 2 0 2 \n1 2 0 2 \n2 2 0 2 \n3 2 0 2 \n4 2 0 2 \n... ... ... ... \n49995 1 8 1 \n49996 1 8 1 \n49997 1 8 1 \n49998 1 8 1 \n49999 1 8 1 \n\n srch_children_count srch_room_count srch_saturday_night_bool random_bool \n0 0 1 1 0 \n1 0 1 1 0 \n2 0 1 1 0 \n3 0 1 1 0 \n4 0 1 1 0 \n... ... ... ... ... \n49995 2 1 1 0 \n49996 2 1 1 0 \n49997 2 1 1 0 \n49998 2 1 1 0 \n49999 2 1 1 0 \n\n[50000 rows x 28 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
prop_location_score1prop_location_score2prop_log_historical_priceprice_usdorig_destination_distanceprop_review_scoreavg_bookings_usdstdev_bookings_usdpositionprop_country_id...site_idvisitor_location_country_idsrch_destination_idsrch_length_of_staysrch_booking_windowsrch_adults_countsrch_children_countsrch_room_countsrch_saturday_night_boolrandom_bool
02.200.04720.0080.086.593.5172.60550086.693494'1''219'...'5''219''13233'2020110
12.560.02210.0075.082.514.5277.47000077.690000'1''219'...'5''219''13233'2020110
22.560.08630.0089.091.154.0124.18571462.019917'1''219'...'5''219''13233'2020110
32.200.01040.0040.087.172.5103.60333383.719086'1''219'...'5''219''13233'2020110
40.690.01380.0065.085.214.0113.76000030.520000'1''219'...'5''219''13233'2020110
..................................................................
499953.220.18414.1165.0208.623.0116.35550092.122719'0''219'...'5''219''21382'1812110
499963.260.18024.2470.0208.553.575.94833327.501123'0''219'...'5''219''21382'1812110
499972.770.19694.91104.0208.244.0171.82000071.153756'0''219'...'5''219''21382'1812110
499983.180.17853.8440.0208.642.596.950000105.050405'0''219'...'5''219''21382'1812110
499992.830.16164.0765.0208.394.082.27473746.771863'1''219'...'5''219''21382'1812110
\n

50000 rows × 28 columns

\n
" }, - "execution_count": 6, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -230,20 +232,20 @@ "source": [ "from onnxoptimizer.query.pandas import model_udf\n", "\n", - "predict_df = X[:120000]\n", + "predict_df = X[:50000]\n", "\n", "predict_df" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 33, "id": "9c16abf732937869", "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:45.655020077Z", - "start_time": "2023-09-14T14:22:45.392257311Z" + "end_time": "2023-09-20T08:48:14.096934811Z", + "start_time": "2023-09-20T08:48:14.057414246Z" } }, "outputs": [ @@ -280,22 +282,22 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 34, "id": "56c69e2d0fca0d29", "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:47.350480927Z", - "start_time": "2023-09-14T14:22:45.439294187Z" + "end_time": "2023-09-20T08:48:16.338626086Z", + "start_time": "2023-09-20T08:48:14.838467649Z" } }, "outputs": [ { "data": { - "text/plain": " result_lr result_linear\n0 0 0\n1 0 0\n2 0 0\n3 0 0\n4 0 0\n... ... ...\n119995 0 0\n119996 0 0\n119997 0 0\n119998 0 0\n119999 0 0\n\n[120000 rows x 2 columns]", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
result_lrresult_linear
000
100
200
300
400
.........
11999500
11999600
11999700
11999800
11999900
\n

120000 rows × 2 columns

\n
" + "text/plain": " result_lr result_linear\n0 0 0\n1 0 0\n2 0 0\n3 0 0\n4 0 0\n... ... ...\n49995 0 0\n49996 0 0\n49997 0 0\n49998 1 0\n49999 0 0\n\n[50000 rows x 2 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
result_lrresult_linear
000
100
200
300
400
.........
4999500
4999600
4999700
4999810
4999900
\n

50000 rows × 2 columns

\n
" }, - "execution_count": 8, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -313,23 +315,23 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 35, "id": "56add5579a69412e", "metadata": { "collapsed": false, "scrolled": true, "ExecuteTime": { - "end_time": "2023-09-14T14:22:50.498272243Z", - "start_time": "2023-09-14T14:22:47.349342849Z" + "end_time": "2023-09-20T08:48:20.799196650Z", + "start_time": "2023-09-20T08:48:18.603184048Z" } }, "outputs": [ { "data": { - "text/plain": " result_lr result_linear\n0 0 0\n1 0 0\n2 0 0\n3 0 0\n4 0 0\n... ... ...\n119995 0 0\n119996 0 0\n119997 0 0\n119998 0 0\n119999 0 0\n\n[120000 rows x 2 columns]", - "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
result_lrresult_linear
000
100
200
300
400
.........
11999500
11999600
11999700
11999800
11999900
\n

120000 rows × 2 columns

\n
" + "text/plain": " result_lr result_linear\n0 0 0\n1 0 0\n2 0 0\n3 0 0\n4 0 0\n... ... ...\n49995 0 0\n49996 0 0\n49997 0 0\n49998 1 0\n49999 0 0\n\n[50000 rows x 2 columns]", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
result_lrresult_linear
000
100
200
300
400
.........
4999500
4999600
4999700
4999810
4999900
\n

50000 rows × 2 columns

\n
" }, - "execution_count": 9, + "execution_count": 35, "metadata": {}, "output_type": "execute_result" } @@ -345,21 +347,21 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 36, "id": "867637f8372b663a", "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-09-14T14:22:50.503248165Z", - "start_time": "2023-09-14T14:22:50.498632239Z" + "end_time": "2023-09-20T08:48:21.910501710Z", + "start_time": "2023-09-20T08:48:21.906785142Z" } }, "outputs": [ { "data": { - "text/plain": "(1.900327659008326, 3.1426765380019788)" + "text/plain": "(1.520371841994347, 2.186767206003424)" }, - "execution_count": 10, + "execution_count": 36, "metadata": {}, "output_type": "execute_result" } @@ -370,19 +372,19 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 37, "id": "6cd071ae-e0fc-40a7-a17a-f2d5d0a2469e", "metadata": { "ExecuteTime": { - "end_time": "2023-09-14T14:22:50.611925720Z", - "start_time": "2023-09-14T14:22:50.505370666Z" + "end_time": "2023-09-20T08:48:23.934956586Z", + "start_time": "2023-09-20T08:48:23.875832769Z" } }, "outputs": [ { "data": { "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAdAUlEQVR4nO3df4xV5Z348c9F4Q7UmVF0meHHKKQaBBGGopQB08ENLcuyXSbZuMZsAmXRRIurLo1sp9uIpdmMW4PSpKzUNMrutiyudcUs/mRx0SjjKsi4glVjVxlaZ0btwgyM7YDM+f7ReN35yiCXX48zvF7JSbznPs89zyW5d96eOXNvLsuyLAAAEhmQegEAwOlNjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJnpl7A0eju7o533303SktLI5fLpV4OAHAUsiyLffv2xYgRI2LAgN7Pf/SJGHn33Xejqqoq9TIAgGOwe/fuGDVqVK/394kYKS0tjYjfP5mysrLEqwEAjkZHR0dUVVUVfo73pk/EyMe/mikrKxMjANDHfNYlFi5gBQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkdWbqBQCcCqO//WjqJcDn1jt3zE16fGdGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkioqRu65556YOHFilJWVRVlZWdTU1MTjjz9+xDkPPvhgXHzxxVFSUhKXXnppPPbYY8e1YACgfykqRkaNGhV33HFHbNu2LbZu3Rp/+Id/GPPmzYudO3cedvyWLVvimmuuiUWLFsX27dujrq4u6urqYseOHSdk8QBA35fLsiw7ngcYOnRo3HnnnbFo0aJP3Xf11VdHZ2dnbNiwobBv2rRpUV1dHatXrz7qY3R0dER5eXm0t7dHWVnZ8SwXOE351l7o3cn61t6j/fl9zNeMHDp0KNatWxednZ1RU1Nz2DGNjY0xa9asHvtmz54djY2NR3zsrq6u6Ojo6LEBAP1T0THy6quvxllnnRX5fD6uv/76ePjhh2P8+PGHHdva2hoVFRU99lVUVERra+sRj9HQ0BDl5eWFraqqqthlAgB9RNExMnbs2Ghqaor/+q//ihtuuCEWLFgQr7322gldVH19fbS3txe23bt3n9DHBwA+P84sdsKgQYPiwgsvjIiIKVOmxEsvvRQ//OEP48c//vGnxlZWVkZbW1uPfW1tbVFZWXnEY+Tz+cjn88UuDQDog477c0a6u7ujq6vrsPfV1NTEpk2beuzbuHFjr9eYAACnn6LOjNTX18ecOXPi/PPPj3379sXatWtj8+bN8eSTT0ZExPz582PkyJHR0NAQERE333xz1NbWxooVK2Lu3Lmxbt262Lp1a9x7770n/pkAAH1SUTHy3nvvxfz586OlpSXKy8tj4sSJ8eSTT8ZXv/rViIhobm6OAQM+Odkyffr0WLt2bXz3u9+N73znO3HRRRfF+vXrY8KECSf2WQAAfdZxf87IqeBzRoDj5XNGoHd99nNGAABOBDECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSKipGGhoa4vLLL4/S0tIYNmxY1NXVxRtvvHHEOWvWrIlcLtdjKykpOa5FAwD9R1Ex8swzz8TixYvjhRdeiI0bN8bBgwfja1/7WnR2dh5xXllZWbS0tBS2Xbt2HdeiAYD+48xiBj/xxBM9bq9ZsyaGDRsW27Zti6985Su9zsvlclFZWXlsKwQA+rXjumakvb09IiKGDh16xHH79++PCy64IKqqqmLevHmxc+fOI47v6uqKjo6OHhsA0D8dc4x0d3fHLbfcEjNmzIgJEyb0Om7s2LFx3333xSOPPBI//elPo7u7O6ZPnx6/+tWvep3T0NAQ5eXlha2qqupYlwkAfM7lsizLjmXiDTfcEI8//ng899xzMWrUqKOed/DgwRg3blxcc8018f3vf/+wY7q6uqKrq6twu6OjI6qqqqK9vT3KysqOZbnAaW70tx9NvQT43Hrnjrkn5XE7OjqivLz8M39+F3XNyMduvPHG2LBhQzz77LNFhUhExMCBA2Py5Mnx1ltv9Tomn89HPp8/lqUBAH1MUb+mybIsbrzxxnj44Yfj6aefjjFjxhR9wEOHDsWrr74aw4cPL3ouAND/FHVmZPHixbF27dp45JFHorS0NFpbWyMiory8PAYPHhwREfPnz4+RI0dGQ0NDREQsX748pk2bFhdeeGHs3bs37rzzzti1a1dce+21J/ipAAB9UVExcs8990RExMyZM3vsv//+++Mb3/hGREQ0NzfHgAGfnHDZs2dPXHfdddHa2hrnnHNOTJkyJbZs2RLjx48/vpUDAP3CMV/Aeiod7QUwAL1xASv0LvUFrL6bBgBISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkioqRhoaGuPzyy6O0tDSGDRsWdXV18cYbb3zmvAcffDAuvvjiKCkpiUsvvTQee+yxY14wANC/FBUjzzzzTCxevDheeOGF2LhxYxw8eDC+9rWvRWdnZ69ztmzZEtdcc00sWrQotm/fHnV1dVFXVxc7duw47sUDAH1fLsuy7Fgnv//++zFs2LB45pln4itf+cphx1x99dXR2dkZGzZsKOybNm1aVFdXx+rVq4/qOB0dHVFeXh7t7e1RVlZ2rMsFTmOjv/1o6iXA59Y7d8w9KY97tD+/j+uakfb29oiIGDp0aK9jGhsbY9asWT32zZ49OxobG3ud09XVFR0dHT02AKB/OuYY6e7ujltuuSVmzJgREyZM6HVca2trVFRU9NhXUVERra2tvc5paGiI8vLywlZVVXWsywQAPueOOUYWL14cO3bsiHXr1p3I9URERH19fbS3txe23bt3n/BjAACfD2cey6Qbb7wxNmzYEM8++2yMGjXqiGMrKyujra2tx762traorKzsdU4+n498Pn8sSwMA+piizoxkWRY33nhjPPzww/H000/HmDFjPnNOTU1NbNq0qce+jRs3Rk1NTXErBQD6paLOjCxevDjWrl0bjzzySJSWlhau+ygvL4/BgwdHRMT8+fNj5MiR0dDQEBERN998c9TW1saKFSti7ty5sW7duti6dWvce++9J/ipAAB9UVFnRu65555ob2+PmTNnxvDhwwvbAw88UBjT3NwcLS0thdvTp0+PtWvXxr333huTJk2Kn//857F+/fojXvQKAJw+ijozcjQfSbJ58+ZP7bvqqqviqquuKuZQAMBpwnfTAABJiREAICkxAgAkJUYAgKSO6UPP+hNfngVHdrK+QAvgY86MAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABIqugYefbZZ+PrX/96jBgxInK5XKxfv/6I4zdv3hy5XO5TW2tr67GuGQDoR4qOkc7Ozpg0aVKsWrWqqHlvvPFGtLS0FLZhw4YVe2gAoB86s9gJc+bMiTlz5hR9oGHDhsXZZ59d9DwAoH87ZdeMVFdXx/Dhw+OrX/1qPP/880cc29XVFR0dHT02AKB/OukxMnz48Fi9enU89NBD8dBDD0VVVVXMnDkzXn755V7nNDQ0RHl5eWGrqqo62csEABIp+tc0xRo7dmyMHTu2cHv69Onxy1/+Mu6+++7453/+58POqa+vjyVLlhRud3R0CBIA6KdOeowcztSpU+O5557r9f58Ph/5fP4UrggASCXJ54w0NTXF8OHDUxwaAPicKfrMyP79++Ott94q3H777bejqakphg4dGueff37U19fHr3/96/inf/qniIhYuXJljBkzJi655JL43e9+Fz/5yU/i6aefjqeeeurEPQsAoM8qOka2bt0aV155ZeH2x9d2LFiwINasWRMtLS3R3NxcuP/AgQPxrW99K37961/HkCFDYuLEifEf//EfPR4DADh9FR0jM2fOjCzLer1/zZo1PW4vXbo0li5dWvTCAIDTg++mAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACCpomPk2Wefja9//esxYsSIyOVysX79+s+cs3nz5vjSl74U+Xw+LrzwwlizZs0xLBUA6I+KjpHOzs6YNGlSrFq16qjGv/322zF37ty48soro6mpKW655Za49tpr48knnyx6sQBA/3NmsRPmzJkTc+bMOerxq1evjjFjxsSKFSsiImLcuHHx3HPPxd133x2zZ88u9vAAQD9z0q8ZaWxsjFmzZvXYN3v27GhsbOx1TldXV3R0dPTYAID+6aTHSGtra1RUVPTYV1FRER0dHfHb3/72sHMaGhqivLy8sFVVVZ3sZQIAiXwu/5qmvr4+2tvbC9vu3btTLwkAOEmKvmakWJWVldHW1tZjX1tbW5SVlcXgwYMPOyefz0c+nz/ZSwMAPgdO+pmRmpqa2LRpU499GzdujJqampN9aACgDyg6Rvbv3x9NTU3R1NQUEb//092mpqZobm6OiN//imX+/PmF8ddff338z//8TyxdujRef/31+Id/+If413/91/jrv/7rE/MMAIA+regY2bp1a0yePDkmT54cERFLliyJyZMnx2233RYRES0tLYUwiYgYM2ZMPProo7Fx48aYNGlSrFixIn7yk5/4s14AICKO4ZqRmTNnRpZlvd5/uE9XnTlzZmzfvr3YQwEAp4HP5V/TAACnDzECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSOqYYWbVqVYwePTpKSkriy1/+crz44ou9jl2zZk3kcrkeW0lJyTEvGADoX4qOkQceeCCWLFkSy5Yti5dffjkmTZoUs2fPjvfee6/XOWVlZdHS0lLYdu3adVyLBgD6j6Jj5K677orrrrsuFi5cGOPHj4/Vq1fHkCFD4r777ut1Ti6Xi8rKysJWUVFxXIsGAPqPomLkwIEDsW3btpg1a9YnDzBgQMyaNSsaGxt7nbd///644IILoqqqKubNmxc7d+484nG6urqio6OjxwYA9E9FxcgHH3wQhw4d+tSZjYqKimhtbT3snLFjx8Z9990XjzzySPz0pz+N7u7umD59evzqV7/q9TgNDQ1RXl5e2KqqqopZJgDQh5z0v6apqamJ+fPnR3V1ddTW1sa//du/xR/8wR/Ej3/8417n1NfXR3t7e2HbvXv3yV4mAJDImcUMPu+88+KMM86Itra2Hvvb2tqisrLyqB5j4MCBMXny5Hjrrbd6HZPP5yOfzxezNACgjyrqzMigQYNiypQpsWnTpsK+7u7u2LRpU9TU1BzVYxw6dCheffXVGD58eHErBQD6paLOjERELFmyJBYsWBCXXXZZTJ06NVauXBmdnZ2xcOHCiIiYP39+jBw5MhoaGiIiYvny5TFt2rS48MILY+/evXHnnXfGrl274tprrz2xzwQA6JOKjpGrr7463n///bjtttuitbU1qqur44knnihc1Nrc3BwDBnxywmXPnj1x3XXXRWtra5xzzjkxZcqU2LJlS4wfP/7EPQsAoM/KZVmWpV7EZ+no6Ijy8vJob2+PsrKyE/rYo7/96Al9POhv3rljbuolnBBe69C7k/U6P9qf376bBgBISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkjilGVq1aFaNHj46SkpL48pe/HC+++OIRxz/44INx8cUXR0lJSVx66aXx2GOPHdNiAYD+p+gYeeCBB2LJkiWxbNmyePnll2PSpEkxe/bseO+99w47fsuWLXHNNdfEokWLYvv27VFXVxd1dXWxY8eO4148AND3FR0jd911V1x33XWxcOHCGD9+fKxevTqGDBkS991332HH//CHP4w/+qM/iltvvTXGjRsX3//+9+NLX/pS/OhHPzruxQMAfd+ZxQw+cOBAbNu2Lerr6wv7BgwYELNmzYrGxsbDzmlsbIwlS5b02Dd79uxYv359r8fp6uqKrq6uwu329vaIiOjo6ChmuUelu+vDE/6Y0J+cjNddCl7r0LuT9Tr/+HGzLDviuKJi5IMPPohDhw5FRUVFj/0VFRXx+uuvH3ZOa2vrYce3trb2epyGhob43ve+96n9VVVVxSwXOAHKV6ZeAXCynezX+b59+6K8vLzX+4uKkVOlvr6+x9mU7u7u+N///d8499xzI5fLJVwZJ1NHR0dUVVXF7t27o6ysLPVygJPEa/30kWVZ7Nu3L0aMGHHEcUXFyHnnnRdnnHFGtLW19djf1tYWlZWVh51TWVlZ1PiIiHw+H/l8vse+s88+u5il0oeVlZV5g4LTgNf66eFIZ0Q+VtQFrIMGDYopU6bEpk2bCvu6u7tj06ZNUVNTc9g5NTU1PcZHRGzcuLHX8QDA6aXoX9MsWbIkFixYEJdddllMnTo1Vq5cGZ2dnbFw4cKIiJg/f36MHDkyGhoaIiLi5ptvjtra2lixYkXMnTs31q1bF1u3bo177733xD4TAKBPKjpGrr766nj//ffjtttui9bW1qiuro4nnniicJFqc3NzDBjwyQmX6dOnx9q1a+O73/1ufOc734mLLroo1q9fHxMmTDhxz4J+IZ/Px7Jlyz71Kzqgf/Fa5/+Xyz7r720AAE4i300DACQlRgCApMQIAJCUGAEAkhIj9Blr1qzx4XfQh9x+++1RXV19So61efPmyOVysXfv3lNyPE4sMQIAJCVGOGW6urripptuimHDhkVJSUlcccUV8dJLL0XEJ/9X8+ijj8bEiROjpKQkpk2bFjt27Cjcv3Dhwmhvb49cLhe5XC5uv/32hM8G+reZM2fGTTfdFEuXLo2hQ4dGZWXlp15zzc3NMW/evDjrrLOirKws/vzP/7zw9R9r1qyJ733ve/HKK68UXrNr1qw57LG6u7tj+fLlMWrUqMjn84XPr/rYO++8E7lcLtatWxfTp0+PkpKSmDBhQjzzzDOF+6+88sqIiDjnnHMil8vFN77xjRP+b8JJlMEpctNNN2UjRozIHnvssWznzp3ZggULsnPOOSf7zW9+k/3nf/5nFhHZuHHjsqeeeir77//+7+xP/uRPstGjR2cHDhzIurq6spUrV2ZlZWVZS0tL1tLSku3bty/1U4J+q7a2NisrK8tuv/327M0338z+8R//McvlctlTTz2VZVmWHTp0KKuurs6uuOKKbOvWrdkLL7yQTZkyJautrc2yLMs+/PDD7Fvf+lZ2ySWXFF6zH3744WGPddddd2VlZWXZv/zLv2Svv/56tnTp0mzgwIHZm2++mWVZlr399ttZRGSjRo3Kfv7zn2evvfZadu2112alpaXZBx98kH300UfZQw89lEVE9sYbb2QtLS3Z3r17T8m/EyeGGOGU2L9/fzZw4MDsZz/7WWHfgQMHshEjRmQ/+MEPCjGybt26wv2/+c1vssGDB2cPPPBAlmVZdv/992fl5eWneulwWqqtrc2uuOKKHvsuv/zy7G/+5m+yLMuyp556KjvjjDOy5ubmwv07d+7MIiJ78cUXsyzLsmXLlmWTJk36zGONGDEi+7u/+7tPHeub3/xmlmWfxMgdd9xRuP/gwYPZqFGjsr//+7/PsiwrvIfs2bOn6OdKen5Nwynxy1/+Mg4ePBgzZswo7Bs4cGBMnTo1fvGLXxT2/d8vUBw6dGiMHTu2x/3AqTNx4sQet4cPHx7vvfdeRET84he/iKqqqqiqqircP378+Dj77LOLes12dHTEu+++2+O9ISJixowZn3qc//v+cOaZZ8Zll13m/aGfECMAHNbAgQN73M7lctHd3Z1oNfRnYoRT4otf/GIMGjQonn/++cK+gwcPxksvvRTjx48v7HvhhRcK/71nz5548803Y9y4cRERMWjQoDh06NCpWzTQq3HjxsXu3btj9+7dhX2vvfZa7N27t/CaPprXbFlZWYwYMaLHe0NExPPPP9/jvSGi5/vDRx99FNu2bevx/hAR3iP6qKK/tReOxRe+8IW44YYb4tZbb42hQ4fG+eefHz/4wQ/iww8/jEWLFsUrr7wSERHLly+Pc889NyoqKuJv//Zv47zzzou6urqIiBg9enTs378/Nm3aFJMmTYohQ4bEkCFDEj4rOH3NmjUrLr300viLv/iLWLlyZXz00UfxzW9+M2pra+Oyyy6LiN+/Zt9+++1oamqKUaNGRWlp6WG/qffWW2+NZcuWxRe/+MWorq6O+++/P5qamuJnP/tZj3GrVq2Kiy66KMaNGxd333137NmzJ/7yL/8yIiIuuOCCyOVysWHDhvjjP/7jGDx4cJx11lkn/x+CEyP1RSucPn77299mf/VXf5Wdd955WT6fz2bMmFG40O3ji8/+/d//PbvkkkuyQYMGZVOnTs1eeeWVHo9x/fXXZ+eee24WEdmyZcsSPAs4PdTW1mY333xzj33z5s3LFixYULi9a9eu7E//9E+zL3zhC1lpaWl21VVXZa2trYX7f/e732V/9md/lp199tlZRGT333//YY916NCh7Pbbb89GjhyZDRw4MJs0aVL2+OOPF+7/+ALWtWvXZlOnTs0GDRqUjR8/Pnv66ad7PM7y5cuzysrKLJfL9Vgnn3+5LMuytDkEv/8ckSuvvDL27NnjU1aBHt55550YM2ZMbN++/ZR9oiunlmtGAICkxAgAkJRf0wAASTkzAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJPX/AI2FCb/3e24IAAAAAElFTkSuQmCC" + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYfklEQVR4nO3df2xV9f348ddFoQVti8hoKVQhahD80SKIFswKCRtjzEm2OGOWwPBH4tSpw8jGtgiyLHUuKEvGxoxR9kMGUzfM1DkYDo1aoyh1w5/RoWVK649BC6gF6fn+Ybz79CtVLlLeUh6P5CTec877nvclObdPT0/vzWVZlgUAQCK9Uk8AADi0iREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEjq8NQT2BsdHR3x+uuvR0lJSeRyudTTAQD2QpZlsW3btqisrIxevbq+/nFQxMjrr78eVVVVqacBAOyDTZs2xdChQ7vcflDESElJSUR88GJKS0sTzwYA2BttbW1RVVWV/znelYMiRj781UxpaakYAYCDzCfdYuEGVgAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUoenngDAgTDs+/emngJ8Zr1y/bSkx3dlBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABIqqAYqa+vj9NPPz1KSkpi0KBBMX369HjhhRc+cdwdd9wRJ554YhQXF8cpp5wS99133z5PGADoWQqKkQcffDAuu+yyeOyxx2L16tWxa9eu+OIXvxg7duzocsyjjz4a559/flx44YWxfv36mD59ekyfPj02bNjwqScPABz8clmWZfs6+M0334xBgwbFgw8+GJ///Of3uM95550XO3bsiHvuuSe/7swzz4yamppYsmTJXh2nra0tysrKorW1NUpLS/d1usAhbNj37009BfjMeuX6ad3yvHv78/tT3TPS2toaEREDBgzocp+GhoaYPHlyp3VTpkyJhoaGLse0t7dHW1tbpwUA6Jn2OUY6OjriqquuigkTJsTJJ5/c5X7Nzc1RXl7eaV15eXk0Nzd3Oaa+vj7KysryS1VV1b5OEwD4jNvnGLnssstiw4YNsXz58v05n4iImDt3brS2tuaXTZs27fdjAACfDYfvy6DLL7887rnnnnjooYdi6NChH7tvRUVFtLS0dFrX0tISFRUVXY4pKiqKoqKifZkaAHCQKejKSJZlcfnll8ef//zneOCBB2L48OGfOKa2tjbWrFnTad3q1aujtra2sJkCAD1SQVdGLrvssli2bFncfffdUVJSkr/vo6ysLPr27RsRETNmzIghQ4ZEfX19RERceeWVUVdXFwsXLoxp06bF8uXLY926dXHzzTfv55cCAByMCroy8qtf/SpaW1tj4sSJMXjw4PyyYsWK/D5NTU2xefPm/OPx48fHsmXL4uabb47q6uq48847Y+XKlR970ysAcOgo6MrI3nwkydq1az+y7txzz41zzz23kEMBAIcI300DACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABIquAYeeihh+Lss8+OysrKyOVysXLlyo/df+3atZHL5T6yNDc37+ucAYAe5PBCB+zYsSOqq6vjggsuiK997Wt7Pe6FF16I0tLS/ONBgwYVeuhuMez796aeAnymvXL9tNRTAHq4gmNk6tSpMXXq1IIPNGjQoOjfv3/B4wCAnu2A3TNSU1MTgwcPji984QvxyCOPfOy+7e3t0dbW1mkBAHqmbo+RwYMHx5IlS+Kuu+6Ku+66K6qqqmLixInx1FNPdTmmvr4+ysrK8ktVVVV3TxMASKTgX9MUasSIETFixIj84/Hjx8fLL78cN910U/zud7/b45i5c+fG7Nmz84/b2toECQD0UN0eI3sybty4ePjhh7vcXlRUFEVFRQdwRgBAKkk+Z6SxsTEGDx6c4tAAwGdMwVdGtm/fHi+99FL+8caNG6OxsTEGDBgQxxxzTMydOzdee+21+O1vfxsREYsWLYrhw4fHSSedFO+9917ccsst8cADD8SqVav236sAAA5aBcfIunXrYtKkSfnHH97bMXPmzFi6dGls3rw5mpqa8tt37twZV199dbz22mvRr1+/OPXUU+Pvf/97p+cAAA5dBcfIxIkTI8uyLrcvXbq00+M5c+bEnDlzCp4YAHBo8N00AEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACRVcIw89NBDcfbZZ0dlZWXkcrlYuXLlJ45Zu3ZtnHbaaVFUVBTHH398LF26dB+mCgD0RAXHyI4dO6K6ujoWL168V/tv3Lgxpk2bFpMmTYrGxsa46qqr4qKLLoq//e1vBU8WAOh5Di90wNSpU2Pq1Kl7vf+SJUti+PDhsXDhwoiIGDlyZDz88MNx0003xZQpUwo9PADQw3T7PSMNDQ0xefLkTuumTJkSDQ0NXY5pb2+Ptra2TgsA0DN1e4w0NzdHeXl5p3Xl5eXR1tYW77777h7H1NfXR1lZWX6pqqrq7mkCAIl8Jv+aZu7cudHa2ppfNm3alHpKAEA3KfiekUJVVFRES0tLp3UtLS1RWloaffv23eOYoqKiKCoq6u6pAQCfAd1+ZaS2tjbWrFnTad3q1aujtra2uw8NABwECo6R7du3R2NjYzQ2NkbEB3+629jYGE1NTRHxwa9YZsyYkd//kksuiX//+98xZ86ceP755+OXv/xl/PGPf4zvfve7++cVAAAHtYJjZN26dTF69OgYPXp0RETMnj07Ro8eHddee21ERGzevDkfJhERw4cPj3vvvTdWr14d1dXVsXDhwrjlllv8WS8AEBH7cM/IxIkTI8uyLrfv6dNVJ06cGOvXry/0UADAIeAz+dc0AMChQ4wAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkNQ+xcjixYtj2LBhUVxcHGeccUY8/vjjXe67dOnSyOVynZbi4uJ9njAA0LMUHCMrVqyI2bNnx7x58+Kpp56K6urqmDJlSrzxxhtdjiktLY3Nmzfnl1dfffVTTRoA6DkKjpEbb7wxLr744pg1a1aMGjUqlixZEv369Ytbb721yzG5XC4qKiryS3l5+aeaNADQcxQUIzt37ownn3wyJk+e/L8n6NUrJk+eHA0NDV2O2759exx77LFRVVUV55xzTjzzzDMfe5z29vZoa2vrtAAAPVNBMfLWW2/F7t27P3Jlo7y8PJqbm/c4ZsSIEXHrrbfG3XffHb///e+jo6Mjxo8fH//5z3+6PE59fX2UlZXll6qqqkKmCQAcRLr9r2lqa2tjxowZUVNTE3V1dfGnP/0pPve5z8Wvf/3rLsfMnTs3Wltb88umTZu6e5oAQCKHF7LzwIED47DDDouWlpZO61taWqKiomKvnqN3794xevToeOmll7rcp6ioKIqKigqZGgBwkCroykifPn1izJgxsWbNmvy6jo6OWLNmTdTW1u7Vc+zevTv+9a9/xeDBgwubKQDQIxV0ZSQiYvbs2TFz5swYO3ZsjBs3LhYtWhQ7duyIWbNmRUTEjBkzYsiQIVFfXx8REQsWLIgzzzwzjj/++Ni6dWv87Gc/i1dffTUuuuii/ftKAICDUsExct5558Wbb74Z1157bTQ3N0dNTU3cf//9+Ztam5qaolev/11w2bJlS1x88cXR3NwcRx11VIwZMyYeffTRGDVq1P57FQDAQSuXZVmWehKfpK2tLcrKyqK1tTVKS0v363MP+/69+/X5oKd55fppqaewXzjXoWvddZ7v7c9v300DACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFJiBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASEqMAABJiREAICkxAgAkJUYAgKTECACQlBgBAJISIwBAUmIEAEhKjAAASYkRACApMQIAJCVGAICkxAgAkJQYAQCSEiMAQFL7FCOLFy+OYcOGRXFxcZxxxhnx+OOPf+z+d9xxR5x44olRXFwcp5xyStx33337NFkAoOcpOEZWrFgRs2fPjnnz5sVTTz0V1dXVMWXKlHjjjTf2uP+jjz4a559/flx44YWxfv36mD59ekyfPj02bNjwqScPABz8Co6RG2+8MS6++OKYNWtWjBo1KpYsWRL9+vWLW2+9dY/7//znP48vfelLcc0118TIkSPjxz/+cZx22mnxi1/84lNPHgA4+B1eyM47d+6MJ598MubOnZtf16tXr5g8eXI0NDTscUxDQ0PMnj2707opU6bEypUruzxOe3t7tLe35x+3trZGRERbW1sh090rHe3v7PfnhJ6kO867FJzr0LXuOs8/fN4syz52v4Ji5K233ordu3dHeXl5p/Xl5eXx/PPP73FMc3PzHvdvbm7u8jj19fVx3XXXfWR9VVVVIdMF9oOyRalnAHS37j7Pt23bFmVlZV1uLyhGDpS5c+d2uprS0dER//3vf+Poo4+OXC6XcGZ0p7a2tqiqqopNmzZFaWlp6ukA3cS5fujIsiy2bdsWlZWVH7tfQTEycODAOOyww6KlpaXT+paWlqioqNjjmIqKioL2j4goKiqKoqKiTuv69+9fyFQ5iJWWlnqDgkOAc/3Q8HFXRD5U0A2sffr0iTFjxsSaNWvy6zo6OmLNmjVRW1u7xzG1tbWd9o+IWL16dZf7AwCHloJ/TTN79uyYOXNmjB07NsaNGxeLFi2KHTt2xKxZsyIiYsaMGTFkyJCor6+PiIgrr7wy6urqYuHChTFt2rRYvnx5rFu3Lm6++eb9+0oAgINSwTFy3nnnxZtvvhnXXnttNDc3R01NTdx///35m1SbmpqiV6//XXAZP358LFu2LH70ox/FD37wgzjhhBNi5cqVcfLJJ++/V0GPUFRUFPPmzfvIr+iAnsW5zv8vl33S39sAAHQj300DACQlRgCApMQIAJCUGAEAkhIjHDSWLl3qw+/gIDJ//vyoqak5IMdau3Zt5HK52Lp16wE5HvuXGAEAkhIjHDDt7e1xxRVXxKBBg6K4uDjOOuuseOKJJyLif/9Xc++998app54axcXFceaZZ8aGDRvy22fNmhWtra2Ry+Uil8vF/PnzE74a6NkmTpwYV1xxRcyZMycGDBgQFRUVHznnmpqa4pxzzokjjzwySktL4xvf+Eb+6z+WLl0a1113XTz99NP5c3bp0qV7PFZHR0csWLAghg4dGkVFRfnPr/rQK6+8ErlcLpYvXx7jx4+P4uLiOPnkk+PBBx/Mb580aVJERBx11FGRy+XiW9/61n7/N6EbZXCAXHHFFVllZWV23333Zc8880w2c+bM7Kijjsrefvvt7B//+EcWEdnIkSOzVatWZf/85z+zr3zlK9mwYcOynTt3Zu3t7dmiRYuy0tLSbPPmzdnmzZuzbdu2pX5J0GPV1dVlpaWl2fz587MXX3wx+81vfpPlcrls1apVWZZl2e7du7OamprsrLPOytatW5c99thj2ZgxY7K6urosy7LsnXfeya6++urspJNOyp+z77zzzh6PdeONN2alpaXZH/7wh+z555/P5syZk/Xu3Tt78cUXsyzLso0bN2YRkQ0dOjS78847s2effTa76KKLspKSkuytt97K3n///eyuu+7KIiJ74YUXss2bN2dbt249IP9O7B9ihANi+/btWe/evbPbb789v27nzp1ZZWVldsMNN+RjZPny5fntb7/9dta3b99sxYoVWZZl2W233ZaVlZUd6KnDIamuri4766yzOq07/fTTs+9973tZlmXZqlWrssMOOyxramrKb3/mmWeyiMgef/zxLMuybN68eVl1dfUnHquysjL7yU9+8pFjXXrppVmW/S9Grr/++vz2Xbt2ZUOHDs1++tOfZlmW5d9DtmzZUvBrJT2/puGAePnll2PXrl0xYcKE/LrevXvHuHHj4rnnnsuv+79foDhgwIAYMWJEp+3AgXPqqad2ejx48OB44403IiLiueeei6qqqqiqqspvHzVqVPTv37+gc7atrS1ef/31Tu8NERETJkz4yPP83/eHww8/PMaOHev9oYcQIwDsUe/evTs9zuVy0dHRkWg29GRihAPiuOOOiz59+sQjjzySX7dr16544oknYtSoUfl1jz32WP6/t2zZEi+++GKMHDkyIiL69OkTu3fvPnCTBro0cuTI2LRpU2zatCm/7tlnn42tW7fmz+m9OWdLS0ujsrKy03tDRMQjjzzS6b0hovP7w/vvvx9PPvlkp/eHiPAecZAq+Ft7YV8cccQR8e1vfzuuueaaGDBgQBxzzDFxww03xDvvvBMXXnhhPP300xERsWDBgjj66KOjvLw8fvjDH8bAgQNj+vTpERExbNiw2L59e6xZsyaqq6ujX79+0a9fv4SvCg5dkydPjlNOOSW++c1vxqJFi+L999+PSy+9NOrq6mLs2LER8cE5u3HjxmhsbIyhQ4dGSUnJHr+p95prrol58+bFcccdFzU1NXHbbbdFY2Nj3H777Z32W7x4cZxwwgkxcuTIuOmmm2LLli1xwQUXRETEscceG7lcLu6555748pe/HH379o0jjzyy+/8h2D9S37TCoePdd9/NvvOd72QDBw7MioqKsgkTJuRvdPvw5rO//OUv2UknnZT16dMnGzduXPb00093eo5LLrkkO/roo7OIyObNm5fgVcChoa6uLrvyyis7rTvnnHOymTNn5h+/+uqr2Ve/+tXsiCOOyEpKSrJzzz03a25uzm9/7733sq9//etZ//79s4jIbrvttj0ea/fu3dn8+fOzIUOGZL17986qq6uzv/71r/ntH97AumzZsmzcuHFZnz59slGjRmUPPPBAp+dZsGBBVlFRkeVyuU7z5LMvl2VZljaH4IPPEZk0aVJs2bLFp6wCnbzyyisxfPjwWL9+/QH7RFcOLPeMAABJiREAICm/pgEAknJlBABISowAAEmJEQAgKTECACQlRgCApMQIAJCUGAEAkhIjAEBSYgQASOr/AS8/VcIOgBY3AAAAAElFTkSuQmCC" }, "metadata": {}, "output_type": "display_data" @@ -397,6 +399,16 @@ "plt.bar(x, y)\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + }, + "id": "cb091d06bff277b0" } ], "metadata": { diff --git a/onnxoptimizer/query/onnx/compile.py b/onnxoptimizer/query/onnx/compile.py index 4dc5573..c1971f0 100644 --- a/onnxoptimizer/query/onnx/compile.py +++ b/onnxoptimizer/query/onnx/compile.py @@ -1,6 +1,14 @@ -from onnx.compose import merge_graphs, merge_models - -from onnxoptimizer.query.pandas.core.computation.ops import ONNXPredicate, ONNXFuncNode, BinOp, Term, UnaryOp +import numpy as np +import onnx.numpy_helper +from onnx import ValueInfoProto +from onnx.compose import merge_models + +from onnxoptimizer.query.onnx.joint import merge_project_models_wrap +from onnxoptimizer.query.onnx.joint.fragment import ModelFragment, OpModelFragment, TermModelFragment +from onnxoptimizer.query.onnx.joint.merge_expr import merge_models_binary_wrap +from onnxoptimizer.query.pandas.core.computation.ops import (ONNXPredicate, ONNXFuncNode, + BinOp, Term, UnaryOp, Constant) +from onnxoptimizer.query.types.mapper import numpy_onnx_tensor_type_map ONNX_CMP_OPS_SYMS = (">", "<", ">=", "<=", "==", "!=") _onnx_cmp_ops_nodes = ( @@ -51,27 +59,56 @@ _onnx_unary_ops_map = dict(zip(ONNX_UNARY_OPS_SYMS, _onnx_unary_ops_nodes)) +def merge_fragment_with_unary_op( + m: ModelFragment, + op: OpModelFragment, +): + model = m.model + op_model = op.model + + model = merge_models(model, op_model, io_map=[ + (m.get_default_output().name, op.get_default_input().name) + ]) + return OpModelFragment(model, op.return_type, op.op) + + class ONNXPredicateCompiler: + helper = onnx.helper + numpy_helper = onnx.numpy_helper + + IDENTITY = "Identity" + def __init__(self, env): self.env = env + self.temps = 0 + + def compile_save(self, node, path=None): + model = self.compile(node) + onnx.save(model.model, path) - def compile(self, node): + def compile(self, node) -> ModelFragment: compile_method = f"compile_{type(node).__name__}" compile_func = getattr(self, compile_method) return compile_func(node) - def compile_ONNXPredicate(self, node: ONNXPredicate): - lhs = node.lhs - rhs = node.rhs - op = node.op + def get_temp_name_by_value(self, value): + name = f"{type(value).__name__}_{self.temps}" + self.temps += 1 - lhs_ir = self.compile(lhs) - rhs_ir = self.compile(rhs) - print("compile predicate:", op) + return name - def compile_ONNXFuncNode(self, node: ONNXFuncNode): - print("compile func node:", node.name) + def get_temp_name(self): + name = f"temp_{self.temps}" + self.temps += 1 + + return name + + def compile_ONNXPredicate(self, node: ONNXPredicate): + return self.compile_BinOp(node) + + def compile_ONNXFuncNode(self, node: ONNXFuncNode) -> ModelFragment: + return ModelFragment(node.model, node.return_type) def compile_BinOp(self, node: BinOp): lhs = node.lhs @@ -80,15 +117,136 @@ def compile_BinOp(self, node: BinOp): lhs_ir = self.compile(lhs) rhs_ir = self.compile(rhs) - print("compile normal bin op:", op) + + lhs_prefix = self.get_temp_name() + rhs_prefix = self.get_temp_name() + + lhs_tensor = lhs_ir.get_default_output() + lhs_name = lhs_prefix + lhs_tensor.name + + rhs_tensor = rhs_ir.get_default_output() + rhs_name = rhs_prefix + rhs_tensor.name + + lhs_input_tensor = ValueInfoProto() + lhs_input_tensor.CopyFrom(lhs_tensor) + lhs_input_tensor.name = lhs_name + + rhs_input_tensor = ValueInfoProto() + rhs_input_tensor.CopyFrom(rhs_tensor) + rhs_input_tensor.name = rhs_name + + output_type = numpy_onnx_tensor_type_map[node.return_type.type] + + output_name = self.get_temp_name() + + # Try delegate shape inference job to onnx + output_tensor = self.helper.make_tensor_value_info( + name=output_name, + elem_type=output_type, + shape=[None] + ) + + op_node = self.helper.make_node( + op_type=_onnx_binary_ops_dict[op], + inputs=[lhs_name, rhs_name], + outputs=[output_name] + ) + + partial_graph = self.helper.make_graph( + nodes=[op_node], + name=self.get_temp_name(), + inputs=[lhs_input_tensor, rhs_input_tensor], + outputs=[output_tensor] + ) + + partial_model = self.helper.make_model(partial_graph) + + partial_model = merge_models_binary_wrap( + lhs_ir.model, + rhs_ir.model, + partial_model, + prefix1=lhs_prefix, + prefix2=rhs_prefix + ) + + partial_frag = OpModelFragment(partial_model, node.return_type, op) + + return partial_frag def compile_UnaryOp(self, node: UnaryOp): op = node.op operand = self.compile(node.operand) - print("compile normal unary op:", op) + + output_name = self.get_temp_name() + + input_tensor = operand.get_default_output() + input_name = input_tensor.name + + output_type = numpy_onnx_tensor_type_map[node.return_type.type] + + # Should I do shape inference here? + output_shape = input_tensor.type.tensor_type.shape + + output_tensor = self.helper.make_tensor_value_info( + name=output_name, + elem_type=output_type, + shape=output_shape + ) + + op_node = self.helper.make_node( + op_type=_onnx_unary_ops_map[op], + inputs=[input_name], + outputs=[output_name] + ) + + partial_graph = self.helper.make_graph( + nodes=[op_node], + name=self.get_temp_name(), + inputs=[input_tensor], + outputs=[output_tensor] + ) + + partial_model = self.helper.make_model(partial_graph) + + onnx.checker.check_model(partial_model) + + temp_frag = OpModelFragment(partial_model, node.return_type, op) + + merge_frag = merge_fragment_with_unary_op(operand, temp_frag) + + return merge_frag + + def compile_as_initializer(self, node): + output_name = self.get_temp_name() + value_name = self.get_temp_name_by_value(node.value) + + np_value = np.array(node.value) + constant_value = self.numpy_helper.from_array(np_value, name=value_name) + + identity_node = self.helper.make_node(self.IDENTITY, + inputs=[value_name], + outputs=[output_name]) + + term_output = self.helper.make_tensor_value_info(name=output_name, + elem_type=constant_value.data_type, + shape=np_value.shape) + + partial_graph = self.helper.make_graph( + nodes=[identity_node], + name=self.get_temp_name(), + inputs=[], + outputs=[term_output], + initializer=[constant_value] + ) + + partial_model = self.helper.make_model(partial_graph) + + onnx.checker.check_model(partial_model) + + return TermModelFragment(partial_model, np_value.dtype) def compile_Term(self, node: Term): - print("compile term node:", node.name) + return self.compile_as_initializer(node) - def compile_Constant(self, node: Term): - print("compile constant node:", node.name) + def compile_Constant(self, node: Constant): + return self.compile_as_initializer(node) diff --git a/onnxoptimizer/query/onnx/context.py b/onnxoptimizer/query/onnx/context.py index 52e50f5..729f3f6 100644 --- a/onnxoptimizer/query/onnx/context.py +++ b/onnxoptimizer/query/onnx/context.py @@ -10,8 +10,25 @@ def __init__(self, model_obj: ModelObject): model_data = self.model_obj.model.SerializeToString() self.infer_session = ort.InferenceSession(model_data) + + self.labels_map = { + elem.name: elem for elem in self.infer_session.get_outputs() + if elem.name.endswith("label") or elem.name.endswith("variable") + } + self.probabilities_map = { + elem.name: elem for elem in self.infer_session.get_outputs() + if elem.name.endswith("probability") + } + self.infer_input = {} + def return_type(self, which=None): + if which is None: + assert len(self.labels_map) == 1 + return next(iter(self.labels_map.values())).type + else: + return self.labels_map[which].type + def set_infer_input(self, **kwargs): self.infer_input = kwargs @@ -27,10 +44,8 @@ def __call__(self): } session = self.infer_session - labels = [elem.name for elem in session.get_outputs() if elem.name.endswith("label") or elem.name.endswith("variable")] - probabilities = [elem.name for elem in session.get_outputs() if elem.name.endswith("probability")] - label_out = [] + labels = list(self.labels_map.keys()) for elem in labels: label_out.append(elem.replace("output_label", "").replace("variable", "")) diff --git a/onnxoptimizer/query/onnx/joint/__init__.py b/onnxoptimizer/query/onnx/joint/__init__.py index 746b65b..533c3ba 100644 --- a/onnxoptimizer/query/onnx/joint/__init__.py +++ b/onnxoptimizer/query/onnx/joint/__init__.py @@ -1,7 +1,7 @@ -from onnxoptimizer.query.onnx.joint.expr_compose import merge_project_models -from onnxoptimizer.query.onnx.joint.expr_compose import merge_project_graphs +from onnxoptimizer.query.onnx.joint.merge_expr import merge_project_models_wrap +from onnxoptimizer.query.onnx.joint.merge_expr import merge_project_graphs __all__ = [ - "merge_project_models", + "merge_project_models_wrap", "merge_project_graphs" ] diff --git a/onnxoptimizer/query/onnx/joint/expr_compose.py b/onnxoptimizer/query/onnx/joint/expr_compose.py deleted file mode 100644 index 76f8f42..0000000 --- a/onnxoptimizer/query/onnx/joint/expr_compose.py +++ /dev/null @@ -1,164 +0,0 @@ -from typing import MutableMapping, Optional - -from onnx import ModelProto, helper, checker, GraphProto - -from onnxoptimizer.query.onnx.joint.model_utils import add_prefix_model - - -def merge_project_graphs( - g1: GraphProto, - g2: GraphProto, - name: Optional[str] = None, - doc_string: Optional[str] = None, -) -> GraphProto: - if type(g1) is not GraphProto: - raise ValueError("g1 argument is not an ONNX graph") - if type(g2) is not GraphProto: - raise ValueError("g2 argument is not an ONNX graph") - - g = GraphProto() - - g.node.extend(g1.node) - g.node.extend(g2.node) - - input_names = set() - for e in g1.input: - if e.name not in input_names: - g.input.append(e) - input_names.add(e.name) - for e in g2.input: - if e.name not in input_names: - g.input.append(e) - input_names.add(e.name) - - g.output.extend(g1.output) - g.output.extend(g2.output) - - g.initializer.extend(g1.initializer) - g.initializer.extend(g2.initializer) - - g.sparse_initializer.extend(g1.sparse_initializer) - g.sparse_initializer.extend(g2.sparse_initializer) - - g.value_info.extend(g1.value_info) - g.value_info.extend(g2.value_info) - - g.name = name if name is not None else "_".join([g1.name, g2.name]) - - if doc_string is None: - doc_string = ( - f"Graph combining {g1.name} and {g2.name}\n" - + g1.name - + "\n" - + g1.doc_string - + "\n" - + g2.name - + "\n" - + g2.doc_string - ) - g.doc_string = doc_string - - return g - - -def merge_project_models( - m1: ModelProto, - m2: ModelProto, - prefix1: Optional[str] = None, - prefix2: Optional[str] = None, - name: Optional[str] = None, - doc_string: Optional[str] = None, - producer_name: Optional[str] = "onnx.expr_compose.merge_models", - producer_version: Optional[str] = "1.0", - domain: Optional[str] = "", - model_version: Optional[int] = 1, ) -> ModelProto: - if type(m1) is not ModelProto: - raise ValueError("m1 argument is not an ONNX model") - if type(m2) is not ModelProto: - raise ValueError("m2 argument is not an ONNX model") - - if m1.ir_version != m2.ir_version: - raise ValueError( - f"IR version mismatch {m1.ir_version} != {m2.ir_version}." - " Both models should have the same IR version" - ) - ir_version = m1.ir_version - - opset_import_map: MutableMapping[str, int] = {} - opset_imports = list(m1.opset_import) + list(m2.opset_import) - - for entry in opset_imports: - if entry.domain in opset_import_map: - found_version = opset_import_map[entry.domain] - if entry.version != found_version: - # raise ValueError( - # "Can't merge two models with different operator set ids for a given domain. " - # f"Got: {m1.opset_import} and {m2.opset_import}" - # ) - opset_import_map[entry.domain] = max(int(entry.version), int(found_version)) - else: - opset_import_map[entry.domain] = entry.version - - # Prefixing names in the graph if requested, adjusting io_map accordingly - if prefix1 or prefix2: - if prefix1: - m1_copy = ModelProto() - m1_copy.CopyFrom(m1) - m1 = m1_copy - m1 = add_prefix_model(m1, prefix=prefix1) - if prefix2: - m2_copy = ModelProto() - m2_copy.CopyFrom(m2) - m2 = m2_copy - m2 = add_prefix_model(m2, prefix=prefix2) - - graph = merge_project_graphs( - m1.graph, - m2.graph, - name=name, - doc_string=doc_string, - ) - model = helper.make_model( - graph, - producer_name=producer_name, - producer_version=producer_version, - domain=domain, - model_version=model_version, - opset_imports=opset_imports, - ir_version=ir_version, - ) - - # Merging model metadata props - model_props = {} - for meta_entry in m1.metadata_props: - model_props[meta_entry.key] = meta_entry.value - for meta_entry in m2.metadata_props: - if meta_entry.key in model_props: - value = model_props[meta_entry.key] - if value != meta_entry.value: - raise ValueError( - "Can't merge models with different values for the same model metadata property." - f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}." - ) - else: - model_props[meta_entry.key] = meta_entry.value - helper.set_model_props(model, model_props) - - # Merging functions - function_overlap = list( - {f.name for f in m1.functions} & {f.name for f in m2.functions} - ) - if function_overlap: - raise ValueError( - "Can't merge models with overlapping local function names." - " Found in both graphs: " + ", ".join(function_overlap) - ) - model.functions.MergeFrom(m1.functions) - model.functions.MergeFrom(m2.functions) - - # walk around bug from sklearn-onnx converter - # zipmap will not pass the check - # see: https://github.com/onnx/sklearn-onnx/issues/858 - checker.check_model(model, full_check=False) - - return model diff --git a/onnxoptimizer/query/onnx/joint/fragment.py b/onnxoptimizer/query/onnx/joint/fragment.py new file mode 100644 index 0000000..d01b11a --- /dev/null +++ b/onnxoptimizer/query/onnx/joint/fragment.py @@ -0,0 +1,62 @@ +import numpy as np +from onnx import ModelProto + +from onnxoptimizer.query.types.mapper import numpy_onnx_tensor_type_map + + +class ModelFragment: + def __init__(self, model_partial: ModelProto, return_type): + self.model_partial = model_partial + self.return_type = return_type + + @property + def model(self): + return self.model_partial + + @property + def model_inputs(self): + return list(self.model_partial.graph.input) + + @property + def model_outputs(self): + return list(self.model_partial.graph.output) + + def return_tensor_type(self): + return numpy_onnx_tensor_type_map[self.return_type] + + def get_outputs_endswith(self, suffix: str): + outputs = [] + + for elem in self.model_outputs: + if elem.name.endswith(suffix): + outputs.append(elem) + + return outputs + + def get_default_output(self): + maybe_output = self.get_outputs_endswith("_label") + maybe_output.extend(self.get_outputs_endswith("_variable")) + return maybe_output[0] + + def get_default_input(self): + return self.model_inputs[0] + + +class OpModelFragment(ModelFragment): + def __init__(self, model_partial: ModelProto, return_type, op: str): + super().__init__(model_partial, return_type) + self.op = op + + def get_default_output(self): + return self.model_outputs[0] + + def get_default_input(self): + return self.model_inputs[0] + + +class TermModelFragment(ModelFragment): + def get_default_output(self): + return self.model_outputs[0] + + def get_default_input(self): + return self.model_inputs[0] diff --git a/onnxoptimizer/query/onnx/joint/merge_expr.py b/onnxoptimizer/query/onnx/joint/merge_expr.py new file mode 100644 index 0000000..8f98a4a --- /dev/null +++ b/onnxoptimizer/query/onnx/joint/merge_expr.py @@ -0,0 +1,416 @@ +from typing import MutableMapping, Optional + +import onnx +from onnx import ModelProto, helper, checker, GraphProto + +from onnxoptimizer.query.onnx.joint.model_utils import add_prefix_model + + +def merge_project_graphs( + g1: GraphProto, + g2: GraphProto, + name: Optional[str] = None, + doc_string: Optional[str] = None, +) -> GraphProto: + if type(g1) is not GraphProto: + raise ValueError("g1 argument is not an ONNX graph") + if type(g2) is not GraphProto: + raise ValueError("g2 argument is not an ONNX graph") + + g = GraphProto() + + g.node.extend(g1.node) + g.node.extend(g2.node) + + input_names = set() + for e in g1.input: + if e.name not in input_names: + g.input.append(e) + input_names.add(e.name) + for e in g2.input: + if e.name not in input_names: + g.input.append(e) + input_names.add(e.name) + + g.output.extend(g1.output) + g.output.extend(g2.output) + + g.initializer.extend(g1.initializer) + g.initializer.extend(g2.initializer) + + g.sparse_initializer.extend(g1.sparse_initializer) + g.sparse_initializer.extend(g2.sparse_initializer) + + g.value_info.extend(g1.value_info) + g.value_info.extend(g2.value_info) + + g.name = name if name is not None else "_".join([g1.name, g2.name]) + + if doc_string is None: + doc_string = ( + f"Graph combining {g1.name} and {g2.name}\n" + + g1.name + + "\n" + + g1.doc_string + + "\n" + + g2.name + + "\n" + + g2.doc_string + ) + g.doc_string = doc_string + + return g + + +def merge_project_models( + m1: ModelProto, + m2: ModelProto, + prefix1: Optional[str] = None, + prefix2: Optional[str] = None, + name: Optional[str] = None, + doc_string: Optional[str] = None, + producer_name: Optional[str] = "onnx.expr_compose.merge_models", + producer_version: Optional[str] = "1.0", + domain: Optional[str] = "", + model_version: Optional[int] = 1, ) -> ModelProto: + if type(m1) is not ModelProto: + raise ValueError("m1 argument is not an ONNX model") + if type(m2) is not ModelProto: + raise ValueError("m2 argument is not an ONNX model") + + if m1.ir_version != m2.ir_version: + raise ValueError( + f"IR version mismatch {m1.ir_version} != {m2.ir_version}." + " Both models should have the same IR version" + ) + ir_version = m1.ir_version + + opset_import_map: MutableMapping[str, int] = {} + opset_imports = list(m1.opset_import) + list(m2.opset_import) + + for entry in opset_imports: + if entry.domain in opset_import_map: + found_version = opset_import_map[entry.domain] + if entry.version != found_version: + # raise ValueError( + # "Can't merge two models with different operator set ids for a given domain. " + # f"Got: {m1.opset_import} and {m2.opset_import}" + # ) + opset_import_map[entry.domain] = max(int(entry.version), int(found_version)) + else: + opset_import_map[entry.domain] = entry.version + + # Prefixing names in the graph if requested, adjusting io_map accordingly + if prefix1 or prefix2: + if prefix1: + m1_copy = ModelProto() + m1_copy.CopyFrom(m1) + m1 = m1_copy + m1 = add_prefix_model(m1, prefix=prefix1) + if prefix2: + m2_copy = ModelProto() + m2_copy.CopyFrom(m2) + m2 = m2_copy + m2 = add_prefix_model(m2, prefix=prefix2) + + graph = merge_project_graphs( + m1.graph, + m2.graph, + name=name, + doc_string=doc_string, + ) + model = helper.make_model( + graph, + producer_name=producer_name, + producer_version=producer_version, + domain=domain, + model_version=model_version, + opset_imports=opset_imports, + ir_version=ir_version, + ) + + # Merging model metadata props + model_props = {} + for meta_entry in m1.metadata_props: + model_props[meta_entry.key] = meta_entry.value + for meta_entry in m2.metadata_props: + if meta_entry.key in model_props: + value = model_props[meta_entry.key] + if value != meta_entry.value: + raise ValueError( + "Can't merge models with different values for the same model metadata property." + f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}." + ) + else: + model_props[meta_entry.key] = meta_entry.value + helper.set_model_props(model, model_props) + + # Merging functions + function_overlap = list( + {f.name for f in m1.functions} & {f.name for f in m2.functions} + ) + if function_overlap: + raise ValueError( + "Can't merge models with overlapping local function names." + " Found in both graphs: " + ", ".join(function_overlap) + ) + model.functions.MergeFrom(m1.functions) + model.functions.MergeFrom(m2.functions) + + # walk around bug from sklearn-onnx converter + # zipmap will not pass the check + # see: https://github.com/onnx/sklearn-onnx/issues/858 + checker.check_model(model, full_check=False) + + return model + + +def merge_project_models_wrap( + m1: ModelProto, + m2: ModelProto, + prefix1: Optional[str] = None, + prefix2: Optional[str] = None, + name: Optional[str] = None, + doc_string: Optional[str] = None, + producer_name: Optional[str] = "onnx.expr_compose.merge_models", + producer_version: Optional[str] = "1.0", + domain: Optional[str] = "", + model_version: Optional[int] = 1, ) -> ModelProto: + if m1.ir_version != m2.ir_version: + ir_version = max(m1.ir_version, m2.ir_version) + m1.ir_version = ir_version + m2.ir_version = ir_version + + return merge_project_models( + m1=m1, + m2=m2, + prefix1=prefix1, + prefix2=prefix2, + name=name, + doc_string=doc_string, + producer_name=producer_name, + producer_version=producer_version, + domain=domain, + model_version=model_version + ) + + +def merge_graphs_binary( + g1: GraphProto, + g2: GraphProto, + gop: GraphProto, + name: Optional[str] = None, + doc_string: Optional[str] = None, +) -> GraphProto: + if type(g1) is not GraphProto: + raise ValueError("g1 argument is not an ONNX graph") + if type(g2) is not GraphProto: + raise ValueError("g2 argument is not an ONNX graph") + if type(gop) is not GraphProto: + raise ValueError("gop argument is not an ONNX graph") + + g = GraphProto() + + g.node.extend(g1.node) + g.node.extend(g2.node) + g.node.extend(gop.node) + + input_names = set() + for e in g1.input: + if e.name not in input_names: + g.input.append(e) + input_names.add(e.name) + for e in g2.input: + if e.name not in input_names: + g.input.append(e) + input_names.add(e.name) + + g.output.extend(gop.output) + + g.initializer.extend(g1.initializer) + g.initializer.extend(g2.initializer) + g.initializer.extend(gop.initializer) + + g.sparse_initializer.extend(g1.sparse_initializer) + g.sparse_initializer.extend(g2.sparse_initializer) + g.sparse_initializer.extend(gop.sparse_initializer) + + g.value_info.extend(g1.value_info) + g.value_info.extend(g2.value_info) + g.value_info.extend(gop.value_info) + + g.name = name if name is not None else "_".join([g1.name, g2.name, gop.name]) + + if doc_string is None: + doc_string = ( + f"Graph combining {g1.name} and {g2.name}\n" + + g1.name + + "\n" + + g1.doc_string + + "\n" + + g2.name + + "\n" + + g2.doc_string + ) + g.doc_string = doc_string + + return g + + +def merge_models_binary( + m1: ModelProto, + m2: ModelProto, + mop: ModelProto, + prefix1: Optional[str] = None, + prefix2: Optional[str] = None, + name: Optional[str] = None, + doc_string: Optional[str] = None, + producer_name: Optional[str] = "onnx.compose.merge_models", + producer_version: Optional[str] = "1.0", + domain: Optional[str] = "", + model_version: Optional[int] = 1, +): + if type(m1) is not ModelProto: + raise ValueError("m1 argument is not an ONNX model") + if type(m2) is not ModelProto: + raise ValueError("m2 argument is not an ONNX model") + if type(mop) is not ModelProto: + raise ValueError("mop argument is not an ONNX model") + + if m1.ir_version == m2.ir_version == mop.ir_version: + ir_version = m1.ir_version + else: + raise ValueError( + f"IR version mismatch {m1.ir_version} != {m2.ir_version}." + " Both models should have the same IR version" + ) + + opset_import_map: MutableMapping[str, int] = {} + opset_imports = list(m1.opset_import) + list(m2.opset_import) + + for entry in opset_imports: + if entry.domain in opset_import_map: + found_version = opset_import_map[entry.domain] + if entry.version != found_version: + # raise ValueError( + # "Can't merge two models with different operator set ids for a given domain. " + # f"Got: {m1.opset_import} and {m2.opset_import}" + # ) + opset_import_map[entry.domain] = max(int(entry.version), int(found_version)) + else: + opset_import_map[entry.domain] = entry.version + + # Prefixing names in the graph if requested, adjusting io_map accordingly + if prefix1 or prefix2: + if prefix1: + m1_copy = ModelProto() + m1_copy.CopyFrom(m1) + m1 = m1_copy + m1 = add_prefix_model(m1, prefix=prefix1) + if prefix2: + m2_copy = ModelProto() + m2_copy.CopyFrom(m2) + m2 = m2_copy + m2 = add_prefix_model(m2, prefix=prefix2) + + graph = merge_graphs_binary( + m1.graph, + m2.graph, + mop.graph, + name=name, + doc_string=doc_string, + ) + model = helper.make_model( + graph, + producer_name=producer_name, + producer_version=producer_version, + domain=domain, + model_version=model_version, + opset_imports=opset_imports, + ir_version=ir_version, + ) + + # Merging model metadata props + model_props = {} + for meta_entry in m1.metadata_props: + model_props[meta_entry.key] = meta_entry.value + + for meta_entry in m2.metadata_props: + if meta_entry.key in model_props: + value = model_props[meta_entry.key] + if value != meta_entry.value: + raise ValueError( + "Can't merge models with different values for the same model metadata property." + f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}." + ) + else: + model_props[meta_entry.key] = meta_entry.value + + for meta_entry in mop.metadata_props: + if meta_entry.key in model_props: + value = model_props[meta_entry.key] + if value != meta_entry.value: + raise ValueError( + "Can't merge models with different values for the same model metadata property." + f" Found: property = {meta_entry.key}, with values {value} and {meta_entry.value}." + ) + else: + model_props[meta_entry.key] = meta_entry.value + helper.set_model_props(model, model_props) + + # Merging functions + function_overlap = list( + {f.name for f in m1.functions} & {f.name for f in m2.functions} & + {f.name for f in mop.functions} + ) + if function_overlap: + raise ValueError( + "Can't merge models with overlapping local function names." + " Found in both graphs: " + ", ".join(function_overlap) + ) + model.functions.MergeFrom(m1.functions) + model.functions.MergeFrom(m2.functions) + model.functions.MergeFrom(mop.functions) + + onnx.save(model.SerializeToString(), "temp3.onnx") + + # walk around bug from sklearn-onnx converter + # zipmap will not pass the check + # see: https://github.com/onnx/sklearn-onnx/issues/858 + checker.check_model(model, full_check=False) + + return model + + +def merge_models_binary_wrap( + m1: ModelProto, + m2: ModelProto, + mop: ModelProto, + prefix1: Optional[str] = None, + prefix2: Optional[str] = None, + name: Optional[str] = None, + doc_string: Optional[str] = None, + producer_name: Optional[str] = "onnx.compose.merge_models", + producer_version: Optional[str] = "1.0", + domain: Optional[str] = "", + model_version: Optional[int] = 1, +): + if not (m1.ir_version == m2.ir_version == mop.ir_version): + ir_version = max(m1.ir_version, m2.ir_version, mop.ir_version) + m1.ir_version = ir_version + m2.ir_version = ir_version + mop.ir_version = ir_version + + return merge_models_binary( + m1=m1, + m2=m2, + mop=mop, + prefix1=prefix1, + prefix2=prefix2, + name=name, + doc_string=doc_string, + producer_name=producer_name, + producer_version=producer_version, + domain=domain, + model_version=model_version + ) diff --git a/onnxoptimizer/query/onnx/model.py b/onnxoptimizer/query/onnx/model.py index e23b4ce..c8adfe3 100644 --- a/onnxoptimizer/query/onnx/model.py +++ b/onnxoptimizer/query/onnx/model.py @@ -4,7 +4,7 @@ from sklearn.pipeline import Pipeline import onnxoptimizer -from onnxoptimizer.query.types.mapper import numpy_onnx_type_map +from onnxoptimizer.query.types.mapper import input_numpy_onnx_type_map class ModelObject: @@ -14,7 +14,7 @@ def __init__(self, pipeline: str | Pipeline | ModelProto, schema=None): if type(pipeline) == str: self.model = onnx.load_model(pipeline) elif type(pipeline) == Pipeline: - init_types = [(k, numpy_onnx_type_map[v]) for k, v in self.schema.items()] + init_types = [(k, input_numpy_onnx_type_map[v]) for k, v in self.schema.items()] self.model = skl2onnx.to_onnx(pipeline, initial_types=init_types) else: self.model = pipeline diff --git a/onnxoptimizer/query/pandas/api/patch.py b/onnxoptimizer/query/pandas/api/patch.py index 8b1459d..cbb0efe 100644 --- a/onnxoptimizer/query/pandas/api/patch.py +++ b/onnxoptimizer/query/pandas/api/patch.py @@ -10,6 +10,7 @@ LEVEL_OFFSET_1 = 1 + @callable_patch(pandas) def predict_eval(expr: str, parser: str = "pandas", @@ -23,7 +24,7 @@ def predict_eval(expr: str, enable_opt: bool = True): return pandas_eval(expr, parser=parser, engine=engine, local_dict=local_dict, global_dict=global_dict, - resolvers=resolvers, level=level+LEVEL_OFFSET_1, + resolvers=resolvers, level=level + LEVEL_OFFSET_1, target=target, inplace=inplace, enable_opt=enable_opt) diff --git a/onnxoptimizer/query/pandas/core/computation/eval.py b/onnxoptimizer/query/pandas/core/computation/eval.py index f9a16c9..e5646d1 100644 --- a/onnxoptimizer/query/pandas/core/computation/eval.py +++ b/onnxoptimizer/query/pandas/core/computation/eval.py @@ -253,7 +253,7 @@ def pandas_eval( expr_to_eval.append(parsed_expr) predicate_compiler = ONNXPredicateCompiler(env) - predicate_compiler.compile(expr_to_eval[0].terms) + predicate_compiler.compile_save(expr_to_eval[0].terms, "./temp_rs.onnx") ################# # Optimization Phase diff --git a/onnxoptimizer/query/pandas/core/computation/ops.py b/onnxoptimizer/query/pandas/core/computation/ops.py index b36ff38..d9aff2b 100644 --- a/onnxoptimizer/query/pandas/core/computation/ops.py +++ b/onnxoptimizer/query/pandas/core/computation/ops.py @@ -31,6 +31,7 @@ result_type_many, ) from onnxoptimizer.query.pandas.core.computation.scope import DEFAULT_GLOBALS +from onnxoptimizer.query.types.mapper import onnx_type_str_numpy_map if TYPE_CHECKING: from collections.abc import ( @@ -648,8 +649,8 @@ def evaluate(self, *args, **kwargs): @property def return_type(self): - # TODO: add return type inference of ort call. - return np.ndarray + onnx_type_str = self.model_context.return_type() + return onnx_type_str_numpy_map[onnx_type_str] @property def local_name(self) -> str: @@ -657,7 +658,8 @@ def local_name(self) -> str: @property def type(self): - return np.ndarray + onnx_type_str = self.model_context.return_type() + return onnx_type_str_numpy_map[onnx_type_str] @property def is_scalar(self): @@ -677,7 +679,7 @@ def __init__(self, op, lhs, rhs): @property def return_type(self): - return np.bool_ + return np.dtype('bool') def __call__(self, env): # call eval in parsing time will diff --git a/onnxoptimizer/query/pandas/optimization/optimizer.py b/onnxoptimizer/query/pandas/optimization/optimizer.py index 266ecf0..820f1f5 100644 --- a/onnxoptimizer/query/pandas/optimization/optimizer.py +++ b/onnxoptimizer/query/pandas/optimization/optimizer.py @@ -1,5 +1,7 @@ +import onnx + import onnxoptimizer -from onnxoptimizer.query.onnx.joint import merge_project_models +from onnxoptimizer.query.onnx.joint import merge_project_models_wrap from onnxoptimizer.query.onnx.context import ModelContext from onnxoptimizer.query.onnx.model import ModelObject @@ -22,11 +24,11 @@ def optimize(self, expr_list): model0_prefix, model0_model = models[0] model1_prefix, model1_model = models[1] - model_fused = merge_project_models(model0_model, model1_model, model0_prefix, model1_prefix) + model_fused = merge_project_models_wrap(model0_model, model1_model, model0_prefix, model1_prefix) for i in range(2, len(models)): model_prefix, model_model = models[i] - model_fused = merge_project_models(model_fused, model_model, "", model_prefix) + model_fused = merge_project_models_wrap(model_fused, model_model, "", model_prefix) model_fused = self.model_optimizer.optimize(model_fused, fixed_point=True) diff --git a/onnxoptimizer/query/types/mapper.py b/onnxoptimizer/query/types/mapper.py index 9b7ebda..6409ec8 100644 --- a/onnxoptimizer/query/types/mapper.py +++ b/onnxoptimizer/query/types/mapper.py @@ -1,6 +1,7 @@ import numpy as np +from onnx import TensorProto from onnxconverter_common import (Int64TensorType, Int32TensorType, - FloatTensorType, StringTensorType) + FloatTensorType, StringTensorType, BooleanTensorType) numpy_type_map = { np.float32: np.float32, @@ -10,10 +11,25 @@ np.object_: str, } -numpy_onnx_type_map = { +input_numpy_onnx_type_map = { np.int64: Int64TensorType([None, 1]), np.int32: Int32TensorType([None, 1]), np.float64: FloatTensorType([None, 1]), np.float32: FloatTensorType([None, 1]), np.object_: StringTensorType([None, 1]), + np.bool_: BooleanTensorType([None, 1]) +} + +numpy_onnx_tensor_type_map = { + np.int64: TensorProto.INT64, + np.int32: TensorProto.INT32, + np.float64: TensorProto.FLOAT, + np.float32: TensorProto.FLOAT, + np.object_: TensorProto.STRING, + np.bool_: TensorProto.BOOL, +} + +onnx_type_str_numpy_map = { + 'tensor(int64)': np.dtype("int64"), + 'tensor(float)': np.dtype("float") } diff --git a/onnxoptimizer/test/pandas_query_test.py b/onnxoptimizer/test/pandas_query_test.py index 130e47d..90d1ae5 100644 --- a/onnxoptimizer/test/pandas_query_test.py +++ b/onnxoptimizer/test/pandas_query_test.py @@ -6,13 +6,15 @@ from onnxoptimizer.query.pandas.api import model_udf +DATA_DIR = "/home/uw1/MLquery/reference/snippets/py_onnx/expedia" + class TestEval(unittest.TestCase): @classmethod def setUpClass(cls) -> None: - path1 = "/home/uw1/snippets/py_onnx/expedia/data/S_listings.csv" - path2 = "/home/uw1/snippets/py_onnx/expedia/data/R1_hotels.csv" - path3 = "/home/uw1/snippets/py_onnx/expedia/data/R2_searches.csv" + path1 = f"{DATA_DIR}/data/S_listings.csv" + path2 = f"{DATA_DIR}/data/R1_hotels.csv" + path3 = f"{DATA_DIR}/data/R2_searches.csv" S_listings = pd.read_csv(path1) R1_hotels = pd.read_csv(path2) @@ -92,13 +94,14 @@ def say_hello(a): res = df.predict_eval('''new=@say_hello(a=age) new2=@say_hello(a=age)''', engine='python') b = 3 - pd.predict_eval("b + 1 > 1") + # pd.predict_eval("b + 1 > 1") + df.predict_filter("sin(age) > 15") print(res) def test_predict_end2end(self): batch = self.df.iloc[: 4096, :] - @model_udf("/home/uw1/snippets/py_onnx/expedia/expedia.onnx") + @model_udf(f"{DATA_DIR}/expedia_lr.onnx") def expedia_infer(infer_df): return infer_df.to_dict(orient="series") @@ -112,7 +115,7 @@ def expedia_infer(infer_df): def test_predict_filter_eval(self): batch = self.df.iloc[: 4096, :] - @model_udf("/home/uw1/snippets/py_onnx/expedia/expedia.onnx") + @model_udf(f"{DATA_DIR}/expedia_lr.onnx") def expedia_infer(infer_df): return infer_df.to_dict(orient="series")