Skip to content

Commit

Permalink
Add and demonstrate a utility function for testing @rules.
Browse files Browse the repository at this point in the history
  • Loading branch information
stuhood committed Apr 5, 2018
1 parent fa3e30b commit f62743a
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 14 deletions.
26 changes: 12 additions & 14 deletions tests/python/pants_test/engine/test_build_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,30 +11,28 @@
from pants.build_graph.address import Address
from pants.engine.addressable import (Exactly, SubclassesOf, addressable, addressable_dict,
addressable_list)
from pants.engine.build_files import ResolvedTypeMismatchError, create_graph_rules
from pants.engine.fs import create_fs_rules
from pants.engine.build_files import (ResolvedTypeMismatchError, create_graph_rules,
parse_address_family)
from pants.engine.fs import Dir, FileContent, FilesContent, PathGlobs, create_fs_rules
from pants.engine.mapper import AddressMapper, ResolveError
from pants.engine.nodes import Return, Throw
from pants.engine.parser import SymbolTable
from pants.engine.struct import HasProducts, Struct, StructWithDeps
from pants_test.engine.examples.parsers import (JsonParser, PythonAssignmentsParser,
PythonCallbacksParser)
from pants_test.engine.scheduler_test_base import SchedulerTestBase
from pants_test.engine.util import Target, run_rule


class Target(Struct, HasProducts):
class ParseAddressFamilyTest(unittest.TestCase):
def test_empty(self):
"""Test that parsing an empty BUILD file results in an empty AddressFamily."""
address_mapper = AddressMapper(JsonParser(TestTable()))
af = run_rule(parse_address_family, address_mapper, Dir('/dev/null'), {
(FilesContent, PathGlobs): lambda _: FilesContent([FileContent('/dev/null/BUILD', '')])
})
self.assertEquals(len(af.objects_by_name), 0)

def __init__(self, name=None, configurations=None, **kwargs):
super(Target, self).__init__(name=name, **kwargs)
self.configurations = configurations

@property
def products(self):
return self.configurations

@addressable_list(SubclassesOf(Struct))
def configurations(self):
pass


class ApacheThriftConfiguration(StructWithDeps):
Expand Down
62 changes: 62 additions & 0 deletions tests/python/pants_test/engine/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,80 @@
unicode_literals, with_statement)

import re
from types import GeneratorType

from pants.binaries.binary_util import BinaryUtilPrivate
from pants.engine.addressable import SubclassesOf, addressable_list
from pants.engine.native import Native
from pants.engine.parser import SymbolTable
from pants.engine.rules import RuleIndex
from pants.engine.scheduler import WrappedNativeScheduler
from pants.engine.selectors import Get
from pants.engine.struct import HasProducts, Struct
from pants_test.option.util.fakes import create_options_for_optionables
from pants_test.subsystem.subsystem_util import init_subsystem


def run_rule(rule, *args):
"""A test helper function that runs an @rule with a set of arguments and Get providers.
An @rule named `my_rule` that takes one argument and makes no `Get` requests can be invoked
like so (although you could also just invoke it directly):
```
return_value = run_rule(my_rule, arg1)
```
In the case of an @rule that makes Get requests, things get more interesting: an extra argument
is required that represents a dict mapping (product, subject) type pairs to one argument functions
that take a subject value and return a product value.
So in the case of an @rule named `my_co_rule` that takes one argument and makes Get requests
for product and subject types (Listing, Dir), the invoke might look like:
```
return_value = run_rule(my_co_rule, arg1, {(Listing, Dir): lambda x: Listing(..)})
```
:returns: The return value of the completed @rule.
"""

task_rule = getattr(rule, '_rule', None)
if task_rule is None:
raise TypeError('Expected to receive a decorated `@rule`; got: {}'.format(rule))

gets_len = len(task_rule.input_gets)

if len(args) != len(task_rule.input_selectors) + (1 if gets_len else 0):
raise ValueError('Rule expected to receive arguments of the form: {}; got: {}'.format(
task_rule.input_selectors, args))

args, get_providers = (args[:-1], args[-1]) if gets_len > 0 else (args, {})
if gets_len != len(get_providers):
raise ValueError('Rule expected to receive Get providers for {}; got: {}'.format(
task_rule.input_gets, get_providers))

res = rule(*args)
if not isinstance(res, GeneratorType):
return res

def get(product, subject):
provider = get_providers.get((product, type(subject)))
if provider is None:
raise AssertionError('Rule requested: Get{}, which cannot be satisfied.'.format(
(product, type(subject), subject)))
return provider(subject)

rule_coroutine = res
rule_input = None
while True:
res = rule_coroutine.send(rule_input)
if isinstance(res, Get):
rule_input = get(res.product, res.subject)
elif type(res) in (tuple, list):
rule_input = [get(g.product, g.subject) for g in res]
else:
return res


def init_native():
"""Initialize and return a `Native` instance."""
init_subsystem(BinaryUtilPrivate.Factory)
Expand Down

0 comments on commit f62743a

Please sign in to comment.