Skip to content

Commit

Permalink
refactor: remove unneed mut for session context (#11864)
Browse files Browse the repository at this point in the history
* doc: remove mut from session context docstring

* refactor: remove unnecessary mut for session context

* refactor: remove more unused mut
  • Loading branch information
sunng87 authored Aug 8, 2024
1 parent 60d1d3a commit d0a1d30
Show file tree
Hide file tree
Showing 14 changed files with 65 additions and 62 deletions.
6 changes: 2 additions & 4 deletions datafusion-cli/examples/cli-session-context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl CliSessionContext for MyUnionerContext {
#[tokio::main]
/// Runs the example.
pub async fn main() {
let mut my_ctx = MyUnionerContext::default();
let my_ctx = MyUnionerContext::default();

let mut print_options = PrintOptions {
format: datafusion_cli::print_format::PrintFormat::Automatic,
Expand All @@ -91,7 +91,5 @@ pub async fn main() {
color: true,
};

exec_from_repl(&mut my_ctx, &mut print_options)
.await
.unwrap();
exec_from_repl(&my_ctx, &mut print_options).await.unwrap();
}
2 changes: 1 addition & 1 deletion datafusion-cli/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ mod tests {
use datafusion::prelude::SessionContext;

fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
let mut ctx = SessionContext::new();
let ctx = SessionContext::new();
ctx.register_catalog_list(Arc::new(DynamicFileCatalog::new(
ctx.state().catalog_list().clone(),
ctx.state_weak_ref(),
Expand Down
2 changes: 1 addition & 1 deletion datafusion-cli/src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub enum OutputFormat {
impl Command {
pub async fn execute(
&self,
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
print_options: &mut PrintOptions,
) -> Result<()> {
match self {
Expand Down
16 changes: 8 additions & 8 deletions datafusion-cli/src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use tokio::signal;

/// run and execute SQL statements and commands, against a context with the given print options
pub async fn exec_from_commands(
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
commands: Vec<String>,
print_options: &PrintOptions,
) -> Result<()> {
Expand All @@ -62,7 +62,7 @@ pub async fn exec_from_commands(

/// run and execute SQL statements and commands from a file, against a context with the given print options
pub async fn exec_from_lines(
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
reader: &mut BufReader<File>,
print_options: &PrintOptions,
) -> Result<()> {
Expand Down Expand Up @@ -102,7 +102,7 @@ pub async fn exec_from_lines(
}

pub async fn exec_from_files(
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
files: Vec<String>,
print_options: &PrintOptions,
) -> Result<()> {
Expand All @@ -121,7 +121,7 @@ pub async fn exec_from_files(

/// run and execute SQL statements and commands against a context with the given print options
pub async fn exec_from_repl(
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
print_options: &mut PrintOptions,
) -> rustyline::Result<()> {
let mut rl = Editor::new()?;
Expand Down Expand Up @@ -204,7 +204,7 @@ pub async fn exec_from_repl(
}

pub(super) async fn exec_and_print(
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
print_options: &PrintOptions,
sql: String,
) -> Result<()> {
Expand Down Expand Up @@ -300,7 +300,7 @@ fn config_file_type_from_str(ext: &str) -> Option<ConfigFileType> {
}

async fn create_plan(
ctx: &mut dyn CliSessionContext,
ctx: &dyn CliSessionContext,
statement: Statement,
) -> Result<LogicalPlan, DataFusionError> {
let mut plan = ctx.session_state().statement_to_plan(statement).await?;
Expand Down Expand Up @@ -473,7 +473,7 @@ mod tests {
"cos://bucket/path/file.parquet",
"gcs://bucket/path/file.parquet",
];
let mut ctx = SessionContext::new();
let ctx = SessionContext::new();
let task_ctx = ctx.task_ctx();
let dialect = &task_ctx.session_config().options().sql_parser.dialect;
let dialect = dialect_from_str(dialect).ok_or_else(|| {
Expand All @@ -488,7 +488,7 @@ mod tests {
let statements = DFParser::parse_sql_with_dialect(&sql, dialect.as_ref())?;
for statement in statements {
//Should not fail
let mut plan = create_plan(&mut ctx, statement).await?;
let mut plan = create_plan(&ctx, statement).await?;
if let LogicalPlan::Copy(copy_to) = &mut plan {
assert_eq!(copy_to.output_url, location);
assert_eq!(copy_to.file_type.get_ext(), "parquet".to_string());
Expand Down
10 changes: 5 additions & 5 deletions datafusion-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ async fn main_inner() -> Result<()> {

let runtime_env = create_runtime_env(rt_config.clone())?;

let mut ctx =
let ctx =
SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env));
ctx.refresh_catalogs().await?;
// install dynamic catalog provider that knows how to open files
Expand Down Expand Up @@ -212,20 +212,20 @@ async fn main_inner() -> Result<()> {

if commands.is_empty() && files.is_empty() {
if !rc.is_empty() {
exec::exec_from_files(&mut ctx, rc, &print_options).await?;
exec::exec_from_files(&ctx, rc, &print_options).await?;
}
// TODO maybe we can have thiserror for cli but for now let's keep it simple
return exec::exec_from_repl(&mut ctx, &mut print_options)
return exec::exec_from_repl(&ctx, &mut print_options)
.await
.map_err(|e| DataFusionError::External(Box::new(e)));
}

if !files.is_empty() {
exec::exec_from_files(&mut ctx, files, &print_options).await?;
exec::exec_from_files(&ctx, files, &print_options).await?;
}

if !commands.is_empty() {
exec::exec_from_commands(&mut ctx, commands, &print_options).await?;
exec::exec_from_commands(&ctx, commands, &print_options).await?;
}

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ async fn main() -> Result<()> {
let dir_a = prepare_example_data()?;
let dir_b = prepare_example_data()?;

let mut ctx = SessionContext::new();
let ctx = SessionContext::new();
let state = ctx.state();
let catlist = Arc::new(CustomCatalogProviderList::new());

Expand Down
14 changes: 7 additions & 7 deletions datafusion/core/benches/filter_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use futures::executor::block_on;
use std::sync::Arc;
use tokio::runtime::Runtime;

async fn query(ctx: &mut SessionContext, sql: &str) {
async fn query(ctx: &SessionContext, sql: &str) {
let rt = Runtime::new().unwrap();

// execute the query
Expand Down Expand Up @@ -70,25 +70,25 @@ fn criterion_benchmark(c: &mut Criterion) {
let batch_size = 4096; // 2^12

c.bench_function("filter_array", |b| {
let mut ctx = create_context(array_len, batch_size).unwrap();
b.iter(|| block_on(query(&mut ctx, "select f32, f64 from t where f32 >= f64")))
let ctx = create_context(array_len, batch_size).unwrap();
b.iter(|| block_on(query(&ctx, "select f32, f64 from t where f32 >= f64")))
});

c.bench_function("filter_scalar", |b| {
let mut ctx = create_context(array_len, batch_size).unwrap();
let ctx = create_context(array_len, batch_size).unwrap();
b.iter(|| {
block_on(query(
&mut ctx,
&ctx,
"select f32, f64 from t where f32 >= 250 and f64 > 250",
))
})
});

c.bench_function("filter_scalar in list", |b| {
let mut ctx = create_context(array_len, batch_size).unwrap();
let ctx = create_context(array_len, batch_size).unwrap();
b.iter(|| {
block_on(query(
&mut ctx,
&ctx,
"select f32, f64 from t where f32 in (10, 20, 30, 40)",
))
})
Expand Down
14 changes: 7 additions & 7 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,7 @@ impl DataFrame {
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// # use datafusion_common::ScalarValue;
/// let mut ctx = SessionContext::new();
/// let ctx = SessionContext::new();
/// # ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
/// let results = ctx
/// .sql("SELECT a FROM example WHERE b = $1")
Expand Down Expand Up @@ -2649,8 +2649,8 @@ mod tests {

#[tokio::test]
async fn registry() -> Result<()> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;

// declare the udf
let my_fn: ScalarFunctionImplementation =
Expand Down Expand Up @@ -2783,8 +2783,8 @@ mod tests {

/// Create a logical plan from a SQL query
async fn create_plan(sql: &str) -> Result<LogicalPlan> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
Ok(ctx.sql(sql).await?.into_unoptimized_plan())
}

Expand Down Expand Up @@ -3147,9 +3147,9 @@ mod tests {
"datafusion.sql_parser.enable_ident_normalization".to_owned(),
"false".to_owned(),
)]))?;
let mut ctx = SessionContext::new_with_config(config);
let ctx = SessionContext::new_with_config(config);
let name = "aggregate_test_100";
register_aggregate_csv(&mut ctx, name).await?;
register_aggregate_csv(&ctx, name).await?;
let df = ctx.table(name);

let df = df
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/src/dataframe/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,14 @@ mod tests {
async fn write_parquet_with_small_rg_size() -> Result<()> {
// This test verifies writing a parquet file with small rg size
// relative to datafusion.execution.batch_size does not panic
let mut ctx = SessionContext::new_with_config(
SessionConfig::from_string_hash_map(HashMap::from_iter(
let ctx = SessionContext::new_with_config(SessionConfig::from_string_hash_map(
HashMap::from_iter(
[("datafusion.execution.batch_size", "10")]
.iter()
.map(|(s1, s2)| (s1.to_string(), s2.to_string())),
))?,
);
register_aggregate_csv(&mut ctx, "aggregate_test_100").await?;
),
)?);
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
let test_df = ctx.table("aggregate_test_100").await?;

let output_path = "file://local/test.parquet";
Expand Down
13 changes: 8 additions & 5 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ where
/// # use datafusion::{error::Result, assert_batches_eq};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let mut ctx = SessionContext::new();
/// let ctx = SessionContext::new();
/// ctx.register_csv("example", "tests/data/example.csv", CsvReadOptions::new()).await?;
/// let results = ctx
/// .sql("SELECT a, min(b) FROM example GROUP BY a LIMIT 100")
Expand Down Expand Up @@ -369,7 +369,7 @@ impl SessionContext {
/// # use datafusion_execution::object_store::ObjectStoreUrl;
/// let object_store_url = ObjectStoreUrl::parse("file://").unwrap();
/// let object_store = object_store::local::LocalFileSystem::new();
/// let mut ctx = SessionContext::new();
/// let ctx = SessionContext::new();
/// // All files with the file:// url prefix will be read from the local file system
/// ctx.register_object_store(object_store_url.as_ref(), Arc::new(object_store));
/// ```
Expand Down Expand Up @@ -452,7 +452,7 @@ impl SessionContext {
/// # use datafusion::{error::Result, assert_batches_eq};
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let mut ctx = SessionContext::new();
/// let ctx = SessionContext::new();
/// ctx
/// .sql("CREATE TABLE foo (x INTEGER)")
/// .await?
Expand Down Expand Up @@ -480,7 +480,7 @@ impl SessionContext {
/// # use datafusion::physical_plan::collect;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
/// let mut ctx = SessionContext::new();
/// let ctx = SessionContext::new();
/// let options = SQLOptions::new()
/// .with_allow_ddl(false);
/// let err = ctx.sql_with_options("CREATE TABLE foo (x INTEGER)", options)
Expand Down Expand Up @@ -1357,7 +1357,7 @@ impl SessionContext {
}

/// Register [`CatalogProviderList`] in [`SessionState`]
pub fn register_catalog_list(&mut self, catalog_list: Arc<dyn CatalogProviderList>) {
pub fn register_catalog_list(&self, catalog_list: Arc<dyn CatalogProviderList>) {
self.state.write().register_catalog_list(catalog_list)
}

Expand Down Expand Up @@ -1386,15 +1386,18 @@ impl FunctionRegistry for SessionContext {
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
self.state.read().udwf(name)
}

fn register_udf(&mut self, udf: Arc<ScalarUDF>) -> Result<Option<Arc<ScalarUDF>>> {
self.state.write().register_udf(udf)
}

fn register_udaf(
&mut self,
udaf: Arc<AggregateUDF>,
) -> Result<Option<Arc<AggregateUDF>>> {
self.state.write().register_udaf(udaf)
}

fn register_udwf(&mut self, udwf: Arc<WindowUDF>) -> Result<Option<Arc<WindowUDF>>> {
self.state.write().register_udwf(udwf)
}
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/test_util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ pub fn aggr_test_schema() -> SchemaRef {

/// Register session context for the aggregate_test_100.csv file
pub async fn register_aggregate_csv(
ctx: &mut SessionContext,
ctx: &SessionContext,
table_name: &str,
) -> Result<()> {
let schema = aggr_test_schema();
Expand All @@ -128,8 +128,8 @@ pub async fn register_aggregate_csv(

/// Create a table from the aggregate_test_100.csv file with the specified name
pub async fn test_table_with_name(name: &str) -> Result<DataFrame> {
let mut ctx = SessionContext::new();
register_aggregate_csv(&mut ctx, name).await?;
let ctx = SessionContext::new();
register_aggregate_csv(&ctx, name).await?;
ctx.table(name).await
}

Expand Down
Loading

0 comments on commit d0a1d30

Please sign in to comment.