import argparse import asyncio import collections import logging import ssl class Server: def __init__(self, host='localhost', port=5000, buffer_chunk_size=10**4, buffer_length_limit=10**4): self._host = host self._port = port self._stopping = False self.buffer = collections.deque() # Shared queue of bytes self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk self._buffer_length_limit = buffer_length_limit # How many chunks in buffer self._working = False self._server = None self._ssl_context = None @property def host(self) -> str: return self._host @property def port(self) -> int: return self._port @property def stopping(self) -> bool: return self._stopping @property def buffer_length_limit(self) -> int: return self._buffer_length_limit @property def buffer_chunk_size(self) -> int: return self._buffer_chunk_size @property def working(self) -> bool: return self._working @property def server(self) -> asyncio.base_events.Server: return self._server @property def ssl_context(self) -> ssl.SSLContext: return self._ssl_context def set_ssl_context(self, ssl_context: ssl.SSLContext): self._ssl_context = ssl_context async def run_reader(self, reader): while not self.stopping: try: # Stop if buffer is full while len(self.buffer) >= self.buffer_length_limit: await asyncio.sleep(1) continue input_data = await reader.read(self.buffer_chunk_size) self.buffer.append(input_data) except Exception as e: logging.error(e) async def run_writer(self, writer): consecutive_interruptions = 0 errors = 0 while not self.stopping: try: try: input_data = self.buffer.popleft() except IndexError: # Slow down if buffer is short consecutive_interruptions += 1 if consecutive_interruptions > 3: break await asyncio.sleep(.5) continue else: consecutive_interruptions = 0 if not input_data: break writer.write(input_data) await writer.drain() except Exception as e: logging.error(e) errors += 1 if errors > 3: break await asyncio.sleep(0.5) writer.close() async def connect(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): """Connect with client. Decide whether client is sender or receiver and start transmission. """ client_hello = await reader.readline() peer_is_sender = client_hello.decode('utf-8') == 'sender\n' writer.write("Start!\n".encode('utf-8')) # Send start signal to client await writer.drain() if peer_is_sender: self._working = True logging.info("Sender is connecting...") await self.run_reader(reader=reader) logging.info("Incoming transmission ended") else: logging.info("Receiver is connecting...") await self.run_writer(writer=writer) logging.info("Outgoing transmission ended") self._working = False return def run(self): loop = asyncio.get_event_loop() logging.info("Starting file bridging server...") try: loop.run_until_complete(self.run_server()) except KeyboardInterrupt: logging.info("Stopping...") # Cancel connection tasks (they should be done but are pending) for task in asyncio.all_tasks(loop): task.cancel() loop.run_until_complete( self.server.wait_closed() ) loop.close() logging.info("Stopped.") async def run_server(self): self._server = await asyncio.start_server( ssl=self.ssl_context, client_connected_cb=self.connect, host=self.host, port=self.port, ) async with self.server: try: await self.server.serve_forever() except KeyboardInterrupt: logging.info("Stopping...") self.server.close() await self.server.wait_closed() return def stop(self, *_): if self.working and not self.stopping: logging.info("Received interruption signal, stopping...") self._stopping = True else: raise KeyboardInterrupt("Not working yet...") if __name__ == '__main__': # noinspection SpellCheckingInspection log_formatter = logging.Formatter( "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s", style='%' ) root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) console_handler = logging.StreamHandler() console_handler.setFormatter(log_formatter) console_handler.setLevel(logging.DEBUG) root_logger.addHandler(console_handler) # Parse command-line arguments parser = argparse.ArgumentParser(description='Run server', allow_abbrev=False) parser.add_argument('--host', type=str, default=None, required=False, help='server address') parser.add_argument('--port', type=int, default=None, required=False, help='server port') args = vars(parser.parse_args()) _host = args['host'] _port = args['port'] # If _host and _port are not provided from command-line, try to import them if _host is None: try: from config import host as _host except ImportError: _host = None if _port is None: try: from config import port as _port except ImportError: _port = None # If import fails, prompt user for _host or _port while _host is None: _host = input("Enter host:\t\t\t\t\t\t") while _port is None: try: _port = int(input("Enter port:\t\t\t\t\t\t")) except ValueError: logging.info("Invalid port. Enter a valid port number!") _port = None server = Server( host=_host, port=_port, ) try: from config import certificate, key _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) _ssl_context.check_hostname = False _ssl_context.load_cert_chain(certificate, key) server.set_ssl_context(_ssl_context) except ImportError: logging.info("Please consider using SSL.") certificate, key = None, None server.run()