From 3b7aa265abb114089e1c7ae3b29e06a03cafb5fc Mon Sep 17 00:00:00 2001 From: Davte Date: Mon, 13 Apr 2020 12:58:37 +0200 Subject: [PATCH] Refactoring --- README.md | 32 ++++++++++ filebridging/client.py | 133 ++++++++++++++++++++++------------------- filebridging/server.py | 118 ++++++++++++++++++++++++++---------- 3 files changed, 190 insertions(+), 93 deletions(-) diff --git a/README.md b/README.md index e442bee..a99cccc 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,35 @@ # filebridging Share files via a bridge server using TCP over SSL and aes-256-cbc encryption. + +## Requirements +Python3.8+ is needed for this package. + +## Usage +If you need a virtual environment, create it. +```bash +python3.8 -m venv env; +alias pip="env/bin/pip"; +alias python="env/bin/python"; +``` + +Install filebridging and read the help. +```bash +pip install filebridging +python -m filebridging.server --help +python -m filebridging.client --help +``` + +## Examples +Client-server example +```bash +# 3 distinct tabs +python -m filebridging.server --host localhost --port 5000 --certificate ~/.ssh/server.crt --key ~/.ssh/server.key +python -m filebridging.client s --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/file_to_send +python -m filebridging.client r --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/Downloads +``` + +Client-client example +```bash + +``` diff --git a/filebridging/client.py b/filebridging/client.py index d6a5732..99220ca 100644 --- a/filebridging/client.py +++ b/filebridging/client.py @@ -20,6 +20,8 @@ class Client: self._host = host self._port = port self._stopping = False + self._reader = None + self._writer = None # Shared queue of bytes self.buffer = collections.deque() # How many bytes per chunk @@ -47,6 +49,14 @@ class Client: def stopping(self) -> bool: return self._stopping + @property + def reader(self) -> asyncio.StreamReader: + return self._reader + + @property + def writer(self) -> asyncio.StreamWriter: + return self._writer + @property def buffer_length_limit(self) -> int: return self._buffer_length_limit @@ -91,36 +101,52 @@ class Client: def file_size(self): return self._file_size - async def run_sending_client(self, file_path='~/output.txt'): + async def run_client(self, file_path, action): self._file_path = file_path - file_name = os.path.basename(os.path.abspath(file_path)) - file_size = os.path.getsize(os.path.abspath(file_path)) + if action == 'send': + file_name = os.path.basename(os.path.abspath(file_path)) + file_size = os.path.getsize(os.path.abspath(file_path)) + self.set_file_information( + file_name=file_name, + file_size=file_size + ) try: reader, writer = await asyncio.open_connection( host=self.host, port=self.port, ssl=self.ssl_context ) + self._reader = reader + self._writer = writer except ConnectionRefusedError as exception: logging.error(exception) return writer.write( - f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8') + ( + f"s|{self.token}|" + f"{self.file_name}|{self.file_size}\n".encode('utf-8') + ) if action == 'send' + else f"r|{self.token}\n".encode('utf-8') ) - self.set_file_information(file_name=file_name, - file_size=file_size) await writer.drain() # Wait for server start signal while 1: server_hello = await reader.readline() if not server_hello: - logging.error("Server disconnected.") + logging.info("Server disconnected.") return - server_hello = server_hello.decode('utf-8').strip('\n') - if server_hello == 'start!': + server_hello = server_hello.decode('utf-8').strip('\n').split('|') + if action == 'receive' and server_hello[0] == 's': + self.set_file_information(file_name=server_hello[2], + file_size=server_hello[3]) + elif server_hello[0] == 'start!': break - logging.info(f"Server said: {server_hello}") - await self.send(writer=writer) + else: + logging.info(f"Server said: {'|'.join(server_hello)}") + if action == 'send': + await self.send(writer=writer) + else: + await self.receive(reader=reader) async def encrypt_file(self, input_file, output_file): self._encryption_complete = False @@ -177,7 +203,7 @@ class Client: writer.write(output_data) await writer.drain() except ConnectionResetError: - logging.info('Server closed the connection.') + logging.error('Server closed the connection.') self.stop() break bytes_sent += self.buffer_chunk_size @@ -200,36 +226,6 @@ class Client: writer.close() return - async def run_receiving_client(self, file_path='~/input.txt'): - self._file_path = file_path - try: - reader, writer = await asyncio.open_connection( - host=self.host, - port=self.port, - ssl=self.ssl_context - ) - except ConnectionRefusedError as exception: - logging.error(exception) - return - writer.write(f"r|{self.token}\n".encode('utf-8')) - await writer.drain() - # Wait for server start signal - while 1: - server_hello = await reader.readline() - if not server_hello: - logging.info("Server disconnected.") - return - server_hello = server_hello.decode('utf-8').strip('\n') - if server_hello.startswith('info'): - _, file_name, file_size = server_hello.split('|') - self.set_file_information(file_name=file_name, - file_size=file_size) - elif server_hello == 'start!': - break - else: - logging.info(f"Server said: {server_hello}") - await self.receive(reader=reader) - async def receive(self, reader: asyncio.StreamReader): self._working = True file_path = os.path.join( @@ -293,6 +289,7 @@ class Client: if self.working: logging.info("Received interruption signal, stopping...") self._stopping = True + self.writer.close() else: raise KeyboardInterrupt("Not working yet...") @@ -302,6 +299,23 @@ class Client: if file_size is not None: self._file_size = int(file_size) + def run(self, file_path, action): + loop = asyncio.get_event_loop() + try: + loop.run_until_complete( + self.run_client(file_path=file_path, + action=action) + ) + except KeyboardInterrupt: + logging.error("Interrupted") + for task in asyncio.all_tasks(loop): + task.cancel() + self.writer.close() + loop.run_until_complete( + self.writer.wait_closed() + ) + loop.close() + def get_action(action): """Parse abbreviations for `action`.""" @@ -362,6 +376,10 @@ def main(): default=None, required=False, help='server port') + cli_parser.add_argument('--certificate', type=str, + default=None, + required=False, + help='server SSL certificate') cli_parser.add_argument('--action', type=str, default=None, required=False, @@ -386,6 +404,7 @@ def main(): args = vars(cli_parser.parse_args()) host = args['host'] port = args['port'] + certificate = args['certificate'] action = get_action(args['action']) file_path = args['path'] password = args['password'] @@ -431,6 +450,11 @@ def main(): from config import token except ImportError: token = None + if certificate is None or not os.path.isfile(certificate): + try: + from config import certificate + except ImportError: + certificate = None # If import fails, prompt user for host or port new_settings = {} # After getting these settings, offer to store them @@ -516,42 +540,29 @@ def main(): logging.info("Configuration values stored.") else: logging.info("Proceeding without storing values...") - loop = asyncio.get_event_loop() client = Client( host=host, port=port, password=password, token=token ) - try: - from config import certificate + if certificate is not None: _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) _ssl_context.check_hostname = False _ssl_context.load_verify_locations(certificate) client.set_ssl_context(_ssl_context) - except ImportError: + else: logging.warning( "Please consider using SSL. To do so, add in `config.py` or " "provide via Command Line Interface the path to a valid SSL " "certificate. Example:\n\n" "certificate = 'path/to/certificate.crt'" ) - # noinspection PyUnusedLocal - certificate = None logging.info("Starting client...") - if action == 'send': - loop.run_until_complete( - client.run_sending_client( - file_path=file_path - ) - ) - else: - loop.run_until_complete( - client.run_receiving_client( - file_path=file_path - ) - ) - loop.close() + client.run( + file_path=file_path, + action=action + ) logging.info("Stopped client") diff --git a/filebridging/server.py b/filebridging/server.py index 01131d3..6a5d879 100644 --- a/filebridging/server.py +++ b/filebridging/server.py @@ -7,12 +7,14 @@ import argparse import asyncio import collections import logging +import os import ssl +from typing import Union class Server: def __init__(self, host='localhost', port=5000, - buffer_chunk_size=10**4, buffer_length_limit=10**4): + buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4): self._host = host self._port = port self.connections = collections.OrderedDict() @@ -57,9 +59,9 @@ class Server: @property def buffer_is_full(self): return ( - sum(len(buffer) - for buffer in self.buffers.values()) - >= self.buffer_length_limit + sum(len(buffer) + for buffer in self.buffers.values()) + >= self.buffer_length_limit ) def set_ssl_context(self, ssl_context: ssl.SSLContext): @@ -80,7 +82,7 @@ class Server: logging.error(e) break except Exception as e: - logging.error(e, exc_info=True) + logging.error(f"Unexpected exception:\n{e}", exc_info=True) async def run_writer(self, writer, connection_token): consecutive_interruptions = 0 @@ -105,6 +107,7 @@ class Server: writer.write(input_data) await writer.drain() except ConnectionResetError as e: + logging.error("Here") logging.error(e) break except Exception as e: @@ -134,7 +137,32 @@ class Server: sender=False, receiver=False ) - if client_hello[0] == 's': + + async def _write(message: Union[list, str, bytes], + terminate_line=True) -> int: + # Adapt + if type(message) is list: + message = '|'.join(message) + if type(message) is str: + if terminate_line: + message += '\n' + message = message.encode('utf-8') + if type(message) is not bytes: + return 1 + try: + writer.write(message) + await writer.drain() + except ConnectionResetError: + logging.error("Client disconnected.") + except Exception as e: + logging.error(f"Unexpected exception:\n{e}", exc_info=True) + else: + return 0 # On success, return 0 + # On exception, disconnect and return 1 + self.disconnect(connection_token=connection_token) + return 1 + + if client_hello[0] == 's': # Sender client connection if self.connections[connection_token]['sender']: await self.refuse_connection( writer=writer, @@ -151,19 +179,19 @@ class Server: while not self.connections[connection_token]['receiver']: index += 1 if index >= step: - writer.write("Waiting for receiver...\n".encode('utf-8')) - await writer.drain() + if await _write("Waiting for receiver..."): + return step += 1 index = 0 await asyncio.sleep(.5) # Send start signal to client - writer.write("start!\n".encode('utf-8')) - await writer.drain() + if await _write("start!"): + return logging.info("Incoming transmission starting...") await self.run_reader(reader=reader, connection_token=connection_token) logging.info("Incoming transmission ended") - else: # Receiver client connection + elif client_hello[0] == 'r': # Receiver client connection if self.connections[connection_token]['receiver']: await self.refuse_connection( writer=writer, @@ -177,25 +205,32 @@ class Server: while not self.connections[connection_token]['sender']: index += 1 if index >= step: - writer.write("Waiting for sender...\n".encode('utf-8')) - await writer.drain() + if await _write("Waiting for sender..."): + return step += 1 index = 0 await asyncio.sleep(.5) # Send file information and start signal to client writer.write( - "info|" + "s|hidden_token|" f"{self.connections[connection_token]['file_name']}|" f"{self.connections[connection_token]['file_size']}" "\n".encode('utf-8') ) - writer.write("start!\n".encode('utf-8')) - await writer.drain() + if await _write("start!"): + return await self.run_writer(writer=writer, connection_token=connection_token) logging.info("Outgoing transmission ended") - del self.buffers[connection_token] - del self.connections[connection_token] + self.disconnect(connection_token=connection_token) + else: + await self.refuse_connection(writer=writer, + message="Invalid client_hello!") + return + + def disconnect(self, connection_token: str) -> None: + del self.buffers[connection_token] + del self.connections[connection_token] def run(self): loop = asyncio.get_event_loop() @@ -255,19 +290,29 @@ def main(): root_logger.addHandler(console_handler) # Parse command-line arguments - parser = argparse.ArgumentParser(description='Run server', - allow_abbrev=False) - parser.add_argument('--host', type=str, - default=None, - required=False, - help='server address') - parser.add_argument('--port', type=int, - default=None, - required=False, - help='server port') - args = vars(parser.parse_args()) + cli_parser = argparse.ArgumentParser(description='Run server', + allow_abbrev=False) + cli_parser.add_argument('--host', type=str, + default=None, + required=False, + help='server address') + cli_parser.add_argument('--port', type=int, + default=None, + required=False, + help='server port') + cli_parser.add_argument('--certificate', type=str, + default=None, + required=False, + help='server SSL certificate') + cli_parser.add_argument('--key', type=str, + default=None, + required=False, + help='server SSL key') + args = vars(cli_parser.parse_args()) host = args['host'] port = args['port'] + certificate = args['certificate'] + key = args['key'] # If host and port are not provided from command-line, try to import them if host is None: @@ -296,13 +341,22 @@ def main(): port=port, ) try: - # noinspection PyUnresolvedReferences - from config import certificate, key + if certificate is None or not os.path.isfile(certificate): + from config import certificate + if key is None or not os.path.isfile(key): + from config import key + if not os.path.isfile(certificate): + certificate = None + if not os.path.isfile(key): + key = None + except ImportError: + pass + if certificate and key: _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) _ssl_context.check_hostname = False _ssl_context.load_cert_chain(certificate, key) server.set_ssl_context(_ssl_context) - except ImportError: + else: logging.warning( "Please consider using SSL. To do so, add in `config.py` or " "provide via Command Line Interface the path to a valid SSL "