Skip to content

Commit

Permalink
properly match provider being used with provider arguments so that kw…
Browse files Browse the repository at this point in the history
…args are correctly filtered
  • Loading branch information
hjoaquim committed May 10, 2024
1 parent e8a6fe5 commit 8e9b228
Showing 1 changed file with 13 additions and 3 deletions.
16 changes: 13 additions & 3 deletions cli/openbb_cli/argparse_translator/argparse_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __init__(
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
self.provider_parameters: List[str] = []
self.provider_parameters: Dict[str, List[str]] = {}

self._parser = argparse.ArgumentParser(
prog=func.__name__,
Expand All @@ -218,6 +218,7 @@ def __init__(

if custom_argument_groups:
for group in custom_argument_groups:
self.provider_parameters[group.name] = []
argparse_group = self._parser.add_argument_group(group.name)
for argument in group.arguments:
self._handle_argument_in_groups(argument, argparse_group)
Expand Down Expand Up @@ -278,7 +279,8 @@ def _update_providers(
if f"--{argument.name}" not in self._parser._option_string_actions:
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
group.add_argument(f"--{argument.name}", **kwargs)
self.provider_parameters.append(argument.name)
if group.title in self.provider_parameters:
self.provider_parameters[group.title].append(argument.name)

else:
kwargs = argument.model_dump(exclude={"name"}, exclude_none=True)
Expand Down Expand Up @@ -582,11 +584,19 @@ def execute_func(
kwargs = self._unflatten_args(vars(parsed_args))
kwargs = self._update_with_custom_types(kwargs)

provider = kwargs.get("provider")
provider_args = []
if provider and provider in self.provider_parameters:
provider_args = self.provider_parameters[provider]
else:
for args in self.provider_parameters.values():
provider_args.extend(args)

# remove kwargs that doesn't match the signature or provider parameters
kwargs = {
key: value
for key, value in kwargs.items()
if key in self.signature.parameters or key in self.provider_parameters
if key in self.signature.parameters or key in provider_args
}

return self.func(**kwargs)
Expand Down

0 comments on commit 8e9b228

Please sign in to comment.