Module opshin.rewrite.rewrite_cast_condition

Expand source code
from copy import copy

from ast import *

from ..util import CompilingNodeTransformer

"""
Rewrites all occurences of conditions to an implicit cast to bool
"""

SPECIAL_BOOL = "~bool"


class RewriteConditions(CompilingNodeTransformer):
    step = "Rewriting conditions to bools"

    def visit_Module(self, node: Module) -> Module:
        node.body.insert(0, Assign([Name(SPECIAL_BOOL, Store())], Name("bool", Load())))
        return self.generic_visit(node)

    def visit_If(self, node: If) -> If:
        if_cp = copy(node)
        if_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(if_cp)

    def visit_IfExp(self, node: IfExp) -> IfExp:
        if_cp = copy(node)
        if_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(if_cp)

    def visit_While(self, node: While) -> While:
        while_cp = copy(node)
        while_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(while_cp)

    def visit_BoolOp(self, node: BoolOp) -> BoolOp:
        bo_cp = copy(node)
        bo_cp.values = [
            Call(Name(SPECIAL_BOOL, Load()), [self.visit(v)], []) for v in bo_cp.values
        ]
        return self.generic_visit(bo_cp)

    def visit_Assert(self, node: Assert) -> Assert:
        assert_cp = copy(node)
        assert_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(assert_cp)

Classes

class RewriteConditions

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteConditions(CompilingNodeTransformer):
    step = "Rewriting conditions to bools"

    def visit_Module(self, node: Module) -> Module:
        node.body.insert(0, Assign([Name(SPECIAL_BOOL, Store())], Name("bool", Load())))
        return self.generic_visit(node)

    def visit_If(self, node: If) -> If:
        if_cp = copy(node)
        if_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(if_cp)

    def visit_IfExp(self, node: IfExp) -> IfExp:
        if_cp = copy(node)
        if_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(if_cp)

    def visit_While(self, node: While) -> While:
        while_cp = copy(node)
        while_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(while_cp)

    def visit_BoolOp(self, node: BoolOp) -> BoolOp:
        bo_cp = copy(node)
        bo_cp.values = [
            Call(Name(SPECIAL_BOOL, Load()), [self.visit(v)], []) for v in bo_cp.values
        ]
        return self.generic_visit(bo_cp)

    def visit_Assert(self, node: Assert) -> Assert:
        assert_cp = copy(node)
        assert_cp.test = Call(Name(SPECIAL_BOOL, Load()), [node.test], [])
        return self.generic_visit(assert_cp)

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_Assert(self, node: ast.Assert) ‑> ast.Assert
def visit_BoolOp(self, node: ast.BoolOp) ‑> ast.BoolOp
def visit_If(self, node: ast.If) ‑> ast.If
def visit_IfExp(self, node: ast.IfExp) ‑> ast.IfExp
def visit_Module(self, node: ast.Module) ‑> ast.Module
def visit_While(self, node: ast.While) ‑> ast.While