diff --git a/Cargo.lock b/Cargo.lock index 09eec364..67b00e08 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1307,6 +1307,7 @@ version = "6.0.0" dependencies = [ "anyhow", "async-stream", + "async-trait", "axum 0.7.5", "backoff", "base64 0.21.7", diff --git a/crates/librqbit/Cargo.toml b/crates/librqbit/Cargo.toml index 03ec3572..955e8aaf 100644 --- a/crates/librqbit/Cargo.toml +++ b/crates/librqbit/Cargo.toml @@ -31,7 +31,12 @@ sha1w = { path = "../sha1w", default-features = false, package = "librqbit-sha1- dht = { path = "../dht", package = "librqbit-dht", version = "5.0.4" } librqbit-upnp = { path = "../upnp", version = "0.1.0" } -tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tokio = { version = "1", features = [ + "macros", + "rt-multi-thread", + "fs", + "io-util", +] } axum = { version = "0.7.4" } tower-http = { version = "0.5", features = ["cors", "trace"] } tokio-stream = "0.1" @@ -79,6 +84,7 @@ memmap2 = { version = "0.9.4" } lru = { version = "0.12.3", optional = true } mime_guess = { version = "2.0.5", default-features = false } tokio-socks = "0.5.2" +async-trait = "0.1.81" [build-dependencies] anyhow = "1" diff --git a/crates/librqbit/examples/custom_storage.rs b/crates/librqbit/examples/custom_storage.rs index f85fd379..2be41c97 100644 --- a/crates/librqbit/examples/custom_storage.rs +++ b/crates/librqbit/examples/custom_storage.rs @@ -71,7 +71,7 @@ async fn main() -> anyhow::Result<()> { Default::default(), SessionOptions { disable_dht_persistence: true, - persistence: false, + persistence: None, listen_port_range: None, enable_upnp_port_forwarding: false, ..Default::default() diff --git a/crates/librqbit/src/api.rs b/crates/librqbit/src/api.rs index 4a253aef..b23b8956 100644 --- a/crates/librqbit/src/api.rs +++ b/crates/librqbit/src/api.rs @@ -96,39 +96,43 @@ impl Api { .per_peer_stats_snapshot(filter)) } - pub fn api_torrent_action_pause(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_pause(&self, idx: TorrentId) -> Result { let handle = self.mgr_handle(idx)?; - handle - .pause() + self.session() + .pause(&handle) + .await .context("error pausing torrent") .with_error_status_code(StatusCode::BAD_REQUEST)?; Ok(Default::default()) } - pub fn api_torrent_action_start(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_start(&self, idx: TorrentId) -> Result { let handle = self.mgr_handle(idx)?; self.session .unpause(&handle) + .await .context("error unpausing torrent") .with_error_status_code(StatusCode::BAD_REQUEST)?; Ok(Default::default()) } - pub fn api_torrent_action_forget(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_forget(&self, idx: TorrentId) -> Result { self.session .delete(idx, false) + .await .context("error forgetting torrent")?; Ok(Default::default()) } - pub fn api_torrent_action_delete(&self, idx: TorrentId) -> Result { + pub async fn api_torrent_action_delete(&self, idx: TorrentId) -> Result { self.session .delete(idx, true) + .await .context("error deleting torrent with files")?; Ok(Default::default()) } - pub fn api_torrent_action_update_only_files( + pub async fn api_torrent_action_update_only_files( &self, idx: TorrentId, only_files: &HashSet, @@ -136,6 +140,7 @@ impl Api { let handle = self.mgr_handle(idx)?; self.session .update_only_files(&handle, only_files) + .await .context("error updating only_files")?; Ok(Default::default()) } diff --git a/crates/librqbit/src/http_api.rs b/crates/librqbit/src/http_api.rs index bb696353..a300d24c 100644 --- a/crates/librqbit/src/http_api.rs +++ b/crates/librqbit/src/http_api.rs @@ -368,28 +368,28 @@ impl HttpApi { State(state): State, Path(idx): Path, ) -> Result { - state.api_torrent_action_pause(idx).map(axum::Json) + state.api_torrent_action_pause(idx).await.map(axum::Json) } async fn torrent_action_start( State(state): State, Path(idx): Path, ) -> Result { - state.api_torrent_action_start(idx).map(axum::Json) + state.api_torrent_action_start(idx).await.map(axum::Json) } async fn torrent_action_forget( State(state): State, Path(idx): Path, ) -> Result { - state.api_torrent_action_forget(idx).map(axum::Json) + state.api_torrent_action_forget(idx).await.map(axum::Json) } async fn torrent_action_delete( State(state): State, Path(idx): Path, ) -> Result { - state.api_torrent_action_delete(idx).map(axum::Json) + state.api_torrent_action_delete(idx).await.map(axum::Json) } #[derive(Deserialize)] @@ -404,6 +404,7 @@ impl HttpApi { ) -> Result { state .api_torrent_action_update_only_files(idx, &req.only_files.into_iter().collect()) + .await .map(axum::Json) } diff --git a/crates/librqbit/src/lib.rs b/crates/librqbit/src/lib.rs index 8990e0e0..1f66a3bd 100644 --- a/crates/librqbit/src/lib.rs +++ b/crates/librqbit/src/lib.rs @@ -39,6 +39,7 @@ mod peer_connection; mod peer_info_reader; mod read_buf; mod session; +mod session_persistence; mod spawn_utils; pub mod storage; mod stream_connect; @@ -53,7 +54,7 @@ pub use dht; pub use peer_connection::PeerConnectionOptions; pub use session::{ AddTorrent, AddTorrentOptions, AddTorrentResponse, ListOnlyResponse, Session, SessionOptions, - SUPPORTED_SCHEMES, + SessionPersistenceConfig, SUPPORTED_SCHEMES, }; pub use spawn_utils::spawn as librqbit_spawn; pub use torrent_state::{ diff --git a/crates/librqbit/src/session.rs b/crates/librqbit/src/session.rs index 7526002b..70afda01 100644 --- a/crates/librqbit/src/session.rs +++ b/crates/librqbit/src/session.rs @@ -1,12 +1,10 @@ use std::{ - any::TypeId, borrow::Cow, collections::{HashMap, HashSet}, - io::{BufReader, BufWriter, Read}, + io::Read, net::SocketAddr, path::{Path, PathBuf}, - str::FromStr, - sync::Arc, + sync::{atomic::AtomicUsize, Arc}, time::Duration, }; @@ -15,6 +13,9 @@ use crate::{ merge_streams::merge_streams, peer_connection::PeerConnectionOptions, read_buf::ReadBuf, + session_persistence::{ + json::JsonSessionPersistenceStore, BoxSessionPersistenceStore, SessionPersistenceStore, + }, spawn_utils::BlockingSpawner, storage::{ filesystem::FilesystemStorageFactory, BoxStorageFactory, StorageFactoryExt, TorrentStorage, @@ -27,7 +28,7 @@ use crate::{ ManagedTorrentInfo, }; use anyhow::{bail, Context}; -use bencode::{bencode_serialize_to_writer, BencodeDeserializer}; +use bencode::bencode_serialize_to_writer; use buffers::{ByteBuf, ByteBufOwned, ByteBufT}; use bytes::Bytes; use clone_to_owned::CloneToOwned; @@ -48,7 +49,7 @@ use librqbit_core::{ }; use parking_lot::RwLock; use peer_binary_protocol::Handshake; -use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde::{Deserialize, Serialize}; use tokio::net::{TcpListener, TcpStream}; use tokio_stream::StreamExt; use tokio_util::sync::{CancellationToken, DropGuard}; @@ -80,155 +81,22 @@ fn torrent_from_bytes(bytes: Bytes) -> anyhow::Result { #[derive(Default)] pub struct SessionDatabase { - next_id: TorrentId, torrents: HashMap, } impl SessionDatabase { - fn add_torrent( - &mut self, - torrent: ManagedTorrentHandle, - preferred_id: Option, - ) -> TorrentId { - match preferred_id { - Some(id) if self.torrents.contains_key(&id) => { - warn!("id {id} already present in DB, ignoring \"preferred_id\" parameter"); - } - Some(id) => { - self.torrents.insert(id, torrent); - self.next_id = id.max(self.next_id).wrapping_add(1); - return id; - } - _ => {} - } - let idx = self.next_id; - self.torrents.insert(idx, torrent); - self.next_id += 1; - idx - } - - fn serialize(&self) -> SerializedSessionDatabase { - SerializedSessionDatabase { - torrents: self - .torrents - .iter() - // We don't support serializing / deserializing of other storage types. - .filter(|(_, torrent)| { - torrent - .storage_factory - .is_type_id(TypeId::of::()) - }) - .map(|(id, torrent)| { - ( - *id, - SerializedTorrent { - trackers: torrent - .info() - .trackers - .iter() - .map(|u| u.to_string()) - .collect(), - info_hash: torrent.info_hash().as_string(), - // TODO: this could take up too much space / time / resources to write on interval. - // Store this outside the JSON file - // - // torrent_bytes: torrent.info.torrent_bytes.clone(), - torrent_bytes: Bytes::new(), - info: torrent.info().info.clone(), - only_files: torrent.only_files().clone(), - is_paused: torrent - .with_state(|s| matches!(s, ManagedTorrentState::Paused(_))), - output_folder: torrent.info().options.output_folder.clone(), - }, - ) - }) - .collect(), - } + fn add_torrent(&mut self, torrent: ManagedTorrentHandle, id: TorrentId) { + self.torrents.insert(id, torrent); } } -#[derive(Serialize, Deserialize)] -struct SerializedTorrent { - info_hash: String, - #[serde( - serialize_with = "serialize_torrent", - deserialize_with = "deserialize_torrent" - )] - info: TorrentMetaV1Info, - #[serde( - serialize_with = "serialize_torrent_bytes", - deserialize_with = "deserialize_torrent_bytes", - default - )] - torrent_bytes: Bytes, - trackers: HashSet, - output_folder: PathBuf, - only_files: Option>, - is_paused: bool, -} - -fn serialize_torrent( - t: &TorrentMetaV1Info, - serializer: S, -) -> Result -where - S: Serializer, -{ - use base64::{engine::general_purpose, Engine as _}; - use serde::ser::Error; - let mut writer = Vec::new(); - bencode_serialize_to_writer(t, &mut writer).map_err(S::Error::custom)?; - let s = general_purpose::STANDARD_NO_PAD.encode(&writer); - s.serialize(serializer) -} - -fn deserialize_torrent<'de, D>(deserializer: D) -> Result, D::Error> -where - D: Deserializer<'de>, -{ - use base64::{engine::general_purpose, Engine as _}; - use serde::de::Error; - let s = String::deserialize(deserializer)?; - let b = general_purpose::STANDARD_NO_PAD - .decode(s) - .map_err(D::Error::custom)?; - TorrentMetaV1Info::::deserialize(&mut BencodeDeserializer::new_from_buf(&b)) - .map_err(D::Error::custom) -} - -fn serialize_torrent_bytes(t: &Bytes, serializer: S) -> Result -where - S: Serializer, -{ - use base64::{engine::general_purpose, Engine as _}; - let s = general_purpose::STANDARD_NO_PAD.encode(t); - s.serialize(serializer) -} - -fn deserialize_torrent_bytes<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, -{ - use base64::{engine::general_purpose, Engine as _}; - use serde::de::Error; - let s = String::deserialize(deserializer)?; - let b = general_purpose::STANDARD_NO_PAD - .decode(s) - .map_err(D::Error::custom)?; - Ok(b.into()) -} - -#[derive(Serialize, Deserialize)] -struct SerializedSessionDatabase { - torrents: HashMap, -} - pub struct Session { peer_id: Id20, dht: Option, - persistence_filename: PathBuf, + persistence: Option>, peer_opts: PeerConnectionOptions, spawner: BlockingSpawner, + next_id: AtomicUsize, db: RwLock, output_folder: PathBuf, @@ -425,7 +293,6 @@ pub fn read_local_file_including_stdin(filename: &str) -> anyhow::Result pub enum AddTorrent<'a> { Url(Cow<'a, str>), TorrentFileBytes(Bytes), - TorrentInfo(Box), } impl<'a> AddTorrent<'a> { @@ -458,11 +325,22 @@ impl<'a> AddTorrent<'a> { match self { Self::Url(s) => s.into_owned().into_bytes().into(), Self::TorrentFileBytes(b) => b, - Self::TorrentInfo(..) => unimplemented!(), } } } +pub enum SessionPersistenceConfig { + /// The filename for persistence. By default uses an OS-specific folder. + Json { folder: Option }, +} + +impl SessionPersistenceConfig { + pub fn default_json_persistence_folder() -> anyhow::Result { + let dir = get_configuration_directory("session")?; + Ok(dir.data_dir().to_owned()) + } +} + #[derive(Default)] pub struct SessionOptions { /// Turn on to disable DHT. @@ -476,9 +354,7 @@ pub struct SessionOptions { /// Turn on to dump session contents into a file periodically, so that on next start /// all remembered torrents will continue where they left off. - pub persistence: bool, - /// The filename for persistence. By default uses an OS-specific folder. - pub persistence_filename: Option, + pub persistence: Option, /// The peer ID to use. If not specified, a random one will be generated. pub peer_id: Option, @@ -557,11 +433,6 @@ impl Session { Self::new_with_opts(default_output_folder, SessionOptions::default()) } - pub fn default_persistence_filename() -> anyhow::Result { - let dir = get_configuration_directory("session")?; - Ok(dir.data_dir().join("session.json")) - } - pub fn cancellation_token(&self) -> &CancellationToken { &self.cancellation_token } @@ -576,15 +447,16 @@ impl Session { let peer_id = opts.peer_id.unwrap_or_else(generate_peer_id); let token = CancellationToken::new(); - let (tcp_listener, tcp_listen_port) = if let Some(port_range) = opts.listen_port_range { - let (l, p) = create_tcp_listener(port_range) - .await - .context("error listening on TCP")?; - info!("Listening on 0.0.0.0:{p} for incoming peer connections"); - (Some(l), Some(p)) - } else { - (None, None) - }; + let (tcp_listener, tcp_listen_port) = + if let Some(port_range) = opts.listen_port_range.clone() { + let (l, p) = create_tcp_listener(port_range) + .await + .context("error listening on TCP")?; + info!("Listening on 0.0.0.0:{p} for incoming peer connections"); + (Some(l), Some(p)) + } else { + (None, None) + }; let dht = if opts.disable_dht { None @@ -606,11 +478,31 @@ impl Session { Some(dht) }; let peer_opts = opts.peer_opts.unwrap_or_default(); - let persistence_filename = match opts.persistence_filename { - Some(filename) => filename, - None if !opts.persistence => PathBuf::new(), - None => Self::default_persistence_filename()?, - }; + + async fn persistence_factory( + opts: &SessionOptions, + ) -> anyhow::Result> { + match &opts.persistence { + Some(SessionPersistenceConfig::Json { folder }) => { + let folder = match folder.as_ref() { + Some(f) => f.clone(), + None => SessionPersistenceConfig::default_json_persistence_folder()?, + }; + + Ok(Some(Box::new( + JsonSessionPersistenceStore::new(folder) + .await + .context("error initializing JsonSessionPersistenceStore")?, + ))) + } + None => Ok(None), + } + } + + let persistence = persistence_factory(&opts) + .await + .context("error initializing session persistence store")?; + let spawner = BlockingSpawner::default(); let (disk_write_tx, disk_write_rx) = opts @@ -646,12 +538,13 @@ impl Session { let stream_connector = Arc::new(StreamConnector::from(proxy_config)); let session = Arc::new(Self { - persistence_filename, + persistence, peer_id, dht, peer_opts, spawner, output_folder: default_output_folder, + next_id: AtomicUsize::new(0), db: RwLock::new(Default::default()), _cancellation_token_drop_guard: token.clone().drop_guard(), cancellation_token: token, @@ -688,18 +581,34 @@ impl Session { } } - if opts.persistence { - info!( - "will use {:?} for session persistence", - session.persistence_filename - ); - if let Some(parent) = session.persistence_filename.parent() { - std::fs::create_dir_all(parent).with_context(|| { - format!("couldn't create directory {:?} for session storage", parent) - })?; + if let Some(persistence) = session.persistence.as_ref() { + info!("will use {persistence:?} for session persistence"); + + let mut ps = persistence.stream_all().await?; + let mut added_all = false; + let mut futs = FuturesUnordered::new(); + + while !added_all || !futs.is_empty() { + tokio::select! { + Some(res) = futs.next(), if !futs.is_empty() => { + if let Err(e) = res { + error!("error adding torrent to session: {e:?}"); + } + }, + st = ps.next(), if !added_all => { + if let Some(st) = st { + let (id, st) = st?; + let span = error_span!("add_torrent", info_hash=?st.info_hash()); + let (add_torrent, mut opts) = st.into_add_torrent()?; + opts.preferred_id = Some(id); + let fut = session.add_torrent(add_torrent, Some(opts)).instrument(span); + futs.push(fut); + } else { + added_all = true; + } + }, + } } - let persistence_task = session.clone().task_persistence(); - session.spawn(error_span!("session_persistence"), persistence_task); } Ok(session) @@ -707,29 +616,6 @@ impl Session { .boxed() } - async fn task_persistence(self: Arc) -> anyhow::Result<()> { - // Populate initial from the state filename - if let Err(e) = self.populate_from_stored().await { - error!("could not populate session from stored file: {:?}", e); - } - - let session = Arc::downgrade(&self); - drop(self); - - loop { - tokio::time::sleep(Duration::from_secs(10)).await; - let session = match session.upgrade() { - Some(s) => s, - None => break, - }; - if let Err(e) = session.dump_to_disk() { - error!("error dumping session to disk: {:?}", e); - } - } - - Ok(()) - } - async fn check_incoming_connection( &self, addr: SocketAddr, @@ -868,102 +754,6 @@ impl Session { tokio::time::sleep(Duration::from_secs(1)).await; } - async fn populate_from_stored(self: &Arc) -> anyhow::Result<()> { - let mut rdr = match std::fs::File::open(&self.persistence_filename) { - Ok(f) => BufReader::new(f), - Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(()), - Err(e) => { - return Err(e).context(format!( - "error opening session file {:?}", - self.persistence_filename - )) - } - }; - let db: SerializedSessionDatabase = - serde_json::from_reader(&mut rdr).context("error deserializing session database")?; - let mut futures = Vec::new(); - for (id, storrent) in db.torrents.into_iter() { - let trackers: Vec = storrent - .trackers - .into_iter() - .map(|t| ByteBufOwned::from(t.into_bytes())) - .collect(); - - let torrent_bytes = storrent.torrent_bytes; - - let add_torrent = if !torrent_bytes.is_empty() { - AddTorrent::TorrentFileBytes(torrent_bytes) - } else { - let info_hash = Id20::from_str(&storrent.info_hash)?; - debug!(?info_hash, "torrent added before 6.1.0, need to readd"); - let info = TorrentMetaV1Owned { - announce: trackers.first().cloned(), - announce_list: vec![trackers], - info: storrent.info, - comment: None, - created_by: None, - encoding: None, - publisher: None, - publisher_url: None, - creation_date: None, - info_hash, - }; - AddTorrent::TorrentInfo(Box::new(info)) - }; - - futures.push({ - let session = self.clone(); - async move { - session - .add_torrent( - add_torrent, - Some(AddTorrentOptions { - paused: storrent.is_paused, - output_folder: Some( - storrent - .output_folder - .to_str() - .context("broken path")? - .to_owned(), - ), - only_files: storrent.only_files, - overwrite: true, - preferred_id: Some(id), - ..Default::default() - }), - ) - .await - .map_err(|e| { - error!("error adding torrent from stored session: {:?}", e); - e - }) - } - }); - } - futures::future::join_all(futures).await; - Ok(()) - } - - fn dump_to_disk(&self) -> anyhow::Result<()> { - let tmp_filename = format!("{}.tmp", self.persistence_filename.to_str().unwrap()); - let mut tmp = BufWriter::new( - std::fs::OpenOptions::new() - .create(true) - .truncate(true) - .write(true) - .open(&tmp_filename) - .with_context(|| format!("error opening {:?}", tmp_filename))?, - ); - let serialized = self.db.read().serialize(); - serde_json::to_writer(&mut tmp, &serialized).context("error serializing")?; - drop(tmp); - - std::fs::rename(&tmp_filename, &self.persistence_filename) - .context("error renaming persistence file")?; - trace!(filename=?self.persistence_filename, "wrote persistence"); - Ok(()) - } - /// Run a callback given the currently managed torrents. pub fn with_torrents( &self, @@ -1073,15 +863,6 @@ impl Session { AddTorrent::TorrentFileBytes(bytes) => torrent_from_bytes(bytes) .context("error decoding torrent")? - , - AddTorrent::TorrentInfo(t) => { - // TODO: remove this branch entirely - ParsedTorrentFile{ - info: *t, - info_bytes: Default::default(), - torrent_bytes: Default::default(), - } - }, }; let trackers = torrent.info @@ -1214,7 +995,17 @@ impl Session { })); } + let id = if let Some(id) = opts.preferred_id { + id + } else if let Some(p) = self.persistence.as_ref() { + p.next_id().await? + } else { + self.next_id + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + }; + let mut builder = ManagedTorrentBuilder::new( + id, info, info_hash, torrent_bytes, @@ -1250,19 +1041,26 @@ impl Session { builder.peer_read_write_timeout(t); } - let (managed_torrent, id) = { + let managed_torrent = { let mut g = self.db.write(); - if let Some((id, handle)) = g.torrents.iter().find(|(_, t)| t.info_hash() == info_hash) - { - return Ok(AddTorrentResponse::AlreadyManaged(*id, handle.clone())); + if let Some((id, handle)) = g.torrents.iter().find_map(|(eid, t)| { + if t.info_hash() == info_hash || *eid == id { + Some((*eid, t.clone())) + } else { + None + } + }) { + return Ok(AddTorrentResponse::AlreadyManaged(id, handle)); } - let next_id = g.torrents.len(); - let managed_torrent = - builder.build(error_span!(parent: None, "torrent", id = next_id))?; - let id = g.add_torrent(managed_torrent.clone(), opts.preferred_id); - (managed_torrent, id) + let managed_torrent = builder.build(error_span!(parent: None, "torrent", id))?; + g.add_torrent(managed_torrent.clone(), id); + managed_torrent }; + if let Some(p) = self.persistence.as_ref() { + p.store(id, &managed_torrent).await?; + } + // Merge "initial_peers" and "peer_rx" into one stream. let peer_rx = merge_two_optional_streams( if !initial_peers.is_empty() { @@ -1289,7 +1087,7 @@ impl Session { self.db.read().torrents.get(&id).cloned() } - pub fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { + pub async fn delete(&self, id: TorrentId, delete_files: bool) -> anyhow::Result<()> { let removed = self .db .write() @@ -1301,6 +1099,12 @@ impl Session { debug!("error pausing torrent before deletion: {e:?}") } + if let Some(p) = self.persistence.as_ref() { + if let Err(e) = p.delete(id).await { + error!(error=?e, "error deleting torrent from database"); + } + } + let storage = removed .with_state_mut(|s| match s.take() { ManagedTorrentState::Initializing(p) => p.files.take().ok(), @@ -1374,7 +1178,21 @@ impl Session { Ok(merge_two_optional_streams(dht_rx, peer_rx)) } - pub fn unpause(self: &Arc, handle: &ManagedTorrentHandle) -> anyhow::Result<()> { + async fn try_update_persistence_metadata(&self, handle: &ManagedTorrentHandle) { + if let Some(p) = self.persistence.as_ref() { + if let Err(e) = p.update_metadata(handle.id(), handle).await { + warn!(storage=?p, error=?e, "error updating metadata") + } + } + } + + pub async fn pause(&self, handle: &ManagedTorrentHandle) -> anyhow::Result<()> { + handle.pause()?; + self.try_update_persistence_metadata(handle).await; + Ok(()) + } + + pub async fn unpause(self: &Arc, handle: &ManagedTorrentHandle) -> anyhow::Result<()> { let peer_rx = self.make_peer_rx( handle.info_hash(), handle.info().trackers.clone().into_iter().collect(), @@ -1382,15 +1200,17 @@ impl Session { handle.info().options.force_tracker_interval, )?; handle.start(peer_rx, false, self.cancellation_token.child_token())?; + self.try_update_persistence_metadata(handle).await; Ok(()) } - pub fn update_only_files( + pub async fn update_only_files( self: &Arc, handle: &ManagedTorrentHandle, only_files: &HashSet, ) -> anyhow::Result<()> { handle.update_only_files(only_files)?; + self.try_update_persistence_metadata(handle).await; Ok(()) } diff --git a/crates/librqbit/src/session_persistence/json.rs b/crates/librqbit/src/session_persistence/json.rs new file mode 100644 index 00000000..c8438ddd --- /dev/null +++ b/crates/librqbit/src/session_persistence/json.rs @@ -0,0 +1,231 @@ +use std::{any::TypeId, collections::HashMap, path::PathBuf}; + +use crate::{ + session::TorrentId, storage::filesystem::FilesystemStorageFactory, + torrent_state::ManagedTorrentHandle, ManagedTorrentState, +}; +use anyhow::{bail, Context}; +use async_trait::async_trait; +use futures::{stream::BoxStream, StreamExt}; +use itertools::Itertools; +use librqbit_core::Id20; +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tracing::{trace, warn}; + +use super::{SerializedTorrent, SessionPersistenceStore}; + +#[derive(Serialize, Deserialize, Default)] +struct SerializedSessionDatabase { + torrents: HashMap, +} + +pub struct JsonSessionPersistenceStore { + output_folder: PathBuf, + db_filename: PathBuf, + db_content: tokio::sync::RwLock, +} + +impl std::fmt::Debug for JsonSessionPersistenceStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSON database: {:?}", self.output_folder) + } +} + +impl JsonSessionPersistenceStore { + pub async fn new(output_folder: PathBuf) -> anyhow::Result { + let db_filename = output_folder.join("session.json"); + tokio::fs::create_dir_all(&output_folder) + .await + .with_context(|| { + format!( + "couldn't create directory {:?} for session storage", + output_folder + ) + })?; + + let db = match tokio::fs::File::open(&db_filename).await { + Ok(f) => { + let mut buf = Vec::new(); + let mut rdr = tokio::io::BufReader::new(f); + rdr.read_to_end(&mut buf).await?; + + serde_json::from_reader(&buf[..]).context("error deserializing session database")? + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Default::default(), + Err(e) => { + return Err(e).context(format!("error opening session file {:?}", db_filename)) + } + }; + + Ok(Self { + db_filename, + output_folder, + db_content: tokio::sync::RwLock::new(db), + }) + } + + async fn flush(&self) -> anyhow::Result<()> { + let tmp_filename = format!("{}.tmp", self.db_filename.to_str().unwrap()); + let mut tmp = tokio::fs::OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open(&tmp_filename) + .await + .with_context(|| format!("error opening {:?}", tmp_filename))?; + + let mut buf = Vec::new(); + serde_json::to_writer(&mut buf, &*self.db_content.read().await) + .context("error serializing")?; + tmp.write_all(&buf) + .await + .with_context(|| format!("error writing {tmp_filename:?}"))?; + + tokio::fs::rename(&tmp_filename, &self.db_filename) + .await + .context("error renaming persistence file")?; + trace!(filename=?self.db_filename, "wrote persistence"); + Ok(()) + } + + fn torrent_bytes_filename(&self, info_hash: &Id20) -> PathBuf { + self.output_folder.join(format!("{:?}.torrent", info_hash)) + } + + async fn update_db( + &self, + id: TorrentId, + torrent: &ManagedTorrentHandle, + write_torrent_file: bool, + ) -> anyhow::Result<()> { + if !torrent + .storage_factory + .is_type_id(TypeId::of::()) + { + bail!("storages other than FilesystemStorageFactory are not supported"); + } + + let st = SerializedTorrent { + trackers: torrent + .info() + .trackers + .iter() + .map(|u| u.to_string()) + .collect(), + info_hash: torrent.info_hash(), + // we don't serialize this here, but to a file instead. + torrent_bytes: Default::default(), + only_files: torrent.only_files().clone(), + is_paused: torrent.with_state(|s| matches!(s, ManagedTorrentState::Paused(_))), + output_folder: torrent.info().options.output_folder.clone(), + }; + + if write_torrent_file && !torrent.info().torrent_bytes.is_empty() { + let torrent_bytes_file = self.torrent_bytes_filename(&torrent.info_hash()); + match tokio::fs::OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&torrent_bytes_file) + .await + { + Ok(mut f) => { + if let Err(e) = f.write_all(&torrent.info().torrent_bytes).await { + warn!(error=?e, file=?torrent_bytes_file, "error writing torrent bytes") + } + } + Err(e) => { + warn!(error=?e, file=?torrent_bytes_file, "error opening torrent bytes file") + } + } + } + + self.db_content.write().await.torrents.insert(id, st); + self.flush().await?; + + Ok(()) + } +} + +#[async_trait] +impl SessionPersistenceStore for JsonSessionPersistenceStore { + async fn next_id(&self) -> anyhow::Result { + Ok(self + .db_content + .read() + .await + .torrents + .keys() + .copied() + .max() + .map(|max| max + 1) + .unwrap_or(0)) + } + + async fn delete(&self, id: TorrentId) -> anyhow::Result<()> { + if let Some(t) = self.db_content.write().await.torrents.remove(&id) { + self.flush().await?; + let tf = self.torrent_bytes_filename(&t.info_hash); + if let Err(e) = tokio::fs::remove_file(&tf).await { + warn!(error=?e, filename=?tf, "error removing torrent file"); + } + } + + Ok(()) + } + + async fn get(&self, id: TorrentId) -> anyhow::Result { + let mut st = self + .db_content + .read() + .await + .torrents + .get(&id) + .cloned() + .context("no torrent found")?; + let mut buf = Vec::new(); + let torrent_bytes_filename = self.torrent_bytes_filename(&st.info_hash); + let mut torrent_bytes_file = match tokio::fs::File::open(&torrent_bytes_filename).await { + Ok(f) => f, + Err(e) => { + warn!(error=?e, filename=?torrent_bytes_filename, "error opening torrent bytes file"); + return Ok(st); + } + }; + if let Err(e) = torrent_bytes_file.read_to_end(&mut buf).await { + warn!(error=?e, filename=?torrent_bytes_filename, "error reading torrent bytes file"); + } else { + st.torrent_bytes = buf.into(); + } + return Ok(st); + } + + async fn stream_all( + &self, + ) -> anyhow::Result>> { + let all_ids = self + .db_content + .read() + .await + .torrents + .keys() + .copied() + .collect_vec(); + Ok(futures::stream::iter(all_ids) + .then(move |id| async move { self.get(id).await.map(move |st| (id, st)) }) + .boxed()) + } + + async fn store(&self, id: TorrentId, torrent: &ManagedTorrentHandle) -> anyhow::Result<()> { + self.update_db(id, torrent, true).await + } + + async fn update_metadata( + &self, + id: TorrentId, + torrent: &ManagedTorrentHandle, + ) -> anyhow::Result<()> { + self.update_db(id, torrent, false).await + } +} diff --git a/crates/librqbit/src/session_persistence/mod.rs b/crates/librqbit/src/session_persistence/mod.rs new file mode 100644 index 00000000..c574d3c6 --- /dev/null +++ b/crates/librqbit/src/session_persistence/mod.rs @@ -0,0 +1,93 @@ +pub mod json; + +use std::{collections::HashSet, path::PathBuf}; + +use anyhow::Context; +use async_trait::async_trait; +use bytes::Bytes; +use futures::stream::BoxStream; +use librqbit_core::magnet::Magnet; +use librqbit_core::Id20; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +use crate::{ + session::TorrentId, torrent_state::ManagedTorrentHandle, AddTorrent, AddTorrentOptions, +}; + +#[derive(Serialize, Deserialize, Clone)] +pub struct SerializedTorrent { + #[serde( + serialize_with = "serialize_info_hash", + deserialize_with = "deserialize_info_hash" + )] + info_hash: Id20, + #[serde(skip)] + torrent_bytes: Bytes, + trackers: HashSet, + output_folder: PathBuf, + only_files: Option>, + is_paused: bool, +} + +impl SerializedTorrent { + pub fn info_hash(&self) -> &Id20 { + &self.info_hash + } + pub fn into_add_torrent(self) -> anyhow::Result<(AddTorrent<'static>, AddTorrentOptions)> { + let add_torrent = if !self.torrent_bytes.is_empty() { + AddTorrent::TorrentFileBytes(self.torrent_bytes) + } else { + let magnet = + Magnet::from_id20(self.info_hash, self.trackers.into_iter().collect()).to_string(); + AddTorrent::from_url(magnet) + }; + + let opts = AddTorrentOptions { + paused: self.is_paused, + output_folder: Some( + self.output_folder + .to_str() + .context("broken path")? + .to_owned(), + ), + only_files: self.only_files, + overwrite: true, + ..Default::default() + }; + + Ok((add_torrent, opts)) + } +} + +// TODO: make this info_hash first, ID-second. +#[async_trait] +pub trait SessionPersistenceStore: core::fmt::Debug + Send + Sync { + async fn next_id(&self) -> anyhow::Result; + async fn store(&self, id: TorrentId, torrent: &ManagedTorrentHandle) -> anyhow::Result<()>; + async fn delete(&self, id: TorrentId) -> anyhow::Result<()>; + async fn get(&self, id: TorrentId) -> anyhow::Result; + async fn update_metadata( + &self, + id: TorrentId, + torrent: &ManagedTorrentHandle, + ) -> anyhow::Result<()>; + async fn stream_all( + &self, + ) -> anyhow::Result>>; +} + +pub type BoxSessionPersistenceStore = Box; + +fn serialize_info_hash(id: &Id20, serializer: S) -> Result +where + S: Serializer, +{ + id.as_string().serialize(serializer) +} + +fn deserialize_info_hash<'de, D>(deserializer: D) -> Result +where + D: Deserializer<'de>, +{ + Id20::deserialize(deserializer) +} diff --git a/crates/librqbit/src/tests/e2e.rs b/crates/librqbit/src/tests/e2e.rs index dedc2e85..6485a55c 100644 --- a/crates/librqbit/src/tests/e2e.rs +++ b/crates/librqbit/src/tests/e2e.rs @@ -66,8 +66,7 @@ async fn test_e2e_download() { disable_dht: true, disable_dht_persistence: true, dht_config: None, - persistence: false, - persistence_filename: None, + persistence: None, peer_id: Some(peer_id), peer_opts: None, listen_port_range: Some(15100..17000), @@ -150,8 +149,7 @@ async fn test_e2e_download() { disable_dht: true, disable_dht_persistence: true, dht_config: None, - persistence: false, - persistence_filename: None, + persistence: None, listen_port_range: None, enable_upnp_port_forwarding: false, ..Default::default() @@ -230,7 +228,7 @@ async fn test_e2e_download() { } info!("handle is completed"); - session.delete(id, false).unwrap(); + session.delete(id, false).await.unwrap(); info!("deleted handle"); diff --git a/crates/librqbit/src/tests/e2e_stream.rs b/crates/librqbit/src/tests/e2e_stream.rs index a2ab08eb..61ec7410 100644 --- a/crates/librqbit/src/tests/e2e_stream.rs +++ b/crates/librqbit/src/tests/e2e_stream.rs @@ -28,7 +28,7 @@ async fn e2e_stream() -> anyhow::Result<()> { crate::SessionOptions { disable_dht: true, peer_id: Some(TestPeerMetadata::good().as_peer_id()), - persistence: false, + persistence: None, listen_port_range: Some(16001..16100), enable_upnp_port_forwarding: false, ..Default::default() @@ -72,7 +72,7 @@ async fn e2e_stream() -> anyhow::Result<()> { client_dir.path().into(), crate::SessionOptions { disable_dht: true, - persistence: false, + persistence: None, peer_id: Some(TestPeerMetadata::good().as_peer_id()), listen_port_range: None, enable_upnp_port_forwarding: false, diff --git a/crates/librqbit/src/torrent_state/mod.rs b/crates/librqbit/src/torrent_state/mod.rs index 7537dbc3..09e7389f 100644 --- a/crates/librqbit/src/torrent_state/mod.rs +++ b/crates/librqbit/src/torrent_state/mod.rs @@ -36,6 +36,7 @@ use tracing::warn; use crate::chunk_tracker::ChunkTracker; use crate::file_info::FileInfo; +use crate::session::TorrentId; use crate::spawn_utils::BlockingSpawner; use crate::storage::BoxStorageFactory; use crate::stream_connect::StreamConnector; @@ -114,6 +115,8 @@ pub struct ManagedTorrentInfo { } pub struct ManagedTorrent { + pub id: TorrentId, + // TODO: merge ManagedTorrent and ManagedTorrentInfo pub info: Arc, pub(crate) storage_factory: BoxStorageFactory, @@ -122,6 +125,10 @@ pub struct ManagedTorrent { } impl ManagedTorrent { + pub fn id(&self) -> TorrentId { + self.id + } + pub fn info(&self) -> &ManagedTorrentInfo { &self.info } @@ -344,7 +351,7 @@ impl ManagedTorrent { } /// Pause the torrent if it's live. - pub fn pause(&self) -> anyhow::Result<()> { + pub(crate) fn pause(&self) -> anyhow::Result<()> { let mut g = self.locked.write(); match &g.state { ManagedTorrentState::Live(live) => { @@ -501,6 +508,7 @@ impl ManagedTorrent { } pub(crate) struct ManagedTorrentBuilder { + id: TorrentId, info: TorrentMetaV1Info, output_folder: PathBuf, info_hash: Id20, @@ -521,6 +529,7 @@ pub(crate) struct ManagedTorrentBuilder { impl ManagedTorrentBuilder { pub fn new( + id: usize, info: TorrentMetaV1Info, info_hash: Id20, torrent_bytes: Bytes, @@ -529,6 +538,7 @@ impl ManagedTorrentBuilder { storage_factory: BoxStorageFactory, ) -> Self { Self { + id, info, info_hash, torrent_bytes, @@ -641,6 +651,7 @@ impl ManagedTorrentBuilder { self.storage_factory.create_and_init(&info)?, )); Ok(Arc::new(ManagedTorrent { + id: self.id, locked: RwLock::new(ManagedTorrentLocked { state: ManagedTorrentState::Initializing(initializing), only_files: self.only_files, diff --git a/crates/rqbit/src/main.rs b/crates/rqbit/src/main.rs index 782c370e..e5799b3c 100644 --- a/crates/rqbit/src/main.rs +++ b/crates/rqbit/src/main.rs @@ -13,7 +13,7 @@ use librqbit::{ }, tracing_subscriber_config_utils::{init_logging, InitLoggingOptions}, AddTorrent, AddTorrentOptions, AddTorrentResponse, Api, ListOnlyResponse, - PeerConnectionOptions, Session, SessionOptions, TorrentStatsState, + PeerConnectionOptions, Session, SessionOptions, SessionPersistenceConfig, TorrentStatsState, }; use size_format::SizeFormatterBinary as SF; use tracing::{error, error_span, info, trace_span, warn}; @@ -132,9 +132,13 @@ struct ServerStartOptions { long = "disable-persistence", help = "Disable server persistence. It will not read or write its state to disk." )] + + /// Disable session persistence. disable_persistence: bool, - #[arg(long = "persistence-filename")] - persistence_filename: Option, + + /// The folder to store session data in. By default uses OS specific folder. + #[arg(long = "persistence-folder")] + persistence_folder: Option, } #[derive(Parser)] @@ -297,8 +301,7 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { disable_dht_persistence: opts.disable_dht_persistence, dht_config: None, // This will be overriden by "server start" below if needed. - persistence: false, - persistence_filename: None, + persistence: None, peer_id: None, peer_opts: Some(PeerConnectionOptions { connect_timeout: Some(opts.peer_connect_timeout), @@ -389,9 +392,11 @@ async fn async_main(opts: Opts) -> anyhow::Result<()> { match &opts.subcommand { SubCommand::Server(server_opts) => match &server_opts.subcommand { ServerSubcommand::Start(start_opts) => { - sopts.persistence = !start_opts.disable_persistence; - sopts.persistence_filename = - start_opts.persistence_filename.clone().map(PathBuf::from); + if !start_opts.disable_persistence { + sopts.persistence = Some(SessionPersistenceConfig::Json { + folder: start_opts.persistence_folder.clone().map(PathBuf::from), + }) + } let session = Session::new_with_opts(PathBuf::from(&start_opts.output_folder), sopts) diff --git a/desktop/src-tauri/Cargo.lock b/desktop/src-tauri/Cargo.lock index 14f6de26..cc2ad56e 100644 --- a/desktop/src-tauri/Cargo.lock +++ b/desktop/src-tauri/Cargo.lock @@ -1847,6 +1847,7 @@ version = "6.0.0" dependencies = [ "anyhow", "async-stream", + "async-trait", "axum", "backoff", "base64 0.21.7", @@ -1897,6 +1898,7 @@ name = "librqbit-bencode" version = "2.2.3" dependencies = [ "anyhow", + "bytes", "librqbit-buffers", "librqbit-clone-to-owned", "librqbit-sha1-wrapper", @@ -1907,6 +1909,7 @@ dependencies = [ name = "librqbit-buffers" version = "3.0.1" dependencies = [ + "bytes", "librqbit-clone-to-owned", "serde", ] @@ -1914,12 +1917,16 @@ dependencies = [ [[package]] name = "librqbit-clone-to-owned" version = "2.2.1" +dependencies = [ + "bytes", +] [[package]] name = "librqbit-core" version = "3.9.0" dependencies = [ "anyhow", + "bytes", "data-encoding", "directories", "hex 0.4.3", @@ -1942,6 +1949,7 @@ version = "5.0.4" dependencies = [ "anyhow", "backoff", + "bytes", "chrono", "dashmap", "futures", @@ -1969,6 +1977,7 @@ dependencies = [ "bincode", "bitvec", "byteorder", + "bytes", "librqbit-bencode", "librqbit-buffers", "librqbit-clone-to-owned", diff --git a/desktop/src-tauri/src/config.rs b/desktop/src-tauri/src/config.rs index cef1abab..17058cd9 100644 --- a/desktop/src-tauri/src/config.rs +++ b/desktop/src-tauri/src/config.rs @@ -1,10 +1,10 @@ use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - path::PathBuf, + path::{Path, PathBuf}, time::Duration, }; -use librqbit::{dht::PersistentDht, Session}; +use librqbit::dht::PersistentDht; use serde::{Deserialize, Serialize}; use serde_with::serde_as; @@ -49,14 +49,35 @@ impl Default for RqbitDesktopConfigTcpListen { #[serde(default)] pub struct RqbitDesktopConfigPersistence { pub disable: bool, + + #[serde(default)] + pub folder: PathBuf, + + /// Deprecated, but keeping for backwards compat for serialized / deserialized config. + #[serde(default)] pub filename: PathBuf, } +impl RqbitDesktopConfigPersistence { + pub(crate) fn fix_backwards_compat(&mut self) { + if self.folder != Path::new("") { + return; + } + if self.filename != Path::new("") { + if let Some(parent) = self.filename.parent() { + self.folder = parent.to_owned(); + } + } + } +} + impl Default for RqbitDesktopConfigPersistence { fn default() -> Self { + let folder = librqbit::SessionPersistenceConfig::default_json_persistence_folder().unwrap(); Self { disable: false, - filename: Session::default_persistence_filename().unwrap(), + folder, + filename: PathBuf::new(), } } } diff --git a/desktop/src-tauri/src/main.rs b/desktop/src-tauri/src/main.rs index 2b03d637..36f6b531 100644 --- a/desktop/src-tauri/src/main.rs +++ b/desktop/src-tauri/src/main.rs @@ -21,6 +21,7 @@ use librqbit::{ dht::PersistentDhtConfig, tracing_subscriber_config_utils::{init_logging, InitLoggingOptions, InitLoggingResult}, AddTorrent, AddTorrentOptions, Api, ApiError, PeerConnectionOptions, Session, SessionOptions, + SessionPersistenceConfig, }; use parking_lot::RwLock; use serde::Serialize; @@ -42,7 +43,9 @@ struct State { fn read_config(path: &str) -> anyhow::Result { let rdr = BufReader::new(File::open(path)?); - Ok(serde_json::from_reader(rdr)?) + let mut config: RqbitDesktopConfig = serde_json::from_reader(rdr)?; + config.persistence.fix_backwards_compat(); + Ok(config) } fn write_config(path: &str, config: &RqbitDesktopConfig) -> anyhow::Result<()> { @@ -65,6 +68,17 @@ async fn api_from_config( init_logging: &InitLoggingResult, config: &RqbitDesktopConfig, ) -> anyhow::Result { + let persistence = if config.persistence.disable { + None + } else { + Some(SessionPersistenceConfig::Json { + folder: if config.persistence.folder == Path::new("") { + None + } else { + Some(config.persistence.folder.clone()) + }, + }) + }; let session = Session::new_with_opts( config.default_download_location.clone(), SessionOptions { @@ -74,8 +88,7 @@ async fn api_from_config( config_filename: Some(config.dht.persistence_filename.clone()), ..Default::default() }), - persistence: !config.persistence.disable, - persistence_filename: Some(config.persistence.filename.clone()), + persistence, peer_opts: Some(PeerConnectionOptions { connect_timeout: Some(config.peer_opts.connect_timeout), read_write_timeout: Some(config.peer_opts.read_write_timeout), @@ -266,7 +279,7 @@ async fn torrent_action_delete( state: tauri::State<'_, State>, id: usize, ) -> Result { - state.api()?.api_torrent_action_delete(id) + state.api()?.api_torrent_action_delete(id).await } #[tauri::command] @@ -274,7 +287,7 @@ async fn torrent_action_pause( state: tauri::State<'_, State>, id: usize, ) -> Result { - state.api()?.api_torrent_action_pause(id) + state.api()?.api_torrent_action_pause(id).await } #[tauri::command] @@ -282,7 +295,7 @@ async fn torrent_action_forget( state: tauri::State<'_, State>, id: usize, ) -> Result { - state.api()?.api_torrent_action_forget(id) + state.api()?.api_torrent_action_forget(id).await } #[tauri::command] @@ -290,7 +303,7 @@ async fn torrent_action_start( state: tauri::State<'_, State>, id: usize, ) -> Result { - state.api()?.api_torrent_action_start(id) + state.api()?.api_torrent_action_start(id).await } #[tauri::command] @@ -302,6 +315,7 @@ async fn torrent_action_configure( state .api()? .api_torrent_action_update_only_files(id, &only_files.into_iter().collect()) + .await } #[tauri::command] diff --git a/desktop/src/configuration.tsx b/desktop/src/configuration.tsx index a0157a50..9936b5b0 100644 --- a/desktop/src/configuration.tsx +++ b/desktop/src/configuration.tsx @@ -16,7 +16,7 @@ interface RqbitDesktopConfigTcpListen { interface RqbitDesktopConfigPersistence { disable: boolean; - filename: PathLike; + folder: PathLike; } interface RqbitDesktopConfigPeerOpts { diff --git a/desktop/src/configure.tsx b/desktop/src/configure.tsx index 55a1012b..acd225ec 100644 --- a/desktop/src/configure.tsx +++ b/desktop/src/configure.tsx @@ -130,7 +130,7 @@ export const ConfigModal: React.FC<{ }; const handleToggleChange: React.ChangeEventHandler = ( - e, + e ) => { const name: string = e.target.name; const [mainField, subField] = name.split(".", 2); @@ -166,7 +166,7 @@ export const ConfigModal: React.FC<{ text: "Error saving configuration", details: e, }); - }, + } ); }; @@ -292,10 +292,10 @@ export const ConfigModal: React.FC<{ />