Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
amorey committed Apr 9, 2024
1 parent 5f2d7be commit f15e6c6
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 22 deletions.
30 changes: 13 additions & 17 deletions starlette_wtf/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, request: StarletteRequest, *args, **kwargs):
"""
# cache request
self._request = request

# for WTForms CSRF handling
if hasattr(request.state, 'csrf_config'):
config = request.state.csrf_config
Expand All @@ -88,7 +88,6 @@ def __init__(self, request: StarletteRequest, *args, **kwargs):

super().__init__(*args, **kwargs)


@classmethod
async def from_formdata(cls, request: StarletteRequest, formdata=_Auto,
**kwargs):
Expand All @@ -111,11 +110,10 @@ async def from_formdata(cls, request: StarletteRequest, formdata=_Auto,
formdata = await get_formdata(request)
else:
formdata = None

# return new instance
return cls(request, formdata=formdata, **kwargs)


async def _validate_async(self, validator, field):
"""Execute async validator
"""
Expand All @@ -126,7 +124,6 @@ async def _validate_async(self, validator, field):
return False
return True


async def validate(self, extra_validators=None):
"""Overload :meth:`validate` to handle custom async validators
"""
Expand All @@ -136,40 +133,39 @@ async def validate(self, extra_validators=None):
extra = {}

async_validators = {}

# use extra validators to check for StopValidation errors
completed = []
def record_status(form, field):
completed.append(form._prefix + field.name)


def record_status(_, field):
completed.append(field.name)

for name, field in self._fields.items():
func = getattr(self.__class__, f"async_validate_{name}", None)
if func:
async_validators[name] = (func, field)
async_validators[field.name] = (func, field)
extra.setdefault(name, []).append(record_status)

# execute non-async validators
success = super().validate(extra_validators=extra)

# execute async validators
tasks = [self._validate_async(*async_validators[name]) for name in \
tasks = [self._validate_async(*async_validators[field_name]) for field_name in
completed]
async_results = await asyncio.gather(*tasks)

# check results
if False in async_results:
success = False

return success



def is_submitted(self):
"""Consider the form submitted if there is an active request and
the method is ``POST``, ``PUT``, ``PATCH``, or ``DELETE``.
"""
return self._request.method in SUBMIT_METHODS



async def validate_on_submit(self, extra_validators=None):
"""Call :meth:`validate` only if the form is submitted.
This is a shortcut for ``form.is_submitted() and form.validate()``.
Expand Down
56 changes: 56 additions & 0 deletions tests/test_asyncvalidators.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,59 @@ async def index(request):
app.add_route('/', methods=['POST'], route=index)

client.post('/', data={'field1': 'xxx1', 'field2': 'xxx2'})


def test_async_validator_with_prefix_success(app, client, FormWithAsyncValidators):
async def index(request):
form = await FormWithAsyncValidators.from_formdata(
request,
prefix="myprefix-",
)
assert form.field1.data == 'value1'
assert form.field2.data == 'value2'

# validate and check again
success = await form.validate()
assert success == True

# check values and errors
assert form.field1.data == 'value1'
assert 'field1' not in form.errors

assert form.field2.data == 'value2'
assert 'field2' not in form.errors

return PlainTextResponse()

app.add_route('/', methods=['POST'], route=index)

client.post('/', data={'myprefix-field1': 'value1', 'myprefix-field2': 'value2'})


def test_async_validator_with_prefix_error(app, client, FormWithAsyncValidators):
async def index(request):
form = await FormWithAsyncValidators.from_formdata(
request,
prefix="myprefix-",
)
assert form.field1.data == 'xxx1'
assert form.field2.data == 'xxx2'

# validate and check again
success = await form.validate()
assert success == False
assert form.field1.data == 'xxx1'
assert form.field2.data == 'xxx2'

# check errors
assert len(form.errors['field1']) == 1
assert form.errors['field1'][0] == 'Field value is incorrect.'

assert len(form.errors['field2']) == 1
assert form.errors['field2'][0] == 'Field value is incorrect.'

return PlainTextResponse()

app.add_route('/', methods=['POST'], route=index)

client.post('/', data={'myprefix-field1': 'xxx1', 'myprefix-field2': 'xxx2'})
16 changes: 11 additions & 5 deletions tests/test_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,15 +156,21 @@ async def index(request):
assert client.post('/', data={'name': 'value'}).text == 'False'


def test_form_with_prefix(app, client, BasicForm):
def test_validate_form_with_prefix(app, client, BasicForm):
async def index(request):
form = await BasicForm.from_formdata(
request,
prefix="myprefix-",
)
assert form.name.data == 'x'
return PlainTextResponse()

app.add_route('/', methods=['POST'], route=index)
if await form.validate_on_submit():
assert request.method == 'POST'

client.post('/', data={'myprefix-name': 'x'})
return PlainTextResponse(str('name' in form.errors))

app.add_route('/', methods=['GET', 'POST'], route=index)

assert client.get('/').text == 'False'
assert client.post('/').text == 'True'
assert client.post('/', data={'name': 'value'}).text == 'True'
assert client.post('/', data={'myprefix-name': 'value'}).text == 'False'

0 comments on commit f15e6c6

Please sign in to comment.