From 0a29d2bad35b87888fc6569e9ea37efc73cb4308 Mon Sep 17 00:00:00 2001 From: Mike Bloy Date: Sun, 7 Mar 2021 13:47:30 -0600 Subject: [PATCH] create basic handler --- setup.cfg | 4 ++ src/hasskiosk/__init__.py | 2 + src/hasskiosk/mqtt.py | 102 +++++++++++++++----------------------- src/hasskiosk/runner.py | 16 ------ tests/test_config.py | 4 +- tox.ini | 1 + 6 files changed, 49 insertions(+), 80 deletions(-) diff --git a/setup.cfg b/setup.cfg index 7ab2a3e..1e80a9f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,3 +22,7 @@ install_requires = [options.packages.find] where=src + +[options.entry_points] +console_scripts = + hasskiosk = hasskiosk:run diff --git a/src/hasskiosk/__init__.py b/src/hasskiosk/__init__.py index 70337f3..ec1b423 100644 --- a/src/hasskiosk/__init__.py +++ b/src/hasskiosk/__init__.py @@ -1,7 +1,9 @@ """Home Assistant Kiosk.""" from ._version import version as __version__ +from .runner import run __all__ = [ + "run", "__version__", ] diff --git a/src/hasskiosk/mqtt.py b/src/hasskiosk/mqtt.py index 2fbbba0..b295901 100644 --- a/src/hasskiosk/mqtt.py +++ b/src/hasskiosk/mqtt.py @@ -1,74 +1,52 @@ """Manage mqtt connections.""" import asyncio -import logging -from typing import Any, Dict +from contextlib import AsyncExitStack +from typing import Set -from gmqtt import Client +from asyncio_mqtt import Client -class MQTT: - """MQTT manager. Wrapper around paho.mqtt.client.Client.""" +class MQTTManager: + """MQTT manager class.""" - def __init__(self, config: Dict[str, Any]): - """Init MQTT. + def __init__(self, hostname: str, port: int, username: str, password: str): + """Initialize with the following data. Arguments: - config: a config object returned by config.read_config() + hostname: MQTT host to connect to + port: port to use for connecting + username: authentication username + password: authentication password """ - mqtt = config["mqtt"] - self._host = mqtt["host"] - self._port = mqtt["port"] - self._keepalive = mqtt["keepalive"] - self._username = mqtt["username"] - self._password = mqtt["password"] - self._client_id = config["client_id"] - self._topics: Dict[str, Any] = dict() - self._subscriptions = mqtt["subscriptions"] - self._client = Client(client_id=self._client_id, clean_session=True) + super().__init__() + self._hostname = hostname + self._port = port + self._username = username + self._password = password + self._tasks: Set[asyncio.Task] = set() + self._mqtt: Client = None - async def connect(self): - """Connect to the client and log the connection.""" - logger = logging.getLogger(__name__) - logger.info( - "connecting to MQTT at %s:%s with client_id %s", - self._host, - self._port, - self._client_id, - ) - self._client.set_auth_credentials(self._username, self._password) - self._client.on_connect = self.on_connect - self._client.on_disconnect = self.on_disconnect - self._client.on_subscribe = self.on_subscribe - self._client.on_message = self.on_message - await self._client.connect(self._host, self._port, self._keepalive) + async def run(self): + """MQTT async runner.""" + async with AsyncExitStack() as ctx: + self._tasks = set() + ctx.push_async_callback(self._cancel_tasks) + self._mqtt = Client( + hostname=self._host, + port=self._port, + username=self._username, + password=self._password, + clean_session=True, + ) + await self.enter_async_context(self._mqtt) - async def disconnect(self): - """Wrapper around client disconnect.""" - await self._client.disconnect() - - def on_connect(self, client: Client, flags: int, rc: int): - """Callback method for the client connection.""" - logger = logging.getLogger(__name__) - logger.info("client %s connected with result code %s", client, rc) - for sub in self._subscriptions: - client.subscribe(sub, qos=0) - - @staticmethod - def on_disconnect(client: Client, packet, exc=None): - """Callback for disconnections.""" - logger = logging.getLogger(__name__) - logger.info("disconnected from broker: %s", client) - - @staticmethod - def on_subscribe(client: Client, mid: int, qos: int, properties): - """Callback for subscriptions.""" - logger = logging.getLogger(__name__) - logger.info("Subscribed to topic(s) with mid %s and qos %s", mid, qos) - - @staticmethod - def on_message(client: Client, topic, payload: bytes, qos, properties): - """Callback for message handling.""" - logger = logging.getLogger(__name__) - message = payload.decode() - logger.info("Recieved message '%s' on topic %s", message, topic) + async def _cancel_tasks(self): + for task in self._tasks: + if task.done(): + continue + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/src/hasskiosk/runner.py b/src/hasskiosk/runner.py index 901e5b3..8e6b5a1 100644 --- a/src/hasskiosk/runner.py +++ b/src/hasskiosk/runner.py @@ -40,11 +40,6 @@ async def main(config: Dict[str, Any]): messages = await exit_stack.enter_async_context(manager) tasks.add(asyncio.create_task(handler(messages))) - other_messages = await exit_stack.enter_async_context( - mqtt.unfiltered_messages() - ) - tasks.add(asyncio.create_task(dead_letter_handler(other_messages))) - await mqtt.subscribe(config["mqtt"]["subscription"]) await asyncio.gather(*tasks) @@ -61,17 +56,6 @@ async def screen_state_mqtt_handler(messages): ) -async def dead_letter_handler(messages): - """Logger for uncaught messages.""" - log = logging.getLogger(__name__) - async for message in messages: - log.info( - "unfiltered message on topic %s: %s", - message.topic, - message.payload.decode(), - ) - - async def cancel_tasks(tasks): """Cancel tasks on shutdown.""" for task in tasks: diff --git a/tests/test_config.py b/tests/test_config.py index a0d6633..c494b9d 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -22,7 +22,7 @@ def mock_env(monkeypatch): @pytest.fixture def logging_config() -> Dict[str, Any]: - """logging configuration fixture.""" + """Logging configuration fixture.""" config = { "version": 1, "disable_existing_loggers": False, @@ -69,7 +69,7 @@ def test_read_config(): "port": 1883, "keepalive": 60, "subscription": "home/test/#", - "screen_state_topic": "#/presence", + "screen_state_topic": "home/+/presence", } diff --git a/tox.ini b/tox.ini index 75b6978..d3eef28 100644 --- a/tox.ini +++ b/tox.ini @@ -103,6 +103,7 @@ include = .tox/**/hasskiosk/ omit = **/hasskiosk/_version.py + **/hasskiosk/__main__.py setup.py tests/*