Module opshin.rewrite.rewrite_scoping

Expand source code
import typing
from ast import *
from copy import copy

from ordered_set import OrderedSet

from .rewrite_forbidden_overwrites import FORBIDDEN_NAMES
from ..type_inference import INITIAL_SCOPE
from ..util import CompilingNodeTransformer, CompilingNodeVisitor

"""
Rewrites all variable names to point to the definition in the nearest enclosing scope
"""


class ShallowNameDefCollector(CompilingNodeVisitor):
    step = "Collecting defined variable names"

    def __init__(self):
        self.vars = OrderedSet()

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars.add(node.id)

    def visit_ClassDef(self, node: ClassDef):
        self.vars.add(node.name)
        # methods will be put in global scope so add them now
        for attribute in node.body:
            if isinstance(attribute, FunctionDef):
                self.vars.add(attribute.name)
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars.add(node.name)
        # ignore the recursive stuff


class RewriteScoping(CompilingNodeTransformer):
    step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
    latest_scope_id: int
    scopes: typing.List[typing.Tuple[OrderedSet, int]]
    current_Self: typing.Tuple[str, str]

    def variable_scope_id(self, name: str) -> int:
        """find the id of the scope in which this variable is defined (closest to its usage)"""
        name = name
        for scope, scope_id in reversed(self.scopes):
            if name in scope:
                return scope_id
        raise NameError(
            f"free variable '{name}' referenced before assignment in enclosing scope"
        )

    def enter_scope(self):
        self.scopes.append((OrderedSet(), self.latest_scope_id))
        self.latest_scope_id += 1

    def exit_scope(self):
        self.scopes.pop()

    def set_variable_scope(self, name: str):
        self.scopes[-1][0].add(name)

    def map_name(self, name: str):
        scope_id = self.variable_scope_id(name)
        if scope_id == -1:
            # do not rewrite Dict, Union, etc
            return name
        return f"{name}_{scope_id}"

    def visit_Module(self, node: Module) -> Module:
        self.latest_scope_id = 0
        self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)]
        node_cp = copy(node)
        self.enter_scope()
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        node_cp.body = [self.visit(s) for s in node.body]
        return node_cp

    def visit_Name(self, node: Name) -> Name:
        nc = copy(node)
        # setting is handled in either enclosing module or function
        if node.id == "Self":
            assert node.idSelf == self.current_Self[1]
            nc.idSelf_new = self.current_Self[0]
        nc.id = self.map_name(node.id)
        return nc

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        cp_node = RecordScoper.scope(node, self)
        for i, attribute in enumerate(cp_node.body):
            if isinstance(attribute, FunctionDef):
                self.current_Self = (cp_node.name, cp_node.orig_name)
                cp_node.body[i] = self.visit_FunctionDef(attribute, method=True)
        return cp_node

    def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> FunctionDef:
        node_cp = copy(node)
        # setting is handled in either enclosing module or function
        node_cp.name = self.map_name(node.name) if not method else node.name
        self.enter_scope()
        node_cp.args = copy(node.args)
        node_cp.args.args = []
        # args are defined in this scope
        for a in node.args.args:
            a_cp = copy(a)
            self.set_variable_scope(a.arg)
            a_cp.arg = self.map_name(a.arg)
            a_cp.annotation = self.visit(a.annotation)
            node_cp.args.args.append(a_cp)
        node_cp.returns = self.visit(node.returns)
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        # map all vars and recurse
        node_cp.body = [self.visit(s) for s in node.body]
        self.exit_scope()
        return node_cp

    def visit_NoneType(self, node: None) -> None:
        return node


class RecordScoper(NodeTransformer):
    _scoper: RewriteScoping

    def __init__(self, scoper: RewriteScoping):
        self._scoper = scoper

    @classmethod
    def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef:
        f = cls(scoper)
        return f.visit(c)

    def visit_ClassDef(self, c: ClassDef) -> ClassDef:
        node_cp = copy(c)
        node_cp.name = self._scoper.map_name(node_cp.name)
        return self.generic_visit(node_cp)

    def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign:
        assert isinstance(
            node.target, Name
        ), "Record elements must have named attributes"
        node_cp = copy(node)
        node_cp.annotation = self._scoper.visit(node_cp.annotation)
        return node_cp

Classes

class RecordScoper (scoper: RewriteScoping)

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RecordScoper(NodeTransformer):
    _scoper: RewriteScoping

    def __init__(self, scoper: RewriteScoping):
        self._scoper = scoper

    @classmethod
    def scope(cls, c: ClassDef, scoper: RewriteScoping) -> ClassDef:
        f = cls(scoper)
        return f.visit(c)

    def visit_ClassDef(self, c: ClassDef) -> ClassDef:
        node_cp = copy(c)
        node_cp.name = self._scoper.map_name(node_cp.name)
        return self.generic_visit(node_cp)

    def visit_AnnAssign(self, node: AnnAssign) -> AnnAssign:
        assert isinstance(
            node.target, Name
        ), "Record elements must have named attributes"
        node_cp = copy(node)
        node_cp.annotation = self._scoper.visit(node_cp.annotation)
        return node_cp

Ancestors

  • ast.NodeTransformer
  • ast.NodeVisitor

Static methods

def scope(c: ast.ClassDef, scoper: RewriteScoping) ‑> ast.ClassDef

Methods

def visit_AnnAssign(self, node: ast.AnnAssign) ‑> ast.AnnAssign
def visit_ClassDef(self, c: ast.ClassDef) ‑> ast.ClassDef
class RewriteScoping

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteScoping(CompilingNodeTransformer):
    step = "Rewrite all variables to inambiguously point to the definition in the nearest enclosing scope"
    latest_scope_id: int
    scopes: typing.List[typing.Tuple[OrderedSet, int]]
    current_Self: typing.Tuple[str, str]

    def variable_scope_id(self, name: str) -> int:
        """find the id of the scope in which this variable is defined (closest to its usage)"""
        name = name
        for scope, scope_id in reversed(self.scopes):
            if name in scope:
                return scope_id
        raise NameError(
            f"free variable '{name}' referenced before assignment in enclosing scope"
        )

    def enter_scope(self):
        self.scopes.append((OrderedSet(), self.latest_scope_id))
        self.latest_scope_id += 1

    def exit_scope(self):
        self.scopes.pop()

    def set_variable_scope(self, name: str):
        self.scopes[-1][0].add(name)

    def map_name(self, name: str):
        scope_id = self.variable_scope_id(name)
        if scope_id == -1:
            # do not rewrite Dict, Union, etc
            return name
        return f"{name}_{scope_id}"

    def visit_Module(self, node: Module) -> Module:
        self.latest_scope_id = 0
        self.scopes = [(OrderedSet(INITIAL_SCOPE.keys() | FORBIDDEN_NAMES), -1)]
        node_cp = copy(node)
        self.enter_scope()
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        node_cp.body = [self.visit(s) for s in node.body]
        return node_cp

    def visit_Name(self, node: Name) -> Name:
        nc = copy(node)
        # setting is handled in either enclosing module or function
        if node.id == "Self":
            assert node.idSelf == self.current_Self[1]
            nc.idSelf_new = self.current_Self[0]
        nc.id = self.map_name(node.id)
        return nc

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        cp_node = RecordScoper.scope(node, self)
        for i, attribute in enumerate(cp_node.body):
            if isinstance(attribute, FunctionDef):
                self.current_Self = (cp_node.name, cp_node.orig_name)
                cp_node.body[i] = self.visit_FunctionDef(attribute, method=True)
        return cp_node

    def visit_FunctionDef(self, node: FunctionDef, method: bool = False) -> FunctionDef:
        node_cp = copy(node)
        # setting is handled in either enclosing module or function
        node_cp.name = self.map_name(node.name) if not method else node.name
        self.enter_scope()
        node_cp.args = copy(node.args)
        node_cp.args.args = []
        # args are defined in this scope
        for a in node.args.args:
            a_cp = copy(a)
            self.set_variable_scope(a.arg)
            a_cp.arg = self.map_name(a.arg)
            a_cp.annotation = self.visit(a.annotation)
            node_cp.args.args.append(a_cp)
        node_cp.returns = self.visit(node.returns)
        # vars defined in this scope
        shallow_node_def_collector = ShallowNameDefCollector()
        for s in node.body:
            shallow_node_def_collector.visit(s)
        vars_def = shallow_node_def_collector.vars
        for var_name in vars_def:
            self.set_variable_scope(var_name)
        # map all vars and recurse
        node_cp.body = [self.visit(s) for s in node.body]
        self.exit_scope()
        return node_cp

    def visit_NoneType(self, node: None) -> None:
        return node

Ancestors

Class variables

var current_Self : Tuple[str, str]
var latest_scope_id : int
var scopes : List[Tuple[ordered_set.OrderedSet, int]]
var step

Methods

def enter_scope(self)
def exit_scope(self)
def map_name(self, name: str)
def set_variable_scope(self, name: str)
def variable_scope_id(self, name: str) ‑> int

find the id of the scope in which this variable is defined (closest to its usage)

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
def visit_FunctionDef(self, node: ast.FunctionDef, method: bool = False) ‑> ast.FunctionDef
def visit_Module(self, node: ast.Module) ‑> ast.Module
def visit_Name(self, node: ast.Name) ‑> ast.Name
def visit_NoneType(self, node: None) ‑> None
class ShallowNameDefCollector

A node visitor base class that walks the abstract syntax tree and calls a visitor function for every node found. This function may return a value which is forwarded by the visit method.

This class is meant to be subclassed, with the subclass adding visitor methods.

Per default the visitor functions for the nodes are 'visit_' + class name of the node. So a TryFinally node visit function would be visit_TryFinally. This behavior can be changed by overriding the visit method. If no visitor function exists for a node (return value None) the generic_visit visitor is used instead.

Don't use the NodeVisitor if you want to apply changes to nodes during traversing. For this a special visitor exists (NodeTransformer) that allows modifications.

Expand source code
class ShallowNameDefCollector(CompilingNodeVisitor):
    step = "Collecting defined variable names"

    def __init__(self):
        self.vars = OrderedSet()

    def visit_Name(self, node: Name) -> None:
        if isinstance(node.ctx, Store):
            self.vars.add(node.id)

    def visit_ClassDef(self, node: ClassDef):
        self.vars.add(node.name)
        # methods will be put in global scope so add them now
        for attribute in node.body:
            if isinstance(attribute, FunctionDef):
                self.vars.add(attribute.name)
        # ignore the content (i.e. attribute names) of class definitions

    def visit_FunctionDef(self, node: FunctionDef):
        self.vars.add(node.name)
        # ignore the recursive stuff

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeVisitor.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef)
def visit_FunctionDef(self, node: ast.FunctionDef)
def visit_Name(self, node: ast.Name) ‑> None