Module opshin.optimize.optimize_const_folding

Expand source code
import typing
from collections import defaultdict
import logging

from ast import *
from ordered_set import OrderedSet

from pycardano import PlutusData

try:
    unparse
except NameError:
    from astunparse import unparse

from ..util import CompilingNodeTransformer, CompilingNodeVisitor, OPSHIN_LOGGER
from ..type_inference import INITIAL_SCOPE

"""
Pre-evaluates constant statements
"""

ACCEPTED_ATOMIC_TYPES = [
    int,
    str,
    bytes,
    type(None),
    bool,
]

SAFE_GLOBALS_LIST = [
    abs,
    all,
    any,
    ascii,
    bin,
    bool,
    bytes,
    bytearray,
    callable,
    chr,
    classmethod,
    compile,
    complex,
    delattr,
    dict,
    dir,
    divmod,
    enumerate,
    filter,
    float,
    format,
    frozenset,
    getattr,
    hasattr,
    hash,
    hex,
    id,
    input,
    int,
    isinstance,
    issubclass,
    iter,
    len,
    list,
    map,
    max,
    min,
    next,
    object,
    oct,
    open,
    ord,
    pow,
    print,
    property,
    range,
    repr,
    reversed,
    round,
    set,
    setattr,
    slice,
    sorted,
    staticmethod,
    str,
    sum,
    super,
    tuple,
    type,
    vars,
    zip,
]
SAFE_GLOBALS = {x.__name__: x for x in SAFE_GLOBALS_LIST}


class ShallowNameDefCollector(CompilingNodeVisitor):
    step = "Collecting occuring variable names"

    def __init__(self):
        self.vars = OrderedSet()

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars.add(node.id)

    def visit_ClassDef(self, node: ClassDef):
        self.vars.add(node.name)
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars.add(node.name)
        # ignore the recursive stuff


class DefinedTimesVisitor(CompilingNodeVisitor):
    step = "Collecting how often variables are written"

    def __init__(self):
        self.vars = defaultdict(int)

    def visit_For(self, node: For) -> None:
        # visit twice to have all names bumped to min 2 assignments
        self.generic_visit(node)
        self.generic_visit(node)
        return
        # TODO future items: use this together with guaranteed available
        # visit twice to have this name bumped to min 2 assignments
        self.visit(node.target)
        # visit the whole function
        self.generic_visit(node)

    def visit_While(self, node: While) -> None:
        # visit twice to have all names bumped to min 2 assignments
        self.generic_visit(node)
        self.generic_visit(node)
        return
        # TODO future items: use this together with guaranteed available

    def visit_If(self, node: If) -> None:
        # TODO future items: use this together with guaranteed available
        # visit twice to have all names bumped to min 2 assignments
        self.generic_visit(node)
        self.generic_visit(node)

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars[node.id] += 1

    def visit_ClassDef(self, node: ClassDef):
        self.vars[node.name] += 1
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars[node.name] += 1
        # visit arguments twice, they are generally assigned more than once
        for arg in node.args.args:
            self.vars[arg.arg] += 2
        self.generic_visit(node)

    def visit_Import(self, node: Import):
        for n in node.names:
            self.vars[n] += 1

    def visit_ImportFrom(self, node: ImportFrom):
        for n in node.names:
            self.vars[n] += 1


class OptimizeConstantFolding(CompilingNodeTransformer):
    step = "Constant folding"

    def __init__(self):
        self.scopes_visible = [
            OrderedSet(INITIAL_SCOPE.keys()).difference(SAFE_GLOBALS.keys())
        ]
        self.scopes_constants = [dict()]
        self.constants = OrderedSet()

    def enter_scope(self):
        self.scopes_visible.append(OrderedSet())
        self.scopes_constants.append(dict())

    def add_var_visible(self, var: str):
        self.scopes_visible[-1].add(var)

    def add_vars_visible(self, var: typing.Iterable[str]):
        self.scopes_visible[-1].update(var)

    def add_constant(self, var: str, value: typing.Any):
        self.scopes_constants[-1][var] = value

    def visible_vars(self):
        res_set = OrderedSet()
        for s in self.scopes_visible:
            res_set.update(s)
        return res_set

    def _constant_vars(self):
        res_d = {}
        for s in self.scopes_constants:
            res_d.update(s)
        return res_d

    def exit_scope(self):
        self.scopes_visible.pop(-1)
        self.scopes_constants.pop(-1)

    def _non_overwritten_globals(self):
        overwritten_vars = self.visible_vars()

        def err():
            raise ValueError("Was overwritten!")

        non_overwritten_globals = {
            k: (v if k not in overwritten_vars else err)
            for k, v in SAFE_GLOBALS.items()
        }
        return non_overwritten_globals

    def update_constants(self, node):
        a = self._non_overwritten_globals()
        a.update(self._constant_vars())
        g = a
        l = {}
        try:
            exec(unparse(node), g, l)
        except Exception as e:
            OPSHIN_LOGGER.debug(e)
        else:
            # the class is defined and added to the globals
            self.scopes_constants[-1].update(l)

    def visit_Module(self, node: Module) -> Module:
        self.enter_scope()
        def_vars_collector = ShallowNameDefCollector()
        def_vars_collector.visit(node)
        def_vars = def_vars_collector.vars
        self.add_vars_visible(def_vars)

        constant_collector = DefinedTimesVisitor()
        constant_collector.visit(node)
        constants = constant_collector.vars
        # if it is only assigned exactly once, it must be a constant (due to immutability)
        self.constants = {c for c, i in constants.items() if i == 1}

        res = self.generic_visit(node)
        self.exit_scope()
        return res

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        self.add_var_visible(node.name)
        if node.name in self.constants:
            a = self._non_overwritten_globals()
            a.update(self._constant_vars())
            g = a
            try:
                # we need to pass the global dict as local dict here to make closures possible (rec functions)
                exec(unparse(node), g, g)
            except Exception as e:
                OPSHIN_LOGGER.debug(e)
            else:
                # the class is defined and added to the globals
                self.scopes_constants[-1][node.name] = g[node.name]

        self.enter_scope()
        self.add_vars_visible(arg.arg for arg in node.args.args)
        def_vars_collector = ShallowNameDefCollector()
        for s in node.body:
            def_vars_collector.visit(s)
        def_vars = def_vars_collector.vars
        self.add_vars_visible(def_vars)

        res_node = self.generic_visit(node)
        self.exit_scope()
        return res_node

    def visit_ClassDef(self, node: ClassDef):
        if node.name in self.constants:
            self.update_constants(node)
        return node

    def visit_ImportFrom(self, node: ImportFrom):
        if all(n in self.constants for n in node.names):
            self.update_constants(node)
        return node

    def visit_Import(self, node: Import):
        if all(n in self.constants for n in node.names):
            self.update_constants(node)
        return node

    def visit_Assign(self, node: Assign):
        if len(node.targets) != 1:
            return node
        target = node.targets[0]
        if not isinstance(target, Name):
            return node

        if target.id in self.constants:
            self.update_constants(node)
        node.value = self.visit(node.value)
        return node

    def visit_AnnAssign(self, node: AnnAssign):
        target = node.target
        if not isinstance(target, Name):
            return node

        if target.id in self.constants:
            self.update_constants(node)
        node.value = self.visit(node.value)
        return node

    def generic_visit(self, node: AST):
        node = super().generic_visit(node)
        if not isinstance(node, expr):
            # only evaluate expressions, not statements
            return node
        if isinstance(node, Constant):
            # prevents unneccessary computations
            return node
        try:
            node_source = unparse(node)
        except Exception as e:
            OPSHIN_LOGGER.debug("Error when trying to unparse node: %s", e)
            return node
        if "print(" in node_source:
            # do not optimize away print statements
            return node
        try:
            # we add preceding constant plutusdata definitions here!
            g = self._non_overwritten_globals()
            l = self._constant_vars()
            node_eval = eval(node_source, g, l)
        except Exception as e:
            OPSHIN_LOGGER.debug("Error trying to evaluate node: %s", e)
            return node

        if any(
            isinstance(node_eval, t)
            for t in ACCEPTED_ATOMIC_TYPES + [list, dict, PlutusData]
        ) and not (node_eval == [] or node_eval == {}):
            new_node = Constant(node_eval, None)
            copy_location(new_node, node)
            return new_node
        return node

Classes

class DefinedTimesVisitor

A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the visit method.

This class is meant to be subclassed, with the subclass adding visitor methods.

Per default the visitor functions for the nodes are 'visit_' + class name of the node. So a TryFinally node visit function would be visit_TryFinally. This behavior can be changed by overriding the visit method. If no visitor function exists for a node (return value None) the generic_visit visitor is used instead.

Don't use the NodeVisitor if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer) that allows modifications.

Expand source code
class DefinedTimesVisitor(CompilingNodeVisitor):
    step = "Collecting how often variables are written"

    def __init__(self):
        self.vars = defaultdict(int)

    def visit_For(self, node: For) -> None:
        # visit twice to have all names bumped to min 2 assignments
        self.generic_visit(node)
        self.generic_visit(node)
        return
        # TODO future items: use this together with guaranteed available
        # visit twice to have this name bumped to min 2 assignments
        self.visit(node.target)
        # visit the whole function
        self.generic_visit(node)

    def visit_While(self, node: While) -> None:
        # visit twice to have all names bumped to min 2 assignments
        self.generic_visit(node)
        self.generic_visit(node)
        return
        # TODO future items: use this together with guaranteed available

    def visit_If(self, node: If) -> None:
        # TODO future items: use this together with guaranteed available
        # visit twice to have all names bumped to min 2 assignments
        self.generic_visit(node)
        self.generic_visit(node)

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars[node.id] += 1

    def visit_ClassDef(self, node: ClassDef):
        self.vars[node.name] += 1
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars[node.name] += 1
        # visit arguments twice, they are generally assigned more than once
        for arg in node.args.args:
            self.vars[arg.arg] += 2
        self.generic_visit(node)

    def visit_Import(self, node: Import):
        for n in node.names:
            self.vars[n] += 1

    def visit_ImportFrom(self, node: ImportFrom):
        for n in node.names:
            self.vars[n] += 1

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeVisitor.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef)
def visit_For(self, node: ast.For) ‑> None
def visit_FunctionDef(self, node: ast.FunctionDef)
def visit_If(self, node: ast.If) ‑> None
def visit_Import(self, node: ast.Import)
def visit_ImportFrom(self, node: ast.ImportFrom)
def visit_Name(self, node: ast.Name) ‑> None
def visit_While(self, node: ast.While) ‑> None
class OptimizeConstantFolding

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class OptimizeConstantFolding(CompilingNodeTransformer):
    step = "Constant folding"

    def __init__(self):
        self.scopes_visible = [
            OrderedSet(INITIAL_SCOPE.keys()).difference(SAFE_GLOBALS.keys())
        ]
        self.scopes_constants = [dict()]
        self.constants = OrderedSet()

    def enter_scope(self):
        self.scopes_visible.append(OrderedSet())
        self.scopes_constants.append(dict())

    def add_var_visible(self, var: str):
        self.scopes_visible[-1].add(var)

    def add_vars_visible(self, var: typing.Iterable[str]):
        self.scopes_visible[-1].update(var)

    def add_constant(self, var: str, value: typing.Any):
        self.scopes_constants[-1][var] = value

    def visible_vars(self):
        res_set = OrderedSet()
        for s in self.scopes_visible:
            res_set.update(s)
        return res_set

    def _constant_vars(self):
        res_d = {}
        for s in self.scopes_constants:
            res_d.update(s)
        return res_d

    def exit_scope(self):
        self.scopes_visible.pop(-1)
        self.scopes_constants.pop(-1)

    def _non_overwritten_globals(self):
        overwritten_vars = self.visible_vars()

        def err():
            raise ValueError("Was overwritten!")

        non_overwritten_globals = {
            k: (v if k not in overwritten_vars else err)
            for k, v in SAFE_GLOBALS.items()
        }
        return non_overwritten_globals

    def update_constants(self, node):
        a = self._non_overwritten_globals()
        a.update(self._constant_vars())
        g = a
        l = {}
        try:
            exec(unparse(node), g, l)
        except Exception as e:
            OPSHIN_LOGGER.debug(e)
        else:
            # the class is defined and added to the globals
            self.scopes_constants[-1].update(l)

    def visit_Module(self, node: Module) -> Module:
        self.enter_scope()
        def_vars_collector = ShallowNameDefCollector()
        def_vars_collector.visit(node)
        def_vars = def_vars_collector.vars
        self.add_vars_visible(def_vars)

        constant_collector = DefinedTimesVisitor()
        constant_collector.visit(node)
        constants = constant_collector.vars
        # if it is only assigned exactly once, it must be a constant (due to immutability)
        self.constants = {c for c, i in constants.items() if i == 1}

        res = self.generic_visit(node)
        self.exit_scope()
        return res

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        self.add_var_visible(node.name)
        if node.name in self.constants:
            a = self._non_overwritten_globals()
            a.update(self._constant_vars())
            g = a
            try:
                # we need to pass the global dict as local dict here to make closures possible (rec functions)
                exec(unparse(node), g, g)
            except Exception as e:
                OPSHIN_LOGGER.debug(e)
            else:
                # the class is defined and added to the globals
                self.scopes_constants[-1][node.name] = g[node.name]

        self.enter_scope()
        self.add_vars_visible(arg.arg for arg in node.args.args)
        def_vars_collector = ShallowNameDefCollector()
        for s in node.body:
            def_vars_collector.visit(s)
        def_vars = def_vars_collector.vars
        self.add_vars_visible(def_vars)

        res_node = self.generic_visit(node)
        self.exit_scope()
        return res_node

    def visit_ClassDef(self, node: ClassDef):
        if node.name in self.constants:
            self.update_constants(node)
        return node

    def visit_ImportFrom(self, node: ImportFrom):
        if all(n in self.constants for n in node.names):
            self.update_constants(node)
        return node

    def visit_Import(self, node: Import):
        if all(n in self.constants for n in node.names):
            self.update_constants(node)
        return node

    def visit_Assign(self, node: Assign):
        if len(node.targets) != 1:
            return node
        target = node.targets[0]
        if not isinstance(target, Name):
            return node

        if target.id in self.constants:
            self.update_constants(node)
        node.value = self.visit(node.value)
        return node

    def visit_AnnAssign(self, node: AnnAssign):
        target = node.target
        if not isinstance(target, Name):
            return node

        if target.id in self.constants:
            self.update_constants(node)
        node.value = self.visit(node.value)
        return node

    def generic_visit(self, node: AST):
        node = super().generic_visit(node)
        if not isinstance(node, expr):
            # only evaluate expressions, not statements
            return node
        if isinstance(node, Constant):
            # prevents unneccessary computations
            return node
        try:
            node_source = unparse(node)
        except Exception as e:
            OPSHIN_LOGGER.debug("Error when trying to unparse node: %s", e)
            return node
        if "print(" in node_source:
            # do not optimize away print statements
            return node
        try:
            # we add preceding constant plutusdata definitions here!
            g = self._non_overwritten_globals()
            l = self._constant_vars()
            node_eval = eval(node_source, g, l)
        except Exception as e:
            OPSHIN_LOGGER.debug("Error trying to evaluate node: %s", e)
            return node

        if any(
            isinstance(node_eval, t)
            for t in ACCEPTED_ATOMIC_TYPES + [list, dict, PlutusData]
        ) and not (node_eval == [] or node_eval == {}):
            new_node = Constant(node_eval, None)
            copy_location(new_node, node)
            return new_node
        return node

Ancestors

Class variables

var step

Methods

def add_constant(self, var: str, value: Any)
def add_var_visible(self, var: str)
def add_vars_visible(self, var: Iterable[str])
def enter_scope(self)
def exit_scope(self)
def generic_visit(self, node: ast.AST)

Called if no explicit visitor function exists for a node.

def update_constants(self, node)
def visible_vars(self)
def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_AnnAssign(self, node: ast.AnnAssign)
def visit_Assign(self, node: ast.Assign)
def visit_ClassDef(self, node: ast.ClassDef)
def visit_FunctionDef(self, node: ast.FunctionDef) ‑> ast.FunctionDef
def visit_Import(self, node: ast.Import)
def visit_ImportFrom(self, node: ast.ImportFrom)
def visit_Module(self, node: ast.Module) ‑> ast.Module
class ShallowNameDefCollector

A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the visit method.

This class is meant to be subclassed, with the subclass adding visitor methods.

Per default the visitor functions for the nodes are 'visit_' + class name of the node. So a TryFinally node visit function would be visit_TryFinally. This behavior can be changed by overriding the visit method. If no visitor function exists for a node (return value None) the generic_visit visitor is used instead.

Don't use the NodeVisitor if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer) that allows modifications.

Expand source code
class ShallowNameDefCollector(CompilingNodeVisitor):
    step = "Collecting occuring variable names"

    def __init__(self):
        self.vars = OrderedSet()

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars.add(node.id)

    def visit_ClassDef(self, node: ClassDef):
        self.vars.add(node.name)
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars.add(node.name)
        # ignore the recursive stuff

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeVisitor.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef)
def visit_FunctionDef(self, node: ast.FunctionDef)
def visit_Name(self, node: ast.Name) ‑> None