Module opshin.optimize.optimize_fold_if_fallthrough

Expand source code
from ast import *
from copy import copy

from ..typed_util import (
    ScopedSequenceNodeTransformer,
    annotate_compound_statement_fallthrough,
)

"""
If exactly one branch of an if-statement can fall through, fold the following
statements in the enclosing sequence into that branch.
"""


class OptimizeFoldIfFallthrough(ScopedSequenceNodeTransformer):
    step = "Folding trailing statements into sole fallthrough if-branches"

    def fold_sequence(self, statements):
        folded = []
        i = 0
        while i < len(statements):
            if statements[i] is None:
                i += 1
                continue
            stmt = self.visit(statements[i])
            if stmt is None:
                i += 1
                continue
            if isinstance(stmt, If):
                body_can_fall_through = getattr(stmt, "body_can_fall_through", True)
                orelse_can_fall_through = getattr(stmt, "orelse_can_fall_through", True)
                if body_can_fall_through != orelse_can_fall_through and i + 1 < len(
                    statements
                ):
                    trailing = statements[i + 1 :]
                    if body_can_fall_through:
                        stmt.body = self.fold_sequence(stmt.body + trailing)
                    else:
                        stmt.orelse = self.fold_sequence(stmt.orelse + trailing)
                    folded.append(annotate_compound_statement_fallthrough(stmt))
                    break
            folded.append(stmt)
            i += 1
        return folded

    def visit_Module(self, node: Module) -> Module:
        node_cp = super().visit_Module(node)
        node_cp.body = self.fold_sequence(node_cp.body)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        node_cp = super().visit_FunctionDef(node)
        node_cp.body = self.fold_sequence(node_cp.body)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        node_cp = super().visit_ClassDef(node)
        node_cp.body = self.fold_sequence(node_cp.body)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_If(self, node: If) -> If:
        node_cp = copy(node)
        node_cp.test = self.visit(node.test)
        node_cp.body = self.fold_sequence(node.body)
        node_cp.orelse = self.fold_sequence(node.orelse)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_While(self, node: While) -> While:
        node_cp = copy(node)
        node_cp.test = self.visit(node.test)
        node_cp.body = self.fold_sequence(node.body)
        node_cp.orelse = self.fold_sequence(node.orelse)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_For(self, node: For) -> For:
        node_cp = copy(node)
        node_cp.target = self.visit(node.target)
        node_cp.iter = self.visit(node.iter)
        node_cp.body = self.fold_sequence(node.body)
        node_cp.orelse = self.fold_sequence(node.orelse)
        return annotate_compound_statement_fallthrough(node_cp)

Classes

class OptimizeFoldIfFallthrough

Rewrite nested statement sequences while preserving the surrounding node.

Expand source code
class OptimizeFoldIfFallthrough(ScopedSequenceNodeTransformer):
    step = "Folding trailing statements into sole fallthrough if-branches"

    def fold_sequence(self, statements):
        folded = []
        i = 0
        while i < len(statements):
            if statements[i] is None:
                i += 1
                continue
            stmt = self.visit(statements[i])
            if stmt is None:
                i += 1
                continue
            if isinstance(stmt, If):
                body_can_fall_through = getattr(stmt, "body_can_fall_through", True)
                orelse_can_fall_through = getattr(stmt, "orelse_can_fall_through", True)
                if body_can_fall_through != orelse_can_fall_through and i + 1 < len(
                    statements
                ):
                    trailing = statements[i + 1 :]
                    if body_can_fall_through:
                        stmt.body = self.fold_sequence(stmt.body + trailing)
                    else:
                        stmt.orelse = self.fold_sequence(stmt.orelse + trailing)
                    folded.append(annotate_compound_statement_fallthrough(stmt))
                    break
            folded.append(stmt)
            i += 1
        return folded

    def visit_Module(self, node: Module) -> Module:
        node_cp = super().visit_Module(node)
        node_cp.body = self.fold_sequence(node_cp.body)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
        node_cp = super().visit_FunctionDef(node)
        node_cp.body = self.fold_sequence(node_cp.body)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        node_cp = super().visit_ClassDef(node)
        node_cp.body = self.fold_sequence(node_cp.body)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_If(self, node: If) -> If:
        node_cp = copy(node)
        node_cp.test = self.visit(node.test)
        node_cp.body = self.fold_sequence(node.body)
        node_cp.orelse = self.fold_sequence(node.orelse)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_While(self, node: While) -> While:
        node_cp = copy(node)
        node_cp.test = self.visit(node.test)
        node_cp.body = self.fold_sequence(node.body)
        node_cp.orelse = self.fold_sequence(node.orelse)
        return annotate_compound_statement_fallthrough(node_cp)

    def visit_For(self, node: For) -> For:
        node_cp = copy(node)
        node_cp.target = self.visit(node.target)
        node_cp.iter = self.visit(node.iter)
        node_cp.body = self.fold_sequence(node.body)
        node_cp.orelse = self.fold_sequence(node.orelse)
        return annotate_compound_statement_fallthrough(node_cp)

Ancestors

Class variables

var step

Methods

def fold_sequence(self, statements)
Expand source code
def fold_sequence(self, statements):
    folded = []
    i = 0
    while i < len(statements):
        if statements[i] is None:
            i += 1
            continue
        stmt = self.visit(statements[i])
        if stmt is None:
            i += 1
            continue
        if isinstance(stmt, If):
            body_can_fall_through = getattr(stmt, "body_can_fall_through", True)
            orelse_can_fall_through = getattr(stmt, "orelse_can_fall_through", True)
            if body_can_fall_through != orelse_can_fall_through and i + 1 < len(
                statements
            ):
                trailing = statements[i + 1 :]
                if body_can_fall_through:
                    stmt.body = self.fold_sequence(stmt.body + trailing)
                else:
                    stmt.orelse = self.fold_sequence(stmt.orelse + trailing)
                folded.append(annotate_compound_statement_fallthrough(stmt))
                break
        folded.append(stmt)
        i += 1
    return folded
def visit(self, node)

Inherited from: ScopedSequenceNodeTransformer.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
    node_cp = super().visit_ClassDef(node)
    node_cp.body = self.fold_sequence(node_cp.body)
    return annotate_compound_statement_fallthrough(node_cp)
def visit_For(self, node: ast.For) ‑> ast.For
Expand source code
def visit_For(self, node: For) -> For:
    node_cp = copy(node)
    node_cp.target = self.visit(node.target)
    node_cp.iter = self.visit(node.iter)
    node_cp.body = self.fold_sequence(node.body)
    node_cp.orelse = self.fold_sequence(node.orelse)
    return annotate_compound_statement_fallthrough(node_cp)
def visit_FunctionDef(self, node: ast.FunctionDef) ‑> ast.FunctionDef
Expand source code
def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
    node_cp = super().visit_FunctionDef(node)
    node_cp.body = self.fold_sequence(node_cp.body)
    return annotate_compound_statement_fallthrough(node_cp)
def visit_If(self, node: ast.If) ‑> ast.If
Expand source code
def visit_If(self, node: If) -> If:
    node_cp = copy(node)
    node_cp.test = self.visit(node.test)
    node_cp.body = self.fold_sequence(node.body)
    node_cp.orelse = self.fold_sequence(node.orelse)
    return annotate_compound_statement_fallthrough(node_cp)
def visit_Module(self, node: ast.Module) ‑> ast.Module
Expand source code
def visit_Module(self, node: Module) -> Module:
    node_cp = super().visit_Module(node)
    node_cp.body = self.fold_sequence(node_cp.body)
    return annotate_compound_statement_fallthrough(node_cp)
def visit_While(self, node: ast.While) ‑> ast.While
Expand source code
def visit_While(self, node: While) -> While:
    node_cp = copy(node)
    node_cp.test = self.visit(node.test)
    node_cp.body = self.fold_sequence(node.body)
    node_cp.orelse = self.fold_sequence(node.orelse)
    return annotate_compound_statement_fallthrough(node_cp)