From a714572fc29286d83733b9f53f188e0f2a8196a9 Mon Sep 17 00:00:00 2001 From: Charlie Marsh Date: Wed, 16 Aug 2023 21:08:55 -0400 Subject: [PATCH] Introduce ExpressionRef --- crates/ruff_python_ast/src/expression.rs | 291 ++++++++++++++++++ crates/ruff_python_ast/src/lib.rs | 2 + .../src/expression/expr_bin_op.rs | 13 +- .../src/expression/expr_bool_op.rs | 12 +- .../src/expression/mod.rs | 20 +- .../src/expression/parentheses.rs | 8 +- 6 files changed, 317 insertions(+), 29 deletions(-) create mode 100644 crates/ruff_python_ast/src/expression.rs diff --git a/crates/ruff_python_ast/src/expression.rs b/crates/ruff_python_ast/src/expression.rs new file mode 100644 index 0000000000000..c76b093faba4d --- /dev/null +++ b/crates/ruff_python_ast/src/expression.rs @@ -0,0 +1,291 @@ +use ruff_text_size::TextRange; + +use crate::node::AnyNodeRef; +use crate::{self as ast, Expr, Ranged}; + +/// Unowned pendant to [`ast::Expr`] that stores a reference instead of a owned value. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum ExpressionRef<'a> { + BoolOp(&'a ast::ExprBoolOp), + NamedExpr(&'a ast::ExprNamedExpr), + BinOp(&'a ast::ExprBinOp), + UnaryOp(&'a ast::ExprUnaryOp), + Lambda(&'a ast::ExprLambda), + IfExp(&'a ast::ExprIfExp), + Dict(&'a ast::ExprDict), + Set(&'a ast::ExprSet), + ListComp(&'a ast::ExprListComp), + SetComp(&'a ast::ExprSetComp), + DictComp(&'a ast::ExprDictComp), + GeneratorExp(&'a ast::ExprGeneratorExp), + Await(&'a ast::ExprAwait), + Yield(&'a ast::ExprYield), + YieldFrom(&'a ast::ExprYieldFrom), + Compare(&'a ast::ExprCompare), + Call(&'a ast::ExprCall), + FormattedValue(&'a ast::ExprFormattedValue), + FString(&'a ast::ExprFString), + Constant(&'a ast::ExprConstant), + Attribute(&'a ast::ExprAttribute), + Subscript(&'a ast::ExprSubscript), + Starred(&'a ast::ExprStarred), + Name(&'a ast::ExprName), + List(&'a ast::ExprList), + Tuple(&'a ast::ExprTuple), + Slice(&'a ast::ExprSlice), + IpyEscapeCommand(&'a ast::ExprIpyEscapeCommand), +} + +impl<'a> From<&'a Box> for ExpressionRef<'a> { + fn from(value: &'a Box) -> Self { + ExpressionRef::from(value.as_ref()) + } +} + +impl<'a> From<&'a Expr> for ExpressionRef<'a> { + fn from(value: &'a Expr) -> Self { + match value { + Expr::BoolOp(value) => ExpressionRef::BoolOp(value), + Expr::NamedExpr(value) => ExpressionRef::NamedExpr(value), + Expr::BinOp(value) => ExpressionRef::BinOp(value), + Expr::UnaryOp(value) => ExpressionRef::UnaryOp(value), + Expr::Lambda(value) => ExpressionRef::Lambda(value), + Expr::IfExp(value) => ExpressionRef::IfExp(value), + Expr::Dict(value) => ExpressionRef::Dict(value), + Expr::Set(value) => ExpressionRef::Set(value), + Expr::ListComp(value) => ExpressionRef::ListComp(value), + Expr::SetComp(value) => ExpressionRef::SetComp(value), + Expr::DictComp(value) => ExpressionRef::DictComp(value), + Expr::GeneratorExp(value) => ExpressionRef::GeneratorExp(value), + Expr::Await(value) => ExpressionRef::Await(value), + Expr::Yield(value) => ExpressionRef::Yield(value), + Expr::YieldFrom(value) => ExpressionRef::YieldFrom(value), + Expr::Compare(value) => ExpressionRef::Compare(value), + Expr::Call(value) => ExpressionRef::Call(value), + Expr::FormattedValue(value) => ExpressionRef::FormattedValue(value), + Expr::FString(value) => ExpressionRef::FString(value), + Expr::Constant(value) => ExpressionRef::Constant(value), + Expr::Attribute(value) => ExpressionRef::Attribute(value), + Expr::Subscript(value) => ExpressionRef::Subscript(value), + Expr::Starred(value) => ExpressionRef::Starred(value), + Expr::Name(value) => ExpressionRef::Name(value), + Expr::List(value) => ExpressionRef::List(value), + Expr::Tuple(value) => ExpressionRef::Tuple(value), + Expr::Slice(value) => ExpressionRef::Slice(value), + Expr::IpyEscapeCommand(value) => ExpressionRef::IpyEscapeCommand(value), + } + } +} + +impl<'a> From<&'a ast::ExprBoolOp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprBoolOp) -> Self { + Self::BoolOp(value) + } +} +impl<'a> From<&'a ast::ExprNamedExpr> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprNamedExpr) -> Self { + Self::NamedExpr(value) + } +} +impl<'a> From<&'a ast::ExprBinOp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprBinOp) -> Self { + Self::BinOp(value) + } +} +impl<'a> From<&'a ast::ExprUnaryOp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprUnaryOp) -> Self { + Self::UnaryOp(value) + } +} +impl<'a> From<&'a ast::ExprLambda> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprLambda) -> Self { + Self::Lambda(value) + } +} +impl<'a> From<&'a ast::ExprIfExp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprIfExp) -> Self { + Self::IfExp(value) + } +} +impl<'a> From<&'a ast::ExprDict> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprDict) -> Self { + Self::Dict(value) + } +} +impl<'a> From<&'a ast::ExprSet> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSet) -> Self { + Self::Set(value) + } +} +impl<'a> From<&'a ast::ExprListComp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprListComp) -> Self { + Self::ListComp(value) + } +} +impl<'a> From<&'a ast::ExprSetComp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSetComp) -> Self { + Self::SetComp(value) + } +} +impl<'a> From<&'a ast::ExprDictComp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprDictComp) -> Self { + Self::DictComp(value) + } +} +impl<'a> From<&'a ast::ExprGeneratorExp> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprGeneratorExp) -> Self { + Self::GeneratorExp(value) + } +} +impl<'a> From<&'a ast::ExprAwait> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprAwait) -> Self { + Self::Await(value) + } +} +impl<'a> From<&'a ast::ExprYield> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprYield) -> Self { + Self::Yield(value) + } +} +impl<'a> From<&'a ast::ExprYieldFrom> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprYieldFrom) -> Self { + Self::YieldFrom(value) + } +} +impl<'a> From<&'a ast::ExprCompare> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprCompare) -> Self { + Self::Compare(value) + } +} +impl<'a> From<&'a ast::ExprCall> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprCall) -> Self { + Self::Call(value) + } +} +impl<'a> From<&'a ast::ExprFormattedValue> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprFormattedValue) -> Self { + Self::FormattedValue(value) + } +} +impl<'a> From<&'a ast::ExprFString> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprFString) -> Self { + Self::FString(value) + } +} +impl<'a> From<&'a ast::ExprConstant> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprConstant) -> Self { + Self::Constant(value) + } +} +impl<'a> From<&'a ast::ExprAttribute> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprAttribute) -> Self { + Self::Attribute(value) + } +} +impl<'a> From<&'a ast::ExprSubscript> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSubscript) -> Self { + Self::Subscript(value) + } +} +impl<'a> From<&'a ast::ExprStarred> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprStarred) -> Self { + Self::Starred(value) + } +} +impl<'a> From<&'a ast::ExprName> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprName) -> Self { + Self::Name(value) + } +} +impl<'a> From<&'a ast::ExprList> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprList) -> Self { + Self::List(value) + } +} +impl<'a> From<&'a ast::ExprTuple> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprTuple) -> Self { + Self::Tuple(value) + } +} +impl<'a> From<&'a ast::ExprSlice> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprSlice) -> Self { + Self::Slice(value) + } +} +impl<'a> From<&'a ast::ExprIpyEscapeCommand> for ExpressionRef<'a> { + fn from(value: &'a ast::ExprIpyEscapeCommand) -> Self { + Self::IpyEscapeCommand(value) + } +} + +impl<'a> From> for AnyNodeRef<'a> { + fn from(value: ExpressionRef<'a>) -> Self { + match value { + ExpressionRef::BoolOp(expression) => AnyNodeRef::ExprBoolOp(expression), + ExpressionRef::NamedExpr(expression) => AnyNodeRef::ExprNamedExpr(expression), + ExpressionRef::BinOp(expression) => AnyNodeRef::ExprBinOp(expression), + ExpressionRef::UnaryOp(expression) => AnyNodeRef::ExprUnaryOp(expression), + ExpressionRef::Lambda(expression) => AnyNodeRef::ExprLambda(expression), + ExpressionRef::IfExp(expression) => AnyNodeRef::ExprIfExp(expression), + ExpressionRef::Dict(expression) => AnyNodeRef::ExprDict(expression), + ExpressionRef::Set(expression) => AnyNodeRef::ExprSet(expression), + ExpressionRef::ListComp(expression) => AnyNodeRef::ExprListComp(expression), + ExpressionRef::SetComp(expression) => AnyNodeRef::ExprSetComp(expression), + ExpressionRef::DictComp(expression) => AnyNodeRef::ExprDictComp(expression), + ExpressionRef::GeneratorExp(expression) => AnyNodeRef::ExprGeneratorExp(expression), + ExpressionRef::Await(expression) => AnyNodeRef::ExprAwait(expression), + ExpressionRef::Yield(expression) => AnyNodeRef::ExprYield(expression), + ExpressionRef::YieldFrom(expression) => AnyNodeRef::ExprYieldFrom(expression), + ExpressionRef::Compare(expression) => AnyNodeRef::ExprCompare(expression), + ExpressionRef::Call(expression) => AnyNodeRef::ExprCall(expression), + ExpressionRef::FormattedValue(expression) => AnyNodeRef::ExprFormattedValue(expression), + ExpressionRef::FString(expression) => AnyNodeRef::ExprFString(expression), + ExpressionRef::Constant(expression) => AnyNodeRef::ExprConstant(expression), + ExpressionRef::Attribute(expression) => AnyNodeRef::ExprAttribute(expression), + ExpressionRef::Subscript(expression) => AnyNodeRef::ExprSubscript(expression), + ExpressionRef::Starred(expression) => AnyNodeRef::ExprStarred(expression), + ExpressionRef::Name(expression) => AnyNodeRef::ExprName(expression), + ExpressionRef::List(expression) => AnyNodeRef::ExprList(expression), + ExpressionRef::Tuple(expression) => AnyNodeRef::ExprTuple(expression), + ExpressionRef::Slice(expression) => AnyNodeRef::ExprSlice(expression), + ExpressionRef::IpyEscapeCommand(expression) => { + AnyNodeRef::ExprIpyEscapeCommand(expression) + } + } + } +} + +impl Ranged for ExpressionRef<'_> { + fn range(&self) -> TextRange { + match self { + ExpressionRef::BoolOp(expression) => expression.range(), + ExpressionRef::NamedExpr(expression) => expression.range(), + ExpressionRef::BinOp(expression) => expression.range(), + ExpressionRef::UnaryOp(expression) => expression.range(), + ExpressionRef::Lambda(expression) => expression.range(), + ExpressionRef::IfExp(expression) => expression.range(), + ExpressionRef::Dict(expression) => expression.range(), + ExpressionRef::Set(expression) => expression.range(), + ExpressionRef::ListComp(expression) => expression.range(), + ExpressionRef::SetComp(expression) => expression.range(), + ExpressionRef::DictComp(expression) => expression.range(), + ExpressionRef::GeneratorExp(expression) => expression.range(), + ExpressionRef::Await(expression) => expression.range(), + ExpressionRef::Yield(expression) => expression.range(), + ExpressionRef::YieldFrom(expression) => expression.range(), + ExpressionRef::Compare(expression) => expression.range(), + ExpressionRef::Call(expression) => expression.range(), + ExpressionRef::FormattedValue(expression) => expression.range(), + ExpressionRef::FString(expression) => expression.range(), + ExpressionRef::Constant(expression) => expression.range(), + ExpressionRef::Attribute(expression) => expression.range(), + ExpressionRef::Subscript(expression) => expression.range(), + ExpressionRef::Starred(expression) => expression.range(), + ExpressionRef::Name(expression) => expression.range(), + ExpressionRef::List(expression) => expression.range(), + ExpressionRef::Tuple(expression) => expression.range(), + ExpressionRef::Slice(expression) => expression.range(), + ExpressionRef::IpyEscapeCommand(expression) => expression.range(), + } + } +} diff --git a/crates/ruff_python_ast/src/lib.rs b/crates/ruff_python_ast/src/lib.rs index 3fb4c5f170059..ac615c12803aa 100644 --- a/crates/ruff_python_ast/src/lib.rs +++ b/crates/ruff_python_ast/src/lib.rs @@ -5,6 +5,7 @@ pub mod all; pub mod call_path; pub mod comparable; pub mod docstrings; +mod expression; pub mod hashable; pub mod helpers; pub mod identifier; @@ -20,6 +21,7 @@ pub mod types; pub mod visitor; pub mod whitespace; +pub use expression::*; pub use nodes::*; pub trait Ranged { diff --git a/crates/ruff_python_formatter/src/expression/expr_bin_op.rs b/crates/ruff_python_formatter/src/expression/expr_bin_op.rs index 26071e293d774..c8e7cd51daf27 100644 --- a/crates/ruff_python_formatter/src/expression/expr_bin_op.rs +++ b/crates/ruff_python_formatter/src/expression/expr_bin_op.rs @@ -1,13 +1,13 @@ use std::iter; +use smallvec::SmallVec; + +use ruff_formatter::{format_args, write, FormatOwnedWithRule, FormatRefWithRule}; +use ruff_python_ast::node::AnyNodeRef; use ruff_python_ast::{ Constant, Expr, ExprAttribute, ExprBinOp, ExprConstant, ExprUnaryOp, Operator, StringConstant, UnaryOp, }; -use smallvec::SmallVec; - -use ruff_formatter::{format_args, write, FormatOwnedWithRule, FormatRefWithRule}; -use ruff_python_ast::node::{AnyNodeRef, AstNode}; use crate::comments::{trailing_comments, trailing_node_comments, SourceComment}; use crate::expression::expr_constant::ExprConstantLayout; @@ -73,10 +73,7 @@ impl FormatNodeRule for FormatExprBinOp { let binary_chain: SmallVec<[&ExprBinOp; 4]> = iter::successors(Some(item), |parent| { parent.left.as_bin_op_expr().and_then(|bin_expression| { - if is_expression_parenthesized( - bin_expression.as_any_node_ref(), - source, - ) { + if is_expression_parenthesized(bin_expression.into(), source) { None } else { Some(bin_expression) diff --git a/crates/ruff_python_formatter/src/expression/expr_bool_op.rs b/crates/ruff_python_formatter/src/expression/expr_bool_op.rs index 9449c6662ada2..8c818ffdfc047 100644 --- a/crates/ruff_python_formatter/src/expression/expr_bool_op.rs +++ b/crates/ruff_python_formatter/src/expression/expr_bool_op.rs @@ -1,12 +1,13 @@ +use ruff_formatter::{write, FormatOwnedWithRule, FormatRefWithRule, FormatRuleWithOptions}; +use ruff_python_ast::node::AnyNodeRef; +use ruff_python_ast::{BoolOp, Expr, ExprBoolOp}; + use crate::comments::leading_comments; use crate::expression::parentheses::{ in_parentheses_only_group, in_parentheses_only_soft_line_break_or_space, NeedsParentheses, OptionalParentheses, }; use crate::prelude::*; -use ruff_formatter::{write, FormatOwnedWithRule, FormatRefWithRule, FormatRuleWithOptions}; -use ruff_python_ast::node::{AnyNodeRef, AstNode}; -use ruff_python_ast::{BoolOp, Expr, ExprBoolOp}; use super::parentheses::is_expression_parenthesized; @@ -95,10 +96,7 @@ impl Format> for FormatValue<'_> { fn fmt(&self, f: &mut PyFormatter) -> FormatResult<()> { match self.value { Expr::BoolOp(bool_op) - if !is_expression_parenthesized( - bool_op.as_any_node_ref(), - f.context().source(), - ) => + if !is_expression_parenthesized(bool_op.into(), f.context().source()) => { // Mark chained boolean operations e.g. `x and y or z` and avoid creating a new group write!(f, [bool_op.format().with_options(BoolOpLayout::Chained)]) diff --git a/crates/ruff_python_formatter/src/expression/mod.rs b/crates/ruff_python_formatter/src/expression/mod.rs index a617c91328776..0991be5fc2357 100644 --- a/crates/ruff_python_formatter/src/expression/mod.rs +++ b/crates/ruff_python_formatter/src/expression/mod.rs @@ -6,7 +6,7 @@ use ruff_formatter::{ use ruff_python_ast as ast; use ruff_python_ast::node::AnyNodeRef; use ruff_python_ast::visitor::preorder::{walk_expr, PreorderVisitor}; -use ruff_python_ast::{Expr, Operator}; +use ruff_python_ast::{Expr, ExpressionRef, Operator}; use crate::builders::parenthesize_if_expands; use crate::context::{NodeLevel, WithNodeLevel}; @@ -472,7 +472,7 @@ impl<'input> PreorderVisitor<'input> for CanOmitOptionalParenthesesVisitor<'inpu self.last = Some(expr); // Rule only applies for non-parenthesized expressions. - if is_expression_parenthesized(AnyNodeRef::from(expr), self.context.source()) { + if is_expression_parenthesized(expr.into(), self.context.source()) { self.any_parenthesized_expressions = true; } else { self.visit_subexpression(expr); @@ -526,12 +526,12 @@ pub enum CallChainLayout { } impl CallChainLayout { - pub(crate) fn from_expression(mut expr: AnyNodeRef, source: &str) -> Self { + pub(crate) fn from_expression(mut expr: ExpressionRef, source: &str) -> Self { let mut attributes_after_parentheses = 0; loop { match expr { - AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, .. }) => { - expr = AnyNodeRef::from(value.as_ref()); + ExpressionRef::Attribute(ast::ExprAttribute { value, .. }) => { + expr = ExpressionRef::from(value.as_ref()); // ``` // f().g // ^^^ value @@ -554,9 +554,9 @@ impl CallChainLayout { // ^^^^^^^^^^ expr // ^^^^ value // ``` - AnyNodeRef::ExprCall(ast::ExprCall { func: inner, .. }) - | AnyNodeRef::ExprSubscript(ast::ExprSubscript { value: inner, .. }) => { - expr = AnyNodeRef::from(inner.as_ref()); + ExpressionRef::Call(ast::ExprCall { func: inner, .. }) + | ExpressionRef::Subscript(ast::ExprSubscript { value: inner, .. }) => { + expr = ExpressionRef::from(inner.as_ref()); } _ => { // We to format the following in fluent style: @@ -586,7 +586,7 @@ impl CallChainLayout { /// formatting pub(crate) fn apply_in_node<'a>( self, - item: impl Into>, + item: impl Into>, f: &mut PyFormatter, ) -> CallChainLayout { match self { @@ -627,7 +627,7 @@ fn has_parentheses(expr: &Expr, context: &PyFormatContext) -> Option bool { +pub(crate) fn is_expression_parenthesized(expr: ExpressionRef, contents: &str) -> bool { // First test if there's a closing parentheses because it tends to be cheaper. if matches!( first_non_trivia_token(expr.end(), contents), @@ -378,7 +378,7 @@ impl Format> for FormatEmptyParenthesized<'_> { #[cfg(test)] mod tests { - use ruff_python_ast::node::AnyNodeRef; + use ruff_python_ast::ExpressionRef; use ruff_python_parser::parse_expression; use crate::expression::parentheses::is_expression_parenthesized; @@ -388,7 +388,7 @@ mod tests { let expression = r#"(b().c("")).d()"#; let expr = parse_expression(expression, "").unwrap(); assert!(!is_expression_parenthesized( - AnyNodeRef::from(&expr), + ExpressionRef::from(&expr), expression )); }