usse/scrape/venv/lib/python3.10/site-packages/cattrs/strategies/_unions.py

236 lines
8.2 KiB
Python
Raw Normal View History

2023-12-22 14:26:01 +00:00
from collections import defaultdict
from typing import Any, Callable, Dict, Optional, Type, Union
from attrs import NOTHING
from cattrs import BaseConverter, Converter
from cattrs._compat import get_newtype_base, is_literal, is_subclass, is_union_type
__all__ = [
"default_tag_generator",
"configure_tagged_union",
"configure_union_passthrough",
]
def default_tag_generator(typ: Type) -> str:
"""Return the class name."""
return typ.__name__
def configure_tagged_union(
union: Any,
converter: Converter,
tag_generator: Callable[[Type], str] = default_tag_generator,
tag_name: str = "_type",
default: Optional[Type] = NOTHING,
) -> None:
"""
Configure the converter so that `union` (which should be a union) is
un/structured with the help of an additional piece of data in the
unstructured payload, the tag.
:param converter: The converter to apply the strategy to.
:param tag_generator: A `tag_generator` function is used to map each
member of the union to a tag, which is then included in the
unstructured payload. The default tag generator returns the name of
the class.
:param tag_name: The key under which the tag will be set in the
unstructured payload. By default, `'_type'`.
:param default: An optional class to be used if the tag information
is not present when structuring.
The tagged union strategy currently only works with the dict
un/structuring base strategy.
.. versionadded:: 23.1.0
"""
args = union.__args__
tag_to_hook = {}
exact_cl_unstruct_hooks = {}
for cl in args:
tag = tag_generator(cl)
struct_handler = converter._structure_func.dispatch(cl)
unstruct_handler = converter._unstructure_func.dispatch(cl)
def structure_union_member(val: dict, _cl=cl, _h=struct_handler) -> cl:
return _h(val, _cl)
def unstructure_union_member(val: union, _h=unstruct_handler) -> dict:
return _h(val)
tag_to_hook[tag] = structure_union_member
exact_cl_unstruct_hooks[cl] = unstructure_union_member
cl_to_tag = {cl: tag_generator(cl) for cl in args}
if default is not NOTHING:
default_handler = converter._structure_func.dispatch(default)
def structure_default(val: dict, _cl=default, _h=default_handler):
return _h(val, _cl)
tag_to_hook = defaultdict(lambda: structure_default, tag_to_hook)
cl_to_tag = defaultdict(lambda: default, cl_to_tag)
def unstructure_tagged_union(
val: union,
_exact_cl_unstruct_hooks=exact_cl_unstruct_hooks,
_cl_to_tag=cl_to_tag,
_tag_name=tag_name,
) -> Dict:
res = _exact_cl_unstruct_hooks[val.__class__](val)
res[_tag_name] = _cl_to_tag[val.__class__]
return res
if default is NOTHING:
def structure_tagged_union(
val: dict, _, _tag_to_cl=tag_to_hook, _tag_name=tag_name
) -> union:
val = val.copy()
return _tag_to_cl[val.pop(_tag_name)](val)
else:
def structure_tagged_union(
val: dict,
_,
_tag_to_hook=tag_to_hook,
_tag_name=tag_name,
_dh=default_handler,
_default=default,
) -> union:
if _tag_name in val:
val = val.copy()
return _tag_to_hook[val.pop(_tag_name)](val)
return _dh(val, _default)
converter.register_unstructure_hook(union, unstructure_tagged_union)
converter.register_structure_hook(union, structure_tagged_union)
def configure_union_passthrough(union: Any, converter: BaseConverter) -> None:
"""
Configure the converter to support validating and passing through unions of the
provided types and their subsets.
For example, all mature JSON libraries natively support producing unions of ints,
floats, Nones, and strings. Using this strategy, a converter can be configured
to efficiently validate and pass through unions containing these types.
The most important point is that another library (in this example the JSON
library) handles producing the union, and the converter is configured to just
validate it.
Literals of provided types are also supported, and are checked by value.
NewTypes of provided types are also supported.
The strategy is designed to be O(1) in execution time, and independent of the
ordering of types in the union.
If the union contains a class and one or more of its subclasses, the subclasses
will also be included when validating the superclass.
.. versionadded:: 23.2.0
"""
args = set(union.__args__)
def make_structure_native_union(exact_type: Any) -> Callable:
# `exact_type` is likely to be a subset of the entire configured union (`args`).
literal_values = {
v for t in exact_type.__args__ if is_literal(t) for v in t.__args__
}
# We have no idea what the actual type of `val` will be, so we can't
# use it blindly with an `in` check since it might not be hashable.
# So we do an additional check when handling literals.
# Note: do no use `literal_values` here, since {0, False} gets reduced to {0}
literal_classes = {
v.__class__
for t in exact_type.__args__
if is_literal(t)
for v in t.__args__
}
non_literal_classes = {
get_newtype_base(t) or t
for t in exact_type.__args__
if not is_literal(t) and ((get_newtype_base(t) or t) in args)
}
# We augment the set of allowed classes with any configured subclasses of
# the exact subclasses.
non_literal_classes |= {
a for a in args if any(is_subclass(a, c) for c in non_literal_classes)
}
# We check for spillover - union types not handled by the strategy.
# If spillover exists and we fail to validate our types, we call
# further into the converter with the rest.
spillover = {
a
for a in exact_type.__args__
if (get_newtype_base(a) or a) not in non_literal_classes
and not is_literal(a)
}
if spillover:
spillover_type = (
Union[tuple(spillover)] if len(spillover) > 1 else next(iter(spillover))
)
def structure_native_union(
val: Any,
_: Any,
classes=non_literal_classes,
vals=literal_values,
converter=converter,
spillover=spillover_type,
) -> exact_type:
if val.__class__ in literal_classes and val in vals:
return val
if val.__class__ in classes:
return val
return converter.structure(val, spillover)
else:
def structure_native_union(
val: Any, _: Any, classes=non_literal_classes, vals=literal_values
) -> exact_type:
if val.__class__ in literal_classes and val in vals:
return val
if val.__class__ in classes:
return val
raise TypeError(f"{val} ({val.__class__}) not part of {_}")
return structure_native_union
def contains_native_union(exact_type: Any) -> bool:
"""Can we handle this type?"""
if is_union_type(exact_type):
type_args = set(exact_type.__args__)
# We special case optionals, since they are very common
# and are handled a little more efficiently by default.
if len(type_args) == 2 and type(None) in type_args:
return False
literal_classes = {
lit_arg.__class__
for t in type_args
if is_literal(t)
for lit_arg in t.__args__
}
non_literal_types = {
get_newtype_base(t) or t for t in type_args if not is_literal(t)
}
return (literal_classes | non_literal_types) & args
return False
converter.register_structure_hook_factory(
contains_native_union, make_structure_native_union
)