create basic handler

This commit is contained in:
Mike Bloy 2021-03-07 13:47:30 -06:00
parent 0063951aeb
commit 0a29d2bad3
6 changed files with 49 additions and 80 deletions

View File

@ -22,3 +22,7 @@ install_requires =
[options.packages.find]
where=src
[options.entry_points]
console_scripts =
hasskiosk = hasskiosk:run

View File

@ -1,7 +1,9 @@
"""Home Assistant Kiosk."""
from ._version import version as __version__
from .runner import run
__all__ = [
"run",
"__version__",
]

View File

@ -1,74 +1,52 @@
"""Manage mqtt connections."""
import asyncio
import logging
from typing import Any, Dict
from contextlib import AsyncExitStack
from typing import Set
from gmqtt import Client
from asyncio_mqtt import Client
class MQTT:
"""MQTT manager. Wrapper around paho.mqtt.client.Client."""
class MQTTManager:
"""MQTT manager class."""
def __init__(self, config: Dict[str, Any]):
"""Init MQTT.
def __init__(self, hostname: str, port: int, username: str, password: str):
"""Initialize with the following data.
Arguments:
config: a config object returned by config.read_config()
hostname: MQTT host to connect to
port: port to use for connecting
username: authentication username
password: authentication password
"""
mqtt = config["mqtt"]
self._host = mqtt["host"]
self._port = mqtt["port"]
self._keepalive = mqtt["keepalive"]
self._username = mqtt["username"]
self._password = mqtt["password"]
self._client_id = config["client_id"]
self._topics: Dict[str, Any] = dict()
self._subscriptions = mqtt["subscriptions"]
self._client = Client(client_id=self._client_id, clean_session=True)
super().__init__()
self._hostname = hostname
self._port = port
self._username = username
self._password = password
self._tasks: Set[asyncio.Task] = set()
self._mqtt: Client = None
async def connect(self):
"""Connect to the client and log the connection."""
logger = logging.getLogger(__name__)
logger.info(
"connecting to MQTT at %s:%s with client_id %s",
self._host,
self._port,
self._client_id,
)
self._client.set_auth_credentials(self._username, self._password)
self._client.on_connect = self.on_connect
self._client.on_disconnect = self.on_disconnect
self._client.on_subscribe = self.on_subscribe
self._client.on_message = self.on_message
await self._client.connect(self._host, self._port, self._keepalive)
async def run(self):
"""MQTT async runner."""
async with AsyncExitStack() as ctx:
self._tasks = set()
ctx.push_async_callback(self._cancel_tasks)
self._mqtt = Client(
hostname=self._host,
port=self._port,
username=self._username,
password=self._password,
clean_session=True,
)
await self.enter_async_context(self._mqtt)
async def disconnect(self):
"""Wrapper around client disconnect."""
await self._client.disconnect()
def on_connect(self, client: Client, flags: int, rc: int):
"""Callback method for the client connection."""
logger = logging.getLogger(__name__)
logger.info("client %s connected with result code %s", client, rc)
for sub in self._subscriptions:
client.subscribe(sub, qos=0)
@staticmethod
def on_disconnect(client: Client, packet, exc=None):
"""Callback for disconnections."""
logger = logging.getLogger(__name__)
logger.info("disconnected from broker: %s", client)
@staticmethod
def on_subscribe(client: Client, mid: int, qos: int, properties):
"""Callback for subscriptions."""
logger = logging.getLogger(__name__)
logger.info("Subscribed to topic(s) with mid %s and qos %s", mid, qos)
@staticmethod
def on_message(client: Client, topic, payload: bytes, qos, properties):
"""Callback for message handling."""
logger = logging.getLogger(__name__)
message = payload.decode()
logger.info("Recieved message '%s' on topic %s", message, topic)
async def _cancel_tasks(self):
for task in self._tasks:
if task.done():
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@ -40,11 +40,6 @@ async def main(config: Dict[str, Any]):
messages = await exit_stack.enter_async_context(manager)
tasks.add(asyncio.create_task(handler(messages)))
other_messages = await exit_stack.enter_async_context(
mqtt.unfiltered_messages()
)
tasks.add(asyncio.create_task(dead_letter_handler(other_messages)))
await mqtt.subscribe(config["mqtt"]["subscription"])
await asyncio.gather(*tasks)
@ -61,17 +56,6 @@ async def screen_state_mqtt_handler(messages):
)
async def dead_letter_handler(messages):
"""Logger for uncaught messages."""
log = logging.getLogger(__name__)
async for message in messages:
log.info(
"unfiltered message on topic %s: %s",
message.topic,
message.payload.decode(),
)
async def cancel_tasks(tasks):
"""Cancel tasks on shutdown."""
for task in tasks:

View File

@ -22,7 +22,7 @@ def mock_env(monkeypatch):
@pytest.fixture
def logging_config() -> Dict[str, Any]:
"""logging configuration fixture."""
"""Logging configuration fixture."""
config = {
"version": 1,
"disable_existing_loggers": False,
@ -69,7 +69,7 @@ def test_read_config():
"port": 1883,
"keepalive": 60,
"subscription": "home/test/#",
"screen_state_topic": "#/presence",
"screen_state_topic": "home/+/presence",
}

View File

@ -103,6 +103,7 @@ include =
.tox/**/hasskiosk/
omit =
**/hasskiosk/_version.py
**/hasskiosk/__main__.py
setup.py
tests/*