diff --git a/tests/frameworks/test_motor_asyncio.py b/tests/frameworks/test_motor_asyncio.py index 2e2e049..f709650 100644 --- a/tests/frameworks/test_motor_asyncio.py +++ b/tests/frameworks/test_motor_asyncio.py @@ -99,6 +99,24 @@ async def do_test(): loop.run_until_complete(do_test()) + def test_update_many(self, loop, classroom_model): + Teacher = classroom_model.Teacher + + async def do_test(): + john = Teacher(name='John Buck', has_apple=False) + await john.commit() + jane = Teacher(name='Jane Buck', has_apple=False) + await jane.commit() + query = {"name": {"$regex": ".*Buck"}} + result = await Teacher(has_apple=True).commit_many(query) + assert result.modified_count == 2 + await john.reload() + assert john.has_apple + await jane.reload() + assert jane.has_apple + + loop.run_until_complete(do_test()) + def test_replace(self, loop, classroom_model): Student = classroom_model.Student diff --git a/tests/frameworks/test_pymongo.py b/tests/frameworks/test_pymongo.py index bed4a92..1ab24ba 100644 --- a/tests/frameworks/test_pymongo.py +++ b/tests/frameworks/test_pymongo.py @@ -74,6 +74,20 @@ def test_update(self, classroom_model): with pytest.raises(exceptions.NotCreatedError): Student(name='Joe').commit(conditions={'name': 'dummy'}) + def test_update_many(self, classroom_model): + Teacher = classroom_model.Teacher + john = Teacher(name='John Buck', has_apple=False) + john.commit() + jane = Teacher(name='Jane Buck', has_apple=False) + jane.commit() + query = {"name": {"$regex": ".*Buck$"}} + result = Teacher(has_apple=True).commit_many(query) + assert result.modified_count == 2 + john.reload() + assert john.has_apple + jane.reload() + assert jane.has_apple + def test_replace(self, classroom_model): Student = classroom_model.Student john = Student(name='John Doe', birthday=dt.datetime(1995, 12, 12)) diff --git a/tests/frameworks/test_txmongo.py b/tests/frameworks/test_txmongo.py index 98c4b04..9aa0434 100644 --- a/tests/frameworks/test_txmongo.py +++ b/tests/frameworks/test_txmongo.py @@ -101,6 +101,21 @@ def test_update(self, classroom_model): with pytest.raises(exceptions.NotCreatedError): yield Student(name='Joe').commit(conditions={'name': 'dummy'}) + @pytest_inlineCallbacks + def test_update_many(self, classroom_model): + Teacher = classroom_model.Teacher + john = Teacher(name='John Buck', has_apple=False) + yield john.commit() + jane = Teacher(name='Jane Buck', has_apple=False) + yield jane.commit() + query = {"name": {"$regex": ".*Buck$"}} + result = yield Teacher(has_apple=True).commit_many(query) + assert result.modified_count == 2 + yield john.reload() + assert john.has_apple + yield jane.reload() + assert jane.has_apple + @pytest_inlineCallbacks def test_replace(self, classroom_model): Student = classroom_model.Student diff --git a/tests/test_data_proxy.py b/tests/test_data_proxy.py index 95d2abb..139c6ca 100644 --- a/tests/test_data_proxy.py +++ b/tests/test_data_proxy.py @@ -435,6 +435,69 @@ class MySchema(BaseSchema): d.get('e')['b'] = 4 assert_equal_order(d.to_mongo()['e'], {'a': 1, 'b': 4, 'c': 3}) + def test_update_many_string(self): + class MySchema(BaseSchema): + field_a = fields.StringField() + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': "new_value"}) + payload = d.to_mongo_update_many() + assert payload == {'$set': {'field_a': 'new_value'}} + + def test_update_many_dict(self): + class MySchema(BaseSchema): + field_a = fields.DictField() + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': {'sub_field_a': 'new_value'}}) + payload = d.to_mongo_update_many() + assert payload == {'$set': {'field_a.sub_field_a': 'new_value'}} + + def test_update_many_deep_dict(self): + class MySchema(BaseSchema): + field_a = fields.DictField() + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': {'sub_field_a': {'sub_sub_field_a': 'new_value'}}}) + payload = d.to_mongo_update_many() + assert payload == {'$set': {'field_a.sub_field_a.sub_sub_field_a': 'new_value'}} + + def test_update_many_list(self): + class MySchema(BaseSchema): + field_a = fields.ListField(fields.StringField()) + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': ['new_value']}) + payload = d.to_mongo_update_many() + assert payload == {'$push': {'field_a': 'new_value'}} + + def test_update_many_list_replace(self): + class MySchema(BaseSchema): + field_a = fields.ListField(fields.StringField()) + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': ['new_value']}) + payload = d.to_mongo_update_many(replace_arrays=True) + assert payload == {'$set': {'field_a': ['new_value']}} + + def test_update_many_dict_list(self): + class MySchema(BaseSchema): + field_a = fields.DictField() + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': {'sub_field_a': ['new_value']}}) + payload = d.to_mongo_update_many() + assert payload == {'$push': {'field_a.sub_field_a': 'new_value'}} + + def test_update_many_dict_list_replace(self): + class MySchema(BaseSchema): + field_a = fields.DictField() + + DataProxy = data_proxy_factory('My', MySchema()) + d = DataProxy({'field_a': {'sub_field_a': ['new_value']}}) + payload = d.to_mongo_update_many(replace_arrays=True) + assert payload == {'$set': {'field_a.sub_field_a': ['new_value']}} + class TestNonStrictDataProxy(BaseTest): diff --git a/umongo/data_proxy.py b/umongo/data_proxy.py index b7351e3..3b8028f 100644 --- a/umongo/data_proxy.py +++ b/umongo/data_proxy.py @@ -54,6 +54,18 @@ def _to_mongo_update(self): mongo_data['$unset'] = {k: "" for k in unset_data} return mongo_data or None + def to_mongo_update_many(self, replace_arrays=False): + mongo_fields = self.MongoFieldChanges(replace_arrays=replace_arrays) + for name in self.get_modified_fields(): + field = self._fields[name] + name = field.attribute or name + val = field.serialize_to_mongo(self._data[name]) + if val is ma.missing: + mongo_fields.unset_field_value(name) + else: + mongo_fields.set_field_value(name, val) + return mongo_fields.mongo_data + def from_mongo(self, data): self._data = {} for key, val in data.items(): @@ -191,6 +203,59 @@ def keys(self): def values(self): return self._data.values() + # helper class to compute update_many field names and operators + class MongoFieldChanges: + def __init__(self, replace_arrays=False): + self._replace_arrays = replace_arrays + self._set_data = {} + self._push_data = {} + self._unset_data = [] + + @property + def mongo_data(self): + _mongo_data = {} + if self._set_data: + _mongo_data['$set'] = self._set_data + if self._push_data: + _mongo_data['$push'] = self._push_data + if self._unset_data: + _mongo_data['$unset'] = {k: "" for k in self._unset_data} + return _mongo_data or None + + def set_field_value(self, field_name, value): + field_type = type(value) + if field_type is list: + self._set_list_data(field_name, value) + elif field_type is dict: + self._set_dict_data(field_name, value) + else: + self._set_data[field_name] = value + + def unset_field_value(self, field): + self._unset_data.append(field) + + def _set_dict_data(self, name: str, val: dict): + for key, value in val.items(): + set_name = f'{name}.{key}' + if type(value) is dict: + self._set_dict_data(set_name, value) + elif type(value) is list: + self._set_list_data(set_name, value) + else: + self._set_data[set_name] = value + + def _set_list_data(self, name: str, val: list): + if self._replace_arrays: + # replacing the list value - use the $set operator + self._set_data[name] = val + else: + # adding items in val to list - if the length of the list is 1, then just $push the value + if len(val) == 1: + self._push_data[name] = val[0] + else: + # use the $each operator with the value so that each value is added to the field's array + self._push_data[name] = {'$each': val} + class BaseNonStrictDataProxy(BaseDataProxy): """ diff --git a/umongo/frameworks/motor_asyncio.py b/umongo/frameworks/motor_asyncio.py index c944c14..905553c 100644 --- a/umongo/frameworks/motor_asyncio.py +++ b/umongo/frameworks/motor_asyncio.py @@ -205,6 +205,30 @@ async def commit(self, io_validate_all=False, conditions=None, replace=False): self._data.clear_modified() return ret + async def commit_many(self, conditions=None, replace_arrays=False): + """ + Commit changes to multiple documents per conditions. + :param conditions: Only perform commit if matching record in db + satisfies condition(s) (e.g. version number). + Raises :class:`umongo.exceptions.UpdateError` if the + conditions are not satisfied. + :param replace_arrays: False (default) to add array elements to document using a $push operator + or True to replace fields containing arrays with the supplied value(s) using a $set operator + :return: A :class:`pymongo.results.UpdateResult` + """ + query = conditions or {} + # pre_update can provide additional query filter and/or + # modify the fields' values + additional_filter = await self.__coroutined_pre_update() + if additional_filter: + query.update(map_query(additional_filter, self.schema.fields)) + await self.io_validate(validate_all=False) + payload = self._data.to_mongo_update_many(replace_arrays=replace_arrays) + ret = await self.collection.update_many(query, payload, session=SESSION.get()) + await self.__coroutined_post_update(ret) + self._data.clear_modified() + return ret + async def delete(self, conditions=None): """ Alias of :meth:`remove` to enforce default api. diff --git a/umongo/frameworks/pymongo.py b/umongo/frameworks/pymongo.py index 6feb53d..eb43d89 100644 --- a/umongo/frameworks/pymongo.py +++ b/umongo/frameworks/pymongo.py @@ -153,6 +153,30 @@ def commit(self, io_validate_all=False, conditions=None, replace=False): self._data.clear_modified() return ret + def commit_many(self, conditions=None, replace_arrays=False): + """ + Commit changes to multiple documents per conditions. + :param conditions: Only perform commit if matching record in db + satisfies condition(s) (e.g. version number). + Raises :class:`umongo.exceptions.UpdateError` if the + conditions are not satisfied. + :param replace_arrays: False (default) to add array elements to document or True to replace fields + containing arrays with the supplied value(s) + :return: A :class:`pymongo.results.UpdateResult` + """ + query = conditions or {} + # pre_update can provide additional query filter and/or + # modify the fields' values + additional_filter = self.pre_update() + if additional_filter: + query.update(map_query(additional_filter, self.schema.fields)) + self.io_validate(validate_all=False) + payload = self._data.to_mongo_update_many(replace_arrays=replace_arrays) + ret = self.collection.update_many(query, payload, session=SESSION.get()) + self.post_update(ret) + self._data.clear_modified() + return ret + def delete(self, conditions=None): """ Remove the document from database. diff --git a/umongo/frameworks/txmongo.py b/umongo/frameworks/txmongo.py index 5b44010..92fde90 100644 --- a/umongo/frameworks/txmongo.py +++ b/umongo/frameworks/txmongo.py @@ -109,6 +109,31 @@ def commit(self, io_validate_all=False, conditions=None, replace=False): self._data.clear_modified() return ret + @inlineCallbacks + def commit_many(self, conditions=None, replace_arrays=False): + """ + Commit changes to multiple documents per conditions. + :param conditions: Only perform commit if matching record in db + satisfies condition(s) (e.g. version number). + Raises :class:`umongo.exceptions.UpdateError` if the + conditions are not satisfied. + :param replace_arrays: False (default) to add array elements to document or True to replace fields + containing arrays with the supplied value(s) + :return: A :class:`pymongo.results.UpdateResult` + """ + query = conditions or {} + # pre_update can provide additional query filter and/or + # modify the fields' values + additional_filter = yield maybeDeferred(self.pre_update) + if additional_filter: + query.update(map_query(additional_filter, self.schema.fields)) + yield self.io_validate(validate_all=False) + payload = self._data.to_mongo_update_many(replace_arrays=replace_arrays) + ret = yield self.collection.update_many(query, payload) + yield maybeDeferred(self.post_update, ret) + self._data.clear_modified() + return ret + @inlineCallbacks def delete(self, conditions=None): """