diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 33a22a378d555..92921eaa11af7 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -380,6 +380,10 @@ impl<'db> Type<'db> { } } + pub fn builtin_str(db: &'db dyn Db) -> Self { + builtins_symbol_ty(db, "str") + } + pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool { match self { Type::Class(class) => class.is_stdlib_symbol(db, module_name, name), @@ -721,6 +725,44 @@ impl<'db> Type<'db> { Type::Tuple(_) => builtins_symbol_ty(db, "tuple"), } } + + /// Return the string representation of this type when converted to string as it would be + /// provided by the `__str__` method. + /// + /// When not available, this should fall back to the value of `[Type::repr]`. + /// Note: this method is used in the builtins `format`, `print`, `str.format` and `f-strings`. + #[must_use] + pub fn str(&self, db: &'db dyn Db) -> Type<'db> { + match self { + Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db), + Type::StringLiteral(_) | Type::LiteralString => *self, + // TODO: handle more complex types + _ => Type::builtin_str(db).to_instance(db), + } + } + + /// Return the string representation of this type as it would be provided by the `__repr__` + /// method at runtime. + #[must_use] + pub fn repr(&self, db: &'db dyn Db) -> Type<'db> { + match self { + Type::IntLiteral(number) => Type::StringLiteral(StringLiteralType::new(db, { + number.to_string().into_boxed_str() + })), + Type::BooleanLiteral(true) => { + Type::StringLiteral(StringLiteralType::new(db, "True".into())) + } + Type::BooleanLiteral(false) => { + Type::StringLiteral(StringLiteralType::new(db, "False".into())) + } + Type::StringLiteral(literal) => Type::StringLiteral(StringLiteralType::new(db, { + format!("'{}'", literal.value(db).escape_default()).into() + })), + Type::LiteralString => Type::LiteralString, + // TODO: handle more complex types + _ => Type::builtin_str(db).to_instance(db), + } + } } impl<'db> From<&Type<'db>> for Type<'db> { @@ -1198,12 +1240,13 @@ mod tests { /// A test representation of a type that can be transformed unambiguously into a real Type, /// given a db. - #[derive(Debug)] + #[derive(Debug, Clone)] enum Ty { Never, Unknown, Any, IntLiteral(i64), + BoolLiteral(bool), StringLiteral(&'static str), LiteralString, BytesLiteral(&'static str), @@ -1222,6 +1265,7 @@ mod tests { Ty::StringLiteral(s) => { Type::StringLiteral(StringLiteralType::new(db, (*s).into())) } + Ty::BoolLiteral(b) => Type::BooleanLiteral(b), Ty::LiteralString => Type::LiteralString, Ty::BytesLiteral(s) => { Type::BytesLiteral(BytesLiteralType::new(db, s.as_bytes().into())) @@ -1331,4 +1375,28 @@ mod tests { let db = setup_db(); assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous); } + + #[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))] + #[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))] + #[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))] + #[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("ab'cd"))] // no quotes + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))] + fn has_correct_str(ty: Ty, expected: Ty) { + let db = setup_db(); + + assert_eq!(ty.into_type(&db).str(&db), expected.into_type(&db)); + } + + #[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))] + #[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))] + #[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))] + #[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("'ab\\'cd'"))] // single quotes + #[test_case(Ty::LiteralString, Ty::LiteralString)] + #[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))] + fn has_correct_repr(ty: Ty, expected: Ty) { + let db = setup_db(); + + assert_eq!(ty.into_type(&db).repr(&db), expected.into_type(&db)); + } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 82532c9c70dba..32484fe2bf604 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1653,50 +1653,50 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> { let ast::ExprFString { range: _, value } = fstring; + let mut collector = StringPartsCollector::new(); for part in value { + // Make sure we iter through every parts to infer all sub-expressions. The `collector` + // struct ensures we don't allocate unnecessary strings. match part { - ast::FStringPart::Literal(_) => { - // TODO string literal type + ast::FStringPart::Literal(literal) => { + collector.push_str(&literal.value); } ast::FStringPart::FString(fstring) => { - let ast::FString { - range: _, - elements, - flags: _, - } = fstring; - for element in elements { - self.infer_fstring_element(element); - } - } - } - } - - // TODO str type - Type::Unknown - } - - fn infer_fstring_element(&mut self, element: &ast::FStringElement) { - match element { - ast::FStringElement::Literal(_) => { - // TODO string literal type - } - ast::FStringElement::Expression(expr_element) => { - let ast::FStringExpressionElement { - range: _, - expression, - debug_text: _, - conversion: _, - format_spec, - } = expr_element; - self.infer_expression(expression); - - if let Some(format_spec) = format_spec { - for spec_element in &format_spec.elements { - self.infer_fstring_element(spec_element); + for element in &fstring.elements { + match element { + ast::FStringElement::Expression(expression) => { + let ast::FStringExpressionElement { + range: _, + expression, + debug_text: _, + conversion, + format_spec, + } = expression; + let ty = self.infer_expression(expression); + + // TODO: handle format specifiers by calling a method + // (`Type::format`?) that handles the `__format__` method. + // Conversion flags should be handled before calling `__format__`. + // https://docs.python.org/3/library/string.html#format-string-syntax + if !conversion.is_none() || format_spec.is_some() { + collector.add_expression(); + } else { + if let Type::StringLiteral(literal) = ty.str(self.db) { + collector.push_str(literal.value(self.db)); + } else { + collector.add_expression(); + } + } + } + ast::FStringElement::Literal(literal) => { + collector.push_str(&literal.value); + } + } } } } } + collector.ty(self.db) } fn infer_ellipsis_literal_expression( @@ -2659,6 +2659,53 @@ enum ModuleNameResolutionError { TooManyDots, } +/// Struct collecting string parts when inferring a formatted string. Infers a string literal if the +/// concatenated string is small enough, otherwise infers a literal string. +/// +/// If the formatted string contains an expression (with a representation unknown at compile time), +/// infers an instance of `builtins.str`. +struct StringPartsCollector { + concatenated: Option, + expression: bool, +} + +impl StringPartsCollector { + fn new() -> Self { + Self { + concatenated: Some(String::new()), + expression: false, + } + } + + fn push_str(&mut self, literal: &str) { + if let Some(mut concatenated) = self.concatenated.take() { + if concatenated.len().saturating_add(literal.len()) + <= TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE + { + concatenated.push_str(literal); + self.concatenated = Some(concatenated); + } else { + self.concatenated = None; + } + } + } + + fn add_expression(&mut self) { + self.concatenated = None; + self.expression = true; + } + + fn ty(self, db: &dyn Db) -> Type { + if self.expression { + Type::builtin_str(db).to_instance(db) + } else if let Some(concatenated) = self.concatenated { + Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str())) + } else { + Type::LiteralString + } + } +} + #[cfg(test)] mod tests { @@ -3593,6 +3640,71 @@ mod tests { Ok(()) } + #[test] + fn fstring_expression() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + x = 0 + y = str() + z = False + + a = f'hello' + b = f'h {x}' + c = 'one ' f'single ' f'literal' + d = 'first ' f'second({b})' f' third' + e = f'-{y}-' + f = f'-{y}-' f'--' '--' + g = f'{z} == {False} is {True}' + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "Literal[\"hello\"]"); + assert_public_ty(&db, "src/a.py", "b", "Literal[\"h 0\"]"); + assert_public_ty(&db, "src/a.py", "c", "Literal[\"one single literal\"]"); + assert_public_ty(&db, "src/a.py", "d", "Literal[\"first second(h 0) third\"]"); + assert_public_ty(&db, "src/a.py", "e", "str"); + assert_public_ty(&db, "src/a.py", "f", "str"); + assert_public_ty(&db, "src/a.py", "g", "Literal[\"False == False is True\"]"); + + Ok(()) + } + + #[test] + fn fstring_expression_with_conversion_flags() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + string = 'hello' + a = f'{string!r}' + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "str"); // Should be `Literal["'hello'"]` + + Ok(()) + } + + #[test] + fn fstring_expression_with_format_specifier() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + a = f'{1:02}' + ", + )?; + + assert_public_ty(&db, "src/a.py", "a", "str"); // Should be `Literal["01"]` + + Ok(()) + } + #[test] fn basic_call_expression() -> anyhow::Result<()> { let mut db = setup_db();