# pyright: reportUnknownVariableType=false, reportUnknownParameterType=false, reportUnknownMemberType=false, reportUnknownArgumentType=false, reportMissingTypeArgument=false, reportMissingTypeStubs=false
"""
Imports and extends the ``sympy`` library for symbolic mathematics.
Contains tools for converting Sympy expressions to Python modules and functions.
"""
import re
from collections.abc import Callable, Iterable, Sequence
from functools import lru_cache
from typing import Any, Literal, Self, cast, overload
import numpy as np
import sympy as sp
from sympy import default_sort_key
from sympy.utilities.lambdify import lambdastr, lambdify
from .settings import SETTINGS
from .units import Quantity
from .utypes import Numpy1DArray
[docs]
@lru_cache
def to_identifier(s: sp.Symbol | str) -> str:
"""
Converts a Sympy symbol to a valid Python identifier.
This function will only remove special characters.
The Latex command ``\\text{}`` is replaced with ``T``. This is
done to differentiate between symbols ``\\text{kg}`` (returns ``Tkg``)
and ``kg`` (returns ``kg``).
Parameters
----------
s : sp.Symbol | str
Input symbol or string representation
Returns
-------
str
Valid Python identifier created from the input symbol
"""
# assume that input strings are already identifiers
if isinstance(s, str):
return s
s = s.name
s_orig = s
s = s.replace(",", "_")
s = s.replace("^", "__")
s = s.replace("'", "prime")
# need to differentiate between symbols "\text{m}" and "m"
# the string "text" is a bit long, replace with "T"
s = s.replace("text", "T")
# the substring "lambda" cannot exist in the identifier
s = s.replace("lambda", "lam")
# remove all non-alphanumeric or _
s = re.sub(r"\W+", "", s)
if not s.isidentifier():
raise ValueError(f"Symbol could not be converted to a valid Python identifer: {s_orig}")
return s
[docs]
@lru_cache
def get_args(e: sp.Basic) -> list[str]:
"""
Returns a sorted list of identifiers for
each free symbol in the input expression.
The sort order is according to the string outputs
from :py:func:`encomp.sympy.to_identifier`.
Parameters
----------
e : sp.Basic
Input expression
Returns
-------
list[str]
Sorted list of identifiers for each free symbol
"""
return sorted(map(to_identifier, e.free_symbols))
[docs]
def recursive_subs(e: sp.Basic, replacements: list[tuple[sp.Symbol, sp.Basic]]) -> sp.Basic:
"""
Substitute the expressions in ``replacements`` recursively.
This might not be necessary in all cases, Sympy's builtin
``subs()`` method should also do this recursively.
.. note::
The order of the tuples in ``replacements`` might matter,
make sure to order these sensibly in case the expression contains
a lot of nested substitutions.
Parameters
----------
e : sp.Basic
Input expression
replacements : list[tuple[sp.Symbol, sp.Basic]]
List of replacements: ``symbol, replace``
Returns
-------
sp.Basic
Substituted expression
"""
new_e = None
for _ in range(0, len(replacements) + 1):
new_e = e.subs(replacements)
if new_e == e:
return new_e
else:
e = new_e
if new_e is None:
raise ValueError(f"Could not substitute, {e=}, {replacements=}")
return new_e
[docs]
def simplify_exponents(e: sp.Basic) -> sp.Basic:
"""
Simplifies an expression by combining float and int exponents.
This is not done automatically by Sympy.
Adapted from
https://stackoverflow.com/questions/54243832/sympy-wont-simplify-or-expand-exponential-with-decimals
Parameters
----------
e : sp.Basic
A Sympy expression, potentially containing mixed float and int exponents
Returns
-------
sp.Basic
Simplified expression with float and int exponents combined
"""
def rewrite(expr: sp.Basic, new_args: tuple[sp.Basic, ...]) -> sp.Basic:
new_args_list = list(new_args)
pow_val = new_args_list[1]
pow_val_int = int(new_args_list[1])
if pow_val.epsilon_eq(pow_val_int):
new_args_list[1] = sp.Integer(pow_val_int)
return type(expr)(*new_args_list)
def is_float_pow(expr: sp.Basic) -> bool:
return expr.is_Pow and expr.args[1].is_Float
if not e.args:
return e
else:
new_args = tuple(simplify_exponents(a) for a in e.args)
if is_float_pow(e):
return rewrite(e, new_args)
else:
return type(e)(*new_args)
[docs]
def get_sol_expr(
equations: sp.Equality | list[sp.Equality],
symbol: sp.Symbol,
avoid: set[sp.Symbol] | None = None,
) -> sp.Basic | None:
"""
Wrapper around ``sp.solve`` that returns the solution expression
for a *single* symbol, or None in case Sympy
could not solve for the specified symbol.
Only considers equations in the input list that actually contains the symbol.
Prefers to use equations that contain ``symbol`` on the LHS.
Parameters
----------
eqns : sp.Equality | list[sp.Equality]
List of equations or a single equation
symbol : sp.Symbol
Symbol to solve for (isolate)
avoid : set[sp.Symbol] | None, optional
Set of symbols to avoid in the substitution expressions, by default None
Returns
-------
sp.Basic | None
Expression that equals ``symbol``, or None in case the
equation(s) could not be solved
"""
if avoid is None:
avoid = set()
if isinstance(equations, sp.Equality):
equations = [equations]
# only include unique equations that actually contains the symbol,
# preferably on the LHS
# this might leave multiple equations, there's no guarantee
# that the equations can be solved
# sort by the number of free symbols, use default_sort_key as the
# secondary sort key to make sure that the order is consistent
def eqn_simplicity(eqn: sp.Eq) -> tuple[int, tuple[Any, ...]]:
return len(eqn.lhs.free_symbols), default_sort_key(eqn)
equations = sorted(set(filter(lambda eqn: symbol in eqn.free_symbols, equations)), key=eqn_simplicity)
# in case there are multiple equations containing the requested symbol,
# first check if any of the equations directly contain the symbol on the LHS
if len(equations) > 1:
for eqn in equations:
if symbol in eqn.lhs.free_symbols:
ret = get_sol_expr(eqn, symbol)
# don't return an expression that contains symbols to be avoided
if ret is not None and not (avoid & ret.free_symbols):
return ret
# if the symbol could not be solved directly from the LHS of a single equation,
# try to solve all relevant (i.e. containing the requested symbol) equations instead
# use dict=True to avoid inconsistent return types from sp.solve
# make sure to define the assumptions correctly for all symbols, otherwise the
# Sympy solver might not be able to find an explicit solution
sol = sp.solve(equations, symbol, dict=True)
if not sol:
return None
# sp.solve() returns a list of dict, there should only be one element
# since we solved for a single variable
sol = sol[0]
# hopefully, there is only be a single key in this dict
# (quadratic equations might have multiple solutions, etc...)
# sort with default_sort_key to keep the output consistent
# Sympy might otherwise order expressions randomly
return cast(sp.Basic, sorted(sol.values(), key=default_sort_key)[0])
[docs]
def get_lambda_kwargs(
value_map: dict[sp.Symbol | str, Quantity | np.ndarray],
include: Sequence[sp.Symbol | str] | None = None,
*,
units: bool = False,
) -> dict[str, Quantity | np.ndarray]:
"""
Returns a mapping from identifier to value (Quantity or float)
based on the input value map (Symbol to value).
If ``include`` is a list, only these symbols will be included.
Parameters
----------
value_map : dict[sp.Symbol | str, Quantity | np.ndarray]
Mapping from symbol or symbol identifier to value
include : Sequence[sp.Symbol | str] | None, optional
Optional sequence of symbols or symbol identifiers to include, by default None
units : bool, optional
Whether to keep the units, if False Quantity is converted
to float (after calling ``to_base_units()``), by default False
Returns
-------
dict[str | Quantity | np.ndarray]
Mapping from identifier to value
"""
if include is not None:
include = [to_identifier(n) for n in include]
def _get_val(
x: Quantity[Any, Numpy1DArray] | Numpy1DArray,
) -> Quantity | Numpy1DArray:
if not isinstance(x, Quantity):
return x
if units:
return x.to_base_units()
else:
return x.to_base_units().m
return {
to_identifier(a): _get_val(b) for a, b in value_map.items() if include is None or to_identifier(a) in include
}
@overload
def get_lambda(e: sp.Basic, *, to_str: Literal[True]) -> tuple[str, list[str]]: ...
@overload
def get_lambda(e: sp.Basic, *, to_str: Literal[False]) -> tuple[Callable, list[str]]: ...
@overload
def get_lambda(e: sp.Basic) -> tuple[Callable, list[str]]: ...
[docs]
@lru_cache
def get_lambda(e: sp.Basic, *, to_str: bool = False) -> tuple[Callable | str, list[str]]:
"""
Converts the input expression to a lambda function
with valid identifiers as parameter names.
Parameters
----------
e : sp.Basic
Input expression
to_str : bool, optional
Whether to return the string representation of the lambda function,
by default False
Returns
-------
tuple[Callable | str, list[str]]
The lambda function (or string representation) and the list
of parameters to the function
"""
# sorted list of function parameters (as valid identifiers)
args = get_args(e)
# substitute the symbols with the identifier version,
# otherwise they will be converted to dummy identifiers (even if dummify=False)
e_identifiers = e.subs({n: sp.Symbol(to_identifier(n), **n.assumptions0) for n in e.free_symbols})
_lambda_func = lambdastr if to_str else lambdify
fcn = _lambda_func(args, e_identifiers, dummify=False)
return fcn, args
[docs]
def get_lambda_matrix(M: sp.MutableDenseMatrix) -> tuple[str, list[str]]:
"""
Converts the input matrix into a lambda function that returns
an array.
Converts the matrix to Python source, it is not possible to
use in-memory lambda functions for this. Use ``eval(src)``
on the output from this function to create a function object.
Parameters
----------
M : sp.MutableDenseMatrix
Input matrix
Returns
-------
tuple[str, list[str]]
Python source code for the function and a list of parameters
"""
args = set()
nrows, ncols = M.shape
arr = np.zeros((nrows, ncols), dtype=object)
for i in range(nrows):
for j in range(ncols):
fcn_str, n_args = get_lambda(M[i, j], to_str=True)
args |= set(n_args)
# remove the "lambda x, y, x:" part and extra parens,
# they are added back later
fcn_str = fcn_str.split(":", 1)[-1].strip().removeprefix("(").removesuffix(")")
arr[i, j] = fcn_str
# remove quotes around strings, they are mathematical expressions
funcs = str(arr.tolist()).replace("'", "").replace('"', "")
# TODO: "VisibleDeprecationWarning: Creating an ndarray from ragged..."
# when mixing input vectors and floats
func_src = f"lambda {', '.join(args)}: np.array({funcs})"
return func_src, sorted(args)
[docs]
@lru_cache
def get_function(e: sp.Basic, *, units: bool = False) -> Callable:
"""
Wrapper around :py:func:`encomp.sympy.get_lambda` that
handles inputs and potential units.
Parameters
----------
e : sp.Basic
Input expression
units : bool, optional
Whether to keep the units, if False Quantity is converted
to float (after calling ``to_base_units()``), by default False
Returns
-------
Callable
Function that evaluates the input expression, can be called
with extra kwargs. The kwargs can be a dict with mapping
from symbol to value.
"""
fcn, args = get_lambda(e)
def expr_func(params: dict) -> Any: # noqa: ANN401
return fcn(**get_lambda_kwargs(params, args, units=units))
return expr_func
[docs]
def evaluate(
e: sp.Basic,
value_map: dict[sp.Symbol, Quantity | np.ndarray],
*,
units: bool = False,
) -> Quantity | np.ndarray:
"""
Evaluates the input expression, given the mapping of symbol to
value in ``value_map``.
Parameters
----------
e : sp.Basic
Input expression to evaluate
value_map : dict[sp.Symbol, Quantity | np.ndarray]
Mapping from symbol to value for all required symbols in ``e``,
additional symbols may be present
units : bool, optional
Whether to keep the units, if False Quantity is converted
to float (after calling ``to_base_units()``), by default False
Returns
-------
Quantity | np.ndarray
Value of the expression, Quantity if ``units=True`` otherwise float
"""
fcn = get_function(e, units=units)
return cast(Quantity | np.ndarray, fcn(value_map))
[docs]
def substitute_unknowns(
e: sp.Basic,
knowns: set[sp.Symbol],
eqns: list[sp.Equality],
avoid: set[sp.Symbol] | None = None,
) -> sp.Basic:
"""
Uses the equations ``eqns`` to substitute the unknown symbols
in the input expression. Uses recursion to deal with nested substitutions.
Parameters
----------
e : sp.Basic
Input expression that potentially contains unknown symbols
knowns : set[sp.Symbol]
Set of known symbols
eqns : list[sp.Equality]
List of equations that define the unknown symbols in terms of known ones
avoid : set[sp.Symbol] | None, optional
Set of symbols to avoid in the substitution expressions, by default None
Returns
-------
sp.Basic
The substituted expression without any unknown symbols
"""
if avoid is None:
avoid = set()
replacements: list[tuple[sp.Symbol, sp.Basic]] = []
def _get_unknowns(expr: sp.Basic) -> list[sp.Symbol]:
all_symbols = cast(list[sp.Symbol], sorted(expr.free_symbols, key=default_sort_key))
already_replaced = [m[0] for m in replacements]
return [n for n in all_symbols if n not in (knowns | avoid) and n not in already_replaced]
unknowns_list = _get_unknowns(e)
for n in unknowns_list:
n_expr = get_sol_expr(eqns, n, avoid=avoid)
if n_expr is None:
raise ValueError(f"Symbol {n} could not be isolated based on the specified equations.")
# check if the expression for n contains even more unknown symbols
# extend the list that is iterated over to account for these symbols
# this will not loop infinitely since the replacements are accounted
# for in the _get_unknowns function
additional_unknowns = _get_unknowns(n_expr)
unknowns_list.extend(additional_unknowns)
replacements.append((n, n_expr))
# the replacements list is reversed, since the "deepest" level
# of substitutions must be done first
replacements = replacements[::-1]
# recursively apply the substitutions until the expression no longer changes
# since the replacements list is ordered from deep → shallow, this
# will substitute everything correctly
return recursive_subs(e, replacements)
[docs]
def typeset_chemical(s: str) -> str:
"""
Typesets chemical formulas using Latex.
Parameters
----------
s : str
Input string
Returns
-------
str
Output string with chemical formulas typeset correctly
"""
parts = []
for n in re.sub(r"[A-Z]_\d", r"|\g<0>|", s).split("|"):
if re.match(r"[A-Z]_\d", n):
parts.extend([f"{n[:-2]}", "}", f"_{n[-1]}", "\\text{"])
else:
parts.append(n)
parts = ["\\text{"] + [n for n in parts if n]
if parts[-1] == "\\text{":
parts = parts[:-1]
ret = "".join(parts)
if ret.count("{") == ret.count("}") + 1:
ret += "}"
return ret
[docs]
def typeset(x: str | int) -> str:
"""
Does some additional typesetting for the input
Latex string, for example ``\\text{}`` around
strings and upper-case characters.
Use comma ``,`` to separate parts of the input, for example
.. code:: none
input,i
will be typeset as ``\\text{input},i``: the ``i`` is
a separate part and is typeset with a math font.
Spaces around commas will be removed to make sub- and
superscripts more compact.
Use ``~`` before a single upper-case letter to typeset it
with a math font.
Uses flags from ``encomp.settings`` to determine
how to typeset the input.
Parameters
----------
x : str | int
Input string or int (will be converted to str)
Returns
-------
str
Output Latex string
"""
x = str(x)
if not SETTINGS.typeset_symbol_scripts:
return x
parts = [n.strip() for n in x.split(",")]
for i, p in enumerate(parts):
# avoid typesetting single upper-case letters as text
# if they start with ~
if p.startswith("~") and p[1].isupper():
parts[i] = p[1]
continue
# only typeset single words, also ignore Latex code
if " " in p or "\\" in p:
continue
alpha_str = "".join(n for n in p if n.isalpha())
# typeset everything except 1-letter lower case as text
typeset_text = len(alpha_str) >= 2 or (len(alpha_str) == 1 and alpha_str.isupper())
if typeset_text:
# handle chemical compounds
if re.match("[A-Z]", p):
p = typeset_chemical(p)
parts[i] = p
else:
parts[i] = "\\text{" + p + "}"
return ",".join(parts)
[docs]
class Symbol(sp.Symbol):
[docs]
def decorate(
self,
prefix: str | int | None = None,
suffix: str | int | None = None,
prefix_sub: str | int | None = None,
prefix_sup: str | int | None = None,
suffix_sub: str | int | None = None,
suffix_sup: str | int | None = None,
) -> Self:
"""
Method that decorates a symbol with
subscripts and/or superscripts before or after the symbol.
Returns a new symbol object with the same assumptions (i.e. real,
positive, complex, etc...) as the input.
Using LaTeX syntax supported by ``sympy``:
.. code:: none
{prefix}^{prefix_sub}_{prefix_sup}{symbol}_{suffix_sub}^{suffix_sup}{suffix}
``symbol`` is the input symbol.
The ``prefix`` and ``suffix`` parts are added without ``_`` or ``^``.
Each of the parts (except ``symbol``) can be empty.
The decorations can be string or integer, floats are not allowed.
In case the input symbol already contains sub- or superscripts,
the decorations are not appended to those. Instead, a new level
is introduced. To keep things simple, make sure that the input symbol
is a simple symbol.
Use the ``append`` method to append to an existing sub-
or superscript in the suffix.
Parameters
----------
prefix : str | int | None, optional
Prefix added before the symbol, by default None
suffix : str | int | None, optional
Suffix added after the symbol, by default None
prefix_sub : str | int | None, optional
Subscript prefix before the symbol and after ``prefix``, by default None
prefix_sup : str | int | None, optional
Superscript prefix before the symbol and after ``prefix``, by default None
suffix_sub : str | int | None, optional
Subscript suffix after the symbol and before ``suffix``, by default None
suffix_sup : str | int | None, optional
Superscript suffix after the symbol and before ``suffix``, by default None
Returns
-------
Symbol
A new symbol with the same assumptions as the input, with added decorations
"""
parts = [
prefix,
typeset(prefix_sup) if prefix_sup is not None else None,
typeset(prefix_sub) if prefix_sub is not None else None,
self.name,
typeset(suffix_sub) if suffix_sub is not None else None,
typeset(suffix_sup) if suffix_sup is not None else None,
suffix,
]
delimiters = ["", "^", "_", "", "_", "^", ""]
decorated_parts = []
for p, d in zip(parts, delimiters, strict=False):
if p is None:
continue
p = str(p)
# don't add extra braces around the base symbol
if p != self.name:
p = "{" + p + "}"
decorated_parts.append(d + p)
decorated_symbol = "".join(decorated_parts)
# assumptions0 contains assumptions that are not None
return self.__class__(decorated_symbol, **self.assumptions0)
[docs]
def append(self, s: str | int, where: Literal["sub", "sup"] = "sub") -> Self:
"""
Adds the input ``s`` to an existing sub- or superscript.
Does not append to prefixes.
Creates the sub- or superscript if it does not exist.
Parameters
----------
s : str | int
Text or index to be added
where : Literal['sub', 'sup'], optional
Whether to append to the subscript or superscript, by default 'sub'
Returns
-------
Symbol
A new symbol with the same assumptions as the input,
with updated sub- or superscript
"""
delimiter = "_" if where == "sub" else "^"
symbol = self.name
s = typeset(s)
if delimiter not in symbol:
decorated_parts = [symbol, delimiter, "{" + s + "}"]
else:
*base_symbol, existing_suffix = symbol.split(delimiter)
base_symbol_str = "".join(base_symbol)
# assume that the input Latex symbol is correct,
# don't deal with unbalanced braces
existing_suffix = existing_suffix.removeprefix("{").removesuffix("}")
existing_suffix += str(s)
decorated_parts = [base_symbol_str, delimiter, "{" + existing_suffix + "}"]
decorated_symbol = "".join(decorated_parts)
return self.__class__(decorated_symbol, **self.assumptions0)
def _(self, x: str) -> Self:
"""
Add subscript ``x`.
"""
return self.append(x, where="sub")
def __(self, x: str) -> Self:
"""
Add superscript ``x`.
"""
return self.append(x, where="sup")
[docs]
def delta(self) -> Self:
"""
Add ``\\delta`` prefix.
"""
return self.decorate(prefix="\\Delta")
def _patch_symbol_class(dest: type[sp.Symbol] | sp.Symbol) -> None:
for n in ["decorate", "append", "delta", "_", "__"]:
if not hasattr(dest, n):
method = getattr(Symbol, n)
setattr(dest, n, method)
[docs]
def symbols(inp: str, **kwargs: Any) -> list[Symbol]: # noqa: ANN401
ret = sp.symbols(inp, **kwargs)
if not isinstance(ret, Iterable):
raise ValueError(
f"Expected more than one input symbol, use sp.Symbol('{inp}') directly to create a single symbol"
)
for n in ret:
_patch_symbol_class(n)
return cast(list[Symbol], list(ret))
_patch_symbol_class(sp.Symbol)