Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Db query #9

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
9 changes: 7 additions & 2 deletions labml_db/driver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Type, TYPE_CHECKING, Optional
from typing import List, Type, TYPE_CHECKING, Optional, Tuple

from ..types import ModelDict
from ..types import ModelDict, QueryDict, SortDict

if TYPE_CHECKING:
from .. import Serializer, Model
Expand Down Expand Up @@ -28,3 +28,8 @@ def msave_dict(self, key: List[str], data: List[ModelDict]):

def get_all(self) -> List[str]:
raise NotImplementedError

def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: Optional[SortDict],
randomize: bool = False, limit: Optional[int] = None, sort_by_text_score: bool = False) -> Tuple[
List[Tuple[str, ModelDict]], int]:
raise NotImplementedError
60 changes: 56 additions & 4 deletions labml_db/driver/mongo.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import List, Type, TYPE_CHECKING, Optional, Dict
from collections import OrderedDict
from typing import List, Type, TYPE_CHECKING, Optional, Dict, Tuple, OrderedDict

from labml_db.serializer.utils import encode_keys, decode_keys
import pymongo

from ..serializer.utils import encode_keys, decode_keys
from . import DbDriver
from ..types import ModelDict
from ..types import ModelDict, QueryDict, SortDict

if TYPE_CHECKING:
import pymongo
from ..model import Model


Expand Down Expand Up @@ -70,3 +71,54 @@ def get_all(self) -> List[str]:
cur = self._collection.find(projection=['_id'])
keys = [self._to_key(d['_id']) for d in cur]
return keys

def search(self, text_query: Optional[str], filters: Optional[QueryDict], sort: Optional[SortDict],
randomize: bool = False, limit: Optional[int] = None, sort_by_text_score: bool = False) -> Tuple[
List[Tuple[str, ModelDict]], int]:
pipeline = []

match = dict()
if filters:
for property_name, item in filters.items():
value, equal = item
if equal:
match[property_name] = value
else:
match[property_name] = {'$ne': value}
if text_query:
match['$text'] = {'$search': text_query}
if len(match) > 0:
pipeline.append({'$match': match})

if randomize:
pipeline.append({'$facet': {'data': [{'$sample': {'size': limit}}], 'count': [{'$count': 'count'}]}})
else:
sort_query = OrderedDict()
if sort_by_text_score:
sort_query['score'] = {'$meta': 'textScore'}
if sort is not None and len(sort) > 0:
for k, v in sort:
sort_query[k] = pymongo.ASCENDING if v else pymongo.DESCENDING

if len(sort_query) > 0:
pipeline.append({'$sort': sort_query})

if limit:
pipeline.append({'$facet': {'data': [{'$limit': limit}], 'count': [{'$count': 'count'}]}})

cursor = self._collection.aggregate(pipeline)
res = []
count = 0
if limit:
for item in cursor:
for c in item['count']:
count += c['count']
for d in item['data']:
res.append((d['_id'], self._load_data(d)))
else:
for d in cursor:
res.append((d['_id'], self._load_data(d)))

count = len(res)

return res, count
32 changes: 28 additions & 4 deletions labml_db/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import copy
import warnings
from typing import Generic, Union, Any
from typing import Generic, Union, Any, Tuple
from typing import TypeVar, List, Dict, Type, Set, Optional, _GenericAlias, TYPE_CHECKING

from .types import Primitive, ModelDict
from .types import Primitive, ModelDict, QueryDict, SortDict

if TYPE_CHECKING:
from .driver import DbDriver
Expand Down Expand Up @@ -192,6 +192,10 @@ def __init__(self, key: Optional[str] = None, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)

for k, v in self._defaults.items():
if k not in kwargs:
setattr(self, k, v)

def __init_subclass__(cls, **kwargs):
if cls.__name__ in Model.__models:
warnings.warn(f"{cls.__name__} already used")
Expand Down Expand Up @@ -308,8 +312,8 @@ def from_dict_transform(cls, data: ModelDict) -> Dict[str, Any]:
def to_dict(self) -> ModelDict:
values = {}
for k, v in self._values.items():
if k not in self._defaults or self._defaults[k] != v:
values[k] = v
# TODO: exclude defaults from the saved data based on a flag
values[k] = v
values = self.to_dict_transform(values)
return values

Expand Down Expand Up @@ -340,3 +344,23 @@ def __repr__(self):
kv = [f'{k}={repr(v)}' for k, v in self._values.items()]
kv = ', '.join(kv)
return f'{self.__class__.__name__}({kv})'

@classmethod
def search(cls, text_query: Optional[str] = None, filters: Optional[QueryDict] = None,
sort: Optional[SortDict] = None, randomize: bool = False, limit: Optional[int] = None,
sort_by_text_score: bool = False) -> Tuple[List[_KT], int]:
if sort is not None and len(sort) > 0 and randomize:
raise ValueError('Cannot have both randomize and sort criteria')
if limit is not None and limit <= 0:
raise ValueError('Limit should be higher than 0')
if randomize and not limit:
raise ValueError('A limit should be provided when results are randomized')
if sort_by_text_score and not text_query:
raise ValueError("Cannot search by text score when there's no text query")
if randomize and sort_by_text_score:
raise ValueError('Cannot have both randomize and sort by text score')

db_driver = Model.__db_drivers[cls.__name__]
data, total_count = db_driver.search(text_query=text_query, filters=filters, sort=sort, randomize=randomize,
limit=limit, sort_by_text_score=sort_by_text_score)
return [Model._to_model(k, d) for k, d in data], total_count
1 change: 1 addition & 0 deletions labml_db/serializer/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class JsonSerializer(Serializer):
file_extension = 'json'

def to_string(self, data: ModelDict) -> str:
assert data
return json.dumps(encode_keys(data))

def from_string(self, data: Optional[str]) -> Optional[ModelDict]:
Expand Down
4 changes: 2 additions & 2 deletions labml_db/serializer/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict

from labml_db import Key
from labml_db.types import Primitive
from .. import Key
from ..types import Primitive


def encode_key(key: Key) -> Dict[str, str]:
Expand Down
1 change: 1 addition & 0 deletions labml_db/serializer/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class YamlSerializer(Serializer):

def to_string(self, data: ModelDict) -> str:
import yaml
assert data
return yaml.dump(encode_keys(data), default_flow_style=False)

def from_string(self, data: Optional[str]) -> Optional[ModelDict]:
Expand Down
6 changes: 4 additions & 2 deletions labml_db/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Dict, Union
from typing import List, Dict, Union, Tuple

Primitive = Union[Dict[str, 'Primitive'], List['Primitive'], int, str, float, bool, None]
ModelDict = Dict[str, Primitive]

# {Property: (value, equal/not_equal)}
QueryDict = Dict[str, Tuple[Union[List['Primitive'], int, str, float, bool], bool]]
SortDict = List[Tuple[str, bool]]