Preliminary and final additional tasks implemented

This commit is contained in:
Davte 2019-07-12 14:19:15 +02:00
parent c00fce62aa
commit 34c2135962

View File

@ -79,6 +79,8 @@ class Bot(TelegramBot, ObjectWithDatabase):
TelegramBot.__init__(self, token) TelegramBot.__init__(self, token)
ObjectWithDatabase.__init__(self, database_url=database_url) ObjectWithDatabase.__init__(self, database_url=database_url)
self._path = None self._path = None
self.preliminary_tasks = []
self.final_tasks = []
self._offset = 0 self._offset = 0
self._hostname = hostname self._hostname = hostname
self._certificate = certificate self._certificate = certificate
@ -1712,6 +1714,21 @@ class Bot(TelegramBot, ObjectWithDatabase):
return await self.routing_table[key](value) return await self.routing_table[key](value)
logging.error(f"Unknown type of update.\n{update}") logging.error(f"Unknown type of update.\n{update}")
def additional_task(self, when='BEFORE', *args, **kwargs):
"""Add a task before at app start or cleanup.
Decorate an async function to have it awaited `BEFORE` or `AFTER` main
loop.
"""
when = when[0].lower()
def additional_task_decorator(task):
if when == 'b':
self.preliminary_tasks.append(task(*args, **kwargs))
elif when == 'a':
self.final_tasks.append(task(*args, **kwargs))
return additional_task_decorator
@classmethod @classmethod
async def start_app(cls): async def start_app(cls):
"""Start running `aiohttp.web.Application`. """Start running `aiohttp.web.Application`.
@ -1730,6 +1747,9 @@ class Bot(TelegramBot, ObjectWithDatabase):
async def stop_app(cls): async def stop_app(cls):
"""Close bot sessions and cleanup.""" """Close bot sessions and cleanup."""
for bot in cls.bots: for bot in cls.bots:
await asyncio.gather(
*bot.final_tasks
)
await bot.close_sessions() await bot.close_sessions()
await cls.runner.cleanup() await cls.runner.cleanup()
@ -1759,6 +1779,18 @@ class Bot(TelegramBot, ObjectWithDatabase):
cls.local_host = local_host cls.local_host = local_host
if port is not None: if port is not None:
cls.port = port cls.port = port
try:
cls.loop.run_until_complete(
asyncio.gather(
*[
preliminary_task
for bot in cls.bots
for preliminary_task in bot.preliminary_tasks
]
)
)
except Exception as e:
logging.error(f"{e}", exc_info=True)
for bot in cls.bots: for bot in cls.bots:
bot.setup() bot.setup()
asyncio.ensure_future(cls.start_app()) asyncio.ensure_future(cls.start_app())