mirror of
https://github.com/koloideal/Argenta.git
synced 2026-06-10 18:15:28 +03:00
refactor and optimize argspace
This commit is contained in:
@@ -18,26 +18,48 @@ class ArgSpace:
|
||||
def __init__(self, all_arguments: list[InputArgument]) -> None:
|
||||
self.all_arguments = all_arguments
|
||||
|
||||
self._name_object_paired_args: dict[str, InputArgument] = {}
|
||||
self._type_object_paired_args: dict[type[BaseArgument], list[InputArgument]] = {
|
||||
BooleanArgument: [],
|
||||
ValueArgument: []
|
||||
}
|
||||
|
||||
self._setup_getters()
|
||||
|
||||
@classmethod
|
||||
def from_namespace(
|
||||
cls, namespace: Namespace, processed_args: list[ValueArgument | BooleanArgument]
|
||||
cls,
|
||||
namespace: Namespace,
|
||||
processed_args: list[ValueArgument | BooleanArgument]
|
||||
) -> Self:
|
||||
name_type_paired_args: dict[str, type[BaseArgument]] = {arg.name: type(arg) for arg in processed_args}
|
||||
return cls(
|
||||
[
|
||||
InputArgument(name=name, value=value, founder_class=name_type_paired_args[name])
|
||||
for name, value in vars(namespace).items()
|
||||
]
|
||||
)
|
||||
name_type_paired_processed_args: dict[str, type[BaseArgument]] = {
|
||||
arg.name: type(arg) for arg in processed_args
|
||||
}
|
||||
parsed_arguments: list[InputArgument] = []
|
||||
|
||||
for name, value in vars(namespace).items():
|
||||
parsed_arguments.append(
|
||||
InputArgument(
|
||||
name=name,
|
||||
value=value,
|
||||
founder_class=name_type_paired_processed_args[name]
|
||||
)
|
||||
)
|
||||
|
||||
return cls(parsed_arguments)
|
||||
|
||||
def _setup_getters(self):
|
||||
if not self.all_arguments:
|
||||
return
|
||||
for input_arg in self.all_arguments:
|
||||
self._name_object_paired_args[input_arg.name] = input_arg
|
||||
self._type_object_paired_args[input_arg.founder_class].append(input_arg)
|
||||
|
||||
def get_by_name(self, name: str) -> InputArgument | None:
|
||||
for arg in self.all_arguments:
|
||||
if arg.name == name:
|
||||
return arg
|
||||
return None
|
||||
return self._name_object_paired_args.get(name)
|
||||
|
||||
def get_by_type(self, arg_type: type[BaseArgument]) -> list[InputArgument] | list[Never]:
|
||||
return [arg for arg in self.all_arguments if arg.founder_class is arg_type]
|
||||
return self._type_object_paired_args.get(arg_type, [])
|
||||
|
||||
|
||||
class ArgParser:
|
||||
@@ -75,7 +97,10 @@ class ArgParser:
|
||||
for arg in processed_args:
|
||||
if isinstance(arg, BooleanArgument):
|
||||
_ = self._core.add_argument(
|
||||
arg.string_entity, action=arg.action, help=arg.help, deprecated=arg.is_deprecated
|
||||
arg.string_entity,
|
||||
action=arg.action,
|
||||
help=arg.help,
|
||||
deprecated=arg.is_deprecated
|
||||
)
|
||||
else:
|
||||
_ = self._core.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user