Module opshin.rewrite.rewrite_scoping
Expand source code
import typing
from ast import *
from copy import copy
from ordered_set import OrderedSet
from .rewrite_forbidden_overwrites import FORBIDDEN_NAMES
from ..type_inference import INITIAL_SCOPE
from ..util import CompilingNodeTransformer, CompilingNodeVisitor
"""
Rewrites all variable names to point to the definition in the nearest enclosing scope
"""
class ShallowNameDefCollector(CompilingNodeVisitor):
step = "Collecting defined variable names"
def __init__(self):
self.vars = OrderedSet()
def visit_Name(self, node: Name) -> None:
if isinstance(node.ctx, Store):
self.vars.add(node.id)
def visit_ClassDef(self, node: ClassDef):
self.vars.add(node.name)
# methods will be put in global scope so add them now
for attribute in node.body:
if isinstance(attribute, FunctionDef):
self.vars.add(attribute.name)
# ignore the content (i.e. attribute names) of class definitions
def visit_FunctionDef(self, node: FunctionDef):
self.vars.add(node.name)
# ignore the recursive stuff
class RewriteScoping(CompilingNodeTransformer):
step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
latest_scope_id: int
scopes: typing.List[typing.Tuple[OrderedSet, int]]
current_Self: typing.Tuple[str, str]
def variable_scope_id(self, name: str) -> int:
"""find the id of the scope in which this variable is defined (closest to its usage)"""
name = name
for scope, scope_id in reversed(self.scopes):
if name in scope:
return scope_id
raise NameError(
f"free variable '{name}' referenced before assignment in enclosing scope"
)
def enter_scope(self):
self.scopes.append((OrderedSet(), self.latest_scope_id))
self.latest_scope_id += 1
def exit_scope(self):
self.scopes.pop()
def set_variable_scope(self, name: str):
self.scopes[-1][0].add(name)
def map_name(self, name: str):
scope_id = self.variable_scope_id(name)
if scope_id == -1:
# do not rewrite Dict, Union, etc
return name
return f"{name}_{scope_id}"
def visit_Module(self, node: Module) -> Module:
self.latest_scope_id = 0
self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)]
node_cp = copy(node)
self.enter_scope()
# vars defined in this scope
shallow_node_def_collector = ShallowNameDefCollector()
for s in node.body:
shallow_node_def_collector.visit(s)
vars_def = shallow_node_def_collector.vars
for var_name in vars_def:
self.set_variable_scope(var_name)
node_cp.body = [self.visit(s) for s in node.body]
return node_cp
def visit_Name(self, node: Name) -> Name:
nc = copy(node)
# setting is handled in either enclosing module or function
if node.id == "Self":
assert node.idSelf == self.current_Self[1]
nc.idSelf_new = self.current_Self[0]
nc.id = self.map_name(node.id)
return nc
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
cp_node = RecordScoper.scope(node, self)
for i, attribute in enumerate(cp_node.body):
if isinstance(attribute, FunctionDef):
self.current_Self = (cp_node.name, cp_node.orig_name)
cp_node.body[i] = self.visit_FunctionDef(attribute, method=True)
return cp_node
def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> FunctionDef:
node_cp = copy(node)
# setting is handled in either enclosing module or function
node_cp.name = self.map_name(node.name) if not method else node.name
self.enter_scope()
node_cp.args = copy(node.args)
node_cp.args.args = []
# args are defined in this scope
for a in node.args.args:
a_cp = copy(a)
self.set_variable_scope(a.arg)
a_cp.arg = self.map_name(a.arg)
a_cp.annotation = self.visit(a.annotation)
node_cp.args.args.append(a_cp)
node_cp.returns = self.visit(node.returns)
# vars defined in this scope
shallow_node_def_collector = ShallowNameDefCollector()
for s in node.body:
shallow_node_def_collector.visit(s)
vars_def = shallow_node_def_collector.vars
for var_name in vars_def:
self.set_variable_scope(var_name)
# map all vars and recurse
node_cp.body = [self.visit(s) for s in node.body]
self.exit_scope()
return node_cp
def visit_NoneType(self, node: None) -> None:
return node
class RecordScoper(NodeTransformer):
_scoper: RewriteScoping
def __init__(self, scoper: RewriteScoping):
self._scoper = scoper
@classmethod
def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef:
f = cls(scoper)
return f.visit(c)
def visit_ClassDef(self, c: ClassDef) -> ClassDef:
node_cp = copy(c)
node_cp.name = self._scoper.map_name(node_cp.name)
return self.generic_visit(node_cp)
def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign:
assert isinstance(
node.target, Name
), "Record elements must have named attributes"
node_cp = copy(node)
node_cp.annotation = self._scoper.visit(node_cp.annotation)
return node_cp
Classes
class RecordScoper (scoper: RewriteScoping)
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RecordScoper(NodeTransformer): _scoper: RewriteScoping def __init__(self, scoper: RewriteScoping): self._scoper = scoper @classmethod def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef: f = cls(scoper) return f.visit(c) def visit_ClassDef(self, c: ClassDef) -> ClassDef: node_cp = copy(c) node_cp.name = self._scoper.map_name(node_cp.name) return self.generic_visit(node_cp) def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign: assert isinstance( node.target, Name ), "Record elements must have named attributes" node_cp = copy(node) node_cp.annotation = self._scoper.visit(node_cp.annotation) return node_cp
Ancestors
- ast.NodeTransformer
- ast.NodeVisitor
Static methods
def scope(c: ast.ClassDef, scoper: RewriteScoping) ‑> ast.ClassDef
Methods
def visit_AnnAssign(self, node: ast.AnnAssign) ‑> ast.AnnAssign
def visit_ClassDef(self, c: ast.ClassDef) ‑> ast.ClassDef
class RewriteScoping
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RewriteScoping(CompilingNodeTransformer): step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope" latest_scope_id: int scopes: typing.List[typing.Tuple[OrderedSet, int]] current_Self: typing.Tuple[str, str] def variable_scope_id(self, name: str) -> int: """find the id of the scope in which this variable is defined (closest to its usage)""" name = name for scope, scope_id in reversed(self.scopes): if name in scope: return scope_id raise NameError( f"free variable '{name}' referenced before assignment in enclosing scope" ) def enter_scope(self): self.scopes.append((OrderedSet(), self.latest_scope_id)) self.latest_scope_id += 1 def exit_scope(self): self.scopes.pop() def set_variable_scope(self, name: str): self.scopes[-1][0].add(name) def map_name(self, name: str): scope_id = self.variable_scope_id(name) if scope_id == -1: # do not rewrite Dict, Union, etc return name return f"{name}_{scope_id}" def visit_Module(self, node: Module) -> Module: self.latest_scope_id = 0 self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)] node_cp = copy(node) self.enter_scope() # vars defined in this scope shallow_node_def_collector = ShallowNameDefCollector() for s in node.body: shallow_node_def_collector.visit(s) vars_def = shallow_node_def_collector.vars for var_name in vars_def: self.set_variable_scope(var_name) node_cp.body = [self.visit(s) for s in node.body] return node_cp def visit_Name(self, node: Name) -> Name: nc = copy(node) # setting is handled in either enclosing module or function if node.id == "Self": assert node.idSelf == self.current_Self[1] nc.idSelf_new = self.current_Self[0] nc.id = self.map_name(node.id) return nc def visit_ClassDef(self, node: ClassDef) -> ClassDef: cp_node = RecordScoper.scope(node, self) for i, attribute in enumerate(cp_node.body): if isinstance(attribute, FunctionDef): self.current_Self = (cp_node.name, cp_node.orig_name) cp_node.body[i] = self.visit_FunctionDef(attribute, method=True) return cp_node def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> FunctionDef: node_cp = copy(node) # setting is handled in either enclosing module or function node_cp.name = self.map_name(node.name) if not method else node.name self.enter_scope() node_cp.args = copy(node.args) node_cp.args.args = [] # args are defined in this scope for a in node.args.args: a_cp = copy(a) self.set_variable_scope(a.arg) a_cp.arg = self.map_name(a.arg) a_cp.annotation = self.visit(a.annotation) node_cp.args.args.append(a_cp) node_cp.returns = self.visit(node.returns) # vars defined in this scope shallow_node_def_collector = ShallowNameDefCollector() for s in node.body: shallow_node_def_collector.visit(s) vars_def = shallow_node_def_collector.vars for var_name in vars_def: self.set_variable_scope(var_name) # map all vars and recurse node_cp.body = [self.visit(s) for s in node.body] self.exit_scope() return node_cp def visit_NoneType(self, node: None) -> None: return node
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var current_Self : Tuple[str, str]
var latest_scope_id : int
var scopes : List[Tuple[ordered_set.OrderedSet, int]]
var step
Methods
def enter_scope(self)
def exit_scope(self)
def map_name(self, name: str)
def set_variable_scope(self, name: str)
def variable_scope_id(self, name: str) ‑> int
-
find the id of the scope in which this variable is defined (closest to its usage)
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.
def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
def visit_FunctionDef(self, node: ast.FunctionDef, method: bool = False) ‑> ast.FunctionDef
def visit_Module(self, node: ast.Module) ‑> ast.Module
def visit_Name(self, node: ast.Name) ‑> ast.Name
def visit_NoneType(self, node: None) ‑> None
class ShallowNameDefCollector
-
A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the
visit
method.This class is meant to be subclassed, with the subclass adding visitor methods.
Per default the visitor functions for the nodes are
'visit_'
+ class name of the node. So aTryFinally
node visit function would bevisit_TryFinally
. This behavior can be changed by overriding thevisit
method. If no visitor function exists for a node (return valueNone
) thegeneric_visit
visitor is used instead.Don't use the
NodeVisitor
if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer
) that allows modifications.Expand source code
class ShallowNameDefCollector(CompilingNodeVisitor): step = "Collecting defined variable names" def __init__(self): self.vars = OrderedSet() def visit_Name(self, node: Name) -> None: if isinstance(node.ctx, Store): self.vars.add(node.id) def visit_ClassDef(self, node: ClassDef): self.vars.add(node.name) # methods will be put in global scope so add them now for attribute in node.body: if isinstance(attribute, FunctionDef): self.vars.add(attribute.name) # ignore the content (i.e. attribute names) of class definitions def visit_FunctionDef(self, node: FunctionDef): self.vars.add(node.name) # ignore the recursive stuff
Ancestors
- CompilingNodeVisitor
- TypedNodeVisitor
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeVisitor
.visit
Visit a node.
def visit_ClassDef(self, node: ast.ClassDef)
def visit_FunctionDef(self, node: ast.FunctionDef)
def visit_Name(self, node: ast.Name) ‑> None