deeplearning4jでword2vecのベクトルデータからNNで極性判定してみる

はじめに

Word2Vecの出力を機械学習の入力にすると精度が良くなると評判です。そこで今回は、以前作ったWord2Vecでツイートをベクトルデータにして、それをニューラルネットに突っ込んで極性判定してみました。
amarec (20161016-163219)

関連情報

実装

今回もdeeplearning4jを使います。従来手法はニューラルネットではなくSVMを用いる場合が多いですが、最近deeplearning4jで色々やってたというのもあったので、そのまま行きます。今回もdeeplearning4jの公式のサンプルプログラムを存分に活用・拡張させていただきました。

まずSentimentツイートの取得部分を作ります。公式のサンプルプログラムを多少いじりました。大きな変更点としては、ポジティブとネガティブで学習データがインバランスでも学習できるようにしました。

public class SentimentSimpleIterator implements DataSetIterator {

  /**
   *
   */
  private static final long serialVersionUID = 3720823704002518845L;

  private final WordVectors wordVectors;
  private final int batchSize;
  private final int vectorSize;

  private int cursor = 0;
  private int posCursor = 0;
  private int negCursor = 0;
  private final File[] positiveFiles;
  private final File[] negativeFiles;
  private final int numPositives;
  private final int numNegatives;
  private final int numTotals;
  private final Random rnd;
  private final TokenizerFactory tokenizerFactory;

  /**
   * @param dataDirectory the directory of the IMDB review data set
   * @param wordVectors WordVectors object
   * @param batchSize Size of each minibatch for training
   * @param truncateLength If reviews exceed
   * @param train If true: return the training data. If false: return the testing data.
   */
  public SentimentSimpleIterator(String dataDirectory, WordVectors wordVectors, int batchSize, int truncateLength, boolean train) throws IOException {
    this.batchSize = batchSize;
    this.vectorSize = wordVectors.lookupTable().layerSize();

    File p = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/positive/") + "/");
    File n = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/negative/") + "/");
    positiveFiles = p.listFiles();
    negativeFiles = n.listFiles();
    numPositives  = positiveFiles.length;
    numNegatives  = negativeFiles.length;
    numTotals     = numPositives+numNegatives;
    rnd           = new Random(1);

    this.wordVectors = wordVectors;

    tokenizerFactory = new DefaultTokenizerFactory();
    tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
  }


  @Override
  public DataSet next(int num) {
    if (cursor >= positiveFiles.length + negativeFiles.length) throw new NoSuchElementException();
    try{
      return nextDataSet(num);
    }catch(IOException e){
      throw new RuntimeException(e);
    }
  }

  private DataSet nextDataSet(int num) throws IOException {
    //First: load reviews to String. Alternate positive and negative reviews
    List<String> reviews = new ArrayList<>(num);
    boolean[] positive = new boolean[num];
    for( int i=0; i<num && cursor<totalExamples(); i++ ){
      int idx = rnd.nextInt(numTotals);
      boolean mode = modeJudge(idx);
      if (mode) {
        String review = FileUtils.readFileToString(positiveFiles[posCursor]);
        reviews.add(review);
        positive[i] = true;
        posCursor++;
      } else {
        String review = FileUtils.readFileToString(negativeFiles[negCursor]);
        reviews.add(review);
        positive[i] = false;
        negCursor++;
      }
      cursor++;
    }

    //Second: tokenize reviews and filter out unknown words
    List<List<String>> allTokens = new ArrayList<>(reviews.size());
    for(String s : reviews){
      List<String> tokens = tokenizerFactory.create(s).getTokens();
      List<String> tokensFiltered = new ArrayList<>();
      for(String t : tokens ){
        if(wordVectors.hasWord(t)) tokensFiltered.add(t);
      }
      allTokens.add(tokensFiltered);
    }

    //Create data for training
    //Here: we have reviews.size() examples of varying lengths
    List<INDArray> inputs = new ArrayList<>();
    List<INDArray> labels = new ArrayList<>();

    for( int i=0; i<reviews.size(); i++ ){
      List<String> tokens = allTokens.get(i);
      //Get word vectors for each word in review, and put them in the training data
      for( int j=0; j<tokens.size() && j<1; j++ ){
        INDArray vector = wordVectors.getWordVectorsMean(tokens);
        inputs.add(vector);
        int idx = (positive[i] ? 0 : 1);
        INDArray label = FeatureUtil.toOutcomeVector(idx, 2);
        labels.add(label);
      }
    }

    return new DataSet(Nd4j.vstack(inputs.toArray(new INDArray[0])), Nd4j.vstack(labels.toArray(new INDArray[0])));
  }

  private boolean modeJudge(int idx) {
    if (posCursor >= numPositives)
      return false;
    else if (negCursor >= numNegatives)
      return true;
    else if (idx < numPositives)
      return true;
    else
      return false;
  }

  @Override
  public int totalExamples() {
    return positiveFiles.length + negativeFiles.length;
  }

  @Override
  public int inputColumns() {
    return vectorSize;
  }

  @Override
  public int totalOutcomes() {
    return 2;
  }

  @Override
  public void reset() {
    cursor = 0;
    posCursor = 0;
    negCursor = 0;
  }

  public boolean resetSupported() {
    return true;
  }

  @Override
  public boolean asyncSupported() {
    return true;
  }

  @Override
  public int batch() {
    return batchSize;
  }

  @Override
  public int cursor() {
    return cursor;
  }

  @Override
  public int numExamples() {
    return totalExamples();
  }

  @Override
  public void setPreProcessor(DataSetPreProcessor preProcessor) {
    throw new UnsupportedOperationException();
  }

  @Override
  public List<String> getLabels() {
    return Arrays.asList("positive","negative");
  }

  @Override
  public boolean hasNext() {
    return cursor < numExamples();
  }

  @Override
  public DataSet next() {
    return next(batchSize);
  }

  @Override
  public void remove() {

  }
  @Override
  public  DataSetPreProcessor getPreProcessor() {
    throw new UnsupportedOperationException("Not implemented");
  }

  /** Convenience method for loading review to String */
  public String loadReviewToString(int index) throws IOException{
    File f;
    if(index%2 == 0) f = positiveFiles[index/2];
    else f = negativeFiles[index/2];
    return FileUtils.readFileToString(f);
  }

  /** Convenience method to get label for review */
  public boolean isPositiveReview(int index){
    return index%2 == 0;
  }
}

続いて学習部分のプログラム。Word2Vecの出力を受け取って、pos/negの二値に収束させるニューラルネットです。まだdeeplearning4jの深層学習モジュールの使い方がよくわかっていないので結構適当です。たぶんパラメータチューニングは色々出来ると思います。

public class SentimentTrainCmd {

  public static void main (final String[] args) throws Exception {
    WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
    int numInputs   = wvec.lookupTable().layerSize();
    int numOutputs  = 2;
    int batchSize   = 50;
    int iterations  = 5;
    int nEpochs     = 5;

    MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .seed(7485)
        .iterations(iterations)
        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
        .learningRate(0.01)
        .updater(Updater.RMSPROP)
        .list()
        .layer(0, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
            .weightInit(WeightInit.XAVIER)
            .activation("softmax").weightInit(WeightInit.XAVIER)
            .nIn(numInputs).nOut(numOutputs)
            .build())
        .pretrain(false).backprop(true).build();

    MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();
    model.setListeners(new ScoreIterationListener(10));


    System.out.println("Starting training");
    DataSetIterator train = new AsyncDataSetIterator(
        new SentimentSimpleIterator(args[1],wvec,batchSize,300,true),1);
    DataSetIterator test = new AsyncDataSetIterator(
        new SentimentSimpleIterator(args[1],wvec,100,300,false),1);
    for( int i=0; i<nEpochs; i++ ){
      model.fit(train);
      train.reset();

      System.out.println("Epoch " + i + " complete. Starting evaluation:");
      Evaluation evaluation = new Evaluation();
      while(test.hasNext()) {
        DataSet t = test.next();
        INDArray features = t.getFeatures();
        INDArray lables = t.getLabels();
        INDArray predicted = model.output(features,false);
        evaluation.eval(lables,predicted);
      }
      test.reset();
      System.out.println(evaluation.stats());
    }

    System.out.println("Save model");
    ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
  }

}

実験

実験条件はこれまでと同じです。さて肝心の精度ですが、、、

Examples labeled as 0 classified by model as 0: 37 times
Examples labeled as 0 classified by model as 1: 130 times
Examples labeled as 1 classified by model as 0: 3 times
Examples labeled as 1 classified by model as 1: 1091 times
Accuracy: 0.8945
Precision: 0.9093
Recall: 0.6094
F1 Score: 0.7297

Accuracyが約90%で、F値も約73%と悪くないです。しかしよく見てみると、分類結果はほとんどnegativeを出力していることがわかります。学習もテストも恣意的なことは何一つないのですが、、、うーん問題ですね。negativeを出力しやすい原因は明らかで、学習データがnegativeが多いからです。一般的に、ラベルに偏りがある場合、SVMなどの機械学習ではそのラベルを出力しやすくなってしまいます。俗に言う過学習です。学習データのバランスを揃えてやればうまくいくと思いますが、、、そういう旧時代的な調整は深層学習という新時代にはなるべくしたくないですね。と言っても、これだと極性判定器としては使い物にならないので何とかしないといけませんが・・・。

おわりに

第3回にて、これまでやってきた機械学習のイメージ通りの結果が出てきました。学習データがインバランスの場合はデータの多い方に結果が偏ることはよくあります。今回の場合はnegativeが多いので、無事negativeに倒れてくれてよかった(?)です。さて、次回はRNNを試してみます。deeplearning4jのサンプルプログラムにword2vecからのRNNがあったので、簡単に試せるでしょう。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です