FP8 為什麼也能訓練(含 Transformer 誤差界限推導)

更新 發佈閱讀 6 分鐘

直接談核心:

👉 為什麼只有 8-bit 的 FP8,仍然可以訓練像 GPT 這種超大型 Transformer?

使用「表示法 → 誤差模型 → Transformer 誤差傳播 → 為什麼不崩 → 工程關鍵」一步推導。


一、FP8 是什麼?

FP8 不是單一格式,主流有兩種:

1️⃣ E4M3(精度較高)

  • exponent:4 bit
  • mantissa:3 bit

2️⃣ E5M2(範圍較大)

  • exponent:5 bit
  • mantissa:2 bit

👉 對比:

vocus|新世代的創作平台

👉 關鍵事實:

FP8 幾乎沒有精度(只有 2~3 bit mantissa)


二、浮點誤差模型(數學核心)

任意浮點數可寫成:

fl(x) = x(1 + δ)

其中:

|δ| ≤ ε(machine epsilon)

FP8 的 ε 很大!

大約:

ε ≈ 2^{-2} ~ 2^{-3}0.25 ~ 0.125

👉 也就是:

❗ 單次誤差可能高達 10%~25%


三、問題來了:這麼大誤差為什麼不爆?

關鍵在:

👉 誤差在 Transformer 中會被「平均 + 正規化 + 抵消」


四、矩陣乘法誤差分析(核心推導)

Transformer 核心:

Y = XW

加入 FP8 誤差:

fl(X) = X(1 + δ₁)
fl(W) = W(1 + δ₂)

計算結果:

fl(Y) = fl(X) · fl(W)
XW (1 + δ₁ + δ₂ + δ₁δ₂)

忽略高階項:

XW (1 + δ)

👉 關鍵結論:

矩陣乘法的誤差仍然是「相對誤差」


五、誤差不會爆炸的原因(超重要)


1️⃣ 大數平均效應(Law of Large Numbers)

矩陣乘法本質:

y_i = Σ x_j w_j

👉 每一項都有誤差:

x_j w_j (1 + δ_j)

總和:

Σ x_j w_j + Σ x_j w_j δ_j

如果 δ 是隨機:

Σ δ_j ≈ 0(互相抵消)

👉 結論:

誤差會被平均掉


2️⃣ LayerNorm(數值穩定器)

Transformer 每層都有:

x' = (x - μ) / σ

👉 效果:

  • 拉回均值 0
  • 控制方差 = 1

👉 結論:

任何誤差都被重新縮放


3️⃣ Residual connection(殘差結構)

y = x + F(x)

👉 如果 F(x) 有誤差:

  • 原始 x 仍保留

👉 結論:

誤差不會累積爆炸


4️⃣ Softmax(只看相對大小)

softmax(x_i) = e^{x_i} / Σ e^{x_j}

👉 如果有誤差:

x_i → x_i + ε

結果影響:

👉 只影響「排序」,不是絕對值


👉 結論:

對低精度非常穩定


六、訓練動態的誤差界限

我們看參數更新:

θ_{t+1} = θ_t - η ∇L

加入 FP8:

L' =L(1 + δ)

更新變成:

θ_{t+1} = θ_t - η∇L - ηδ∇L

👉 誤差項:

ηδ∇L

如果:

  • η 很小(學習率)
  • δ 雖大但隨機

👉 長期效果:

Σ ηδ∇L ≈ 隨機噪聲

👉 結論:

FP8 誤差 ≈ SGD 噪聲


七、真正讓 FP8 可行的工程關鍵


1️⃣ Per-tensor scaling(最重要)

FP8 範圍太小 → 必須縮放:

x_FP8 = x / scale

👉 每個 tensor:

  • 動態選 scale
  • 保持數值落在可表示範圍


2️⃣ Mixed precision(一定有 FP16/32)

實際做法(NVIDIA GPU):

vocus|新世代的創作平台

累加一定用高精度


3️⃣ Stochastic rounding(隨機捨入)

不是:

0.490

而是:

以機率保留精度

👉 效果:

讓誤差變成“無偏噪聲”


八、為什麼 GPT 特別適合 FP8?


1️⃣ 超大維度(誤差平均)

  • hidden size:數千
  • attention:上萬乘法

👉 誤差自然抵消


2️⃣ 正規化層很多

  • LayerNorm
  • RMSNorm

👉 不讓數值爆炸


3️⃣ 訓練本來就是 noisy optimization

👉 FP8 只是多一點 noise


九、直觀理解(非常關鍵)

👉 FP8 ≈「極度粗糙的測量工具」

但:

  • 你在爬山(找最小值)
  • 不需要精確高度
  • 只需要「往下走」

十、為什麼能快這麼多?

在 NVIDIA GPU(如 Hopper / Blackwell):

vocus|新世代的創作平台

👉 原因:

  • 記憶體 ↓ 4倍
  • 帶寬 ↓ 4倍
  • Tensor Core 專用 FP8

十一、最終數學總結


誤差來源:

δ ≈ O(10^{-1})

但最終影響:

Σ δ_i / N0(平均)

訓練誤差:

FP8 誤差 ≈ SGD noise

十二、一句話總結

👉 FP8 能訓練 GPT,因為深度學習本質是“統計收斂”,不是“精確計算”




留言
avatar-img
sirius數字沙龍
18會員
427內容數
吃自助火鍋啦!不要客氣,想吃啥,請自行取用!
sirius數字沙龍的其他內容
2026/04/07
這個問題其實切到現代 AI 的核心: 👉 為什麼低精度(FP16 / BF16)不但可用,還能訓練像 GPT 這樣的大模型? 用「直覺 → 數學 → 工程技巧 → 為什麼可行」四層說明。 一、核心直覺(先講結論) 👉 神經網路不需要“精確”,只需要“方向正確” 訓練本質是: 參數
Thumbnail
2026/04/07
這個問題其實切到現代 AI 的核心: 👉 為什麼低精度(FP16 / BF16)不但可用,還能訓練像 GPT 這樣的大模型? 用「直覺 → 數學 → 工程技巧 → 為什麼可行」四層說明。 一、核心直覺(先講結論) 👉 神經網路不需要“精確”,只需要“方向正確” 訓練本質是: 參數
Thumbnail
2026/04/07
浮點數(floating-point)在不同程式語言中的差異,本質不是數學不同,而是「實作細節不同」。核心標準幾乎都來自 👉 IEEE 754 但「語言怎麼用、預設精度、誤差處理」會讓結果看起來不一樣。 拆成 4 層:標準 → 差異來源 → 各語言比較 → 實例 一、浮點數本質
Thumbnail
2026/04/07
浮點數(floating-point)在不同程式語言中的差異,本質不是數學不同,而是「實作細節不同」。核心標準幾乎都來自 👉 IEEE 754 但「語言怎麼用、預設精度、誤差處理」會讓結果看起來不一樣。 拆成 4 層:標準 → 差異來源 → 各語言比較 → 實例 一、浮點數本質
Thumbnail
2026/04/07
這一個題目已經從「數字表示」進入電腦底層操作核心了。 bitmask(位元遮罩)本質就是:👉 用二進制的每一個 bit 當作開關(0/1)來控制資料 用「權限 → 一般資料 → AI tensor」三層來說明。 一、什麼是 bitmask?
Thumbnail
2026/04/07
這一個題目已經從「數字表示」進入電腦底層操作核心了。 bitmask(位元遮罩)本質就是:👉 用二進制的每一個 bit 當作開關(0/1)來控制資料 用「權限 → 一般資料 → AI tensor」三層來說明。 一、什麼是 bitmask?
Thumbnail
看更多