Module opshin.rewrite.rewrite_import_typing

Expand source code
from ast import *
from typing import Optional

from ..util import CompilingNodeTransformer

"""
Checks that there was an import of dataclass if there are any class definitions
"""


ALLOWED_TYPING_IMPORTS = {"Dict", "List", "Union", "Self"}


class RewriteImportTyping(CompilingNodeTransformer):
    step = "Checking import and usage of typing"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.imports = set()

    def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
        if node.module != "typing":
            return node

        for n in node.names:
            if n.name not in ALLOWED_TYPING_IMPORTS:
                raise ValueError(
                    f"Only the following imports from typing are allowed: {ALLOWED_TYPING_IMPORTS}, found {n.name}"
                )
            if n.asname is not None:
                raise ValueError("Imports from typing cannot be aliased")
            self.imports.add(n.name)
        return None

    def visit_Name(self, node: Name) -> Optional[Name]:
        if node.id in ALLOWED_TYPING_IMPORTS and node.id not in self.imports:
            raise ValueError(
                f"{node.id} used, which is a keyword for special OpShin types, but typing not imported. Please add 'from typing import {node.id}'"
            )
        return node

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        for i, attribute in enumerate(node.body):
            if isinstance(attribute, FunctionDef):
                for j, arg in enumerate(attribute.args.args):
                    if isinstance(arg.annotation, Name) and arg.annotation.id == "Self":
                        assert (
                            "Self" in self.imports
                        ), "Self used but not imported from typing. Please add 'from typing import Self'"
                        node.body[i].args.args[j].annotation.idSelf = node.name
                    if (
                        isinstance(arg.annotation, Subscript)
                        and arg.annotation.value.id == "Union"
                    ):
                        assert (
                            "Union" in self.imports
                        ), "Union used but not imported from typing. Please add 'from typing import Union'"
                        for k, s in enumerate(arg.annotation.slice.elts):
                            if isinstance(s, Name) and s.id == "Self":
                                assert (
                                    "Self" in self.imports
                                ), "Self used but not imported from typing. Please add 'from typing import Self'"
                                node.body[i].args.args[j].annotation.slice.elts[
                                    k
                                ].idSelf = node.name

                if (
                    isinstance(attribute.returns, Name)
                    and attribute.returns.id == "Self"
                ):
                    assert (
                        "Self" in self.imports
                    ), "Self used but not imported from typing. Please add 'from typing import Self'"
                    node.body[i].returns.idSelf = node.name

        return node

Classes

class RewriteImportTyping (*args, **kwargs)

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteImportTyping(CompilingNodeTransformer):
    step = "Checking import and usage of typing"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.imports = set()

    def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
        if node.module != "typing":
            return node

        for n in node.names:
            if n.name not in ALLOWED_TYPING_IMPORTS:
                raise ValueError(
                    f"Only the following imports from typing are allowed: {ALLOWED_TYPING_IMPORTS}, found {n.name}"
                )
            if n.asname is not None:
                raise ValueError("Imports from typing cannot be aliased")
            self.imports.add(n.name)
        return None

    def visit_Name(self, node: Name) -> Optional[Name]:
        if node.id in ALLOWED_TYPING_IMPORTS and node.id not in self.imports:
            raise ValueError(
                f"{node.id} used, which is a keyword for special OpShin types, but typing not imported. Please add 'from typing import {node.id}'"
            )
        return node

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        for i, attribute in enumerate(node.body):
            if isinstance(attribute, FunctionDef):
                for j, arg in enumerate(attribute.args.args):
                    if isinstance(arg.annotation, Name) and arg.annotation.id == "Self":
                        assert (
                            "Self" in self.imports
                        ), "Self used but not imported from typing. Please add 'from typing import Self'"
                        node.body[i].args.args[j].annotation.idSelf = node.name
                    if (
                        isinstance(arg.annotation, Subscript)
                        and arg.annotation.value.id == "Union"
                    ):
                        assert (
                            "Union" in self.imports
                        ), "Union used but not imported from typing. Please add 'from typing import Union'"
                        for k, s in enumerate(arg.annotation.slice.elts):
                            if isinstance(s, Name) and s.id == "Self":
                                assert (
                                    "Self" in self.imports
                                ), "Self used but not imported from typing. Please add 'from typing import Self'"
                                node.body[i].args.args[j].annotation.slice.elts[
                                    k
                                ].idSelf = node.name

                if (
                    isinstance(attribute.returns, Name)
                    and attribute.returns.id == "Self"
                ):
                    assert (
                        "Self" in self.imports
                    ), "Self used but not imported from typing. Please add 'from typing import Self'"
                    node.body[i].returns.idSelf = node.name

        return node

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
    for i, attribute in enumerate(node.body):
        if isinstance(attribute, FunctionDef):
            for j, arg in enumerate(attribute.args.args):
                if isinstance(arg.annotation, Name) and arg.annotation.id == "Self":
                    assert (
                        "Self" in self.imports
                    ), "Self used but not imported from typing. Please add 'from typing import Self'"
                    node.body[i].args.args[j].annotation.idSelf = node.name
                if (
                    isinstance(arg.annotation, Subscript)
                    and arg.annotation.value.id == "Union"
                ):
                    assert (
                        "Union" in self.imports
                    ), "Union used but not imported from typing. Please add 'from typing import Union'"
                    for k, s in enumerate(arg.annotation.slice.elts):
                        if isinstance(s, Name) and s.id == "Self":
                            assert (
                                "Self" in self.imports
                            ), "Self used but not imported from typing. Please add 'from typing import Self'"
                            node.body[i].args.args[j].annotation.slice.elts[
                                k
                            ].idSelf = node.name

            if (
                isinstance(attribute.returns, Name)
                and attribute.returns.id == "Self"
            ):
                assert (
                    "Self" in self.imports
                ), "Self used but not imported from typing. Please add 'from typing import Self'"
                node.body[i].returns.idSelf = node.name

    return node
def visit_ImportFrom(self, node: ast.ImportFrom) ‑> ast.ImportFrom | None
Expand source code
def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
    if node.module != "typing":
        return node

    for n in node.names:
        if n.name not in ALLOWED_TYPING_IMPORTS:
            raise ValueError(
                f"Only the following imports from typing are allowed: {ALLOWED_TYPING_IMPORTS}, found {n.name}"
            )
        if n.asname is not None:
            raise ValueError("Imports from typing cannot be aliased")
        self.imports.add(n.name)
    return None
def visit_Name(self, node: ast.Name) ‑> ast.Name | None
Expand source code
def visit_Name(self, node: Name) -> Optional[Name]:
    if node.id in ALLOWED_TYPING_IMPORTS and node.id not in self.imports:
        raise ValueError(
            f"{node.id} used, which is a keyword for special OpShin types, but typing not imported. Please add 'from typing import {node.id}'"
        )
    return node