2020-04-19 19:54:52 +02:00

392 lines
14 KiB
Python

"""Server class.
May be a local server or a publicly reachable server.
Arguments
- host: localhost, IPv4 address or domain (e.g. www.example.com)
- port: port to reach (must be enabled)
- certificate [optional]: server certificate for SSL
- key [optional]: needed only for standalone clients
"""
import argparse
import asyncio
import collections
import logging
import os
import ssl
from typing import Union
class Server:
def __init__(self, host='localhost', port=5000, ssl_context=None,
buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
self._host = host
self._port = port
self._ssl_context = ssl_context
self.connections = collections.OrderedDict()
# Dict of queues of bytes
self.buffers = collections.OrderedDict()
# How many bytes per chunk
self._buffer_chunk_size = buffer_chunk_size
# How many chunks in buffer
self._buffer_length_limit = buffer_length_limit
self._working = False
self._server = None
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def buffer_length_limit(self) -> int:
return self._buffer_length_limit
@property
def buffer_chunk_size(self) -> int:
return self._buffer_chunk_size
@property
def working(self) -> bool:
return self._working
@property
def server(self) -> asyncio.base_events.Server:
return self._server
@property
def ssl_context(self) -> ssl.SSLContext:
return self._ssl_context
@property
def buffer_is_full(self):
return (
sum(len(buffer)
for buffer in self.buffers.values())
>= self.buffer_length_limit
)
def set_ssl_context(self, ssl_context: ssl.SSLContext):
self._ssl_context = ssl_context
async def run_reader(self, reader: asyncio.StreamReader, connection_token):
received_bytes = 0
while 1:
try:
# Wait one second if buffer is full
while self.buffer_is_full:
await asyncio.sleep(1)
continue
input_data = await reader.read(self.buffer_chunk_size)
received_bytes += len(input_data)
if connection_token not in self.buffers:
break
self.buffers[connection_token].append(input_data)
except ConnectionResetError as e:
logging.error(e)
break
except Exception as e:
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
print(received_bytes)
async def run_writer(self, writer, connection_token):
consecutive_interruptions = 0
errors = 0
while connection_token in self.buffers:
try:
input_data = self.buffers[connection_token].popleft()
except IndexError:
# Slow down if buffer is empty; after 1.5 s of silence, break
consecutive_interruptions += 1
if consecutive_interruptions > 3:
print("Too many interruptions")
break
await asyncio.sleep(.5)
continue
else:
consecutive_interruptions = 0
if not input_data:
print("No input data received")
await asyncio.sleep(.5)
continue
break
try:
writer.write(input_data)
await writer.drain()
except ConnectionResetError as e:
logging.error(e)
break
except Exception as e:
logging.error(e, exc_info=True)
errors += 1
if errors > 3:
break
await asyncio.sleep(0.5)
writer.close()
async def connect(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
"""Connect with client.
Decide whether client is sender or receiver and start transmission.
"""
client_hello = await reader.readline()
try:
client_hello = client_hello.decode('utf-8').strip('\n').split('|')
except UnicodeDecodeError:
logging.error("Invalid client hello.")
return
if len(client_hello) != 4:
await self.refuse_connection(writer=writer,
message="Invalid client_hello!")
return
connection_token = client_hello[1]
if connection_token not in self.connections:
self.connections[connection_token] = dict(
sender=False,
receiver=False
)
async def _write(message: Union[list, str, bytes],
terminate_line=True) -> int:
# Adapt
if type(message) is list:
message = '|'.join(map(str, message))
if type(message) is str:
if terminate_line:
message += '\n'
message = message.encode('utf-8')
if type(message) is not bytes:
return 1
try:
writer.write(message)
await writer.drain()
except ConnectionResetError:
logging.error("Client disconnected.")
except Exception as e:
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
else:
return 0 # On success, return 0
# On exception, disconnect and return 1
logging.error("Disconnecting...")
self.disconnect(connection_token=connection_token)
return 1
if client_hello[0] == 's': # Sender client connection
if self.connections[connection_token]['sender']:
await self.refuse_connection(
writer=writer,
message="Invalid token! "
"A sender client is already connected!\n"
)
return
self.connections[connection_token]['sender'] = True
self.connections[connection_token]['file_name'] = client_hello[2]
self.connections[connection_token]['file_size'] = client_hello[3]
self.buffers[connection_token] = collections.deque()
logging.info("Sender is connecting...")
index, step = 0, 1
while not self.connections[connection_token]['receiver']:
index += 1
if index >= step:
if await _write("Waiting for receiver..."):
return
step += 1
index = 0
await asyncio.sleep(.5)
# Send start signal to client
if await _write("start!"):
return
logging.info("Incoming transmission starting...")
await self.run_reader(reader=reader,
connection_token=connection_token)
logging.info("Incoming transmission ended")
elif client_hello[0] == 'r': # Receiver client connection
if self.connections[connection_token]['receiver']:
await self.refuse_connection(
writer=writer,
message="Invalid token! "
"A receiver client is already connected!\n"
)
return
self.connections[connection_token]['receiver'] = True
logging.info("Receiver is connecting...")
index, step = 0, 1
while not self.connections[connection_token]['sender']:
index += 1
if index >= step:
if await _write("Waiting for sender..."):
return
step += 1
index = 0
await asyncio.sleep(.5)
# Send file information and start signal to client
if await _write(
['s',
'hidden_token',
self.connections[connection_token]['file_name'],
self.connections[connection_token]['file_size']]
):
return
if await _write("start!"):
return
await self.run_writer(writer=writer,
connection_token=connection_token)
logging.info("Outgoing transmission ended")
self.disconnect(connection_token=connection_token)
else:
await self.refuse_connection(writer=writer,
message="Invalid client_hello!")
return
def disconnect(self, connection_token: str) -> None:
if connection_token in self.buffers:
del self.buffers[connection_token]
if connection_token in self.connections:
del self.connections[connection_token]
def run(self):
loop = asyncio.get_event_loop()
logging.info("Starting file bridging server...")
try:
loop.run_until_complete(self.run_server())
except KeyboardInterrupt:
print()
logging.info("Stopping...")
# Cancel connection tasks (they should be done but are pending)
for task in asyncio.all_tasks(loop):
task.cancel()
loop.run_until_complete(
self.server.wait_closed()
)
loop.close()
logging.info("Stopped.")
async def run_server(self):
self._server = await asyncio.start_server(
ssl=self.ssl_context,
client_connected_cb=self.connect,
host=self.host,
port=self.port,
)
async with self.server:
logging.info("Running at `{s.host}:{s.port}`".format(s=self))
await self.server.serve_forever()
@staticmethod
async def refuse_connection(writer: asyncio.StreamWriter,
message: str = None):
"""Send a `message` via writer and close it."""
if message is None:
message = "Connection refused!\n"
writer.write(
message.encode('utf-8')
)
await writer.drain()
writer.close()
def main():
# noinspection SpellCheckingInspection
log_formatter = logging.Formatter(
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
style='%'
)
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
# noinspection PyUnresolvedReferences
asyncio.selector_events.logger.setLevel(logging.ERROR)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
console_handler.setLevel(logging.DEBUG)
root_logger.addHandler(console_handler)
# Parse command-line arguments
cli_parser = argparse.ArgumentParser(description='Run server',
allow_abbrev=False)
cli_parser.add_argument('--host', type=str,
default=None,
required=False,
help='server address')
cli_parser.add_argument('--port', type=int,
default=None,
required=False,
help='server port')
cli_parser.add_argument('--certificate', type=str,
default=None,
required=False,
help='server SSL certificate')
cli_parser.add_argument('--key', type=str,
default=None,
required=False,
help='server SSL key')
args = vars(cli_parser.parse_args())
host = args['host']
port = args['port']
certificate = args['certificate']
key = args['key']
# If host and port are not provided from command-line, try to import them
if host is None:
try:
from config import host
except ImportError:
host = None
if port is None:
try:
from config import port
except ImportError:
port = None
# If import fails, prompt user for host or port
while host is None:
host = input("Enter host:\t\t\t\t\t\t")
while port is None:
try:
port = int(input("Enter port:\t\t\t\t\t\t"))
except ValueError:
logging.info("Invalid port. Enter a valid port number!")
port = None
try:
if certificate is None or not os.path.isfile(certificate):
from config import certificate
if key is None or not os.path.isfile(key):
from config import key
if not os.path.isfile(certificate):
certificate = None
if not os.path.isfile(key):
key = None
except ImportError:
certificate = None
key = None
ssl_context = None
if certificate and key:
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(certificate, key)
else:
logging.warning(
"Please consider using SSL. To do so, add in `config.py` or "
"provide via Command Line Interface the path to a valid SSL "
"key and certificate. Example:\n\n"
"key = 'path/to/secret.key'\n"
"certificate = 'path/to/certificate.crt'"
)
server = Server(
host=host,
port=port,
ssl_context=ssl_context
)
server.run()
if __name__ == '__main__':
main()