import argparse import asyncio import collections import logging class Server: def __init__(self, host='localhost', port=5000, buffer_chunk_size=10**4, buffer_length_limit=10**4): self._host = host self._port = port self._stopping = False self.buffer = collections.deque() # Shared queue of bytes self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk self._buffer_length_limit = buffer_length_limit # How many chunks in buffer self._working = False self.at_eof = False self._server = None @property def host(self) -> str: return self._host @property def port(self) -> int: return self._port @property def stopping(self) -> bool: return self._stopping @property def buffer_length_limit(self) -> int: return self._buffer_length_limit @property def buffer_chunk_size(self) -> int: return self._buffer_chunk_size @property def working(self) -> bool: return self._working @property def server(self) -> asyncio.base_events.Server: return self._server async def run_reader(self, reader): while not self.stopping: try: # Stop if buffer is full while len(self.buffer) >= self.buffer_length_limit: await asyncio.sleep(1) continue input_data = await reader.read(self.buffer_chunk_size) if reader.at_eof(): self.at_eof = True self.buffer.append(input_data) except Exception as e: logging.error(e) async def run_writer(self, writer): while not self.stopping: try: # Slow down if buffer is short if len(self.buffer) < 3: await asyncio.sleep(.1) try: input_data = self.buffer.popleft() except IndexError: if not self.at_eof: continue else: writer.write_eof() await writer.drain() self.at_eof = False break writer.write(input_data) await writer.drain() except Exception as e: logging.error(e) async def connect(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): """Connect with client. Decide whether client is sender or receiver and start transmission. """ peer_is_sender = not self.working # TODO: ask peer role self._working = True if peer_is_sender: logging.info("Sender is connecting...") await self.run_reader(reader=reader) logging.info("Incoming transmission ended") else: logging.info("Receiver is connecting...") await self.run_writer(writer=writer) logging.info("Outgoing transmission ended") self._working = False # Reset peer_is_sender return def run(self): loop = asyncio.get_event_loop() logging.info("Starting file bridging server...") try: loop.run_until_complete(self.run_server()) except KeyboardInterrupt: logging.info("Stopping...") # Cancel connection tasks (they should be done but are pending) for task in asyncio.all_tasks(loop): task.cancel() loop.run_until_complete( self.server.wait_closed() ) loop.close() logging.info("Stooped.") async def run_server(self): self._server = await asyncio.start_server( client_connected_cb=self.connect, host=self.host, port=self.port ) async with self.server: try: await self.server.serve_forever() except KeyboardInterrupt: logging.info("Stopping...") self.server.close() await self.server.wait_closed() return def stop(self, *_): if self.working and not self.stopping: logging.info("Received interruption signal, stopping...") self._stopping = True else: raise KeyboardInterrupt("Not working yet...") if __name__ == '__main__': # noinspection SpellCheckingInspection log_formatter = logging.Formatter( "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s", style='%' ) root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) console_handler = logging.StreamHandler() console_handler.setFormatter(log_formatter) console_handler.setLevel(logging.DEBUG) root_logger.addHandler(console_handler) # Parse command-line arguments parser = argparse.ArgumentParser(description='Run server', allow_abbrev=False) parser.add_argument('--host', type=str, default=None, required=False, help='server address') parser.add_argument('--port', type=int, default=None, required=False, help='server port') args = vars(parser.parse_args()) _host = args['host'] _port = args['port'] # If _host and _port are not provided from command-line, try to import them if _host is None: try: from config import host as _host except ImportError: _host = None if _port is None: try: from config import port as _port except ImportError: _port = None # If import fails, prompt user for _host or _port while _host is None: _host = input("Enter host:\t\t\t\t\t\t") while _port is None: try: _port = int(input("Enter port:\t\t\t\t\t\t")) except ValueError: logging.info("Invalid port. Enter a valid port number!") _port = None server = Server( host=_host, port=_port, ) server.run()