Skip to content

Commit

Permalink
cleanup object store scans
Browse files Browse the repository at this point in the history
  • Loading branch information
scsmithr committed Jan 21, 2024
1 parent 2417212 commit faf8108
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 97 deletions.
1 change: 1 addition & 0 deletions crates/datafusion_ext/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ where
}
}

#[derive(Debug, Clone)]
pub struct IdentValue(String);

impl IdentValue {
Expand Down
203 changes: 112 additions & 91 deletions crates/sqlbuiltins/src/functions/table/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,31 +224,16 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
args: &[FuncParamValue],
_parent: RuntimePreference,
) -> Result<RuntimePreference> {
let mut args = args.iter();
let url_arg = args.next().unwrap().to_owned();

let urls: Vec<DatasourceUrl> = if url_arg.is_valid::<DatasourceUrl>() {
vec![url_arg.try_into()?]
} else {
url_arg.try_into()?
};

let mut urls = urls.iter().map(|url| match url.datasource_url_type() {
let urls = self.urls_from_args(args)?;
// All urls are of the same type, just need to get the runtime from the
// first.
Ok(match urls.first().unwrap().datasource_url_type() {
DatasourceUrlType::File => RuntimePreference::Local,
DatasourceUrlType::Http => RuntimePreference::Remote,
DatasourceUrlType::Gcs => RuntimePreference::Remote,
DatasourceUrlType::S3 => RuntimePreference::Remote,
DatasourceUrlType::Azure => RuntimePreference::Remote,
});
let first = urls.next().unwrap();

if urls.all(|url| std::mem::discriminant(&url) == std::mem::discriminant(&first)) {
Ok(first)
} else {
Err(ExtensionError::String(
"cannot mix different types of urls".to_owned(),
))
}
})
}

async fn create_provider(
Expand All @@ -257,24 +242,8 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
args: Vec<FuncParamValue>,
mut opts: HashMap<String, FuncParamValue>,
) -> Result<Arc<dyn TableProvider>> {
if args.is_empty() {
return Err(ExtensionError::InvalidNumArgs);
}

let mut args = args.into_iter();
let url_arg = args.next().unwrap().to_owned();

let urls: Vec<DatasourceUrl> = if url_arg.is_valid::<DatasourceUrl>() {
vec![url_arg.try_into()?]
} else {
url_arg.try_into()?
};

if urls.is_empty() {
return Err(ExtensionError::String(
"at least one url expected".to_owned(),
));
}
let urls = self.urls_from_args(&args)?;
let creds_ident = self.credentials_from_args(&args)?;

// Read in user provided options and use them to construct the format.
let mut format = Opts::read_options(&opts)?;
Expand All @@ -287,10 +256,7 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
Some(cmp.parse::<FileCompressionType>()?)
}
None => {
let path = urls
.first()
.ok_or_else(|| ExtensionError::String("at least one url expected".to_string()))?
.path();
let path = urls.first().expect("non-empty urls").path();
let path = std::path::Path::new(path.as_ref());
path.extension()
.and_then(|ext| ext.to_string_lossy().as_ref().parse().ok())
Expand All @@ -308,7 +274,7 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
(Arc<dyn ObjStoreAccess>, Vec<DatasourceUrl>),
> = HashMap::new();
for source_url in urls {
let access = get_store_access(ctx, &source_url, args.clone(), opts.clone())?;
let access = get_store_access(ctx, &source_url, creds_ident.as_ref(), opts.clone())?;
let base_url = access
.base_url()
.map_err(|e| ExtensionError::Access(Box::new(e)))?;
Expand Down Expand Up @@ -340,6 +306,59 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
}
}

impl<Opts> ObjScanTableFunc<Opts> {
/// Get data source urls form the function arguments.
///
/// The returned vec is guaranteed to have all urls be of the same data
/// source type, and will contain at least one url.
fn urls_from_args(&self, args: &[FuncParamValue]) -> Result<Vec<DatasourceUrl>> {
let mut args = args.iter();
let url_arg = match args.next() {
Some(arg) => arg.to_owned(),
None => {
return Err(ExtensionError::String(
"Expected at least one argument.".to_string(),
))
}
};

// TODO: wtf?
let urls: Vec<DatasourceUrl> = if url_arg.is_valid::<DatasourceUrl>() {
vec![url_arg.try_into()?]
} else {
url_arg.try_into()?
};

if urls.is_empty() {
return Err(ExtensionError::String(
"Expected at least one url.".to_string(),
));
}

let first = urls.first().unwrap();
if !urls
.iter()
.all(|url| url.datasource_url_type() == first.datasource_url_type())
{
return Err(ExtensionError::String(
"Cannot mix different types of urls.".to_string(),
));
}

Ok(urls)
}

/// Try to pull an identifier for credentials out of arguments.
///
/// Credential identifiers are expected to be the second argument to the
/// function.
fn credentials_from_args(&self, args: &[FuncParamValue]) -> Result<Option<IdentValue>> {
args.get(1)
.map(|arg| IdentValue::try_from(arg.clone()))
.transpose()
}
}

/// Gets a table provider for the files at location.
///
/// If the file is detected to be local, the table provider will be wrapped in a
Expand All @@ -359,61 +378,24 @@ async fn get_table_provider(
Ok(prov)
}

/// Get's an object store accessor for the provided url.
///
/// If the object store requires credentials, `creds_ident` can be provided to
/// lookup saved credentials in the catalog. Otherwise individual values (access
/// keys, etc) will be pulled out of `opts`.
fn get_store_access(
ctx: &dyn TableFuncContextProvider,
source_url: &DatasourceUrl,
mut args: vec::IntoIter<FuncParamValue>,
creds_ident: Option<&IdentValue>,
mut opts: HashMap<String, FuncParamValue>,
) -> Result<Arc<dyn ObjStoreAccess>> {
let access: Arc<dyn ObjStoreAccess> = match args.len() {
0 => {
// Raw credentials or No credentials
match source_url.datasource_url_type() {
DatasourceUrlType::Http => create_http_store_access(source_url)?,
DatasourceUrlType::File => create_local_store_access(ctx)?,
DatasourceUrlType::Gcs => {
let service_account_key = opts
.remove("service_account_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_gcs_table_provider(source_url, service_account_key)?
}
DatasourceUrlType::S3 => {
let access_key_id = opts
.remove("access_key_id")
.map(FuncParamValue::try_into)
.transpose()?;

let secret_access_key = opts
.remove("secret_access_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_s3_store_access(source_url, &mut opts, access_key_id, secret_access_key)?
}
DatasourceUrlType::Azure => {
let access_key = opts
.remove("access_key")
.map(FuncParamValue::try_into)
.transpose()?;
let account = opts
.remove("account_name")
.map(FuncParamValue::try_into)
.transpose()?;

create_azure_store_access(source_url, account, access_key)?
}
}
}
1 => {
// Credentials object
let creds: IdentValue = args.next().unwrap().try_into()?;
let access: Arc<dyn ObjStoreAccess> = match creds_ident {
Some(ident) => {
let creds = ctx
.get_session_catalog()
.resolve_credentials(creds.as_str())
.resolve_credentials(ident.as_str())
.ok_or(ExtensionError::String(format!(
"missing credentials object: {creds}"
"missing credentials object: {ident}"
)))?;

match source_url.datasource_url_type() {
Expand Down Expand Up @@ -472,7 +454,46 @@ fn get_store_access(
}
}
}
_ => return Err(ExtensionError::InvalidNumArgs),
None => {
// Raw credentials or No credentials
match source_url.datasource_url_type() {
DatasourceUrlType::Http => create_http_store_access(source_url)?,
DatasourceUrlType::File => create_local_store_access(ctx)?,
DatasourceUrlType::Gcs => {
let service_account_key = opts
.remove("service_account_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_gcs_table_provider(source_url, service_account_key)?
}
DatasourceUrlType::S3 => {
let access_key_id = opts
.remove("access_key_id")
.map(FuncParamValue::try_into)
.transpose()?;

let secret_access_key = opts
.remove("secret_access_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_s3_store_access(source_url, &mut opts, access_key_id, secret_access_key)?
}
DatasourceUrlType::Azure => {
let access_key = opts
.remove("access_key")
.map(FuncParamValue::try_into)
.transpose()?;
let account = opts
.remove("account_name")
.map(FuncParamValue::try_into)
.transpose()?;

create_azure_store_access(source_url, account, access_key)?
}
}
}
};

Ok(access)
Expand Down
1 change: 1 addition & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ rpc-tests:
'sqllogictests/cast/*' \
'sqllogictests/cte/*' \
'sqllogictests/functions/arrow_cast' \
'sqllogictests/functions/csv_scan' \
'sqllogictests/functions/delta_scan' \
'sqllogictests/functions/generate_series' \
'sqllogictests/functions/version' \
Expand Down
2 changes: 1 addition & 1 deletion testdata/sqllogictests/functions/csv_scan.slt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ select count(*) from csv_scan([
----
204

statement error at least one url expected
statement error Expected at least one url.
select * from csv_scan([]);

# Glob patterns not supported on HTTP
Expand Down
2 changes: 1 addition & 1 deletion testdata/sqllogictests/functions/json_scan.slt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ select count(*) from ndjson_scan([
----
204

statement error at least one url expected
statement error Expected at least one url.
select * from ndjson_scan([]);

# Glob patterns not supported on HTTP
Expand Down
2 changes: 1 addition & 1 deletion testdata/sqllogictests/functions/read_csv.slt
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ select count(*) from read_csv([
----
204

statement error at least one url expected
statement error Expected at least one url.
select * from read_csv([]);

# Glob patterns not supported on HTTP
Expand Down
6 changes: 3 additions & 3 deletions testdata/sqllogictests/functions/read_json.slt
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ select count(*) from read_ndjson([
----
204

statement error at least one url expected
statement error Expected at least one url.
select * from read_ndjson([]);

# Glob patterns not supported on HTTP

statement error Unexpected status code '404 Not Found'
statement error Unexpected status code '404 Not Found'
select * from read_ndjson(
'https://raw.githubusercontent.com/GlareDB/glaredb/main/testdata/sqllogictests_datasources_common/data/*.ndjson'
);
Expand All @@ -49,4 +49,4 @@ statement error Note that globbing is not supported for HTTP.
select * from read_ndjson(
'https://raw.githubusercontent.com/GlareDB/glaredb/main/testdata/sqllogictests_datasources_common/data/*.ndjson'
);

0 comments on commit faf8108

Please sign in to comment.