Skip to content

Commit

Permalink
Input: Add ytdl search (#210)
Browse files Browse the repository at this point in the history
* Add ytdl search

* fix fmt

* Remove compose, add tests, return AuxMetadata

* fix parsing of AuxMetadata and better test

* Fix playability of `YoutubeDl::new_search`

Refactors such that parsing of (ND)JSON is handled in only one location
now, which allows us to greatly simplify the actual `search` method. The
main change is that any `new_search` is now instantly playable.

---------

Co-authored-by: Kyle Simpson <kyleandrew.simpson@gmail.com>
  • Loading branch information
cycle-five and FelixMcFelix authored Dec 12, 2023
1 parent 873aeae commit d681b71
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 25 deletions.
18 changes: 7 additions & 11 deletions examples/serenity/voice/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,15 +278,7 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
},
};

if !url.starts_with("http") {
check_msg(
msg.channel_id
.say(&ctx.http, "Must provide a valid URL")
.await,
);

return Ok(());
}
let do_search = !url.starts_with("http");

let guild_id = msg.guild_id.unwrap();

Expand All @@ -305,8 +297,12 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult {
if let Some(handler_lock) = manager.get(guild_id) {
let mut handler = handler_lock.lock().await;

let src = YoutubeDl::new(http_client, url);
let _ = handler.play_input(src.into());
let mut src = if do_search {
YoutubeDl::new_search(http_client, url)
} else {
YoutubeDl::new(http_client, url)
};
let _ = handler.play_input(src.clone().into());

check_msg(msg.channel_id.say(&ctx.http, "Playing song").await);
} else {
Expand Down
127 changes: 113 additions & 14 deletions src/input/sources/ytdl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ use tokio::process::Command;

const YOUTUBE_DL_COMMAND: &str = "yt-dlp";

#[derive(Clone, Debug)]
enum QueryType {
Url(String),
Search(String),
}

/// A lazily instantiated call to download a file, finding its URL via youtube-dl.
///
/// By default, this uses yt-dlp and is backed by an [`HttpRequest`]. This handler
Expand All @@ -30,7 +36,7 @@ pub struct YoutubeDl {
program: &'static str,
client: Client,
metadata: Option<AuxMetadata>,
url: String,
query: QueryType,
}

impl YoutubeDl {
Expand All @@ -52,14 +58,63 @@ impl YoutubeDl {
program,
client,
metadata: None,
url,
query: QueryType::Url(url),
}
}

/// Creates a request to search youtube for an optionally specified number of videos matching `query`,
/// using "yt-dlp".
#[must_use]
pub fn new_search(client: Client, query: String) -> Self {
Self::new_search_ytdl_like(YOUTUBE_DL_COMMAND, client, query)
}

/// Creates a request to search youtube for an optionally specified number of videos matching `query`,
/// using `program`.
#[must_use]
pub fn new_search_ytdl_like(program: &'static str, client: Client, query: String) -> Self {
Self {
program,
client,
metadata: None,
query: QueryType::Search(query),
}
}

async fn query(&mut self) -> Result<Output, AudioStreamError> {
/// Runs a search for the given query, returning a list of up to `n_results`
/// possible matches which are `AuxMetadata` objects containing a valid URL.
///
/// Returns up to 5 matches by default.
pub async fn search(
&mut self,
n_results: Option<usize>,
) -> Result<Vec<AuxMetadata>, AudioStreamError> {
let n_results = n_results.unwrap_or(5);

Ok(match &self.query {
// Safer to just return the metadata for the pointee if possible
QueryType::Url(_) => vec![self.aux_metadata().await?],
QueryType::Search(_) => self
.query(n_results)
.await?
.into_iter()
.map(|v| v.as_aux_metadata())
.collect(),
})
}

async fn query(&mut self, n_results: usize) -> Result<Vec<Output>, AudioStreamError> {
let new_query;
let query_str = match &self.query {
QueryType::Url(url) => url,
QueryType::Search(query) => {
new_query = format!("ytsearch{n_results}:{query}");
&new_query
},
};
let ytdl_args = [
"-j",
&self.url,
query_str,
"-f",
"ba[abr>0][vcodec=none]/best",
"--no-playlist",
Expand All @@ -77,14 +132,35 @@ impl YoutubeDl {
})
})?;

// NOTE: must be mut for simd-json.
#[allow(clippy::unnecessary_mut_passed)]
let stdout: Output = crate::json::from_slice(&mut output.stdout[..])
if !output.status.success() {
return Err(AudioStreamError::Fail(
format!(
"{} failed with non-zero status code: {}",
self.program,
std::str::from_utf8(&output.stderr[..]).unwrap_or("<no error message>")
)
.into(),
));
}

// NOTE: must be split_mut for simd-json.
let out = output
.stdout
.split_mut(|&b| b == b'\n')
.filter_map(|x| (!x.is_empty()).then(|| crate::json::from_slice(x)))
.collect::<Result<Vec<Output>, _>>()
.map_err(|e| AudioStreamError::Fail(Box::new(e)))?;

self.metadata = Some(stdout.as_aux_metadata());
let meta = out
.first()
.ok_or_else(|| {
AudioStreamError::Fail(format!("no results found for '{query_str}'").into())
})?
.as_aux_metadata();

Ok(stdout)
self.metadata = Some(meta);

Ok(out)
}
}

Expand All @@ -103,11 +179,13 @@ impl Compose for YoutubeDl {
async fn create_async(
&mut self,
) -> Result<AudioStream<Box<dyn MediaSource>>, AudioStreamError> {
let stdout = self.query().await?;
// panic safety: `query` should have ensured > 0 results if `Ok`
let mut results = self.query(1).await?;
let result = results.swap_remove(0);

let mut headers = HeaderMap::default();

if let Some(map) = stdout.http_headers {
if let Some(map) = result.http_headers {
headers.extend(map.iter().filter_map(|(k, v)| {
Some((
HeaderName::from_bytes(k.as_bytes()).ok()?,
Expand All @@ -118,9 +196,9 @@ impl Compose for YoutubeDl {

let mut req = HttpRequest {
client: self.client.clone(),
request: stdout.url,
request: result.url,
headers,
content_length: stdout.filesize,
content_length: result.filesize,
};

req.create_async().await
Expand All @@ -135,7 +213,7 @@ impl Compose for YoutubeDl {
return Ok(meta.clone());
}

self.query().await?;
self.query(1).await?;

self.metadata.clone().ok_or_else(|| {
let msg: Box<dyn Error + Send + Sync + 'static> =
Expand Down Expand Up @@ -185,4 +263,25 @@ mod tests {

assert!(ytdl.aux_metadata().await.is_err());
}

#[tokio::test]
#[ntest::timeout(20_000)]
async fn ytdl_search_plays() {
let mut ytdl = YoutubeDl::new_search(Client::new(), "cloudkicker 94 days".into());
let res = ytdl.search(Some(1)).await;

let res = res.unwrap();
assert_eq!(res.len(), 1);

track_plays_passthrough(move || ytdl).await;
}

#[tokio::test]
#[ntest::timeout(20_000)]
async fn ytdl_search_3() {
let mut ytdl = YoutubeDl::new_search(Client::new(), "test".into());
let res = ytdl.search(Some(3)).await;

assert_eq!(res.unwrap().len(), 3);
}
}

0 comments on commit d681b71

Please sign in to comment.