Skip to content

Commit

Permalink
Improve taxscale method descriptions
Browse files Browse the repository at this point in the history
  • Loading branch information
Mauko Quiroga committed Dec 27, 2019
1 parent 002aa4c commit 89c0d85
Showing 1 changed file with 117 additions and 52 deletions.
169 changes: 117 additions & 52 deletions openfisca_core/taxscales.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
# -*- coding: utf-8 -*-


from bisect import bisect_left, bisect_right
import copy
import logging
import itertools
import os

import numpy as np
from numpy import maximum as max_, minimum as min_
from numpy import (
around,
array,
digitize,
dot,
finfo,
float as float_,
hstack,
inf,
maximum as max_,
minimum as min_,
ndarray,
ones,
outer,
round as round_,
tile,
)

from typing import Optional

from openfisca_core.commons import empty_clone
from openfisca_core.tools import indent
Expand Down Expand Up @@ -51,31 +65,37 @@ def __repr__(self):
def calc(self, base):
raise NotImplementedError('Method "calc" is not implemented for {}'.format(self.__class__.__name__))

def compute_bracket_index(self, base, factor = 1, round_base_decimals = None):
def compute_bracket_index(
self,
tax_bases: ndarray,
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray:
"""
Compute the relevant bracket for the given tax bases
:param numpy.ndarray base: Numpy array of the tax bases.
:param float factor: A numerical factor to apply to the thresholds of the tax scale, defaults to 1
:param int round_base_decimals: Decimals to keep when rounding thresholds, defaults to None (no rounding)
:returns: An integer numpy.ndarray with the relevant bracket indices for the given tax bases.
:rtype: numpy.ndarray
Compute the relevant bracket for the given tax bases.
For instance:
:param tax_bases: Array of the tax bases.
:param factor: Factor to apply to the thresholds of the tax scales.
:param round_base_decimals: Decimals to keep when rounding thresholds.
>>>
>>>
:returns: Int array with relevant bracket indices for the given tax bases.
>>> marginal_tax_scale = MarginalRateTaxScale()
>>> marginal_tax_scale.add_bracket(0, 0)
>>> marginal_tax_scale.add_bracket(100, 0.1)
>>> tax_bases = array([0, 150])
>>> marginal_tax_scale.compute_bracket_index(tax_bases)
[0, 1]
"""
base1 = tile(tax_bases, (len(self.thresholds), 1)).T # type: ignore
factor = ones(len(tax_bases)) * factor

# finfo(float_).eps is used to avoid nan = 0 * inf creation
thresholds1 = outer(factor + finfo(float_).eps, array(self.thresholds + [inf])) # type: ignore

base1 = np.tile(base, (len(self.thresholds), 1)).T
if isinstance(factor, (float, int)):
factor = np.ones(len(base)) * factor
# np.finfo(np.float).eps is used to avoid np.nan = 0 * np.inf creation
thresholds1 = np.outer(factor + np.finfo(np.float).eps, np.array(self.thresholds + [np.inf]))
if round_base_decimals is not None:
thresholds1 = np.round(thresholds1, round_base_decimals)
thresholds1 = round_(thresholds1, round_base_decimals)

return (base1 - thresholds1[:, :-1] >= 0).sum(axis = 1) - 1

def copy(self):
Expand Down Expand Up @@ -132,15 +152,15 @@ def multiply_thresholds(self, factor, decimals = None, inplace = True, new_name
assert new_name is None
for i, threshold in enumerate(self.thresholds):
if decimals is not None:
self.thresholds[i] = np.around(threshold * factor, decimals = decimals)
self.thresholds[i] = around(threshold * factor, decimals = decimals)
else:
self.thresholds[i] = threshold * factor
return self

new_tax_scale = self.__class__(new_name or self.name, option = self.option, unit = self.unit)
for threshold, rate in zip(self.thresholds, self.rates):
if decimals is not None:
new_tax_scale.thresholds.append(np.around(threshold * factor, decimals = decimals))
new_tax_scale.thresholds.append(around(threshold * factor, decimals = decimals))
else:
new_tax_scale.thresholds.append(threshold * factor)

Expand Down Expand Up @@ -176,9 +196,9 @@ def add_bracket(self, threshold, amount):
self.amounts.insert(i, amount)

def calc(self, base, right=False):
guarded_thresholds = np.array([-np.inf] + self.thresholds + [np.inf])
bracket_indices = np.digitize(base, guarded_thresholds, right=right)
guarded_amounts = np.array([0] + self.amounts + [0])
guarded_thresholds = array([-inf] + self.thresholds + [inf])
bracket_indices = digitize(base, guarded_thresholds, right=right)
guarded_amounts = array([0] + self.amounts + [0])
return guarded_amounts[bracket_indices - 1]


Expand All @@ -189,27 +209,27 @@ class MarginalAmountTaxScale(SingleAmountTaxScale):
'''

def calc(self, base):
base1 = np.tile(base, (len(self.thresholds), 1)).T
thresholds1 = np.tile(np.hstack((self.thresholds, np.inf)), (len(base), 1))
base1 = tile(base, (len(self.thresholds), 1)).T
thresholds1 = tile(hstack((self.thresholds, inf)), (len(base), 1))
a = max_(min_(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0)
return np.dot(self.amounts, a.T > 0)
return dot(self.amounts, a.T > 0)


class LinearAverageRateTaxScale(AbstractRateTaxScale):
def calc(self, base):
if len(self.rates) == 1:
return base * self.rates[0]

tiled_base = np.tile(base, (len(self.thresholds) - 1, 1)).T
tiled_thresholds = np.tile(self.thresholds, (len(base), 1))
tiled_base = tile(base, (len(self.thresholds) - 1, 1)).T
tiled_thresholds = tile(self.thresholds, (len(base), 1))
bracket_dummy = (tiled_base >= tiled_thresholds[:, :-1]) * (tiled_base < tiled_thresholds[:, 1:])
rates_array = np.array(self.rates)
thresholds_array = np.array(self.thresholds)
rates_array = array(self.rates)
thresholds_array = array(self.thresholds)
rate_slope = (rates_array[1:] - rates_array[:-1]) / (thresholds_array[1:] - thresholds_array[:-1])
average_rate_slope = np.dot(bracket_dummy, rate_slope.T)
average_rate_slope = dot(bracket_dummy, rate_slope.T)

bracket_average_start_rate = np.dot(bracket_dummy, rates_array[:-1])
bracket_threshold = np.dot(bracket_dummy, thresholds_array[:-1])
bracket_average_start_rate = dot(bracket_dummy, rates_array[:-1])
bracket_threshold = dot(bracket_dummy, thresholds_array[:-1])
log.info("bracket_average_start_rate : {}".format(bracket_average_start_rate))
log.info("average_rate_slope: {}".format(average_rate_slope))
return base * (bracket_average_start_rate + (base - bracket_threshold) * average_rate_slope)
Expand All @@ -236,21 +256,46 @@ def add_tax_scale(self, tax_scale):
self.combine_bracket(rate, threshold_low, threshold_high)
self.combine_bracket(tax_scale.rates[-1], tax_scale.thresholds[-1]) # Pour traiter le dernier threshold

def calc(self, base, factor = 1, round_base_decimals = None):
base1 = np.tile(base, (len(self.thresholds), 1)).T
if isinstance(factor, (float, int)):
factor = np.ones(len(base)) * factor
# np.finfo(np.float).eps is used to avoid np.nan = 0 * np.inf creation
thresholds1 = np.outer(factor + np.finfo(np.float).eps, np.array(self.thresholds + [np.inf]))
def calc(
self,
tax_bases: ndarray,
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray:
"""
Compute the tax amount for the given tax bases by applying the taxscale.
:param tax_bases: Array of the tax bases.
:param factor: Factor to apply to the thresholds of the tax scale.
:param round_base_decimals: Decimals to keep when rounding thresholds.
:returns: Float array with tax amount for the given tax bases.
>>> marginal_tax_scale = MarginalRateTaxScale()
>>> marginal_tax_scale.add_bracket(0, 0)
>>> marginal_tax_scale.add_bracket(100, 0.1)
>>> tax_bases = array([0, 150])
>>> marginal_tax_scale.calc(tax_bases)
[0.0, 5.0]
"""

base1 = tile(tax_bases, (len(self.thresholds), 1)).T # type: ignore
factor = ones(len(tax_bases)) * factor

# finfo(float_).eps is used to avoid nan = 0 * inf creation
thresholds1 = outer(factor + finfo(float_).eps, array(self.thresholds + [inf])) # type: ignore

if round_base_decimals is not None:
thresholds1 = np.round(thresholds1, round_base_decimals)
thresholds1 = round_(thresholds1, round_base_decimals)

a = max_(min_(base1, thresholds1[:, 1:]) - thresholds1[:, :-1], 0)

if round_base_decimals is None:
return np.dot(self.rates, a.T)
return dot(self.rates, a.T)
else:
r = np.tile(self.rates, (len(base), 1))
b = np.round(a, round_base_decimals)
return np.round(r * b, round_base_decimals).sum(axis = 1)
r = tile(self.rates, (len(tax_bases), 1))
b = round_(a, round_base_decimals)
return round_(r * b, round_base_decimals).sum(axis = 1)

def combine_bracket(self, rate, threshold_low = 0, threshold_high = False):
# Insert threshold_low and threshold_high without modifying rates
Expand All @@ -272,9 +317,29 @@ def combine_bracket(self, rate, threshold_low = 0, threshold_high = False):
self.add_bracket(self.thresholds[i], rate)
i += 1

def compute_marginal_rate(self, base, factor = 1, round_base_decimals = None):
return np.array(self.rates)[
self.compute_bracket_index(base, factor, round_base_decimals)
def compute_marginal_rate(
self, tax_bases: ndarray,
factor: float = 1.0,
round_base_decimals: Optional[int] = None,
) -> ndarray:
"""
Compute the marginal tax rate relevant for the given tax bases.
:param base: Array of the tax bases.
:param factor: Factor to apply to the thresholds of the tax scale.
:param round_base_decimals: Decimals to keep when rounding thresholds.
:returns: Float array with relevant marginal tax rate for the given tax bases.
>>> marginal_tax_scale = MarginalRateTaxScale()
>>> marginal_tax_scale.add_bracket(0, 0)
>>> marginal_tax_scale.add_bracket(100, 0.1)
>>> tax_bases = array([0, 150])
>>> marginal_tax_scale.compute_marginal_rate(tax_bases)
[0.0, 0.1]
"""
return array(self.rates)[
self.compute_bracket_index(tax_bases, factor, round_base_decimals)
]

def inverse(self):
Expand Down

0 comments on commit 89c0d85

Please sign in to comment.