This commit is contained in:
2025-03-11 20:44:06 +03:00
parent 5a6fc1d8ca
commit d30515c1a2
20 changed files with 83 additions and 77 deletions
+13 -11
View File
@@ -1,16 +1,15 @@
from .params.flag.entity import Flag
from .params.flag.flags_group.entity import FlagsGroup
from argenta.command.flag.entity import Flag
from argenta.command.flag.flags_group import FlagsGroup
from .exceptions import (UnprocessedInputFlagException,
RepeatedInputFlagsException,
EmptyInputCommandException)
from typing import Generic, TypeVar
from typing import Generic, TypeVar, cast, Literal
CommandType = TypeVar('CommandType')
T = TypeVar('T')
class Command(Generic[T]):
class Command(Generic[CommandType]):
def __init__(self, trigger: str,
description: str = None,
flags: Flag | FlagsGroup = None):
@@ -57,7 +56,7 @@ class Command(Generic[T]):
return self._input_flags
@staticmethod
def parse_input_command(raw_command: str) -> 'Command[T]':
def parse_input_command(raw_command: str) -> 'Command[CommandType]':
if not raw_command:
raise EmptyInputCommandException()
list_of_tokens = raw_command.split()
@@ -67,7 +66,7 @@ class Command(Generic[T]):
flags: FlagsGroup = FlagsGroup()
current_flag_name = None
current_flag_value = None
for _ in list_of_tokens:
for k, _ in enumerate(list_of_tokens):
if _.startswith('-'):
flag_prefix_last_symbol_index = _.rfind('-')
if current_flag_name or len(_) < 2 or len(_[:flag_prefix_last_symbol_index]) > 3:
@@ -79,12 +78,15 @@ class Command(Generic[T]):
raise UnprocessedInputFlagException()
else:
current_flag_value = _
if current_flag_name and current_flag_value:
if current_flag_name:
if not len(list_of_tokens) == k+1:
if not list_of_tokens[k+1].startswith('-'):
continue
flag_prefix_last_symbol_index = current_flag_name.rfind('-')
flag_prefix = current_flag_name[:flag_prefix_last_symbol_index+1]
flag_name = current_flag_name[flag_prefix_last_symbol_index+1:]
input_flag = Flag(flag_name=flag_name,
flag_prefix=flag_prefix)
flag_prefix=cast(Literal['-', '--', '---'], flag_prefix))
input_flag.set_value(current_flag_value)
all_flags = [x.get_string_entity() for x in flags.get_flags()]