Skip to content

Commit

Permalink
Merge pull request #4 from AppliedIntuition/ankit/add-dict-union
Browse files Browse the repository at this point in the history
Support dict select union
  • Loading branch information
ankit-agarwal-ai authored Dec 1, 2022
2 parents c854a38 + ee3dbe6 commit ccbe820
Show file tree
Hide file tree
Showing 6 changed files with 238 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
import net.starlark.java.syntax.TokenKind;

/**
* An attribute value consisting of a concatenation of native types and selects, e.g:
* An attribute value consisting of a concatenation (via the {@code +} operator for lists or the
* {@code |} operator for dicts) of native types and selects, e.g:
*
* <pre>
* rule(
Expand Down Expand Up @@ -113,7 +114,11 @@ static SelectorList concat(Object x, Object y) throws EvalException {

@Override
public SelectorList binaryOp(TokenKind op, Object that, boolean thisLeft) throws EvalException {
if (op == TokenKind.PLUS) {
if (getNativeType(that).equals(Dict.class)) {
if (op == TokenKind.PIPE) {
return thisLeft ? concat(this, that) : concat(that, this);
}
} else if (op == TokenKind.PLUS) {
return thisLeft ? concat(this, that) : concat(that, this);
}
return null;
Expand Down Expand Up @@ -141,7 +146,7 @@ static SelectorList of(Iterable<?> values) throws EvalException {
}
if (!canConcatenate(getNativeType(firstValue), getNativeType(value))) {
throw Starlark.errorf(
"'+' operator applied to incompatible types (%s, %s)",
"Cannot combine incompatible types (%s, %s)",
getTypeName(firstValue), getTypeName(value));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ public String toString() {

@Override
public SelectorList binaryOp(TokenKind op, Object that, boolean thisLeft) throws EvalException {
if (op == TokenKind.PLUS) {
return thisLeft ? SelectorList.concat(this, that) : SelectorList.concat(that, this);
}
return null;
return SelectorList.of(this).binaryOp(op, that, thisLeft);
}

@Override
Expand Down
18 changes: 15 additions & 3 deletions src/main/java/com/google/devtools/build/lib/packages/Type.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.util.Set;
import java.util.logging.Level;
import javax.annotation.Nullable;
import net.starlark.java.eval.Dict;
import net.starlark.java.eval.EvalException;
import net.starlark.java.eval.Printer;
import net.starlark.java.eval.Sequence;
Expand Down Expand Up @@ -208,9 +209,11 @@ public LabelClass getLabelClass() {

/**
* Implementation of concatenation for this type, as if by {@code elements[0] + ... +
* elements[n-1]}). Returns null to indicate concatenation isn't supported. This method exists to
* support deferred additions {@code select + T} for catenable types T such as string, int, and
* list.
* elements[n-1]}) for scalars or lists, or {@code elements[0] | ... | elements[n-1]} for dicts.
* Returns null to indicate concatenation isn't supported.
*
* <p>This method exists to support deferred additions {@code select + T} for catenable types T
* such as string, int, list, and deferred unions {@code select | T} for map types T.
*/
public T concat(Iterable<T> elements) {
return null;
Expand Down Expand Up @@ -568,6 +571,15 @@ public Map<KeyT, ValueT> convert(Object x, Object what, Object context)
return ImmutableMap.copyOf(result);
}

@Override
public Map<KeyT, ValueT> concat(Iterable<Map<KeyT, ValueT>> iterable) {
Dict.Builder<KeyT, ValueT> output = new Dict.Builder<>();
for (Map<KeyT, ValueT> map : iterable) {
output.putAll(map);
}
return output.buildImmutable();
}

@Override
public Map<KeyT, ValueT> getDefaultValue() {
return empty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,26 @@ public void noMatchCustomErrorMessage() throws Exception {
@Test
public void nativeTypeConcatenatedWithSelect() throws Exception {
writeConfigRules();
scratch.file("java/foo/BUILD",
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" return []",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = {':always': 'always'} | select({",
" '//conditions:a': {':a': 'a'},",
" '//conditions:b': {':b': 'b'},",
" })",
")",
"java_binary(",
" name = 'binary',",
" srcs = ['binary.java'],",
Expand All @@ -831,12 +850,37 @@ public void nativeTypeConcatenatedWithSelect() throws Exception {
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));

checkRule(
"//java/foo:mytarget",
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));
}

@Test
public void selectConcatenatedWithNativeType() throws Exception {
writeConfigRules();
scratch.file("java/foo/BUILD",
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" return []",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = select({",
" '//conditions:a': {':a': 'a'},",
" '//conditions:b': {':b': 'b'},",
" }) | {':always': 'always'}",
")",
"java_binary(",
" name = 'binary',",
" srcs = ['binary.java'],",
Expand All @@ -859,12 +903,40 @@ public void selectConcatenatedWithNativeType() throws Exception {
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));

checkRule(
"//java/foo:mytarget",
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libalways.jar", "bin java/foo/libb.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar"));
}

@Test
public void selectConcatenatedWithSelect() throws Exception {
writeConfigRules();
scratch.file("java/foo/BUILD",
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" return []",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = select({",
" '//conditions:a': {':a': 'a'},",
" '//conditions:b': {':b': 'b'},",
" }) | select({",
" '//conditions:a': {':a2': 'a2'},",
" '//conditions:b': {':b2': 'b2'},",
" })",
")",
"java_binary(",
" name = 'binary',",
" srcs = ['binary.java'],",
Expand Down Expand Up @@ -894,6 +966,58 @@ public void selectConcatenatedWithSelect() throws Exception {
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libb.jar", "bin java/foo/libb2.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar", "bin java/foo/liba2.jar"));

checkRule(
"//java/foo:mytarget",
"--foo=b",
/*expected:*/ ImmutableList.of("bin java/foo/libb.jar", "bin java/foo/libb2.jar"),
/*not expected:*/ ImmutableList.of("bin java/foo/liba.jar", "bin java/foo/liba2.jar"));
}

@Test
public void dictsWithSameKey() throws Exception {
writeConfigRules();
scratch.file(
"java/foo/rule.bzl",
"def _rule_impl(ctx):",
" outputs = []",
" for target, value in ctx.attr.deps.items():",
" output = ctx.actions.declare_file(target.label.name + value)",
" ctx.actions.write(content = value, output = output)",
" outputs.append(output)",
" return [DefaultInfo(files=depset(outputs))]",
"myrule = rule(",
" implementation = _rule_impl,",
" attrs = {",
" 'deps': attr.label_keyed_string_dict()",
" },",
")");
scratch.file(
"java/foo/BUILD",
"load(':rule.bzl', 'myrule')",
"myrule(",
" name = 'mytarget',",
" deps = select({",
" '//conditions:a': {':a': 'a'},",
" }) | select({",
" '//conditions:a': {':a': 'a2'},",
" })",
")",
"java_library(",
" name = 'a',",
" srcs = ['a.java']",
")",
"filegroup(",
" name = 'group',",
" srcs = [':mytarget'],",
")");

checkRule(
"//java/foo:group",
"srcs",
ImmutableList.of("--foo=a"),
/*expected:*/ ImmutableList.of("bin java/foo/aa2"),
/*not expected:*/ ImmutableList.of("bin java/foo/aa"));
}

@Test
Expand Down Expand Up @@ -926,7 +1050,7 @@ public void concatenationWithDifferentTypes() throws Exception {

reporter.removeHandler(failFastHandler);
assertThrows(NoSuchTargetException.class, () -> getTarget("//java/foo:binary"));
assertContainsEvent("'+' operator applied to incompatible types");
assertContainsEvent("Cannot combine incompatible types");
}

@Test
Expand Down Expand Up @@ -965,7 +1089,7 @@ public void selectsWithGlobsWrongType() throws Exception {

reporter.removeHandler(failFastHandler);
assertThrows(NoSuchTargetException.class, () -> getTarget("//foo:binary"));
assertContainsEvent("'+' operator applied to incompatible types");
assertContainsEvent("Cannot combine incompatible types");
}

@Test
Expand Down Expand Up @@ -1155,8 +1279,7 @@ public void noneValuesWithMultipleSelectsMixedValues() throws Exception {
reporter.removeHandler(failFastHandler);
useConfiguration("--define", "mode=a");
assertThat(getConfiguredTarget("//a:gen")).isNull();
assertContainsEvent(
"'+' operator applied to incompatible types (select of string, select of NoneType)");
assertContainsEvent("Cannot combine incompatible types (select of string, select of NoneType)");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,50 @@ public void testSelectorList() throws Exception {
.entrySet());
}

@Test
public void testSelectorDict() throws Exception {
Object selector1 =
new SelectorValue(
ImmutableMap.of(
"//conditions:a",
ImmutableMap.of("//a:a", "a"),
"//conditions:b",
ImmutableMap.of("//b:b", "b")),
"");
Object selector2 =
new SelectorValue(
ImmutableMap.of(
"//conditions:c",
ImmutableMap.of("//c:c", "c"),
"//conditions:d",
ImmutableMap.of("//d:d", "d")),
"");
BuildType.SelectorList<Map<Label, String>> selectorList =
new BuildType.SelectorList<>(
ImmutableList.of(selector1, selector2),
null,
labelConverter,
BuildType.LABEL_KEYED_STRING_DICT);

assertThat(selectorList.getOriginalType()).isEqualTo(BuildType.LABEL_KEYED_STRING_DICT);
assertThat(selectorList.getKeyLabels())
.containsExactly(
Label.parseAbsolute("//conditions:a", ImmutableMap.of()),
Label.parseAbsolute("//conditions:b", ImmutableMap.of()),
Label.parseAbsolute("//conditions:c", ImmutableMap.of()),
Label.parseAbsolute("//conditions:d", ImmutableMap.of()));

List<Selector<Map<Label, String>>> selectors = selectorList.getSelectors();
assertThat(selectors.get(0).getEntries().entrySet())
.containsExactlyElementsIn(
ImmutableMap.of(
Label.parseAbsolute("//conditions:a", ImmutableMap.of()),
ImmutableMap.of(Label.create("@//a", "a"), "a"),
Label.parseAbsolute("//conditions:b", ImmutableMap.of()),
ImmutableMap.of(Label.create("@//b", "b"), "b"))
.entrySet());
}

@Test
public void testSelectorListMixedTypes() throws Exception {
Object selector1 =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,49 @@ public void testPlus() throws Exception {

@Test
public void testPlusIncompatibleType() throws Exception {

assertFails(
"select({'foo': ['FOO'], 'bar': ['BAR']}) + 1",
"'+' operator applied to incompatible types (select of list, int)");
"Cannot combine incompatible types (select of list, int)");
assertFails(
"select({'foo': ['FOO']}) + select({'bar': 2})",
"'+' operator applied to incompatible types (select of list, select of int)");
"Cannot combine incompatible types (select of list, select of int)");

assertFails(
"select({'foo': ['FOO']}) + select({'bar': {'a': 'a'}})",
"Cannot combine incompatible types (select of list, select of dict)");
assertFails(
"select({'bar': {'a': 'a'}}) + select({'foo': ['FOO']})",
"Cannot combine incompatible types (select of dict, select of list)");
assertFails(
"['FOO'] + select({'bar': {'a': 'a'}})",
"Cannot combine incompatible types (list, select of dict)");
assertFails(
"select({'bar': {'a': 'a'}}) + ['FOO']",
"Cannot combine incompatible types (select of dict, list)");
assertFails(
"select({'foo': ['FOO']}) + {'a': 'a'}", "unsupported binary operation: select + dict");
assertFails(
"{'a': 'a'} + select({'foo': ['FOO']})", "unsupported binary operation: dict + select");
}

@Test
public void testUnionIncompatibleType() throws Exception {
assertFails(
"select({'foo': ['FOO']}) | select({'bar': {'a': 'a'}})",
"Cannot combine incompatible types (select of list, select of dict)");
assertFails(
"select({'bar': {'a': 'a'}}) | select({'foo': ['FOO']})",
"Cannot combine incompatible types (select of dict, select of list)");
assertFails(
"['FOO'] | select({'bar': {'a': 'a'}})", "unsupported binary operation: list | select");
assertFails(
"select({'bar': {'a': 'a'}}) | ['FOO']", "unsupported binary operation: select | list");
assertFails(
"select({'foo': ['FOO']}) | {'a': 'a'}",
"Cannot combine incompatible types (select of list, dict)");
assertFails(
"{'a': 'a'} | select({'foo': ['FOO']})",
"Cannot combine incompatible types (dict, select of list)");
}
}

0 comments on commit ccbe820

Please sign in to comment.