Skip to content

Commit

Permalink
Collectors.toMap handling for streams (#938)
Browse files Browse the repository at this point in the history
Fixes #934

The key new thing with the support here is we have further nesting.
Rather than a `map` method, where the relevant lambda is passed
directly:
```java
stream.filter(foo -> foo.bar != null).map(foo -> foo.bar.baz)
```
In this case we have a `collect` call, which gets as its argument the
result of `Collectors.toMap`, and the relevant lambdas are passed to
`toMap`:
```java
stream
  .filter(foo -> foo.bar != null)
  .collect(Collectors.toMap(foo -> foo.bar.baz, foo -> foo.bar.other))
```
Supporting this requires some new types of logic in our streams handler
(particularly because there are multiple relevant lambdas for a single
`collect` call). We do also handle anonymous inner classes.
  • Loading branch information
msridhar authored Mar 27, 2024
1 parent ce892d7 commit 76f0f77
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@
* THE SOFTWARE.
*/

import com.google.auto.value.AutoValue;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import com.google.errorprone.VisitorState;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.ClassTree;
Expand All @@ -48,14 +51,17 @@
import com.uber.nullaway.dataflow.AccessPathElement;
import com.uber.nullaway.dataflow.AccessPathNullnessAnalysis;
import com.uber.nullaway.dataflow.NullnessStore;
import com.uber.nullaway.handlers.stream.MaplikeMethodRecord;
import com.uber.nullaway.handlers.stream.MaplikeToFilterInstanceRecord;
import com.uber.nullaway.handlers.stream.CollectLikeMethodRecord;
import com.uber.nullaway.handlers.stream.MapLikeMethodRecord;
import com.uber.nullaway.handlers.stream.MapOrCollectLikeMethodRecord;
import com.uber.nullaway.handlers.stream.MapOrCollectMethodToFilterInstanceRecord;
import com.uber.nullaway.handlers.stream.StreamTypeRecord;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import javax.lang.model.element.Element;
import javax.lang.model.element.ElementKind;
import javax.lang.model.element.ExecutableElement;
Expand Down Expand Up @@ -92,6 +98,13 @@ class StreamNullabilityPropagator extends BaseNoOpHandler {
* 'Observable.filter'). In general, for observable.a().b().c(), c is the outer call of b and b the outer call
* of a in the chain.
*
* We also support collect-like methods, which take a collector factory method as an argument, e.g.:
*
* stream.filter(...).collect(Collectors.toMap(l1, l2)) (where l1 and l2 are lambdas)
*
* For such scenarios, the lambdas l1 and l2 (or the named method in the equivalent anonymous class) serve
* an equivalent role to the map methods discussed above.
*
* This class works by building the following maps which keep enough state outside of the standard dataflow
* analysis for us to figure out what's going on:
*
Expand All @@ -113,9 +126,32 @@ class StreamNullabilityPropagator extends BaseNoOpHandler {
private final Map<MethodInvocationTree, Tree> observableCallToInnerMethodOrLambda =
new LinkedHashMap<>();

// Map from map method (or lambda) to corresponding previous filter method (e.g. B.apply =>
// A.filter)
private final Map<Tree, MaplikeToFilterInstanceRecord> mapToFilterMap = new LinkedHashMap<>();
@AutoValue
abstract static class CollectRecordAndInnerMethod {

static CollectRecordAndInnerMethod create(
CollectLikeMethodRecord collectlikeMethodRecord, Tree innerMethodOrLambda) {
return new AutoValue_StreamNullabilityPropagator_CollectRecordAndInnerMethod(
collectlikeMethodRecord, innerMethodOrLambda);
}

abstract CollectLikeMethodRecord getCollectLikeMethodRecord();

abstract Tree getInnerMethodOrLambda();
}

// Maps collect calls in the observable call chain to the relevant (collect record, inner method
// or lambda) pairs.
// We need a Multimap here since there may be multiple relevant methods / lambdas.
// E.g.: stream.filter(...).collect(Collectors.toMap(l1, l2)) => (record for toMap, {l1,l2})
private final Multimap<MethodInvocationTree, CollectRecordAndInnerMethod>
collectCallToRecordsAndInnerMethodsOrLambdas = LinkedHashMultimap.create();

// Map from map or collect method (or lambda) to corresponding previous filter method (e.g.
// B.apply => A.filter for the map example above, or l1 => A.filter and l2 => A.filter for the
// collect example above)
private final Map<Tree, MapOrCollectMethodToFilterInstanceRecord> mapOrCollectRecordToFilterMap =
new LinkedHashMap<>();

/*
* Note that the above methods imply a diagram like the following:
Expand Down Expand Up @@ -163,7 +199,8 @@ public void onMatchTopLevelClass(
this.filterMethodOrLambdaSet.clear();
this.observableOuterCallInChain.clear();
this.observableCallToInnerMethodOrLambda.clear();
this.mapToFilterMap.clear();
this.collectCallToRecordsAndInnerMethodsOrLambdas.clear();
this.mapOrCollectRecordToFilterMap.clear();
this.filterToNSMap.clear();
this.bodyToMethodOrLambda.clear();
this.returnToEnclosingMethodOrLambda.clear();
Expand Down Expand Up @@ -204,19 +241,91 @@ public void onMatchMethodInvocation(
ClassTree annonClassBody = ((NewClassTree) argTree).getClassBody();
// Ensure that this `new B() ...` has a custom class body, otherwise, we skip for now.
if (annonClassBody != null) {
MaplikeMethodRecord methodRecord = streamType.getMaplikeMethodRecord(methodSymbol);
handleMapAnonClass(methodRecord, tree, annonClassBody);
MapLikeMethodRecord methodRecord = streamType.getMaplikeMethodRecord(methodSymbol);
handleMapOrCollectAnonClassBody(
methodRecord,
annonClassBody,
t -> observableCallToInnerMethodOrLambda.put(tree, t));
}
} else if (argTree instanceof LambdaExpressionTree) {
observableCallToInnerMethodOrLambda.put(tree, argTree);
} else if (argTree instanceof MemberReferenceTree) {
observableCallToInnerMethodOrLambda.put(tree, argTree);
}
} else {
if (methodSymbol.getParameters().length() == 1) {
// We can have multiple CollectLikeMethodRecords for a single collect method, reflecting
// the different possible collector factory methods whose result may be passed to a
// collect call. At a single collect call site, at most one of these records will be
// relevant. So, we loop through them all, but break out of the loop as soon as we find
// one that matches.
for (CollectLikeMethodRecord collectlikeMethodRecord :
streamType.getCollectlikeMethodRecords(methodSymbol)) {
boolean handled = handleCollectCall(tree, collectlikeMethodRecord);
if (handled) {
break;
}
}
}
}
}
}
}

/**
* Handles a call to a collect-like method. If the argument to the method is supported, updates
* the {@link #collectCallToRecordsAndInnerMethodsOrLambdas} map appropriately.
*
* @param collectInvocationTree The MethodInvocationTree representing the call to the collect-like
* method.
* @param collectlikeMethodRecord The record representing the collect-like method.
* @return true if the argument to the collect method was a call to the factory method in the
* record, false otherwise.
*/
private boolean handleCollectCall(
MethodInvocationTree collectInvocationTree, CollectLikeMethodRecord collectlikeMethodRecord) {
ExpressionTree argTree = collectInvocationTree.getArguments().get(0);
if (argTree instanceof MethodInvocationTree) {
// the argument passed to the collect method. We check if this is a call to the collector
// factory method from the record
MethodInvocationTree collectInvokeArg = (MethodInvocationTree) argTree;
Symbol.MethodSymbol collectInvokeArgSymbol = ASTHelpers.getSymbol(collectInvokeArg);
if (collectInvokeArgSymbol
.owner
.getQualifiedName()
.contentEquals(collectlikeMethodRecord.collectorFactoryMethodClass())
&& collectInvokeArgSymbol
.toString()
.equals(collectlikeMethodRecord.collectorFactoryMethodSignature())) {
List<? extends ExpressionTree> arguments = collectInvokeArg.getArguments();
for (int ind : collectlikeMethodRecord.argsToCollectorFactoryMethod()) {
ExpressionTree factoryMethodArg = arguments.get(ind);
// TODO eventually, support method references, though this is likely only useful in
// JSpecify mode with generics checking
if (factoryMethodArg instanceof NewClassTree) {
ClassTree anonClassBody = ((NewClassTree) factoryMethodArg).getClassBody();
// Ensure that this `new B() ...` has a custom class body, otherwise, we skip for now.
if (anonClassBody != null) {
handleMapOrCollectAnonClassBody(
collectlikeMethodRecord,
anonClassBody,
t ->
collectCallToRecordsAndInnerMethodsOrLambdas.put(
collectInvocationTree,
CollectRecordAndInnerMethod.create(collectlikeMethodRecord, t)));
}
} else if (factoryMethodArg instanceof LambdaExpressionTree) {
collectCallToRecordsAndInnerMethodsOrLambdas.put(
collectInvocationTree,
CollectRecordAndInnerMethod.create(collectlikeMethodRecord, factoryMethodArg));
}
}
return true;
}
}
return false;
}

private void buildObservableCallChain(MethodInvocationTree tree) {
ExpressionTree methodSelect = tree.getMethodSelect();
if (methodSelect instanceof MemberSelectTree) {
Expand All @@ -242,13 +351,24 @@ private void handleChainFromFilter(
// Check for a map method (which might be a pass-through method or the first method after a
// pass-through chain)
if (observableCallToInnerMethodOrLambda.containsKey(outerCallInChain)) {
// Update mapToFilterMap
// Update mapOrCollectRecordToFilterMap
Symbol.MethodSymbol mapMethod = ASTHelpers.getSymbol(outerCallInChain);
if (streamType.isMapMethod(mapMethod)) {
MaplikeToFilterInstanceRecord record =
new MaplikeToFilterInstanceRecord(
MapOrCollectMethodToFilterInstanceRecord record =
new MapOrCollectMethodToFilterInstanceRecord(
streamType.getMaplikeMethodRecord(mapMethod), filterMethodOrLambda);
mapToFilterMap.put(observableCallToInnerMethodOrLambda.get(outerCallInChain), record);
mapOrCollectRecordToFilterMap.put(
observableCallToInnerMethodOrLambda.get(outerCallInChain), record);
}
} else if (collectCallToRecordsAndInnerMethodsOrLambdas.containsKey(outerCallInChain)) {
// Update mapOrCollectRecordToFilterMap for all relevant methods / lambdas
for (CollectRecordAndInnerMethod collectRecordAndInnerMethod :
collectCallToRecordsAndInnerMethodsOrLambdas.get(outerCallInChain)) {
MapOrCollectMethodToFilterInstanceRecord record =
new MapOrCollectMethodToFilterInstanceRecord(
collectRecordAndInnerMethod.getCollectLikeMethodRecord(), filterMethodOrLambda);
mapOrCollectRecordToFilterMap.put(
collectRecordAndInnerMethod.getInnerMethodOrLambda(), record);
}
}
} while (outerCallInChain != null
Expand Down Expand Up @@ -280,22 +400,24 @@ private void handleFilterLambda(
handleChainFromFilter(streamType, observableDotFilter, lambdaTree, state);
}

private void handleMapAnonClass(
MaplikeMethodRecord methodRecord,
MethodInvocationTree observableDotMap,
ClassTree annonClassBody) {
for (Tree t : annonClassBody.getMembers()) {
/**
* If the relevant inner method from the method record is found in the class body, the consumer is
* called with the corresponding MethodTree.
*/
private void handleMapOrCollectAnonClassBody(
MapOrCollectLikeMethodRecord methodRecord, ClassTree anonClassBody, Consumer<Tree> consumer) {
for (Tree t : anonClassBody.getMembers()) {
if (t instanceof MethodTree
&& ((MethodTree) t).getName().toString().equals(methodRecord.getInnerMethodName())) {
observableCallToInnerMethodOrLambda.put(observableDotMap, t);
&& ((MethodTree) t).getName().toString().equals(methodRecord.innerMethodName())) {
consumer.accept(t);
}
}
}

@Override
public void onMatchMethod(
NullAway analysis, MethodTree tree, VisitorState state, Symbol.MethodSymbol methodSymbol) {
if (mapToFilterMap.containsKey(tree)) {
if (mapOrCollectRecordToFilterMap.containsKey(tree)) {
bodyToMethodOrLambda.put(tree.getBody(), tree);
}
}
Expand All @@ -314,7 +436,7 @@ public void onMatchLambdaExpression(
AccessPathNullnessAnalysis nullnessAnalysis = analysis.getNullnessAnalysis(state);
nullnessAnalysis.forceRunOnMethod(state.getPath(), state.context);
}
if (mapToFilterMap.containsKey(tree)) {
if (mapOrCollectRecordToFilterMap.containsKey(tree)) {
bodyToMethodOrLambda.put(tree.getBody(), tree);
}
}
Expand All @@ -325,7 +447,8 @@ public void onMatchMethodReference(
MemberReferenceTree tree,
VisitorState state,
Symbol.MethodSymbol methodSymbol) {
MaplikeToFilterInstanceRecord callInstanceRecord = mapToFilterMap.get(tree);
MapOrCollectMethodToFilterInstanceRecord callInstanceRecord =
mapOrCollectRecordToFilterMap.get(tree);
if (callInstanceRecord != null && ((JCTree.JCMemberReference) tree).kind.isUnbound()) {
// Unbound method reference, check if we know the corresponding path to be NonNull from the
// previous filter.
Expand Down Expand Up @@ -421,13 +544,15 @@ public NullnessStore.Builder onDataflowInitialStore(
return nullnessBuilder;
}
assert (tree instanceof MethodTree || tree instanceof LambdaExpressionTree);
MaplikeToFilterInstanceRecord callInstanceRecord = mapToFilterMap.get(tree);
MapOrCollectMethodToFilterInstanceRecord callInstanceRecord =
mapOrCollectRecordToFilterMap.get(tree);
if (callInstanceRecord != null) {
// Plug Nullness info from filter method into entry to map method.
Tree filterTree = callInstanceRecord.getFilter();
assert (filterTree instanceof MethodTree || filterTree instanceof LambdaExpressionTree);
MaplikeMethodRecord mapMR = callInstanceRecord.getMaplikeMethodRecord();
for (int argIdx : mapMR.getArgsFromStream()) {
MapOrCollectLikeMethodRecord methodRecord =
callInstanceRecord.getMapOrCollectLikeMethodRecord();
for (int argIdx : methodRecord.argsFromStream()) {
LocalVariableNode filterLocalName;
LocalVariableNode mapLocalName;
if (filterTree instanceof MethodTree) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,29 @@ public static StreamNullabilityPropagator getJavaStreamNullabilityPropagator() {
"accept",
ImmutableSet.of(0))
.withMapMethodAllFromName("flatMap", "apply", ImmutableSet.of(0))
// Names and relevant arguments of all the methods of java.util.stream.Stream that
// behave like .collect(...) for the purposes of this checker
.withCollectMethodFromSignature(
"<R,A>collect(java.util.stream.Collector<? super T,A,R>)",
"java.util.stream.Collectors",
"<T,K,U>toMap(java.util.function.Function<? super T,? extends K>,java.util.function.Function<? super T,? extends U>)",
ImmutableSet.of(0, 1),
"apply",
ImmutableSet.of(0))
.withCollectMethodFromSignature(
"<R,A>collect(java.util.stream.Collector<? super T,A,R>)",
"java.util.stream.Collectors",
"<T,K>groupingBy(java.util.function.Function<? super T,? extends K>)",
ImmutableSet.of(0),
"apply",
ImmutableSet.of(0))
.withCollectMethodFromSignature(
"<R,A>collect(java.util.stream.Collector<? super T,A,R>)",
"com.google.common.collect.ImmutableMap",
"<T,K,V>toImmutableMap(java.util.function.Function<? super T,? extends K>,java.util.function.Function<? super T,? extends V>)",
ImmutableSet.of(0, 1),
"apply",
ImmutableSet.of(0))
// List of methods of java.util.stream.Stream through which we just propagate the
// nullability information of the last call, e.g. m() in
// Observable.filter(...).m().map(...) means the
Expand Down
Loading

0 comments on commit 76f0f77

Please sign in to comment.