Module opshin.rewrite.rewrite_annotate_fallthrough

Expand source code
import ast
from copy import copy
from ast import *

from ..util import CompilingNodeTransformer
from ..typed_util import annotate_compound_statement_fallthrough
from .rewrite_cast_condition import SPECIAL_BOOL


class RewriteAnnotateFallthrough(CompilingNodeTransformer):
    step = "Annotating statement fallthrough"

    @staticmethod
    def expr_is_definitely_false(node):
        if isinstance(node, Constant):
            return not bool(node.value)
        if (
            isinstance(node, Call)
            and isinstance(node.func, Name)
            and (node.func.id == SPECIAL_BOOL or node.func.orig_id == SPECIAL_BOOL)
            and len(node.args) == 1
            and not node.keywords
        ):
            return RewriteAnnotateFallthrough.expr_is_definitely_false(node.args[0])
        return False

    def generic_visit(self, node):
        visited = super().generic_visit(node)
        if isinstance(visited, ast.stmt):
            visited.can_fall_through = getattr(visited, "can_fall_through", True)
        return visited

    def visit_Module(self, node: Module) -> Module:
        module_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(module_cp)

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        func_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(func_cp)

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        class_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(class_cp)

    def visit_If(self, node: If) -> If:
        if_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(if_cp)

    def visit_For(self, node: For) -> For:
        for_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(for_cp)

    def visit_While(self, node: While) -> While:
        while_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(while_cp)

    def visit_Return(self, node: Return) -> Return:
        return_cp = self.generic_visit(copy(node))
        return_cp.can_fall_through = False
        return return_cp

    def visit_Assert(self, node: Assert) -> Assert:
        assert_cp = self.generic_visit(copy(node))
        assert_cp.can_fall_through = not self.expr_is_definitely_false(assert_cp.test)
        return assert_cp

Classes

class RewriteAnnotateFallthrough

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteAnnotateFallthrough(CompilingNodeTransformer):
    step = "Annotating statement fallthrough"

    @staticmethod
    def expr_is_definitely_false(node):
        if isinstance(node, Constant):
            return not bool(node.value)
        if (
            isinstance(node, Call)
            and isinstance(node.func, Name)
            and (node.func.id == SPECIAL_BOOL or node.func.orig_id == SPECIAL_BOOL)
            and len(node.args) == 1
            and not node.keywords
        ):
            return RewriteAnnotateFallthrough.expr_is_definitely_false(node.args[0])
        return False

    def generic_visit(self, node):
        visited = super().generic_visit(node)
        if isinstance(visited, ast.stmt):
            visited.can_fall_through = getattr(visited, "can_fall_through", True)
        return visited

    def visit_Module(self, node: Module) -> Module:
        module_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(module_cp)

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        func_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(func_cp)

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        class_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(class_cp)

    def visit_If(self, node: If) -> If:
        if_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(if_cp)

    def visit_For(self, node: For) -> For:
        for_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(for_cp)

    def visit_While(self, node: While) -> While:
        while_cp = self.generic_visit(copy(node))
        return annotate_compound_statement_fallthrough(while_cp)

    def visit_Return(self, node: Return) -> Return:
        return_cp = self.generic_visit(copy(node))
        return_cp.can_fall_through = False
        return return_cp

    def visit_Assert(self, node: Assert) -> Assert:
        assert_cp = self.generic_visit(copy(node))
        assert_cp.can_fall_through = not self.expr_is_definitely_false(assert_cp.test)
        return assert_cp

Ancestors

Class variables

var step

Static methods

def expr_is_definitely_false(node)
Expand source code
@staticmethod
def expr_is_definitely_false(node):
    if isinstance(node, Constant):
        return not bool(node.value)
    if (
        isinstance(node, Call)
        and isinstance(node.func, Name)
        and (node.func.id == SPECIAL_BOOL or node.func.orig_id == SPECIAL_BOOL)
        and len(node.args) == 1
        and not node.keywords
    ):
        return RewriteAnnotateFallthrough.expr_is_definitely_false(node.args[0])
    return False

Methods

def generic_visit(self, node)

Called if no explicit visitor function exists for a node.

Expand source code
def generic_visit(self, node):
    visited = super().generic_visit(node)
    if isinstance(visited, ast.stmt):
        visited.can_fall_through = getattr(visited, "can_fall_through", True)
    return visited
def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_Assert(self, node: ast.Assert) ‑> ast.Assert
Expand source code
def visit_Assert(self, node: Assert) -> Assert:
    assert_cp = self.generic_visit(copy(node))
    assert_cp.can_fall_through = not self.expr_is_definitely_false(assert_cp.test)
    return assert_cp
def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
    class_cp = self.generic_visit(copy(node))
    return annotate_compound_statement_fallthrough(class_cp)
def visit_For(self, node: ast.For) ‑> ast.For
Expand source code
def visit_For(self, node: For) -> For:
    for_cp = self.generic_visit(copy(node))
    return annotate_compound_statement_fallthrough(for_cp)
def visit_FunctionDef(self, node: ast.FunctionDef) ‑> ast.FunctionDef
Expand source code
def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
    func_cp = self.generic_visit(copy(node))
    return annotate_compound_statement_fallthrough(func_cp)
def visit_If(self, node: ast.If) ‑> ast.If
Expand source code
def visit_If(self, node: If) -> If:
    if_cp = self.generic_visit(copy(node))
    return annotate_compound_statement_fallthrough(if_cp)
def visit_Module(self, node: ast.Module) ‑> ast.Module
Expand source code
def visit_Module(self, node: Module) -> Module:
    module_cp = self.generic_visit(copy(node))
    return annotate_compound_statement_fallthrough(module_cp)
def visit_Return(self, node: ast.Return) ‑> ast.Return
Expand source code
def visit_Return(self, node: Return) -> Return:
    return_cp = self.generic_visit(copy(node))
    return_cp.can_fall_through = False
    return return_cp
def visit_While(self, node: ast.While) ‑> ast.While
Expand source code
def visit_While(self, node: While) -> While:
    while_cp = self.generic_visit(copy(node))
    return annotate_compound_statement_fallthrough(while_cp)