與自然語言處理類似,對預訓練視覺主干的遷移提高了模型在各種視覺任務上的性能。更大的數據集、可擴展的架構和新的訓練方法都推動了模型性能的提升。
然而,視覺模型仍然遠遠落后于語言模型。具體來說,迄今為止最大的視覺模型 ViT 只有 4B 參數,而入門級語言模型通常超過 10B 參數,更別說具有 540B 參數的大型語言模型。
為了探索 AI 模型的性能極限,Google Research 最近在 CV 領域的一項研究,率先將 Vision Transformer 參數量擴展到了 22B,提出 ViT-22B,與之前類似的模型參數量 4B 相比,可以說這是迄今為止最大的稠密型 ViT 模型。
論文地址:https://arxiv.org/pdf/2302.05442.pdf
對比之前最大的 ViT- G 和 ViT-e,表 1 給出了比較結果,由下表可得,ViT-22B 主要是擴展了模型的寬度,使得參數量更大,深度和 ViT-G 一樣。
當前的 ViT 大模型
正如這位知乎網友所說,難道是谷歌在 ChatGPT 上輸了一局,勢必要在 CV 領域爭口氣?
如何做到的?原來研究早期,他們發現在擴展 ViT 的過程中,出現了訓練不穩定性,并且可能會帶來架構變化。然后研究人員仔細設計模型,并且以前所未有的效率來實現模型并行訓練。ViT-22B 的質量是通過一套全面的任務來評估的,從(少樣本)分類到密集輸出任務,在這些任務中,它達到或超過了當前 SOTA 水平。例如,即使用作凍結的視覺特征提取器,ViT-22B 在 ImageNet 上的準確率也達到了 89.5%。通過訓練 text tower 來匹配這些視覺特征,它在 ImageNet 上實現了 85.9% 的零樣本設置準確率。此外,該模型可以看作是一個教師,用作蒸餾目標,研究人員訓練了一個 ViT-B 學生模型,在 ImageNet 上的準確率為 88.6%,達到了此類規模模型上 SOTA 水平。
模型架構
ViT-22B 是一種基于 Transformer 的編碼器模型,類似于原始 Vision Transformer 架構,但包含以下三個主要修改,以提高效率和大規模訓練的穩定性:并行層、查詢 / 鍵(QK)歸一化和 omitted biases。
并行層。正如 Wang 和 Komatsuzaki 研究所述,該研究設計了一個 Attention 和 MLP 并行結構:
這可以通過組合 MLP 和注意力塊的線性投影來實現額外的并行化。值得注意的是,用于查詢 / 鍵 / 值投影的矩陣乘法和 MLP 的第一線性層被融合到一個單獨的操作中,對于 MLP 的注意力外投影和第二層線性層也是如此。
QK 歸一化。訓練大模型的一個困難是模型的穩定性,在將 ViT 擴展的過程中,研究人員發現在幾千輪的 step 后訓練損失呈發散性。特別是在 8B 參數的模型中這種現象尤為突出。為了穩定模型訓練,研究人員采用 Gilmer 等人的方法,在點積注意力計算之前對查詢和鍵應用 LayerNorm 歸一化操作,以提升訓練的穩定性。具體來說,注意力權重計算為:
omitted biases。在 PaLM 之后,偏置項從 QKV 投影中移除,并且所有的 Layernorm 都在沒有偏置的情況下應用,從而提高了加速器的利用率 (3%),且質量沒有下降。然而,與 PaLM 不同的是,研究人員對 MLP 密集層使用了偏置項,即便如此,這種方式在兼顧質量的同時,速度沒有下降。
圖 2 展示了一個 ViT-22B 編碼器塊。嵌入層在原有 ViT 的基礎上進行了 patch 提取、線性投影和添加位置嵌入等操作。研究人員使用多頭注意力池化來聚合頭中的每個 token 表示。
ViT-22B 使用 14 × 14 的 patch,圖像分辨率為 224 × 224。ViT-22B 采用了一種學習到的一維位置嵌入。在對高分辨率圖像進行微調期間,研究人員根據預訓練的位置嵌入在原始圖像中的位置執行二維插值。
訓練基礎設施與效率
ViT-22B 使用 FLAX 庫,實現方式是 JAX,并在 Scenic 中構建。它同時利用了模型和數據并行性。值得一提的是,研究人員使用了 jax. xmap API,它提供了對所有中間體的分片(例如權重和激活)以及芯片間通信的顯式控制。研究人員將芯片組織成大小為 t × k 的 2D 邏輯網格,其中 t 是數據平行軸的大小,k 是模型軸的大小。然后,對于 t 組中的每個組,k 個設備獲得相同批次的圖像,每個設備只保留 1/k 的激活,并負責計算所有線性層輸出的 1/k(詳細內容如下)。
圖 3:異步并行線性操作(y = Ax):跨設備的重疊通信和計算的模型并行矩陣乘法。
異步并行線性操作。為了最大限度地提高吞吐量,必須考慮計算和通信。也就是說,如果希望這些操作在分析上等效于未分片的情況,就必須盡可能少地進行通信,理想情況下讓它們重疊,這樣就可以保持矩陣乘法單元(FLOP 的大部分容量所在)始終處于繁忙狀態。
參數分片。該模型在第一個軸上是數據并行的。每個參數可以在這個軸上完全復制,也可以讓每個設備保存它的一個塊。研究人員選擇從模型參數中分割一些大張量,以便能夠擬合更大的模型和批量大小。
使用這些技術,ViT-22B 在 TPUv4 上訓練期間,每個核每秒處理 1.15k token。ViT-22B 的模型 flops 利用率(MFU)為 54.9%,表明硬件的使用非常有效。請注意,PaLM 報告的 MFU 為 46.2%,而研究人員在相同硬件上為 ViT-e(僅數據并行)測量的 MFU 為 44.0%。
實驗結果
實驗探究了 ViT-22B 用于圖像分類的評估結果。
表 2 結果顯示,ViT-22B 在各種指標上仍有顯著的改善。此外,研究表明,像 ViT-22B 這樣的大型模型的 Linear probing 可以接近或超過具有高分辨率的小型模型的 full fine-tuning 性能,通常成本更小、更容易做到。
研究進一步在細粒度分類數據集 iNaturalist 2017 上測試線性可分離性,將 ViT-22B 與其他 ViT 變體進行比較。研究測試了 224px 和 384px 的輸入分辨率。結果如圖 4。研究觀察到 ViT-22B 明顯優于其他 ViT 變體,特別是在標準的 224px 輸入分辨率下。這表明 ViT-22B 中大量的參數對于從圖像中提取詳細信息是有用的。
表 3 顯示了 ViT-22B 對 CLIP、ALIGN、BASIC、CoCa、LiT 模型的零樣本遷移結果。表 3 底部比較了三個 ViT 模型性能。
在所有的 ImageNet 測試集中,ViT-22B 取得了相當或更好的結果。值得注意的是,ObjectNet 測試集上的零樣本結果與 ViT 模型大小高度相關。最大的 ViT-22B 將新的 SOTA 設置在具有挑戰性的 ObjectNet 測試集中。
Out-of-distribution (OOD)。研究構建了一個從 JFT 到 ImageNet 的標簽映射,以及從 ImageNet 到不同分布外數據集的標簽映射,即 ObjectNet、ImageNet-v2、ImageNet- R 和 ImageNet- A。
目前可以確認的結果是,與 ImageNet 上的改進一致,擴展模型增加了分布外性能。這適用于只看過 JFT 圖像的模型,以及在 ImageNet 上進行微調的模型。在這兩種情況下,ViT-22B 在更大的模型上都延續了 OOD 性能更好的趨勢(圖 5,表 11)。
此外,研究人員還研究了 ViT-22B 模型在語義分割和單目深度估計任務中捕獲的幾何和空間信息質量。
語義分割。研究人員在三個基準上評估 ViT-22B 作為語義分割主干:ADE20K、Pascal Context 和 Pascal VOC。從表 4 可以看出,當只看到少量分割掩碼時,ViT-22B 主干遷移效果更好。
單目深度估計。表 5 總結了研究的主要發現。從最上面的行(DPT 解碼器)中可以觀察到,與不同的主干相比,使用 ViT-22B 特性產生了最好的性能(在所有指標上)。通過將 ViT-22B 主干與 ViT-e(一個較小的模型,但在與 ViT-22B 相同的數據上進行訓練)進行比較,研究發現擴展架構可以提高性能。
此外,將 ViT-e 主干與 ViT-L(與 ViT-e 類似的架構,但訓練的數據更少)進行比較,研究發現這些改進也來自于擴展訓練前的數據。這些發現表明,更大的模型和更大的數據集都有助于提高性能。
該研究還在視頻數據集上進行了探索。表 6 展示了在 Kinetics 400 和 Moments in Time 數據集上的視頻分類結果,表明可以使用凍結的主干實現具有競爭力的結果。研究首先與 ViT-e 進行比較,ViT-e 擁有最大的先驗視覺主干模型,由 40 億個參數組成,并且也在 JFT 數據集上進行訓練。我們觀察到更大的 ViT-22B 模型在 Kinetics 400 上提高了 1.5 分,在 Moments in Time 上提高了 1.3 分。
最后研究注意到,通過完整的端到端微調,還有進一步改進的空間。
更多技術細節請參閱原論文。