add working proof of concept using asyncio-mqtt

This commit is contained in:
Mike Bloy 2021-02-28 23:48:35 -06:00
parent 20b92734e1
commit bed27b83dc
7 changed files with 181 additions and 7 deletions

View File

@ -17,8 +17,8 @@ package_dir =
=src =src
packages = find: packages = find:
install_requires = install_requires =
asyncio_mqtt
environs environs
paho-mqtt
[options.packages.find] [options.packages.find]
where=src where=src

View File

@ -0,0 +1,6 @@
"""Main module runner."""
from .runner import run
if __name__ == "__main__":
run()

View File

@ -1,6 +1,8 @@
"""Configuration management from environment.""" """Configuration management from environment."""
import logging.config import logging.config
import os
import socket
from logging import getLogger from logging import getLogger
from typing import Any, Dict from typing import Any, Dict
@ -46,10 +48,6 @@ def read_config() -> Dict[str, Any]:
}, },
}, },
} }
with env.prefixed("TOPIC_"):
config["topics"] = {
"presence": env("PRESENCE"),
}
with env.prefixed("MQTT_"): with env.prefixed("MQTT_"):
config["mqtt"] = { config["mqtt"] = {
"host": env("HOST"), "host": env("HOST"),
@ -57,7 +55,13 @@ def read_config() -> Dict[str, Any]:
"username": env("USERNAME"), "username": env("USERNAME"),
"password": env("PASSWORD"), "password": env("PASSWORD"),
"keepalive": env.int("KEEPALIVE", 60), "keepalive": env.int("KEEPALIVE", 60),
"subscription": env("SUBSCRIBE_TOPIC"),
"screen_state_topic": env.str("SCREEN_STATE_TOPIC", "home/+/presence"),
} }
hostname = socket.gethostname()
pid = os.getpid()
sysname = config["sysname"]
config["client_id"] = f"{sysname}-{hostname}-{pid}"
return config return config

74
src/hasskiosk/mqtt.py Normal file
View File

@ -0,0 +1,74 @@
"""Manage mqtt connections."""
import asyncio
import logging
from typing import Any, Dict
from gmqtt import Client
class MQTT:
"""MQTT manager. Wrapper around paho.mqtt.client.Client."""
def __init__(self, config: Dict[str, Any]):
"""Init MQTT.
Arguments:
config: a config object returned by config.read_config()
"""
mqtt = config["mqtt"]
self._host = mqtt["host"]
self._port = mqtt["port"]
self._keepalive = mqtt["keepalive"]
self._username = mqtt["username"]
self._password = mqtt["password"]
self._client_id = config["client_id"]
self._topics: Dict[str, Any] = dict()
self._subscriptions = mqtt["subscriptions"]
self._client = Client(client_id=self._client_id, clean_session=True)
async def connect(self):
"""Connect to the client and log the connection."""
logger = logging.getLogger(__name__)
logger.info(
"connecting to MQTT at %s:%s with client_id %s",
self._host,
self._port,
self._client_id,
)
self._client.set_auth_credentials(self._username, self._password)
self._client.on_connect = self.on_connect
self._client.on_disconnect = self.on_disconnect
self._client.on_subscribe = self.on_subscribe
self._client.on_message = self.on_message
await self._client.connect(self._host, self._port, self._keepalive)
async def disconnect(self):
"""Wrapper around client disconnect."""
await self._client.disconnect()
def on_connect(self, client: Client, flags: int, rc: int):
"""Callback method for the client connection."""
logger = logging.getLogger(__name__)
logger.info("client %s connected with result code %s", client, rc)
for sub in self._subscriptions:
client.subscribe(sub, qos=0)
@staticmethod
def on_disconnect(client: Client, packet, exc=None):
"""Callback for disconnections."""
logger = logging.getLogger(__name__)
logger.info("disconnected from broker: %s", client)
@staticmethod
def on_subscribe(client: Client, mid: int, qos: int, properties):
"""Callback for subscriptions."""
logger = logging.getLogger(__name__)
logger.info("Subscribed to topic(s) with mid %s and qos %s", mid, qos)
@staticmethod
def on_message(client: Client, topic, payload: bytes, qos, properties):
"""Callback for message handling."""
logger = logging.getLogger(__name__)
message = payload.decode()
logger.info("Recieved message '%s' on topic %s", message, topic)

84
src/hasskiosk/runner.py Normal file
View File

@ -0,0 +1,84 @@
"""Runner and daemon management."""
import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any, Dict, Set
from asyncio_mqtt import Client
from .config import configure_logging, read_config
def run():
"""Run the daemon."""
config = read_config()
configure_logging(config)
asyncio.run(main(config))
async def main(config: Dict[str, Any]):
"""Setup and run the async tasks."""
async with AsyncExitStack() as exit_stack:
tasks: Set[asyncio.Task] = set()
exit_stack.push_async_callback(cancel_tasks, tasks)
mqtt = Client(
hostname=config["mqtt"]["host"],
port=config["mqtt"]["port"],
username=config["mqtt"]["username"],
password=config["mqtt"]["password"],
clean_session=True,
)
await exit_stack.enter_async_context(mqtt)
topic_handlers = (
(config["mqtt"]["screen_state_topic"], screen_state_mqtt_handler),
)
for topic, handler in topic_handlers:
manager = mqtt.filtered_messages(topic)
messages = await exit_stack.enter_async_context(manager)
tasks.add(asyncio.create_task(handler(messages)))
other_messages = await exit_stack.enter_async_context(
mqtt.unfiltered_messages()
)
tasks.add(asyncio.create_task(dead_letter_handler(other_messages)))
await mqtt.subscribe(config["mqtt"]["subscription"])
await asyncio.gather(*tasks)
async def screen_state_mqtt_handler(messages):
"""Screen state handler, reacts on presence messages."""
log = logging.getLogger(__name__)
async for message in messages:
log.info(
"screen sate message on topic %s: %s",
message.topic,
message.payload.decode(),
)
async def dead_letter_handler(messages):
"""Logger for uncaught messages."""
log = logging.getLogger(__name__)
async for message in messages:
log.info(
"unfiltered message on topic %s: %s",
message.topic,
message.payload.decode(),
)
async def cancel_tasks(tasks):
"""Cancel tasks on shutdown."""
for task in tasks:
if task.done():
continue
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

View File

@ -1,6 +1,8 @@
"""Tests for the configuration management.""" """Tests for the configuration management."""
import logging import logging
import os
import socket
from typing import Any, Dict from typing import Any, Dict
import pytest import pytest
@ -12,7 +14,7 @@ from hasskiosk.config import configure_logging, read_config
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mock_env(monkeypatch): def mock_env(monkeypatch):
"""Environment mock to test values.""" """Environment mock to test values."""
monkeypatch.setenv("HASSKIOSK_TOPIC_PRESENCE", "home/test/presence") monkeypatch.setenv("HASSKIOSK_MQTT_SUBSCRIBE_TOPIC", "home/test/#")
monkeypatch.setenv("HASSKIOSK_MQTT_HOST", "ha.example.com") monkeypatch.setenv("HASSKIOSK_MQTT_HOST", "ha.example.com")
monkeypatch.setenv("HASSKIOSK_MQTT_USERNAME", "testymctesterson") monkeypatch.setenv("HASSKIOSK_MQTT_USERNAME", "testymctesterson")
monkeypatch.setenv("HASSKIOSK_MQTT_PASSWORD", "hunter2") monkeypatch.setenv("HASSKIOSK_MQTT_PASSWORD", "hunter2")
@ -55,14 +57,19 @@ def logging_config() -> Dict[str, Any]:
def test_read_config(): def test_read_config():
"""Test the read_config function.""" """Test the read_config function."""
config = read_config() config = read_config()
hostname = socket.gethostname()
pid = os.getpid()
assert config["version"] == __version__ assert config["version"] == __version__
assert config["sysname"] == "hasskiosk" assert config["sysname"] == "hasskiosk"
assert config["client_id"] == f"hasskiosk-{hostname}-{pid}"
assert config["mqtt"] == { assert config["mqtt"] == {
"host": "ha.example.com", "host": "ha.example.com",
"username": "testymctesterson", "username": "testymctesterson",
"password": "hunter2", "password": "hunter2",
"port": 1883, "port": 1883,
"keepalive": 60, "keepalive": 60,
"subscription": "home/test/#",
"screen_state_topic": "#/presence",
} }

View File

@ -54,7 +54,6 @@ commands =
recreate=True recreate=True
deps = deps =
bpython bpython
mypy
python-language-server python-language-server
rope rope
{[testenv:py37]deps} {[testenv:py37]deps}