Module opshin.optimize.optimize_fold_if_fallthrough
Expand source code
from ast import *
from copy import copy
from ..typed_util import (
ScopedSequenceNodeTransformer,
annotate_compound_statement_fallthrough,
)
"""
If exactly one branch of an if-statement can fall through, fold the following
statements in the enclosing sequence into that branch.
"""
class OptimizeFoldIfFallthrough(ScopedSequenceNodeTransformer):
step = "Folding trailing statements into sole fallthrough if-branches"
def fold_sequence(self, statements):
folded = []
i = 0
while i < len(statements):
if statements[i] is None:
i += 1
continue
stmt = self.visit(statements[i])
if stmt is None:
i += 1
continue
if isinstance(stmt, If):
body_can_fall_through = getattr(stmt, "body_can_fall_through", True)
orelse_can_fall_through = getattr(stmt, "orelse_can_fall_through", True)
if body_can_fall_through != orelse_can_fall_through and i + 1 < len(
statements
):
trailing = statements[i + 1 :]
if body_can_fall_through:
stmt.body = self.fold_sequence(stmt.body + trailing)
else:
stmt.orelse = self.fold_sequence(stmt.orelse + trailing)
folded.append(annotate_compound_statement_fallthrough(stmt))
break
folded.append(stmt)
i += 1
return folded
def visit_Module(self, node: Module) -> Module:
node_cp = super().visit_Module(node)
node_cp.body = self.fold_sequence(node_cp.body)
return annotate_compound_statement_fallthrough(node_cp)
def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef:
node_cp = super().visit_FunctionDef(node)
node_cp.body = self.fold_sequence(node_cp.body)
return annotate_compound_statement_fallthrough(node_cp)
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
node_cp = super().visit_ClassDef(node)
node_cp.body = self.fold_sequence(node_cp.body)
return annotate_compound_statement_fallthrough(node_cp)
def visit_If(self, node: If) -> If:
node_cp = copy(node)
node_cp.test = self.visit(node.test)
node_cp.body = self.fold_sequence(node.body)
node_cp.orelse = self.fold_sequence(node.orelse)
return annotate_compound_statement_fallthrough(node_cp)
def visit_While(self, node: While) -> While:
node_cp = copy(node)
node_cp.test = self.visit(node.test)
node_cp.body = self.fold_sequence(node.body)
node_cp.orelse = self.fold_sequence(node.orelse)
return annotate_compound_statement_fallthrough(node_cp)
def visit_For(self, node: For) -> For:
node_cp = copy(node)
node_cp.target = self.visit(node.target)
node_cp.iter = self.visit(node.iter)
node_cp.body = self.fold_sequence(node.body)
node_cp.orelse = self.fold_sequence(node.orelse)
return annotate_compound_statement_fallthrough(node_cp)
Classes
class OptimizeFoldIfFallthrough-
Rewrite nested statement sequences while preserving the surrounding node.
Expand source code
class OptimizeFoldIfFallthrough(ScopedSequenceNodeTransformer): step = "Folding trailing statements into sole fallthrough if-branches" def fold_sequence(self, statements): folded = [] i = 0 while i < len(statements): if statements[i] is None: i += 1 continue stmt = self.visit(statements[i]) if stmt is None: i += 1 continue if isinstance(stmt, If): body_can_fall_through = getattr(stmt, "body_can_fall_through", True) orelse_can_fall_through = getattr(stmt, "orelse_can_fall_through", True) if body_can_fall_through != orelse_can_fall_through and i + 1 < len( statements ): trailing = statements[i + 1 :] if body_can_fall_through: stmt.body = self.fold_sequence(stmt.body + trailing) else: stmt.orelse = self.fold_sequence(stmt.orelse + trailing) folded.append(annotate_compound_statement_fallthrough(stmt)) break folded.append(stmt) i += 1 return folded def visit_Module(self, node: Module) -> Module: node_cp = super().visit_Module(node) node_cp.body = self.fold_sequence(node_cp.body) return annotate_compound_statement_fallthrough(node_cp) def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef: node_cp = super().visit_FunctionDef(node) node_cp.body = self.fold_sequence(node_cp.body) return annotate_compound_statement_fallthrough(node_cp) def visit_ClassDef(self, node: ClassDef) -> ClassDef: node_cp = super().visit_ClassDef(node) node_cp.body = self.fold_sequence(node_cp.body) return annotate_compound_statement_fallthrough(node_cp) def visit_If(self, node: If) -> If: node_cp = copy(node) node_cp.test = self.visit(node.test) node_cp.body = self.fold_sequence(node.body) node_cp.orelse = self.fold_sequence(node.orelse) return annotate_compound_statement_fallthrough(node_cp) def visit_While(self, node: While) -> While: node_cp = copy(node) node_cp.test = self.visit(node.test) node_cp.body = self.fold_sequence(node.body) node_cp.orelse = self.fold_sequence(node.orelse) return annotate_compound_statement_fallthrough(node_cp) def visit_For(self, node: For) -> For: node_cp = copy(node) node_cp.target = self.visit(node.target) node_cp.iter = self.visit(node.iter) node_cp.body = self.fold_sequence(node.body) node_cp.orelse = self.fold_sequence(node.orelse) return annotate_compound_statement_fallthrough(node_cp)Ancestors
- ScopedSequenceNodeTransformer
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var step
Methods
def fold_sequence(self, statements)-
Expand source code
def fold_sequence(self, statements): folded = [] i = 0 while i < len(statements): if statements[i] is None: i += 1 continue stmt = self.visit(statements[i]) if stmt is None: i += 1 continue if isinstance(stmt, If): body_can_fall_through = getattr(stmt, "body_can_fall_through", True) orelse_can_fall_through = getattr(stmt, "orelse_can_fall_through", True) if body_can_fall_through != orelse_can_fall_through and i + 1 < len( statements ): trailing = statements[i + 1 :] if body_can_fall_through: stmt.body = self.fold_sequence(stmt.body + trailing) else: stmt.orelse = self.fold_sequence(stmt.orelse + trailing) folded.append(annotate_compound_statement_fallthrough(stmt)) break folded.append(stmt) i += 1 return folded def visit(self, node)-
Inherited from:
ScopedSequenceNodeTransformer.visitVisit a node.
def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef-
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef: node_cp = super().visit_ClassDef(node) node_cp.body = self.fold_sequence(node_cp.body) return annotate_compound_statement_fallthrough(node_cp) def visit_For(self, node: ast.For) ‑> ast.For-
Expand source code
def visit_For(self, node: For) -> For: node_cp = copy(node) node_cp.target = self.visit(node.target) node_cp.iter = self.visit(node.iter) node_cp.body = self.fold_sequence(node.body) node_cp.orelse = self.fold_sequence(node.orelse) return annotate_compound_statement_fallthrough(node_cp) def visit_FunctionDef(self, node: ast.FunctionDef) ‑> ast.FunctionDef-
Expand source code
def visit_FunctionDef(self, node: FunctionDef) -> FunctionDef: node_cp = super().visit_FunctionDef(node) node_cp.body = self.fold_sequence(node_cp.body) return annotate_compound_statement_fallthrough(node_cp) def visit_If(self, node: ast.If) ‑> ast.If-
Expand source code
def visit_If(self, node: If) -> If: node_cp = copy(node) node_cp.test = self.visit(node.test) node_cp.body = self.fold_sequence(node.body) node_cp.orelse = self.fold_sequence(node.orelse) return annotate_compound_statement_fallthrough(node_cp) def visit_Module(self, node: ast.Module) ‑> ast.Module-
Expand source code
def visit_Module(self, node: Module) -> Module: node_cp = super().visit_Module(node) node_cp.body = self.fold_sequence(node_cp.body) return annotate_compound_statement_fallthrough(node_cp) def visit_While(self, node: ast.While) ‑> ast.While-
Expand source code
def visit_While(self, node: While) -> While: node_cp = copy(node) node_cp.test = self.visit(node.test) node_cp.body = self.fold_sequence(node.body) node_cp.orelse = self.fold_sequence(node.orelse) return annotate_compound_statement_fallthrough(node_cp)