Skip to content

Commit

Permalink
[red-knot] support fstring expressions (#13511)
Browse files Browse the repository at this point in the history
<!--
Thank you for contributing to Ruff! To help us out with reviewing,
please consider the following:

- Does this pull request include a summary of the change? (See below.)
- Does this pull request include a descriptive title?
- Does this pull request include references to any relevant issues?
-->

## Summary

Implement inference for `f-string`, contributes to #12701.

### First Implementation

When looking at the way `mypy` handles things, I noticed the following:
- No variables (e.g. `f"hello"`) ⇒ `LiteralString`
- Any variable (e.g. `f"number {1}"`) ⇒ `str`

My first commit (1ba5d0f) implements
exactly this logic, except that we deal with string literals just like
`infer_string_literal_expression` (if below `MAX_STRING_LITERAL_SIZE`,
show `Literal["exact string"]`)

### Second Implementation

My second commit (90326ce) pushes
things a bit further to handle cases where the expression within the
`f-string` are all literal values (string representation known at static
time).

Here's an example of when this could happen in code:
```python
BASE_URL = "https://httpbin.org"
VERSION = "v1"
endpoint = f"{BASE_URL}/{VERSION}/post"  # Literal["https://httpbin.org/v1/post"]
```
As this can be sightly more costly (additional allocations), I don't
know if we want this feature.

## Test Plan

- Added a test `fstring_expression` covering all cases I can think of

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
  • Loading branch information
Slyces and carljm authored Sep 27, 2024
1 parent f3e464e commit 1639488
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 37 deletions.
70 changes: 69 additions & 1 deletion crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ impl<'db> Type<'db> {
}
}

pub fn builtin_str(db: &'db dyn Db) -> Self {
builtins_symbol_ty(db, "str")
}

pub fn is_stdlib_symbol(&self, db: &'db dyn Db, module_name: &str, name: &str) -> bool {
match self {
Type::Class(class) => class.is_stdlib_symbol(db, module_name, name),
Expand Down Expand Up @@ -721,6 +725,44 @@ impl<'db> Type<'db> {
Type::Tuple(_) => builtins_symbol_ty(db, "tuple"),
}
}

/// Return the string representation of this type when converted to string as it would be
/// provided by the `__str__` method.
///
/// When not available, this should fall back to the value of `[Type::repr]`.
/// Note: this method is used in the builtins `format`, `print`, `str.format` and `f-strings`.
#[must_use]
pub fn str(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::IntLiteral(_) | Type::BooleanLiteral(_) => self.repr(db),
Type::StringLiteral(_) | Type::LiteralString => *self,
// TODO: handle more complex types
_ => Type::builtin_str(db).to_instance(db),
}
}

/// Return the string representation of this type as it would be provided by the `__repr__`
/// method at runtime.
#[must_use]
pub fn repr(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::IntLiteral(number) => Type::StringLiteral(StringLiteralType::new(db, {
number.to_string().into_boxed_str()
})),
Type::BooleanLiteral(true) => {
Type::StringLiteral(StringLiteralType::new(db, "True".into()))
}
Type::BooleanLiteral(false) => {
Type::StringLiteral(StringLiteralType::new(db, "False".into()))
}
Type::StringLiteral(literal) => Type::StringLiteral(StringLiteralType::new(db, {
format!("'{}'", literal.value(db).escape_default()).into()
})),
Type::LiteralString => Type::LiteralString,
// TODO: handle more complex types
_ => Type::builtin_str(db).to_instance(db),
}
}
}

impl<'db> From<&Type<'db>> for Type<'db> {
Expand Down Expand Up @@ -1198,12 +1240,13 @@ mod tests {

/// A test representation of a type that can be transformed unambiguously into a real Type,
/// given a db.
#[derive(Debug)]
#[derive(Debug, Clone)]
enum Ty {
Never,
Unknown,
Any,
IntLiteral(i64),
BoolLiteral(bool),
StringLiteral(&'static str),
LiteralString,
BytesLiteral(&'static str),
Expand All @@ -1222,6 +1265,7 @@ mod tests {
Ty::StringLiteral(s) => {
Type::StringLiteral(StringLiteralType::new(db, (*s).into()))
}
Ty::BoolLiteral(b) => Type::BooleanLiteral(b),
Ty::LiteralString => Type::LiteralString,
Ty::BytesLiteral(s) => {
Type::BytesLiteral(BytesLiteralType::new(db, s.as_bytes().into()))
Expand Down Expand Up @@ -1331,4 +1375,28 @@ mod tests {
let db = setup_db();
assert_eq!(ty.into_type(&db).bool(&db), Truthiness::Ambiguous);
}

#[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))]
#[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))]
#[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))]
#[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("ab'cd"))] // no quotes
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
fn has_correct_str(ty: Ty, expected: Ty) {
let db = setup_db();

assert_eq!(ty.into_type(&db).str(&db), expected.into_type(&db));
}

#[test_case(Ty::IntLiteral(1), Ty::StringLiteral("1"))]
#[test_case(Ty::BoolLiteral(true), Ty::StringLiteral("True"))]
#[test_case(Ty::BoolLiteral(false), Ty::StringLiteral("False"))]
#[test_case(Ty::StringLiteral("ab'cd"), Ty::StringLiteral("'ab\\'cd'"))] // single quotes
#[test_case(Ty::LiteralString, Ty::LiteralString)]
#[test_case(Ty::BuiltinInstance("int"), Ty::BuiltinInstance("str"))]
fn has_correct_repr(ty: Ty, expected: Ty) {
let db = setup_db();

assert_eq!(ty.into_type(&db).repr(&db), expected.into_type(&db));
}
}
184 changes: 148 additions & 36 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1653,50 +1653,50 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_fstring_expression(&mut self, fstring: &ast::ExprFString) -> Type<'db> {
let ast::ExprFString { range: _, value } = fstring;

let mut collector = StringPartsCollector::new();
for part in value {
// Make sure we iter through every parts to infer all sub-expressions. The `collector`
// struct ensures we don't allocate unnecessary strings.
match part {
ast::FStringPart::Literal(_) => {
// TODO string literal type
ast::FStringPart::Literal(literal) => {
collector.push_str(&literal.value);
}
ast::FStringPart::FString(fstring) => {
let ast::FString {
range: _,
elements,
flags: _,
} = fstring;
for element in elements {
self.infer_fstring_element(element);
}
}
}
}

// TODO str type
Type::Unknown
}

fn infer_fstring_element(&mut self, element: &ast::FStringElement) {
match element {
ast::FStringElement::Literal(_) => {
// TODO string literal type
}
ast::FStringElement::Expression(expr_element) => {
let ast::FStringExpressionElement {
range: _,
expression,
debug_text: _,
conversion: _,
format_spec,
} = expr_element;
self.infer_expression(expression);

if let Some(format_spec) = format_spec {
for spec_element in &format_spec.elements {
self.infer_fstring_element(spec_element);
for element in &fstring.elements {
match element {
ast::FStringElement::Expression(expression) => {
let ast::FStringExpressionElement {
range: _,
expression,
debug_text: _,
conversion,
format_spec,
} = expression;
let ty = self.infer_expression(expression);

// TODO: handle format specifiers by calling a method
// (`Type::format`?) that handles the `__format__` method.
// Conversion flags should be handled before calling `__format__`.
// https://docs.python.org/3/library/string.html#format-string-syntax
if !conversion.is_none() || format_spec.is_some() {
collector.add_expression();
} else {
if let Type::StringLiteral(literal) = ty.str(self.db) {
collector.push_str(literal.value(self.db));
} else {
collector.add_expression();
}
}
}
ast::FStringElement::Literal(literal) => {
collector.push_str(&literal.value);
}
}
}
}
}
}
collector.ty(self.db)
}

fn infer_ellipsis_literal_expression(
Expand Down Expand Up @@ -2659,6 +2659,53 @@ enum ModuleNameResolutionError {
TooManyDots,
}

/// Struct collecting string parts when inferring a formatted string. Infers a string literal if the
/// concatenated string is small enough, otherwise infers a literal string.
///
/// If the formatted string contains an expression (with a representation unknown at compile time),
/// infers an instance of `builtins.str`.
struct StringPartsCollector {
concatenated: Option<String>,
expression: bool,
}

impl StringPartsCollector {
fn new() -> Self {
Self {
concatenated: Some(String::new()),
expression: false,
}
}

fn push_str(&mut self, literal: &str) {
if let Some(mut concatenated) = self.concatenated.take() {
if concatenated.len().saturating_add(literal.len())
<= TypeInferenceBuilder::MAX_STRING_LITERAL_SIZE
{
concatenated.push_str(literal);
self.concatenated = Some(concatenated);
} else {
self.concatenated = None;
}
}
}

fn add_expression(&mut self) {
self.concatenated = None;
self.expression = true;
}

fn ty(self, db: &dyn Db) -> Type {
if self.expression {
Type::builtin_str(db).to_instance(db)
} else if let Some(concatenated) = self.concatenated {
Type::StringLiteral(StringLiteralType::new(db, concatenated.into_boxed_str()))
} else {
Type::LiteralString
}
}
}

#[cfg(test)]
mod tests {

Expand Down Expand Up @@ -3593,6 +3640,71 @@ mod tests {
Ok(())
}

#[test]
fn fstring_expression() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
x = 0
y = str()
z = False
a = f'hello'
b = f'h {x}'
c = 'one ' f'single ' f'literal'
d = 'first ' f'second({b})' f' third'
e = f'-{y}-'
f = f'-{y}-' f'--' '--'
g = f'{z} == {False} is {True}'
",
)?;

assert_public_ty(&db, "src/a.py", "a", "Literal[\"hello\"]");
assert_public_ty(&db, "src/a.py", "b", "Literal[\"h 0\"]");
assert_public_ty(&db, "src/a.py", "c", "Literal[\"one single literal\"]");
assert_public_ty(&db, "src/a.py", "d", "Literal[\"first second(h 0) third\"]");
assert_public_ty(&db, "src/a.py", "e", "str");
assert_public_ty(&db, "src/a.py", "f", "str");
assert_public_ty(&db, "src/a.py", "g", "Literal[\"False == False is True\"]");

Ok(())
}

#[test]
fn fstring_expression_with_conversion_flags() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
string = 'hello'
a = f'{string!r}'
",
)?;

assert_public_ty(&db, "src/a.py", "a", "str"); // Should be `Literal["'hello'"]`

Ok(())
}

#[test]
fn fstring_expression_with_format_specifier() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
a = f'{1:02}'
",
)?;

assert_public_ty(&db, "src/a.py", "a", "str"); // Should be `Literal["01"]`

Ok(())
}

#[test]
fn basic_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit 1639488

Please sign in to comment.