Refactoring

This commit is contained in:
Davte 2020-04-13 12:58:37 +02:00
parent 932760bcb6
commit 3b7aa265ab
3 changed files with 190 additions and 93 deletions

View File

@ -1,3 +1,35 @@
# filebridging
Share files via a bridge server using TCP over SSL and aes-256-cbc encryption.
## Requirements
Python3.8+ is needed for this package.
## Usage
If you need a virtual environment, create it.
```bash
python3.8 -m venv env;
alias pip="env/bin/pip";
alias python="env/bin/python";
```
Install filebridging and read the help.
```bash
pip install filebridging
python -m filebridging.server --help
python -m filebridging.client --help
```
## Examples
Client-server example
```bash
# 3 distinct tabs
python -m filebridging.server --host localhost --port 5000 --certificate ~/.ssh/server.crt --key ~/.ssh/server.key
python -m filebridging.client s --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/file_to_send
python -m filebridging.client r --host localhost --port 5000 --certificate ~/.ssh/server.crt --token 12345678 --password supersecretpasswordhere --path ~/Downloads
```
Client-client example
```bash
```

View File

@ -20,6 +20,8 @@ class Client:
self._host = host
self._port = port
self._stopping = False
self._reader = None
self._writer = None
# Shared queue of bytes
self.buffer = collections.deque()
# How many bytes per chunk
@ -47,6 +49,14 @@ class Client:
def stopping(self) -> bool:
return self._stopping
@property
def reader(self) -> asyncio.StreamReader:
return self._reader
@property
def writer(self) -> asyncio.StreamWriter:
return self._writer
@property
def buffer_length_limit(self) -> int:
return self._buffer_length_limit
@ -91,36 +101,52 @@ class Client:
def file_size(self):
return self._file_size
async def run_sending_client(self, file_path='~/output.txt'):
async def run_client(self, file_path, action):
self._file_path = file_path
if action == 'send':
file_name = os.path.basename(os.path.abspath(file_path))
file_size = os.path.getsize(os.path.abspath(file_path))
self.set_file_information(
file_name=file_name,
file_size=file_size
)
try:
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
ssl=self.ssl_context
)
self._reader = reader
self._writer = writer
except ConnectionRefusedError as exception:
logging.error(exception)
return
writer.write(
f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
(
f"s|{self.token}|"
f"{self.file_name}|{self.file_size}\n".encode('utf-8')
) if action == 'send'
else f"r|{self.token}\n".encode('utf-8')
)
self.set_file_information(file_name=file_name,
file_size=file_size)
await writer.drain()
# Wait for server start signal
while 1:
server_hello = await reader.readline()
if not server_hello:
logging.error("Server disconnected.")
logging.info("Server disconnected.")
return
server_hello = server_hello.decode('utf-8').strip('\n')
if server_hello == 'start!':
server_hello = server_hello.decode('utf-8').strip('\n').split('|')
if action == 'receive' and server_hello[0] == 's':
self.set_file_information(file_name=server_hello[2],
file_size=server_hello[3])
elif server_hello[0] == 'start!':
break
logging.info(f"Server said: {server_hello}")
else:
logging.info(f"Server said: {'|'.join(server_hello)}")
if action == 'send':
await self.send(writer=writer)
else:
await self.receive(reader=reader)
async def encrypt_file(self, input_file, output_file):
self._encryption_complete = False
@ -177,7 +203,7 @@ class Client:
writer.write(output_data)
await writer.drain()
except ConnectionResetError:
logging.info('Server closed the connection.')
logging.error('Server closed the connection.')
self.stop()
break
bytes_sent += self.buffer_chunk_size
@ -200,36 +226,6 @@ class Client:
writer.close()
return
async def run_receiving_client(self, file_path='~/input.txt'):
self._file_path = file_path
try:
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
ssl=self.ssl_context
)
except ConnectionRefusedError as exception:
logging.error(exception)
return
writer.write(f"r|{self.token}\n".encode('utf-8'))
await writer.drain()
# Wait for server start signal
while 1:
server_hello = await reader.readline()
if not server_hello:
logging.info("Server disconnected.")
return
server_hello = server_hello.decode('utf-8').strip('\n')
if server_hello.startswith('info'):
_, file_name, file_size = server_hello.split('|')
self.set_file_information(file_name=file_name,
file_size=file_size)
elif server_hello == 'start!':
break
else:
logging.info(f"Server said: {server_hello}")
await self.receive(reader=reader)
async def receive(self, reader: asyncio.StreamReader):
self._working = True
file_path = os.path.join(
@ -293,6 +289,7 @@ class Client:
if self.working:
logging.info("Received interruption signal, stopping...")
self._stopping = True
self.writer.close()
else:
raise KeyboardInterrupt("Not working yet...")
@ -302,6 +299,23 @@ class Client:
if file_size is not None:
self._file_size = int(file_size)
def run(self, file_path, action):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(
self.run_client(file_path=file_path,
action=action)
)
except KeyboardInterrupt:
logging.error("Interrupted")
for task in asyncio.all_tasks(loop):
task.cancel()
self.writer.close()
loop.run_until_complete(
self.writer.wait_closed()
)
loop.close()
def get_action(action):
"""Parse abbreviations for `action`."""
@ -362,6 +376,10 @@ def main():
default=None,
required=False,
help='server port')
cli_parser.add_argument('--certificate', type=str,
default=None,
required=False,
help='server SSL certificate')
cli_parser.add_argument('--action', type=str,
default=None,
required=False,
@ -386,6 +404,7 @@ def main():
args = vars(cli_parser.parse_args())
host = args['host']
port = args['port']
certificate = args['certificate']
action = get_action(args['action'])
file_path = args['path']
password = args['password']
@ -431,6 +450,11 @@ def main():
from config import token
except ImportError:
token = None
if certificate is None or not os.path.isfile(certificate):
try:
from config import certificate
except ImportError:
certificate = None
# If import fails, prompt user for host or port
new_settings = {} # After getting these settings, offer to store them
@ -516,42 +540,29 @@ def main():
logging.info("Configuration values stored.")
else:
logging.info("Proceeding without storing values...")
loop = asyncio.get_event_loop()
client = Client(
host=host,
port=port,
password=password,
token=token
)
try:
from config import certificate
if certificate is not None:
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
_ssl_context.check_hostname = False
_ssl_context.load_verify_locations(certificate)
client.set_ssl_context(_ssl_context)
except ImportError:
else:
logging.warning(
"Please consider using SSL. To do so, add in `config.py` or "
"provide via Command Line Interface the path to a valid SSL "
"certificate. Example:\n\n"
"certificate = 'path/to/certificate.crt'"
)
# noinspection PyUnusedLocal
certificate = None
logging.info("Starting client...")
if action == 'send':
loop.run_until_complete(
client.run_sending_client(
file_path=file_path
client.run(
file_path=file_path,
action=action
)
)
else:
loop.run_until_complete(
client.run_receiving_client(
file_path=file_path
)
)
loop.close()
logging.info("Stopped client")

View File

@ -7,7 +7,9 @@ import argparse
import asyncio
import collections
import logging
import os
import ssl
from typing import Union
class Server:
@ -80,7 +82,7 @@ class Server:
logging.error(e)
break
except Exception as e:
logging.error(e, exc_info=True)
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
async def run_writer(self, writer, connection_token):
consecutive_interruptions = 0
@ -105,6 +107,7 @@ class Server:
writer.write(input_data)
await writer.drain()
except ConnectionResetError as e:
logging.error("Here")
logging.error(e)
break
except Exception as e:
@ -134,7 +137,32 @@ class Server:
sender=False,
receiver=False
)
if client_hello[0] == 's':
async def _write(message: Union[list, str, bytes],
terminate_line=True) -> int:
# Adapt
if type(message) is list:
message = '|'.join(message)
if type(message) is str:
if terminate_line:
message += '\n'
message = message.encode('utf-8')
if type(message) is not bytes:
return 1
try:
writer.write(message)
await writer.drain()
except ConnectionResetError:
logging.error("Client disconnected.")
except Exception as e:
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
else:
return 0 # On success, return 0
# On exception, disconnect and return 1
self.disconnect(connection_token=connection_token)
return 1
if client_hello[0] == 's': # Sender client connection
if self.connections[connection_token]['sender']:
await self.refuse_connection(
writer=writer,
@ -151,19 +179,19 @@ class Server:
while not self.connections[connection_token]['receiver']:
index += 1
if index >= step:
writer.write("Waiting for receiver...\n".encode('utf-8'))
await writer.drain()
if await _write("Waiting for receiver..."):
return
step += 1
index = 0
await asyncio.sleep(.5)
# Send start signal to client
writer.write("start!\n".encode('utf-8'))
await writer.drain()
if await _write("start!"):
return
logging.info("Incoming transmission starting...")
await self.run_reader(reader=reader,
connection_token=connection_token)
logging.info("Incoming transmission ended")
else: # Receiver client connection
elif client_hello[0] == 'r': # Receiver client connection
if self.connections[connection_token]['receiver']:
await self.refuse_connection(
writer=writer,
@ -177,23 +205,30 @@ class Server:
while not self.connections[connection_token]['sender']:
index += 1
if index >= step:
writer.write("Waiting for sender...\n".encode('utf-8'))
await writer.drain()
if await _write("Waiting for sender..."):
return
step += 1
index = 0
await asyncio.sleep(.5)
# Send file information and start signal to client
writer.write(
"info|"
"s|hidden_token|"
f"{self.connections[connection_token]['file_name']}|"
f"{self.connections[connection_token]['file_size']}"
"\n".encode('utf-8')
)
writer.write("start!\n".encode('utf-8'))
await writer.drain()
if await _write("start!"):
return
await self.run_writer(writer=writer,
connection_token=connection_token)
logging.info("Outgoing transmission ended")
self.disconnect(connection_token=connection_token)
else:
await self.refuse_connection(writer=writer,
message="Invalid client_hello!")
return
def disconnect(self, connection_token: str) -> None:
del self.buffers[connection_token]
del self.connections[connection_token]
@ -255,19 +290,29 @@ def main():
root_logger.addHandler(console_handler)
# Parse command-line arguments
parser = argparse.ArgumentParser(description='Run server',
cli_parser = argparse.ArgumentParser(description='Run server',
allow_abbrev=False)
parser.add_argument('--host', type=str,
cli_parser.add_argument('--host', type=str,
default=None,
required=False,
help='server address')
parser.add_argument('--port', type=int,
cli_parser.add_argument('--port', type=int,
default=None,
required=False,
help='server port')
args = vars(parser.parse_args())
cli_parser.add_argument('--certificate', type=str,
default=None,
required=False,
help='server SSL certificate')
cli_parser.add_argument('--key', type=str,
default=None,
required=False,
help='server SSL key')
args = vars(cli_parser.parse_args())
host = args['host']
port = args['port']
certificate = args['certificate']
key = args['key']
# If host and port are not provided from command-line, try to import them
if host is None:
@ -296,13 +341,22 @@ def main():
port=port,
)
try:
# noinspection PyUnresolvedReferences
from config import certificate, key
if certificate is None or not os.path.isfile(certificate):
from config import certificate
if key is None or not os.path.isfile(key):
from config import key
if not os.path.isfile(certificate):
certificate = None
if not os.path.isfile(key):
key = None
except ImportError:
pass
if certificate and key:
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
_ssl_context.check_hostname = False
_ssl_context.load_cert_chain(certificate, key)
server.set_ssl_context(_ssl_context)
except ImportError:
else:
logging.warning(
"Please consider using SSL. To do so, add in `config.py` or "
"provide via Command Line Interface the path to a valid SSL "