Skip to content

Commit

Permalink
chore: Get most of the rpc tests working (#2472)
Browse files Browse the repository at this point in the history
  • Loading branch information
scsmithr authored Jan 22, 2024
1 parent aeecfc3 commit 9eab113
Show file tree
Hide file tree
Showing 11 changed files with 315 additions and 445 deletions.
1 change: 1 addition & 0 deletions crates/datafusion_ext/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ where
}
}

#[derive(Debug, Clone)]
pub struct IdentValue(String);

impl IdentValue {
Expand Down
2 changes: 1 addition & 1 deletion crates/protogen/src/sqlexec/physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ pub struct AnalyzeExec {
pub struct ExecutionPlanExtension {
#[prost(
oneof = "ExecutionPlanExtensionType",
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31"
tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32"
)]
pub inner: Option<ExecutionPlanExtensionType>,
}
Expand Down
203 changes: 112 additions & 91 deletions crates/sqlbuiltins/src/functions/table/object_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,31 +224,16 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
args: &[FuncParamValue],
_parent: RuntimePreference,
) -> Result<RuntimePreference> {
let mut args = args.iter();
let url_arg = args.next().unwrap().to_owned();

let urls: Vec<DatasourceUrl> = if url_arg.is_valid::<DatasourceUrl>() {
vec![url_arg.try_into()?]
} else {
url_arg.try_into()?
};

let mut urls = urls.iter().map(|url| match url.datasource_url_type() {
let urls = self.urls_from_args(args)?;
// All urls are of the same type, just need to get the runtime from the
// first.
Ok(match urls.first().unwrap().datasource_url_type() {
DatasourceUrlType::File => RuntimePreference::Local,
DatasourceUrlType::Http => RuntimePreference::Remote,
DatasourceUrlType::Gcs => RuntimePreference::Remote,
DatasourceUrlType::S3 => RuntimePreference::Remote,
DatasourceUrlType::Azure => RuntimePreference::Remote,
});
let first = urls.next().unwrap();

if urls.all(|url| std::mem::discriminant(&url) == std::mem::discriminant(&first)) {
Ok(first)
} else {
Err(ExtensionError::String(
"cannot mix different types of urls".to_owned(),
))
}
})
}

async fn create_provider(
Expand All @@ -257,24 +242,8 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
args: Vec<FuncParamValue>,
mut opts: HashMap<String, FuncParamValue>,
) -> Result<Arc<dyn TableProvider>> {
if args.is_empty() {
return Err(ExtensionError::InvalidNumArgs);
}

let mut args = args.into_iter();
let url_arg = args.next().unwrap().to_owned();

let urls: Vec<DatasourceUrl> = if url_arg.is_valid::<DatasourceUrl>() {
vec![url_arg.try_into()?]
} else {
url_arg.try_into()?
};

if urls.is_empty() {
return Err(ExtensionError::String(
"at least one url expected".to_owned(),
));
}
let urls = self.urls_from_args(&args)?;
let creds_ident = self.credentials_from_args(&args)?;

// Read in user provided options and use them to construct the format.
let mut format = Opts::read_options(&opts)?;
Expand All @@ -287,10 +256,7 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
Some(cmp.parse::<FileCompressionType>()?)
}
None => {
let path = urls
.first()
.ok_or_else(|| ExtensionError::String("at least one url expected".to_string()))?
.path();
let path = urls.first().expect("non-empty urls").path();
let path = std::path::Path::new(path.as_ref());
path.extension()
.and_then(|ext| ext.to_string_lossy().as_ref().parse().ok())
Expand All @@ -308,7 +274,7 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
(Arc<dyn ObjStoreAccess>, Vec<DatasourceUrl>),
> = HashMap::new();
for source_url in urls {
let access = get_store_access(ctx, &source_url, args.clone(), opts.clone())?;
let access = get_store_access(ctx, &source_url, creds_ident.as_ref(), opts.clone())?;
let base_url = access
.base_url()
.map_err(|e| ExtensionError::Access(Box::new(e)))?;
Expand Down Expand Up @@ -340,6 +306,59 @@ impl<Opts: OptionReader> TableFunc for ObjScanTableFunc<Opts> {
}
}

impl<Opts> ObjScanTableFunc<Opts> {
/// Get data source urls form the function arguments.
///
/// The returned vec is guaranteed to have all urls be of the same data
/// source type, and will contain at least one url.
fn urls_from_args(&self, args: &[FuncParamValue]) -> Result<Vec<DatasourceUrl>> {
let mut args = args.iter();
let url_arg = match args.next() {
Some(arg) => arg.to_owned(),
None => {
return Err(ExtensionError::String(
"Expected at least one argument.".to_string(),
))
}
};

// TODO: wtf?
let urls: Vec<DatasourceUrl> = if url_arg.is_valid::<DatasourceUrl>() {
vec![url_arg.try_into()?]
} else {
url_arg.try_into()?
};

if urls.is_empty() {
return Err(ExtensionError::String(
"Expected at least one url.".to_string(),
));
}

let first = urls.first().unwrap();
if !urls
.iter()
.all(|url| url.datasource_url_type() == first.datasource_url_type())
{
return Err(ExtensionError::String(
"Cannot mix different types of urls.".to_string(),
));
}

Ok(urls)
}

/// Try to pull an identifier for credentials out of arguments.
///
/// Credential identifiers are expected to be the second argument to the
/// function.
fn credentials_from_args(&self, args: &[FuncParamValue]) -> Result<Option<IdentValue>> {
args.get(1)
.map(|arg| IdentValue::try_from(arg.clone()))
.transpose()
}
}

/// Gets a table provider for the files at location.
///
/// If the file is detected to be local, the table provider will be wrapped in a
Expand All @@ -359,61 +378,24 @@ async fn get_table_provider(
Ok(prov)
}

/// Get's an object store accessor for the provided url.
///
/// If the object store requires credentials, `creds_ident` can be provided to
/// lookup saved credentials in the catalog. Otherwise individual values (access
/// keys, etc) will be pulled out of `opts`.
fn get_store_access(
ctx: &dyn TableFuncContextProvider,
source_url: &DatasourceUrl,
mut args: vec::IntoIter<FuncParamValue>,
creds_ident: Option<&IdentValue>,
mut opts: HashMap<String, FuncParamValue>,
) -> Result<Arc<dyn ObjStoreAccess>> {
let access: Arc<dyn ObjStoreAccess> = match args.len() {
0 => {
// Raw credentials or No credentials
match source_url.datasource_url_type() {
DatasourceUrlType::Http => create_http_store_access(source_url)?,
DatasourceUrlType::File => create_local_store_access(ctx)?,
DatasourceUrlType::Gcs => {
let service_account_key = opts
.remove("service_account_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_gcs_table_provider(source_url, service_account_key)?
}
DatasourceUrlType::S3 => {
let access_key_id = opts
.remove("access_key_id")
.map(FuncParamValue::try_into)
.transpose()?;

let secret_access_key = opts
.remove("secret_access_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_s3_store_access(source_url, &mut opts, access_key_id, secret_access_key)?
}
DatasourceUrlType::Azure => {
let access_key = opts
.remove("access_key")
.map(FuncParamValue::try_into)
.transpose()?;
let account = opts
.remove("account_name")
.map(FuncParamValue::try_into)
.transpose()?;

create_azure_store_access(source_url, account, access_key)?
}
}
}
1 => {
// Credentials object
let creds: IdentValue = args.next().unwrap().try_into()?;
let access: Arc<dyn ObjStoreAccess> = match creds_ident {
Some(ident) => {
let creds = ctx
.get_session_catalog()
.resolve_credentials(creds.as_str())
.resolve_credentials(ident.as_str())
.ok_or(ExtensionError::String(format!(
"missing credentials object: {creds}"
"missing credentials object: {ident}"
)))?;

match source_url.datasource_url_type() {
Expand Down Expand Up @@ -472,7 +454,46 @@ fn get_store_access(
}
}
}
_ => return Err(ExtensionError::InvalidNumArgs),
None => {
// Raw credentials or No credentials
match source_url.datasource_url_type() {
DatasourceUrlType::Http => create_http_store_access(source_url)?,
DatasourceUrlType::File => create_local_store_access(ctx)?,
DatasourceUrlType::Gcs => {
let service_account_key = opts
.remove("service_account_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_gcs_table_provider(source_url, service_account_key)?
}
DatasourceUrlType::S3 => {
let access_key_id = opts
.remove("access_key_id")
.map(FuncParamValue::try_into)
.transpose()?;

let secret_access_key = opts
.remove("secret_access_key")
.map(FuncParamValue::try_into)
.transpose()?;

create_s3_store_access(source_url, &mut opts, access_key_id, secret_access_key)?
}
DatasourceUrlType::Azure => {
let access_key = opts
.remove("access_key")
.map(FuncParamValue::try_into)
.transpose()?;
let account = opts
.remove("account_name")
.map(FuncParamValue::try_into)
.transpose()?;

create_azure_store_access(source_url, account, access_key)?
}
}
}
};

Ok(access)
Expand Down
Loading

0 comments on commit 9eab113

Please sign in to comment.