Skip to content

Commit

Permalink
mypy: stricter settings (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshData authored May 10, 2024
2 parents 4691a62 + 380e44e commit a9a8a62
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_and_build.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Tests

on: [push]
on: [push, pull_request]

jobs:
build:
Expand Down
12 changes: 8 additions & 4 deletions email_validator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

# Export the main method, helper methods, and the public data types.
from .exceptions_types import ValidatedEmail, EmailNotValidError, \
EmailSyntaxError, EmailUndeliverableError
Expand All @@ -9,12 +11,14 @@
"EmailSyntaxError", "EmailUndeliverableError",
"caching_resolver", "__version__"]


def caching_resolver(*args, **kwargs):
# Lazy load `deliverability` as it is slow to import (due to dns.resolver)
if TYPE_CHECKING:
from .deliverability import caching_resolver
else:
def caching_resolver(*args, **kwargs):
# Lazy load `deliverability` as it is slow to import (due to dns.resolver)
from .deliverability import caching_resolver

return caching_resolver(*args, **kwargs)
return caching_resolver(*args, **kwargs)


# These global attributes are a part of the library's API and can be
Expand Down
7 changes: 4 additions & 3 deletions email_validator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@
import json
import os
import sys
from typing import Any, Dict, Optional

from .validate_email import validate_email
from .validate_email import validate_email, _Resolver
from .deliverability import caching_resolver
from .exceptions_types import EmailNotValidError


def main(dns_resolver=None):
def main(dns_resolver: Optional[_Resolver] = None) -> None:
# The dns_resolver argument is for tests.

# Set options from environment variables.
options = {}
options: Dict[str, Any] = {}
for varname in ('ALLOW_SMTPUTF8', 'ALLOW_QUOTED_LOCAL', 'ALLOW_DOMAIN_LITERAL',
'GLOBALLY_DELIVERABLE', 'CHECK_DELIVERABILITY', 'TEST_ENVIRONMENT'):
if varname in os.environ:
Expand Down
24 changes: 15 additions & 9 deletions email_validator/deliverability.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Any, Dict
from typing import Any, List, Optional, Tuple, TypedDict

import ipaddress

Expand All @@ -8,17 +8,24 @@
import dns.exception


def caching_resolver(*, timeout: Optional[int] = None, cache=None, dns_resolver=None):
def caching_resolver(*, timeout: Optional[int] = None, cache: Any = None, dns_resolver: Optional[dns.resolver.Resolver] = None) -> dns.resolver.Resolver:
if timeout is None:
from . import DEFAULT_TIMEOUT
timeout = DEFAULT_TIMEOUT
resolver = dns_resolver or dns.resolver.Resolver()
resolver.cache = cache or dns.resolver.LRUCache() # type: ignore
resolver.lifetime = timeout # type: ignore # timeout, in seconds
resolver.cache = cache or dns.resolver.LRUCache()
resolver.lifetime = timeout # timeout, in seconds
return resolver


def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Optional[int] = None, dns_resolver=None):
DeliverabilityInfo = TypedDict("DeliverabilityInfo", {
"mx": List[Tuple[int, str]],
"mx_fallback_type": Optional[str],
"unknown-deliverability": str,
}, total=False)


def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Optional[int] = None, dns_resolver: Optional[dns.resolver.Resolver] = None) -> DeliverabilityInfo:
# Check that the domain resolves to an MX record. If there is no MX record,
# try an A or AAAA record which is a deprecated fallback for deliverability.
# Raises an EmailUndeliverableError on failure. On success, returns a dict
Expand All @@ -36,7 +43,7 @@ def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Option
elif timeout is not None:
raise ValueError("It's not valid to pass both timeout and dns_resolver.")

deliverability_info: Dict[str, Any] = {}
deliverability_info: DeliverabilityInfo = {}

try:
try:
Expand Down Expand Up @@ -69,9 +76,9 @@ def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Option
# https://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
# https://www.iana.org/assignments/iana-ipv6-special-registry/iana-ipv6-special-registry.xhtml
# (Issue #134.)
def is_global_addr(ipaddr):
def is_global_addr(address: Any) -> bool:
try:
ipaddr = ipaddress.ip_address(ipaddr)
ipaddr = ipaddress.ip_address(address)
except ValueError:
return False
return ipaddr.is_global
Expand Down Expand Up @@ -115,7 +122,6 @@ def is_global_addr(ipaddr):
for rec in response:
value = b"".join(rec.strings)
if value.startswith(b"v=spf1 "):
deliverability_info["spf"] = value.decode("ascii", errors='replace')
if value == b"v=spf1 -all":
raise EmailUndeliverableError(f"The domain name {domain_i18n} does not send email.")
except dns.resolver.NoAnswer:
Expand Down
29 changes: 12 additions & 17 deletions email_validator/exceptions_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional
from typing import Any, Dict, List, Optional, Tuple, Union


class EmailNotValidError(ValueError):
Expand All @@ -24,7 +24,7 @@ class ValidatedEmail:
"""The email address that was passed to validate_email. (If passed as bytes, this will be a string.)"""
original: str

"""The normalized email address, which should always be used in preferance to the original address.
"""The normalized email address, which should always be used in preference to the original address.
The normalized address converts an IDNA ASCII domain name to Unicode, if possible, and performs
Unicode normalization on the local part and on the domain (if originally Unicode). It is the
concatenation of the local_part and domain attributes, separated by an @-sign."""
Expand Down Expand Up @@ -56,39 +56,34 @@ class ValidatedEmail:

"""If a deliverability check is performed and if it succeeds, a list of (priority, domain)
tuples of MX records specified in the DNS for the domain."""
mx: list
mx: List[Tuple[int, str]]

"""If no MX records are actually specified in DNS and instead are inferred, through an obsolete
mechanism, from A or AAAA records, the value is the type of DNS record used instead (`A` or `AAAA`)."""
mx_fallback_type: str
mx_fallback_type: Optional[str]

"""The display name in the original input text, unquoted and unescaped, or None."""
display_name: str
display_name: Optional[str]

"""Tests use this constructor."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

def __repr__(self):
def __repr__(self) -> str:
return f"<ValidatedEmail {self.normalized}>"

"""For backwards compatibility, support old field names."""
def __getattr__(self, key):
def __getattr__(self, key: str) -> str:
if key == "original_email":
return self.original
if key == "email":
return self.normalized
raise AttributeError(key)

@property
def email(self):
def email(self) -> str:
warnings.warn("ValidatedEmail.email is deprecated and will be removed, use ValidatedEmail.normalized instead", DeprecationWarning)
return self.normalized

"""For backwards compatibility, some fields are also exposed through a dict-like interface. Note
that some of the names changed when they became attributes."""
def __getitem__(self, key):
def __getitem__(self, key: str) -> Union[Optional[str], bool, List[Tuple[int, str]]]:
warnings.warn("dict-like access to the return value of validate_email is deprecated and may not be supported in the future.", DeprecationWarning, stacklevel=2)
if key == "email":
return self.normalized
Expand All @@ -109,7 +104,7 @@ def __getitem__(self, key):
raise KeyError()

"""Tests use this."""
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, ValidatedEmail):
return False
return (
Expand All @@ -127,7 +122,7 @@ def __eq__(self, other):
)

"""This helps producing the README."""
def as_constructor(self):
def as_constructor(self) -> str:
return "ValidatedEmail(" \
+ ",".join(f"\n {key}={repr(getattr(self, key))}"
for key in ('normalized', 'local_part', 'domain',
Expand All @@ -139,7 +134,7 @@ def as_constructor(self):
+ ")"

"""Convenience method for accessing ValidatedEmail as a dict"""
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
d = self.__dict__
if d.get('domain_address'):
d['domain_address'] = repr(d['domain_address'])
Expand Down
44 changes: 31 additions & 13 deletions email_validator/syntax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .exceptions_types import EmailSyntaxError
from .exceptions_types import EmailSyntaxError, ValidatedEmail
from .rfc_constants import EMAIL_MAX_LENGTH, LOCAL_PART_MAX_LENGTH, DOMAIN_MAX_LENGTH, \
DOT_ATOM_TEXT, DOT_ATOM_TEXT_INTL, ATEXT_RE, ATEXT_INTL_DOT_RE, ATEXT_HOSTNAME_INTL, QTEXT_INTL, \
DNS_LABEL_LENGTH_LIMIT, DOT_ATOM_TEXT_HOSTNAME, DOMAIN_NAME_REGEX, DOMAIN_LITERAL_CHARS
Expand All @@ -7,10 +7,10 @@
import unicodedata
import idna # implements IDNA 2008; Python's codec is only IDNA 2003
import ipaddress
from typing import Optional
from typing import Optional, Tuple, TypedDict, Union


def split_email(email):
def split_email(email: str) -> Tuple[Optional[str], str, str, bool]:
# Return the display name, unescaped local part, and domain part
# of the address, and whether the local part was quoted. If no
# display name was present and angle brackets do not surround
Expand Down Expand Up @@ -46,7 +46,7 @@ def split_email(email):
# We assume the input string is already stripped of leading and
# trailing CFWS.

def split_string_at_unquoted_special(text, specials):
def split_string_at_unquoted_special(text: str, specials: Tuple[str, ...]) -> Tuple[str, str]:
# Split the string at the first character in specials (an @-sign
# or left angle bracket) that does not occur within quotes.
inside_quote = False
Expand Down Expand Up @@ -77,7 +77,7 @@ def split_string_at_unquoted_special(text, specials):

return left_part, right_part

def unquote_quoted_string(text):
def unquote_quoted_string(text: str) -> Tuple[str, bool]:
# Remove surrounding quotes and unescape escaped backslashes
# and quotes. Escapes are parsed liberally. I think only
# backslashes and quotes can be escaped but we'll allow anything
Expand Down Expand Up @@ -155,15 +155,15 @@ def unquote_quoted_string(text):
return display_name, local_part, domain_part, is_quoted_local_part


def get_length_reason(addr, utf8=False, limit=EMAIL_MAX_LENGTH):
def get_length_reason(addr: str, utf8: bool = False, limit: int = EMAIL_MAX_LENGTH) -> str:
"""Helper function to return an error message related to invalid length."""
diff = len(addr) - limit
prefix = "at least " if utf8 else ""
suffix = "s" if diff > 1 else ""
return f"({prefix}{diff} character{suffix} too many)"


def safe_character_display(c):
def safe_character_display(c: str) -> str:
# Return safely displayable characters in quotes.
if c == '\\':
return f"\"{c}\"" # can't use repr because it escapes it
Expand All @@ -180,8 +180,14 @@ def safe_character_display(c):
return unicodedata.name(c, h)


class LocalPartValidationResult(TypedDict):
local_part: str
ascii_local_part: Optional[str]
smtputf8: bool


def validate_email_local_part(local: str, allow_smtputf8: bool = True, allow_empty_local: bool = False,
quoted_local_part: bool = False):
quoted_local_part: bool = False) -> LocalPartValidationResult:
"""Validates the syntax of the local part of an email address."""

if len(local) == 0:
Expand Down Expand Up @@ -345,7 +351,7 @@ def validate_email_local_part(local: str, allow_smtputf8: bool = True, allow_emp
raise EmailSyntaxError("The email address contains invalid characters before the @-sign.")


def check_unsafe_chars(s, allow_space=False):
def check_unsafe_chars(s: str, allow_space: bool = False) -> None:
# Check for unsafe characters or characters that would make the string
# invalid or non-sensible Unicode.
bad_chars = set()
Expand Down Expand Up @@ -397,7 +403,7 @@ def check_unsafe_chars(s, allow_space=False):
+ ", ".join(safe_character_display(c) for c in sorted(bad_chars)) + ".")


def check_dot_atom(label, start_descr, end_descr, is_hostname):
def check_dot_atom(label: str, start_descr: str, end_descr: str, is_hostname: bool) -> None:
# RFC 5322 3.2.3
if label.endswith("."):
raise EmailSyntaxError(end_descr.format("period"))
Expand All @@ -416,7 +422,12 @@ def check_dot_atom(label, start_descr, end_descr, is_hostname):
raise EmailSyntaxError("An email address cannot have a period and a hyphen next to each other.")


def validate_email_domain_name(domain, test_environment=False, globally_deliverable=True):
class DomainNameValidationResult(TypedDict):
ascii_domain: str
domain: str


def validate_email_domain_name(domain: str, test_environment: bool = False, globally_deliverable: bool = True) -> DomainNameValidationResult:
"""Validates the syntax of the domain part of an email address."""

# Check for invalid characters before normalization.
Expand Down Expand Up @@ -580,7 +591,7 @@ def validate_email_domain_name(domain, test_environment=False, globally_delivera
}


def validate_email_length(addrinfo):
def validate_email_length(addrinfo: ValidatedEmail) -> None:
# If the email address has an ASCII representation, then we assume it may be
# transmitted in ASCII (we can't assume SMTPUTF8 will be used on all hops to
# the destination) and the length limit applies to ASCII characters (which is
Expand Down Expand Up @@ -621,11 +632,18 @@ def validate_email_length(addrinfo):
raise EmailSyntaxError(f"The email address is too long {reason}.")


def validate_email_domain_literal(domain_literal):
class DomainLiteralValidationResult(TypedDict):
domain_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
domain: str


def validate_email_domain_literal(domain_literal: str) -> DomainLiteralValidationResult:
# This is obscure domain-literal syntax. Parse it and return
# a compressed/normalized address.
# RFC 5321 4.1.3 and RFC 5322 3.4.1.

addr: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]

# Try to parse the domain literal as an IPv4 address.
# There is no tag for IPv4 addresses, so we can never
# be sure if the user intends an IPv4 address.
Expand Down
Loading

0 comments on commit a9a8a62

Please sign in to comment.