Skip to content

Commit

Permalink
[red-knot] Infer Unknown for the loop var in async for loops (#13243
Browse files Browse the repository at this point in the history
)
  • Loading branch information
AlexWaygood authored Sep 4, 2024
1 parent 0512428 commit e965f9c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ where
ForStmtDefinitionNodeRef {
iterable: &node.iter,
target: name_node,
is_async: node.is_async,
},
);
}
Expand Down
22 changes: 16 additions & 6 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) is_async: bool,
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -206,12 +207,15 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::AugmentedAssignment(augmented_assignment) => {
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
}
DefinitionNodeRef::For(ForStmtDefinitionNodeRef { iterable, target }) => {
DefinitionKind::For(ForStmtDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::For(ForStmtDefinitionNodeRef {
iterable,
target,
is_async,
}) => DefinitionKind::For(ForStmtDefinitionKind {
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
is_async,
}),
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
iterable,
target,
Expand Down Expand Up @@ -265,6 +269,7 @@ impl DefinitionNodeRef<'_> {
Self::For(ForStmtDefinitionNodeRef {
iterable: _,
target,
is_async: _,
}) => target.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::Parameter(node) => match node {
Expand Down Expand Up @@ -388,6 +393,7 @@ impl WithItemDefinitionKind {
pub struct ForStmtDefinitionKind {
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
is_async: bool,
}

impl ForStmtDefinitionKind {
Expand All @@ -398,6 +404,10 @@ impl ForStmtDefinitionKind {
pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}

pub(crate) fn is_async(&self) -> bool {
self.is_async
}
}

#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
Expand Down
69 changes: 66 additions & 3 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_for_statement_definition(
for_statement_definition.target(),
for_statement_definition.iterable(),
for_statement_definition.is_async(),
definition,
);
}
Expand Down Expand Up @@ -1045,6 +1046,7 @@ impl<'db> TypeInferenceBuilder<'db> {
&mut self,
target: &ast::ExprName,
iterable: &ast::Expr,
is_async: bool,
definition: Definition<'db>,
) {
let expression = self.index.expression(iterable);
Expand All @@ -1054,9 +1056,14 @@ impl<'db> TypeInferenceBuilder<'db> {
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));

let loop_var_value_ty = iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self);
let loop_var_value_ty = if is_async {
// TODO(Alex): async iterables/iterators!
Type::Unknown
} else {
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self)
};

self.types
.expressions
Expand Down Expand Up @@ -3026,6 +3033,62 @@ mod tests {
Ok(())
}

/// This tests that we understand that `async` for loops
/// do not work according to the synchronous iteration protocol
#[test]
fn invalid_async_for_loop() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
async def foo():
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
return Iterator()
async for x in Iterator():
pass
",
)?;

// TODO(Alex) async iterables/iterators!
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");

Ok(())
}

#[test]
fn basic_async_for_loop() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
async def foo():
class IntAsyncIterator:
async def __anext__(self) -> int:
return 42
class IntAsyncIterable:
def __aiter__(self) -> IntAsyncIterator:
return IntAsyncIterator()
async for x in IntAsyncIterable():
pass
",
)?;

// TODO(Alex) async iterables/iterators!
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");

Ok(())
}

#[test]
fn class_constructor_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit e965f9c

Please sign in to comment.