Source code for magicclass.macro

from __future__ import annotations
from contextlib import contextmanager
from functools import wraps
import inspect
from copy import deepcopy
from enum import Enum
from collections import UserList
from pathlib import Path
from typing import Callable, Iterable, Iterator, Any, overload, TypeVar
from types import FunctionType, MethodType, BuiltinFunctionType
import numpy as np

T = TypeVar("T")

[docs]class Symbol: # Map of how to convert object into a symbol. _type_map: dict[type, Callable[[Any], str]] = { type: lambda e: e.__name__, FunctionType: lambda e: e.__name__, MethodType: lambda e: e.__name__, BuiltinFunctionType: lambda e: e.__name__, Enum: lambda e: repr(str(e.name)), Path: lambda e: f"r'{e}'", } def __init__(self, seq: str, object_id: int = None, type: type = Any): self.data = str(seq) self.object_id = object_id or id(seq) self.type = type self.valid = True def __repr__(self) -> str: return self.data def __str__(self) -> str: return self.data def __hash__(self) -> int: return self.object_id def __eq__(self, other: Symbol) -> bool: if not isinstance(other, Symbol): raise TypeError(f"'==' is not supported between Symbol and {type(other)}") return self.object_id == other.object_id
[docs] def as_parameter(self, default=inspect._empty): return inspect.Parameter(self.data, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=default, annotation=self.type)
[docs] @classmethod def register_type(cls, type: type[T], function: Callable[[T], str]): if not callable(function): raise TypeError("The second argument must be callable.") cls._type_map[type] = function
[docs]def symbol(obj: Any) -> Symbol: if isinstance(obj, Symbol): return obj valid = True objtype = type(obj) if isinstance(obj, str): seq = repr(obj) elif np.isscalar(obj): # int, float, bool, ... seq = obj elif isinstance(obj, tuple): seq = "(" + ", ".join(symbol(a).data for a in obj) + ")" if objtype is not tuple: seq = objtype.__name__ + seq elif isinstance(obj, list): seq = "[" + ", ".join(symbol(a).data for a in obj) + "]" if objtype is not list: seq = f"{objtype.__name__}({seq})" elif isinstance(obj, dict): seq = "{" + ", ".join(f"{symbol(k)}: {symbol(v)}" for k, v in obj.items()) + "}" if objtype is not dict: seq = f"{objtype.__name__}({seq})" elif isinstance(obj, set): seq = "{" + ", ".join(symbol(a).data for a in obj) + "}" if objtype is not set: seq = f"{objtype.__name__}({seq})" elif isinstance(obj, slice): seq = f"{objtype.__name__}({obj.start}, {obj.stop}, {obj.step})" elif objtype in Symbol._type_map: seq = Symbol._type_map[objtype](obj) else: for k, func in Symbol._type_map.items(): if isinstance(obj, k): seq = func(obj) break else: seq = f"var{hex(id(obj))}" # hexadecimals are easier to distinguish valid = False sym = Symbol(seq, id(obj), type(obj)) sym.valid = valid return sym
[docs]def register_type(type: type[T], function: Callable[[T], str]): return Symbol.register_type(type, function)
[docs]class Expr: """ Python expression class. Inspired by Julia (https://docs.julialang.org/en/v1/manual/metaprogramming/), this class enables efficient macro recording and macro operation. Expr objects are mainly composed of "head" and "args". "Head" denotes what kind of operation it is, and "args" denotes the arguments needed for the operation. Other attributes, such as "symbol", is not necessary as a Expr object but it is useful to create executable codes. """ n: int = 0 # a map of how to conver expression into string. _map: dict[Head, Callable[[Expr], str]] = { Head.init : lambda e: f"{e.args[0]} = {e.args[1]}({', '.join(map(str, e.args[2:]))})", Head.getattr: lambda e: f"{e.args[0]}.{e.args[1]}", Head.setattr: lambda e: f"{e.args[0]}.{e.args[1]} = {e.args[2]}", Head.getitem: lambda e: f"{e.args[0]}[{e.args[1]}]", Head.setitem: lambda e: f"{e.args[0]}[{e.args[1]}] = {e.args[2]}", Head.call : lambda e: f"{e.args[0]}({', '.join(map(str, e.args[1:]))})", Head.assign : lambda e: f"{e.args[0]}={e.args[1]}", Head.value : lambda e: str(e.args[0]), Head.comment: lambda e: f"# {e.args[0]}", } def __init__(self, head: Head, args: Iterable[Any]): self.head = Head(head) if head == Head.value: self.args = [args[0]] else: self.args = list(map(self.__class__.parse, args)) self.number = self.__class__.n self.__class__.n += 1 def __repr__(self) -> str: return self._repr() def _repr(self, ind: int = 0) -> str: """ Recursively expand expressions until it reaches value/assign expression. """ if self.head in (Head.value, Head.assign): return str(self) out = [f"head: {self.head.name}\n{' '*ind}args:\n"] for i, arg in enumerate(self.args): out.append(f"{i:>{ind+2}}: {arg._repr(ind+4)}\n") return "".join(out) def __str__(self) -> str: return self.__class__._map[self.head](self) def __eq__(self, expr: Expr) -> bool: if self.head != Head.value: raise TypeError(f"Expression must be value, got {self.head}") if isinstance(expr, str): return self.args[0] == expr elif isinstance(expr, self.__class__): return self.args[0] == expr.args[0] else: raise ValueError(f"'==' is not supported between Expr and {type(expr)}")
[docs] def copy(self): return deepcopy(self)
[docs] def eval(self, _globals: dict[Symbol: Any] = {}, _locals: dict[Symbol: Any] = {}): _globals = {sym.data: v for sym, v in _globals} _locals = {sym.data: v for sym, v in _locals} if self.head in (Head.assign, Head.setitem, Head.setattr): return exec(str(self), _globals, _locals) else: return eval(str(self), _globals, _locals)
[docs] @classmethod def parse_method(cls, obj: Any, func: Callable, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Expr: """ Make a method call expression. Expression: obj.func(*args, **kwargs) """ method = cls(head=Head.getattr, args=[symbol(obj), func]) inputs = [method] + cls.convert_args(args, kwargs) return cls(head=Head.call, args=inputs)
[docs] @classmethod def parse_init(cls, obj: Any, init_cls: type, args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None) -> Expr: if args is None: args = () if kwargs is None: kwargs = {} sym = symbol(obj) inputs = [sym, init_cls] + cls.convert_args(args, kwargs) return cls(head=Head.init, args=inputs)
[docs] @classmethod def parse_call(cls, func: Callable, args: tuple[Any, ...] = None, kwargs: dict[str, Any] = None) -> Expr: """ Make a function call expression. Expression: func(*args, **kwargs) """ if args is None: args = () if kwargs is None: kwargs = {} inputs = [func] + cls.convert_args(args, kwargs) return cls(head=Head.call, args=inputs)
[docs] @classmethod def convert_args(cls, args: tuple[Any, ...], kwargs: dict[str|Symbol, Any]) -> list: inputs = [] for a in args: inputs.append(a) for k, v in kwargs.items(): inputs.append(cls(Head.assign, [Symbol(k), v])) return inputs
[docs] @classmethod def parse(cls, a: Any) -> Expr: return a if isinstance(a, cls) else cls(Head.value, [symbol(a)])
[docs] def iter_args(self) -> Iterator[Symbol]: """ Recursively iterate along all the arguments. """ for arg in self.args: if isinstance(arg, self.__class__): yield from arg.iter_args() elif isinstance(arg, Symbol): yield arg else: raise RuntimeError(arg)
[docs] def iter_values(self) -> Iterator[Expr]: """ Recursively iterate along all the values. """ for arg in self.args: if isinstance(arg, self.__class__): if arg.head == Head.value: yield arg else: yield from arg.iter_values()
[docs] def iter_expr(self) -> Iterator[Expr]: """ Recursively iterate over all the nested Expr, until it reached to non-nested Expr. This method is useful in macro generation. """ yielded = False for arg in self.args: if isinstance(arg, self.__class__): yield from arg.iter_expr() yielded = True if not yielded: yield self
[docs] def format(self, mapping: dict[Symbol, Symbol|Expr], inplace: bool = False) -> Expr: if not inplace: self = self.copy() for arg in self.iter_values(): try: new = mapping[arg.args[0]] except KeyError: pass else: arg.args[0] = new return self
[docs]class Macro(UserList): """ List with pretty output customized for macro. """ def __init__(self, iterable: Iterable = (), *, active: bool = True): super().__init__(iterable) self.active = active
[docs] def append(self, __object: Expr): if not isinstance(__object, Expr): raise TypeError("Cannot append objects to Macro except for MacroExpr objecs.") return super().append(__object)
def __str__(self) -> str: return "\n".join(map(str, self)) @overload def __getitem__(self, key: int | str) -> Expr: ... @overload def __getitem__(self, key: slice) -> Macro[Expr]: ... def __getitem__(self, key): return super().__getitem__(key) def __iter__(self) -> Iterator[Expr]: return super().__iter__() def __repr__(self) -> str: return ",\n".join(repr(expr) for expr in self)
[docs] @contextmanager def context(self, active: bool): was_active = self.active self.active = active yield self.active = was_active
[docs] def record(self, function=None, *, returned_callback: Callable[[Expr], Expr]=None): def wrapper(func): if isinstance(func, MethodType): if func.__name__ == "__init__": def make_expr(*args, **kwargs): return Expr.parse_init(args[0], args[0].__class__, args[1:], kwargs) elif func.__name__ == "__call__": def make_expr(*args, **kwargs): return Expr.parse_call(Expr(Head.getattr, [args[0], func]), args[1:], kwargs) else: def make_expr(*args, **kwargs): return Expr.parse_method(args[0], func, args[1:], kwargs) else: def make_expr(*args, **kwargs): return Expr.parse_call(func, args, kwargs) @wraps(func) def macro_recorder_equipped(*args, **kwargs): with self.context(False): out = func(*args, **kwargs) if self.active: expr = make_expr(*args, **kwargs) if returned_callback is not None: expr = out(expr) self.append(expr) return out return macro_recorder_equipped return wrapper if function is None else wrapper(function)