Skip to content

Commit

Permalink
Format with Black.
Browse files Browse the repository at this point in the history
  • Loading branch information
m3nu committed Aug 21, 2019
1 parent 2f7d47a commit 3525e4a
Show file tree
Hide file tree
Showing 13 changed files with 286 additions and 236 deletions.
120 changes: 64 additions & 56 deletions src/invoice2data/extract/invoice_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
logger = logging.getLogger(__name__)

OPTIONS_DEFAULT = {
'remove_whitespace': False,
'remove_accents': False,
'lowercase': False,
'currency': 'EUR',
'date_formats': [],
'languages': [],
'decimal_separator': '.',
'replace': [], # example: see templates/fr/fr.free.mobile.yml
"remove_whitespace": False,
"remove_accents": False,
"lowercase": False,
"currency": "EUR",
"date_formats": [],
"languages": [],
"decimal_separator": ".",
"replace": [], # example: see templates/fr/fr.free.mobile.yml
}

PLUGIN_MAPPING = {'lines': lines, 'tables': tables}
PLUGIN_MAPPING = {"lines": lines, "tables": tables}


class InvoiceTemplate(OrderedDict):
Expand Down Expand Up @@ -53,111 +53,115 @@ def __init__(self, *args, **kwargs):
# Merge template-specific options with defaults
self.options = OPTIONS_DEFAULT.copy()

for lang in self.options['languages']:
assert len(lang) == 2, 'lang code must have 2 letters'
for lang in self.options["languages"]:
assert len(lang) == 2, "lang code must have 2 letters"

if 'options' in self:
self.options.update(self['options'])
if "options" in self:
self.options.update(self["options"])

# Set issuer, if it doesn't exist.
if 'issuer' not in self.keys():
self['issuer'] = self['keywords'][0]
if "issuer" not in self.keys():
self["issuer"] = self["keywords"][0]

def prepare_input(self, extracted_str):
"""
Input raw string and do transformations, as set in template file.
"""

# Remove withspace
if self.options['remove_whitespace']:
optimized_str = re.sub(' +', '', extracted_str)
if self.options["remove_whitespace"]:
optimized_str = re.sub(" +", "", extracted_str)
else:
optimized_str = extracted_str

# Remove accents
if self.options['remove_accents']:
if self.options["remove_accents"]:
optimized_str = unidecode(optimized_str)

# convert to lower case
if self.options['lowercase']:
if self.options["lowercase"]:
optimized_str = optimized_str.lower()

# specific replace
for replace in self.options['replace']:
assert len(replace) == 2, 'A replace should be a list of 2 items'
for replace in self.options["replace"]:
assert len(replace) == 2, "A replace should be a list of 2 items"
optimized_str = optimized_str.replace(replace[0], replace[1])

return optimized_str

def matches_input(self, optimized_str):
"""See if string matches keywords set in template file"""

if all([keyword in optimized_str for keyword in self['keywords']]):
logger.debug('Matched template %s', self['template_name'])
if all([keyword in optimized_str for keyword in self["keywords"]]):
logger.debug("Matched template %s", self["template_name"])
return True

def parse_number(self, value):
assert (
value.count(self.options['decimal_separator']) < 2
), 'Decimal separator cannot be present several times'
value.count(self.options["decimal_separator"]) < 2
), "Decimal separator cannot be present several times"
# replace decimal separator by a |
amount_pipe = value.replace(self.options['decimal_separator'], '|')
amount_pipe = value.replace(self.options["decimal_separator"], "|")
# remove all possible thousands separators
amount_pipe_no_thousand_sep = re.sub(r'[.,\s]', '', amount_pipe)
amount_pipe_no_thousand_sep = re.sub(r"[.,\s]", "", amount_pipe)
# put dot as decimal sep
return float(amount_pipe_no_thousand_sep.replace('|', '.'))
return float(amount_pipe_no_thousand_sep.replace("|", "."))

def parse_date(self, value):
"""Parses date and returns date after parsing"""
res = dateparser.parse(
value, date_formats=self.options['date_formats'], languages=self.options['languages']
value,
date_formats=self.options["date_formats"],
languages=self.options["languages"],
)
logger.debug("result of date parsing=%s", res)
return res

def coerce_type(self, value, target_type):
if target_type == 'int':
if target_type == "int":
if not value.strip():
return 0
return int(self.parse_number(value))
elif target_type == 'float':
elif target_type == "float":
if not value.strip():
return 0.0
return float(self.parse_number(value))
elif target_type == 'date':
elif target_type == "date":
return self.parse_date(value)
assert False, 'Unknown type'
assert False, "Unknown type"

def extract(self, optimized_str):
"""
Given a template file and a string, extract matching data fields.
"""

logger.debug('START optimized_str ========================')
logger.debug("START optimized_str ========================")
logger.debug(optimized_str)
logger.debug('END optimized_str ==========================')
logger.debug("END optimized_str ==========================")
logger.debug(
'Date parsing: languages=%s date_formats=%s',
self.options['languages'],
self.options['date_formats'],
"Date parsing: languages=%s date_formats=%s",
self.options["languages"],
self.options["date_formats"],
)
logger.debug('Float parsing: decimal separator=%s', self.options['decimal_separator'])
logger.debug("keywords=%s", self['keywords'])
logger.debug(
"Float parsing: decimal separator=%s", self.options["decimal_separator"]
)
logger.debug("keywords=%s", self["keywords"])
logger.debug(self.options)

# Try to find data for each field.
output = {}
output['issuer'] = self['issuer']
output["issuer"] = self["issuer"]

for k, v in self['fields'].items():
if k.startswith('static_'):
for k, v in self["fields"].items():
if k.startswith("static_"):
logger.debug("field=%s | static value=%s", k, v)
output[k.replace('static_', '')] = v
output[k.replace("static_", "")] = v
else:
logger.debug("field=%s | regexp=%s", k, v)

sum_field = False
if k.startswith('sum_amount') and type(v) is list:
if k.startswith("sum_amount") and type(v) is list:
k = k[4:] # remove 'sum_' prefix
sum_field = True
# Fields can have multiple expressions
Expand All @@ -174,12 +178,14 @@ def extract(self, optimized_str):
res_find = re.findall(v, optimized_str)
if res_find:
logger.debug("res_find=%s", res_find)
if k.startswith('date') or k.endswith('date'):
if k.startswith("date") or k.endswith("date"):
output[k] = self.parse_date(res_find[0])
if not output[k]:
logger.error("Date parsing failed on date '%s'", res_find[0])
logger.error(
"Date parsing failed on date '%s'", res_find[0]
)
return None
elif k.startswith('amount'):
elif k.startswith("amount"):
if sum_field:
output[k] = 0
for amount_to_parse in res_find:
Expand All @@ -195,30 +201,32 @@ def extract(self, optimized_str):
else:
logger.warning("regexp for field %s didn't match", k)

output['currency'] = self.options['currency']
output["currency"] = self.options["currency"]

# Run plugins:
for plugin_keyword, plugin_func in PLUGIN_MAPPING.items():
if plugin_keyword in self.keys():
plugin_func.extract(self, optimized_str, output)

# If required fields were found, return output, else log error.
if 'required_fields' not in self.keys():
required_fields = ['date', 'amount', 'invoice_number', 'issuer']
if "required_fields" not in self.keys():
required_fields = ["date", "amount", "invoice_number", "issuer"]
else:
required_fields = []
for v in self['required_fields']:
for v in self["required_fields"]:
required_fields.append(v)

if set(required_fields).issubset(output.keys()):
output['desc'] = 'Invoice from %s' % (self['issuer'])
output["desc"] = "Invoice from %s" % (self["issuer"])
logger.debug(output)
return output
else:
fields = list(set(output.keys()))
logger.error(
'Unable to match all required fields. '
'The required fields are: {0}. '
'Output contains the following fields: {1}.'.format(required_fields, fields)
"Unable to match all required fields. "
"The required fields are: {0}. "
"Output contains the following fields: {1}.".format(
required_fields, fields
)
)
return None
26 changes: 15 additions & 11 deletions src/invoice2data/extract/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import codecs
import chardet

logging.getLogger('chardet').setLevel(logging.WARNING)
logging.getLogger("chardet").setLevel(logging.WARNING)


# borrowed from http://stackoverflow.com/a/21912744
Expand All @@ -31,7 +31,9 @@ def construct_mapping(loader, node):
loader.flatten_mapping(node)
return object_pairs_hook(loader.construct_pairs(node))

OrderedLoader.add_constructor(yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping
)

return yaml.load(stream, OrderedLoader)

Expand Down Expand Up @@ -77,23 +79,25 @@ def read_templates(folder=None):
output = []

if folder is None:
folder = pkg_resources.resource_filename(__name__, 'templates')
folder = pkg_resources.resource_filename(__name__, "templates")

for path, subdirs, files in os.walk(folder):
for name in sorted(files):
if name.endswith('.yml'):
with open(os.path.join(path, name), 'rb') as f:
encoding = chardet.detect(f.read())['encoding']
with codecs.open(os.path.join(path, name), encoding=encoding) as template_file:
if name.endswith(".yml"):
with open(os.path.join(path, name), "rb") as f:
encoding = chardet.detect(f.read())["encoding"]
with codecs.open(
os.path.join(path, name), encoding=encoding
) as template_file:
tpl = ordered_load(template_file.read())
tpl['template_name'] = name
tpl["template_name"] = name

# Test if all required fields are in template:
assert 'keywords' in tpl.keys(), 'Missing keywords field.'
assert "keywords" in tpl.keys(), "Missing keywords field."

# Keywords as list, if only one.
if type(tpl['keywords']) is not list:
tpl['keywords'] = [tpl['keywords']]
if type(tpl["keywords"]) is not list:
tpl["keywords"] = [tpl["keywords"]]

output.append(InvoiceTemplate(tpl))
return output
Loading

0 comments on commit 3525e4a

Please sign in to comment.