From 3f5384f9e9523d96bc2d11ec1abe1f515ae652be Mon Sep 17 00:00:00 2001 From: Davte Date: Mon, 13 Apr 2020 19:45:05 +0200 Subject: [PATCH] Refactoring --- filebridging/client.py | 303 ++++++++++++++++++++++++++++---------- filebridging/server.py | 69 ++++----- filebridging/utilities.py | 42 +++++- 3 files changed, 305 insertions(+), 109 deletions(-) diff --git a/filebridging/client.py b/filebridging/client.py index 99220ca..385b8c3 100644 --- a/filebridging/client.py +++ b/filebridging/client.py @@ -9,16 +9,25 @@ import random import ssl import string import sys +from typing import Union from . import utilities class Client: - def __init__(self, host='localhost', port=3001, - buffer_chunk_size=10**4, buffer_length_limit=10**4, - password=None, token=None): + def __init__(self, host='localhost', port=5000, ssl_context=None, + action=None, + standalone=False, + buffer_chunk_size=10 ** 4, + buffer_length_limit=10 ** 4, + file_path=None, + password=None, + token=None): self._host = host self._port = port + self._ssl_context = ssl_context + self._action = action + self._standalone = standalone self._stopping = False self._reader = None self._writer = None @@ -28,7 +37,7 @@ class Client: self._buffer_chunk_size = buffer_chunk_size # How many chunks in buffer self._buffer_length_limit = buffer_length_limit - self._file_path = None + self._file_path = file_path self._working = False self._token = token self._password = password @@ -36,6 +45,7 @@ class Client: self._encryption_complete = False self._file_name = None self._file_size = None + self._file_size_string = None @property def host(self) -> str: @@ -45,6 +55,21 @@ class Client: def port(self) -> int: return self._port + @property + def action(self) -> str: + """Client role. + + Possible values: + - `send` + - `receive` + """ + return self._action + + @property + def standalone(self) -> bool: + """Tell whether client should run as server as well.""" + return self._standalone + @property def stopping(self) -> bool: return self._stopping @@ -101,52 +126,128 @@ class Client: def file_size(self): return self._file_size - async def run_client(self, file_path, action): - self._file_path = file_path - if action == 'send': - file_name = os.path.basename(os.path.abspath(file_path)) - file_size = os.path.getsize(os.path.abspath(file_path)) + @property + def file_size_string(self): + return self._file_size_string + + async def run_client(self) -> None: + if self.action == 'send': + file_name = os.path.basename(os.path.abspath(self.file_path)) + file_size = os.path.getsize(os.path.abspath(self.file_path)) + # File size increases after encryption + # "Salted_" (8 bytes) + salt (8 bytes) + # Then, 1-16 bytes are added to make file_size a multiple of 16 + # i.e., (32 - file_size mod 16) bytes are added to original size + if self.password: + file_size += 32 - (file_size % 16) self.set_file_information( file_name=file_name, file_size=file_size ) - try: - reader, writer = await asyncio.open_connection( + if self.standalone: + server = await asyncio.start_server( + ssl=self.ssl_context, + client_connected_cb=self._connect, host=self.host, port=self.port, - ssl=self.ssl_context ) - self._reader = reader - self._writer = writer - except ConnectionRefusedError as exception: - logging.error(exception) - return - writer.write( - ( - f"s|{self.token}|" - f"{self.file_name}|{self.file_size}\n".encode('utf-8') - ) if action == 'send' - else f"r|{self.token}\n".encode('utf-8') - ) - await writer.drain() + async with server: + logging.info("Running at `{s.host}:{s.port}`".format(s=self)) + await server.serve_forever() + else: + try: + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context + ) + except (ConnectionRefusedError, ConnectionResetError) as exception: + logging.error(f"Connection error: {exception}") + return + await self.connect(reader=reader, writer=writer) + + async def _connect(self, reader: asyncio.StreamReader, + writer: asyncio.StreamWriter): + try: + return await self.connect(reader, writer) + except KeyboardInterrupt: + print() + except Exception as e: + logging.error(e) + + async def connect(self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter): + self._reader = reader + self._writer = writer + + async def _write(message: Union[list, str, bytes], + terminate_line=True) -> int: + """Framework for `asyncio.StreamWriter.write` method. + + Create string from list, encode it, send and drain writer. + Return 0 on success, 1 on error. + """ + # Adapt + if type(message) is list: + message = '|'.join(map(str, message)) + if type(message) is str: + if terminate_line: + message += '\n' + message = message.encode('utf-8') + if type(message) is not bytes: + return 1 + try: + writer.write(message) + await writer.drain() + except ConnectionResetError: + logging.error("Client disconnected.") + except Exception as e: + logging.error(f"Unexpected exception:\n{e}", exc_info=True) + else: + return 0 # On success, return 0 + # On exception, return 1 + return 1 + + if self.action == 'send' or not self.standalone: + if await _write( + [self.action[0], self.token, + self.file_name, self.file_size] + ): + return # Wait for server start signal while 1: - server_hello = await reader.readline() + server_hello = await self.reader.readline() if not server_hello: - logging.info("Server disconnected.") + logging.error("Server disconnected.") return server_hello = server_hello.decode('utf-8').strip('\n').split('|') - if action == 'receive' and server_hello[0] == 's': + if self.action == 'receive' and server_hello[0] == 's': + self.set_file_information(file_name=server_hello[2], file_size=server_hello[3]) + elif ( + self.standalone + and self.action == 'send' + and server_hello[0] == 'r' + ): + # Check token + if server_hello[1] != self.token: + if await _write("Invalid session token!"): + return + return elif server_hello[0] == 'start!': break else: logging.info(f"Server said: {'|'.join(server_hello)}") - if action == 'send': - await self.send(writer=writer) + if self.standalone: + if await _write("start!"): + return + break + if self.action == 'send': + await self.send(writer=self.writer) else: - await self.receive(reader=reader) + await self.receive(reader=self.reader) async def encrypt_file(self, input_file, output_file): self._encryption_complete = False @@ -201,28 +302,27 @@ class Client: break try: writer.write(output_data) - await writer.drain() + await asyncio.wait_for(writer.drain(), timeout=3.0) except ConnectionResetError: + print() # New line after progress_bar logging.error('Server closed the connection.') self.stop() break - bytes_sent += self.buffer_chunk_size + except asyncio.exceptions.TimeoutError: + print() # New line after progress_bar + logging.error('Server closed the connection.') + self.stop() + break + bytes_sent += len(output_data) new_progress = min( int(bytes_sent / self.file_size * 100), 100 ) - progress_showed = (new_progress // 10) * 10 - sys.stdout.write( - f"\t\t\tSending `{self.file_name}`: " - f"{'#' * (progress_showed // 10)}" - f"{'.' * ((100 - progress_showed) // 10)}\t" - f"{new_progress}% completed " - f"({min(bytes_sent, self.file_size) // 1000} " - f"of {self.file_size // 1000} KB)\r" + self.print_progress_bar( + progress=new_progress, + bytes_=bytes_sent, ) - sys.stdout.flush() - sys.stdout.write('\n') - sys.stdout.flush() + print() # New line after progress_bar writer.close() return @@ -242,26 +342,19 @@ class Client: bytes_received = 0 while not self.stopping: input_data = await reader.read(self.buffer_chunk_size) - bytes_received += self.buffer_chunk_size + bytes_received += len(input_data) new_progress = min( int(bytes_received / self.file_size * 100), 100 ) - progress_showed = (new_progress // 10) * 10 - sys.stdout.write( - f"\t\t\tReceiving `{self.file_name}`: " - f"{'#' * (progress_showed // 10)}" - f"{'.' * ((100 - progress_showed) // 10)}\t" - f"{new_progress}% completed " - f"({min(bytes_received, self.file_size) // 1000} " - f"of {self.file_size // 1000} KB)\r" + self.print_progress_bar( + progress=new_progress, + bytes_=bytes_received ) - sys.stdout.flush() if not input_data: break file_to_receive.write(input_data) - sys.stdout.write('\n') - sys.stdout.flush() + print() # New line after sys.stdout.write logging.info("File received.") if self.password: logging.info("Decrypting file...") @@ -289,7 +382,8 @@ class Client: if self.working: logging.info("Received interruption signal, stopping...") self._stopping = True - self.writer.close() + if self.writer: + self.writer.close() else: raise KeyboardInterrupt("Not working yet...") @@ -298,24 +392,62 @@ class Client: self._file_name = file_name if file_size is not None: self._file_size = int(file_size) + self._file_size_string = utilities.get_file_size_representation( + self.file_size + ) - def run(self, file_path, action): + def run(self): loop = asyncio.get_event_loop() try: loop.run_until_complete( - self.run_client(file_path=file_path, - action=action) + self.run_client() ) except KeyboardInterrupt: + print() logging.error("Interrupted") for task in asyncio.all_tasks(loop): task.cancel() - self.writer.close() + if self.writer: + self.writer.close() loop.run_until_complete( - self.writer.wait_closed() + self.wait_closed() ) loop.close() + def print_progress_bar(self, progress: int, bytes_: int): + """Print client progress bar. + + `progress` % = `bytes_string` transferred + out of `self.file_size_string`. + """ + action = { + 'send': "Sending", + 'receive': "Receiving" + }[self.action] + bytes_string = utilities.get_file_size_representation( + bytes_ + ) + utilities.print_progress_bar( + prefix=f"\t\t\t{action} `{self.file_name}`: ", + done_symbol='#', + pending_symbol='.', + progress=progress, + scale=5, + suffix=( + " completed " + f"({bytes_string} " + f"of {self.file_size_string})" + ) + ) + + @staticmethod + async def wait_closed() -> None: + """Give time to cancelled tasks to end properly. + + Sleep .1 second and return. + """ + await asyncio.sleep(.1) + def get_action(action): """Parse abbreviations for `action`.""" @@ -380,6 +512,11 @@ def main(): default=None, required=False, help='server SSL certificate') + cli_parser.add_argument('--key', type=str, + default=None, + required=False, + help='server SSL key (required only for ' + 'SSL-secured standalone client)') cli_parser.add_argument('--action', type=str, default=None, required=False, @@ -397,6 +534,9 @@ def main(): required=False, help='Session token ' '(must be the same for both clients)') + cli_parser.add_argument('--standalone', + action='store_true', + help='Run both as client and server') cli_parser.add_argument('others', metavar='R or S', nargs='*', @@ -405,10 +545,12 @@ def main(): host = args['host'] port = args['port'] certificate = args['certificate'] + key = args['key'] action = get_action(args['action']) file_path = args['path'] password = args['password'] token = args['token'] + standalone = args['standalone'] # If host and port are not provided from command-line, try to import them sys.path.append(os.path.abspath('.')) @@ -455,6 +597,11 @@ def main(): from config import certificate except ImportError: certificate = None + if key is None or not os.path.isfile(key): + try: + from config import key + except ImportError: + key = None # If import fails, prompt user for host or port new_settings = {} # After getting these settings, offer to store them @@ -540,17 +687,18 @@ def main(): logging.info("Configuration values stored.") else: logging.info("Proceeding without storing values...") - client = Client( - host=host, - port=port, - password=password, - token=token - ) + ssl_context = None if certificate is not None: - _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) - _ssl_context.check_hostname = False - _ssl_context.load_verify_locations(certificate) - client.set_ssl_context(_ssl_context) + if key is None: # Server-dependent client + ssl_context = ssl.create_default_context( + purpose=ssl.Purpose.SERVER_AUTH + ) + ssl_context.load_verify_locations(certificate) + else: # Standalone client + ssl_context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH + ) + ssl_context.load_cert_chain(certificate, key) else: logging.warning( "Please consider using SSL. To do so, add in `config.py` or " @@ -559,10 +707,17 @@ def main(): "certificate = 'path/to/certificate.crt'" ) logging.info("Starting client...") - client.run( + client = Client( + host=host, + port=port, + ssl_context=ssl_context, + action=action, + standalone=standalone, file_path=file_path, - action=action + password=password, + token=token ) + client.run() logging.info("Stopped client") diff --git a/filebridging/server.py b/filebridging/server.py index 6a5d879..e01e2bb 100644 --- a/filebridging/server.py +++ b/filebridging/server.py @@ -13,10 +13,11 @@ from typing import Union class Server: - def __init__(self, host='localhost', port=5000, + def __init__(self, host='localhost', port=5000, ssl_context=None, buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4): self._host = host self._port = port + self._ssl_context = ssl_context self.connections = collections.OrderedDict() # Dict of queues of bytes self.buffers = collections.OrderedDict() @@ -87,27 +88,24 @@ class Server: async def run_writer(self, writer, connection_token): consecutive_interruptions = 0 errors = 0 - while 1: + while connection_token in self.buffers: try: - try: - if connection_token not in self.buffers: - break - input_data = self.buffers[connection_token].popleft() - except IndexError: - # Slow down if buffer is short - consecutive_interruptions += 1 - if consecutive_interruptions > 3: - break - await asyncio.sleep(.5) - continue - else: - consecutive_interruptions = 0 - if not input_data: + input_data = self.buffers[connection_token].popleft() + except IndexError: + # Slow down if buffer is empty; after 1.5 s of silence, break + consecutive_interruptions += 1 + if consecutive_interruptions > 3: break + await asyncio.sleep(.5) + continue + else: + consecutive_interruptions = 0 + if not input_data: + break + try: writer.write(input_data) await writer.drain() except ConnectionResetError as e: - logging.error("Here") logging.error(e) break except Exception as e: @@ -127,7 +125,7 @@ class Server: """ client_hello = await reader.readline() client_hello = client_hello.decode('utf-8').strip('\n').split('|') - if len(client_hello) not in (2, 4,): + if len(client_hello) != 4: await self.refuse_connection(writer=writer, message="Invalid client_hello!") return @@ -142,7 +140,7 @@ class Server: terminate_line=True) -> int: # Adapt if type(message) is list: - message = '|'.join(message) + message = '|'.join(map(str, message)) if type(message) is str: if terminate_line: message += '\n' @@ -211,12 +209,13 @@ class Server: index = 0 await asyncio.sleep(.5) # Send file information and start signal to client - writer.write( - "s|hidden_token|" - f"{self.connections[connection_token]['file_name']}|" - f"{self.connections[connection_token]['file_size']}" - "\n".encode('utf-8') - ) + if await _write( + ['s', + 'hidden_token', + self.connections[connection_token]['file_name'], + self.connections[connection_token]['file_size']] + ): + return if await _write("start!"): return await self.run_writer(writer=writer, @@ -238,6 +237,7 @@ class Server: try: loop.run_until_complete(self.run_server()) except KeyboardInterrupt: + print() logging.info("Stopping...") # Cancel connection tasks (they should be done but are pending) for task in asyncio.all_tasks(loop): @@ -336,10 +336,6 @@ def main(): logging.info("Invalid port. Enter a valid port number!") port = None - server = Server( - host=host, - port=port, - ) try: if certificate is None or not os.path.isfile(certificate): from config import certificate @@ -350,12 +346,12 @@ def main(): if not os.path.isfile(key): key = None except ImportError: - pass + certificate = None + key = None + ssl_context = None if certificate and key: - _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) - _ssl_context.check_hostname = False - _ssl_context.load_cert_chain(certificate, key) - server.set_ssl_context(_ssl_context) + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certificate, key) else: logging.warning( "Please consider using SSL. To do so, add in `config.py` or " @@ -364,6 +360,11 @@ def main(): "key = 'path/to/secret.key'\n" "certificate = 'path/to/certificate.crt'" ) + server = Server( + host=host, + port=port, + ssl_context=ssl_context + ) server.run() diff --git a/filebridging/utilities.py b/filebridging/utilities.py index a372cd5..66a206d 100644 --- a/filebridging/utilities.py +++ b/filebridging/utilities.py @@ -2,11 +2,51 @@ import logging import signal +import sys + +units_of_measurements = { + 1: 'bytes', + 1000: 'KB', + 1000 * 1000: 'MB', + 1000 * 1000 * 1000: 'GB', + 1000 * 1000 * 1000 * 1000: 'TB', +} + + +def get_file_size_representation(file_size): + scale, unit = get_scale_and_unit(file_size=file_size) + if scale < 10: + return f"{file_size} {unit}" + return f"{(file_size // (scale / 100)) / 100:.2f} {unit}" + + +def get_scale_and_unit(file_size): + scale, unit = min(units_of_measurements.items()) + for scale, unit in sorted(units_of_measurements.items(), reverse=True): + if file_size > scale: + break + return scale, unit + + +def print_progress_bar(prefix='', + suffix='', + done_symbol="#", + pending_symbol=".", + progress=0, + scale=10): + progress_showed = (progress // scale) * scale + sys.stdout.write( + f"{prefix}" + f"{done_symbol * (progress_showed // scale)}" + f"{pending_symbol * ((100 - progress_showed) // scale)}\t" + f"{progress}%" + f"{suffix} \r" + ) + sys.stdout.flush() def timed_input(message: str = None, timeout: int = 5): - class TimeoutExpired(Exception): pass