Module opshin.rewrite.rewrite_import_typing
Expand source code
from ast import *
from typing import Optional
from ..util import CompilingNodeTransformer
"""
Checks that there was an import of dataclass if there are any class definitions
"""
class RewriteImportTyping(CompilingNodeTransformer):
step = "Checking import and usage of typing"
imports_typing = False
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.imports_Self = False
def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]:
if node.module != "typing":
return node
if len(node.names) == 1 and node.names[0].name == "Self":
self.imports_Self = True
return None
assert (
len(node.names) == 3
), "The program must contain one 'from typing import Dict, List, Union'"
for i, n in enumerate(["Dict", "List", "Union"]):
assert (
node.names[i].name == n
), "The program must contain one 'from typing import Dict, List, Union'"
assert (
node.names[i].asname == None
), "The program must contain one 'from typing import Dict, List, Union'"
self.imports_typing = True
return None
def visit_ClassDef(self, node: ClassDef) -> ClassDef:
assert (
self.imports_typing
), "typing must be imported in order to use datum classes"
if self.imports_Self:
for i, attribute in enumerate(node.body):
if isinstance(attribute, FunctionDef):
for j, arg in enumerate(attribute.args.args):
if (
isinstance(arg.annotation, Name)
and arg.annotation.id == "Self"
):
node.body[i].args.args[j].annotation.idSelf = node.name
if (
isinstance(arg.annotation, Subscript)
and arg.annotation.value.id == "Union"
):
for k, s in enumerate(arg.annotation.slice.elts):
if isinstance(s, Name) and s.id == "Self":
node.body[i].args.args[j].annotation.slice.elts[
k
].idSelf = node.name
if (
isinstance(attribute.returns, Name)
and attribute.returns.id == "Self"
):
node.body[i].returns.idSelf = node.name
return node
Classes
class RewriteImportTyping (*args, **kwargs)
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RewriteImportTyping(CompilingNodeTransformer): step = "Checking import and usage of typing" imports_typing = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.imports_Self = False def visit_ImportFrom(self, node: ImportFrom) -> Optional[ImportFrom]: if node.module != "typing": return node if len(node.names) == 1 and node.names[0].name == "Self": self.imports_Self = True return None assert ( len(node.names) == 3 ), "The program must contain one 'from typing import Dict, List, Union'" for i, n in enumerate(["Dict", "List", "Union"]): assert ( node.names[i].name == n ), "The program must contain one 'from typing import Dict, List, Union'" assert ( node.names[i].asname == None ), "The program must contain one 'from typing import Dict, List, Union'" self.imports_typing = True return None def visit_ClassDef(self, node: ClassDef) -> ClassDef: assert ( self.imports_typing ), "typing must be imported in order to use datum classes" if self.imports_Self: for i, attribute in enumerate(node.body): if isinstance(attribute, FunctionDef): for j, arg in enumerate(attribute.args.args): if ( isinstance(arg.annotation, Name) and arg.annotation.id == "Self" ): node.body[i].args.args[j].annotation.idSelf = node.name if ( isinstance(arg.annotation, Subscript) and arg.annotation.value.id == "Union" ): for k, s in enumerate(arg.annotation.slice.elts): if isinstance(s, Name) and s.id == "Self": node.body[i].args.args[j].annotation.slice.elts[ k ].idSelf = node.name if ( isinstance(attribute.returns, Name) and attribute.returns.id == "Self" ): node.body[i].returns.idSelf = node.name return node
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var imports_typing
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.
def visit_ClassDef(self, node: ast.ClassDef) ‑> ast.ClassDef
def visit_ImportFrom(self, node: ast.ImportFrom) ‑> ast.ImportFrom | None