Skip to content

Commit

Permalink
Improve taxscales.py typing
Browse files Browse the repository at this point in the history
Merge pull request #929 from openfisca/improve-taxscales-typing
  • Loading branch information
Mauko Quiroga authored Jan 3, 2020
2 parents 6ccddcb + a58783c commit 7c26bc4
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 13 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

### 34.6.4 [#929](https://github.com/openfisca/openfisca-core/pull/929)

#### Documentation

- Add more explicit typing annotations for `ndarrays`.

### 34.6.3 [#928](https://github.com/openfisca/openfisca-core/pull/928)

#### Technical changes
Expand Down
38 changes: 26 additions & 12 deletions openfisca_core/taxscales.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import copy
import itertools
import logging
Expand Down Expand Up @@ -42,7 +44,7 @@ def __init__(
class_name: str,
method_name: str,
arg_name: str,
arg_value: ndarray
arg_value: Union[List, ndarray]
) -> None:
message = [
f"'{class_name}.{method_name}' can't be run with an empty '{arg_name}':\n",
Expand Down Expand Up @@ -107,7 +109,7 @@ def __repr__(self) -> Any:
f"{self.__class__.__name__}",
)

def calc(self, _tax_base: ndarray, _right: bool) -> Any:
def calc(self, _tax_base: Union[ndarray[int], ndarray[float]], _right: bool) -> Any:
raise NotImplementedError(
"Method 'calc' is not implemented for "
f"{self.__class__.__name__}",
Expand All @@ -127,7 +129,7 @@ def multiply_thresholds(

def bracket_indices(
self,
tax_base: ndarray,
tax_base: Union[ndarray[int], ndarray[float]],
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> Any:
Expand Down Expand Up @@ -245,10 +247,10 @@ def multiply_thresholds(

def bracket_indices(
self,
tax_base: ndarray,
tax_base: Union[ndarray[int], ndarray[float]],
factor: float = 1.0,
round_decimals: Optional[int] = None,
) -> ndarray:
) -> ndarray[int]:
"""
Compute the relevant bracket indices for the given tax bases.
Expand Down Expand Up @@ -340,7 +342,11 @@ def add_bracket(self, threshold: int, amount: Union[int, float]) -> None:
self.thresholds.insert(i, threshold)
self.amounts.insert(i, amount)

def calc(self, tax_base: ndarray, right: bool = False) -> ndarray:
def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
right: bool = False,
) -> ndarray[float]:
guarded_thresholds = array([-inf] + self.thresholds + [inf])
bracket_indices = digitize(tax_base, guarded_thresholds, right = right)
guarded_amounts = array([0] + self.amounts + [0])
Expand All @@ -360,15 +366,23 @@ class MarginalAmountTaxScale(SingleAmountTaxScale):
containing the input.
"""

def calc(self, tax_base: ndarray, _right: bool = False) -> ndarray:
def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
_right: bool = False,
) -> ndarray[float]:
base1 = tile(tax_base, (len(self.thresholds), 1)).T
thresholds1 = tile(hstack((self.thresholds, inf)), (len(tax_base), 1))
a = max_(min_(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0)
return dot(self.amounts, a.T > 0)


class LinearAverageRateTaxScale(AbstractRateTaxScale):
def calc(self, tax_base: ndarray, _right: bool = False) -> ndarray:
def calc(
self,
tax_base: Union[ndarray[int], ndarray[float]],
_right: bool = False,
) -> ndarray[float]:
if len(self.rates) == 1:
return tax_base * self.rates[0]

Expand Down Expand Up @@ -446,10 +460,10 @@ def add_tax_scale(self, tax_scale: AbstractRateTaxScale) -> None:

def calc(
self,
tax_base: ndarray,
tax_base: Union[ndarray[int], ndarray[float]],
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray:
) -> ndarray[float]:
"""
Compute the tax amount for the given tax bases by applying the taxscale.
Expand Down Expand Up @@ -515,10 +529,10 @@ def combine_bracket(

def marginal_rates(
self,
tax_base: ndarray,
tax_base: Union[ndarray[int], ndarray[float]],
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray:
) -> ndarray[float]:
"""
Compute the marginal tax rates relevant for the given tax bases.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

setup(
name = 'OpenFisca-Core',
version = '34.6.3',
version = '34.6.4',
author = 'OpenFisca Team',
author_email = 'contact@openfisca.org',
classifiers = [
Expand Down

0 comments on commit 7c26bc4

Please sign in to comment.