Skip to content

Commit

Permalink
Enable Py func replacement for now, lower/upper_bin
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Jul 10, 2023
1 parent 951a784 commit 68c765e
Show file tree
Hide file tree
Showing 18 changed files with 431 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
package io.deephaven.integrations.python;

import io.deephaven.engine.util.PythonObjectWrapper;
import io.deephaven.util.annotations.ScriptApi;
import org.jpy.PyObject;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
*/
package io.deephaven.integrations.python;

import io.deephaven.engine.util.PythonObjectWrapper;
import io.deephaven.util.annotations.ScriptApi;
import org.jpy.PyObject;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.deephaven.integrations.python;

import io.deephaven.engine.util.PythonObjectWrapper;
import io.deephaven.util.QueryConstants;
import org.jpy.PyObject;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@
import io.deephaven.engine.context.QueryScope;
import io.deephaven.engine.table.impl.MatchPair;
import io.deephaven.engine.table.impl.ShiftedColumnsFactory;
import io.deephaven.engine.util.AbstractScriptSession;
import io.deephaven.engine.util.PyCallableWrapper.ColumnChunkArgument;
import io.deephaven.engine.util.PyCallableWrapper.ConstantChunkArgument;
import io.deephaven.engine.util.PyCallableWrapper;
import io.deephaven.engine.util.PythonDeephavenSession;
import io.deephaven.engine.util.PythonScopeJpyImpl;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
import io.deephaven.util.type.TypeUtils;
Expand All @@ -109,6 +112,8 @@
import io.deephaven.vector.Vector;
import org.apache.commons.lang3.ClassUtils;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.jpy.PyInputMode;
import org.jpy.PyListWrapper;
import org.jpy.PyObject;

import java.lang.reflect.Array;
Expand Down Expand Up @@ -465,6 +470,66 @@ private Class<?> findNestedClass(Class<?> enclosingClass, String nestedClassName
return m.get(nestedClassName);
}

private boolean pyToJavaReplaced(final Class<?> scope, final MethodCallExpr n) {
final String methodName = n.getNameAsString();
final QueryScope queryScope = ExecutionContext.getContext().getQueryScope();

if (scope == null) {
final Class<?> methodClass = variables.get(methodName);
if (methodClass == PyCallableWrapper.class) {
final Object paramValueRaw = queryScope.readParamValue(methodName, null);
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
try {
String javaMethodName = pyCallableWrapper.getAttribute("_j_simple_name", String.class);
n.setName(javaMethodName);
int nargs = pyCallableWrapper.getAttribute("_nargs", int.class);
PyListWrapper defArgList = (PyListWrapper) pyCallableWrapper.getAttribute("_def_args").asList();
int defArgsToAdd = nargs - n.getArguments().size();
if (defArgsToAdd > defArgList.size()) {
throw new IllegalArgumentException("Missing args for " + methodName);
}
if (defArgsToAdd < 0) {
// do nothing, could be valid if the java method has a vararg
}
// Since the func call follows the Java syntax, we don't do name matching for the keyword args
for (int i = defArgList.size() - (nargs - n.getArguments().size()); i < defArgList.size(); i++) {
Object v = PythonScopeJpyImpl.convert(defArgList.get(i));
n.addArgument(v.getClass() == String.class ? "\"" + v.toString() + "\"" : v.toString());
}
return true;
} catch (IllegalArgumentException iae) {
throw iae;
} catch (RuntimeException e) {
// Not a Java replaceable callable, can safely ignore
}
}
} else if (scope == org.jpy.PyObject.class || scope == PyCallableWrapper.class) {
String pythonExpr = n.getScope().get().toString() + "." + methodName;
PythonDeephavenSession pds =
(PythonDeephavenSession) ((AbstractScriptSession.UnsynchronizedScriptSessionQueryScope) queryScope)
.scriptSession();
Map<String, PyObject> scopeVars = pds.getVariablesRaw();
PyObject pyobj;
try {
pyobj = PyObject.executeCode(pythonExpr, PyInputMode.EXPRESSION, scopeVars, null);
} catch (RuntimeException e) {
throw new RuntimeException("Cannot find Python callable: " + pythonExpr);
}
try {
Object obj = PythonScopeJpyImpl.convert(pyobj);
if (obj.getClass() == (PyCallableWrapper.class)) {
String javaMethodName = ((PyCallableWrapper) obj).getAttribute("_j_simple_name", String.class);
n.setScope(null);
n.setName(javaMethodName);
return true;
}
} catch (Exception e) {
// Not a Java replaceable callable, can safely ignore
}
}
return false;
}

private Method getMethod(final Class<?> scope, final String methodName, final Class<?>[] paramTypes,
final Class<?>[][] parameterizedTypes) {
final ArrayList<CandidateExecutable<Method>> acceptableMethods = new ArrayList<>();
Expand Down Expand Up @@ -1140,6 +1205,16 @@ public Class<?> visit(NameExpr n, VisitArgs printer) {
* throw them to 'findClass()'. Many details are not relevant here. For example, field access is handled by a
* different method: visit(FieldAccessExpr, StringBuilder).
*/
Map<String, String> pyConstantsMap = Map.of(
"True", "true",
"False", "false",
"None", "null");
final String name = n.getNameAsString();
String jConstant = pyConstantsMap.get(name);
if (jConstant != null) {
printer.append(jConstant);
return name.equals("None") ? NULL_CLASS : boolean.class;
}
printer.append(n.getNameAsString());

Class<?> ret = variables.get(n.getNameAsString());
Expand Down Expand Up @@ -1865,6 +1940,13 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
return result;
}).orElse(null);

if (pyToJavaReplaced(scope, n)) {
if (scope != null) {
scope = null;
innerPrinter.reset();
}
}

Expression[] expressions = n.getArguments() == null ? new Expression[0]
: n.getArguments().toArray(new Expression[0]);

Expand All @@ -1873,7 +1955,6 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
Class<?>[][] parameterizedTypes = getParameterizedTypes(expressions);

Method method = getMethod(scope, n.getNameAsString(), expressionTypes, parameterizedTypes);

Class<?>[] argumentTypes = method.getParameterTypes();

// now do some parameter conversions...
Expand Down Expand Up @@ -1913,6 +1994,7 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
}
} else { // Groovy or Java method call
printer.append(innerPrinter);
n.setName(method.getName());
printer.append(n.getNameAsString());
}

Expand Down Expand Up @@ -2538,6 +2620,12 @@ public boolean hasStringBuilder() {
return builder != null;
}

public void reset() {
if (hasStringBuilder()) {
builder.setLength(0);
}
}

/**
* Convenience method: forwards argument to 'builder' if 'builder' is not null
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending
*/
package io.deephaven.integrations.python;
package io.deephaven.engine.util;

import io.deephaven.base.FileUtils;
import io.deephaven.base.verify.Assert;
Expand All @@ -10,13 +10,7 @@
import io.deephaven.engine.exceptions.CancellationException;
import io.deephaven.engine.context.QueryScope;
import io.deephaven.engine.updategraph.UpdateGraph;
import io.deephaven.engine.util.AbstractScriptSession;
import io.deephaven.engine.util.PythonEvaluator;
import io.deephaven.engine.util.PythonEvaluatorJpy;
import io.deephaven.engine.util.PythonScope;
import io.deephaven.engine.util.ScriptFinder;
import io.deephaven.engine.util.ScriptSession;
import io.deephaven.integrations.python.PythonDeephavenSession.PythonSnapshot;
import io.deephaven.engine.util.PythonDeephavenSession.PythonSnapshot;
import io.deephaven.engine.util.scripts.ScriptPathLoader;
import io.deephaven.engine.util.scripts.ScriptPathLoaderState;
import io.deephaven.internal.log.LoggerFactory;
Expand All @@ -40,14 +34,18 @@
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.AbstractMap;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;

import static java.util.stream.Collectors.toMap;

/**
* A ScriptSession that uses a JPy cpython interpreter internally.
* <p>
Expand Down Expand Up @@ -207,6 +205,12 @@ public Map<String, Object> getVariables() {
return outMap;
}

public Map<String, PyObject> getVariablesRaw() {
return scope.getEntriesRaw()
.map(e -> new AbstractMap.SimpleImmutableEntry<>(scope.convertStringKey(e.getKey()), e.getValue()))
.collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
}

protected static class PythonSnapshot implements Snapshot, SafeCloseable {

private final PyDictWrapper dict;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/**
* Copyright (c) 2016-2022 Deephaven Data Labs and Patent Pending
*/
package io.deephaven.integrations.python;
package io.deephaven.engine.util;

import org.jpy.PyModule;
import org.jpy.PyObject;
Expand Down
Loading

0 comments on commit 68c765e

Please sign in to comment.