Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: insert into sqlite tables #2745

Merged
merged 9 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/datasources/src/sqlite/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use datafusion::arrow::array::{
StringBuilder,
Time64MicrosecondBuilder,
TimestampMicrosecondBuilder,
UInt64Builder,
};
use datafusion::arrow::datatypes::{DataType, SchemaRef, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -41,6 +42,12 @@ impl Converter {
DataType::Boolean => {
Box::new(BooleanBuilder::with_capacity(RECORD_BATCH_CAPACITY))
}
DataType::UInt64 => {
// sqlite can't produce this type, but for
// consistency with other glaredb count
// interfaces, we might
Box::new(UInt64Builder::with_capacity(RECORD_BATCH_CAPACITY))
}
DataType::Int64 => Box::new(Int64Builder::with_capacity(RECORD_BATCH_CAPACITY)),
DataType::Float64 => {
Box::new(Float64Builder::with_capacity(RECORD_BATCH_CAPACITY))
Expand Down
176 changes: 174 additions & 2 deletions crates/datasources/src/sqlite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use datafusion::execution::TaskContext;
use datafusion::logical_expr::{BinaryExpr, TableProviderFilterPushDown, TableType};
use datafusion::physical_expr::PhysicalSortExpr;
use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
execute_stream,
DisplayAs,
DisplayFormatType,
ExecutionPlan,
Expand All @@ -27,13 +29,15 @@ use datafusion::physical_plan::{
Statistics,
};
use datafusion::prelude::Expr;
use datafusion::scalar::ScalarValue;
use datafusion_ext::errors::ExtensionError;
use datafusion_ext::functions::VirtualLister;
use datafusion_ext::metrics::DataSourceMetricsStreamAdapter;
use futures::{StreamExt, TryStreamExt};

use self::errors::Result;
use self::wrapper::SqliteAsyncClient;
use crate::common::util;
use crate::common::util::{self, COUNT_SCHEMA};

type DataFusionResult<T> = Result<T, DataFusionError>;

Expand Down Expand Up @@ -286,6 +290,26 @@ impl TableProvider for SqliteTableProvider {
metrics: ExecutionPlanMetricsSet::new(),
}))
}

tychoish marked this conversation as resolved.
Show resolved Hide resolved

async fn insert_into(
&self,
_state: &SessionState,
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
if overwrite {
return Err(DataFusionError::Execution("cannot overwrite".to_string()));
}

Ok(Arc::new(SqliteInsertExec {
input,
table: self.table.to_string(),
state: self.state.clone(),
schema: self.schema.clone(),
metrics: ExecutionPlanMetricsSet::new(),
}))
}
}

#[derive(Debug)]
Expand Down Expand Up @@ -421,7 +445,6 @@ fn write_expr(expr: &Expr, schema: &Schema, buf: &mut String) -> Result<bool> {
if should_skip_binary_expr(binary, schema)? {
return Ok(false);
}

if !write_expr(binary.left.as_ref(), schema, buf)? {
return Ok(false);
}
Expand Down Expand Up @@ -459,3 +482,152 @@ fn should_skip_binary_expr(expr: &BinaryExpr, schema: &Schema) -> Result<bool> {
// Skip if we're trying to do any kind of binary op with text column
Ok(is_not_supported_dt(&expr.left, schema)? || is_not_supported_dt(&expr.right, schema)?)
}

#[derive(Debug)]
pub struct SqliteInsertExec {
table: String,
input: Arc<dyn ExecutionPlan>,
state: SqliteAccessState,
schema: SchemaRef,
metrics: ExecutionPlanMetricsSet,
}

impl ExecutionPlan for SqliteInsertExec {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
self.schema.clone()
}

fn output_partitioning(&self) -> Partitioning {
Partitioning::UnknownPartitioning(1)
}

fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
Vec::new()
}

fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
if children.is_empty() {
Ok(self)
} else {
Err(DataFusionError::Execution(
"cannot replace children for SqliteInsertExec".to_string(),
))
}
}

fn execute(
&self,
partition: usize,
ctx: Arc<TaskContext>,
) -> DataFusionResult<SendableRecordBatchStream> {
if partition != 0 {
return Err(DataFusionError::Execution(format!(
"invalid partition: {partition}"
)));
}
let input = self.input.clone();
let table = self.table.clone();
let schema = self.schema.clone();
let client = self.state.client.clone();

let stream = futures::stream::once(async move {
let column_names = Arc::new(
schema
.fields
.into_iter()
.map(|field| field.name().to_owned())
.collect::<Vec<String>>()
.join(","),
);
// TODO: this entire operation in the scope of a
// transaction.

// this takes each incoming record batch, creates a single
// insert statement and runs the insert. This means that
// the returned count, depending on the number of input
// record batches, may have multiple rows.
//
// this seems generally unlikely to occur in general use,
// and is potentially correct anyway. (certainly,
// preferable to the handling in the early-2024
// implementation which builds one very large message and
// sends it to the database in a single statement.)
Ok::<_, DataFusionError>(
execute_stream(input, ctx)?
.map(move |input_batch| {
let batch = input_batch?;

if batch.num_rows() == 0 {
return Err(DataFusionError::Execution(
"cannot insert empty value".to_string(),
));
}

let mut stmt = String::default();
write!(
&mut stmt,
"INSERT INTO {} ({}) VALUES ",
table, column_names,
)?;

for row_idx in 0..batch.num_rows() {
let mut row_values = Vec::with_capacity(batch.num_columns());
for column in batch.columns() {
let mut buf = String::default();
util::encode_literal_to_text(
util::Datasource::Sqlite,
&mut buf,
&ScalarValue::try_from_array(column.as_ref(), row_idx)
.map_err(|e| DataFusionError::Execution(e.to_string()))?,
)
.map_err(|e| DataFusionError::Execution(e.to_string()))?;
row_values.push(buf);
}

write!(
&mut stmt,
"{}({})",
if row_idx == 0 { "" } else { ", " },
row_values.join(",")
)?;
}
write!(&mut stmt, ";")?;
Ok(client.query(COUNT_SCHEMA.clone(), stmt))
})
.try_flatten(),
)
})
.try_flatten();

Ok(Box::pin(DataSourceMetricsStreamAdapter::new(
RecordBatchStreamAdapter::new(COUNT_SCHEMA.clone(), Box::pin(stream)),
partition,
&self.metrics,
)))
}

fn statistics(&self) -> DataFusionResult<Statistics> {
Ok(Statistics::new_unknown(self.schema().as_ref()))
}

fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}

impl DisplayAs for SqliteInsertExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "SqliteInsertExec")
}
}
3 changes: 1 addition & 2 deletions crates/datasources/src/sqlite/wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::path::PathBuf;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_sqlite::rusqlite;
use async_sqlite::rusqlite::types::Value;
use async_sqlite::rusqlite::{self, OpenFlags};
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::error::DataFusionError;
Expand All @@ -30,7 +30,6 @@ impl fmt::Debug for SqliteAsyncClient {
impl SqliteAsyncClient {
pub async fn new(path: PathBuf) -> Result<Self> {
let inner = async_sqlite::ClientBuilder::new()
.flags(OpenFlags::SQLITE_OPEN_READ_ONLY)
.path(&path)
.open()
.await?;
Expand Down
2 changes: 1 addition & 1 deletion justfile
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ venv:
pytest *args:
{{VENV_BIN}}/poetry -C tests install --no-root
{{VENV_BIN}}/poetry -C tests lock --no-update
{{VENV_BIN}}/poetry -C tests run pytest -n auto -v --rootdir={{invocation_directory()}}/tests {{ if args == "" {'tests'} else {args} }}
{{VENV_BIN}}/poetry -C tests run pytest -v --rootdir={{invocation_directory()}}/tests {{ if args == "" {'tests'} else {args} }}

# private helpers below
# ---------------------
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/glaredb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def glaredb_path() -> list[pathlib.Path]:
def binary_path(glaredb_path: list[pathlib.Path]) -> pathlib.Path:
return glaredb_path[0] if glaredb_path[0].exists() else glaredb_path[1]


@pytest.fixture
def glaredb_connection(
binary_path: list[pathlib.Path],
Expand Down
18 changes: 17 additions & 1 deletion tests/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ ruff = "0.1.14"
dbt-core = "^1.7.7"
dbt-postgres = "^1.7.7"
pytest-xdist = "^3.5.0"
pysqlite3-binary = "^0.5.2.post3"

[build-system]
requires = ["poetry-core"]
Expand Down
1 change: 1 addition & 0 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def test_dbt_glaredb_external_postgres(

assert result == 5


def test_dbt_join_csv_with_table(
glaredb_connection: psycopg2.extensions.connection,
dbt_project_path: pathlib.Path,
Expand Down
47 changes: 47 additions & 0 deletions tests/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

import pytest
import sqlite3

import psycopg2
import psycopg2.extras

from tests.fixtures.glaredb import glaredb_connection, glaredb_path, binary_path


def test_inserts(
glaredb_connection: psycopg2.extensions.connection,
tmp_path_factory: pytest.TempPathFactory,
):
tmp_dir = tmp_path_factory.mktemp(basename="sqlite-inserts")
db_path = tmp_dir.joinpath("insertdb")

conn = sqlite3.connect(db_path)
db = conn.cursor()

db.execute("create table insertable (a, b, c)")
db.execute("select * from insertable")
assert len(db.fetchall()) == 0

db.close()
conn.commit()
conn.close()

with glaredb_connection.cursor() as curr:
curr.execute(
"create external table einsertable from sqlite "
f"options (location = '{db_path}', table = 'insertable')"
)
curr.execute("alter table einsertable set access_mode to read_write")
curr.execute("insert into einsertable values (1, 2, 3), (4, 5, 6);")

curr.execute("select * from einsertable;")
assert len(curr.fetchall()) == 2

conn = sqlite3.connect(db_path)
db = conn.cursor()
db.execute("select * from insertable;")
assert len(db.fetchall()) == 2

db.close()
conn.close()