matrix-bot/matrix_bot/bot.py

569 lines
23 KiB
Python
Raw Normal View History

2020-08-03 18:32:53 +02:00
import asyncio
2020-08-06 17:16:43 +02:00
import getpass
2020-08-05 23:13:23 +02:00
import json
import logging
import os
import signal
import sys
import traceback
2020-08-03 18:32:53 +02:00
2020-08-05 23:13:23 +02:00
import click
from nio import (AsyncClient, AsyncClientConfig, InviteEvent,
KeyVerificationCancel, KeyVerificationEvent,
KeyVerificationKey, KeyVerificationMac, KeyVerificationStart,
LocalProtocolError, LoginResponse, MatrixRoom, RoomMessage,
RoomMessageText, SyncResponse, ToDeviceError)
2020-08-03 18:32:53 +02:00
2020-08-05 23:13:23 +02:00
from .config import Config
2020-08-06 17:16:43 +02:00
from .plugins import all_plugins
2020-08-05 23:13:23 +02:00
from .utils import setup_logger
2020-08-03 18:32:53 +02:00
2020-08-05 23:13:23 +02:00
# from .message import TextMessage
2020-08-03 18:32:53 +02:00
2020-08-05 23:13:23 +02:00
class Bot(object):
def __init__(self) -> None:
2020-08-06 17:16:43 +02:00
self.__client = None
2020-08-05 23:13:23 +02:00
self.__logger = setup_logger(__name__)
self.logger.debug('Initializing Bot')
self.__loop = asyncio.get_running_loop()
self.__first_sync = True
2020-08-06 17:16:43 +02:00
self.__command_plugins = {}
self.__message_plugins = {}
self.__runtime_plugins = {}
2020-08-05 23:13:23 +02:00
for s in (signal.SIGINT, signal.SIGTERM):
self.__loop.add_signal_handler(s, self.__signal_handler)
2020-08-06 17:16:43 +02:00
async def login(self) -> (AsyncClient, None):
2020-08-03 18:32:53 +02:00
"""Login to the matrix homeserver defined in the config file.
"""
2020-08-05 23:13:23 +02:00
self.logger.debug('Starting login process')
self.__client_config = AsyncClientConfig(
2020-08-03 18:32:53 +02:00
max_limit_exceeded=0,
max_timeouts=0,
store_sync_tokens=True,
encryption_enabled=True,
)
# If there are no previously-saved credentials, we'll use the password
2020-08-05 23:13:23 +02:00
if not os.path.exists(Config.CONFIG_FILE):
self.logger.debug('Starting password verification process')
2020-08-03 18:32:53 +02:00
click.secho(
'First time use. Did not find credential file. Asking '
'for homeserver, user, and password to create '
'credential file.\n',
bold=True,
)
2020-08-05 23:13:23 +02:00
if not os.path.exists(Config.STORE_PATH):
os.makedirs(Config.STORE_PATH)
2020-08-03 18:32:53 +02:00
2020-08-06 17:16:43 +02:00
credentials = self.__ask_credentials()
2020-08-03 18:32:53 +02:00
# Initialize the matrix client
2020-08-06 17:16:43 +02:00
self.__client = AsyncClient(
2020-08-03 18:32:53 +02:00
credentials['homeserver'],
credentials['user_id'],
2020-08-05 23:13:23 +02:00
store_path=Config.STORE_PATH,
2020-08-03 18:32:53 +02:00
config=self.client_config,
)
2020-08-06 17:16:43 +02:00
pw = None
try:
pw = getpass.getpass(click.style('Your Password: ', bold=True))
except EOFError:
print()
await self.shutdown()
return None
finally:
if pw:
resp = await self.client.login(
password=pw, device_name=credentials['device_name'])
del pw
2020-08-03 18:32:53 +02:00
# check that we logged in succesfully
if isinstance(resp, LoginResponse):
2020-08-05 23:13:23 +02:00
self.__write_details_to_disk(resp, credentials)
2020-08-03 18:32:53 +02:00
else:
2020-08-05 23:13:23 +02:00
self.logger.debug(f'homeserver = {credentials["homeserver"]}; '
f' user = {credentials["user_id"]}')
self.logger.warn(f'Failed to log in: {resp}')
2020-08-03 18:32:53 +02:00
sys.exit(1)
2020-08-05 23:13:23 +02:00
self.__config = {
2020-08-03 18:32:53 +02:00
'user_id': credentials['user_id'],
'homeserver': credentials['homeserver'],
2020-08-04 12:03:56 +02:00
'device_name': credentials['device_name'],
2020-08-03 18:32:53 +02:00
'device_id': resp.device_id,
'access_token': resp.access_token,
}
click.secho(
'Logged in using a password. Credentials were stored. '
'On next execution the stored login credentials will '
'be used.',
fg='green',
)
# Otherwise the config file exists, so we'll use the stored credentials
else:
2020-08-05 23:13:23 +02:00
self.logger.debug('Reading credentials.json')
2020-08-03 18:32:53 +02:00
# open the file in read-only mode
2020-08-05 23:13:23 +02:00
with open(Config.CONFIG_FILE, 'r') as f:
self.__config = json.load(f)
2020-08-03 18:32:53 +02:00
# Initialize the matrix client based on credentials from file
2020-08-06 17:16:43 +02:00
self.__client = AsyncClient(
self.config['credentials']['homeserver'],
self.config['credentials']['user_id'],
device_id=self.config['credentials']['device_id'],
2020-08-05 23:13:23 +02:00
store_path=Config.STORE_PATH,
2020-08-03 18:32:53 +02:00
config=self.client_config,
)
2020-08-06 17:16:43 +02:00
self.__client.restore_login(
user_id=self.config['credentials']['user_id'],
device_id=self.config['credentials']['device_id'],
access_token=self.config['credentials']['access_token'],
2020-08-03 18:32:53 +02:00
)
2020-08-05 23:13:23 +02:00
self.logger.debug('Logged in using stored credentials.')
2020-08-03 18:32:53 +02:00
2020-08-06 17:16:43 +02:00
return self.__client
2020-08-05 23:13:23 +02:00
2020-08-06 17:16:43 +02:00
async def __upload_keys(self) -> None:
2020-08-06 19:09:42 +02:00
if self.client.should_upload_keys:
await self.client.keys_upload()
2020-08-05 23:13:23 +02:00
if self.client.should_query_keys:
await self.client.keys_query()
if self.client.should_claim_keys:
2020-08-06 17:16:43 +02:00
await self.client.keys_claim(
self.client.get_users_for_key_claiming())
async def sync(self) -> None:
self.logger.debug('Starting sync')
2020-08-06 19:09:42 +02:00
await self.__upload_keys()
2020-08-03 18:32:53 +02:00
2020-08-06 19:09:42 +02:00
await self.client.sync(timeout=30000, full_state=True)
2020-08-05 23:13:23 +02:00
async def sync_forever(self) -> None:
# next_batch = self.__read_next_batch()
await self.client.sync_forever(timeout=30000, full_state=True)
2020-08-03 18:32:53 +02:00
2020-08-05 23:13:23 +02:00
async def verify(self) -> None:
2020-08-04 12:03:56 +02:00
"""Login and wait for and perform emoji verify."""
# Set up event callbacks
2020-08-05 23:13:23 +02:00
client = await self.login()
2020-08-06 17:16:43 +02:00
if getattr(self, 'client', None):
self.logger.debug('Adding callbacks')
client.add_to_device_callback(self.__to_device_callback,
(KeyVerificationEvent, ))
self.__upload_keys()
click.secho('\nStarting verification process...',
bold=True,
fg='green')
click.secho(
'\nThis program is ready and waiting for the other '
'party to initiate an emoji verification with us by '
'selecting "Verify by Emoji" in their Matrix '
'client.',
fg='green',
)
await self.sync_forever()
2020-08-04 12:03:56 +02:00
2020-08-05 23:13:23 +02:00
async def run(self) -> None:
await self.login()
2020-08-04 12:03:56 +02:00
2020-08-06 17:16:43 +02:00
if self.__client:
self.client.add_response_callback(self.__sync_callback,
(SyncResponse, ))
self.client.add_event_callback(self.__message_callback,
(RoomMessage, ))
self.client.add_event_callback(self.__invite_callback,
(InviteEvent, ))
2020-08-03 18:32:53 +02:00
2020-08-06 17:16:43 +02:00
self.__load_plugins()
for plugin in self.runtime_plugins.keys():
await self.runtime_plugins[plugin].on_run()
await self.sync_forever()
2020-08-04 12:03:56 +02:00
2020-08-05 23:13:23 +02:00
async def find_room_by_id(self, room_id: str) -> (MatrixRoom, None):
rooms = self.client.rooms.keys()
if room_id in rooms:
return self.client.rooms[room_id]
return None
2020-08-03 18:32:53 +02:00
2020-08-06 17:16:43 +02:00
def __ask_credentials(self) -> dict:
2020-08-03 18:32:53 +02:00
"""Ask the user for credentials
"""
2020-08-04 12:03:56 +02:00
try:
2020-08-05 23:13:23 +02:00
homeserver = Config.HOMESERVER_URL
2020-08-04 12:03:56 +02:00
homeserver = click.prompt(
click.style('Enter your homeserver URL', bold=True),
default=homeserver,
)
2020-08-03 18:32:53 +02:00
2020-08-04 12:03:56 +02:00
if not homeserver.startswith('https://'):
homeserver = 'https://' + homeserver
2020-08-03 18:32:53 +02:00
2020-08-06 17:16:43 +02:00
user_id = f'@{getpass.getuser()}:gaja-group.com'
2020-08-04 12:03:56 +02:00
user_id = click.prompt(click.style('Enter your full user ID',
bold=True),
default=user_id)
2020-08-03 18:32:53 +02:00
2020-08-04 12:03:56 +02:00
device_name = 'matrix-bot'
device_name = click.prompt(
click.style('Choose a name for this device', bold=True),
default=device_name,
)
2020-08-03 18:32:53 +02:00
2020-08-04 12:03:56 +02:00
except click.exceptions.Abort:
2020-08-06 17:16:43 +02:00
sys.exit(1)
2020-08-03 18:32:53 +02:00
return {
'homeserver': homeserver,
'user_id': user_id,
'device_name': device_name,
}
2020-08-05 23:13:23 +02:00
# Callbacks
async def __to_device_callback(self, event): # noqa
2020-08-03 18:32:53 +02:00
"""Handle events sent to device."""
try:
client = self.client
if isinstance(event, KeyVerificationStart): # first step
""" first step: receive KeyVerificationStart
KeyVerificationStart(
source={'content':
{'method': 'm.sas.v1',
'from_device': 'DEVICEIDXY',
'key_agreement_protocols':
['curve25519-hkdf-sha256', 'curve25519'],
'hashes': ['sha256'],
'message_authentication_codes':
['hkdf-hmac-sha256', 'hmac-sha256'],
'short_authentication_string':
['decimal', 'emoji'],
'transaction_id': 'SomeTxId'
},
'type': 'm.key.verification.start',
'sender': '@user2:example.org'
},
sender='@user2:example.org',
transaction_id='SomeTxId',
from_device='DEVICEIDXY',
method='m.sas.v1',
key_agreement_protocols=[
'curve25519-hkdf-sha256', 'curve25519'],
hashes=['sha256'],
message_authentication_codes=[
'hkdf-hmac-sha256', 'hmac-sha256'],
short_authentication_string=['decimal', 'emoji'])
"""
if 'emoji' not in event.short_authentication_string:
2020-08-05 23:13:23 +02:00
click.echo(
2020-08-03 18:32:53 +02:00
'Other device does not support emoji verification '
2020-08-05 23:13:23 +02:00
f'{event.short_authentication_string}.')
2020-08-03 18:32:53 +02:00
return
resp = await client.accept_key_verification(
event.transaction_id)
if isinstance(resp, ToDeviceError):
2020-08-05 23:13:23 +02:00
self.logger.warning(
f'accept_key_verification failed with {resp}',
fg='red')
2020-08-03 18:32:53 +02:00
sas = client.key_verifications[event.transaction_id]
todevice_msg = sas.share_key()
resp = await client.to_device(todevice_msg)
if isinstance(resp, ToDeviceError):
2020-08-05 23:13:23 +02:00
self.logger.warning(f'to_device failed with {resp}',
fg='red')
2020-08-03 18:32:53 +02:00
elif isinstance(event, KeyVerificationCancel): # anytime
""" at any time: receive KeyVerificationCancel
KeyVerificationCancel(source={
'content': {'code': 'm.mismatched_sas',
'reason': 'Mismatched authentication string',
'transaction_id': 'SomeTxId'},
'type': 'm.key.verification.cancel',
'sender': '@user2:example.org'},
sender='@user2:example.org',
transaction_id='SomeTxId',
code='m.mismatched_sas',
reason='Mismatched short authentication string')
"""
# There is no need to issue a
# client.cancel_key_verification(tx_id, reject=False)
# here. The SAS flow is already cancelled.
# We only need to inform the user.
2020-08-05 23:13:23 +02:00
click.echo('\nVerification has been cancelled by '
f'{event.sender} for reason "{event.reason}".')
2020-08-03 18:32:53 +02:00
elif isinstance(event, KeyVerificationKey): # second step
""" Second step is to receive KeyVerificationKey
KeyVerificationKey(
source={'content': {
'key': 'SomeCryptoKey',
'transaction_id': 'SomeTxId'},
'type': 'm.key.verification.key',
'sender': '@user2:example.org'
},
sender='@user2:example.org',
transaction_id='SomeTxId',
key='SomeCryptoKey')
"""
click.secho('\nEmoji verification initiated.\n')
sas = client.key_verifications[event.transaction_id]
emojis = sas.get_emoji()
emoji_list = [' '.join(e) for e in emojis]
click.echo(', '.join(emoji_list))
# print(f'{sas.get_emoji()}')'
print()
try:
if click.confirm(
click.style('Do the emojis match?', bold=True), ):
click.secho(
'\nMatch! The verification for this '
'device will be accepted.',
fg='green',
)
resp = await client.confirm_short_auth_string(
event.transaction_id)
if isinstance(resp, ToDeviceError):
click.secho(
'confirm_short_auth_string failed with '
f'{resp}',
fg='red',
)
else: # no, don't match, reject
click.secho(
'\nNo match! Device will NOT be verified '
'by rejecting verification.',
fg='yellow',
)
resp = await client.cancel_key_verification(
event.transaction_id, reject=True)
if isinstance(resp, ToDeviceError):
click.secho(
f'cancel_key_verification failed with {resp}',
fg='red',
)
except click.exceptions.Abort: # C or anything for cancel
click.secho(
'Cancelled by user! Verification will be '
'cancelled.',
fg='red',
)
resp = await client.cancel_key_verification(
event.transaction_id, reject=False)
if isinstance(resp, ToDeviceError):
2020-08-05 23:13:23 +02:00
self.logger.warn(
f'cancel_key_verification failed with {resp}')
2020-08-03 18:32:53 +02:00
elif isinstance(event, KeyVerificationMac): # third step
""" Third step is to receive KeyVerificationMac
KeyVerificationMac(
source={'content': {
'mac': {'ed25519:DEVICEIDXY': 'SomeKey1',
'ed25519:SomeKey2': 'SomeKey3'},
'keys': 'SomeCryptoKey4',
'transaction_id': 'SomeTxId'},
'type': 'm.key.verification.mac',
'sender': '@user2:example.org'},
sender='@user2:example.org',
transaction_id='SomeTxId',
mac={'ed25519:DEVICEIDXY': 'SomeKey1',
'ed25519:SomeKey2': 'SomeKey3'},
keys='SomeCryptoKey4')
"""
sas = client.key_verifications[event.transaction_id]
try:
todevice_msg = sas.get_mac()
except LocalProtocolError as e:
# e.g. it might have been cancelled by ourselves
click.secho(
f'Cancelled or protocol error: Reason: {e}.\n'
f'Verification with {event.sender} not '
'concluded. Try again?',
fg='yellow',
)
else:
resp = await client.to_device(todevice_msg)
if isinstance(resp, ToDeviceError):
2020-08-05 23:13:23 +02:00
self.logger.warn(f'to_device failed with {resp}')
2020-08-03 18:32:53 +02:00
# print(f'sas.we_started_it = {sas.we_started_it}\n'
# f'sas.sas_accepted = {sas.sas_accepted}\n'
# f'sas.canceled = {sas.canceled}\n'
# f'sas.timed_out = {sas.timed_out}\n'
# f'sas.verified = {sas.verified}\n'
# f'sas.verified_devices = {sas.verified_devices}\n')
2020-08-04 12:03:56 +02:00
click.secho(
'Emoji verification was successful! Please use Ctrl+C '
'to exit.',
fg='green')
2020-08-03 18:32:53 +02:00
else:
2020-08-05 23:13:23 +02:00
self.logger.warn(
f'Received unexpected event type {type(event)}. '
f'Event is {event}. Event will be ignored.')
2020-08-03 18:32:53 +02:00
except BaseException:
2020-08-05 23:13:23 +02:00
self.logger.critical(traceback.format_exc())
async def shutdown(self) -> None:
self.logger.info('Shutdown Bot')
for task in asyncio.Task.all_tasks():
task.cancel()
2020-08-06 17:16:43 +02:00
if getattr(self, 'client', None):
await self.client.close()
2020-08-03 18:32:53 +02:00
2020-08-05 23:13:23 +02:00
async def __sync_callback(self, event: any) -> None:
self.logger.debug('Client syncing and saving next batch token')
if self.__first_sync and len(self.client.invited_rooms) > 0:
for room in self.client.invited_rooms:
await self.client.join(room)
self.__first_sync = False
with open(Config.NEXT_BATCH_PATH, 'w') as next_batch_token:
2020-08-04 12:03:56 +02:00
next_batch_token.write(event.next_batch)
2020-08-05 23:13:23 +02:00
async def __invite_callback(self, source: MatrixRoom, sender: any) -> None:
await self.client.join(source.room_id)
2020-08-06 17:16:43 +02:00
async def __handle_text_message(self, room: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Handling Text Message %s', message)
for plugin in self.message_plugins.keys():
await self.message_plugins[plugin].on_message(room, message)
async def __handle_command_message(self, room: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Handling Command Message %s', message)
if 'help' not in self.command_plugins.keys():
self.plugins['help'] = all_plugins['help']
body = message.body.split(' ')
plugin = self.command_plugins['help']
if len(body) > 1:
if (body[1] in self.command_plugins.keys()):
plugin = self.command_plugins[body[1]]
self.logger.debug('Handling Command %s', body[1])
await plugin.on_command(room, message)
2020-08-05 23:13:23 +02:00
async def __text_message_callback(self, source: MatrixRoom,
message: RoomMessageText) -> None:
self.logger.debug('Text Message Recieved: %s %s: %s', source.room_id,
message.sender, message.body)
2020-08-06 17:16:43 +02:00
if (message.body.startswith(self.config['config']['chat_prefix'])):
return await self.__handle_command_message(source, message)
else:
return await self.__handle_text_message(source, message)
2020-08-05 23:13:23 +02:00
async def __message_callback(self, source: MatrixRoom,
message: RoomMessage) -> None:
self.logger.debug('Message Recieved')
if (isinstance(message, RoomMessageText)):
await self.__text_message_callback(source, message)
def __signal_handler(self) -> None:
self.loop.create_task(self.shutdown())
2020-08-06 17:16:43 +02:00
# Plugins
def __load_plugins(self) -> None:
for plugin in self.config['config']['plugins']:
if plugin in all_plugins.keys():
obj = all_plugins[plugin](self, plugin)
self.logger.info('Loading plugin %s', plugin)
if getattr(obj, 'on_command', None):
self.command_plugins[plugin] = obj
if getattr(obj, 'on_message', None):
self.message_plugins[plugin] = obj
if getattr(obj, 'on_run', None):
self.runtime_plugins[plugin] = obj
2020-08-05 23:13:23 +02:00
# Files
def __write_details_to_disk(self, resp: LoginResponse,
credentials: dict) -> None:
"""Write the required login details to disk.
It will allow following logins to be made without password.
Arguments:
---------
resp : LoginResponse - successful client login response
credentials : dict - The credentials used to sign in
"""
# open the config file in write-mode
with open(Config.CONFIG_FILE, 'w') as f:
# write the login details to disk
json.dump(
{
'credentials': {
'homeserver': credentials['homeserver'],
'device_name': credentials['device_name'],
'user_id': resp.user_id,
'device_id': resp.device_id,
'access_token': resp.access_token,
2020-08-06 17:16:43 +02:00
},
'config': {
'chat_prefix': Config.CHAT_PREFIX,
'plugins': ['help']
},
2020-08-05 23:13:23 +02:00
},
f,
)
def __read_next_batch(self) -> (str, None):
# we read the previously-written token...
next_batch_name = Config.NEXT_BATCH_PATH
if os.path.exists(next_batch_name):
with open(next_batch_name, 'r') as next_batch_token:
# ... and well async_client to use it
self.client.next_batch = next_batch_token.read()
return self.client.next_batch
# Properties
@property
def loop(self) -> asyncio.AbstractEventLoop:
return self.__loop
@property
def logger(self) -> logging.Logger:
return self.__logger
@property
def config(self) -> dict:
return self.__config
@property
def client_config(self) -> AsyncClientConfig:
return self.__client_config
@property
def client(self) -> AsyncClient:
return self.__client
2020-08-06 17:16:43 +02:00
@property
def message_plugins(self) -> list:
return self.__message_plugins
@property
def command_plugins(self) -> list:
return self.__command_plugins
@property
def runtime_plugins(self) -> list:
return self.__runtime_plugins