Skip to content

Commit

Permalink
fix: various iterator and generator bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
nickelpro committed Oct 8, 2024
1 parent bba9d1d commit 1ee19d8
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/util/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ void init_gPO() {
gPO.meth = PyUnicode_FromString("REQUEST_METHOD");
gPO.wsgi_ver = PyTuple_Pack(2, PyLong_FromLong(1), PyLong_FromLong(0));
gPO.wsgi_input = PyUnicode_FromString("wsgi.input");
gPO.close = PyUnicode_FromString("close");
gPO.velocem_caps = PyUnicode_FromString("velocem.captures");
#define HTTP_METHOD(c, n) PyUnicode_FromString(#n),
gPO.methods = {
Expand Down
1 change: 1 addition & 0 deletions src/util/Constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ struct GlobalPythonObjects {
PyObject* meth;
PyObject* wsgi_ver;
PyObject* wsgi_input;
PyObject* close;
PyObject* velocem_caps;
std::array<PyObject*, 47> methods;
};
Expand Down
22 changes: 16 additions & 6 deletions src/util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define PY_SSIZE_T_CLEAN
#include <Python.h>

#include "Constants.hpp"

namespace velocem {

void unpack_unicode(PyObject* str, const char** base, Py_ssize_t* len,
Expand Down Expand Up @@ -120,12 +122,20 @@ void replace_key(PyObject* dict, PyObject* oldK, PyObject* newK) {
}

void close_iterator(PyObject* iter) {
PyObject* close {PyObject_GetAttrString(iter, "close")};
if(close) {
PyObject* ret {PyObject_CallNoArgs(close)};
Py_XDECREF(ret);
Py_DECREF(close);
}
if(!PyObject_HasAttr(iter, gPO.close))
return;

PyObject* close {PyObject_GetAttr(iter, gPO.close)};
if(!close)
throw std::runtime_error {"Python GetAttr error"};

PyObject* ret {PyObject_CallNoArgs(close)};
Py_DECREF(close);

if(!ret)
throw std::runtime_error {"Python close error"};

Py_DECREF(ret);
}

} // namespace velocem
13 changes: 8 additions & 5 deletions src/wsgi/App.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,13 +461,12 @@ WSGIAppRet* WSGIApp::run(WSGIRequest* req, int http_minor, int meth,
if(!iter) [[unlikely]]
throw std::runtime_error {"Python function call error"};

if(!status_) [[unlikely]] {
PyErr_SetString(PyExc_RuntimeError, "start_response() not called");
throw std::runtime_error {"WSGI application error"};
}

if(PyGen_Check(iter)) {
PyObject* first {prime_generator(iter)};
if(!status_) [[unlikely]] {
PyErr_SetString(PyExc_RuntimeError, "start_response() not called");
throw std::runtime_error {"WSGI application error"};
}
ret->conlen = build_headers(ret->buf, keepalive);

// Once we've built the headers you don't get to change them anymore
Expand All @@ -479,6 +478,10 @@ WSGIAppRet* WSGIApp::run(WSGIRequest* req, int http_minor, int meth,
insert_body_generator(ret->buf, writebuf_, iter, first, *ret->conlen);

} else {
if(!status_) [[unlikely]] {
PyErr_SetString(PyExc_RuntimeError, "start_response() not called");
throw std::runtime_error {"WSGI application error"};
}
ret->conlen = build_headers(ret->buf, keepalive);
in_handle = false;

Expand Down
78 changes: 78 additions & 0 deletions test/apps/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,84 @@ def echo(environ, start_response):
return environ['wsgi.input'].read()


@router.get('/list')
def list_(environ, start_response):
start_response('200 OK', [])
return [b'Hello', b' ', b'World']


@router.get('/tuple')
def tuple_(environ, start_response):
start_response('200 OK', [])
return (b'Hello', b' ', b'World')


@router.get('/iterator')
def iter_(environ, start_response):
class Iter:
def __init__(self):
self.val = 0

def __iter__(self):
return self

def __next__(self):
self.val += 1
match self.val:
case 1:
return b'Hello'
case 2:
return b' '
case 3:
return b'World'
case 4:
raise StopIteration

start_response('200 OK', [])
return Iter()


@router.get('/generator')
def gen(environ, start_response):
start_response('200 OK', [])
yield b'Hello'
yield b' '
yield b'World'


called_close_count = 0


@router.get('/call_close')
def call_close(environ, start_response):
start_response('200 OK', [])

class Iter:
def __init__(self):
self.val = 0

def __iter__(self):
return self

def __next__(self):
if self.val:
raise StopIteration
self.val = 1
return b'Hello World'

def close(self):
global called_close_count
called_close_count += 1

return Iter()


@router.get('/called_close')
def called_close(environ, start_response):
start_response('200 OK', [])
return f'{called_close_count}'.encode('ascii')


@router.get('/no_start_response')
def no_start_response(environ, start_response):
return b'Hello World'
Expand Down
31 changes: 29 additions & 2 deletions test/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,38 @@ def f(resp):
run_req_test(f)


def check_hello(resp):
assert resp.read() == b'Hello World'


def test_hello_world(wsgi_server):
run_req_test(check_hello, endpoint='/hello')


def test_list(wsgi_server):
run_req_test(check_hello, endpoint='/list')


def test_tuple(wsgi_server):
run_req_test(check_hello, endpoint='/tuple')


def test_iterator(wsgi_server):
run_req_test(check_hello, endpoint='/iterator')


def test_generator(wsgi_server):
run_req_test(check_hello, endpoint='/generator')


def test_call_close(wsgi_server):
run_req_test(check_hello, endpoint='/call_close')

def f(resp):
assert resp.read() == b'Hello World'
val = int(resp.read().decode('ascii'))
assert val == 10

run_req_test(f, endpoint='/hello')
run_req_test(f, endpoint='/called_close')


def test_echo(wsgi_server):
Expand Down

0 comments on commit 1ee19d8

Please sign in to comment.