deeplearning4jのRNNで極性判定を作った~Early Stop編~

はじめに

deeplearning4jのRNN+LSTMで日本語極性判定技術を作っています。先日公開してからかなり反響があり、アクセス数増加、Apitoreのユーザーも増え、Qiitaではデイリーランキング5位になりました。ありがとうございます!さて、今回は過学習を回避し、かつ学習効率が下がったら学習を止めてしまう、Early Stoppingを試しました。ついでに学習データも増強したので、学習データ増強before/afterも評価しようと思います。
amarec (20161031-202055)

デモサイト

ソースコード

関連情報

Early Stoppingの実装

例のごとく、deeplearning4jの公式でサンプルプログラムが公開されていました。今回はこの中から、epochで5回連続性能が上がらなかったら終了、という感じで学習させることにします。

下にコードを載せます。以前の実装との違いは、EarlyStoppingModelSaverやIEarlyStoppingTrainerですね。

/**
 * args[0] input: word2vecファイル名
 * args[1] input: train/test親フォルダ名
 * args[2] output: 出力ディレクトリ名
 *
 * @param args
 * @throws Exception
 */
public static void main (final String[] args) throws Exception {
  if (args[0]==null || args[1]==null || args[2]==null)
    System.exit(1);

  WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
  int numInputs   = wvec.lookupTable().layerSize();
  int numOutputs  = 2; // positive or negative
  int batchSize   = 50;
  int iterations  = 5;
  int nEpochs     = 100;
  int thresEpochs = 5;

  MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
      .seed(1)
      .iterations(iterations)
      .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
      .learningRate(0.0018)
      .updater(Updater.RMSPROP)
      .regularization(true).l2(1e-5)
      .weightInit(WeightInit.XAVIER)
      .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
      .gradientNormalizationThreshold(1.0)
      .list()
      .layer(0, new GravesLSTM.Builder()
          .nIn(numInputs).nOut(numInputs)
          .activation("softsign")
          .build())
      .layer(1, new RnnOutputLayer.Builder()
          .lossFunction(LossFunctions.LossFunction.MCXENT)
          .activation("softmax")
          .nIn(numInputs).nOut(numOutputs)
          .build())
      .pretrain(false).backprop(true).build();

  MultiLayerNetwork model = new MultiLayerNetwork(conf);
  model.setListeners(new ScoreIterationListener(1));

  LOG.info("Starting training");
  DataSetIterator train = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),2);
  DataSetIterator test = new AsyncDataSetIterator(
      new SentimentRecurrentIterator(args[1],wvec,100,300,false),2);

  EarlyStoppingModelSaver<MultiLayerNetwork> saver = new LocalFileModelSaver(args[2]);
  EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
      .epochTerminationConditions(new MaxEpochsTerminationCondition(nEpochs),
              new ScoreImprovementEpochTerminationCondition(thresEpochs))
      .scoreCalculator(new DataSetLossCalculator(test, true))
      .modelSaver(saver)
      .build();

  IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new EarlyStoppingTrainer(esConf,model,train);
  EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
  LOG.info("Termination reason: " + result.getTerminationReason());
  LOG.info("Termination details: " + result.getTerminationDetails());
  LOG.info("Total epochs: " + result.getTotalEpochs());
  LOG.info("Best epoch number: " + result.getBestModelEpoch());
  LOG.info("Score at best epoch: " + result.getBestModelScore());
}

実験

ついでに学習データも増強しました。以前作った目的語リストでTwitterを再クロールしました。Twitterの検索APIは最大7日前まで遡って検索でき、前回データを取ったのが2週間前なので新しいデータがとれます。さて、増強前後の性能を見てみましょう。条件は以下になります。

学習データ

セット positive negative
Before 6,000 40,900
After 10,780 76,442

Early Stopping用の評価データ

セット positive negative
Before 160 1,100
After 996 7,692

性能評価用データ

セット positive negative
共通 200 800

学習は時間はそれぞれ・・・

  • beforeが約1日、epochは7で終了
  • afterは約1.5日、epochは16で終了

となっています。GPUを使っていないので1 epoch当たりの実行速度は極めて遅いですね。あと、epochが全然回っていないので、学習設定があまり良くないかもしれません。パラメータは相変わらずdeeplearning4jのサンプルプログラムやネットで見かけたよく使われる数値です。BeforeとAfterでEarly Stopping用のデータが違うのもご容赦ください。兎にも角にも、まずは精度を見てみましょう。

まずはBeforeを

Examples labeled as 0 classified by model as 0: 90 times
Examples labeled as 0 classified by model as 1: 110 times
Examples labeled as 1 classified by model as 0: 10 times
Examples labeled as 1 classified by model as 1: 790 times
Accuracy:  0.88
Precision: 0.8889
Recall:    0.7188
F1 Score:  0.7948

つづいてAfter

Examples labeled as 0 classified by model as 0: 107 times
Examples labeled as 0 classified by model as 1: 93 times
Examples labeled as 1 classified by model as 0: 6 times
Examples labeled as 1 classified by model as 1: 794 times
Accuracy:  0.901
Precision: 0.921
Recall:    0.7638
F1 Score:  0.835

結果から分かる通り、純粋に性能が上がっています。どちらもややprecision重視になっているので、positiveの判定性能は5割くらいですね。うーん、使えるレベルにはないかなあ?改善点があるとするとパラメータチューニングもそうですが、学習データはDistant Supervision方式なので人手のアノテーションデータに差し替えて行きたいです。ちなみにDistant Supervision方式を取っているので、この極性判定技術は「Twitter社の極性判定技術をF値で83.5%再現している」ということになります。なので性能はどんどん上げていきたいですね。

おわりに

今回はdeeplearning4jのRNNでEarly stoppingを試してみました。過学習を避けることができるので、普通に導入すべきですね。Early Stoppingは今回紹介したものの他にも色々と停止条件を設定できます。正直、どのパラメータをどの程度に設定すべきかは神のみぞ知るセカイなので、色々と試すために計算資源の潤沢さが必要になりますね。

もっと極性判定の性能を上げていきたいので、もしよろしければこちらのデモサイトで極性判定結果の改善にご協力ください。無料で作れる日本語極性判定技術を盛り上げて行きたいです。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です