diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagator.java b/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagator.java index f8348c8da3..81aa9d37ed 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagator.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagator.java @@ -22,6 +22,7 @@ * THE SOFTWARE. */ +import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.LinkedHashMultimap; @@ -125,11 +126,26 @@ class StreamNullabilityPropagator extends BaseNoOpHandler { private final Map observableCallToInnerMethodOrLambda = new LinkedHashMap<>(); - // Maps collect calls in the observable call chain to the relevant inner methods or lambdas. + @AutoValue + abstract static class CollectRecordAndInnerMethod { + + static CollectRecordAndInnerMethod create( + CollectLikeMethodRecord collectlikeMethodRecord, Tree innerMethodOrLambda) { + return new AutoValue_StreamNullabilityPropagator_CollectRecordAndInnerMethod( + collectlikeMethodRecord, innerMethodOrLambda); + } + + abstract CollectLikeMethodRecord getCollectLikeMethodRecord(); + + abstract Tree getInnerMethodOrLambda(); + } + + // Maps collect calls in the observable call chain to the relevant (collect record, inner method + // or lambda) pairs. // We need a Multimap here since there may be multiple relevant methods / lambdas. - // E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => {l1,l2} - private final Multimap collectCallToInnerMethodsOrLambdas = - LinkedHashMultimap.create(); + // E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => (record for toMap, {l1,l2}) + private final Multimap + collectCallToRecordsAndInnerMethodsOrLambdas = LinkedHashMultimap.create(); // Map from map or collect method (or lambda) to corresponding previous filter method (e.g. // B.apply => A.filter for the map example above, or l1 => A.filter and l2 => A.filter for the @@ -183,7 +199,7 @@ public void onMatchTopLevelClass( this.filterMethodOrLambdaSet.clear(); this.observableOuterCallInChain.clear(); this.observableCallToInnerMethodOrLambda.clear(); - this.collectCallToInnerMethodsOrLambdas.clear(); + this.collectCallToRecordsAndInnerMethodsOrLambdas.clear(); this.mapOrCollectRecordToFilterMap.clear(); this.filterToNSMap.clear(); this.bodyToMethodOrLambda.clear(); @@ -237,10 +253,14 @@ public void onMatchMethodInvocation( observableCallToInnerMethodOrLambda.put(tree, argTree); } } else { - CollectLikeMethodRecord collectlikeMethodRecord = - streamType.getCollectlikeMethodRecord(methodSymbol); - if (collectlikeMethodRecord != null && methodSymbol.getParameters().length() == 1) { - handleCollectCall(tree, collectlikeMethodRecord); + if (methodSymbol.getParameters().length() == 1) { + for (CollectLikeMethodRecord collectlikeMethodRecord : + streamType.getCollectlikeMethodRecords(methodSymbol)) { + boolean handled = handleCollectCall(tree, collectlikeMethodRecord); + if (handled) { + break; + } + } } } } @@ -249,13 +269,15 @@ public void onMatchMethodInvocation( /** * Handles a call to a collect-like method. If the argument to the method is supported, updates - * the {@link #collectCallToInnerMethodsOrLambdas} map appropriately. + * the {@link #collectCallToRecordsAndInnerMethodsOrLambdas} map appropriately. * * @param collectInvocationTree The MethodInvocationTree representing the call to the collect-like * method. * @param collectlikeMethodRecord The record representing the collect-like method. + * @return true if the argument to the collect method was a call to the factory method in the + * record, false otherwise. */ - private void handleCollectCall( + private boolean handleCollectCall( MethodInvocationTree collectInvocationTree, CollectLikeMethodRecord collectlikeMethodRecord) { ExpressionTree argTree = collectInvocationTree.getArguments().get(0); if (argTree instanceof MethodInvocationTree) { @@ -282,14 +304,21 @@ private void handleCollectCall( handleMapOrCollectAnonClassBody( collectlikeMethodRecord, anonClassBody, - t -> collectCallToInnerMethodsOrLambdas.put(collectInvocationTree, t)); + t -> + collectCallToRecordsAndInnerMethodsOrLambdas.put( + collectInvocationTree, + CollectRecordAndInnerMethod.create(collectlikeMethodRecord, t))); } } else if (factoryMethodArg instanceof LambdaExpressionTree) { - collectCallToInnerMethodsOrLambdas.put(collectInvocationTree, factoryMethodArg); + collectCallToRecordsAndInnerMethodsOrLambdas.put( + collectInvocationTree, + CollectRecordAndInnerMethod.create(collectlikeMethodRecord, factoryMethodArg)); } } + return true; } } + return false; } private void buildObservableCallChain(MethodInvocationTree tree) { @@ -326,19 +355,15 @@ private void handleChainFromFilter( mapOrCollectRecordToFilterMap.put( observableCallToInnerMethodOrLambda.get(outerCallInChain), record); } - } else if (collectCallToInnerMethodsOrLambdas.containsKey(outerCallInChain)) { - Symbol.MethodSymbol collectMethod = ASTHelpers.getSymbol(outerCallInChain); - CollectLikeMethodRecord collectlikeMethodRecord = - streamType.getCollectlikeMethodRecord(collectMethod); - if (collectlikeMethodRecord != null) { - // Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas - for (Tree innerMethodOrLambda : - collectCallToInnerMethodsOrLambdas.get(outerCallInChain)) { - MapOrCollectMethodToFilterInstanceRecord record = - new MapOrCollectMethodToFilterInstanceRecord( - collectlikeMethodRecord, filterMethodOrLambda); - mapOrCollectRecordToFilterMap.put(innerMethodOrLambda, record); - } + } else if (collectCallToRecordsAndInnerMethodsOrLambdas.containsKey(outerCallInChain)) { + // Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas + for (CollectRecordAndInnerMethod collectRecordAndInnerMethod : + collectCallToRecordsAndInnerMethodsOrLambdas.get(outerCallInChain)) { + MapOrCollectMethodToFilterInstanceRecord record = + new MapOrCollectMethodToFilterInstanceRecord( + collectRecordAndInnerMethod.getCollectLikeMethodRecord(), filterMethodOrLambda); + mapOrCollectRecordToFilterMap.put( + collectRecordAndInnerMethod.getInnerMethodOrLambda(), record); } } } while (outerCallInChain != null diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagatorFactory.java b/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagatorFactory.java index 985458359e..0ae2400fdf 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagatorFactory.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/StreamNullabilityPropagatorFactory.java @@ -74,6 +74,13 @@ public static StreamNullabilityPropagator getJavaStreamNullabilityPropagator() { ImmutableSet.of(0, 1), "apply", ImmutableSet.of(0)) + .withCollectMethodFromSignature( + "collect(java.util.stream.Collector)", + "java.util.stream.Collectors", + "groupingBy(java.util.function.Function)", + ImmutableSet.of(0), + "apply", + ImmutableSet.of(0)) // List of methods of java.util.stream.Stream through which we just propagate the // nullability information of the last call, e.g. m() in // Observable.filter(...).m().map(...) means the diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamModelBuilder.java b/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamModelBuilder.java index ea9e45a8a5..be13dcfec5 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamModelBuilder.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamModelBuilder.java @@ -23,6 +23,7 @@ */ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.errorprone.predicates.TypePredicate; import com.google.errorprone.predicates.type.DescendantOf; @@ -48,7 +49,7 @@ public class StreamModelBuilder { private ImmutableSet.Builder filterMethodSimpleNames; private ImmutableMap.Builder mapMethodSigToRecord; private ImmutableMap.Builder mapMethodSimpleNameToRecord; - private ImmutableMap.Builder collectMethodSigToRecord; + private ImmutableMultimap.Builder collectMethodSigToRecords; private ImmutableSet.Builder passthroughMethodSigs; private ImmutableSet.Builder passthroughMethodSimpleNames; @@ -75,7 +76,7 @@ private void finalizeOpenStreamTypeRecord() { filterMethodSimpleNames.build(), mapMethodSigToRecord.build(), mapMethodSimpleNameToRecord.build(), - collectMethodSigToRecord.build(), + collectMethodSigToRecords.build(), passthroughMethodSigs.build(), passthroughMethodSimpleNames.build())); } @@ -109,7 +110,7 @@ private void initializeBuilders() { this.filterMethodSimpleNames = ImmutableSet.builder(); this.mapMethodSigToRecord = ImmutableMap.builder(); this.mapMethodSimpleNameToRecord = ImmutableMap.builder(); - this.collectMethodSigToRecord = ImmutableMap.builder(); + this.collectMethodSigToRecords = ImmutableMultimap.builder(); this.passthroughMethodSigs = ImmutableSet.builder(); this.passthroughMethodSimpleNames = ImmutableSet.builder(); } @@ -200,7 +201,7 @@ public StreamModelBuilder withCollectMethodFromSignature( ImmutableSet argsToCollectorFactoryMethod, String innerMethodName, ImmutableSet argsFromStream) { - this.collectMethodSigToRecord.put( + this.collectMethodSigToRecords.put( collectMethodSig, CollectLikeMethodRecord.create( collectorFactoryMethodClass, diff --git a/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamTypeRecord.java b/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamTypeRecord.java index c03ebce421..437d9e0d4b 100644 --- a/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamTypeRecord.java +++ b/nullaway/src/main/java/com/uber/nullaway/handlers/stream/StreamTypeRecord.java @@ -23,13 +23,14 @@ */ import static com.uber.nullaway.NullabilityUtil.castToNonNull; +import com.google.common.collect.ImmutableCollection; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.errorprone.VisitorState; import com.google.errorprone.predicates.TypePredicate; import com.sun.tools.javac.code.Symbol; import com.sun.tools.javac.code.Type; -import javax.annotation.Nullable; /** An immutable model describing a class from a stream-based API such as RxJava. */ public class StreamTypeRecord { @@ -47,7 +48,7 @@ public class StreamTypeRecord { private final ImmutableMap mapMethodSigToRecord; private final ImmutableMap mapMethodSimpleNameToRecord; - private final ImmutableMap collectMethodSigToRecord; + private final ImmutableMultimap collectMethodSigToRecords; // List of methods of java.util.stream.Stream through which we just propagate the nullability // information of the last call, e.g. m() in Observable.filter(...).m().map(...) means the @@ -64,7 +65,7 @@ public StreamTypeRecord( ImmutableSet filterMethodSimpleNames, ImmutableMap mapMethodSigToRecord, ImmutableMap mapMethodSimpleNameToRecord, - ImmutableMap collectMethodSigToRecord, + ImmutableMultimap collectMethodSigToRecords, ImmutableSet passthroughMethodSigs, ImmutableSet passthroughMethodSimpleNames) { this.typePredicate = typePredicate; @@ -72,7 +73,7 @@ public StreamTypeRecord( this.filterMethodSimpleNames = filterMethodSimpleNames; this.mapMethodSigToRecord = mapMethodSigToRecord; this.mapMethodSimpleNameToRecord = mapMethodSimpleNameToRecord; - this.collectMethodSigToRecord = collectMethodSigToRecord; + this.collectMethodSigToRecords = collectMethodSigToRecords; this.passthroughMethodSigs = passthroughMethodSigs; this.passthroughMethodSimpleNames = passthroughMethodSimpleNames; } @@ -101,9 +102,9 @@ record = return record; } - @Nullable - public CollectLikeMethodRecord getCollectlikeMethodRecord(Symbol.MethodSymbol methodSymbol) { - return collectMethodSigToRecord.get(methodSymbol.toString()); + public ImmutableCollection getCollectlikeMethodRecords( + Symbol.MethodSymbol methodSymbol) { + return collectMethodSigToRecords.get(methodSymbol.toString()); } public boolean isPassthroughMethod(Symbol.MethodSymbol methodSymbol) { diff --git a/nullaway/src/test/java/com/uber/nullaway/NullAwayFrameworkTests.java b/nullaway/src/test/java/com/uber/nullaway/NullAwayFrameworkTests.java index dd7ba0bea8..0b0de7a248 100644 --- a/nullaway/src/test/java/com/uber/nullaway/NullAwayFrameworkTests.java +++ b/nullaway/src/test/java/com/uber/nullaway/NullAwayFrameworkTests.java @@ -92,6 +92,37 @@ public void streamSupportCollectorsToMap() { .doTest(); } + @Test + public void streamSupportCollectorsGroupingBy() { + defaultCompilationHelper + .addSourceLines( + "Test.java", + "package com.uber;", + "import java.util.*;", + "import java.util.stream.*;", + "import java.util.function.Function;", + "import javax.annotation.Nullable;", + "class Test {", + " static class Foo {", + " @Nullable String bar;", + " String baz = \"baz\";", + " }", + " Map> testNegative() {", + " List foos = new ArrayList<>();", + " return foos.stream()", + " .filter(foo -> foo.bar != null)", + " .collect(Collectors.groupingBy(foo -> foo.bar.length()));", + " }", + " Map> testPositive1() {", + " List foos = new ArrayList<>();", + " return foos.stream()", + " // BUG: Diagnostic contains: dereferenced expression foo.bar is @Nullable", + " .collect(Collectors.groupingBy(foo -> foo.bar.length()));", + " }", + "}") + .doTest(); + } + @Test public void supportObjectsIsNull() { defaultCompilationHelper