Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add commit_many feature #385

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/frameworks/test_motor_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,24 @@ async def do_test():

loop.run_until_complete(do_test())

def test_update_many(self, loop, classroom_model):
Teacher = classroom_model.Teacher

async def do_test():
john = Teacher(name='John Buck', has_apple=False)
await john.commit()
jane = Teacher(name='Jane Buck', has_apple=False)
await jane.commit()
query = {"name": {"$regex": ".*Buck"}}
result = await Teacher(has_apple=True).commit_many(query)
assert result.modified_count == 2
await john.reload()
assert john.has_apple
await jane.reload()
assert jane.has_apple

loop.run_until_complete(do_test())

def test_replace(self, loop, classroom_model):
Student = classroom_model.Student

Expand Down
14 changes: 14 additions & 0 deletions tests/frameworks/test_pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,20 @@ def test_update(self, classroom_model):
with pytest.raises(exceptions.NotCreatedError):
Student(name='Joe').commit(conditions={'name': 'dummy'})

def test_update_many(self, classroom_model):
Teacher = classroom_model.Teacher
john = Teacher(name='John Buck', has_apple=False)
john.commit()
jane = Teacher(name='Jane Buck', has_apple=False)
jane.commit()
query = {"name": {"$regex": ".*Buck$"}}
result = Teacher(has_apple=True).commit_many(query)
assert result.modified_count == 2
john.reload()
assert john.has_apple
jane.reload()
assert jane.has_apple

def test_replace(self, classroom_model):
Student = classroom_model.Student
john = Student(name='John Doe', birthday=dt.datetime(1995, 12, 12))
Expand Down
15 changes: 15 additions & 0 deletions tests/frameworks/test_txmongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ def test_update(self, classroom_model):
with pytest.raises(exceptions.NotCreatedError):
yield Student(name='Joe').commit(conditions={'name': 'dummy'})

@pytest_inlineCallbacks
def test_update_many(self, classroom_model):
Teacher = classroom_model.Teacher
john = Teacher(name='John Buck', has_apple=False)
yield john.commit()
jane = Teacher(name='Jane Buck', has_apple=False)
yield jane.commit()
query = {"name": {"$regex": ".*Buck$"}}
result = yield Teacher(has_apple=True).commit_many(query)
assert result.modified_count == 2
yield john.reload()
assert john.has_apple
yield jane.reload()
assert jane.has_apple

@pytest_inlineCallbacks
def test_replace(self, classroom_model):
Student = classroom_model.Student
Expand Down
63 changes: 63 additions & 0 deletions tests/test_data_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,69 @@ class MySchema(BaseSchema):
d.get('e')['b'] = 4
assert_equal_order(d.to_mongo()['e'], {'a': 1, 'b': 4, 'c': 3})

def test_update_many_string(self):
class MySchema(BaseSchema):
field_a = fields.StringField()

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': "new_value"})
payload = d.to_mongo_update_many()
assert payload == {'$set': {'field_a': 'new_value'}}

def test_update_many_dict(self):
class MySchema(BaseSchema):
field_a = fields.DictField()

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': {'sub_field_a': 'new_value'}})
payload = d.to_mongo_update_many()
assert payload == {'$set': {'field_a.sub_field_a': 'new_value'}}

def test_update_many_deep_dict(self):
class MySchema(BaseSchema):
field_a = fields.DictField()

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': {'sub_field_a': {'sub_sub_field_a': 'new_value'}}})
payload = d.to_mongo_update_many()
assert payload == {'$set': {'field_a.sub_field_a.sub_sub_field_a': 'new_value'}}

def test_update_many_list(self):
class MySchema(BaseSchema):
field_a = fields.ListField(fields.StringField())

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': ['new_value']})
payload = d.to_mongo_update_many()
assert payload == {'$push': {'field_a': 'new_value'}}

def test_update_many_list_replace(self):
class MySchema(BaseSchema):
field_a = fields.ListField(fields.StringField())

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': ['new_value']})
payload = d.to_mongo_update_many(replace_arrays=True)
assert payload == {'$set': {'field_a': ['new_value']}}

def test_update_many_dict_list(self):
class MySchema(BaseSchema):
field_a = fields.DictField()

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': {'sub_field_a': ['new_value']}})
payload = d.to_mongo_update_many()
assert payload == {'$push': {'field_a.sub_field_a': 'new_value'}}

def test_update_many_dict_list_replace(self):
class MySchema(BaseSchema):
field_a = fields.DictField()

DataProxy = data_proxy_factory('My', MySchema())
d = DataProxy({'field_a': {'sub_field_a': ['new_value']}})
payload = d.to_mongo_update_many(replace_arrays=True)
assert payload == {'$set': {'field_a.sub_field_a': ['new_value']}}


class TestNonStrictDataProxy(BaseTest):

Expand Down
65 changes: 65 additions & 0 deletions umongo/data_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,18 @@ def _to_mongo_update(self):
mongo_data['$unset'] = {k: "" for k in unset_data}
return mongo_data or None

def to_mongo_update_many(self, replace_arrays=False):
mongo_fields = self.MongoFieldChanges(replace_arrays=replace_arrays)
for name in self.get_modified_fields():
field = self._fields[name]
name = field.attribute or name
val = field.serialize_to_mongo(self._data[name])
if val is ma.missing:
mongo_fields.unset_field_value(name)
else:
mongo_fields.set_field_value(name, val)
return mongo_fields.mongo_data

def from_mongo(self, data):
self._data = {}
for key, val in data.items():
Expand Down Expand Up @@ -191,6 +203,59 @@ def keys(self):
def values(self):
return self._data.values()

# helper class to compute update_many field names and operators
class MongoFieldChanges:
def __init__(self, replace_arrays=False):
self._replace_arrays = replace_arrays
self._set_data = {}
self._push_data = {}
self._unset_data = []

@property
def mongo_data(self):
_mongo_data = {}
if self._set_data:
_mongo_data['$set'] = self._set_data
if self._push_data:
_mongo_data['$push'] = self._push_data
if self._unset_data:
_mongo_data['$unset'] = {k: "" for k in self._unset_data}
return _mongo_data or None

def set_field_value(self, field_name, value):
field_type = type(value)
if field_type is list:
self._set_list_data(field_name, value)
elif field_type is dict:
self._set_dict_data(field_name, value)
else:
self._set_data[field_name] = value

def unset_field_value(self, field):
self._unset_data.append(field)

def _set_dict_data(self, name: str, val: dict):
for key, value in val.items():
set_name = f'{name}.{key}'
if type(value) is dict:
self._set_dict_data(set_name, value)
elif type(value) is list:
self._set_list_data(set_name, value)
else:
self._set_data[set_name] = value

def _set_list_data(self, name: str, val: list):
if self._replace_arrays:
# replacing the list value - use the $set operator
self._set_data[name] = val
else:
# adding items in val to list - if the length of the list is 1, then just $push the value
if len(val) == 1:
self._push_data[name] = val[0]
else:
# use the $each operator with the value so that each value is added to the field's array
self._push_data[name] = {'$each': val}


class BaseNonStrictDataProxy(BaseDataProxy):
"""
Expand Down
24 changes: 24 additions & 0 deletions umongo/frameworks/motor_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,30 @@ async def commit(self, io_validate_all=False, conditions=None, replace=False):
self._data.clear_modified()
return ret

async def commit_many(self, conditions=None, replace_arrays=False):
"""
Commit changes to multiple documents per conditions.
:param conditions: Only perform commit if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.UpdateError` if the
conditions are not satisfied.
:param replace_arrays: False (default) to add array elements to document using a $push operator
or True to replace fields containing arrays with the supplied value(s) using a $set operator
:return: A :class:`pymongo.results.UpdateResult`
"""
query = conditions or {}
# pre_update can provide additional query filter and/or
# modify the fields' values
additional_filter = await self.__coroutined_pre_update()
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
await self.io_validate(validate_all=False)
payload = self._data.to_mongo_update_many(replace_arrays=replace_arrays)
ret = await self.collection.update_many(query, payload, session=SESSION.get())
await self.__coroutined_post_update(ret)
self._data.clear_modified()
return ret

async def delete(self, conditions=None):
"""
Alias of :meth:`remove` to enforce default api.
Expand Down
24 changes: 24 additions & 0 deletions umongo/frameworks/pymongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,30 @@ def commit(self, io_validate_all=False, conditions=None, replace=False):
self._data.clear_modified()
return ret

def commit_many(self, conditions=None, replace_arrays=False):
"""
Commit changes to multiple documents per conditions.
:param conditions: Only perform commit if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.UpdateError` if the
conditions are not satisfied.
:param replace_arrays: False (default) to add array elements to document or True to replace fields
containing arrays with the supplied value(s)
:return: A :class:`pymongo.results.UpdateResult`
"""
query = conditions or {}
# pre_update can provide additional query filter and/or
# modify the fields' values
additional_filter = self.pre_update()
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
self.io_validate(validate_all=False)
payload = self._data.to_mongo_update_many(replace_arrays=replace_arrays)
ret = self.collection.update_many(query, payload, session=SESSION.get())
self.post_update(ret)
self._data.clear_modified()
return ret

def delete(self, conditions=None):
"""
Remove the document from database.
Expand Down
25 changes: 25 additions & 0 deletions umongo/frameworks/txmongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,31 @@ def commit(self, io_validate_all=False, conditions=None, replace=False):
self._data.clear_modified()
return ret

@inlineCallbacks
def commit_many(self, conditions=None, replace_arrays=False):
"""
Commit changes to multiple documents per conditions.
:param conditions: Only perform commit if matching record in db
satisfies condition(s) (e.g. version number).
Raises :class:`umongo.exceptions.UpdateError` if the
conditions are not satisfied.
:param replace_arrays: False (default) to add array elements to document or True to replace fields
containing arrays with the supplied value(s)
:return: A :class:`pymongo.results.UpdateResult`
"""
query = conditions or {}
# pre_update can provide additional query filter and/or
# modify the fields' values
additional_filter = yield maybeDeferred(self.pre_update)
if additional_filter:
query.update(map_query(additional_filter, self.schema.fields))
yield self.io_validate(validate_all=False)
payload = self._data.to_mongo_update_many(replace_arrays=replace_arrays)
ret = yield self.collection.update_many(query, payload)
yield maybeDeferred(self.post_update, ret)
self._data.clear_modified()
return ret

@inlineCallbacks
def delete(self, conditions=None):
"""
Expand Down