Module opshin.rewrite.rewrite_import_dataclasses
Expand source code
from ast import *
from typing import Optional
import pluthon as plt
from frozenlist2 import frozenlist
from ..typed_ast import RawPlutoExpr
from ..type_impls import (
FunctionType,
InstanceType,
PolymorphicFunction,
PolymorphicFunctionType,
RawTupleType,
RecordType,
Type,
)
from ..util import CompilingNodeTransformer, OLambda, OVar
"""
Checks that there was an import of dataclass if there are any class definitions
"""
class AstupleImpl(PolymorphicFunction):
def type_from_args(self, args: list[Type]) -> FunctionType:
assert (
len(args) == 1
), f"'astuple' takes one argument, but {len(args)} were given"
arg = args[0]
assert isinstance(arg, InstanceType) and isinstance(
arg.typ, RecordType
), f"'astuple' expects a dataclass instance, found {arg.python_type()}"
return FunctionType(
args,
InstanceType(
RawTupleType(
frozenlist([field_typ for _, field_typ in arg.typ.record.fields]),
)
),
)
def impl_from_args(self, args: list[Type]) -> plt.AST:
arg = args[0]
assert isinstance(arg, InstanceType) and isinstance(
arg.typ, RecordType
), "Can only convert dataclass instances with astuple"
return OLambda(["x"], plt.Fields(OVar("x")))
ASTUPLE_TYPE = InstanceType(PolymorphicFunctionType(AstupleImpl()))
class RewriteImportDataclasses(CompilingNodeTransformer):
step = "Resolving the import and usage of dataclass"
def __init__(self):
self.imports_dataclasses = False
self.imports_astuple = False
def visit_ImportFrom(self, node: ImportFrom):
if node.module != "dataclasses":
return node
imported_names = {name.name for name in node.names}
additional_assigns = []
for imported_name in node.names:
assert imported_name.name in {
"dataclass",
"astuple",
}, "Only 'dataclass' and 'astuple' may be imported from dataclasses"
if imported_name.name == "dataclass":
assert (
imported_name.asname == None
), "Imports of dataclass from dataclasses cannot be aliased"
else:
target_name = imported_name.asname or "astuple"
additional_assigns.append(
Assign(
targets=[
Name(
id=target_name,
typ=ASTUPLE_TYPE,
ctx=Store(),
)
],
value=RawPlutoExpr(
typ=ASTUPLE_TYPE,
expr=plt.Unit(),
),
)
)
self.imports_dataclasses = (
self.imports_dataclasses or "dataclass" in imported_names
)
self.imports_astuple = self.imports_astuple or "astuple" in imported_names
return additional_assigns
def visit_Call(self, node: Call) -> Call:
node = self.generic_visit(node)
if isinstance(node.func, Name) and node.func.id == "astuple":
assert (
self.imports_astuple
), "astuple must be imported via 'from dataclasses import astuple'"
return node
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
assert (
self.imports_dataclasses
), "dataclasses must be imported in order to use datum classes"
assert (
len(node.decorator_list) == 1
), "Class definitions must have the decorator @dataclass"
if isinstance(node.decorator_list[0], Call):
node_decorator = node.decorator_list[0].func
elif isinstance(node.decorator_list[0], Name):
node_decorator = node.decorator_list[0]
else:
raise AssertionError("Class definitions must have the decorator @dataclass")
assert isinstance(
node_decorator, Name
), "Class definitions must have the decorator @dataclass"
assert (
node_decorator.id == "dataclass"
), "Class definitions must have the decorator @dataclass"
return node
Classes
class AstupleImpl (*args, **kwargs)-
Expand source code
class AstupleImpl(PolymorphicFunction): def type_from_args(self, args: list[Type]) -> FunctionType: assert ( len(args) == 1 ), f"'astuple' takes one argument, but {len(args)} were given" arg = args[0] assert isinstance(arg, InstanceType) and isinstance( arg.typ, RecordType ), f"'astuple' expects a dataclass instance, found {arg.python_type()}" return FunctionType( args, InstanceType( RawTupleType( frozenlist([field_typ for _, field_typ in arg.typ.record.fields]), ) ), ) def impl_from_args(self, args: list[Type]) -> plt.AST: arg = args[0] assert isinstance(arg, InstanceType) and isinstance( arg.typ, RecordType ), "Can only convert dataclass instances with astuple" return OLambda(["x"], plt.Fields(OVar("x")))Ancestors
Methods
def impl_from_args(self, args: list[Type]) ‑> pluthon.pluthon_ast.AST-
Expand source code
def impl_from_args(self, args: list[Type]) -> plt.AST: arg = args[0] assert isinstance(arg, InstanceType) and isinstance( arg.typ, RecordType ), "Can only convert dataclass instances with astuple" return OLambda(["x"], plt.Fields(OVar("x"))) def type_from_args(self, args: list[Type]) ‑> FunctionType-
Expand source code
def type_from_args(self, args: list[Type]) -> FunctionType: assert ( len(args) == 1 ), f"'astuple' takes one argument, but {len(args)} were given" arg = args[0] assert isinstance(arg, InstanceType) and isinstance( arg.typ, RecordType ), f"'astuple' expects a dataclass instance, found {arg.python_type()}" return FunctionType( args, InstanceType( RawTupleType( frozenlist([field_typ for _, field_typ in arg.typ.record.fields]), ) ), )
class RewriteImportDataclasses-
A :class:
NodeVisitorsubclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformerwill walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo) todata['foo']::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visitmethod for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RewriteImportDataclasses(CompilingNodeTransformer): step = "Resolving the import and usage of dataclass" def __init__(self): self.imports_dataclasses = False self.imports_astuple = False def visit_ImportFrom(self, node: ImportFrom): if node.module != "dataclasses": return node imported_names = {name.name for name in node.names} additional_assigns = [] for imported_name in node.names: assert imported_name.name in { "dataclass", "astuple", }, "Only 'dataclass' and 'astuple' may be imported from dataclasses" if imported_name.name == "dataclass": assert ( imported_name.asname == None ), "Imports of dataclass from dataclasses cannot be aliased" else: target_name = imported_name.asname or "astuple" additional_assigns.append( Assign( targets=[ Name( id=target_name, typ=ASTUPLE_TYPE, ctx=Store(), ) ], value=RawPlutoExpr( typ=ASTUPLE_TYPE, expr=plt.Unit(), ), ) ) self.imports_dataclasses = ( self.imports_dataclasses or "dataclass" in imported_names ) self.imports_astuple = self.imports_astuple or "astuple" in imported_names return additional_assigns def visit_Call(self, node: Call) -> Call: node = self.generic_visit(node) if isinstance(node.func, Name) and node.func.id == "astuple": assert ( self.imports_astuple ), "astuple must be imported via 'from dataclasses import astuple'" return node def visit_ClassDef(self, node: ClassDef) -> ClassDef: assert ( self.imports_dataclasses ), "dataclasses must be imported in order to use datum classes" assert ( len(node.decorator_list) == 1 ), "Class definitions must have the decorator @dataclass" if isinstance(node.decorator_list[0], Call): node_decorator = node.decorator_list[0].func elif isinstance(node.decorator_list[0], Name): node_decorator = node.decorator_list[0] else: raise AssertionError("Class definitions must have the decorator @dataclass") assert isinstance( node_decorator, Name ), "Class definitions must have the decorator @dataclass" assert ( node_decorator.id == "dataclass" ), "Class definitions must have the decorator @dataclass" return nodeAncestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)-
Inherited from:
CompilingNodeTransformer.visitVisit a node.
def visit_Call(self, node: ast.Call) ‑> ast.Call-
Expand source code
def visit_Call(self, node: Call) -> Call: node = self.generic_visit(node) if isinstance(node.func, Name) and node.func.id == "astuple": assert ( self.imports_astuple ), "astuple must be imported via 'from dataclasses import astuple'" return node def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef-
Expand source code
def visit_ClassDef(self, node: ClassDef) -> ClassDef: assert ( self.imports_dataclasses ), "dataclasses must be imported in order to use datum classes" assert ( len(node.decorator_list) == 1 ), "Class definitions must have the decorator @dataclass" if isinstance(node.decorator_list[0], Call): node_decorator = node.decorator_list[0].func elif isinstance(node.decorator_list[0], Name): node_decorator = node.decorator_list[0] else: raise AssertionError("Class definitions must have the decorator @dataclass") assert isinstance( node_decorator, Name ), "Class definitions must have the decorator @dataclass" assert ( node_decorator.id == "dataclass" ), "Class definitions must have the decorator @dataclass" return node def visit_ImportFrom(self, node: ast.ImportFrom)-
Expand source code
def visit_ImportFrom(self, node: ImportFrom): if node.module != "dataclasses": return node imported_names = {name.name for name in node.names} additional_assigns = [] for imported_name in node.names: assert imported_name.name in { "dataclass", "astuple", }, "Only 'dataclass' and 'astuple' may be imported from dataclasses" if imported_name.name == "dataclass": assert ( imported_name.asname == None ), "Imports of dataclass from dataclasses cannot be aliased" else: target_name = imported_name.asname or "astuple" additional_assigns.append( Assign( targets=[ Name( id=target_name, typ=ASTUPLE_TYPE, ctx=Store(), ) ], value=RawPlutoExpr( typ=ASTUPLE_TYPE, expr=plt.Unit(), ), ) ) self.imports_dataclasses = ( self.imports_dataclasses or "dataclass" in imported_names ) self.imports_astuple = self.imports_astuple or "astuple" in imported_names return additional_assigns