テクノロジー

2024年11月15日

リカレントニューラルネットワーク(RNN)とは

回帰結合型ニューラルネットワークと呼ばれるもので、系列データ(連続データ)に対して適用されるNNである。
系列データとしては、文章や時系列データなど、その前後のつながりに意味のあるデータのことを指します。

著者プロフィール

IT分野における教育の先駆者として、多くのエンジニアを育成するプログラミングスクールの運営、Web開発やAI研修を行なっています。幅広いレベルの受講生に対して実践的なスキルを提供。生徒の成長を第一に考え、効果的で魅力的な教育プログラムの設計に情熱を注いでいます。

ゴール

  • 「課題 リカレントニューラルネットワーク」を解くうえで必要な知識や技術について理解する

目的

  • リカレントニューラルネットワークのアルゴリズムを理解できる
  • リカレントニューラルネットワークをスクラッチで実装できる

どのように学ぶか

「課題 リカレントニューラルネットワーク」の流れに沿って解説していきます。

リカレントニューラルネットワーク(RNN、Recurrent Neural Network)

回帰結合型ニューラルネットワークと呼ばれるもので、系列データ(連続データ)に対して適用されるNNである。

系列データとしては、文章や時系列データなど、その前後のつながりに意味のあるデータのことを指します。

アルゴリズム概要

RNNの基本構成

tを時間とすると、普通のネットワークは下記のような構成になります。下記の図はt=1の場合を表しています。

Image from Gyazo

これがリカレントニューラルネットワークでは、次のようになります。

Image from Gyazo

このように、1つ前の時間(t-1)の中間層の出力を、現在の時間(t)の中間層の入力として受け取っていることがわかると思います。

これがリカレントニューラルネットワークの基本的な構成になります。

RNNの学習方法

RNNは上記で見てきた通り、以前の時間も考慮して、学習を進めていかなければならないため、通常のNNの確率的勾配降下法が適用できません。
そうした場合に、用いられる学習方法としては、様々提案されていますが、シンプルでよく使われているのがBPTT(Back Propagation Through Time)と呼ばれる手法です。

前述のRNNの図を一部抽出し、より一般化して見てみると下記のような構成になります。

Image from Gyazo

Wh/Wxはそれぞれが重みで、この重みを更新していきます。

入力から出力までを式に表すと下記のようになります。

\[a_t = x_{t}\cdot W_{x} + h_{t-1}\cdot W_{h} + B\\ h_t = tanh(a_t)\]

$a_t$ : 時刻tの活性化関数を通す前の状態 (batch_size, n_nodes)

$h_t$ : 時刻tの状態・出力 (batch_size, n_nodes)

$x_{t}$ : 時刻tの入力 (batch_size, n_features)

$W_{x}$ : 入力に対する重み (n_features, n_nodes)

$h_{t-1}$ : 時刻t-1の状態(前の時刻から伝わる順伝播) (batch_size, n_nodes)

$W_{h}$ : 状態に対する重み。 (n_nodes, n_nodes)

$B$ : バイアス項 (n_nodes,)

初期状態 $h_{0}$ はすべて0とすることが多いですが、任意の値を与えることも可能です。

上記の処理を系列数n_sequences回繰り返すことになります。RNN全体への入力 $x$ は(batch_size, n_sequences, n_features)のような配列で渡されることになり、そこから各時刻の配列を取り出していきます。

分類問題であれば、それぞれの時刻のhに対して全結合層とソフトマックス関数(またはシグモイド関数)を使用します。タスクによっては最後の時刻のhだけを使用することもあります。

この式を基に、各種勾配を計算し、重みを更新していきます。

ここで気を付けなければならない点は、1つ前の時間の中間層の重みに関しても考慮しないといけない点になります。よって、1つ前の時間の誤差も計算する必要があります。

このリカレントニューラルネットワークに関しては、理論の理解は難しいと思いますので、概念的なところを正しく理解したうえで、ライブラリで実装できることを目標としましょう。

ひとまず、概念部分の理解を深めるため、フォワードプロパゲーションの処理をスクラッチで作ってみます。

まとめ

  • リカレントニューラルネットワークとその構造は、上記の説明で理解できる
  • 理解を深めるために一から順伝播を確認して実装する

ダイビックのことをもっと知ってみませんか?