Re - ImageJで学ぶ!: 第71回 ImageJで深層学習モデルを動かす!-NVIDIA pre-trained modelの例-

2021年12月17日金曜日

第71回 ImageJで深層学習モデルを動かす!-NVIDIA pre-trained modelの例-

深層学習はアカデミックからの視点で見ればその歴史はその分野内では長いほうで、以前から脳の機能を模倣したアルゴリズムとして研究されてきました。しかし、計算量が膨大になるため、ハードウェア的に計算コストがかかりすぎるなどのリミテーションがあり(あったので)、大衆化までの道のりが長くなってしまいました。しかし、TensorFlowやChainer(現在のPyTorchにも影響を与えた秀逸な日本製・日本発機械学習向けの計算機ライブラリ、現在は開発は終了している)などの深層学習に親和性の高いオープンソースの計算処理ライブラリや、オープンデータ(例えば、MNISTやImageNetなど)を世界で共有しようというデータサイエンス界の文化もこの分野が急速に注目される追い風になりました。

医療応用も進んでいます。今日では、「医用AI」「医療AI」「医療機器プログラム(あるいはすでに申請済みの医療機器プログラムへの機能追加をするなど)」などという枠として承認申請ができるようになり、診療報酬制度下での運用方法も厚労省を中心に議論されています。例えば、CTスキャナーでは、深層学習を使った超解像画像を作ったり、内視鏡では、内視鏡画像からリアルタイムに病変を検出し、悪性度などを予測する、など、臨床でのクリニカル・クエッションに応えようとするモデル(いわゆる、医療AI)が販売されています。

このような中で、今回はImageJで深層学習モデルを動かす方法について、NVIDIA社がNGCという取り組みの中で研究者とともに公開している深層学習モデル(肺のセグメンテーション)を例に、その概要と簡単な使い方を説明します。

NVIDIA NGC


NGCは「NVIDIA GPU CLOUD」の略です。このNGCの取り組みでは、GPUを普及させるために、GPUを利用するとメリットの大きい深層学習モデル開発などを対象として、必要なツールを提供しています。
訓練済みの深層学習モデルデータや、モデルを動かすために必要なコンピュータプログラム群などです。

図 NVIDIA NGC Catalogのページ例

これらのツールはカタログにまとめられており、ユーザーが必要とするサービスを選べるようになっています。クラウド上で深層学習を動かせるように、環境が整えられています。
とはいうものの、クラウドサービスとなると、個人の研究やちょっとした勉強レベルでは手を出しにくいですし、試してみようと思いづらいところもあるのではないでしょうか(難しそうだから)。また慣れが必要と感じてしまうのは私だけでしょうか。普段からクラウドが当然で、「あ、PCの電源つけよう!クラウド上の!」とすいすいできる人はクラウドサービスエンジニアくらいでしょう。私のような凡人には遠い感覚です。ふだんの使い慣れたパソコンで、なんならImageJで動かしたい(できるなら無料で)と思ってしまいます。

今回は、このような私が無料・深層学習・ImageJにトライしてみました。
以下、手順になります。

pre-trained modelのダウンロード


訓練済みモデル(pre-trained model)は、NVIDIA NGC > Catalog > modelsのページからダウンロードできます。キーワード検索できます。例として、「lung」で検索すると、このようなモデルのリストが表示されました。

図 「lung」検索結果

今回はこの「clara_pt_covid19_ct_lung_segmentation」を使ってみたいと思います。このモデルは、「A pre-trained model for volumetric (3D) segmentation of the lung from CT images.」ということで、3Dのボリュームデータから、肺の領域を予測してくれるようです。

では、このモデルをダウンロードし、解凍します。
ダウンロード時は、「files.zip」というファイル名になっていました。
解凍すると、中身はこのようになっています。

図 files.zip

モデルデータは、解凍したフォルダー内のmodelsフォルダーに入っています。
この例では、「model.pt」「model.ts」です(拡張子が見えないときは、右クリックからプロパティを開いて確認できます)。

図 modelファイル(files>modelsフォルダの中身)

さて、このptファイルは何のファイルでしょうか。私はTensorflow派なので最初は何のファイルだか分かりませんでした。このptファイルは、PyTorchで作成された深層学習モデル用のファイルです。tsファイルはトーチスクリプトファイルと呼ばれるファイルで、ptファイルと同様に、モデルデータを保存しているデータです。ptファイルとの違いは、モデルを推論実行させる際に高度な処理の設定を追加できることです。この「.pt」拡張子のファイルは、モデルの学習状態とモデルの構造の両方を保持できます。ただし、モデルの学習状態のみを保持することもできます。モデルの学習状態のみを保持している場合は、モデルの構造は保持しません。これはクリエイター事情です。なので、ptとtsのファイルが合わせてある場合は、モデルとして利用する際に両方のファイルが必要であることが多いです。今回のケースは2つとも必要です。

CPUかGPUか


深層学習モデルは、(特にPyTorchの場合)GPU用あるいはCPU用に出力することができます。NVIDIA NGCで公開されているモデルはすべてGPU用になっています。そのため、なにもしなければ、GPUを持っていない人はここから先の操作ができません。しかし、それでは「大衆化」とは呼べません(個人的に)。ここでは、CPUでも動かせるようにモデルに前処理を追加します。GPU前提の方は、ここの処理は不要です。CPUで動かしたい方は、この処理が必要です。
GoogleColabを使って、GPU用に保存されたモデル(先ほどのpt, tsファイル)を、CPU用にリトレースします。新しくColabノートブックを作成し、ランタイムをGPUに切り替えます。方法は以下の通りです。

      import torch
      # 自分のモデルが保存されている場所を指定します。
      # Colabノートブックに一時的に直接アップロードした場合は、"model.pt"や"model.ts"としてください
      p2state_dict = "/content/drive/MyDrive/NVIDIA_PreTrained_Models/clara_pt_covid19_ct_lung_segmentation/models/model.pt"
      p2m = "/content/drive/MyDrive/NVIDIA_PreTrained_Models/clara_pt_covid19_ct_lung_segmentation/models/model.ts"
      state_dict = torch.load(p2state_dict) # 学習重み
      model = torch.load(p2m) # モデルの構造など 
      model.load_state_dict(state_dict) # All keys matched successfully, モデルに学習状態をロード
      model = model.to("cpu")# gpuにトレースするときはmodel = model.to("cuda")
      traced_model = torch.jit.trace(model, (None,224,224,32))# traced_model = torch.jit.script(model) # こちらでもよさそう(試していません)
      torch.jit.save(traced_model, "model_cpu.pt")
      # model_cpu.ptというファイルがColab内のフォルダに作成されるのでダウンロードして使います。
  	

CPU用にリトレースしたモデルデータは、こちらに公開しています。

ImageJ/Fijiで動かしてみる(プラグインを開発する)


ここから、Javaの強みである「Write once, run anywhere」を享受させてもらいます。
今回は、プラグインとしてファイルを作るのではなく、ImageJをライブラリとして使って、簡単に動作を確かめる方法で解説させていただきます。
最終的に動かすところまではできるので、プラグインとしてJar化して使いたいなどは、ImageJ公式のプラグイン開発方法をご参照ください。
深層学習モデルを推論実行させるために必要なライブラリとして、今回はAmazonが開発している「Deep Java Library」を使います。
開発環境には、Eclipseを使っていきます。
Javaは、私の環境では、AdoptOpenJDK11を使っています。

では、Mavenプロジェクトを新しく作成し、pom.xmlを次のように編集しています。
groupid、artifactid、versionなどは自分の設定に合わせてください。
一部抜粋していますが、<project></project>で挟まれる領域です。

        <groupid>com.vis.machinelearning</groupid>
        <artifactid>ai</artifactid>
        <version>0.0.1-SNAPSHOT</version>

        <properties>
            <!--Minimal version for compiling TensorFlow Java is JDK 8-->
            <maven .compiler.source="">1.11</maven>
            <maven .compiler.target="">1.11</maven>
        </properties>

        <dependencies>
            <dependency>
                <groupid>net.imagej</groupid>
                <artifactid>ij</artifactid>
                <version>1.53j</version>
            </dependency>
            <dependency>
                <groupid>ai.djl</groupid>
                <artifactid>api</artifactid>
                <version>0.14.0</version>
            </dependency>
            <!--https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-engine-->
            <dependency>
                <groupid>ai.djl.pytorch</groupid>
                <artifactid>pytorch-engine</artifactid>
                <version>0.14.0</version>
            </dependency>
            <!--https://mvnrepository.com/artifact/ai.djl.pytorch/pytorch-native-auto-->
            <dependency>
                <groupid>ai.djl.pytorch</groupid>
                <artifactid>pytorch-native-auto</artifactid>
                <version>1.9.1</version>
                <scope>runtime</scope>
            </dependency>
            <!--https://mvnrepository.com/artifact/ai.djl.tensorflow/tensorflow-engine-->
            <dependency>
                <groupid>ai.djl.tensorflow</groupid>
                <artifactid>tensorflow-engine</artifactid>
                <version>0.14.0</version>
            </dependency>
            <!--https://mvnrepository.com/artifact/ai.djl.tensorflow/tensorflow-native-auto-->
            <dependency>
                <groupid>ai.djl.tensorflow</groupid>
                <artifactid>tensorflow-native-auto</artifactid>
                <version>2.4.1</version>
                <scope>runtime</scope>
            </dependency>
            <dependency>
                <groupid>ai.djl</groupid>
                <artifactid>model-zoo</artifactid>
                <version>0.14.0</version>
            </dependency>
            <dependency>
                <groupid>ai.djl.pytorch</groupid>
                <artifactid>pytorch-model-zoo</artifactid>
                <version>0.14.0</version>
            </dependency>
        </dependencies>
    </project>

今回は、Windowsで動かしていますが、MacでもLinuxでも動くFatなライブラリをインストールしています(そのはずです)。必要なものは、ImageJ、DJL、DJL用のPyTorchのラッパーです。Tesorflowのラッパーも加えればTensorflowモデルを動かすことができます。上記のpom.xmlにはこれらが含まれています。
うまくライブラリを参照できると、エラーのない状態になります。


図 うまくライブラリをMavenで参照できた状態(少し時間がかかります)

サンプルデータ


ここでは、サンプルCTシリーズ画像として、TCIAで公開されているCOVID-19データで試していきます。
モデルのdocs>READMEを見てみると、下記のような記載があります。

## Input
Input: 1 channel CT image with intensity in HU and arbitary spacing

1. Resampling spacing to (0.8, 0.8, 5) mm
2. Clipping intensity to [-1500, 500] HU
3. Converting to channel first
4. Randomly cropping the volume to a fixed size (224, 224, 32)
5. Randomly applying spatial flipping
6. Randomly applying spatial rotation
7. Randomly shifting intensity of the volume

## Output
Output: 2 channels
- Label 0: everything else
- Label 1: lung

入力は0.8*0.8*5.0ボクセルのCTシリーズ画像を前提としているようです。
入力画像をこのボクセルサイズになるようにリサンプリングしておきます。
ただし、後から肺の領域をクロップし、224*224*32、つまり、マトリクスサイズが224*224の32スライスのボリュームに整形する(4の手順)ので、ここはそこまで厳密でなくともよいかと思います(結局、クロップするときのサイズで拡大率が変わります。ただ、何もしないよりはばらつきが少なくなり、ある程度の整合性が担保されます。ここは開発者に詳細を伺わなければ解釈が難しいところです)。

リサンプリング済みのデータはこちらです。
オリジナルはこちらです(非圧縮済み)。

オリジナルからインプットデータを作成したい方のために、Fijiを使って作成する手順を簡単に示します。
  1. 例えば、ここで利用するオリジナルのサンプルCTシリーズは、Voxel size: 0.7031x0.7031x1.25 mm^3です(ImageJで開く>Image>Show Info...から確認)。
  2. XY方向のリサンプリングは、縦横方向のマトリクスサイズを、0.7031/0.8倍して補正します。オリジナルの512*512マトリクスを450*450にします。
  3. Z方向のリサンプリングは、簡易的ですが、Image>Stacks>ResliceZというFijiのプラグイン機能を利用します。ImageJではデフォルトで入っていないので、ここはFijiを使ったほうが楽です。Z軸方向に5.0mmにしたいので、5.0としてリスライスすればOKです。
  4. 次に、大まかに肺領域を32スライス選びます(Image>Duplicateからスライス選択)。
  5. 最後に、矩形のROIを肺を囲うように1つ設定し、ROIが設定された画像をTIFとして1つのファイルに保存します。ほかのフォーマットはお勧めしません。
  6. これで入力画像の準備ができました。
図 サンプルデータ
(ボクセルサイズ0.8*0.8*5.0, 32スライスに限定)

モデルのロードから推論実行まで


ここではメソッド全体を記載しています。
各コードの解説はコード内のコメントに記載しています。
  
  public static void main(String[] args) throws IOException, ModelException, TranslateException {
	   
		/*
         * load model
         * load image
         * create input
         * predict
         * save results
		 */
        //リソースにモデルファイルを置き、モデルファイルまでのパスで設定している例
		String model_loc = "clara_pt_covid19_ct_lung_segmentation/models/model_cpu.pt";
		URI modelUri = null;
		try {
			modelUri = HelloPyTorch.class.getClassLoader().getResource(model_loc).toURI();
		} catch (URISyntaxException e) {
			e.printStackTrace();
		}
		Path modelPath = Paths.get(modelUri);
		
		//224*224*32 in HU unit.
		ImagePlus in = new ImagePlus("C:\\Users\\ユーザー\\Desktop\\Resampled.tif");//入力画像ファイルへのパス
		ij.gui.Roi rect = in.getRoi();//画像からROIを取得
		System.out.println("roi size:"+rect.getFloatWidth()+" "+rect.getFloatHeight());
		float[] vol = new float[224*224*32];
		NDList input = new NDList();
		NDManager manager = NDManager.newBaseManager();
		int vox_pos = 0;
		for(int i=0;i<in.getNSlices();i++) {
			in.setPosition(i+1);
			in.setRoi(rect);
			Calibration cal = in.getCalibration().copy();
			ImagePlus sliceCrop = in.crop();
			sliceCrop = sliceCrop.resize(224,224,"bilinear");
			ImageProcessor crop = sliceCrop.getProcessor();
			crop.setCalibrationTable(cal.getCTable());
			crop.setMinAndMax(-1500d+32768, 500d+32768);//ct 16-bit on IJ.
			for(int r =0;r<224;r++) {
				for(int h =0;h<224;h++) {
					float val = crop.getPixelValue(h, r);
					vol[vox_pos++] = val;
					if(i==0 && r < 10 && h < 10) {
						// System.out.println(val);//test print pixel val in HU.
					}
				}
			}
		}
		Shape s = new Shape(new long[] {1l,1l,32l,224l,224l});//batch,channel,d,h,w
		NDArray ndar = manager.create(vol, s);//create input to predict.
//		System.out.println(ndar.getShape());
		input.add(ndar);
		
        //入力から出力までのクライテリアを定義する入力も出力もNDListになる。
		Criteria<NDList, NDList> criteria =
                Criteria.builder()
                        .setTypes(NDList.class, NDList.class)
                        .optModelPath(modelPath)
                        .optTranslator(new NoopTranslator())
                        .build();
		//ZooModelクラスを利用して、訓練済みモデルをモデル化し、予測実行する
        try (ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);
            Predictor<NDList, NDList> predictor = model.newPredictor()) {
                NDList result =  predictor.predict(input);
                NDArray res = result.get(0);
                NDArray bg = res.get(new NDIndex("0,0,:,:,:"));
                NDArray lung = res.get(new NDIndex("0,1,:,:,:"));
                System.out.println(res.getShape());//(1, 2, 32, 224, 224)
                System.out.println("min max : "+res.min()+", "+res.max());
                System.out.println(bg.getShape());//(32, 224, 224)
                System.out.println(lung.getShape());//(32, 224, 224)
                /*
                 * - Label 0: everything else
                 * - Label 1: lung
                 */
                //予測結果を画像へ
                float[] pred = res.toFloatArray();//1*2*32*224*224 = 3211264
                
                for(int c = 0;c<2;c++) {
                	ImageStack stack = new ImageStack(224, 224, 32);
                    for(int d=0;d<32;d++) {
                    	float[] pred_lbl = new float[224*224];//per slice
                    	int pred_lbl_pos = 0;
                    	for(int h=0;h<224;h++) {
                    		for(int w=0;w<224;w++) {
                    			if(c == 0) {
                					pred_lbl[pred_lbl_pos] = pred[(d*224*224)+pred_lbl_pos];
                				} else {
                					pred_lbl[pred_lbl_pos] = pred[c*(224*224*32)-1+((d)*224*224)+pred_lbl_pos];
                				}
                    			pred_lbl_pos++;
                            }
                        }
                    	FloatProcessor fp = new FloatProcessor(224, 224, pred_lbl);
                    	stack.setProcessor(fp, d+1);
                    }
                    ImagePlus imp = new ImagePlus("", stack);
                    imp.show();
                    IJ.saveAs(imp, "tif", System.getProperty("user.home") + "/Desktop/test_"+(c+1));
                }
        }
		
	}
    
実行すると、推論結果が得られます。結果の予測画像はデスクトップに保存されます。

図 推論結果(左:予測画像肺以外、右:肺)

図 予測画像の二値化結果

あとは、これを元のサイズに戻し、切り出した位置にもっていけば、どの領域が肺か、肺以外かを予想した結果を確認できます(この部分は省略します。ImageROIにしてオリジナル画像に重ねて表示するなどで実装できます)。

まとめ


今回は、ImageJを用いた深層学習モデルの利用例を紹介しました。Tesorflowモデルも動くので、いろいろなことができると思います。また、深層学習が必要ない場合は、WEKAという機械学習ライブラリを利用することもできます。
自分だけの・自分の施設だけの深層学習モデルをあなたの手で動かしてみることもできるのではないでしょうか。

References

  • NVIDIA NGC https://docs.nvidia.com/ngc/ngc-overview/index.html
  • https://xtech.nikkei.com/atcl/nxt/column/18/00001/03341/
Visionary Imaging Services, Inc.

0 件のコメント:

コメントを投稿