188 lines
5.9 KiB
Python
188 lines
5.9 KiB
Python
import argparse
|
|
import asyncio
|
|
import collections
|
|
import logging
|
|
|
|
|
|
class Server:
|
|
def __init__(self, host='localhost', port=5000,
|
|
buffer_chunk_size=10**4, buffer_length_limit=10**4):
|
|
self._host = host
|
|
self._port = port
|
|
self._stopping = False
|
|
self.buffer = collections.deque() # Shared queue of bytes
|
|
self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk
|
|
self._buffer_length_limit = buffer_length_limit # How many chunks in buffer
|
|
self._working = False
|
|
self.at_eof = False
|
|
self._server = None
|
|
|
|
@property
|
|
def host(self) -> str:
|
|
return self._host
|
|
|
|
@property
|
|
def port(self) -> int:
|
|
return self._port
|
|
|
|
@property
|
|
def stopping(self) -> bool:
|
|
return self._stopping
|
|
|
|
@property
|
|
def buffer_length_limit(self) -> int:
|
|
return self._buffer_length_limit
|
|
|
|
@property
|
|
def buffer_chunk_size(self) -> int:
|
|
return self._buffer_chunk_size
|
|
|
|
@property
|
|
def working(self) -> bool:
|
|
return self._working
|
|
|
|
@property
|
|
def server(self) -> asyncio.base_events.Server:
|
|
return self._server
|
|
|
|
async def run_reader(self, reader):
|
|
while not self.stopping:
|
|
try:
|
|
# Stop if buffer is full
|
|
while len(self.buffer) >= self.buffer_length_limit:
|
|
await asyncio.sleep(1)
|
|
continue
|
|
input_data = await reader.read(self.buffer_chunk_size)
|
|
if reader.at_eof():
|
|
self.at_eof = True
|
|
self.buffer.append(input_data)
|
|
except Exception as e:
|
|
logging.error(e)
|
|
|
|
async def run_writer(self, writer):
|
|
while not self.stopping:
|
|
try:
|
|
# Slow down if buffer is short
|
|
if len(self.buffer) < 3:
|
|
await asyncio.sleep(.1)
|
|
try:
|
|
input_data = self.buffer.popleft()
|
|
except IndexError:
|
|
if not self.at_eof:
|
|
continue
|
|
else:
|
|
writer.write_eof()
|
|
await writer.drain()
|
|
self.at_eof = False
|
|
break
|
|
writer.write(input_data)
|
|
await writer.drain()
|
|
except Exception as e:
|
|
logging.error(e)
|
|
|
|
async def connect(self,
|
|
reader: asyncio.StreamReader,
|
|
writer: asyncio.StreamWriter):
|
|
"""Connect with client.
|
|
|
|
Decide whether client is sender or receiver and start transmission.
|
|
"""
|
|
peer_is_sender = not self.working # TODO: ask peer role
|
|
self._working = True
|
|
if peer_is_sender:
|
|
logging.info("Sender is connecting...")
|
|
await self.run_reader(reader=reader)
|
|
logging.info("Incoming transmission ended")
|
|
else:
|
|
logging.info("Receiver is connecting...")
|
|
await self.run_writer(writer=writer)
|
|
logging.info("Outgoing transmission ended")
|
|
self._working = False # Reset peer_is_sender
|
|
return
|
|
|
|
async def run_server(self):
|
|
self._server = await asyncio.start_server(
|
|
client_connected_cb=self.connect,
|
|
host=self.host,
|
|
port=self.port
|
|
)
|
|
async with self.server:
|
|
try:
|
|
await self.server.serve_forever()
|
|
except KeyboardInterrupt:
|
|
logging.info("Stopping...")
|
|
self.server.close()
|
|
await self.server.wait_closed()
|
|
return
|
|
|
|
def stop(self, *_):
|
|
if self.working and not self.stopping:
|
|
logging.info("Received interruption signal, stopping...")
|
|
self._stopping = True
|
|
else:
|
|
raise KeyboardInterrupt("Not working yet...")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# noinspection SpellCheckingInspection
|
|
log_formatter = logging.Formatter(
|
|
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
|
|
style='%'
|
|
)
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(logging.DEBUG)
|
|
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(log_formatter)
|
|
console_handler.setLevel(logging.DEBUG)
|
|
root_logger.addHandler(console_handler)
|
|
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser(description='Run server',
|
|
allow_abbrev=False)
|
|
parser.add_argument('--host', type=str,
|
|
default=None,
|
|
required=False,
|
|
help='server address')
|
|
parser.add_argument('--port', type=int,
|
|
default=None,
|
|
required=False,
|
|
help='server port')
|
|
args = vars(parser.parse_args())
|
|
_host = args['host']
|
|
_port = args['port']
|
|
|
|
# If _host and _port are not provided from command-line, try to import them
|
|
if _host is None:
|
|
try:
|
|
from config import host as _host
|
|
except ImportError:
|
|
_host = None
|
|
if _port is None:
|
|
try:
|
|
from config import port as _port
|
|
except ImportError:
|
|
_port = None
|
|
|
|
# If import fails, prompt user for _host or _port
|
|
while _host is None:
|
|
_host = input("Enter host:\t\t\t\t\t\t")
|
|
while _port is None:
|
|
try:
|
|
_port = int(input("Enter port:\t\t\t\t\t\t"))
|
|
except ValueError:
|
|
logging.info("Invalid port. Enter a valid port number!")
|
|
_port = None
|
|
|
|
loop = asyncio.get_event_loop()
|
|
server = Server(
|
|
host=_host,
|
|
port=_port,
|
|
)
|
|
logging.info("Starting file bridging server...")
|
|
try:
|
|
loop.run_until_complete(server.run_server())
|
|
except KeyboardInterrupt:
|
|
logging.info("Stopping...")
|
|
loop.close()
|