521 lines
18 KiB
Python
521 lines
18 KiB
Python
import argparse
|
|
import asyncio
|
|
import collections
|
|
import logging
|
|
import os
|
|
import random
|
|
import ssl
|
|
import string
|
|
import sys
|
|
|
|
|
|
class Client:
|
|
def __init__(self, host='localhost', port=3001,
|
|
buffer_chunk_size=10**4, buffer_length_limit=10**4,
|
|
password=None, token=None):
|
|
self._host = host
|
|
self._port = port
|
|
self._stopping = False
|
|
# Shared queue of bytes
|
|
self.buffer = collections.deque()
|
|
# How many bytes per chunk
|
|
self._buffer_chunk_size = buffer_chunk_size
|
|
# How many chunks in buffer
|
|
self._buffer_length_limit = buffer_length_limit
|
|
self._file_path = None
|
|
self._working = False
|
|
self._token = token
|
|
self._password = password
|
|
self._ssl_context = None
|
|
self._encryption_complete = False
|
|
self._file_name = None
|
|
self._file_size = None
|
|
|
|
@property
|
|
def host(self) -> str:
|
|
return self._host
|
|
|
|
@property
|
|
def port(self) -> int:
|
|
return self._port
|
|
|
|
@property
|
|
def stopping(self) -> bool:
|
|
return self._stopping
|
|
|
|
@property
|
|
def buffer_length_limit(self) -> int:
|
|
return self._buffer_length_limit
|
|
|
|
@property
|
|
def buffer_chunk_size(self) -> int:
|
|
return self._buffer_chunk_size
|
|
|
|
@property
|
|
def file_path(self) -> str:
|
|
return self._file_path
|
|
|
|
@property
|
|
def working(self) -> bool:
|
|
return self._working
|
|
|
|
@property
|
|
def ssl_context(self) -> ssl.SSLContext:
|
|
return self._ssl_context
|
|
|
|
def set_ssl_context(self, ssl_context: ssl.SSLContext):
|
|
self._ssl_context = ssl_context
|
|
|
|
@property
|
|
def token(self):
|
|
return self._token
|
|
|
|
@property
|
|
def password(self):
|
|
"""Password for file encryption or decryption."""
|
|
return self._password
|
|
|
|
@property
|
|
def encryption_complete(self):
|
|
return self._encryption_complete
|
|
|
|
@property
|
|
def file_name(self):
|
|
return self._file_name
|
|
|
|
@property
|
|
def file_size(self):
|
|
return self._file_size
|
|
|
|
async def run_sending_client(self, file_path='~/output.txt'):
|
|
self._file_path = file_path
|
|
file_name = os.path.basename(os.path.abspath(file_path))
|
|
file_size = os.path.getsize(os.path.abspath(file_path))
|
|
try:
|
|
reader, writer = await asyncio.open_connection(
|
|
host=self.host,
|
|
port=self.port,
|
|
ssl=self.ssl_context
|
|
)
|
|
except ConnectionRefusedError as exception:
|
|
logging.error(exception)
|
|
return
|
|
writer.write(
|
|
f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
|
|
)
|
|
self.set_file_information(file_name=file_name,
|
|
file_size=file_size)
|
|
await writer.drain()
|
|
# Wait for server start signal
|
|
while 1:
|
|
server_hello = await reader.readline()
|
|
if not server_hello:
|
|
logging.info("Server disconnected.")
|
|
return
|
|
server_hello = server_hello.decode('utf-8').strip('\n')
|
|
if server_hello == 'start!':
|
|
break
|
|
logging.info(f"Server said: {server_hello}")
|
|
await self.send(writer=writer)
|
|
|
|
async def encrypt_file(self, input_file, output_file):
|
|
self._encryption_complete = False
|
|
logging.info("Encrypting file...")
|
|
stdout, stderr = ''.encode(), ''.encode()
|
|
try:
|
|
_subprocess = await asyncio.create_subprocess_shell(
|
|
"openssl enc -aes-256-cbc "
|
|
"-md sha512 -pbkdf2 -iter 100000 -salt "
|
|
f"-in \"{input_file}\" -out \"{output_file}\" "
|
|
f"-pass pass:{self.password}"
|
|
)
|
|
stdout, stderr = await _subprocess.communicate()
|
|
except Exception as e:
|
|
logging.error(
|
|
"Exception {e}:\n{o}\n{er}".format(
|
|
e=e,
|
|
o=stdout.decode().strip(),
|
|
er=stderr.decode().strip()
|
|
)
|
|
)
|
|
logging.info("Encryption completed.")
|
|
self._encryption_complete = True
|
|
|
|
async def send(self, writer: asyncio.StreamWriter):
|
|
self._working = True
|
|
file_path = self.file_path
|
|
if self.password:
|
|
file_path = self.file_path + '.enc'
|
|
# Remove already-encrypted file if present (salt would differ)
|
|
if os.path.isfile(file_path):
|
|
os.remove(file_path)
|
|
asyncio.ensure_future(
|
|
self.encrypt_file(
|
|
input_file=self.file_path,
|
|
output_file=file_path
|
|
)
|
|
)
|
|
# Give encryption an edge
|
|
while not os.path.isfile(file_path):
|
|
await asyncio.sleep(.5)
|
|
logging.info("Sending file...")
|
|
bytes_sent = 0
|
|
with open(file_path, 'rb') as file_to_send:
|
|
while not self.stopping:
|
|
output_data = file_to_send.read(self.buffer_chunk_size)
|
|
if not output_data:
|
|
# If encryption is in progress, wait and read again later
|
|
if self.password and not self.encryption_complete:
|
|
await asyncio.sleep(1)
|
|
continue
|
|
break
|
|
try:
|
|
writer.write(output_data)
|
|
await writer.drain()
|
|
except ConnectionResetError:
|
|
logging.info('Server closed the connection.')
|
|
self.stop()
|
|
break
|
|
bytes_sent += self.buffer_chunk_size
|
|
new_progress = min(
|
|
int(bytes_sent / self.file_size * 100),
|
|
100
|
|
)
|
|
progress_showed = (new_progress // 10) * 10
|
|
sys.stdout.write(
|
|
f"\t\t\tSending `{self.file_name}`: "
|
|
f"{'#' * (progress_showed // 10)}"
|
|
f"{'.' * ((100 - progress_showed) // 10)}\t"
|
|
f"{new_progress}% completed "
|
|
f"({min(bytes_sent, self.file_size) // 1000} "
|
|
f"of {self.file_size // 1000} KB)\r"
|
|
)
|
|
sys.stdout.flush()
|
|
sys.stdout.write('\n')
|
|
sys.stdout.flush()
|
|
writer.close()
|
|
return
|
|
|
|
async def run_receiving_client(self, file_path='~/input.txt'):
|
|
self._file_path = file_path
|
|
try:
|
|
reader, writer = await asyncio.open_connection(
|
|
host=self.host,
|
|
port=self.port,
|
|
ssl=self.ssl_context
|
|
)
|
|
except ConnectionRefusedError as exception:
|
|
logging.error(exception)
|
|
return
|
|
writer.write(f"r|{self.token}\n".encode('utf-8'))
|
|
await writer.drain()
|
|
# Wait for server start signal
|
|
while 1:
|
|
server_hello = await reader.readline()
|
|
if not server_hello:
|
|
logging.info("Server disconnected.")
|
|
return
|
|
server_hello = server_hello.decode('utf-8').strip('\n')
|
|
if server_hello.startswith('info'):
|
|
_, file_name, file_size = server_hello.split('|')
|
|
self.set_file_information(file_name=file_name,
|
|
file_size=file_size)
|
|
elif server_hello == 'start!':
|
|
break
|
|
else:
|
|
logging.info(f"Server said: {server_hello}")
|
|
await self.receive(reader=reader)
|
|
|
|
async def receive(self, reader: asyncio.StreamReader):
|
|
self._working = True
|
|
file_path = os.path.join(
|
|
os.path.abspath(
|
|
self.file_path
|
|
),
|
|
self.file_name
|
|
)
|
|
original_file_path = file_path
|
|
if self.password:
|
|
file_path += '.enc'
|
|
logging.info("Receiving file...")
|
|
with open(file_path, 'wb') as file_to_receive:
|
|
bytes_received = 0
|
|
while not self.stopping:
|
|
input_data = await reader.read(self.buffer_chunk_size)
|
|
bytes_received += self.buffer_chunk_size
|
|
new_progress = min(
|
|
int(bytes_received / self.file_size * 100),
|
|
100
|
|
)
|
|
progress_showed = (new_progress // 10) * 10
|
|
sys.stdout.write(
|
|
f"\t\t\tReceiving `{self.file_name}`: "
|
|
f"{'#' * (progress_showed // 10)}"
|
|
f"{'.' * ((100 - progress_showed) // 10)}\t"
|
|
f"{new_progress}% completed "
|
|
f"({min(bytes_received, self.file_size) // 1000} "
|
|
f"of {self.file_size // 1000} KB)\r"
|
|
)
|
|
sys.stdout.flush()
|
|
if not input_data:
|
|
break
|
|
file_to_receive.write(input_data)
|
|
sys.stdout.write('\n')
|
|
sys.stdout.flush()
|
|
logging.info("File received.")
|
|
if self.password:
|
|
logging.info("Decrypting file...")
|
|
stdout, stderr = ''.encode(), ''.encode()
|
|
try:
|
|
_subprocess = await asyncio.create_subprocess_shell(
|
|
"openssl enc -aes-256-cbc "
|
|
"-md sha512 -pbkdf2 -iter 100000 -salt -d "
|
|
f"-in \"{file_path}\" -out \"{original_file_path}\" "
|
|
f"-pass pass:{self.password}"
|
|
)
|
|
stdout, stderr = await _subprocess.communicate()
|
|
logging.info("Decryption completed.")
|
|
except Exception as e:
|
|
logging.error(
|
|
"Exception {e}:\n{o}\n{er}".format(
|
|
e=e,
|
|
o=stdout.decode().strip(),
|
|
er=stderr.decode().strip()
|
|
)
|
|
)
|
|
logging.info("Decryption failed", exc_info=True)
|
|
|
|
def stop(self, *_):
|
|
if self.working:
|
|
logging.info("Received interruption signal, stopping...")
|
|
self._stopping = True
|
|
else:
|
|
raise KeyboardInterrupt("Not working yet...")
|
|
|
|
def set_file_information(self, file_name=None, file_size=None):
|
|
if file_name is not None:
|
|
self._file_name = file_name
|
|
if file_size is not None:
|
|
self._file_size = int(file_size)
|
|
|
|
|
|
def get_action(action):
|
|
"""Parse abbreviations for `action`."""
|
|
if not isinstance(action, str):
|
|
return
|
|
elif action.lower().startswith('r'):
|
|
return 'receive'
|
|
elif action.lower().startswith('s'):
|
|
return 'send'
|
|
|
|
|
|
def get_file_path(path, action='receive'):
|
|
"""Check that file `path` is correct and return it."""
|
|
if (
|
|
isinstance(path, str)
|
|
and action == 'send'
|
|
and os.path.isfile(path)
|
|
):
|
|
return path
|
|
elif (
|
|
isinstance(path, str)
|
|
and action == 'receive'
|
|
and os.access(os.path.dirname(os.path.abspath(path)), os.W_OK)
|
|
):
|
|
return path
|
|
elif path is not None:
|
|
logging.error(f"Invalid file: `{path}`")
|
|
|
|
|
|
def main():
|
|
# noinspection SpellCheckingInspection
|
|
log_formatter = logging.Formatter(
|
|
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
|
|
style='%'
|
|
)
|
|
root_logger = logging.getLogger()
|
|
root_logger.setLevel(logging.DEBUG)
|
|
|
|
# noinspection PyUnresolvedReferences
|
|
asyncio.selector_events.logger.setLevel(logging.ERROR)
|
|
|
|
console_handler = logging.StreamHandler()
|
|
console_handler.setFormatter(log_formatter)
|
|
console_handler.setLevel(logging.DEBUG)
|
|
root_logger.addHandler(console_handler)
|
|
|
|
# Parse command-line arguments
|
|
cli_parser = argparse.ArgumentParser(description='Run client',
|
|
allow_abbrev=False)
|
|
cli_parser.add_argument('--host', type=str,
|
|
default=None,
|
|
required=False,
|
|
help='server address')
|
|
cli_parser.add_argument('--port', type=int,
|
|
default=None,
|
|
required=False,
|
|
help='server port')
|
|
cli_parser.add_argument('--action', type=str,
|
|
default=None,
|
|
required=False,
|
|
help='[S]end or [R]eceive')
|
|
cli_parser.add_argument('--path', type=str,
|
|
default=None,
|
|
required=False,
|
|
help='File path to send / folder path to receive')
|
|
cli_parser.add_argument('--password', '--p', '--pass', type=str,
|
|
default=None,
|
|
required=False,
|
|
help='Password for file encryption or decryption')
|
|
cli_parser.add_argument('--token', '--t', '--session_token', type=str,
|
|
default=None,
|
|
required=False,
|
|
help='Session token '
|
|
'(must be the same for both clients)')
|
|
cli_parser.add_argument('others',
|
|
metavar='R or S',
|
|
nargs='*',
|
|
help='[S]end or [R]eceive (see `action`)')
|
|
args = vars(cli_parser.parse_args())
|
|
host = args['host']
|
|
port = args['port']
|
|
action = get_action(args['action'])
|
|
file_path = args['path']
|
|
password = args['password']
|
|
token = args['token']
|
|
|
|
# If host and port are not provided from command-line, try to import them
|
|
if host is None:
|
|
try:
|
|
from config import host
|
|
except ImportError:
|
|
host = None
|
|
if port is None:
|
|
try:
|
|
from config import port
|
|
except ImportError:
|
|
port = None
|
|
# Take `s`, `r` etc. from command line as `action`
|
|
if action is None:
|
|
for arg in args['others']:
|
|
action = get_action(arg)
|
|
if action:
|
|
break
|
|
if action is None:
|
|
try:
|
|
from config import action
|
|
action = get_action(action)
|
|
except ImportError:
|
|
action = None
|
|
if file_path is None:
|
|
try:
|
|
from config import file_path
|
|
file_path = get_action(file_path)
|
|
except ImportError:
|
|
file_path = None
|
|
if password is None:
|
|
try:
|
|
from config import password
|
|
except ImportError:
|
|
password = None
|
|
if token is None:
|
|
try:
|
|
from config import token
|
|
except ImportError:
|
|
token = None
|
|
|
|
# If import fails, prompt user for host or port
|
|
while host is None:
|
|
host = input("Enter host:\t\t\t\t\t\t")
|
|
while port is None:
|
|
try:
|
|
port = int(input("Enter port:\t\t\t\t\t\t"))
|
|
except ValueError:
|
|
logging.info("Invalid port. Enter a valid port number!")
|
|
port = None
|
|
while action is None:
|
|
action = get_action(
|
|
input("Do you want to (R)eceive or (S)end a file?\t\t")
|
|
)
|
|
if (
|
|
(action == 'send'
|
|
and not os.path.isfile(os.path.abspath(file_path)))
|
|
or (action == 'receive'
|
|
and not os.path.isdir(os.path.abspath(file_path)))
|
|
):
|
|
file_path = None
|
|
while file_path is None:
|
|
if action == 'send':
|
|
file_path = get_file_path(
|
|
path=input(f"Enter file to send:\t\t\t\t\t\t"),
|
|
action=action
|
|
)
|
|
if not os.path.isfile(os.path.abspath(file_path)):
|
|
file_path = None
|
|
elif action == 'receive':
|
|
file_path = get_file_path(
|
|
path=input(f"Enter destination folder:\t\t\t\t\t\t"),
|
|
action=action
|
|
)
|
|
if not os.path.isdir(os.path.abspath(file_path)):
|
|
file_path = None
|
|
if password is None:
|
|
logging.warning(
|
|
"You have provided no password for file encryption.\n"
|
|
"Your file will be unencoded unless you provide a password in "
|
|
"config file."
|
|
)
|
|
if token is None and action == 'send':
|
|
# Generate a random [6-10] chars-long alphanumerical token
|
|
token = ''.join(
|
|
random.SystemRandom().choice(
|
|
string.ascii_uppercase + string.digits
|
|
)
|
|
for _ in range(random.SystemRandom().randint(6, 10))
|
|
)
|
|
logging.info(
|
|
"You have not provided a token for this connection.\n"
|
|
f"A token has been generated for you:\t\t{token}\n"
|
|
"Your peer must be informed of this token.\n"
|
|
"For future connections, you may provide a custom token writing "
|
|
"it in config file."
|
|
)
|
|
while token is None or not (6 <= len(token) <= 10):
|
|
token = input("Please enter a 6-10 chars token.\t\t\t\t")
|
|
loop = asyncio.get_event_loop()
|
|
client = Client(
|
|
host=host,
|
|
port=port,
|
|
password=password,
|
|
token=token
|
|
)
|
|
try:
|
|
from config import certificate
|
|
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
|
_ssl_context.check_hostname = False
|
|
_ssl_context.load_verify_locations(certificate)
|
|
client.set_ssl_context(_ssl_context)
|
|
except ImportError:
|
|
logging.warning("Please consider using SSL.")
|
|
# noinspection PyUnusedLocal
|
|
certificate = None
|
|
logging.info("Starting client...")
|
|
if action == 'send':
|
|
loop.run_until_complete(
|
|
client.run_sending_client(
|
|
file_path=file_path
|
|
)
|
|
)
|
|
else:
|
|
loop.run_until_complete(
|
|
client.run_receiving_client(
|
|
file_path=file_path
|
|
)
|
|
)
|
|
loop.close()
|
|
logging.info("Stopped client")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|