This commit is contained in:
2026-01-04 16:05:19 +03:00
parent f46a0ac45b
commit 260171086b
14 changed files with 302 additions and 30 deletions
@@ -0,0 +1,43 @@
"""add_indexes_and_fk
Revision ID: ca107b03ddf8
Revises: 40f5317720a4
Create Date: 2026-01-04 15:32:14.881408
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
revision: str = 'ca107b03ddf8'
down_revision: str | None = '40f5317720a4'
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index(op.f('ix_options_question_id'), 'options', ['question_id'], unique=False)
op.create_index(op.f('ix_questions_test_id'), 'questions', ['test_id'], unique=False)
op.create_index(op.f('ix_test_attempts_test_id'), 'test_attempts', ['test_id'], unique=False)
op.create_foreign_key(None, 'test_attempts', 'users', ['user_id'], ['id'])
op.create_index(op.f('ix_user_answers_attempt_id'), 'user_answers', ['attempt_id'], unique=False)
op.create_index(op.f('ix_user_answers_question_id'), 'user_answers', ['question_id'], unique=False)
op.create_index(op.f('ix_users_group'), 'users', ['group'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_group'), table_name='users')
op.drop_index(op.f('ix_user_answers_question_id'), table_name='user_answers')
op.drop_index(op.f('ix_user_answers_attempt_id'), table_name='user_answers')
op.drop_constraint(None, 'test_attempts', type_='foreignkey')
op.drop_index(op.f('ix_test_attempts_test_id'), table_name='test_attempts')
op.drop_index(op.f('ix_questions_test_id'), table_name='questions')
op.drop_index(op.f('ix_options_question_id'), table_name='options')
# ### end Alembic commands ###
@@ -68,8 +68,7 @@ async def on_user_input(message: Message, _widget: MessageInput, manager: Dialog
user = None
if text.startswith("@"):
username = text[1:]
all_users = await user_dao.get_all()
user = next((u for u in all_users if u.username == username), None)
user = await user_dao.get_by_username(username)
elif text.isdigit():
user = await user_dao.get_by_id(int(text))
@@ -137,8 +137,7 @@ async def on_user_input(message: Message, _widget: MessageInput, manager: Dialog
user = None
if text.startswith("@"):
username = text[1:]
all_users = await user_dao.get_all()
user = next((u for u in all_users if u.username == username), None)
user = await user_dao.get_by_username(username)
elif text.isdigit():
user = await user_dao.get_by_id(int(text))
@@ -30,7 +30,6 @@ class RejectNotCreatorMiddleware(BaseMiddleware):
if user_id == config.bot.creator_id:
return await handler(event, data)
await event.answer("У вас нет доступа к панели создателя.")
return
return await handler(event, data)
@@ -1,17 +1,19 @@
import asyncio
import functools
import json
from datetime import date, datetime, time
from aiogram import Bot
from aiogram.types import BufferedInputFile, CallbackQuery, Message
from aiogram_dialog import Dialog, DialogManager, StartMode, Window
from aiogram_dialog.widgets.input import MessageInput
from aiogram_dialog.widgets.kbd import Button, Calendar, Column, ScrollingGroup, Select
from aiogram_dialog.widgets.kbd import Button, Calendar, Column, Row, ScrollingGroup, Select
from aiogram_dialog.widgets.text import Const, Format
from dishka import FromDishka
from dishka.integrations.aiogram_dialog import inject
from trudex.application.bot.shared_dialogs.states import SharedCreateTestSG, SharedTestsSG
from trudex.domain.schemas import QuestionType
from trudex.infrastructure.database.dao.group import GroupDAO
from trudex.infrastructure.database.dao.test import TestDAO
from trudex.infrastructure.database.repo.test import TestRepository
@@ -199,11 +201,16 @@ async def get_attempt_detail(
"<b>📋 Ответы:</b>\n",
]
# Загружаем все вопросы с опциями за один запрос
question_ids = [answer.question_id for answer in answers]
questions_map = await test_repo.get_questions_with_options_by_ids(question_ids)
for i, answer in enumerate(answers, 1):
question, options = await test_repo.get_question_with_options(answer.question_id)
if not question:
question_data = questions_map.get(answer.question_id)
if not question_data:
continue
question, options = question_data
correct_options = [opt for opt in options if opt.is_correct]
correct_texts = [opt.text for opt in correct_options]
@@ -221,6 +228,87 @@ async def get_attempt_detail(
return {"attempt_info": "\n".join(lines)}
@inject
async def on_export_test(
_callback: CallbackQuery,
_button: Button,
manager: DialogManager,
test_repo: FromDishka[TestRepository],
) -> None:
test_id = manager.dialog_data.get("selected_test_id")
if not test_id:
await _callback.answer("❌ Тест не найден")
return
assert _callback.message is not None
await _callback.answer("⏳ Экспортирую тест...")
test, questions_with_options = await test_repo.get_full_test(test_id)
if not test:
await _callback.message.answer("❌ Тест не найден")
return
export_data: dict = {
"title": test.title,
"description": test.description,
"password": test.password,
"attempts": test.attempts,
"expires_at": test.expires_at.isoformat() if test.expires_at else None,
"for_group": test.for_group,
"questions": [],
}
questions_list: list = export_data["questions"]
for question, options in questions_with_options:
question_data: dict = {
"question_type": question.question_type.value,
"question": question.text,
}
if question.question_type == QuestionType.INPUT:
correct_options = [o for o in options if o.is_correct]
if correct_options:
question_data["correct_answer"] = correct_options[0].text
else:
question_data["answers"] = [
{"option": o.text, "is_correct": o.is_correct}
for o in options
]
questions_list.append(question_data)
json_str = json.dumps(export_data, ensure_ascii=False, indent=2)
created_str = test.created_at.strftime("%d.%m.%Y %H:%M") if test.created_at else ""
updated_str = test.updated_at.strftime("%d.%m.%Y %H:%M") if test.updated_at else ""
questions_count = len(questions_with_options)
comment_header = f"""// ═══════════════════════════════════════════════════════════════
// ЭКСПОРТ ТЕСТА: {test.title}
// ═══════════════════════════════════════════════════════════════
//
// ❓ Вопросов: {questions_count}
// 📅 Создан: {created_str}
// 🔄 Обновлён: {updated_str}
//
// ═══════════════════════════════════════════════════════════════
"""
full_content = comment_header + json_str
safe_title = "".join(c if c.isalnum() or c in "-_" else "_" for c in test.title)[:50]
filename = f"{safe_title}.json"
await _callback.message.answer_document(
document=BufferedInputFile(full_content.encode("utf-8"), filename=filename),
caption=f"📤 <b>Экспорт теста:</b> {test.title}",
)
@inject
async def on_share_test(_callback: CallbackQuery, _button: Button, manager: DialogManager, config: FromDishka[Config], bot_inst: FromDishka[Bot]):
test_id = manager.dialog_data.get("selected_test_id")
@@ -462,6 +550,7 @@ shared_tests_dialog = Dialog(
),
Button(Const("📊 Статистика"), id="statistics", on_click=on_statistics),
Button(Const("🔗 Поделиться"), id="share", on_click=on_share_test),
Button(Const("📤 Экспорт"), id="export", on_click=on_export_test),
Button(Const("✏️ Изменить"), id="edit_menu", on_click=on_edit_menu),
Button(Const("◀️ Назад"), id="back", on_click=on_back_to_list),
),
@@ -11,6 +11,7 @@ from trudex.infrastructure.database.dao.test import TestDAO
from trudex.infrastructure.database.models import QuestionType
from trudex.infrastructure.database.repo.test import TestRepository
from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository
from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter
@inject
@@ -60,6 +61,7 @@ async def on_start_deeplink_test(
test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
):
assert _callback.from_user is not None
@@ -89,6 +91,12 @@ async def on_start_deeplink_test(
await attempt_repo.attempt_dao.delete(active_attempt.id)
if test.password:
# Проверяем rate limit перед показом экрана ввода пароля
allowed, wait_time = await rate_limiter.check(user_id)
if not allowed:
minutes = int(wait_time // 60) + 1
await _callback.answer(f"⏳ Слишком много попыток. Подождите {minutes} мин.", show_alert=True)
return
await manager.switch_to(UserDeeplinkSG.password_input)
else:
await start_test_without_password(manager, test_repo, attempt_repo, test_id, user_id)
@@ -141,6 +149,7 @@ async def on_deeplink_password_input(
test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
):
assert message.from_user is not None
@@ -164,7 +173,14 @@ async def on_deeplink_password_input(
manager, test_repo, attempt_repo, test_id, message.from_user.id
)
else:
await message.answer("Неверный пароль")
# Проверяем rate limit при неверном пароле
allowed, wait_time = await rate_limiter.check(message.from_user.id)
if not allowed:
minutes = int(wait_time // 60) + 1
await message.answer(f"❌ Неверный пароль\n⏳ Слишком много попыток. Подождите {minutes} мин.")
await manager.start(UserMenuSG.main, mode=StartMode.RESET_STACK)
else:
await message.answer("❌ Неверный пароль")
async def on_back_to_menu(_callback: CallbackQuery, _button: Button, manager: DialogManager):
@@ -12,6 +12,7 @@ from trudex.infrastructure.database.dao.test import TestDAO
from trudex.infrastructure.database.dao.user_answer import UserAnswerDAO
from trudex.infrastructure.database.repo.test import TestRepository
from trudex.infrastructure.database.repo.test_attempt import TestAttemptRepository
from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter
from trudex.infrastructure.utils.timezone import now_msk_naive
@@ -32,6 +33,7 @@ async def on_start_test(
test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
):
assert _callback.from_user is not None
test_id = manager.dialog_data.get("selected_test_id")
@@ -66,6 +68,12 @@ async def on_start_test(
await attempt_repo.attempt_dao.delete(active_attempt.id)
if test.password:
# Проверяем rate limit перед показом экрана ввода пароля
allowed, wait_time = await rate_limiter.check(user_id)
if not allowed:
minutes = int(wait_time // 60) + 1
await _callback.answer(f"⏳ Слишком много попыток. Подождите {minutes} мин.", show_alert=True)
return
await manager.start(UserTestSG.password_input, mode=StartMode.NORMAL, data={"test_id": test_id})
else:
_, questions = await test_repo.get_test_with_questions(test_id)
@@ -100,6 +108,7 @@ async def on_password_input(
test_dao: FromDishka[TestDAO],
test_repo: FromDishka[TestRepository],
attempt_repo: FromDishka[TestAttemptRepository],
rate_limiter: FromDishka[PasswordRateLimiter],
):
assert message.from_user is not None
start_data = manager.start_data or {}
@@ -137,7 +146,14 @@ async def on_password_input(
await manager.switch_to(first_state)
else:
await message.answer("Неверный пароль")
# Проверяем rate limit при неверном пароле
allowed, wait_time = await rate_limiter.check(message.from_user.id)
if not allowed:
minutes = int(wait_time // 60) + 1
await message.answer(f"❌ Неверный пароль\n⏳ Слишком много попыток. Подождите {minutes} мин.")
await manager.done()
else:
await message.answer("❌ Неверный пароль")
@inject
@@ -27,6 +27,13 @@ class UserDAO:
model = result.scalar_one_or_none()
return UserDTO(model).to_domain() if model else None
async def get_by_username(self, username: str) -> DomainUser | None:
result = await self.session.execute(
select(User).where(User.username == username)
)
model = result.scalar_one_or_none()
return UserDTO(model).to_domain() if model else None
async def get_all(self) -> list[DomainUser]:
result = await self.session.execute(
select(User).order_by(User.created_at.desc())
+9 -8
View File
@@ -16,11 +16,11 @@ class User(Base):
__tablename__ = "users"
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
username: Mapped[str | None] = mapped_column(String(32))
username: Mapped[str | None] = mapped_column(String(32), index=True)
first_name: Mapped[str] = mapped_column(String(64))
last_name: Mapped[str | None] = mapped_column(String(64))
name: Mapped[str | None] = mapped_column(String(128))
group: Mapped[int | None] = mapped_column(CheckConstraint("group >= 1000 AND group <= 9999"))
group: Mapped[int | None] = mapped_column(CheckConstraint("group >= 1000 AND group <= 9999"), index=True)
is_admin: Mapped[bool] = mapped_column(default=False)
name_updated_at: Mapped[datetime | None] = mapped_column(default=None)
group_updated_at: Mapped[datetime | None] = mapped_column(default=None)
@@ -70,7 +70,7 @@ class Question(Base):
__tablename__ = "questions"
id: Mapped[int] = mapped_column(primary_key=True)
test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"))
test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"), index=True)
text: Mapped[str] = mapped_column(Text)
position: Mapped[int] = mapped_column(Integer, default=0)
question_type: Mapped[QuestionType] = mapped_column(default=QuestionType.SINGLE)
@@ -88,7 +88,7 @@ class Option(Base):
__tablename__ = "options"
id: Mapped[int] = mapped_column(primary_key=True)
question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"))
question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"), index=True)
text: Mapped[str] = mapped_column(String(255))
is_correct: Mapped[bool] = mapped_column(default=False)
explanation: Mapped[str | None] = mapped_column(Text)
@@ -101,13 +101,14 @@ class TestAttempt(Base):
__tablename__ = "test_attempts"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, index=True)
test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"))
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.id"), index=True)
test_id: Mapped[int] = mapped_column(ForeignKey("tests.id"), index=True)
started_at: Mapped[datetime] = mapped_column(server_default=func.now())
finished_at: Mapped[datetime | None] = mapped_column(default=None)
score: Mapped[int] = mapped_column(Integer, default=0)
is_passed: Mapped[bool] = mapped_column(default=False)
user: Mapped["User"] = relationship()
test: Mapped["Test"] = relationship()
answers: Mapped[list["UserAnswer"]] = relationship(
back_populates="attempt",
@@ -120,8 +121,8 @@ class UserAnswer(Base):
__tablename__ = "user_answers"
id: Mapped[int] = mapped_column(primary_key=True)
attempt_id: Mapped[int] = mapped_column(ForeignKey("test_attempts.id"))
question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"))
attempt_id: Mapped[int] = mapped_column(ForeignKey("test_attempts.id"), index=True)
question_id: Mapped[int] = mapped_column(ForeignKey("questions.id"), index=True)
selected_option_id: Mapped[int | None] = mapped_column(ForeignKey("options.id"), default=None)
text_answer: Mapped[str | None] = mapped_column(Text, default=None)
is_correct: Mapped[bool] = mapped_column(default=False)
@@ -122,6 +122,28 @@ class TestRepository:
count = result.scalar_one()
return count
async def get_questions_with_options_by_ids(
self, question_ids: list[int]
) -> dict[int, tuple[Question, list[Option]]]:
"""Загружает вопросы с опциями по списку ID за один запрос."""
if not question_ids:
return {}
result = await self.session.execute(
select(QuestionModel)
.where(QuestionModel.id.in_(question_ids))
.options(selectinload(QuestionModel.options))
)
question_models = list(result.scalars().all())
questions_dict: dict[int, tuple[Question, list[Option]]] = {}
for qm in question_models:
question = QuestionDTO(qm).to_domain()
options = [OptionDTO(o).to_domain() for o in qm.options]
questions_dict[qm.id] = (question, options)
return questions_dict
async def duplicate_test(self, test_id: int, new_title: str) -> Test | None:
test, questions_with_options = await self.get_full_test(test_id)
if not test:
@@ -155,18 +155,16 @@ class TestAttemptRepository:
return [UserAnswerDTO(model).to_domain() for model in models]
async def get_question_statistics(self, question_id: int) -> dict[str, int]:
total_result = await self.session.execute(
select(func.count(UserAnswerModel.id))
result = await self.session.execute(
select(
func.count(UserAnswerModel.id).label("total"),
func.sum(func.cast(UserAnswerModel.is_correct, func.Integer)).label("correct")
)
.where(UserAnswerModel.question_id == question_id)
)
total = total_result.scalar_one()
correct_result = await self.session.execute(
select(func.count(UserAnswerModel.id))
.where(UserAnswerModel.question_id == question_id)
.where(UserAnswerModel.is_correct == True)
)
correct = correct_result.scalar_one()
row = result.one()
total = row.total or 0
correct = row.correct or 0
return {
"total_answers": total,
+5
View File
@@ -18,12 +18,17 @@ from trudex.infrastructure.database.repo.test_attempt import TestAttemptReposito
from trudex.infrastructure.database.repo.user import UserRepository
from trudex.infrastructure.scheduling.tasks import deactivate_expired_tests
from trudex.infrastructure.utils.config import Config
from trudex.infrastructure.utils.rate_limiter import PasswordRateLimiter
class DatabaseProvider(Provider):
@provide(scope=Scope.APP)
def get_session_maker(self, config: Config) -> async_sessionmaker[AsyncSession]:
return new_session_maker(config.database.url)
@provide(scope=Scope.APP)
def get_password_rate_limiter(self) -> PasswordRateLimiter:
return PasswordRateLimiter()
@provide(scope=Scope.REQUEST)
async def get_session(
+23 -2
View File
@@ -3,7 +3,13 @@ import logging
from dataclasses import dataclass
from aiogram import Bot
from aiogram.exceptions import TelegramBadRequest, TelegramForbiddenError
from aiogram.exceptions import (
TelegramAPIError,
TelegramBadRequest,
TelegramForbiddenError,
TelegramNetworkError,
TelegramRetryAfter,
)
from trudex.infrastructure.database.dao.user import UserDAO
@@ -28,14 +34,29 @@ async def broadcast_message(bot: Bot, message_id: int, chat_id: int, user_dao: U
try:
await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id)
success += 1
except TelegramRetryAfter as e:
logger.warning("Rate limited, waiting %d seconds", e.retry_after)
await asyncio.sleep(e.retry_after)
# Retry after waiting
try:
await bot.copy_message(chat_id=user.id, from_chat_id=chat_id, message_id=message_id)
success += 1
except TelegramAPIError:
failed += 1
except TelegramForbiddenError:
logger.debug("Broadcast failed (forbidden): user_id=%d", user.id)
failed += 1
except TelegramBadRequest as e:
logger.debug("Broadcast failed (bad request): user_id=%d, error=%s", user.id, e)
failed += 1
except TelegramNetworkError as e:
logger.warning("Network error during broadcast: user_id=%d, error=%s", user.id, e)
failed += 1
except TelegramAPIError as e:
logger.warning("Telegram API error during broadcast: user_id=%d, error=%s", user.id, e)
failed += 1
await asyncio.sleep(0.1)
await asyncio.sleep(0.05)
logger.info("Broadcast completed: success=%d, failed=%d, total=%d", success, failed, len(users))
return BroadcastStats(success=success, failed=failed, total=len(users))
@@ -0,0 +1,57 @@
import asyncio
import time
from dataclasses import dataclass
@dataclass
class UserBucket:
tokens: float
last_updated: float
class RateLimiter:
def __init__(self, rate: int, period: int):
self.rate = rate
self.period = period
self.fill_rate = rate / period
self.buckets: dict[int, UserBucket] = {}
self._lock = asyncio.Lock()
async def check(self, user_id: int) -> tuple[bool, float]:
async with self._lock:
now = time.time()
if user_id not in self.buckets:
self.buckets[user_id] = UserBucket(
tokens=self.rate - 1,
last_updated=now
)
return True, 0.0
bucket = self.buckets[user_id]
elapsed = now - bucket.last_updated
added_tokens = elapsed * self.fill_rate
bucket.tokens = min(self.rate, bucket.tokens + added_tokens)
bucket.last_updated = now
if bucket.tokens >= 1:
bucket.tokens -= 1
return True, 0.0
else:
wait_time = (1 - bucket.tokens) / self.fill_rate
return False, wait_time
async def cleanup(self) -> None:
async with self._lock:
full_buckets = [
user_id for user_id, bucket in self.buckets.items()
if bucket.tokens >= self.rate
]
for user_id in full_buckets:
del self.buckets[user_id]
class PasswordRateLimiter(RateLimiter):
def __init__(self):
super().__init__(rate=5, period=3600)