diff --git a/src/dutylog/application/__main__.py b/src/dutylog/application/__main__.py index cac735f..1f2f585 100644 --- a/src/dutylog/application/__main__.py +++ b/src/dutylog/application/__main__.py @@ -5,9 +5,12 @@ from aiogram import Bot, Dispatcher from aiogram.client.default import DefaultBotProperties from aiogram.enums import ParseMode from aiogram_dialog import setup_dialogs +from dishka import make_async_container +from dishka.integrations.aiogram import setup_dishka from src.dutylog.application.bot.user_handlers import router as user_router from src.dutylog.application.bot.user_dialogs import main_menu_dialog +from src.dutylog.infrastructure.ioc import ConfigProvider, DatabaseProvider, DAOProvider from src.dutylog.infrastructure.utils.config import load_config @@ -22,13 +25,21 @@ async def main(): ) dp = Dispatcher() + container = make_async_container( + ConfigProvider(), + DatabaseProvider(), + DAOProvider(), + ) + dp.include_router(user_router) dp.include_router(main_menu_dialog) setup_dialogs(dp) + setup_dishka(container, dp) await dp.start_polling(bot) if __name__ == "__main__": asyncio.run(main()) + diff --git a/src/dutylog/infrastructure/ioc.py b/src/dutylog/infrastructure/ioc.py index aeba2e7..6505abe 100644 --- a/src/dutylog/infrastructure/ioc.py +++ b/src/dutylog/infrastructure/ioc.py @@ -1,7 +1,38 @@ -from sqlalchemy.ext.asyncio import AsyncSession +from collections.abc import AsyncIterable -from src.dutylog.infrastructure.database.dao.users_dao import UsersDAO +from dishka import Provider, Scope, provide +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker + +from dutylog.infrastructure.database.config import create_engine, create_session_maker +from dutylog.infrastructure.database.dao.users_dao import UsersDAO +from dutylog.infrastructure.utils.config import Config, load_config -def get_users_dao(session: AsyncSession) -> UsersDAO: - return UsersDAO(session) +class ConfigProvider(Provider): + @provide(scope=Scope.APP) + def get_config(self) -> Config: + return load_config() + + +class DatabaseProvider(Provider): + @provide(scope=Scope.APP) + def get_engine(self, config: Config) -> AsyncEngine: + return create_engine(config.database.url) + + @provide(scope=Scope.APP) + def get_session_maker(self, engine: AsyncEngine) -> async_sessionmaker[AsyncSession]: + return create_session_maker(engine) + + @provide(scope=Scope.REQUEST) + async def get_session( + self, session_maker: async_sessionmaker[AsyncSession] + ) -> AsyncIterable[AsyncSession]: + async with session_maker() as session: + yield session + + +class DAOProvider(Provider): + @provide(scope=Scope.REQUEST) + def get_users_dao(self, session: AsyncSession) -> UsersDAO: + return UsersDAO(session) +