from __future__ import annotations import inspect import sys from collections.abc import Callable, Iterable, Mapping from contextlib import AbstractContextManager from types import TracebackType from typing import TYPE_CHECKING, Any if sys.version_info < (3, 11): from ._exceptions import BaseExceptionGroup if TYPE_CHECKING: _Handler = Callable[[BaseException], Any] class _Catcher: def __init__(self, handler_map: Mapping[tuple[type[BaseException], ...], _Handler]): self._handler_map = handler_map def __enter__(self) -> None: pass def __exit__( self, etype: type[BaseException] | None, exc: BaseException | None, tb: TracebackType | None, ) -> bool: if exc is not None: unhandled = self.handle_exception(exc) if unhandled is exc: return False elif unhandled is None: return True else: if isinstance(exc, BaseExceptionGroup): try: raise unhandled from exc.__cause__ except BaseExceptionGroup: # Change __context__ to __cause__ because Python 3.11 does this # too unhandled.__context__ = exc.__cause__ raise raise unhandled from exc return False def handle_exception(self, exc: BaseException) -> BaseException | None: excgroup: BaseExceptionGroup | None if isinstance(exc, BaseExceptionGroup): excgroup = exc else: excgroup = BaseExceptionGroup("", [exc]) new_exceptions: list[BaseException] = [] for exc_types, handler in self._handler_map.items(): matched, excgroup = excgroup.split(exc_types) if matched: try: try: raise matched except BaseExceptionGroup: result = handler(matched) except BaseExceptionGroup as new_exc: if new_exc is matched: new_exceptions.append(new_exc) else: new_exceptions.extend(new_exc.exceptions) except BaseException as new_exc: new_exceptions.append(new_exc) else: if inspect.iscoroutine(result): raise TypeError( f"Error trying to handle {matched!r} with {handler!r}. " "Exception handler must be a sync function." ) from exc if not excgroup: break if new_exceptions: if len(new_exceptions) == 1: return new_exceptions[0] return BaseExceptionGroup("", new_exceptions) elif ( excgroup and len(excgroup.exceptions) == 1 and excgroup.exceptions[0] is exc ): return exc else: return excgroup def catch( __handlers: Mapping[type[BaseException] | Iterable[type[BaseException]], _Handler] ) -> AbstractContextManager[None]: if not isinstance(__handlers, Mapping): raise TypeError("the argument must be a mapping") handler_map: dict[ tuple[type[BaseException], ...], Callable[[BaseExceptionGroup]] ] = {} for type_or_iterable, handler in __handlers.items(): iterable: tuple[type[BaseException]] if isinstance(type_or_iterable, type) and issubclass( type_or_iterable, BaseException ): iterable = (type_or_iterable,) elif isinstance(type_or_iterable, Iterable): iterable = tuple(type_or_iterable) else: raise TypeError( "each key must be either an exception classes or an iterable thereof" ) if not callable(handler): raise TypeError("handlers must be callable") for exc_type in iterable: if not isinstance(exc_type, type) or not issubclass( exc_type, BaseException ): raise TypeError( "each key must be either an exception classes or an iterable " "thereof" ) if issubclass(exc_type, BaseExceptionGroup): raise TypeError( "catching ExceptionGroup with catch() is not allowed. " "Use except instead." ) handler_map[iterable] = handler return _Catcher(handler_map)