Module opshin.rewrite.rewrite_import_dataclasses

Expand source code
from ast import *
from typing import Optional

import pluthon as plt
from frozenlist2 import frozenlist

from ..typed_ast import RawPlutoExpr
from ..type_impls import (
    FunctionType,
    InstanceType,
    PolymorphicFunction,
    PolymorphicFunctionType,
    RawTupleType,
    RecordType,
    Type,
)
from ..util import CompilingNodeTransformer, OLambda, OVar

"""
Checks that there was an import of dataclass if there are any class definitions
"""


class AstupleImpl(PolymorphicFunction):
    def type_from_args(self, args: list[Type]) -> FunctionType:
        assert (
            len(args) == 1
        ), f"'astuple' takes one argument, but {len(args)} were given"
        arg = args[0]
        assert isinstance(arg, InstanceType) and isinstance(
            arg.typ, RecordType
        ), f"'astuple' expects a dataclass instance, found {arg.python_type()}"
        return FunctionType(
            args,
            InstanceType(
                RawTupleType(
                    frozenlist([field_typ for _, field_typ in arg.typ.record.fields]),
                )
            ),
        )

    def impl_from_args(self, args: list[Type]) -> plt.AST:
        arg = args[0]
        assert isinstance(arg, InstanceType) and isinstance(
            arg.typ, RecordType
        ), "Can only convert dataclass instances with astuple"
        return OLambda(["x"], plt.Fields(OVar("x")))


ASTUPLE_TYPE = InstanceType(PolymorphicFunctionType(AstupleImpl()))


class RewriteImportDataclasses(CompilingNodeTransformer):
    step = "Resolving the import and usage of dataclass"

    def __init__(self):
        self.imports_dataclasses = False
        self.imports_astuple = False

    def visit_ImportFrom(self, node: ImportFrom):
        if node.module != "dataclasses":
            return node
        imported_names = {name.name for name in node.names}
        additional_assigns = []
        for imported_name in node.names:
            assert imported_name.name in {
                "dataclass",
                "astuple",
            }, "Only 'dataclass' and 'astuple' may be imported from dataclasses"
            if imported_name.name == "dataclass":
                assert (
                    imported_name.asname == None
                ), "Imports of dataclass from dataclasses cannot be aliased"
            else:
                target_name = imported_name.asname or "astuple"
                additional_assigns.append(
                    Assign(
                        targets=[
                            Name(
                                id=target_name,
                                typ=ASTUPLE_TYPE,
                                ctx=Store(),
                            )
                        ],
                        value=RawPlutoExpr(
                            typ=ASTUPLE_TYPE,
                            expr=plt.Unit(),
                        ),
                    )
                )
        self.imports_dataclasses = (
            self.imports_dataclasses or "dataclass" in imported_names
        )
        self.imports_astuple = self.imports_astuple or "astuple" in imported_names
        return additional_assigns

    def visit_Call(self, node: Call) -> Call:
        node = self.generic_visit(node)
        if isinstance(node.func, Name) and node.func.id == "astuple":
            assert (
                self.imports_astuple
            ), "astuple must be imported via 'from dataclasses import astuple'"
        return node

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        assert (
            self.imports_dataclasses
        ), "dataclasses must be imported in order to use datum classes"
        assert (
            len(node.decorator_list) == 1
        ), "Class definitions must have the decorator @dataclass"
        if isinstance(node.decorator_list[0], Call):
            node_decorator = node.decorator_list[0].func
        elif isinstance(node.decorator_list[0], Name):
            node_decorator = node.decorator_list[0]
        else:
            raise AssertionError("Class definitions must have the decorator @dataclass")
        assert isinstance(
            node_decorator, Name
        ), "Class definitions must have the decorator @dataclass"
        assert (
            node_decorator.id == "dataclass"
        ), "Class definitions must have the decorator @dataclass"
        return node

Classes

class AstupleImpl (*args, **kwargs)
Expand source code
class AstupleImpl(PolymorphicFunction):
    def type_from_args(self, args: list[Type]) -> FunctionType:
        assert (
            len(args) == 1
        ), f"'astuple' takes one argument, but {len(args)} were given"
        arg = args[0]
        assert isinstance(arg, InstanceType) and isinstance(
            arg.typ, RecordType
        ), f"'astuple' expects a dataclass instance, found {arg.python_type()}"
        return FunctionType(
            args,
            InstanceType(
                RawTupleType(
                    frozenlist([field_typ for _, field_typ in arg.typ.record.fields]),
                )
            ),
        )

    def impl_from_args(self, args: list[Type]) -> plt.AST:
        arg = args[0]
        assert isinstance(arg, InstanceType) and isinstance(
            arg.typ, RecordType
        ), "Can only convert dataclass instances with astuple"
        return OLambda(["x"], plt.Fields(OVar("x")))

Ancestors

Methods

def impl_from_args(self, args: list[Type]) ‑> pluthon.pluthon_ast.AST
Expand source code
def impl_from_args(self, args: list[Type]) -> plt.AST:
    arg = args[0]
    assert isinstance(arg, InstanceType) and isinstance(
        arg.typ, RecordType
    ), "Can only convert dataclass instances with astuple"
    return OLambda(["x"], plt.Fields(OVar("x")))
def type_from_args(self, args: list[Type]) ‑> FunctionType
Expand source code
def type_from_args(self, args: list[Type]) -> FunctionType:
    assert (
        len(args) == 1
    ), f"'astuple' takes one argument, but {len(args)} were given"
    arg = args[0]
    assert isinstance(arg, InstanceType) and isinstance(
        arg.typ, RecordType
    ), f"'astuple' expects a dataclass instance, found {arg.python_type()}"
    return FunctionType(
        args,
        InstanceType(
            RawTupleType(
                frozenlist([field_typ for _, field_typ in arg.typ.record.fields]),
            )
        ),
    )
class RewriteImportDataclasses

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteImportDataclasses(CompilingNodeTransformer):
    step = "Resolving the import and usage of dataclass"

    def __init__(self):
        self.imports_dataclasses = False
        self.imports_astuple = False

    def visit_ImportFrom(self, node: ImportFrom):
        if node.module != "dataclasses":
            return node
        imported_names = {name.name for name in node.names}
        additional_assigns = []
        for imported_name in node.names:
            assert imported_name.name in {
                "dataclass",
                "astuple",
            }, "Only 'dataclass' and 'astuple' may be imported from dataclasses"
            if imported_name.name == "dataclass":
                assert (
                    imported_name.asname == None
                ), "Imports of dataclass from dataclasses cannot be aliased"
            else:
                target_name = imported_name.asname or "astuple"
                additional_assigns.append(
                    Assign(
                        targets=[
                            Name(
                                id=target_name,
                                typ=ASTUPLE_TYPE,
                                ctx=Store(),
                            )
                        ],
                        value=RawPlutoExpr(
                            typ=ASTUPLE_TYPE,
                            expr=plt.Unit(),
                        ),
                    )
                )
        self.imports_dataclasses = (
            self.imports_dataclasses or "dataclass" in imported_names
        )
        self.imports_astuple = self.imports_astuple or "astuple" in imported_names
        return additional_assigns

    def visit_Call(self, node: Call) -> Call:
        node = self.generic_visit(node)
        if isinstance(node.func, Name) and node.func.id == "astuple":
            assert (
                self.imports_astuple
            ), "astuple must be imported via 'from dataclasses import astuple'"
        return node

    def visit_ClassDef(self, node: ClassDef) -> ClassDef:
        assert (
            self.imports_dataclasses
        ), "dataclasses must be imported in order to use datum classes"
        assert (
            len(node.decorator_list) == 1
        ), "Class definitions must have the decorator @dataclass"
        if isinstance(node.decorator_list[0], Call):
            node_decorator = node.decorator_list[0].func
        elif isinstance(node.decorator_list[0], Name):
            node_decorator = node.decorator_list[0]
        else:
            raise AssertionError("Class definitions must have the decorator @dataclass")
        assert isinstance(
            node_decorator, Name
        ), "Class definitions must have the decorator @dataclass"
        assert (
            node_decorator.id == "dataclass"
        ), "Class definitions must have the decorator @dataclass"
        return node

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_Call(self, node: ast.Call) ‑> ast.Call
Expand source code
def visit_Call(self, node: Call) -> Call:
    node = self.generic_visit(node)
    if isinstance(node.func, Name) and node.func.id == "astuple":
        assert (
            self.imports_astuple
        ), "astuple must be imported via 'from dataclasses import astuple'"
    return node
def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
    assert (
        self.imports_dataclasses
    ), "dataclasses must be imported in order to use datum classes"
    assert (
        len(node.decorator_list) == 1
    ), "Class definitions must have the decorator @dataclass"
    if isinstance(node.decorator_list[0], Call):
        node_decorator = node.decorator_list[0].func
    elif isinstance(node.decorator_list[0], Name):
        node_decorator = node.decorator_list[0]
    else:
        raise AssertionError("Class definitions must have the decorator @dataclass")
    assert isinstance(
        node_decorator, Name
    ), "Class definitions must have the decorator @dataclass"
    assert (
        node_decorator.id == "dataclass"
    ), "Class definitions must have the decorator @dataclass"
    return node
def visit_ImportFrom(self, node: ast.ImportFrom)
Expand source code
def visit_ImportFrom(self, node: ImportFrom):
    if node.module != "dataclasses":
        return node
    imported_names = {name.name for name in node.names}
    additional_assigns = []
    for imported_name in node.names:
        assert imported_name.name in {
            "dataclass",
            "astuple",
        }, "Only 'dataclass' and 'astuple' may be imported from dataclasses"
        if imported_name.name == "dataclass":
            assert (
                imported_name.asname == None
            ), "Imports of dataclass from dataclasses cannot be aliased"
        else:
            target_name = imported_name.asname or "astuple"
            additional_assigns.append(
                Assign(
                    targets=[
                        Name(
                            id=target_name,
                            typ=ASTUPLE_TYPE,
                            ctx=Store(),
                        )
                    ],
                    value=RawPlutoExpr(
                        typ=ASTUPLE_TYPE,
                        expr=plt.Unit(),
                    ),
                )
            )
    self.imports_dataclasses = (
        self.imports_dataclasses or "dataclass" in imported_names
    )
    self.imports_astuple = self.imports_astuple or "astuple" in imported_names
    return additional_assigns