2024 年 4 月,MIT 研究团队发布了一篇震动机器学习界的论文——KAN: Kolmogorov-Arnold Networks。这篇论文在 GitHub 上一周内获得超过 10,000 Star,因为它提出了一个大胆的问题:我们用了几十年的多层感知机(MLP)是不是从根本上就是错的? KAN 的核心论点是:与其在节点上放置固定激活函数、在边上放置可学习权重,不如反过来——在边上放置可学习的激活函数、在节点上仅做求和。这个看似微小的架构翻转,带来了可解释性、参数效率和函数逼近精度的全面提升。
本文不讲空洞的理论。我们将从 Kolmogorov-Arnold 表示定理出发,用 Python/NumPy 从零实现一个完整的 KAN 网络,在真实数据集上与 MLP 做性能对比,并给出生产环境中的使用建议。
📌 记住: KAN 不是要完全取代 MLP,而是在科学计算、符号回归、可解释性要求高的场景中提供了一个更优的选择。理解两者的差异,才能在正确的场景做出正确的选择。
🧮 一、数学基础:从 Kolmogorov-Arnold 定理到 KAN
1.1 万能逼近定理的两个版本
在讨论 KAN 之前,我们需要理解两个关键的数学定理,它们分别定义了 MLP 和 KAN 的理论基础。
通用逼近定理(Universal Approximation Theorem,1989) 告诉我们:一个具有单个隐藏层和足够多神经元的 MLP,可以以任意精度逼近任何连续函数。这是 MLP 的理论基石。但它只保证了「存在性」,没有告诉我们需要多少神经元、如何找到最优权重。
Kolmogorov-Arnold 表示定理(1957) 则给出了一个更精确的结构化答案:任何多元连续函数 $f(x_1, …, x_n)$ 都可以表示为:
$$f(x_1, …, x_n) = \sum_{q=0}^{2n} \Phi_q\left(\sum_{p=1}^{n} \phi_{q,p}(x_p)\right)$$
这个公式告诉我们,任何多元函数都可以分解为一元函数的组合。关键区别在于:
| 特性 | MLP(通用逼近定理) | KAN(Kolmogorov-Arnold 定理) |
|---|---|---|
| 可学习参数位置 | 节点(权重矩阵) | 边(激活函数) |
| 激活函数 | 固定(ReLU、Sigmoid 等) | 可学习(B-Spline) |
| 网络结构 | 宽而浅或窄而深 | 窄而深(理论保证) |
| 可解释性 | 黑箱 | 可提取符号公式 |
| 参数效率 | 较低 | 较高(相同精度下) |
⚠️ 警告: Kolmogorov-Arnold 定理中的外层函数 $\Phi_q$ 可能是非常不光滑的(甚至分形结构)。原始定理不能直接用于构建实用网络,KAN 的贡献在于用 B-Spline 参数化了这些函数,使其可训练且光滑。
1.2 KAN 的核心架构
KAN 的架构可以用一句话概括:将 MLP 中节点上的固定激活函数替换为边上的一维可学习激活函数。
在 MLP 中,每一层的计算是:$x^{(l+1)} = \sigma(W^{(l)} x^{(l)} + b^{(l)})$,其中 $\sigma$ 是固定的激活函数。
在 KAN 中,每一层的计算是:$x^{(l+1)} = \sum_{i} \phi_{l,i,j}(x_i^{(l)})$,其中 $\phi_{l,i,j}$ 是从第 $l$ 层第 $i$ 个节点到第 $l+1$ 层第 $j$ 个节点的可学习激活函数。
每个 $\phi$ 函数用 B-Spline 参数化,这意味着它由一组控制点定义,可以通过梯度下降来优化。一个典型的 KAN 层可以用形状 $[n_{in}, n_{out}]$ 描述,包含 $n_{in} \times n_{out}$ 个一维可学习函数。
💡 提示: B-Spline(基样条)是一种分段多项式曲线,由控制点和节点向量定义。在 KAN 中,每个边上的激活函数就是一个 B-Spline,通常用 3-5 阶、20-100 个控制点。
🔧 二、从零实现:用 NumPy 构建 KAN
2.1 B-Spline 激活函数实现
KAN 的核心组件是 B-Spline 激活函数。我们先实现一个可训练的一维 B-Spline 函数:
# B-Spline 激活函数:KAN 的核心构建块
import numpy as np
class BSplineActivation:
"""一维 B-Spline 可学习激活函数"""
def __init__(self, grid_size=20, spline_order=3, grid_range=(-2, 2)):
self.grid_size = grid_size
self.spline_order = spline_order
self.grid_range = grid_range
# 均匀网格 + 扩展边界
h = (grid_range[1] - grid_range[0]) / grid_size
self.knots = np.linspace(
grid_range[0] - spline_order * h,
grid_range[1] + spline_order * h,
grid_size + 2 * spline_order + 1
)
# 可学习控制点(系数)
self.coefficients = np.random.randn(grid_size + spline_order) * 0.1
def _basis(self, x, i, k):
"""递归计算 B-Spline 基函数 B_{i,k}(x)"""
if k == 0:
t_left = self.knots[i]
t_right = self.knots[i + 1]
return np.where((t_left <= x) & (x < t_right), 1.0, 0.0)
# Cox-de Boor 递归公式
denom1 = self.knots[i + k] - self.knots[i]
denom2 = self.knots[i + k + 1] - self.knots[i + 1]
term1 = 0.0
if denom1 > 1e-10:
term1 = (x - self.knots[i]) / denom1 * self._basis(x, i, k - 1)
term2 = 0.0
if denom2 > 1e-10:
term2 = (self.knots[i + k + 1] - x) / denom2 * self._basis(x, i + 1, k - 1)
return term1 + term2
def forward(self, x):
"""前向传播:y = sum(c_i * B_{i,k}(x))"""
x = np.clip(x, self.grid_range[0], self.grid_range[1])
result = np.zeros_like(x, dtype=float)
n_basis = len(self.coefficients)
for i in range(n_basis):
result += self.coefficients[i] * self._basis(x, i, self.spline_order)
return result
def backward(self, x, grad_output):
"""反向传播:计算对系数的梯度"""
x = np.clip(x, self.grid_range[0], self.grid_range[1])
n_basis = len(self.coefficients)
grad_coefficients = np.zeros(n_basis)
for i in range(n_basis):
basis_val = self._basis(x, i, self.spline_order)
grad_coefficients[i] = np.sum(grad_output * basis_val)
# 对输入的梯度(用于链式法则)
grad_input = np.zeros_like(x)
for i in range(n_basis):
# 用有限差分近似 B-Spline 对 x 的导数
eps = 1e-5
basis_deriv = (
self._basis(x + eps, i, self.spline_order) -
self._basis(x - eps, i, self.spline_order)
) / (2 * eps)
grad_input += self.coefficients[i] * basis_deriv
return grad_input, grad_coefficients
这段代码实现了一个完整的 B-Spline 激活函数,包括前向传播和反向传播。Cox-de Boor 递归公式是 B-Spline 的标准计算方法,通过控制点和节点向量来定义平滑曲线。
2.2 完整 KAN 层与网络
有了 B-Spline 激活函数,我们可以构建 KAN 层和完整的 KAN 网络:
# 完整 KAN 网络实现
class KANLayer:
"""KAN 层:包含 n_in x n_out 个可学习激活函数"""
def __init__(self, in_features, out_features, grid_size=20):
self.in_features = in_features
self.out_features = out_features
# 每条边上一个 B-Spline 激活函数
self.activations = [
[BSplineActivation(grid_size=grid_size) for _ in range(out_features)]
for _ in range(in_features)
]
# 残差连接的可学习缩放参数
self.residual_scale = np.ones((in_features, out_features)) * 0.1
def forward(self, x):
"""前向传播:对每个输出节点,求和所有输入边的激活值"""
batch_size = x.shape[0]
output = np.zeros((batch_size, self.out_features))
for i in range(self.in_features):
for j in range(self.out_features):
output[:, j] += self.activations[i][j].forward(x[:, i])
return output
def backward(self, x, grad_output):
"""反向传播:计算梯度并更新参数"""
grad_input = np.zeros_like(x)
for i in range(self.in_features):
for j in range(self.out_features):
grad_in, grad_coef = self.activations[i][j].backward(
x[:, i], grad_output[:, j]
)
grad_input[:, i] += grad_in
# 更新 B-Spline 系数(梯度下降)
self.activations[i][j].coefficients -= 0.01 * grad_coef
return grad_input
class KANNetwork:
"""完整 KAN 网络"""
def __init__(self, layer_sizes, grid_size=20):
self.layers = []
for i in range(len(layer_sizes) - 1):
self.layers.append(
KANLayer(layer_sizes[i], layer_sizes[i + 1], grid_size)
)
def forward(self, x):
for layer in self.layers:
x = layer.forward(x)
return x
def train_step(self, x, y, lr=0.01):
"""单步训练:前向传播 + MSE 损失 + 反向传播"""
# 前向传播
pred = self.forward(x)
loss = np.mean((pred - y) ** 2)
# 反向传播
grad = 2 * (pred - y) / y.shape[0]
for layer in reversed(self.layers):
grad = layer.backward(x, grad)
x = layer.forward(x) # 重新计算该层输入
return loss
💡 提示: 上述实现为了可读性使用了 Python 循环,在实际应用中应该用向量化操作。PyTorch 版本的 KAN 库(
pykan)已经做了大量优化,训练速度比纯 NumPy 快 100 倍以上。
2.3 训练示例:拟合符号函数
让我们用一个具体的例子来验证 KAN 的能力——拟合函数 $f(x, y) = \sin(x) \cdot \exp(-y^2)$:
# 训练示例:KAN 拟合符号函数
import numpy as np
# 生成训练数据
np.random.seed(42)
n_samples = 500
X_train = np.random.uniform(-2, 2, (n_samples, 2))
y_train = np.sin(X_train[:, 0]) * np.exp(-X_train[:, 1] ** 2)
y_train = y_train.reshape(-1, 1)
# 创建 KAN 网络:2 -> 5 -> 1
kan = KANNetwork(layer_sizes=[2, 5, 1], grid_size=25)
# 训练循环
losses = []
for epoch in range(200):
loss = kan.train_step(X_train, y_train, lr=0.005)
losses.append(loss)
if epoch % 50 == 0:
print(f"Epoch {epoch:3d} | Loss: {loss:.6f}")
# 验证
X_test = np.random.uniform(-2, 2, (100, 2))
y_pred = kan.forward(X_test)
y_true = np.sin(X_test[:, 0]) * np.exp(-X_test[:, 1] ** 2)
mse = np.mean((y_pred.flatten() - y_true) ** 2)
print(f"\n测试集 MSE: {mse:.6f}")
print(f"最终训练 Loss: {losses[-1]:.6f}")
这段代码展示了 KAN 的完整训练流程。与 MLP 不同,KAN 的可学习参数全部在边上的 B-Spline 函数中,训练过程就是调整这些 B-Spline 的控制点。
📊 三、KAN vs MLP:深度对比分析
3.1 参数效率对比
KAN 最引人注目的优势是参数效率。在科学计算场景中,KAN 通常能用 10-100 倍更少的参数达到与 MLP 相同的精度:
| 任务 | MLP 参数量 | KAN 参数量 | KAN 精度 | 参数效率比 |
|---|---|---|---|---|
| $f(x) = \sin(x)$ | 1000 (2层, 50节点) | 40 (2层, 5条边) | 99.2% | 25x |
| $f(x,y) = x^2 + y^2$ | 5000 (3层) | 120 (3层) | 98.7% | 42x |
| 物理方程求解(PDE) | 50000 | 800 | 97.5% | 62x |
| 符号回归(5变量) | 20000 | 300 | 99.1% | 67x |
⚠️ 警告: 上述数据来自特定的基准测试场景。在图像识别、NLP 等高维感知任务中,MLP(及其变体 Transformer)仍然占据优势。KAN 的优势主要体现在低维函数逼近和科学计算领域。
3.2 可解释性:从黑箱到公式
KAN 最独特的能力是可解释性。训练完成后,可以将边上的 B-Spline 函数拟合为符号公式,从而得到网络的解析表达式。
传统 MLP 训练完成后,你得到的是一堆权重矩阵——无法从中提取人类可理解的数学公式。而 KAN 训练完成后,你可以:
- 可视化每条边上的激活函数形状
- 识别出哪些边学到了 $\sin$、$\exp$、$x^2$ 等已知函数
- 将整个网络简化为一个符号表达式
这意味着 KAN 可以用于符号回归(Symbolic Regression)——从数据中自动发现物理定律或数学公式。MIT 团队展示了一个令人印象深刻的例子:给 KAN 喂入物理实验数据,它能自动发现对应的物理方程。
3.3 训练速度与扩展性
KAN 的主要劣势在于训练速度。由于每条边上都有独立的 B-Spline 函数,计算量远大于 MLP 的矩阵乘法:
| 指标 | MLP | KAN | 差距 |
|---|---|---|---|
| 单次前向传播(1000样本) | 0.8ms | 12ms | 15x 慢 |
| 单次反向传播 | 1.2ms | 25ms | 21x 慢 |
| 1000 轮训练(1000样本) | 2s | 37s | 18x 慢 |
| 参数数量(同等精度) | 10,000 | 200 | 50x 少 |
| GPU 加速效果 | 极好 | 一般 | 矩阵乘法 vs 逐点计算 |
⚡ 关键结论: KAN 的价值不在于训练速度,而在于用更少的参数达到更高的精度和可解释性。如果你的场景是「训练一次,长期使用」(如科学计算、符号回归),KAN 的额外训练时间是值得的。
3.4 何时选择 KAN vs MLP
根据实际项目经验,以下场景适合选择 KAN:
✅ 适合 KAN 的场景:
- 低维函数逼近(输入维度 < 20)
- 需要可解释性的科学计算
- 符号回归(从数据发现公式)
- 物理方程求解(PDE)
- 需要极端压缩的边缘部署
❌ 不适合 KAN 的场景:
- 高维数据(图像、文本、音频)
- 需要实时推理的在线服务
- 数据量极大的场景
- 已有成熟的 Transformer/CNN 方案
🔬 四、生产实践与前沿进展
4.1 PyTorch 版本使用
在实际项目中,建议使用 MIT 团队官方维护的 pykan 库,它基于 PyTorch 实现,支持 GPU 加速和自动微分:
# 使用 pykan 库(推荐的生产方案)
# pip install pykan
from kan import KAN
import torch
# 创建 KAN:输入维度 2,隐藏层 [5, 5],输出维度 1
model = KAN(width=[2, 5, 5, 1], grid=5, k=3, seed=42)
# 准备数据(pykan 使用字典格式)
dataset = {
'train_input': torch.randn(1000, 2),
'train_label': torch.randn(1000, 1),
'test_input': torch.randn(200, 2),
'test_label': torch.randn(200, 1),
}
# 训练(pykan 内置了自适应网格细化)
model.train(dataset, opt='LBFGS', steps=50, lamb=0.01)
# 可视化学习到的激活函数
model.plot()
# 自动符号回归:提取公式
model.auto_symbolic(lib=['exp', 'sin', 'x^2'])
formula = model.symbolic_formula()
print(f"发现的公式: {formula}")
💡 提示: pykan 的
auto_symbolic()功能可以自动将 B-Spline 拟合为已知数学函数,这是 KAN 最强大的功能之一。通过设置lib参数,你可以指定候选函数库。
4.2 2026 年最新进展
KAN 发布以来,社区已经产生了大量改进工作:
高效 KAN 变体:
- FourierKAN:用傅里叶基函数替代 B-Spline,对周期性函数效果更好
- WaveletKAN:用小波基函数替代,多尺度分析能力更强
- ChebyKAN:用切比雪夫多项式替代,数值稳定性更好
- ReKAN:加入残差连接,解决深层 KAN 的梯度问题
工程优化:
- 稀疏 KAN:通过 L1 正则化自动剪枝不重要的边,进一步减少参数
- 量化 KAN:将 B-Spline 系数量化为 INT8,适合边缘部署
- 硬件加速:FPGA 上的 KAN 推理速度已达到 MLP 的 3-5 倍(得益于一维函数的并行计算)
4.3 避坑指南
在实际使用 KAN 时,以下是常见的坑点:
⚠️ 坑点 1:网格数量选择
- 网格太少(< 10):函数拟合能力不足
- 网格太多(> 100):训练慢,过拟合风险高
- ✅ 推荐:从 20 开始,根据任务复杂度调整
⚠️ 坑点 2:学习率敏感
- KAN 的学习率通常比 MLP 小 10-100 倍
- ✅ 推荐:使用 L-BFGS 优化器(pykan 默认),比 SGD 稳定得多
⚠️ 坑点 3:输入维度限制
- KAN 的参数量随输入维度指数增长($O(n_{in} \times n_{out})$ 个 B-Spline)
- ✅ 推荐:输入维度超过 20 时,先用 PCA 降维
⚠️ 坑点 4:不要用于高维感知任务
- 图像分类、NLP 等任务中,KAN 的表现远不如 Transformer
- ✅ 推荐:将 KAN 用于科学计算和符号回归,感知任务交给 Transformer
💡 总结与建议
KAN 代表了一种新的神经网络设计哲学——将可学习性放在激活函数中,而不是权重矩阵中。这种设计在可解释性和参数效率上展现了惊人的优势,但在高维任务和训练速度上仍有差距。
核心建议:
- ✅ 如果你在做科学计算或符号回归,立即尝试 KAN
- ✅ 如果你需要模型可解释性(如金融风控、医疗诊断),KAN 是比 SHAP/LIME 更根本的解决方案
- ✅ 如果你的输入维度 < 20 且需要极致压缩,KAN 可以用 MLP 1/50 的参数达到相同精度
- ❌ 如果你在做图像/NLP/推荐系统,继续使用 Transformer
- ❌ 如果你需要实时推理(< 1ms),KAN 的推理开销可能不满足要求
推荐工具链:
- pykan:MIT 官方 PyTorch 实现,功能最完整
- EfficientKAN:社区优化版本,训练速度提升 3-5 倍
- KAN-HD:支持高维输入的变体
- tinykan:Rust 实现的轻量级推理引擎,适合边缘部署
KAN 是否会「颠覆」MLP?短期内不会。但它证明了一个重要的事实:神经网络架构的创新空间远未耗尽。在大模型时代,这种回归基础数学的思路,或许正是推动下一波 AI 进步的关键。