Apitore blog

Apitoreを運営していた元起業家のブログ

deeplearning4jのdoc2vecで極性判定してみた

はじめに

deeplearning4jのdoc2vec (正確にはparagraph2vec) で極性判定をしてみました。学習データは自作した目的語リスト(約2,300語)を含む極性ツイートです。TwitterのSearchAPIでTwitter社がつけた極性付きツイートをクロールしています。 amarec (20161015-202032)

実装

ネット上に情報が少ないですが、公式にサンプルプログラミングがあるのでそちらを参考にしました。まずpomファイルから。

<dependency>
  <groupId>org.deeplearning4j</groupId>
  <artifactId>deeplearning4j-ui</artifactId>
  <version>0.6.0</version>
</dependency>
<dependency>
  <groupId>org.deeplearning4j</groupId>
  <artifactId>deeplearning4j-nlp</artifactId>
  <version>0.6.0</version>
</dependency>
<dependency>
  <groupId>org.nd4j</groupId>
  <artifactId>nd4j-native</artifactId>
  <version>0.6.0</version>
</dependency>

つづいてプログラム。

ClassPathResource resource1 = new ClassPathResource("paravec/train");
LabelAwareIterator iter1 = new FileLabelAwareIterator.Builder()
    .addSourceFolder(resource1.getFile())
    .build();
TokenizerFactory t = new DefaultTokenizerFactory();
t.setTokenPreProcessor(new CommonPreprocessor());
System.out.println( "Build model..." );
int batchSize   = 1000;
int iterations  = 5;
int layerSize   = 200;
ParagraphVectors vec = new ParagraphVectors.Builder()
    .batchSize(batchSize)
    .epochs(20)
    .trainWordVectors(true)
    .minWordFrequency(1)
    //.useAdaGrad(false)
    //.layerSize(layerSize)
    //.iterations(iterations)
    .seed(1)
    //.windowSize(5)
    .learningRate(0.025)
    .minLearningRate(1e-3)
    //.sampling(0)
    //.negativeSample(0)
    .iterate(iter1)
    .tokenizerFactory(t)
    .workers(6)
    //.labelsSource(new LabelsSource(Arrays.asList("negative", "neutral","positive")))
    //.stopWords(new ArrayList<String>())
    .build();
vec.fit();
WordVectorSerializer.writeParagraphVectors(vec, args[0]);
ClassPathResource resource2 = new ClassPathResource("paravec/test");
LabelAwareIterator iter2 = new FileLabelAwareIterator.Builder()
    .addSourceFolder(resource2.getFile())
    .build();
MeansBuilder meansBuilder = new MeansBuilder(
    (InMemoryLookupTable<VocabWord>) vec.getLookupTable(), t);
LabelSeeker seeker = new LabelSeeker(iter2.getLabelsSource().getLabels(),
    (InMemoryLookupTable<VocabWord>) vec.getLookupTable());
int crr=0,err=0;
while (iter2.hasNextDocument()) {
  LabelledDocument document = iter2.nextDocument();
  INDArray documentAsCentroid = meansBuilder.documentAsVector(document);
  List<Pair<String, Double>> scores = seeker.getScores(documentAsCentroid);
  String pLabel = null;
  Double pScore= -10000D;
  System.out.println("Document '" + document.getLabel() + "' falls into the following categories: ");
  for (Pair<String, Double> score: scores) {
    System.out.println("        " + score.getFirst() + ": " + score.getSecond());
    if (pScore<score.getSecond()) {
      pScore = score.getSecond();
      pLabel = score.getFirst();
    }
  }
  if (document.getLabel().equals(pLabel))
    crr++;
  else
    err++;
}
double acc = 1.0*crr/(crr+err);
System.out.println("Accuracy: " + acc + ", Correct: " + crr);

学習データは以下のような形で格納します。

  • positiveラベルのデータ
    • paravec/train/positive/1.txt
    • paravec/train/positive/2.txt
    • ...
  • negativeラベルのデータ
    • paravec/train/negative/1.txt
    • paravec/train/negative/2.txt
    • ...

フォルダ名がラベルになっているのが特徴です。データの形式はword2vecのときと同様で、今回の場合はツイートに形態素解析を適用して半角スペースで分かち書きしておきます。1ファイル1ツイートにしてください。学習パラメータは適当です。各自で試行錯誤する部分ですね。

実験

約2300語の目的語リストを作り、その単語を含む極性付きツイートを学習データとします。極性付きツイートは、例えば「iPhone :)」と検索するとポジティブツイートが、「iPhone :(」と検索するとネガティブツイートが取得できます。Twitter社がどのようなアルゴリズムで極性判定をしているかわかりませんが、精度はあまり良くなさそうです。Distant Supervision的な学習方法なので、まあ何とかなるでしょう! 学習はpositive 6000ツイート、negative 40900ツイートです。同じように集めたのに極端にネガティブが多くなりました。評価用に別途 positive 160ツイート、negative 1100ツイートを用意しました。 気になる精度は・・・約29%!!!低っ!!!!? ちょっと予測結果のスコアを確認してみましたが、positiveのスコアがかなり高く出ていました。ほぼpositiveに分類されています。うーん、なぜだ?ちなみに、「positiveのスコアが0.7を超えない場合はnegativeとする」というposi/nega分類にしたら、精度は80%くらいになりました。うーん、こういうヒューリスティックは良くないな・・・。 学習データの量がラベルでインバランスなのが原因かと思いましたが、揃えてもあまり変わりませんでした。どこかにバグがあるのかな?ひとつ言えるのは、データがまだ少ないので、OOV(Out-Of-Vocabulary)が多くなっています。それでpositiveのスコアが高くなるとは限りませんが、語彙数は増やす必要がありますね。

おわりに

deeplearning4jのdoc2vecで極性判定をしてみました。性能が非常に悪いので、現状使えるとは言い難いです。もう少し改善をしてからWebAPIとして公開します。