Skip to content

Commit

Permalink
pythongh-105858: Improve AST node constructors
Browse files Browse the repository at this point in the history
Demonstration:

>>> ast.FunctionDef.__annotations__
{'name': <class 'str'>, 'args': <class 'ast.arguments'>, 'body': list[ast.stmt], 'decorator_list': list[ast.expr], 'returns': ast.expr | None, 'type_comment': str | None, 'type_params': list[ast.type_param]}
>>> ast.FunctionDef()
<stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'name'. This will become an error in Python 3.15.
<stdin>:1: DeprecationWarning: FunctionDef.__init__ missing 1 required positional argument: 'args'. This will become an error in Python 3.15.
<ast.FunctionDef object at 0x101959460>
>>> node = ast.FunctionDef(name="foo", args=ast.arguments())
>>> node.decorator_list
[]
>>> ast.FunctionDef(whatever="you want", name="x", args=ast.arguments())
<stdin>:1: DeprecationWarning: FunctionDef.__init__ got an unexpected keyword argument 'whatever'. Support for arbitrary keyword arguments is deprecated and will be removed in Python 3.15.
<ast.FunctionDef object at 0x1019581f0>

Known problems:
- Subclasses of AST nodes don't work properly, because we don't look up __annotations__ on the
  right class.
- Unpickling throws DeprecationWarnings, probably because of how we construct the unpickled
  object.

Need to think more about how to handle those cases.
  • Loading branch information
JelleZijlstra committed Jun 17, 2023
1 parent 34e93d3 commit b1b73d9
Show file tree
Hide file tree
Showing 4 changed files with 3,458 additions and 11 deletions.
38 changes: 35 additions & 3 deletions Lib/test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,17 @@ def test_arguments(self):
x = ast.arguments()
self.assertEqual(x._fields, ('posonlyargs', 'args', 'vararg', 'kwonlyargs',
'kw_defaults', 'kwarg', 'defaults'))

with self.assertRaises(AttributeError):
x.args
self.assertEqual(x.__annotations__, {
'posonlyargs': list[ast.arg],
'args': list[ast.arg],
'vararg': ast.arg | None,
'kwonlyargs': list[ast.arg],
'kw_defaults': list[ast.expr],
'kwarg': ast.arg | None,
'defaults': list[ast.expr],
})

self.assertEqual(x.args, [])
self.assertIsNone(x.vararg)

x = ast.arguments(*range(1, 8))
Expand Down Expand Up @@ -573,15 +581,22 @@ def test_classattrs_deprecated(self):

self.assertEqual([str(w.message) for w in wlog], [
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
"Constant.__init__ missing 1 required positional argument: 'value'. This will become "
'an error in Python 3.15.',
'Attribute n is deprecated and will be removed in Python 3.14; use value instead',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
'Attribute n is deprecated and will be removed in Python 3.14; use value instead',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
"Constant.__init__ missing 1 required positional argument: 'value'. This will become "
'an error in Python 3.15.',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
'Attribute n is deprecated and will be removed in Python 3.14; use value instead',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
"Constant.__init__ got an unexpected keyword argument 'foo'. Support for "
'arbitrary keyword arguments is deprecated and will be removed in Python '
'3.15.',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
'Attribute n is deprecated and will be removed in Python 3.14; use value instead',
'ast.Num is deprecated and will be removed in Python 3.14; use ast.Constant instead',
Expand Down Expand Up @@ -2828,6 +2843,23 @@ def visit_Call(self, node: ast.Call):
self.assertASTTransformation(PrintToLog, code, expected)


class ASTConstructorTests(unittest.TestCase):
"""Test the autogenerated constructors for AST nodes."""

def test_FunctionDef(self):
args = ast.arguments()
self.assertEqual(args.args, [])
self.assertEqual(args.posonlyargs, [])
with self.assertWarnsRegex(DeprecationWarning,
r"FunctionDef\.__init__ missing 1 required positional argument: 'name'"):
node = ast.FunctionDef(args=args)
self.assertFalse(hasattr(node, "name"))
self.assertEqual(node.decorator_list, [])
node = ast.FunctionDef(name='foo', args=args)
self.assertEqual(node.name, 'foo')
self.assertEqual(node.decorator_list, [])


@support.cpython_only
class ModuleStateTests(unittest.TestCase):
# bpo-41194, bpo-41261, bpo-41631: The _ast module uses a global state.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Improve the constructors for :mod:`ast` nodes. Arguments of list types now
default to an empty list if omitted. AST nodes now have an
``__annotations__`` attribute with the expected types of their attributes.
Passing unrecognized extra arguments to AST nodes is deprecated and will
become an error in Python 3.15. Omitting a required argument to an AST node
is deprecated and will become an error in Python 3.15. Patch by Jelle
Zijlstra.
163 changes: 159 additions & 4 deletions Parser/asdl_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
MAX_COL = 80
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"

builtin_type_to_c_type = {
"identifier": "PyUnicode_Type",
"string": "PyUnicode_Type",
"int": "PyLong_Type",
"constant": "PyBaseObject_Type",
}

def get_c_type(name):
"""Return a string for the C name of the type.
Expand Down Expand Up @@ -764,6 +771,65 @@ def visitConstructor(self, cons, name):
self.emit("};",0)


class AnnotationsVisitor(PickleVisitor):
def visitModule(self, mod):
self.file.write(textwrap.dedent('''
static int
add_ast_annotations(struct ast_state *state)
{
bool cond;
'''))
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
return 1;
}
'''))

def visitProduct(self, prod, name):
self.emit_annotations(name, prod.fields)

def visitSum(self, sum, name):
for t in sum.types:
self.visitConstructor(t, name)

def visitConstructor(self, cons, name):
self.emit_annotations(cons.name, cons.fields)

def emit_annotations(self, name, fields):
self.emit(f"PyObject *{name}_annotations = PyDict_New();", 1)
self.emit(f"if (!{name}_annotations) return 0;", 1)
for field in fields:
self.emit("{", 1)
if field.type in builtin_type_to_c_type:
self.emit(f"PyObject *type = (PyObject *)&{builtin_type_to_c_type[field.type]};", 2)
else:
self.emit(f"PyObject *type = state->{field.type}_type;", 2)
if field.opt:
self.emit("type = _Py_union_type_or(type, Py_None);", 2)
self.emit("cond = type != NULL;", 2)
self.emit_annotations_error(name, 2)
elif field.seq:
self.emit("type = Py_GenericAlias((PyObject *)&PyList_Type, type);", 2)
self.emit("cond = type != NULL;", 2)
self.emit_annotations_error(name, 2)
else:
self.emit("Py_INCREF(type);", 2)
self.emit(f"cond = PyDict_SetItemString({name}_annotations, \"{field.name}\", type) == 0;", 2)
self.emit("Py_DECREF(type);", 2)
self.emit_annotations_error(name, 2)
self.emit("}", 1)
self.emit(f'cond = PyObject_SetAttrString(state->{name}_type, "__annotations__", {name}_annotations) == 0;', 1)
self.emit(f"Py_DECREF({name}_annotations);", 1)
self.emit("if (!cond) return 0;", 1)

def emit_annotations_error(self, name, depth):
self.emit("if (!cond) {", depth)
self.emit(f"Py_DECREF({name}_annotations);", depth + 1)
self.emit("return 0;", depth + 1)
self.emit("}", depth)


class PyTypesVisitor(PickleVisitor):

def visitModule(self, mod):
Expand Down Expand Up @@ -812,7 +878,7 @@ def visitModule(self, mod):
Py_ssize_t i, numfields = 0;
int res = -1;
PyObject *key, *value, *fields;
PyObject *key, *value, *fields, *remaining_fields = NULL;
if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
goto cleanup;
}
Expand All @@ -821,6 +887,13 @@ def visitModule(self, mod):
if (numfields == -1) {
goto cleanup;
}
remaining_fields = PySet_New(fields);
}
else {
remaining_fields = PySet_New(NULL);
}
if (remaining_fields == NULL) {
goto cleanup;
}
res = 0; /* if no error occurs, this stays 0 to the end */
Expand All @@ -840,6 +913,11 @@ def visitModule(self, mod):
goto cleanup;
}
res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
if (PySet_Discard(remaining_fields, name) < 0) {
res = -1;
Py_DECREF(name);
goto cleanup;
}
Py_DECREF(name);
if (res < 0) {
goto cleanup;
Expand All @@ -852,29 +930,101 @@ def visitModule(self, mod):
if (contains == -1) {
res = -1;
goto cleanup;
} else if (contains == 1) {
Py_ssize_t p = PySequence_Index(fields, key);
}
else if (contains == 1) {
int p = PySet_Discard(remaining_fields, key);
if (p == -1) {
res = -1;
goto cleanup;
}
if (p < PyTuple_GET_SIZE(args)) {
if (p == 0) {
PyErr_Format(PyExc_TypeError,
"%.400s got multiple values for argument '%U'",
Py_TYPE(self)->tp_name, key);
res = -1;
goto cleanup;
}
}
else if (
PyUnicode_CompareWithASCIIString(key, "lineno") != 0 &&
PyUnicode_CompareWithASCIIString(key, "col_offset") != 0 &&
PyUnicode_CompareWithASCIIString(key, "end_lineno") != 0 &&
PyUnicode_CompareWithASCIIString(key, "end_col_offset") != 0
) {
if (PyErr_WarnFormat(
PyExc_DeprecationWarning, 1,
"%.400s.__init__ got an unexpected keyword argument '%U'. "
"Support for arbitrary keyword arguments is deprecated "
"and will be removed in Python 3.15.",
Py_TYPE(self)->tp_name, key
) < 0) {
goto cleanup;
}
}
res = PyObject_SetAttr(self, key, value);
if (res < 0) {
goto cleanup;
}
}
}
Py_ssize_t size = PySet_Size(remaining_fields);
PyObject *annotations = NULL, *remaining_list = NULL;
if (size > 0) {
if (!_PyObject_LookupAttr((PyObject*)Py_TYPE(self), &_Py_ID(__annotations__), &annotations)) {
res = -1;
goto cleanup;
}
remaining_list = PySequence_List(remaining_fields);
if (!remaining_list) {
goto set_remaining_cleanup;
}
for (Py_ssize_t i = 0; i < size; i++) {
PyObject *name = PyList_GET_ITEM(remaining_list, i);
PyObject *type = PyDict_GetItemWithError(annotations, name);
if (!type) {
if (!PyErr_Occurred()) {
PyErr_SetObject(PyExc_KeyError, name);
}
goto set_remaining_cleanup;
}
if (_PyUnion_Check(type)) {
// optional field
// do nothing, we'll have set a None default on the class
}
else if (Py_IS_TYPE(type, &Py_GenericAliasType)) {
// list field
PyObject *empty = PyList_New(0);
if (!empty) {
goto set_remaining_cleanup;
}
res = PyObject_SetAttr(self, name, empty);
Py_DECREF(empty);
if (res < 0) {
goto set_remaining_cleanup;
}
}
else {
// simple field (e.g., identifier)
if (PyErr_WarnFormat(
PyExc_DeprecationWarning, 1,
"%.400s.__init__ missing 1 required positional argument: '%U'. "
"This will become an error in Python 3.15.",
Py_TYPE(self)->tp_name, name
) < 0) {
goto cleanup;
}
}
}
}
cleanup:
Py_XDECREF(fields);
Py_XDECREF(remaining_fields);
return res;
set_remaining_cleanup:
Py_XDECREF(remaining_list);
Py_XDECREF(annotations);
res = -1;
goto cleanup;
}
/* Pickling support */
Expand Down Expand Up @@ -1123,6 +1273,9 @@ def visitModule(self, mod):
for dfn in mod.dfns:
self.visit(dfn)
self.file.write(textwrap.dedent('''
if (!add_ast_annotations(state)) {
return 0;
}
state->recursion_depth = 0;
state->recursion_limit = 0;
state->initialized = 1;
Expand Down Expand Up @@ -1542,6 +1695,7 @@ def generate_module_def(mod, metadata, f, internal_h):
#include "pycore_ceval.h" // _Py_EnterRecursiveCall
#include "pycore_interp.h" // _PyInterpreterState.ast
#include "pycore_pystate.h" // _PyInterpreterState_GET()
#include "pycore_unionobject.h" // _Py_union_type_or
#include "structmember.h"
#include <stddef.h>
Expand Down Expand Up @@ -1651,6 +1805,7 @@ def write_source(mod, metadata, f, internal_h_file):
v = ChainOfVisitors(
SequenceConstructorVisitor(f),
PyTypesDeclareVisitor(f),
AnnotationsVisitor(f),
PyTypesVisitor(f),
Obj2ModPrototypeVisitor(f),
FunctionVisitor(f),
Expand Down
Loading

0 comments on commit b1b73d9

Please sign in to comment.