matrix-bot/matrix_bot/bot.py

569 lines
23 KiB
Python

import asyncio
import getpass
import json
import logging
import os
import signal
import sys
import traceback
import click
from nio import (AsyncClient, AsyncClientConfig, InviteEvent,
KeyVerificationCancel, KeyVerificationEvent,
KeyVerificationKey, KeyVerificationMac, KeyVerificationStart,
LocalProtocolError, LoginResponse, MatrixRoom, RoomMessage,
RoomMessageText, SyncResponse, ToDeviceError)
from .config import Config
from .plugins import all_plugins
from .utils import setup_logger
# from .message import TextMessage
class Bot(object):
def __init__(self) -> None:
self.__client = None
self.__logger = setup_logger(__name__)
self.logger.debug('Initializing Bot')
self.__loop = asyncio.get_running_loop()
self.__first_sync = True
self.__command_plugins = {}
self.__message_plugins = {}
self.__runtime_plugins = {}
for s in (signal.SIGINT, signal.SIGTERM):
self.__loop.add_signal_handler(s, self.__signal_handler)
async def login(self) -> (AsyncClient, None):
"""Login to the matrix homeserver defined in the config file.
"""
self.logger.debug('Starting login process')
self.__client_config = AsyncClientConfig(
max_limit_exceeded=0,
max_timeouts=0,
store_sync_tokens=True,
encryption_enabled=True,
)
# If there are no previously-saved credentials, we'll use the password
if not os.path.exists(Config.CONFIG_FILE):
self.logger.debug('Starting password verification process')
click.secho(
'First time use. Did not find credential file. Asking '
'for homeserver, user, and password to create '
'credential file.\n',
bold=True,
)
if not os.path.exists(Config.STORE_PATH):
os.makedirs(Config.STORE_PATH)
credentials = self.__ask_credentials()
# Initialize the matrix client
self.__client = AsyncClient(
credentials['homeserver'],
credentials['user_id'],
store_path=Config.STORE_PATH,
config=self.client_config,
)
pw = None
try:
pw = getpass.getpass(click.style('Your Password: ', bold=True))
except EOFError:
print()
await self.shutdown()
return None
finally:
if pw:
resp = await self.client.login(
password=pw, device_name=credentials['device_name'])
del pw
# check that we logged in succesfully
if isinstance(resp, LoginResponse):
self.__write_details_to_disk(resp, credentials)
else:
self.logger.debug(f'homeserver = {credentials["homeserver"]}; '
f' user = {credentials["user_id"]}')
self.logger.warn(f'Failed to log in: {resp}')
sys.exit(1)
self.__config = {
'user_id': credentials['user_id'],
'homeserver': credentials['homeserver'],
'device_name': credentials['device_name'],
'device_id': resp.device_id,
'access_token': resp.access_token,
}
click.secho(
'Logged in using a password. Credentials were stored. '
'On next execution the stored login credentials will '
'be used.',
fg='green',
)
# Otherwise the config file exists, so we'll use the stored credentials
else:
self.logger.debug('Reading credentials.json')
# open the file in read-only mode
with open(Config.CONFIG_FILE, 'r') as f:
self.__config = json.load(f)
# Initialize the matrix client based on credentials from file
self.__client = AsyncClient(
self.config['credentials']['homeserver'],
self.config['credentials']['user_id'],
device_id=self.config['credentials']['device_id'],
store_path=Config.STORE_PATH,
config=self.client_config,
)
self.__client.restore_login(
user_id=self.config['credentials']['user_id'],
device_id=self.config['credentials']['device_id'],
access_token=self.config['credentials']['access_token'],
)
self.logger.debug('Logged in using stored credentials.')
return self.__client
async def __upload_keys(self) -> None:
if self.client.should_upload_keys:
await self.client.keys_upload()
if self.client.should_query_keys:
await self.client.keys_query()
if self.client.should_claim_keys:
await self.client.keys_claim(
self.client.get_users_for_key_claiming())
async def sync(self) -> None:
self.logger.debug('Starting sync')
await self.__upload_keys()
await self.client.sync(timeout=30000, full_state=True)
async def sync_forever(self) -> None:
# next_batch = self.__read_next_batch()
await self.client.sync_forever(timeout=30000, full_state=True)
async def verify(self) -> None:
"""Login and wait for and perform emoji verify."""
# Set up event callbacks
client = await self.login()
if getattr(self, 'client', None):
self.logger.debug('Adding callbacks')
client.add_to_device_callback(self.__to_device_callback,
(KeyVerificationEvent, ))
self.__upload_keys()
click.secho('\nStarting verification process...',
bold=True,
fg='green')
click.secho(
'\nThis program is ready and waiting for the other '
'party to initiate an emoji verification with us by '
'selecting "Verify by Emoji" in their Matrix '
'client.',
fg='green',
)
await self.sync_forever()
async def run(self) -> None:
await self.login()
if self.__client:
self.client.add_response_callback(self.__sync_callback,
(SyncResponse, ))
self.client.add_event_callback(self.__message_callback,
(RoomMessage, ))
self.client.add_event_callback(self.__invite_callback,
(InviteEvent, ))
self.__load_plugins()
for plugin in self.runtime_plugins.keys():
await self.runtime_plugins[plugin].on_run()
await self.sync_forever()
async def find_room_by_id(self, room_id: str) -> (MatrixRoom, None):
rooms = self.client.rooms.keys()
if room_id in rooms:
return self.client.rooms[room_id]
return None
def __ask_credentials(self) -> dict:
"""Ask the user for credentials
"""
try:
homeserver = Config.HOMESERVER_URL
homeserver = click.prompt(
click.style('Enter your homeserver URL', bold=True),
default=homeserver,
)
if not homeserver.startswith('https://'):
homeserver = 'https://' + homeserver
user_id = f'@{getpass.getuser()}:gaja-group.com'
user_id = click.prompt(click.style('Enter your full user ID',
bold=True),
default=user_id)
device_name = 'matrix-bot'
device_name = click.prompt(
click.style('Choose a name for this device', bold=True),
default=device_name,
)
except click.exceptions.Abort:
sys.exit(1)
return {
'homeserver': homeserver,
'user_id': user_id,
'device_name': device_name,
}
# Callbacks
async def __to_device_callback(self, event): # noqa
"""Handle events sent to device."""
try:
client = self.client
if isinstance(event, KeyVerificationStart): # first step
""" first step: receive KeyVerificationStart
KeyVerificationStart(
source={'content':
{'method': 'm.sas.v1',
'from_device': 'DEVICEIDXY',
'key_agreement_protocols':
['curve25519-hkdf-sha256', 'curve25519'],
'hashes': ['sha256'],
'message_authentication_codes':
['hkdf-hmac-sha256', 'hmac-sha256'],
'short_authentication_string':
['decimal', 'emoji'],
'transaction_id': 'SomeTxId'
},
'type': 'm.key.verification.start',
'sender': '@user2:example.org'
},
sender='@user2:example.org',
transaction_id='SomeTxId',
from_device='DEVICEIDXY',
method='m.sas.v1',
key_agreement_protocols=[
'curve25519-hkdf-sha256', 'curve25519'],
hashes=['sha256'],
message_authentication_codes=[
'hkdf-hmac-sha256', 'hmac-sha256'],
short_authentication_string=['decimal', 'emoji'])
"""
if 'emoji' not in event.short_authentication_string:
click.echo(
'Other device does not support emoji verification '
f'{event.short_authentication_string}.')
return
resp = await client.accept_key_verification(
event.transaction_id)
if isinstance(resp, ToDeviceError):
self.logger.warning(
f'accept_key_verification failed with {resp}',
fg='red')
sas = client.key_verifications[event.transaction_id]
todevice_msg = sas.share_key()
resp = await client.to_device(todevice_msg)
if isinstance(resp, ToDeviceError):
self.logger.warning(f'to_device failed with {resp}',
fg='red')
elif isinstance(event, KeyVerificationCancel): # anytime
""" at any time: receive KeyVerificationCancel
KeyVerificationCancel(source={
'content': {'code': 'm.mismatched_sas',
'reason': 'Mismatched authentication string',
'transaction_id': 'SomeTxId'},
'type': 'm.key.verification.cancel',
'sender': '@user2:example.org'},
sender='@user2:example.org',
transaction_id='SomeTxId',
code='m.mismatched_sas',
reason='Mismatched short authentication string')
"""
# There is no need to issue a
# client.cancel_key_verification(tx_id, reject=False)
# here. The SAS flow is already cancelled.
# We only need to inform the user.
click.echo('\nVerification has been cancelled by '
f'{event.sender} for reason "{event.reason}".')
elif isinstance(event, KeyVerificationKey): # second step
""" Second step is to receive KeyVerificationKey
KeyVerificationKey(
source={'content': {
'key': 'SomeCryptoKey',
'transaction_id': 'SomeTxId'},
'type': 'm.key.verification.key',
'sender': '@user2:example.org'
},
sender='@user2:example.org',
transaction_id='SomeTxId',
key='SomeCryptoKey')
"""
click.secho('\nEmoji verification initiated.\n')
sas = client.key_verifications[event.transaction_id]
emojis = sas.get_emoji()
emoji_list = [' '.join(e) for e in emojis]
click.echo(', '.join(emoji_list))
# print(f'{sas.get_emoji()}')'
print()
try:
if click.confirm(
click.style('Do the emojis match?', bold=True), ):
click.secho(
'\nMatch! The verification for this '
'device will be accepted.',
fg='green',
)
resp = await client.confirm_short_auth_string(
event.transaction_id)
if isinstance(resp, ToDeviceError):
click.secho(
'confirm_short_auth_string failed with '
f'{resp}',
fg='red',
)
else: # no, don't match, reject
click.secho(
'\nNo match! Device will NOT be verified '
'by rejecting verification.',
fg='yellow',
)
resp = await client.cancel_key_verification(
event.transaction_id, reject=True)
if isinstance(resp, ToDeviceError):
click.secho(
f'cancel_key_verification failed with {resp}',
fg='red',
)
except click.exceptions.Abort: # C or anything for cancel
click.secho(
'Cancelled by user! Verification will be '
'cancelled.',
fg='red',
)
resp = await client.cancel_key_verification(
event.transaction_id, reject=False)
if isinstance(resp, ToDeviceError):
self.logger.warn(
f'cancel_key_verification failed with {resp}')
elif isinstance(event, KeyVerificationMac): # third step
""" Third step is to receive KeyVerificationMac
KeyVerificationMac(
source={'content': {
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
'ed25519:SomeKey2': 'SomeKey3'},
'keys': 'SomeCryptoKey4',
'transaction_id': 'SomeTxId'},
'type': 'm.key.verification.mac',
'sender': '@user2:example.org'},
sender='@user2:example.org',
transaction_id='SomeTxId',
mac={'ed25519:DEVICEIDXY': 'SomeKey1',
'ed25519:SomeKey2': 'SomeKey3'},
keys='SomeCryptoKey4')
"""
sas = client.key_verifications[event.transaction_id]
try:
todevice_msg = sas.get_mac()
except LocalProtocolError as e:
# e.g. it might have been cancelled by ourselves
click.secho(
f'Cancelled or protocol error: Reason: {e}.\n'
f'Verification with {event.sender} not '
'concluded. Try again?',
fg='yellow',
)
else:
resp = await client.to_device(todevice_msg)
if isinstance(resp, ToDeviceError):
self.logger.warn(f'to_device failed with {resp}')
# print(f'sas.we_started_it = {sas.we_started_it}\n'
# f'sas.sas_accepted = {sas.sas_accepted}\n'
# f'sas.canceled = {sas.canceled}\n'
# f'sas.timed_out = {sas.timed_out}\n'
# f'sas.verified = {sas.verified}\n'
# f'sas.verified_devices = {sas.verified_devices}\n')
click.secho(
'Emoji verification was successful! Please use Ctrl+C '
'to exit.',
fg='green')
else:
self.logger.warn(
f'Received unexpected event type {type(event)}. '
f'Event is {event}. Event will be ignored.')
except BaseException:
self.logger.critical(traceback.format_exc())
async def shutdown(self) -> None:
self.logger.info('Shutdown Bot')
for task in asyncio.Task.all_tasks():
task.cancel()
if getattr(self, 'client', None):
await self.client.close()
async def __sync_callback(self, event: any) -> None:
self.logger.debug('Client syncing and saving next batch token')
if self.__first_sync and len(self.client.invited_rooms) > 0:
for room in self.client.invited_rooms:
await self.client.join(room)
self.__first_sync = False
with open(Config.NEXT_BATCH_PATH, 'w') as next_batch_token:
next_batch_token.write(event.next_batch)
async def __invite_callback(self, source: MatrixRoom, sender: any) -> None:
await self.client.join(source.room_id)
async def __handle_text_message(self, room: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Handling Text Message %s', message)
for plugin in self.message_plugins.keys():
await self.message_plugins[plugin].on_message(room, message)
async def __handle_command_message(self, room: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Handling Command Message %s', message)
if 'help' not in self.command_plugins.keys():
self.plugins['help'] = all_plugins['help']
body = message.body.split(' ')
plugin = self.command_plugins['help']
if len(body) > 1:
if (body[1] in self.command_plugins.keys()):
plugin = self.command_plugins[body[1]]
self.logger.debug('Handling Command %s', body[1])
await plugin.on_command(room, message)
async def __text_message_callback(self, source: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Text Message Recieved: %s %s: %s', source.room_id,
message.sender, message.body)
if (message.body.startswith(self.config['config']['chat_prefix'])):
return await self.__handle_command_message(source, message)
else:
return await self.__handle_text_message(source, message)
async def __message_callback(self, source: MatrixRoom,
message: RoomMessage) -> None:
self.logger.debug('Message Recieved')
if (isinstance(message, RoomMessageText)):
await self.__text_message_callback(source, message)
def __signal_handler(self) -> None:
self.loop.create_task(self.shutdown())
# Plugins
def __load_plugins(self) -> None:
for plugin in self.config['config']['plugins']:
if plugin in all_plugins.keys():
obj = all_plugins[plugin](self, plugin)
self.logger.info('Loading plugin %s', plugin)
if getattr(obj, 'on_command', None):
self.command_plugins[plugin] = obj
if getattr(obj, 'on_message', None):
self.message_plugins[plugin] = obj
if getattr(obj, 'on_run', None):
self.runtime_plugins[plugin] = obj
# Files
def __write_details_to_disk(self, resp: LoginResponse,
credentials: dict) -> None:
"""Write the required login details to disk.
It will allow following logins to be made without password.
Arguments:
---------
resp : LoginResponse - successful client login response
credentials : dict - The credentials used to sign in
"""
# open the config file in write-mode
with open(Config.CONFIG_FILE, 'w') as f:
# write the login details to disk
json.dump(
{
'credentials': {
'homeserver': credentials['homeserver'],
'device_name': credentials['device_name'],
'user_id': resp.user_id,
'device_id': resp.device_id,
'access_token': resp.access_token,
},
'config': {
'chat_prefix': Config.CHAT_PREFIX,
'plugins': ['help']
},
},
f,
)
def __read_next_batch(self) -> (str, None):
# we read the previously-written token...
next_batch_name = Config.NEXT_BATCH_PATH
if os.path.exists(next_batch_name):
with open(next_batch_name, 'r') as next_batch_token:
# ... and well async_client to use it
self.client.next_batch = next_batch_token.read()
return self.client.next_batch
# Properties
@property
def loop(self) -> asyncio.AbstractEventLoop:
return self.__loop
@property
def logger(self) -> logging.Logger:
return self.__logger
@property
def config(self) -> dict:
return self.__config
@property
def client_config(self) -> AsyncClientConfig:
return self.__client_config
@property
def client(self) -> AsyncClient:
return self.__client
@property
def message_plugins(self) -> list:
return self.__message_plugins
@property
def command_plugins(self) -> list:
return self.__command_plugins
@property
def runtime_plugins(self) -> list:
return self.__runtime_plugins