Refactoring

This commit is contained in:
Davte 2020-04-13 19:45:05 +02:00
parent 3b7aa265ab
commit 3f5384f9e9
3 changed files with 305 additions and 109 deletions

View File

@ -9,16 +9,25 @@ import random
import ssl
import string
import sys
from typing import Union
from . import utilities
class Client:
def __init__(self, host='localhost', port=3001,
buffer_chunk_size=10**4, buffer_length_limit=10**4,
password=None, token=None):
def __init__(self, host='localhost', port=5000, ssl_context=None,
action=None,
standalone=False,
buffer_chunk_size=10 ** 4,
buffer_length_limit=10 ** 4,
file_path=None,
password=None,
token=None):
self._host = host
self._port = port
self._ssl_context = ssl_context
self._action = action
self._standalone = standalone
self._stopping = False
self._reader = None
self._writer = None
@ -28,7 +37,7 @@ class Client:
self._buffer_chunk_size = buffer_chunk_size
# How many chunks in buffer
self._buffer_length_limit = buffer_length_limit
self._file_path = None
self._file_path = file_path
self._working = False
self._token = token
self._password = password
@ -36,6 +45,7 @@ class Client:
self._encryption_complete = False
self._file_name = None
self._file_size = None
self._file_size_string = None
@property
def host(self) -> str:
@ -45,6 +55,21 @@ class Client:
def port(self) -> int:
return self._port
@property
def action(self) -> str:
"""Client role.
Possible values:
- `send`
- `receive`
"""
return self._action
@property
def standalone(self) -> bool:
"""Tell whether client should run as server as well."""
return self._standalone
@property
def stopping(self) -> bool:
return self._stopping
@ -101,52 +126,128 @@ class Client:
def file_size(self):
return self._file_size
async def run_client(self, file_path, action):
self._file_path = file_path
if action == 'send':
file_name = os.path.basename(os.path.abspath(file_path))
file_size = os.path.getsize(os.path.abspath(file_path))
@property
def file_size_string(self):
return self._file_size_string
async def run_client(self) -> None:
if self.action == 'send':
file_name = os.path.basename(os.path.abspath(self.file_path))
file_size = os.path.getsize(os.path.abspath(self.file_path))
# File size increases after encryption
# "Salted_" (8 bytes) + salt (8 bytes)
# Then, 1-16 bytes are added to make file_size a multiple of 16
# i.e., (32 - file_size mod 16) bytes are added to original size
if self.password:
file_size += 32 - (file_size % 16)
self.set_file_information(
file_name=file_name,
file_size=file_size
)
try:
reader, writer = await asyncio.open_connection(
if self.standalone:
server = await asyncio.start_server(
ssl=self.ssl_context,
client_connected_cb=self._connect,
host=self.host,
port=self.port,
ssl=self.ssl_context
)
self._reader = reader
self._writer = writer
except ConnectionRefusedError as exception:
logging.error(exception)
return
writer.write(
(
f"s|{self.token}|"
f"{self.file_name}|{self.file_size}\n".encode('utf-8')
) if action == 'send'
else f"r|{self.token}\n".encode('utf-8')
)
await writer.drain()
async with server:
logging.info("Running at `{s.host}:{s.port}`".format(s=self))
await server.serve_forever()
else:
try:
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
ssl=self.ssl_context
)
except (ConnectionRefusedError, ConnectionResetError) as exception:
logging.error(f"Connection error: {exception}")
return
await self.connect(reader=reader, writer=writer)
async def _connect(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
try:
return await self.connect(reader, writer)
except KeyboardInterrupt:
print()
except Exception as e:
logging.error(e)
async def connect(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
self._reader = reader
self._writer = writer
async def _write(message: Union[list, str, bytes],
terminate_line=True) -> int:
"""Framework for `asyncio.StreamWriter.write` method.
Create string from list, encode it, send and drain writer.
Return 0 on success, 1 on error.
"""
# Adapt
if type(message) is list:
message = '|'.join(map(str, message))
if type(message) is str:
if terminate_line:
message += '\n'
message = message.encode('utf-8')
if type(message) is not bytes:
return 1
try:
writer.write(message)
await writer.drain()
except ConnectionResetError:
logging.error("Client disconnected.")
except Exception as e:
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
else:
return 0 # On success, return 0
# On exception, return 1
return 1
if self.action == 'send' or not self.standalone:
if await _write(
[self.action[0], self.token,
self.file_name, self.file_size]
):
return
# Wait for server start signal
while 1:
server_hello = await reader.readline()
server_hello = await self.reader.readline()
if not server_hello:
logging.info("Server disconnected.")
logging.error("Server disconnected.")
return
server_hello = server_hello.decode('utf-8').strip('\n').split('|')
if action == 'receive' and server_hello[0] == 's':
if self.action == 'receive' and server_hello[0] == 's':
self.set_file_information(file_name=server_hello[2],
file_size=server_hello[3])
elif (
self.standalone
and self.action == 'send'
and server_hello[0] == 'r'
):
# Check token
if server_hello[1] != self.token:
if await _write("Invalid session token!"):
return
return
elif server_hello[0] == 'start!':
break
else:
logging.info(f"Server said: {'|'.join(server_hello)}")
if action == 'send':
await self.send(writer=writer)
if self.standalone:
if await _write("start!"):
return
break
if self.action == 'send':
await self.send(writer=self.writer)
else:
await self.receive(reader=reader)
await self.receive(reader=self.reader)
async def encrypt_file(self, input_file, output_file):
self._encryption_complete = False
@ -201,28 +302,27 @@ class Client:
break
try:
writer.write(output_data)
await writer.drain()
await asyncio.wait_for(writer.drain(), timeout=3.0)
except ConnectionResetError:
print() # New line after progress_bar
logging.error('Server closed the connection.')
self.stop()
break
bytes_sent += self.buffer_chunk_size
except asyncio.exceptions.TimeoutError:
print() # New line after progress_bar
logging.error('Server closed the connection.')
self.stop()
break
bytes_sent += len(output_data)
new_progress = min(
int(bytes_sent / self.file_size * 100),
100
)
progress_showed = (new_progress // 10) * 10
sys.stdout.write(
f"\t\t\tSending `{self.file_name}`: "
f"{'#' * (progress_showed // 10)}"
f"{'.' * ((100 - progress_showed) // 10)}\t"
f"{new_progress}% completed "
f"({min(bytes_sent, self.file_size) // 1000} "
f"of {self.file_size // 1000} KB)\r"
self.print_progress_bar(
progress=new_progress,
bytes_=bytes_sent,
)
sys.stdout.flush()
sys.stdout.write('\n')
sys.stdout.flush()
print() # New line after progress_bar
writer.close()
return
@ -242,26 +342,19 @@ class Client:
bytes_received = 0
while not self.stopping:
input_data = await reader.read(self.buffer_chunk_size)
bytes_received += self.buffer_chunk_size
bytes_received += len(input_data)
new_progress = min(
int(bytes_received / self.file_size * 100),
100
)
progress_showed = (new_progress // 10) * 10
sys.stdout.write(
f"\t\t\tReceiving `{self.file_name}`: "
f"{'#' * (progress_showed // 10)}"
f"{'.' * ((100 - progress_showed) // 10)}\t"
f"{new_progress}% completed "
f"({min(bytes_received, self.file_size) // 1000} "
f"of {self.file_size // 1000} KB)\r"
self.print_progress_bar(
progress=new_progress,
bytes_=bytes_received
)
sys.stdout.flush()
if not input_data:
break
file_to_receive.write(input_data)
sys.stdout.write('\n')
sys.stdout.flush()
print() # New line after sys.stdout.write
logging.info("File received.")
if self.password:
logging.info("Decrypting file...")
@ -289,7 +382,8 @@ class Client:
if self.working:
logging.info("Received interruption signal, stopping...")
self._stopping = True
self.writer.close()
if self.writer:
self.writer.close()
else:
raise KeyboardInterrupt("Not working yet...")
@ -298,24 +392,62 @@ class Client:
self._file_name = file_name
if file_size is not None:
self._file_size = int(file_size)
self._file_size_string = utilities.get_file_size_representation(
self.file_size
)
def run(self, file_path, action):
def run(self):
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(
self.run_client(file_path=file_path,
action=action)
self.run_client()
)
except KeyboardInterrupt:
print()
logging.error("Interrupted")
for task in asyncio.all_tasks(loop):
task.cancel()
self.writer.close()
if self.writer:
self.writer.close()
loop.run_until_complete(
self.writer.wait_closed()
self.wait_closed()
)
loop.close()
def print_progress_bar(self, progress: int, bytes_: int):
"""Print client progress bar.
`progress` % = `bytes_string` transferred
out of `self.file_size_string`.
"""
action = {
'send': "Sending",
'receive': "Receiving"
}[self.action]
bytes_string = utilities.get_file_size_representation(
bytes_
)
utilities.print_progress_bar(
prefix=f"\t\t\t{action} `{self.file_name}`: ",
done_symbol='#',
pending_symbol='.',
progress=progress,
scale=5,
suffix=(
" completed "
f"({bytes_string} "
f"of {self.file_size_string})"
)
)
@staticmethod
async def wait_closed() -> None:
"""Give time to cancelled tasks to end properly.
Sleep .1 second and return.
"""
await asyncio.sleep(.1)
def get_action(action):
"""Parse abbreviations for `action`."""
@ -380,6 +512,11 @@ def main():
default=None,
required=False,
help='server SSL certificate')
cli_parser.add_argument('--key', type=str,
default=None,
required=False,
help='server SSL key (required only for '
'SSL-secured standalone client)')
cli_parser.add_argument('--action', type=str,
default=None,
required=False,
@ -397,6 +534,9 @@ def main():
required=False,
help='Session token '
'(must be the same for both clients)')
cli_parser.add_argument('--standalone',
action='store_true',
help='Run both as client and server')
cli_parser.add_argument('others',
metavar='R or S',
nargs='*',
@ -405,10 +545,12 @@ def main():
host = args['host']
port = args['port']
certificate = args['certificate']
key = args['key']
action = get_action(args['action'])
file_path = args['path']
password = args['password']
token = args['token']
standalone = args['standalone']
# If host and port are not provided from command-line, try to import them
sys.path.append(os.path.abspath('.'))
@ -455,6 +597,11 @@ def main():
from config import certificate
except ImportError:
certificate = None
if key is None or not os.path.isfile(key):
try:
from config import key
except ImportError:
key = None
# If import fails, prompt user for host or port
new_settings = {} # After getting these settings, offer to store them
@ -540,17 +687,18 @@ def main():
logging.info("Configuration values stored.")
else:
logging.info("Proceeding without storing values...")
client = Client(
host=host,
port=port,
password=password,
token=token
)
ssl_context = None
if certificate is not None:
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
_ssl_context.check_hostname = False
_ssl_context.load_verify_locations(certificate)
client.set_ssl_context(_ssl_context)
if key is None: # Server-dependent client
ssl_context = ssl.create_default_context(
purpose=ssl.Purpose.SERVER_AUTH
)
ssl_context.load_verify_locations(certificate)
else: # Standalone client
ssl_context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH
)
ssl_context.load_cert_chain(certificate, key)
else:
logging.warning(
"Please consider using SSL. To do so, add in `config.py` or "
@ -559,10 +707,17 @@ def main():
"certificate = 'path/to/certificate.crt'"
)
logging.info("Starting client...")
client.run(
client = Client(
host=host,
port=port,
ssl_context=ssl_context,
action=action,
standalone=standalone,
file_path=file_path,
action=action
password=password,
token=token
)
client.run()
logging.info("Stopped client")

View File

@ -13,10 +13,11 @@ from typing import Union
class Server:
def __init__(self, host='localhost', port=5000,
def __init__(self, host='localhost', port=5000, ssl_context=None,
buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
self._host = host
self._port = port
self._ssl_context = ssl_context
self.connections = collections.OrderedDict()
# Dict of queues of bytes
self.buffers = collections.OrderedDict()
@ -87,27 +88,24 @@ class Server:
async def run_writer(self, writer, connection_token):
consecutive_interruptions = 0
errors = 0
while 1:
while connection_token in self.buffers:
try:
try:
if connection_token not in self.buffers:
break
input_data = self.buffers[connection_token].popleft()
except IndexError:
# Slow down if buffer is short
consecutive_interruptions += 1
if consecutive_interruptions > 3:
break
await asyncio.sleep(.5)
continue
else:
consecutive_interruptions = 0
if not input_data:
input_data = self.buffers[connection_token].popleft()
except IndexError:
# Slow down if buffer is empty; after 1.5 s of silence, break
consecutive_interruptions += 1
if consecutive_interruptions > 3:
break
await asyncio.sleep(.5)
continue
else:
consecutive_interruptions = 0
if not input_data:
break
try:
writer.write(input_data)
await writer.drain()
except ConnectionResetError as e:
logging.error("Here")
logging.error(e)
break
except Exception as e:
@ -127,7 +125,7 @@ class Server:
"""
client_hello = await reader.readline()
client_hello = client_hello.decode('utf-8').strip('\n').split('|')
if len(client_hello) not in (2, 4,):
if len(client_hello) != 4:
await self.refuse_connection(writer=writer,
message="Invalid client_hello!")
return
@ -142,7 +140,7 @@ class Server:
terminate_line=True) -> int:
# Adapt
if type(message) is list:
message = '|'.join(message)
message = '|'.join(map(str, message))
if type(message) is str:
if terminate_line:
message += '\n'
@ -211,12 +209,13 @@ class Server:
index = 0
await asyncio.sleep(.5)
# Send file information and start signal to client
writer.write(
"s|hidden_token|"
f"{self.connections[connection_token]['file_name']}|"
f"{self.connections[connection_token]['file_size']}"
"\n".encode('utf-8')
)
if await _write(
['s',
'hidden_token',
self.connections[connection_token]['file_name'],
self.connections[connection_token]['file_size']]
):
return
if await _write("start!"):
return
await self.run_writer(writer=writer,
@ -238,6 +237,7 @@ class Server:
try:
loop.run_until_complete(self.run_server())
except KeyboardInterrupt:
print()
logging.info("Stopping...")
# Cancel connection tasks (they should be done but are pending)
for task in asyncio.all_tasks(loop):
@ -336,10 +336,6 @@ def main():
logging.info("Invalid port. Enter a valid port number!")
port = None
server = Server(
host=host,
port=port,
)
try:
if certificate is None or not os.path.isfile(certificate):
from config import certificate
@ -350,12 +346,12 @@ def main():
if not os.path.isfile(key):
key = None
except ImportError:
pass
certificate = None
key = None
ssl_context = None
if certificate and key:
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
_ssl_context.check_hostname = False
_ssl_context.load_cert_chain(certificate, key)
server.set_ssl_context(_ssl_context)
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
ssl_context.load_cert_chain(certificate, key)
else:
logging.warning(
"Please consider using SSL. To do so, add in `config.py` or "
@ -364,6 +360,11 @@ def main():
"key = 'path/to/secret.key'\n"
"certificate = 'path/to/certificate.crt'"
)
server = Server(
host=host,
port=port,
ssl_context=ssl_context
)
server.run()

View File

@ -2,11 +2,51 @@
import logging
import signal
import sys
units_of_measurements = {
1: 'bytes',
1000: 'KB',
1000 * 1000: 'MB',
1000 * 1000 * 1000: 'GB',
1000 * 1000 * 1000 * 1000: 'TB',
}
def get_file_size_representation(file_size):
scale, unit = get_scale_and_unit(file_size=file_size)
if scale < 10:
return f"{file_size} {unit}"
return f"{(file_size // (scale / 100)) / 100:.2f} {unit}"
def get_scale_and_unit(file_size):
scale, unit = min(units_of_measurements.items())
for scale, unit in sorted(units_of_measurements.items(), reverse=True):
if file_size > scale:
break
return scale, unit
def print_progress_bar(prefix='',
suffix='',
done_symbol="#",
pending_symbol=".",
progress=0,
scale=10):
progress_showed = (progress // scale) * scale
sys.stdout.write(
f"{prefix}"
f"{done_symbol * (progress_showed // scale)}"
f"{pending_symbol * ((100 - progress_showed) // scale)}\t"
f"{progress}%"
f"{suffix} \r"
)
sys.stdout.flush()
def timed_input(message: str = None,
timeout: int = 5):
class TimeoutExpired(Exception):
pass