Movatterモバイル変換


[0]ホーム

URL:


LOL Cross-Validation

Eric Bridgeford

2020-06-25

Cross Validation to Determine Error of Embedding Alg composed with Classifier

require(lolR)require(ggplot2)n =400d =30r =3

Here, we look at how to see the generalizability of a given model in the form of the cross-validated error. We simulate data withn=200 andd=30:

testdat <-lol.sims.rtrunk(n, d)X <-testdat$XY <-testdat$Ydata <-data.frame(x1=X[,1],x2=X[,2],y=Y)data$y <-factor(data$y)ggplot(data,aes(x=x1,y=x2,color=y))+geom_point()+xlab("x1")+ylab("x2")+ggtitle("Simulated Data")

We arbitrarily select LOL as our algorithm, and look at the leave-one-out (loo) cross-validated error with the LDA classifier. We project the resulting model to3 dimensions and visualize the first2:

result <-lol.xval.eval(X, Y, r,alg = lol.project.lol,alg.return="A",classifier=MASS::lda,classifier.return="class",k='loo')data <-data.frame(x1=result$model$Xr[,1],x2=result$model$Xr[,2],y=Y)data$y <-factor(data$y)ggplot(data,aes(x=x1,y=x2,color=y))+geom_point()+xlab("x1")+ylab("x2")+ggtitle(sprintf("Projected Data using LOL, L=%.2f", result$lhat))

Cross Validation to Determine Optimal Embedding Dimensions of Embedding Alg composed with Classifier

result <-lol.xval.optimal_dimselect(X, Y,rs=c(5,10,15),alg = lol.project.lol,alg.return="A",classifier=MASS::lda,classifier.return="class",k='loo')data <-data.frame(x1=result$model$Xr[,1],x2=result$model$Xr[,2],y=Y)data$y <-factor(data$y)ggplot(data,aes(x=x1,y=x2,color=y))+geom_point()+xlab("x1")+ylab("x2")+ggtitle(sprintf("Projected Data using LOL, L=%.2f", result$optimal.lhat))

ggplot(result$foldmeans.data,aes(x=r,y=lhat))+geom_line()+xlab("Embedding Dimensions, r")+ylab("Misclassification Rate, L")+ggtitle("Impact on Misclassification Rate of Embedding Dimension")

print(sprintf("optimal dimension: %d", result$optimal.r))
## [1] "optimal dimension: 5"
print(sprintf("Misclassification rate at rhat = %d: %.2f", result$optimal.r, result$optimal.lhat))
## [1] "Misclassification rate at rhat = 5: 0.07"## [2] "Misclassification rate at rhat = 5: 0.07"

[8]ページ先頭

©2009-2025 Movatter.jp