Skip to content

Commit d372fac

Browse files
committed
websocket clietn hardening
1 parent f163df0 commit d372fac

File tree

3 files changed

+288
-69
lines changed

3 files changed

+288
-69
lines changed

src/mistapi/websockets/__ws_client.py

Lines changed: 123 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
self._auto_reconnect = auto_reconnect
6262
self._max_reconnect_attempts = max_reconnect_attempts
6363
self._reconnect_backoff = reconnect_backoff
64+
self._lock = threading.Lock()
6465
self._ws: websocket.WebSocketApp | None = None
6566
self._thread: threading.Thread | None = None
6667
self._queue: queue.Queue[dict | None] = queue.Queue()
@@ -81,7 +82,15 @@ def __init__(
8182
# Auth / URL helpers
8283

8384
def _build_ws_url(self) -> str:
84-
return f"wss://{self._mist_session._cloud_uri.replace('api.', 'api-ws.', 1)}/api-ws/v1/stream"
85+
cloud_uri = self._mist_session._cloud_uri
86+
if not cloud_uri.startswith("api."):
87+
logger.warning(
88+
"cloud_uri %r does not start with 'api.'; "
89+
"WebSocket URL may be incorrect",
90+
cloud_uri,
91+
)
92+
ws_host = cloud_uri.replace("api.", "api-ws.", 1)
93+
return f"wss://{ws_host}/api-ws/v1/stream"
8594

8695
def _get_headers(self) -> dict:
8796
if self._mist_session._apitoken:
@@ -105,7 +114,7 @@ def _get_cookie(self) -> str | None:
105114
c.name,
106115
)
107116
continue
108-
safe.append(f"{c.name}={c.value}")
117+
safe.append(f"{c.name}={c.value or ''}")
109118
return "; ".join(safe) if safe else None
110119
return None
111120

@@ -150,14 +159,23 @@ def on_close(self, callback: Callable[[int | None, str | None], None]) -> None:
150159
# Internal WebSocketApp handlers
151160

152161
def _handle_open(self, ws: websocket.WebSocketApp) -> None:
153-
for channel in self._channels:
154-
ws.send(json.dumps({"subscribe": channel}))
162+
try:
163+
for channel in self._channels:
164+
ws.send(json.dumps({"subscribe": channel}))
165+
except Exception as exc:
166+
logger.error("Subscription send failed: %s", exc)
167+
self._handle_error(ws, exc)
168+
ws.close()
169+
return
155170
self._reconnect_attempts = 0
156171
self._last_close_code = None
157172
self._last_close_msg = None
158173
self._connected.set()
159174
if self._on_open_cb:
160-
self._on_open_cb()
175+
try:
176+
self._on_open_cb()
177+
except Exception:
178+
logger.exception("on_open callback raised")
161179

162180
def _handle_message(self, ws: websocket.WebSocketApp, message: str | bytes) -> None:
163181
if isinstance(message, bytes):
@@ -166,13 +184,20 @@ def _handle_message(self, ws: websocket.WebSocketApp, message: str | bytes) -> N
166184
data = json.loads(message)
167185
except (json.JSONDecodeError, TypeError):
168186
data = {"raw": message}
169-
self._queue.put(data)
170187
if self._on_message_cb:
171-
self._on_message_cb(data)
188+
try:
189+
self._on_message_cb(data)
190+
except Exception:
191+
logger.exception("on_message callback raised")
192+
else:
193+
self._queue.put(data)
172194

173195
def _handle_error(self, ws: websocket.WebSocketApp, error: Exception) -> None:
174196
if self._on_error_cb:
175-
self._on_error_cb(error)
197+
try:
198+
self._on_error_cb(error)
199+
except Exception:
200+
logger.exception("on_error callback raised")
176201

177202
def _handle_close(
178203
self,
@@ -209,80 +234,109 @@ def connect(self, run_in_background: bool = True) -> None:
209234
If True, runs the WebSocket loop in a daemon thread (non-blocking).
210235
If False, blocks the calling thread until disconnected.
211236
"""
212-
if self._connected.is_set() or (
213-
self._thread is not None and self._thread.is_alive()
214-
):
215-
raise RuntimeError("Already connected; call disconnect() first")
216-
self._user_disconnect.clear()
217-
self._finished.clear()
218-
self._reconnect_attempts = 0
219-
# Drain stale sentinel from previous connection
220-
while not self._queue.empty():
221-
try:
222-
self._queue.get_nowait()
223-
except queue.Empty:
224-
break
237+
with self._lock:
238+
if self._connected.is_set() or (
239+
self._thread is not None and self._thread.is_alive()
240+
):
241+
raise RuntimeError("Already connected; call disconnect() first")
242+
self._user_disconnect.clear()
243+
self._finished.clear()
244+
self._reconnect_attempts = 0
245+
# Drain stale sentinel from previous connection
246+
while not self._queue.empty():
247+
try:
248+
self._queue.get_nowait()
249+
except queue.Empty:
250+
break
225251

226-
self._ws = self._create_ws_app()
227-
if run_in_background:
228-
self._thread = threading.Thread(target=self._run_forever_safe, daemon=True)
229-
self._thread.start()
230-
else:
252+
self._ws = self._create_ws_app()
253+
if run_in_background:
254+
self._thread = threading.Thread(
255+
target=self._run_forever_safe, daemon=True
256+
)
257+
self._thread.start()
258+
if not run_in_background:
231259
self._run_forever_safe()
232260

233261
def _run_forever_safe(self) -> None:
234-
while True:
235-
try:
236-
sslopt = self._build_sslopt()
237-
self._ws.run_forever(
238-
ping_interval=self._ping_interval,
239-
ping_timeout=self._ping_timeout,
240-
sslopt=sslopt,
241-
)
242-
except Exception as exc:
243-
self._handle_error(self._ws, exc)
244-
self._handle_close(self._ws, -1, str(exc))
262+
try:
263+
while True:
264+
try:
265+
sslopt = self._build_sslopt()
266+
self._ws.run_forever(
267+
ping_interval=self._ping_interval,
268+
ping_timeout=self._ping_timeout,
269+
sslopt=sslopt,
270+
)
271+
except Exception as exc:
272+
self._handle_error(self._ws, exc)
273+
self._handle_close(self._ws, -1, str(exc))
245274

246-
if self._user_disconnect.is_set() or not self._auto_reconnect:
247-
break
275+
if self._user_disconnect.is_set() or not self._auto_reconnect:
276+
break
277+
278+
self._reconnect_attempts += 1
279+
if self._reconnect_attempts > self._max_reconnect_attempts:
280+
logger.warning(
281+
"Max reconnect attempts (%d) reached, giving up",
282+
self._max_reconnect_attempts,
283+
)
284+
break
248285

249-
self._reconnect_attempts += 1
250-
if self._reconnect_attempts > self._max_reconnect_attempts:
251-
logger.warning(
252-
"Max reconnect attempts (%d) reached, giving up",
286+
delay = self._reconnect_backoff * (
287+
2 ** (self._reconnect_attempts - 1)
288+
)
289+
logger.info(
290+
"Reconnecting in %.1fs (attempt %d/%d)",
291+
delay,
292+
self._reconnect_attempts,
253293
self._max_reconnect_attempts,
254294
)
255-
break
295+
if self._user_disconnect.wait(timeout=delay):
296+
break # disconnect() called during backoff
256297

257-
delay = self._reconnect_backoff * (2 ** (self._reconnect_attempts - 1))
258-
logger.info(
259-
"Reconnecting in %.1fs (attempt %d/%d)",
260-
delay,
261-
self._reconnect_attempts,
262-
self._max_reconnect_attempts,
263-
)
264-
if self._user_disconnect.wait(timeout=delay):
265-
break # disconnect() called during backoff
266-
267-
# Guard against a disconnect that happens immediately after the
268-
# backoff wait returns but before creating a new WebSocketApp.
269-
if self._user_disconnect.is_set():
270-
break
271-
272-
self._ws = self._create_ws_app()
298+
# Guard against a disconnect that happens immediately after the
299+
# backoff wait returns but before creating a new WebSocketApp.
300+
if self._user_disconnect.is_set():
301+
break
273302

274-
# Final close: put sentinel, call callback, signal finished
275-
self._queue.put(None)
276-
if self._on_close_cb:
277-
self._on_close_cb(self._last_close_code, self._last_close_msg)
278-
self._finished.set()
303+
with self._lock:
304+
old_ws = self._ws
305+
self._ws = self._create_ws_app()
306+
if old_ws:
307+
try:
308+
old_ws.close()
309+
except Exception:
310+
pass
311+
312+
# Final close: put sentinel, call callback
313+
self._queue.put(None)
314+
if self._on_close_cb:
315+
try:
316+
self._on_close_cb(self._last_close_code, self._last_close_msg)
317+
except Exception:
318+
logger.exception("on_close callback raised")
319+
finally:
320+
self._finished.set()
321+
322+
def disconnect(self, wait: bool = False, timeout: float | None = None) -> None:
323+
"""Close the WebSocket connection.
279324
280-
def disconnect(self) -> None:
281-
"""Close the WebSocket connection."""
325+
PARAMS
326+
-----------
327+
wait : bool, default False
328+
If True, block until the background thread has finished.
329+
timeout : float or None, default None
330+
Maximum seconds to wait for the thread to finish (only used
331+
when *wait* is True). ``None`` means wait indefinitely.
332+
"""
282333
self._user_disconnect.set()
283-
ws = self._ws
334+
with self._lock:
335+
ws = self._ws
284336
if ws:
285337
ws.close()
338+
if wait and self._thread is not None:
339+
self._thread.join(timeout=timeout)
286340

287341
def receive(self) -> Generator[dict, None, None]:
288342
"""
@@ -330,6 +384,6 @@ def __enter__(self) -> "_MistWebsocket":
330384
def __exit__(self, *args) -> None:
331385
self.disconnect()
332386

333-
def ready(self) -> bool | None:
387+
def ready(self) -> bool:
334388
"""Returns True if the WebSocket connection is open and ready."""
335389
return self._ws is not None and self._ws.ready()

src/mistapi/websockets/session.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ class SessionWithUrl(_MistWebsocket):
2828
Authenticated API session.
2929
url : str
3030
URL of the WebSocket channel to connect to.
31+
32+
.. warning::
33+
34+
The session's authentication credentials (API token or cookies)
35+
are sent to whatever host is specified in this URL. Only use
36+
trusted URLs — never pass user-supplied or untrusted input.
3137
ping_interval : int, default 30
3238
Interval in seconds to send WebSocket ping frames (keep-alive).
3339
ping_timeout : int, default 10
@@ -74,6 +80,8 @@ def __init__(
7480
max_reconnect_attempts: int = 5,
7581
reconnect_backoff: float = 2.0,
7682
) -> None:
83+
if not url.startswith("wss://"):
84+
raise ValueError("url must use the wss:// scheme")
7785
self._url = url
7886
super().__init__(
7987
mist_session,

0 commit comments

Comments
 (0)