This commit is contained in:
2026-01-04 16:05:19 +03:00
parent f46a0ac45b
commit 260171086b
14 changed files with 302 additions and 30 deletions
@@ -0,0 +1,43 @@
"""add_indexes_and_fk
Revision ID: ca107b03ddf8
Revises: 40f5317720a4
Create Date: 2026-01-04 15:32:14.881408
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
revision: str = 'ca107b03ddf8'
down_revision: str | None = '40f5317720a4'
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f('ix_options_question_id'), 'options', ['question_id'], unique=False)
op.create_index(op.f('ix_questions_test_id'), 'questions', ['test_id'], unique=False)
op.create_index(op.f('ix_test_attempts_test_id'), 'test_attempts', ['test_id'], unique=False)
op.create_foreign_key(None, 'test_attempts', 'users', ['user_id'], ['id'])
op.create_index(op.f('ix_user_answers_attempt_id'), 'user_answers', ['attempt_id'], unique=False)
op.create_index(op.f('ix_user_answers_question_id'), 'user_answers', ['question_id'], unique=False)
op.create_index(op.f('ix_users_group'), 'users', ['group'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_group'), table_name='users')
op.drop_index(op.f('ix_user_answers_question_id'), table_name='user_answers')
op.drop_index(op.f('ix_user_answers_attempt_id'), table_name='user_answers')
op.drop_constraint(None, 'test_attempts', type_='foreignkey')
op.drop_index(op.f('ix_test_attempts_test_id'), table_name='test_attempts')
op.drop_index(op.f('ix_questions_test_id'), table_name='questions')
op.drop_index(op.f('ix_options_question_id'), table_name='options')
# ### end Alembic commands ###
@@ -68,8 +68,7 @@ async def on_user_input(message: Message, _widget: MessageInput, manager: Dialog
user = None user = None
if text.startswith("@"): if text.startswith("@"):
username = text[1:] username = text[1:]
all_users = await user_dao.get_all() user = await user_dao.get_by_username(username)
user = next((u for u in all_users if u.username == username), None)
elif text.isdigit(): elif text.isdigit():
user = await user_dao.get_by_id(int(text)) user = await user_dao.get_by_id(int(text))
@@ -137,8 +137,7 @@ async def on_user_input(message: Message, _widget: MessageInput, manager: Dialog
user = None user = None
if text.startswith("@"): if text.startswith("@"):
username = text[1:] username = text[1:]
all_users = await user_dao.get_all() user = await user_dao.get_by_username(username)
user = next((u for u in all_users if u.username == username), None)
elif text.isdigit(): elif text.isdigit():
user = await user_dao.get_by_id(int(text)) user = await user_dao.get_by_id(int(text))
@@ -30,7 +30,6 @@ class RejectNotCreatorMiddleware(BaseMiddleware):
if user_id == config.bot.creator_id: if user_id == config.bot.creator_id:
return await handler(event, data) return await handler(event, data)
await event.answer("У вас нет доступа к панели создателя.")
return return
return await handler(event, data) return await handler(event, data)
@@ -1,17 +1,19 @@
import asyncio import asyncio
import functools import functools
import json
from datetime import date, datetime, time from datetime import date, datetime, time
from aiogram import Bot from aiogram import Bot
from aiogram.types import BufferedInputFile, CallbackQuery, Message from aiogram.types import BufferedInputFile, CallbackQuery, Message
from aiogram_dialog import Dialog, DialogManager, StartMode, Window from aiogram_dialog import Dialog, DialogManager, StartMode, Window
from aiogram_dialog.widgets.input import MessageInput from aiogram_dialog.widgets.input import MessageInput
from aiogram_dialog.widgets.kbd import Button, Calendar, Column, ScrollingGroup, Select from aiogram_dialog.widgets.kbd import Button, Calendar, Column, Row, ScrollingGroup, Select
from aiogram_dialog.widgets.text import Const, Format from aiogram_dialog.widgets.text import Const, Format
from dishka import FromDishka from dishka import FromDishka
from dishka.integrations.aiogram_dialog import inject from dishka.integrations.aiogram_dialog import inject
from trudex.application.bot.shared_dialogs.states import SharedCreateTestSG, SharedTestsSG from trudex.application.bot.shared_dialogs.states import SharedCreateTestSG, SharedTestsSG
from trudex.domain.schemas import QuestionType
from trudex.infrastructure.database.dao.group import GroupDAO from trudex.infrastructure.database.dao.group import GroupDAO
from trudex.infrastructure.database.dao.test import TestDAO from trudex.infrastructure.database.dao.test import TestDAO
from trudex.infrastructure.database.repo.test import TestRepository from trudex.infrastructure.database.repo.test import TestRepository
@@ -199,11 +201,16 @@ async def get_attempt_detail(
"<b>📋 Ответы:</b>\n", "<b>📋 Ответы:</b>\n",
] ]
# Загружаем все вопросы с опциями за один запрос
question_ids = [answer.question_id for answer in answers]
questions_map = await test_repo.get_questions_with_options_by_ids(question_ids)
for i, answer in enumerate(answers, 1): for i, answer in enumerate(answers, 1):
question, options = await test_repo.get_question_with_options(answer.question_id) question_data = questions_map.get(answer.question_id)
if not question: if not question_data:
continue continue
question, options = question_data
correct_options = [opt for opt in options if opt.is_correct] correct_options = [opt for opt in options if opt.is_correct]
correct_texts = [opt.text for opt in correct_options] correct_texts = [opt.text for opt in correct_options]
@@ -221,6 +228,87 @@ async def get_attempt_detail(
return {"attempt_info": "\n".join(lines)} return {"attempt_info": "\n".join(lines)}
@inject
async def on_export_test(
_callback: CallbackQuery,
_button: Button,
manager: DialogManager,
test_repo: FromDishka[TestRepository],
) -> None:
test_id = manager.dialog_data.get("selected_test_id")
if not test_id:
await _callback.answer("❌ Тест не найден")
return
assert _callback.message is not None
await _callback.answer("⏳ Экспортирую тест...")
test, questions_with_options = await test_repo.get_full_test(test_id)
if not test:
await _callback.message.answer("❌ Тест не найден")
return
export_data: dict = {
"title": test.title,
"description": test.description,
"password": test.password,
"attempts": test.attempts,
"expires_at": test.expires_at.isoformat() if test.expires_at else None,
"for_group": test.for_group,
"questions": [],
}
questions_list: list = export_data["questions"]
for question, options in questions_with_options:
question_data: dict = {
"question_type": question.question_type.value,
"question": question.text,
}
if question.question_type == QuestionType.INPUT:
correct_options = [o for o in options if o.is_correct]
if correct_options:
question_data["correct_answer"] = correct_options[0].text
else:
question_data["answers"] = [
{"option": o.text, "is_correct": o.is_correct}
for o in options
]
questions_list.append(question_data)
json_str = json.dumps(export_data, ensure_ascii=False, indent=2)
created_str = test.created_at.strftime("%d.%m.%Y %H:%M") if test.created_at else ""
updated_str = test.updated_at.strftime("%d.%m.%Y %H:%M") if test.updated_at else ""
questions_count = len(questions_with_options)
comment_header = f"""// ═══════════════════════════════════════════════════════════════
// ЭКСПОРТ ТЕСТА: {test.title}
// ═══════════════════════════════════════════════════════════════
//
// ❓ Вопросов: {questions_count}
// 📅 Создан: {created_str}
// 🔄 Обновлён: {updated_str}
//
// ═══════════════════════════════════════════════════════════════
"""
full_content = comment_header + json_str
safe_title = "".join(c if c.isalnum() or c in "-_" else "_" for c in test.title)[:50]
filename = f"{safe_title}.json"
await _callback.message.answer_document(
document=BufferedInputFile(full_content.encode("utf-8"), filename=filename),
caption=f"📤 <b>Экспорт теста:</b> {test.title}",
)
@inject @inject
async def on_share_test(_callback: CallbackQuery, _button: Button, manager: DialogManager, config: FromDishka[Config], bot_inst: FromDishka[Bot]): async def on_share_test(_callback: CallbackQuery, _button: Button, manager: DialogManager, config: FromDishka[Config], bot_inst: FromDishka[Bot]):
test_id = manager.dialog_data.get("selected_test_id") test_id = manager.dialog_data.get("selected_test_id")
@@ -462,6 +550,7 @@ shared_tests_dialog = Dialog(
), ),
Button(Const("📊 Статистика"), id="statistics", on_click=on_statistics), Button(Const("📊 Статистика"), id="statistics", on_click=on_statistics),
Button(Const("🔗 Поделиться"), id="share", on_click=on_share_test), Button(Const("🔗 Поделиться"), id="share", on_click=on_share_test),
Button(Const("📤 Экспорт"), id="export", on_click=on_export_test),
Button(Const("✏️ Изменить"), id="edit_menu", on_click=on_edit_menu), Button(Const("✏️ Изменить"), id="edit_menu", on_click=on_edit_menu),
Button(Const("◀️ Назад"), id="back", on_click=on_back_to_list), Button(Const("◀️ Назад"), id="back", on_click=on_back_to_list),
), ),
@@ -11,6 +11,7 @@ from trudex.infrastructure.database.dao.test import TestDAO
from trudex.infrastructure.database.models import QuestionType from trudex.infrastructure.database.models import QuestionType
from trudex.infrastructure.database.repo.test import TestRepository from trudex.infrastructure.database.repo.test import TestRepository
from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository
from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter
@inject @inject
@@ -60,6 +61,7 @@ async def on_start_deeplink_test(
test_dao: FromDishka[TestDAO], test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository], test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository], attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
): ):
assert _callback.from_user is not None assert _callback.from_user is not None
@@ -89,6 +91,12 @@ async def on_start_deeplink_test(
await attempt_repo.attempt_dao.delete(active_attempt.id) await attempt_repo.attempt_dao.delete(active_attempt.id)
if test.password: if test.password:
# Проверяем rate limit перед показом экрана ввода пароля
allowed, wait_time = await rate_limiter.check(user_id)
if not allowed:
minutes = int(wait_time // 60) + 1
await _callback.answer(f"⏳ Слишком много попыток. Подождите {minutes} мин.", show_alert=True)
return
await manager.switch_to(UserDeeplinkSG.password_input) await manager.switch_to(UserDeeplinkSG.password_input)
else: else:
await start_test_without_password(manager, test_repo, attempt_repo, test_id, user_id) await start_test_without_password(manager, test_repo, attempt_repo, test_id, user_id)
@@ -141,6 +149,7 @@ async def on_deeplink_password_input(
test_dao: FromDishka[TestDAO], test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository], test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository], attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
): ):
assert message.from_user is not None assert message.from_user is not None
@@ -164,7 +173,14 @@ async def on_deeplink_password_input(
manager, test_repo, attempt_repo, test_id, message.from_user.id manager, test_repo, attempt_repo, test_id, message.from_user.id
) )
else: else:
await message.answer("Неверный пароль") # Проверяем rate limit при неверном пароле
allowed, wait_time = await rate_limiter.check(message.from_user.id)
if not allowed:
minutes = int(wait_time // 60) + 1
await message.answer(f"❌ Неверный пароль\n⏳ Слишком много попыток. Подождите {minutes} мин.")
await manager.start(UserMenuSG.main, mode=StartMode.RESET_STACK)
else:
await message.answer("❌ Неверный пароль")
async def on_back_to_menu(_callback: CallbackQuery, _button: Button, manager: DialogManager): async def on_back_to_menu(_callback: CallbackQuery, _button: Button, manager: DialogManager):
@@ -12,6 +12,7 @@ from trudex.infrastructure.database.dao.test import TestDAO
from trudex.infrastructure.database.dao.user_answer import UserAnswerDAO from trudex.infrastructure.database.dao.user_answer import UserAnswerDAO
from trudex.infrastructure.database.repo.test import TestRepository from trudex.infrastructure.database.repo.test import TestRepository
from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository
from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter
from trudex.infrastructure.utils.timezone import now_msk_naive from trudex.infrastructure.utils.timezone import now_msk_naive
@@ -32,6 +33,7 @@ async def on_start_test(
test_dao: FromDishka[TestDAO], test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository], test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository], attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
): ):
assert _callback.from_user is not None assert _callback.from_user is not None
test_id = manager.dialog_data.get("selected_test_id") test_id = manager.dialog_data.get("selected_test_id")
@@ -66,6 +68,12 @@ async def on_start_test(
await attempt_repo.attempt_dao.delete(active_attempt.id) await attempt_repo.attempt_dao.delete(active_attempt.id)
if test.password: if test.password:
# Проверяем rate limit перед показом экрана ввода пароля
allowed, wait_time = await rate_limiter.check(user_id)
if not allowed:
minutes = int(wait_time // 60) + 1
await _callback.answer(f"⏳ Слишком много попыток. Подождите {minutes} мин.", show_alert=True)
return
await manager.start(UserTestSG.password_input, mode=StartMode.NORMAL, data={"test_id": test_id}) await manager.start(UserTestSG.password_input, mode=StartMode.NORMAL, data={"test_id": test_id})
else: else:
_, questions = await test_repo.get_test_with_questions(test_id) _, questions = await test_repo.get_test_with_questions(test_id)
@@ -100,6 +108,7 @@ async def on_password_input(
test_dao: FromDishka[TestDAO], test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository], test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository], attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
): ):
assert message.from_user is not None assert message.from_user is not None
start_data = manager.start_data or {} start_data = manager.start_data or {}
@@ -137,7 +146,14 @@ async def on_password_input(
await manager.switch_to(first_state) await manager.switch_to(first_state)
else: else:
await message.answer("Неверный пароль") # Проверяем rate limit при неверном пароле
allowed, wait_time = await rate_limiter.check(message.from_user.id)
if not allowed:
minutes = int(wait_time // 60) + 1
await message.answer(f"❌ Неверный пароль\n⏳ Слишком много попыток. Подождите {minutes} мин.")
await manager.done()
else:
await message.answer("❌ Неверный пароль")
@inject @inject
@@ -27,6 +27,13 @@ class UserDAO:
model = result.scalar_one_or_none() model = result.scalar_one_or_none()
return UserDTO(model).to_domain() if model else None return UserDTO(model).to_domain() if model else None
async def get_by_username(self, username: str) -> DomainUser | None:
result = await self.session.execute(
select(User).where(User.username == username)
)
model = result.scalar_one_or_none()
return UserDTO(model).to_domain() if model else None
async def get_all(self) -> list[DomainUser]: async def get_all(self) -> list[DomainUser]:
result = await self.session.execute( result = await self.session.execute(
select(User).order_by(User.created_at.desc()) select(User).order_by(User.created_at.desc())
+9 -8
View File
@@ -16,11 +16,11 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True) id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
username: Mapped[str | None] = mapped_column(String(32)) username: Mapped[str | None] = mapped_column(String(32), index=True)
first_name: Mapped[str] = mapped_column(String(64)) first_name: Mapped[str] = mapped_column(String(64))
last_name: Mapped[str | None] = mapped_column(String(64)) last_name: Mapped[str | None] = mapped_column(String(64))
name: Mapped[str | None] = mapped_column(String(128)) name: Mapped[str | None] = mapped_column(String(128))
group: Mapped[int | None] = mapped_column(CheckConstraint("group >= 1000 AND group <= 9999")) group: Mapped[int | None] = mapped_column(CheckConstraint("group >= 1000 AND group <= 9999"), index=True)
is_admin: Mapped[bool] = mapped_column(default=False) is_admin: Mapped[bool] = mapped_column(default=False)
name_updated_at: Mapped[datetime | None] = mapped_column(default=None) name_updated_at: Mapped[datetime | None] = mapped_column(default=None)
group_updated_at: Mapped[datetime | None] = mapped_column(default=None) group_updated_at: Mapped[datetime | None] = mapped_column(default=None)
@@ -70,7 +70,7 @@ class Question(Base):
__tablename__ = "questions" __tablename__ = "questions"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(ForeignKey("tests.id")) test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"), index=True)
text: Mapped[str] = mapped_column(Text) text: Mapped[str] = mapped_column(Text)
position: Mapped[int] = mapped_column(Integer, default=0) position: Mapped[int] = mapped_column(Integer, default=0)
question_type: Mapped[QuestionType] = mapped_column(default=QuestionType.SINGLE) question_type: Mapped[QuestionType] = mapped_column(default=QuestionType.SINGLE)
@@ -88,7 +88,7 @@ class Option(Base):
__tablename__ = "options" __tablename__ = "options"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
question_id: Mapped[int] = mapped_column(ForeignKey("questions.id")) question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"), index=True)
text: Mapped[str] = mapped_column(String(255)) text: Mapped[str] = mapped_column(String(255))
is_correct: Mapped[bool] = mapped_column(default=False) is_correct: Mapped[bool] = mapped_column(default=False)
explanation: Mapped[str | None] = mapped_column(Text) explanation: Mapped[str | None] = mapped_column(Text)
@@ -101,13 +101,14 @@ class TestAttempt(Base):
__tablename__ = "test_attempts" __tablename__ = "test_attempts"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.id"), index=True)
test_id: Mapped[int] = mapped_column(ForeignKey("tests.id")) test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"), index=True)
started_at: Mapped[datetime] = mapped_column(server_default=func.now()) started_at: Mapped[datetime] = mapped_column(server_default=func.now())
finished_at: Mapped[datetime | None] = mapped_column(default=None) finished_at: Mapped[datetime | None] = mapped_column(default=None)
score: Mapped[int] = mapped_column(Integer, default=0) score: Mapped[int] = mapped_column(Integer, default=0)
is_passed: Mapped[bool] = mapped_column(default=False) is_passed: Mapped[bool] = mapped_column(default=False)
user: Mapped["User"] = relationship()
test: Mapped["Test"] = relationship() test: Mapped["Test"] = relationship()
answers: Mapped[list["UserAnswer"]] = relationship( answers: Mapped[list["UserAnswer"]] = relationship(
back_populates="attempt", back_populates="attempt",
@@ -120,8 +121,8 @@ class UserAnswer(Base):
__tablename__ = "user_answers" __tablename__ = "user_answers"
id: Mapped[int] = mapped_column(primary_key=True) id: Mapped[int] = mapped_column(primary_key=True)
attempt_id: Mapped[int] = mapped_column(ForeignKey("test_attempts.id")) attempt_id: Mapped[int] = mapped_column(ForeignKey("test_attempts.id"), index=True)
question_id: Mapped[int] = mapped_column(ForeignKey("questions.id")) question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"), index=True)
selected_option_id: Mapped[int | None] = mapped_column(ForeignKey("options.id"), default=None) selected_option_id: Mapped[int | None] = mapped_column(ForeignKey("options.id"), default=None)
text_answer: Mapped[str | None] = mapped_column(Text, default=None) text_answer: Mapped[str | None] = mapped_column(Text, default=None)
is_correct: Mapped[bool] = mapped_column(default=False) is_correct: Mapped[bool] = mapped_column(default=False)
@@ -122,6 +122,28 @@ class TestRepository:
count = result.scalar_one() count = result.scalar_one()
return count return count
async def get_questions_with_options_by_ids(
self, question_ids: list[int]
) -> dict[int, tuple[Question, list[Option]]]:
"""Загружает вопросы с опциями по списку ID за один запрос."""
if not question_ids:
return {}
result = await self.session.execute(
select(QuestionModel)
.where(QuestionModel.id.in_(question_ids))
.options(selectinload(QuestionModel.options))
)
question_models = list(result.scalars().all())
questions_dict: dict[int, tuple[Question, list[Option]]] = {}
for qm in question_models:
question = QuestionDTO(qm).to_domain()
options = [OptionDTO(o).to_domain() for o in qm.options]
questions_dict[qm.id] = (question, options)
return questions_dict
async def duplicate_test(self, test_id: int, new_title: str) -> Test | None: async def duplicate_test(self, test_id: int, new_title: str) -> Test | None:
test, questions_with_options = await self.get_full_test(test_id) test, questions_with_options = await self.get_full_test(test_id)
if not test: if not test:
@@ -155,18 +155,16 @@ class TestAttemptRepository:
return [UserAnswerDTO(model).to_domain() for model in models] return [UserAnswerDTO(model).to_domain() for model in models]
async def get_question_statistics(self, question_id: int) -> dict[str, int]: async def get_question_statistics(self, question_id: int) -> dict[str, int]:
total_result = await self.session.execute( result = await self.session.execute(
select(func.count(UserAnswerModel.id)) select(
func.count(UserAnswerModel.id).label("total"),
func.sum(func.cast(UserAnswerModel.is_correct, func.Integer)).label("correct")
)
.where(UserAnswerModel.question_id == question_id) .where(UserAnswerModel.question_id == question_id)
) )
total = total_result.scalar_one() row = result.one()
total = row.total or 0
correct_result = await self.session.execute( correct = row.correct or 0
select(func.count(UserAnswerModel.id))
.where(UserAnswerModel.question_id == question_id)
.where(UserAnswerModel.is_correct == True)
)
correct = correct_result.scalar_one()
return { return {
"total_answers": total, "total_answers": total,
+5
View File
@@ -18,6 +18,7 @@ from trudex.infrastructure.database.repo.test_attempt import TestAttemptReposito
from trudex.infrastructure.database.repo.user import UserRepository from trudex.infrastructure.database.repo.user import UserRepository
from trudex.infrastructure.scheduling.tasks import deactivate_expired_tests from trudex.infrastructure.scheduling.tasks import deactivate_expired_tests
from trudex.infrastructure.utils.config import Config from trudex.infrastructure.utils.config import Config
from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter
class DatabaseProvider(Provider): class DatabaseProvider(Provider):
@@ -25,6 +26,10 @@ class DatabaseProvider(Provider):
def get_session_maker(self, config: Config) -> async_sessionmaker[AsyncSession]: def get_session_maker(self, config: Config) -> async_sessionmaker[AsyncSession]:
return new_session_maker(config.database.url) return new_session_maker(config.database.url)
@provide(scope=Scope.APP)
def get_password_rate_limiter(self) -> PasswordRateLimiter:
return PasswordRateLimiter()
@provide(scope=Scope.REQUEST) @provide(scope=Scope.REQUEST)
async def get_session( async def get_session(
self, session_maker: async_sessionmaker[AsyncSession] self, session_maker: async_sessionmaker[AsyncSession]
+23 -2
View File
@@ -3,7 +3,13 @@ import logging
from dataclasses import dataclass from dataclasses import dataclass
from aiogram import Bot from aiogram import Bot
from aiogram.exceptions import TelegramBadRequest, TelegramForbiddenError from aiogram.exceptions import (
TelegramAPIError,
TelegramBadRequest,
TelegramForbiddenError,
TelegramNetworkError,
TelegramRetryAfter,
)
from trudex.infrastructure.database.dao.user import UserDAO from trudex.infrastructure.database.dao.user import UserDAO
@@ -28,14 +34,29 @@ async def broadcast_message(bot: Bot, message_id: int, chat_id: int, user_dao: U
try: try:
await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id) await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id)
success += 1 success += 1
except TelegramRetryAfter as e:
logger.warning("Rate limited, waiting %d seconds", e.retry_after)
await asyncio.sleep(e.retry_after)
# Retry after waiting
try:
await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id)
success += 1
except TelegramAPIError:
failed += 1
except TelegramForbiddenError: except TelegramForbiddenError:
logger.debug("Broadcast failed (forbidden): user_id=%d", user.id) logger.debug("Broadcast failed (forbidden): user_id=%d", user.id)
failed += 1 failed += 1
except TelegramBadRequest as e: except TelegramBadRequest as e:
logger.debug("Broadcast failed (bad request): user_id=%d, error=%s", user.id, e) logger.debug("Broadcast failed (bad request): user_id=%d, error=%s", user.id, e)
failed += 1 failed += 1
except TelegramNetworkError as e:
logger.warning("Network error during broadcast: user_id=%d, error=%s", user.id, e)
failed += 1
except TelegramAPIError as e:
logger.warning("Telegram API error during broadcast: user_id=%d, error=%s", user.id, e)
failed += 1
await asyncio.sleep(0.1) await asyncio.sleep(0.05)
logger.info("Broadcast completed: success=%d, failed=%d, total=%d", success, failed, len(users)) logger.info("Broadcast completed: success=%d, failed=%d, total=%d", success, failed, len(users))
return BroadcastStats(success=success, failed=failed, total=len(users)) return BroadcastStats(success=success, failed=failed, total=len(users))
@@ -0,0 +1,57 @@
import asyncio
import time
from dataclasses import dataclass
@dataclass
class UserBucket:
tokens: float
last_updated: float
class RateLimiter:
def __init__(self, rate: int, period: int):
self.rate = rate
self.period = period
self.fill_rate = rate / period
self.buckets: dict[int, UserBucket] = {}
self._lock = asyncio.Lock()
async def check(self, user_id: int) -> tuple[bool, float]:
async with self._lock:
now = time.time()
if user_id not in self.buckets:
self.buckets[user_id] = UserBucket(
tokens=self.rate - 1,
last_updated=now
)
return True, 0.0
bucket = self.buckets[user_id]
elapsed = now - bucket.last_updated
added_tokens = elapsed * self.fill_rate
bucket.tokens = min(self.rate, bucket.tokens + added_tokens)
bucket.last_updated = now
if bucket.tokens >= 1:
bucket.tokens -= 1
return True, 0.0
else:
wait_time = (1 - bucket.tokens) / self.fill_rate
return False, wait_time
async def cleanup(self) -> None:
async with self._lock:
full_buckets = [
user_id for user_id, bucket in self.buckets.items()
if bucket.tokens >= self.rate
]
for user_id in full_buckets:
del self.buckets[user_id]
class PasswordRateLimiter(RateLimiter):
def __init__(self):
super().__init__(rate=5, period=3600)