diff --git a/agate/aggregations/__init__.py b/agate/aggregations/__init__.py index e4f40cc29..cf82a30f5 100644 --- a/agate/aggregations/__init__.py +++ b/agate/aggregations/__init__.py @@ -21,6 +21,7 @@ from agate.aggregations.any import Any # noqa from agate.aggregations.count import Count # noqa from agate.aggregations.deciles import Deciles # noqa +from agate.aggregations.first import First # noqa from agate.aggregations.has_nulls import HasNulls # noqa from agate.aggregations.iqr import IQR # noqa from agate.aggregations.mad import MAD # noqa diff --git a/agate/aggregations/first.py b/agate/aggregations/first.py new file mode 100644 index 000000000..37e169500 --- /dev/null +++ b/agate/aggregations/first.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python + +from agate.aggregations.base import Aggregation +from agate.data_types import Boolean + + +class First(Aggregation): + """ + Returns the first value that passes a test. + + If the test is omitted, the aggregation will return the first value in the column. + + If no values pass the test, the aggregation will raise an exception. + + :param column_name: + The name of the column to check. + :param test: + A function that takes a value and returns `True` or `False`. Test may be + omitted when checking :class:`.Boolean` data. + """ + def __init__(self, column_name, test=None): + self._column_name = column_name + self._test = test + + def get_aggregate_data_type(self, table): + return table.columns[self._column_name].data_type + + def validate(self, table): + column = table.columns[self._column_name] + data = column.values() + + if self._test is not None and len([d for d in data if self._test(d)]) == 0: + raise ValueError('No values pass the given test.') + + def run(self, table): + column = table.columns[self._column_name] + data = column.values() + + if self._test is None: + return data[0] + + return next((d for d in data if self._test(d))) diff --git a/tests/test_aggregations.py b/tests/test_aggregations.py index 11eefe132..e0dc625da 100644 --- a/tests/test_aggregations.py +++ b/tests/test_aggregations.py @@ -67,6 +67,17 @@ def test_all(self): self.assertEqual(All('one', lambda d: d != 5).run(self.table), True) self.assertEqual(All('one', lambda d: d == 2).run(self.table), False) + def test_first(self): + with self.assertRaises(ValueError): + First('one', lambda d: d == 5).validate(self.table) + + First('one', lambda d: d).validate(self.table) + + self.assertIsInstance(First('one').get_aggregate_data_type(self.table), Number) + self.assertEqual(First('one').run(self.table), 1) + self.assertEqual(First('one', lambda d: d == 2).run(self.table), 2) + self.assertEqual(First('one', lambda d: not d).run(self.table), None) + def test_count(self): rows = ( (1, 2, 'a'),