diff --git a/filebridging/client.py b/filebridging/client.py index 3e29a21..23c3284 100644 --- a/filebridging/client.py +++ b/filebridging/client.py @@ -40,25 +40,27 @@ class Client: buffer_length_limit=10 ** 4, file_path=None, password=None, - token=None): + token=None, + ssl_handshake_timeout=None): self._host = host self._port = port self._ssl_context = ssl_context self._action = action self._standalone = standalone - self._stopping = False - self._reader = None - self._writer = None - # Shared queue of bytes - self.buffer = collections.deque() + self._file_path = file_path + self._password = password + self._token = token + self._ssl_handshake_timeout = ssl_handshake_timeout # How many bytes per chunk self._buffer_chunk_size = buffer_chunk_size # How many chunks in buffer self._buffer_length_limit = buffer_length_limit - self._file_path = file_path + # Shared queue of bytes + self.buffer = collections.deque() self._working = False - self._token = token - self._password = password + self._stopping = False + self._reader = None + self._writer = None self._encryption_complete = False self._file_name = None self._file_size = None @@ -138,6 +140,16 @@ class Client: def set_ssl_context(self, ssl_context: ssl.SSLContext): self._ssl_context = ssl_context + @property + def ssl_handshake_timeout(self) -> Union[int, None]: + """Return SSL handshake timeout. + + If SSL context is not set, return None. + Otherwise, return seconds to wait before considering handshake failed. + """ + if self.ssl_context: + return self._ssl_handshake_timeout + @property def token(self): """Session token. @@ -199,13 +211,13 @@ class Client: host=self.host, port=self.port, ssl=self.ssl_context, - ssl_handshake_timeout=5 + ssl_handshake_timeout=self.ssl_handshake_timeout ) except (ConnectionRefusedError, ConnectionResetError, ConnectionAbortedError) as exception: logging.error(f"Connection error: {exception}") return - except ssl.SSLCertVerificationError as exception: + except ssl.SSLError as exception: logging.error(f"SSL error: {exception}") return await self.connect(reader=reader, writer=writer) @@ -276,7 +288,7 @@ class Client: while 1: server_hello = await self.reader.readline() if not server_hello: - logging.error("Server disconnected.") + logging.error("Server refused connection.") return server_hello = server_hello.decode('utf-8').strip('\n').split('|') if self.action == 'receive' and server_hello[0] == 's': diff --git a/filebridging/server.py b/filebridging/server.py index dc34244..fc10c64 100644 --- a/filebridging/server.py +++ b/filebridging/server.py @@ -237,8 +237,10 @@ class Server: return def disconnect(self, connection_token: str) -> None: - del self.buffers[connection_token] - del self.connections[connection_token] + if connection_token in self.buffers: + del self.buffers[connection_token] + if connection_token in self.connections: + del self.connections[connection_token] def run(self): loop = asyncio.get_event_loop()