矩阵乘法 = 线性变换的复合。AB 先做 B 的变换,再做 A 的变换。

矩阵的基本运算

package main

import "fmt"

type Matrix struct {
    rows, cols int
    data       [][]float64
}

func NewMatrix(rows, cols int) *Matrix {
    data := make([][]float64, rows)
    for i := range data {
        data[i] = make([]float64, cols)
    }
    return &Matrix{rows, cols, data}
}

// 矩阵乘法:O(n³)
func Mul(a, b *Matrix) *Matrix {
    if a.cols != b.rows {
        panic("维度不匹配")
    }
    c := NewMatrix(a.rows, b.cols)
    for i := 0; i < a.rows; i++ {
        for j := 0; j < b.cols; j++ {
            for k := 0; k < a.cols; k++ {
                c.data[i][j] += a.data[i][k] * b.data[k][j]
            }
        }
    }
    return c
}

// 转置
func Transpose(a *Matrix) *Matrix {
    b := NewMatrix(a.cols, a.rows)
    for i := 0; i < a.rows; i++ {
        for j := 0; j < a.cols; j++ {
            b.data[j][i] = a.data[i][j]
        }
    }
    return b
}

func main() {
    // 2x2 矩阵乘法
    a := &Matrix{2, 2, [][]float64{{1, 2}, {3, 4}}}
    b := &Matrix{2, 2, [][]float64{{5, 6}, {7, 8}}}
    c := Mul(a, b)
    fmt.Println(c.data)  // [[19 22] [43 50]]
}

高斯消元:解线性方程组

高斯消元把增广矩阵化为行阶梯型,然后回代求解。时间复杂度 O(n³)。是数值线性代数的基础,背后是矩阵的 LU 分解。

package main

import (
    "fmt"
    "math"
)

// 高斯消元解 Ax = b,返回解向量 x
// augmented: n x (n+1) 增广矩阵 [A|b]
func gaussianElimination(aug [][]float64) []float64 {
    n := len(aug)
    for col := 0; col < n; col++ {
        // 选主元(部分主元法,减少数值误差)
        maxRow := col
        for r := col + 1; r < n; r++ {
            if math.Abs(aug[r][col]) > math.Abs(aug[maxRow][col]) {
                maxRow = r
            }
        }
        aug[col], aug[maxRow] = aug[maxRow], aug[col]

        pivot := aug[col][col]
        if math.Abs(pivot) < 1e-10 {
            continue  // 奇异矩阵
        }

        // 消元
        for r := col + 1; r < n; r++ {
            factor := aug[r][col] / pivot
            for j := col; j <= n; j++ {
                aug[r][j] -= factor * aug[col][j]
            }
        }
    }

    // 回代
    x := make([]float64, n)
    for i := n - 1; i >= 0; i-- {
        x[i] = aug[i][n]
        for j := i + 1; j < n; j++ {
            x[i] -= aug[i][j] * x[j]
        }
        x[i] /= aug[i][i]
    }
    return x
}

func main() {
    // 解方程组:x + 2y = 5,3x + 4y = 6
    aug := [][]float64{
        {1, 2, 5},
        {3, 4, 6},
    }
    x := gaussianElimination(aug)
    fmt.Printf("x=%.1f, y=%.1f\n", x[0], x[1])  // x=-4.0, y=4.5
}

矩阵与神经网络的关系

神经网络的一层前向传播可以写成 output = activation(W × input + b),其中 W 是权重矩阵,b 是偏置向量。批量计算时,同时处理 batch_size 个样本,input 变成矩阵,整个计算是两次矩阵乘法加一次激活函数。GPU 之所以快,是因为它专门为大矩阵乘法优化了硬件。

矩阵概念神经网络对应Go 实现
权重矩阵 W全连接层的参数[][]float64
矩阵乘法 Wx线性变换(前向传播)Mul(W, x)
转置 Wᵀ反向传播的梯度计算Transpose(W)
矩阵求逆正规方程法求解 θ高斯消元 / LU 分解
行列式判断矩阵是否可逆递归展开 / LU 分解
口诀前向传播 = 矩阵乘法链,反向传播 = 链式法则 + 转置。
推荐做法
  • 用列优先存储(column-major)提高缓存命中率
  • 大矩阵运算直接用 gonum 或 NumPy 风格库
  • 数值计算中始终用部分主元法而非朴素消元
不推荐
  • 矩阵乘法不满足交换律——AB ≠ BA,顺序要对
  • 直接求矩阵逆来解线性方程组——用 LU 分解更稳定更快
常见误区
  • 浮点累加误差在高维矩阵乘法中会放大——必要时用 Kahan 求和

判断标准:能用矩阵表示一个神经网络层,并写出前向传播代码 → 掌握本章。

线性代数是数学中最实用的分支——它统一了几何、方程组和数据科学。

— Gilbert Strang《线性代数导论》