Skip to content

Commit

Permalink
Platform specific directory handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Nukesor committed Oct 15, 2020
1 parent c86ba90 commit 5962230
Show file tree
Hide file tree
Showing 11 changed files with 158 additions and 67 deletions.
20 changes: 2 additions & 18 deletions client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@ use std::env::{current_dir, vars};
use std::io::{self, Write as std_Write};

use anyhow::{bail, Context, Result};
use async_std::net::TcpStream;
#[cfg(not(windows))]
use async_std::os::unix::net::UnixStream;
use log::error;

use pueue::message::*;
use pueue::platform::socket::*;
use pueue::protocol::*;
use pueue::settings::Settings;

Expand Down Expand Up @@ -58,21 +56,7 @@ impl Client {
}
};

let mut socket: Box<dyn GenericSocket> = if let Some(socket_path) = unix_socket_path {
let stream = UnixStream::connect(socket_path).await?;
Box::new(stream)
} else {
// Don't allow anything else than loopback until we have proper crypto
// let address = format!("{}:{}", address, port);
let address = format!("127.0.0.1:{}", port.unwrap());

// Connect to socket
let socket = TcpStream::connect(&address)
.await
.context("Failed to connect to the daemon. Did you start it?")?;

Box::new(socket)
};
let mut socket = get_client(unix_socket_path, port).await?;

// Send the secret to the daemon
let secret = settings.shared.secret.clone().into_bytes();
Expand Down
11 changes: 1 addition & 10 deletions daemon/socket.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
use std::sync::mpsc::Sender;

use anyhow::{bail, Result};
use async_std::net::TcpListener;
#[cfg(not(windows))]
use async_std::os::unix::net::UnixListener;
use async_std::task;
use log::{debug, info, warn};

Expand Down Expand Up @@ -38,13 +35,7 @@ pub async fn accept_incoming(sender: Sender<Message>, state: SharedState, opt: O
}
};

let listener: Box<dyn GenericListener> = if let Some(socket_path) = unix_socket_path {
Box::new(UnixListener::bind(socket_path).await?)
} else {
let port = port.unwrap();
let address = format!("127.0.0.1:{}", port);
Box::new(TcpListener::bind(address).await?)
};
let listener = get_listener(unix_socket_path, port).await?;

loop {
// Poll if we have a new incoming connection.
Expand Down
20 changes: 12 additions & 8 deletions shared/platform/linux/directories.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ use users::{get_current_uid, get_user_by_uid};
pub fn get_unix_socket_path() -> Result<String> {
// Get the user and their username
let user = get_user_by_uid(get_current_uid())
.ok_or(anyhow!("Couldn't find username for current user"))?;
.ok_or_else(|| anyhow!("Couldn't find username for current user"))?;
let username = user.name().to_string_lossy();

// Create the socket in the default /tmp/ directory
let path = Path::new("/tmp/").join(format!("pueue_{}.socket", username));
Ok(path.to_string_lossy().into())
// Create the socket in the default pueue path
let pueue_path = PathBuf::from(default_pueue_path()?);
let path = pueue_path.join(format!("pueue_{}.socket", username));
Ok(path
.to_str()
.ok_or_else(|| anyhow!("Failed to parse log path (Weird characters?)"))?
.to_string())
}

fn get_home_dir() -> Result<PathBuf> {
Expand All @@ -35,25 +39,25 @@ pub fn default_pueue_path() -> Result<String> {
let path = get_home_dir()?.join(".local/share/pueue");
Ok(path
.to_str()
.ok_or(anyhow!("Failed to parse log path (Weird characters?)"))?
.ok_or_else(|| anyhow!("Failed to parse log path (Weird characters?)"))?
.to_string())
}

#[cfg(test)]
mod tests {
use super::*;

use std::fs::{remove_file, File};
use std::fs::{create_dir_all, remove_file, File};
use std::io::prelude::*;

use anyhow::Result;

#[test]
fn test_create_unix_socket() -> Result<()> {
let path = get_unix_socket_path()?;
create_dir_all(default_pueue_path()?)?;

// If pueue is currently running on the system,
// simply accept that we found the correct path
// If pueue is currently running on the system, simply accept that we found the correct path
if PathBuf::from(&path).exists() {
return Ok(());
}
Expand Down
17 changes: 17 additions & 0 deletions shared/platform/macos/directories.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
use std::path::{Path, PathBuf};

use anyhow::{anyhow, Result};
use users::{get_current_uid, get_user_by_uid};

/// Get the default unix socket path for the current user
pub fn get_unix_socket_path() -> Result<String> {
// Get the user and their username
let user = get_user_by_uid(get_current_uid())
.ok_or(anyhow!("Couldn't find username for current user"))?;
let username = user.name().to_string_lossy();

// Create the socket in the default pueue path
let pueue_path = PathBuf::from(default_pueue_path()?);
let path = pueue_path.join(format!("pueue_{}.socket", username));
Ok(path
.to_str()
.ok_or(anyhow!("Failed to parse log path (Weird characters?)"))?
.to_string())
}

fn get_home_dir() -> Result<PathBuf> {
dirs::home_dir().ok_or_else(|| anyhow!("Couldn't resolve home dir"))
Expand Down
15 changes: 15 additions & 0 deletions shared/platform/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
/// Linux specific stuff
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
pub mod linux;
/// MacOs specific stuff
#[cfg(target_os = "macos")]
pub mod macos;
/// Shared unix stuff
#[cfg(not(target_os = "windows"))]
pub mod unix;
/// Windows specific stuff
#[cfg(target_os = "windows")]
pub mod windows;

/// Shared unix stuff for sockets
#[cfg(not(target_os = "windows"))]
pub use self::unix::socket;

/// Windows specific socket stuff
#[cfg(target_os = "windows")]
pub use self::windows::socket;

// The next block is platform specific directory functions
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
pub use self::linux::directories;

Expand Down
1 change: 1 addition & 0 deletions shared/platform/unix/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod socket;
66 changes: 66 additions & 0 deletions shared/platform/unix/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use anyhow::{Context, Result};
use async_std::io::{Read, Write};
use async_std::net::{TcpListener, TcpStream};
use async_std::os::unix::net::{UnixListener, UnixStream};
use async_trait::async_trait;

pub trait GenericSocket: Read + Write + Unpin + Send + Sync {}
pub type SocketBox = Box<dyn GenericSocket>;

#[async_trait]
pub trait GenericListener: Sync + Send {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>>;
}

#[async_trait]
impl GenericListener for TcpListener {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>> {
let (socket, _) = self.accept().await?;
Ok(Box::new(socket))
}
}

#[async_trait]
impl GenericListener for UnixListener {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>> {
let (socket, _) = self.accept().await?;
Ok(Box::new(socket))
}
}

impl GenericSocket for TcpStream {}
impl GenericSocket for UnixStream {}

pub async fn get_client(
unix_socket_path: Option<String>,
port: Option<String>,
) -> Result<Box<dyn GenericSocket>> {
if let Some(socket_path) = unix_socket_path {
let stream = UnixStream::connect(socket_path).await?;
return Ok(Box::new(stream));
}

// Don't allow anything else than loopback until we have proper crypto
// let address = format!("{}:{}", address, port);
let address = format!("127.0.0.1:{}", port.unwrap());

// Connect to socket
let socket = TcpStream::connect(&address)
.await
.context("Failed to connect to the daemon. Did you start it?")?;

Ok(Box::new(socket))
}

pub async fn get_listener(
unix_socket_path: Option<String>,
port: Option<String>,
) -> Result<Box<dyn GenericListener>> {
if let Some(socket_path) = unix_socket_path {
return Ok(Box::new(UnixListener::bind(socket_path).await?));
}

let port = port.unwrap();
let address = format!("127.0.0.1:{}", port);
Ok(Box::new(TcpListener::bind(address).await?))
}
11 changes: 11 additions & 0 deletions shared/platform/windows/directories.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@ use std::path::{Path, PathBuf};

use anyhow::{anyhow, Result};

/// Get the default unix socket path for the current user
pub fn get_unix_socket_path() -> Result<String> {
// Create the socket in the default pueue path
let pueue_path = PathBuf::from(default_pueue_path()?);
let path = pueue_path.join("pueue.socket");
Ok(path
.to_str()
.ok_or(anyhow!("Failed to parse log path (Weird characters?)"))?
.to_string())
}

fn get_home_dir() -> Result<PathBuf> {
dirs::home_dir().ok_or_else(|| anyhow!("Couldn't resolve home dir"))
}
Expand Down
1 change: 1 addition & 0 deletions shared/platform/windows/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod directories;
pub mod socket;
31 changes: 31 additions & 0 deletions shared/platform/windows/socket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use anyhow::Result;
use async_std::io::{Read, Write};
use async_std::net::{TcpListener, TcpStream};
use async_trait::async_trait;

pub trait GenericSocket: Read + Write + Unpin + Send + Sync {}
pub type SocketBox = Box<dyn GenericSocket>;

#[async_trait]
pub trait GenericListener: Sync + Send {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>>;
}

#[async_trait]
impl GenericListener for TcpListener {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>> {
let (socket, _) = self.accept().await?;
Ok(Box::new(socket))
}
}

impl GenericSocket for TcpStream {}

pub async fn get_listener(
unix_socket_path: Option<String>,
port: Option<String>,
) -> Result<Box<dyn GenericListener>> {
let port = port.unwrap();
let address = format!("127.0.0.1:{}", port);
Ok(Box::new(TcpListener::bind(address).await?))
}
32 changes: 1 addition & 31 deletions shared/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,13 @@
use std::io::Cursor;

use anyhow::{Context, Result};
use async_std::io::{Read, Write};
use async_std::net::{TcpListener, TcpStream};
#[cfg(not(windows))]
use async_std::os::unix::net::{UnixListener, UnixStream};
use async_std::prelude::*;
use async_trait::async_trait;
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use log::debug;

use crate::message::*;

#[async_trait]
pub trait GenericListener: Sync + Send {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>>;
}

#[async_trait]
impl GenericListener for TcpListener {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>> {
let (socket, _) = self.accept().await?;
Ok((Box::new(socket)))
}
}

#[cfg(not(windows))]
#[async_trait]
impl GenericListener for UnixListener {
async fn accept<'a>(&'a self) -> Result<Box<dyn GenericSocket>> {
let (socket, _) = self.accept().await?;
Ok((Box::new(socket)))
}
}

pub trait GenericSocket: Read + Write + Unpin + Send + Sync {}
pub type SocketBox = Box<dyn GenericSocket>;
impl GenericSocket for TcpStream {}
impl GenericSocket for UnixStream {}
pub use crate::platform::socket::*;

/// Convenience wrapper around send_bytes.
/// Deserialize a message and feed the bytes into send_bytes.
Expand Down

0 comments on commit 5962230

Please sign in to comment.