From c4b4e14dc01ab7a0f322e3747dc489fea178c8a9 Mon Sep 17 00:00:00 2001 From: Davte Date: Mon, 13 Apr 2020 21:00:00 +0200 Subject: [PATCH] First working version --- .gitignore | 7 + README.md | 52 ++- filebridging/__init__.py | 18 + filebridging/client.py | 795 ++++++++++++++++++++++++++++++++++++++ filebridging/server.py | 378 ++++++++++++++++++ filebridging/utilities.py | 69 ++++ requirements.txt | 0 setup.py | 75 ++++ 8 files changed, 1393 insertions(+), 1 deletion(-) create mode 100644 filebridging/__init__.py create mode 100644 filebridging/client.py create mode 100644 filebridging/server.py create mode 100644 filebridging/utilities.py create mode 100644 requirements.txt create mode 100644 setup.py diff --git a/.gitignore b/.gitignore index 7f7cccc..4087523 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,11 @@ # ---> Python + +# Configuration file +*config.py + +# Data folder +data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 13083b5..fff1643 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,53 @@ # filebridging -Share files via a bridge server. \ No newline at end of file +Share files via a bridge server using TCP over SSL and end-to-end encryption. + +## Requirements +Python3.8+ is needed for this package. + +## Usage +If you need a virtual environment, create it. +```bash +python3.8 -m venv env; +alias pip="env/bin/pip"; +alias python="env/bin/python"; +``` + +Install filebridging and read the help. +```bash +pip install filebridging +python -m filebridging.server --help +python -m filebridging.client --help +``` + +## Examples +* Client-server example + ```bash + # 3 distinct tabs + python -m filebridging.server --host localhost --port 5000 --certificate ~/.ssh/server.crt --key ~/.ssh/server.key + python -m filebridging.client s --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/file_to_send + python -m filebridging.client r --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/Downloads + ``` + +* Client-client example + ```bash + # 2 distinct tabs + python -m filebridging.client s --host localhost --port 5000 --certificate ~/.ssh/server.crt --key ~/.ssh/private.key --token 12345678 --password supersecretpasswordhere --path ~/file_to_send --standalone + python -m filebridging.client r --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/Downloads + ``` + The receiver client may be standalone as well: just add the `--key` parameter (for SSL-secured sessions) and the `--standalone` flag. + +* Configuration file example + ```python + #!/bin/python + + host = "www.example.com" + port = 5000 + certificate = "/path/to/public.crt" + key = "/path/to/private.key" + + action = 'r' + password = 'verysecretpassword' + token = 'sessiontok' + file_path = '.' + ``` diff --git a/filebridging/__init__.py b/filebridging/__init__.py new file mode 100644 index 0000000..fdc6956 --- /dev/null +++ b/filebridging/__init__.py @@ -0,0 +1,18 @@ +"""General information about this package. + +Python 3.8+ is needed to use this package. +```python3.8+ +from filebridging.client import Client +from filebridging.server import Server +help(Client) +help(Server) +``` +""" + +__author__ = "Davide Testa" +__email__ = "davide@davte.it" +__credits__ = [] +__license__ = "GNU General Public License v3.0" +__version__ = "0.0.1" +__maintainer__ = "Davide Testa" +__contact__ = "t.me/davte" diff --git a/filebridging/client.py b/filebridging/client.py new file mode 100644 index 0000000..1cff6ad --- /dev/null +++ b/filebridging/client.py @@ -0,0 +1,795 @@ +"""Receiver and sender client class. + +Arguments + - host: localhost, IPv4 address or domain (e.g. www.example.com) + - port: port to reach (must be enabled) + - action: either [S]end or [R]eceive + - file_path: file to send / destination folder + - token: session token (6-10 alphanumerical characters) + - certificate [optional]: server certificate for SSL + - key [optional]: needed only for standalone clients + - password [optional]: necessary to end-to-end encryption + - standalone [optional]: allow client-to-client communication (the host + must be reachable by both clients) +""" + +import argparse +import asyncio +import collections +import logging +import os +import random +import ssl +import string +import sys +from typing import Union + +from . import utilities + + +class Client: + """Sender or receiver client. + + Create a Client object providing host, port and other optional parameters. + Then, run it with `Client().run()` method + """ + def __init__(self, host='localhost', port=5000, ssl_context=None, + action=None, + standalone=False, + buffer_chunk_size=10 ** 4, + buffer_length_limit=10 ** 4, + file_path=None, + password=None, + token=None): + self._host = host + self._port = port + self._ssl_context = ssl_context + self._action = action + self._standalone = standalone + self._stopping = False + self._reader = None + self._writer = None + # Shared queue of bytes + self.buffer = collections.deque() + # How many bytes per chunk + self._buffer_chunk_size = buffer_chunk_size + # How many chunks in buffer + self._buffer_length_limit = buffer_length_limit + self._file_path = file_path + self._working = False + self._token = token + self._password = password + self._ssl_context = None + self._encryption_complete = False + self._file_name = None + self._file_size = None + self._file_size_string = None + + @property + def host(self) -> str: + """Host to reach. + + For standalone clients, you must be able to listen this host. + """ + return self._host + + @property + def port(self) -> int: + """Port number.""" + return self._port + + @property + def action(self) -> str: + """Client role. + + Possible values: + - `send` + - `receive` + """ + return self._action + + @property + def standalone(self) -> bool: + """Tell whether client should run as server as well.""" + return self._standalone + + @property + def stopping(self) -> bool: + return self._stopping + + @property + def reader(self) -> asyncio.StreamReader: + return self._reader + + @property + def writer(self) -> asyncio.StreamWriter: + return self._writer + + @property + def buffer_length_limit(self) -> int: + """Max number of buffer chunks in memory. + + You may want to reduce this limit to allocate less memory, or increase + it to boost performance. + """ + return self._buffer_length_limit + + @property + def buffer_chunk_size(self) -> int: + """Length (bytes) of buffer chunks in memory. + + You may want to reduce this limit to allocate less memory, or increase + it to boost performance. + """ + return self._buffer_chunk_size + + @property + def file_path(self) -> str: + """Path of file to send or destination folder.""" + return self._file_path + + @property + def working(self) -> bool: + return self._working + + @property + def ssl_context(self) -> ssl.SSLContext: + return self._ssl_context + + def set_ssl_context(self, ssl_context: ssl.SSLContext): + self._ssl_context = ssl_context + + @property + def token(self): + """Session token. + + 6-10 alphanumerical characters to provide to server to link sender and + receiver. + """ + return self._token + + @property + def password(self): + """Password for file encryption or decryption.""" + return self._password + + @property + def encryption_complete(self): + return self._encryption_complete + + @property + def file_name(self): + return self._file_name + + @property + def file_size(self): + return self._file_size + + @property + def file_size_string(self): + """Formatted file size (e.g. 64.22 MB).""" + return self._file_size_string + + async def run_client(self) -> None: + if self.action == 'send': + file_name = os.path.basename(os.path.abspath(self.file_path)) + file_size = os.path.getsize(os.path.abspath(self.file_path)) + # File size increases after encryption + # "Salted_" (8 bytes) + salt (8 bytes) + # Then, 1-16 bytes are added to make file_size a multiple of 16 + # i.e., (32 - file_size mod 16) bytes are added to original size + if self.password: + file_size += 32 - (file_size % 16) + self.set_file_information( + file_name=file_name, + file_size=file_size + ) + if self.standalone: + server = await asyncio.start_server( + ssl=self.ssl_context, + client_connected_cb=self._connect, + host=self.host, + port=self.port, + ) + async with server: + logging.info("Running at `{s.host}:{s.port}`".format(s=self)) + await server.serve_forever() + else: + try: + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context + ) + except (ConnectionRefusedError, ConnectionResetError) as exception: + logging.error(f"Connection error: {exception}") + return + await self.connect(reader=reader, writer=writer) + + async def _connect(self, reader: asyncio.StreamReader, + writer: asyncio.StreamWriter): + """Wrap connect method to catch exceptions. + + This is required since callbacks are never awaited and potential + exception would be logged at loop.close(). + Only standalone clients need this wrapper, regular clients might use + connect method directly. + """ + try: + return await self.connect(reader, writer) + except KeyboardInterrupt: + print() + except Exception as e: + logging.error(e) + + async def connect(self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter): + """Communicate with the server or the other client. + + Send information about the client (connection token, role, file name + and size), get information from the server (file name and size), wait + for start signal and then send or receive the file. + """ + self._reader = reader + self._writer = writer + + async def _write(message: Union[list, str, bytes], + terminate_line=True) -> int: + """Framework for `asyncio.StreamWriter.write` method. + + Create string from list, encode it, send and drain writer. + Return 0 on success, 1 on error. + """ + # Adapt + if type(message) is list: + message = '|'.join(map(str, message)) + if type(message) is str: + if terminate_line: + message += '\n' + message = message.encode('utf-8') + if type(message) is not bytes: + return 1 + try: + writer.write(message) + await writer.drain() + except ConnectionResetError: + logging.error("Client disconnected.") + except Exception as e: + logging.error(f"Unexpected exception:\n{e}", exc_info=True) + else: + return 0 # On success, return 0 + # On exception, return 1 + return 1 + + if self.action == 'send' or not self.standalone: + if await _write( + [self.action[0], self.token, + self.file_name, self.file_size] + ): + return + # Wait for server start signal + while 1: + server_hello = await self.reader.readline() + if not server_hello: + logging.error("Server disconnected.") + return + server_hello = server_hello.decode('utf-8').strip('\n').split('|') + if self.action == 'receive' and server_hello[0] == 's': + + self.set_file_information(file_name=server_hello[2], + file_size=server_hello[3]) + elif ( + self.standalone + and self.action == 'send' + and server_hello[0] == 'r' + ): + # Check token + if server_hello[1] != self.token: + if await _write("Invalid session token!"): + return + return + elif server_hello[0] == 'start!': + break + else: + logging.info(f"Server said: {'|'.join(server_hello)}") + if self.standalone: + if await _write("start!"): + return + break + if self.action == 'send': + await self.send(writer=self.writer) + else: + await self.receive(reader=self.reader) + + async def encrypt_file(self, input_file, output_file): + """Use openssl to encrypt the input_file. + + The encrypted file will overwrite `output_file` if it exists. + """ + self._encryption_complete = False + logging.info("Encrypting file...") + stdout, stderr = ''.encode(), ''.encode() + try: + _subprocess = await asyncio.create_subprocess_shell( + "openssl enc -aes-256-cbc " + "-md sha512 -pbkdf2 -iter 100000 -salt " + f"-in \"{input_file}\" -out \"{output_file}\" " + f"-pass pass:{self.password}" + ) + stdout, stderr = await _subprocess.communicate() + except Exception as e: + logging.error( + "Exception {e}:\n{o}\n{er}".format( + e=e, + o=stdout.decode().strip(), + er=stderr.decode().strip() + ) + ) + logging.info("Encryption completed.") + self._encryption_complete = True + + async def send(self, writer: asyncio.StreamWriter): + """Encrypt and send the file. + + Caution: if no password is provided, the file will be sent as clear + text. + """ + self._working = True + file_path = self.file_path + if self.password: + file_path = self.file_path + '.enc' + # Remove already-encrypted file if present (salt would differ) + if os.path.isfile(file_path): + os.remove(file_path) + asyncio.ensure_future( + self.encrypt_file( + input_file=self.file_path, + output_file=file_path + ) + ) + # Give encryption an edge + while not os.path.isfile(file_path): + await asyncio.sleep(.5) + logging.info("Sending file...") + bytes_sent = 0 + with open(file_path, 'rb') as file_to_send: + while not self.stopping: + output_data = file_to_send.read(self.buffer_chunk_size) + if not output_data: + # If encryption is in progress, wait and read again later + if self.password and not self.encryption_complete: + await asyncio.sleep(1) + continue + break + try: + writer.write(output_data) + await asyncio.wait_for(writer.drain(), timeout=3.0) + except ConnectionResetError: + print() # New line after progress_bar + logging.error('Server closed the connection.') + self.stop() + break + except asyncio.exceptions.TimeoutError: + print() # New line after progress_bar + logging.error('Server closed the connection.') + self.stop() + break + bytes_sent += len(output_data) + new_progress = min( + int(bytes_sent / self.file_size * 100), + 100 + ) + self.print_progress_bar( + progress=new_progress, + bytes_=bytes_sent, + ) + print() # New line after progress_bar + writer.close() + return + + async def receive(self, reader: asyncio.StreamReader): + """Download the file and decrypt it. + + If no password is provided, the file cannot be decrypted. + """ + self._working = True + file_path = os.path.join( + os.path.abspath( + self.file_path + ), + self.file_name + ) + original_file_path = file_path + if self.password: + file_path += '.enc' + logging.info("Receiving file...") + with open(file_path, 'wb') as file_to_receive: + bytes_received = 0 + while not self.stopping: + input_data = await reader.read(self.buffer_chunk_size) + bytes_received += len(input_data) + new_progress = min( + int(bytes_received / self.file_size * 100), + 100 + ) + self.print_progress_bar( + progress=new_progress, + bytes_=bytes_received + ) + if not input_data: + break + file_to_receive.write(input_data) + print() # New line after sys.stdout.write + if bytes_received < self.file_size: + logging.warning("Transmission terminated too soon!") + if self.password: + logging.error("Partial files can not be decrypted!") + return + logging.info("File received.") + if self.password: + logging.info("Decrypting file...") + stdout, stderr = ''.encode(), ''.encode() + try: + _subprocess = await asyncio.create_subprocess_shell( + "openssl enc -aes-256-cbc " + "-md sha512 -pbkdf2 -iter 100000 -salt -d " + f"-in \"{file_path}\" -out \"{original_file_path}\" " + f"-pass pass:{self.password}" + ) + stdout, stderr = await _subprocess.communicate() + logging.info("Decryption completed.") + except Exception as e: + logging.error( + "Exception {e}:\n{o}\n{er}".format( + e=e, + o=stdout.decode().strip(), + er=stderr.decode().strip() + ) + ) + logging.info("Decryption failed", exc_info=True) + + def stop(self, *_): + if self.working: + logging.info("Received interruption signal, stopping...") + self._stopping = True + if self.writer: + self.writer.close() + else: + raise KeyboardInterrupt("Not working yet...") + + def set_file_information(self, file_name=None, file_size=None): + if file_name is not None: + self._file_name = file_name + if file_size is not None: + self._file_size = int(file_size) + self._file_size_string = utilities.get_file_size_representation( + self.file_size + ) + + def run(self): + loop = asyncio.get_event_loop() + try: + loop.run_until_complete( + self.run_client() + ) + except KeyboardInterrupt: + print() + logging.error("Interrupted") + for task in asyncio.all_tasks(loop): + task.cancel() + if self.writer: + self.writer.close() + loop.run_until_complete( + self.wait_closed() + ) + loop.close() + + def print_progress_bar(self, progress: int, bytes_: int): + """Print client progress bar. + + `progress` % = `bytes_string` transferred + out of `self.file_size_string`. + """ + action = { + 'send': "Sending", + 'receive': "Receiving" + }[self.action] + bytes_string = utilities.get_file_size_representation( + bytes_ + ) + utilities.print_progress_bar( + prefix=f"\t\t\t{action} `{self.file_name}`: ", + done_symbol='#', + pending_symbol='.', + progress=progress, + scale=5, + suffix=( + " completed " + f"({bytes_string} " + f"of {self.file_size_string})" + ) + ) + + @staticmethod + async def wait_closed() -> None: + """Give time to cancelled tasks to end properly. + + Sleep .1 second and return. + """ + await asyncio.sleep(.1) + + +def get_action(action): + """Parse abbreviations for `action`.""" + if not isinstance(action, str): + return + elif action.lower().startswith('r'): + return 'receive' + elif action.lower().startswith('s'): + return 'send' + + +def get_file_path(path, action='receive'): + """Check that file `path` is correct and return it.""" + path = os.path.abspath( + os.path.expanduser(path) + ) + if ( + isinstance(path, str) + and action == 'send' + and os.path.isfile(path) + ): + return path + elif ( + isinstance(path, str) + and action == 'receive' + and os.access(os.path.dirname(path), os.W_OK) + ): + return path + elif path is not None: + logging.error(f"Invalid file: `{path}`") + + +def main(): + # noinspection SpellCheckingInspection + log_formatter = logging.Formatter( + "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s", + style='%' + ) + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + + # noinspection PyUnresolvedReferences + asyncio.selector_events.logger.setLevel(logging.ERROR) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_formatter) + console_handler.setLevel(logging.DEBUG) + root_logger.addHandler(console_handler) + + # Parse command-line arguments + cli_parser = argparse.ArgumentParser(description='Run client', + allow_abbrev=False) + cli_parser.add_argument('--host', type=str, + default=None, + required=False, + help='server address') + cli_parser.add_argument('--port', type=int, + default=None, + required=False, + help='server port') + cli_parser.add_argument('--certificate', type=str, + default=None, + required=False, + help='server SSL certificate') + cli_parser.add_argument('--key', type=str, + default=None, + required=False, + help='server SSL key (required only for ' + 'SSL-secured standalone client)') + cli_parser.add_argument('--action', type=str, + default=None, + required=False, + help='[S]end or [R]eceive') + cli_parser.add_argument('--path', type=str, + default=None, + required=False, + help='File path to send / folder path to receive') + cli_parser.add_argument('--password', '--p', '--pass', type=str, + default=None, + required=False, + help='Password for file encryption or decryption') + cli_parser.add_argument('--token', '--t', '--session_token', type=str, + default=None, + required=False, + help='Session token ' + '(must be the same for both clients)') + cli_parser.add_argument('--standalone', + action='store_true', + help='Run both as client and server') + cli_parser.add_argument('others', + metavar='R or S', + nargs='*', + help='[S]end or [R]eceive (see `action`)') + args = vars(cli_parser.parse_args()) + host = args['host'] + port = args['port'] + certificate = args['certificate'] + key = args['key'] + action = get_action(args['action']) + file_path = args['path'] + password = args['password'] + token = args['token'] + standalone = args['standalone'] + + # If host and port are not provided from command-line, try to import them + sys.path.append(os.path.abspath('.')) + if host is None: + try: + from config import host + except ImportError: + host = None + if port is None: + try: + from config import port + except ImportError: + port = None + # Take `s`, `r` etc. from command line as `action` + if action is None: + for arg in args['others']: + action = get_action(arg) + if action: + break + if action is None: + try: + from config import action + action = get_action(action) + except ImportError: + action = None + if file_path is None: + try: + from config import file_path + file_path = get_action(file_path) + except ImportError: + file_path = None + if password is None: + try: + from config import password + except ImportError: + password = None + if token is None: + try: + from config import token + except ImportError: + token = None + if certificate is None or not os.path.isfile(certificate): + try: + from config import certificate + except ImportError: + certificate = None + if key is None or not os.path.isfile(key): + try: + from config import key + except ImportError: + key = None + + # If import fails, prompt user for host or port + new_settings = {} # After getting these settings, offer to store them + while host is None: + host = input("Enter host:\t\t\t\t\t\t") + new_settings['host'] = host + while port is None: + try: + port = int(input("Enter port:\t\t\t\t\t\t")) + except ValueError: + logging.info("Invalid port. Enter a valid port number!") + port = None + new_settings['port'] = port + while action is None: + action = get_action( + input("Do you want to (R)eceive or (S)end a file?\t\t") + ) + if file_path is not None and ( + (action == 'send' + and not os.path.isfile(os.path.abspath(file_path))) + or (action == 'receive' + and not os.path.isdir(os.path.abspath(file_path))) + ): + file_path = None + while file_path is None: + if action == 'send': + file_path = get_file_path( + path=input(f"Enter file to send:\t\t\t\t\t"), + action=action + ) + if file_path and not os.path.isfile(os.path.abspath(file_path)): + file_path = None + elif action == 'receive': + file_path = get_file_path( + path=input(f"Enter destination folder:\t\t\t\t"), + action=action + ) + if file_path and not os.path.isdir(os.path.abspath(file_path)): + file_path = None + if password is None: + logging.warning( + "You have provided no password for file encryption.\n" + "Your file will be unencoded unless you provide a password in " + "config file." + ) + if token is None and action == 'send': + # Generate a random [6-10] chars-long alphanumerical token + token = ''.join( + random.SystemRandom().choice( + string.ascii_uppercase + string.digits + ) + for _ in range(random.SystemRandom().randint(6, 10)) + ) + logging.info( + "You have not provided a token for this connection.\n" + f"A token has been generated for you:\t\t\t{token}\n" + "Your peer must be informed of this token.\n" + "For future connections, you may provide a custom token writing " + "it in config file." + ) + while token is None or not (6 <= len(token) <= 10): + token = input("Please enter a 6-10 chars token.\t\t\t") + if new_settings: + answer = utilities.timed_input( + "You may store the following configuration values in " + "`config.py`.\n\n" + '\n'.join( + '\t\t'.join(map(str, item)) + for item in new_settings.items() + ) + '\n\n' + 'Do you want to store them?\t\t\t\t', + timeout=3 + ) + if answer: + with open('config.py', 'a') as configuration_file: + configuration_file.writelines( + [ + f'{name} = "{value}"\n' + if type(value) is str + else f'{name} = {value}\n' + for name, value in new_settings.items() + ] + ) + logging.info("Configuration values stored.") + else: + logging.info("Proceeding without storing values...") + ssl_context = None + if certificate and key and standalone: # Standalone client + ssl_context = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH + ) + ssl_context.load_cert_chain(certificate, key) + elif certificate: # Server-dependent client + ssl_context = ssl.create_default_context( + purpose=ssl.Purpose.SERVER_AUTH + ) + ssl_context.load_verify_locations(certificate) + else: + logging.warning( + "Please consider using SSL. To do so, add in `config.py` or " + "provide via Command Line Interface the path to a valid SSL " + "certificate. Example:\n\n" + "certificate = 'path/to/certificate.crt'" + ) + logging.info("Starting client...") + client = Client( + host=host, + port=port, + ssl_context=ssl_context, + action=action, + standalone=standalone, + file_path=file_path, + password=password, + token=token + ) + client.run() + logging.info("Stopped client") + + +if __name__ == '__main__': + main() diff --git a/filebridging/server.py b/filebridging/server.py new file mode 100644 index 0000000..7110632 --- /dev/null +++ b/filebridging/server.py @@ -0,0 +1,378 @@ +"""Server class. + +May be a local server or a publicly reachable server. + +Arguments + - host: localhost, IPv4 address or domain (e.g. www.example.com) + - port: port to reach (must be enabled) + - certificate [optional]: server certificate for SSL + - key [optional]: needed only for standalone clients +""" + +import argparse +import asyncio +import collections +import logging +import os +import ssl +from typing import Union + + +class Server: + def __init__(self, host='localhost', port=5000, ssl_context=None, + buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4): + self._host = host + self._port = port + self._ssl_context = ssl_context + self.connections = collections.OrderedDict() + # Dict of queues of bytes + self.buffers = collections.OrderedDict() + # How many bytes per chunk + self._buffer_chunk_size = buffer_chunk_size + # How many chunks in buffer + self._buffer_length_limit = buffer_length_limit + self._working = False + self._server = None + self._ssl_context = None + + @property + def host(self) -> str: + return self._host + + @property + def port(self) -> int: + return self._port + + @property + def buffer_length_limit(self) -> int: + return self._buffer_length_limit + + @property + def buffer_chunk_size(self) -> int: + return self._buffer_chunk_size + + @property + def working(self) -> bool: + return self._working + + @property + def server(self) -> asyncio.base_events.Server: + return self._server + + @property + def ssl_context(self) -> ssl.SSLContext: + return self._ssl_context + + @property + def buffer_is_full(self): + return ( + sum(len(buffer) + for buffer in self.buffers.values()) + >= self.buffer_length_limit + ) + + def set_ssl_context(self, ssl_context: ssl.SSLContext): + self._ssl_context = ssl_context + + async def run_reader(self, reader, connection_token): + while 1: + try: + # Wait one second if buffer is full + while self.buffer_is_full: + await asyncio.sleep(1) + continue + input_data = await reader.read(self.buffer_chunk_size) + if connection_token not in self.buffers: + break + self.buffers[connection_token].append(input_data) + except ConnectionResetError as e: + logging.error(e) + break + except Exception as e: + logging.error(f"Unexpected exception:\n{e}", exc_info=True) + + async def run_writer(self, writer, connection_token): + consecutive_interruptions = 0 + errors = 0 + while connection_token in self.buffers: + try: + input_data = self.buffers[connection_token].popleft() + except IndexError: + # Slow down if buffer is empty; after 1.5 s of silence, break + consecutive_interruptions += 1 + if consecutive_interruptions > 3: + break + await asyncio.sleep(.5) + continue + else: + consecutive_interruptions = 0 + if not input_data: + break + try: + writer.write(input_data) + await writer.drain() + except ConnectionResetError as e: + logging.error(e) + break + except Exception as e: + logging.error(e, exc_info=True) + errors += 1 + if errors > 3: + break + await asyncio.sleep(0.5) + writer.close() + + async def connect(self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter): + """Connect with client. + + Decide whether client is sender or receiver and start transmission. + """ + client_hello = await reader.readline() + client_hello = client_hello.decode('utf-8').strip('\n').split('|') + if len(client_hello) != 4: + await self.refuse_connection(writer=writer, + message="Invalid client_hello!") + return + connection_token = client_hello[1] + if connection_token not in self.connections: + self.connections[connection_token] = dict( + sender=False, + receiver=False + ) + + async def _write(message: Union[list, str, bytes], + terminate_line=True) -> int: + # Adapt + if type(message) is list: + message = '|'.join(map(str, message)) + if type(message) is str: + if terminate_line: + message += '\n' + message = message.encode('utf-8') + if type(message) is not bytes: + return 1 + try: + writer.write(message) + await writer.drain() + except ConnectionResetError: + logging.error("Client disconnected.") + except Exception as e: + logging.error(f"Unexpected exception:\n{e}", exc_info=True) + else: + return 0 # On success, return 0 + # On exception, disconnect and return 1 + self.disconnect(connection_token=connection_token) + return 1 + + if client_hello[0] == 's': # Sender client connection + if self.connections[connection_token]['sender']: + await self.refuse_connection( + writer=writer, + message="Invalid token! " + "A sender client is already connected!\n" + ) + return + self.connections[connection_token]['sender'] = True + self.connections[connection_token]['file_name'] = client_hello[2] + self.connections[connection_token]['file_size'] = client_hello[3] + self.buffers[connection_token] = collections.deque() + logging.info("Sender is connecting...") + index, step = 0, 1 + while not self.connections[connection_token]['receiver']: + index += 1 + if index >= step: + if await _write("Waiting for receiver..."): + return + step += 1 + index = 0 + await asyncio.sleep(.5) + # Send start signal to client + if await _write("start!"): + return + logging.info("Incoming transmission starting...") + await self.run_reader(reader=reader, + connection_token=connection_token) + logging.info("Incoming transmission ended") + elif client_hello[0] == 'r': # Receiver client connection + if self.connections[connection_token]['receiver']: + await self.refuse_connection( + writer=writer, + message="Invalid token! " + "A receiver client is already connected!\n" + ) + return + self.connections[connection_token]['receiver'] = True + logging.info("Receiver is connecting...") + index, step = 0, 1 + while not self.connections[connection_token]['sender']: + index += 1 + if index >= step: + if await _write("Waiting for sender..."): + return + step += 1 + index = 0 + await asyncio.sleep(.5) + # Send file information and start signal to client + if await _write( + ['s', + 'hidden_token', + self.connections[connection_token]['file_name'], + self.connections[connection_token]['file_size']] + ): + return + if await _write("start!"): + return + await self.run_writer(writer=writer, + connection_token=connection_token) + logging.info("Outgoing transmission ended") + self.disconnect(connection_token=connection_token) + else: + await self.refuse_connection(writer=writer, + message="Invalid client_hello!") + return + + def disconnect(self, connection_token: str) -> None: + del self.buffers[connection_token] + del self.connections[connection_token] + + def run(self): + loop = asyncio.get_event_loop() + logging.info("Starting file bridging server...") + try: + loop.run_until_complete(self.run_server()) + except KeyboardInterrupt: + print() + logging.info("Stopping...") + # Cancel connection tasks (they should be done but are pending) + for task in asyncio.all_tasks(loop): + task.cancel() + loop.run_until_complete( + self.server.wait_closed() + ) + loop.close() + logging.info("Stopped.") + + async def run_server(self): + self._server = await asyncio.start_server( + ssl=self.ssl_context, + client_connected_cb=self.connect, + host=self.host, + port=self.port, + ) + async with self.server: + logging.info("Running at `{s.host}:{s.port}`".format(s=self)) + await self.server.serve_forever() + + @staticmethod + async def refuse_connection(writer: asyncio.StreamWriter, + message: str = None): + """Send a `message` via writer and close it.""" + if message is None: + message = "Connection refused!\n" + writer.write( + message.encode('utf-8') + ) + await writer.drain() + writer.close() + + +def main(): + # noinspection SpellCheckingInspection + log_formatter = logging.Formatter( + "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s", + style='%' + ) + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + + # noinspection PyUnresolvedReferences + asyncio.selector_events.logger.setLevel(logging.ERROR) + + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_formatter) + console_handler.setLevel(logging.DEBUG) + root_logger.addHandler(console_handler) + + # Parse command-line arguments + cli_parser = argparse.ArgumentParser(description='Run server', + allow_abbrev=False) + cli_parser.add_argument('--host', type=str, + default=None, + required=False, + help='server address') + cli_parser.add_argument('--port', type=int, + default=None, + required=False, + help='server port') + cli_parser.add_argument('--certificate', type=str, + default=None, + required=False, + help='server SSL certificate') + cli_parser.add_argument('--key', type=str, + default=None, + required=False, + help='server SSL key') + args = vars(cli_parser.parse_args()) + host = args['host'] + port = args['port'] + certificate = args['certificate'] + key = args['key'] + + # If host and port are not provided from command-line, try to import them + if host is None: + try: + from config import host + except ImportError: + host = None + if port is None: + try: + from config import port + except ImportError: + port = None + + # If import fails, prompt user for host or port + while host is None: + host = input("Enter host:\t\t\t\t\t\t") + while port is None: + try: + port = int(input("Enter port:\t\t\t\t\t\t")) + except ValueError: + logging.info("Invalid port. Enter a valid port number!") + port = None + + try: + if certificate is None or not os.path.isfile(certificate): + from config import certificate + if key is None or not os.path.isfile(key): + from config import key + if not os.path.isfile(certificate): + certificate = None + if not os.path.isfile(key): + key = None + except ImportError: + certificate = None + key = None + ssl_context = None + if certificate and key: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain(certificate, key) + else: + logging.warning( + "Please consider using SSL. To do so, add in `config.py` or " + "provide via Command Line Interface the path to a valid SSL " + "key and certificate. Example:\n\n" + "key = 'path/to/secret.key'\n" + "certificate = 'path/to/certificate.crt'" + ) + server = Server( + host=host, + port=port, + ssl_context=ssl_context + ) + server.run() + + +if __name__ == '__main__': + main() diff --git a/filebridging/utilities.py b/filebridging/utilities.py new file mode 100644 index 0000000..66a206d --- /dev/null +++ b/filebridging/utilities.py @@ -0,0 +1,69 @@ +"""Useful functions.""" + +import logging +import signal +import sys + +units_of_measurements = { + 1: 'bytes', + 1000: 'KB', + 1000 * 1000: 'MB', + 1000 * 1000 * 1000: 'GB', + 1000 * 1000 * 1000 * 1000: 'TB', +} + + +def get_file_size_representation(file_size): + scale, unit = get_scale_and_unit(file_size=file_size) + if scale < 10: + return f"{file_size} {unit}" + return f"{(file_size // (scale / 100)) / 100:.2f} {unit}" + + +def get_scale_and_unit(file_size): + scale, unit = min(units_of_measurements.items()) + for scale, unit in sorted(units_of_measurements.items(), reverse=True): + if file_size > scale: + break + return scale, unit + + +def print_progress_bar(prefix='', + suffix='', + done_symbol="#", + pending_symbol=".", + progress=0, + scale=10): + progress_showed = (progress // scale) * scale + sys.stdout.write( + f"{prefix}" + f"{done_symbol * (progress_showed // scale)}" + f"{pending_symbol * ((100 - progress_showed) // scale)}\t" + f"{progress}%" + f"{suffix} \r" + ) + sys.stdout.flush() + + +def timed_input(message: str = None, + timeout: int = 5): + class TimeoutExpired(Exception): + pass + + def interrupted(signal_number, stack_frame): + """Called when read times out.""" + raise TimeoutExpired + + if message is None: + message = f"Enter something within {timeout} seconds" + + signal.alarm(timeout) + signal.signal(signal.SIGALRM, interrupted) + try: + given_input = input(message) + except TimeoutExpired: + given_input = None + print() # Print end of line + logging.info("Timeout!") + signal.alarm(0) + return given_input diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e460863 --- /dev/null +++ b/setup.py @@ -0,0 +1,75 @@ +"""Setup.""" + +import codecs +import os +import re +import setuptools +import sys + +if sys.version_info < (3, 8): + raise RuntimeError("Python3.8+ is needed to use this library") + +here = os.path.abspath(os.path.dirname(__file__)) + + +def read(*parts): + """Read file in `part.part.part.part.ext`. + + Start from `here` and follow the path given by `*parts` + """ + with codecs.open(os.path.join(here, *parts), 'r') as fp: + return fp.read() + + +def find_information(info, *file_path_parts): + """Read information in file.""" + version_file = read(*file_path_parts) + version_match = re.search( + r"^__{info}__ = ['\"]([^'\"]*)['\"]".format( + info=info + ), + version_file, + re.M + ) + if version_match: + return version_match.group(1) + raise RuntimeError("Unable to find version string.") + + +with open("README.md", "r") as readme_file: + long_description = readme_file.read() + +setuptools.setup( + name='filebridging', + version=find_information("version", "filebridging", "__init__.py"), + author=find_information("author", "filebridging", "__init__.py"), + author_email=find_information("email", "filebridging", "__init__.py"), + description=( + "Share files via a bridge server using TCP over SSL and end-to-end " + "encryption." + ), + license=find_information("license", "filebridging", "__init__.py"), + long_description=long_description, + long_description_content_type="text/markdown", + url="https://gogs.davte.it/davte/filebridging", + packages=setuptools.find_packages(), + platforms=['any'], + install_requires=[], + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Environment :: Console", + "Framework :: AsyncIO", + "Intended Audience :: End Users/Desktop", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Communications :: File Sharing", + ], + keywords=( + 'file share ' + 'tcp ssl tls end-to-end encryption ' + 'python asyncio async' + ), + include_package_data=True, +)