diff --git a/src/client.py b/src/client.py index 1559338..47ed61d 100644 --- a/src/client.py +++ b/src/client.py @@ -58,9 +58,17 @@ class Client: if not output_data: break writer.write(output_data) + try: + await writer.drain() + except ConnectionResetError: + logging.info('Server closed the connection.') + self.stop() + break + else: + # If transmission has succeeded, write end of file + writer.write_eof() await writer.drain() - writer.write_eof() - await writer.drain() + return async def run_receiving_client(self, file_path='~/input.txt'): self._file_path = file_path @@ -117,8 +125,8 @@ if __name__ == '__main__': loop = asyncio.get_event_loop() client = Client( # host='127.0.0.1', # localhost - host='5.249.159.33', # Aruba - port=(5000 if action == 'send' else 5001), + host='davte.it', # Aruba + port=5000, ) # loop.add_signal_handler(signal.SIGINT, client.stop, loop) logging.info("Starting client...") @@ -127,4 +135,4 @@ if __name__ == '__main__': else: loop.run_until_complete(client.run_receiving_client(file_path=_file_path)) loop.close() - logging.info("Stopped server") + logging.info("Stopped client") diff --git a/src/server.py b/src/server.py index 3fb03e2..c87e0bd 100644 --- a/src/server.py +++ b/src/server.py @@ -4,29 +4,25 @@ import logging class Server: - def __init__(self, host='localhost', input_port=5000, output_port=5001, + def __init__(self, host='localhost', port=5000, buffer_chunk_size=10**4, buffer_length_limit=10**4): self._host = host - self._input_port = input_port - self._output_port = output_port + self._port = port self._stopping = False self.buffer = collections.deque() # Shared queue of bytes self._buffer_chunk_size = buffer_chunk_size # How many bytes per chunk self._buffer_length_limit = buffer_length_limit # How many chunks in buffer self._working = False self.at_eof = False + self._server = None @property def host(self) -> str: return self._host @property - def input_port(self) -> int: - return self._input_port - - @property - def output_port(self) -> int: - return self._output_port + def port(self) -> int: + return self._port @property def stopping(self) -> bool: @@ -44,6 +40,10 @@ class Server: def working(self) -> bool: return self._working + @property + def server(self) -> asyncio.base_events.Server: + return self._server + async def run_reader(self, reader): while not self.stopping: try: @@ -79,23 +79,33 @@ class Server: except Exception as e: logging.error(e) - # noinspection PyUnusedLocal - async def handle_incoming_data(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self._working = True - asyncio.ensure_future(self.run_reader(reader=reader)) + async def connect(self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter): + """Connect with client. - # noinspection PyUnusedLocal - async def handle_outgoing_data(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): - self._working = True - asyncio.ensure_future(self.run_writer(writer=writer)) + Decide whether client is sender or receiver and start transmission. + """ + if not self.working: + self._working = True + logging.info("Sender is connecting...") + await self.run_reader(reader=reader) + logging.info("Incoming transmission ended") + else: + logging.info("Receiver is connecting...") + await self.run_writer(writer=writer) + logging.info("Outgoing transmission ended") + self.stop() + return async def run_server(self): - reader_server = await asyncio.start_server(client_connected_cb=self.handle_incoming_data, - host=self.host, port=self.input_port) - await asyncio.start_server(client_connected_cb=self.handle_outgoing_data, - host=self.host, port=self.output_port) - async with reader_server: - await reader_server.serve_forever() + self._server = await asyncio.start_server( + client_connected_cb=self.connect, + host=self.host, + port=self.port + ) + async with self.server: + await self.server.serve_forever() return def stop(self, *_): @@ -107,6 +117,7 @@ class Server: if __name__ == '__main__': + # noinspection SpellCheckingInspection log_formatter = logging.Formatter( "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s", style='%' @@ -123,8 +134,7 @@ if __name__ == '__main__': server = Server( # host='127.0.0.1', # localhost host='5.249.159.33', # Aruba - input_port=5000, - output_port=5001 + port=5000, ) logging.info("Starting file bridging server...") try: