236 lines
8.2 KiB
Python
236 lines
8.2 KiB
Python
from collections import defaultdict
|
|
from typing import Any, Callable, Dict, Optional, Type, Union
|
|
|
|
from attrs import NOTHING
|
|
|
|
from cattrs import BaseConverter, Converter
|
|
from cattrs._compat import get_newtype_base, is_literal, is_subclass, is_union_type
|
|
|
|
__all__ = [
|
|
"default_tag_generator",
|
|
"configure_tagged_union",
|
|
"configure_union_passthrough",
|
|
]
|
|
|
|
|
|
def default_tag_generator(typ: Type) -> str:
|
|
"""Return the class name."""
|
|
return typ.__name__
|
|
|
|
|
|
def configure_tagged_union(
|
|
union: Any,
|
|
converter: Converter,
|
|
tag_generator: Callable[[Type], str] = default_tag_generator,
|
|
tag_name: str = "_type",
|
|
default: Optional[Type] = NOTHING,
|
|
) -> None:
|
|
"""
|
|
Configure the converter so that `union` (which should be a union) is
|
|
un/structured with the help of an additional piece of data in the
|
|
unstructured payload, the tag.
|
|
|
|
:param converter: The converter to apply the strategy to.
|
|
:param tag_generator: A `tag_generator` function is used to map each
|
|
member of the union to a tag, which is then included in the
|
|
unstructured payload. The default tag generator returns the name of
|
|
the class.
|
|
:param tag_name: The key under which the tag will be set in the
|
|
unstructured payload. By default, `'_type'`.
|
|
:param default: An optional class to be used if the tag information
|
|
is not present when structuring.
|
|
|
|
The tagged union strategy currently only works with the dict
|
|
un/structuring base strategy.
|
|
|
|
.. versionadded:: 23.1.0
|
|
"""
|
|
args = union.__args__
|
|
tag_to_hook = {}
|
|
exact_cl_unstruct_hooks = {}
|
|
for cl in args:
|
|
tag = tag_generator(cl)
|
|
struct_handler = converter._structure_func.dispatch(cl)
|
|
unstruct_handler = converter._unstructure_func.dispatch(cl)
|
|
|
|
def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
|
|
return _h(val, _cl)
|
|
|
|
def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
|
|
return _h(val)
|
|
|
|
tag_to_hook[tag] = structure_union_member
|
|
exact_cl_unstruct_hooks[cl] = unstructure_union_member
|
|
|
|
cl_to_tag = {cl: tag_generator(cl) for cl in args}
|
|
|
|
if default is not NOTHING:
|
|
default_handler = converter._structure_func.dispatch(default)
|
|
|
|
def structure_default(val: dict, _cl=default, _h=default_handler):
|
|
return _h(val, _cl)
|
|
|
|
tag_to_hook = defaultdict(lambda: structure_default, tag_to_hook)
|
|
cl_to_tag = defaultdict(lambda: default, cl_to_tag)
|
|
|
|
def unstructure_tagged_union(
|
|
val: union,
|
|
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
|
|
_cl_to_tag=cl_to_tag,
|
|
_tag_name=tag_name,
|
|
) -> Dict:
|
|
res = _exact_cl_unstruct_hooks[val.__class__](val)
|
|
res[_tag_name] = _cl_to_tag[val.__class__]
|
|
return res
|
|
|
|
if default is NOTHING:
|
|
|
|
def structure_tagged_union(
|
|
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
|
|
) -> union:
|
|
val = val.copy()
|
|
return _tag_to_cl[val.pop(_tag_name)](val)
|
|
|
|
else:
|
|
|
|
def structure_tagged_union(
|
|
val: dict,
|
|
_,
|
|
_tag_to_hook=tag_to_hook,
|
|
_tag_name=tag_name,
|
|
_dh=default_handler,
|
|
_default=default,
|
|
) -> union:
|
|
if _tag_name in val:
|
|
val = val.copy()
|
|
return _tag_to_hook[val.pop(_tag_name)](val)
|
|
return _dh(val, _default)
|
|
|
|
converter.register_unstructure_hook(union, unstructure_tagged_union)
|
|
converter.register_structure_hook(union, structure_tagged_union)
|
|
|
|
|
|
def configure_union_passthrough(union: Any, converter: BaseConverter) -> None:
|
|
"""
|
|
Configure the converter to support validating and passing through unions of the
|
|
provided types and their subsets.
|
|
|
|
For example, all mature JSON libraries natively support producing unions of ints,
|
|
floats, Nones, and strings. Using this strategy, a converter can be configured
|
|
to efficiently validate and pass through unions containing these types.
|
|
|
|
The most important point is that another library (in this example the JSON
|
|
library) handles producing the union, and the converter is configured to just
|
|
validate it.
|
|
|
|
Literals of provided types are also supported, and are checked by value.
|
|
|
|
NewTypes of provided types are also supported.
|
|
|
|
The strategy is designed to be O(1) in execution time, and independent of the
|
|
ordering of types in the union.
|
|
|
|
If the union contains a class and one or more of its subclasses, the subclasses
|
|
will also be included when validating the superclass.
|
|
|
|
.. versionadded:: 23.2.0
|
|
"""
|
|
args = set(union.__args__)
|
|
|
|
def make_structure_native_union(exact_type: Any) -> Callable:
|
|
# `exact_type` is likely to be a subset of the entire configured union (`args`).
|
|
literal_values = {
|
|
v for t in exact_type.__args__ if is_literal(t) for v in t.__args__
|
|
}
|
|
|
|
# We have no idea what the actual type of `val` will be, so we can't
|
|
# use it blindly with an `in` check since it might not be hashable.
|
|
# So we do an additional check when handling literals.
|
|
# Note: do no use `literal_values` here, since {0, False} gets reduced to {0}
|
|
literal_classes = {
|
|
v.__class__
|
|
for t in exact_type.__args__
|
|
if is_literal(t)
|
|
for v in t.__args__
|
|
}
|
|
|
|
non_literal_classes = {
|
|
get_newtype_base(t) or t
|
|
for t in exact_type.__args__
|
|
if not is_literal(t) and ((get_newtype_base(t) or t) in args)
|
|
}
|
|
|
|
# We augment the set of allowed classes with any configured subclasses of
|
|
# the exact subclasses.
|
|
non_literal_classes |= {
|
|
a for a in args if any(is_subclass(a, c) for c in non_literal_classes)
|
|
}
|
|
|
|
# We check for spillover - union types not handled by the strategy.
|
|
# If spillover exists and we fail to validate our types, we call
|
|
# further into the converter with the rest.
|
|
spillover = {
|
|
a
|
|
for a in exact_type.__args__
|
|
if (get_newtype_base(a) or a) not in non_literal_classes
|
|
and not is_literal(a)
|
|
}
|
|
|
|
if spillover:
|
|
spillover_type = (
|
|
Union[tuple(spillover)] if len(spillover) > 1 else next(iter(spillover))
|
|
)
|
|
|
|
def structure_native_union(
|
|
val: Any,
|
|
_: Any,
|
|
classes=non_literal_classes,
|
|
vals=literal_values,
|
|
converter=converter,
|
|
spillover=spillover_type,
|
|
) -> exact_type:
|
|
if val.__class__ in literal_classes and val in vals:
|
|
return val
|
|
if val.__class__ in classes:
|
|
return val
|
|
return converter.structure(val, spillover)
|
|
|
|
else:
|
|
|
|
def structure_native_union(
|
|
val: Any, _: Any, classes=non_literal_classes, vals=literal_values
|
|
) -> exact_type:
|
|
if val.__class__ in literal_classes and val in vals:
|
|
return val
|
|
if val.__class__ in classes:
|
|
return val
|
|
raise TypeError(f"{val} ({val.__class__}) not part of {_}")
|
|
|
|
return structure_native_union
|
|
|
|
def contains_native_union(exact_type: Any) -> bool:
|
|
"""Can we handle this type?"""
|
|
if is_union_type(exact_type):
|
|
type_args = set(exact_type.__args__)
|
|
# We special case optionals, since they are very common
|
|
# and are handled a little more efficiently by default.
|
|
if len(type_args) == 2 and type(None) in type_args:
|
|
return False
|
|
|
|
literal_classes = {
|
|
lit_arg.__class__
|
|
for t in type_args
|
|
if is_literal(t)
|
|
for lit_arg in t.__args__
|
|
}
|
|
non_literal_types = {
|
|
get_newtype_base(t) or t for t in type_args if not is_literal(t)
|
|
}
|
|
|
|
return (literal_classes | non_literal_types) & args
|
|
return False
|
|
|
|
converter.register_structure_hook_factory(
|
|
contains_native_union, make_structure_native_union
|
|
)
|