Add Plugin system

This commit is contained in:
Patrick Neff 2020-08-06 17:16:43 +02:00
parent 13f1c93778
commit 46681e1a6c
11 changed files with 298 additions and 80 deletions

View File

@ -1,2 +1,2 @@
[settings] [settings]
known_third_party = appdirs,click,markdown,nio,setuptools known_third_party = aiohttp,appdirs,click,markdown,nio,setuptools

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import getpass
import json import json
import logging import logging
import os import os
@ -14,6 +15,7 @@ from nio import (AsyncClient, AsyncClientConfig, InviteEvent,
RoomMessageText, SyncResponse, ToDeviceError) RoomMessageText, SyncResponse, ToDeviceError)
from .config import Config from .config import Config
from .plugins import all_plugins
from .utils import setup_logger from .utils import setup_logger
# from .message import TextMessage # from .message import TextMessage
@ -21,14 +23,18 @@ from .utils import setup_logger
class Bot(object): class Bot(object):
def __init__(self) -> None: def __init__(self) -> None:
self.__client = None
self.__logger = setup_logger(__name__) self.__logger = setup_logger(__name__)
self.logger.debug('Initializing Bot') self.logger.debug('Initializing Bot')
self.__loop = asyncio.get_running_loop() self.__loop = asyncio.get_running_loop()
self.__first_sync = True self.__first_sync = True
self.__command_plugins = {}
self.__message_plugins = {}
self.__runtime_plugins = {}
for s in (signal.SIGINT, signal.SIGTERM): for s in (signal.SIGINT, signal.SIGTERM):
self.__loop.add_signal_handler(s, self.__signal_handler) self.__loop.add_signal_handler(s, self.__signal_handler)
async def login(self) -> AsyncClient: async def login(self) -> (AsyncClient, None):
"""Login to the matrix homeserver defined in the config file. """Login to the matrix homeserver defined in the config file.
""" """
self.logger.debug('Starting login process') self.logger.debug('Starting login process')
@ -51,22 +57,27 @@ class Bot(object):
if not os.path.exists(Config.STORE_PATH): if not os.path.exists(Config.STORE_PATH):
os.makedirs(Config.STORE_PATH) os.makedirs(Config.STORE_PATH)
credentials = self._ask_credentials() credentials = self.__ask_credentials()
# Initialize the matrix client # Initialize the matrix client
client = AsyncClient( self.__client = AsyncClient(
credentials['homeserver'], credentials['homeserver'],
credentials['user_id'], credentials['user_id'],
store_path=Config.STORE_PATH, store_path=Config.STORE_PATH,
config=self.client_config, config=self.client_config,
) )
pw = click.prompt(click.style('Your Password', bold=True), pw = None
hide_input=True) try:
pw = getpass.getpass(click.style('Your Password: ', bold=True))
resp = await client.login(password=pw, except EOFError:
device_name=credentials['device_name']) print()
await self.shutdown()
return None
finally:
if pw:
resp = await self.client.login(
password=pw, device_name=credentials['device_name'])
del pw del pw
# check that we logged in succesfully # check that we logged in succesfully
@ -100,37 +111,37 @@ class Bot(object):
with open(Config.CONFIG_FILE, 'r') as f: with open(Config.CONFIG_FILE, 'r') as f:
self.__config = json.load(f) self.__config = json.load(f)
# Initialize the matrix client based on credentials from file # Initialize the matrix client based on credentials from file
client = AsyncClient( self.__client = AsyncClient(
self.config['homeserver'], self.config['credentials']['homeserver'],
self.config['user_id'], self.config['credentials']['user_id'],
device_id=self.config['device_id'], device_id=self.config['credentials']['device_id'],
store_path=Config.STORE_PATH, store_path=Config.STORE_PATH,
config=self.client_config, config=self.client_config,
) )
client.restore_login( self.__client.restore_login(
user_id=self.config['user_id'], user_id=self.config['credentials']['user_id'],
device_id=self.config['device_id'], device_id=self.config['credentials']['device_id'],
access_token=self.config['access_token'], access_token=self.config['credentials']['access_token'],
) )
self.logger.debug('Logged in using stored credentials.') self.logger.debug('Logged in using stored credentials.')
self.__client = client return self.__client
return client async def __upload_keys(self) -> None:
async def sync(self) -> None:
self.logger.debug('Starting sync')
next_batch = self.__read_next_batch()
if self.client.should_upload_keys:
await self.client.keys_upload() await self.client.keys_upload()
if self.client.should_query_keys: if self.client.should_query_keys:
await self.client.keys_query() await self.client.keys_query()
if self.client.should_claim_keys: if self.client.should_claim_keys:
await self.client.keys_claim(self.get_users_for_key_claiming()) await self.client.keys_claim(
self.client.get_users_for_key_claiming())
async def sync(self) -> None:
self.logger.debug('Starting sync')
next_batch = self.__read_next_batch()
self.__upload_keys()
await self.client.sync(timeout=30000, await self.client.sync(timeout=30000,
full_state=True, full_state=True,
@ -146,13 +157,11 @@ class Bot(object):
# Set up event callbacks # Set up event callbacks
client = await self.login() client = await self.login()
if getattr(self, 'client', None):
self.logger.debug('Adding callbacks') self.logger.debug('Adding callbacks')
client.add_to_device_callback(self.__to_device_callback, client.add_to_device_callback(self.__to_device_callback,
(KeyVerificationEvent, )) (KeyVerificationEvent, ))
# Sync encryption keys with the server self.__upload_keys()
# Required for participating in encrypted rooms
if self.client.should_upload_keys:
await self.client.keys_upload()
click.secho('\nStarting verification process...', click.secho('\nStarting verification process...',
bold=True, bold=True,
fg='green') fg='green')
@ -168,12 +177,17 @@ class Bot(object):
async def run(self) -> None: async def run(self) -> None:
await self.login() await self.login()
if self.__client:
self.client.add_response_callback(self.__sync_callback, self.client.add_response_callback(self.__sync_callback,
(SyncResponse, )) (SyncResponse, ))
self.client.add_event_callback(self.__message_callback, self.client.add_event_callback(self.__message_callback,
(RoomMessage, )) (RoomMessage, ))
self.client.add_event_callback(self.__invite_callback, (InviteEvent, )) self.client.add_event_callback(self.__invite_callback,
(InviteEvent, ))
self.__load_plugins()
for plugin in self.runtime_plugins.keys():
await self.runtime_plugins[plugin].on_run()
await self.sync_forever() await self.sync_forever()
async def find_room_by_id(self, room_id: str) -> (MatrixRoom, None): async def find_room_by_id(self, room_id: str) -> (MatrixRoom, None):
@ -182,7 +196,7 @@ class Bot(object):
return self.client.rooms[room_id] return self.client.rooms[room_id]
return None return None
def _ask_credentials(self) -> dict: def __ask_credentials(self) -> dict:
"""Ask the user for credentials """Ask the user for credentials
""" """
try: try:
@ -195,7 +209,7 @@ class Bot(object):
if not homeserver.startswith('https://'): if not homeserver.startswith('https://'):
homeserver = 'https://' + homeserver homeserver = 'https://' + homeserver
user_id = '@user:gaja-group.com' user_id = f'@{getpass.getuser()}:gaja-group.com'
user_id = click.prompt(click.style('Enter your full user ID', user_id = click.prompt(click.style('Enter your full user ID',
bold=True), bold=True),
default=user_id) default=user_id)
@ -207,7 +221,7 @@ class Bot(object):
) )
except click.exceptions.Abort: except click.exceptions.Abort:
sys.exit(0) sys.exit(1)
return { return {
'homeserver': homeserver, 'homeserver': homeserver,
@ -412,6 +426,7 @@ class Bot(object):
self.logger.info('Shutdown Bot') self.logger.info('Shutdown Bot')
for task in asyncio.Task.all_tasks(): for task in asyncio.Task.all_tasks():
task.cancel() task.cancel()
if getattr(self, 'client', None):
await self.client.close() await self.client.close()
async def __sync_callback(self, event: any) -> None: async def __sync_callback(self, event: any) -> None:
@ -426,14 +441,36 @@ class Bot(object):
async def __invite_callback(self, source: MatrixRoom, sender: any) -> None: async def __invite_callback(self, source: MatrixRoom, sender: any) -> None:
await self.client.join(source.room_id) await self.client.join(source.room_id)
async def __handle_text_message(self, room: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Handling Text Message %s', message)
for plugin in self.message_plugins.keys():
await self.message_plugins[plugin].on_message(room, message)
async def __handle_command_message(self, room: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Handling Command Message %s', message)
if 'help' not in self.command_plugins.keys():
self.plugins['help'] = all_plugins['help']
body = message.body.split(' ')
plugin = self.command_plugins['help']
if len(body) > 1:
if (body[1] in self.command_plugins.keys()):
plugin = self.command_plugins[body[1]]
self.logger.debug('Handling Command %s', body[1])
await plugin.on_command(room, message)
async def __text_message_callback(self, source: MatrixRoom, async def __text_message_callback(self, source: MatrixRoom,
message: RoomMessageText) -> None: message: RoomMessageText) -> None:
self.logger.debug('Text Message Recieved: %s %s: %s', source.room_id, self.logger.debug('Text Message Recieved: %s %s: %s', source.room_id,
message.sender, message.body) message.sender, message.body)
if (message.body.startswith(self.config['config']['chat_prefix'])):
return await self.__handle_command_message(source, message)
else:
return await self.__handle_text_message(source, message)
async def __message_callback(self, source: MatrixRoom, async def __message_callback(self, source: MatrixRoom,
message: RoomMessage) -> None: message: RoomMessage) -> None:
print(message)
self.logger.debug('Message Recieved') self.logger.debug('Message Recieved')
if (isinstance(message, RoomMessageText)): if (isinstance(message, RoomMessageText)):
await self.__text_message_callback(source, message) await self.__text_message_callback(source, message)
@ -441,6 +478,20 @@ class Bot(object):
def __signal_handler(self) -> None: def __signal_handler(self) -> None:
self.loop.create_task(self.shutdown()) self.loop.create_task(self.shutdown())
# Plugins
def __load_plugins(self) -> None:
for plugin in self.config['config']['plugins']:
if plugin in all_plugins.keys():
obj = all_plugins[plugin](self, plugin)
self.logger.info('Loading plugin %s', plugin)
if getattr(obj, 'on_command', None):
self.command_plugins[plugin] = obj
if getattr(obj, 'on_message', None):
self.message_plugins[plugin] = obj
if getattr(obj, 'on_run', None):
self.runtime_plugins[plugin] = obj
# Files # Files
def __write_details_to_disk(self, resp: LoginResponse, def __write_details_to_disk(self, resp: LoginResponse,
@ -466,7 +517,11 @@ class Bot(object):
'user_id': resp.user_id, 'user_id': resp.user_id,
'device_id': resp.device_id, 'device_id': resp.device_id,
'access_token': resp.access_token, 'access_token': resp.access_token,
} },
'config': {
'chat_prefix': Config.CHAT_PREFIX,
'plugins': ['help']
},
}, },
f, f,
) )
@ -501,3 +556,15 @@ class Bot(object):
@property @property
def client(self) -> AsyncClient: def client(self) -> AsyncClient:
return self.__client return self.__client
@property
def message_plugins(self) -> list:
return self.__message_plugins
@property
def command_plugins(self) -> list:
return self.__command_plugins
@property
def runtime_plugins(self) -> list:
return self.__runtime_plugins

View File

@ -4,10 +4,11 @@ import sys
import click import click
from .bot import Bot from .bot import Bot
from .utils import run_async, setup_logger from .client import send_message as client_send_message
from .config import Config from .config import Config
from .message import TextMessage, MarkdownMessage
from .exceptions import NoRoomException from .exceptions import NoRoomException
from .message import MarkdownMessage, TextMessage
from .utils import run_async, setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
@ -29,6 +30,7 @@ async def send_message(room_id: str, message: str, markdown: bool) -> None:
logger.debug('Sending a message to %s', room_id) logger.debug('Sending a message to %s', room_id)
bot = Bot() bot = Bot()
client = await bot.login() client = await bot.login()
if client:
await bot.sync() await bot.sync()
room = await bot.find_room_by_id(room_id) room = await bot.find_room_by_id(room_id)
if markdown: if markdown:
@ -113,5 +115,19 @@ def send(ctx: click.Context, room_id: str, message: list, markdown: bool,
run_async(send_message(room_id, ' '.join(message), markdown)) run_async(send_message(room_id, ' '.join(message), markdown))
@cli.group()
@click.pass_context
def client(ctx: click.Context) -> None:
pass
@client.command('send')
@click.pass_context
@click.argument('room_id')
@click.argument('message', nargs=-1, required=True)
def client_send(ctx: click.Context, room_id: str, message: list) -> None:
run_async(client_send_message(room_id, ' '.join(message)))
if __name__ == '__main__': if __name__ == '__main__':
cli(obj={}) cli(obj={})

18
matrix_bot/client.py Normal file
View File

@ -0,0 +1,18 @@
import aiohttp
async def send_message(room: str, message: str) -> None:
conn = aiohttp.UnixConnector(path='/tmp/test.sock')
try:
async with aiohttp.request('POST',
'http://localhost/',
json={
'type': 'message.text',
'room_id': room,
'message': message
},
connector=conn) as resp:
assert resp.status == 200
print(await resp.text())
finally:
await conn.close()

View File

@ -1,4 +1,5 @@
import os import os
import appdirs import appdirs
@ -6,6 +7,7 @@ class Config:
APP_NAME = 'matrix-bot' APP_NAME = 'matrix-bot'
# Default Homeserver URL # Default Homeserver URL
HOMESERVER_URL = 'https://matrix.gaja-group.com' HOMESERVER_URL = 'https://matrix.gaja-group.com'
CHAT_PREFIX = '!mbot'
@classmethod @classmethod
def init(cls, botname: str = 'default', loglevel: int = 40) -> None: def init(cls, botname: str = 'default', loglevel: int = 40) -> None:
@ -26,7 +28,6 @@ class Config:
'store') # local directory 'store') # local directory
cls.NEXT_BATCH_PATH = os.path.join(cls.DATA_DIRECTORY, cls.NEXT_BATCH_PATH = os.path.join(cls.DATA_DIRECTORY,
'next_batch') # local directory 'next_batch') # local directory
print(cls.CONFIG_DIRECTORY)
Config.init() Config.init()

View File

@ -1,11 +1,11 @@
import re
import logging import logging
import re
from nio import AsyncClient, MatrixRoom
from markdown import markdown from markdown import markdown
from nio import AsyncClient, MatrixRoom
from .utils import setup_logger
from .exceptions import NoRoomException from .exceptions import NoRoomException
from .utils import setup_logger
class Message(object): class Message(object):
@ -71,11 +71,12 @@ class MarkdownMessage(Message):
super().__init__(_type) super().__init__(_type)
self.formatted_body = markdown(body) self.formatted_body = markdown(body)
self.format = 'org.matrix.custom.html' self.format = 'org.matrix.custom.html'
self.body = self.__clean_html(self.formatted_body) self.body = self.clean_html(self.formatted_body)
self.msgtype = msgtype self.msgtype = msgtype
def __clean_html(self, raw_html: str) -> str: @classmethod
cleantext = re.sub(self.clean_regexp, '', raw_html) def clean_html(cls, raw_html: str) -> str:
cleantext = re.sub(cls.clean_regexp, '', raw_html)
return cleantext return cleantext
@property @property

View File

@ -0,0 +1,13 @@
import importlib
import os
import pkgutil
pkg_dir = os.path.dirname(__file__)
all_plugins = {}
for (module_loader, name, ispkg) in pkgutil.iter_modules([pkg_dir]):
if not name.startswith('_'):
cls = importlib.import_module('.' + name, __package__).Plugin
all_plugins[name] = cls

View File

@ -0,0 +1,21 @@
import logging
from nio import AsyncClient
from ..utils import setup_logger
class _Plugin(object):
def __init__(self, bot: any, name: str) -> None:
self.__name = name
self.__bot = bot
self.__client = bot.client
self.__logger = setup_logger(f'{__package__}.{self.__name}')
@property
def logger(self) -> logging.Logger:
return self.__logger
@property
def client(self) -> AsyncClient:
return self.__client

View File

@ -0,0 +1,17 @@
from nio import MatrixRoom, RoomMessageText
from ..message import MarkdownMessage
from ._plugin import _Plugin
class Plugin(_Plugin):
help = 'Echo the given input back into the room'
async def on_command(self, room: MatrixRoom,
message: RoomMessageText) -> None:
body = message.body
if (message.formatted_body):
body = MarkdownMessage.clean_html(message.formatted_body)
body = ' '.join(body.split(' ')[2:])
message = MarkdownMessage(body)
await message.send(self.client, room)

View File

@ -0,0 +1,27 @@
from nio import MatrixRoom, RoomMessageText
from ..config import Config
from ..message import MarkdownMessage
from ._plugin import _Plugin
class Plugin(_Plugin):
help = 'Display this help message'
@property
def help_message(self) -> str:
prefix = Config.CHAT_PREFIX
message = ('# Help\n'
'Use the bot by starting a message with '
f'`{prefix}`\n\n'
'## Usage\n')
for plugin in self.__bot.command_plugins:
plugin = self.__bot.command_plugins[plugin]
message += f'* {prefix} {plugin.__name} - {plugin.help}\n'
return message
async def on_command(self, room: MatrixRoom,
message: RoomMessageText) -> None:
message = MarkdownMessage(self.help_message)
await message.send(self.client, room)

View File

@ -0,0 +1,37 @@
import json
from aiohttp import web
from ..message import MarkdownMessage
from ._plugin import _Plugin
class Plugin(_Plugin):
help = 'Echo the given input back into the room'
async def __handle_request(self, request: web.Request) -> None:
response = {'status': 'Bad Request', 'status_code': 400}
if request.has_body:
body = await request.json()
if 'type' in body.keys():
self.logger.debug('Request recieved: %s %s', request, body)
if body['type'] == 'message.text' and 'room_id' in body.keys(
) and 'message' in body.keys():
room = await self.__bot.find_room_by_id(body['room_id'])
message = MarkdownMessage(body['message'])
await message.send(self.client, room)
response = {'status': 'OK', 'status_code': 200}
return web.Response(body=json.dumps(response),
status=response['status_code'],
content_type='application/json')
async def on_run(self) -> None:
app = web.Application()
app.router.add_post('/', self.__handle_request)
runner = web.AppRunner(app)
self.logger.debug('Setup Web server')
await runner.setup()
self.logger.debug('Web server starts listening')
site = web.UnixSite(runner, '/tmp/test.sock')
await site.start()