filebridging/src/server.py

188 lines
5.9 KiB
Python

import argparse
import asyncio
import collections
import logging
class Server:
def __init__(self, host='localhost', port=5000,
buffer_chunk_size=10**4, buffer_length_limit=10**4):
self._host = host
self._port = port
self._stopping = False
self.buffer = collections.deque() # Shared queue of bytes
self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
self._working = False
self.at_eof = False
self._server = None
@property
def host(self) -> str:
return self._host
@property
def port(self) -> int:
return self._port
@property
def stopping(self) -> bool:
return self._stopping
@property
def buffer_length_limit(self) -> int:
return self._buffer_length_limit
@property
def buffer_chunk_size(self) -> int:
return self._buffer_chunk_size
@property
def working(self) -> bool:
return self._working
@property
def server(self) -> asyncio.base_events.Server:
return self._server
async def run_reader(self, reader):
while not self.stopping:
try:
# Stop if buffer is full
while len(self.buffer) >= self.buffer_length_limit:
await asyncio.sleep(1)
continue
input_data = await reader.read(self.buffer_chunk_size)
if reader.at_eof():
self.at_eof = True
self.buffer.append(input_data)
except Exception as e:
logging.error(e)
async def run_writer(self, writer):
while not self.stopping:
try:
# Slow down if buffer is short
if len(self.buffer) < 3:
await asyncio.sleep(.1)
try:
input_data = self.buffer.popleft()
except IndexError:
if not self.at_eof:
continue
else:
writer.write_eof()
await writer.drain()
self.at_eof = False
break
writer.write(input_data)
await writer.drain()
except Exception as e:
logging.error(e)
async def connect(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
"""Connect with client.
Decide whether client is sender or receiver and start transmission.
"""
peer_is_sender = not self.working # TODO: ask peer role
self._working = True
if peer_is_sender:
logging.info("Sender is connecting...")
await self.run_reader(reader=reader)
logging.info("Incoming transmission ended")
else:
logging.info("Receiver is connecting...")
await self.run_writer(writer=writer)
logging.info("Outgoing transmission ended")
self._working = False # Reset peer_is_sender
return
async def run_server(self):
self._server = await asyncio.start_server(
client_connected_cb=self.connect,
host=self.host,
port=self.port
)
async with self.server:
try:
await self.server.serve_forever()
except KeyboardInterrupt:
logging.info("Stopping...")
self.server.close()
await self.server.wait_closed()
return
def stop(self, *_):
if self.working and not self.stopping:
logging.info("Received interruption signal, stopping...")
self._stopping = True
else:
raise KeyboardInterrupt("Not working yet...")
if __name__ == '__main__':
# noinspection SpellCheckingInspection
log_formatter = logging.Formatter(
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
style='%'
)
root_logger = logging.getLogger()
root_logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
console_handler.setLevel(logging.DEBUG)
root_logger.addHandler(console_handler)
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Run server',
allow_abbrev=False)
parser.add_argument('--host', type=str,
default=None,
required=False,
help='server address')
parser.add_argument('--port', type=int,
default=None,
required=False,
help='server port')
args = vars(parser.parse_args())
_host = args['host']
_port = args['port']
# If _host and _port are not provided from command-line, try to import them
if _host is None:
try:
from config import host as _host
except ImportError:
_host = None
if _port is None:
try:
from config import port as _port
except ImportError:
_port = None
# If import fails, prompt user for _host or _port
while _host is None:
_host = input("Enter host:\t\t\t\t\t\t")
while _port is None:
try:
_port = int(input("Enter port:\t\t\t\t\t\t"))
except ValueError:
logging.info("Invalid port. Enter a valid port number!")
_port = None
loop = asyncio.get_event_loop()
server = Server(
host=_host,
port=_port,
)
logging.info("Starting file bridging server...")
try:
loop.run_until_complete(server.run_server())
except KeyboardInterrupt:
logging.info("Stopping...")
loop.close()