deeplearning4jでword2vecのベクトルデータからRNNで極性判定してみる

はじめに

今回はword2vecとRNNを組み合わせて極性判定をしてみます。deeplearning4jのサンプルプログラムがほぼそのまま使えました。
amarec (20161016-163226)

関連情報

実装

公式のプログラムをほぼそのまま使いました。前回同様、インバランスなデータでも使えるように修正しておきます。

public class SentimentRecurrentIterator implements DataSetIterator {

  /**
   *
   */
  private static final long serialVersionUID = 3720823704002518845L;

  private final WordVectors wordVectors;
  private final int batchSize;
  private final int vectorSize;
  private final int truncateLength;

  private int cursor = 0;
  private int posCursor = 0;
  private int negCursor = 0;
  private final File[] positiveFiles;
  private final File[] negativeFiles;
  private final int numPositives;
  private final int numNegatives;
  private final int numTotals;
  private final Random rnd;
  private final TokenizerFactory tokenizerFactory;

  /**
   * @param dataDirectory the directory of the IMDB review data set
   * @param wordVectors WordVectors object
   * @param batchSize Size of each minibatch for training
   * @param truncateLength If reviews exceed
   * @param train If true: return the training data. If false: return the testing data.
   */
  public SentimentRecurrentIterator(String dataDirectory, WordVectors wordVectors, int batchSize, int truncateLength, boolean train) throws IOException {
    this.batchSize = batchSize;
    this.vectorSize = wordVectors.lookupTable().layerSize();

    File p = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/positive/") + "/");
    File n = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/negative/") + "/");
    positiveFiles = p.listFiles();
    negativeFiles = n.listFiles();
    numPositives  = positiveFiles.length;
    numNegatives  = negativeFiles.length;
    numTotals     = numPositives+numNegatives;
    rnd           = new Random(1);

    this.wordVectors = wordVectors;
    this.truncateLength = truncateLength;

    tokenizerFactory = new DefaultTokenizerFactory();
    tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
  }


  @Override
  public DataSet next(int num) {
    if (cursor >= positiveFiles.length + negativeFiles.length) throw new NoSuchElementException();
    try{
      return nextDataSet(num);
    }catch(IOException e){
      throw new RuntimeException(e);
    }
  }

  private DataSet nextDataSet(int num) throws IOException {
    //First: load reviews to String. Alternate positive and negative reviews
    List<String> reviews = new ArrayList<>(num);
    boolean[] positive = new boolean[num];
    for( int i=0; i<num && cursor<totalExamples(); i++ ){
      int idx = rnd.nextInt(numTotals);
      boolean mode = modeJudge(idx);
      if (mode) {
        String review = FileUtils.readFileToString(positiveFiles[posCursor]);
        reviews.add(review);
        positive[i] = true;
        posCursor++;
      } else {
        String review = FileUtils.readFileToString(negativeFiles[negCursor]);
        reviews.add(review);
        positive[i] = false;
        negCursor++;
      }
      cursor++;
    }

    //Second: tokenize reviews and filter out unknown words
    List<List<String>> allTokens = new ArrayList<>(reviews.size());
    int maxLength = 0;
    for(String s : reviews){
      List<String> tokens = tokenizerFactory.create(s).getTokens();
      List<String> tokensFiltered = new ArrayList<>();
      for(String t : tokens ){
        if(wordVectors.hasWord(t)) tokensFiltered.add(t);
      }
      allTokens.add(tokensFiltered);
      maxLength = Math.max(maxLength,tokensFiltered.size());
    }

    //If longest review exceeds 'truncateLength': only take the first 'truncateLength' words
    if(maxLength > truncateLength) maxLength = truncateLength;

    //Create data for training
    //Here: we have reviews.size() examples of varying lengths
    INDArray features = Nd4j.create(reviews.size(), vectorSize, maxLength);
    INDArray labels = Nd4j.create(reviews.size(), 2, maxLength);    //Two labels: positive or negative
    //Because we are dealing with reviews of different lengths and only one output at the final time step: use padding arrays
    //Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
    INDArray featuresMask = Nd4j.zeros(reviews.size(), maxLength);
    INDArray labelsMask = Nd4j.zeros(reviews.size(), maxLength);

    int[] temp = new int[2];
    for( int i=0; i<reviews.size(); i++ ){
      List<String> tokens = allTokens.get(i);
      temp[0] = i;
      //Get word vectors for each word in review, and put them in the training data
      for( int j=0; j<tokens.size() && j<maxLength; j++ ){
        String token = tokens.get(j);
        INDArray vector = wordVectors.getWordVectorMatrix(token);
        features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);

        temp[1] = j;
        featuresMask.putScalar(temp, 1.0);  //Word is present (not padding) for this example + time step -> 1.0 in features mask
      }

      int idx = (positive[i] ? 0 : 1);
      int lastIdx = Math.min(tokens.size(),maxLength);
      labels.putScalar(new int[]{i,idx,lastIdx-1},1.0);   //Set label: [0,1] for negative, [1,0] for positive
      labelsMask.putScalar(new int[]{i,lastIdx-1},1.0);   //Specify that an output exists at the final time step for this example
    }

    return new DataSet(features,labels,featuresMask,labelsMask);
  }

  private boolean modeJudge(int idx) {
    if (posCursor >= numPositives)
      return false;
    else if (negCursor >= numNegatives)
      return true;
    else if (idx < numPositives)
      return true;
    else
      return false;
  }

  @Override
  public int totalExamples() {
    return positiveFiles.length + negativeFiles.length;
  }

  @Override
  public int inputColumns() {
    return vectorSize;
  }

  @Override
  public int totalOutcomes() {
    return 2;
  }

  @Override
  public void reset() {
    cursor = 0;
    posCursor = 0;
    negCursor = 0;
  }

  public boolean resetSupported() {
    return true;
  }

  @Override
  public boolean asyncSupported() {
    return true;
  }

  @Override
  public int batch() {
    return batchSize;
  }

  @Override
  public int cursor() {
    return cursor;
  }

  @Override
  public int numExamples() {
    return totalExamples();
  }

  @Override
  public void setPreProcessor(DataSetPreProcessor preProcessor) {
    throw new UnsupportedOperationException();
  }

  @Override
  public List<String> getLabels() {
    return Arrays.asList("positive","negative");
  }

  @Override
  public boolean hasNext() {
    return cursor < numExamples();
  }

  @Override
  public DataSet next() {
    return next(batchSize);
  }

  @Override
  public void remove() {

  }
  @Override
  public  DataSetPreProcessor getPreProcessor() {
    throw new UnsupportedOperationException("Not implemented");
  }
}

学習部分はこちらです。相変わらずパラメータ等は適当です。プログラムを見ればわかりますが、今回のRNNはLSTMにしています。LSTMは単語の記述順序を考慮できるモデルです。

public class SentimentRecurrentTrainCmd {

  public static void main (final String[] args) throws Exception {
    WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
    int numInputs   = wvec.lookupTable().layerSize();
    int numOutputs  = 2;
    int batchSize   = 50;
    int iterations  = 5;
    int nEpochs     = 5;

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(7485)
        .iterations(iterations)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .learningRate(0.0018)
        .updater(Updater.RMSPROP)
        .regularization(true).l2(1e-5)
        .weightInit(WeightInit.XAVIER)
        .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
        .gradientNormalizationThreshold(1.0)
        .list()
        .layer(0, new GravesLSTM.Builder()
            .nIn(numInputs).nOut(200)
            .activation("softsign")
            .build())
        .layer(1, new RnnOutputLayer.Builder()
            .lossFunction(LossFunctions.LossFunction.MCXENT)
            .activation("softmax")
            .nIn(200).nOut(numOutputs)
            .build())
        .pretrain(false).backprop(true).build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(new ScoreIterationListener(1));


    System.out.println("Starting training");
    DataSetIterator train = new AsyncDataSetIterator(
        new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),1);
    DataSetIterator test = new AsyncDataSetIterator(
        new SentimentRecurrentIterator(args[1],wvec,100,300,false),1);
    for( int i=0; i<nEpochs; i++ ){
      model.fit(train);
      train.reset();

      System.out.println("Epoch " + i + " complete. Starting evaluation:");
      Evaluation evaluation = new Evaluation();
      while(test.hasNext()) {
        DataSet t = test.next();
        INDArray features = t.getFeatures();
        INDArray lables = t.getLabels();
        INDArray inMask = t.getFeaturesMaskArray();
        INDArray outMask = t.getLabelsMaskArray();
        INDArray predicted = model.output(features,false,inMask,outMask);
        evaluation.evalTimeSeries(lables,predicted,outMask);
      }
      test.reset();
      System.out.println(evaluation.stats());
    }

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
  }

}

実験

実験条件はこれまでと同じです。

Examples labeled as 0 classified by model as 0: 136 times
Examples labeled as 0 classified by model as 1: 31 times
Examples labeled as 1 classified by model as 0: 96 times
Examples labeled as 1 classified by model as 1: 998 times
Accuracy: 0.8993
Precision: 0.778
Recall: 0.8633
F1 Score: 0.8185

F値82%はなかなかのものです。そして前回のニューラルネットの結果と違い、ちゃんとpositiveの正解率も上がっています。RNN+LSTMにしただけでこれほど目に見えて改善するとは驚きです。パラメータチューニングをもう少し丁寧にやればもっと性能が上がるかも?

おわりに

RNN+LSTMモデルにしたら精度がグーンと上がりました。この一連の投稿で今回はじめて自然言語処理で深層学習をやってみました。理論だけは少しは勉強していましたが、実際に使ってみるとRNN+LSTMはすごいですね。LSTMは単純なRNNのモデルなのにここまで精度が上がるとは・・・。とは言え、既存手法もパラメータチューニングをすれば精度は上がる可能性はあります。RNN+LSTMで手軽に精度を上げられることが脅威ですね。ということは、技術的に他社と差別化を図るためには他社が集められないほどのデータを大量の抱える等、努力が必要ですね。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です