Initial commit: 首次建仓,建立目录结构
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,123 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
class CairoType(ABC):
|
||||
"""
|
||||
Base type for all Cairo type representations. All types extend it.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeltType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo field element.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoolType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo boolean.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TupleType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo tuples without named fields.
|
||||
"""
|
||||
|
||||
types: List[CairoType] #: types of every tuple element.
|
||||
|
||||
|
||||
@dataclass
|
||||
class NamedTupleType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo tuples with named fields.
|
||||
"""
|
||||
|
||||
types: OrderedDict[str, CairoType] #: types of every tuple member.
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArrayType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo arrays.
|
||||
"""
|
||||
|
||||
inner_type: CairoType #: type of element inside array.
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo structures.
|
||||
"""
|
||||
|
||||
name: str #: Structure name
|
||||
# We need ordered dict, because it is important in serialization
|
||||
types: OrderedDict[str, CairoType] #: types of every structure member.
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo enums.
|
||||
"""
|
||||
|
||||
name: str
|
||||
variants: OrderedDict[str, CairoType]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo options.
|
||||
"""
|
||||
|
||||
type: CairoType
|
||||
|
||||
|
||||
@dataclass
|
||||
class UintType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo unsigned integers.
|
||||
"""
|
||||
|
||||
bits: int
|
||||
|
||||
def check_range(self, value: int):
|
||||
"""
|
||||
Utility method checking if the `value` is in range.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypeIdentifier(CairoType):
|
||||
"""
|
||||
Type representation of Cairo identifiers.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo unit `()`.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventType(CairoType):
|
||||
"""
|
||||
Type representation of Cairo Event.
|
||||
"""
|
||||
|
||||
name: str
|
||||
types: OrderedDict[str, CairoType]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,77 @@
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
class CairoType:
|
||||
"""
|
||||
Base class for cairo types.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeFelt(CairoType):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeCodeoffset(CairoType):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypePointer(CairoType):
|
||||
pointee: CairoType
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeIdentifier(CairoType):
|
||||
"""
|
||||
Represents a name of an unresolved type.
|
||||
This type can be resolved to TypeStruct or TypeDefinition.
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeStruct(CairoType):
|
||||
scope: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeFunction(CairoType):
|
||||
"""
|
||||
Represents a type of a function.
|
||||
"""
|
||||
|
||||
scope: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TypeTuple(CairoType):
|
||||
"""
|
||||
Represents a type of a named or unnamed tuple.
|
||||
For example, "(felt, felt*)" or "(a: felt, b: felt*)".
|
||||
"""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Item(CairoType):
|
||||
"""
|
||||
Represents a possibly named type item of a TypeTuple.
|
||||
For example: "felt" or "a: felt".
|
||||
"""
|
||||
|
||||
name: Optional[str]
|
||||
typ: CairoType
|
||||
|
||||
members: List["TypeTuple.Item"]
|
||||
has_trailing_comma: bool = dataclasses.field(hash=False, compare=False)
|
||||
|
||||
@property
|
||||
def is_named(self) -> bool:
|
||||
return all(member.name is not None for member in self.members)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExprIdentifier(CairoType):
|
||||
name: str
|
||||
@ -0,0 +1,46 @@
|
||||
from ....lark import Lark
|
||||
|
||||
from .cairo_types import CairoType
|
||||
from .parser_transformer import ParserTransformer
|
||||
|
||||
CAIRO_EBNF = """
|
||||
%import common.WS_INLINE
|
||||
%ignore WS_INLINE
|
||||
|
||||
IDENTIFIER: /[a-zA-Z_][a-zA-Z_0-9]*/
|
||||
_DBL_STAR: "**"
|
||||
COMMA: ","
|
||||
|
||||
?type: non_identifier_type
|
||||
| identifier -> type_struct
|
||||
|
||||
comma_separated{item}: item? (COMMA item)* COMMA?
|
||||
|
||||
named_type: identifier (":" type)? | non_identifier_type
|
||||
non_identifier_type: "felt" -> type_felt
|
||||
| "codeoffset" -> type_codeoffset
|
||||
| type "*" -> type_pointer
|
||||
| type _DBL_STAR -> type_pointer2
|
||||
| "(" comma_separated{named_type} ")" -> type_tuple
|
||||
|
||||
identifier: IDENTIFIER ("." IDENTIFIER)*
|
||||
"""
|
||||
|
||||
|
||||
def parse(code: str) -> CairoType:
|
||||
"""
|
||||
Parses the given string and returns a CairoType.
|
||||
"""
|
||||
|
||||
grammar = CAIRO_EBNF
|
||||
|
||||
grammar_parser = Lark(
|
||||
grammar=grammar,
|
||||
start=["type"],
|
||||
parser="lalr",
|
||||
)
|
||||
|
||||
parsed = grammar_parser.parse(code)
|
||||
transformed = ParserTransformer().transform(parsed)
|
||||
|
||||
return transformed
|
||||
@ -0,0 +1,138 @@
|
||||
import dataclasses
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from ....lark import Token, Transformer, v_args
|
||||
|
||||
from .cairo_types import (
|
||||
CairoType,
|
||||
ExprIdentifier,
|
||||
TypeCodeoffset,
|
||||
TypeFelt,
|
||||
TypeIdentifier,
|
||||
TypePointer,
|
||||
TypeStruct,
|
||||
TypeTuple,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ParserContext:
|
||||
"""
|
||||
Represents information that affects the parsing process.
|
||||
"""
|
||||
|
||||
# If True, treat type identifiers as resolved.
|
||||
resolved_types: bool = False
|
||||
|
||||
|
||||
class ParserError(Exception):
|
||||
"""
|
||||
Base exception for parsing process.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CommaSeparated:
|
||||
"""
|
||||
Represents a list of comma separated values, such as expressions or types.
|
||||
"""
|
||||
|
||||
args: list
|
||||
has_trailing_comma: bool
|
||||
|
||||
|
||||
class ParserTransformer(Transformer):
|
||||
"""
|
||||
Transforms the lark tree into an AST based on the classes defined in cairo_types.py.
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument, no-self-use
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.parser_context = ParserContext()
|
||||
|
||||
def __default__(self, data: str, children, meta):
|
||||
raise TypeError(f"Unable to parse tree node of type {data}")
|
||||
|
||||
def comma_separated(self, value) -> CommaSeparated:
|
||||
saw_comma = None
|
||||
args: list = []
|
||||
for v in value:
|
||||
if isinstance(v, Token) and v.type == "COMMA":
|
||||
if saw_comma is not False:
|
||||
raise ParserError("Unexpected comma.")
|
||||
saw_comma = True
|
||||
else:
|
||||
if saw_comma is False:
|
||||
raise ParserError("Expected a comma before this expression.")
|
||||
args.append(v)
|
||||
|
||||
# Reset state.
|
||||
saw_comma = False
|
||||
|
||||
if saw_comma is None:
|
||||
saw_comma = False
|
||||
|
||||
return CommaSeparated(args=args, has_trailing_comma=saw_comma)
|
||||
|
||||
# Types.
|
||||
|
||||
@v_args(meta=True)
|
||||
def named_type(self, meta, value) -> TypeTuple.Item:
|
||||
name: Optional[str]
|
||||
if len(value) == 1:
|
||||
# Unnamed type.
|
||||
(typ,) = value
|
||||
name = None
|
||||
if isinstance(typ, ExprIdentifier):
|
||||
typ = self.type_struct([typ])
|
||||
elif len(value) == 2:
|
||||
# Named type.
|
||||
identifier, typ = value
|
||||
assert isinstance(identifier, ExprIdentifier)
|
||||
assert isinstance(typ, CairoType)
|
||||
if "." in identifier.name:
|
||||
raise ParserError("Unexpected . in name.")
|
||||
name = identifier.name
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected number of values. {value}")
|
||||
|
||||
return TypeTuple.Item(name=name, typ=typ)
|
||||
|
||||
@v_args(meta=True)
|
||||
def type_felt(self, meta, value):
|
||||
return TypeFelt()
|
||||
|
||||
@v_args(meta=True)
|
||||
def type_codeoffset(self, meta, value):
|
||||
return TypeCodeoffset()
|
||||
|
||||
def type_struct(self, value):
|
||||
assert len(value) == 1 and isinstance(value[0], ExprIdentifier)
|
||||
if self.parser_context.resolved_types:
|
||||
# If parser_context.resolved_types is True, assume that the type is a struct.
|
||||
return TypeStruct(scope=value[0].name)
|
||||
|
||||
return TypeIdentifier(name=value[0].name)
|
||||
|
||||
@v_args(meta=True)
|
||||
def type_pointer(self, meta, value):
|
||||
return TypePointer(pointee=value[0])
|
||||
|
||||
@v_args(meta=True)
|
||||
def type_pointer2(self, meta, value):
|
||||
return TypePointer(pointee=TypePointer(pointee=value[0]))
|
||||
|
||||
@v_args(meta=True)
|
||||
def type_tuple(self, meta, value: Tuple[CommaSeparated]):
|
||||
(lst,) = value
|
||||
return TypeTuple(members=lst.args, has_trailing_comma=lst.has_trailing_comma)
|
||||
|
||||
@v_args(meta=True)
|
||||
def identifier(self, meta, value):
|
||||
return ExprIdentifier(name=".".join(x.value for x in value))
|
||||
|
||||
@v_args(meta=True)
|
||||
def identifier_def(self, meta, value):
|
||||
return ExprIdentifier(name=value[0].value)
|
||||
@ -0,0 +1,64 @@
|
||||
from typing import List
|
||||
|
||||
from ..constants import FIELD_PRIME
|
||||
|
||||
CairoData = List[int]
|
||||
|
||||
|
||||
MAX_UINT256 = (1 << 256) - 1
|
||||
MIN_UINT256 = 0
|
||||
|
||||
|
||||
def uint256_range_check(value: int):
|
||||
if not MIN_UINT256 <= value <= MAX_UINT256:
|
||||
raise ValueError(
|
||||
f"Uint256 is expected to be in range [0;2**256), got: {value}."
|
||||
)
|
||||
|
||||
|
||||
MIN_FELT = -FIELD_PRIME // 2
|
||||
MAX_FELT = FIELD_PRIME // 2
|
||||
|
||||
|
||||
def is_in_felt_range(value: int) -> bool:
|
||||
return 0 <= value < FIELD_PRIME
|
||||
|
||||
|
||||
def cairo_vm_range_check(value: int):
|
||||
if not is_in_felt_range(value):
|
||||
raise ValueError(
|
||||
f"Felt is expected to be in range [0; {FIELD_PRIME}), got: {value}."
|
||||
)
|
||||
|
||||
|
||||
def encode_shortstring(text: str) -> int:
|
||||
"""
|
||||
A function which encodes short string value (at most 31 characters) into cairo felt (MSB as first character)
|
||||
|
||||
:param text: A short string value in python
|
||||
:return: Short string value encoded into felt
|
||||
"""
|
||||
if len(text) > 31:
|
||||
raise ValueError(
|
||||
f"Shortstring cannot be longer than 31 characters, got: {len(text)}."
|
||||
)
|
||||
|
||||
try:
|
||||
text_bytes = text.encode("ascii")
|
||||
except UnicodeEncodeError as u_err:
|
||||
raise ValueError(f"Expected an ascii string. Found: {repr(text)}.") from u_err
|
||||
value = int.from_bytes(text_bytes, "big")
|
||||
|
||||
cairo_vm_range_check(value)
|
||||
return value
|
||||
|
||||
|
||||
def decode_shortstring(value: int) -> str:
|
||||
"""
|
||||
A function which decodes a felt value to short string (at most 31 characters)
|
||||
|
||||
:param value: A felt value
|
||||
:return: Decoded string which is corresponds to that felt
|
||||
"""
|
||||
cairo_vm_range_check(value)
|
||||
return "".join([chr(i) for i in value.to_bytes(31, byteorder="big")]).lstrip("\x00")
|
||||
@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, cast
|
||||
|
||||
from .deprecated_parse import cairo_types as cairo_lang_types
|
||||
from .data_types import (
|
||||
ArrayType,
|
||||
CairoType,
|
||||
FeltType,
|
||||
NamedTupleType,
|
||||
StructType,
|
||||
TupleType,
|
||||
)
|
||||
from .deprecated_parse.parser import parse
|
||||
|
||||
|
||||
class UnknownCairoTypeError(ValueError):
|
||||
"""
|
||||
Error thrown when TypeParser finds type that was not declared prior to parsing.
|
||||
"""
|
||||
|
||||
type_name: str
|
||||
|
||||
def __init__(self, type_name: str):
|
||||
super().__init__(f"Type '{type_name}' is not defined")
|
||||
self.type_name = type_name
|
||||
|
||||
|
||||
class TypeParser:
|
||||
"""
|
||||
Low level utility class for parsing Cairo types that can be used in external methods.
|
||||
"""
|
||||
|
||||
defined_types: Dict[str, StructType]
|
||||
|
||||
def __init__(self, defined_types: Dict[str, StructType]):
|
||||
"""
|
||||
TypeParser constructor.
|
||||
|
||||
:param defined_types: dictionary containing all defined types. For now, they can only be structures.
|
||||
"""
|
||||
self.defined_types = defined_types
|
||||
for name, struct in defined_types.items():
|
||||
if name != struct.name:
|
||||
raise ValueError(
|
||||
f"Keys must match name of type, '{name}' != '{struct.name}'."
|
||||
)
|
||||
|
||||
def parse_inline_type(self, type_string: str) -> CairoType:
|
||||
"""
|
||||
Inline type is one that can be used inline, for instance as return type. For instance
|
||||
(a: Uint256, b: felt*, c: (felt, felt)). Structure can only be referenced in inline type, can't be defined
|
||||
this way.
|
||||
|
||||
:param type_string: type to parse.
|
||||
"""
|
||||
parsed = parse(type_string)
|
||||
return self._transform_cairo_lang_type(parsed)
|
||||
|
||||
def _transform_cairo_lang_type(
|
||||
self, cairo_type: cairo_lang_types.CairoType
|
||||
) -> CairoType:
|
||||
"""
|
||||
For now, we use parse function from cairo-lang package. It will be replaced in the future, but we need to hide
|
||||
it from the users.
|
||||
This function takes types returned by cairo-lang package and maps them to our type classes.
|
||||
|
||||
:param cairo_type: type returned from parse_type function.
|
||||
:return: CairoType defined by our package.
|
||||
"""
|
||||
if isinstance(cairo_type, cairo_lang_types.TypeFelt):
|
||||
return FeltType()
|
||||
|
||||
if isinstance(cairo_type, cairo_lang_types.TypePointer):
|
||||
return ArrayType(self._transform_cairo_lang_type(cairo_type.pointee))
|
||||
|
||||
if isinstance(cairo_type, cairo_lang_types.TypeIdentifier):
|
||||
return self._get_struct(str(cairo_type.name))
|
||||
|
||||
if isinstance(cairo_type, cairo_lang_types.TypeTuple):
|
||||
# Cairo returns is_named when there are no members
|
||||
if cairo_type.is_named and len(cairo_type.members) != 0:
|
||||
assert all(member.name is not None for member in cairo_type.members)
|
||||
|
||||
return NamedTupleType(
|
||||
OrderedDict(
|
||||
(
|
||||
cast(
|
||||
str, member.name
|
||||
), # without that pyright is complaining
|
||||
self._transform_cairo_lang_type(member.typ),
|
||||
)
|
||||
for member in cairo_type.members
|
||||
)
|
||||
)
|
||||
|
||||
return TupleType(
|
||||
[
|
||||
self._transform_cairo_lang_type(member.typ)
|
||||
for member in cairo_type.members
|
||||
]
|
||||
)
|
||||
|
||||
# Contracts don't support codeoffset as input/output type, user can only use it if it was defined in types
|
||||
if isinstance(cairo_type, cairo_lang_types.TypeCodeoffset):
|
||||
return self._get_struct("codeoffset")
|
||||
|
||||
# Other options are: TypeFunction, TypeStruct
|
||||
# Neither of them are possible. In particular TypeStruct is not possible because we parse structs without
|
||||
# info about other structs, so they will be just TypeIdentifier (structure that was not parsed).
|
||||
|
||||
# This is an error of our logic, so we throw a RuntimeError.
|
||||
raise RuntimeError(
|
||||
f"Received unknown type '{cairo_type}' from parser."
|
||||
) # pragma: no cover
|
||||
|
||||
def _get_struct(self, name: str):
|
||||
if name not in self.defined_types:
|
||||
raise UnknownCairoTypeError(name)
|
||||
return self.defined_types[name]
|
||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
from ...abi.v1.parser_transformer import parse
|
||||
from ..data_types import CairoType, EnumType, StructType, TypeIdentifier
|
||||
|
||||
|
||||
class UnknownCairoTypeError(ValueError):
|
||||
"""
|
||||
Error thrown when TypeParser finds type that was not declared prior to parsing.
|
||||
"""
|
||||
|
||||
type_name: str
|
||||
|
||||
def __init__(self, type_name: str):
|
||||
super().__init__(
|
||||
# pylint: disable=line-too-long
|
||||
f"Type '{type_name}' is not defined. Please report this issue at https://github.com/software-mansion/starknet.py/issues"
|
||||
)
|
||||
self.type_name = type_name
|
||||
|
||||
|
||||
class TypeParser:
|
||||
"""
|
||||
Low level utility class for parsing Cairo types that can be used in external methods.
|
||||
"""
|
||||
|
||||
defined_types: Dict[str, Union[StructType, EnumType]]
|
||||
|
||||
def __init__(self, defined_types: Dict[str, Union[StructType, EnumType]]):
|
||||
"""
|
||||
TypeParser constructor.
|
||||
|
||||
:param defined_types: dictionary containing all defined types. For now, they can only be structures.
|
||||
"""
|
||||
self.defined_types = defined_types
|
||||
for name, defined_type in defined_types.items():
|
||||
if name != defined_type.name:
|
||||
raise ValueError(
|
||||
f"Keys must match name of type, '{name}' != '{defined_type.name}'."
|
||||
)
|
||||
|
||||
def parse_inline_type(self, type_string: str) -> CairoType:
|
||||
"""
|
||||
Inline type is one that can be used inline, for instance as return type. For instance
|
||||
(core::felt252, (), (core::felt252,)). Structure can only be referenced in inline type, can't be defined
|
||||
this way.
|
||||
|
||||
:param type_string: type to parse.
|
||||
"""
|
||||
parsed = parse(type_string, self.defined_types)
|
||||
if isinstance(parsed, TypeIdentifier):
|
||||
for defined_name in self.defined_types.keys():
|
||||
if parsed.name == defined_name.split("<")[0].strip(":"):
|
||||
return self.defined_types[defined_name]
|
||||
raise UnknownCairoTypeError(parsed.name)
|
||||
|
||||
return parsed
|
||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Union
|
||||
|
||||
from ...abi.v2.parser_transformer import parse
|
||||
from ..data_types import (
|
||||
CairoType,
|
||||
EnumType,
|
||||
EventType,
|
||||
StructType,
|
||||
TypeIdentifier,
|
||||
)
|
||||
|
||||
|
||||
class UnknownCairoTypeError(ValueError):
|
||||
"""
|
||||
Error thrown when TypeParser finds type that was not declared prior to parsing.
|
||||
"""
|
||||
|
||||
type_name: str
|
||||
|
||||
def __init__(self, type_name: str):
|
||||
super().__init__(
|
||||
# pylint: disable=line-too-long
|
||||
f"Type '{type_name}' is not defined. Please report this issue at https://github.com/software-mansion/starknet.py/issues"
|
||||
)
|
||||
self.type_name = type_name
|
||||
|
||||
|
||||
class TypeParser:
|
||||
"""
|
||||
Low level utility class for parsing Cairo types that can be used in external methods.
|
||||
"""
|
||||
|
||||
defined_types: Dict[str, Union[StructType, EnumType, EventType]]
|
||||
|
||||
def __init__(
|
||||
self, defined_types: Dict[str, Union[StructType, EnumType, EventType]]
|
||||
):
|
||||
"""
|
||||
TypeParser constructor.
|
||||
|
||||
:param defined_types: dictionary containing all defined types. For now, they can only be structures.
|
||||
"""
|
||||
self.defined_types = defined_types
|
||||
for name, defined_type in defined_types.items():
|
||||
if name != defined_type.name:
|
||||
raise ValueError(
|
||||
f"Keys must match name of type, '{name}' != '{defined_type.name}'."
|
||||
)
|
||||
|
||||
def update_defined_types(
|
||||
self, defined_types: Dict[str, Union[StructType, EnumType, EventType]]
|
||||
) -> None:
|
||||
self.defined_types.update(defined_types)
|
||||
|
||||
def add_defined_type(
|
||||
self, defined_type: Union[StructType, EnumType, EventType]
|
||||
) -> None:
|
||||
self.defined_types.update({defined_type.name: defined_type})
|
||||
|
||||
def parse_inline_type(self, type_string: str) -> CairoType:
|
||||
"""
|
||||
Inline type is one that can be used inline, for instance as return type. For instance
|
||||
(core::felt252, (), (core::felt252,)). Structure can only be referenced in inline type, can't be defined
|
||||
this way.
|
||||
|
||||
:param type_string: type to parse.
|
||||
"""
|
||||
parsed = parse(type_string, self.defined_types)
|
||||
if isinstance(parsed, TypeIdentifier):
|
||||
for defined_name in self.defined_types.keys():
|
||||
if parsed.name == defined_name.split("<")[0].strip(":"):
|
||||
return self.defined_types[defined_name]
|
||||
raise UnknownCairoTypeError(parsed.name)
|
||||
|
||||
return parsed
|
||||
@ -0,0 +1,7 @@
|
||||
# utils to use starknet library in ccxt
|
||||
from .constants import EC_ORDER
|
||||
from ..starkware.crypto.signature import grind_key
|
||||
|
||||
def get_private_key_from_eth_signature(eth_signature_hex: str) -> int:
|
||||
r = eth_signature_hex[2 : 64 + 2] if eth_signature_hex[0:2] == '0x' else eth_signature_hex[0 : 64]
|
||||
return grind_key(int(r, 16), EC_ORDER)
|
||||
@ -0,0 +1,15 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
def int_from_hex(number: Union[str, int]) -> int:
|
||||
return number if isinstance(number, int) else int(number, 16)
|
||||
|
||||
|
||||
def int_from_bytes(
|
||||
value: bytes,
|
||||
byte_order: Literal["big", "little"] = "big",
|
||||
signed: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Converts the given bytes object (parsed according to the given byte order) to an integer.
|
||||
"""
|
||||
return int.from_bytes(value, byteorder=byte_order, signed=signed)
|
||||
@ -0,0 +1,39 @@
|
||||
from pathlib import Path
|
||||
|
||||
# Address came from starkware-libs/starknet-addresses repository: https://github.com/starkware-libs/starknet-addresses
|
||||
FEE_CONTRACT_ADDRESS = (
|
||||
"0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7"
|
||||
)
|
||||
|
||||
DEFAULT_DEPLOYER_ADDRESS = (
|
||||
"0x041a78e741e5aF2fEc34B695679bC6891742439f7AFB8484Ecd7766661aD02BF"
|
||||
)
|
||||
|
||||
API_VERSION = 0
|
||||
|
||||
RPC_CONTRACT_NOT_FOUND_ERROR = 20
|
||||
RPC_INVALID_MESSAGE_SELECTOR_ERROR = 21
|
||||
RPC_CLASS_HASH_NOT_FOUND_ERROR = 28
|
||||
RPC_CONTRACT_ERROR = 40
|
||||
|
||||
DEFAULT_ENTRY_POINT_NAME = "__default__"
|
||||
DEFAULT_L1_ENTRY_POINT_NAME = "__l1_default__"
|
||||
DEFAULT_ENTRY_POINT_SELECTOR = 0
|
||||
DEFAULT_DECLARE_SENDER_ADDRESS = 1
|
||||
|
||||
# MAX_STORAGE_ITEM_SIZE and ADDR_BOUND must be consistent with the corresponding constant in
|
||||
# starkware/starknet/common/storage.cairo.
|
||||
MAX_STORAGE_ITEM_SIZE = 256
|
||||
ADDR_BOUND = 2**251 - MAX_STORAGE_ITEM_SIZE
|
||||
|
||||
FIELD_PRIME = 0x800000000000011000000000000000000000000000000000000000000000001
|
||||
EC_ORDER = 0x800000000000010FFFFFFFFFFFFFFFFB781126DCAE7B2321E66A241ADC64D2F
|
||||
|
||||
# From cairo-lang
|
||||
# int_from_bytes(b"STARKNET_CONTRACT_ADDRESS")
|
||||
CONTRACT_ADDRESS_PREFIX = 523065374597054866729014270389667305596563390979550329787219
|
||||
L2_ADDRESS_UPPER_BOUND = 2**251 - 256
|
||||
|
||||
QUERY_VERSION_BASE = 2**128
|
||||
|
||||
ROOT_PATH = Path(__file__).parent
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,79 @@
|
||||
from typing import Sequence
|
||||
|
||||
from ..constants import CONTRACT_ADDRESS_PREFIX, L2_ADDRESS_UPPER_BOUND
|
||||
from .utils import (
|
||||
HEX_PREFIX,
|
||||
_starknet_keccak,
|
||||
compute_hash_on_elements,
|
||||
encode_uint,
|
||||
get_bytes_length,
|
||||
)
|
||||
|
||||
|
||||
def compute_address(
|
||||
*,
|
||||
class_hash: int,
|
||||
constructor_calldata: Sequence[int],
|
||||
salt: int,
|
||||
deployer_address: int = 0,
|
||||
) -> int:
|
||||
"""
|
||||
Computes the contract address in the Starknet network - a unique identifier of the contract.
|
||||
|
||||
:param class_hash: class hash of the contract
|
||||
:param constructor_calldata: calldata for the contract constructor
|
||||
:param salt: salt used to calculate contract address
|
||||
:param deployer_address: address of the deployer (if not provided default 0 is used)
|
||||
:return: Contract's address
|
||||
"""
|
||||
|
||||
constructor_calldata_hash = compute_hash_on_elements(data=constructor_calldata)
|
||||
raw_address = compute_hash_on_elements(
|
||||
data=[
|
||||
CONTRACT_ADDRESS_PREFIX,
|
||||
deployer_address,
|
||||
salt,
|
||||
class_hash,
|
||||
constructor_calldata_hash,
|
||||
],
|
||||
)
|
||||
|
||||
return raw_address % L2_ADDRESS_UPPER_BOUND
|
||||
|
||||
|
||||
def get_checksum_address(address: str) -> str:
|
||||
"""
|
||||
Outputs formatted checksum address.
|
||||
|
||||
Follows implementation of starknet.js. It is not compatible with EIP55 as it treats hex string as encoded number,
|
||||
instead of encoding it as ASCII string.
|
||||
|
||||
:param address: Address to encode
|
||||
:return: Checksum address
|
||||
"""
|
||||
if not address.lower().startswith(HEX_PREFIX):
|
||||
raise ValueError(f"{address} is not a valid hexadecimal address.")
|
||||
|
||||
int_address = int(address, 16)
|
||||
string_address = address[2:].zfill(64)
|
||||
|
||||
address_in_bytes = encode_uint(int_address, get_bytes_length(int_address))
|
||||
address_hash = _starknet_keccak(address_in_bytes)
|
||||
|
||||
result = "".join(
|
||||
(
|
||||
char.upper()
|
||||
if char.isalpha() and (address_hash >> 256 - 4 * i - 1) & 1
|
||||
else char
|
||||
)
|
||||
for i, char in enumerate(string_address)
|
||||
)
|
||||
|
||||
return f"{HEX_PREFIX}{result}"
|
||||
|
||||
|
||||
def is_checksum_address(address: str) -> bool:
|
||||
"""
|
||||
Checks if provided string is in a checksum address format.
|
||||
"""
|
||||
return get_checksum_address(address) == address
|
||||
@ -0,0 +1,111 @@
|
||||
# File is copied from
|
||||
# https://github.com/starkware-libs/cairo-lang/blob/v0.13.1/src/starkware/starknet/core/os/contract_class/compiled_class_hash_objects.py
|
||||
|
||||
import dataclasses
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Union
|
||||
|
||||
from poseidon_py.poseidon_hash import poseidon_hash_many
|
||||
|
||||
|
||||
class BytecodeSegmentStructure(ABC):
|
||||
"""
|
||||
Represents the structure of the bytecode to allow loading it partially into the OS memory.
|
||||
See the documentation of the OS function `bytecode_hash_node` in `compiled_class.cairo`
|
||||
for more details.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def hash(self) -> int:
|
||||
"""
|
||||
Computes the hash of the node.
|
||||
"""
|
||||
|
||||
def bytecode_with_skipped_segments(self):
|
||||
"""
|
||||
Returns the bytecode of the node.
|
||||
Skipped segments are replaced with [-1, -2, -2, -2, ...].
|
||||
"""
|
||||
res: List[int] = []
|
||||
self.add_bytecode_with_skipped_segments(res)
|
||||
return res
|
||||
|
||||
@abstractmethod
|
||||
def add_bytecode_with_skipped_segments(self, data: List[int]):
|
||||
"""
|
||||
Same as bytecode_with_skipped_segments, but appends the result to the given list.
|
||||
"""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BytecodeLeaf(BytecodeSegmentStructure):
|
||||
"""
|
||||
Represents a leaf in the bytecode segment tree.
|
||||
"""
|
||||
|
||||
data: List[int]
|
||||
|
||||
def hash(self) -> int:
|
||||
return poseidon_hash_many(self.data)
|
||||
|
||||
def add_bytecode_with_skipped_segments(self, data: List[int]):
|
||||
data.extend(self.data)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BytecodeSegmentedNode(BytecodeSegmentStructure):
|
||||
"""
|
||||
Represents an internal node in the bytecode segment tree.
|
||||
Each child can be loaded into memory or skipped.
|
||||
"""
|
||||
|
||||
segments: List["BytecodeSegment"]
|
||||
|
||||
def hash(self) -> int:
|
||||
return (
|
||||
poseidon_hash_many(
|
||||
itertools.chain( # pyright: ignore
|
||||
*[
|
||||
(node.segment_length, node.inner_structure.hash())
|
||||
for node in self.segments
|
||||
]
|
||||
)
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
|
||||
def add_bytecode_with_skipped_segments(self, data: List[int]):
|
||||
for segment in self.segments:
|
||||
if segment.is_used:
|
||||
segment.inner_structure.add_bytecode_with_skipped_segments(data)
|
||||
else:
|
||||
data.append(-1)
|
||||
data.extend(-2 for _ in range(segment.segment_length - 1))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BytecodeSegment:
|
||||
"""
|
||||
Represents a child of BytecodeSegmentedNode.
|
||||
"""
|
||||
|
||||
# The length of the segment.
|
||||
segment_length: int
|
||||
# Should the segment (or part of it) be loaded to memory.
|
||||
# In other words, is the segment used during the execution.
|
||||
# Note that if is_used is False, the entire segment is not loaded to memory.
|
||||
# If is_used is True, it is possible that part of the segment will be skipped (according
|
||||
# to the "is_used" field of the child segments).
|
||||
is_used: bool
|
||||
# The inner structure of the segment.
|
||||
inner_structure: BytecodeSegmentStructure
|
||||
|
||||
def __post_init__(self):
|
||||
assert (
|
||||
self.segment_length > 0
|
||||
), f"Invalid segment length: {self.segment_length}."
|
||||
|
||||
|
||||
# Represents a nested list of integers. E.g., [1, [2, [3], 4], 5, 6].
|
||||
NestedIntList = Union[int, List[Any]]
|
||||
@ -0,0 +1,16 @@
|
||||
from ..constants import (
|
||||
DEFAULT_ENTRY_POINT_NAME,
|
||||
DEFAULT_ENTRY_POINT_SELECTOR,
|
||||
DEFAULT_L1_ENTRY_POINT_NAME,
|
||||
)
|
||||
from ..hash.utils import _starknet_keccak
|
||||
|
||||
|
||||
def get_selector_from_name(func_name: str) -> int:
|
||||
"""
|
||||
Returns the selector of a contract's function name.
|
||||
"""
|
||||
if func_name in [DEFAULT_ENTRY_POINT_NAME, DEFAULT_L1_ENTRY_POINT_NAME]:
|
||||
return DEFAULT_ENTRY_POINT_SELECTOR
|
||||
|
||||
return _starknet_keccak(data=func_name.encode("ascii"))
|
||||
@ -0,0 +1,12 @@
|
||||
from functools import reduce
|
||||
|
||||
from constants import ADDR_BOUND
|
||||
from hash.utils import _starknet_keccak, pedersen_hash
|
||||
|
||||
|
||||
def get_storage_var_address(var_name: str, *args: int) -> int:
|
||||
"""
|
||||
Returns the storage address of a Starknet storage variable given its name and arguments.
|
||||
"""
|
||||
res = _starknet_keccak(var_name.encode("ascii"))
|
||||
return reduce(pedersen_hash, args, res) % ADDR_BOUND
|
||||
@ -0,0 +1,78 @@
|
||||
import functools
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from ... import keccak
|
||||
|
||||
from ..common import int_from_bytes
|
||||
from ..constants import EC_ORDER
|
||||
from ...starkware.crypto.signature import (
|
||||
ECSignature,
|
||||
private_to_stark_key,
|
||||
sign
|
||||
# verify
|
||||
)
|
||||
from ...starkware.crypto.fast_pedersen_hash import (
|
||||
pedersen_hash
|
||||
)
|
||||
|
||||
MASK_250 = 2**250 - 1
|
||||
HEX_PREFIX = "0x"
|
||||
|
||||
|
||||
def _starknet_keccak(data: bytes) -> int:
|
||||
"""
|
||||
A variant of eth-keccak that computes a value that fits in a Starknet field element.
|
||||
"""
|
||||
return int_from_bytes(keccak.SHA3(data)) & MASK_250
|
||||
|
||||
|
||||
# def pedersen_hash(left: int, right: int) -> int:
|
||||
# """
|
||||
# One of two hash functions (along with _starknet_keccak) used throughout Starknet.
|
||||
# """
|
||||
# return cpp_hash(left, right)
|
||||
|
||||
|
||||
def compute_hash_on_elements(data: Sequence) -> int:
|
||||
"""
|
||||
Computes a hash chain over the data, in the following order:
|
||||
h(h(h(h(0, data[0]), data[1]), ...), data[n-1]), n).
|
||||
|
||||
The hash is initialized with 0 and ends with the data length appended.
|
||||
The length is appended in order to avoid collisions of the following kind:
|
||||
H([x,y,z]) = h(h(x,y),z) = H([w, z]) where w = h(x,y).
|
||||
"""
|
||||
return functools.reduce(pedersen_hash, [*data, len(data)], 0)
|
||||
|
||||
|
||||
def message_signature(
|
||||
msg_hash: int, priv_key: int, seed: Optional[int] = 32
|
||||
) -> ECSignature:
|
||||
"""
|
||||
Signs the message with private key.
|
||||
"""
|
||||
return sign(msg_hash, priv_key, seed)
|
||||
|
||||
|
||||
# def verify_message_signature(
|
||||
# msg_hash: int, signature: List[int], public_key: int
|
||||
# ) -> bool:
|
||||
# """
|
||||
# Verifies ECDSA signature of a given message hash with a given public key.
|
||||
# Returns true if public_key signs the message.
|
||||
# """
|
||||
# sig_r, sig_s = signature
|
||||
# # sig_w = pow(sig_s, -1, EC_ORDER)
|
||||
# return verify(msg_hash=msg_hash, r=sig_r, s=sig_s, public_key=public_key)
|
||||
|
||||
|
||||
def encode_uint(value: int, bytes_length: int = 32) -> bytes:
|
||||
return value.to_bytes(bytes_length, byteorder="big")
|
||||
|
||||
|
||||
def encode_uint_list(data: List[int]) -> bytes:
|
||||
return b"".join(encode_uint(x) for x in data)
|
||||
|
||||
|
||||
def get_bytes_length(value: int) -> int:
|
||||
return (value.bit_length() + 7) // 8
|
||||
Binary file not shown.
Binary file not shown.
@ -0,0 +1,45 @@
|
||||
"""
|
||||
TypedDict structures for TypedData
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
|
||||
class Revision(Enum):
|
||||
"""
|
||||
Enum representing the revision of the specification to be used.
|
||||
"""
|
||||
|
||||
V0 = 0
|
||||
V1 = 1
|
||||
|
||||
|
||||
class ParameterDict(TypedDict):
|
||||
"""
|
||||
TypedDict representing a Parameter object
|
||||
"""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
|
||||
|
||||
class StarkNetDomainDict(TypedDict):
|
||||
"""
|
||||
TypedDict representing a domain object (both StarkNetDomain, StarknetDomain).
|
||||
"""
|
||||
|
||||
name: str
|
||||
version: str
|
||||
chainId: str
|
||||
revision: Optional[Revision]
|
||||
|
||||
|
||||
class TypedDataDict(TypedDict):
|
||||
"""
|
||||
TypedDict representing a TypedData object
|
||||
"""
|
||||
|
||||
types: Dict[str, List[ParameterDict]]
|
||||
primaryType: str
|
||||
domain: StarkNetDomainDict
|
||||
message: Dict[str, Any]
|
||||
@ -0,0 +1,24 @@
|
||||
# PayloadSerializer and FunctionSerializationAdapter would mostly be used by users
|
||||
from .data_serializers import (
|
||||
ArraySerializer,
|
||||
CairoDataSerializer,
|
||||
FeltSerializer,
|
||||
NamedTupleSerializer,
|
||||
PayloadSerializer,
|
||||
StructSerializer,
|
||||
TupleSerializer,
|
||||
Uint256Serializer,
|
||||
)
|
||||
from .errors import (
|
||||
CairoSerializerException,
|
||||
InvalidTypeException,
|
||||
InvalidValueException,
|
||||
)
|
||||
from .factory import (
|
||||
serializer_for_event,
|
||||
serializer_for_function,
|
||||
serializer_for_payload,
|
||||
serializer_for_type,
|
||||
)
|
||||
from .function_serialization_adapter import FunctionSerializationAdapter
|
||||
from .tuple_dataclass import TupleDataclass
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,40 @@
|
||||
from typing import List
|
||||
|
||||
from ..cairo.felt import CairoData
|
||||
|
||||
|
||||
class OutOfBoundsError(Exception):
|
||||
def __init__(self, position: int, requested_size: int, remaining_size: int):
|
||||
super().__init__(
|
||||
f"Requested {requested_size} elements, {remaining_size} available."
|
||||
)
|
||||
self.position = position
|
||||
self.requested_size = requested_size
|
||||
self.remaining_len = remaining_size
|
||||
|
||||
|
||||
class CalldataReader:
|
||||
_data: List[int]
|
||||
_position: int
|
||||
|
||||
def __init__(self, data: List[int]):
|
||||
self._data = data
|
||||
self._position = 0
|
||||
|
||||
@property
|
||||
def remaining_len(self) -> int:
|
||||
return len(self._data) - self._position
|
||||
|
||||
def read(self, size: int) -> CairoData:
|
||||
if size < 1:
|
||||
raise ValueError("size must be greater than 0")
|
||||
|
||||
if size > self.remaining_len:
|
||||
raise OutOfBoundsError(
|
||||
position=self._position,
|
||||
requested_size=size,
|
||||
remaining_size=self.remaining_len,
|
||||
)
|
||||
data = self._data[self._position : self._position + size]
|
||||
self._position += size
|
||||
return data
|
||||
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator, Iterator, List
|
||||
|
||||
from ._calldata_reader import (
|
||||
CairoData,
|
||||
CalldataReader,
|
||||
OutOfBoundsError,
|
||||
)
|
||||
from .errors import InvalidTypeException, InvalidValueException
|
||||
|
||||
|
||||
class Context(ABC):
|
||||
"""
|
||||
Holds information about context when (de)serializing data. This is needed to inform what and where went
|
||||
wrong during processing. Every separate (de)serialization should have its own context.
|
||||
"""
|
||||
|
||||
_namespace_stack: List[str]
|
||||
|
||||
def __init__(self):
|
||||
self._namespace_stack = []
|
||||
|
||||
@property
|
||||
def current_entity(self):
|
||||
"""
|
||||
Name of currently processed entity.
|
||||
|
||||
:return: transformed path.
|
||||
"""
|
||||
return ".".join(self._namespace_stack)
|
||||
|
||||
@contextmanager
|
||||
def push_entity(self, name: str) -> Generator:
|
||||
"""
|
||||
Manager used for maintaining information about names of (de)serialized types. Wraps some errors with
|
||||
custom errors, adding information about the context.
|
||||
|
||||
:param name: name of (de)serialized entity.
|
||||
"""
|
||||
# This ensures the name will be popped if everything is ok. In case an exception is raised we want the stack to
|
||||
# be filled to wrap the error at the end.
|
||||
self._namespace_stack.append(name)
|
||||
yield
|
||||
self._namespace_stack.pop()
|
||||
|
||||
def ensure_valid_value(self, valid: bool, text: str):
|
||||
if not valid:
|
||||
raise InvalidValueException(f"{self._error_prefix}: {text}.")
|
||||
|
||||
def ensure_valid_type(self, value: Any, valid: bool, expected_type: str):
|
||||
if not valid:
|
||||
raise InvalidTypeException(
|
||||
f"{self._error_prefix}: expected {expected_type}, "
|
||||
f"received '{value}' of type '{type(value)}'."
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def _wrap_errors(self):
|
||||
try:
|
||||
yield
|
||||
except OutOfBoundsError as err:
|
||||
action_name = (
|
||||
f"deserialize '{self.current_entity}'"
|
||||
if self._namespace_stack
|
||||
else "deserialize"
|
||||
)
|
||||
# This way we can precisely inform user what's wrong when reading calldata.
|
||||
raise InvalidValueException(
|
||||
f"Not enough data to {action_name}. "
|
||||
f"Can't read {err.requested_size} values at position {err.position}, {err.remaining_len} available."
|
||||
) from err
|
||||
|
||||
# Those two are based on ValueError and TypeError, we have to catch them early
|
||||
except (InvalidValueException, InvalidTypeException) as err:
|
||||
raise err
|
||||
|
||||
except ValueError as err:
|
||||
raise InvalidValueException(f"{self._error_prefix}: {err}") from err
|
||||
except TypeError as err:
|
||||
raise InvalidTypeException(f"{self._error_prefix}: {err}") from err
|
||||
|
||||
@property
|
||||
def _error_prefix(self):
|
||||
if not self._namespace_stack:
|
||||
return "Error"
|
||||
return f"Error at path '{self.current_entity}'"
|
||||
|
||||
|
||||
class SerializationContext(Context):
|
||||
"""
|
||||
Context used during serialization.
|
||||
"""
|
||||
|
||||
# Type is iterator, because ContextManager doesn't work with pyright :|
|
||||
# https://github.com/microsoft/pyright/issues/476
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def create(cls) -> Iterator[SerializationContext]:
|
||||
context = cls()
|
||||
with context._wrap_errors():
|
||||
yield context
|
||||
|
||||
|
||||
class DeserializationContext(Context):
|
||||
"""
|
||||
Context used during deserialization.
|
||||
"""
|
||||
|
||||
reader: CalldataReader
|
||||
|
||||
def __init__(self, calldata: CairoData):
|
||||
"""
|
||||
Don't use default constructor. Use DeserializationContext.create context manager.
|
||||
"""
|
||||
super().__init__()
|
||||
self._namespace_stack = []
|
||||
self.reader = CalldataReader(calldata)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def create(cls, data: CairoData) -> Iterator[DeserializationContext]:
|
||||
context = cls(data)
|
||||
with context._wrap_errors():
|
||||
yield context
|
||||
context._ensure_all_values_read(len(data))
|
||||
|
||||
def _ensure_all_values_read(self, total_len: int):
|
||||
values_not_used = self.reader.remaining_len
|
||||
if values_not_used != 0:
|
||||
# We want to output up to 3 values. It there is more they will be truncated like "0x1,0x1,0x1..."
|
||||
max_values_to_show = 3
|
||||
values_to_show = min(values_not_used, max_values_to_show)
|
||||
example = ",".join(hex(v) for v in self.reader.read(values_to_show))
|
||||
suffix = "..." if values_not_used > max_values_to_show else ""
|
||||
|
||||
raise InvalidValueException(
|
||||
f"Last {values_not_used} values '{example}{suffix}' out of total {total_len} "
|
||||
"values were not used during deserialization."
|
||||
)
|
||||
@ -0,0 +1,10 @@
|
||||
from .array_serializer import ArraySerializer
|
||||
from .bool_serializer import BoolSerializer
|
||||
from .byte_array_serializer import ByteArraySerializer
|
||||
from .cairo_data_serializer import CairoDataSerializer
|
||||
from .felt_serializer import FeltSerializer
|
||||
from .named_tuple_serializer import NamedTupleSerializer
|
||||
from .payload_serializer import PayloadSerializer
|
||||
from .struct_serializer import StructSerializer
|
||||
from .tuple_serializer import TupleSerializer
|
||||
from .uint256_serializer import Uint256Serializer
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,82 @@
|
||||
# We have to use parametrised type from typing
|
||||
from collections import OrderedDict as _OrderedDict
|
||||
from typing import Dict, Generator, List, OrderedDict
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
# The actual serialization logic is very similar among all serializers: they either serialize data based on
|
||||
# position or their name. Having this logic reused adds indirection, but makes sure proper logic is used everywhere.
|
||||
|
||||
|
||||
def deserialize_to_list(
|
||||
deserializers: List[CairoDataSerializer], context: DeserializationContext
|
||||
) -> List:
|
||||
"""
|
||||
Deserializes data from context to list. This logic is used in every sequential type (arrays and tuples).
|
||||
"""
|
||||
result = []
|
||||
|
||||
for index, serializer in enumerate(deserializers):
|
||||
with context.push_entity(f"[{index}]"):
|
||||
result.append(serializer.deserialize_with_context(context))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def deserialize_to_dict(
|
||||
deserializers: OrderedDict[str, CairoDataSerializer],
|
||||
context: DeserializationContext,
|
||||
) -> OrderedDict:
|
||||
"""
|
||||
Deserializes data from context to dictionary. This logic is used in every type with named fields (structs,
|
||||
named tuples and payloads).
|
||||
"""
|
||||
result = _OrderedDict()
|
||||
|
||||
for key, serializer in deserializers.items():
|
||||
with context.push_entity(key):
|
||||
result[key] = serializer.deserialize_with_context(context)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_from_list(
|
||||
serializers: List[CairoDataSerializer], context: SerializationContext, values: List
|
||||
) -> Generator[int, None, None]:
|
||||
"""
|
||||
Serializes data from list. This logic is used in every sequential type (arrays and tuples).
|
||||
"""
|
||||
context.ensure_valid_value(
|
||||
len(serializers) == len(values),
|
||||
f"expected {len(serializers)} elements, {len(values)} provided",
|
||||
)
|
||||
|
||||
for index, (serializer, value) in enumerate(zip(serializers, values)):
|
||||
with context.push_entity(f"[{index}]"):
|
||||
yield from serializer.serialize_with_context(context, value)
|
||||
|
||||
|
||||
def serialize_from_dict(
|
||||
serializers: OrderedDict[str, CairoDataSerializer],
|
||||
context: SerializationContext,
|
||||
values: Dict,
|
||||
) -> Generator[int, None, None]:
|
||||
"""
|
||||
Serializes data from dict. This logic is used in every type with named fields (structs, named tuples and payloads).
|
||||
"""
|
||||
excessive_keys = set(values.keys()).difference(serializers.keys())
|
||||
context.ensure_valid_value(
|
||||
not excessive_keys,
|
||||
f"unexpected keys '{','.join(excessive_keys)}' were provided",
|
||||
)
|
||||
|
||||
for name, serializer in serializers.items():
|
||||
with context.push_entity(name):
|
||||
context.ensure_valid_value(name in values, f"key '{name}' is missing")
|
||||
yield from serializer.serialize_with_context(context, values[name])
|
||||
@ -0,0 +1,43 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator, Iterable, List
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from ..data_serializers._common import (
|
||||
deserialize_to_list,
|
||||
serialize_from_list,
|
||||
)
|
||||
from ..data_serializers.cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArraySerializer(CairoDataSerializer[Iterable, List]):
|
||||
"""
|
||||
Serializer for arrays. In abi they are represented as a pointer to a type.
|
||||
Can serialize any iterable and prepends its length to resulting list.
|
||||
Deserializes data to a list.
|
||||
|
||||
Examples:
|
||||
[1,2,3] => [3,1,2,3]
|
||||
[] => [0]
|
||||
"""
|
||||
|
||||
inner_serializer: CairoDataSerializer
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> List:
|
||||
with context.push_entity("len"):
|
||||
[size] = context.reader.read(1)
|
||||
|
||||
return deserialize_to_list([self.inner_serializer] * size, context)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: List
|
||||
) -> Generator[int, None, None]:
|
||||
yield len(value)
|
||||
yield from serialize_from_list(
|
||||
[self.inner_serializer] * len(value), context, value
|
||||
)
|
||||
@ -0,0 +1,37 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator
|
||||
|
||||
from .._context import (
|
||||
Context,
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BoolSerializer(CairoDataSerializer[bool, int]):
|
||||
"""
|
||||
Serializer for boolean.
|
||||
"""
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> bool:
|
||||
[val] = context.reader.read(1)
|
||||
self._ensure_bool(context, val)
|
||||
return bool(val)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: bool
|
||||
) -> Generator[int, None, None]:
|
||||
context.ensure_valid_type(value, isinstance(value, bool), "bool")
|
||||
self._ensure_bool(context, value)
|
||||
yield int(value)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_bool(context: Context, value: int):
|
||||
context.ensure_valid_value(
|
||||
value in [0, 1],
|
||||
f"invalid value '{value}' - must be in [0, 2) range",
|
||||
)
|
||||
@ -0,0 +1,66 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator
|
||||
|
||||
from ...cairo.felt import decode_shortstring, encode_shortstring
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from ._common import (
|
||||
deserialize_to_list,
|
||||
serialize_from_list,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
from .felt_serializer import FeltSerializer
|
||||
|
||||
BYTES_31_SIZE = 31
|
||||
|
||||
|
||||
@dataclass
|
||||
class ByteArraySerializer(CairoDataSerializer[str, str]):
|
||||
"""
|
||||
Serializer for ByteArrays. Serializes to and deserializes from str values.
|
||||
|
||||
Examples:
|
||||
"" => [0,0,0]
|
||||
"hello" => [0,448378203247,5]
|
||||
"""
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> str:
|
||||
with context.push_entity("data_array_len"):
|
||||
[size] = context.reader.read(1)
|
||||
|
||||
data = deserialize_to_list([FeltSerializer()] * size, context)
|
||||
|
||||
with context.push_entity("pending_word"):
|
||||
[pending_word] = context.reader.read(1)
|
||||
|
||||
with context.push_entity("pending_word_len"):
|
||||
[pending_word_len] = context.reader.read(1)
|
||||
|
||||
pending_word = decode_shortstring(pending_word)
|
||||
context.ensure_valid_value(
|
||||
len(pending_word) == pending_word_len,
|
||||
f"Invalid length {pending_word_len} for pending word {pending_word}",
|
||||
)
|
||||
|
||||
data_joined = "".join(map(decode_shortstring, data))
|
||||
return data_joined + pending_word
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: str
|
||||
) -> Generator[int, None, None]:
|
||||
context.ensure_valid_type(value, isinstance(value, str), "str")
|
||||
data = [
|
||||
value[i : i + BYTES_31_SIZE] for i in range(0, len(value), BYTES_31_SIZE)
|
||||
]
|
||||
pending_word = (
|
||||
"" if len(data) == 0 or len(data[-1]) == BYTES_31_SIZE else data.pop(-1)
|
||||
)
|
||||
|
||||
yield len(data)
|
||||
yield from serialize_from_list([FeltSerializer()] * len(data), context, data)
|
||||
yield encode_shortstring(pending_word)
|
||||
yield len(pending_word)
|
||||
@ -0,0 +1,71 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, Generic, List, TypeVar
|
||||
|
||||
from .._calldata_reader import CairoData
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
|
||||
# Python type that is accepted by a serializer
|
||||
# pylint: disable=invalid-name
|
||||
SerializationType = TypeVar("SerializationType")
|
||||
|
||||
# Python type that will be returned from a serializer. Often same as SerializationType.
|
||||
# pylint: disable=invalid-name
|
||||
DeserializationType = TypeVar("DeserializationType")
|
||||
|
||||
|
||||
class CairoDataSerializer(ABC, Generic[SerializationType, DeserializationType]):
|
||||
"""
|
||||
Base class for serializing/deserializing data to/from calldata.
|
||||
"""
|
||||
|
||||
def deserialize(self, data: List[int]) -> DeserializationType:
|
||||
"""
|
||||
Transform calldata into python value.
|
||||
|
||||
:param data: calldata to deserialize.
|
||||
:return: defined DeserializationType.
|
||||
"""
|
||||
with DeserializationContext.create(data) as context:
|
||||
return self.deserialize_with_context(context)
|
||||
|
||||
def serialize(self, data: SerializationType) -> CairoData:
|
||||
"""
|
||||
Transform python data into calldata.
|
||||
|
||||
:param data: data to serialize.
|
||||
:return: calldata.
|
||||
"""
|
||||
with SerializationContext.create() as context:
|
||||
serialized_data = list(self.serialize_with_context(context, data))
|
||||
|
||||
return self.remove_units_from_serialized_data(serialized_data)
|
||||
|
||||
@abstractmethod
|
||||
def deserialize_with_context(
|
||||
self, context: DeserializationContext
|
||||
) -> DeserializationType:
|
||||
"""
|
||||
Transform calldata into python value.
|
||||
|
||||
:param context: context of this deserialization.
|
||||
:return: defined DeserializationType.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: SerializationType
|
||||
) -> Generator[int, None, None]:
|
||||
"""
|
||||
Transform python value into calldata.
|
||||
|
||||
:param context: context of this serialization.
|
||||
:param value: python value to serialize.
|
||||
:return: defined SerializationType.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def remove_units_from_serialized_data(serialized_data: List) -> List:
|
||||
return [x for x in serialized_data if x is not None]
|
||||
@ -0,0 +1,71 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Generator, OrderedDict, Tuple, Union
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
from ..tuple_dataclass import TupleDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumSerializer(CairoDataSerializer[Union[Dict, TupleDataclass], TupleDataclass]):
|
||||
"""
|
||||
Serializer of enums.
|
||||
Can serialize a dictionary and TupleDataclass.
|
||||
Deserializes data to a TupleDataclass.
|
||||
|
||||
Example:
|
||||
enum MyEnum {
|
||||
a: u128,
|
||||
b: u128
|
||||
}
|
||||
|
||||
{"a": 1} => [0, 1]
|
||||
{"b": 100} => [1, 100]
|
||||
TupleDataclass(variant='a', value=100) => [0, 100]
|
||||
"""
|
||||
|
||||
serializers: OrderedDict[str, CairoDataSerializer]
|
||||
|
||||
def deserialize_with_context(
|
||||
self, context: DeserializationContext
|
||||
) -> TupleDataclass:
|
||||
[variant_index] = context.reader.read(1)
|
||||
variant_name, serializer = self._get_variant(variant_index)
|
||||
|
||||
with context.push_entity("enum.variant: " + variant_name):
|
||||
result_dict = {
|
||||
"variant": variant_name,
|
||||
"value": serializer.deserialize_with_context(context),
|
||||
}
|
||||
|
||||
return TupleDataclass.from_dict(result_dict)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Union[Dict, TupleDataclass]
|
||||
) -> Generator[int, None, None]:
|
||||
if isinstance(value, Dict):
|
||||
items = list(value.items())
|
||||
if len(items) != 1:
|
||||
raise ValueError(
|
||||
"Can serialize only one enum variant, got: " + str(len(items))
|
||||
)
|
||||
|
||||
variant_name, variant_value = items[0]
|
||||
else:
|
||||
variant_name, variant_value = value
|
||||
|
||||
yield self._get_variant_index(variant_name)
|
||||
yield from self.serializers[variant_name].serialize_with_context(
|
||||
context, variant_value
|
||||
)
|
||||
|
||||
def _get_variant(self, variant_index: int) -> Tuple[str, CairoDataSerializer]:
|
||||
return list(self.serializers.items())[variant_index]
|
||||
|
||||
def _get_variant_index(self, variant_name: str) -> int:
|
||||
return list(self.serializers.keys()).index(variant_name)
|
||||
@ -0,0 +1,50 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator
|
||||
|
||||
from ...cairo.felt import encode_shortstring, is_in_felt_range
|
||||
from ...constants import FIELD_PRIME
|
||||
from .._context import (
|
||||
Context,
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeltSerializer(CairoDataSerializer[int, int]):
|
||||
"""
|
||||
Serializer for field element. At the time of writing it is the only existing numeric type.
|
||||
"""
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> int:
|
||||
[val] = context.reader.read(1)
|
||||
self._ensure_felt(context, val)
|
||||
return val
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: int
|
||||
) -> Generator[int, None, None]:
|
||||
if isinstance(value, str):
|
||||
warnings.warn(
|
||||
"Serializing shortstrings in FeltSerializer is deprecated. "
|
||||
"Use starknet_py.cairo.felt.encode_shortstring instead.",
|
||||
category=DeprecationWarning,
|
||||
)
|
||||
value = encode_shortstring(value)
|
||||
yield value
|
||||
return
|
||||
|
||||
context.ensure_valid_type(value, isinstance(value, int), "int")
|
||||
self._ensure_felt(context, value)
|
||||
yield value
|
||||
|
||||
@staticmethod
|
||||
def _ensure_felt(context: Context, value: int):
|
||||
context.ensure_valid_value(
|
||||
is_in_felt_range(value),
|
||||
f"invalid value '{value}' - must be in [0, {FIELD_PRIME}) range",
|
||||
)
|
||||
@ -0,0 +1,58 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Generator, NamedTuple, OrderedDict, Union
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from ._common import (
|
||||
deserialize_to_dict,
|
||||
serialize_from_dict,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
from ..tuple_dataclass import TupleDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class NamedTupleSerializer(
|
||||
CairoDataSerializer[Union[Dict, NamedTuple, TupleDataclass], TupleDataclass]
|
||||
):
|
||||
"""
|
||||
Serializer for tuples with named fields.
|
||||
Can serialize a dictionary, a named tuple and TupleDataclass.
|
||||
Deserializes data to a TupleDataclass.
|
||||
|
||||
Example:
|
||||
{"a": 1, "b": 2} => [1,2]
|
||||
"""
|
||||
|
||||
serializers: OrderedDict[str, CairoDataSerializer]
|
||||
|
||||
def deserialize_with_context(
|
||||
self, context: DeserializationContext
|
||||
) -> TupleDataclass:
|
||||
as_dictionary = deserialize_to_dict(self.serializers, context)
|
||||
return TupleDataclass.from_dict(as_dictionary)
|
||||
|
||||
def serialize_with_context(
|
||||
self,
|
||||
context: SerializationContext,
|
||||
value: Union[Dict, NamedTuple, TupleDataclass],
|
||||
) -> Generator[int, None, None]:
|
||||
# We can't use isinstance(value, NamedTuple), because there is no NamedTuple type.
|
||||
context.ensure_valid_type(
|
||||
value,
|
||||
isinstance(value, (dict, TupleDataclass)) or self._is_namedtuple(value),
|
||||
"dict, NamedTuple or TupleDataclass",
|
||||
)
|
||||
|
||||
# noinspection PyUnresolvedReferences, PyProtectedMember
|
||||
values: Dict = value if isinstance(value, dict) else value._asdict()
|
||||
|
||||
yield from serialize_from_dict(self.serializers, context, values)
|
||||
|
||||
@staticmethod
|
||||
def _is_namedtuple(value) -> bool:
|
||||
return isinstance(value, tuple) and hasattr(value, "_fields")
|
||||
@ -0,0 +1,43 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OptionSerializer(CairoDataSerializer[Optional[Any], Optional[Any]]):
|
||||
"""
|
||||
Serializer for Option type.
|
||||
Can serialize None and common CairoTypes.
|
||||
Deserializes data to None or CairoType.
|
||||
|
||||
Example:
|
||||
None => [1]
|
||||
{"option1": 123, "option2": None} => [0, 123, 1]
|
||||
"""
|
||||
|
||||
serializer: CairoDataSerializer
|
||||
|
||||
def deserialize_with_context(
|
||||
self, context: DeserializationContext
|
||||
) -> Optional[Any]:
|
||||
(is_none,) = context.reader.read(1)
|
||||
if is_none == 1:
|
||||
return None
|
||||
|
||||
return self.serializer.deserialize_with_context(context)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Optional[Any]
|
||||
) -> Generator[int, None, None]:
|
||||
if value is None:
|
||||
yield 1
|
||||
else:
|
||||
yield 0
|
||||
yield from self.serializer.serialize_with_context(context, value)
|
||||
@ -0,0 +1,40 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Generator, List, Tuple
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputSerializer(CairoDataSerializer[List, Tuple]):
|
||||
"""
|
||||
Serializer for function output.
|
||||
Can't serialize anything.
|
||||
Deserializes data to a Tuple.
|
||||
|
||||
Example:
|
||||
[1, 1, 1] => (340282366920938463463374607431768211457)
|
||||
"""
|
||||
|
||||
serializers: List[CairoDataSerializer] = field(init=True)
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> Tuple:
|
||||
result = []
|
||||
|
||||
for index, serializer in enumerate(self.serializers):
|
||||
with context.push_entity("output[" + str(index) + "]"):
|
||||
result.append(serializer.deserialize_with_context(context))
|
||||
|
||||
return tuple(result)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Dict
|
||||
) -> Generator[int, None, None]:
|
||||
raise ValueError(
|
||||
"Output serializer can't be used to transform python data into calldata."
|
||||
)
|
||||
@ -0,0 +1,72 @@
|
||||
from collections import OrderedDict as _OrderedDict
|
||||
from dataclasses import InitVar, dataclass, field
|
||||
from typing import Dict, Generator, OrderedDict
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from ._common import (
|
||||
deserialize_to_dict,
|
||||
serialize_from_dict,
|
||||
)
|
||||
from .array_serializer import ArraySerializer
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
from .felt_serializer import FeltSerializer
|
||||
from ..tuple_dataclass import TupleDataclass
|
||||
|
||||
SIZE_SUFFIX = "_len"
|
||||
SIZE_SUFFIX_LEN = len(SIZE_SUFFIX)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PayloadSerializer(CairoDataSerializer[Dict, TupleDataclass]):
|
||||
"""
|
||||
Serializer for payloads like function arguments/function outputs/events.
|
||||
Can serialize a dictionary.
|
||||
Deserializes data to a TupleDataclass.
|
||||
|
||||
Example:
|
||||
{"a": 1, "b": 2} => [1,2]
|
||||
"""
|
||||
|
||||
# Value present only in constructor.
|
||||
# We don't want to mutate the serializers received in constructor.
|
||||
input_serializers: InitVar[OrderedDict[str, CairoDataSerializer]]
|
||||
|
||||
serializers: OrderedDict[str, CairoDataSerializer] = field(init=False)
|
||||
|
||||
def __post_init__(self, input_serializers):
|
||||
"""
|
||||
ABI adds ARG_len for every argument ARG that is an array. We parse length as a part of ArraySerializer, so we
|
||||
need to remove those lengths from args.
|
||||
"""
|
||||
self.serializers = _OrderedDict(
|
||||
(key, serializer)
|
||||
for key, serializer in input_serializers.items()
|
||||
if not self._is_len_arg(key, input_serializers)
|
||||
)
|
||||
|
||||
def deserialize_with_context(
|
||||
self, context: DeserializationContext
|
||||
) -> TupleDataclass:
|
||||
as_dictionary = deserialize_to_dict(self.serializers, context)
|
||||
return TupleDataclass.from_dict(as_dictionary)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Dict
|
||||
) -> Generator[int, None, None]:
|
||||
yield from serialize_from_dict(self.serializers, context, value)
|
||||
|
||||
@staticmethod
|
||||
def _is_len_arg(arg_name: str, serializers: Dict[str, CairoDataSerializer]) -> bool:
|
||||
return (
|
||||
arg_name.endswith(SIZE_SUFFIX)
|
||||
and isinstance(serializers[arg_name], FeltSerializer)
|
||||
# There is an ArraySerializer under key that is arg_name without the size suffix
|
||||
and isinstance(
|
||||
serializers.get(arg_name[:-SIZE_SUFFIX_LEN]), ArraySerializer
|
||||
)
|
||||
)
|
||||
@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Generator, OrderedDict
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from ._common import (
|
||||
deserialize_to_dict,
|
||||
serialize_from_dict,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StructSerializer(CairoDataSerializer[Dict, Dict]):
|
||||
"""
|
||||
Serializer of custom structures.
|
||||
Can serialize a dictionary.
|
||||
Deserializes data to a dictionary.
|
||||
|
||||
Example:
|
||||
{"a": 1, "b": 2} => [1,2]
|
||||
"""
|
||||
|
||||
serializers: OrderedDict[str, CairoDataSerializer]
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> Dict:
|
||||
return deserialize_to_dict(self.serializers, context)
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Dict
|
||||
) -> Generator[int, None, None]:
|
||||
yield from serialize_from_dict(self.serializers, context, value)
|
||||
@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator, Iterable, List, Tuple
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from ._common import (
|
||||
deserialize_to_list,
|
||||
serialize_from_list,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TupleSerializer(CairoDataSerializer[Iterable, Tuple]):
|
||||
"""
|
||||
Serializer for tuples without named fields.
|
||||
Can serialize any iterable.
|
||||
Deserializes data to a python tuple.
|
||||
|
||||
Example:
|
||||
(1,2,(3,4)) => [1,2,3,4]
|
||||
"""
|
||||
|
||||
serializers: List[CairoDataSerializer]
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> Tuple:
|
||||
return tuple(deserialize_to_list(self.serializers, context))
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Iterable
|
||||
) -> Generator[int, None, None]:
|
||||
yield from serialize_from_list(self.serializers, context, [*value])
|
||||
@ -0,0 +1,76 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator, TypedDict, Union
|
||||
|
||||
from ...cairo.felt import uint256_range_check
|
||||
from .._context import (
|
||||
Context,
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
U128_UPPER_BOUND = 2**128
|
||||
|
||||
|
||||
class Uint256Dict(TypedDict):
|
||||
low: int
|
||||
high: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Uint256Serializer(CairoDataSerializer[Union[int, Uint256Dict], int]):
|
||||
"""
|
||||
Serializer of Uint256. In Cairo it is represented by structure {low: Uint128, high: Uint128}.
|
||||
Can serialize an int.
|
||||
Deserializes data to an int.
|
||||
|
||||
Examples:
|
||||
0 => [0,0]
|
||||
1 => [1,0]
|
||||
2**128 => [0,1]
|
||||
3 + 2**128 => [3,1]
|
||||
"""
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> int:
|
||||
[low, high] = context.reader.read(2)
|
||||
|
||||
# Checking if resulting value is in [0, 2**256) range is not enough. Uint256 should be made of two uint128.
|
||||
with context.push_entity("low"):
|
||||
self._ensure_valid_uint128(low, context)
|
||||
with context.push_entity("high"):
|
||||
self._ensure_valid_uint128(high, context)
|
||||
|
||||
return (high << 128) + low
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Union[int, Uint256Dict]
|
||||
) -> Generator[int, None, None]:
|
||||
context.ensure_valid_type(value, isinstance(value, (int, dict)), "int or dict")
|
||||
if isinstance(value, int):
|
||||
yield from self._serialize_from_int(value)
|
||||
else:
|
||||
yield from self._serialize_from_dict(context, value)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_from_int(value: int) -> Generator[int, None, None]:
|
||||
uint256_range_check(value)
|
||||
result = (value % 2**128, value // 2**128)
|
||||
yield from result
|
||||
|
||||
def _serialize_from_dict(
|
||||
self, context: SerializationContext, value: Uint256Dict
|
||||
) -> Generator[int, None, None]:
|
||||
with context.push_entity("low"):
|
||||
self._ensure_valid_uint128(value["low"], context)
|
||||
yield value["low"]
|
||||
with context.push_entity("high"):
|
||||
self._ensure_valid_uint128(value["high"], context)
|
||||
yield value["high"]
|
||||
|
||||
@staticmethod
|
||||
def _ensure_valid_uint128(value: int, context: Context):
|
||||
context.ensure_valid_value(
|
||||
0 <= value < U128_UPPER_BOUND, "expected value in range [0;2**128)"
|
||||
)
|
||||
@ -0,0 +1,100 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generator, TypedDict, Union
|
||||
|
||||
from ...cairo.felt import uint256_range_check
|
||||
from .._context import (
|
||||
Context,
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
class Uint256Dict(TypedDict):
|
||||
low: int
|
||||
high: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UintSerializer(CairoDataSerializer[Union[int, Uint256Dict], int]):
|
||||
"""
|
||||
Serializer of uint. In Cairo there are few uints (u8, ..., u128 and u256).
|
||||
u256 is represented by structure {low: u128, high: u128}.
|
||||
Can serialize an int and dict.
|
||||
Deserializes data to an int.
|
||||
|
||||
Examples:
|
||||
if bits < 256:
|
||||
0 => [0]
|
||||
1 => [1]
|
||||
2**128-1 => [2**128-1]
|
||||
else:
|
||||
0 => [0,0]
|
||||
1 => [1,0]
|
||||
2**128 => [0,1]
|
||||
3 + 2**128 => [3,1]
|
||||
"""
|
||||
|
||||
bits: int
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> int:
|
||||
if self.bits < 256:
|
||||
(uint,) = context.reader.read(1)
|
||||
with context.push_entity("uint" + str(self.bits)):
|
||||
self._ensure_valid_uint(uint, context, self.bits)
|
||||
|
||||
return uint
|
||||
|
||||
[low, high] = context.reader.read(2)
|
||||
|
||||
# Checking if resulting value is in [0, 2**256) range is not enough. Uint256 should be made of two uint128.
|
||||
with context.push_entity("low"):
|
||||
self._ensure_valid_uint(low, context, bits=128)
|
||||
with context.push_entity("high"):
|
||||
self._ensure_valid_uint(high, context, bits=128)
|
||||
|
||||
return (high << 128) + low
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Union[int, Uint256Dict]
|
||||
) -> Generator[int, None, None]:
|
||||
context.ensure_valid_type(value, isinstance(value, (int, dict)), "int or dict")
|
||||
if isinstance(value, int):
|
||||
yield from self._serialize_from_int(value, context, self.bits)
|
||||
else:
|
||||
yield from self._serialize_from_dict(context, value)
|
||||
|
||||
@staticmethod
|
||||
def _serialize_from_int(
|
||||
value: int, context: SerializationContext, bits: int
|
||||
) -> Generator[int, None, None]:
|
||||
if bits < 256:
|
||||
UintSerializer._ensure_valid_uint(value, context, bits)
|
||||
|
||||
yield value
|
||||
else:
|
||||
uint256_range_check(value)
|
||||
|
||||
result = (value % 2**128, value >> 128)
|
||||
yield from result
|
||||
|
||||
def _serialize_from_dict(
|
||||
self, context: SerializationContext, value: Uint256Dict
|
||||
) -> Generator[int, None, None]:
|
||||
with context.push_entity("low"):
|
||||
self._ensure_valid_uint(value["low"], context, bits=128)
|
||||
yield value["low"]
|
||||
with context.push_entity("high"):
|
||||
self._ensure_valid_uint(value["high"], context, bits=128)
|
||||
yield value["high"]
|
||||
|
||||
@staticmethod
|
||||
def _ensure_valid_uint(value: int, context: Context, bits: int):
|
||||
"""
|
||||
Ensures that value is a valid uint on `bits` bits.
|
||||
"""
|
||||
context.ensure_valid_value(
|
||||
0 <= value < 2**bits, "expected value in range [0;2**" + str(bits) + ")"
|
||||
)
|
||||
@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
from .._context import (
|
||||
DeserializationContext,
|
||||
SerializationContext,
|
||||
)
|
||||
from .cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnitSerializer(CairoDataSerializer[None, None]):
|
||||
"""
|
||||
Serializer for unit type.
|
||||
Can only serialize None.
|
||||
Deserializes data to None.
|
||||
|
||||
Example:
|
||||
[] => None
|
||||
"""
|
||||
|
||||
def deserialize_with_context(self, context: DeserializationContext) -> None:
|
||||
return None
|
||||
|
||||
def serialize_with_context(
|
||||
self, context: SerializationContext, value: Optional[Any]
|
||||
) -> Generator[None, None, None]:
|
||||
if value is not None:
|
||||
raise ValueError("Can only serialize `None`.")
|
||||
yield None
|
||||
@ -0,0 +1,10 @@
|
||||
class CairoSerializerException(Exception):
|
||||
"""Exception thrown by CairoSerializer."""
|
||||
|
||||
|
||||
class InvalidTypeException(CairoSerializerException, TypeError):
|
||||
"""Exception thrown when invalid type was provided."""
|
||||
|
||||
|
||||
class InvalidValueException(CairoSerializerException, ValueError):
|
||||
"""Exception thrown when invalid value was provided."""
|
||||
@ -0,0 +1,229 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Union
|
||||
|
||||
from ..abi.v0 import Abi as AbiV0
|
||||
from ..abi.v1 import Abi as AbiV1
|
||||
from ..abi.v2 import Abi as AbiV2
|
||||
from ..cairo.data_types import (
|
||||
ArrayType,
|
||||
BoolType,
|
||||
CairoType,
|
||||
EnumType,
|
||||
EventType,
|
||||
FeltType,
|
||||
NamedTupleType,
|
||||
OptionType,
|
||||
StructType,
|
||||
TupleType,
|
||||
UintType,
|
||||
UnitType,
|
||||
)
|
||||
from .data_serializers import (
|
||||
BoolSerializer,
|
||||
ByteArraySerializer,
|
||||
)
|
||||
from .data_serializers.array_serializer import ArraySerializer
|
||||
from .data_serializers.cairo_data_serializer import (
|
||||
CairoDataSerializer,
|
||||
)
|
||||
from .data_serializers.enum_serializer import EnumSerializer
|
||||
from .data_serializers.felt_serializer import FeltSerializer
|
||||
from .data_serializers.named_tuple_serializer import (
|
||||
NamedTupleSerializer,
|
||||
)
|
||||
from .data_serializers.option_serializer import (
|
||||
OptionSerializer,
|
||||
)
|
||||
from .data_serializers.output_serializer import (
|
||||
OutputSerializer,
|
||||
)
|
||||
from .data_serializers.payload_serializer import (
|
||||
PayloadSerializer,
|
||||
)
|
||||
from .data_serializers.struct_serializer import (
|
||||
StructSerializer,
|
||||
)
|
||||
from .data_serializers.tuple_serializer import TupleSerializer
|
||||
from .data_serializers.uint256_serializer import (
|
||||
Uint256Serializer,
|
||||
)
|
||||
from .data_serializers.uint_serializer import UintSerializer
|
||||
from .data_serializers.unit_serializer import UnitSerializer
|
||||
from .errors import InvalidTypeException
|
||||
from .function_serialization_adapter import (
|
||||
FunctionSerializationAdapter,
|
||||
FunctionSerializationAdapterV1,
|
||||
)
|
||||
|
||||
_uint256_type = StructType("Uint256", OrderedDict(low=FeltType(), high=FeltType()))
|
||||
_byte_array_type = StructType(
|
||||
"core::byte_array::ByteArray",
|
||||
OrderedDict(
|
||||
data=ArrayType(FeltType()),
|
||||
pending_word=FeltType(),
|
||||
pending_word_len=UintType(bits=32),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def serializer_for_type(cairo_type: CairoType) -> CairoDataSerializer:
|
||||
"""
|
||||
Create a serializer for cairo type.
|
||||
|
||||
:param cairo_type: CairoType.
|
||||
:return: CairoDataSerializer.
|
||||
"""
|
||||
# pylint: disable=too-many-return-statements, too-many-branches
|
||||
if isinstance(cairo_type, FeltType):
|
||||
return FeltSerializer()
|
||||
|
||||
if isinstance(cairo_type, BoolType):
|
||||
return BoolSerializer()
|
||||
|
||||
if isinstance(cairo_type, StructType):
|
||||
# Special case: Uint256 is represented as struct
|
||||
if cairo_type == _uint256_type:
|
||||
return Uint256Serializer()
|
||||
|
||||
if cairo_type == _byte_array_type:
|
||||
return ByteArraySerializer()
|
||||
|
||||
return StructSerializer(
|
||||
OrderedDict(
|
||||
(name, serializer_for_type(member_type))
|
||||
for name, member_type in cairo_type.types.items()
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(cairo_type, ArrayType):
|
||||
return ArraySerializer(serializer_for_type(cairo_type.inner_type))
|
||||
|
||||
if isinstance(cairo_type, TupleType):
|
||||
return TupleSerializer(
|
||||
[serializer_for_type(member) for member in cairo_type.types]
|
||||
)
|
||||
|
||||
if isinstance(cairo_type, NamedTupleType):
|
||||
return NamedTupleSerializer(
|
||||
OrderedDict(
|
||||
(name, serializer_for_type(member_type))
|
||||
for name, member_type in cairo_type.types.items()
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(cairo_type, UintType):
|
||||
return UintSerializer(bits=cairo_type.bits)
|
||||
|
||||
if isinstance(cairo_type, OptionType):
|
||||
return OptionSerializer(serializer_for_type(cairo_type.type))
|
||||
|
||||
if isinstance(cairo_type, UnitType):
|
||||
return UnitSerializer()
|
||||
|
||||
if isinstance(cairo_type, EnumType):
|
||||
return EnumSerializer(
|
||||
OrderedDict(
|
||||
(name, serializer_for_type(variant_type))
|
||||
for name, variant_type in cairo_type.variants.items()
|
||||
)
|
||||
)
|
||||
if isinstance(cairo_type, EventType):
|
||||
return serializer_for_payload(cairo_type.types)
|
||||
|
||||
raise InvalidTypeException(f"Received unknown Cairo type '{cairo_type}'.")
|
||||
|
||||
|
||||
# We don't want to require users to use OrderedDict. Regular python requires order since python 3.7.
|
||||
def serializer_for_payload(payload: Dict[str, CairoType]) -> PayloadSerializer:
|
||||
"""
|
||||
Create PayloadSerializer for types listed in a dictionary. Please note that the order of fields in the dict is
|
||||
very important. Make sure the keys are provided in the right order.
|
||||
|
||||
:param payload: dictionary with cairo types.
|
||||
:return: PayloadSerializer that can be used to (de)serialize events/function calls.
|
||||
"""
|
||||
return PayloadSerializer(
|
||||
OrderedDict(
|
||||
(name, serializer_for_type(cairo_type))
|
||||
for name, cairo_type in payload.items()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def serializer_for_outputs(payload: List[CairoType]) -> OutputSerializer:
|
||||
"""
|
||||
Create OutputSerializer for types in list. Please note that the order of fields in the list is
|
||||
very important. Make sure the types are provided in the right order.
|
||||
|
||||
:param payload: list with cairo types.
|
||||
:return: OutputSerializer that can be used to deserialize function outputs.
|
||||
"""
|
||||
return OutputSerializer(
|
||||
serializers=[serializer_for_type(cairo_type) for cairo_type in payload]
|
||||
)
|
||||
|
||||
|
||||
EventV0 = AbiV0.Event
|
||||
EventV1 = AbiV1.Event
|
||||
EventV2 = EventType
|
||||
|
||||
|
||||
def serializer_for_event(event: EventV0 | EventV1 | EventV2) -> PayloadSerializer:
|
||||
"""
|
||||
Create serializer for an event.
|
||||
|
||||
:param event: parsed event.
|
||||
:return: PayloadSerializer that can be used to (de)serialize events.
|
||||
"""
|
||||
if isinstance(event, EventV0):
|
||||
return serializer_for_payload(event.data)
|
||||
if isinstance(event, EventV1):
|
||||
return serializer_for_payload(event.inputs)
|
||||
return serializer_for_payload(event.types)
|
||||
|
||||
|
||||
def serializer_for_function(
|
||||
abi_function: AbiV0.Function,
|
||||
) -> FunctionSerializationAdapter:
|
||||
"""
|
||||
Create FunctionSerializationAdapter for serializing function inputs and deserializing function outputs.
|
||||
|
||||
:param abi_function: parsed function's abi.
|
||||
:return: FunctionSerializationAdapter.
|
||||
"""
|
||||
return FunctionSerializationAdapter(
|
||||
inputs_serializer=serializer_for_payload(abi_function.inputs),
|
||||
outputs_deserializer=serializer_for_payload(abi_function.outputs),
|
||||
)
|
||||
|
||||
|
||||
def serializer_for_function_v1(
|
||||
abi_function: Union[AbiV1.Function, AbiV2.Function],
|
||||
) -> FunctionSerializationAdapter:
|
||||
"""
|
||||
Create FunctionSerializationAdapter for serializing function inputs and deserializing function outputs.
|
||||
|
||||
:param abi_function: parsed function's abi.
|
||||
:return: FunctionSerializationAdapter.
|
||||
"""
|
||||
return FunctionSerializationAdapterV1(
|
||||
inputs_serializer=serializer_for_payload(abi_function.inputs),
|
||||
outputs_deserializer=serializer_for_outputs(abi_function.outputs),
|
||||
)
|
||||
|
||||
|
||||
def serializer_for_constructor_v2(
|
||||
abi_function: AbiV2.Constructor,
|
||||
) -> FunctionSerializationAdapter:
|
||||
"""
|
||||
Create FunctionSerializationAdapter for serializing constructor inputs.
|
||||
|
||||
:param abi_function: parsed constructor's abi.
|
||||
:return: FunctionSerializationAdapter.
|
||||
"""
|
||||
return FunctionSerializationAdapterV1(
|
||||
inputs_serializer=serializer_for_payload(abi_function.inputs),
|
||||
outputs_deserializer=serializer_for_outputs([]),
|
||||
)
|
||||
@ -0,0 +1,110 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Set, Tuple
|
||||
|
||||
from ..cairo.felt import CairoData
|
||||
from .data_serializers.output_serializer import (
|
||||
OutputSerializer,
|
||||
)
|
||||
from .data_serializers.payload_serializer import (
|
||||
PayloadSerializer,
|
||||
)
|
||||
from .errors import InvalidTypeException
|
||||
from .tuple_dataclass import TupleDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionSerializationAdapter:
|
||||
"""
|
||||
Class serializing ``*args`` and ``**kwargs`` by adapting them to function inputs.
|
||||
"""
|
||||
|
||||
inputs_serializer: PayloadSerializer
|
||||
outputs_deserializer: PayloadSerializer
|
||||
|
||||
expected_args: Tuple[str] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
self.expected_args = tuple(
|
||||
self.inputs_serializer.serializers.keys()
|
||||
) # pyright: ignore
|
||||
|
||||
def serialize(self, *args, **kwargs) -> CairoData:
|
||||
"""
|
||||
Method using args and kwargs to match members and serialize them separately.
|
||||
|
||||
:return: Members serialized separately in SerializedPayload.
|
||||
"""
|
||||
named_arguments = self._merge_arguments(args, kwargs)
|
||||
return self.inputs_serializer.serialize(named_arguments)
|
||||
|
||||
def deserialize(self, data: List[int]) -> TupleDataclass:
|
||||
"""
|
||||
Deserializes data into TupleDataclass containing python representations.
|
||||
|
||||
:return: cairo data.
|
||||
"""
|
||||
return self.outputs_deserializer.deserialize(data)
|
||||
|
||||
def _merge_arguments(self, args: Tuple, kwargs: Dict) -> Dict:
|
||||
"""
|
||||
Merges positional and keyed arguments.
|
||||
"""
|
||||
# After this line we know that len(args) <= len(self.expected_args)
|
||||
self._ensure_no_unnecessary_positional_args(args)
|
||||
|
||||
named_arguments = dict(kwargs)
|
||||
for arg, input_name in zip(args, self.expected_args):
|
||||
if input_name in kwargs:
|
||||
raise InvalidTypeException(
|
||||
f"Both positional and named argument provided for '{input_name}'."
|
||||
)
|
||||
named_arguments[input_name] = arg
|
||||
|
||||
expected_args = set(self.expected_args)
|
||||
provided_args = set(named_arguments.keys())
|
||||
|
||||
# named_arguments might have unnecessary arguments coming from kwargs (we ensure that
|
||||
# len(args) <= len(self.expected_args) above)
|
||||
self._ensure_no_unnecessary_args(expected_args, provided_args)
|
||||
|
||||
# there might be some argument missing (not provided)
|
||||
self._ensure_no_missing_args(expected_args, provided_args)
|
||||
|
||||
return named_arguments
|
||||
|
||||
def _ensure_no_unnecessary_positional_args(self, args: Tuple):
|
||||
if len(args) > len(self.expected_args):
|
||||
raise InvalidTypeException(
|
||||
f"Provided {len(args)} positional arguments, {len(self.expected_args)} possible."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_no_unnecessary_args(expected_args: Set[str], provided_args: Set[str]):
|
||||
excessive_arguments = provided_args - expected_args
|
||||
if excessive_arguments:
|
||||
raise InvalidTypeException(
|
||||
f"Unnecessary named arguments provided: '{', '.join(excessive_arguments)}'."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_no_missing_args(expected_args: Set[str], provided_args: Set[str]):
|
||||
missing_arguments = expected_args - provided_args
|
||||
if missing_arguments:
|
||||
raise InvalidTypeException(
|
||||
f"Missing arguments: '{', '.join(missing_arguments)}'."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionSerializationAdapterV1(FunctionSerializationAdapter):
|
||||
outputs_deserializer: OutputSerializer
|
||||
|
||||
def deserialize(self, data: List[int]) -> Tuple:
|
||||
"""
|
||||
Deserializes data into TupleDataclass containing python representations.
|
||||
|
||||
:return: cairo data.
|
||||
"""
|
||||
return self.outputs_deserializer.deserialize(data)
|
||||
@ -0,0 +1,59 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields, make_dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class TupleDataclass:
|
||||
"""
|
||||
Dataclass that behaves like a tuple at the same time. Used when data has defined order and names.
|
||||
For instance in case of named tuples or function responses.
|
||||
"""
|
||||
|
||||
# getattr is called when attribute is not found in object. For instance when using object.unknown_attribute.
|
||||
# This way pyright will know that there might be some arguments it doesn't know about and will stop complaining
|
||||
# about some fields that don't exist statically.
|
||||
def __getattr__(self, item):
|
||||
# This should always fail - only attributes that don't exist end up in here.
|
||||
# We use __getattribute__ to get the native error.
|
||||
return super().__getattribute__(item)
|
||||
|
||||
def __getitem__(self, item: int):
|
||||
field = fields(self)[item]
|
||||
return getattr(self, field.name)
|
||||
|
||||
def __iter__(self):
|
||||
return (getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def as_tuple(self) -> Tuple:
|
||||
"""
|
||||
Creates a regular tuple from TupleDataclass.
|
||||
"""
|
||||
return tuple(self)
|
||||
|
||||
def as_dict(self) -> Dict:
|
||||
"""
|
||||
Creates a regular dict from TupleDataclass.
|
||||
"""
|
||||
return {field.name: getattr(self, field.name) for field in fields(self)}
|
||||
|
||||
# Added for backward compatibility with previous implementation based on NamedTuple
|
||||
def _asdict(self):
|
||||
return self.as_dict()
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, TupleDataclass):
|
||||
return self.as_tuple() == other.as_tuple()
|
||||
return self.as_tuple() == other
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: Dict, *, name: Optional[str] = None) -> TupleDataclass:
|
||||
result_class = make_dataclass(
|
||||
name or "TupleDataclass",
|
||||
fields=[(key, type(value)) for key, value in data.items()],
|
||||
bases=(TupleDataclass,),
|
||||
frozen=True,
|
||||
eq=False,
|
||||
)
|
||||
return result_class(**data)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user