From c2b0f5804c15e3dfb0d700683f995a8f5351b2a9 Mon Sep 17 00:00:00 2001 From: Bugen Zhao Date: Tue, 2 Apr 2024 16:01:49 +0800 Subject: [PATCH] fix: correctly implement default subcommand for `single-node` (#16080) Signed-off-by: Bugen Zhao --- src/cmd_all/src/bin/risingwave.rs | 155 ++++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 40 deletions(-) diff --git a/src/cmd_all/src/bin/risingwave.rs b/src/cmd_all/src/bin/risingwave.rs index 1bbfe2265c316..c95e7a61db4d3 100644 --- a/src/cmd_all/src/bin/risingwave.rs +++ b/src/cmd_all/src/bin/risingwave.rs @@ -12,13 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(assert_matches)] #![cfg_attr(coverage, feature(coverage_attribute))] +use std::ffi::OsString; use std::str::FromStr; -use anyhow::Result; -use clap::error::ErrorKind; -use clap::{command, ArgMatches, Args, Command, FromArgMatches}; +use clap::error::Result as ClapResult; +use clap::{command, ArgMatches, Args, Command, CommandFactory, FromArgMatches}; use risingwave_cmd::{compactor, compute, ctl, frontend, meta}; use risingwave_cmd_all::{SingleNodeOpts, StandaloneOpts}; use risingwave_common::git_sha; @@ -86,7 +87,7 @@ const VERSION: &str = { }; /// Component to launch. -#[derive(Clone, Copy, EnumIter, EnumString, Display, IntoStaticStr)] +#[derive(Debug, Clone, Copy, EnumIter, EnumString, Display, IntoStaticStr)] #[strum(serialize_all = "snake_case")] enum Component { Compute, @@ -147,7 +148,8 @@ impl Component { Component::Frontend => FrontendOpts::augment_args(cmd), Component::Compactor => CompactorOpts::augment_args(cmd), Component::Ctl => CtlOpts::augment_args(cmd), - Component::Playground => cmd, + Component::Playground => cmd + .about("Shortcut for `single-node --in-memory`, should not be used in production"), Component::Standalone => StandaloneOpts::augment_args(cmd), Component::SingleNode => SingleNodeOpts::augment_args(cmd), } @@ -157,22 +159,20 @@ impl Component { fn commands() -> Vec { Self::iter() .map(|c| { - let is_playground = matches!(c, Component::Playground); let name: &'static str = c.into(); let command = Command::new(name).visible_aliases(c.aliases()); - let command = if is_playground { - command.hide(true) - } else { - command - }; c.augment_args(command) }) .collect() } } -#[cfg_attr(coverage, coverage(off))] -fn main() -> Result<()> { +/// Parse the given arguments and return the component and its matches. +fn parse_args(args: I) -> ClapResult<(Component, ArgMatches)> +where + I: IntoIterator, + T: Into + Clone, +{ let risingwave = || { command!(BINARY_NAME) .about("All-in-one executable for components of RisingWave") @@ -188,39 +188,38 @@ fn main() -> Result<()> { risingwave() .subcommand_value_name("COMPONENT") .subcommand_help_heading("Components") - .subcommand_required(true) - .subcommands(Component::commands()), + .subcommands(Component::commands()) + // Make single node the "default subcommand" + .args_conflicts_with_subcommands(true) + .args(SingleNodeOpts::command().get_arguments()), ); - let matches = match command.try_get_matches() { - Ok(m) => m, - Err(e) - if e.kind() == ErrorKind::MissingSubcommand - || e.kind() == ErrorKind::UnknownArgument => - { - // `$ ./risingwave` - // NOTE(kwannoel): This is a hack to make `risingwave` - // work as an alias of `risingwave single-process`. - // If invocation is not a multicall and there's no subcommand, - // we will try to invoke it as a single node. - let command = Component::SingleNode.augment_args(risingwave()); - let matches = command.get_matches(); - Component::SingleNode.start(&matches); - return Ok(()); - } - Err(e) => { - e.exit(); - } + let matches = command.try_get_matches_from(args)?; + let multicall = matches.subcommand().unwrap(); + + let (component_name, matches) = if multicall.0 == BINARY_NAME { + // This is not a multicall. Match argv[1] as a component. + (multicall.1) + .subcommand() + // If there's no subcommand, it must be single node ("default subcommand"). + .unwrap_or_else(|| (Component::SingleNode.into(), multicall.1)) + } else { + multicall }; - let multicall = matches.subcommand().unwrap(); - let argv_1 = multicall.1.subcommand(); - let (component_name, matches) = argv_1.unwrap_or(multicall); + let component = Component::from_str(component_name).unwrap(); // always succeeds + let matches = matches.clone(); + + Ok((component, matches)) +} - let component = Component::from_str(component_name)?; - component.start(matches); +#[cfg_attr(coverage, coverage(off))] +fn main() { + let (component, matches) = parse_args(std::env::args_os()) + .map_err(|e| e.exit()) + .unwrap(); - Ok(()) + component.start(&matches); } fn standalone(opts: StandaloneOpts) { @@ -243,3 +242,79 @@ fn single_node(opts: SingleNodeOpts) { risingwave_rt::init_risingwave_logger(settings); risingwave_rt::main_okk(risingwave_cmd_all::standalone(opts)).unwrap(); } + +#[cfg(test)] +mod tests { + use std::assert_matches::assert_matches; + + use clap::error::ErrorKind; + + use super::{parse_args, Component}; + + #[test] + fn test_basic() { + let (c, _) = + parse_args(["./risingwave", "meta", "--advertise-addr", "1.2.3.4:5678"]).unwrap(); + assert_matches!(c, Component::Meta); + } + + #[test] + fn test_multicall() { + let (c, _) = parse_args(["./meta-node", "--advertise-addr", "1.2.3.4:5678"]).unwrap(); + assert_matches!(c, Component::Meta); + } + + #[test] + fn test_missing_sub_subcommand() { + let e = parse_args(["./risingwave", "ctl"]).unwrap_err(); + assert_matches!( + e.kind(), + ErrorKind::DisplayHelpOnMissingArgumentOrSubcommand + ); + } + + #[test] + fn test_sub_expected_subcommand_but_got_unknown_arg() { + let e = parse_args(["./risingwave", "ctl", "--foo"]).unwrap_err(); + assert_matches!(e.kind(), ErrorKind::UnknownArgument); + } + + #[test] + fn test_issue_16065() { + let e = parse_args([ + "./risingwave", + "ctl", + "meta", + "unregister-worker", + "my-bad-arg", + ]) + .unwrap_err(); + assert_matches!(e.kind(), ErrorKind::UnknownArgument); + assert!(e.to_string().contains("my-bad-arg"), "{e}") + } + + #[test] + fn test_default_subcommand_single_node() { + let (c, _) = parse_args(["./risingwave"]).unwrap(); + assert_matches!(c, Component::SingleNode); + } + + #[test] + fn test_default_subcommand_single_node_with_args() { + let (c, _) = parse_args(["./risingwave", "--in-memory"]).unwrap(); + assert_matches!(c, Component::SingleNode); + } + + #[test] + fn test_default_subcommand_single_node_with_unknown_args() { + let e = parse_args(["./risingwave", "--foo"]).unwrap_err(); + assert_matches!(e.kind(), ErrorKind::UnknownArgument); + } + + #[test] + fn test_default_subcommand_single_node_with_other_explicit_subcommand() { + let e = parse_args(["./risingwave", "--in-memory", "ctl"]).unwrap_err(); + assert_matches!(e.kind(), ErrorKind::ArgumentConflict); + assert!(e.to_string().contains("cannot be used with"), "{e}"); + } +}