Skip to content

Commit

Permalink
Test and improve param cloning
Browse files Browse the repository at this point in the history
  • Loading branch information
fpagnoux committed Sep 22, 2019
1 parent 6101ea2 commit a74e878
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 46 deletions.
98 changes: 53 additions & 45 deletions openfisca_core/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


import copy
from typing import Iterable, Optional
from typing import Iterable, Optional, Dict, List, Union
import logging
import os
import sys
Expand Down Expand Up @@ -134,12 +134,12 @@ class Parameter(object):
"""

def __init__(self, name, data, file_path = None):
self.name = name
self.file_path = file_path
self.name: str = name
self.file_path: str = file_path
_validate_parameter(self, data, data_type = dict)
self.description = None
self.metadata = {}
self.documentation = None
self.description: str = None
self.metadata: Dict = {}
self.documentation: str = None
self.values_history = self # Only for backward compatibility

# Normal parameter declaration: the values are declared under the 'values' key: parse the description and metadata.
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(self, name, data, file_path = None):
value_at_instant = ParameterAtInstant(value_name, instant_str, data = instant_info, file_path = self.file_path, metadata = self.metadata)
values_list.append(value_at_instant)

self.values_list = values_list
self.values_list: List[ParameterAtInstant] = values_list

def __repr__(self):
return os.linesep.join([
Expand All @@ -193,17 +193,12 @@ def __call__(self, instant):
return self.get_at_instant(instant)

def clone(self):
new = empty_clone(self)
new_dict = new.__dict__

for key, value in self.__dict__.items():
if key in ['description', 'documentation', 'file_path', 'name', 'values_list']:
new_dict[key] = value
else:
new_dict[key] = copy.deepcopy(value)
clone = empty_clone(self)
clone.__dict__ = self.__dict__.copy()

new_dict['metadata'] = copy.deepcopy(self.metadata)
return new
clone.metadata = copy.deepcopy(self.metadata)
clone.values_list = [parameter_at_instant.clone() for parameter_at_instant in self.values_list]
return clone

def get_at_instant(self, instant):
instant = str(periods.instant(instant))
Expand Down Expand Up @@ -295,18 +290,18 @@ def __init__(self, name, instant_str, data = None, file_path = None, metadata =
:param string instant_str: Date of the value in the format `YYYY-MM-DD`.
:param dict data: Data, usually loaded from a YAML file.
"""
self.name = name
self.instant_str = instant_str
self.file_path = file_path
self.metadata = {}
self.name: str = name
self.instant_str: str = instant_str
self.file_path: str = file_path
self.metadata: Dict = {}

# Accept { 2015-01-01: 4000 }
if not isinstance(data, dict) and isinstance(data, ALLOWED_PARAM_TYPES):
self.value = data
return

self.validate(data)
self.value = data['value']
self.value: float = data['value']

if metadata is not None:
self.metadata.update(metadata) # Inherit metadata from Parameter
Expand Down Expand Up @@ -334,6 +329,12 @@ def __eq__(self, other):
def __repr__(self):
return "ParameterAtInstant({})".format({self.instant_str: self.value})

def clone(self):
clone = empty_clone(self)
clone.__dict__ = self.__dict__.copy()
clone.metadata = copy.deepcopy(self.metadata)
return clone


class ParameterNode(object):
"""
Expand Down Expand Up @@ -373,12 +374,12 @@ def __init__(self, name = "", directory_path = None, data = None, file_path = No
>>> node = ParameterNode('benefits', directory_path = '/path/to/country_package/parameters/benefits')
"""
self.name = name
self.children = {}
self.description = None
self.documentation = None
self.file_path = None
self.metadata = {}
self.name: str = name
self.children: Dict[str, Union[ParameterNode, Parameter, Scale]] = {}
self.description: str = None
self.documentation: str = None
self.file_path: str = None
self.metadata: Dict = {}

if directory_path:
self.file_path = directory_path
Expand Down Expand Up @@ -475,20 +476,18 @@ def get_descendants(self):
yield from child.get_descendants()

def clone(self):
new = empty_clone(self)
new_dict = new.__dict__

for key, value in self.__dict__.items():
if key not in ('children', 'metadata'):
new_dict[key] = value
clone = empty_clone(self)
clone.__dict__ = self.__dict__.copy()

children = dict()
new_dict['children'] = children
for child, node in self.children.items():
new_dict[child] = children[child] = node.clone()
clone.metadata = copy.deepcopy(self.metadata)
clone.children = {
key: child.clone()
for key, child in self.children.items()
}
for child_key, child in clone.children.items():
setattr(clone, child_key, child)

new_dict['metadata'] = copy.deepcopy(self.metadata)
return new
return clone


class ParameterNodeAtInstant(object):
Expand Down Expand Up @@ -718,11 +717,11 @@ def __init__(self, name, data, file_path):
:param data: Data loaded from a YAML file. In case of a reform, the data can also be created dynamically.
:param file_path: File the parameter was loaded from.
"""
self.name = name
self.file_path = file_path
self.name: str = name
self.file_path: str = file_path
_validate_parameter(self, data, data_type = dict, allowed_keys = self._allowed_keys)
self.description = data.get('description')
self.metadata = {}
self.description: str = data.get('description')
self.metadata: Dict = {}
_set_backward_compatibility_metadata(self, data)
self.metadata.update(data.get('metadata', {}))

Expand All @@ -738,7 +737,7 @@ def __init__(self, name, data, file_path):
bracket_name = _compose_name(name, item_name = i)
bracket = Bracket(name = bracket_name, data = bracket_data, file_path = file_path)
brackets.append(bracket)
self.brackets = brackets
self.brackets: List[Bracket] = brackets

def __call__(self, instant):
return self.get_at_instant(instant)
Expand Down Expand Up @@ -808,6 +807,15 @@ def __repr__(self):
def get_descendants(self):
return iter(())

def clone(self):
clone = empty_clone(self)
clone.__dict__ = self.__dict__.copy()

clone.brackets = [bracket.clone() for bracket in self.brackets]
clone.metadata = copy.deepcopy(self.metadata)

return clone


class Bracket(ParameterNode):
"""
Expand Down
51 changes: 50 additions & 1 deletion tests/core/parameter_validation/test_parameter_clone.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

from ..test_countries import tax_benefit_system

import os
from openfisca_core.parameters import ParameterNode
Expand All @@ -18,3 +18,52 @@ def test_clone():
assert id(clone) != id(parameters)
assert id(clone.node1) != id(parameters.node1)
assert id(clone.node1.param) != id(parameters.node1.param)


def test_clone_parameter():

param = tax_benefit_system.parameters.taxes.income_tax_rate
clone = param.clone()

assert clone is not param
assert clone.values_list is not param.values_list
assert clone.values_list[0] is not param.values_list[0]

assert clone.values_list == param.values_list


def test_clone_parameter_node():
node = tax_benefit_system.parameters.taxes
clone = node.clone()

assert clone is not node
assert clone.income_tax_rate is not node.income_tax_rate
assert clone.children['income_tax_rate'] is not node.children['income_tax_rate']


def test_clone_scale():
scale = tax_benefit_system.parameters.taxes.social_security_contribution
clone = scale.clone()

assert clone.brackets[0] is not scale.brackets[0]
assert clone.brackets[0].rate is not scale.brackets[0].rate


def test_deep_edit():
parameters = tax_benefit_system.parameters
clone = parameters.clone()

param = parameters.taxes.income_tax_rate
clone_param = clone.taxes.income_tax_rate

original_value = param.values_list[0].value
clone_param.values_list[0].value = 100
assert param.values_list[0].value == original_value

scale = parameters.taxes.social_security_contribution
clone_scale = clone.taxes.social_security_contribution

original_scale_value = scale.brackets[0].rate.values_list[0].value
clone_scale.brackets[0].rate.values_list[0].value = 10

assert scale.brackets[0].rate.values_list[0].value == original_scale_value

0 comments on commit a74e878

Please sign in to comment.