"""Cross platform abstractions for inter-process communication On Unix, this uses AF_UNIX sockets. On Windows, this uses NamedPipes. """ import base64 import os import shutil import sys import tempfile from typing import Optional, Callable from typing_extensions import Final, Type from types import TracebackType if sys.platform == 'win32': # This may be private, but it is needed for IPC on Windows, and is basically stable import _winapi import ctypes _IPCHandle = int kernel32 = ctypes.windll.kernel32 DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers else: import socket _IPCHandle = socket.socket class IPCException(Exception): """Exception for IPC issues.""" pass class IPCBase: """Base class for communication between the dmypy client and server. This contains logic shared between the client and server, such as reading and writing. """ connection: _IPCHandle def __init__(self, name: str, timeout: Optional[float]) -> None: self.name = name self.timeout = timeout def read(self, size: int = 100000) -> bytes: """Read bytes from an IPC connection until its empty.""" bdata = bytearray() if sys.platform == 'win32': while True: ov, err = _winapi.ReadFile(self.connection, size, overlapped=True) try: if err == _winapi.ERROR_IO_PENDING: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE res = _winapi.WaitForSingleObject(ov.event, timeout) if res != _winapi.WAIT_OBJECT_0: raise IPCException(f"Bad result from I/O wait: {res}") except BaseException: ov.cancel() raise _, err = ov.GetOverlappedResult(True) more = ov.getbuffer() if more: bdata.extend(more) if err == 0: # we are done! break elif err == _winapi.ERROR_MORE_DATA: # read again continue elif err == _winapi.ERROR_OPERATION_ABORTED: raise IPCException("ReadFile operation aborted.") else: while True: more = self.connection.recv(size) if not more: break bdata.extend(more) return bytes(bdata) def write(self, data: bytes) -> None: """Write bytes to an IPC connection.""" if sys.platform == 'win32': try: ov, err = _winapi.WriteFile(self.connection, data, overlapped=True) # TODO: remove once typeshed supports Literal types assert isinstance(ov, _winapi.Overlapped) assert isinstance(err, int) try: if err == _winapi.ERROR_IO_PENDING: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE res = _winapi.WaitForSingleObject(ov.event, timeout) if res != _winapi.WAIT_OBJECT_0: raise IPCException(f"Bad result from I/O wait: {res}") elif err != 0: raise IPCException(f"Failed writing to pipe with error: {err}") except BaseException: ov.cancel() raise bytes_written, err = ov.GetOverlappedResult(True) assert err == 0, err assert bytes_written == len(data) except OSError as e: raise IPCException(f"Failed to write with error: {e.winerror}") from e else: self.connection.sendall(data) self.connection.shutdown(socket.SHUT_WR) def close(self) -> None: if sys.platform == 'win32': if self.connection != _winapi.NULL: _winapi.CloseHandle(self.connection) else: self.connection.close() class IPCClient(IPCBase): """The client side of an IPC connection.""" def __init__(self, name: str, timeout: Optional[float]) -> None: super().__init__(name, timeout) if sys.platform == 'win32': timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER try: _winapi.WaitNamedPipe(self.name, timeout) except FileNotFoundError as e: raise IPCException(f"The NamedPipe at {self.name} was not found.") from e except OSError as e: if e.winerror == _winapi.ERROR_SEM_TIMEOUT: raise IPCException("Timed out waiting for connection.") from e else: raise try: self.connection = _winapi.CreateFile( self.name, _winapi.GENERIC_READ | _winapi.GENERIC_WRITE, 0, _winapi.NULL, _winapi.OPEN_EXISTING, _winapi.FILE_FLAG_OVERLAPPED, _winapi.NULL, ) except OSError as e: if e.winerror == _winapi.ERROR_PIPE_BUSY: raise IPCException("The connection is busy.") from e else: raise _winapi.SetNamedPipeHandleState(self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None) else: self.connection = socket.socket(socket.AF_UNIX) self.connection.settimeout(timeout) self.connection.connect(name) def __enter__(self) -> 'IPCClient': return self def __exit__(self, exc_ty: 'Optional[Type[BaseException]]' = None, exc_val: Optional[BaseException] = None, exc_tb: Optional[TracebackType] = None, ) -> None: self.close() class IPCServer(IPCBase): BUFFER_SIZE: Final = 2 ** 16 def __init__(self, name: str, timeout: Optional[float] = None) -> None: if sys.platform == 'win32': name = r'\\.\pipe\{}-{}.pipe'.format( name, base64.urlsafe_b64encode(os.urandom(6)).decode()) else: name = f'{name}.sock' super().__init__(name, timeout) if sys.platform == 'win32': self.connection = _winapi.CreateNamedPipe(self.name, _winapi.PIPE_ACCESS_DUPLEX | _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE | _winapi.FILE_FLAG_OVERLAPPED, _winapi.PIPE_READMODE_MESSAGE | _winapi.PIPE_TYPE_MESSAGE | _winapi.PIPE_WAIT | 0x8, # PIPE_REJECT_REMOTE_CLIENTS 1, # one instance self.BUFFER_SIZE, self.BUFFER_SIZE, _winapi.NMPWAIT_WAIT_FOREVER, 0, # Use default security descriptor ) if self.connection == -1: # INVALID_HANDLE_VALUE err = _winapi.GetLastError() raise IPCException(f'Invalid handle to pipe: {err}') else: self.sock_directory = tempfile.mkdtemp() sockfile = os.path.join(self.sock_directory, self.name) self.sock = socket.socket(socket.AF_UNIX) self.sock.bind(sockfile) self.sock.listen(1) if timeout is not None: self.sock.settimeout(timeout) def __enter__(self) -> 'IPCServer': if sys.platform == 'win32': # NOTE: It is theoretically possible that this will hang forever if the # client never connects, though this can be "solved" by killing the server try: ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True) # TODO: remove once typeshed supports Literal types assert isinstance(ov, _winapi.Overlapped) except OSError as e: # Don't raise if the client already exists, or the client already connected if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA): raise else: try: timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE res = _winapi.WaitForSingleObject(ov.event, timeout) assert res == _winapi.WAIT_OBJECT_0 except BaseException: ov.cancel() _winapi.CloseHandle(self.connection) raise _, err = ov.GetOverlappedResult(True) assert err == 0 else: try: self.connection, _ = self.sock.accept() except socket.timeout as e: raise IPCException('The socket timed out') from e return self def __exit__(self, exc_ty: 'Optional[Type[BaseException]]' = None, exc_val: Optional[BaseException] = None, exc_tb: Optional[TracebackType] = None, ) -> None: if sys.platform == 'win32': try: # Wait for the client to finish reading the last write before disconnecting if not FlushFileBuffers(self.connection): raise IPCException("Failed to flush NamedPipe buffer," "maybe the client hung up?") finally: DisconnectNamedPipe(self.connection) else: self.close() def cleanup(self) -> None: if sys.platform == 'win32': self.close() else: shutil.rmtree(self.sock_directory) @property def connection_name(self) -> str: if sys.platform == 'win32': return self.name else: return self.sock.getsockname()