Apitore blog

Apitoreを運営していた元起業家のブログ

deeplearning4jのword2vecの限界とその上手な使い方

はじめに

前回の記事でdeeplearning4jのword2vecに日本語の形態素解析器kuromojiを組み込みました。その後、日本語Wikipediaを学習させていたのですが・・・丸5日学習させても終わりませんでした。そこで、どこに時間がかかっているのか分析し、どういう運用が現状で良いか考えてみました。 amarec (20160922-173547)

deeplearning4jのword2vecのボトルネック

青空文庫の「吾輩は猫である」は著作権切れで無料に使えるので、こちらを学習データにしました。データはこちらから辿れます。全部だと少し量が多いので、先頭から10000行くらいを使いました。時間を計測するのがめんどくさかったので結論だけ言います。

  • パラメータは実行速度にそこまで影響しない
  • 自作したTokenizeFactoryを使うと致命的に遅くなる。これは私の実装の問題ではなく、形態素解析しながら処理すると遅くなるようだ。DefaultTokenizerFactoryを使うだけで劇的に早くなるので、予めmecabやkuromojiで分かち書きすれば良い。mecabやkuromojiを使えば分かち書きは一瞬で終了する。
  • deeplearning4jのword2vecの学習は、本家のC実装に比べると遅い。本家のC実装では、日本語wikipedia全文を3時間程度で処理する。

本家word2vecでのモデルの作り方とdeeplearning4jのword2vecでのモデルの読み込み方

本家のC版のword2vecはsvnレポジトリが死んでいるので、非正規になりますがsvnからgithubへコピーしたこちらのプロジェクトを使います。環境はcygwinでやりました。makeをすれば一発で使えるようになりました。 学習に入ります。日本語のWikipediaは分かち書きにしておきます。

$ mecab -Owakati infile -o outfile

私は以下のプログラムを書いて、分かち書きをしました。

public static void main (final String[] args) throws Exception {
  String infile   = args[0];
  String outfile  = args[1];
  BufferedReader br = new BufferedReader(new FileReader(infile));
  BufferedWriter bw = new BufferedWriter(new FileWriter(outfile));
  Tokenizer tokenizer = new Tokenizer();
  String line;
  int count=0;
  while ((line = br.readLine()) != null) {
    count++;
    //System.out.println(count);
    List<Token> list = tokenizer.tokenize(line);
    StringBuffer sb = new StringBuffer();
    for (Token tok: list) {
      String str = tok.getSurface();
      sb.append(str.toLowerCase());
      sb.append(" ");
    }
    bw.write(sb.toString()+"\n");
    if (count%100000 == 0) {
      bw.flush();
      System.out.println(count);
    }
  }
  bw.flush();
  br.close();
  bw.close();
}

本家word2vecの学習に入ります。パラメータは前回の記事のパラメータに近くなるようにしたつもりです。間違ってたら教えて下さい。

$ ./word2vec -train wakati.txt -output out.model -cbow 0 -size 200 -hs 0 -window 5 -sample 1e-3 -negative 5 -threads 6 -iter 5 -min-count 5 -binary 0

「binary 0」がポイントです。バイナリファイルにしないとモデルが大きくなってしまいますが、deeplearning4jのword2vecで読み込めなかったので仕方ないです。バイナリファイルにしない場合は、モデルはベクトル表現で記述されます。出力したモデルをdeeplearning4jのword2vecで読み込みます。

InputStream is = new FileInputStream("out.model");
WordVectors vec = WordVectorSerializer.loadTxtVectors(is, true);
Collection<String> lst = vec.wordsNearest("day", 10);
System.out.println(lst);

おわりに

本家のword2vecはさすがのgoogle製、爆速です。deeplearning4jはまだversionが0.5.0なので、word2vecの速度については仕方がないと思います。一方で、deeplearning4jはJavaなので、Spring Bootを使ったアプリケーションとは相性抜群です。APIを公開したら、また記事にして紹介します。