Module opshin.util
Expand source code
from _ast import Name, Store, ClassDef, FunctionDef, Load
from collections import defaultdict
from copy import copy, deepcopy
import logging
import typing
import ast
from dataclasses import dataclass
import pycardano
from frozendict import frozendict
from frozenlist2 import frozenlist
import uplc.ast as uplc
import pluthon as plt
from hashlib import sha256
OPSHIN_LOGGER = logging.getLogger("opshin")
OPSHIN_LOG_HANDLER = logging.StreamHandler()
OPSHIN_LOGGER.addHandler(OPSHIN_LOG_HANDLER)
class FileContextFilter(logging.Filter):
"""
This is a filter which injects contextual information into the log.
The information is about the currently inspected AST node.
The information needs to be updated inside the NodeTransformer and NodeVisitor classes.
"""
file_name = "unknown"
node: ast.AST = None
def filter(self, record):
if self.node is None:
record.lineno = 1
record.col_offset = 0
record.end_lineno = 1
record.end_col_offset = 0
else:
record.lineno = self.node.lineno
record.col_offset = self.node.col_offset
record.end_lineno = self.node.end_lineno
record.end_col_offset = self.node.end_col_offset
return True
OPSHIN_LOG_CONTEXT_FILTER = FileContextFilter()
OPSHIN_LOG_HANDLER.addFilter(OPSHIN_LOG_CONTEXT_FILTER)
def distinct(xs: list):
"""Returns true iff the list consists of distinct elements"""
return len(xs) == len(set(xs))
class TypedNodeTransformer(ast.NodeTransformer):
def visit(self, node):
"""Visit a node."""
OPSHIN_LOG_CONTEXT_FILTER.node = node
node_class_name = node.__class__.__name__
if node_class_name.startswith("Typed"):
node_class_name = node_class_name[len("Typed") :]
method = "visit_" + node_class_name
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
class TypedNodeVisitor(ast.NodeVisitor):
def visit(self, node):
"""Visit a node."""
OPSHIN_LOG_CONTEXT_FILTER.node = node
node_class_name = node.__class__.__name__
if node_class_name.startswith("Typed"):
node_class_name = node_class_name[len("Typed") :]
method = "visit_" + node_class_name
visitor = getattr(self, method, self.generic_visit)
return visitor(node)
class CompilerError(Exception):
def __init__(self, orig_err: Exception, node: ast.AST, compilation_step: str):
self.orig_err = orig_err
self.node = node
self.compilation_step = compilation_step
class CompilingNodeTransformer(TypedNodeTransformer):
step = "Node transformation"
def visit(self, node):
OPSHIN_LOG_CONTEXT_FILTER.node = node
try:
return super().visit(node)
except Exception as e:
if isinstance(e, CompilerError):
raise e
raise CompilerError(e, node, self.step)
class NoOp(CompilingNodeTransformer):
"""A variation of the Compiling Node transformer that performs no changes"""
pass
class CompilingNodeVisitor(TypedNodeVisitor):
step = "Node visiting"
def visit(self, node):
try:
return super().visit(node)
except Exception as e:
if isinstance(e, CompilerError):
raise e
raise CompilerError(e, node, self.step)
def data_from_json(j: typing.Dict[str, typing.Any]) -> uplc.PlutusData:
if "bytes" in j:
return uplc.PlutusByteString(bytes.fromhex(j["bytes"]))
if "int" in j:
return uplc.PlutusInteger(int(j["int"]))
if "list" in j:
return uplc.PlutusList(frozenlist(list(map(data_from_json, j["list"]))))
if "map" in j:
return uplc.PlutusMap(
frozendict(
{data_from_json(d["k"]): data_from_json(d["v"]) for d in j["map"]}
)
)
if "constructor" in j and "fields" in j:
return uplc.PlutusConstr(
j["constructor"], frozenlist(list(map(data_from_json, j["fields"])))
)
raise NotImplementedError(f"Unknown datum representation {j}")
def datum_to_cbor(d: pycardano.Datum) -> bytes:
return pycardano.PlutusData.to_cbor(d)
def datum_to_json(d: pycardano.Datum) -> str:
return pycardano.PlutusData.to_json(d)
def custom_fix_missing_locations(node, parent=None):
"""
Works like ast.fix_missing_location but forces it onto everything
"""
def _fix(node, lineno, col_offset, end_lineno, end_col_offset):
if getattr(node, "lineno", None) is None:
node.lineno = lineno
else:
lineno = node.lineno
if getattr(node, "end_lineno", None) is None:
node.end_lineno = end_lineno
else:
end_lineno = node.end_lineno
if getattr(node, "col_offset", None) is None:
node.col_offset = col_offset
else:
col_offset = node.col_offset
if getattr(node, "end_col_offset", None) is None:
node.end_col_offset = end_col_offset
else:
end_col_offset = node.end_col_offset
for child in ast.iter_child_nodes(node):
_fix(child, lineno, col_offset, end_lineno, end_col_offset)
lineno, col_offset, end_lineno, end_col_offset = (
getattr(parent, "lineno", 1),
getattr(parent, "col_offset", 0),
getattr(parent, "end_lineno", 1),
getattr(parent, "end_col_offset", 0),
)
_fix(node, lineno, col_offset, end_lineno, end_col_offset)
return node
_patterns_cached = {}
def make_pattern(structure: plt.AST) -> plt.Pattern:
"""Creates a shared pattern from the given lambda, cached so that it is re-used in subsequent calls"""
structure_serialized = structure.dumps()
if _patterns_cached.get(structure_serialized) is None:
# @dataclass
# class AdHocPattern(plt.Pattern):
# def compose(self):
# return structure
AdHocPattern = type(
f"AdHocPattern_{sha256(structure_serialized.encode()).digest().hex()}",
(plt.Pattern,),
{"compose": lambda self: deepcopy(structure)},
)
AdHocPattern = dataclass(AdHocPattern)
_patterns_cached[structure_serialized] = AdHocPattern()
return deepcopy(_patterns_cached[structure_serialized])
def patternize(method):
def wrapped(*args, **kwargs):
return make_pattern(method(*args, **kwargs))
return wrapped
def force_params(lmd: plt.Lambda) -> plt.Lambda:
if isinstance(lmd, plt.Lambda):
return plt.Lambda(
lmd.vars, plt.Let([(v, plt.Force(plt.Var(v))) for v in lmd.vars], lmd.term)
)
if isinstance(lmd, plt.Pattern):
return make_pattern(force_params(lmd.compose()))
class NameWriteCollector(CompilingNodeVisitor):
step = "Collecting variables that are written"
def __init__(self):
self.written = defaultdict(int)
def visit_Name(self, node: Name) -> None:
if isinstance(node.ctx, Store):
self.written[node.id] += 1
def visit_ClassDef(self, node: ClassDef):
# ignore the content (i.e. attribute names) of class definitions
self.written[node.name] += 1
pass
def visit_FunctionDef(self, node: FunctionDef):
# ignore the type hints of function arguments
self.written[node.name] += 1
for a in node.args.args:
self.written[a.arg] += 1
for s in node.body:
self.visit(s)
def written_vars(node):
"""
Returns all variable names written to in this node
"""
collector = NameWriteCollector()
collector.visit(node)
return sorted(collector.written.keys())
class NameReadCollector(CompilingNodeVisitor):
step = "Collecting variables that are read"
def __init__(self):
self.read = defaultdict(int)
def visit_AnnAssign(self, node) -> None:
# ignore annotations of variables
self.visit(node.value)
self.visit(node.target)
def visit_FunctionDef(self, node) -> None:
# ignore annotations of paramters and return
for b in node.body:
self.visit(b)
def visit_Name(self, node: Name) -> None:
if isinstance(node.ctx, Load):
self.read[node.id] += 1
def visit_ClassDef(self, node: ClassDef):
# ignore the content (i.e. attribute names) of class definitions
pass
def read_vars(node):
"""
Returns all variable names read to in this node
"""
collector = NameReadCollector()
collector.visit(node)
return sorted(collector.read.keys())
def all_vars(node):
return sorted(set(read_vars(node) + written_vars(node)))
def externally_bound_vars(node: FunctionDef):
"""A superset of the variables bound from an outer scope"""
return sorted(set(read_vars(node)) - set(written_vars(node)) - {"isinstance"})
def opshin_name_scheme_compatible_varname(n: str):
return f"1{n}"
def OVar(name: str):
return plt.Var(opshin_name_scheme_compatible_varname(name))
def OLambda(names: typing.List[str], term: plt.AST):
return plt.Lambda([opshin_name_scheme_compatible_varname(x) for x in names], term)
def OLet(bindings: typing.List[typing.Tuple[str, plt.AST]], term: plt.AST):
return plt.Let(
[(opshin_name_scheme_compatible_varname(n), t) for n, t in bindings], term
)
def SafeLambda(vars: typing.List[str], term: plt.AST) -> plt.Lambda:
if not vars:
return plt.Lambda(["0_"], term)
return plt.Lambda(vars, term)
def SafeOLambda(vars: typing.List[str], term: plt.AST) -> plt.Lambda:
if not vars:
return OLambda(["0_"], term)
return OLambda(vars, term)
def SafeApply(term: plt.AST, *vars: typing.List[plt.AST]) -> plt.Apply:
if not vars:
return plt.Apply(term, plt.Delay(plt.Unit()))
return plt.Apply(term, *vars)
Functions
def OLambda(names: List[str], term: pluthon.pluthon_ast.AST)
def OLet(bindings: List[Tuple[str, pluthon.pluthon_ast.AST]], term: pluthon.pluthon_ast.AST)
def OVar(name: str)
def SafeApply(term: pluthon.pluthon_ast.AST, *vars: List[pluthon.pluthon_ast.AST]) ‑> pluthon.pluthon_ast.Apply
def SafeLambda(vars: List[str], term: pluthon.pluthon_ast.AST) ‑> pluthon.pluthon_ast.Lambda
def SafeOLambda(vars: List[str], term: pluthon.pluthon_ast.AST) ‑> pluthon.pluthon_ast.Lambda
def all_vars(node)
def custom_fix_missing_locations(node, parent=None)
-
Works like ast.fix_missing_location but forces it onto everything
def data_from_json(j: Dict[str, Any]) ‑> uplc.ast.PlutusData
def datum_to_cbor(d: pycardano.plutus.PlutusData | dict | int | bytes | pycardano.serialization.IndefiniteList | pycardano.serialization.RawCBOR | pycardano.plutus.RawPlutusData) ‑> bytes
def datum_to_json(d: pycardano.plutus.PlutusData | dict | int | bytes | pycardano.serialization.IndefiniteList | pycardano.serialization.RawCBOR | pycardano.plutus.RawPlutusData) ‑> str
def distinct(xs: list)
-
Returns true iff the list consists of distinct elements
def externally_bound_vars(node: ast.FunctionDef)
-
A superset of the variables bound from an outer scope
def force_params(lmd: pluthon.pluthon_ast.Lambda) ‑> pluthon.pluthon_ast.Lambda
def make_pattern(structure: pluthon.pluthon_ast.AST) ‑> pluthon.pluthon_ast.Pattern
-
Creates a shared pattern from the given lambda, cached so that it is re-used in subsequent calls
def opshin_name_scheme_compatible_varname(n: str)
def patternize(method)
def read_vars(node)
-
Returns all variable names read to in this node
def written_vars(node)
-
Returns all variable names written to in this node
Classes
class CompilerError (orig_err: Exception, node: ast.AST, compilation_step: str)
-
Common base class for all non-exit exceptions.
Expand source code
class CompilerError(Exception): def __init__(self, orig_err: Exception, node: ast.AST, compilation_step: str): self.orig_err = orig_err self.node = node self.compilation_step = compilation_step
Ancestors
- builtins.Exception
- builtins.BaseException
class CompilingNodeTransformer
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class CompilingNodeTransformer(TypedNodeTransformer): step = "Node transformation" def visit(self, node): OPSHIN_LOG_CONTEXT_FILTER.node = node try: return super().visit(node) except Exception as e: if isinstance(e, CompilerError): raise e raise CompilerError(e, node, self.step)
Ancestors
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Subclasses
- PlutoCompiler
- OptimizeConstantFolding
- OptimizeRemoveDeadconstants
- OptimizeRemoveDeadvars
- OptimizeRemovePass
- RewriteAugAssign
- RewriteConditions
- RewriteComparisonChaining
- RewriteEmptyDicts
- RewriteEmptyLists
- RewriteForbiddenOverwrites
- RewriteForbiddenReturn
- RewriteImport
- RewriteLocation
- RewriteImportDataclasses
- RewriteImportHashlib
- RewriteImportIntegrityCheck
- RewriteImportPlutusData
- RewriteImportTyping
- RewriteImportUPLCBuiltins
- RewriteInjectBuiltinsConstr
- RewriteInjectBuiltins
- RewriteOrigName
- RewriteRemoveTypeStuff
- RewriteScoping
- RewriteSubscript38
- RewriteTupleAssign
- AggressiveTypeInferencer
- NoOp
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
TypedNodeTransformer
.visit
Visit a node.
class CompilingNodeVisitor
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class CompilingNodeVisitor(TypedNodeVisitor): step = "Node visiting" def visit(self, node): try: return super().visit(node) except Exception as e: if isinstance(e, CompilerError): raise e raise CompilerError(e, node, self.step)
Ancestors
- TypedNodeVisitor
- ast.NodeVisitor
Subclasses
- DefinedTimesVisitor
- ShallowNameDefCollector
- NameLoadCollector
- SafeOperationVisitor
- ShallowNameDefCollector
- NameReadCollector
- NameWriteCollector
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
TypedNodeVisitor
.visit
Visit a node.
class FileContextFilter (name='')
-
This is a filter which injects contextual information into the log.
The information is about the currently inspected AST node. The information needs to be updated inside the NodeTransformer and NodeVisitor classes.
Initialize a filter.
Initialize with the name of the logger which, together with its children, will have its events allowed through the filter. If no name is specified, allow every event.
Expand source code
class FileContextFilter(logging.Filter): """ This is a filter which injects contextual information into the log. The information is about the currently inspected AST node. The information needs to be updated inside the NodeTransformer and NodeVisitor classes. """ file_name = "unknown" node: ast.AST = None def filter(self, record): if self.node is None: record.lineno = 1 record.col_offset = 0 record.end_lineno = 1 record.end_col_offset = 0 else: record.lineno = self.node.lineno record.col_offset = self.node.col_offset record.end_lineno = self.node.end_lineno record.end_col_offset = self.node.end_col_offset return True
Ancestors
- logging.Filter
Class variables
var file_name
var node : ast.AST
Methods
def filter(self, record)
-
Determine if the specified record is to be logged.
Returns True if the record should be logged, or False otherwise. If deemed appropriate, the record may be modified in-place.
class NameReadCollector
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class NameReadCollector(CompilingNodeVisitor): step = "Collecting variables that are read" def __init__(self): self.read = defaultdict(int) def visit_AnnAssign(self, node) -> None: # ignore annotations of variables self.visit(node.value) self.visit(node.target) def visit_FunctionDef(self, node) -> None: # ignore annotations of paramters and return for b in node.body: self.visit(b) def visit_Name(self, node: Name) -> None: if isinstance(node.ctx, Load): self.read[node.id] += 1 def visit_ClassDef(self, node: ClassDef): # ignore the content (i.e. attribute names) of class definitions pass
Ancestors
- CompilingNodeVisitor
- TypedNodeVisitor
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeVisitor
.visit
Visit a node.
def visit_AnnAssign(self, node) ‑> None
def visit_ClassDef(self, node: ast.ClassDef)
def visit_FunctionDef(self, node) ‑> None
def visit_Name(self, node: ast.Name) ‑> None
class NameWriteCollector
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class NameWriteCollector(CompilingNodeVisitor): step = "Collecting variables that are written" def __init__(self): self.written = defaultdict(int) def visit_Name(self, node: Name) -> None: if isinstance(node.ctx, Store): self.written[node.id] += 1 def visit_ClassDef(self, node: ClassDef): # ignore the content (i.e. attribute names) of class definitions self.written[node.name] += 1 pass def visit_FunctionDef(self, node: FunctionDef): # ignore the type hints of function arguments self.written[node.name] += 1 for a in node.args.args: self.written[a.arg] += 1 for s in node.body: self.visit(s)
Ancestors
- CompilingNodeVisitor
- TypedNodeVisitor
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeVisitor
.visit
Visit a node.
def visit_ClassDef(self, node: ast.ClassDef)
def visit_FunctionDef(self, node: ast.FunctionDef)
def visit_Name(self, node: ast.Name) ‑> None
class NoOp
-
A variation of the Compiling Node transformer that performs no changes
Expand source code
class NoOp(CompilingNodeTransformer): """A variation of the Compiling Node transformer that performs no changes""" pass
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.
class TypedNodeTransformer
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class TypedNodeTransformer(ast.NodeTransformer): def visit(self, node): """Visit a node.""" OPSHIN_LOG_CONTEXT_FILTER.node = node node_class_name = node.__class__.__name__ if node_class_name.startswith("Typed"): node_class_name = node_class_name[len("Typed") :] method = "visit_" + node_class_name visitor = getattr(self, method, self.generic_visit) return visitor(node)
Ancestors
- ast.NodeTransformer
- ast.NodeVisitor
Subclasses
Methods
def visit(self, node)
-
Visit a node.
class TypedNodeVisitor
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class TypedNodeVisitor(ast.NodeVisitor): def visit(self, node): """Visit a node.""" OPSHIN_LOG_CONTEXT_FILTER.node = node node_class_name = node.__class__.__name__ if node_class_name.startswith("Typed"): node_class_name = node_class_name[len("Typed") :] method = "visit_" + node_class_name visitor = getattr(self, method, self.generic_visit) return visitor(node)
Ancestors
- ast.NodeVisitor
Subclasses
Methods
def visit(self, node)
-
Visit a node.