Module opshin.optimize.optimize_remove_deadconstants
Expand source code
from ast import *
from copy import copy
from ..util import CompilingNodeVisitor, CompilingNodeTransformer
from ..type_inference import INITIAL_SCOPE
"""
Removes expressions that are safely side effect free in sequences of statements
(e.g. constants, names, lambdas, string comments)
"""
class SafeOperationVisitor(CompilingNodeVisitor):
step = "Collecting computations that can not throw errors"
def __init__(self, guaranteed_names):
self.guaranteed_names = guaranteed_names
def generic_visit(self, node: AST) -> bool:
# generally every operation is unsafe except we whitelist it
return False
def visit_Lambda(self, node: Lambda) -> bool:
# lambda definition is fine as it actually doesn't compute anything
return True
def visit_Constant(self, node: Constant) -> bool:
# Constants can not fail
return True
def visit_RawPlutoExpr(self, node) -> bool:
# these expressions are not evaluated further
return True
def visit_Name(self, node: Name) -> bool:
return node.id in self.guaranteed_names
class OptimizeRemoveDeadConstants(CompilingNodeTransformer):
step = "Removing dead expressions"
guaranteed_avail_names = [
list(INITIAL_SCOPE.keys()) + ["isinstance", "Union", "Dict", "List"]
]
def enter_scope(self):
self.guaranteed_avail_names.append([])
def exit_scope(self):
self.guaranteed_avail_names.pop()
def set_guaranteed(self, name: str):
self.guaranteed_avail_names[-1].append(name)
def visit_stmts(self, stmts):
res = []
for s in stmts:
r = self.visit(s)
if r is not None:
res.append(r)
return res
def visit_Module(self, node: Module):
node_cp = copy(node)
self.enter_scope()
node_cp.body = self.visit_stmts(node.body)
self.exit_scope()
return node_cp
def visit_Expr(self, node: Expr):
if SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit(node.value):
return None
return node
def visit_FunctionDef(self, node: FunctionDef):
node_cp = copy(node)
self.set_guaranteed(node.name)
self.enter_scope()
for a in node.args.args:
self.set_guaranteed(a.arg)
node_cp.body = self.visit_stmts(node.body)
self.exit_scope()
return node_cp
def visit_Assign(self, node: Assign):
for t in node.targets:
if isinstance(t, Name):
self.set_guaranteed(t.id)
return self.generic_visit(node)
def visit_AnnAssign(self, node: AnnAssign):
if isinstance(node.target, Name):
self.set_guaranteed(node.target.id)
return self.generic_visit(node)
Classes
class OptimizeRemoveDeadConstants-
A :class:
NodeVisitorsubclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformerwill walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo) todata['foo']::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visitmethod for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class OptimizeRemoveDeadConstants(CompilingNodeTransformer): step = "Removing dead expressions" guaranteed_avail_names = [ list(INITIAL_SCOPE.keys()) + ["isinstance", "Union", "Dict", "List"] ] def enter_scope(self): self.guaranteed_avail_names.append([]) def exit_scope(self): self.guaranteed_avail_names.pop() def set_guaranteed(self, name: str): self.guaranteed_avail_names[-1].append(name) def visit_stmts(self, stmts): res = [] for s in stmts: r = self.visit(s) if r is not None: res.append(r) return res def visit_Module(self, node: Module): node_cp = copy(node) self.enter_scope() node_cp.body = self.visit_stmts(node.body) self.exit_scope() return node_cp def visit_Expr(self, node: Expr): if SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit(node.value): return None return node def visit_FunctionDef(self, node: FunctionDef): node_cp = copy(node) self.set_guaranteed(node.name) self.enter_scope() for a in node.args.args: self.set_guaranteed(a.arg) node_cp.body = self.visit_stmts(node.body) self.exit_scope() return node_cp def visit_Assign(self, node: Assign): for t in node.targets: if isinstance(t, Name): self.set_guaranteed(t.id) return self.generic_visit(node) def visit_AnnAssign(self, node: AnnAssign): if isinstance(node.target, Name): self.set_guaranteed(node.target.id) return self.generic_visit(node)Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var guaranteed_avail_namesvar step
Methods
def enter_scope(self)-
Expand source code
def enter_scope(self): self.guaranteed_avail_names.append([]) def exit_scope(self)-
Expand source code
def exit_scope(self): self.guaranteed_avail_names.pop() def set_guaranteed(self, name: str)-
Expand source code
def set_guaranteed(self, name: str): self.guaranteed_avail_names[-1].append(name) def visit(self, node)-
Inherited from:
CompilingNodeTransformer.visitVisit a node.
def visit_AnnAssign(self, node: ast.AnnAssign)-
Expand source code
def visit_AnnAssign(self, node: AnnAssign): if isinstance(node.target, Name): self.set_guaranteed(node.target.id) return self.generic_visit(node) def visit_Assign(self, node: ast.Assign)-
Expand source code
def visit_Assign(self, node: Assign): for t in node.targets: if isinstance(t, Name): self.set_guaranteed(t.id) return self.generic_visit(node) def visit_Expr(self, node: ast.Expr)-
Expand source code
def visit_Expr(self, node: Expr): if SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit(node.value): return None return node def visit_FunctionDef(self, node: ast.FunctionDef)-
Expand source code
def visit_FunctionDef(self, node: FunctionDef): node_cp = copy(node) self.set_guaranteed(node.name) self.enter_scope() for a in node.args.args: self.set_guaranteed(a.arg) node_cp.body = self.visit_stmts(node.body) self.exit_scope() return node_cp def visit_Module(self, node: ast.Module)-
Expand source code
def visit_Module(self, node: Module): node_cp = copy(node) self.enter_scope() node_cp.body = self.visit_stmts(node.body) self.exit_scope() return node_cp def visit_stmts(self, stmts)-
Expand source code
def visit_stmts(self, stmts): res = [] for s in stmts: r = self.visit(s) if r is not None: res.append(r) return res
class SafeOperationVisitor (guaranteed_names)-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visitmethod.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'+ class name of the node. So aTryFinallynode visit function would bevisit_TryFinally. This behavior can be changed by overriding thevisitmethod. If no visitor function exists for a node (return valueNone) thegeneric_visitvisitor is used instead.Don't use the
NodeVisitorif you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer) that allows modifications.Expand source code
class SafeOperationVisitor(CompilingNodeVisitor): step = "Collecting computations that can not throw errors" def __init__(self, guaranteed_names): self.guaranteed_names = guaranteed_names def generic_visit(self, node: AST) -> bool: # generally every operation is unsafe except we whitelist it return False def visit_Lambda(self, node: Lambda) -> bool: # lambda definition is fine as it actually doesn't compute anything return True def visit_Constant(self, node: Constant) -> bool: # Constants can not fail return True def visit_RawPlutoExpr(self, node) -> bool: # these expressions are not evaluated further return True def visit_Name(self, node: Name) -> bool: return node.id in self.guaranteed_namesAncestors
- CompilingNodeVisitor
- TypedNodeVisitor
- ast.NodeVisitor
Class variables
var step
Methods
def generic_visit(self, node: ast.AST) ‑> bool-
Called if no explicit visitor function exists for a node.
Expand source code
def generic_visit(self, node: AST) -> bool: # generally every operation is unsafe except we whitelist it return False def visit(self, node)-
Inherited from:
CompilingNodeVisitor.visitVisit a node.
def visit_Constant(self, node: ast.Constant) ‑> bool-
Expand source code
def visit_Constant(self, node: Constant) -> bool: # Constants can not fail return True def visit_Lambda(self, node: ast.Lambda) ‑> bool-
Expand source code
def visit_Lambda(self, node: Lambda) -> bool: # lambda definition is fine as it actually doesn't compute anything return True def visit_Name(self, node: ast.Name) ‑> bool-
Expand source code
def visit_Name(self, node: Name) -> bool: return node.id in self.guaranteed_names def visit_RawPlutoExpr(self, node) ‑> bool-
Expand source code
def visit_RawPlutoExpr(self, node) -> bool: # these expressions are not evaluated further return True