create basic handler

This commit is contained in:
Mike Bloy 2021-03-07 13:47:30 -06:00
parent 0063951aeb
commit 0a29d2bad3
6 changed files with 49 additions and 80 deletions

View File

@ -22,3 +22,7 @@ install_requires =
[options.packages.find] [options.packages.find]
where=src where=src
[options.entry_points]
console_scripts =
hasskiosk = hasskiosk:run

View File

@ -1,7 +1,9 @@
"""Home Assistant Kiosk.""" """Home Assistant Kiosk."""
from ._version import version as __version__ from ._version import version as __version__
from .runner import run
__all__ = [ __all__ = [
"run",
"__version__", "__version__",
] ]

View File

@ -1,74 +1,52 @@
"""Manage mqtt connections.""" """Manage mqtt connections."""
import asyncio import asyncio
import logging from contextlib import AsyncExitStack
from typing import Any, Dict from typing import Set
from gmqtt import Client from asyncio_mqtt import Client
class MQTT: class MQTTManager:
"""MQTT manager. Wrapper around paho.mqtt.client.Client.""" """MQTT manager class."""
def __init__(self, config: Dict[str, Any]): def __init__(self, hostname: str, port: int, username: str, password: str):
"""Init MQTT. """Initialize with the following data.
Arguments: Arguments:
config: a config object returned by config.read_config() hostname: MQTT host to connect to
port: port to use for connecting
username: authentication username
password: authentication password
""" """
mqtt = config["mqtt"] super().__init__()
self._host = mqtt["host"] self._hostname = hostname
self._port = mqtt["port"] self._port = port
self._keepalive = mqtt["keepalive"] self._username = username
self._username = mqtt["username"] self._password = password
self._password = mqtt["password"] self._tasks: Set[asyncio.Task] = set()
self._client_id = config["client_id"] self._mqtt: Client = None
self._topics: Dict[str, Any] = dict()
self._subscriptions = mqtt["subscriptions"]
self._client = Client(client_id=self._client_id, clean_session=True)
async def connect(self): async def run(self):
"""Connect to the client and log the connection.""" """MQTT async runner."""
logger = logging.getLogger(__name__) async with AsyncExitStack() as ctx:
logger.info( self._tasks = set()
"connecting to MQTT at %s:%s with client_id %s", ctx.push_async_callback(self._cancel_tasks)
self._host, self._mqtt = Client(
self._port, hostname=self._host,
self._client_id, port=self._port,
username=self._username,
password=self._password,
clean_session=True,
) )
self._client.set_auth_credentials(self._username, self._password) await self.enter_async_context(self._mqtt)
self._client.on_connect = self.on_connect
self._client.on_disconnect = self.on_disconnect
self._client.on_subscribe = self.on_subscribe
self._client.on_message = self.on_message
await self._client.connect(self._host, self._port, self._keepalive)
async def disconnect(self): async def _cancel_tasks(self):
"""Wrapper around client disconnect.""" for task in self._tasks:
await self._client.disconnect() if task.done():
continue
def on_connect(self, client: Client, flags: int, rc: int): task.cancel()
"""Callback method for the client connection.""" try:
logger = logging.getLogger(__name__) await task
logger.info("client %s connected with result code %s", client, rc) except asyncio.CancelledError:
for sub in self._subscriptions: pass
client.subscribe(sub, qos=0)
@staticmethod
def on_disconnect(client: Client, packet, exc=None):
"""Callback for disconnections."""
logger = logging.getLogger(__name__)
logger.info("disconnected from broker: %s", client)
@staticmethod
def on_subscribe(client: Client, mid: int, qos: int, properties):
"""Callback for subscriptions."""
logger = logging.getLogger(__name__)
logger.info("Subscribed to topic(s) with mid %s and qos %s", mid, qos)
@staticmethod
def on_message(client: Client, topic, payload: bytes, qos, properties):
"""Callback for message handling."""
logger = logging.getLogger(__name__)
message = payload.decode()
logger.info("Recieved message '%s' on topic %s", message, topic)

View File

@ -40,11 +40,6 @@ async def main(config: Dict[str, Any]):
messages = await exit_stack.enter_async_context(manager) messages = await exit_stack.enter_async_context(manager)
tasks.add(asyncio.create_task(handler(messages))) tasks.add(asyncio.create_task(handler(messages)))
other_messages = await exit_stack.enter_async_context(
mqtt.unfiltered_messages()
)
tasks.add(asyncio.create_task(dead_letter_handler(other_messages)))
await mqtt.subscribe(config["mqtt"]["subscription"]) await mqtt.subscribe(config["mqtt"]["subscription"])
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
@ -61,17 +56,6 @@ async def screen_state_mqtt_handler(messages):
) )
async def dead_letter_handler(messages):
"""Logger for uncaught messages."""
log = logging.getLogger(__name__)
async for message in messages:
log.info(
"unfiltered message on topic %s: %s",
message.topic,
message.payload.decode(),
)
async def cancel_tasks(tasks): async def cancel_tasks(tasks):
"""Cancel tasks on shutdown.""" """Cancel tasks on shutdown."""
for task in tasks: for task in tasks:

View File

@ -22,7 +22,7 @@ def mock_env(monkeypatch):
@pytest.fixture @pytest.fixture
def logging_config() -> Dict[str, Any]: def logging_config() -> Dict[str, Any]:
"""logging configuration fixture.""" """Logging configuration fixture."""
config = { config = {
"version": 1, "version": 1,
"disable_existing_loggers": False, "disable_existing_loggers": False,
@ -69,7 +69,7 @@ def test_read_config():
"port": 1883, "port": 1883,
"keepalive": 60, "keepalive": 60,
"subscription": "home/test/#", "subscription": "home/test/#",
"screen_state_topic": "#/presence", "screen_state_topic": "home/+/presence",
} }

View File

@ -103,6 +103,7 @@ include =
.tox/**/hasskiosk/ .tox/**/hasskiosk/
omit = omit =
**/hasskiosk/_version.py **/hasskiosk/_version.py
**/hasskiosk/__main__.py
setup.py setup.py
tests/* tests/*