diff --git a/mock/mock_app/main.py b/mock/mock_app/main.py index 84664dd..bdf981b 100644 --- a/mock/mock_app/main.py +++ b/mock/mock_app/main.py @@ -4,15 +4,26 @@ from argenta import App, Orchestrator from argenta.app import PredefinedMessages, DynamicDividingLine, AutoCompleter from argenta.orchestrator import ArgParser from argenta.orchestrator.argparser import BooleanArgument, ValueArgument +from dishka import Provider, provide, Scope # type: ignore -arg_parser: ArgParser = ArgParser(processed_args=[BooleanArgument(name="repeat", is_deprecated=True), - ValueArgument(name="required", is_required=True)]) +class temProvider(Provider): + @provide(scope=Scope.APP) + def get_apace(self) -> int: + return 1234 + +arg_parser: ArgParser = ArgParser( + processed_args=[ + BooleanArgument(name="repeat", is_deprecated=True), + ValueArgument(name="required", is_required=True), + ] +) app: App = App( dividing_line=DynamicDividingLine(), autocompleter=AutoCompleter(), ) -orchestrator: Orchestrator = Orchestrator() +orchestrator: Orchestrator = Orchestrator(arg_parser, custom_providers=[temProvider()]) + def main(): app.include_router(work_router) @@ -20,7 +31,7 @@ def main(): app.add_message_on_startup(PredefinedMessages.USAGE) app.add_message_on_startup(PredefinedMessages.AUTOCOMPLETE) app.add_message_on_startup(PredefinedMessages.HELP) - + orchestrator.start_polling(app) diff --git a/mock/mock_app/routers.py b/mock/mock_app/routers.py index 5d36e1f..7501fa7 100644 --- a/mock/mock_app/routers.py +++ b/mock/mock_app/routers.py @@ -1,18 +1,20 @@ from argenta.command import Command, PredefinedFlags, Flags, Flag, PossibleValues from argenta.response import Response from argenta import Router +from argenta.di import FromDishka -work_router: Router = Router(title="Work points:") +work_router: Router = Router(title="Work points:", disable_redirect_stdout=True) -flag = Flag('csdv', possible_values=PossibleValues.NEITHER) +flag = Flag("csdv", possible_values=PossibleValues.NEITHER) @work_router.command( - Command("get", + Command( + "get", description="Get Help", aliases=["help", "Get_help"], - flags=Flags([PredefinedFlags.PORT, PredefinedFlags.HOST]) + flags=Flags([PredefinedFlags.PORT, PredefinedFlags.HOST]), ) ) def command_help(response: Response): @@ -21,6 +23,5 @@ def command_help(response: Response): @work_router.command("run") -def command_start_solving(response: Response): - print(response.status) - print(response.input_flags.flags) +def command_start_solving(response: Response, argspace: FromDishka[int]): + print(argspace) diff --git a/pyproject.toml b/pyproject.toml index b63175f..9fd14fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "argenta" -version = "1.1.1" +version = "1.1.2" description = "Python library for building modular CLI applications" authors = [{ name = "kolo", email = "kolo.is.main@gmail.com" }] requires-python = ">=3.11" @@ -10,6 +10,7 @@ dependencies = [ "rich (>=14.0.0,<15.0.0)", "art (>=6.4,<7.0)", "pyreadline3>=3.5.4; sys_platform == 'win32'", + "dishka>=1.7.2", ] [tool.ruff] @@ -25,6 +26,9 @@ exclude = [ [tool.pyright] typeCheckingMode = "strict" +[tool.mypy] +disable_error_code = "import-untyped" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" @@ -36,4 +40,3 @@ dev = [ "ruff>=0.12.12", "wemake-python-styleguide>=0.17.0", ] - diff --git a/src/argenta/__init__.py b/src/argenta/__init__.py index 8f4f423..6685918 100644 --- a/src/argenta/__init__.py +++ b/src/argenta/__init__.py @@ -1,5 +1,6 @@ __all__ = ["App", "Orchestrator", "Router"] -from argenta.orchestrator.entity import Orchestrator, App +from argenta.orchestrator.entity import Orchestrator +from argenta.app.models import App from argenta.router.entity import Router diff --git a/src/argenta/app/models.py b/src/argenta/app/models.py index c8821dd..676cf15 100644 --- a/src/argenta/app/models.py +++ b/src/argenta/app/models.py @@ -3,7 +3,6 @@ import re from contextlib import redirect_stdout from typing import Never, TypeAlias -from argenta.orchestrator.argparser.entity import ArgSpace from art import text2art # pyright: ignore[reportMissingTypeStubs, reportUnknownVariableType] from rich.console import Console from rich.markup import escape @@ -32,18 +31,21 @@ Matches: TypeAlias = list[str] | list[Never] class BaseApp: - def __init__(self, *, prompt: str, - initial_message: str, - farewell_message: str, - exit_command: Command, - system_router_title: str | None, - ignore_command_register: bool, - dividing_line: StaticDividingLine | DynamicDividingLine, - repeat_command_groups: bool, - override_system_messages: bool, - autocompleter: AutoCompleter, - print_func: Printer, - argspace: ArgSpace | None = None) -> None: + def __init__( + self, + *, + prompt: str, + initial_message: str, + farewell_message: str, + exit_command: Command, + system_router_title: str | None, + ignore_command_register: bool, + dividing_line: StaticDividingLine | DynamicDividingLine, + repeat_command_groups: bool, + override_system_messages: bool, + autocompleter: AutoCompleter, + print_func: Printer, + ) -> None: self._prompt: str = prompt self._print_func: Printer = print_func self._exit_command: Command = exit_command @@ -53,27 +55,44 @@ class BaseApp: self._repeat_command_groups_description: bool = repeat_command_groups self._override_system_messages: bool = override_system_messages self._autocompleter: AutoCompleter = autocompleter - self._argspace: ArgSpace | None = argspace self._farewell_message: str = farewell_message self._initial_message: str = initial_message - self._description_message_gen: DescriptionMessageGenerator = lambda command, description: f"{command} *=*=* {description}" - self._registered_routers: RegisteredRouters = RegisteredRouters() + self._description_message_gen: DescriptionMessageGenerator = ( + lambda command, description: f"{command} *=*=* {description}" + ) + self.registered_routers: RegisteredRouters = RegisteredRouters() self._messages_on_startup: list[str] = [] self._matching_lower_triggers_with_routers: dict[str, Router] = {} self._matching_default_triggers_with_routers: dict[str, Router] = {} - self._current_matching_triggers_with_routers: dict[str, Router] = self._matching_lower_triggers_with_routers if self._ignore_command_register else self._matching_default_triggers_with_routers + self._current_matching_triggers_with_routers: dict[str, Router] = ( + self._matching_lower_triggers_with_routers + if self._ignore_command_register + else self._matching_default_triggers_with_routers + ) - self._incorrect_input_syntax_handler: NonStandardBehaviorHandler[str] = lambda _: print_func(f"Incorrect flag syntax: {_}") - self._repeated_input_flags_handler: NonStandardBehaviorHandler[str] = lambda _: print_func(f"Repeated input flags: {_}") - self._empty_input_command_handler: EmptyCommandHandler = lambda: print_func("Empty input command") - self._unknown_command_handler: NonStandardBehaviorHandler[InputCommand] = lambda _: print_func(f"Unknown command: {_.trigger}") - self._exit_command_handler: NonStandardBehaviorHandler[Response] = lambda _: print_func(self._farewell_message) + self._incorrect_input_syntax_handler: NonStandardBehaviorHandler[str] = ( + lambda _: print_func(f"Incorrect flag syntax: {_}") + ) + self._repeated_input_flags_handler: NonStandardBehaviorHandler[str] = ( + lambda _: print_func(f"Repeated input flags: {_}") + ) + self._empty_input_command_handler: EmptyCommandHandler = lambda: print_func( + "Empty input command" + ) + self._unknown_command_handler: NonStandardBehaviorHandler[InputCommand] = ( + lambda _: print_func(f"Unknown command: {_.trigger}") + ) + self._exit_command_handler: NonStandardBehaviorHandler[Response] = ( + lambda _: print_func(self._farewell_message) + ) - def set_description_message_pattern(self, _: DescriptionMessageGenerator, /) -> None: + def set_description_message_pattern( + self, _: DescriptionMessageGenerator, / + ) -> None: """ Public. Sets the output pattern of the available commands :param _: output pattern of the available commands @@ -81,7 +100,9 @@ class BaseApp: """ self._description_message_gen = _ - def set_incorrect_input_syntax_handler(self, _: NonStandardBehaviorHandler[str], /) -> None: + def set_incorrect_input_syntax_handler( + self, _: NonStandardBehaviorHandler[str], / + ) -> None: """ Public. Sets the handler for incorrect flags when entering a command :param _: handler for incorrect flags when entering a command @@ -89,7 +110,9 @@ class BaseApp: """ self._incorrect_input_syntax_handler = _ - def set_repeated_input_flags_handler(self, _: NonStandardBehaviorHandler[str], /) -> None: + def set_repeated_input_flags_handler( + self, _: NonStandardBehaviorHandler[str], / + ) -> None: """ Public. Sets the handler for repeated flags when entering a command :param _: handler for repeated flags when entering a command @@ -97,7 +120,9 @@ class BaseApp: """ self._repeated_input_flags_handler = _ - def set_unknown_command_handler(self, _: NonStandardBehaviorHandler[InputCommand], /) -> None: + def set_unknown_command_handler( + self, _: NonStandardBehaviorHandler[InputCommand], / + ) -> None: """ Public. Sets the handler for unknown commands when entering a command :param _: handler for unknown commands when entering a command @@ -113,7 +138,9 @@ class BaseApp: """ self._empty_input_command_handler = _ - def set_exit_command_handler(self, _: NonStandardBehaviorHandler[Response], /) -> None: + def set_exit_command_handler( + self, _: NonStandardBehaviorHandler[Response], / + ) -> None: """ Public. Sets the handler for exit command when entering a command :param _: handler for exit command when entering a command @@ -126,7 +153,7 @@ class BaseApp: Private. Prints the description of the available commands :return: None """ - for registered_router in self._registered_routers: + for registered_router in self.registered_routers: if registered_router.title: self._print_func(registered_router.title) for command_handler in registered_router.command_handlers: @@ -167,14 +194,18 @@ class BaseApp: length=max_length_line, is_override=self._override_system_messages ) ) - - elif isinstance(self._dividing_line, StaticDividingLine): # pyright: ignore[reportUnnecessaryIsInstance] + + elif isinstance(self._dividing_line, StaticDividingLine): # pyright: ignore[reportUnnecessaryIsInstance] self._print_func( - self._dividing_line.get_full_static_line(is_override=self._override_system_messages) + self._dividing_line.get_full_static_line( + is_override=self._override_system_messages + ) ) print(text.strip("\n")) self._print_func( - self._dividing_line.get_full_static_line(is_override=self._override_system_messages) + self._dividing_line.get_full_static_line( + is_override=self._override_system_messages + ) ) else: @@ -189,13 +220,9 @@ class BaseApp: trigger = command.trigger exit_trigger = self._exit_command.trigger if self._ignore_command_register: - if ( - trigger.lower() == exit_trigger.lower() - ): + if trigger.lower() == exit_trigger.lower(): return True - elif trigger.lower() in [ - x.lower() for x in self._exit_command.aliases - ]: + elif trigger.lower() in [x.lower() for x in self._exit_command.aliases]: return True else: if trigger == exit_trigger: @@ -212,16 +239,18 @@ class BaseApp: """ input_command_trigger = command.trigger if self._ignore_command_register: - if input_command_trigger.lower() in list(self._current_matching_triggers_with_routers.keys()): + if input_command_trigger.lower() in list( + self._current_matching_triggers_with_routers.keys() + ): return False else: - if input_command_trigger in list(self._current_matching_triggers_with_routers.keys()): + if input_command_trigger in list( + self._current_matching_triggers_with_routers.keys() + ): return False return True - def _error_handler( - self, error: InputCommandException, raw_command: str - ) -> None: + def _error_handler(self, error: InputCommandException, raw_command: str) -> None: """ Private. Handles parsing errors of the entered command :param error: error being handled @@ -246,13 +275,13 @@ class BaseApp: def _(response: Response) -> None: self._exit_command_handler(response) - if system_router not in self._registered_routers.registered_routers: + if system_router not in self.registered_routers.registered_routers: system_router.command_register_ignore = self._ignore_command_register - self._registered_routers.add_registered_router(system_router) + self.registered_routers.add_registered_router(system_router) def _most_similar_command(self, unknown_command: str) -> str | None: all_commands = list(self._current_matching_triggers_with_routers.keys()) - + matches_startswith_unknown_command: Matches = sorted( cmd for cmd in all_commands if cmd.startswith(unknown_command) ) @@ -275,26 +304,36 @@ class BaseApp: :return: None """ self._prompt = f"[italic dim bold]{self._prompt}" - self._initial_message = ("\n" + f"[bold red]{text2art(self._initial_message, font='tarty1')}" + "\n") + self._initial_message = ( + "\n" + f"[bold red]{text2art(self._initial_message, font='tarty1')}" + "\n" + ) self._farewell_message = ( - "[bold red]\n\n" + - str(text2art(self._farewell_message, font="chanky")) + # pyright: ignore[reportUnknownArgumentType] - "\n[/bold red]\n" + - "[red i]github.com/koloideal/Argenta[/red i] | [red bold i]made by kolo[/red bold i]\n" + "[bold red]\n\n" + + str(text2art(self._farewell_message, font="chanky")) # pyright: ignore[reportUnknownArgumentType] + + "\n[/bold red]\n" + + "[red i]github.com/koloideal/Argenta[/red i] | [red bold i]made by kolo[/red bold i]\n" ) self._description_message_gen = lambda command, description: ( f"[bold red]{escape('[' + command + ']')}[/bold red] " f"[blue dim]*=*=*[/blue dim] " f"[bold yellow italic]{escape(description)}" ) - self._incorrect_input_syntax_handler = lambda raw_command: self._print_func(f"[red bold]Incorrect flag syntax: {escape(raw_command)}") - self._repeated_input_flags_handler = lambda raw_command: self._print_func(f"[red bold]Repeated input flags: {escape(raw_command)}") - self._empty_input_command_handler = lambda: self._print_func("[red bold]Empty input command") + self._incorrect_input_syntax_handler = lambda raw_command: self._print_func( + f"[red bold]Incorrect flag syntax: {escape(raw_command)}" + ) + self._repeated_input_flags_handler = lambda raw_command: self._print_func( + f"[red bold]Repeated input flags: {escape(raw_command)}" + ) + self._empty_input_command_handler = lambda: self._print_func( + "[red bold]Empty input command" + ) def unknown_command_handler(command: InputCommand) -> None: cmd_trg: str = command.trigger mst_sim_cmd: str | None = self._most_similar_command(cmd_trg) - first_part_of_text = f"[red]Unknown command:[/red] [blue]{escape(cmd_trg)}[/blue]" + first_part_of_text = ( + f"[red]Unknown command:[/red] [blue]{escape(cmd_trg)}[/blue]" + ) second_part_of_text = ( ("[red], most similar:[/red] " + ("[blue]" + mst_sim_cmd + "[/blue]")) if mst_sim_cmd @@ -311,21 +350,27 @@ class BaseApp: """ self._setup_system_router() - for router_entity in self._registered_routers: + for router_entity in self.registered_routers: router_triggers = router_entity.triggers router_aliases = router_entity.aliases combined = router_triggers + router_aliases for trigger in combined: self._matching_default_triggers_with_routers[trigger] = router_entity - self._matching_lower_triggers_with_routers[trigger.lower()] = router_entity + self._matching_lower_triggers_with_routers[trigger.lower()] = ( + router_entity + ) - self._autocompleter.initial_setup(list(self._current_matching_triggers_with_routers.keys())) + self._autocompleter.initial_setup( + list(self._current_matching_triggers_with_routers.keys()) + ) seen = {} for item in list(self._current_matching_triggers_with_routers.keys()): if item in seen: - Console().print(f"\n[b red]WARNING:[/b red] Overlapping trigger or alias: [b blue]{item}[/b blue]") + Console().print( + f"\n[b red]WARNING:[/b red] Overlapping trigger or alias: [b blue]{item}[/b blue]" + ) else: seen[item] = True @@ -352,7 +397,8 @@ DEFAULT_EXIT_COMMAND: Command = Command("Q", description="Exit command") class App(BaseApp): def __init__( - self, *, + self, + *, prompt: str = "What do you want to do?\n\n", initial_message: str = "Argenta\n", farewell_message: str = "\nSee you\n", @@ -395,12 +441,11 @@ class App(BaseApp): print_func=print_func, ) - def run_polling(self, argspace: ArgSpace | None) -> None: + def run_polling(self) -> None: """ Private. Starts the user input processing cycle :return: None """ - self._argspace = argspace self._pre_cycle_setup() while True: if self._repeat_command_groups_description: @@ -409,7 +454,9 @@ class App(BaseApp): raw_command: str = Console().input(self._prompt) try: - input_command: InputCommand = InputCommand.parse(raw_command=raw_command) + input_command: InputCommand = InputCommand.parse( + raw_command=raw_command + ) except InputCommandException as error: with redirect_stdout(io.StringIO()) as stderr: self._error_handler(error, raw_command) @@ -419,7 +466,9 @@ class App(BaseApp): if self._is_exit_command(input_command): system_router.finds_appropriate_handler(input_command) - self._autocompleter.exit_setup(list(self._current_matching_triggers_with_routers.keys())) + self._autocompleter.exit_setup( + list(self._current_matching_triggers_with_routers.keys()) + ) return if self._is_unknown_command(input_command): @@ -429,18 +478,40 @@ class App(BaseApp): self._print_framed_text(stdout_res) continue - processing_router = self._current_matching_triggers_with_routers[input_command.trigger.lower()] + processing_router = self._current_matching_triggers_with_routers[ + input_command.trigger.lower() + ] if processing_router.disable_redirect_stdout: if isinstance(self._dividing_line, StaticDividingLine): - self._print_func(self._dividing_line.get_full_static_line(is_override=self._override_system_messages)) + self._print_func( + self._dividing_line.get_full_static_line( + is_override=self._override_system_messages + ) + ) processing_router.finds_appropriate_handler(input_command) - self._print_func(self._dividing_line.get_full_static_line(is_override=self._override_system_messages)) + self._print_func( + self._dividing_line.get_full_static_line( + is_override=self._override_system_messages + ) + ) else: dividing_line_unit_part: str = self._dividing_line.get_unit_part() - self._print_func(StaticDividingLine(dividing_line_unit_part).get_full_static_line(is_override=self._override_system_messages)) + self._print_func( + StaticDividingLine( + dividing_line_unit_part + ).get_full_static_line( + is_override=self._override_system_messages + ) + ) processing_router.finds_appropriate_handler(input_command) - self._print_func(StaticDividingLine(dividing_line_unit_part).get_full_static_line(is_override=self._override_system_messages)) + self._print_func( + StaticDividingLine( + dividing_line_unit_part + ).get_full_static_line( + is_override=self._override_system_messages + ) + ) else: with redirect_stdout(io.StringIO()) as stdout: processing_router.finds_appropriate_handler(input_command) @@ -455,7 +526,7 @@ class App(BaseApp): :return: None """ router.command_register_ignore = self._ignore_command_register - self._registered_routers.add_registered_router(router) + self.registered_routers.add_registered_router(router) def include_routers(self, *routers: Router) -> None: """ diff --git a/src/argenta/di/__init__.py b/src/argenta/di/__init__.py new file mode 100644 index 0000000..04d5216 --- /dev/null +++ b/src/argenta/di/__init__.py @@ -0,0 +1,2 @@ +from argenta.di.integration import inject as inject +from argenta.di.integration import FromDishka as FromDishka diff --git a/src/argenta/di/integration.py b/src/argenta/di/integration.py new file mode 100644 index 0000000..a4001ec --- /dev/null +++ b/src/argenta/di/integration.py @@ -0,0 +1,45 @@ +__all__ = ["inject", "setup_dishka", "FromDishka"] + +from typing import Any, Callable, TypeVar + +from dishka import Container, FromDishka +from dishka.integrations.base import wrap_injection, is_dishka_injected + +from argenta.response import Response +from argenta.app import App + + +T = TypeVar("T") + + +def inject(func: Callable[..., T]) -> Callable[..., T]: + return wrap_injection( + func=func, + is_async=False, + container_getter=_get_container_from_response, + ) + + +def setup_dishka(app: App, *, auto_inject: bool = False) -> None: + if auto_inject: + _auto_inject_handlers(app) + + +def _get_container_from_response( + args: tuple[Any, ...], kwargs: dict[str, Any] +) -> Container: + for arg in args: + if isinstance(arg, Response): + if hasattr(arg, "_dishka_container"): + return arg._dishka_container # pyright: ignore[reportPrivateUsage] + break + + raise RuntimeError("dishka container not found in Response") + + +def _auto_inject_handlers(app: App) -> None: + for router in app.registered_routers: + for command_handler in router.command_handlers: + if not is_dishka_injected(command_handler.handler_as_func): + injected_handler = inject(command_handler.handler_as_func) + command_handler.handler_as_func = injected_handler diff --git a/src/argenta/di/providers.py b/src/argenta/di/providers.py new file mode 100644 index 0000000..43e404e --- /dev/null +++ b/src/argenta/di/providers.py @@ -0,0 +1,14 @@ +from argenta.orchestrator.argparser import ArgParser +from dishka import Provider, provide, Scope + +from argenta.orchestrator.argparser.entity import ArgSpace + + +class SystemProvider(Provider): + def __init__(self, arg_parser: ArgParser): + super().__init__() + self._arg_parser: ArgParser = arg_parser + + @provide(scope=Scope.APP) + def get_argspace(self) -> ArgSpace: + return self._arg_parser.parse_args() diff --git a/src/argenta/orchestrator/entity.py b/src/argenta/orchestrator/entity.py index 3e7143e..7934e56 100644 --- a/src/argenta/orchestrator/entity.py +++ b/src/argenta/orchestrator/entity.py @@ -1,19 +1,28 @@ -from argenta.app.models import App +from argenta.app import App +from argenta.response import Response + from argenta.orchestrator.argparser import ArgParser -from argenta.orchestrator.argparser.entity import ArgSpace +from argenta.di.integration import setup_dishka +from argenta.di.providers import SystemProvider + +from dishka import Provider, make_container DEFAULT_ARGPARSER: ArgParser = ArgParser(processed_args=[]) class Orchestrator: - def __init__(self, arg_parser: ArgParser = DEFAULT_ARGPARSER): + def __init__(self, arg_parser: ArgParser = DEFAULT_ARGPARSER, + custom_providers: list[Provider] = [], + auto_inject_handlers: bool = True): """ Public. An orchestrator and configurator that defines the behavior of an integrated system, one level higher than the App :param arg_parser: Cmd argument parser and configurator at startup :return: None """ self._arg_parser: ArgParser = arg_parser + self._custom_providers: list[Provider] = custom_providers + self._auto_inject_handlers: bool = auto_inject_handlers def start_polling(self, app: App) -> None: """ @@ -21,5 +30,8 @@ class Orchestrator: :param app: a running application :return: None """ - parsed_argspace: ArgSpace = self._arg_parser.parse_args() - app.run_polling(argspace=parsed_argspace) + container = make_container(SystemProvider(self._arg_parser), *self._custom_providers) + Response.patch_by_container(container) + setup_dishka(app, auto_inject=self._auto_inject_handlers) + + app.run_polling() diff --git a/src/argenta/response/entity.py b/src/argenta/response/entity.py index 8ecd209..3e10371 100644 --- a/src/argenta/response/entity.py +++ b/src/argenta/response/entity.py @@ -1,4 +1,5 @@ -from typing import Literal +from dishka import Container + from argenta.command.flag.flags.models import InputFlags from argenta.response.status import ResponseStatus @@ -7,7 +8,7 @@ EMPTY_INPUT_FLAGS: InputFlags = InputFlags() class Response: - __slots__: tuple[Literal['status', 'input_flags'], ...] = ("status", "input_flags") + _dishka_container: Container def __init__( self, @@ -21,3 +22,7 @@ class Response: """ self.status: ResponseStatus = status self.input_flags: InputFlags = input_flags + + @classmethod + def patch_by_container(cls, container: Container) -> None: + cls._dishka_container = container diff --git a/src/argenta/router/command_handler/entity.py b/src/argenta/router/command_handler/entity.py index a9290fb..ac01d71 100644 --- a/src/argenta/router/command_handler/entity.py +++ b/src/argenta/router/command_handler/entity.py @@ -6,13 +6,13 @@ from argenta.response import Response class CommandHandler: - def __init__(self, handler_as_func: Callable[[Response], None], handled_command: Command): + def __init__(self, handler_as_func: Callable[..., None], handled_command: Command): """ Private. Entity of the model linking the handler and the command being processed :param handler: the handler being called :param handled_command: the command being processed """ - self.handler_as_func: Callable[[Response], None] = handler_as_func + self.handler_as_func: Callable[..., None] = handler_as_func self.handled_command: Command = handled_command def handling(self, response: Response) -> None: @@ -30,7 +30,9 @@ class CommandHandlers: Private. The model that unites all CommandHandler of the routers :param command_handlers: list of CommandHandlers for register """ - self.command_handlers: list[CommandHandler] = command_handlers if command_handlers else [] + self.command_handlers: list[CommandHandler] = ( + command_handlers if command_handlers else [] + ) def add_handler(self, command_handler: CommandHandler) -> None: """ diff --git a/src/argenta/router/entity.py b/src/argenta/router/entity.py index 34a9870..841c6d0 100644 --- a/src/argenta/router/entity.py +++ b/src/argenta/router/entity.py @@ -6,25 +6,23 @@ from argenta.command import Command, InputCommand from argenta.command.flag import ValidationStatus from argenta.response import Response, ResponseStatus from argenta.router.command_handler.entity import CommandHandlers, CommandHandler -from argenta.command.flag.flags import ( - Flags, - InputFlags -) +from argenta.command.flag.flags import Flags, InputFlags from argenta.router.exceptions import ( RepeatedFlagNameException, - TooManyTransferredArgsException, RequiredArgumentNotPassedException, TriggerContainSpacesException, ) -HandlerFunc: TypeAlias = Callable[[Response], None] +HandlerFunc: TypeAlias = Callable[..., None] class Router: def __init__( - self, *, title: str | None = "Default title", - disable_redirect_stdout: bool = False + self, + *, + title: str | None = "Default title", + disable_redirect_stdout: bool = False, ): """ Public. Directly configures and manages handlers @@ -58,7 +56,6 @@ class Router: def decorator(func: HandlerFunc) -> HandlerFunc: _validate_func_args(func) self.command_handlers.add_handler(CommandHandler(func, redefined_command)) - return func return decorator @@ -91,7 +88,9 @@ class Router: handle_command = command_handler.handled_command if handle_command.registered_flags.flags: if input_command_flags.flags: - response: Response = _structuring_input_flags(handle_command, input_command_flags) + response: Response = _structuring_input_flags( + handle_command, input_command_flags + ) command_handler.handling(response) else: response = Response(ResponseStatus.ALL_FLAGS_VALID) @@ -102,7 +101,9 @@ class Router: for input_flag in input_command_flags: input_flag.status = ValidationStatus.UNDEFINED undefined_flags.add_flag(input_flag) - response = Response(ResponseStatus.UNDEFINED_FLAGS, input_flags=undefined_flags) + response = Response( + ResponseStatus.UNDEFINED_FLAGS, input_flags=undefined_flags + ) command_handler.handling(response) else: response = Response(ResponseStatus.ALL_FLAGS_VALID) @@ -137,14 +138,17 @@ class CommandDecorator: self.router: Router = router_instance self.command: Command = command - def __call__(self, handler_func: Callable[[Response], None]) -> Callable[[Response], None]: + def __call__(self, handler_func: Callable[..., None]) -> Callable[..., None]: _validate_func_args(handler_func) - self.router.command_handlers.add_handler(CommandHandler(handler_func, self.command)) + self.router.command_handlers.add_handler( + CommandHandler(handler_func, self.command) + ) return handler_func -def _structuring_input_flags(handled_command: Command, - input_flags: InputFlags) -> Response: +def _structuring_input_flags( + handled_command: Command, input_flags: InputFlags +) -> Response: """ Private. Validates flags of input command :param handled_command: entity of the handled command @@ -154,45 +158,42 @@ def _structuring_input_flags(handled_command: Command, invalid_value_flags, undefined_flags = False, False for flag in input_flags: - flag_status: ValidationStatus = (handled_command.validate_input_flag(flag)) + flag_status: ValidationStatus = handled_command.validate_input_flag(flag) flag.status = flag_status if flag_status == ValidationStatus.INVALID: invalid_value_flags = True elif flag_status == ValidationStatus.UNDEFINED: undefined_flags = True - status = ResponseStatus.from_flags(has_invalid_value_flags=invalid_value_flags, - has_undefined_flags=undefined_flags) - - return Response( - status=status, - input_flags=input_flags + status = ResponseStatus.from_flags( + has_invalid_value_flags=invalid_value_flags, has_undefined_flags=undefined_flags ) -def _validate_func_args(func: Callable[[Response], None]) -> None: + return Response(status=status, input_flags=input_flags) + + +def _validate_func_args(func: Callable[..., None]) -> None: """ Private. Validates the arguments of the handler :param func: entity of the handler func :return: None if func is valid else raise exception """ transferred_args = getfullargspec(func).args - if len(transferred_args) > 1: - raise TooManyTransferredArgsException() - elif len(transferred_args) == 0: + if len(transferred_args) == 0: raise RequiredArgumentNotPassedException() - transferred_arg: str = transferred_args[0] + response_arg: str = transferred_args[0] func_annotations: dict[str, None] = get_annotations(func) - arg_annotation = func_annotations.get(transferred_arg) + response_arg_annotation = func_annotations.get(response_arg) - if arg_annotation is not None: - if arg_annotation is not Response: + if response_arg_annotation is not None: + if response_arg_annotation is not Response: source_line: int = getsourcelines(func)[1] Console().print( - f'\nFile "{getsourcefile(func)}", line {source_line}\n[b red]WARNING:[/b red] [i]The typehint ' + - f"of argument([green]{transferred_arg}[/green]) passed to the handler must be [/i][bold blue]{Response}[/bold blue]," + - f" [i]but[/i] [bold blue]{arg_annotation}[/bold blue] [i]is specified[/i]", + f'\nFile "{getsourcefile(func)}", line {source_line}\n[b red]WARNING:[/b red] [i]The typehint ' + + f"of argument([green]{response_arg}[/green]) passed to the handler must be [/i][bold blue]{Response}[/bold blue]," + + f" [i]but[/i] [bold blue]{response_arg_annotation}[/bold blue] [i]is specified[/i]", highlight=False, ) diff --git a/src/argenta/router/exceptions.py b/src/argenta/router/exceptions.py index fe45870..161f08f 100644 --- a/src/argenta/router/exceptions.py +++ b/src/argenta/router/exceptions.py @@ -5,24 +5,17 @@ class RepeatedFlagNameException(Exception): """ Private. Raised when a repeated flag name is registered """ + @override def __str__(self) -> str: return "Repeated registered flag names in register command" -class TooManyTransferredArgsException(Exception): - """ - Private. Raised when too many arguments are passed - """ - @override - def __str__(self) -> str: - return "Too many transferred arguments" - - class RequiredArgumentNotPassedException(Exception): """ Private. Raised when a required argument is not passed """ + @override def __str__(self) -> str: return "Required argument not passed" @@ -32,6 +25,7 @@ class TriggerContainSpacesException(Exception): """ Private. Raised when there is a space in the trigger being registered """ + @override def __str__(self) -> str: return "Command trigger cannot contain spaces" diff --git a/tests/unit_tests/test_router.py b/tests/unit_tests/test_router.py index 4417812..54593de 100644 --- a/tests/unit_tests/test_router.py +++ b/tests/unit_tests/test_router.py @@ -7,7 +7,6 @@ from argenta.command import Command from argenta.router.entity import _structuring_input_flags, _validate_command, _validate_func_args # pyright: ignore[reportPrivateUsage] from argenta.router.exceptions import (TriggerContainSpacesException, RepeatedFlagNameException, - TooManyTransferredArgsException, RequiredArgumentNotPassedException) import unittest @@ -79,12 +78,6 @@ class TestRouter(unittest.TestCase): with self.assertRaises(RequiredArgumentNotPassedException): _validate_func_args(handler) # pyright: ignore[reportArgumentType] - def test_validate_incorrect_func_args2(self): - def handler(args, kwargs): # pyright: ignore[reportMissingParameterType, reportUnknownParameterType] - pass - with self.assertRaises(TooManyTransferredArgsException): - _validate_func_args(handler) # pyright: ignore[reportArgumentType] - def test_get_router_aliases(self): router = Router() @router.command(Command('some', aliases=['test', 'case'])) @@ -108,12 +101,3 @@ class TestRouter(unittest.TestCase): def handler(response: Response): # pyright: ignore[reportUnusedFunction] pass self.assertListEqual(router.aliases, []) - - - - - - - - -