JAXとPyTorchの速度検証

JAXとPyTorchの性能差を検証します。今回、ベンチマークはpytestのbenchmarkモジュールを使って行いました。

今回のコードはここに置いてあります。python_bench/neural_network at master · Chachay/python_bench

環境

  • python 3.10
  • numpy 1.24.3
  • pytorch 2.0.0
  • jax
  • CUDA 11.7 + cudnn 8.5.0
  • NVIDIA Driver Version: 531.14
  • NVIDIA RTX3060

あいにくWindowsではPyTorch2.0のJIT機能は使えません。

ライブラリの性能

AlexNetおよびGoogleNetで比較し、JAXはPyTorchの1.4~2.0倍の性能とわかりました。Whisper-jaxでPyTorchからJAXに書き換えた性能向上分2倍と同等です。また、JAXのチュートリアルでも、2.5~3.4倍の性能と紹介されており、妥当な結果と思われます。

性能差はAlexNetやVGGのような単純な2次元畳み込みよりも、GoogleNetやResNetのようなモデルのほうがつきやすいようです。

両ライブラリ推論性能(ms)
Network
Flax PyTorch
AlexNet 2.7 (1.0) 3.8 (1.4)
GoogleNet 37.0 (1.0) 81.1 (2.2)

AlexNetの実装

推論(Inference)での性能差を計測するため、全結合層や学習に関する層を省略したAlexNetモデルを定義した。

import torch
from torch import nn

class AlexNetPyTorch(nn.Module):
    def __init__(self):
        super(AlexNetPyTorch, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )

    def forward(self,x):
        x = self.features(x)
        return x
from flax import linen as nn

class AlexNetFlax(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(96, (11, 11), strides=(4, 4), name='conv1')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        x = nn.Conv(256, (5, 5), padding=((2, 2),(2, 2)), name='conv2')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2 ,2))

        x = nn.Conv(384, (3, 3), padding=((1, 1),(1, 1)), name='conv3')(x)
        x = nn.relu(x)

        x = nn.Conv(384, (3, 3), padding=((1, 1),(1 ,1)), name='conv4')(x)
        x = nn.relu(x)

        x = nn.Conv(256, (3, 3), padding=((1, 1),(1, 1)), name='conv5')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2))

        return x

GoogleNet

こちらはGitHubのレポジトリをご参照ください。

ベンチのコード

ベンチはpytest-benchmarkを使って準備しました。イメージを共有するため簡略化したものを紹介します。 ベンチはrun(x)を20回繰り返し実行(iterations)する計測を3セット(rounds)した結果を返します。 各測定の前に2回のwarmupを行います。

import pytest
import numpy as np

import torch
import jax
import jax.numpy as jnp

from models_flax import AlexNetFlax
from models_pytorch import AlexNetPyTorch

@pytest.mark.benchmark(
    group="AlexNet",
    warmup=True
)
def test_AlexNetPytorch(benchmark):
    model = AlexNetPyTorch()
    model.to('cuda')
    model.eval()

    # FlaxとPyTorchでデータの並び順が異なることに注意
    # バッチ数、チャネル数、画像高さ、画像幅 [N, C, H, W]
    x = np.random.rand(16, 3, 224, 224).astype(np.float32)
    x = torch.from_numpy(x).to('cuda')

    def run(_x):
        with torch.no_grad():
            return model(_x)

    benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)

@pytest.mark.benchmark(
    group="AlexNet",
    warmup=True
)
def test_AlexNetFlax(benchmark):
    model = AlexNetFlax()

    key1, key2 = jax.random.split(jax.random.PRNGKey(0))

    # FlaxとPyTorchでデータの並び順が異なることに注意
    # バッチ数、画像高さ、画像幅、チャネル数 [N, H, W, C]
    x = jax.random.normal(key1, (16, 224, 224, 3))
    weight = model.init(key2, x) # Initialization cal

    @jax.jit
    def run(_x):
        y = model.apply(weight, _x)
        # JAXは非同期実行するのでベンチのため結果がでるのを待ちます。
        jax.block_until_ready(y)
        return y

    # warm_upラウンドを2回いれることで、jitの時間を除外する
    benchmark.pedantic(run, args=(x,), warmup_rounds=2, iterations=20, rounds=3)

if __name__ == "__main__":
    pytest.main(['-v', __file__])

実行

全部の条件のベンチを行うコマンドはこちらです。

pytest benchmark_main.py --benchmark-compare

結果はこのようにグループで出力されます。

--------------------------------------------------------------------------- benchmark 'AlexNet': 2 tests ---------------------------------------------------------------------------
Name (time in ms)          Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_AlexNetFlax        2.7234 (1.0)      2.7352 (1.0)      2.7289 (1.0)      0.0059 (1.0)      2.7283 (1.0)      0.0088 (1.0)           1;0  366.4413 (1.0)           3          20
test_AlexNetPytorch     3.7357 (1.37)     3.9136 (1.43)     3.8180 (1.40)     0.0897 (15.15)    3.8047 (1.39)     0.1335 (15.09)         1;0  261.9172 (0.71)          3          20
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

----------------------------------------------------------------------------- benchmark 'GoogleNet': 2 tests ----------------------------------------------------------------------------
Name (time in ms)             Min                Max               Mean            StdDev             Median               IQR            Outliers      OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_GoogleNetFlax        36.6693 (1.0)      37.2125 (1.0)      36.9712 (1.0)      0.2766 (1.17)     37.0316 (1.0)      0.4074 (1.15)          1;0  27.0481 (1.0)           3          20
test_GoogleNetPytorch     80.8444 (2.20)     81.3176 (2.19)     81.0850 (2.19)     0.2367 (1.0)      81.0931 (2.19)     0.3549 (1.0)           1;0  12.3327 (0.46)          3          20
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

個別に実行する場合

# ひとつだけ
pytest .\benchmark_main.py::test_AlexNetPytorch
# 複数
pytest .\benchmark_main.py -k "test_GoogleNetPytorch or test_GoogleNetFlax"

参考

Win11でJAX!

OpenAIのwhisperをpytorchからjaxに書き直して70倍速くなった (sanchit-gandhi/whisper-jax)というニュースでjaxに興味持ちました。 実のところ、pytorch→jaxの寄与分は、この70倍のうち2倍とのことなのですが、それでもかなりのパフォーマンスです。

まずはjaxの手配から。WSL2 or Dockerとも思いましたが、Windowsネイティブで実行を目指しました。

jaxのビルド(Windows)

jaxはWindows向けに公式バイナリが配布されておらず、自分でビルドする必要があります。少し前までコミュニティビルドバイナリがあったようなのですが、23年4月30日現在、jaxlib 0.3.17+CUDA 11.1などバージョンが古いものしか見当たりません。

ビルド環境

コンパイラのバージョンの組み合わせなどによってビルドが通ったり通らんかったりしそうなのでメモしておきます。

このほか大事な点
  • Win 11の設定で開発者モードを有効にする
  • Bazelのバージョンは大事。bazeliskを使うと良いバージョンのbazelを選んでくれる。
  • realpathなど、bash系のコマンドを導入すること。Git bash付属のものを利用可能。(msys2のScoopならC:\Users\{ユーザ名}\scoop\apps\msys2\2023-03-18\usr\binなど)
  • jaxのクローンはできるだけドライブ直下に。パス長がギリギリになる。
  • exFATのドライブを使うとSynbolicLinkを作れないのでエラーが出る
  • サブコンポになってるTensorFlowなどがVC2022に対応してないかも (Issue #60062 · tensorflow/tensorflow). 必要に応じて環境変数BAZEL_VCを定義する(C:\Program Files(x86)\Microsoft Visual Studio\2019\BuildTools\VC)
私が確認した範囲ではjaxlib v0.3.24+ CUDA 11.7がWindowsでビルドできる最新の組み合わせでした。v0.4.7やCUDA 12.1はダメそう。

WindowsをP3MV3(1.6 USD/hr)をSelf-Hostedするお金があったらMatrix作って確かめます!

Git bash付属のコマンド類

Bazelの公式ではmsys2と書かれているがgit bashと周辺ツールのほうが素性が良さそうで、このあたりが使えるよう環境変数の$env:pathに追加します。

  • C:\Program Files\Git\cmd
  • C:\Program Files\Git\mingw64\bin
  • C:\Program Files\Git\usr\bin

追加後、特にBazelが使いたがるrealpathが動作すれば良い。

あわせてBAZEL_SH=C:\Program Files\Git\usr\bin\bash.exeとしておく。

ビルド作業

環境を整えたあと、ビルド。condaはjax用に環境つくっておきました。

conda create -n jax python=3.10
conda activate jax
conda install numpy

cd d:/
git clone https://github.com/google/jax.git
cd jax
git checkout jaxlib-v0.3.24
python .\build\build.py --enable_cuda `
  --cuda_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
  --cudnn_path="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.7" `
  --cuda_version="11.7" --cudnn_version="8.5.0" --bazel_statup_options="--output_user_root=d:/tmp" `
  --bazel_path="D:/bazel.exe"

ビルド時間はCore i7 10th Gen(8 Core)で3時間くらいというとこでしょうか。

途中、CUDAコードのビルド時に依存パッケージの文字コード警告が出て、その後エラーが出てしまうことがありました。

external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(1): warning C4819: The file contains a character that cannot be represented in the current code page (932). Save the file in Unicode format to prevent data loss
...(中略)...
external/org_tensorflow/tensorflow/compiler/xla/stream_executor/gpu/asm_compiler.cc(106): error C2065: 'ptxas_path': undeclared identifier
...(後略)...

こちらはコンパイラの警告 (レベル 1) C4819 | Microsoft Learnに従って、asm_compiler.ccをBOM付UTF8(さらに念のため改行コードをCRLFに変換)し、build.pyを実行するコマンドを繰り返すことで完了までたどり着きました。 当該ファイルに変な文字が入っているようには見えなかったので不思議です。

最後までビルドが通ると下記ログがでます。

C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\setuptools\command\install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
C:\Users\chachay\miniconda3\envs\jax\lib\site-packages\wheel\bdist_wheel.py:83: RuntimeWarning: Config variable 'Py_DEBUG' is unset, Python ABI tag may be incorrect
if get_flag("Py_DEBUG", hasattr(sys, "gettotalrefcount"), warn=(impl == "cp")):
Output wheel: D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

To install the newly-built jaxlib wheel, run:
  pip install D:\jax\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

余談

github actionでバイナリ作ろうとしたらホストのメモリが足りずヒープエラーで強制終了したのですが、

  1. github actionでページングを有効化する(actions/configure-pages)
  2. BAZELの最大利用メモリを制限する.--local_ram_resources=2048(TF - Bazel Build options)
といった方法で解決できるそうです。ちなみに1を使いました。ただ、Github actionが360分でタイムアウトするので工夫が必要かなと思います。

whl配布するならselfhosted serverが欲しくなります…。

インストール

完成品のjaxlibはd:/jax/distにあります。jaxやflaxとあわせてインストールします。

cd d:/jax/dist
conda activate jax
pip install flax==0.6.4 . .\dist\jaxlib-0.3.24-cp310-cp310-win_amd64.whl

Bazelのキャッシュをきれいにするなら

bazel clean
bazel shutdown

jaxの試食

動作するか確認します。付属のサンプルスクリプトを走らせます。

python .\examples\kernel_lsq.py
MSE: 3.916308e-08

jaxpr of gram(linear_kernel):
{ lambda ; a:f32[100,20]. let
    b:f32[100,100] = dot_general[
      dimension_numbers=(((1,), (1,)), ((), ()))
      precision=(<Precision.HIGH: 1>, <Precision.HIGH: 1>)
      preferred_element_type=None
    ] a a
  in (b,) }

jaxpr of gram(rbf_kernel):
{ lambda ; a:f32[100,20]. let
    b:f32[100,1,20] = broadcast_in_dim[
      broadcast_dimensions=(0, 2)
      shape=(100, 1, 20)
    ] a
    c:f32[1,100,20] = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 100, 20)
    ] a
    d:f32[100,100,20] = sub b c
    e:f32[100,100,20] = integer_pow[y=2] d
    f:f32[100,100] = reduce_sum[axes=(2,)] e
    g:f32[100,100] = neg f
    h:f32[100,100] = exp g
  in (h,) }

動いた! 寝る!

参考

開発チーム 組織と人の成長戦略

開発チーム 組織と人の成長戦略
開発チーム 組織と人の成長戦略
  • 著者: David Loftesness, Alexander Grosse
  • 訳者: 武舎 るみ, 武舎 広幸
  • 出版社/メーカー: マイナビ出版
  • 発売日: 2020/5/29

急成長を前提とするスタートアップ向けの人事戦略論で、10人程度のチームから200人の会社を対象としているようです。「採用」「人事管理」「組織」「文化」「コミュニケーション」の5領域の観点から、必要なプロセス・管理を議論しており、よく構造化されています。 最終章の第12章に、急成長する企業で起こりがちな問題と処方箋(1~11章の内容への参照)がまとめられております。

残念な点として、本書内では議論の理解を補助する多様なブログ記事など副読リソースを提供していますが、bit.lyで紹介されているものは、リンクが切れていたり、乗っ取られていたりすることがあります。

読書メモ

採用

  • 人事のライフサイクルは、募集(紹介、ソーシング、応募)→面接→採用(内部)決定→身元紹介→採用と条件の通知→研修(オンボーディング)→ [中略] →退社手続き

Acquihiring

この他の採用の方法のひとつにAcquihiringがありますが、その内容にもページを割いてました。

  • 文化的不適合が原因で失敗に終わる例も多い。Acquihiring実施には、才能、スキル、文化、価値観でのマッチングが重要になる。マッチングは創業者だけでなく、極力すべての社員とも行う。
  • 被買収企業の現行プロダクトやビジネスの買収後の処遇を計画する。
  • 買収後に「現在手掛けているビジネス・プロダクトが終了する」「被買収会社のチームが解体されて買収会社に組み込まれる」可能性に対して期待値を形成する。
  • 買収前から(同じ釜の飯を食うくらい親密な業務交流があるような)協業で良い関係を築いていると成功しやすい。

研修(オンボーディング)

  • 最初の週、理想的には初日に上司と1 on 1を持つ
  • 幹部に引き合わせるなどし、会社の歴史や製品、現状を伝える機会を設ける(オリエンテーション)
  • 製品への理解、サポートチームで製品が顧客にどう受け取られているのか知ってもらう
  • チーム数が5以上、社員数が20名くらいであればチームローテーションという部門巡回型の職場体験が良い。チーム数がさらに多くなってくると巡回経路が長くなりすぎるため、正式に研修プログラムを作ると良い。様々なフォーマットがありえるが、伝統的日本企業の新卒社員研修のようなチーム配属を伸ばして会社全体を理解してもらうフォーマットも候補にあがる。

人事管理

  • 会社の規模が大きくなるにつれて組織の多層化が必要になる。そこで管理職が登場するが実装の段取りを誤るとと、現場のエンジニアが「二流市民」に落とされたと誤解したり、管理職にならないとキャリアの先がないように錯覚したりする。このため、エンジニア用・管理職用の2通りのキャリアパスを用意するなど、正しいコミュニケーションが処方箋となる (Creating a career path | by Alexander Grosse)
  • プロセス不在だと、業務で不文律がはびこるようになる

コミュニケーション

(COVID前の著書なので、実感としてリモートワークのメリット・デメリットいずれもニュース性がなくなってしまったが) リモートワークの弊害と取り入れに関する工夫も言及されていました。

  • 会社を構成する社員数が多くなると、社員の組み合わせ数が爆発的に増えるためミーティングがミーティングを呼ぶようになる。
  • コミュニケーション手段の選び方、タイミング、透明性、コミュニケーションを取らないという考え方も出てくる。