public class SentimentRecurrentIterator implements DataSetIterator {
private static final long serialVersionUID = 3720823704002518845L;
private final WordVectors wordVectors;
private final int batchSize;
private final int vectorSize;
private final int truncateLength;
private int cursor = 0;
private int posCursor = 0;
private int negCursor = 0;
private final File[] positiveFiles;
private final File[] negativeFiles;
private final int numPositives;
private final int numNegatives;
private final int numTotals;
private final Random rnd;
private final TokenizerFactory tokenizerFactory;
* @param dataDirectory the directory of the IMDB review data set
* @param wordVectors WordVectors object
* @param batchSize Size of each minibatch for training
* @param truncateLength If reviews exceed
* @param train If true: return the training data. If false: return the testing data.
public SentimentRecurrentIterator(String dataDirectory, WordVectors wordVectors, int batchSize, int truncateLength, boolean train) throws IOException {
this.batchSize = batchSize;
this.vectorSize = wordVectors.lookupTable().layerSize();
File p = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/positive/") + "/");
File n = new File(FilenameUtils.concat(dataDirectory, (train ? "train" : "test") + "/negative/") + "/");
positiveFiles = p.listFiles();
negativeFiles = n.listFiles();
numPositives = positiveFiles.length;
numNegatives = negativeFiles.length;
numTotals = numPositives+numNegatives;
rnd = new Random(1);
this.wordVectors = wordVectors;
this.truncateLength = truncateLength;
tokenizerFactory = new DefaultTokenizerFactory();
tokenizerFactory.setTokenPreProcessor(new CommonPreprocessor());
public DataSet next(int num) {
if (cursor >= positiveFiles.length + negativeFiles.length) throw new NoSuchElementException();
return nextDataSet(num);
}catch(IOException e){
throw new RuntimeException(e);
private DataSet nextDataSet(int num) throws IOException {
//First: load reviews to String. Alternate positive and negative reviews
List<String> reviews = new ArrayList<>(num);
boolean[] positive = new boolean[num];
for( int i=0; i<num && cursor<totalExamples(); i++ ){
int idx = rnd.nextInt(numTotals);
boolean mode = modeJudge(idx);
if (mode) {
String review = FileUtils.readFileToString(positiveFiles[posCursor]);
positive[i] = true;
} else {
String review = FileUtils.readFileToString(negativeFiles[negCursor]);
positive[i] = false;
//Second: tokenize reviews and filter out unknown words
List<List<String>> allTokens = new ArrayList<>(reviews.size());
int maxLength = 0;
for(String s : reviews){
List<String> tokens = tokenizerFactory.create(s).getTokens();
List<String> tokensFiltered = new ArrayList<>();
for(String t : tokens ){
if(wordVectors.hasWord(t)) tokensFiltered.add(t);
maxLength = Math.max(maxLength,tokensFiltered.size());
//If longest review exceeds 'truncateLength': only take the first 'truncateLength' words
if(maxLength > truncateLength) maxLength = truncateLength;
//Create data for training
//Here: we have reviews.size() examples of varying lengths
INDArray features = Nd4j.create(reviews.size(), vectorSize, maxLength);
INDArray labels = Nd4j.create(reviews.size(), 2, maxLength); //Two labels: positive or negative
//Because we are dealing with reviews of different lengths and only one output at the final time step: use padding arrays
//Mask arrays contain 1 if data is present at that time step for that example, or 0 if data is just padding
INDArray featuresMask = Nd4j.zeros(reviews.size(), maxLength);
INDArray labelsMask = Nd4j.zeros(reviews.size(), maxLength);
int[] temp = new int[2];
for( int i=0; i<reviews.size(); i++ ){
List<String> tokens = allTokens.get(i);
temp[0] = i;
//Get word vectors for each word in review, and put them in the training data
for( int j=0; j<tokens.size() && j<maxLength; j++ ){
String token = tokens.get(j);
INDArray vector = wordVectors.getWordVectorMatrix(token);
features.put(new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(j)}, vector);
temp[1] = j;
featuresMask.putScalar(temp, 1.0); //Word is present (not padding) for this example + time step -> 1.0 in features mask
int idx = (positive[i] ? 0 : 1);
int lastIdx = Math.min(tokens.size(),maxLength);
labels.putScalar(new int[]{i,idx,lastIdx-1},1.0); //Set label: [0,1] for negative, [1,0] for positive
labelsMask.putScalar(new int[]{i,lastIdx-1},1.0); //Specify that an output exists at the final time step for this example
return new DataSet(features,labels,featuresMask,labelsMask);
private boolean modeJudge(int idx) {
if (posCursor >= numPositives)
return false;
else if (negCursor >= numNegatives)
return true;
else if (idx < numPositives)
return true;
return false;
public int totalExamples() {
return positiveFiles.length + negativeFiles.length;
public int inputColumns() {
return vectorSize;
public int totalOutcomes() {
return 2;
public void reset() {
cursor = 0;
posCursor = 0;
negCursor = 0;
public boolean resetSupported() {
return true;
public boolean asyncSupported() {
return true;
public int batch() {
return batchSize;
public int cursor() {
return cursor;
public int numExamples() {
return totalExamples();
public void setPreProcessor(DataSetPreProcessor preProcessor) {
throw new UnsupportedOperationException();
public List<String> getLabels() {
return Arrays.asList("positive","negative");
public boolean hasNext() {
return cursor < numExamples();
public DataSet next() {
return next(batchSize);
public void remove() {
public DataSetPreProcessor getPreProcessor() {
throw new UnsupportedOperationException("Not implemented");
public class SentimentRecurrentTrainCmd {
public static void main (final String[] args) throws Exception {
WordVectors wvec = WordVectorSerializer.loadTxtVectors(new File(args[0]));
int numInputs = wvec.lookupTable().layerSize();
int numOutputs = 2;
int batchSize = 50;
int iterations = 5;
int nEpochs = 5;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.layer(0, new GravesLSTM.Builder()
.layer(1, new RnnOutputLayer.Builder()
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.setListeners(new ScoreIterationListener(1));
System.out.println("Starting training");
DataSetIterator train = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[1],wvec,batchSize,300,true),1);
DataSetIterator test = new AsyncDataSetIterator(
new SentimentRecurrentIterator(args[1],wvec,100,300,false),1);
for( int i=0; i<nEpochs; i++ ){
System.out.println("Epoch " + i + " complete. Starting evaluation:");
Evaluation evaluation = new Evaluation();
while(test.hasNext()) {
DataSet t = test.next();
INDArray features = t.getFeatures();
INDArray lables = t.getLabels();
INDArray inMask = t.getFeaturesMaskArray();
INDArray outMask = t.getLabelsMaskArray();
INDArray predicted = model.output(features,false,inMask,outMask);
System.out.println("Save model");
ModelSerializer.writeModel(model, new FileOutputStream(args[2]), true);
実験条件はこれまでと同じです。 Examples labeled as 0 classified by model as 0: 136 times Examples labeled as 0 classified by model as 1: 31 times Examples labeled as 1 classified by model as 0: 96 times Examples labeled as 1 classified by model as 1: 998 times Accuracy: 0.8993 Precision: 0.778 Recall: 0.8633 F1 Score: 0.8185 F値82%はなかなかのものです。そして前回のニューラルネットの結果と違い、ちゃんとpositiveの正解率も上がっています。RNN+LSTMにしただけでこれほど目に見えて改善するとは驚きです。パラメータチューニングをもう少し丁寧にやればもっと性能が上がるかも?