From 260171086bdec3ff69d4d90d91d9ef296b734a72 Mon Sep 17 00:00:00 2001 From: kolo Date: Sun, 4 Jan 2026 16:05:19 +0300 Subject: [PATCH] commit --- .../ca107b03ddf8_add_indexes_and_fk.py | 43 +++++++++ .../application/bot/admin_dialogs/users.py | 3 +- .../application/bot/creator_dialogs/users.py | 3 +- .../bot/middlewares/reject_not_creator.py | 1 - .../application/bot/shared_dialogs/tests.py | 95 ++++++++++++++++++- .../application/bot/user_dialogs/deeplink.py | 18 +++- .../application/bot/user_dialogs/take_test.py | 18 +++- .../infrastructure/database/dao/user.py | 7 ++ src/trudex/infrastructure/database/models.py | 17 ++-- .../infrastructure/database/repo/test.py | 22 +++++ .../database/repo/test_attempt.py | 18 ++-- src/trudex/infrastructure/di.py | 5 + src/trudex/infrastructure/utils/broadcast.py | 25 ++++- .../infrastructure/utils/rate_limiter.py | 57 +++++++++++ 14 files changed, 302 insertions(+), 30 deletions(-) create mode 100644 alembic/versions/ca107b03ddf8_add_indexes_and_fk.py create mode 100644 src/trudex/infrastructure/utils/rate_limiter.py diff --git a/alembic/versions/ca107b03ddf8_add_indexes_and_fk.py b/alembic/versions/ca107b03ddf8_add_indexes_and_fk.py new file mode 100644 index 0000000..980f226 --- /dev/null +++ b/alembic/versions/ca107b03ddf8_add_indexes_and_fk.py @@ -0,0 +1,43 @@ +"""add_indexes_and_fk + +Revision ID: ca107b03ddf8 +Revises: 40f5317720a4 +Create Date: 2026-01-04 15:32:14.881408 + +""" +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +revision: str = 'ca107b03ddf8' +down_revision: str | None = '40f5317720a4' +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_index(op.f('ix_options_question_id'), 'options', ['question_id'], unique=False) + op.create_index(op.f('ix_questions_test_id'), 'questions', ['test_id'], unique=False) + op.create_index(op.f('ix_test_attempts_test_id'), 'test_attempts', ['test_id'], unique=False) + op.create_foreign_key(None, 'test_attempts', 'users', ['user_id'], ['id']) + op.create_index(op.f('ix_user_answers_attempt_id'), 'user_answers', ['attempt_id'], unique=False) + op.create_index(op.f('ix_user_answers_question_id'), 'user_answers', ['question_id'], unique=False) + op.create_index(op.f('ix_users_group'), 'users', ['group'], unique=False) + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_users_username'), table_name='users') + op.drop_index(op.f('ix_users_group'), table_name='users') + op.drop_index(op.f('ix_user_answers_question_id'), table_name='user_answers') + op.drop_index(op.f('ix_user_answers_attempt_id'), table_name='user_answers') + op.drop_constraint(None, 'test_attempts', type_='foreignkey') + op.drop_index(op.f('ix_test_attempts_test_id'), table_name='test_attempts') + op.drop_index(op.f('ix_questions_test_id'), table_name='questions') + op.drop_index(op.f('ix_options_question_id'), table_name='options') + # ### end Alembic commands ### diff --git a/src/trudex/application/bot/admin_dialogs/users.py b/src/trudex/application/bot/admin_dialogs/users.py index c1b67ba..b2cc96a 100644 --- a/src/trudex/application/bot/admin_dialogs/users.py +++ b/src/trudex/application/bot/admin_dialogs/users.py @@ -68,8 +68,7 @@ async def on_user_input(message: Message, _widget: MessageInput, manager: Dialog user = None if text.startswith("@"): username = text[1:] - all_users = await user_dao.get_all() - user = next((u for u in all_users if u.username == username), None) + user = await user_dao.get_by_username(username) elif text.isdigit(): user = await user_dao.get_by_id(int(text)) diff --git a/src/trudex/application/bot/creator_dialogs/users.py b/src/trudex/application/bot/creator_dialogs/users.py index 6158593..f71c027 100644 --- a/src/trudex/application/bot/creator_dialogs/users.py +++ b/src/trudex/application/bot/creator_dialogs/users.py @@ -137,8 +137,7 @@ async def on_user_input(message: Message, _widget: MessageInput, manager: Dialog user = None if text.startswith("@"): username = text[1:] - all_users = await user_dao.get_all() - user = next((u for u in all_users if u.username == username), None) + user = await user_dao.get_by_username(username) elif text.isdigit(): user = await user_dao.get_by_id(int(text)) diff --git a/src/trudex/application/bot/middlewares/reject_not_creator.py b/src/trudex/application/bot/middlewares/reject_not_creator.py index 8fded29..86aadab 100644 --- a/src/trudex/application/bot/middlewares/reject_not_creator.py +++ b/src/trudex/application/bot/middlewares/reject_not_creator.py @@ -30,7 +30,6 @@ class RejectNotCreatorMiddleware(BaseMiddleware): if user_id == config.bot.creator_id: return await handler(event, data) - await event.answer("У вас нет доступа к панели создателя.") return return await handler(event, data) diff --git a/src/trudex/application/bot/shared_dialogs/tests.py b/src/trudex/application/bot/shared_dialogs/tests.py index e97f320..021098e 100644 --- a/src/trudex/application/bot/shared_dialogs/tests.py +++ b/src/trudex/application/bot/shared_dialogs/tests.py @@ -1,17 +1,19 @@ import asyncio import functools +import json from datetime import date, datetime, time from aiogram import Bot from aiogram.types import BufferedInputFile, CallbackQuery, Message from aiogram_dialog import Dialog, DialogManager, StartMode, Window from aiogram_dialog.widgets.input import MessageInput -from aiogram_dialog.widgets.kbd import Button, Calendar, Column, ScrollingGroup, Select +from aiogram_dialog.widgets.kbd import Button, Calendar, Column, Row, ScrollingGroup, Select from aiogram_dialog.widgets.text import Const, Format from dishka import FromDishka from dishka.integrations.aiogram_dialog import inject from trudex.application.bot.shared_dialogs.states import SharedCreateTestSG, SharedTestsSG +from trudex.domain.schemas import QuestionType from trudex.infrastructure.database.dao.group import GroupDAO from trudex.infrastructure.database.dao.test import TestDAO from trudex.infrastructure.database.repo.test import TestRepository @@ -199,11 +201,16 @@ async def get_attempt_detail( "📋 Ответы:\n", ] + # Загружаем все вопросы с опциями за один запрос + question_ids = [answer.question_id for answer in answers] + questions_map = await test_repo.get_questions_with_options_by_ids(question_ids) + for i, answer in enumerate(answers, 1): - question, options = await test_repo.get_question_with_options(answer.question_id) - if not question: + question_data = questions_map.get(answer.question_id) + if not question_data: continue + question, options = question_data correct_options = [opt for opt in options if opt.is_correct] correct_texts = [opt.text for opt in correct_options] @@ -221,6 +228,87 @@ async def get_attempt_detail( return {"attempt_info": "\n".join(lines)} +@inject +async def on_export_test( + _callback: CallbackQuery, + _button: Button, + manager: DialogManager, + test_repo: FromDishka[TestRepository], +) -> None: + test_id = manager.dialog_data.get("selected_test_id") + + if not test_id: + await _callback.answer("❌ Тест не найден") + return + + assert _callback.message is not None + await _callback.answer("⏳ Экспортирую тест...") + + test, questions_with_options = await test_repo.get_full_test(test_id) + + if not test: + await _callback.message.answer("❌ Тест не найден") + return + + export_data: dict = { + "title": test.title, + "description": test.description, + "password": test.password, + "attempts": test.attempts, + "expires_at": test.expires_at.isoformat() if test.expires_at else None, + "for_group": test.for_group, + "questions": [], + } + + questions_list: list = export_data["questions"] + + for question, options in questions_with_options: + question_data: dict = { + "question_type": question.question_type.value, + "question": question.text, + } + + if question.question_type == QuestionType.INPUT: + correct_options = [o for o in options if o.is_correct] + if correct_options: + question_data["correct_answer"] = correct_options[0].text + else: + question_data["answers"] = [ + {"option": o.text, "is_correct": o.is_correct} + for o in options + ] + + questions_list.append(question_data) + + json_str = json.dumps(export_data, ensure_ascii=False, indent=2) + + created_str = test.created_at.strftime("%d.%m.%Y %H:%M") if test.created_at else "—" + updated_str = test.updated_at.strftime("%d.%m.%Y %H:%M") if test.updated_at else "—" + questions_count = len(questions_with_options) + + comment_header = f"""// ═══════════════════════════════════════════════════════════════ +// ЭКСПОРТ ТЕСТА: {test.title} +// ═══════════════════════════════════════════════════════════════ +// +// ❓ Вопросов: {questions_count} +// 📅 Создан: {created_str} +// 🔄 Обновлён: {updated_str} +// +// ═══════════════════════════════════════════════════════════════ + +""" + + full_content = comment_header + json_str + + safe_title = "".join(c if c.isalnum() or c in "-_" else "_" for c in test.title)[:50] + filename = f"{safe_title}.json" + + await _callback.message.answer_document( + document=BufferedInputFile(full_content.encode("utf-8"), filename=filename), + caption=f"📤 Экспорт теста: {test.title}", + ) + + @inject async def on_share_test(_callback: CallbackQuery, _button: Button, manager: DialogManager, config: FromDishka[Config], bot_inst: FromDishka[Bot]): test_id = manager.dialog_data.get("selected_test_id") @@ -462,6 +550,7 @@ shared_tests_dialog = Dialog( ), Button(Const("📊 Статистика"), id="statistics", on_click=on_statistics), Button(Const("🔗 Поделиться"), id="share", on_click=on_share_test), + Button(Const("📤 Экспорт"), id="export", on_click=on_export_test), Button(Const("✏️ Изменить"), id="edit_menu", on_click=on_edit_menu), Button(Const("◀️ Назад"), id="back", on_click=on_back_to_list), ), diff --git a/src/trudex/application/bot/user_dialogs/deeplink.py b/src/trudex/application/bot/user_dialogs/deeplink.py index 1cd907d..b990ae5 100644 --- a/src/trudex/application/bot/user_dialogs/deeplink.py +++ b/src/trudex/application/bot/user_dialogs/deeplink.py @@ -11,6 +11,7 @@ from trudex.infrastructure.database.dao.test import TestDAO from trudex.infrastructure.database.models import QuestionType from trudex.infrastructure.database.repo.test import TestRepository from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository +from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter @inject @@ -60,6 +61,7 @@ async def on_start_deeplink_test( test_dao: FromDishka[TestDAO], test_repo: FromDishka[TestRepository], attempt_repo: FromDishka[TestAttemptRepository], + rate_limiter: FromDishka[PasswordRateLimiter], ): assert _callback.from_user is not None @@ -89,6 +91,12 @@ async def on_start_deeplink_test( await attempt_repo.attempt_dao.delete(active_attempt.id) if test.password: + # Проверяем rate limit перед показом экрана ввода пароля + allowed, wait_time = await rate_limiter.check(user_id) + if not allowed: + minutes = int(wait_time // 60) + 1 + await _callback.answer(f"⏳ Слишком много попыток. Подождите {minutes} мин.", show_alert=True) + return await manager.switch_to(UserDeeplinkSG.password_input) else: await start_test_without_password(manager, test_repo, attempt_repo, test_id, user_id) @@ -141,6 +149,7 @@ async def on_deeplink_password_input( test_dao: FromDishka[TestDAO], test_repo: FromDishka[TestRepository], attempt_repo: FromDishka[TestAttemptRepository], + rate_limiter: FromDishka[PasswordRateLimiter], ): assert message.from_user is not None @@ -164,7 +173,14 @@ async def on_deeplink_password_input( manager, test_repo, attempt_repo, test_id, message.from_user.id ) else: - await message.answer("❌ Неверный пароль") + # Проверяем rate limit при неверном пароле + allowed, wait_time = await rate_limiter.check(message.from_user.id) + if not allowed: + minutes = int(wait_time // 60) + 1 + await message.answer(f"❌ Неверный пароль\n⏳ Слишком много попыток. Подождите {minutes} мин.") + await manager.start(UserMenuSG.main, mode=StartMode.RESET_STACK) + else: + await message.answer("❌ Неверный пароль") async def on_back_to_menu(_callback: CallbackQuery, _button: Button, manager: DialogManager): diff --git a/src/trudex/application/bot/user_dialogs/take_test.py b/src/trudex/application/bot/user_dialogs/take_test.py index c1b6e6f..e618084 100644 --- a/src/trudex/application/bot/user_dialogs/take_test.py +++ b/src/trudex/application/bot/user_dialogs/take_test.py @@ -12,6 +12,7 @@ from trudex.infrastructure.database.dao.test import TestDAO from trudex.infrastructure.database.dao.user_answer import UserAnswerDAO from trudex.infrastructure.database.repo.test import TestRepository from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository +from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter from trudex.infrastructure.utils.timezone import now_msk_naive @@ -32,6 +33,7 @@ async def on_start_test( test_dao: FromDishka[TestDAO], test_repo: FromDishka[TestRepository], attempt_repo: FromDishka[TestAttemptRepository], + rate_limiter: FromDishka[PasswordRateLimiter], ): assert _callback.from_user is not None test_id = manager.dialog_data.get("selected_test_id") @@ -66,6 +68,12 @@ async def on_start_test( await attempt_repo.attempt_dao.delete(active_attempt.id) if test.password: + # Проверяем rate limit перед показом экрана ввода пароля + allowed, wait_time = await rate_limiter.check(user_id) + if not allowed: + minutes = int(wait_time // 60) + 1 + await _callback.answer(f"⏳ Слишком много попыток. Подождите {minutes} мин.", show_alert=True) + return await manager.start(UserTestSG.password_input, mode=StartMode.NORMAL, data={"test_id": test_id}) else: _, questions = await test_repo.get_test_with_questions(test_id) @@ -100,6 +108,7 @@ async def on_password_input( test_dao: FromDishka[TestDAO], test_repo: FromDishka[TestRepository], attempt_repo: FromDishka[TestAttemptRepository], + rate_limiter: FromDishka[PasswordRateLimiter], ): assert message.from_user is not None start_data = manager.start_data or {} @@ -137,7 +146,14 @@ async def on_password_input( await manager.switch_to(first_state) else: - await message.answer("❌ Неверный пароль") + # Проверяем rate limit при неверном пароле + allowed, wait_time = await rate_limiter.check(message.from_user.id) + if not allowed: + minutes = int(wait_time // 60) + 1 + await message.answer(f"❌ Неверный пароль\n⏳ Слишком много попыток. Подождите {minutes} мин.") + await manager.done() + else: + await message.answer("❌ Неверный пароль") @inject diff --git a/src/trudex/infrastructure/database/dao/user.py b/src/trudex/infrastructure/database/dao/user.py index 8b88842..3914f2c 100644 --- a/src/trudex/infrastructure/database/dao/user.py +++ b/src/trudex/infrastructure/database/dao/user.py @@ -27,6 +27,13 @@ class UserDAO: model = result.scalar_one_or_none() return UserDTO(model).to_domain() if model else None + async def get_by_username(self, username: str) -> DomainUser | None: + result = await self.session.execute( + select(User).where(User.username == username) + ) + model = result.scalar_one_or_none() + return UserDTO(model).to_domain() if model else None + async def get_all(self) -> list[DomainUser]: result = await self.session.execute( select(User).order_by(User.created_at.desc()) diff --git a/src/trudex/infrastructure/database/models.py b/src/trudex/infrastructure/database/models.py index 2d18dca..a780cc5 100644 --- a/src/trudex/infrastructure/database/models.py +++ b/src/trudex/infrastructure/database/models.py @@ -16,11 +16,11 @@ class User(Base): __tablename__ = "users" id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - username: Mapped[str | None] = mapped_column(String(32)) + username: Mapped[str | None] = mapped_column(String(32), index=True) first_name: Mapped[str] = mapped_column(String(64)) last_name: Mapped[str | None] = mapped_column(String(64)) name: Mapped[str | None] = mapped_column(String(128)) - group: Mapped[int | None] = mapped_column(CheckConstraint("group >= 1000 AND group <= 9999")) + group: Mapped[int | None] = mapped_column(CheckConstraint("group >= 1000 AND group <= 9999"), index=True) is_admin: Mapped[bool] = mapped_column(default=False) name_updated_at: Mapped[datetime | None] = mapped_column(default=None) group_updated_at: Mapped[datetime | None] = mapped_column(default=None) @@ -70,7 +70,7 @@ class Question(Base): __tablename__ = "questions" id: Mapped[int] = mapped_column(primary_key=True) - test_id: Mapped[int] = mapped_column(ForeignKey("tests.id")) + test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"), index=True) text: Mapped[str] = mapped_column(Text) position: Mapped[int] = mapped_column(Integer, default=0) question_type: Mapped[QuestionType] = mapped_column(default=QuestionType.SINGLE) @@ -88,7 +88,7 @@ class Option(Base): __tablename__ = "options" id: Mapped[int] = mapped_column(primary_key=True) - question_id: Mapped[int] = mapped_column(ForeignKey("questions.id")) + question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"), index=True) text: Mapped[str] = mapped_column(String(255)) is_correct: Mapped[bool] = mapped_column(default=False) explanation: Mapped[str | None] = mapped_column(Text) @@ -101,13 +101,14 @@ class TestAttempt(Base): __tablename__ = "test_attempts" id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(BigInteger, index=True) - test_id: Mapped[int] = mapped_column(ForeignKey("tests.id")) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.id"), index=True) + test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"), index=True) started_at: Mapped[datetime] = mapped_column(server_default=func.now()) finished_at: Mapped[datetime | None] = mapped_column(default=None) score: Mapped[int] = mapped_column(Integer, default=0) is_passed: Mapped[bool] = mapped_column(default=False) + user: Mapped["User"] = relationship() test: Mapped["Test"] = relationship() answers: Mapped[list["UserAnswer"]] = relationship( back_populates="attempt", @@ -120,8 +121,8 @@ class UserAnswer(Base): __tablename__ = "user_answers" id: Mapped[int] = mapped_column(primary_key=True) - attempt_id: Mapped[int] = mapped_column(ForeignKey("test_attempts.id")) - question_id: Mapped[int] = mapped_column(ForeignKey("questions.id")) + attempt_id: Mapped[int] = mapped_column(ForeignKey("test_attempts.id"), index=True) + question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"), index=True) selected_option_id: Mapped[int | None] = mapped_column(ForeignKey("options.id"), default=None) text_answer: Mapped[str | None] = mapped_column(Text, default=None) is_correct: Mapped[bool] = mapped_column(default=False) diff --git a/src/trudex/infrastructure/database/repo/test.py b/src/trudex/infrastructure/database/repo/test.py index 521b913..b2dcae1 100644 --- a/src/trudex/infrastructure/database/repo/test.py +++ b/src/trudex/infrastructure/database/repo/test.py @@ -122,6 +122,28 @@ class TestRepository: count = result.scalar_one() return count + async def get_questions_with_options_by_ids( + self, question_ids: list[int] + ) -> dict[int, tuple[Question, list[Option]]]: + """Загружает вопросы с опциями по списку ID за один запрос.""" + if not question_ids: + return {} + + result = await self.session.execute( + select(QuestionModel) + .where(QuestionModel.id.in_(question_ids)) + .options(selectinload(QuestionModel.options)) + ) + question_models = list(result.scalars().all()) + + questions_dict: dict[int, tuple[Question, list[Option]]] = {} + for qm in question_models: + question = QuestionDTO(qm).to_domain() + options = [OptionDTO(o).to_domain() for o in qm.options] + questions_dict[qm.id] = (question, options) + + return questions_dict + async def duplicate_test(self, test_id: int, new_title: str) -> Test | None: test, questions_with_options = await self.get_full_test(test_id) if not test: diff --git a/src/trudex/infrastructure/database/repo/test_attempt.py b/src/trudex/infrastructure/database/repo/test_attempt.py index 4842e4e..56ae2ce 100644 --- a/src/trudex/infrastructure/database/repo/test_attempt.py +++ b/src/trudex/infrastructure/database/repo/test_attempt.py @@ -155,18 +155,16 @@ class TestAttemptRepository: return [UserAnswerDTO(model).to_domain() for model in models] async def get_question_statistics(self, question_id: int) -> dict[str, int]: - total_result = await self.session.execute( - select(func.count(UserAnswerModel.id)) + result = await self.session.execute( + select( + func.count(UserAnswerModel.id).label("total"), + func.sum(func.cast(UserAnswerModel.is_correct, func.Integer)).label("correct") + ) .where(UserAnswerModel.question_id == question_id) ) - total = total_result.scalar_one() - - correct_result = await self.session.execute( - select(func.count(UserAnswerModel.id)) - .where(UserAnswerModel.question_id == question_id) - .where(UserAnswerModel.is_correct == True) - ) - correct = correct_result.scalar_one() + row = result.one() + total = row.total or 0 + correct = row.correct or 0 return { "total_answers": total, diff --git a/src/trudex/infrastructure/di.py b/src/trudex/infrastructure/di.py index e4706d3..4b23fe0 100644 --- a/src/trudex/infrastructure/di.py +++ b/src/trudex/infrastructure/di.py @@ -18,12 +18,17 @@ from trudex.infrastructure.database.repo.test_attempt import TestAttemptReposito from trudex.infrastructure.database.repo.user import UserRepository from trudex.infrastructure.scheduling.tasks import deactivate_expired_tests from trudex.infrastructure.utils.config import Config +from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter class DatabaseProvider(Provider): @provide(scope=Scope.APP) def get_session_maker(self, config: Config) -> async_sessionmaker[AsyncSession]: return new_session_maker(config.database.url) + + @provide(scope=Scope.APP) + def get_password_rate_limiter(self) -> PasswordRateLimiter: + return PasswordRateLimiter() @provide(scope=Scope.REQUEST) async def get_session( diff --git a/src/trudex/infrastructure/utils/broadcast.py b/src/trudex/infrastructure/utils/broadcast.py index a30c35f..0059c81 100644 --- a/src/trudex/infrastructure/utils/broadcast.py +++ b/src/trudex/infrastructure/utils/broadcast.py @@ -3,7 +3,13 @@ import logging from dataclasses import dataclass from aiogram import Bot -from aiogram.exceptions import TelegramBadRequest, TelegramForbiddenError +from aiogram.exceptions import ( + TelegramAPIError, + TelegramBadRequest, + TelegramForbiddenError, + TelegramNetworkError, + TelegramRetryAfter, +) from trudex.infrastructure.database.dao.user import UserDAO @@ -28,14 +34,29 @@ async def broadcast_message(bot: Bot, message_id: int, chat_id: int, user_dao: U try: await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id) success += 1 + except TelegramRetryAfter as e: + logger.warning("Rate limited, waiting %d seconds", e.retry_after) + await asyncio.sleep(e.retry_after) + # Retry after waiting + try: + await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id) + success += 1 + except TelegramAPIError: + failed += 1 except TelegramForbiddenError: logger.debug("Broadcast failed (forbidden): user_id=%d", user.id) failed += 1 except TelegramBadRequest as e: logger.debug("Broadcast failed (bad request): user_id=%d, error=%s", user.id, e) failed += 1 + except TelegramNetworkError as e: + logger.warning("Network error during broadcast: user_id=%d, error=%s", user.id, e) + failed += 1 + except TelegramAPIError as e: + logger.warning("Telegram API error during broadcast: user_id=%d, error=%s", user.id, e) + failed += 1 - await asyncio.sleep(0.1) + await asyncio.sleep(0.05) logger.info("Broadcast completed: success=%d, failed=%d, total=%d", success, failed, len(users)) return BroadcastStats(success=success, failed=failed, total=len(users)) diff --git a/src/trudex/infrastructure/utils/rate_limiter.py b/src/trudex/infrastructure/utils/rate_limiter.py new file mode 100644 index 0000000..06daf03 --- /dev/null +++ b/src/trudex/infrastructure/utils/rate_limiter.py @@ -0,0 +1,57 @@ +import asyncio +import time +from dataclasses import dataclass + + +@dataclass +class UserBucket: + tokens: float + last_updated: float + + +class RateLimiter: + def __init__(self, rate: int, period: int): + self.rate = rate + self.period = period + self.fill_rate = rate / period + self.buckets: dict[int, UserBucket] = {} + self._lock = asyncio.Lock() + + async def check(self, user_id: int) -> tuple[bool, float]: + async with self._lock: + now = time.time() + + if user_id not in self.buckets: + self.buckets[user_id] = UserBucket( + tokens=self.rate - 1, + last_updated=now + ) + return True, 0.0 + + bucket = self.buckets[user_id] + + elapsed = now - bucket.last_updated + added_tokens = elapsed * self.fill_rate + bucket.tokens = min(self.rate, bucket.tokens + added_tokens) + bucket.last_updated = now + + if bucket.tokens >= 1: + bucket.tokens -= 1 + return True, 0.0 + else: + wait_time = (1 - bucket.tokens) / self.fill_rate + return False, wait_time + + async def cleanup(self) -> None: + async with self._lock: + full_buckets = [ + user_id for user_id, bucket in self.buckets.items() + if bucket.tokens >= self.rate + ] + for user_id in full_buckets: + del self.buckets[user_id] + + +class PasswordRateLimiter(RateLimiter): + def __init__(self): + super().__init__(rate=5, period=3600)