diff --git a/src/cli/init.rs b/src/cli/init.rs index 496c75c70..0b5f58d3b 100644 --- a/src/cli/init.rs +++ b/src/cli/init.rs @@ -1,7 +1,9 @@ use crate::{config::get_default_author, consts}; +use anyhow::anyhow; use clap::Parser; use minijinja::{context, Environment}; use rattler_conda_types::Platform; +use std::io::{Error, ErrorKind}; use std::{fs, path::PathBuf}; /// Creates a new project @@ -40,7 +42,7 @@ const GITIGNORE_TEMPLATE: &str = r#"# pixi environments pub async fn execute(args: Args) -> anyhow::Result<()> { let env = Environment::new(); - let dir = args.path; + let dir = get_dir(args.path)?; let manifest_path = dir.join(consts::PROJECT_MANIFEST); let gitignore_path = dir.join(".gitignore"); @@ -53,7 +55,15 @@ pub async fn execute(args: Args) -> anyhow::Result<()> { fs::create_dir_all(&dir).ok(); // Write pixi.toml - let name = dir.file_name().unwrap().to_string_lossy(); + let name = dir + .file_name() + .ok_or_else(|| { + anyhow!( + "Cannot get file or directory name from the path: {}", + dir.to_string_lossy() + ) + })? + .to_string_lossy(); let version = "0.1.0"; let author = get_default_author(); let channels = if args.channels.is_empty() { @@ -93,3 +103,53 @@ pub async fn execute(args: Args) -> anyhow::Result<()> { Ok(()) } + +fn get_dir(path: PathBuf) -> Result { + if path.components().count() == 1 { + Ok(std::env::current_dir().unwrap_or_default().join(path)) + } else { + path.canonicalize().map_err(|e| match e.kind() { + ErrorKind::NotFound => Error::new( + ErrorKind::NotFound, + format!( + "Cannot find '{}' please make sure the folder is reachable", + path.to_string_lossy() + ), + ), + _ => Error::new( + ErrorKind::InvalidInput, + "Cannot canonicalize the given path", + ), + }) + } +} + +#[cfg(test)] +mod tests { + use crate::cli::init::get_dir; + use std::path::PathBuf; + + #[test] + fn test_get_name() { + assert_eq!( + get_dir(PathBuf::from(".")).unwrap(), + std::env::current_dir().unwrap() + ); + assert_eq!( + get_dir(PathBuf::from("test_folder")).unwrap(), + std::env::current_dir().unwrap().join("test_folder") + ); + assert_eq!( + get_dir(std::env::current_dir().unwrap()).unwrap(), + PathBuf::from(std::env::current_dir().unwrap().canonicalize().unwrap()) + ); + } + + #[test] + fn test_get_name_panic() { + match get_dir(PathBuf::from("invalid/path")) { + Ok(_) => panic!("Expected error, but got OK"), + Err(e) => assert_eq!(e.kind(), std::io::ErrorKind::NotFound), + } + } +}