Module opshin.optimize.optimize_union_expansion

Expand source code
from _ast import BoolOp, Call, FunctionDef, If, UnaryOp
from ast import *
from typing import Any, List
from ..util import CompilingNodeTransformer
from copy import deepcopy

"""
Expand union types
"""


def type_to_suffix(typ: expr) -> str:
    try:
        raw = unparse(typ)
    except Exception:
        return "UnknownType"
    return (
        raw.replace(" ", "")
        .replace("__", "___")
        .replace("[", "_l_")
        .replace("]", "_r_")
        .replace(",", "_c_")
        .replace(".", "_d_")
    )


class RemoveDeadCode(CompilingNodeTransformer):
    def __init__(self, arg_types: dict[str, str]):
        self.arg_types = arg_types

    def visit_FunctionDef(self, node: FunctionDef) -> Any:
        node.body = self.visit_sequence(node.body)
        return node

    def visit_sequence(self, stmts):
        new_stmts = []
        for stmt in stmts:
            s = self.visit(stmt)
            if isinstance(s, If) and isinstance(s.test, Constant):
                if s.test.value:
                    new_stmts.extend(s.body)
                else:
                    new_stmts.extend(s.orelse)
            else:
                new_stmts.append(s)
        return new_stmts

    def visit_If(self, node: If) -> Any:
        """
        Common types for `ast.If.test`:

            ast.Name      - `if x:`                     (truthiness of a variable)
            ast.Constant  - `if True:`, `if 0:`         (literal truthy/falsy)
            ast.Call      - `if func()`, `isinstance()` (function call)
            ast.Compare   - `if x > 3:`                 (comparison)
            ast.BoolOp    - `if x and y:`               (`and` / `or` logic)
            ast.UnaryOp   - `if not x:`                 (negation, e.g. `not`)
            ast.BinOp     - `if x + y:`                 (binary operation)
            ast.Attribute - `if obj.ready:`             (attribute access)
            ast.Subscript - `if arr[0]:`                (indexing)
            ast.Lambda    - `if lambda x: x > 0:`       (lambda - rare)
            ast.IfExp     - `if a if cond else b:`      (ternary - rare)

            The most likely to be used are ast.Call (if isinstance(...)), ast.BoolOp (if isinstance(...) and/or isinstance(...)), and ast.UnaryOp (if not isinstance(...))
        """
        node.test = self.visit(node.test)
        node.body = self.visit_sequence(node.body)
        node.orelse = self.visit_sequence(node.orelse)
        return node

    def visit_Call(self, node: Call) -> Any:
        node = self.generic_visit(node)
        # Check if this is an isinstance(x, T) call
        if (
            isinstance(node.func, Name)
            and node.func.id == "isinstance"
            and len(node.args) == 2
        ):
            arg, typ = node.args
            if isinstance(arg, Name) and isinstance(typ, Name):
                known_type = self.arg_types.get(arg.id)
                if known_type is not None:
                    typ_str = getattr(typ, "id", type_to_suffix(typ))
                    return Constant(value=(known_type == typ_str))

        return node

    def visit_BoolOp(self, node: BoolOp) -> Any:
        node.values = [self.visit(v) for v in node.values]
        # Check if all values are constants
        if all(isinstance(v, Constant) for v in node.values):
            values = [bool(v.value) for v in node.values]
            if isinstance(node.op, And):
                return Constant(value=all(values))
            elif isinstance(node.op, Or):
                return Constant(value=any(values))

        # Partial simplification: drop neutral constants
        # e.g. in `x or True`, return Constant(True)
        # e.g. in `x and False`, return Constant(False)
        if isinstance(node.op, And):
            for v in node.values:
                if isinstance(v, Constant) and not v.value:
                    return Constant(value=False)  # short-circuit
            node.values = [
                v for v in node.values if not (isinstance(v, Constant) and v.value)
            ]
        elif isinstance(node.op, Or):
            for v in node.values:
                if isinstance(v, Constant) and v.value:
                    return Constant(value=True)  # short-circuit
            node.values = [
                v for v in node.values if not (isinstance(v, Constant) and not v.value)
            ]
        # If only one value remains, return it directly
        if len(node.values) == 1:
            return node.values[0]
        return node

    def visit_UnaryOp(self, node: UnaryOp) -> Any:
        node.operand = self.visit(node.operand)

        # Only handle 'not' operations for now
        if isinstance(node.op, Not):
            # If it's `not <constant>`, simplify it
            if isinstance(node.operand, Constant):
                return Constant(value=not bool(node.operand.value))

        return node

    def visit_IfExp(self, node: IfExp) -> Any:
        node.test = self.visit(node.test)
        node.body = self.visit(node.body)
        node.orelse = self.visit(node.orelse)

        # Simplify if the test condition is a constant
        if isinstance(node.test, Constant):
            if node.test.value:
                return node.body
            else:
                return node.orelse

        return node


class OptimizeUnionExpansion(CompilingNodeTransformer):
    step = "Expanding Unions"

    def visit(self, node):
        if hasattr(node, "body") and isinstance(node.body, list):
            node.body = self.visit_sequence(node.body)
        if hasattr(node, "orelse") and isinstance(node.orelse, list):
            node.orelse = self.visit_sequence(node.orelse)
        if hasattr(node, "finalbody") and isinstance(node.finalbody, list):
            node.finalbody = self.visit_sequence(node.finalbody)
        return super().visit(node)

    def is_Union_annotation(self, ann: expr):
        if isinstance(ann, Subscript) and isinstance(ann.value, Name):
            if ann.value.id == "Union":
                return ann.slice.elts
        if isinstance(ann, BinOp) and isinstance(ann.op, BitOr):
            return self.flatten_union_bitor(ann)
        return False

    def flatten_union_bitor(self, node):
        # Recursively collect all types in a | b | c expression
        if isinstance(node, BinOp) and isinstance(node.op, BitOr):
            return self.flatten_union_bitor(node.left) + self.flatten_union_bitor(
                node.right
            )
        else:
            return [node]

    def split_functions(
        self, stmt: FunctionDef, args: list, arg_types: dict, naming=""
    ) -> List[FunctionDef]:
        """
                Recursively generate variants of a function with all possible combinations
        of expanded union types for its arguments.
        """
        new_functions = []
        for i, arg in enumerate(args):
            if not arg:
                continue
            n_args = deepcopy(args)
            n_args[i] = False
            for typ in arg:
                new_f = deepcopy(stmt)
                new_f.args.args[i].annotation = typ
                typ_str = getattr(typ, "id", type_to_suffix(typ))
                new_f.name = f"{naming}_{typ_str}"
                new_arg_types = deepcopy(arg_types)
                new_arg_types[stmt.args.args[i].arg] = typ_str
                new_f = RemoveDeadCode(new_arg_types).visit(new_f)
                new_functions.append(new_f)
                new_functions.extend(
                    self.split_functions(new_f, n_args, new_arg_types, new_f.name)
                )
            # Look for variation where this arg is still Union
            new_functions.extend(
                self.split_functions(stmt, n_args, arg_types, f"{naming}_Union")
            )
            # Handle only one Union per recursion level
            break

        return new_functions

    def visit_sequence(self, body):
        new_body = []
        for stmt in body:
            new_body.append(stmt)
            if isinstance(stmt, FunctionDef):
                args = [
                    self.is_Union_annotation(arg.annotation) for arg in stmt.args.args
                ]
                # number prefix here should guarantee naming uniqueness
                new_funcs = self.split_functions(stmt, args, {}, stmt.name + "+")
                # track variants
                new_body[-1].expanded_variants = [f.name for f in new_funcs]
                new_body.extend(new_funcs)
        return new_body

Functions

def type_to_suffix(typ: ast.expr) ‑> str
Expand source code
def type_to_suffix(typ: expr) -> str:
    try:
        raw = unparse(typ)
    except Exception:
        return "UnknownType"
    return (
        raw.replace(" ", "")
        .replace("__", "___")
        .replace("[", "_l_")
        .replace("]", "_r_")
        .replace(",", "_c_")
        .replace(".", "_d_")
    )

Classes

class OptimizeUnionExpansion

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class OptimizeUnionExpansion(CompilingNodeTransformer):
    step = "Expanding Unions"

    def visit(self, node):
        if hasattr(node, "body") and isinstance(node.body, list):
            node.body = self.visit_sequence(node.body)
        if hasattr(node, "orelse") and isinstance(node.orelse, list):
            node.orelse = self.visit_sequence(node.orelse)
        if hasattr(node, "finalbody") and isinstance(node.finalbody, list):
            node.finalbody = self.visit_sequence(node.finalbody)
        return super().visit(node)

    def is_Union_annotation(self, ann: expr):
        if isinstance(ann, Subscript) and isinstance(ann.value, Name):
            if ann.value.id == "Union":
                return ann.slice.elts
        if isinstance(ann, BinOp) and isinstance(ann.op, BitOr):
            return self.flatten_union_bitor(ann)
        return False

    def flatten_union_bitor(self, node):
        # Recursively collect all types in a | b | c expression
        if isinstance(node, BinOp) and isinstance(node.op, BitOr):
            return self.flatten_union_bitor(node.left) + self.flatten_union_bitor(
                node.right
            )
        else:
            return [node]

    def split_functions(
        self, stmt: FunctionDef, args: list, arg_types: dict, naming=""
    ) -> List[FunctionDef]:
        """
                Recursively generate variants of a function with all possible combinations
        of expanded union types for its arguments.
        """
        new_functions = []
        for i, arg in enumerate(args):
            if not arg:
                continue
            n_args = deepcopy(args)
            n_args[i] = False
            for typ in arg:
                new_f = deepcopy(stmt)
                new_f.args.args[i].annotation = typ
                typ_str = getattr(typ, "id", type_to_suffix(typ))
                new_f.name = f"{naming}_{typ_str}"
                new_arg_types = deepcopy(arg_types)
                new_arg_types[stmt.args.args[i].arg] = typ_str
                new_f = RemoveDeadCode(new_arg_types).visit(new_f)
                new_functions.append(new_f)
                new_functions.extend(
                    self.split_functions(new_f, n_args, new_arg_types, new_f.name)
                )
            # Look for variation where this arg is still Union
            new_functions.extend(
                self.split_functions(stmt, n_args, arg_types, f"{naming}_Union")
            )
            # Handle only one Union per recursion level
            break

        return new_functions

    def visit_sequence(self, body):
        new_body = []
        for stmt in body:
            new_body.append(stmt)
            if isinstance(stmt, FunctionDef):
                args = [
                    self.is_Union_annotation(arg.annotation) for arg in stmt.args.args
                ]
                # number prefix here should guarantee naming uniqueness
                new_funcs = self.split_functions(stmt, args, {}, stmt.name + "+")
                # track variants
                new_body[-1].expanded_variants = [f.name for f in new_funcs]
                new_body.extend(new_funcs)
        return new_body

Ancestors

Class variables

var step

Methods

def flatten_union_bitor(self, node)
Expand source code
def flatten_union_bitor(self, node):
    # Recursively collect all types in a | b | c expression
    if isinstance(node, BinOp) and isinstance(node.op, BitOr):
        return self.flatten_union_bitor(node.left) + self.flatten_union_bitor(
            node.right
        )
    else:
        return [node]
def is_Union_annotation(self, ann: ast.expr)
Expand source code
def is_Union_annotation(self, ann: expr):
    if isinstance(ann, Subscript) and isinstance(ann.value, Name):
        if ann.value.id == "Union":
            return ann.slice.elts
    if isinstance(ann, BinOp) and isinstance(ann.op, BitOr):
        return self.flatten_union_bitor(ann)
    return False
def split_functions(self, stmt: ast.FunctionDef, args: list, arg_types: dict, naming='') ‑> List[ast.FunctionDef]

Recursively generate variants of a function with all possible combinations of expanded union types for its arguments.

Expand source code
def split_functions(
    self, stmt: FunctionDef, args: list, arg_types: dict, naming=""
) -> List[FunctionDef]:
    """
            Recursively generate variants of a function with all possible combinations
    of expanded union types for its arguments.
    """
    new_functions = []
    for i, arg in enumerate(args):
        if not arg:
            continue
        n_args = deepcopy(args)
        n_args[i] = False
        for typ in arg:
            new_f = deepcopy(stmt)
            new_f.args.args[i].annotation = typ
            typ_str = getattr(typ, "id", type_to_suffix(typ))
            new_f.name = f"{naming}_{typ_str}"
            new_arg_types = deepcopy(arg_types)
            new_arg_types[stmt.args.args[i].arg] = typ_str
            new_f = RemoveDeadCode(new_arg_types).visit(new_f)
            new_functions.append(new_f)
            new_functions.extend(
                self.split_functions(new_f, n_args, new_arg_types, new_f.name)
            )
        # Look for variation where this arg is still Union
        new_functions.extend(
            self.split_functions(stmt, n_args, arg_types, f"{naming}_Union")
        )
        # Handle only one Union per recursion level
        break

    return new_functions
def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

Expand source code
def visit(self, node):
    if hasattr(node, "body") and isinstance(node.body, list):
        node.body = self.visit_sequence(node.body)
    if hasattr(node, "orelse") and isinstance(node.orelse, list):
        node.orelse = self.visit_sequence(node.orelse)
    if hasattr(node, "finalbody") and isinstance(node.finalbody, list):
        node.finalbody = self.visit_sequence(node.finalbody)
    return super().visit(node)
def visit_sequence(self, body)
Expand source code
def visit_sequence(self, body):
    new_body = []
    for stmt in body:
        new_body.append(stmt)
        if isinstance(stmt, FunctionDef):
            args = [
                self.is_Union_annotation(arg.annotation) for arg in stmt.args.args
            ]
            # number prefix here should guarantee naming uniqueness
            new_funcs = self.split_functions(stmt, args, {}, stmt.name + "+")
            # track variants
            new_body[-1].expanded_variants = [f.name for f in new_funcs]
            new_body.extend(new_funcs)
    return new_body
class RemoveDeadCode (arg_types: dict[str, str])

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RemoveDeadCode(CompilingNodeTransformer):
    def __init__(self, arg_types: dict[str, str]):
        self.arg_types = arg_types

    def visit_FunctionDef(self, node: FunctionDef) -> Any:
        node.body = self.visit_sequence(node.body)
        return node

    def visit_sequence(self, stmts):
        new_stmts = []
        for stmt in stmts:
            s = self.visit(stmt)
            if isinstance(s, If) and isinstance(s.test, Constant):
                if s.test.value:
                    new_stmts.extend(s.body)
                else:
                    new_stmts.extend(s.orelse)
            else:
                new_stmts.append(s)
        return new_stmts

    def visit_If(self, node: If) -> Any:
        """
        Common types for `ast.If.test`:

            ast.Name      - `if x:`                     (truthiness of a variable)
            ast.Constant  - `if True:`, `if 0:`         (literal truthy/falsy)
            ast.Call      - `if func()`, `isinstance()` (function call)
            ast.Compare   - `if x > 3:`                 (comparison)
            ast.BoolOp    - `if x and y:`               (`and` / `or` logic)
            ast.UnaryOp   - `if not x:`                 (negation, e.g. `not`)
            ast.BinOp     - `if x + y:`                 (binary operation)
            ast.Attribute - `if obj.ready:`             (attribute access)
            ast.Subscript - `if arr[0]:`                (indexing)
            ast.Lambda    - `if lambda x: x > 0:`       (lambda - rare)
            ast.IfExp     - `if a if cond else b:`      (ternary - rare)

            The most likely to be used are ast.Call (if isinstance(...)), ast.BoolOp (if isinstance(...) and/or isinstance(...)), and ast.UnaryOp (if not isinstance(...))
        """
        node.test = self.visit(node.test)
        node.body = self.visit_sequence(node.body)
        node.orelse = self.visit_sequence(node.orelse)
        return node

    def visit_Call(self, node: Call) -> Any:
        node = self.generic_visit(node)
        # Check if this is an isinstance(x, T) call
        if (
            isinstance(node.func, Name)
            and node.func.id == "isinstance"
            and len(node.args) == 2
        ):
            arg, typ = node.args
            if isinstance(arg, Name) and isinstance(typ, Name):
                known_type = self.arg_types.get(arg.id)
                if known_type is not None:
                    typ_str = getattr(typ, "id", type_to_suffix(typ))
                    return Constant(value=(known_type == typ_str))

        return node

    def visit_BoolOp(self, node: BoolOp) -> Any:
        node.values = [self.visit(v) for v in node.values]
        # Check if all values are constants
        if all(isinstance(v, Constant) for v in node.values):
            values = [bool(v.value) for v in node.values]
            if isinstance(node.op, And):
                return Constant(value=all(values))
            elif isinstance(node.op, Or):
                return Constant(value=any(values))

        # Partial simplification: drop neutral constants
        # e.g. in `x or True`, return Constant(True)
        # e.g. in `x and False`, return Constant(False)
        if isinstance(node.op, And):
            for v in node.values:
                if isinstance(v, Constant) and not v.value:
                    return Constant(value=False)  # short-circuit
            node.values = [
                v for v in node.values if not (isinstance(v, Constant) and v.value)
            ]
        elif isinstance(node.op, Or):
            for v in node.values:
                if isinstance(v, Constant) and v.value:
                    return Constant(value=True)  # short-circuit
            node.values = [
                v for v in node.values if not (isinstance(v, Constant) and not v.value)
            ]
        # If only one value remains, return it directly
        if len(node.values) == 1:
            return node.values[0]
        return node

    def visit_UnaryOp(self, node: UnaryOp) -> Any:
        node.operand = self.visit(node.operand)

        # Only handle 'not' operations for now
        if isinstance(node.op, Not):
            # If it's `not <constant>`, simplify it
            if isinstance(node.operand, Constant):
                return Constant(value=not bool(node.operand.value))

        return node

    def visit_IfExp(self, node: IfExp) -> Any:
        node.test = self.visit(node.test)
        node.body = self.visit(node.body)
        node.orelse = self.visit(node.orelse)

        # Simplify if the test condition is a constant
        if isinstance(node.test, Constant):
            if node.test.value:
                return node.body
            else:
                return node.orelse

        return node

Ancestors

Methods

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_BoolOp(self, node: ast.BoolOp) ‑> Any
Expand source code
def visit_BoolOp(self, node: BoolOp) -> Any:
    node.values = [self.visit(v) for v in node.values]
    # Check if all values are constants
    if all(isinstance(v, Constant) for v in node.values):
        values = [bool(v.value) for v in node.values]
        if isinstance(node.op, And):
            return Constant(value=all(values))
        elif isinstance(node.op, Or):
            return Constant(value=any(values))

    # Partial simplification: drop neutral constants
    # e.g. in `x or True`, return Constant(True)
    # e.g. in `x and False`, return Constant(False)
    if isinstance(node.op, And):
        for v in node.values:
            if isinstance(v, Constant) and not v.value:
                return Constant(value=False)  # short-circuit
        node.values = [
            v for v in node.values if not (isinstance(v, Constant) and v.value)
        ]
    elif isinstance(node.op, Or):
        for v in node.values:
            if isinstance(v, Constant) and v.value:
                return Constant(value=True)  # short-circuit
        node.values = [
            v for v in node.values if not (isinstance(v, Constant) and not v.value)
        ]
    # If only one value remains, return it directly
    if len(node.values) == 1:
        return node.values[0]
    return node
def visit_Call(self, node: ast.Call) ‑> Any
Expand source code
def visit_Call(self, node: Call) -> Any:
    node = self.generic_visit(node)
    # Check if this is an isinstance(x, T) call
    if (
        isinstance(node.func, Name)
        and node.func.id == "isinstance"
        and len(node.args) == 2
    ):
        arg, typ = node.args
        if isinstance(arg, Name) and isinstance(typ, Name):
            known_type = self.arg_types.get(arg.id)
            if known_type is not None:
                typ_str = getattr(typ, "id", type_to_suffix(typ))
                return Constant(value=(known_type == typ_str))

    return node
def visit_FunctionDef(self, node: ast.FunctionDef) ‑> Any
Expand source code
def visit_FunctionDef(self, node: FunctionDef) -> Any:
    node.body = self.visit_sequence(node.body)
    return node
def visit_If(self, node: ast.If) ‑> Any

Common types for ast.If.test:

ast.Name      - `if x:`                     (truthiness of a variable)
ast.Constant  - `if True:`, `if 0:`         (literal truthy/falsy)
ast.Call      - <code>if func()</code>, <code>isinstance()</code> (function call)
ast.Compare   - `if x > 3:`                 (comparison)
ast.BoolOp    - `if x and y:`               (<code>and</code> / <code>or</code> logic)
ast.UnaryOp   - `if not x:`                 (negation, e.g. <code>not</code>)
ast.BinOp     - `if x + y:`                 (binary operation)
ast.Attribute - `if obj.ready:`             (attribute access)
ast.Subscript - `if arr[0]:`                (indexing)
ast.Lambda    - `if lambda x: x > 0:`       (lambda - rare)
ast.IfExp     - `if a if cond else b:`      (ternary - rare)

The most likely to be used are ast.Call (if isinstance(...)), ast.BoolOp (if isinstance(...) and/or isinstance(...)), and ast.UnaryOp (if not isinstance(...))
Expand source code
def visit_If(self, node: If) -> Any:
    """
    Common types for `ast.If.test`:

        ast.Name      - `if x:`                     (truthiness of a variable)
        ast.Constant  - `if True:`, `if 0:`         (literal truthy/falsy)
        ast.Call      - `if func()`, `isinstance()` (function call)
        ast.Compare   - `if x > 3:`                 (comparison)
        ast.BoolOp    - `if x and y:`               (`and` / `or` logic)
        ast.UnaryOp   - `if not x:`                 (negation, e.g. `not`)
        ast.BinOp     - `if x + y:`                 (binary operation)
        ast.Attribute - `if obj.ready:`             (attribute access)
        ast.Subscript - `if arr[0]:`                (indexing)
        ast.Lambda    - `if lambda x: x > 0:`       (lambda - rare)
        ast.IfExp     - `if a if cond else b:`      (ternary - rare)

        The most likely to be used are ast.Call (if isinstance(...)), ast.BoolOp (if isinstance(...) and/or isinstance(...)), and ast.UnaryOp (if not isinstance(...))
    """
    node.test = self.visit(node.test)
    node.body = self.visit_sequence(node.body)
    node.orelse = self.visit_sequence(node.orelse)
    return node
def visit_IfExp(self, node: ast.IfExp) ‑> Any
Expand source code
def visit_IfExp(self, node: IfExp) -> Any:
    node.test = self.visit(node.test)
    node.body = self.visit(node.body)
    node.orelse = self.visit(node.orelse)

    # Simplify if the test condition is a constant
    if isinstance(node.test, Constant):
        if node.test.value:
            return node.body
        else:
            return node.orelse

    return node
def visit_UnaryOp(self, node: ast.UnaryOp) ‑> Any
Expand source code
def visit_UnaryOp(self, node: UnaryOp) -> Any:
    node.operand = self.visit(node.operand)

    # Only handle 'not' operations for now
    if isinstance(node.op, Not):
        # If it's `not <constant>`, simplify it
        if isinstance(node.operand, Constant):
            return Constant(value=not bool(node.operand.value))

    return node
def visit_sequence(self, stmts)
Expand source code
def visit_sequence(self, stmts):
    new_stmts = []
    for stmt in stmts:
        s = self.visit(stmt)
        if isinstance(s, If) and isinstance(s.test, Constant):
            if s.test.value:
                new_stmts.extend(s.body)
            else:
                new_stmts.extend(s.orelse)
        else:
            new_stmts.append(s)
    return new_stmts