Module opshin.rewrite.rewrite_import

Expand source code
import ast

import importlib
import importlib.util
import pathlib
import typing
import sys
from ast import *
from ordered_set import OrderedSet

from ..util import CompilingNodeTransformer

"""
Checks that there was an import of dataclass if there are any class definitions
"""


def import_module(name, package=None):
    """An approximate implementation of import."""
    absolute_name = importlib.util.resolve_name(name, package)
    try:
        return sys.modules[absolute_name]
    except KeyError:
        pass

    path = None
    if "." in absolute_name:
        parent_name, _, child_name = absolute_name.rpartition(".")
        parent_module = import_module(parent_name)
        path = parent_module.__spec__.submodule_search_locations
    for finder in sys.meta_path:
        spec = finder.find_spec(absolute_name, path)
        if spec is not None:
            break
    else:
        msg = f"No module named {absolute_name!r}"
        raise ModuleNotFoundError(msg, name=absolute_name)
    module = importlib.util.module_from_spec(spec)
    sys.modules[absolute_name] = module
    spec.loader.exec_module(module)
    if path is not None:
        setattr(parent_module, child_name, module)
    return module


class RewriteLocation(CompilingNodeTransformer):
    def __init__(self, orig_node):
        self.orig_node = orig_node

    def visit(self, node):
        node = ast.copy_location(node, self.orig_node)
        return super().visit(node)


class RewriteImport(CompilingNodeTransformer):
    step = "Resolving imports"

    def __init__(self, filename=None, package=None, resolved_imports=None):
        self.filename = filename
        self.package = package
        self.resolved_imports = resolved_imports or OrderedSet()

    def visit_ImportFrom(
        self, node: ImportFrom
    ) -> typing.Union[ImportFrom, typing.List[AST], None]:
        if node.module in [
            "pycardano",
            "typing",
            "dataclasses",
            "hashlib",
            "opshin.bridge",
            "opshin.std.integrity",
        ]:
            return node
        assert (
            len(node.names) == 1
        ), "The import must have the form 'from <pkg> import *'"
        assert (
            node.names[0].name == "*"
        ), "The import must have the form 'from <pkg> import *'"
        assert (
            node.names[0].asname == None
        ), "The import must have the form 'from <pkg> import *'"
        # TODO set anchor point according to own package
        if self.filename:
            sys.path.append(str(pathlib.Path(self.filename).parent.absolute()))
        module = import_module(node.module, self.package)
        if self.filename:
            sys.path.pop()
        module_file = pathlib.Path(module.__file__)
        if module_file in self.resolved_imports:
            # Import was already resolved and its names are visible
            return None
        self.resolved_imports.add(module_file)
        assert (
            module_file.suffix == ".py"
        ), "The import must import a single python file."
        # visit the imported file again - make sure that recursive imports are resolved accordingly
        with module_file.open("r") as fp:
            module_content = fp.read()
        resolved = parse(module_content, filename=module_file.name)
        # annotate this to point to the original line number!
        RewriteLocation(node).visit(resolved)
        # recursively import all statements there
        recursive_resolver = RewriteImport(
            filename=str(module_file),
            package=module.__package__,
            resolved_imports=self.resolved_imports,
        )
        recursively_resolved: Module = recursive_resolver.visit(resolved)
        self.resolved_imports.update(recursive_resolver.resolved_imports)
        return recursively_resolved.body

Functions

def import_module(name, package=None)

An approximate implementation of import.

Expand source code
def import_module(name, package=None):
    """An approximate implementation of import."""
    absolute_name = importlib.util.resolve_name(name, package)
    try:
        return sys.modules[absolute_name]
    except KeyError:
        pass

    path = None
    if "." in absolute_name:
        parent_name, _, child_name = absolute_name.rpartition(".")
        parent_module = import_module(parent_name)
        path = parent_module.__spec__.submodule_search_locations
    for finder in sys.meta_path:
        spec = finder.find_spec(absolute_name, path)
        if spec is not None:
            break
    else:
        msg = f"No module named {absolute_name!r}"
        raise ModuleNotFoundError(msg, name=absolute_name)
    module = importlib.util.module_from_spec(spec)
    sys.modules[absolute_name] = module
    spec.loader.exec_module(module)
    if path is not None:
        setattr(parent_module, child_name, module)
    return module

Classes

class RewriteImport (filename=None, package=None, resolved_imports=None)

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteImport(CompilingNodeTransformer):
    step = "Resolving imports"

    def __init__(self, filename=None, package=None, resolved_imports=None):
        self.filename = filename
        self.package = package
        self.resolved_imports = resolved_imports or OrderedSet()

    def visit_ImportFrom(
        self, node: ImportFrom
    ) -> typing.Union[ImportFrom, typing.List[AST], None]:
        if node.module in [
            "pycardano",
            "typing",
            "dataclasses",
            "hashlib",
            "opshin.bridge",
            "opshin.std.integrity",
        ]:
            return node
        assert (
            len(node.names) == 1
        ), "The import must have the form 'from <pkg> import *'"
        assert (
            node.names[0].name == "*"
        ), "The import must have the form 'from <pkg> import *'"
        assert (
            node.names[0].asname == None
        ), "The import must have the form 'from <pkg> import *'"
        # TODO set anchor point according to own package
        if self.filename:
            sys.path.append(str(pathlib.Path(self.filename).parent.absolute()))
        module = import_module(node.module, self.package)
        if self.filename:
            sys.path.pop()
        module_file = pathlib.Path(module.__file__)
        if module_file in self.resolved_imports:
            # Import was already resolved and its names are visible
            return None
        self.resolved_imports.add(module_file)
        assert (
            module_file.suffix == ".py"
        ), "The import must import a single python file."
        # visit the imported file again - make sure that recursive imports are resolved accordingly
        with module_file.open("r") as fp:
            module_content = fp.read()
        resolved = parse(module_content, filename=module_file.name)
        # annotate this to point to the original line number!
        RewriteLocation(node).visit(resolved)
        # recursively import all statements there
        recursive_resolver = RewriteImport(
            filename=str(module_file),
            package=module.__package__,
            resolved_imports=self.resolved_imports,
        )
        recursively_resolved: Module = recursive_resolver.visit(resolved)
        self.resolved_imports.update(recursive_resolver.resolved_imports)
        return recursively_resolved.body

Ancestors

Class variables

var step

Methods

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

def visit_ImportFrom(self, node: ast.ImportFrom) ‑> Union[ast.ImportFrom, List[ast.AST], ForwardRef(None)]
Expand source code
def visit_ImportFrom(
    self, node: ImportFrom
) -> typing.Union[ImportFrom, typing.List[AST], None]:
    if node.module in [
        "pycardano",
        "typing",
        "dataclasses",
        "hashlib",
        "opshin.bridge",
        "opshin.std.integrity",
    ]:
        return node
    assert (
        len(node.names) == 1
    ), "The import must have the form 'from <pkg> import *'"
    assert (
        node.names[0].name == "*"
    ), "The import must have the form 'from <pkg> import *'"
    assert (
        node.names[0].asname == None
    ), "The import must have the form 'from <pkg> import *'"
    # TODO set anchor point according to own package
    if self.filename:
        sys.path.append(str(pathlib.Path(self.filename).parent.absolute()))
    module = import_module(node.module, self.package)
    if self.filename:
        sys.path.pop()
    module_file = pathlib.Path(module.__file__)
    if module_file in self.resolved_imports:
        # Import was already resolved and its names are visible
        return None
    self.resolved_imports.add(module_file)
    assert (
        module_file.suffix == ".py"
    ), "The import must import a single python file."
    # visit the imported file again - make sure that recursive imports are resolved accordingly
    with module_file.open("r") as fp:
        module_content = fp.read()
    resolved = parse(module_content, filename=module_file.name)
    # annotate this to point to the original line number!
    RewriteLocation(node).visit(resolved)
    # recursively import all statements there
    recursive_resolver = RewriteImport(
        filename=str(module_file),
        package=module.__package__,
        resolved_imports=self.resolved_imports,
    )
    recursively_resolved: Module = recursive_resolver.visit(resolved)
    self.resolved_imports.update(recursive_resolver.resolved_imports)
    return recursively_resolved.body
class RewriteLocation (orig_node)

A :class:NodeVisitor subclass that walks the abstract syntax tree and allows modification of nodes.

The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node. If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value. The return value may be the original node in which case no replacement takes place.

Here is an example transformer that rewrites all occurrences of name lookups (foo) to data['foo']::

class RewriteName(NodeTransformer):

   def visit_Name(self, node):
       return Subscript(
           value=Name(id='data', ctx=Load()),
           slice=Constant(value=node.id),
           ctx=node.ctx
       )

Keep in mind that if the node you're operating on has child nodes you must either transform the child nodes yourself or call the :meth:generic_visit method for the node first.

For nodes that were part of a collection of statements (that applies to all statement nodes), the visitor may also return a list of nodes rather than just a single node.

Usually you use the transformer like this::

node = YourTransformer().visit(node)

Expand source code
class RewriteLocation(CompilingNodeTransformer):
    def __init__(self, orig_node):
        self.orig_node = orig_node

    def visit(self, node):
        node = ast.copy_location(node, self.orig_node)
        return super().visit(node)

Ancestors

Methods

def visit(self, node)

Inherited from: CompilingNodeTransformer.visit

Visit a node.

Expand source code
def visit(self, node):
    node = ast.copy_location(node, self.orig_node)
    return super().visit(node)