如何正确创建自定义 SymPy Symbol 类(支持完整运算重载)

在 sympy 中创建真正可用的自定

义符号类需继承 `expr` 并重写 `add`/`mul`/`pow` 的核心方法(如 `flatten` 和 `__new__`),而非仅继承 `symbol`;否则幂运算等复合操作仍会退化为原生类型。

SymPy 的核心代数运算(如 +、*、**)并非完全通过操作符重载(如 __pow__)驱动,而是由顶层工厂类(如 Add、Mul、Pow)统一调度并执行表达式规范化(如合并同类项、指数相加)。因此,单纯让 CustomSymbol 重写 __pow__ 是无效的——当执行 a**2 * a**3 时,SymPy 内部会先将整个乘积识别为 Mul,再调用 Pow.flatten 或直接构造 Pow(a, 5),完全绕过 CustomSymbol.__pow__。

✅ 正确做法是构建一套协同工作的自定义表达式家族

  1. 定义基类 CustomExpr:统一实现 __add__/__mul__/__pow__,强制返回自定义运算类;
  2. 子类化 Symbol + CustomExpr:获得符号语义与自定义行为双重能力;
  3. 重写 CustomAdd/CustomMul/CustomPow 的关键方法
    • flatten(cls, seq):复制 SymPy 源码,将所有 Add/Mul/Pow 替换为对应自定义类(注意处理系数排序、非交换乘法等细节);
    • __new__(仅 CustomPow):缓存并确保返回 CustomPow 实例;
    • _eval_subs(self, old, new):返回 None,防止默认替换逻辑将自定义类“降级”为原生类;
  4. 全局替换运算符别名(可选但推荐)
    Add = CustomAdd
    Mul = CustomMul
    Pow = CustomPow

    这能确保后续所有 SymPy 内部构造(如 expand()、simplify())也使用你的类——否则它们仍会生成原生类型。

⚠️ 注意事项:

  • 不要试图只覆盖 __pow__ 或 __mul__:SymPy 的 Mul 构造器会跳过实例方法,直接调用 Mul.flatten;
  • flatten 方法极其关键:它负责归并 a**2 * a**3 → a**5,若未重写,结果必为 Pow 而非 CustomPow;
  • 所有自定义类必须继承 CustomExpr(而非仅 Expr),以保证运算符重载链完整;
  • 若需 expand()、rewrite() 等高级功能,必须为每个自定义类实现对应的 _eval_expand_* 或 _eval_rewrite 方法,否则它们会回退到原生逻辑并破坏类型一致性。

下面是一个最小可行示例(已精简关键逻辑,生产环境请严格按 SymPy 源码补全 flatten):

from sympy import Expr, Symbol, Add, Mul, Pow, S
from sympy.core.add import _addsort
from sympy.core.mul import _mulsort, _keep_coeff

class CustomExpr(Expr):
    def __add__(self, other): return CustomAdd(self, other)
    def __radd__(self, other): return CustomAdd(other, self)
    def __mul__(self, other): return CustomMul(self, other)
    def __rmul__(self, other): return CustomMul(other, self)
    def __pow__(self, other): return CustomPow(self, other)

class CustomSymbol(CustomExpr, Symbol):
    def __new__(cls, name, **kwargs):
        return Symbol.__new__(cls, name, **kwargs)

class CustomAdd(CustomExpr, Add):
    @classmethod
    def flatten(cls, seq):
        # 复制 sympy.core.add.Add.flatten 源码,将内部 Add→CustomAdd, Mul→CustomMul, Pow→CustomPow
        from sympy.core.add import _addsort
        terms = []
        for x in seq:
            if isinstance(x, cls):
                terms.extend(x.args)
            else:
                terms.append(x)
        # ...(省略归并常数、排序等完整逻辑)
        coeff, nonnumber = S.Zero, []
        for t in terms:
            if t.is_Number:
                coeff += t
            else:
                nonnumber.append(t)
        if coeff is S.Zero and not nonnumber:
            return [S.Zero], []
        if coeff is S.Zero:
            newseq = nonnumber
        else:
            newseq = [coeff] + nonnumber
        _addsort(newseq)  # 排序
        return newseq, {}

    def _eval_subs(self, old, new): return None

class CustomMul(CustomExpr, Mul):
    @classmethod
    def flatten(cls, seq):
        # 同理:复制 Mul.flatten,替换为 CustomMul/CustomAdd/CustomPow
        ...
    def _eval_subs(self, old, new): return None

class CustomPow(CustomExpr, Pow):
    def __new__(cls, b, e, evaluate=None):
        # 复制 Pow.__new__,确保返回 CustomPow
        if evaluate is None:
            evaluate = global_parameters.evaluate
        if evaluate:
            # ...(标准求值逻辑)
            pass
        return Expr.__new__(cls, b, e)

    def _eval_subs(self, old, new): return None

# 全局启用(关键!)
Add = CustomAdd
Mul = CustomMul
Pow = CustomPow

# 验证
a = CustomSymbol('a')
x = a**2 * a**3
print(type(x))  # 
print(x)        # a**5

总结:SymPy 的设计决定了自定义符号必须是“生态级”扩展——不是单点重载,而是构建与 Symbol/Add/Mul/Pow 深度耦合的平行类体系。虽然工作量较大,但这是唯一能保证符号计算全程保持自定义语义(如附加元数据、特殊求值规则)的稳健方案。