Skip to content

Commit

Permalink
refactor(core): simplify validation
Browse files Browse the repository at this point in the history
If we double-request urls, probably better than being wayy complex.
  • Loading branch information
gadomski committed Sep 23, 2024
1 parent 71c6fbf commit ba2c4dc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 110 deletions.
2 changes: 1 addition & 1 deletion core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Added

- Validation, from the (now defunct) **stac-validate** ([#434](https://github.com/stac-utils/stac-rs/pull/434))
- Validation, from the (now defunct) **stac-validate** ([#434](https://github.com/stac-utils/stac-rs/pull/434), [#438](https://github.com/stac-utils/stac-rs/pull/438))

## [0.10.1] - 2024-09-20

Expand Down
3 changes: 1 addition & 2 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ object-store-all = [
"object-store-http",
]
reqwest = ["dep:reqwest"]
validate = ["dep:jsonschema", "dep:reqwest", "dep:tokio", "dep:tracing"]
validate = ["dep:jsonschema", "dep:reqwest", "dep:tokio"]
validate-blocking = ["validate", "tokio/rt"]

[dependencies]
Expand All @@ -68,7 +68,6 @@ reqwest = { version = "0.12", optional = true, features = ["json", "blocking"] }
serde = { version = "1", features = ["derive"] }
serde_json = { version = "1", features = ["preserve_order"] }
thiserror = "1"
tracing = { version = "0.1", optional = true }
tokio = { version = "1", optional = true }
url = "2"

Expand Down
16 changes: 0 additions & 16 deletions core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,6 @@ pub enum Error {
#[cfg(feature = "validate")]
TokioJoin(#[from] tokio::task::JoinError),

/// [tokio::sync::mpsc::error::SendError]
#[error(transparent)]
#[cfg(feature = "validate")]
TokioSend(
#[from]
tokio::sync::mpsc::error::SendError<(
url::Url,
tokio::sync::oneshot::Sender<crate::Result<std::sync::Arc<serde_json::Value>>>,
)>,
),

/// [tokio::sync::oneshot::error::RecvError]
#[error(transparent)]
#[cfg(feature = "validate")]
TokioRecv(#[from] tokio::sync::oneshot::error::RecvError),

/// [std::num::TryFromIntError]
#[error(transparent)]
TryFromInt(#[from] std::num::TryFromIntError),
Expand Down
114 changes: 23 additions & 91 deletions core/src/validate/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,10 @@ use std::{
pin::Pin,
sync::{Arc, Mutex},
};
use tokio::{
sync::{
mpsc::{error::TryRecvError, Receiver, Sender},
oneshot::Sender as OneshotSender,
RwLock,
},
task::JoinSet,
};
use tokio::{sync::RwLock, task::JoinSet};
use url::Url;

const SCHEMA_BASE: &str = "https://schemas.stacspec.org";
const BUFFER: usize = 10;

/// A cloneable structure for validating STAC.
#[derive(Clone, Debug)]
Expand All @@ -31,7 +23,7 @@ pub struct Validator {
cache: Arc<std::sync::RwLock<HashMap<Url, Arc<Value>>>>,
schemas: Arc<RwLock<HashMap<Url, Arc<JsonschemaValidator>>>>,
urls: Arc<Mutex<HashSet<Url>>>,
sender: Sender<(Url, OneshotSender<Result<Arc<Value>>>)>,
client: Client,
}

struct Resolver {
Expand Down Expand Up @@ -60,14 +52,12 @@ impl Validator {
};
let mut validation_options = JsonschemaValidator::options();
let _ = validation_options.with_resolver(resolver);
let (sender, receiver) = tokio::sync::mpsc::channel(BUFFER);
drop(tokio::spawn(async move { get_urls(receiver).await }));
Validator {
schemas: Arc::new(RwLock::new(schemas(&validation_options))),
validation_options,
cache,
urls,
sender,
client: Client::new(),
}
}

Expand Down Expand Up @@ -171,7 +161,7 @@ impl Validator {
}
}
while let Some(result) = join_set.join_next().await {
result??;
let _ = result??;
}
let object = if let Value::Object(o) = Arc::into_inner(value).unwrap() {
o
Expand Down Expand Up @@ -227,9 +217,7 @@ impl Validator {
return Ok(schema.clone());
}
}
let (sender, receiver) = tokio::sync::oneshot::channel();
self.sender.send((url.clone(), sender)).await?;
let value = receiver.await??;
let value = self.resolve(url.clone()).await?;
let schema = self
.validation_options
.build(&value)
Expand All @@ -242,15 +230,27 @@ impl Validator {
Ok(schema)
}

async fn resolve(&self, url: Url) -> Result<()> {
let (sender, receiver) = tokio::sync::oneshot::channel();
self.sender.send((url.clone(), sender)).await?;
let value = receiver.await??;
async fn resolve(&self, url: Url) -> Result<Arc<Value>> {
{
let cache = self.cache.read().unwrap();
if let Some(value) = cache.get(&url) {
return Ok(value.clone());
}
}
let value: Value = self
.client
.get(url.clone())
.send()
.await?
.error_for_status()?
.json()
.await?;
let value = Arc::new(value);
{
let mut cache = self.cache.write().unwrap();
let _ = cache.insert(url, value);
let _ = cache.insert(url, value.clone());
}
Ok(())
Ok(value)
}
}

Expand Down Expand Up @@ -414,74 +414,6 @@ fn cache() -> HashMap<Url, Arc<Value>> {
cache
}

async fn get_urls(mut receiver: Receiver<(Url, OneshotSender<Result<Arc<Value>>>)>) -> Result<()> {
let mut cache: HashMap<Url, Arc<Value>> = HashMap::new();
let mut gets: HashMap<Url, Vec<OneshotSender<Result<Arc<Value>>>>> = HashMap::new();
let client = Client::new();
let (local_sender, mut local_receiver) = tokio::sync::mpsc::channel(BUFFER);
loop {
match receiver.try_recv() {
Err(TryRecvError::Disconnected) => return Ok(()),
Err(TryRecvError::Empty) => match local_receiver.try_recv() {
Err(TryRecvError::Disconnected) => return Ok(()),
Err(TryRecvError::Empty) => tokio::task::yield_now().await,
Ok((url, result)) => {
let mut senders = gets
.remove(&url)
.expect("all sent values should be in gets");
match result {
Ok(value) => {
let value = Arc::<Value>::new(value);
let _ = cache.insert(url, value.clone());
for sender in senders {
sender.send(Ok(value.clone())).unwrap();
}
}
Err(err) => {
senders
.pop()
.expect("there should be at least one sender")
.send(Err(err))
.unwrap();
}
};
}
},
Ok((url, sender)) => {
if let Some(value) = cache.get(&url) {
sender.send(Ok(value.clone())).unwrap();
} else {
gets.entry(url.clone())
.or_insert_with(|| {
tracing::debug!("getting url: {}", url);
let local_sender = local_sender.clone();
let client = client.clone();
drop(tokio::spawn(async move {
match get(client, url.clone()).await {
Ok(value) => local_sender.send((url, Ok(value))).await,
Err(err) => local_sender.send((url, Err(err))).await,
}
}));
Vec::new()
})
.push(sender);
}
}
}
}
}

async fn get(client: Client, url: Url) -> Result<Value> {
client
.get(url)
.send()
.await?
.error_for_status()?
.json()
.await
.map_err(Error::from)
}

#[cfg(test)]
mod tests {
use super::Validator;
Expand Down

0 comments on commit ba2c4dc

Please sign in to comment.