269 lines
10 KiB
Python
269 lines
10 KiB
Python
"""Cross platform abstractions for inter-process communication
|
|
|
|
On Unix, this uses AF_UNIX sockets.
|
|
On Windows, this uses NamedPipes.
|
|
"""
|
|
|
|
import base64
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
|
|
from typing import Optional, Callable
|
|
from typing_extensions import Final, Type
|
|
|
|
from types import TracebackType
|
|
|
|
if sys.platform == 'win32':
|
|
# This may be private, but it is needed for IPC on Windows, and is basically stable
|
|
import _winapi
|
|
import ctypes
|
|
|
|
_IPCHandle = int
|
|
|
|
kernel32 = ctypes.windll.kernel32
|
|
DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe
|
|
FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers
|
|
else:
|
|
import socket
|
|
_IPCHandle = socket.socket
|
|
|
|
|
|
class IPCException(Exception):
|
|
"""Exception for IPC issues."""
|
|
pass
|
|
|
|
|
|
class IPCBase:
|
|
"""Base class for communication between the dmypy client and server.
|
|
|
|
This contains logic shared between the client and server, such as reading
|
|
and writing.
|
|
"""
|
|
|
|
connection: _IPCHandle
|
|
|
|
def __init__(self, name: str, timeout: Optional[float]) -> None:
|
|
self.name = name
|
|
self.timeout = timeout
|
|
|
|
def read(self, size: int = 100000) -> bytes:
|
|
"""Read bytes from an IPC connection until its empty."""
|
|
bdata = bytearray()
|
|
if sys.platform == 'win32':
|
|
while True:
|
|
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
|
|
try:
|
|
if err == _winapi.ERROR_IO_PENDING:
|
|
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
|
|
res = _winapi.WaitForSingleObject(ov.event, timeout)
|
|
if res != _winapi.WAIT_OBJECT_0:
|
|
raise IPCException(f"Bad result from I/O wait: {res}")
|
|
except BaseException:
|
|
ov.cancel()
|
|
raise
|
|
_, err = ov.GetOverlappedResult(True)
|
|
more = ov.getbuffer()
|
|
if more:
|
|
bdata.extend(more)
|
|
if err == 0:
|
|
# we are done!
|
|
break
|
|
elif err == _winapi.ERROR_MORE_DATA:
|
|
# read again
|
|
continue
|
|
elif err == _winapi.ERROR_OPERATION_ABORTED:
|
|
raise IPCException("ReadFile operation aborted.")
|
|
else:
|
|
while True:
|
|
more = self.connection.recv(size)
|
|
if not more:
|
|
break
|
|
bdata.extend(more)
|
|
return bytes(bdata)
|
|
|
|
def write(self, data: bytes) -> None:
|
|
"""Write bytes to an IPC connection."""
|
|
if sys.platform == 'win32':
|
|
try:
|
|
ov, err = _winapi.WriteFile(self.connection, data, overlapped=True)
|
|
# TODO: remove once typeshed supports Literal types
|
|
assert isinstance(ov, _winapi.Overlapped)
|
|
assert isinstance(err, int)
|
|
try:
|
|
if err == _winapi.ERROR_IO_PENDING:
|
|
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
|
|
res = _winapi.WaitForSingleObject(ov.event, timeout)
|
|
if res != _winapi.WAIT_OBJECT_0:
|
|
raise IPCException(f"Bad result from I/O wait: {res}")
|
|
elif err != 0:
|
|
raise IPCException(f"Failed writing to pipe with error: {err}")
|
|
except BaseException:
|
|
ov.cancel()
|
|
raise
|
|
bytes_written, err = ov.GetOverlappedResult(True)
|
|
assert err == 0, err
|
|
assert bytes_written == len(data)
|
|
except OSError as e:
|
|
raise IPCException(f"Failed to write with error: {e.winerror}") from e
|
|
else:
|
|
self.connection.sendall(data)
|
|
self.connection.shutdown(socket.SHUT_WR)
|
|
|
|
def close(self) -> None:
|
|
if sys.platform == 'win32':
|
|
if self.connection != _winapi.NULL:
|
|
_winapi.CloseHandle(self.connection)
|
|
else:
|
|
self.connection.close()
|
|
|
|
|
|
class IPCClient(IPCBase):
|
|
"""The client side of an IPC connection."""
|
|
|
|
def __init__(self, name: str, timeout: Optional[float]) -> None:
|
|
super().__init__(name, timeout)
|
|
if sys.platform == 'win32':
|
|
timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
|
|
try:
|
|
_winapi.WaitNamedPipe(self.name, timeout)
|
|
except FileNotFoundError as e:
|
|
raise IPCException(f"The NamedPipe at {self.name} was not found.") from e
|
|
except OSError as e:
|
|
if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
|
|
raise IPCException("Timed out waiting for connection.") from e
|
|
else:
|
|
raise
|
|
try:
|
|
self.connection = _winapi.CreateFile(
|
|
self.name,
|
|
_winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
|
|
0,
|
|
_winapi.NULL,
|
|
_winapi.OPEN_EXISTING,
|
|
_winapi.FILE_FLAG_OVERLAPPED,
|
|
_winapi.NULL,
|
|
)
|
|
except OSError as e:
|
|
if e.winerror == _winapi.ERROR_PIPE_BUSY:
|
|
raise IPCException("The connection is busy.") from e
|
|
else:
|
|
raise
|
|
_winapi.SetNamedPipeHandleState(self.connection,
|
|
_winapi.PIPE_READMODE_MESSAGE,
|
|
None,
|
|
None)
|
|
else:
|
|
self.connection = socket.socket(socket.AF_UNIX)
|
|
self.connection.settimeout(timeout)
|
|
self.connection.connect(name)
|
|
|
|
def __enter__(self) -> 'IPCClient':
|
|
return self
|
|
|
|
def __exit__(self,
|
|
exc_ty: 'Optional[Type[BaseException]]' = None,
|
|
exc_val: Optional[BaseException] = None,
|
|
exc_tb: Optional[TracebackType] = None,
|
|
) -> None:
|
|
self.close()
|
|
|
|
|
|
class IPCServer(IPCBase):
|
|
|
|
BUFFER_SIZE: Final = 2 ** 16
|
|
|
|
def __init__(self, name: str, timeout: Optional[float] = None) -> None:
|
|
if sys.platform == 'win32':
|
|
name = r'\\.\pipe\{}-{}.pipe'.format(
|
|
name, base64.urlsafe_b64encode(os.urandom(6)).decode())
|
|
else:
|
|
name = f'{name}.sock'
|
|
super().__init__(name, timeout)
|
|
if sys.platform == 'win32':
|
|
self.connection = _winapi.CreateNamedPipe(self.name,
|
|
_winapi.PIPE_ACCESS_DUPLEX
|
|
| _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
|
|
| _winapi.FILE_FLAG_OVERLAPPED,
|
|
_winapi.PIPE_READMODE_MESSAGE
|
|
| _winapi.PIPE_TYPE_MESSAGE
|
|
| _winapi.PIPE_WAIT
|
|
| 0x8, # PIPE_REJECT_REMOTE_CLIENTS
|
|
1, # one instance
|
|
self.BUFFER_SIZE,
|
|
self.BUFFER_SIZE,
|
|
_winapi.NMPWAIT_WAIT_FOREVER,
|
|
0, # Use default security descriptor
|
|
)
|
|
if self.connection == -1: # INVALID_HANDLE_VALUE
|
|
err = _winapi.GetLastError()
|
|
raise IPCException(f'Invalid handle to pipe: {err}')
|
|
else:
|
|
self.sock_directory = tempfile.mkdtemp()
|
|
sockfile = os.path.join(self.sock_directory, self.name)
|
|
self.sock = socket.socket(socket.AF_UNIX)
|
|
self.sock.bind(sockfile)
|
|
self.sock.listen(1)
|
|
if timeout is not None:
|
|
self.sock.settimeout(timeout)
|
|
|
|
def __enter__(self) -> 'IPCServer':
|
|
if sys.platform == 'win32':
|
|
# NOTE: It is theoretically possible that this will hang forever if the
|
|
# client never connects, though this can be "solved" by killing the server
|
|
try:
|
|
ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
|
|
# TODO: remove once typeshed supports Literal types
|
|
assert isinstance(ov, _winapi.Overlapped)
|
|
except OSError as e:
|
|
# Don't raise if the client already exists, or the client already connected
|
|
if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
|
|
raise
|
|
else:
|
|
try:
|
|
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
|
|
res = _winapi.WaitForSingleObject(ov.event, timeout)
|
|
assert res == _winapi.WAIT_OBJECT_0
|
|
except BaseException:
|
|
ov.cancel()
|
|
_winapi.CloseHandle(self.connection)
|
|
raise
|
|
_, err = ov.GetOverlappedResult(True)
|
|
assert err == 0
|
|
else:
|
|
try:
|
|
self.connection, _ = self.sock.accept()
|
|
except socket.timeout as e:
|
|
raise IPCException('The socket timed out') from e
|
|
return self
|
|
|
|
def __exit__(self,
|
|
exc_ty: 'Optional[Type[BaseException]]' = None,
|
|
exc_val: Optional[BaseException] = None,
|
|
exc_tb: Optional[TracebackType] = None,
|
|
) -> None:
|
|
if sys.platform == 'win32':
|
|
try:
|
|
# Wait for the client to finish reading the last write before disconnecting
|
|
if not FlushFileBuffers(self.connection):
|
|
raise IPCException("Failed to flush NamedPipe buffer,"
|
|
"maybe the client hung up?")
|
|
finally:
|
|
DisconnectNamedPipe(self.connection)
|
|
else:
|
|
self.close()
|
|
|
|
def cleanup(self) -> None:
|
|
if sys.platform == 'win32':
|
|
self.close()
|
|
else:
|
|
shutil.rmtree(self.sock_directory)
|
|
|
|
@property
|
|
def connection_name(self) -> str:
|
|
if sys.platform == 'win32':
|
|
return self.name
|
|
else:
|
|
return self.sock.getsockname()
|