データ読み込みに関するメモ
今日はDL4Jでのデータ読み込みに関するメモです。
簡単に書きたいので、ソースとコメントのみにします。
// RecordReaderはある媒体からデータを読み込んで共通形式のデータ(List<Writable>)を生成するインタフェースです。 // CSV・正規表現・画像・JSON・XML・YAMLなど、多くの実装があります。 // RNNで使用するシーケンシャルなデータはRecordReaderを継承するSequenceRecordReaderインタフェースが担当しますが、 // 時間という軸を表現するため、Listがもう一皮ついて、List<List<Writable>>を生成します。 RecordReader recordReader = new CSVRecordReader(); // 読み込むファイルをInputSplitに指定し、初期化します。 // InputSplitはRecordReaderに媒体の場所を分割して知らせます。今の場合、ファイル一つだけなので意味がないですが、 // 別の実装(NumberedFileInputSplit)を使えば、特定範囲の数字がついた複数のファイルを順に処理するといったことができます。 // (例: file1.csv, ..., fileX.csv) InputSplit inputSplit = new FileSplit(new File("train.csv")); recordReader.initialize(inputSplit); // 共通形式(List<Writable>)のデータからミニバッチデータセット(DataSet)を生成するDataSetIteratorを生成します。 // 分類・回帰などそれぞれの用途に合うコンストラクタが存在します。今回は分類です。 int batchSize = 50; // 行のうち、ラベル(正解)は何カラム目か? int labelIndex = 0; // 今回の場合、分類ですので、分類の数を指定しています。2種類ですね。 int numberOfClasses = 2; DataSetIterator trainIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numberOfClasses); // ...中略... // Deep Learningのネットワークを生成・初期化します。 MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); // nEpochs回分、モデルを訓練させます。 for (int n = 0; n < nEpochs; n++) { model.fit(trainIter); } // 後はテストデータでモデルを評価(model.eval)をしたり、実データで予想(model.output)をしたり、 // モデルをディスクに保存(ModelSerializer.writeModel)すればいいですね。