Source code for mqttools.broker

import asyncio
import logging
import re
import threading
from collections import defaultdict

from .common import CF_FIXED_HEADER
from .common import MAXIMUM_PACKET_SIZE
from .common import ConnectReasonCode
from .common import ControlPacketType
from .common import DisconnectReasonCode
from .common import MalformedPacketError
from .common import PayloadReader
from .common import PropertyIds
from .common import SubackReasonCode
from .common import UnsubackReasonCode
from .common import format_packet
from .common import format_packet_compact
from .common import pack_connack
from .common import pack_disconnect
from .common import pack_pingresp
from .common import pack_publish
from .common import pack_suback
from .common import pack_unsuback
from .common import unpack_connect
from .common import unpack_disconnect
from .common import unpack_publish
from .common import unpack_subscribe
from .common import unpack_unsubscribe

LOGGER = logging.getLogger(__name__)


# ToDo: Session expiry. Keep alive.


class ConnectError(Exception):
    pass


class DisconnectError(Exception):
    pass


class ProtocolError(Exception):
    pass


class NotRunningError(Exception):
    pass


class Session(object):

    def __init__(self, client_id):
        self.client_id = client_id
        self.subscriptions = set()
        self.wildcard_subscriptions = set()
        self.expiry_interval = 0
        self.client = None
        self.maximum_packet_size = MAXIMUM_PACKET_SIZE
        self.will_topic = None
        self.will_message = None

    def clean(self):
        self.subscriptions = set()
        self.wildcard_subscriptions = set()
        self.expiry_interval = 0
        self.client = None
        self.maximum_packet_size = MAXIMUM_PACKET_SIZE
        self.will_topic = None
        self.will_message = None


def is_wildcards_in_topic(topic):
    return '#' in topic or '+' in topic


def compile_wildcards_topic(topic):
    pattern = topic.replace('+', '[^/]*')
    pattern = pattern.replace('/#', '.*')
    pattern = pattern.replace('#', '.*')
    pattern = '^' + pattern + '$'

    return re.compile(pattern)


class Client(object):

    def __init__(self, broker, reader, writer):
        self._broker = broker
        self._reader = reader
        self._writer = writer
        self._session = None
        self._disconnect_reason = DisconnectReasonCode.UNSPECIFIED_ERROR

    async def serve_forever(self):
        addr = self._writer.get_extra_info('peername')

        self.log_info('Serving client %s:%d.', addr[0], addr[1])

        try:
            packet_type, _, payload = await self.read_packet()

            if packet_type == ControlPacketType.CONNECT:
                self.on_connect(payload)
            else:
                raise ConnectError()

            await self.reader_loop()
        except ConnectError:
            pass
        except DisconnectError:
            self._disconnect_reason = DisconnectReasonCode.NORMAL_DISCONNECTION
        except asyncio.IncompleteReadError:
            self.log_debug('Client connection lost.')
        except Exception as e:
            self.log_debug('Reader task stopped by %r.', e)

            if isinstance(e, MalformedPacketError):
                self._disconnect_reason = DisconnectReasonCode.MALFORMED_PACKET
            elif isinstance(e, ProtocolError):
                self._disconnect_reason = DisconnectReasonCode.PROTOCOL_ERROR

            if self._session is not None:
                self.disconnect()

        if self._session is not None:
            self._session.client = None

            if self._session.will_topic is not None:
                if not self.is_normal_disconnection():
                    self._broker.publish(self._session.will_topic,
                                         self._session.will_message,
                                         {})

                    if self._session.will_retain:
                        self._broker.add_retained_message(self._session.will_topic,
                                                          self._session.will_message)

            if self._session.expiry_interval == 0:
                self._broker.remove_session(self._session.client_id)

        self.log_info('Closing client %r.', addr)

    async def reader_loop(self):
        while True:
            packet_type, flags, payload = await self.read_packet()

            if packet_type == ControlPacketType.PUBLISH:
                self.on_publish(payload, flags)
            elif packet_type == ControlPacketType.SUBSCRIBE:
                self.on_subscribe(payload)
            elif packet_type == ControlPacketType.UNSUBSCRIBE:
                self.on_unsubscribe(payload)
            elif packet_type == ControlPacketType.PINGREQ:
                self.on_pingreq()
            elif packet_type == ControlPacketType.DISCONNECT:
                self.on_disconnect(payload)
            else:
                raise ProtocolError()

    async def read_packet(self):
        buf = await self._reader.readexactly(1)
        packet_type, flags = CF_FIXED_HEADER.unpack(buf)
        size = 0
        multiplier = 1
        byte = 0x80

        while (byte & 0x80) == 0x80:
            buf += await self._reader.readexactly(1)
            byte = buf[-1]
            size += ((byte & 0x7f) * multiplier)
            multiplier <<= 7

        data = await self._reader.readexactly(size)

        if LOGGER.isEnabledFor(logging.DEBUG):
            for line in format_packet('Received', buf + data):
                self.log_debug(line)
        elif LOGGER.isEnabledFor(logging.INFO):
            self.log_info(format_packet_compact('Received', buf + data))

        return packet_type, flags, PayloadReader(data)

    def on_connect(self, payload):
        (client_id,
         clean_start,
         will_topic,
         will_message,
         will_retain,
         keep_alive_s,
         properties,
         username,
         password) = unpack_connect(payload)
        self._session, session_present = self._broker.get_session(
            client_id,
            clean_start)

        if session_present:
            self.log_info(
                'Session resumed with %d simple and %d wildcard '
                'subscriptions.',
                len(self._session.subscriptions),
                len(self._session.wildcard_subscriptions))

        self._session.client = self
        reason = ConnectReasonCode.SUCCESS

        if PropertyIds.AUTHENTICATION_METHOD in properties:
            reason = ConnectReasonCode.BAD_AUTHENTICATION_METHOD

        if PropertyIds.MAXIMUM_PACKET_SIZE in properties:
            maximum_packet_size = properties[PropertyIds.MAXIMUM_PACKET_SIZE]
            self._session.maximum_packet_size = maximum_packet_size

        self._session.will_topic = will_topic
        self._session.will_message = will_message
        self._session.will_retain = will_retain

        if PropertyIds.SESSION_EXPIRY_INTERVAL in properties:
            session_expiry_interval = properties[PropertyIds.SESSION_EXPIRY_INTERVAL]
            self._session.expiry_interval = session_expiry_interval

        if (username is not None) or (password is not None):
            reason = ConnectReasonCode.BAD_USER_NAME_OR_PASSWORD

        self._write_packet(pack_connack(
            session_present,
            reason,
            {
                PropertyIds.MAXIMUM_QOS: 0,
                PropertyIds.WILDCARD_SUBSCRIPTION_AVAILABLE: 0,
                PropertyIds.SHARED_SUBSCRIPTION_AVAILABLE: 0
            }))

        if reason != ConnectReasonCode.SUCCESS:
            raise ConnectError()

        self.log_info('Client connected.')

    def on_publish(self, payload, flags):
        topic, message, properties = unpack_publish(payload, (flags >> 1) & 3)

        if is_wildcards_in_topic(topic):
            raise MalformedPacketError(f'Invalid topic {topic} in publish.')

        if flags & 1:
            if message:
                self._broker.add_retained_message(topic, message)
            else:
                self._broker.remove_retained_message(topic)

        self._broker.publish(topic, message, properties)

    def on_subscribe(self, payload):
        packet_identifier, _, subscriptions = unpack_subscribe(payload)
        reasons = bytearray()
        retained_messages = []

        for topic, _ in subscriptions:
            if is_wildcards_in_topic(topic):
                if topic not in self._session.wildcard_subscriptions:
                    self._session.wildcard_subscriptions.add(topic)
                    self._broker.add_wildcard_subscriber(topic, self._session)

                retained_messages += list(
                    self._broker.find_retained_messages_wildcards(topic))
            else:
                if topic not in self._session.subscriptions:
                    self._session.subscriptions.add(topic)
                    self._broker.add_subscriber(topic, self._session)

                retained_message = self._broker.find_retained_message(topic)

                if retained_message:
                    retained_messages.append(retained_message)

            reason = SubackReasonCode.GRANTED_QOS_0
            reasons.append(reason)

        self._write_packet(pack_suback(packet_identifier, reasons))

        for topic, message in retained_messages:
            self.publish(topic, message, True, {})

    def on_unsubscribe(self, payload):
        packet_identifier, topics = unpack_unsubscribe(payload)
        reasons = bytearray()

        for topic in topics:
            reason = UnsubackReasonCode.NO_SUBSCRIPTION_EXISTED

            if is_wildcards_in_topic(topic):
                if topic in self._session.wildcard_subscriptions:
                    self._session.wildcard_subscriptions.remove(topic)
                    self._broker.remove_wildcard_subscriber(topic, self._session)
                    reason = UnsubackReasonCode.SUCCESS
            elif topic in self._session.subscriptions:
                self._session.subscriptions.remove(topic)
                self._broker.remove_subscriber(topic, self._session)
                reason = UnsubackReasonCode.SUCCESS

            reasons.append(reason)

        self._write_packet(pack_unsuback(packet_identifier, reasons))

    def on_pingreq(self):
        self._write_packet(pack_pingresp())

    def on_disconnect(self, payload):
        unpack_disconnect(payload)

        raise DisconnectError()

    def publish(self, topic, message, retain, properties):
        self._write_packet(pack_publish(topic, message, retain, properties))

    def disconnect(self):
        self._write_packet(pack_disconnect(self._disconnect_reason))

    def _send_prefix(self, message):
        if len(message) <= self._session.maximum_packet_size:
            return 'Sending'
        else:
            return 'Not sending'

    def _write_packet(self, message):
        if LOGGER.isEnabledFor(logging.DEBUG):
            for line in format_packet(self._send_prefix(message), message):
                self.log_debug(line)
        elif LOGGER.isEnabledFor(logging.INFO):
            self.log_info(format_packet_compact(self._send_prefix(message),
                                                message))

        if len(message) <= self._session.maximum_packet_size:
            self._writer.write(message)

    def log_debug(self, fmt, *args):
        if LOGGER.isEnabledFor(logging.DEBUG):
            if self._session is None:
                LOGGER.debug(fmt, *args)
            else:
                LOGGER.debug(f'{self._session.client_id} {fmt}', *args)

    def log_info(self, fmt, *args):
        if LOGGER.isEnabledFor(logging.INFO):
            if self._session is None:
                LOGGER.info(fmt, *args)
            else:
                LOGGER.info(f'{self._session.client_id} {fmt}', *args)

    def is_normal_disconnection(self):
        return self._disconnect_reason == DisconnectReasonCode.NORMAL_DISCONNECTION


class Server:

    def __init__(self, serve_client, address):
        self.serve_client = serve_client

        if len(address) == 3:
            self._ssl = address[2]
        else:
            self._ssl = None

        self._host = address[0]
        self._port = address[1]
        self.server = None
        self.ready = asyncio.Event()

    async def serve_forever(self):
        try:
            self.server = await asyncio.start_server(self.serve_client,
                                                     self._host,
                                                     self._port,
                                                     ssl=self._ssl)
        except OSError as e:
            LOGGER.warning('%s', e)
            raise

        self.ready.set()
        server_address = self.server.sockets[0].getsockname()

        LOGGER.info('Listening for clients on %s.', server_address)

        async with self.server:
            await self.server.serve_forever()


[docs]class Broker(object): """A limited MQTT version 5.0 broker. `addresses` is a list of ``(host, port)`` and ``(host, port, ssl)`` tuples. It may also be the host string or one of the tuples. The broker will listen for clients on all given addresses. ``ssl`` is an SSL context passed to `asyncio.start_server()` as `ssl`. Create a broker and serve clients: >>> broker = Broker('localhost') >>> await broker.serve_forever() """ def __init__(self, addresses): if isinstance(addresses, str): addresses = (addresses, 1883) if isinstance(addresses, tuple): addresses = [addresses] self._sessions = {} self._subscribers = defaultdict(list) self._wildcard_subscribers = [] self._servers = [] for address in addresses: self._servers.append(Server(self.serve_client, address)) self._client_tasks = set() self._retained_messages = {} async def getsockname(self, index=0): server = self._servers[index] await server.ready.wait() return server.server.sockets[0].getsockname()
[docs] async def serve_forever(self): """Setup a listener socket and forever serve clients. This coroutine only ends if cancelled by the user. """ try: await asyncio.gather( *[server.serve_forever() for server in self._servers]) except asyncio.CancelledError: # Cancel all client tasks as the TCP server leaves them # running. for client_task in self._client_tasks: client_task.cancel() self._client_tasks = set() raise
async def serve_client(self, reader, writer): current_task = asyncio.current_task() self._client_tasks.add(current_task) client = Client(self, reader, writer) try: await client.serve_forever() finally: try: self._client_tasks.remove(current_task) except KeyError: pass def add_subscriber(self, topic, session): topic_sessions = self._subscribers[topic] if session not in topic_sessions: topic_sessions.append(session) def remove_subscriber(self, topic, session): topic_sessions = self._subscribers[topic] if session in topic_sessions: del topic_sessions[topic_sessions.index(session)] def add_wildcard_subscriber(self, topic, session): re_topic = compile_wildcards_topic(topic) self._wildcard_subscribers.append((topic, session, re_topic)) def remove_wildcard_subscriber(self, topic, session): for index, subscriber in enumerate(self._wildcard_subscribers): if topic == subscriber[0] and session == subscriber[1]: del self._wildcard_subscribers[index] break def add_retained_message(self, topic, message): self._retained_messages[topic] = message def remove_retained_message(self, topic): try: del self._retained_messages[topic] except KeyError: pass def find_retained_messages_wildcards(self, topic): re_topic = compile_wildcards_topic(topic) for topic in self._retained_messages: mo = re_topic.match(topic) if mo: yield (topic, self._retained_messages[topic]) def find_retained_message(self, topic): if topic in self._retained_messages: return (topic, self._retained_messages[topic]) else: return None def iter_subscribers(self, topic): for session in self._subscribers[topic]: if session.client is not None: yield session for _, session, re_topic in self._wildcard_subscribers: if session.client is not None: mo = re_topic.match(topic) if mo: yield session def get_session(self, client_id, clean_start): session_present = False if client_id in self._sessions: session = self._sessions[client_id] if clean_start: for topic in session.subscriptions: self.remove_subscriber(topic, session) for topic in session.wildcard_subscriptions: self.remove_wildcard_subscriber(topic, session) session.clean() else: session_present = True else: session = Session(client_id) self._sessions[client_id] = session return session, session_present def remove_session(self, client_id): del self._sessions[client_id] def publish(self, topic, message, properties): """Publish given topic and message to all subscribers. """ for session in self.iter_subscribers(topic): session.client.publish(topic, message, False, properties)
[docs]class BrokerThread(threading.Thread): """The same as :class:`Broker`, but running in a thread. Create a broker and serve clients for 60 seconds: >>> broker = BrokerThread('broker.hivemq.com') >>> broker.start() >>> time.sleep(60) >>> broker.stop() """ def __init__(self, addresses): super().__init__() self._addresses = addresses self.daemon = True self._loop = asyncio.new_event_loop() self._broker_task = self._loop.create_task(self._run()) self._running = False def run(self): asyncio.set_event_loop(self._loop) self._running = True try: self._loop.run_until_complete(self._broker_task) finally: self._loop.close()
[docs] def stop(self): """Stop the broker. All clients will be disconnected and the thread will terminate. """ if not self._running: raise NotRunningError('The broker is already stopped.') self._running = False def cancel_broker_task(): self._broker_task.cancel() self._loop.call_soon_threadsafe(cancel_broker_task) self.join()
async def _run(self): broker = Broker(self._addresses) try: await broker.serve_forever() except asyncio.CancelledError: pass