Skip to content

Commit

Permalink
Support dict select union
Browse files Browse the repository at this point in the history
Starlark has recentely added support for union operations over dictionaries (bazelbuild/starlark#215). The syntax is already supported in bazel as of bazelbuild#14540, but the same operation with selects of dictionaries is still usupported.

Related issue: bazelbuild#12457

Closes bazelbuild#15075.

PiperOrigin-RevId: 469827107
Change-Id: If82cfaf577db41efc2e9b55af47b0b1710badc10
  • Loading branch information
AlessandroPatti authored and aiuto committed Oct 12, 2022
1 parent 2938965 commit 3d99eec
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
import net.starlark.java.syntax.TokenKind;

/**
* An attribute value consisting of a concatenation of native types and selects, e.g:
* An attribute value consisting of a concatenation (via the {@code +} operator for lists or the
* {@code |} operator for dicts) of native types and selects, e.g:
*
* <pre>
* rule(
Expand Down Expand Up @@ -121,7 +122,7 @@ static SelectorList of(Iterable<?> values) throws EvalException {
}
if (!canConcatenate(getNativeType(firstValue), getNativeType(value))) {
throw Starlark.errorf(
"'+' operator applied to incompatible types (%s, %s)",
"Cannot combine incompatible types (%s, %s)",
getTypeName(firstValue), getTypeName(value));
}
}
Expand All @@ -142,7 +143,11 @@ static SelectorList concat(Object x, Object y) throws EvalException {
@Override
@Nullable
public SelectorList binaryOp(TokenKind op, Object that, boolean thisLeft) throws EvalException {
if (op == TokenKind.PLUS) {
if (getNativeType(that).equals(Dict.class)) {
if (op == TokenKind.PIPE) {
return thisLeft ? concat(this, that) : concat(that, this);
}
} else if (op == TokenKind.PLUS) {
return thisLeft ? concat(this, that) : concat(that, this);
}
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ public String toString() {
@Override
@Nullable
public SelectorList binaryOp(TokenKind op, Object that, boolean thisLeft) throws EvalException {
if (op == TokenKind.PLUS) {
return thisLeft ? SelectorList.concat(this, that) : SelectorList.concat(that, this);
}
return null;
return SelectorList.of(this).binaryOp(op, that, thisLeft);
}

@Override
Expand Down
18 changes: 15 additions & 3 deletions src/main/java/com/google/devtools/build/lib/packages/Type.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Set;
import java.util.logging.Level;
import javax.annotation.Nullable;
import net.starlark.java.eval.Dict;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Printer;
import net.starlark.java.eval.Sequence;
Expand Down Expand Up @@ -213,9 +214,11 @@ public LabelClass getLabelClass() {

/**
* Implementation of concatenation for this type, as if by {@code elements[0] + ... +
* elements[n-1]}). Returns null to indicate concatenation isn't supported. This method exists to
* support deferred additions {@code select + T} for catenable types T such as string, int, and
* list.
* elements[n-1]}) for scalars or lists, or {@code elements[0] | ... | elements[n-1]} for dicts.
* Returns null to indicate concatenation isn't supported.
*
* <p>This method exists to support deferred additions {@code select + T} for catenable types T
* such as string, int, list, and deferred unions {@code select | T} for map types T.
*/
public T concat(Iterable<T> elements) {
return null;
Expand Down Expand Up @@ -567,6 +570,15 @@ public Map<KeyT, ValueT> convert(Object x, Object what, LabelConverter labelConv
return ImmutableMap.copyOf(result);
}

@Override
public Map<KeyT, ValueT> concat(Iterable<Map<KeyT, ValueT>> iterable) {
Dict.Builder<KeyT, ValueT> output = new Dict.Builder<>();
for (Map<KeyT, ValueT> map : iterable) {
output.putAll(map);
}
return output.buildImmutable();
}

@Override
public Map<KeyT, ValueT> getDefaultValue() {
return empty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,26 @@ public void noMatchCustomErrorMessage() throws Exception {
@Test
public void nativeTypeConcatenatedWithSelect() throws Exception {
writeConfigRules();
scratch.file("java/foo/BUILD",
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" return []",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = {':always': 'always'} | select({",
" '//conditions:a': {':a': 'a'},",
" '//conditions:b': {':b': 'b'},",
" })",
")",
"java_binary(",
" name = 'binary',",
" srcs = ['binary.java'],",
Expand All @@ -828,12 +847,37 @@ public void nativeTypeConcatenatedWithSelect() throws Exception {
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));

checkRule(
"//java/foo:mytarget",
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));
}

@Test
public void selectConcatenatedWithNativeType() throws Exception {
writeConfigRules();
scratch.file("java/foo/BUILD",
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" return []",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = select({",
" '//conditions:a': {':a': 'a'},",
" '//conditions:b': {':b': 'b'},",
" }) | {':always': 'always'}",
")",
"java_binary(",
" name = 'binary',",
" srcs = ['binary.java'],",
Expand All @@ -856,12 +900,40 @@ public void selectConcatenatedWithNativeType() throws Exception {
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));

checkRule(
"//java/foo:mytarget",
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));
}

@Test
public void selectConcatenatedWithSelect() throws Exception {
writeConfigRules();
scratch.file("java/foo/BUILD",
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" return []",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = select({",
" '//conditions:a': {':a': 'a'},",
" '//conditions:b': {':b': 'b'},",
" }) | select({",
" '//conditions:a': {':a2': 'a2'},",
" '//conditions:b': {':b2': 'b2'},",
" })",
")",
"java_binary(",
" name = 'binary',",
" srcs = ['binary.java'],",
Expand Down Expand Up @@ -891,6 +963,58 @@ public void selectConcatenatedWithSelect() throws Exception {
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libb.jar", "bin java/foo/libb2.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar", "bin java/foo/liba2.jar"));

checkRule(
"//java/foo:mytarget",
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libb.jar", "bin java/foo/libb2.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar", "bin java/foo/liba2.jar"));
}

@Test
public void dictsWithSameKey() throws Exception {
writeConfigRules();
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" outputs = []",
" for target, value in ctx.attr.deps.items():",
" output = ctx.actions.declare_file(target.label.name + value)",
" ctx.actions.write(content = value, output = output)",
" outputs.append(output)",
" return [DefaultInfo(files=depset(outputs))]",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = select({",
" '//conditions:a': {':a': 'a'},",
" }) | select({",
" '//conditions:a': {':a': 'a2'},",
" })",
")",
"java_library(",
" name = 'a',",
" srcs = ['a.java']",
")",
"filegroup(",
" name = 'group',",
" srcs = [':mytarget'],",
")");

checkRule(
"//java/foo:group",
"srcs",
ImmutableList.of("--foo=a"),
/*expected:*/ ImmutableList.of("bin java/foo/aa2"),
/*not expected:*/ ImmutableList.of("bin java/foo/aa"));
}

@Test
Expand Down Expand Up @@ -923,7 +1047,7 @@ public void concatenationWithDifferentTypes() throws Exception {

reporter.removeHandler(failFastHandler);
assertThrows(NoSuchTargetException.class, () -> getTarget("//java/foo:binary"));
assertContainsEvent("'+' operator applied to incompatible types");
assertContainsEvent("Cannot combine incompatible types");
}

@Test
Expand Down Expand Up @@ -962,7 +1086,7 @@ public void selectsWithGlobsWrongType() throws Exception {

reporter.removeHandler(failFastHandler);
assertThrows(NoSuchTargetException.class, () -> getTarget("//foo:binary"));
assertContainsEvent("'+' operator applied to incompatible types");
assertContainsEvent("Cannot combine incompatible types");
}

@Test
Expand Down Expand Up @@ -1152,8 +1276,7 @@ public void noneValuesWithMultipleSelectsMixedValues() throws Exception {
reporter.removeHandler(failFastHandler);
useConfiguration("--define", "mode=a");
assertThat(getConfiguredTarget("//a:gen")).isNull();
assertContainsEvent(
"'+' operator applied to incompatible types (select of string, select of NoneType)");
assertContainsEvent("Cannot combine incompatible types (select of string, select of NoneType)");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,50 @@ public void testSelectorList() throws Exception {
.entrySet());
}

@Test
public void testSelectorDict() throws Exception {
Object selector1 =
new SelectorValue(
ImmutableMap.of(
"//conditions:a",
ImmutableMap.of("//a:a", "a"),
"//conditions:b",
ImmutableMap.of("//b:b", "b")),
"");
Object selector2 =
new SelectorValue(
ImmutableMap.of(
"//conditions:c",
ImmutableMap.of("//c:c", "c"),
"//conditions:d",
ImmutableMap.of("//d:d", "d")),
"");
BuildType.SelectorList<Map<Label, String>> selectorList =
new BuildType.SelectorList<>(
ImmutableList.of(selector1, selector2),
null,
labelConverter,
BuildType.LABEL_KEYED_STRING_DICT);

assertThat(selectorList.getOriginalType()).isEqualTo(BuildType.LABEL_KEYED_STRING_DICT);
assertThat(selectorList.getKeyLabels())
.containsExactly(
Label.parseAbsolute("//conditions:a", ImmutableMap.of()),
Label.parseAbsolute("//conditions:b", ImmutableMap.of()),
Label.parseAbsolute("//conditions:c", ImmutableMap.of()),
Label.parseAbsolute("//conditions:d", ImmutableMap.of()));

List<Selector<Map<Label, String>>> selectors = selectorList.getSelectors();
assertThat(selectors.get(0).getEntries().entrySet())
.containsExactlyElementsIn(
ImmutableMap.of(
Label.parseAbsolute("//conditions:a", ImmutableMap.of()),
ImmutableMap.of(Label.create("@//a", "a"), "a"),
Label.parseAbsolute("//conditions:b", ImmutableMap.of()),
ImmutableMap.of(Label.create("@//b", "b"), "b"))
.entrySet());
}

@Test
public void testSelectorListMixedTypes() throws Exception {
Object selector1 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,49 @@ public void testPlus() throws Exception {

@Test
public void testPlusIncompatibleType() throws Exception {

assertFails(
"select({'foo': ['FOO'], 'bar': ['BAR']}) + 1",
"'+' operator applied to incompatible types (select of list, int)");
"Cannot combine incompatible types (select of list, int)");
assertFails(
"select({'foo': ['FOO']}) + select({'bar': 2})",
"'+' operator applied to incompatible types (select of list, select of int)");
"Cannot combine incompatible types (select of list, select of int)");

assertFails(
"select({'foo': ['FOO']}) + select({'bar': {'a': 'a'}})",
"Cannot combine incompatible types (select of list, select of dict)");
assertFails(
"select({'bar': {'a': 'a'}}) + select({'foo': ['FOO']})",
"Cannot combine incompatible types (select of dict, select of list)");
assertFails(
"['FOO'] + select({'bar': {'a': 'a'}})",
"Cannot combine incompatible types (list, select of dict)");
assertFails(
"select({'bar': {'a': 'a'}}) + ['FOO']",
"Cannot combine incompatible types (select of dict, list)");
assertFails(
"select({'foo': ['FOO']}) + {'a': 'a'}", "unsupported binary operation: select + dict");
assertFails(
"{'a': 'a'} + select({'foo': ['FOO']})", "unsupported binary operation: dict + select");
}

@Test
public void testUnionIncompatibleType() throws Exception {
assertFails(
"select({'foo': ['FOO']}) | select({'bar': {'a': 'a'}})",
"Cannot combine incompatible types (select of list, select of dict)");
assertFails(
"select({'bar': {'a': 'a'}}) | select({'foo': ['FOO']})",
"Cannot combine incompatible types (select of dict, select of list)");
assertFails(
"['FOO'] | select({'bar': {'a': 'a'}})", "unsupported binary operation: list | select");
assertFails(
"select({'bar': {'a': 'a'}}) | ['FOO']", "unsupported binary operation: select | list");
assertFails(
"select({'foo': ['FOO']}) | {'a': 'a'}",
"Cannot combine incompatible types (select of list, dict)");
assertFails(
"{'a': 'a'} | select({'foo': ['FOO']})",
"Cannot combine incompatible types (dict, select of list)");
}
}

0 comments on commit 3d99eec

Please sign in to comment.