add working proof of concept using asyncio-mqtt

This commit is contained in:
Mike Bloy 2021-02-28 23:48:35 -06:00
parent 20b92734e1
commit bed27b83dc
7 changed files with 181 additions and 7 deletions

View File

@ -17,8 +17,8 @@ package_dir =
=src
packages = find:
install_requires =
asyncio_mqtt
environs
paho-mqtt
[options.packages.find]
where=src

View File

@ -0,0 +1,6 @@
"""Main module runner."""
from .runner import run
if __name__ == "__main__":
run()

View File

@ -1,6 +1,8 @@
"""Configuration management from environment."""
import logging.config
import os
import socket
from logging import getLogger
from typing import Any, Dict
@ -46,10 +48,6 @@ def read_config() -> Dict[str, Any]:
},
},
}
with env.prefixed("TOPIC_"):
config["topics"] = {
"presence": env("PRESENCE"),
}
with env.prefixed("MQTT_"):
config["mqtt"] = {
"host": env("HOST"),
@ -57,7 +55,13 @@ def read_config() -> Dict[str, Any]:
"username": env("USERNAME"),
"password": env("PASSWORD"),
"keepalive": env.int("KEEPALIVE", 60),
"subscription": env("SUBSCRIBE_TOPIC"),
"screen_state_topic": env.str("SCREEN_STATE_TOPIC", "home/+/presence"),
}
hostname = socket.gethostname()
pid = os.getpid()
sysname = config["sysname"]
config["client_id"] = f"{sysname}-{hostname}-{pid}"
return config

74
src/hasskiosk/mqtt.py Normal file
View File

@ -0,0 +1,74 @@
"""Manage mqtt connections."""
import asyncio
import logging
from typing import Any, Dict
from gmqtt import Client
class MQTT:
"""MQTT manager. Wrapper around paho.mqtt.client.Client."""
def __init__(self, config: Dict[str, Any]):
"""Init MQTT.
Arguments:
config: a config object returned by config.read_config()
"""
mqtt = config["mqtt"]
self._host = mqtt["host"]
self._port = mqtt["port"]
self._keepalive = mqtt["keepalive"]
self._username = mqtt["username"]
self._password = mqtt["password"]
self._client_id = config["client_id"]
self._topics: Dict[str, Any] = dict()
self._subscriptions = mqtt["subscriptions"]
self._client = Client(client_id=self._client_id, clean_session=True)
async def connect(self):
"""Connect to the client and log the connection."""
logger = logging.getLogger(__name__)
logger.info(
"connecting to MQTT at %s:%s with client_id %s",
self._host,
self._port,
self._client_id,
)
self._client.set_auth_credentials(self._username, self._password)
self._client.on_connect = self.on_connect
self._client.on_disconnect = self.on_disconnect
self._client.on_subscribe = self.on_subscribe
self._client.on_message = self.on_message
await self._client.connect(self._host, self._port, self._keepalive)
async def disconnect(self):
"""Wrapper around client disconnect."""
await self._client.disconnect()
def on_connect(self, client: Client, flags: int, rc: int):
"""Callback method for the client connection."""
logger = logging.getLogger(__name__)
logger.info("client %s connected with result code %s", client, rc)
for sub in self._subscriptions:
client.subscribe(sub, qos=0)
@staticmethod
def on_disconnect(client: Client, packet, exc=None):
"""Callback for disconnections."""
logger = logging.getLogger(__name__)
logger.info("disconnected from broker: %s", client)
@staticmethod
def on_subscribe(client: Client, mid: int, qos: int, properties):
"""Callback for subscriptions."""
logger = logging.getLogger(__name__)
logger.info("Subscribed to topic(s) with mid %s and qos %s", mid, qos)
@staticmethod
def on_message(client: Client, topic, payload: bytes, qos, properties):
"""Callback for message handling."""
logger = logging.getLogger(__name__)
message = payload.decode()
logger.info("Recieved message '%s' on topic %s", message, topic)

84
src/hasskiosk/runner.py Normal file
View File

@ -0,0 +1,84 @@
"""Runner and daemon management."""
import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any, Dict, Set
from asyncio_mqtt import Client
from .config import configure_logging, read_config
def run():
"""Run the daemon."""
config = read_config()
configure_logging(config)
asyncio.run(main(config))
async def main(config: Dict[str, Any]):
"""Setup and run the async tasks."""
async with AsyncExitStack() as exit_stack:
tasks: Set[asyncio.Task] = set()
exit_stack.push_async_callback(cancel_tasks, tasks)
mqtt = Client(
hostname=config["mqtt"]["host"],
port=config["mqtt"]["port"],
username=config["mqtt"]["username"],
password=config["mqtt"]["password"],
clean_session=True,
)
await exit_stack.enter_async_context(mqtt)
topic_handlers = (
(config["mqtt"]["screen_state_topic"], screen_state_mqtt_handler),
)
for topic, handler in topic_handlers:
manager = mqtt.filtered_messages(topic)
messages = await exit_stack.enter_async_context(manager)
tasks.add(asyncio.create_task(handler(messages)))
other_messages = await exit_stack.enter_async_context(
mqtt.unfiltered_messages()
)
tasks.add(asyncio.create_task(dead_letter_handler(other_messages)))
await mqtt.subscribe(config["mqtt"]["subscription"])
await asyncio.gather(*tasks)
async def screen_state_mqtt_handler(messages):
"""Screen state handler, reacts on presence messages."""
log = logging.getLogger(__name__)
async for message in messages:
log.info(
"screen sate message on topic %s: %s",
message.topic,
message.payload.decode(),
)
async def dead_letter_handler(messages):
"""Logger for uncaught messages."""
log = logging.getLogger(__name__)
async for message in messages:
log.info(
"unfiltered message on topic %s: %s",
message.topic,
message.payload.decode(),
)
async def cancel_tasks(tasks):
"""Cancel tasks on shutdown."""
for task in tasks:
if task.done():
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@ -1,6 +1,8 @@
"""Tests for the configuration management."""
import logging
import os
import socket
from typing import Any, Dict
import pytest
@ -12,7 +14,7 @@ from hasskiosk.config import configure_logging, read_config
@pytest.fixture(autouse=True)
def mock_env(monkeypatch):
"""Environment mock to test values."""
monkeypatch.setenv("HASSKIOSK_TOPIC_PRESENCE", "home/test/presence")
monkeypatch.setenv("HASSKIOSK_MQTT_SUBSCRIBE_TOPIC", "home/test/#")
monkeypatch.setenv("HASSKIOSK_MQTT_HOST", "ha.example.com")
monkeypatch.setenv("HASSKIOSK_MQTT_USERNAME", "testymctesterson")
monkeypatch.setenv("HASSKIOSK_MQTT_PASSWORD", "hunter2")
@ -55,14 +57,19 @@ def logging_config() -> Dict[str, Any]:
def test_read_config():
"""Test the read_config function."""
config = read_config()
hostname = socket.gethostname()
pid = os.getpid()
assert config["version"] == __version__
assert config["sysname"] == "hasskiosk"
assert config["client_id"] == f"hasskiosk-{hostname}-{pid}"
assert config["mqtt"] == {
"host": "ha.example.com",
"username": "testymctesterson",
"password": "hunter2",
"port": 1883,
"keepalive": 60,
"subscription": "home/test/#",
"screen_state_topic": "#/presence",
}

View File

@ -54,7 +54,6 @@ commands =
recreate=True
deps =
bpython
mypy
python-language-server
rope
{[testenv:py37]deps}