filebridging/src/server.py
2020-04-10 10:18:24 +02:00

224 lines
7.1 KiB
Python

import argparse
import asyncio
import collections
import logging
import ssl
class Server:
def __init__(self, host='localhost', port=5000,
buffer_chunk_size=10**4, buffer_length_limit=10**4):
self._host = host
self._port = port
self._stopping = False
self.buffer = collections.deque() # Shared queue of bytes
self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
self._working = False
self._server = None
self._ssl_context = None
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def stopping(self) -> bool:
return self._stopping
@property
def buffer_length_limit(self) -> int:
return self._buffer_length_limit
@property
def buffer_chunk_size(self) -> int:
return self._buffer_chunk_size
@property
def working(self) -> bool:
return self._working
@property
def server(self) -> asyncio.base_events.Server:
return self._server
@property
def ssl_context(self) -> ssl.SSLContext:
return self._ssl_context
def set_ssl_context(self, ssl_context: ssl.SSLContext):
self._ssl_context = ssl_context
async def run_reader(self, reader):
while not self.stopping:
try:
# Stop if buffer is full
while len(self.buffer) >= self.buffer_length_limit:
await asyncio.sleep(1)
continue
input_data = await reader.read(self.buffer_chunk_size)
self.buffer.append(input_data)
except Exception as e:
logging.error(e)
async def run_writer(self, writer):
consecutive_interruptions = 0
errors = 0
while not self.stopping:
try:
try:
input_data = self.buffer.popleft()
except IndexError:
# Slow down if buffer is short
consecutive_interruptions += 1
if consecutive_interruptions > 3:
break
await asyncio.sleep(.5)
continue
else:
consecutive_interruptions = 0
if not input_data:
break
writer.write(input_data)
await writer.drain()
except Exception as e:
logging.error(e)
errors += 1
if errors > 3:
break
await asyncio.sleep(0.5)
writer.close()
async def connect(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
"""Connect with client.
Decide whether client is sender or receiver and start transmission.
"""
client_hello = await reader.readline()
peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
writer.write("Start!\n".encode('utf-8')) # Send start signal to client
await writer.drain()
if peer_is_sender:
self._working = True
logging.info("Sender is connecting...")
await self.run_reader(reader=reader)
logging.info("Incoming transmission ended")
else:
logging.info("Receiver is connecting...")
await self.run_writer(writer=writer)
logging.info("Outgoing transmission ended")
self._working = False
return
def run(self):
loop = asyncio.get_event_loop()
logging.info("Starting file bridging server...")
try:
loop.run_until_complete(self.run_server())
except KeyboardInterrupt:
logging.info("Stopping...")
# Cancel connection tasks (they should be done but are pending)
for task in asyncio.all_tasks(loop):
task.cancel()
loop.run_until_complete(
self.server.wait_closed()
)
loop.close()
logging.info("Stopped.")
async def run_server(self):
self._server = await asyncio.start_server(
ssl=self.ssl_context,
client_connected_cb=self.connect,
host=self.host,
port=self.port,
)
async with self.server:
try:
await self.server.serve_forever()
except KeyboardInterrupt:
logging.info("Stopping...")
self.server.close()
await self.server.wait_closed()
return
def stop(self, *_):
if self.working and not self.stopping:
logging.info("Received interruption signal, stopping...")
self._stopping = True
else:
raise KeyboardInterrupt("Not working yet...")
if __name__ == '__main__':
# noinspection SpellCheckingInspection
log_formatter = logging.Formatter(
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
style='%'
)
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
console_handler.setLevel(logging.DEBUG)
root_logger.addHandler(console_handler)
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Run server',
allow_abbrev=False)
parser.add_argument('--host', type=str,
default=None,
required=False,
help='server address')
parser.add_argument('--port', type=int,
default=None,
required=False,
help='server port')
args = vars(parser.parse_args())
_host = args['host']
_port = args['port']
# If _host and _port are not provided from command-line, try to import them
if _host is None:
try:
from config import host as _host
except ImportError:
_host = None
if _port is None:
try:
from config import port as _port
except ImportError:
_port = None
# If import fails, prompt user for _host or _port
while _host is None:
_host = input("Enter host:\t\t\t\t\t\t")
while _port is None:
try:
_port = int(input("Enter port:\t\t\t\t\t\t"))
except ValueError:
logging.info("Invalid port. Enter a valid port number!")
_port = None
server = Server(
host=_host,
port=_port,
)
try:
from config import certificate, key
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
_ssl_context.check_hostname = False
_ssl_context.load_cert_chain(certificate, key)
server.set_ssl_context(_ssl_context)
except ImportError:
logging.info("Please consider using SSL.")
certificate, key = None, None
server.run()