Skip to content

Commit

Permalink
Implement flag_group in the new rule-based toolchain.
Browse files Browse the repository at this point in the history
BEGIN_PUBLIC
Implement flag_group in the new rule-based toolchain.
END_PUBLIC

PiperOrigin-RevId: 622107179
Change-Id: I9e1971e279f313ce85537c899bcf80860616f8b7
  • Loading branch information
Googler authored and copybara-github committed Apr 5, 2024
1 parent 5467790 commit 6e490f7
Show file tree
Hide file tree
Showing 6 changed files with 458 additions and 35 deletions.
61 changes: 29 additions & 32 deletions cc/toolchains/args.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,50 @@
# limitations under the License.
"""All providers for rule-based bazel toolchain config."""

load("//cc:cc_toolchain_config_lib.bzl", "flag_group")
load("//cc/toolchains/impl:args_utils.bzl", "validate_nested_args")
load(
"//cc/toolchains/impl:collect.bzl",
"collect_action_types",
"collect_files",
"collect_provider",
)
load(
"//cc/toolchains/impl:nested_args.bzl",
"NESTED_ARGS_ATTRS",
"args_wrapper_macro",
"nested_args_provider_from_ctx",
)
load(
":cc_toolchain_info.bzl",
"ActionTypeSetInfo",
"ArgsInfo",
"ArgsListInfo",
"BuiltinVariablesInfo",
"FeatureConstraintInfo",
"NestedArgsInfo",
)

visibility("public")

def _cc_args_impl(ctx):
if not ctx.attr.args and not ctx.attr.env:
fail("cc_args requires at least one of args and env")

actions = collect_action_types(ctx.attr.actions)
files = collect_files(ctx.attr.data)
requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)

if not ctx.attr.args and not ctx.attr.nested and not ctx.attr.env:
fail("cc_args requires at least one of args, nested, and env")

nested = None
if ctx.attr.args:
# TODO: This is temporary until cc_nested_args is implemented.
nested = NestedArgsInfo(
if ctx.attr.args or ctx.attr.nested:
nested = nested_args_provider_from_ctx(ctx)
validate_nested_args(
variables = ctx.attr._variables[BuiltinVariablesInfo].variables,
nested_args = nested,
actions = actions.to_list(),
label = ctx.label,
nested = tuple(),
iterate_over = None,
files = files,
requires_types = {},
legacy_flag_group = flag_group(flags = ctx.attr.args),
)
files = nested.files
else:
files = collect_files(ctx.attr.data)

requires = collect_provider(ctx.attr.requires_any_of, FeatureConstraintInfo)

args = ArgsInfo(
label = ctx.label,
Expand All @@ -72,7 +79,7 @@ def _cc_args_impl(ctx):
),
]

cc_args = rule(
_cc_args = rule(
implementation = _cc_args_impl,
attrs = {
"actions": attr.label_list(
Expand All @@ -82,21 +89,6 @@ cc_args = rule(
See @rules_cc//cc/toolchains/actions:all for valid options.
""",
),
"args": attr.string_list(
doc = """Arguments that should be added to the command-line.
These are evaluated in order, with earlier args appearing earlier in the
invocation of the underlying tool.
""",
),
"data": attr.label_list(
allow_files = True,
doc = """Files required to add this argument to the command-line.
For example, a flag that sets the header directory might add the headers in that
directory as additional files.
""",
),
"env": attr.string_dict(
doc = "Environment variables to be added to the command-line.",
Expand All @@ -108,7 +100,10 @@ directory as additional files.
If omitted, this flag set will be enabled unconditionally.
""",
),
},
"_variables": attr.label(
default = "//cc/toolchains/variables:variables",
),
} | NESTED_ARGS_ATTRS,
provides = [ArgsInfo],
doc = """Declares a list of arguments bound to a set of actions.
Expand All @@ -121,3 +116,5 @@ Examples:
)
""",
)

cc_args = lambda **kwargs: args_wrapper_macro(rule = _cc_args, **kwargs)
93 changes: 92 additions & 1 deletion cc/toolchains/impl/args_utils.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""."""
"""Helper functions for working with args."""

load(":variables.bzl", "get_type")

visibility([
"//cc/toolchains",
"//tests/rule_based_toolchain/...",
])

def get_action_type(args_list, action_type):
"""Returns the corresponding entry in ArgsListInfo.by_action.
Expand All @@ -28,3 +35,87 @@ def get_action_type(args_list, action_type):
return args

return struct(action = action_type, args = tuple(), files = depset([]))

def validate_nested_args(*, nested_args, variables, actions, label, fail = fail):
"""Validates the typing for an nested_args invocation.
Args:
nested_args: (NestedArgsInfo) The nested_args to validate
variables: (Dict[str, VariableInfo]) A mapping from variable name to
the metadata (variable type and valid actions).
actions: (List[ActionTypeInfo]) The actions we require these variables
to be valid for.
label: (Label) The label of the rule we're currently validating.
Used for error messages.
fail: The fail function. Use for testing only.
"""
stack = [(nested_args, {})]

for _ in range(9999999):
if not stack:
break
nested_args, overrides = stack.pop()
if nested_args.iterate_over != None or nested_args.unwrap_options:
# Make sure we don't keep using the same object.
overrides = dict(**overrides)

if nested_args.iterate_over != None:
type = get_type(
name = nested_args.iterate_over,
variables = variables,
overrides = overrides,
actions = actions,
args_label = label,
nested_label = nested_args.label,
fail = fail,
)
if type["name"] == "list":
# Rewrite the type of the thing we iterate over from a List[T]
# to a T.
overrides[nested_args.iterate_over] = type["elements"]
elif type["name"] == "option" and type["elements"]["name"] == "list":
# Rewrite Option[List[T]] to T.
overrides[nested_args.iterate_over] = type["elements"]["elements"]
else:
fail("Attempting to iterate over %s, but it was not a list - it was a %s" % (nested_args.iterate_over, type["repr"]))

# 1) Validate variables marked with after_option_unwrap = False.
# 2) Unwrap Option[T] to T as required.
# 3) Validate variables marked with after_option_unwrap = True.
for after_option_unwrap in [False, True]:
for var_name, requirements in nested_args.requires_types.items():
for requirement in requirements:
if requirement.after_option_unwrap == after_option_unwrap:
type = get_type(
name = var_name,
variables = variables,
overrides = overrides,
actions = actions,
args_label = label,
nested_label = nested_args.label,
fail = fail,
)
if type["name"] not in requirement.valid_types:
fail("{msg}, but {var_name} has type {type}".format(
var_name = var_name,
msg = requirement.msg,
type = type["repr"],
))

# Only unwrap the options after the first iteration of this loop.
if not after_option_unwrap:
for var in nested_args.unwrap_options:
type = get_type(
name = var,
variables = variables,
overrides = overrides,
actions = actions,
args_label = label,
nested_label = nested_args.label,
fail = fail,
)
if type["name"] == "option":
overrides[var] = type["elements"]

for child in nested_args.nested:
stack.append((child, overrides))
124 changes: 123 additions & 1 deletion cc/toolchains/impl/nested_args.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
"""Helper functions for working with args."""

load("@bazel_skylib//lib:structs.bzl", "structs")
load("//cc:cc_toolchain_config_lib.bzl", "flag_group", "variable_with_value")
load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo")
load("//cc/toolchains:cc_toolchain_info.bzl", "NestedArgsInfo", "VariableInfo")
load(":collect.bzl", "collect_files", "collect_provider")

visibility([
"//cc/toolchains",
Expand Down Expand Up @@ -48,6 +50,126 @@ cc_args(
iterate_over = "//toolchains/variables:foo_list",
"""

# @unsorted-dict-items.
NESTED_ARGS_ATTRS = {
"args": attr.string_list(
doc = """json-encoded arguments to be added to the command-line.
Usage:
cc_args(
...,
args = ["--foo", format_arg("%s", "//cc/toolchains/variables:foo")]
)
This is equivalent to flag_group(flags = ["--foo", "%{foo}"])
Mutually exclusive with nested.
""",
),
"nested": attr.label_list(
providers = [NestedArgsInfo],
doc = """nested_args that should be added on the command-line.
Mutually exclusive with args.""",
),
"data": attr.label_list(
allow_files = True,
doc = """Files required to add this argument to the command-line.
For example, a flag that sets the header directory might add the headers in that
directory as additional files.
""",
),
"variables": attr.label_list(
providers = [VariableInfo],
doc = "Variables to be used in substitutions",
),
"iterate_over": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.iterate_over"),
"requires_not_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_available"),
"requires_none": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_not_available"),
"requires_true": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_true"),
"requires_false": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_false"),
"requires_equal": attr.label(providers = [VariableInfo], doc = "Replacement for flag_group.expand_if_equal"),
"requires_equal_value": attr.string(),
}

def args_wrapper_macro(*, name, rule, args = [], **kwargs):
"""Invokes a rule by converting args to attributes.
Args:
name: (str) The name of the target.
rule: (rule) The rule to invoke. Either cc_args or cc_nested_args.
args: (List[str|Formatted]) A list of either strings, or function calls
from format.bzl. For example:
["--foo", format_arg("--sysroot=%s", "//cc/toolchains/variables:sysroot")]
**kwargs: kwargs to pass through into the rule invocation.
"""
out_args = []
vars = []
if type(args) != "list":
fail("Args must be a list in %s" % native.package_relative_label(name))
for arg in args:
if type(arg) == "string":
out_args.append(raw_string(arg))
elif getattr(arg, "format_type") == "format_arg":
arg = structs.to_dict(arg)
if arg["value"] == None:
out_args.append(arg)
else:
var = arg.pop("value")

# Swap the variable from a label to an index. This allows us to
# actually get the providers in a rule.
out_args.append(struct(value = len(vars), **arg))
vars.append(var)
else:
fail("Invalid type of args in %s. Expected either a string or format_args(format_string, variable_label), got value %r" % (native.package_relative_label(name), arg))

rule(
name = name,
args = [json.encode(arg) for arg in out_args],
variables = vars,
**kwargs
)

def _var(target):
if target == None:
return None
return target[VariableInfo].name

# TODO: Consider replacing this with a subrule in the future. However, maybe not
# for a long time, since it'll break compatibility with all bazel versions < 7.
def nested_args_provider_from_ctx(ctx):
"""Gets the nested args provider from a rule that has NESTED_ARGS_ATTRS.
Args:
ctx: The rule context
Returns:
NestedArgsInfo
"""
variables = collect_provider(ctx.attr.variables, VariableInfo)
args = []
for arg in ctx.attr.args:
arg = json.decode(arg)
if "value" in arg:
if arg["value"] != None:
arg["value"] = variables[arg["value"]]
args.append(struct(**arg))

return nested_args_provider(
label = ctx.label,
args = args,
nested = collect_provider(ctx.attr.nested, NestedArgsInfo),
files = collect_files(ctx.attr.data),
iterate_over = _var(ctx.attr.iterate_over),
requires_not_none = _var(ctx.attr.requires_not_none),
requires_none = _var(ctx.attr.requires_none),
requires_true = _var(ctx.attr.requires_true),
requires_false = _var(ctx.attr.requires_false),
requires_equal = _var(ctx.attr.requires_equal),
requires_equal_value = ctx.attr.requires_equal_value,
)

def raw_string(s):
"""Constructs metadata for creating a raw string.
Expand Down
Loading

0 comments on commit 6e490f7

Please sign in to comment.