Module opshin.rewrite.rewrite_import
Expand source code
import ast
import importlib
import importlib.util
import pathlib
import typing
import sys
from ast import *
from ordered_set import OrderedSet
from ..util import CompilingNodeTransformer
"""
Checks that there was an import of dataclass if there are any class definitions
"""
def import_module(name, package=None):
"""An approximate implementation of import."""
absolute_name = importlib.util.resolve_name(name, package)
try:
return sys.modules[absolute_name]
except KeyError:
pass
path = None
if "." in absolute_name:
parent_name, _, child_name = absolute_name.rpartition(".")
parent_module = import_module(parent_name)
path = parent_module.__spec__.submodule_search_locations
for finder in sys.meta_path:
spec = finder.find_spec(absolute_name, path)
if spec is not None:
break
else:
msg = f"No module named {absolute_name!r}"
raise ModuleNotFoundError(msg, name=absolute_name)
module = importlib.util.module_from_spec(spec)
sys.modules[absolute_name] = module
spec.loader.exec_module(module)
if path is not None:
setattr(parent_module, child_name, module)
return module
class RewriteLocation(CompilingNodeTransformer):
def __init__(self, orig_node):
self.orig_node = orig_node
def visit(self, node):
node = ast.copy_location(node, self.orig_node)
return super().visit(node)
class RewriteImport(CompilingNodeTransformer):
step = "Resolving imports"
def __init__(self, filename=None, package=None, resolved_imports=None):
self.filename = filename
self.package = package
self.resolved_imports = resolved_imports or OrderedSet()
def visit_ImportFrom(
self, node: ImportFrom
) -> typing.Union[ImportFrom, typing.List[AST], None]:
if node.module in [
"pycardano",
"typing",
"dataclasses",
"hashlib",
"opshin.bridge",
"opshin.std.integrity",
]:
return node
assert (
len(node.names) == 1
), "The import must have the form 'from <pkg> import *'"
assert (
node.names[0].name == "*"
), "The import must have the form 'from <pkg> import *'"
assert (
node.names[0].asname == None
), "The import must have the form 'from <pkg> import *'"
# TODO set anchor point according to own package
if self.filename:
sys.path.append(str(pathlib.Path(self.filename).parent.absolute()))
module = import_module(node.module, self.package)
if self.filename:
sys.path.pop()
module_file = pathlib.Path(module.__file__)
if module_file.suffix == ".pyc":
module_file = module_file.with_suffix(".py")
if module_file in self.resolved_imports:
# Import was already resolved and its names are visible
return None
self.resolved_imports.add(module_file)
assert (
module_file.suffix == ".py"
), "The import must import a single python file."
# visit the imported file again - make sure that recursive imports are resolved accordingly
with module_file.open("r") as fp:
module_content = fp.read()
resolved = parse(module_content, filename=module_file.name)
# annotate this to point to the original line number!
RewriteLocation(node).visit(resolved)
# recursively import all statements there
recursive_resolver = RewriteImport(
filename=str(module_file),
package=module.__package__,
resolved_imports=self.resolved_imports,
)
recursively_resolved: Module = recursive_resolver.visit(resolved)
self.resolved_imports.update(recursive_resolver.resolved_imports)
return recursively_resolved.body
Functions
def import_module(name, package=None)
-
An approximate implementation of import.
Classes
class RewriteImport (filename=None, package=None, resolved_imports=None)
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RewriteImport(CompilingNodeTransformer): step = "Resolving imports" def __init__(self, filename=None, package=None, resolved_imports=None): self.filename = filename self.package = package self.resolved_imports = resolved_imports or OrderedSet() def visit_ImportFrom( self, node: ImportFrom ) -> typing.Union[ImportFrom, typing.List[AST], None]: if node.module in [ "pycardano", "typing", "dataclasses", "hashlib", "opshin.bridge", "opshin.std.integrity", ]: return node assert ( len(node.names) == 1 ), "The import must have the form 'from <pkg> import *'" assert ( node.names[0].name == "*" ), "The import must have the form 'from <pkg> import *'" assert ( node.names[0].asname == None ), "The import must have the form 'from <pkg> import *'" # TODO set anchor point according to own package if self.filename: sys.path.append(str(pathlib.Path(self.filename).parent.absolute())) module = import_module(node.module, self.package) if self.filename: sys.path.pop() module_file = pathlib.Path(module.__file__) if module_file.suffix == ".pyc": module_file = module_file.with_suffix(".py") if module_file in self.resolved_imports: # Import was already resolved and its names are visible return None self.resolved_imports.add(module_file) assert ( module_file.suffix == ".py" ), "The import must import a single python file." # visit the imported file again - make sure that recursive imports are resolved accordingly with module_file.open("r") as fp: module_content = fp.read() resolved = parse(module_content, filename=module_file.name) # annotate this to point to the original line number! RewriteLocation(node).visit(resolved) # recursively import all statements there recursive_resolver = RewriteImport( filename=str(module_file), package=module.__package__, resolved_imports=self.resolved_imports, ) recursively_resolved: Module = recursive_resolver.visit(resolved) self.resolved_imports.update(recursive_resolver.resolved_imports) return recursively_resolved.body
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Class variables
var step
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.
def visit_ImportFrom(self, node: ast.ImportFrom) ‑> ast.ImportFrom | List[ast.AST] | None
class RewriteLocation (orig_node)
-
A :class:
NodeVisitor
subclass that walks the abstract syntax tree and allows modification of nodes.The
NodeTransformer
will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method isNone
, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.Here is an example transformer that rewrites all occurrences of name lookups (
foo
) todata['foo']
::class RewriteName(NodeTransformer):
def visit_Name(self, node): return Subscript( value=Name(id='data', ctx=Load()), slice=Constant(value=node.id), ctx=node.ctx )
Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:
generic_visit
method for the node first.For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.
Usually you use the transformer like this::
node = YourTransformer().visit(node)
Expand source code
class RewriteLocation(CompilingNodeTransformer): def __init__(self, orig_node): self.orig_node = orig_node def visit(self, node): node = ast.copy_location(node, self.orig_node) return super().visit(node)
Ancestors
- CompilingNodeTransformer
- TypedNodeTransformer
- ast.NodeTransformer
- ast.NodeVisitor
Methods
def visit(self, node)
-
Inherited from:
CompilingNodeTransformer
.visit
Visit a node.