Module opshin.optimize.optimize_remove_deadvars
Expand source code
from ast import *
from copy import copy
from collections import defaultdict
from ordered_set import OrderedSet
from ..util import CompilingNodeVisitor, CompilingNodeTransformer
from ..type_inference import INITIAL_SCOPE
from ..typed_ast import TypedAnnAssign, TypedFunctionDef, TypedClassDef, TypedName
"""
Removes assignments to variables that are never read
"""
class NameLoadCollector(CompilingNodeVisitor):
step = "Collecting used variables"
def __init__(self):
self.loaded = defaultdict(int)
def visit_Name(self, node: TypedName) -> None:
if isinstance(node.ctx, Load):
self.loaded[node.id] += 1
def visit_ClassDef(self, node: TypedClassDef):
# ignore the content (i.e. attribute names) of class definitions
pass
def visit_FunctionDef(self, node: TypedFunctionDef):
# ignore the type hints of function arguments
for s in node.body:
self.visit(s)
for v in node.typ.typ.bound_vars.keys():
self.loaded[v] += 1
if node.typ.typ.bind_self is not None:
self.loaded[node.typ.typ.bind_self] += 1
class SafeOperationVisitor(CompilingNodeVisitor):
step = "Collecting computations that can not throw errors"
def __init__(self, guaranteed_names):
self.guaranteed_names = guaranteed_names
def generic_visit(self, node: AST) -> bool:
# generally every operation is unsafe except we whitelist it
return False
def visit_Lambda(self, node: Lambda) -> bool:
# lambda definition is fine as it actually doesn't compute anything
return True
def visit_Constant(self, node: Constant) -> bool:
# Constants can not fail
return True
def visit_RawPlutoExpr(self, node) -> bool:
# these expressions are not evaluated further
return True
def visit_Name(self, node: Name) -> bool:
return node.id in self.guaranteed_names
class OptimizeRemoveDeadvars(CompilingNodeTransformer):
step = "Removing unused variables"
loaded_vars = None
# names that are guaranteed to be available to the current node
# this acts differently to the type inferencer! in particular, ite/while/for all produce their own scope
guaranteed_avail_names = [
list(INITIAL_SCOPE.keys()) + ["isinstance", "Union", "Dict", "List"]
]
def guaranteed(self, name: str) -> bool:
name = name
for scope in reversed(self.guaranteed_avail_names):
if name in scope:
return True
return False
def enter_scope(self):
self.guaranteed_avail_names.append([])
def exit_scope(self):
self.guaranteed_avail_names.pop()
def set_guaranteed(self, name: str):
self.guaranteed_avail_names[-1].append(name)
def visit_Module(self, node: Module) -> Module:
# repeat until no more change due to removal
# i.e. b = a; c = b needs 2 passes to remove c and b
node_cp = copy(node)
self.loaded_vars = None
while True:
self.enter_scope()
# collect all variable names
collector = NameLoadCollector()
collector.visit(node_cp)
loaded_vars = OrderedSet(collector.loaded.keys()) | {"validator_0"}
# break if the set of loaded vars did not change -> set of vars to remove does also not change
if loaded_vars == self.loaded_vars:
break
# remove unloaded ones
self.loaded_vars = loaded_vars
node_cp.body = [self.visit(s) for s in node_cp.body]
self.exit_scope()
return node_cp
def visit_If(self, node: If):
node_cp = copy(node)
node_cp.test = self.visit(node.test)
self.enter_scope()
node_cp.body = [self.visit(s) for s in node.body]
scope_body_cp = self.guaranteed_avail_names[-1].copy()
self.exit_scope()
self.enter_scope()
node_cp.orelse = [self.visit(s) for s in node.orelse]
scope_orelse_cp = self.guaranteed_avail_names[-1].copy()
self.exit_scope()
# what remains after this in the scope is the intersection of both
for var in OrderedSet(scope_body_cp).intersection(scope_orelse_cp):
self.set_guaranteed(var)
return node_cp
def visit_While(self, node: While):
node_cp = copy(node)
node_cp.test = self.visit(node.test)
self.enter_scope()
node_cp.body = [self.visit(s) for s in node.body]
node_cp.orelse = [self.visit(s) for s in node.orelse]
self.exit_scope()
return node_cp
def visit_For(self, node: For):
node_cp = copy(node)
assert isinstance(node.target, Name), "Can only assign to singleton name"
self.enter_scope()
self.guaranteed(node.target.id)
node_cp.body = [self.visit(s) for s in node.body]
node_cp.orelse = [self.visit(s) for s in node.orelse]
self.exit_scope()
return node_cp
def visit_Assign(self, node: Assign):
if (
len(node.targets) != 1
or not isinstance(node.targets[0], Name)
or node.targets[0].id in self.loaded_vars
or not SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit(
node.value
)
):
for t in node.targets:
assert isinstance(
t, Name
), "Need to have name for dead var remover to work"
self.set_guaranteed(t.id)
return self.generic_visit(node)
return Pass()
def visit_AnnAssign(self, node: TypedAnnAssign):
if (
not isinstance(node.target, Name)
or node.target.id in self.loaded_vars
or not SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit(
node.value
)
# only upcasts are safe!
or not node.target.typ >= node.value.typ
):
assert isinstance(
node.target, Name
), "Need to have assignments to name for dead var remover to work"
self.set_guaranteed(node.target.id)
return self.generic_visit(node)
return Pass()
def visit_ClassDef(self, node: ClassDef):
if node.name in self.loaded_vars:
self.set_guaranteed(node.name)
return node
return Pass()
def visit_FunctionDef(self, node: FunctionDef):
node_cp = copy(node)
if node.name in self.loaded_vars:
self.set_guaranteed(node.name)
self.enter_scope()
# variable names are available here
for a in node.args.args:
self.set_guaranteed(a.arg)
node_cp.body = [self.visit(s) for s in node.body]
self.exit_scope()
return node_cp
return Pass()
Classes
class NameLoadCollector
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class NameLoadCollector(CompilingNodeVisitor): step = "Collecting used variables" def __init__(self): self.loaded = defaultdict(int) def visit_Name(self, node: TypedName) -> None: if isinstance(node.ctx, Load): self.loaded[node.id] += 1 def visit_ClassDef(self, node: TypedClassDef): # ignore the content (i.e. attribute names) of class definitions pass def visit_FunctionDef(self, node: TypedFunctionDef): # ignore the type hints of function arguments for s in node.body: self.visit(s) for v in node.typ.typ.bound_vars.keys(): self.loaded[v] += 1 if node.typ.typ.bind_self is not None: self.loaded[node.typ.typ.bind_self] += 1
Ancestors
- CompilingNodeVisitor
- TypedNodeVisitor
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeVisitor
.visit
Visit a node.
def visit_ClassDef(self, node: TypedClassDef)
def visit_FunctionDef(self, node: TypedFunctionDef)
def visit_Name(self, node: TypedName) ‑> None
class OptimizeRemoveDeadvars
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class OptimizeRemoveDeadvars(CompilingNodeTransformer): step = "Removing unused variables" loaded_vars = None # names that are guaranteed to be available to the current node # this acts differently to the type inferencer! in particular, ite/while/for all produce their own scope guaranteed_avail_names = [ list(INITIAL_SCOPE.keys()) + ["isinstance", "Union", "Dict", "List"] ] def guaranteed(self, name: str) -> bool: name = name for scope in reversed(self.guaranteed_avail_names): if name in scope: return True return False def enter_scope(self): self.guaranteed_avail_names.append([]) def exit_scope(self): self.guaranteed_avail_names.pop() def set_guaranteed(self, name: str): self.guaranteed_avail_names[-1].append(name) def visit_Module(self, node: Module) -> Module: # repeat until no more change due to removal # i.e. b = a; c = b needs 2 passes to remove c and b node_cp = copy(node) self.loaded_vars = None while True: self.enter_scope() # collect all variable names collector = NameLoadCollector() collector.visit(node_cp) loaded_vars = OrderedSet(collector.loaded.keys()) | {"validator_0"} # break if the set of loaded vars did not change -> set of vars to remove does also not change if loaded_vars == self.loaded_vars: break # remove unloaded ones self.loaded_vars = loaded_vars node_cp.body = [self.visit(s) for s in node_cp.body] self.exit_scope() return node_cp def visit_If(self, node: If): node_cp = copy(node) node_cp.test = self.visit(node.test) self.enter_scope() node_cp.body = [self.visit(s) for s in node.body] scope_body_cp = self.guaranteed_avail_names[-1].copy() self.exit_scope() self.enter_scope() node_cp.orelse = [self.visit(s) for s in node.orelse] scope_orelse_cp = self.guaranteed_avail_names[-1].copy() self.exit_scope() # what remains after this in the scope is the intersection of both for var in OrderedSet(scope_body_cp).intersection(scope_orelse_cp): self.set_guaranteed(var) return node_cp def visit_While(self, node: While): node_cp = copy(node) node_cp.test = self.visit(node.test) self.enter_scope() node_cp.body = [self.visit(s) for s in node.body] node_cp.orelse = [self.visit(s) for s in node.orelse] self.exit_scope() return node_cp def visit_For(self, node: For): node_cp = copy(node) assert isinstance(node.target, Name), "Can only assign to singleton name" self.enter_scope() self.guaranteed(node.target.id) node_cp.body = [self.visit(s) for s in node.body] node_cp.orelse = [self.visit(s) for s in node.orelse] self.exit_scope() return node_cp def visit_Assign(self, node: Assign): if ( len(node.targets) != 1 or not isinstance(node.targets[0], Name) or node.targets[0].id in self.loaded_vars or not SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit( node.value ) ): for t in node.targets: assert isinstance( t, Name ), "Need to have name for dead var remover to work" self.set_guaranteed(t.id) return self.generic_visit(node) return Pass() def visit_AnnAssign(self, node: TypedAnnAssign): if ( not isinstance(node.target, Name) or node.target.id in self.loaded_vars or not SafeOperationVisitor(sum(self.guaranteed_avail_names, [])).visit( node.value ) # only upcasts are safe! or not node.target.typ >= node.value.typ ): assert isinstance( node.target, Name ), "Need to have assignments to name for dead var remover to work" self.set_guaranteed(node.target.id) return self.generic_visit(node) return Pass() def visit_ClassDef(self, node: ClassDef): if node.name in self.loaded_vars: self.set_guaranteed(node.name) return node return Pass() def visit_FunctionDef(self, node: FunctionDef): node_cp = copy(node) if node.name in self.loaded_vars: self.set_guaranteed(node.name) self.enter_scope() # variable names are available here for a in node.args.args: self.set_guaranteed(a.arg) node_cp.body = [self.visit(s) for s in node.body] self.exit_scope() return node_cp return Pass()
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var guaranteed_avail_names
var loaded_vars
var step
Methods
def enter_scope(self)
def exit_scope(self)
def guaranteed(self, name: str) ‑> bool
def set_guaranteed(self, name: str)
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.
def visit_AnnAssign(self, node: TypedAnnAssign)
def visit_Assign(self, node: ast.Assign)
def visit_ClassDef(self, node: ast.ClassDef)
def visit_For(self, node: ast.For)
def visit_FunctionDef(self, node: ast.FunctionDef)
def visit_If(self, node: ast.If)
def visit_Module(self, node: ast.Module) ‑> ast.Module
def visit_While(self, node: ast.While)
class SafeOperationVisitor (guaranteed_names)
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class SafeOperationVisitor(CompilingNodeVisitor): step = "Collecting computations that can not throw errors" def __init__(self, guaranteed_names): self.guaranteed_names = guaranteed_names def generic_visit(self, node: AST) -> bool: # generally every operation is unsafe except we whitelist it return False def visit_Lambda(self, node: Lambda) -> bool: # lambda definition is fine as it actually doesn't compute anything return True def visit_Constant(self, node: Constant) -> bool: # Constants can not fail return True def visit_RawPlutoExpr(self, node) -> bool: # these expressions are not evaluated further return True def visit_Name(self, node: Name) -> bool: return node.id in self.guaranteed_names
Ancestors
- CompilingNodeVisitor
- TypedNodeVisitor
- ast.NodeVisitor
Class variables
var step
Methods
def generic_visit(self, node: ast.AST) ‑> bool
-
Called if no explicit visitor function exists for a node.
def visit(self, node)
-
Inherited from:
CompilingNodeVisitor
.visit
Visit a node.
def visit_Constant(self, node: ast.Constant) ‑> bool
def visit_Lambda(self, node: ast.Lambda) ‑> bool
def visit_Name(self, node: ast.Name) ‑> bool
def visit_RawPlutoExpr(self, node) ‑> bool