From 9e60ab4084e30dd4189d9065757c37bff11fc82d Mon Sep 17 00:00:00 2001 From: Nate Bauernfeind Date: Tue, 2 Nov 2021 19:05:17 -0500 Subject: [PATCH] Enable casting PyObject's to primitives; provide type param info for py lists (#1490) Also did this for String/boolean/Boolean: from deephaven import TableTools x = ['0', '1', '42', None] y = TableTools.emptyTable(1).update("X = (String)x[i % x.size()]") print(y.getDefinition()) Note that after some discussion we preferred to NPE when casting to boolean and the value is a None. We also decided that a String cast should throw an exception if not already a String (instead of converting via PyObject::str). Fixes #1009 --- .../lang/DBLanguageFunctionGenerator.java | 116 +++++++++++++++++- .../tables/lang/DBLanguageFunctionUtil.java | 111 +++++++++++++++++ .../db/tables/lang/DBLanguageParser.java | 40 ++++-- .../io/deephaven/db/tables/select/Param.java | 88 +++++++++---- .../db/v2/select/AbstractConditionFilter.java | 12 +- .../db/v2/select/ConditionFilter.java | 2 +- .../db/v2/select/DhFormulaColumn.java | 4 +- .../db/v2/select/codegen/FormulaAnalyzer.java | 13 +- .../db/tables/lang/TestDBLanguageParser.java | 49 +++++++- 9 files changed, 393 insertions(+), 42 deletions(-) diff --git a/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionGenerator.java b/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionGenerator.java index 6c27edd660a..cd5001917cb 100644 --- a/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionGenerator.java +++ b/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionGenerator.java @@ -6,6 +6,7 @@ import io.deephaven.util.type.TypeUtils; import com.github.javaparser.ast.expr.BinaryExpr; +import org.jpy.PyObject; import java.io.*; import java.text.*; @@ -319,7 +320,8 @@ public static void main(String args[]) { buf.append("package io.deephaven.db.tables.lang;\n\n"); - buf.append("import io.deephaven.util.QueryConstants;\n\n"); + buf.append("import io.deephaven.util.QueryConstants;\n"); + buf.append("import org.jpy.PyObject;\n\n"); buf.append("@SuppressWarnings({\"unused\", \"WeakerAccess\", \"SimplifiableIfStatement\"})\n"); buf.append("public final class DBLanguageFunctionUtil {\n\n"); @@ -558,6 +560,118 @@ public static void main(String args[]) { append(buf, castFromObjFormatter, BinaryExpr.Operator.PLUS, c, Object.class); } + // Special casts for PyObject to primitive + buf.append(" public static int intPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_INT;\n"); + buf.append(" }\n"); + buf.append(" return o.getIntValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static double doublePyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_DOUBLE;\n"); + buf.append(" }\n"); + buf.append(" return o.getDoubleValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static long longPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_LONG;\n"); + buf.append(" }\n"); + buf.append(" return o.getLongValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static float floatPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_FLOAT;\n"); + buf.append(" }\n"); + buf.append(" return (float) o.getDoubleValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static char charPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_CHAR;\n"); + buf.append(" }\n"); + buf.append(" return (char) o.getIntValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static byte bytePyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_BYTE;\n"); + buf.append(" }\n"); + buf.append(" return (byte) o.getIntValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static short shortPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return QueryConstants.NULL_SHORT;\n"); + buf.append(" }\n"); + buf.append(" return (short) o.getIntValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static String doStringPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return null;\n"); + buf.append(" }\n"); + buf.append(" return o.getStringValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static boolean booleanPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" throw new NullPointerException(\"Provided value is unexpectedly null;"); + buf.append(" cannot cast to boolean\");\n"); + buf.append(" }\n"); + buf.append(" return o.getBooleanValue();\n"); + buf.append(" }\n\n"); + + buf.append(" public static Boolean doBooleanPyCast(Object a) {\n"); + buf.append(" if (a != null && !(a instanceof PyObject)) {\n"); + buf.append(" throw new IllegalArgumentException(\"Provided value is not a PyObject\");\n"); + buf.append(" }\n"); + buf.append(" PyObject o = (PyObject) a;\n"); + buf.append(" if (o == null || o.isNone()) {\n"); + buf.append(" return null;\n"); + buf.append(" }\n"); + buf.append(" return o.getBooleanValue();\n"); + buf.append(" }\n\n"); + // ------------------------------------------------------------------------------------------------------------------------------------------------------------------ classes = new Class[] {int.class, double.class, long.class, float.class, char.class, byte.class, short.class}; diff --git a/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionUtil.java b/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionUtil.java index 157c4887f08..6a01d6d64be 100644 --- a/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionUtil.java +++ b/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageFunctionUtil.java @@ -6,6 +6,7 @@ package io.deephaven.db.tables.lang; import io.deephaven.util.QueryConstants; +import org.jpy.PyObject; @SuppressWarnings({"unused", "WeakerAccess", "SimplifiableIfStatement"}) public final class DBLanguageFunctionUtil { @@ -19289,6 +19290,116 @@ public static short shortCast(Object a) { return a == null ? QueryConstants.NULL_SHORT : (short) a; } + public static int intPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_INT; + } + return o.getIntValue(); + } + + public static double doublePyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_DOUBLE; + } + return o.getDoubleValue(); + } + + public static long longPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_LONG; + } + return o.getLongValue(); + } + + public static float floatPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_FLOAT; + } + return (float) o.getDoubleValue(); + } + + public static char charPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_CHAR; + } + return (char) o.getIntValue(); + } + + public static byte bytePyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_BYTE; + } + return (byte) o.getIntValue(); + } + + public static short shortPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return QueryConstants.NULL_SHORT; + } + return (short) o.getIntValue(); + } + + public static String doStringPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return null; + } + return o.getStringValue(); + } + + public static boolean booleanPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + throw new NullPointerException("Provided value is unexpectedly null; cannot cast to boolean"); + } + return o.getBooleanValue(); + } + + public static Boolean doBooleanPyCast(Object a) { + if (a != null && !(a instanceof PyObject)) { + throw new IllegalArgumentException("Provided value is not a PyObject"); + } + PyObject o = (PyObject) a; + if (o == null || o.isNone()) { + return null; + } + return o.getBooleanValue(); + } + public static int negate(int a) { return a == QueryConstants.NULL_INT ? QueryConstants.NULL_INT : -a; } diff --git a/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageParser.java b/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageParser.java index f9b64786142..4004bdbdfdf 100644 --- a/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageParser.java +++ b/DB/src/main/java/io/deephaven/db/tables/lang/DBLanguageParser.java @@ -1003,11 +1003,22 @@ public Class visit(ArrayAccessExpr n, VisitArgs printer) { Class paramType = n.getIndex().accept(this, printer); printer.append(')'); - if (DbArray.class.isAssignableFrom(type) && (n.getName() instanceof NameExpr)) { - Class ret = variableParameterizedTypes.get(((NameExpr) n.getName()).getNameAsString())[0]; + // We'll check for a known component type if this a NameExpr. + if (n.getName() instanceof NameExpr) { + Class[] classes = variableParameterizedTypes.get(((NameExpr) n.getName()).getNameAsString()); - if (ret != null) { - return ret; + if (classes != null) { + Class ret = null; + + if (classes.length == 1) { + ret = classes[0]; // scenario 1: this is a list-like type + } else if (classes.length == 2) { + ret = classes[1]; // scenario 2: this is a map-like type + } + + if (ret != null) { + return ret; + } } } @@ -1166,14 +1177,27 @@ else if (fromBoxedType) { * Now actually print the cast. For casts to primitives (except boolean), we use special null-safe functions * (e.g. intCast()) to perform the cast. * - * There is no "booleanCast()" function. + * There is no "booleanCast()" function. However, we do try to cast to String and Boolean from PyObjects. * * There are also no special functions for the identity conversion -- e.g. "intCast(int)" */ - if (toPrimitive && !ret.equals(boolean.class) && !ret.equals(exprType)) { + final boolean isPyUpgrade = + ((ret.equals(boolean.class) || ret.equals(Boolean.class) || ret.equals(String.class)) + && exprType.equals(PyObject.class)); + + if ((toPrimitive && !ret.equals(boolean.class) && !ret.equals(exprType)) || isPyUpgrade) { // Casting to a primitive, except booleans and the identity conversion + if (!toPrimitive) { + // these methods look like `doStringPyCast` and `doBooleanPyCast` + printer.append("do"); + } printer.append(ret.getSimpleName()); - printer.append("Cast("); + + if (exprType != NULL_CLASS && isAssignableFrom(PyObject.class, exprType)) { + printer.append("PyCast("); + } else { + printer.append("Cast("); + } /* * When unboxing to a wider type, do an unboxing conversion followed by a widening conversion. See table @@ -1488,6 +1512,8 @@ public Class visit(FieldAccessExpr n, VisitArgs printer) { // The to-be-cast expr is a Python object field accessor final String clsName = printer.pythonCastContext.getSimpleName(); printer.append(", " + clsName + ".class"); + // Let's advertise to the caller the cast context type + ret = printer.pythonCastContext; } printer.append(')'); } else { diff --git a/DB/src/main/java/io/deephaven/db/tables/select/Param.java b/DB/src/main/java/io/deephaven/db/tables/select/Param.java index d3ed1544825..21d5d216fe2 100644 --- a/DB/src/main/java/io/deephaven/db/tables/select/Param.java +++ b/DB/src/main/java/io/deephaven/db/tables/select/Param.java @@ -9,18 +9,16 @@ import groovy.lang.Closure; import java.lang.reflect.Modifier; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; public class Param { - public static final Param[] ZERO_LENGTH_PARAM_ARRAY = new Param[0]; + public static final Param[] ZERO_LENGTH_PARAM_ARRAY = new Param[0]; private final String name; private final T value; @@ -38,29 +36,69 @@ public Param(String name, T value) { this.value = value; } - public Class getDeclaredType() { - final Class type = value == null ? Object.class - : value instanceof Enum ? ((Enum) value).getDeclaringClass() - // in newer versions of groovy, our closures will be subtypes that evade the logic in - // getDeclaredType - // (they will return a null Class#getCanonicalName b/c they are dynamic classes). - : value instanceof Closure ? Closure.class - : value.getClass(); + public Class getDeclaredClass() { + Type declaredType = getDeclaredType(); + Class cls = classFromType(declaredType); + + if (cls == null) { + throw new IllegalStateException("Unexpected declared type of type '" + + declaredType.getClass().getCanonicalName() + "'"); + } + + return cls; + } + + public static Class classFromType(final Type declaredType) { + if (declaredType instanceof Class) { + return (Class) declaredType; + } + if (declaredType instanceof ParameterizedType) { + return (Class) ((ParameterizedType) declaredType).getRawType(); + } + return null; + } + + public Type getDeclaredType() { + // in newer versions of groovy, our closures will be subtypes that evade the logic in getDeclaredType + // (they will return a null Class#getCanonicalName b/c they are dynamic classes). + final Class type; + if (value == null) { + type = Object.class; + } else if (value instanceof Enum) { + type = ((Enum) value).getDeclaringClass(); + } else if (value instanceof Closure) { + type = Closure.class; + } else { + type = value.getClass(); + } return getDeclaredType(type); } - protected static Class getDeclaredType(Class type) { - OUTER: while (type != Object.class) { + protected static Type getDeclaredType(final Class origType) { + Class type = origType; + while (type != Object.class) { if (Modifier.isPublic(type.getModifiers()) && !type.isAnonymousClass()) { break; } - Class[] interfaces = type.getInterfaces(); - for (Class iface : interfaces) { - if (iface.getMethods().length > 0) { - type = iface; - break OUTER; + + Type[] interfaces = type.getGenericInterfaces(); + for (Type ityp : interfaces) { + Class iface = null; + if (ityp instanceof Class) { + iface = (Class) ityp; + } else if (ityp instanceof ParameterizedType) { + ParameterizedType pt = (ParameterizedType) ityp; + Type rawType = pt.getRawType(); + if (rawType instanceof Class) { + iface = (Class) rawType; + } + } + + if (iface != null && Modifier.isPublic(iface.getModifiers()) && iface.getMethods().length > 0) { + return ityp; } } + type = type.getSuperclass(); } @@ -68,24 +106,20 @@ protected static Class getDeclaredType(Class type) { } public String getDeclaredTypeName() { - return getDeclaredType().getCanonicalName(); + return getDeclaredClass().getCanonicalName(); } public String getPrimitiveTypeNameIfAvailable() { if (value == null) { return getDeclaredTypeName(); } - Class type = getDeclaredType(); + Class type = getDeclaredClass(); if (io.deephaven.util.type.TypeUtils.isBoxedType(type)) { return TypeUtils.getUnboxedType(type).getCanonicalName(); } return getDeclaredTypeName(); } - protected static String getDeclaredTypeName(Class type) { - return getDeclaredType(type).getCanonicalName(); - } - /** * Get a map from binary name to declared type for the dynamic classes referenced by an array of param classes. * @@ -100,7 +134,7 @@ public static Map> expandParameterClasses(final List> private static void visitParameterClass(final Map> found, Class cls) { while (cls.isArray()) { - cls = getDeclaredType(cls.getComponentType()); + cls = classFromType(cls.getComponentType()); } final String name = cls.getName(); diff --git a/DB/src/main/java/io/deephaven/db/v2/select/AbstractConditionFilter.java b/DB/src/main/java/io/deephaven/db/v2/select/AbstractConditionFilter.java index b2b381a18f5..43601ffbee7 100644 --- a/DB/src/main/java/io/deephaven/db/v2/select/AbstractConditionFilter.java +++ b/DB/src/main/java/io/deephaven/db/v2/select/AbstractConditionFilter.java @@ -17,6 +17,8 @@ import org.jetbrains.annotations.NotNull; import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.net.MalformedURLException; import java.util.*; @@ -84,7 +86,15 @@ public synchronized void init(TableDefinition tableDefinition) { final QueryScope queryScope = QueryScope.getScope(); for (Param param : queryScope.getParams(queryScope.getParamNames())) { possibleParams.put(param.getName(), param); - possibleVariables.put(param.getName(), param.getDeclaredType()); + possibleVariables.put(param.getName(), param.getDeclaredClass()); + Type declaredType = param.getDeclaredType(); + if (declaredType instanceof ParameterizedType) { + ParameterizedType pt = (ParameterizedType) declaredType; + Class[] paramTypes = Arrays.stream(pt.getActualTypeArguments()) + .map(Param::classFromType) + .toArray(Class[]::new); + possibleVariableParameterizedTypes.put(param.getName(), paramTypes); + } } Class compType; diff --git a/DB/src/main/java/io/deephaven/db/v2/select/ConditionFilter.java b/DB/src/main/java/io/deephaven/db/v2/select/ConditionFilter.java index c21da03e5a9..10e44e5e25d 100644 --- a/DB/src/main/java/io/deephaven/db/v2/select/ConditionFilter.java +++ b/DB/src/main/java/io/deephaven/db/v2/select/ConditionFilter.java @@ -383,7 +383,7 @@ protected void generateFilterCode(TableDefinition tableDefinition, DBTimeUtils.R addParamClass.accept(column.getComponentType()); } for (final Param param : params) { - addParamClass.accept(param.getDeclaredType()); + addParamClass.accept(param.getDeclaredClass()); } filterKernelClass = CompilerTools.compile("GeneratedFilterKernel", this.classBody = classBody.toString(), diff --git a/DB/src/main/java/io/deephaven/db/v2/select/DhFormulaColumn.java b/DB/src/main/java/io/deephaven/db/v2/select/DhFormulaColumn.java index d281f4274a8..380bc3893f2 100644 --- a/DB/src/main/java/io/deephaven/db/v2/select/DhFormulaColumn.java +++ b/DB/src/main/java/io/deephaven/db/v2/select/DhFormulaColumn.java @@ -669,7 +669,7 @@ private List visitFormulaParameters( if (paramLambda != null) { for (int ii = 0; ii < params.length; ++ii) { final Param p = params[ii]; - final ParamParameter pp = new ParamParameter(ii, p.getName(), p.getDeclaredType(), + final ParamParameter pp = new ParamParameter(ii, p.getName(), p.getDeclaredClass(), p.getDeclaredTypeName()); addIfNotNull(results, paramLambda.apply(pp)); } @@ -692,7 +692,7 @@ private JavaKernelBuilder.Result invokeKernelBuilder() { final Map> arrayDict = makeNameToTypeDict(sd.arrays, columnSources); final Map> allParamDict = new HashMap<>(); for (final Param param : params) { - allParamDict.put(param.getName(), param.getDeclaredType()); + allParamDict.put(param.getName(), param.getDeclaredClass()); } final Map> paramDict = new HashMap<>(); for (final String p : sd.params) { diff --git a/DB/src/main/java/io/deephaven/db/v2/select/codegen/FormulaAnalyzer.java b/DB/src/main/java/io/deephaven/db/v2/select/codegen/FormulaAnalyzer.java index f8445482937..9910a8e31a1 100644 --- a/DB/src/main/java/io/deephaven/db/v2/select/codegen/FormulaAnalyzer.java +++ b/DB/src/main/java/io/deephaven/db/v2/select/codegen/FormulaAnalyzer.java @@ -16,6 +16,8 @@ import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; import java.util.*; public class FormulaAnalyzer { @@ -101,7 +103,16 @@ public static DBLanguageParser.Result getCompiledFormula(Map param : queryScope.getParams(queryScope.getParamNames())) { - possibleVariables.put(param.getName(), param.getDeclaredType()); + possibleVariables.put(param.getName(), param.getDeclaredClass()); + + Type declaredType = param.getDeclaredType(); + if (declaredType instanceof ParameterizedType) { + ParameterizedType pt = (ParameterizedType) declaredType; + Class[] paramTypes = Arrays.stream(pt.getActualTypeArguments()) + .map(Param::classFromType) + .toArray(Class[]::new); + possibleVariableParameterizedTypes.put(param.getName(), paramTypes); + } } for (ColumnDefinition columnDefinition : availableColumns.values()) { diff --git a/DB/src/test/java/io/deephaven/db/tables/lang/TestDBLanguageParser.java b/DB/src/test/java/io/deephaven/db/tables/lang/TestDBLanguageParser.java index c6476fdc7f3..8d2da3c057e 100644 --- a/DB/src/test/java/io/deephaven/db/tables/lang/TestDBLanguageParser.java +++ b/DB/src/test/java/io/deephaven/db/tables/lang/TestDBLanguageParser.java @@ -19,6 +19,7 @@ import org.apache.commons.text.StringEscapeUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.jetbrains.annotations.NotNull; +import org.jpy.PyObject; import org.junit.Before; import java.awt.*; @@ -109,6 +110,7 @@ public void setUp() throws Exception { variables.put("myDBDateTime", DBDateTime.class); variables.put("myTable", Table.class); + variables.put("myPyObject", PyObject.class); variables.put("ExampleQuantity", int.class); variables.put("ExampleQuantity2", double.class); @@ -117,6 +119,7 @@ public void setUp() throws Exception { variables.put("ExampleStr", String.class); variableParameterizedTypes = new HashMap<>(); + variableParameterizedTypes.put("myArrayList", new Class[] {Long.class}); variableParameterizedTypes.put("myHashMap", new Class[] {Integer.class, Double.class}); variableParameterizedTypes.put("myDBArray", new Class[] {Double.class}); } @@ -834,6 +837,48 @@ public void testPrimitiveAndBoxedToObjectCasts() throws Exception { } } + public void testPyObjectToPrimitiveCasts() throws Exception { + String expression = "(int)myPyObject"; + String resultExpression = "intPyCast(myPyObject)"; + check(expression, resultExpression, int.class, new String[] {"myPyObject"}); + + expression = "(double)myPyObject"; + resultExpression = "doublePyCast(myPyObject)"; + check(expression, resultExpression, double.class, new String[] {"myPyObject"}); + + expression = "(long)myPyObject"; + resultExpression = "longPyCast(myPyObject)"; + check(expression, resultExpression, long.class, new String[] {"myPyObject"}); + + expression = "(float)myPyObject"; + resultExpression = "floatPyCast(myPyObject)"; + check(expression, resultExpression, float.class, new String[] {"myPyObject"}); + + expression = "(char)myPyObject"; + resultExpression = "charPyCast(myPyObject)"; + check(expression, resultExpression, char.class, new String[] {"myPyObject"}); + + expression = "(byte)myPyObject"; + resultExpression = "bytePyCast(myPyObject)"; + check(expression, resultExpression, byte.class, new String[] {"myPyObject"}); + + expression = "(short)myPyObject"; + resultExpression = "shortPyCast(myPyObject)"; + check(expression, resultExpression, short.class, new String[] {"myPyObject"}); + + expression = "(String)myPyObject"; + resultExpression = "doStringPyCast(myPyObject)"; + check(expression, resultExpression, String.class, new String[] {"myPyObject"}); + + expression = "(boolean)myPyObject"; + resultExpression = "booleanPyCast(myPyObject)"; + check(expression, resultExpression, boolean.class, new String[] {"myPyObject"}); + + expression = "(Boolean)myPyObject"; + resultExpression = "doBooleanPyCast(myPyObject)"; + check(expression, resultExpression, Boolean.class, new String[] {"myPyObject"}); + } + public void testVariables() throws Exception { String expression = "1+myInt"; String resultExpression = "plus(1, myInt)"; @@ -945,11 +990,11 @@ public void testArrayOperatorOverloading() throws Exception { expression = "myArrayList[15]"; resultExpression = "myArrayList.get(15)"; - check(expression, resultExpression, Object.class, new String[] {"myArrayList"}); + check(expression, resultExpression, Long.class, new String[] {"myArrayList"}); expression = "myHashMap[\"test\"]"; resultExpression = "myHashMap.get(\"test\")"; - check(expression, resultExpression, Object.class, new String[] {"myHashMap"}); + check(expression, resultExpression, Double.class, new String[] {"myHashMap"}); expression = "myIntArray==myDoubleArray"; resultExpression = "eqArray(myIntArray, myDoubleArray)";