diff --git a/python/fusion_engine_client/applications/p1_capture.py b/python/fusion_engine_client/applications/p1_capture.py index 00eeeb6a..5da5892f 100755 --- a/python/fusion_engine_client/applications/p1_capture.py +++ b/python/fusion_engine_client/applications/p1_capture.py @@ -303,7 +303,7 @@ def _set_read_timeout(self) -> None: # If this is a TCP/UDP/UNIX socket, configure it for non-blocking reads. We'll apply a read timeout with # select() in the read loop. - if isinstance(self.input_transport, socket.socket): + if isinstance(self.input_transport, SocketTransport): self.input_transport.setblocking(0) # This function won't do anything if neither timestamp is enabled. enable_socket_timestamping( @@ -349,7 +349,7 @@ def process_input(self) -> None: try: # If this is a TCP/UDP socket, use select() to implement a read timeout so we can wake up # periodically and print status if there's no incoming data. - if isinstance(self.input_transport, socket.socket): + if isinstance(self.input_transport, SocketTransport): ready = select.select([self.input_transport], [], [], self.read_timeout_sec) if ready[0]: received_data, kernel_ts, hw_ts = recv(self.input_transport, self.read_size_bytes) diff --git a/python/fusion_engine_client/utils/socket_timestamping.py b/python/fusion_engine_client/utils/socket_timestamping.py index d178071d..786d761b 100755 --- a/python/fusion_engine_client/utils/socket_timestamping.py +++ b/python/fusion_engine_client/utils/socket_timestamping.py @@ -13,6 +13,8 @@ import sys from typing import BinaryIO, List, Optional, Tuple, Union +from .transport_utils import SocketTransport + _CMSG = Tuple[int, int, bytes] @@ -102,7 +104,8 @@ def parse_timestamps_from_ancdata(ancdata: List[_CMSG]) -> Tuple[Optional[float] return tuple(timestamps) -def enable_socket_timestamping(sock: socket.socket, enable_sw_timestamp: bool, enable_hw_timestamp: bool) -> bool: +def enable_socket_timestamping(sock: Union[socket.socket, SocketTransport, BinaryIO], + enable_sw_timestamp: bool, enable_hw_timestamp: bool) -> bool: '''! Enable kernel-level hardware or software timestamping of incoming socket data. @@ -112,7 +115,13 @@ def enable_socket_timestamping(sock: socket.socket, enable_sw_timestamp: bool, e @return `True` if timestamping is supported on the host OS. ''' - if sys.platform == "linux": + if isinstance(sock, SocketTransport): + sock = sock.socket + + # Handle non-sockets (websocket, BinaryIO (file), etc.) gracefully. + if not isinstance(sock, socket.socket): + return False + elif sys.platform == "linux": if enable_sw_timestamp or enable_hw_timestamp: flags = 0 if enable_sw_timestamp: @@ -127,7 +136,8 @@ def enable_socket_timestamping(sock: socket.socket, enable_sw_timestamp: bool, e return False -def recv(sock: Union[socket.socket, BinaryIO], buffer_size: int) -> Tuple[bytes, Optional[float], Optional[float]]: +def recv(sock: Union[socket.socket, SocketTransport, BinaryIO], buffer_size: int) -> \ + Tuple[bytes, Optional[float], Optional[float]]: '''! Receive data from the specified socket and capture timestamps, if enabled. @@ -139,6 +149,9 @@ def recv(sock: Union[socket.socket, BinaryIO], buffer_size: int) -> Tuple[bytes, - The kernel timestamp, if enabled - The hardware timestamp, if enabled ''' + if isinstance(sock, SocketTransport): + sock = sock.socket + # Handle non-sockets (websocket, BinaryIO (file), etc.) gracefully. if not isinstance(sock, socket.socket): received_data = sock.read(buffer_size) diff --git a/python/fusion_engine_client/utils/transport_utils.py b/python/fusion_engine_client/utils/transport_utils.py index bec18845..6e2efc15 100644 --- a/python/fusion_engine_client/utils/transport_utils.py +++ b/python/fusion_engine_client/utils/transport_utils.py @@ -147,6 +147,35 @@ def write(self, data: Union[bytes, bytearray]) -> int: raise RuntimeError('Output file not opened.') +class SocketTransport: + """! + @brief Socket wrapper class, protecting against multiple close() calls. + + All other member or function accesses are deferred to the underlying `socket.socket` instance. + """ + def __init__(self, *args, **kwargs): + self._socket = socket.socket(*args, **kwargs) + self._closed = False + + @property + def socket(self): + return self._socket + + def close(self): + if not self._closed: + self._closed = True + self._socket.close() + + def __getattr__(self, name): + return getattr(self._socket, name) + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + class WebsocketTransport: """! @brief Websocket wrapper class, mimicking the Python socket API. @@ -226,7 +255,7 @@ def __setattr__(self, item: str, value: Any) -> None: {TRANSPORT_HELP_OPTIONS} """ -TransportClass = Union[socket.socket, serial.Serial, WebsocketTransport, FileTransport] +TransportClass = Union[SocketTransport, serial.Serial, WebsocketTransport, FileTransport] def create_transport(descriptor: str, timeout_sec: float = None, print_func: Callable = None, mode: str = 'both', @@ -272,7 +301,7 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal if print_func is not None: print_func(f'Connecting to tcp://{ip_address}:{port}.') - transport = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + transport = SocketTransport(socket.AF_INET, socket.SOCK_STREAM) if timeout_sec is not None: transport.settimeout(timeout_sec) try: @@ -288,7 +317,7 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal if print_func is not None: print_func(f'Connecting to udp://:{port}.') - transport = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + transport = SocketTransport(socket.AF_INET, socket.SOCK_DGRAM) transport.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) if timeout_sec is not None: transport.settimeout(timeout_sec) @@ -324,7 +353,7 @@ def create_transport(descriptor: str, timeout_sec: float = None, print_func: Cal if print_func is not None: print_func(f'Connecting to unix://{path}.') - transport = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + transport = SocketTransport(socket.AF_UNIX, socket.SOCK_STREAM) if timeout_sec is not None: transport.settimeout(timeout_sec) transport.connect(path) @@ -375,7 +404,7 @@ def recv_from_transport(transport: TransportClass, size_bytes: int) -> bytes: @return A `bytes` array. ''' try: - if isinstance(transport, (socket.socket, WebsocketTransport)): + if isinstance(transport, (SocketTransport, WebsocketTransport)): return transport.recv(size_bytes) else: return transport.read(size_bytes) @@ -384,7 +413,7 @@ def recv_from_transport(transport: TransportClass, size_bytes: int) -> bytes: def set_read_timeout(transport: TransportClass, timeout_sec: float): - if isinstance(transport, socket.socket): + if isinstance(transport, SocketTransport): if timeout_sec == 0: transport.setblocking(False) else: