280 lines
9.3 KiB
Python
280 lines
9.3 KiB
Python
from typing import Dict, List, Union, Tuple, Any, cast
|
|
|
|
from typing_extensions import Final
|
|
|
|
|
|
# Supported Python literal types. All tuple items must have supported
|
|
# literal types as well, but we can't represent the type precisely.
|
|
LiteralValue = Union[str, bytes, int, bool, float, complex, Tuple[object, ...], None]
|
|
|
|
|
|
# Some literals are singletons and handled specially (None, False and True)
|
|
NUM_SINGLETONS: Final = 3
|
|
|
|
|
|
class Literals:
|
|
"""Collection of literal values used in a compilation group and related helpers."""
|
|
|
|
def __init__(self) -> None:
|
|
# Each dict maps value to literal index (0, 1, ...)
|
|
self.str_literals: Dict[str, int] = {}
|
|
self.bytes_literals: Dict[bytes, int] = {}
|
|
self.int_literals: Dict[int, int] = {}
|
|
self.float_literals: Dict[float, int] = {}
|
|
self.complex_literals: Dict[complex, int] = {}
|
|
self.tuple_literals: Dict[Tuple[object, ...], int] = {}
|
|
|
|
def record_literal(self, value: LiteralValue) -> None:
|
|
"""Ensure that the literal value is available in generated code."""
|
|
if value is None or value is True or value is False:
|
|
# These are special cased and always present
|
|
return
|
|
if isinstance(value, str):
|
|
str_literals = self.str_literals
|
|
if value not in str_literals:
|
|
str_literals[value] = len(str_literals)
|
|
elif isinstance(value, bytes):
|
|
bytes_literals = self.bytes_literals
|
|
if value not in bytes_literals:
|
|
bytes_literals[value] = len(bytes_literals)
|
|
elif isinstance(value, int):
|
|
int_literals = self.int_literals
|
|
if value not in int_literals:
|
|
int_literals[value] = len(int_literals)
|
|
elif isinstance(value, float):
|
|
float_literals = self.float_literals
|
|
if value not in float_literals:
|
|
float_literals[value] = len(float_literals)
|
|
elif isinstance(value, complex):
|
|
complex_literals = self.complex_literals
|
|
if value not in complex_literals:
|
|
complex_literals[value] = len(complex_literals)
|
|
elif isinstance(value, tuple):
|
|
tuple_literals = self.tuple_literals
|
|
if value not in tuple_literals:
|
|
for item in value:
|
|
self.record_literal(cast(Any, item))
|
|
tuple_literals[value] = len(tuple_literals)
|
|
else:
|
|
assert False, 'invalid literal: %r' % value
|
|
|
|
def literal_index(self, value: LiteralValue) -> int:
|
|
"""Return the index to the literals array for given value."""
|
|
# The array contains first None and booleans, followed by all str values,
|
|
# followed by bytes values, etc.
|
|
if value is None:
|
|
return 0
|
|
elif value is False:
|
|
return 1
|
|
elif value is True:
|
|
return 2
|
|
n = NUM_SINGLETONS
|
|
if isinstance(value, str):
|
|
return n + self.str_literals[value]
|
|
n += len(self.str_literals)
|
|
if isinstance(value, bytes):
|
|
return n + self.bytes_literals[value]
|
|
n += len(self.bytes_literals)
|
|
if isinstance(value, int):
|
|
return n + self.int_literals[value]
|
|
n += len(self.int_literals)
|
|
if isinstance(value, float):
|
|
return n + self.float_literals[value]
|
|
n += len(self.float_literals)
|
|
if isinstance(value, complex):
|
|
return n + self.complex_literals[value]
|
|
n += len(self.complex_literals)
|
|
if isinstance(value, tuple):
|
|
return n + self.tuple_literals[value]
|
|
assert False, 'invalid literal: %r' % value
|
|
|
|
def num_literals(self) -> int:
|
|
# The first three are for None, True and False
|
|
return (NUM_SINGLETONS + len(self.str_literals) + len(self.bytes_literals) +
|
|
len(self.int_literals) + len(self.float_literals) + len(self.complex_literals) +
|
|
len(self.tuple_literals))
|
|
|
|
# The following methods return the C encodings of literal values
|
|
# of different types
|
|
|
|
def encoded_str_values(self) -> List[bytes]:
|
|
return _encode_str_values(self.str_literals)
|
|
|
|
def encoded_int_values(self) -> List[bytes]:
|
|
return _encode_int_values(self.int_literals)
|
|
|
|
def encoded_bytes_values(self) -> List[bytes]:
|
|
return _encode_bytes_values(self.bytes_literals)
|
|
|
|
def encoded_float_values(self) -> List[str]:
|
|
return _encode_float_values(self.float_literals)
|
|
|
|
def encoded_complex_values(self) -> List[str]:
|
|
return _encode_complex_values(self.complex_literals)
|
|
|
|
def encoded_tuple_values(self) -> List[str]:
|
|
"""Encode tuple values into a C array.
|
|
|
|
The format of the result is like this:
|
|
|
|
<number of tuples>
|
|
<length of the first tuple>
|
|
<literal index of first item>
|
|
...
|
|
<literal index of last item>
|
|
<length of the second tuple>
|
|
...
|
|
"""
|
|
values = self.tuple_literals
|
|
value_by_index = {}
|
|
for value, index in values.items():
|
|
value_by_index[index] = value
|
|
result = []
|
|
num = len(values)
|
|
result.append(str(num))
|
|
for i in range(num):
|
|
value = value_by_index[i]
|
|
result.append(str(len(value)))
|
|
for item in value:
|
|
index = self.literal_index(cast(Any, item))
|
|
result.append(str(index))
|
|
return result
|
|
|
|
|
|
def _encode_str_values(values: Dict[str, int]) -> List[bytes]:
|
|
value_by_index = {}
|
|
for value, index in values.items():
|
|
value_by_index[index] = value
|
|
result = []
|
|
line: List[bytes] = []
|
|
line_len = 0
|
|
for i in range(len(values)):
|
|
value = value_by_index[i]
|
|
c_literal = format_str_literal(value)
|
|
c_len = len(c_literal)
|
|
if line_len > 0 and line_len + c_len > 70:
|
|
result.append(format_int(len(line)) + b''.join(line))
|
|
line = []
|
|
line_len = 0
|
|
line.append(c_literal)
|
|
line_len += c_len
|
|
if line:
|
|
result.append(format_int(len(line)) + b''.join(line))
|
|
result.append(b'')
|
|
return result
|
|
|
|
|
|
def _encode_bytes_values(values: Dict[bytes, int]) -> List[bytes]:
|
|
value_by_index = {}
|
|
for value, index in values.items():
|
|
value_by_index[index] = value
|
|
result = []
|
|
line: List[bytes] = []
|
|
line_len = 0
|
|
for i in range(len(values)):
|
|
value = value_by_index[i]
|
|
c_init = format_int(len(value))
|
|
c_len = len(c_init) + len(value)
|
|
if line_len > 0 and line_len + c_len > 70:
|
|
result.append(format_int(len(line)) + b''.join(line))
|
|
line = []
|
|
line_len = 0
|
|
line.append(c_init + value)
|
|
line_len += c_len
|
|
if line:
|
|
result.append(format_int(len(line)) + b''.join(line))
|
|
result.append(b'')
|
|
return result
|
|
|
|
|
|
def format_int(n: int) -> bytes:
|
|
"""Format an integer using a variable-length binary encoding."""
|
|
if n < 128:
|
|
a = [n]
|
|
else:
|
|
a = []
|
|
while n > 0:
|
|
a.insert(0, n & 0x7f)
|
|
n >>= 7
|
|
for i in range(len(a) - 1):
|
|
# If the highest bit is set, more 7-bit digits follow
|
|
a[i] |= 0x80
|
|
return bytes(a)
|
|
|
|
|
|
def format_str_literal(s: str) -> bytes:
|
|
utf8 = s.encode('utf-8')
|
|
return format_int(len(utf8)) + utf8
|
|
|
|
|
|
def _encode_int_values(values: Dict[int, int]) -> List[bytes]:
|
|
"""Encode int values into C strings.
|
|
|
|
Values are stored in base 10 and separated by 0 bytes.
|
|
"""
|
|
value_by_index = {}
|
|
for value, index in values.items():
|
|
value_by_index[index] = value
|
|
result = []
|
|
line: List[bytes] = []
|
|
line_len = 0
|
|
for i in range(len(values)):
|
|
value = value_by_index[i]
|
|
encoded = b'%d' % value
|
|
if line_len > 0 and line_len + len(encoded) > 70:
|
|
result.append(format_int(len(line)) + b'\0'.join(line))
|
|
line = []
|
|
line_len = 0
|
|
line.append(encoded)
|
|
line_len += len(encoded)
|
|
if line:
|
|
result.append(format_int(len(line)) + b'\0'.join(line))
|
|
result.append(b'')
|
|
return result
|
|
|
|
|
|
def float_to_c(x: float) -> str:
|
|
"""Return C literal representation of a float value."""
|
|
s = str(x)
|
|
if s == 'inf':
|
|
return 'INFINITY'
|
|
elif s == '-inf':
|
|
return '-INFINITY'
|
|
return s
|
|
|
|
|
|
def _encode_float_values(values: Dict[float, int]) -> List[str]:
|
|
"""Encode float values into a C array values.
|
|
|
|
The result contains the number of values followed by individual values.
|
|
"""
|
|
value_by_index = {}
|
|
for value, index in values.items():
|
|
value_by_index[index] = value
|
|
result = []
|
|
num = len(values)
|
|
result.append(str(num))
|
|
for i in range(num):
|
|
value = value_by_index[i]
|
|
result.append(float_to_c(value))
|
|
return result
|
|
|
|
|
|
def _encode_complex_values(values: Dict[complex, int]) -> List[str]:
|
|
"""Encode float values into a C array values.
|
|
|
|
The result contains the number of values followed by pairs of doubles
|
|
representing complex numbers.
|
|
"""
|
|
value_by_index = {}
|
|
for value, index in values.items():
|
|
value_by_index[index] = value
|
|
result = []
|
|
num = len(values)
|
|
result.append(str(num))
|
|
for i in range(num):
|
|
value = value_by_index[i]
|
|
result.append(float_to_c(value.real))
|
|
result.append(float_to_c(value.imag))
|
|
return result
|