Apitore blog

Apitoreを運営していた元起業家のブログ

Kaggleのタイタニックチュートリアルで粘ったら精度80%を超えた

はじめに

過去2回にわたってKaggleのタイタニックチュートリアルをやったけど、精度80%は超えたいと思ってたので諦めきれずにあがいてみた。そしてついに80%を超えたので記事にする。精度80.9%で2016年7月5日現在で366位/4430人中となった。 【過去の記事】

やったこと

ひとことで言うと、オーバーフィッティングを疑って、学習データを間引いた。 そもそも、生データ眺めてみていて、どうも同じようなデータで正負が反対のものが多い印象だった。例えば、SibSp=0かつParch=0かつSex=maleかつPclass=3で見てみると、264人中32人が生存している。32人と232人の生死が別れた理由を年齢で見比べてみると、年齢不詳を除けば、生存者は16~45だった。さらに16~45で絞ってみると、171人中26名が生存している。残る情報はCabinとTicketだが、Cabinはほとんどブランクのため使い物にならず、Ticketも特に奇数偶数で生死に違いがあるわけでもなさそうだった。 つまり、ほとんど同じような素性で生死が分かれている状況だった。こうなってくるともう機械学習だろうが人だろうが生死を推測できない。 私はタイタニックチュートリアルの最初の投稿で「家族の生死が本人の生死に関係する」という予測を立てた。これは間違ってないと思う。ということで、この素性を中心に信頼できそうな素性を足して、オーバーフィッティングを避ける形で適当にデータを間引くことにした。曖昧な部分をあえて残すことで、識別器の汎化性能を上げる。 使った素性を以下に示す。学習データはもともとランダムに並び替えられているそうなので、先頭から700個を使うことにした。700を決めるために、学習データの数を400~891の間で50刻みに動かして一番精度が良かったものを選んでいて、結構適当である。まじめにやるならば、CV的なアプローチで学習データの適切な数を求めたほうがいいと思う。識別器はRandomForest(Ranger)を使った。

属性名
Pclass クラス
FamilyStatus 家族の生死
Sex 性別
RangeAge 年齢

おわりに

結果的に、精度は80%を超えた。検定していないのでなんとも言えないが、テストデータが400以上で約1.5%の改善・・・6個正解が増えた・・・有意差は・・・ないかもなあ。というわけで、目標の80%は超えたので本当にこれで終わり。小手先のテクニックは出し尽くした!

ソースコード

関係ないコードも混じってるけど、そこは適宜削除してください。

library(ranger)
# 家族の生死のリスト準備
familystatus = function(d) {
  rows = nrow(d)
  rtn = NULL
  for (i in 1:rows) {
    if (d[i,"SibSp"]==0 && d[i,"Parch"]==0) {
      # do nothing
    } else {
      name = as.character(d[i,"FamilyName"])
      survive = d[i,"Survived"]
      if (survive==0) survive=-1
      if (is.null(rtn[name]) || is.na(rtn[name])) {
        rtn[name]=survive
      } else {
        rtn[name]=rtn[name]+survive
      }
    }
  }
  rtn
}
# 家族の生死素性化
addfamilystatus = function(d,l) {
  rows = nrow(d)
  rtn = NULL
  for (i in 1:rows) {
    if (d[i,"SibSp"]==0 && d[i,"Parch"]==0) {
      rtn = c(rtn,"U")
    } else {
      name = as.character(d[i,"FamilyName"])
      if (is.na(l[name])) {
        rtn = c(rtn,"U")
      } else if (l[name]>0) {
        rtn = c(rtn,"A")
      } else {
        rtn = c(rtn,"D")
      }
    }
  }
  rtn
}
# 家族の生死のリストアップデート
updatefamilystatus = function(d,l) {
  rows = nrow(d)
  rtn = l
  for (i in 1:rows) {
    if (d[i,"SibSp"]==0 && d[i,"Parch"]==0) {
      # do nothing
    } else {
      name = as.character(d[i,"FamilyName"])
      survive = d[i,"Survived"]
      if (survive==0) survive=-1
      if (is.null(rtn[name]) || is.na(rtn[name])) {
        rtn[name]=survive
      } else {
        rtn[name]=rtn[name]+survive
      }
    }
  }
  rtn
}
## Model
data = read.csv("train.csv")
data$Pclass = ifelse(is.na(data$Pclass), as.character(-1), as.character(data$Pclass))
data$FamilyName = gsub(",.+$","",data$Name)
data$Cabin = gsub("[0-9 ].*$","",data$Cabin)
data$Age = ifelse(is.na(data$Age), -1, round(data$Age))
data$RangeSibSp = ifelse(data$SibSp>2, "2", "1")
data$RangeParch = ifelse(data$Parch>2, "2", "1")
data$TicketHead = NULL
data$TicketNo = NULL
tickets = strsplit(as.character(data$Ticket)," ")
data$RangeAge = NULL
rows = nrow(data)
for (i in 1:rows) {
  if (length(tickets[[i]])==1) {
    data[i, "TicketHead"] = "None"
    data[i, "TicketNo"] = gsub("...$","",tickets[[i]][1])
  } else if (length(tickets[[i]])==3) {
    data[i, "TicketHead"] = paste(tickets[[i]][1],tickets[[i]][2],sep="")
    data[i, "TicketNo"] = gsub("...$","",tickets[[i]][3])
  } else {
    data[i, "TicketHead"] = tickets[[i]][1]
    data[i, "TicketNo"] = gsub("...$","",tickets[[i]][2])
  }
  chk = data[i,"Age"]
  if (chk == -1) {
    data[i,"RangeAge"] = as.character("UNK")
  } else if (chk==0) {
    data[i,"RangeAge"] = as.character("1")
  } else if (chk>=1 && chk<6) {
    data[i,"RangeAge"] = as.character("2")
  } else if (chk>=6 && chk<10) {
    data[i,"RangeAge"] = as.character("3")
  } else if (chk>=10 && chk<20) {
    data[i,"RangeAge"] = as.character("4")
  } else if (chk>=20 && chk<30) {
    data[i,"RangeAge"] = as.character("5")
  } else if (chk>=30 && chk<40) {
    data[i,"RangeAge"] = as.character("6")
  } else if (chk>=40 && chk<50) {
    data[i,"RangeAge"] = as.character("7")
  } else if (chk>=50 && chk<60) {
    data[i,"RangeAge"] = as.character("8")
  } else {
    data[i,"RangeAge"] = as.character("9")
  }
}
data$Age = as.character(data$Age)
fsurvivelist = familystatus(data)
data$FamilyStatus = addfamilystatus(data,fsurvivelist)
bst = ranger(formula=Survived~Pclass+FamilyStatus+Sex+RangeAge,
             data=data[1:700,],
             num.trees = 100,
             classification=TRUE,
             write.forest=TRUE,
             seed=7485)
print(mean(data[1:700,]$Survived == bst$predictions))
## Predict
data.test = NULL
data.test = read.csv("test.csv")
data.test$Survived = 0
data.test$Pclass = ifelse(is.na(data.test$Pclass), as.character(-1), as.character(round(data.test$Pclass)))
data.test$FamilyName = gsub(",.+$","",data.test$Name)
data.test$Cabin = gsub("[0-9 ].*$","",data.test$Cabin)
data.test$Age = ifelse(is.na(data.test$Age), -1, round(data.test$Age))
data.test$RangeSibSp = ifelse(data.test$SibSp>2, "2", "1")
data.test$RangeParch = ifelse(data.test$Parch>2, "2", "1")
data.test$RangeAge = NULL
rows = nrow(data.test)
for (i in 1:rows) {
  chk = data.test[i,"Age"]
  if (chk == -1) {
    data.test[i,"RangeAge"] = as.character("UNK")
  } else if (chk==0) {
    data.test[i,"RangeAge"] = as.character("1")
  } else if (chk>=1 && chk<6) {
    data.test[i,"RangeAge"] = as.character("2")
  } else if (chk>=6 && chk<10) {
    data.test[i,"RangeAge"] = as.character("3")
  } else if (chk>=10 && chk<20) {
    data.test[i,"RangeAge"] = as.character("4")
  } else if (chk>=20 && chk<30) {
    data.test[i,"RangeAge"] = as.character("5")
  } else if (chk>=30 && chk<40) {
    data.test[i,"RangeAge"] = as.character("6")
  } else if (chk>=40 && chk<50) {
    data.test[i,"RangeAge"] = as.character("7")
  } else if (chk>=50 && chk<60) {
    data.test[i,"RangeAge"] = as.character("8")
  } else {
    data.test[i,"RangeAge"] = as.character("9")
  }
}
data.test$Age = as.character(data.test$Age)
#data.test$FamilyStatus = addfamilystatus(data.test,fsurvivelist)
#pred = predict(bst, data.test)
#prediction = pred$predictions
#rtn=NULL
#rtn$PassengerId=data.test$PassengerId
#rtn$Survived = pred$predictions
#write.csv(rtn,'mypredict.csv')
rtn.pred = NULL
rows = nrow(data.test)
for (j in 1:rows) {
  dtest = data.test
  dtest$FamilyStatus = addfamilystatus(dtest,fsurvivelist)
  pred = predict(bst, dtest)
  prediction = pred$predictions
  rtn.pred = c(rtn.pred,prediction[j])
  # update fsurvivelist
  dtest$Survived = prediction
  fsurvivelist = updatefamilystatus(dtest[j,],fsurvivelist)
}
rtn=NULL
rtn$PassengerId=data.test$PassengerId
rtn$Survived = rtn.pred
write.csv(rtn,'mypredict.csv')