From 2859e3ab00f09fbcc4ae259b773fb6ae539ee375 Mon Sep 17 00:00:00 2001 From: Rebecca Turner Date: Thu, 10 Aug 2023 11:32:02 -0700 Subject: [PATCH] Move `Ghci` out of the `Arc>` In #11 I removed the `Weak` pointers from the ghci task structs, favoring the `mpsc` channels for communication over locking a mutex. This means I can make several simplifications, in particular turning `Ghci::reload` into a real method that uses `std::mem::replace`. --- src/ghci/mod.rs | 55 ++++++++++++++++++++----------------------------- src/watcher.rs | 8 +++---- 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/src/ghci/mod.rs b/src/ghci/mod.rs index cf75769e..4d52ca3f 100644 --- a/src/ghci/mod.rs +++ b/src/ghci/mod.rs @@ -99,7 +99,7 @@ impl Ghci { error_path: Option, setup_commands: Vec, test_command: Option, - ) -> miette::Result>> { + ) -> miette::Result { let start_instant = Instant::now(); let mut child = { @@ -132,7 +132,7 @@ impl Ghci { let stderr_handle = task::spawn(async { Ok(()) }); let stdin_handle = task::spawn(async { Ok(()) }); - let ret = Arc::new(Mutex::new(Ghci { + let mut ret = Ghci { command: command_arc, process: child, stdout: stdout_handle, @@ -144,7 +144,7 @@ impl Ghci { error_path: error_path.clone(), setup_commands: setup_commands.clone(), test_command, - })); + }; // Three tasks for my three beautiful streams. let stdout = task::spawn( @@ -185,7 +185,6 @@ impl Ghci { // Now, replace the `JoinHandle`s with the actual values. { - let mut ret = ret.lock().await; ret.stdout = stdout; ret.stderr = stderr; ret.stdin = stdin; @@ -222,16 +221,15 @@ impl Ghci { let span = tracing::debug_span!("Start-of-session sync"); let _enter = span.enter(); // Sync up for any prompts. - let mut guard = ret.lock().await; - guard.sync().await?; + ret.sync().await?; // Get the initial list of loaded modules. - guard.refresh_modules().await?; + ret.refresh_modules().await?; } tracing::info!("ghci started in {:.2?}", start_instant.elapsed()); // Run the user-provided test command, if any. - ret.lock().await.test().await?; + ret.test().await?; Ok(ret) } @@ -240,10 +238,7 @@ impl Ghci { /// /// This may fully restart the `ghci` process. #[instrument(skip_all, level = "debug")] - pub async fn reload( - this: Arc>, - events: Vec, - ) -> miette::Result>> { + pub async fn reload(&mut self, events: Vec) -> miette::Result<()> { // TODO: This method is pretty big -- we should break it up. // Once we know which paths were modified and which paths were removed, we can combine @@ -254,7 +249,6 @@ impl Ghci { let mut needs_reload = Vec::new(); let mut add = Vec::new(); { - let guard = this.lock().await; for event in events { match event { FileEvent::Remove(path) => { @@ -270,7 +264,7 @@ impl Ghci { break; } FileEvent::Modify(path) => { - if guard.modules.contains_source_path(&path)? { + if self.modules.contains_source_path(&path)? { // We can `:reload` paths `ghci` already has loaded. tracing::debug!(?path, "Needs reload"); needs_reload.push(path); @@ -290,16 +284,15 @@ impl Ghci { format_bulleted_list(&needs_restart) ); // TODO: Probably also need a restart hook / `.cabal` hook / similar. - let mut guard = this.lock().await; - guard.stop().await?; - let command = guard.command.clone(); - return Self::new( - command, - guard.error_path.clone(), - guard.setup_commands.clone(), - guard.test_command.clone(), + self.stop().await?; + let new = Self::new( + self.command.clone(), + self.error_path.clone(), + self.setup_commands.clone(), + self.test_command.clone(), ) - .await; + .await?; + let _ = std::mem::replace(self, new); } let needs_add_or_reload = !add.is_empty() || !needs_reload.is_empty(); @@ -311,7 +304,7 @@ impl Ghci { format_bulleted_list(&add) ); for path in add { - let add_result = this.lock().await.add_module(path).await?; + let add_result = self.add_module(path).await?; if let Some(CompilationResult::Err) = add_result { compilation_failed = true; } @@ -324,9 +317,7 @@ impl Ghci { format_bulleted_list(&needs_reload) ); let (sender, receiver) = oneshot::channel(); - let guard = this.lock().await; - guard - .stdin_channel + self.stdin_channel .send(StdinEvent::Reload(sender)) .await .into_diagnostic()?; @@ -342,12 +333,10 @@ impl Ghci { } else { // If we loaded or reloaded any modules, we should run tests. let (sender, receiver) = oneshot::channel(); - let guard = this.lock().await; - guard - .stdin_channel + self.stdin_channel .send(StdinEvent::Test { sender, - test_command: guard.test_command.clone(), + test_command: self.test_command.clone(), }) .await .into_diagnostic()?; @@ -355,9 +344,9 @@ impl Ghci { } } - this.lock().await.sync().await?; + self.sync().await?; - Ok(this) + Ok(()) } /// Sync the input and output streams of this `ghci` session. This will block until all input diff --git a/src/watcher.rs b/src/watcher.rs index b6dc7e73..da2ed6bb 100644 --- a/src/watcher.rs +++ b/src/watcher.rs @@ -6,7 +6,6 @@ use std::time::Duration; use camino::Utf8PathBuf; use tokio::runtime::Handle; -use tokio::sync::Mutex; use tokio::task::block_in_place; use tokio::task::JoinHandle; use tracing::instrument; @@ -39,7 +38,7 @@ pub struct Watcher { impl Watcher { /// Create a new [`Watcher`] from a [`Ghci`] session. pub fn new( - ghci: Arc>, + ghci: Ghci, watch: &[Utf8PathBuf], debounce: Duration, poll: Option, @@ -70,9 +69,8 @@ impl Watcher { } } -#[derive(Clone)] struct ActionHandler { - ghci: Arc>, + ghci: Ghci, } impl ActionHandler { @@ -94,7 +92,7 @@ impl ActionHandler { let events = file_events_from_action(&action)?; if !events.is_empty() { - self.ghci = Ghci::reload(self.ghci.clone(), events).await?; + self.ghci.reload(events).await?; } Ok(())