Module opshin.compiler

Expand source code
import ast
import copy
import typing
from ast import Load, Name, Constant, Slice

import pluthon as plt
import uplc.ast as uplc
from pycardano import PlutusData
from uplc.ast import data_from_cbor

from .bridge import to_uplc_builtin
from .prelude import Nothing
from .type_impls import (
    InstanceType,
    UnionType,
    UnitType,
    RecordType,
    transform_ext_params_map,
    AnyType,
    transform_output_map,
    ClassType,
    PolymorphicFunctionInstanceType,
    ListType,
    TupleType,
    PairType,
    IntegerInstanceType,
    empty_list,
    DictType,
    ByteStringType,
    FunctionType,
    OUnit,
)
from .type_inference import map_to_orig_name, AggressiveTypeInferencer
from .typed_ast import *

from .compiler_config import DEFAULT_CONFIG
from .optimize.optimize_const_folding import OptimizeConstantFolding
from .optimize.optimize_remove_deadconstants import OptimizeRemoveDeadconstants
from .optimize.optimize_union_expansion import OptimizeUnionExpansion

from .rewrite.rewrite_assert_none import RewriteAssertNone
from .rewrite.rewrite_augassign import RewriteAugAssign
from .rewrite.rewrite_cast_condition import RewriteConditions
from .rewrite.rewrite_comparison_chaining import RewriteComparisonChaining
from .rewrite.rewrite_empty_dicts import RewriteEmptyDicts
from .rewrite.rewrite_empty_lists import RewriteEmptyLists
from .rewrite.rewrite_forbidden_overwrites import RewriteForbiddenOverwrites
from .rewrite.rewrite_forbidden_return import RewriteForbiddenReturn
from .rewrite.rewrite_import import RewriteImport
from .rewrite.rewrite_import_dataclasses import RewriteImportDataclasses
from .rewrite.rewrite_import_hashlib import RewriteImportHashlib
from .rewrite.rewrite_import_integrity_check import RewriteImportIntegrityCheck
from .rewrite.rewrite_import_plutusdata import RewriteImportPlutusData
from .rewrite.rewrite_import_typing import RewriteImportTyping
from .rewrite.rewrite_import_uplc_builtins import RewriteImportUPLCBuiltins
from .rewrite.rewrite_inject_builtins import RewriteInjectBuiltins
from .rewrite.rewrite_orig_name import RewriteOrigName
from .rewrite.rewrite_remove_type_stuff import RewriteRemoveTypeStuff
from .rewrite.rewrite_scoping import RewriteScoping
from .rewrite.rewrite_subscript38 import RewriteSubscript38
from .rewrite.rewrite_tuple_assign import RewriteTupleAssign
from .optimize.optimize_remove_pass import OptimizeRemovePass
from .optimize.optimize_remove_deadvars import OptimizeRemoveDeadvars, NameLoadCollector
from .util import (
    CompilingNodeTransformer,
    NoOp,
    OVar,
    OLambda,
    OLet,
    OPSHIN_LOGGER,
    all_vars,
    SafeOLambda,
    opshin_name_scheme_compatible_varname,
    force_params,
    SafeApply,
    SafeLambda,
    written_vars,
    custom_fix_missing_locations,
)


BoolOpMap = {
    ast.And: plt.And,
    ast.Or: plt.Or,
}


def rec_constant_map_data(c):
    if isinstance(c, bool):
        return uplc.PlutusInteger(int(c))
    if isinstance(c, int):
        return uplc.PlutusInteger(c)
    if isinstance(c, type(None)):
        return uplc.PlutusConstr(0, [])
    if isinstance(c, bytes):
        return uplc.PlutusByteString(c)
    if isinstance(c, str):
        return uplc.PlutusByteString(c.encode())
    if isinstance(c, list):
        return uplc.PlutusList([rec_constant_map_data(ce) for ce in c])
    if isinstance(c, dict):
        return uplc.PlutusMap(
            dict(
                zip(
                    (rec_constant_map_data(ce) for ce in c.keys()),
                    (rec_constant_map_data(ce) for ce in c.values()),
                )
            )
        )
    # This can occur when PlutusData is generated during constant folding
    if isinstance(c, PlutusData):
        return data_from_cbor(c.to_cbor())
    raise NotImplementedError(f"Unsupported constant type {type(c)}")


def rec_constant_map(c):
    if isinstance(c, bool):
        return uplc.BuiltinBool(c)
    if isinstance(c, int):
        return uplc.BuiltinInteger(c)
    if isinstance(c, type(None)):
        return uplc.BuiltinUnit()
    if isinstance(c, bytes):
        return uplc.BuiltinByteString(c)
    if isinstance(c, str):
        return uplc.BuiltinString(c)
    if isinstance(c, list):
        return uplc.BuiltinList([rec_constant_map(ce) for ce in c])
    if isinstance(c, dict):
        return uplc.BuiltinList(
            [
                uplc.BuiltinPair(*p)
                for p in zip(
                    (rec_constant_map_data(ce) for ce in c.keys()),
                    (rec_constant_map_data(ce) for ce in c.values()),
                )
            ]
        )
    # This can occur when PlutusData is generated during constant folding
    if isinstance(c, PlutusData):
        return data_from_cbor(c.to_cbor())
    raise NotImplementedError(f"Unsupported constant type {type(c)}")


def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
    """
    Wraps the validator function to enable a double function as minting script

    pass_through defines how many parameters x would normally take and should be passed through to x
    """
    return OLambda(
        [f"v{i}" for i in range(pass_through)] + ["a0", "a1"],
        OLet(
            [("p", plt.Apply(x, *(OVar(f"v{i}") for i in range(pass_through))))],
            plt.Ite(
                # if the second argument has constructor 0 = script context
                plt.DelayedChooseData(
                    OVar("a1"),
                    plt.EqualsInteger(plt.Constructor(OVar("a1")), plt.Integer(0)),
                    plt.Bool(False),
                    plt.Bool(False),
                    plt.Bool(False),
                    plt.Bool(False),
                ),
                # call the validator with a0, a1, and plug in "Nothing" for data
                plt.Apply(
                    OVar("p"),
                    plt.UPLCConstant(to_uplc_builtin(Nothing())),
                    OVar("a0"),
                    OVar("a1"),
                ),
                # else call the validator with a0, a1 and return (now partially bound)
                plt.Apply(OVar("p"), OVar("a0"), OVar("a1")),
            ),
        ),
    )


CallAST = typing.Callable[[plt.AST], plt.AST]


class PlutoCompiler(CompilingNodeTransformer):
    """
    Expects a TypedAST and returns UPLC/Pluto like code
    """

    step = "Compiling python statements to UPLC"

    def __init__(
        self,
        force_three_params=False,
        validator_function_name="validator",
        config=DEFAULT_CONFIG,
    ):
        # parameters
        self.force_three_params = force_three_params
        self.validator_function_name = validator_function_name
        self.config = config
        assert (
            self.config.fast_access_skip is None or self.config.fast_access_skip > 1
        ), "Parameter fast-access-skip needs to be greater than 1 or omitted"
        # marked knowledge during compilation
        self.current_function_typ: typing.List[FunctionType] = []

    def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST:
        def g(s: plt.AST):
            for n in reversed(node_seq):
                compiled_stmt = self.visit(n)
                s = compiled_stmt(s)
            return s

        return g

    def visit_BinOp(self, node: TypedBinOp) -> plt.AST:
        op = node.left.typ.binop(node.op, node.right)
        return plt.Apply(
            op,
            self.visit(node.left),
            self.visit(node.right),
        )

    def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
        op = BoolOpMap.get(type(node.op))
        assert len(node.values) >= 2, "Need to compare at least to values"
        ops = op(
            self.visit(node.values[0]),
            self.visit(node.values[1]),
        )
        for v in node.values[2:]:
            ops = op(ops, self.visit(v))
        return ops

    def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
        op = node.operand.typ.unop(node.op)
        return plt.Apply(
            op,
            self.visit(node.operand),
        )

    def visit_Compare(self, node: TypedCompare) -> plt.AST:
        assert len(node.ops) == 1, "Only single comparisons are supported"
        assert len(node.comparators) == 1, "Only single comparisons are supported"
        cmpop = node.ops[0]
        comparator = node.comparators[0].typ
        op = node.left.typ.cmp(cmpop, comparator)
        return plt.Apply(
            op,
            self.visit(node.left),
            self.visit(node.comparators[0]),
        )

    def visit_Module(self, node: TypedModule) -> plt.AST:
        # extract actually read variables by each function
        if self.validator_function_name is not None:
            # for validators find main function
            # TODO can use more sophisiticated procedure here i.e. functions marked by comment
            main_fun: typing.Optional[InstanceType] = None
            for s in node.body:
                if (
                    isinstance(s, ast.FunctionDef)
                    and s.orig_name == self.validator_function_name
                ):
                    main_fun = s
            assert (
                main_fun is not None
            ), f"Could not find function named {self.validator_function_name}"
            main_fun_typ: FunctionType = main_fun.typ.typ
            assert isinstance(
                main_fun_typ, FunctionType
            ), f"Variable named {self.validator_function_name} is not of type function"

            # check if this is a contract written to double function
            enable_double_func_mint_spend = False
            if len(main_fun_typ.argtyps) >= 3 and self.force_three_params:
                # check if is possible
                second_last_arg = main_fun_typ.argtyps[-2]
                assert isinstance(
                    second_last_arg, InstanceType
                ), "Can not pass Class into validator"
                if isinstance(second_last_arg.typ, UnionType):
                    possible_types = second_last_arg.typ.typs
                else:
                    possible_types = [second_last_arg.typ]
                if any(isinstance(t, UnitType) for t in possible_types):
                    OPSHIN_LOGGER.warning(
                        "The redeemer is annotated to be 'None'. This value is usually encoded in PlutusData with constructor id 0 and no fields. If you want the script to double function as minting and spending script, annotate the second argument with 'NoRedeemer'."
                    )
                enable_double_func_mint_spend = not any(
                    (isinstance(t, RecordType) and t.record.constructor == 0)
                    or isinstance(t, UnitType)
                    for t in possible_types
                )
                if not enable_double_func_mint_spend:
                    OPSHIN_LOGGER.warning(
                        "The second argument to the validator function potentially has constructor id 0. The validator will not be able to double function as minting script and spending script."
                    )

            body = node.body + (
                [
                    TypedReturn(
                        TypedCall(
                            func=ast.Name(
                                id=main_fun.name,
                                typ=InstanceType(main_fun_typ),
                                ctx=ast.Load(),
                            ),
                            typ=main_fun_typ.rettyp,
                            args=[
                                RawPlutoExpr(
                                    expr=transform_ext_params_map(a)(
                                        OVar(f"val_param{i}")
                                    ),
                                    typ=a,
                                )
                                for i, a in enumerate(main_fun_typ.argtyps)
                            ],
                        )
                    )
                ]
            )
            self.current_function_typ.append(FunctionType([], InstanceType(AnyType())))
            name_load_visitor = NameLoadCollector()
            name_load_visitor.visit(node)
            all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

            # write all variables that are ever read
            # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
            validator = SafeOLambda(
                [f"val_param{i}" for i, _ in enumerate(main_fun_typ.argtyps)],
                plt.Let(
                    [
                        (
                            x,
                            plt.Delay(
                                plt.TraceError(f"NameError: {map_to_orig_name(x)}")
                            ),
                        )
                        for x in all_vs
                    ],
                    self.visit_sequence(body)(OUnit),
                ),
            )
            self.current_function_typ.pop()
            if enable_double_func_mint_spend:
                validator = wrap_validator_double_function(
                    validator, pass_through=len(main_fun_typ.argtyps) - 3
                )
            elif self.force_three_params:
                # Error if the double function is enforced but not possible
                raise RuntimeError(
                    "The contract can not always detect if it was passed three or two parameters on-chain."
                )
        else:
            name_load_visitor = NameLoadCollector()
            name_load_visitor.visit(node)
            all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

            body = node.body
            # write all variables that are ever read
            # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
            validator = plt.Let(
                [
                    (
                        x,
                        plt.Delay(plt.TraceError(f"NameError: {map_to_orig_name(x)}")),
                    )
                    for x in all_vs
                ],
                self.visit_sequence(body)(OUnit),
            )

        cp = plt.Program((1, 0, 0), validator)
        return cp

    def visit_Constant(self, node: Constant) -> plt.AST:
        if isinstance(node.value, bytes) and node.value != b"":
            try:
                bytes.fromhex(node.value.decode())
            except ValueError:
                pass
            else:
                OPSHIN_LOGGER.warning(
                    f"The string {node.value} looks like it is supposed to be a hex-encoded bytestring but is actually utf8-encoded. Try using `bytes.fromhex('{node.value.decode()}')` instead."
                )
        plt_val = plt.UPLCConstant(rec_constant_map(node.value))
        return plt_val

    def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
        return plt.Unit()

    def visit_Assign(self, node: TypedAssign) -> CallAST:
        assert (
            len(node.targets) == 1
        ), "Assignments to more than one variable not supported yet"
        assert isinstance(
            node.targets[0], Name
        ), "Assignments to other things then names are not supported"
        compiled_e = self.visit(node.value)
        varname = node.targets[0].id
        if (hasattr(node.targets[0], "is_wrapped") and node.targets[0].is_wrapped) or (
            isinstance(node.targets[0].typ, InstanceType)
            and (
                isinstance(node.targets[0].typ.typ, AnyType)
                or isinstance(node.targets[0].typ.typ, UnionType)
            )
        ):
            # if this is a wrapped variable or Union/Any, we need to map it back to the external parameter type
            # TODO this is terribly inefficient. we would rather want to cast once when entering the body and cast back when leaving
            compiled_e = transform_output_map(node.value.typ)(compiled_e)
        # first evaluate the term, then wrap in a delay
        return lambda x: plt.Let(
            [
                (opshin_name_scheme_compatible_varname(varname), compiled_e),
                (varname, plt.Delay(OVar(varname))),
            ],
            x,
        )

    def visit_AnnAssign(self, node: TypedAnnAssign) -> CallAST:
        assert isinstance(
            node.target, Name
        ), "Assignments to other things then names are not supported"
        assert isinstance(
            node.target.typ, InstanceType
        ), "Can only assign instances to instances"
        val = self.visit(node.value)
        if isinstance(node.value.typ, InstanceType) and (
            isinstance(node.value.typ.typ, AnyType)
            or isinstance(node.value.typ.typ, UnionType)
        ):
            # we need to map this as it will originate from PlutusData
            # AnyType is the only type other than the builtin itself that can be cast to builtin values
            val = transform_ext_params_map(node.target.typ)(val)
        if isinstance(node.target.typ, InstanceType) and (
            isinstance(node.target.typ.typ, AnyType)
            or isinstance(node.target.typ.typ, UnionType)
        ):
            # we need to map this back as it will be treated as PlutusData
            # AnyType is the only type other than the builtin itself that can be cast to from builtin values
            val = transform_output_map(node.value.typ)(val)
        return lambda x: plt.Let(
            [
                (opshin_name_scheme_compatible_varname(node.target.id), val),
                (node.target.id, plt.Delay(OVar(node.target.id))),
            ],
            x,
        )

    def visit_Name(self, node: Name) -> plt.AST:
        # depending on load or store context, return the value of the variable or its name
        if not isinstance(node.ctx, Load):
            raise NotImplementedError(f"Context {node.ctx} not supported")
        if isinstance(node.typ, ClassType):
            # if this is not an instance but a class, call the constructor
            return node.typ.constr()
        if hasattr(node, "is_wrapped") and node.is_wrapped:
            return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
        return plt.Force(plt.Var(node.id))

    def visit_Expr(self, node: TypedExpr) -> CallAST:
        # we exploit UPLCs eager evaluation here
        # the expression is computed even though its value is eventually discarded
        # Note this really only makes sense for Trace
        # we use an invalid name here to avoid conflicts
        return lambda x: plt.Apply(OLambda(["0"], x), self.visit(node.value))

    def visit_Call(self, node: TypedCall) -> plt.AST:
        # compiled_args = " ".join(f"({self.visit(a)} {STATEMONAD})" for a in node.args)
        # return rf"(\{STATEMONAD} -> ({self.visit(node.func)} {compiled_args})"
        # TODO function is actually not of type polymorphic function type here anymore
        if isinstance(node.func.typ, PolymorphicFunctionInstanceType):
            # edge case for weird builtins that are polymorphic
            func_plt = force_params(
                node.func.typ.polymorphic_function.impl_from_args(
                    node.func.typ.typ.argtyps
                )
            )
            bind_self = None
        else:
            assert isinstance(node.func.typ, InstanceType) and isinstance(
                node.func.typ.typ, FunctionType
            ), "Can only call instances of functions"
            func_plt = self.visit(node.func)
            bind_self = node.func.typ.typ.bind_self
        bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys()))
        args = []
        for i, (a, t) in enumerate(zip(node.args, node.func.typ.typ.argtyps)):
            # now impl_from_args has been chosen, skip type arg
            if (
                hasattr(node.func, "orig_id")
                and node.func.orig_id == "isinstance"
                and i == 1
            ):
                continue
            assert isinstance(t, InstanceType)
            # pass in all arguments evaluated with the statemonad
            a_int = self.visit(a)
            if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType):
                # if the function expects input of generic type data, wrap data before passing it inside
                a_int = transform_output_map(a.typ)(a_int)
            args.append(a_int)
        # First assign to let to ensure that the arguments are evaluated before the call, but need to delay
        # as this is a variable assignment
        # Also bring all states of variables read inside the function into scope / update with value in current state
        # before call to simulate statemonad with current state being passed in
        return OLet(
            [(f"p{i}", a) for i, a in enumerate(args)],
            SafeApply(
                func_plt,
                *([plt.Var(bind_self)] if bind_self is not None else []),
                *[plt.Var(n) for n in bound_vs],
                *[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
            ),
        )

    def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
        body = node.body.copy()
        # defaults to returning None if there is no return statement
        if node.typ.typ.rettyp.typ == AnyType():
            ret_val = OUnit
        else:
            ret_val = plt.Unit()
        read_vs = sorted(list(node.typ.typ.bound_vars.keys()))
        if node.typ.typ.bind_self is not None:
            read_vs.insert(0, node.typ.typ.bind_self)
        self.current_function_typ.append(node.typ.typ)
        compiled_body = self.visit_sequence(body)(ret_val)
        self.current_function_typ.pop()
        return lambda x: plt.Let(
            [
                (
                    node.name,
                    plt.Delay(
                        SafeLambda(
                            read_vs + [a.arg for a in node.args.args],
                            compiled_body,
                        )
                    ),
                )
            ],
            x,
        )

    def visit_While(self, node: TypedWhile) -> CallAST:
        # the while loop calls itself, updating the values at overwritten names
        # by overwriting them with arguments to its self-recall
        if node.orelse:
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
            cn = copy.copy(node)
            cn.orelse = []
            return self.visit_sequence([cn] + node.orelse)
        compiled_c = self.visit(node.test)
        compiled_s = self.visit_sequence(node.body)
        written_vs = written_vars(node)
        pwritten_vs = [plt.Var(x) for x in written_vs]
        s_fun = lambda x: plt.Lambda(
            [opshin_name_scheme_compatible_varname("while")] + written_vs,
            plt.Ite(
                compiled_c,
                compiled_s(
                    plt.Apply(
                        OVar("while"),
                        OVar("while"),
                        *copy.deepcopy(pwritten_vs),
                    )
                ),
                x,
            ),
        )

        return lambda x: OLet(
            [
                ("adjusted_next", SafeLambda(written_vs, x)),
                (
                    "while",
                    s_fun(
                        SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                    ),
                ),
            ],
            plt.Apply(OVar("while"), OVar("while"), *copy.deepcopy(pwritten_vs)),
        )

    def visit_For(self, node: TypedFor) -> CallAST:
        if node.orelse:
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
            cn = copy.copy(node)
            cn.orelse = []
            return self.visit_sequence([cn] + node.orelse)
        assert isinstance(node.iter.typ, InstanceType)
        if isinstance(node.iter.typ.typ, ListType):
            assert isinstance(
                node.target, Name
            ), "Can only assign value to singleton element"
            compiled_s = self.visit_sequence(node.body)
            compiled_iter = self.visit(node.iter)
            written_vs = written_vars(node)
            pwritten_vs = [plt.Var(x) for x in written_vs]
            s_fun = lambda x: plt.Lambda(
                [
                    opshin_name_scheme_compatible_varname("for"),
                    opshin_name_scheme_compatible_varname("iter"),
                ]
                + written_vs,
                plt.IteNullList(
                    OVar("iter"),
                    x,
                    plt.Let(
                        [(node.target.id, plt.Delay(plt.HeadList(OVar("iter"))))],
                        compiled_s(
                            plt.Apply(
                                OVar("for"),
                                OVar("for"),
                                plt.TailList(OVar("iter")),
                                *copy.deepcopy(pwritten_vs),
                            )
                        ),
                    ),
                ),
            )
            return lambda x: OLet(
                [
                    ("adjusted_next", plt.Lambda([node.target.id] + written_vs, x)),
                    (
                        "for",
                        s_fun(
                            plt.Apply(
                                OVar("adjusted_next"),
                                plt.Var(node.target.id),
                                *copy.deepcopy(pwritten_vs),
                            )
                        ),
                    ),
                ],
                plt.Apply(
                    OVar("for"),
                    OVar("for"),
                    compiled_iter,
                    *copy.deepcopy(pwritten_vs),
                ),
            )
        raise NotImplementedError(
            "Compilation of for statements for anything but lists not implemented yet"
        )

    def visit_If(self, node: TypedIf) -> CallAST:
        written_vs = written_vars(node)
        pwritten_vs = [plt.Var(x) for x in written_vs]
        return lambda x: OLet(
            [("adjusted_next", SafeLambda(written_vs, x))],
            plt.Ite(
                self.visit(node.test),
                self.visit_sequence(node.body)(
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                ),
                self.visit_sequence(node.orelse)(
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                ),
            ),
        )

    def visit_Return(self, node: TypedReturn) -> CallAST:
        value_plt = self.visit(node.value)
        assert self.current_function_typ, "Can not handle Return outside of a function"
        if isinstance(self.current_function_typ[-1].rettyp.typ, AnyType) or isinstance(
            self.current_function_typ[-1].rettyp.typ, UnionType
        ):
            value_plt = transform_output_map(node.value.typ)(value_plt)
        return lambda _: value_plt

    def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
        assert isinstance(
            node.value.typ, InstanceType
        ), "Can only access elements of instances, not classes"
        if isinstance(node.value.typ.typ, TupleType):
            assert isinstance(
                node.slice, Constant
            ), "Only constant index access for tuples is supported"
            assert isinstance(
                node.slice.value, int
            ), "Only constant index integer access for tuples is supported"
            index = node.slice.value
            if index < 0:
                index += len(node.value.typ.typ.typs)
            assert isinstance(node.ctx, Load), "Tuples are read-only"
            return plt.FunctionalTupleAccess(
                self.visit(node.value),
                index,
                len(node.value.typ.typ.typs),
            )
        if isinstance(node.value.typ.typ, PairType):
            assert isinstance(
                node.slice, Constant
            ), "Only constant index access for pairs is supported"
            assert isinstance(
                node.slice.value, int
            ), "Only constant index integer access for pairs is supported"
            index = node.slice.value
            if index < 0:
                index += 2
            assert isinstance(node.ctx, Load), "Pairs are read-only"
            assert (
                0 <= index < 2
            ), f"Pairs only have 2 elements, index should be 0 or 1, is {node.slice.value}"
            member_func = plt.FstPair if index == 0 else plt.SndPair
            # the content of pairs is always Data, so we need to unwrap
            member_typ = node.typ
            return transform_ext_params_map(member_typ)(
                member_func(
                    self.visit(node.value),
                ),
            )
        if isinstance(node.value.typ.typ, ListType):
            if not isinstance(node.slice, Slice):
                assert (
                    node.slice.typ == IntegerInstanceType
                ), "Only single element list index access supported"
                if isinstance(node.slice, Constant) and node.slice.value >= 0:
                    index = node.slice.value
                    return plt.ConstantIndexAccessListFast(
                        self.visit(node.value),
                        index,
                    )
                return OLet(
                    [
                        (
                            "l",
                            self.visit(node.value),
                        ),
                        (
                            "raw_i",
                            self.visit(node.slice),
                        ),
                        (
                            "i",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_i"), plt.LengthList(OVar("l"))
                                ),
                                OVar("raw_i"),
                            ),
                        ),
                    ],
                    (
                        plt.IndexAccessListFast(self.config.fast_access_skip)(
                            OVar("l"), OVar("i")
                        )
                        if self.config.fast_access_skip is not None
                        else plt.IndexAccessList(OVar("l"), OVar("i"))
                    ),
                )
            else:
                assert (
                    node.slice.upper is not None
                ), "Only slices with upper bound supported"
                assert (
                    node.slice.lower is not None
                ), "Only slices with lower bound supported"
                return OLet(
                    [
                        (
                            "xs",
                            self.visit(node.value),
                        ),
                        (
                            "raw_i",
                            self.visit(node.slice.lower),
                        ),
                        (
                            "i",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_i"),
                                    plt.LengthList(OVar("xs")),
                                ),
                                OVar("raw_i"),
                            ),
                        ),
                        (
                            "raw_j",
                            self.visit(node.slice.upper),
                        ),
                        (
                            "j",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_j"),
                                    plt.LengthList(OVar("xs")),
                                ),
                                OVar("raw_j"),
                            ),
                        ),
                        (
                            "drop",
                            plt.Ite(
                                plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
                                plt.Integer(0),
                                OVar("i"),
                            ),
                        ),
                        (
                            "take",
                            plt.SubtractInteger(OVar("j"), OVar("drop")),
                        ),
                    ],
                    plt.Ite(
                        plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
                        empty_list(node.value.typ.typ.typ),
                        plt.SliceList(
                            OVar("drop"),
                            OVar("take"),
                            OVar("xs"),
                            empty_list(node.value.typ.typ.typ),
                        ),
                    ),
                )
        elif isinstance(node.value.typ.typ, DictType):
            dict_typ = node.value.typ.typ
            if not isinstance(node.slice, Slice):
                return OLet(
                    [
                        (
                            "key",
                            transform_output_map(node.slice.typ)(
                                self.visit(node.slice),
                            ),
                        )
                    ],
                    transform_ext_params_map(dict_typ.value_typ)(
                        plt.SndPair(
                            plt.FindList(
                                self.visit(node.value),
                                OLambda(
                                    ["x"],
                                    plt.EqualsData(
                                        OVar("key"),
                                        plt.FstPair(OVar("x")),
                                    ),
                                ),
                                plt.TraceError("KeyError"),
                            )
                        ),
                    ),
                )
        elif isinstance(node.value.typ.typ, ByteStringType):
            if not isinstance(node.slice, Slice):
                if isinstance(node.slice, Constant) and node.slice.value >= 0:
                    return plt.IndexByteString(
                        self.visit(node.value),
                        self.visit(node.slice),
                    )
                elif isinstance(node.slice, Constant) and node.slice.value < 0:
                    return plt.IndexByteString(
                        self.visit(node.value),
                        plt.AddInteger(
                            self.visit(node.slice),
                            plt.LengthOfByteString(self.visit(node.value)),
                        ),
                    )
                return OLet(
                    [
                        (
                            "bs",
                            self.visit(node.value),
                        ),
                        (
                            "raw_ix",
                            self.visit(node.slice),
                        ),
                        (
                            "ix",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_ix"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_ix"),
                                    plt.LengthOfByteString(OVar("bs")),
                                ),
                                OVar("raw_ix"),
                            ),
                        ),
                    ],
                    plt.IndexByteString(OVar("bs"), OVar("ix")),
                )
            elif isinstance(node.slice, Slice):
                return OLet(
                    [
                        (
                            "bs",
                            self.visit(node.value),
                        ),
                        (
                            "raw_i",
                            self.visit(node.slice.lower),
                        ),
                        (
                            "i",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_i"),
                                    plt.LengthOfByteString(OVar("bs")),
                                ),
                                OVar("raw_i"),
                            ),
                        ),
                        (
                            "raw_j",
                            self.visit(node.slice.upper),
                        ),
                        (
                            "j",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_j"),
                                    plt.LengthOfByteString(OVar("bs")),
                                ),
                                OVar("raw_j"),
                            ),
                        ),
                        (
                            "drop",
                            plt.Ite(
                                plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
                                plt.Integer(0),
                                OVar("i"),
                            ),
                        ),
                        (
                            "take",
                            plt.SubtractInteger(OVar("j"), OVar("drop")),
                        ),
                    ],
                    plt.Ite(
                        plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
                        plt.ByteString(b""),
                        plt.SliceByteString(
                            OVar("drop"),
                            OVar("take"),
                            OVar("bs"),
                        ),
                    ),
                )
        raise NotImplementedError(
            f'Could not implement subscript "{node.slice}" of "{node.value}"'
        )

    def visit_Tuple(self, node: TypedTuple) -> plt.AST:
        return plt.FunctionalTuple(*(self.visit(e) for e in node.elts))

    def visit_ClassDef(self, node: TypedClassDef) -> CallAST:
        return lambda x: plt.Let([(node.name, plt.Delay(node.class_typ.constr()))], x)

    def visit_Attribute(self, node: TypedAttribute) -> plt.AST:
        assert isinstance(
            node.value.typ, InstanceType
        ), "Can only access attributes of instances"
        obj = self.visit(node.value)
        attr = node.value.typ.attribute(node.attr)
        return plt.Apply(attr, obj)

    def visit_Assert(self, node: TypedAssert) -> CallAST:
        return lambda x: plt.Ite(
            self.visit(node.test),
            x,
            plt.Apply(
                plt.Error(),
                (
                    plt.Trace(self.visit(node.msg), plt.Unit())
                    if node.msg is not None
                    else plt.Unit()
                ),
            ),
        )

    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> plt.AST:
        return node.expr

    def visit_List(self, node: TypedList) -> plt.AST:
        assert isinstance(node.typ, InstanceType)
        assert isinstance(node.typ.typ, ListType)
        el_typ = node.typ.typ.typ
        l = empty_list(el_typ)
        for e in reversed(node.elts):
            element = self.visit(e)
            if isinstance(el_typ.typ, AnyType) or isinstance(el_typ.typ, UnionType):
                # if the function expects input of generic type data, wrap data before passing it inside
                element = transform_output_map(e.typ)(element)
            l = plt.MkCons(element, l)
        return l

    def visit_Dict(self, node: TypedDict) -> plt.AST:
        assert isinstance(node.typ, InstanceType)
        assert isinstance(node.typ.typ, DictType)
        key_type = node.typ.typ.key_typ
        value_type = node.typ.typ.value_typ
        l = plt.EmptyDataPairList()
        for k, v in zip(node.keys, node.values):
            l = plt.MkCons(
                plt.MkPairData(
                    transform_output_map(k.typ)(
                        self.visit(k),
                    ),
                    transform_output_map(v.typ)(
                        self.visit(v),
                    ),
                ),
                l,
            )
        return l

    def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
        if isinstance(node.typ.typ, UnionType):
            body = self.visit(node.body)
            orelse = self.visit(node.orelse)
            if not isinstance(node.body.typ, UnionType):
                body = transform_output_map(node.body.typ)(body)
            if not isinstance(node.orelse.typ, UnionType):
                orelse = transform_output_map(node.orelse.typ)(orelse)
            return plt.Ite(self.visit(node.test), body, orelse)
        return plt.Ite(
            self.visit(node.test),
            self.visit(node.body),
            self.visit(node.orelse),
        )

    def visit_ListComp(self, node: TypedListComp) -> plt.AST:
        assert len(node.generators) == 1, "Currently only one generator supported"
        gen = node.generators[0]
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
        assert isinstance(
            gen.target, Name
        ), "Can only assign value to singleton element"
        lst = self.visit(gen.iter)
        ifs = None
        for ifexpr in gen.ifs:
            if ifs is None:
                ifs = self.visit(ifexpr)
            else:
                ifs = plt.And(ifs, self.visit(ifexpr))
        map_fun = OLambda(
            ["x"],
            plt.Let(
                [(gen.target.id, plt.Delay(OVar("x")))],
                self.visit(node.elt),
            ),
        )
        empty_list_con = empty_list(node.elt.typ)
        if ifs is not None:
            filter_fun = OLambda(
                ["x"],
                plt.Let(
                    [(gen.target.id, plt.Delay(OVar("x")))],
                    ifs,
                ),
            )
            return plt.MapFilterList(
                lst,
                filter_fun,
                map_fun,
                empty_list_con,
            )
        else:
            return plt.MapList(
                lst,
                map_fun,
                empty_list_con,
            )

    def visit_DictComp(self, node: TypedDictComp) -> plt.AST:
        assert len(node.generators) == 1, "Currently only one generator supported"
        gen = node.generators[0]
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
        assert isinstance(
            gen.target, Name
        ), "Can only assign value to singleton element"
        lst = self.visit(gen.iter)
        ifs = None
        for ifexpr in gen.ifs:
            if ifs is None:
                ifs = self.visit(ifexpr)
            else:
                ifs = plt.And(ifs, self.visit(ifexpr))
        map_fun = OLambda(
            ["x"],
            plt.Let(
                [(gen.target.id, plt.Delay(OVar("x")))],
                plt.MkPairData(
                    transform_output_map(node.key.typ)(
                        self.visit(node.key),
                    ),
                    transform_output_map(node.value.typ)(
                        self.visit(node.value),
                    ),
                ),
            ),
        )
        empty_list_con = plt.EmptyDataPairList()
        if ifs is not None:
            filter_fun = OLambda(
                ["x"],
                plt.Let(
                    [(gen.target.id, plt.Delay(OVar("x")))],
                    ifs,
                ),
            )
            return plt.MapFilterList(
                lst,
                filter_fun,
                map_fun,
                empty_list_con,
            )
        else:
            return plt.MapList(
                lst,
                map_fun,
                empty_list_con,
            )

    def visit_FormattedValue(self, node: TypedFormattedValue) -> plt.AST:
        return plt.Apply(
            node.value.typ.stringify(),
            self.visit(node.value),
        )

    def visit_JoinedStr(self, node: TypedJoinedStr) -> plt.AST:
        joined_str = plt.Text("")
        for v in reversed(node.values):
            joined_str = plt.AppendString(self.visit(v), joined_str)
        return joined_str

    def generic_visit(self, node: TypedAST) -> plt.AST:
        raise NotImplementedError(f"Can not compile {node}")


def parse(
    source: str,
    filename=None,
) -> ast.AST:
    """
    Parse source code into an AST

    Currently passes everything through Python's ast module.
    """
    tree = ast.parse(source, filename=filename)
    return tree


def compile(
    prog: ast.AST,
    filename=None,
    validator_function_name="validator",
    config=DEFAULT_CONFIG,
) -> plt.Program:
    compile_pipeline = [
        # Important to call this one first - it imports all further files
        RewriteImport(filename=filename),
        # Rewrites that simplify the python code
        RewriteForbiddenReturn(),
        OptimizeConstantFolding() if config.constant_folding else NoOp(),
        OptimizeUnionExpansion() if config.expand_union_types else NoOp(),
        RewriteSubscript38(),
        RewriteAugAssign(),
        RewriteComparisonChaining(),
        RewriteTupleAssign(),
        RewriteImportIntegrityCheck(),
        RewriteImportPlutusData(),
        RewriteImportHashlib(),
        RewriteImportTyping(),
        RewriteForbiddenOverwrites(),
        RewriteImportDataclasses(),
        RewriteInjectBuiltins(),
        RewriteConditions(),
        # Save the original names of variables
        RewriteOrigName(),
        RewriteScoping(),
        # The type inference needs to be run after complex python operations were rewritten
        AggressiveTypeInferencer(config.allow_isinstance_anything),
        # Rewrites that circumvent the type inference or use its results
        RewriteAssertNone(),
        RewriteEmptyLists(),
        RewriteEmptyDicts(),
        RewriteImportUPLCBuiltins(),
        RewriteRemoveTypeStuff(),
        # Apply optimizations
        OptimizeRemoveDeadvars() if config.remove_dead_code else NoOp(),
        OptimizeRemoveDeadconstants() if config.remove_dead_code else NoOp(),
        OptimizeRemovePass(),
    ]
    for s in compile_pipeline:
        prog = s.visit(prog)
        prog = custom_fix_missing_locations(prog)

    # the compiler runs last
    s = PlutoCompiler(
        force_three_params=config.force_three_params,
        validator_function_name=validator_function_name,
        config=config,
    )
    prog = s.visit(prog)

    return prog

Functions

def compile(prog: ast.AST, filename=None, validator_function_name='validator', config=CompilationConfig(compress_patterns=True, iterative_unfold_patterns=False, constant_index_access_list=True, constant_folding=False, allow_isinstance_anything=False, force_three_params=False, remove_dead_code=True, fast_access_skip=None, expand_union_types=False)) ‑> pluthon.pluthon_ast.Program
Expand source code
def compile(
    prog: ast.AST,
    filename=None,
    validator_function_name="validator",
    config=DEFAULT_CONFIG,
) -> plt.Program:
    compile_pipeline = [
        # Important to call this one first - it imports all further files
        RewriteImport(filename=filename),
        # Rewrites that simplify the python code
        RewriteForbiddenReturn(),
        OptimizeConstantFolding() if config.constant_folding else NoOp(),
        OptimizeUnionExpansion() if config.expand_union_types else NoOp(),
        RewriteSubscript38(),
        RewriteAugAssign(),
        RewriteComparisonChaining(),
        RewriteTupleAssign(),
        RewriteImportIntegrityCheck(),
        RewriteImportPlutusData(),
        RewriteImportHashlib(),
        RewriteImportTyping(),
        RewriteForbiddenOverwrites(),
        RewriteImportDataclasses(),
        RewriteInjectBuiltins(),
        RewriteConditions(),
        # Save the original names of variables
        RewriteOrigName(),
        RewriteScoping(),
        # The type inference needs to be run after complex python operations were rewritten
        AggressiveTypeInferencer(config.allow_isinstance_anything),
        # Rewrites that circumvent the type inference or use its results
        RewriteAssertNone(),
        RewriteEmptyLists(),
        RewriteEmptyDicts(),
        RewriteImportUPLCBuiltins(),
        RewriteRemoveTypeStuff(),
        # Apply optimizations
        OptimizeRemoveDeadvars() if config.remove_dead_code else NoOp(),
        OptimizeRemoveDeadconstants() if config.remove_dead_code else NoOp(),
        OptimizeRemovePass(),
    ]
    for s in compile_pipeline:
        prog = s.visit(prog)
        prog = custom_fix_missing_locations(prog)

    # the compiler runs last
    s = PlutoCompiler(
        force_three_params=config.force_three_params,
        validator_function_name=validator_function_name,
        config=config,
    )
    prog = s.visit(prog)

    return prog
def parse(source: str, filename=None) ‑> ast.AST

Parse source code into an AST

Currently passes everything through Python's ast module.

Expand source code
def parse(
    source: str,
    filename=None,
) -> ast.AST:
    """
    Parse source code into an AST

    Currently passes everything through Python's ast module.
    """
    tree = ast.parse(source, filename=filename)
    return tree
def rec_constant_map(c)
Expand source code
def rec_constant_map(c):
    if isinstance(c, bool):
        return uplc.BuiltinBool(c)
    if isinstance(c, int):
        return uplc.BuiltinInteger(c)
    if isinstance(c, type(None)):
        return uplc.BuiltinUnit()
    if isinstance(c, bytes):
        return uplc.BuiltinByteString(c)
    if isinstance(c, str):
        return uplc.BuiltinString(c)
    if isinstance(c, list):
        return uplc.BuiltinList([rec_constant_map(ce) for ce in c])
    if isinstance(c, dict):
        return uplc.BuiltinList(
            [
                uplc.BuiltinPair(*p)
                for p in zip(
                    (rec_constant_map_data(ce) for ce in c.keys()),
                    (rec_constant_map_data(ce) for ce in c.values()),
                )
            ]
        )
    # This can occur when PlutusData is generated during constant folding
    if isinstance(c, PlutusData):
        return data_from_cbor(c.to_cbor())
    raise NotImplementedError(f"Unsupported constant type {type(c)}")
def rec_constant_map_data(c)
Expand source code
def rec_constant_map_data(c):
    if isinstance(c, bool):
        return uplc.PlutusInteger(int(c))
    if isinstance(c, int):
        return uplc.PlutusInteger(c)
    if isinstance(c, type(None)):
        return uplc.PlutusConstr(0, [])
    if isinstance(c, bytes):
        return uplc.PlutusByteString(c)
    if isinstance(c, str):
        return uplc.PlutusByteString(c.encode())
    if isinstance(c, list):
        return uplc.PlutusList([rec_constant_map_data(ce) for ce in c])
    if isinstance(c, dict):
        return uplc.PlutusMap(
            dict(
                zip(
                    (rec_constant_map_data(ce) for ce in c.keys()),
                    (rec_constant_map_data(ce) for ce in c.values()),
                )
            )
        )
    # This can occur when PlutusData is generated during constant folding
    if isinstance(c, PlutusData):
        return data_from_cbor(c.to_cbor())
    raise NotImplementedError(f"Unsupported constant type {type(c)}")
def wrap_validator_double_function(x: pluthon.pluthon_ast.AST, pass_through: int = 0)

Wraps the validator function to enable a double function as minting script

pass_through defines how many parameters x would normally take and should be passed through to x

Expand source code
def wrap_validator_double_function(x: plt.AST, pass_through: int = 0):
    """
    Wraps the validator function to enable a double function as minting script

    pass_through defines how many parameters x would normally take and should be passed through to x
    """
    return OLambda(
        [f"v{i}" for i in range(pass_through)] + ["a0", "a1"],
        OLet(
            [("p", plt.Apply(x, *(OVar(f"v{i}") for i in range(pass_through))))],
            plt.Ite(
                # if the second argument has constructor 0 = script context
                plt.DelayedChooseData(
                    OVar("a1"),
                    plt.EqualsInteger(plt.Constructor(OVar("a1")), plt.Integer(0)),
                    plt.Bool(False),
                    plt.Bool(False),
                    plt.Bool(False),
                    plt.Bool(False),
                ),
                # call the validator with a0, a1, and plug in "Nothing" for data
                plt.Apply(
                    OVar("p"),
                    plt.UPLCConstant(to_uplc_builtin(Nothing())),
                    OVar("a0"),
                    OVar("a1"),
                ),
                # else call the validator with a0, a1 and return (now partially bound)
                plt.Apply(OVar("p"), OVar("a0"), OVar("a1")),
            ),
        ),
    )

Classes

class PlutoCompiler (force_three_params=False, validator_function_name='validator', config=CompilationConfig(compress_patterns=True, iterative_unfold_patterns=False, constant_index_access_list=True, constant_folding=False, allow_isinstance_anything=False, force_three_params=False, remove_dead_code=True, fast_access_skip=None, expand_union_types=False))

Expects a TypedAST and returns UPLC/Pluto like code

Expand source code
class PlutoCompiler(CompilingNodeTransformer):
    """
    Expects a TypedAST and returns UPLC/Pluto like code
    """

    step = "Compiling python statements to UPLC"

    def __init__(
        self,
        force_three_params=False,
        validator_function_name="validator",
        config=DEFAULT_CONFIG,
    ):
        # parameters
        self.force_three_params = force_three_params
        self.validator_function_name = validator_function_name
        self.config = config
        assert (
            self.config.fast_access_skip is None or self.config.fast_access_skip > 1
        ), "Parameter fast-access-skip needs to be greater than 1 or omitted"
        # marked knowledge during compilation
        self.current_function_typ: typing.List[FunctionType] = []

    def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST:
        def g(s: plt.AST):
            for n in reversed(node_seq):
                compiled_stmt = self.visit(n)
                s = compiled_stmt(s)
            return s

        return g

    def visit_BinOp(self, node: TypedBinOp) -> plt.AST:
        op = node.left.typ.binop(node.op, node.right)
        return plt.Apply(
            op,
            self.visit(node.left),
            self.visit(node.right),
        )

    def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
        op = BoolOpMap.get(type(node.op))
        assert len(node.values) >= 2, "Need to compare at least to values"
        ops = op(
            self.visit(node.values[0]),
            self.visit(node.values[1]),
        )
        for v in node.values[2:]:
            ops = op(ops, self.visit(v))
        return ops

    def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
        op = node.operand.typ.unop(node.op)
        return plt.Apply(
            op,
            self.visit(node.operand),
        )

    def visit_Compare(self, node: TypedCompare) -> plt.AST:
        assert len(node.ops) == 1, "Only single comparisons are supported"
        assert len(node.comparators) == 1, "Only single comparisons are supported"
        cmpop = node.ops[0]
        comparator = node.comparators[0].typ
        op = node.left.typ.cmp(cmpop, comparator)
        return plt.Apply(
            op,
            self.visit(node.left),
            self.visit(node.comparators[0]),
        )

    def visit_Module(self, node: TypedModule) -> plt.AST:
        # extract actually read variables by each function
        if self.validator_function_name is not None:
            # for validators find main function
            # TODO can use more sophisiticated procedure here i.e. functions marked by comment
            main_fun: typing.Optional[InstanceType] = None
            for s in node.body:
                if (
                    isinstance(s, ast.FunctionDef)
                    and s.orig_name == self.validator_function_name
                ):
                    main_fun = s
            assert (
                main_fun is not None
            ), f"Could not find function named {self.validator_function_name}"
            main_fun_typ: FunctionType = main_fun.typ.typ
            assert isinstance(
                main_fun_typ, FunctionType
            ), f"Variable named {self.validator_function_name} is not of type function"

            # check if this is a contract written to double function
            enable_double_func_mint_spend = False
            if len(main_fun_typ.argtyps) >= 3 and self.force_three_params:
                # check if is possible
                second_last_arg = main_fun_typ.argtyps[-2]
                assert isinstance(
                    second_last_arg, InstanceType
                ), "Can not pass Class into validator"
                if isinstance(second_last_arg.typ, UnionType):
                    possible_types = second_last_arg.typ.typs
                else:
                    possible_types = [second_last_arg.typ]
                if any(isinstance(t, UnitType) for t in possible_types):
                    OPSHIN_LOGGER.warning(
                        "The redeemer is annotated to be 'None'. This value is usually encoded in PlutusData with constructor id 0 and no fields. If you want the script to double function as minting and spending script, annotate the second argument with 'NoRedeemer'."
                    )
                enable_double_func_mint_spend = not any(
                    (isinstance(t, RecordType) and t.record.constructor == 0)
                    or isinstance(t, UnitType)
                    for t in possible_types
                )
                if not enable_double_func_mint_spend:
                    OPSHIN_LOGGER.warning(
                        "The second argument to the validator function potentially has constructor id 0. The validator will not be able to double function as minting script and spending script."
                    )

            body = node.body + (
                [
                    TypedReturn(
                        TypedCall(
                            func=ast.Name(
                                id=main_fun.name,
                                typ=InstanceType(main_fun_typ),
                                ctx=ast.Load(),
                            ),
                            typ=main_fun_typ.rettyp,
                            args=[
                                RawPlutoExpr(
                                    expr=transform_ext_params_map(a)(
                                        OVar(f"val_param{i}")
                                    ),
                                    typ=a,
                                )
                                for i, a in enumerate(main_fun_typ.argtyps)
                            ],
                        )
                    )
                ]
            )
            self.current_function_typ.append(FunctionType([], InstanceType(AnyType())))
            name_load_visitor = NameLoadCollector()
            name_load_visitor.visit(node)
            all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

            # write all variables that are ever read
            # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
            validator = SafeOLambda(
                [f"val_param{i}" for i, _ in enumerate(main_fun_typ.argtyps)],
                plt.Let(
                    [
                        (
                            x,
                            plt.Delay(
                                plt.TraceError(f"NameError: {map_to_orig_name(x)}")
                            ),
                        )
                        for x in all_vs
                    ],
                    self.visit_sequence(body)(OUnit),
                ),
            )
            self.current_function_typ.pop()
            if enable_double_func_mint_spend:
                validator = wrap_validator_double_function(
                    validator, pass_through=len(main_fun_typ.argtyps) - 3
                )
            elif self.force_three_params:
                # Error if the double function is enforced but not possible
                raise RuntimeError(
                    "The contract can not always detect if it was passed three or two parameters on-chain."
                )
        else:
            name_load_visitor = NameLoadCollector()
            name_load_visitor.visit(node)
            all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

            body = node.body
            # write all variables that are ever read
            # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
            validator = plt.Let(
                [
                    (
                        x,
                        plt.Delay(plt.TraceError(f"NameError: {map_to_orig_name(x)}")),
                    )
                    for x in all_vs
                ],
                self.visit_sequence(body)(OUnit),
            )

        cp = plt.Program((1, 0, 0), validator)
        return cp

    def visit_Constant(self, node: Constant) -> plt.AST:
        if isinstance(node.value, bytes) and node.value != b"":
            try:
                bytes.fromhex(node.value.decode())
            except ValueError:
                pass
            else:
                OPSHIN_LOGGER.warning(
                    f"The string {node.value} looks like it is supposed to be a hex-encoded bytestring but is actually utf8-encoded. Try using `bytes.fromhex('{node.value.decode()}')` instead."
                )
        plt_val = plt.UPLCConstant(rec_constant_map(node.value))
        return plt_val

    def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
        return plt.Unit()

    def visit_Assign(self, node: TypedAssign) -> CallAST:
        assert (
            len(node.targets) == 1
        ), "Assignments to more than one variable not supported yet"
        assert isinstance(
            node.targets[0], Name
        ), "Assignments to other things then names are not supported"
        compiled_e = self.visit(node.value)
        varname = node.targets[0].id
        if (hasattr(node.targets[0], "is_wrapped") and node.targets[0].is_wrapped) or (
            isinstance(node.targets[0].typ, InstanceType)
            and (
                isinstance(node.targets[0].typ.typ, AnyType)
                or isinstance(node.targets[0].typ.typ, UnionType)
            )
        ):
            # if this is a wrapped variable or Union/Any, we need to map it back to the external parameter type
            # TODO this is terribly inefficient. we would rather want to cast once when entering the body and cast back when leaving
            compiled_e = transform_output_map(node.value.typ)(compiled_e)
        # first evaluate the term, then wrap in a delay
        return lambda x: plt.Let(
            [
                (opshin_name_scheme_compatible_varname(varname), compiled_e),
                (varname, plt.Delay(OVar(varname))),
            ],
            x,
        )

    def visit_AnnAssign(self, node: TypedAnnAssign) -> CallAST:
        assert isinstance(
            node.target, Name
        ), "Assignments to other things then names are not supported"
        assert isinstance(
            node.target.typ, InstanceType
        ), "Can only assign instances to instances"
        val = self.visit(node.value)
        if isinstance(node.value.typ, InstanceType) and (
            isinstance(node.value.typ.typ, AnyType)
            or isinstance(node.value.typ.typ, UnionType)
        ):
            # we need to map this as it will originate from PlutusData
            # AnyType is the only type other than the builtin itself that can be cast to builtin values
            val = transform_ext_params_map(node.target.typ)(val)
        if isinstance(node.target.typ, InstanceType) and (
            isinstance(node.target.typ.typ, AnyType)
            or isinstance(node.target.typ.typ, UnionType)
        ):
            # we need to map this back as it will be treated as PlutusData
            # AnyType is the only type other than the builtin itself that can be cast to from builtin values
            val = transform_output_map(node.value.typ)(val)
        return lambda x: plt.Let(
            [
                (opshin_name_scheme_compatible_varname(node.target.id), val),
                (node.target.id, plt.Delay(OVar(node.target.id))),
            ],
            x,
        )

    def visit_Name(self, node: Name) -> plt.AST:
        # depending on load or store context, return the value of the variable or its name
        if not isinstance(node.ctx, Load):
            raise NotImplementedError(f"Context {node.ctx} not supported")
        if isinstance(node.typ, ClassType):
            # if this is not an instance but a class, call the constructor
            return node.typ.constr()
        if hasattr(node, "is_wrapped") and node.is_wrapped:
            return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
        return plt.Force(plt.Var(node.id))

    def visit_Expr(self, node: TypedExpr) -> CallAST:
        # we exploit UPLCs eager evaluation here
        # the expression is computed even though its value is eventually discarded
        # Note this really only makes sense for Trace
        # we use an invalid name here to avoid conflicts
        return lambda x: plt.Apply(OLambda(["0"], x), self.visit(node.value))

    def visit_Call(self, node: TypedCall) -> plt.AST:
        # compiled_args = " ".join(f"({self.visit(a)} {STATEMONAD})" for a in node.args)
        # return rf"(\{STATEMONAD} -> ({self.visit(node.func)} {compiled_args})"
        # TODO function is actually not of type polymorphic function type here anymore
        if isinstance(node.func.typ, PolymorphicFunctionInstanceType):
            # edge case for weird builtins that are polymorphic
            func_plt = force_params(
                node.func.typ.polymorphic_function.impl_from_args(
                    node.func.typ.typ.argtyps
                )
            )
            bind_self = None
        else:
            assert isinstance(node.func.typ, InstanceType) and isinstance(
                node.func.typ.typ, FunctionType
            ), "Can only call instances of functions"
            func_plt = self.visit(node.func)
            bind_self = node.func.typ.typ.bind_self
        bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys()))
        args = []
        for i, (a, t) in enumerate(zip(node.args, node.func.typ.typ.argtyps)):
            # now impl_from_args has been chosen, skip type arg
            if (
                hasattr(node.func, "orig_id")
                and node.func.orig_id == "isinstance"
                and i == 1
            ):
                continue
            assert isinstance(t, InstanceType)
            # pass in all arguments evaluated with the statemonad
            a_int = self.visit(a)
            if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType):
                # if the function expects input of generic type data, wrap data before passing it inside
                a_int = transform_output_map(a.typ)(a_int)
            args.append(a_int)
        # First assign to let to ensure that the arguments are evaluated before the call, but need to delay
        # as this is a variable assignment
        # Also bring all states of variables read inside the function into scope / update with value in current state
        # before call to simulate statemonad with current state being passed in
        return OLet(
            [(f"p{i}", a) for i, a in enumerate(args)],
            SafeApply(
                func_plt,
                *([plt.Var(bind_self)] if bind_self is not None else []),
                *[plt.Var(n) for n in bound_vs],
                *[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
            ),
        )

    def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
        body = node.body.copy()
        # defaults to returning None if there is no return statement
        if node.typ.typ.rettyp.typ == AnyType():
            ret_val = OUnit
        else:
            ret_val = plt.Unit()
        read_vs = sorted(list(node.typ.typ.bound_vars.keys()))
        if node.typ.typ.bind_self is not None:
            read_vs.insert(0, node.typ.typ.bind_self)
        self.current_function_typ.append(node.typ.typ)
        compiled_body = self.visit_sequence(body)(ret_val)
        self.current_function_typ.pop()
        return lambda x: plt.Let(
            [
                (
                    node.name,
                    plt.Delay(
                        SafeLambda(
                            read_vs + [a.arg for a in node.args.args],
                            compiled_body,
                        )
                    ),
                )
            ],
            x,
        )

    def visit_While(self, node: TypedWhile) -> CallAST:
        # the while loop calls itself, updating the values at overwritten names
        # by overwriting them with arguments to its self-recall
        if node.orelse:
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
            cn = copy.copy(node)
            cn.orelse = []
            return self.visit_sequence([cn] + node.orelse)
        compiled_c = self.visit(node.test)
        compiled_s = self.visit_sequence(node.body)
        written_vs = written_vars(node)
        pwritten_vs = [plt.Var(x) for x in written_vs]
        s_fun = lambda x: plt.Lambda(
            [opshin_name_scheme_compatible_varname("while")] + written_vs,
            plt.Ite(
                compiled_c,
                compiled_s(
                    plt.Apply(
                        OVar("while"),
                        OVar("while"),
                        *copy.deepcopy(pwritten_vs),
                    )
                ),
                x,
            ),
        )

        return lambda x: OLet(
            [
                ("adjusted_next", SafeLambda(written_vs, x)),
                (
                    "while",
                    s_fun(
                        SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                    ),
                ),
            ],
            plt.Apply(OVar("while"), OVar("while"), *copy.deepcopy(pwritten_vs)),
        )

    def visit_For(self, node: TypedFor) -> CallAST:
        if node.orelse:
            # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
            cn = copy.copy(node)
            cn.orelse = []
            return self.visit_sequence([cn] + node.orelse)
        assert isinstance(node.iter.typ, InstanceType)
        if isinstance(node.iter.typ.typ, ListType):
            assert isinstance(
                node.target, Name
            ), "Can only assign value to singleton element"
            compiled_s = self.visit_sequence(node.body)
            compiled_iter = self.visit(node.iter)
            written_vs = written_vars(node)
            pwritten_vs = [plt.Var(x) for x in written_vs]
            s_fun = lambda x: plt.Lambda(
                [
                    opshin_name_scheme_compatible_varname("for"),
                    opshin_name_scheme_compatible_varname("iter"),
                ]
                + written_vs,
                plt.IteNullList(
                    OVar("iter"),
                    x,
                    plt.Let(
                        [(node.target.id, plt.Delay(plt.HeadList(OVar("iter"))))],
                        compiled_s(
                            plt.Apply(
                                OVar("for"),
                                OVar("for"),
                                plt.TailList(OVar("iter")),
                                *copy.deepcopy(pwritten_vs),
                            )
                        ),
                    ),
                ),
            )
            return lambda x: OLet(
                [
                    ("adjusted_next", plt.Lambda([node.target.id] + written_vs, x)),
                    (
                        "for",
                        s_fun(
                            plt.Apply(
                                OVar("adjusted_next"),
                                plt.Var(node.target.id),
                                *copy.deepcopy(pwritten_vs),
                            )
                        ),
                    ),
                ],
                plt.Apply(
                    OVar("for"),
                    OVar("for"),
                    compiled_iter,
                    *copy.deepcopy(pwritten_vs),
                ),
            )
        raise NotImplementedError(
            "Compilation of for statements for anything but lists not implemented yet"
        )

    def visit_If(self, node: TypedIf) -> CallAST:
        written_vs = written_vars(node)
        pwritten_vs = [plt.Var(x) for x in written_vs]
        return lambda x: OLet(
            [("adjusted_next", SafeLambda(written_vs, x))],
            plt.Ite(
                self.visit(node.test),
                self.visit_sequence(node.body)(
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                ),
                self.visit_sequence(node.orelse)(
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                ),
            ),
        )

    def visit_Return(self, node: TypedReturn) -> CallAST:
        value_plt = self.visit(node.value)
        assert self.current_function_typ, "Can not handle Return outside of a function"
        if isinstance(self.current_function_typ[-1].rettyp.typ, AnyType) or isinstance(
            self.current_function_typ[-1].rettyp.typ, UnionType
        ):
            value_plt = transform_output_map(node.value.typ)(value_plt)
        return lambda _: value_plt

    def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
        assert isinstance(
            node.value.typ, InstanceType
        ), "Can only access elements of instances, not classes"
        if isinstance(node.value.typ.typ, TupleType):
            assert isinstance(
                node.slice, Constant
            ), "Only constant index access for tuples is supported"
            assert isinstance(
                node.slice.value, int
            ), "Only constant index integer access for tuples is supported"
            index = node.slice.value
            if index < 0:
                index += len(node.value.typ.typ.typs)
            assert isinstance(node.ctx, Load), "Tuples are read-only"
            return plt.FunctionalTupleAccess(
                self.visit(node.value),
                index,
                len(node.value.typ.typ.typs),
            )
        if isinstance(node.value.typ.typ, PairType):
            assert isinstance(
                node.slice, Constant
            ), "Only constant index access for pairs is supported"
            assert isinstance(
                node.slice.value, int
            ), "Only constant index integer access for pairs is supported"
            index = node.slice.value
            if index < 0:
                index += 2
            assert isinstance(node.ctx, Load), "Pairs are read-only"
            assert (
                0 <= index < 2
            ), f"Pairs only have 2 elements, index should be 0 or 1, is {node.slice.value}"
            member_func = plt.FstPair if index == 0 else plt.SndPair
            # the content of pairs is always Data, so we need to unwrap
            member_typ = node.typ
            return transform_ext_params_map(member_typ)(
                member_func(
                    self.visit(node.value),
                ),
            )
        if isinstance(node.value.typ.typ, ListType):
            if not isinstance(node.slice, Slice):
                assert (
                    node.slice.typ == IntegerInstanceType
                ), "Only single element list index access supported"
                if isinstance(node.slice, Constant) and node.slice.value >= 0:
                    index = node.slice.value
                    return plt.ConstantIndexAccessListFast(
                        self.visit(node.value),
                        index,
                    )
                return OLet(
                    [
                        (
                            "l",
                            self.visit(node.value),
                        ),
                        (
                            "raw_i",
                            self.visit(node.slice),
                        ),
                        (
                            "i",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_i"), plt.LengthList(OVar("l"))
                                ),
                                OVar("raw_i"),
                            ),
                        ),
                    ],
                    (
                        plt.IndexAccessListFast(self.config.fast_access_skip)(
                            OVar("l"), OVar("i")
                        )
                        if self.config.fast_access_skip is not None
                        else plt.IndexAccessList(OVar("l"), OVar("i"))
                    ),
                )
            else:
                assert (
                    node.slice.upper is not None
                ), "Only slices with upper bound supported"
                assert (
                    node.slice.lower is not None
                ), "Only slices with lower bound supported"
                return OLet(
                    [
                        (
                            "xs",
                            self.visit(node.value),
                        ),
                        (
                            "raw_i",
                            self.visit(node.slice.lower),
                        ),
                        (
                            "i",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_i"),
                                    plt.LengthList(OVar("xs")),
                                ),
                                OVar("raw_i"),
                            ),
                        ),
                        (
                            "raw_j",
                            self.visit(node.slice.upper),
                        ),
                        (
                            "j",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_j"),
                                    plt.LengthList(OVar("xs")),
                                ),
                                OVar("raw_j"),
                            ),
                        ),
                        (
                            "drop",
                            plt.Ite(
                                plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
                                plt.Integer(0),
                                OVar("i"),
                            ),
                        ),
                        (
                            "take",
                            plt.SubtractInteger(OVar("j"), OVar("drop")),
                        ),
                    ],
                    plt.Ite(
                        plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
                        empty_list(node.value.typ.typ.typ),
                        plt.SliceList(
                            OVar("drop"),
                            OVar("take"),
                            OVar("xs"),
                            empty_list(node.value.typ.typ.typ),
                        ),
                    ),
                )
        elif isinstance(node.value.typ.typ, DictType):
            dict_typ = node.value.typ.typ
            if not isinstance(node.slice, Slice):
                return OLet(
                    [
                        (
                            "key",
                            transform_output_map(node.slice.typ)(
                                self.visit(node.slice),
                            ),
                        )
                    ],
                    transform_ext_params_map(dict_typ.value_typ)(
                        plt.SndPair(
                            plt.FindList(
                                self.visit(node.value),
                                OLambda(
                                    ["x"],
                                    plt.EqualsData(
                                        OVar("key"),
                                        plt.FstPair(OVar("x")),
                                    ),
                                ),
                                plt.TraceError("KeyError"),
                            )
                        ),
                    ),
                )
        elif isinstance(node.value.typ.typ, ByteStringType):
            if not isinstance(node.slice, Slice):
                if isinstance(node.slice, Constant) and node.slice.value >= 0:
                    return plt.IndexByteString(
                        self.visit(node.value),
                        self.visit(node.slice),
                    )
                elif isinstance(node.slice, Constant) and node.slice.value < 0:
                    return plt.IndexByteString(
                        self.visit(node.value),
                        plt.AddInteger(
                            self.visit(node.slice),
                            plt.LengthOfByteString(self.visit(node.value)),
                        ),
                    )
                return OLet(
                    [
                        (
                            "bs",
                            self.visit(node.value),
                        ),
                        (
                            "raw_ix",
                            self.visit(node.slice),
                        ),
                        (
                            "ix",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_ix"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_ix"),
                                    plt.LengthOfByteString(OVar("bs")),
                                ),
                                OVar("raw_ix"),
                            ),
                        ),
                    ],
                    plt.IndexByteString(OVar("bs"), OVar("ix")),
                )
            elif isinstance(node.slice, Slice):
                return OLet(
                    [
                        (
                            "bs",
                            self.visit(node.value),
                        ),
                        (
                            "raw_i",
                            self.visit(node.slice.lower),
                        ),
                        (
                            "i",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_i"),
                                    plt.LengthOfByteString(OVar("bs")),
                                ),
                                OVar("raw_i"),
                            ),
                        ),
                        (
                            "raw_j",
                            self.visit(node.slice.upper),
                        ),
                        (
                            "j",
                            plt.Ite(
                                plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
                                plt.AddInteger(
                                    OVar("raw_j"),
                                    plt.LengthOfByteString(OVar("bs")),
                                ),
                                OVar("raw_j"),
                            ),
                        ),
                        (
                            "drop",
                            plt.Ite(
                                plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
                                plt.Integer(0),
                                OVar("i"),
                            ),
                        ),
                        (
                            "take",
                            plt.SubtractInteger(OVar("j"), OVar("drop")),
                        ),
                    ],
                    plt.Ite(
                        plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
                        plt.ByteString(b""),
                        plt.SliceByteString(
                            OVar("drop"),
                            OVar("take"),
                            OVar("bs"),
                        ),
                    ),
                )
        raise NotImplementedError(
            f'Could not implement subscript "{node.slice}" of "{node.value}"'
        )

    def visit_Tuple(self, node: TypedTuple) -> plt.AST:
        return plt.FunctionalTuple(*(self.visit(e) for e in node.elts))

    def visit_ClassDef(self, node: TypedClassDef) -> CallAST:
        return lambda x: plt.Let([(node.name, plt.Delay(node.class_typ.constr()))], x)

    def visit_Attribute(self, node: TypedAttribute) -> plt.AST:
        assert isinstance(
            node.value.typ, InstanceType
        ), "Can only access attributes of instances"
        obj = self.visit(node.value)
        attr = node.value.typ.attribute(node.attr)
        return plt.Apply(attr, obj)

    def visit_Assert(self, node: TypedAssert) -> CallAST:
        return lambda x: plt.Ite(
            self.visit(node.test),
            x,
            plt.Apply(
                plt.Error(),
                (
                    plt.Trace(self.visit(node.msg), plt.Unit())
                    if node.msg is not None
                    else plt.Unit()
                ),
            ),
        )

    def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> plt.AST:
        return node.expr

    def visit_List(self, node: TypedList) -> plt.AST:
        assert isinstance(node.typ, InstanceType)
        assert isinstance(node.typ.typ, ListType)
        el_typ = node.typ.typ.typ
        l = empty_list(el_typ)
        for e in reversed(node.elts):
            element = self.visit(e)
            if isinstance(el_typ.typ, AnyType) or isinstance(el_typ.typ, UnionType):
                # if the function expects input of generic type data, wrap data before passing it inside
                element = transform_output_map(e.typ)(element)
            l = plt.MkCons(element, l)
        return l

    def visit_Dict(self, node: TypedDict) -> plt.AST:
        assert isinstance(node.typ, InstanceType)
        assert isinstance(node.typ.typ, DictType)
        key_type = node.typ.typ.key_typ
        value_type = node.typ.typ.value_typ
        l = plt.EmptyDataPairList()
        for k, v in zip(node.keys, node.values):
            l = plt.MkCons(
                plt.MkPairData(
                    transform_output_map(k.typ)(
                        self.visit(k),
                    ),
                    transform_output_map(v.typ)(
                        self.visit(v),
                    ),
                ),
                l,
            )
        return l

    def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
        if isinstance(node.typ.typ, UnionType):
            body = self.visit(node.body)
            orelse = self.visit(node.orelse)
            if not isinstance(node.body.typ, UnionType):
                body = transform_output_map(node.body.typ)(body)
            if not isinstance(node.orelse.typ, UnionType):
                orelse = transform_output_map(node.orelse.typ)(orelse)
            return plt.Ite(self.visit(node.test), body, orelse)
        return plt.Ite(
            self.visit(node.test),
            self.visit(node.body),
            self.visit(node.orelse),
        )

    def visit_ListComp(self, node: TypedListComp) -> plt.AST:
        assert len(node.generators) == 1, "Currently only one generator supported"
        gen = node.generators[0]
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
        assert isinstance(
            gen.target, Name
        ), "Can only assign value to singleton element"
        lst = self.visit(gen.iter)
        ifs = None
        for ifexpr in gen.ifs:
            if ifs is None:
                ifs = self.visit(ifexpr)
            else:
                ifs = plt.And(ifs, self.visit(ifexpr))
        map_fun = OLambda(
            ["x"],
            plt.Let(
                [(gen.target.id, plt.Delay(OVar("x")))],
                self.visit(node.elt),
            ),
        )
        empty_list_con = empty_list(node.elt.typ)
        if ifs is not None:
            filter_fun = OLambda(
                ["x"],
                plt.Let(
                    [(gen.target.id, plt.Delay(OVar("x")))],
                    ifs,
                ),
            )
            return plt.MapFilterList(
                lst,
                filter_fun,
                map_fun,
                empty_list_con,
            )
        else:
            return plt.MapList(
                lst,
                map_fun,
                empty_list_con,
            )

    def visit_DictComp(self, node: TypedDictComp) -> plt.AST:
        assert len(node.generators) == 1, "Currently only one generator supported"
        gen = node.generators[0]
        assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
        assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
        assert isinstance(
            gen.target, Name
        ), "Can only assign value to singleton element"
        lst = self.visit(gen.iter)
        ifs = None
        for ifexpr in gen.ifs:
            if ifs is None:
                ifs = self.visit(ifexpr)
            else:
                ifs = plt.And(ifs, self.visit(ifexpr))
        map_fun = OLambda(
            ["x"],
            plt.Let(
                [(gen.target.id, plt.Delay(OVar("x")))],
                plt.MkPairData(
                    transform_output_map(node.key.typ)(
                        self.visit(node.key),
                    ),
                    transform_output_map(node.value.typ)(
                        self.visit(node.value),
                    ),
                ),
            ),
        )
        empty_list_con = plt.EmptyDataPairList()
        if ifs is not None:
            filter_fun = OLambda(
                ["x"],
                plt.Let(
                    [(gen.target.id, plt.Delay(OVar("x")))],
                    ifs,
                ),
            )
            return plt.MapFilterList(
                lst,
                filter_fun,
                map_fun,
                empty_list_con,
            )
        else:
            return plt.MapList(
                lst,
                map_fun,
                empty_list_con,
            )

    def visit_FormattedValue(self, node: TypedFormattedValue) -> plt.AST:
        return plt.Apply(
            node.value.typ.stringify(),
            self.visit(node.value),
        )

    def visit_JoinedStr(self, node: TypedJoinedStr) -> plt.AST:
        joined_str = plt.Text("")
        for v in reversed(node.values):
            joined_str = plt.AppendString(self.visit(v), joined_str)
        return joined_str

    def generic_visit(self, node: TypedAST) -> plt.AST:
        raise NotImplementedError(f"Can not compile {node}")

Ancestors

Class variables

var step

Methods

def generic_visit(self, node: TypedAST) ‑> pluthon.pluthon_ast.AST

Called if no explicit visitor function exists for a node.

Expand source code
def generic_visit(self, node: TypedAST) -> plt.AST:
    raise NotImplementedError(f"Can not compile {node}")
def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_AnnAssign(self, node: TypedAnnAssign) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_AnnAssign(self, node: TypedAnnAssign) -> CallAST:
    assert isinstance(
        node.target, Name
    ), "Assignments to other things then names are not supported"
    assert isinstance(
        node.target.typ, InstanceType
    ), "Can only assign instances to instances"
    val = self.visit(node.value)
    if isinstance(node.value.typ, InstanceType) and (
        isinstance(node.value.typ.typ, AnyType)
        or isinstance(node.value.typ.typ, UnionType)
    ):
        # we need to map this as it will originate from PlutusData
        # AnyType is the only type other than the builtin itself that can be cast to builtin values
        val = transform_ext_params_map(node.target.typ)(val)
    if isinstance(node.target.typ, InstanceType) and (
        isinstance(node.target.typ.typ, AnyType)
        or isinstance(node.target.typ.typ, UnionType)
    ):
        # we need to map this back as it will be treated as PlutusData
        # AnyType is the only type other than the builtin itself that can be cast to from builtin values
        val = transform_output_map(node.value.typ)(val)
    return lambda x: plt.Let(
        [
            (opshin_name_scheme_compatible_varname(node.target.id), val),
            (node.target.id, plt.Delay(OVar(node.target.id))),
        ],
        x,
    )
def visit_Assert(self, node: TypedAssert) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_Assert(self, node: TypedAssert) -> CallAST:
    return lambda x: plt.Ite(
        self.visit(node.test),
        x,
        plt.Apply(
            plt.Error(),
            (
                plt.Trace(self.visit(node.msg), plt.Unit())
                if node.msg is not None
                else plt.Unit()
            ),
        ),
    )
def visit_Assign(self, node: TypedAssign) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_Assign(self, node: TypedAssign) -> CallAST:
    assert (
        len(node.targets) == 1
    ), "Assignments to more than one variable not supported yet"
    assert isinstance(
        node.targets[0], Name
    ), "Assignments to other things then names are not supported"
    compiled_e = self.visit(node.value)
    varname = node.targets[0].id
    if (hasattr(node.targets[0], "is_wrapped") and node.targets[0].is_wrapped) or (
        isinstance(node.targets[0].typ, InstanceType)
        and (
            isinstance(node.targets[0].typ.typ, AnyType)
            or isinstance(node.targets[0].typ.typ, UnionType)
        )
    ):
        # if this is a wrapped variable or Union/Any, we need to map it back to the external parameter type
        # TODO this is terribly inefficient. we would rather want to cast once when entering the body and cast back when leaving
        compiled_e = transform_output_map(node.value.typ)(compiled_e)
    # first evaluate the term, then wrap in a delay
    return lambda x: plt.Let(
        [
            (opshin_name_scheme_compatible_varname(varname), compiled_e),
            (varname, plt.Delay(OVar(varname))),
        ],
        x,
    )
def visit_Attribute(self, node: TypedAttribute) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Attribute(self, node: TypedAttribute) -> plt.AST:
    assert isinstance(
        node.value.typ, InstanceType
    ), "Can only access attributes of instances"
    obj = self.visit(node.value)
    attr = node.value.typ.attribute(node.attr)
    return plt.Apply(attr, obj)
def visit_BinOp(self, node: TypedBinOp) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_BinOp(self, node: TypedBinOp) -> plt.AST:
    op = node.left.typ.binop(node.op, node.right)
    return plt.Apply(
        op,
        self.visit(node.left),
        self.visit(node.right),
    )
def visit_BoolOp(self, node: TypedBoolOp) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_BoolOp(self, node: TypedBoolOp) -> plt.AST:
    op = BoolOpMap.get(type(node.op))
    assert len(node.values) >= 2, "Need to compare at least to values"
    ops = op(
        self.visit(node.values[0]),
        self.visit(node.values[1]),
    )
    for v in node.values[2:]:
        ops = op(ops, self.visit(v))
    return ops
def visit_Call(self, node: TypedCall) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Call(self, node: TypedCall) -> plt.AST:
    # compiled_args = " ".join(f"({self.visit(a)} {STATEMONAD})" for a in node.args)
    # return rf"(\{STATEMONAD} -> ({self.visit(node.func)} {compiled_args})"
    # TODO function is actually not of type polymorphic function type here anymore
    if isinstance(node.func.typ, PolymorphicFunctionInstanceType):
        # edge case for weird builtins that are polymorphic
        func_plt = force_params(
            node.func.typ.polymorphic_function.impl_from_args(
                node.func.typ.typ.argtyps
            )
        )
        bind_self = None
    else:
        assert isinstance(node.func.typ, InstanceType) and isinstance(
            node.func.typ.typ, FunctionType
        ), "Can only call instances of functions"
        func_plt = self.visit(node.func)
        bind_self = node.func.typ.typ.bind_self
    bound_vs = sorted(list(node.func.typ.typ.bound_vars.keys()))
    args = []
    for i, (a, t) in enumerate(zip(node.args, node.func.typ.typ.argtyps)):
        # now impl_from_args has been chosen, skip type arg
        if (
            hasattr(node.func, "orig_id")
            and node.func.orig_id == "isinstance"
            and i == 1
        ):
            continue
        assert isinstance(t, InstanceType)
        # pass in all arguments evaluated with the statemonad
        a_int = self.visit(a)
        if isinstance(t.typ, AnyType) or isinstance(t.typ, UnionType):
            # if the function expects input of generic type data, wrap data before passing it inside
            a_int = transform_output_map(a.typ)(a_int)
        args.append(a_int)
    # First assign to let to ensure that the arguments are evaluated before the call, but need to delay
    # as this is a variable assignment
    # Also bring all states of variables read inside the function into scope / update with value in current state
    # before call to simulate statemonad with current state being passed in
    return OLet(
        [(f"p{i}", a) for i, a in enumerate(args)],
        SafeApply(
            func_plt,
            *([plt.Var(bind_self)] if bind_self is not None else []),
            *[plt.Var(n) for n in bound_vs],
            *[plt.Delay(OVar(f"p{i}")) for i in range(len(args))],
        ),
    )
def visit_ClassDef(self, node: TypedClassDef) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_ClassDef(self, node: TypedClassDef) -> CallAST:
    return lambda x: plt.Let([(node.name, plt.Delay(node.class_typ.constr()))], x)
def visit_Compare(self, node: TypedCompare) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Compare(self, node: TypedCompare) -> plt.AST:
    assert len(node.ops) == 1, "Only single comparisons are supported"
    assert len(node.comparators) == 1, "Only single comparisons are supported"
    cmpop = node.ops[0]
    comparator = node.comparators[0].typ
    op = node.left.typ.cmp(cmpop, comparator)
    return plt.Apply(
        op,
        self.visit(node.left),
        self.visit(node.comparators[0]),
    )
def visit_Constant(self, node: ast.Constant) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Constant(self, node: Constant) -> plt.AST:
    if isinstance(node.value, bytes) and node.value != b"":
        try:
            bytes.fromhex(node.value.decode())
        except ValueError:
            pass
        else:
            OPSHIN_LOGGER.warning(
                f"The string {node.value} looks like it is supposed to be a hex-encoded bytestring but is actually utf8-encoded. Try using `bytes.fromhex('{node.value.decode()}')` instead."
            )
    plt_val = plt.UPLCConstant(rec_constant_map(node.value))
    return plt_val
def visit_Dict(self, node: TypedDict) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Dict(self, node: TypedDict) -> plt.AST:
    assert isinstance(node.typ, InstanceType)
    assert isinstance(node.typ.typ, DictType)
    key_type = node.typ.typ.key_typ
    value_type = node.typ.typ.value_typ
    l = plt.EmptyDataPairList()
    for k, v in zip(node.keys, node.values):
        l = plt.MkCons(
            plt.MkPairData(
                transform_output_map(k.typ)(
                    self.visit(k),
                ),
                transform_output_map(v.typ)(
                    self.visit(v),
                ),
            ),
            l,
        )
    return l
def visit_DictComp(self, node: TypedDictComp) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_DictComp(self, node: TypedDictComp) -> plt.AST:
    assert len(node.generators) == 1, "Currently only one generator supported"
    gen = node.generators[0]
    assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
    assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
    assert isinstance(
        gen.target, Name
    ), "Can only assign value to singleton element"
    lst = self.visit(gen.iter)
    ifs = None
    for ifexpr in gen.ifs:
        if ifs is None:
            ifs = self.visit(ifexpr)
        else:
            ifs = plt.And(ifs, self.visit(ifexpr))
    map_fun = OLambda(
        ["x"],
        plt.Let(
            [(gen.target.id, plt.Delay(OVar("x")))],
            plt.MkPairData(
                transform_output_map(node.key.typ)(
                    self.visit(node.key),
                ),
                transform_output_map(node.value.typ)(
                    self.visit(node.value),
                ),
            ),
        ),
    )
    empty_list_con = plt.EmptyDataPairList()
    if ifs is not None:
        filter_fun = OLambda(
            ["x"],
            plt.Let(
                [(gen.target.id, plt.Delay(OVar("x")))],
                ifs,
            ),
        )
        return plt.MapFilterList(
            lst,
            filter_fun,
            map_fun,
            empty_list_con,
        )
    else:
        return plt.MapList(
            lst,
            map_fun,
            empty_list_con,
        )
def visit_Expr(self, node: TypedExpr) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_Expr(self, node: TypedExpr) -> CallAST:
    # we exploit UPLCs eager evaluation here
    # the expression is computed even though its value is eventually discarded
    # Note this really only makes sense for Trace
    # we use an invalid name here to avoid conflicts
    return lambda x: plt.Apply(OLambda(["0"], x), self.visit(node.value))
def visit_For(self, node: TypedFor) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_For(self, node: TypedFor) -> CallAST:
    if node.orelse:
        # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
        cn = copy.copy(node)
        cn.orelse = []
        return self.visit_sequence([cn] + node.orelse)
    assert isinstance(node.iter.typ, InstanceType)
    if isinstance(node.iter.typ.typ, ListType):
        assert isinstance(
            node.target, Name
        ), "Can only assign value to singleton element"
        compiled_s = self.visit_sequence(node.body)
        compiled_iter = self.visit(node.iter)
        written_vs = written_vars(node)
        pwritten_vs = [plt.Var(x) for x in written_vs]
        s_fun = lambda x: plt.Lambda(
            [
                opshin_name_scheme_compatible_varname("for"),
                opshin_name_scheme_compatible_varname("iter"),
            ]
            + written_vs,
            plt.IteNullList(
                OVar("iter"),
                x,
                plt.Let(
                    [(node.target.id, plt.Delay(plt.HeadList(OVar("iter"))))],
                    compiled_s(
                        plt.Apply(
                            OVar("for"),
                            OVar("for"),
                            plt.TailList(OVar("iter")),
                            *copy.deepcopy(pwritten_vs),
                        )
                    ),
                ),
            ),
        )
        return lambda x: OLet(
            [
                ("adjusted_next", plt.Lambda([node.target.id] + written_vs, x)),
                (
                    "for",
                    s_fun(
                        plt.Apply(
                            OVar("adjusted_next"),
                            plt.Var(node.target.id),
                            *copy.deepcopy(pwritten_vs),
                        )
                    ),
                ),
            ],
            plt.Apply(
                OVar("for"),
                OVar("for"),
                compiled_iter,
                *copy.deepcopy(pwritten_vs),
            ),
        )
    raise NotImplementedError(
        "Compilation of for statements for anything but lists not implemented yet"
    )
def visit_FormattedValue(self, node: TypedFormattedValue) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_FormattedValue(self, node: TypedFormattedValue) -> plt.AST:
    return plt.Apply(
        node.value.typ.stringify(),
        self.visit(node.value),
    )
def visit_FunctionDef(self, node: TypedFunctionDef) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_FunctionDef(self, node: TypedFunctionDef) -> CallAST:
    body = node.body.copy()
    # defaults to returning None if there is no return statement
    if node.typ.typ.rettyp.typ == AnyType():
        ret_val = OUnit
    else:
        ret_val = plt.Unit()
    read_vs = sorted(list(node.typ.typ.bound_vars.keys()))
    if node.typ.typ.bind_self is not None:
        read_vs.insert(0, node.typ.typ.bind_self)
    self.current_function_typ.append(node.typ.typ)
    compiled_body = self.visit_sequence(body)(ret_val)
    self.current_function_typ.pop()
    return lambda x: plt.Let(
        [
            (
                node.name,
                plt.Delay(
                    SafeLambda(
                        read_vs + [a.arg for a in node.args.args],
                        compiled_body,
                    )
                ),
            )
        ],
        x,
    )
def visit_If(self, node: TypedIf) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_If(self, node: TypedIf) -> CallAST:
    written_vs = written_vars(node)
    pwritten_vs = [plt.Var(x) for x in written_vs]
    return lambda x: OLet(
        [("adjusted_next", SafeLambda(written_vs, x))],
        plt.Ite(
            self.visit(node.test),
            self.visit_sequence(node.body)(
                SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
            ),
            self.visit_sequence(node.orelse)(
                SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
            ),
        ),
    )
def visit_IfExp(self, node: TypedIfExp) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_IfExp(self, node: TypedIfExp) -> plt.AST:
    if isinstance(node.typ.typ, UnionType):
        body = self.visit(node.body)
        orelse = self.visit(node.orelse)
        if not isinstance(node.body.typ, UnionType):
            body = transform_output_map(node.body.typ)(body)
        if not isinstance(node.orelse.typ, UnionType):
            orelse = transform_output_map(node.orelse.typ)(orelse)
        return plt.Ite(self.visit(node.test), body, orelse)
    return plt.Ite(
        self.visit(node.test),
        self.visit(node.body),
        self.visit(node.orelse),
    )
def visit_JoinedStr(self, node: TypedJoinedStr) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_JoinedStr(self, node: TypedJoinedStr) -> plt.AST:
    joined_str = plt.Text("")
    for v in reversed(node.values):
        joined_str = plt.AppendString(self.visit(v), joined_str)
    return joined_str
def visit_List(self, node: TypedList) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_List(self, node: TypedList) -> plt.AST:
    assert isinstance(node.typ, InstanceType)
    assert isinstance(node.typ.typ, ListType)
    el_typ = node.typ.typ.typ
    l = empty_list(el_typ)
    for e in reversed(node.elts):
        element = self.visit(e)
        if isinstance(el_typ.typ, AnyType) or isinstance(el_typ.typ, UnionType):
            # if the function expects input of generic type data, wrap data before passing it inside
            element = transform_output_map(e.typ)(element)
        l = plt.MkCons(element, l)
    return l
def visit_ListComp(self, node: TypedListComp) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_ListComp(self, node: TypedListComp) -> plt.AST:
    assert len(node.generators) == 1, "Currently only one generator supported"
    gen = node.generators[0]
    assert isinstance(gen.iter.typ, InstanceType), "Only lists are valid generators"
    assert isinstance(gen.iter.typ.typ, ListType), "Only lists are valid generators"
    assert isinstance(
        gen.target, Name
    ), "Can only assign value to singleton element"
    lst = self.visit(gen.iter)
    ifs = None
    for ifexpr in gen.ifs:
        if ifs is None:
            ifs = self.visit(ifexpr)
        else:
            ifs = plt.And(ifs, self.visit(ifexpr))
    map_fun = OLambda(
        ["x"],
        plt.Let(
            [(gen.target.id, plt.Delay(OVar("x")))],
            self.visit(node.elt),
        ),
    )
    empty_list_con = empty_list(node.elt.typ)
    if ifs is not None:
        filter_fun = OLambda(
            ["x"],
            plt.Let(
                [(gen.target.id, plt.Delay(OVar("x")))],
                ifs,
            ),
        )
        return plt.MapFilterList(
            lst,
            filter_fun,
            map_fun,
            empty_list_con,
        )
    else:
        return plt.MapList(
            lst,
            map_fun,
            empty_list_con,
        )
def visit_Module(self, node: TypedModule) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Module(self, node: TypedModule) -> plt.AST:
    # extract actually read variables by each function
    if self.validator_function_name is not None:
        # for validators find main function
        # TODO can use more sophisiticated procedure here i.e. functions marked by comment
        main_fun: typing.Optional[InstanceType] = None
        for s in node.body:
            if (
                isinstance(s, ast.FunctionDef)
                and s.orig_name == self.validator_function_name
            ):
                main_fun = s
        assert (
            main_fun is not None
        ), f"Could not find function named {self.validator_function_name}"
        main_fun_typ: FunctionType = main_fun.typ.typ
        assert isinstance(
            main_fun_typ, FunctionType
        ), f"Variable named {self.validator_function_name} is not of type function"

        # check if this is a contract written to double function
        enable_double_func_mint_spend = False
        if len(main_fun_typ.argtyps) >= 3 and self.force_three_params:
            # check if is possible
            second_last_arg = main_fun_typ.argtyps[-2]
            assert isinstance(
                second_last_arg, InstanceType
            ), "Can not pass Class into validator"
            if isinstance(second_last_arg.typ, UnionType):
                possible_types = second_last_arg.typ.typs
            else:
                possible_types = [second_last_arg.typ]
            if any(isinstance(t, UnitType) for t in possible_types):
                OPSHIN_LOGGER.warning(
                    "The redeemer is annotated to be 'None'. This value is usually encoded in PlutusData with constructor id 0 and no fields. If you want the script to double function as minting and spending script, annotate the second argument with 'NoRedeemer'."
                )
            enable_double_func_mint_spend = not any(
                (isinstance(t, RecordType) and t.record.constructor == 0)
                or isinstance(t, UnitType)
                for t in possible_types
            )
            if not enable_double_func_mint_spend:
                OPSHIN_LOGGER.warning(
                    "The second argument to the validator function potentially has constructor id 0. The validator will not be able to double function as minting script and spending script."
                )

        body = node.body + (
            [
                TypedReturn(
                    TypedCall(
                        func=ast.Name(
                            id=main_fun.name,
                            typ=InstanceType(main_fun_typ),
                            ctx=ast.Load(),
                        ),
                        typ=main_fun_typ.rettyp,
                        args=[
                            RawPlutoExpr(
                                expr=transform_ext_params_map(a)(
                                    OVar(f"val_param{i}")
                                ),
                                typ=a,
                            )
                            for i, a in enumerate(main_fun_typ.argtyps)
                        ],
                    )
                )
            ]
        )
        self.current_function_typ.append(FunctionType([], InstanceType(AnyType())))
        name_load_visitor = NameLoadCollector()
        name_load_visitor.visit(node)
        all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

        # write all variables that are ever read
        # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
        validator = SafeOLambda(
            [f"val_param{i}" for i, _ in enumerate(main_fun_typ.argtyps)],
            plt.Let(
                [
                    (
                        x,
                        plt.Delay(
                            plt.TraceError(f"NameError: {map_to_orig_name(x)}")
                        ),
                    )
                    for x in all_vs
                ],
                self.visit_sequence(body)(OUnit),
            ),
        )
        self.current_function_typ.pop()
        if enable_double_func_mint_spend:
            validator = wrap_validator_double_function(
                validator, pass_through=len(main_fun_typ.argtyps) - 3
            )
        elif self.force_three_params:
            # Error if the double function is enforced but not possible
            raise RuntimeError(
                "The contract can not always detect if it was passed three or two parameters on-chain."
            )
    else:
        name_load_visitor = NameLoadCollector()
        name_load_visitor.visit(node)
        all_vs = sorted(set(all_vars(node)) | set(name_load_visitor.loaded.keys()))

        body = node.body
        # write all variables that are ever read
        # once at the beginning so that we can always access them (only potentially causing a nameerror at runtime)
        validator = plt.Let(
            [
                (
                    x,
                    plt.Delay(plt.TraceError(f"NameError: {map_to_orig_name(x)}")),
                )
                for x in all_vs
            ],
            self.visit_sequence(body)(OUnit),
        )

    cp = plt.Program((1, 0, 0), validator)
    return cp
def visit_Name(self, node: ast.Name) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Name(self, node: Name) -> plt.AST:
    # depending on load or store context, return the value of the variable or its name
    if not isinstance(node.ctx, Load):
        raise NotImplementedError(f"Context {node.ctx} not supported")
    if isinstance(node.typ, ClassType):
        # if this is not an instance but a class, call the constructor
        return node.typ.constr()
    if hasattr(node, "is_wrapped") and node.is_wrapped:
        return transform_ext_params_map(node.typ)(plt.Force(plt.Var(node.id)))
    return plt.Force(plt.Var(node.id))
def visit_NoneType(self, _: Any | None) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_NoneType(self, _: typing.Optional[typing.Any]) -> plt.AST:
    return plt.Unit()
def visit_RawPlutoExpr(self, node: RawPlutoExpr) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_RawPlutoExpr(self, node: RawPlutoExpr) -> plt.AST:
    return node.expr
def visit_Return(self, node: TypedReturn) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_Return(self, node: TypedReturn) -> CallAST:
    value_plt = self.visit(node.value)
    assert self.current_function_typ, "Can not handle Return outside of a function"
    if isinstance(self.current_function_typ[-1].rettyp.typ, AnyType) or isinstance(
        self.current_function_typ[-1].rettyp.typ, UnionType
    ):
        value_plt = transform_output_map(node.value.typ)(value_plt)
    return lambda _: value_plt
def visit_Subscript(self, node: TypedSubscript) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Subscript(self, node: TypedSubscript) -> plt.AST:
    assert isinstance(
        node.value.typ, InstanceType
    ), "Can only access elements of instances, not classes"
    if isinstance(node.value.typ.typ, TupleType):
        assert isinstance(
            node.slice, Constant
        ), "Only constant index access for tuples is supported"
        assert isinstance(
            node.slice.value, int
        ), "Only constant index integer access for tuples is supported"
        index = node.slice.value
        if index < 0:
            index += len(node.value.typ.typ.typs)
        assert isinstance(node.ctx, Load), "Tuples are read-only"
        return plt.FunctionalTupleAccess(
            self.visit(node.value),
            index,
            len(node.value.typ.typ.typs),
        )
    if isinstance(node.value.typ.typ, PairType):
        assert isinstance(
            node.slice, Constant
        ), "Only constant index access for pairs is supported"
        assert isinstance(
            node.slice.value, int
        ), "Only constant index integer access for pairs is supported"
        index = node.slice.value
        if index < 0:
            index += 2
        assert isinstance(node.ctx, Load), "Pairs are read-only"
        assert (
            0 <= index < 2
        ), f"Pairs only have 2 elements, index should be 0 or 1, is {node.slice.value}"
        member_func = plt.FstPair if index == 0 else plt.SndPair
        # the content of pairs is always Data, so we need to unwrap
        member_typ = node.typ
        return transform_ext_params_map(member_typ)(
            member_func(
                self.visit(node.value),
            ),
        )
    if isinstance(node.value.typ.typ, ListType):
        if not isinstance(node.slice, Slice):
            assert (
                node.slice.typ == IntegerInstanceType
            ), "Only single element list index access supported"
            if isinstance(node.slice, Constant) and node.slice.value >= 0:
                index = node.slice.value
                return plt.ConstantIndexAccessListFast(
                    self.visit(node.value),
                    index,
                )
            return OLet(
                [
                    (
                        "l",
                        self.visit(node.value),
                    ),
                    (
                        "raw_i",
                        self.visit(node.slice),
                    ),
                    (
                        "i",
                        plt.Ite(
                            plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                            plt.AddInteger(
                                OVar("raw_i"), plt.LengthList(OVar("l"))
                            ),
                            OVar("raw_i"),
                        ),
                    ),
                ],
                (
                    plt.IndexAccessListFast(self.config.fast_access_skip)(
                        OVar("l"), OVar("i")
                    )
                    if self.config.fast_access_skip is not None
                    else plt.IndexAccessList(OVar("l"), OVar("i"))
                ),
            )
        else:
            assert (
                node.slice.upper is not None
            ), "Only slices with upper bound supported"
            assert (
                node.slice.lower is not None
            ), "Only slices with lower bound supported"
            return OLet(
                [
                    (
                        "xs",
                        self.visit(node.value),
                    ),
                    (
                        "raw_i",
                        self.visit(node.slice.lower),
                    ),
                    (
                        "i",
                        plt.Ite(
                            plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                            plt.AddInteger(
                                OVar("raw_i"),
                                plt.LengthList(OVar("xs")),
                            ),
                            OVar("raw_i"),
                        ),
                    ),
                    (
                        "raw_j",
                        self.visit(node.slice.upper),
                    ),
                    (
                        "j",
                        plt.Ite(
                            plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
                            plt.AddInteger(
                                OVar("raw_j"),
                                plt.LengthList(OVar("xs")),
                            ),
                            OVar("raw_j"),
                        ),
                    ),
                    (
                        "drop",
                        plt.Ite(
                            plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
                            plt.Integer(0),
                            OVar("i"),
                        ),
                    ),
                    (
                        "take",
                        plt.SubtractInteger(OVar("j"), OVar("drop")),
                    ),
                ],
                plt.Ite(
                    plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
                    empty_list(node.value.typ.typ.typ),
                    plt.SliceList(
                        OVar("drop"),
                        OVar("take"),
                        OVar("xs"),
                        empty_list(node.value.typ.typ.typ),
                    ),
                ),
            )
    elif isinstance(node.value.typ.typ, DictType):
        dict_typ = node.value.typ.typ
        if not isinstance(node.slice, Slice):
            return OLet(
                [
                    (
                        "key",
                        transform_output_map(node.slice.typ)(
                            self.visit(node.slice),
                        ),
                    )
                ],
                transform_ext_params_map(dict_typ.value_typ)(
                    plt.SndPair(
                        plt.FindList(
                            self.visit(node.value),
                            OLambda(
                                ["x"],
                                plt.EqualsData(
                                    OVar("key"),
                                    plt.FstPair(OVar("x")),
                                ),
                            ),
                            plt.TraceError("KeyError"),
                        )
                    ),
                ),
            )
    elif isinstance(node.value.typ.typ, ByteStringType):
        if not isinstance(node.slice, Slice):
            if isinstance(node.slice, Constant) and node.slice.value >= 0:
                return plt.IndexByteString(
                    self.visit(node.value),
                    self.visit(node.slice),
                )
            elif isinstance(node.slice, Constant) and node.slice.value < 0:
                return plt.IndexByteString(
                    self.visit(node.value),
                    plt.AddInteger(
                        self.visit(node.slice),
                        plt.LengthOfByteString(self.visit(node.value)),
                    ),
                )
            return OLet(
                [
                    (
                        "bs",
                        self.visit(node.value),
                    ),
                    (
                        "raw_ix",
                        self.visit(node.slice),
                    ),
                    (
                        "ix",
                        plt.Ite(
                            plt.LessThanInteger(OVar("raw_ix"), plt.Integer(0)),
                            plt.AddInteger(
                                OVar("raw_ix"),
                                plt.LengthOfByteString(OVar("bs")),
                            ),
                            OVar("raw_ix"),
                        ),
                    ),
                ],
                plt.IndexByteString(OVar("bs"), OVar("ix")),
            )
        elif isinstance(node.slice, Slice):
            return OLet(
                [
                    (
                        "bs",
                        self.visit(node.value),
                    ),
                    (
                        "raw_i",
                        self.visit(node.slice.lower),
                    ),
                    (
                        "i",
                        plt.Ite(
                            plt.LessThanInteger(OVar("raw_i"), plt.Integer(0)),
                            plt.AddInteger(
                                OVar("raw_i"),
                                plt.LengthOfByteString(OVar("bs")),
                            ),
                            OVar("raw_i"),
                        ),
                    ),
                    (
                        "raw_j",
                        self.visit(node.slice.upper),
                    ),
                    (
                        "j",
                        plt.Ite(
                            plt.LessThanInteger(OVar("raw_j"), plt.Integer(0)),
                            plt.AddInteger(
                                OVar("raw_j"),
                                plt.LengthOfByteString(OVar("bs")),
                            ),
                            OVar("raw_j"),
                        ),
                    ),
                    (
                        "drop",
                        plt.Ite(
                            plt.LessThanEqualsInteger(OVar("i"), plt.Integer(0)),
                            plt.Integer(0),
                            OVar("i"),
                        ),
                    ),
                    (
                        "take",
                        plt.SubtractInteger(OVar("j"), OVar("drop")),
                    ),
                ],
                plt.Ite(
                    plt.LessThanEqualsInteger(OVar("j"), OVar("i")),
                    plt.ByteString(b""),
                    plt.SliceByteString(
                        OVar("drop"),
                        OVar("take"),
                        OVar("bs"),
                    ),
                ),
            )
    raise NotImplementedError(
        f'Could not implement subscript "{node.slice}" of "{node.value}"'
    )
def visit_Tuple(self, node: TypedTuple) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_Tuple(self, node: TypedTuple) -> plt.AST:
    return plt.FunctionalTuple(*(self.visit(e) for e in node.elts))
def visit_UnaryOp(self, node: TypedUnaryOp) ‑> pluthon.pluthon_ast.AST
Expand source code
def visit_UnaryOp(self, node: TypedUnaryOp) -> plt.AST:
    op = node.operand.typ.unop(node.op)
    return plt.Apply(
        op,
        self.visit(node.operand),
    )
def visit_While(self, node: TypedWhile) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_While(self, node: TypedWhile) -> CallAST:
    # the while loop calls itself, updating the values at overwritten names
    # by overwriting them with arguments to its self-recall
    if node.orelse:
        # If there is orelse, transform it to an appended sequence (TODO check if this is correct)
        cn = copy.copy(node)
        cn.orelse = []
        return self.visit_sequence([cn] + node.orelse)
    compiled_c = self.visit(node.test)
    compiled_s = self.visit_sequence(node.body)
    written_vs = written_vars(node)
    pwritten_vs = [plt.Var(x) for x in written_vs]
    s_fun = lambda x: plt.Lambda(
        [opshin_name_scheme_compatible_varname("while")] + written_vs,
        plt.Ite(
            compiled_c,
            compiled_s(
                plt.Apply(
                    OVar("while"),
                    OVar("while"),
                    *copy.deepcopy(pwritten_vs),
                )
            ),
            x,
        ),
    )

    return lambda x: OLet(
        [
            ("adjusted_next", SafeLambda(written_vs, x)),
            (
                "while",
                s_fun(
                    SafeApply(OVar("adjusted_next"), *copy.deepcopy(pwritten_vs))
                ),
            ),
        ],
        plt.Apply(OVar("while"), OVar("while"), *copy.deepcopy(pwritten_vs)),
    )
def visit_sequence(self, node_seq: List[typedstmt]) ‑> Callable[[pluthon.pluthon_ast.AST], pluthon.pluthon_ast.AST]
Expand source code
def visit_sequence(self, node_seq: typing.List[typedstmt]) -> CallAST:
    def g(s: plt.AST):
        for n in reversed(node_seq):
            compiled_stmt = self.visit(n)
            s = compiled_stmt(s)
        return s

    return g