MLP深層学習 LSTM曽和 修平
LSTM
LSTMとは
• RNNは長い順伝搬ネットワークに展開される
→勾配消失問題が発生する(長い系列が扱えない!)
• LSTMはこの問題を解決し、長い系列を扱えるようにする
LSTMの構造• RNNの中間層の各ユニットを「メモリユニット」というものに置き換える
メモリユニットは前の状態を「覚えていたり」「忘れていたり」することで
RN LST
中間層のユニット
メモリユニット
勾配消失問題を解決しようとする
中間層のユニットとメモリユニット• 中間層の1ユニット
f
WjiWjj’
• メモリユニット
1時刻前の中間層の出力
入力層の出力
Wji
Wjj’
Wjj’I
I
Wji
Wjj’
F
FWji
Wjj’O
O
1時刻前のメモリセルの出力
メモリセル
忘却ゲート
入力ゲート
出力ゲート
Wji
LSTMの構造下の方から見ていく
Wjj’ Wii
この部分は中間層の1ユニットと全く同じ
※添字[j]はユニットの番号,[t]は時刻
u
tj =
X
i
w
(in)ji x
ti +
X
j0
wjj0zt1j0
f(utj)
f(utj)
出力:
(in)
LSTMの構造
1時刻前のメモリセルの出力 を考慮
Wji
Wjj’
I
u
I,tj =
X
i
w
(I,in)ji x
ti +
X
j0
w
Ijj0z
t1j0 + w
Ij s
t1j
s
t1j
出力: g
I,tj = f(uI,t
j )
g
I,tj
LSTMの構造
先に計算した各セルの積
出力:
入力ゲート
f(utj)g
I,tj
f(utj)g
I,tj
これがメモリセルへの入力の1つとなる
LSTMの構造
出力:
Wji
Wjj’
F
F
1時刻前のメモリセルの出力 を「どれだけ覚えているか」をs
t1j
u
F,tj =
X
i
w
(F,in)ji x
ti +
X
j0
w
Fjj0z
t1j0 + w
Fj s
t1j
g
F,tj = f(uF,t
j )
表現している。
LSTMの構造
先に計算した各セルの積
出力:
これがメモリセルへの入力の1つとなる
忘却ゲート
g
F,tj
g
F,tj s
t1j
s
t1j
g
F,tj s
t1j
が1に近ければよく前のメモリセルの状態をよく覚えている事になる
逆に0に近ければ前のメモリセルの状態を忘れた事になる
g
F,tj
LSTMの構造
メモリセルは入力ゲートの出力と入力ゲートの出力の和
メモリセル
忘却ゲート
入力ゲート
s
tj = g
F,tj s
t1j + g
I,tj f(ut
j)
メモリセルや各ゲートの意味については後述
s
tj
LSTMの構造
Wji
Wjj’
O
O
s
tj
u
O,tj =
X
i
w
(O,in)ji x
ti +
X
j0
w
Ojj0z
t1j0 + w
Oj s
tj
出力: g
O,tj = f(uO,t
j )
g
O,tj
現在のメモリセルの値を考慮
LSTMの構造
出力ゲート
s
tj
f(stj)
z
tj
g
O,tj
z
tj = g
O,tj f(stj)
この値が1つのメモリユニットの出力値となる
ゲートの役割
入力重み衝突
i j重み w
ユニットiの入力はユニットjに重みwをかけて「伝達」される
伝達するためには重みwを大きくしてユニットjを活性させないといけない
一方、無関係な入力が入った時は伝達したくない・・というジレンマがある
入力重み衝突
i j重み w
例)ユニットiは仮に「英語の名詞」に対して反応するとする
入力「pen」がきたら活性する。これは次に伝達したい。
入力「ペン」はそもそも日本語なのでこの特徴を抽出する上で関係ない。これは伝達したくない・・
→重みは大きくしたい
→重みは小さくしたい
入力ゲート
・前のユニットの状態を伝達するかどうかを決定
i j重み w
入力ゲート
・前のユニットの出力が必要ものならゲートを開ける・前のユニットの出力が関係ないものならゲートを閉じる
・この判断をするのが↓の部分のネットワーク
Wji
Wjj’
Ig
I,tj
出力重み衝突/出力ゲート
入力重み衝突と同じ。
・前のユニットの状態を「受け取る」かどうかを決定(入力ゲート)
(出力ゲート)・前のユニットの状態を「伝達する」かどうかを決定
Wji
Wjj’
O
O
s
tj
g
O,tj
・この判断をするのが↓の部分のネットワーク
メモリセル
・メモリセルはこれまでの状態を保持している
メモリセル
忘却ゲート
入力ゲート
s
tj = g
F,tj s
t1j + g
I,tj f(ut
j)
s
tj
・しかし、入力の系列がガラッと変わった時、今までの状 態を捨てたい事がある
メモリセル
主語が男か女かを判断例)
He is a student and she is a student
・「and」の前後で文が独立している
このような場合,she is ・・の文を学習するにあたって
これまでの状態(He is a・・)を捨てたい
忘却ゲート
・これまでの状態を覚えておくか、忘れるかを判断する
Wji
Wjj’
F
F
忘却ゲート
s
t1j
g
F,tj s
t1j
s
t1j
忘れるべきか、覚えておくべきかを判断
逆伝搬計算
誤差関数に関する勾配の求め方(復習)・中間層l ← 中間層l+1の勾配を誤差逆伝搬法で求める
・l層のあるユニットjへの総入力は
u
(l)j =
nX
i=1
w
(l)ji z
(l1)i
En
w
(l)ji
=En
u
(l)j
u
(l)j
w
(l)ji
・勾配を求める為微分する
これが計算できれば勾配が求まる→重みが更新できる
誤差関数に関する勾配の求め方(復習)・中間層l ← 中間層l+1の勾配を誤差逆伝搬法で求める
En
w
(l)ji
=En
u
(l)j
u
(l)j
w
(l)ji この部分はそのまま微分可能
この部分は出力層→中間層の時以外はこのまま計算できない
はどうすれば計算できるのか。En
u
(l)j
誤差関数に関する勾配の求め方(復習)
ユニットjが変動すると,Enはどう影響を受けるのか?
j
0
k
l層l+1層
ユニットjの出力分だけ、次の層の各中間ユニットの総入力が 影響を受ける→この影響が出力層まで連鎖していく
ユニットjの出力 x 重みの分だけ総入力に影響がある
誤差関数に関する勾配の求め方(復習)
つまり、uj^(l)がEnに与える影響(変動) はEn
u
(l)j
En
u
(l)j
=X
k
En
u
(l+1)k
u
(l+1)k
u
(l)j と書ける
そしてこの本ではこの値を「デルタ」と呼んでいる
(l)j ⌘ En
u
(l)j
このデルタさえわかれば後はEn
w
(l)ji
=En
u
(l)j
u
(l)j
w
(l)ji
に代入すれば簡単に勾配が求まる。
誤差関数に関する勾配の求め方(復習)では、デルタはどう変形できるか。
En
u
(l)j
=X
k
En
u
(l+1)k
u
(l+1)k
u
(l)j
右辺第一項はl+1層のデルタになっている
(l)j ⌘ En
u
(l)j
(l)j =
X
k
(l+1)k
u
(l+1)k
u
(l)j ので、こう書ける
さて、右辺第二項について考える。
誤差関数に関する勾配の求め方(復習)l+1層のユニットkに関する総入力uは
u
(l+1)k =
X
j
w
(l+1)kj z
(l)j =
X
j
w
(l+1)kj f(u
(l)j )
この式をuj^(l)で微分すると
u
(l+1)k
u
(l)j
= w
(l+1)kj f
0(u(l)j )
よって
(l)j =
X
k
(l+1)k (w
(l+1)kj f
0(u(l)j ))
LSTMの逆伝搬計算LSTMのメモリユニットの各「デルタ」を計算する
最適化対象の変数は以下
Wji
Wjj’
Wjj’II
Wji Wjj’F F
Wji Wjj’OO
Wji
Wj F
Wj I
Wj O
・・これまでと同じ
・・入力ゲート値の重み
・・忘却ゲート値の重み
・・出力ゲート値の重み
逆伝搬計算
Wji
Wjj’
O
O
s
tj
g
O,tj
まずはこのセルのデルタを考える※vk^tは次の出力層への総入力
出力層に関して 次時刻のメモリユニットに関して
O,tj =
X
k
out,tk
v
tk
u
O,tj
+X
j0
(t+1)j0
u
(l+1)k
u
(O,t)j
j’
逆伝搬計算
v
tk =
X
j
w
outkj z
tj
この部分を求める出力層のユニットへの総入力は
これをuj^(O,t)で微分するとv
tk
u
O,tj
= w
outkj f
0(uO,tj )f(stj)
z
ij = g
O,tj f(stj)
出力ゲート
s
tj
f(stj)
z
tj
g
O,tj
である事に注意
O,tj =
X
k
out,tk
v
tk
u
O,tj
+X
j0
(t+1)j0
u
(l+1)k
u
(O,t)j
j’
逆伝搬計算
この部分を求めるこれは先と同じように計算できる
v
tk
u
O,tj
= w
outkj f
0(uO,tj )f(stj)
出力ゲート
s
tj
f(stj)
z
tj
g
O,tj
O,tj =
X
k
out,tk
v
tk
u
O,tj
+X
j0
(t+1)j0
u
(l+1)k
u
(O,t)j
j’
j’u
(t+1)j0
u
O,tj
= wj0jf0(uO,t
j )f(stj)
逆伝搬計算
✏
tj =
X
k
w
outkj δ
out,tk +
X
j0
wj0jδt+1j0
ここで
とおき
v
tk
u
O,tj
= w
outkj f
0(uO,tj )f(stj)
をデルタO,tに代入すると・・
δ
O,tj = f
0(uO,tj )f(stj)✏
tj となる
j’u
(t+1)j0
u
O,tj
= wj0jf0(uO,t
j )f(stj)
逆伝搬計算
出力ゲート
s
tj
f(stj)
z
tj
g
O,tj
次は、このセルのデルタを求める
これも先と同様の考え方をする。
e
tj =
X
k
out,tk
v
tk
s
tj
+X
j0
(t+1)j0
u
(l+1)k
s
tj
出力層に関して 次時刻のメモリユニットに関して
j’
逆伝搬計算
v
tk =
X
j
w
outkj z
tj
出力層のユニットへの総入力は
これをsj^tで微分すると
z
ij = g
O,tj f(stj) である事に注意
e
tj =
X
k
out,tk
v
tk
s
tj
+X
j0
(t+1)j0
u
(l+1)k
s
tj
v
tk
s
tj
=X
j
w
outkj g
O,tj f
0(stj)
j’
逆伝搬計算
こちらに関しても同様の流れで計算して
e
tj =
X
k
out,tk
v
tk
s
tj
+X
j0
(t+1)j0
u
(l+1)k
s
tj
✏
tj =
X
k
w
outkj δ
out,tk +
X
j0
wj0jδt+1j0 とおくと
eδ
tj = g
O,tj f
0(stj)✏tj
j’
j’u
(t+1)j0
u
O,tj
= wj0jgO,tj f
0(stj)
逆伝搬計算
デルタの定義式をもう一度眺める
(l)j =
X
k
(l+1)k (w
(l+1)kj f
0(u(l)j ))
伝搬元のデルタ
伝搬元の重み
現在の層の出力を微分したもの
これらの積の和(伝搬元のユニット数)
逆伝搬計算
次は、メモリセルのデルタを求めるメモリセル
忘却ゲート
入力ゲート
s
tj
このセルの「伝搬元」は・・・外部出力向け ・セル自身への帰還・入力ゲート ・忘却ゲート ・出力ゲートつまり、これら全ての「デルタ」x重みの和がメモリセルの「デルタ」
メモリセルは総入力sj^tを受け、恒等写像の活性化関数を経てsj^tを返すと考える
f
0(stj) = 1
逆伝搬計算
メモリセル
忘却ゲート
入力ゲート
s
tj ・外部出力向け のデルタ
これは先程計算済み。
eδ
tj = g
O,tj f
0(stj)✏tj
逆伝搬計算
メモリセル
忘却ゲート
入力ゲート
s
tj ・セル自身への帰還 のデルタ
1時刻後のメモリセルの値と1時刻後のゲート値の積
g
F,t+1j
cell,t+1j
忘却ゲートメモリセル
伝搬元のデルタ伝搬元の重み
逆伝搬計算
メモリセル
忘却ゲート
入力ゲート
s
tj ・入力ゲート
・忘却ゲート・出力ゲート からのデルタ
それぞれ・・ w
Ij
I,t+1j
w
Fj
F,t+1j
w
Oj
O,tj
逆伝搬計算
よって、メモリセルのデルタは下記のように表される
メモリセル
忘却ゲート
入力ゲート
s
tj
cell,tj = e
tj + g
F,t+1j
cell,t+1j + w
Ij
I,t+1j + w
Fj
F,t+1j + w
Oj
O,tj
まだ
F,tj
I,tjと
が求められていないので求めていく。
最初のセルに関するデルタ
tj
逆伝搬計算
デルタの定義式をもう一度眺める(再掲)
(l)j =
X
k
(l+1)k (w
(l+1)kj f
0(u(l)j ))
伝搬元のデルタ
伝搬元の重み
現在の層の出力を微分したもの
これらの積の和(伝搬元のユニット数)
逆伝搬計算
F,tjまず、
Wji
Wjj’
F
F
に関して。伝搬元のデルタは? =
cell,tj
伝搬元の重みは? = s
t1j
現在の層の出力を微分したものは? = f
0(uF,tj )
伝搬元は単一のセルからなので、和を取る必要はない。
F,tj = f
0(uF,tj )st−1
j
cell,tjよって・・
逆伝搬計算
次、 に関して。伝搬元のデルタは? =
cell,tj
伝搬元の重みは? = 現在の層の出力を微分したものは? = 伝搬元は単一のセルからなので、和を取る必要はない。
よって・・
I,tj
Wji
Wjj’
Ig
I,tj
f(utj)
f
0(uI,tj )
I,tj = f
0(uI,tj )f(ut
j)cell,tj
逆伝搬計算
最後、 に関して。伝搬元のデルタは? =
cell,tj
伝搬元の重みは? = 現在の層の出力を微分したものは? = 伝搬元は単一のセルからなので、和を取る必要はない。
よって・・
tj
Wjj’ Wii
f(utj)
g
I,tj
f
0(utj)
tj = g
I,tj f
0(utj)
cell,tj
逆伝搬計算
これで、全てのデルタの計算ができた。
重みの更新ができる
Top Related