CircleLossなんとなく理解した
はじめに
よく知らなかったのだけれど、距離学習(計量学習、Metric Learning)というのが流行っているらしい。
本技術は、分類器というよりその前段階までの特徴量抽出の部分に着目し、ラベル情報を使っていい感じに分かれるように埋め込む技術なんだとか。
ちょっと面白そうなので、CVPR2020の論文([2002.10857] Circle Loss: A Unified Perspective of Pair Similarity Optimization)で理解して実装してみました。
その前に:深層距離学習(Deep Metric Learning)とは
ざっくり言うと、『意味の近い入力同士の距離が近くなり、意味の遠い入力同士の距離が遠くなるように、特徴量空間に埋め込む方法』の意味です。
このデータ間の『距離』を明示的に埋め込めるようになると、
- ユーザが探しているものに似たもの(=意味空間においてカテゴリが近いもの)をサジェストする
- 通常状態とは特徴量的にかけ離れたような『異常検知』する
- 顔検知などのクラスが無数にあるデータをいい感じにクラスタリングして同一クラスを同定する
などができるようになります(以下サイトが詳しいです)。
そして、勘のいい皆さんはお気づきだと思いますが、対象は必ずしも画像に限りません。
たとえば、自然言語処理で有名なBERTは、雑に表現すれば各単語をよしなに768次元に埋め込んでくれる素晴らしい技術ですが、あくまで単語の埋め込みしかしてくれません。これを拡張し、いい感じの文章自埋め込み表現を獲得する方法として、深層距離学習の一種のsiamese networkが使われています([1908.10084] Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks)。*1
このように、多様なデータを対象に、データ間の距離をいい感じに埋め込んでくれる手法が、深層距離学習になります。
CircleLossの論文を読み解く
論文は以下です。
概要
- ペア類似度最適化(深層距離学習に相当)は、クラス内類似度(同一クラス内のデータ間の類似度、この分野ではSpと表記)を最大化し、クラス間類似度(異なるクラス同士の類似度、Snと表記)を最小化するタスクであり、(Sn - Sp)を最小化することが目的である。
- 一方、単純に(Sn - Sp)を計算して最適化するのは柔軟性に欠け、重みづけが必要である。その結果、円状の決定境界を持つCircleLossが開発された。
- 顔認識や人物の再識別、画像検索データセットを用いた結果、実験的にSOTA。
ということです。
この論文のすごいところ
有り体に言ってしまえば、深層距離学習の損失関数を簡潔に記述し、効率的で簡単な計算に落とし込んでいる点です。下図が、CircleLossを端的に解説している図になります。
図の(a)に相当するものが従来の深層距離学習であり、(b)がCircleLossになります。クラス内類似度Spを最大化し、クラス間類似度Snを最小化するという目的において、(b)の方が合理的だ、ということが明らかですね。
本論文では、ペア類似度最適化問題((sn − sp)の最小化問題)を、(an*sn − ap*sp)一般化し、このanとapの最適な値を導出することで、CircleLossを導出しています。
さらに、本論文では、深層距離学習で扱われてきた二つのアプローチ:
- 組み合わせ系(Triplet loss、Siamese Net等)
- Softmax Cross-Entropy系
の二つの深層距離学習が本質的に同じであることを看破し、CircleLossを縮退させることでこれらの二つが得られると主張しています。さらに、効率の良い勾配を形成することで、容易に良好な距離学習が可能だ、とも主張しています。
以上が、本論文の主な寄与になります。
どうやって統一するの?
論文では、距離学習で用いられる関数は以下の式(1)に統一できると主張しています。
ここで、γはスケーリングパラメータ、mは損失関数のマージン、K、Lはそれぞれ、あるxに対するクラス内類似度スコアの数およびクラス間の類似度スコアの数です。
この統一系の損失関数について、①N-1個のクラス間類似度を正規化し(L→N-1)、②一つの正規化済みクラス内類似度のみを算出するようにすれば(K→1)、Softmax系の損失関数が得られます。
また、統一系の損失関数について、γで割って極限を取れば、組み合わせ系損失関数になることが明らかになります。
以上のように、深層距離学習に関する手法は複数あるように見えて、その実一つの損失関数を縮退させているだけなんだよ、というのが本論文の主張です(しゅごい)。
さらに、最適化で必ず問題となる勾配の在り方についても議論されています。端的に言うと、下図のように、CircleLossにすれば、他の損失関数ではどうにもならないプラトー領域であっても勾配を取ることが出来、適切に距離を学習できるんだよ、というのが本論文の主張です。
特に組み合わせ系の損失関数は、学習のための組み合わせのサンプリングが難しく学習が大変、というのも、何となく↑の図で見えるような気がします。
どうやって最適化するの?
以降では統一系の損失関数で議論します。式(1)では、最小化対象を(sn-sp+m)としていましたが、ここではマージンmをいったん無視し、(an*sn - ap*sp)に拡張します。
それじゃあ、このanとapはどう決めればいいのさ?ということになりますが、ここで一つの考え方を導入します。すなわち、最適な状態Op or Onから状態が乖離しているほど、その係数を早く修正すべきである、という考え方です。具体的には、(5)式です。
ざっくり意味を説明すると、
- クラス内類似度spが最適状態Opよりも小さい場合、apを大きくする
- クラス間類似度snが最適状態Onよりも大きい場合、anを大きくする
という戦略が(5)式です。なお、[x]+は非負にする(ReLUと同じ)という意味です。
さて、ここでマージンの存在を復活させます。(4)式のsnとsp両方にマージンがあると考え、以下の(6)式とします。
ここに、(5)式の係数を大入試、整理すると、exp関数内は以下の式になります。
はい、Circleが出てきました。タイトル回収です。マージンありの考え方の元、最適な状態からの乖離を優先的に修正する、という要素を入れることにより、Circleが導出されました。
あとは、このOp、On、Δp、Δnをどう決めるか、という問題が出ます。
ここで、元々マージンだったmを導入し、
Op = 1+m
On = -m
Δp = 1-m
Δn = m
とします。すると、(7)式は大幅に簡易となり、
となります。式から分かる通り、基本的にsnが0、spが1になるように最適化を進めることになり、Circleの半径がマージンで規定される、という式になりました。
以上から、比較的少ないハイパーパラメータ(スケール係数γとマージンmのみ)のみで、損失関数が得られることがわかります。
他、利点として以下三つが議論されています。
- snとspをバランスよく最適化可能
- 徐々に減衰する勾配
- 明確な収束目標
以上がCircleLossの理論基盤になります。
※補足の感想など
論文では、本来確率を算出するものとして使われるsoftmaxの概念を完全に無視することで、これだけ自由度を得られたんだ!というような話がされています。
よくよく考えると、目標とする確率がlabelという形で0 or 1で表現されているのって結構無理がある気もするので(犬っぽい猫がいたっていいじゃんか、という話)、softmax=確率、の呪縛から逃れる良い手法だったのかもしれないな、と勝手に思ってます。
他と比較するとどうなの?
あとはひたすら実験結果を比較しています。
①顔認識(MFC1データセット、R34はResNet34、R100はResNet100)
②顔認識(LFW、YTFmCFP-FPについてResNet34)
③顔認識(TARの比較、IJB-Cについて)
④人の再識別
⑤類似画像検索
結論としては、いろいろなタスクにおいてCircleLossつよい。
ハイパーパラメータの探索とクラス内/クラス間類似度の比較
MFC1でパラメータスタディをすると、最良のパラメータでは他手法を圧倒する。
また、クラス間類似度の変遷をみると、確かにsoftmax系よりCircleLossのクラス内/クラス間類似度が良好なのがわかる。
収束後の類似度ペアのプロットを見ると、たしかにマージンを小さくするとより理想的な埋め込みに近づいていることがわかる。
以上のように、CircleLossはとてもすごい(小並感)。
論文の結論
- softmax系損失や組み合わせ系損失などの深層距離学習の大半は、クラス内類似度とクラス間類似度を類似ペアに埋め込んでいる(つまり、大枠は同じことをしている)。
- クラス内/クラス間類似度について、最適状態に対する距離に応じたペナルティを与えることで、CircleLossが得られる。
- 柔軟性が高く収束目標が明確であり、最新手法と同等の性能を実現できる。
というお話でした。
じゃあ試してみますか
ということで、解説するだけではつまらないので、手実装してみました。コードはここにあります。
課題設定
お手軽に以下を試してみました。
- MNISTを対象に埋め込み、UMAPで可視化する。可視化結果について、CircleLoss、何もしていないCNN、Softmax系の深層距離学習を比較する。
- 実際に数字の認識率が向上するかを評価してみる。
- (あんまり意味はないけれど、何となく面白そうなのでCircleLossとSoftmax系のキメラ手法を試してみる。)
環境設定
Win10 (GeForce GTX1080Ti)
python=3.7.7
で実行しています。依存ライブラリは以下の通り。
pytorch=1.6.0(CUDA=10.2)
umap-learn=0.4.6
pandas
numpy
torchsummary
matplotlib
実装方法
CircleLoss:論文中の下式を愚直に実装(ここのCircleLoss)。
※あとで著者実装見たんですが、softplus関数とか使えば簡単だったんだね…
Softmax系:SphereFace、ArcFace、CosFaceはArcFaceの論文([1801.07698] ArcFace: Additive Angular Margin Loss for Deep Face Recognition)を元に実装(ここのCosLayerが該当)。
ここに出てくるm1~m3について、
(m1, m2, m3) =
(1.00, 0.00, 0.00) ⇒ 普通のsoftmax
(1.35, 0.00, 0.00) ⇒ SphereFace
(1.00, 0.50, 0.00) ⇒ ArcFace
(1.00, 0.00, 0.35) ⇒ CosFace
と使い分けることが出来る。
また、比較としてAdaCosも論文に基づいて併せて実装した([1905.00292] AdaCos: Adaptively Scaling Cosine Logits for Effectively Learning Deep Face Representations)。こちらは、Loss関数を工夫するというより、スケールパラメータ(L4式のs)を賢く設定するというのが売りなので、基本はL4式に則って実装しています。
また、最後のFC層(softmax層)までのネットワークは、以下のものを使っています。
結果1:MNISTの埋め込み
まずは何もしない状態と、ただのクラス分け後の分布を見ます。そもそもUMAPがよしなにやってくれるので、それなりに綺麗に分かれて見えます。
一方、クラス分け学習後の結果を見ると、たとえば7と2がつながっているなど、上手くいっていないんだろうなあ、という部分が垣間見えます。
続いてsoftmax系。
こちらはクラス内でよくまとまるようになっています。ただし、やはり2と7は区別が難しいようで、SphereFaceとAdaCosではつながり気味です。MNIST的には、あまり単純な損失関数は向かないということでしょうか。
最後に、CircleLoss。
論文の言う通り、めちゃくちゃ綺麗に分かれています。SphereFaceやAdaCosで見られたアメーバの触手のような部分もなくなり、綺麗に分かれています。ちょっとMNISTは課題が単純すぎたかもしれません。。。
結果2:クラス分類性能の向上
それぞれについて、20回分の正解数(各総数は10,000)を可視化した。
上記を見ると、意外なことに埋め込みではイマイチだったSphereFaceが二番目に良好な結果が出ている。2と7の曖昧な決定領域が、意外と分類に役立ているのかもしれない。
また、ArcFaceは、上手くいくと高いスコアを出すが、イマイチな時は壊滅している。原因は不明であり、調査が必要だが、初期設定のマージンが少なすぎるのが原因かもしれない。
ここでは、意外とCircleLossはイマイチな結果をはいている。全体として正解率が高いことから、少々課題が簡単すぎた可能性がある。
結果3:CircleLossとsoftmax系の組み合わせ
最後に、完全な蛇足としてCircleLossとsoftmax系を組み合わせてみた。計算式から考えても、部分的にlossを厳しく設定した、程度の効果しかないと思われるが、なんか面白そう&簡単に実装できたので、分類性能のみ表示する。埋め込み結果については、githubのここにXXX_combined.pngとして保存してある。
なんか下がった…!!
まあ、冷静に考えれば過剰にLossがかかっているので、クラス分類性能を犠牲にしてでも分けてやる、と学習した結果と理解すればよさそう。
やっぱりMNISTは課題が簡単すぎたね…パトラッシュ…
結論
- CircleLoss大体理解した。とてもすごい。天才。
- 実装できた。クラス間距離の最大化とクラス内距離の最小化、という意味ではCircleLossは十分すごい。けれど、クラス分類性能の話になると、MNISTはちょっと簡単すぎた。
- CircleLossとsoftmaxを考えなしに組み合わせると精度下がるので、あまり意味がない。
どうも超長記事になりましたが、ここまで追ってくれた人はありがとうございました。
感想と蛇足
・久しぶりの画像系実装だったので結構疲れた…
・でも深層距離学習楽しい
・天才っているんだなあ
・複数の深層距離学習をたくさん実装出来てめっちゃどや顔だったんだけれど、この記事作成中に深層距離学習用ライブラリを見つけてしまったよ…!!!かなしいよ!!!(;^ω^)