Skip to content

Commit

Permalink
Added PK[T] annotations and update() without WHERE clause.
Browse files Browse the repository at this point in the history
  • Loading branch information
whdev1 committed Feb 3, 2022
1 parent 0976eff commit db0c4e1
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 9 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
setup(
name = 'targa',
packages = ['targa', 'targa.errors'],
version = '1.0.0',
version = '1.0.2',
license='MIT',
description = 'A lightweight async Python library for MySQL queries and modeling.',
author = 'whdev1',
author_email = 'whdev1@protonmail.com',
url = 'https://github.com/whdev1/targa',
download_url = 'https://github.com/whdev1/targa/archive/refs/tags/v1.0.0.tar.gz',
download_url = 'https://github.com/whdev1/targa/archive/refs/tags/v1.0.2.tar.gz',
keywords = ['Targa', 'SQL', 'MySQL', 'async'],
install_requires=['aiomysql'],
classifiers=[
Expand Down
1 change: 1 addition & 0 deletions targa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .database import Database
from .keys import PK
from .model import Model
26 changes: 23 additions & 3 deletions targa/database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .keys import _PK
from .model import Model
import aiomysql
from .errors import InitializationError, MySQLErrors, SubstError
Expand Down Expand Up @@ -198,7 +199,7 @@ async def query(self, query: str, *substitutions, _ensure_conn: bool = True) ->
) for row in rows
]

async def update(self, model_inst: Model, where_clause: str) -> None:
async def update(self, model_inst: Model, where_clause: str = None) -> None:
"""
Updates the specified model in the remote database using an UPDATE statement
which includes the specified WHERE clause.
Expand All @@ -207,13 +208,32 @@ async def update(self, model_inst: Model, where_clause: str) -> None:
model_inst: Model
The model instance that should be updated in the remote database.
where_clause: str
The WHERE clause to include in the SQL statement.
where_clause: str = None
The WHERE clause to include in the SQL statement. Optional if a primary key
was annotated in this model.
Returns:
Nothing
"""

# if no WHERE clause was provided, generate one based on a provided primary key annotation
if not where_clause:
# derive a list of annotation types
annotation_types: List[type] = [type(x[1]) for x in model_inst.__annotations__.items()]

# get the index of the first primary key annotation
pk_field_index: int = annotation_types.index(_PK)

# check that a primary key annotation was actually found
pk_field: str
if pk_field_index >= 0:
pk_field = list(model_inst.__annotations__.keys())[pk_field_index]
else:
raise KeyError('A WHERE clause is required if a primary key was not annotated')

# if it was, generate the WHERE clause
where_clause = f"WHERE {pk_field} = '{self._connection.escape_string(model_inst.__dict__[pk_field])}'"

# ensure that a connection is established
await self._ensure_connection()

Expand Down
32 changes: 32 additions & 0 deletions targa/keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
class _PK:
"""
This class stands in for the _PK[T] annotation and contains
the type specified by the user.
"""

_type: type = None

def __init__(self, _type: type) -> None:
"""
Constructs a new _PK instance to be used in an annotation.
Not intended to be called directly; use a PK[T] annotation instead.
Parameters:
_type: type
The type that this key expects.
Returns:
Nothing.
"""

self._type = _type

class _PK_Factory:
"""
Factory for _PK instances. Provides the PK[T] annotation syntax.
"""

def __getitem__(self, _type: type) -> _PK:
return _PK(_type)

PK = _PK_Factory()
15 changes: 11 additions & 4 deletions targa/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Iterable, Optional
from .keys import _PK

class Model:
def __init__(self, **kwargs) -> None:
Expand All @@ -22,18 +23,24 @@ def __init__(self, **kwargs) -> None:

# loop over all of the fields defined in the derived class
for field in self.__annotations__.keys():
# get the anticipated type of the field based on its annotation
expected_type = self.__annotations__[field]

# ensure that the field was provided in the constructor and that it doesn't have
# a default value already provided
if field not in kwargs.keys() and field not in self.__dict__.keys():
raise AttributeError(
f"No value provided for field '{field}' of model {self.__class__.__name__}."
)

# type check the field based on its annotation
expected_type = self.__annotations__[field]

# check for a PK[T] annotation and unwrap one if necessary
if isinstance(expected_type, _PK):
# extract the type from the PK annotation
expected_type = expected_type._type

# check if the provided object is of the expected type
if kwargs[field].__class__ != expected_type:
# check for an Optional[x] or Union[x, None] annotation
# check for an Optional[T] or Union[T, None] annotation
if hasattr(expected_type, '__args__') and expected_type.__args__[-1] == type(None):
# if this an Optional typing, check if the provided object is None. if not,
# it is invalid
Expand Down

0 comments on commit db0c4e1

Please sign in to comment.