Module opshin.rewrite.rewrite_scoping

Expand source code
import typing
from ast import *
from copy import copy

from ordered_set import OrderedSet

from .rewrite_forbidden_overwrites import FORBIDDEN_NAMES
from ..type_inference import INITIAL_SCOPE
from ..util import CompilingNodeTransformer, CompilingNodeVisitor

"""
Rewrites all variable names to point to the definition in the nearest enclosing scope
"""


class ShallowNameDefCollector(CompilingNodeVisitor):
    step = "Collecting defined variable names"

    def __init__(self):
        self.vars = OrderedSet()

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars.add(node.id)

    def visit_ClassDef(self, node: ClassDef):
        self.vars.add(node.name)
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars.add(node.name)
        # ignore the recursive stuff


class RewriteScoping(CompilingNodeTransformer):
    step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
    latest_scope_id: int
    scopes: typing.List[typing.Tuple[OrderedSet, int]]

    def variable_scope_id(self, name: str) -> int:
        """find the id of the scope in which this variable is defined (closest to its usage)"""
        name = name
        for scope, scope_id in reversed(self.scopes):
            if name in scope:
                return scope_id
        raise NameError(
            f"free variable '{name}' referenced before assignment in enclosing scope"
        )

    def enter_scope(self):
        self.scopes.append((OrderedSet(), self.latest_scope_id))
        self.latest_scope_id += 1

    def exit_scope(self):
        self.scopes.pop()

    def set_variable_scope(self, name: str):
        self.scopes[-1][0].add(name)

    def map_name(self, name: str):
        scope_id = self.variable_scope_id(name)
        if scope_id == -1:
            # do not rewrite Dict, Union, etc
            return name
        return f"{name}_{scope_id}"

    def visit_Module(self, node: Module) -> Module:
        self.latest_scope_id = 0
        self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)]
        node_cp = copy(node)
        self.enter_scope()
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        node_cp.body = [self.visit(s) for s in node.body]
        return node_cp

    def visit_Name(self, node: Name) -> Name:
        nc = copy(node)
        # setting is handled in either enclosing module or function
        nc.id = self.map_name(node.id)
        return nc

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        return RecordScoper.scope(node, self)

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        node_cp = copy(node)
        # setting is handled in either enclosing module or function
        node_cp.name = self.map_name(node.name)
        self.enter_scope()
        node_cp.args = copy(node.args)
        node_cp.args.args = []
        # args are defined in this scope
        for a in node.args.args:
            a_cp = copy(a)
            self.set_variable_scope(a.arg)
            a_cp.arg = self.map_name(a.arg)
            a_cp.annotation = self.visit(a.annotation)
            node_cp.args.args.append(a_cp)
        node_cp.returns = self.visit(node.returns)
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        # map all vars and recurse
        node_cp.body = [self.visit(s) for s in node.body]
        self.exit_scope()
        return node_cp

    def visit_NoneType(self, node: None) -> None:
        return node


class RecordScoper(NodeTransformer):
    _scoper: RewriteScoping

    def __init__(self, scoper: RewriteScoping):
        self._scoper = scoper

    @classmethod
    def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef:
        f = cls(scoper)
        return f.visit(c)

    def visit_ClassDef(self, c: ClassDef) -> ClassDef:
        node_cp = copy(c)
        node_cp.name = self._scoper.map_name(node_cp.name)
        return self.generic_visit(node_cp)

    def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign:
        assert isinstance(
            node.target, Name
        ), "Record elements must have named attributes"
        node_cp = copy(node)
        node_cp.annotation = self._scoper.visit(node_cp.annotation)
        return node_cp

Classes

class RecordScoper (scoper: RewriteScoping)

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RecordScoper(NodeTransformer):
    _scoper: RewriteScoping

    def __init__(self, scoper: RewriteScoping):
        self._scoper = scoper

    @classmethod
    def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef:
        f = cls(scoper)
        return f.visit(c)

    def visit_ClassDef(self, c: ClassDef) -> ClassDef:
        node_cp = copy(c)
        node_cp.name = self._scoper.map_name(node_cp.name)
        return self.generic_visit(node_cp)

    def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign:
        assert isinstance(
            node.target, Name
        ), "Record elements must have named attributes"
        node_cp = copy(node)
        node_cp.annotation = self._scoper.visit(node_cp.annotation)
        return node_cp

Ancestors

  • ast.NodeTransformer
  • ast.NodeVisitor

Static methods

def scope(c: ast.ClassDef, scoper: RewriteScoping) ‑> ast.ClassDef
Expand source code
@classmethod
def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef:
    f = cls(scoper)
    return f.visit(c)

Methods

def visit_AnnAssign(self, node: ast.AnnAssign) ‑> ast.AnnAssign
Expand source code
def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign:
    assert isinstance(
        node.target, Name
    ), "Record elements must have named attributes"
    node_cp = copy(node)
    node_cp.annotation = self._scoper.visit(node_cp.annotation)
    return node_cp
def visit_ClassDef(self, c: ast.ClassDef) ‑> ast.ClassDef
Expand source code
def visit_ClassDef(self, c: ClassDef) -> ClassDef:
    node_cp = copy(c)
    node_cp.name = self._scoper.map_name(node_cp.name)
    return self.generic_visit(node_cp)
class RewriteScoping

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteScoping(CompilingNodeTransformer):
    step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
    latest_scope_id: int
    scopes: typing.List[typing.Tuple[OrderedSet, int]]

    def variable_scope_id(self, name: str) -> int:
        """find the id of the scope in which this variable is defined (closest to its usage)"""
        name = name
        for scope, scope_id in reversed(self.scopes):
            if name in scope:
                return scope_id
        raise NameError(
            f"free variable '{name}' referenced before assignment in enclosing scope"
        )

    def enter_scope(self):
        self.scopes.append((OrderedSet(), self.latest_scope_id))
        self.latest_scope_id += 1

    def exit_scope(self):
        self.scopes.pop()

    def set_variable_scope(self, name: str):
        self.scopes[-1][0].add(name)

    def map_name(self, name: str):
        scope_id = self.variable_scope_id(name)
        if scope_id == -1:
            # do not rewrite Dict, Union, etc
            return name
        return f"{name}_{scope_id}"

    def visit_Module(self, node: Module) -> Module:
        self.latest_scope_id = 0
        self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)]
        node_cp = copy(node)
        self.enter_scope()
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        node_cp.body = [self.visit(s) for s in node.body]
        return node_cp

    def visit_Name(self, node: Name) -> Name:
        nc = copy(node)
        # setting is handled in either enclosing module or function
        nc.id = self.map_name(node.id)
        return nc

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        return RecordScoper.scope(node, self)

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        node_cp = copy(node)
        # setting is handled in either enclosing module or function
        node_cp.name = self.map_name(node.name)
        self.enter_scope()
        node_cp.args = copy(node.args)
        node_cp.args.args = []
        # args are defined in this scope
        for a in node.args.args:
            a_cp = copy(a)
            self.set_variable_scope(a.arg)
            a_cp.arg = self.map_name(a.arg)
            a_cp.annotation = self.visit(a.annotation)
            node_cp.args.args.append(a_cp)
        node_cp.returns = self.visit(node.returns)
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        # map all vars and recurse
        node_cp.body = [self.visit(s) for s in node.body]
        self.exit_scope()
        return node_cp

    def visit_NoneType(self, node: None) -> None:
        return node

Ancestors

Class variables

var latest_scope_id : int
var scopes : List[Tuple[ordered_set.OrderedSet, int]]
var step

Methods

def enter_scope(self)
Expand source code
def enter_scope(self):
    self.scopes.append((OrderedSet(), self.latest_scope_id))
    self.latest_scope_id += 1
def exit_scope(self)
Expand source code
def exit_scope(self):
    self.scopes.pop()
def map_name(self, name: str)
Expand source code
def map_name(self, name: str):
    scope_id = self.variable_scope_id(name)
    if scope_id == -1:
        # do not rewrite Dict, Union, etc
        return name
    return f"{name}_{scope_id}"
def set_variable_scope(self, name: str)
Expand source code
def set_variable_scope(self, name: str):
    self.scopes[-1][0].add(name)
def variable_scope_id(self, name: str) ‑> int

find the id of the scope in which this variable is defined (closest to its usage)

Expand source code
def variable_scope_id(self, name: str) -> int:
    """find the id of the scope in which this variable is defined (closest to its usage)"""
    name = name
    for scope, scope_id in reversed(self.scopes):
        if name in scope:
            return scope_id
    raise NameError(
        f"free variable '{name}' referenced before assignment in enclosing scope"
    )
def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
    return RecordScoper.scope(node, self)
def visit_FunctionDef(self, node: ast.FunctionDef) ‑> ast.FunctionDef
Expand source code
def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
    node_cp = copy(node)
    # setting is handled in either enclosing module or function
    node_cp.name = self.map_name(node.name)
    self.enter_scope()
    node_cp.args = copy(node.args)
    node_cp.args.args = []
    # args are defined in this scope
    for a in node.args.args:
        a_cp = copy(a)
        self.set_variable_scope(a.arg)
        a_cp.arg = self.map_name(a.arg)
        a_cp.annotation = self.visit(a.annotation)
        node_cp.args.args.append(a_cp)
    node_cp.returns = self.visit(node.returns)
    # vars defined in this scope
    shallow_node_def_collector = ShallowNameDefCollector()
    for s in node.body:
        shallow_node_def_collector.visit(s)
    vars_def = shallow_node_def_collector.vars
    for var_name in vars_def:
        self.set_variable_scope(var_name)
    # map all vars and recurse
    node_cp.body = [self.visit(s) for s in node.body]
    self.exit_scope()
    return node_cp
def visit_Module(self, node: ast.Module) ‑> ast.Module
Expand source code
def visit_Module(self, node: Module) -> Module:
    self.latest_scope_id = 0
    self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)]
    node_cp = copy(node)
    self.enter_scope()
    # vars defined in this scope
    shallow_node_def_collector = ShallowNameDefCollector()
    for s in node.body:
        shallow_node_def_collector.visit(s)
    vars_def = shallow_node_def_collector.vars
    for var_name in vars_def:
        self.set_variable_scope(var_name)
    node_cp.body = [self.visit(s) for s in node.body]
    return node_cp
def visit_Name(self, node: ast.Name) ‑> ast.Name
Expand source code
def visit_Name(self, node: Name) -> Name:
    nc = copy(node)
    # setting is handled in either enclosing module or function
    nc.id = self.map_name(node.id)
    return nc
def visit_NoneType(self, node: None) ‑> None
Expand source code
def visit_NoneType(self, node: None) -> None:
    return node
class ShallowNameDefCollector

A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the visit method.

This class is meant to be subclassed, with the subclass adding visitor methods.

Per default the visitor functions for the nodes are 'visit_' + class name of the node. So a TryFinally node visit function would be visit_TryFinally. This behavior can be changed by overriding the visit method. If no visitor function exists for a node (return value None) the generic_visit visitor is used instead.

Don't use the NodeVisitor if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer) that allows modifications.

Expand source code
class ShallowNameDefCollector(CompilingNodeVisitor):
    step = "Collecting defined variable names"

    def __init__(self):
        self.vars = OrderedSet()

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars.add(node.id)

    def visit_ClassDef(self, node: ClassDef):
        self.vars.add(node.name)
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars.add(node.name)
        # ignore the recursive stuff

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeVisitor.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef)
Expand source code
def visit_ClassDef(self, node: ClassDef):
    self.vars.add(node.name)
    # ignore the content (i.e. attribute names) of class definitions
def visit_FunctionDef(self, node: ast.FunctionDef)
Expand source code
def visit_FunctionDef(self, node: FunctionDef):
    self.vars.add(node.name)
    # ignore the recursive stuff
def visit_Name(self, node: ast.Name) ‑> None
Expand source code
def visit_Name(self, node: Name) -> None:
    if isinstance(node.ctx, Store):
        self.vars.add(node.id)