Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor for Dialect. #99

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clvm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .SExp import SExp
from .dialect import Dialect # noqa
from .operators import ( # noqa
QUOTE_ATOM,
QUOTE_ATOM, # deprecated
KEYWORD_TO_ATOM,
KEYWORD_FROM_ATOM,
)
Expand Down
26 changes: 26 additions & 0 deletions clvm/chainable_multi_op_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass
from typing import Optional, Tuple

from .types import CLVMObjectType, MultiOpFn, OperatorDict


@dataclass
class ChainableMultiOpFn:
"""
This structure handles clvm operators. Given an atom, it looks it up in a `dict`, then
falls back to calling `unknown_op_handler`.
"""
op_lookup: OperatorDict
unknown_op_handler: MultiOpFn

def __call__(
self, op: bytes, arguments: CLVMObjectType, max_cost: Optional[int] = None
) -> Tuple[int, CLVMObjectType]:
f = self.op_lookup.get(op)
if f:
try:
return f(arguments)
except TypeError:
# some operators require `max_cost`
return f(arguments, max_cost)
return self.unknown_op_handler(op, arguments, max_cost)
37 changes: 37 additions & 0 deletions clvm/chia_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from .casts import int_to_bytes
from .dialect import ConversionFn, Dialect, new_dialect, opcode_table_for_backend

KEYWORDS = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the dialect responsible for compiling clvm code into byte-code too? if not, we could just put the full function names in this list, removing the need for OP_REWRITE and some of the logic in op_atom_to_imp_table

# core opcodes 0x01-x08
". q a i c f r l x "

# opcodes on atoms as strings 0x09-0x0f
"= >s sha256 substr strlen concat . "

# opcodes on atoms as ints 0x10-0x17
"+ - * / divmod > ash lsh "

# opcodes on atoms as vectors of bools 0x18-0x1c
"logand logior logxor lognot . "

# opcodes for bls 1381 0x1d-0x1f
"point_add pubkey_for_exp . "

# bool opcodes 0x20-0x23
"not any all . "

# misc 0x24
"softfork "
).split()

KEYWORD_FROM_ATOM = {int_to_bytes(k): v for k, v in enumerate(KEYWORDS)}
KEYWORD_TO_ATOM = {v: k for k, v in KEYWORD_FROM_ATOM.items()}


def chia_dialect(strict: bool, to_python: ConversionFn, backend=None) -> Dialect:
quote_kw = KEYWORD_TO_ATOM["q"]
apply_kw = KEYWORD_TO_ATOM["a"]
dialect = new_dialect(quote_kw, apply_kw, strict, to_python, backend=backend)
table = opcode_table_for_backend(KEYWORD_TO_ATOM, backend=backend)
dialect.update(table)
return dialect
168 changes: 168 additions & 0 deletions clvm/dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from typing import Callable, Optional, Tuple

try:
import clvm_rs
except ImportError:
clvm_rs = None

from . import core_ops, more_ops
from .chainable_multi_op_fn import ChainableMultiOpFn
from .handle_unknown_op import (
handle_unknown_op_softfork_ready,
handle_unknown_op_strict,
)
from .run_program import _run_program
from .types import CLVMObjectType, ConversionFn, MultiOpFn, OperatorDict


OP_REWRITE = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like it also would belong in the chia-specific dialect, no?

"+": "add",
"-": "subtract",
"*": "multiply",
"/": "div",
"i": "if",
"c": "cons",
"f": "first",
"r": "rest",
"l": "listp",
"x": "raise",
"=": "eq",
">": "gr",
">s": "gr_bytes",
}


def op_table_for_module(mod):

# python-implemented operators don't take `max_cost` and rust-implemented operators do
# So we make the `max_cost` operator optional with this trick
# TODO: have python-implemented ops also take `max_cost` and unify the API.

def elide_max_cost(f):
def inner_op(sexp, max_cost=None):
try:
return f(sexp, max_cost)
except TypeError:
return f(sexp)

return inner_op

return {
k: elide_max_cost(v) for k, v in mod.__dict__.items() if k.startswith("op_")
}


def op_imp_table_for_backend(backend):
if backend is None and clvm_rs:
backend = "native"

if backend == "native":
if clvm_rs is None:
raise RuntimeError("native backend not installed")
return clvm_rs.native_opcodes_dict()

table = {}
table.update(op_table_for_module(core_ops))
table.update(op_table_for_module(more_ops))
return table


def op_atom_to_imp_table(op_imp_table, keyword_to_atom, op_rewrite=OP_REWRITE):
op_atom_to_imp_table = {}
for op, bytecode in keyword_to_atom.items():
op_name = "op_%s" % op_rewrite.get(op, op)
op_f = op_imp_table.get(op_name)
if op_f:
op_atom_to_imp_table[bytecode] = op_f
return op_atom_to_imp_table


def opcode_table_for_backend(keyword_to_atom, backend):
op_imp_table = op_imp_table_for_backend(backend)
return op_atom_to_imp_table(op_imp_table, keyword_to_atom)


class Dialect:
def __init__(
self,
quote_kw: bytes,
apply_kw: bytes,
multi_op_fn: MultiOpFn,
to_python: ConversionFn,
):
self.quote_kw = quote_kw
self.apply_kw = apply_kw
self.opcode_lookup = dict()
self.multi_op_fn = ChainableMultiOpFn(self.opcode_lookup, multi_op_fn)
self.to_python = to_python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the convention these days was to declare members, with types and sometimes initial values in the class body, rather than just assigning them in the __init__() function.


def update(self, d: OperatorDict) -> None:
self.opcode_lookup.update(d)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that you have this update() function, why do you need to be able to chain the multi_op_fn with ChainableMultiOpFn? it looks more complex than it needs to be


def clear(self) -> None:
self.opcode_lookup.clear()

def run_program(
self,
program: CLVMObjectType,
env: CLVMObjectType,
max_cost: int,
pre_eval_f: Optional[
Callable[[CLVMObjectType, CLVMObjectType], Tuple[int, CLVMObjectType]]
] = None,
) -> Tuple[int, CLVMObjectType]:
cost, r = _run_program(
program,
env,
self.multi_op_fn,
self.quote_kw,
self.apply_kw,
max_cost,
pre_eval_f,
)
return cost, self.to_python(r)


def native_new_dialect(
quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn
) -> Dialect:
unknown_op_callback = (
clvm_rs.NATIVE_OP_UNKNOWN_STRICT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these flags or magic-values exported by the rust bindings, to indicate it's a native rust function?

if strict
else clvm_rs.NATIVE_OP_UNKNOWN_NON_STRICT
)
dialect = clvm_rs.Dialect(
quote_kw,
apply_kw,
unknown_op_callback,
to_python=to_python,
)
return dialect


def python_new_dialect(
quote_kw: bytes, apply_kw: bytes, strict: bool, to_python: ConversionFn
) -> Dialect:
unknown_op_callback = (
handle_unknown_op_strict if strict else handle_unknown_op_softfork_ready
)
dialect = Dialect(
quote_kw,
apply_kw,
unknown_op_callback,
to_python=to_python,
)
return dialect


def new_dialect(
quote_kw: bytes,
apply_kw: bytes,
strict: bool,
to_python: ConversionFn,
backend=None,
):
if backend is None:
backend = "python" if clvm_rs is None else "native"
backend_f = native_new_dialect if backend == "native" else python_new_dialect
return backend_f(quote_kw, apply_kw, strict, to_python)
124 changes: 124 additions & 0 deletions clvm/handle_unknown_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import Tuple

from .CLVMObject import CLVMObject
from .EvalError import EvalError

from .costs import (
ARITH_BASE_COST,
ARITH_COST_PER_BYTE,
ARITH_COST_PER_ARG,
MUL_BASE_COST,
MUL_COST_PER_OP,
MUL_LINEAR_COST_PER_BYTE,
MUL_SQUARE_COST_PER_BYTE_DIVIDER,
CONCAT_BASE_COST,
CONCAT_COST_PER_ARG,
CONCAT_COST_PER_BYTE,
)


def handle_unknown_op_strict(op, arguments, _max_cost=None):
raise EvalError("unimplemented operator", arguments.to(op))


def args_len(op_name, args):
for arg in args.as_iter():
if arg.pair:
raise EvalError("%s requires int args" % op_name, arg)
yield len(arg.as_atom())


# unknown ops are reserved if they start with 0xffff
# otherwise, unknown ops are no-ops, but they have costs. The cost is computed
# like this:

# byte index (reverse):
# | 4 | 3 | 2 | 1 | 0 |
# +---+---+---+---+------------+
# | multiplier |XX | XXXXXX |
# +---+---+---+---+---+--------+
# ^ ^ ^
# | | + 6 bits ignored when computing cost
# cost_multiplier |
# + 2 bits
# cost_function

# 1 is always added to the multiplier before using it to multiply the cost, this
# is since cost may not be 0.

# cost_function is 2 bits and defines how cost is computed based on arguments:
# 0: constant, cost is 1 * (multiplier + 1)
# 1: computed like operator add, multiplied by (multiplier + 1)
# 2: computed like operator mul, multiplied by (multiplier + 1)
# 3: computed like operator concat, multiplied by (multiplier + 1)

# this means that unknown ops where cost_function is 1, 2, or 3, may still be
# fatal errors if the arguments passed are not atoms.


def handle_unknown_op_softfork_ready(
op: bytes, args: CLVMObject, max_cost: int
) -> Tuple[int, CLVMObject]:
# any opcode starting with ffff is reserved (i.e. fatal error)
# opcodes are not allowed to be empty
if len(op) == 0 or op[:2] == b"\xff\xff":
raise EvalError("reserved operator", args.to(op))

# all other unknown opcodes are no-ops
# the cost of the no-ops is determined by the opcode number, except the
# 6 least significant bits.

cost_function = (op[-1] & 0b11000000) >> 6
# the multiplier cannot be 0. it starts at 1

if len(op) > 5:
raise EvalError("invalid operator", args.to(op))

cost_multiplier = int.from_bytes(op[:-1], "big", signed=False) + 1

# 0 = constant
# 1 = like op_add/op_sub
# 2 = like op_multiply
# 3 = like op_concat
if cost_function == 0:
cost = 1
elif cost_function == 1:
# like op_add
cost = ARITH_BASE_COST
arg_size = 0
for length in args_len("unknown op", args):
arg_size += length
cost += ARITH_COST_PER_ARG
cost += arg_size * ARITH_COST_PER_BYTE
elif cost_function == 2:
# like op_multiply
cost = MUL_BASE_COST
operands = args_len("unknown op", args)
try:
vs = next(operands)
for rs in operands:
cost += MUL_COST_PER_OP
cost += (rs + vs) * MUL_LINEAR_COST_PER_BYTE
cost += (rs * vs) // MUL_SQUARE_COST_PER_BYTE_DIVIDER
# this is an estimate, since we don't want to actually multiply the
# values
vs += rs
except StopIteration:
pass

elif cost_function == 3:
# like concat
cost = CONCAT_BASE_COST
length = 0
for arg in args.as_iter():
if arg.pair:
raise EvalError("unknown op on list", arg)
cost += CONCAT_COST_PER_ARG
length += len(arg.atom)
cost += length * CONCAT_COST_PER_BYTE

cost *= cost_multiplier
if cost >= 2**32:
raise EvalError("invalid operator", args.to(op))

return (cost, args.to(b""))
Loading