Refactoring

This commit is contained in:
Davte 2020-04-13 12:58:37 +02:00
parent 932760bcb6
commit 3b7aa265ab
3 changed files with 190 additions and 93 deletions

View File

@ -1,3 +1,35 @@
# filebridging # filebridging
Share files via a bridge server using TCP over SSL and aes-256-cbc encryption. Share files via a bridge server using TCP over SSL and aes-256-cbc encryption.
## Requirements
Python3.8+ is needed for this package.
## Usage
If you need a virtual environment, create it.
```bash
python3.8 -m venv env;
alias pip="env/bin/pip";
alias python="env/bin/python";
```
Install filebridging and read the help.
```bash
pip install filebridging
python -m filebridging.server --help
python -m filebridging.client --help
```
## Examples
Client-server example
```bash
# 3 distinct tabs
python -m filebridging.server --host localhost --port 5000 --certificate ~/.ssh/server.crt --key ~/.ssh/server.key
python -m filebridging.client s --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/file_to_send
python -m filebridging.client r --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/Downloads
```
Client-client example
```bash
```

View File

@ -20,6 +20,8 @@ class Client:
self._host = host self._host = host
self._port = port self._port = port
self._stopping = False self._stopping = False
self._reader = None
self._writer = None
# Shared queue of bytes # Shared queue of bytes
self.buffer = collections.deque() self.buffer = collections.deque()
# How many bytes per chunk # How many bytes per chunk
@ -47,6 +49,14 @@ class Client:
def stopping(self) -> bool: def stopping(self) -> bool:
return self._stopping return self._stopping
@property
def reader(self) -> asyncio.StreamReader:
return self._reader
@property
def writer(self) -> asyncio.StreamWriter:
return self._writer
@property @property
def buffer_length_limit(self) -> int: def buffer_length_limit(self) -> int:
return self._buffer_length_limit return self._buffer_length_limit
@ -91,36 +101,52 @@ class Client:
def file_size(self): def file_size(self):
return self._file_size return self._file_size
async def run_sending_client(self, file_path='~/output.txt'): async def run_client(self, file_path, action):
self._file_path = file_path self._file_path = file_path
file_name = os.path.basename(os.path.abspath(file_path)) if action == 'send':
file_size = os.path.getsize(os.path.abspath(file_path)) file_name = os.path.basename(os.path.abspath(file_path))
file_size = os.path.getsize(os.path.abspath(file_path))
self.set_file_information(
file_name=file_name,
file_size=file_size
)
try: try:
reader, writer = await asyncio.open_connection( reader, writer = await asyncio.open_connection(
host=self.host, host=self.host,
port=self.port, port=self.port,
ssl=self.ssl_context ssl=self.ssl_context
) )
self._reader = reader
self._writer = writer
except ConnectionRefusedError as exception: except ConnectionRefusedError as exception:
logging.error(exception) logging.error(exception)
return return
writer.write( writer.write(
f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8') (
f"s|{self.token}|"
f"{self.file_name}|{self.file_size}\n".encode('utf-8')
) if action == 'send'
else f"r|{self.token}\n".encode('utf-8')
) )
self.set_file_information(file_name=file_name,
file_size=file_size)
await writer.drain() await writer.drain()
# Wait for server start signal # Wait for server start signal
while 1: while 1:
server_hello = await reader.readline() server_hello = await reader.readline()
if not server_hello: if not server_hello:
logging.error("Server disconnected.") logging.info("Server disconnected.")
return return
server_hello = server_hello.decode('utf-8').strip('\n') server_hello = server_hello.decode('utf-8').strip('\n').split('|')
if server_hello == 'start!': if action == 'receive' and server_hello[0] == 's':
self.set_file_information(file_name=server_hello[2],
file_size=server_hello[3])
elif server_hello[0] == 'start!':
break break
logging.info(f"Server said: {server_hello}") else:
await self.send(writer=writer) logging.info(f"Server said: {'|'.join(server_hello)}")
if action == 'send':
await self.send(writer=writer)
else:
await self.receive(reader=reader)
async def encrypt_file(self, input_file, output_file): async def encrypt_file(self, input_file, output_file):
self._encryption_complete = False self._encryption_complete = False
@ -177,7 +203,7 @@ class Client:
writer.write(output_data) writer.write(output_data)
await writer.drain() await writer.drain()
except ConnectionResetError: except ConnectionResetError:
logging.info('Server closed the connection.') logging.error('Server closed the connection.')
self.stop() self.stop()
break break
bytes_sent += self.buffer_chunk_size bytes_sent += self.buffer_chunk_size
@ -200,36 +226,6 @@ class Client:
writer.close() writer.close()
return return
async def run_receiving_client(self, file_path='~/input.txt'):
self._file_path = file_path
try:
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
ssl=self.ssl_context
)
except ConnectionRefusedError as exception:
logging.error(exception)
return
writer.write(f"r|{self.token}\n".encode('utf-8'))
await writer.drain()
# Wait for server start signal
while 1:
server_hello = await reader.readline()
if not server_hello:
logging.info("Server disconnected.")
return
server_hello = server_hello.decode('utf-8').strip('\n')
if server_hello.startswith('info'):
_, file_name, file_size = server_hello.split('|')
self.set_file_information(file_name=file_name,
file_size=file_size)
elif server_hello == 'start!':
break
else:
logging.info(f"Server said: {server_hello}")
await self.receive(reader=reader)
async def receive(self, reader: asyncio.StreamReader): async def receive(self, reader: asyncio.StreamReader):
self._working = True self._working = True
file_path = os.path.join( file_path = os.path.join(
@ -293,6 +289,7 @@ class Client:
if self.working: if self.working:
logging.info("Received interruption signal, stopping...") logging.info("Received interruption signal, stopping...")
self._stopping = True self._stopping = True
self.writer.close()
else: else:
raise KeyboardInterrupt("Not working yet...") raise KeyboardInterrupt("Not working yet...")
@ -302,6 +299,23 @@ class Client:
if file_size is not None: if file_size is not None:
self._file_size = int(file_size) self._file_size = int(file_size)
def run(self, file_path, action):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(
self.run_client(file_path=file_path,
action=action)
)
except KeyboardInterrupt:
logging.error("Interrupted")
for task in asyncio.all_tasks(loop):
task.cancel()
self.writer.close()
loop.run_until_complete(
self.writer.wait_closed()
)
loop.close()
def get_action(action): def get_action(action):
"""Parse abbreviations for `action`.""" """Parse abbreviations for `action`."""
@ -362,6 +376,10 @@ def main():
default=None, default=None,
required=False, required=False,
help='server port') help='server port')
cli_parser.add_argument('--certificate', type=str,
default=None,
required=False,
help='server SSL certificate')
cli_parser.add_argument('--action', type=str, cli_parser.add_argument('--action', type=str,
default=None, default=None,
required=False, required=False,
@ -386,6 +404,7 @@ def main():
args = vars(cli_parser.parse_args()) args = vars(cli_parser.parse_args())
host = args['host'] host = args['host']
port = args['port'] port = args['port']
certificate = args['certificate']
action = get_action(args['action']) action = get_action(args['action'])
file_path = args['path'] file_path = args['path']
password = args['password'] password = args['password']
@ -431,6 +450,11 @@ def main():
from config import token from config import token
except ImportError: except ImportError:
token = None token = None
if certificate is None or not os.path.isfile(certificate):
try:
from config import certificate
except ImportError:
certificate = None
# If import fails, prompt user for host or port # If import fails, prompt user for host or port
new_settings = {} # After getting these settings, offer to store them new_settings = {} # After getting these settings, offer to store them
@ -516,42 +540,29 @@ def main():
logging.info("Configuration values stored.") logging.info("Configuration values stored.")
else: else:
logging.info("Proceeding without storing values...") logging.info("Proceeding without storing values...")
loop = asyncio.get_event_loop()
client = Client( client = Client(
host=host, host=host,
port=port, port=port,
password=password, password=password,
token=token token=token
) )
try: if certificate is not None:
from config import certificate
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
_ssl_context.check_hostname = False _ssl_context.check_hostname = False
_ssl_context.load_verify_locations(certificate) _ssl_context.load_verify_locations(certificate)
client.set_ssl_context(_ssl_context) client.set_ssl_context(_ssl_context)
except ImportError: else:
logging.warning( logging.warning(
"Please consider using SSL. To do so, add in `config.py` or " "Please consider using SSL. To do so, add in `config.py` or "
"provide via Command Line Interface the path to a valid SSL " "provide via Command Line Interface the path to a valid SSL "
"certificate. Example:\n\n" "certificate. Example:\n\n"
"certificate = 'path/to/certificate.crt'" "certificate = 'path/to/certificate.crt'"
) )
# noinspection PyUnusedLocal
certificate = None
logging.info("Starting client...") logging.info("Starting client...")
if action == 'send': client.run(
loop.run_until_complete( file_path=file_path,
client.run_sending_client( action=action
file_path=file_path )
)
)
else:
loop.run_until_complete(
client.run_receiving_client(
file_path=file_path
)
)
loop.close()
logging.info("Stopped client") logging.info("Stopped client")

View File

@ -7,12 +7,14 @@ import argparse
import asyncio import asyncio
import collections import collections
import logging import logging
import os
import ssl import ssl
from typing import Union
class Server: class Server:
def __init__(self, host='localhost', port=5000, def __init__(self, host='localhost', port=5000,
buffer_chunk_size=10**4, buffer_length_limit=10**4): buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
self._host = host self._host = host
self._port = port self._port = port
self.connections = collections.OrderedDict() self.connections = collections.OrderedDict()
@ -57,9 +59,9 @@ class Server:
@property @property
def buffer_is_full(self): def buffer_is_full(self):
return ( return (
sum(len(buffer) sum(len(buffer)
for buffer in self.buffers.values()) for buffer in self.buffers.values())
>= self.buffer_length_limit >= self.buffer_length_limit
) )
def set_ssl_context(self, ssl_context: ssl.SSLContext): def set_ssl_context(self, ssl_context: ssl.SSLContext):
@ -80,7 +82,7 @@ class Server:
logging.error(e) logging.error(e)
break break
except Exception as e: except Exception as e:
logging.error(e, exc_info=True) logging.error(f"Unexpected exception:\n{e}", exc_info=True)
async def run_writer(self, writer, connection_token): async def run_writer(self, writer, connection_token):
consecutive_interruptions = 0 consecutive_interruptions = 0
@ -105,6 +107,7 @@ class Server:
writer.write(input_data) writer.write(input_data)
await writer.drain() await writer.drain()
except ConnectionResetError as e: except ConnectionResetError as e:
logging.error("Here")
logging.error(e) logging.error(e)
break break
except Exception as e: except Exception as e:
@ -134,7 +137,32 @@ class Server:
sender=False, sender=False,
receiver=False receiver=False
) )
if client_hello[0] == 's':
async def _write(message: Union[list, str, bytes],
terminate_line=True) -> int:
# Adapt
if type(message) is list:
message = '|'.join(message)
if type(message) is str:
if terminate_line:
message += '\n'
message = message.encode('utf-8')
if type(message) is not bytes:
return 1
try:
writer.write(message)
await writer.drain()
except ConnectionResetError:
logging.error("Client disconnected.")
except Exception as e:
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
else:
return 0 # On success, return 0
# On exception, disconnect and return 1
self.disconnect(connection_token=connection_token)
return 1
if client_hello[0] == 's': # Sender client connection
if self.connections[connection_token]['sender']: if self.connections[connection_token]['sender']:
await self.refuse_connection( await self.refuse_connection(
writer=writer, writer=writer,
@ -151,19 +179,19 @@ class Server:
while not self.connections[connection_token]['receiver']: while not self.connections[connection_token]['receiver']:
index += 1 index += 1
if index >= step: if index >= step:
writer.write("Waiting for receiver...\n".encode('utf-8')) if await _write("Waiting for receiver..."):
await writer.drain() return
step += 1 step += 1
index = 0 index = 0
await asyncio.sleep(.5) await asyncio.sleep(.5)
# Send start signal to client # Send start signal to client
writer.write("start!\n".encode('utf-8')) if await _write("start!"):
await writer.drain() return
logging.info("Incoming transmission starting...") logging.info("Incoming transmission starting...")
await self.run_reader(reader=reader, await self.run_reader(reader=reader,
connection_token=connection_token) connection_token=connection_token)
logging.info("Incoming transmission ended") logging.info("Incoming transmission ended")
else: # Receiver client connection elif client_hello[0] == 'r': # Receiver client connection
if self.connections[connection_token]['receiver']: if self.connections[connection_token]['receiver']:
await self.refuse_connection( await self.refuse_connection(
writer=writer, writer=writer,
@ -177,25 +205,32 @@ class Server:
while not self.connections[connection_token]['sender']: while not self.connections[connection_token]['sender']:
index += 1 index += 1
if index >= step: if index >= step:
writer.write("Waiting for sender...\n".encode('utf-8')) if await _write("Waiting for sender..."):
await writer.drain() return
step += 1 step += 1
index = 0 index = 0
await asyncio.sleep(.5) await asyncio.sleep(.5)
# Send file information and start signal to client # Send file information and start signal to client
writer.write( writer.write(
"info|" "s|hidden_token|"
f"{self.connections[connection_token]['file_name']}|" f"{self.connections[connection_token]['file_name']}|"
f"{self.connections[connection_token]['file_size']}" f"{self.connections[connection_token]['file_size']}"
"\n".encode('utf-8') "\n".encode('utf-8')
) )
writer.write("start!\n".encode('utf-8')) if await _write("start!"):
await writer.drain() return
await self.run_writer(writer=writer, await self.run_writer(writer=writer,
connection_token=connection_token) connection_token=connection_token)
logging.info("Outgoing transmission ended") logging.info("Outgoing transmission ended")
del self.buffers[connection_token] self.disconnect(connection_token=connection_token)
del self.connections[connection_token] else:
await self.refuse_connection(writer=writer,
message="Invalid client_hello!")
return
def disconnect(self, connection_token: str) -> None:
del self.buffers[connection_token]
del self.connections[connection_token]
def run(self): def run(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -255,19 +290,29 @@ def main():
root_logger.addHandler(console_handler) root_logger.addHandler(console_handler)
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser(description='Run server', cli_parser = argparse.ArgumentParser(description='Run server',
allow_abbrev=False) allow_abbrev=False)
parser.add_argument('--host', type=str, cli_parser.add_argument('--host', type=str,
default=None, default=None,
required=False, required=False,
help='server address') help='server address')
parser.add_argument('--port', type=int, cli_parser.add_argument('--port', type=int,
default=None, default=None,
required=False, required=False,
help='server port') help='server port')
args = vars(parser.parse_args()) cli_parser.add_argument('--certificate', type=str,
default=None,
required=False,
help='server SSL certificate')
cli_parser.add_argument('--key', type=str,
default=None,
required=False,
help='server SSL key')
args = vars(cli_parser.parse_args())
host = args['host'] host = args['host']
port = args['port'] port = args['port']
certificate = args['certificate']
key = args['key']
# If host and port are not provided from command-line, try to import them # If host and port are not provided from command-line, try to import them
if host is None: if host is None:
@ -296,13 +341,22 @@ def main():
port=port, port=port,
) )
try: try:
# noinspection PyUnresolvedReferences if certificate is None or not os.path.isfile(certificate):
from config import certificate, key from config import certificate
if key is None or not os.path.isfile(key):
from config import key
if not os.path.isfile(certificate):
certificate = None
if not os.path.isfile(key):
key = None
except ImportError:
pass
if certificate and key:
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
_ssl_context.check_hostname = False _ssl_context.check_hostname = False
_ssl_context.load_cert_chain(certificate, key) _ssl_context.load_cert_chain(certificate, key)
server.set_ssl_context(_ssl_context) server.set_ssl_context(_ssl_context)
except ImportError: else:
logging.warning( logging.warning(
"Please consider using SSL. To do so, add in `config.py` or " "Please consider using SSL. To do so, add in `config.py` or "
"provide via Command Line Interface the path to a valid SSL " "provide via Command Line Interface the path to a valid SSL "