LibsvmにJavaバインディングを使用しようとしています:
http://www.csie.ntu.edu.tw/~cjlin/libsvm/
Yで簡単に線形分離可能な「トリビアル」な例を実装しました。データは次のように定義されています。
double[][] train = new double[1000][];
double[][] test = new double[10][];
for (int i = 0; i < train.length; i++){
if (i+1 > (train.length/2)){ // 50% positive
double[] vals = {1,0,i+i};
train[i] = vals;
} else {
double[] vals = {0,0,i-i-i-2}; // 50% negative
train[i] = vals;
}
}
最初の「機能」はクラスであり、トレーニングセットも同様に定義されています。
モデルをトレーニングするには:
private svm_model svmTrain() {
svm_problem prob = new svm_problem();
int dataCount = train.length;
prob.y = new double[dataCount];
prob.l = dataCount;
prob.x = new svm_node[dataCount][];
for (int i = 0; i < dataCount; i++){
double[] features = train[i];
prob.x[i] = new svm_node[features.length-1];
for (int j = 1; j < features.length; j++){
svm_node node = new svm_node();
node.index = j;
node.value = features[j];
prob.x[i][j-1] = node;
}
prob.y[i] = features[0];
}
svm_parameter param = new svm_parameter();
param.probability = 1;
param.gamma = 0.5;
param.nu = 0.5;
param.C = 1;
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.cache_size = 20000;
param.eps = 0.001;
svm_model model = svm.svm_train(prob, param);
return model;
}
次に、使用するモデルを評価します。
public int evaluate(double[] features) {
svm_node node = new svm_node();
for (int i = 1; i < features.length; i++){
node.index = i;
node.value = features[i];
}
svm_node[] nodes = new svm_node[1];
nodes[0] = node;
int totalClasses = 2;
int[] labels = new int[totalClasses];
svm.svm_get_labels(_model,labels);
double[] prob_estimates = new double[totalClasses];
double v = svm.svm_predict_probability(_model, nodes, prob_estimates);
for (int i = 0; i < totalClasses; i++){
System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")");
}
System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");
return (int)v;
}
渡された配列がテストセットからのポイントである場合。
結果は常にクラス0を返します。正確な結果は次のとおりです。
(0:0.9882998314585194)(1:0.011700168541480586)(Actual:0.0 Prediction:0.0)
(0:0.9883952943701599)(1:0.011604705629839989)(Actual:0.0 Prediction:0.0)
(0:0.9884899803606306)(1:0.011510019639369528)(Actual:0.0 Prediction:0.0)
(0:0.9885838957058696)(1:0.011416104294130458)(Actual:0.0 Prediction:0.0)
(0:0.9886770466322342)(1:0.011322953367765776)(Actual:0.0 Prediction:0.0)
(0:0.9870913229268679)(1:0.012908677073132284)(Actual:1.0 Prediction:0.0)
(0:0.9868781382588805)(1:0.013121861741119505)(Actual:1.0 Prediction:0.0)
(0:0.986661444476744)(1:0.013338555523255982)(Actual:1.0 Prediction:0.0)
(0:0.9864411843906802)(1:0.013558815609319848)(Actual:1.0 Prediction:0.0)
(0:0.9862172999068877)(1:0.013782700093112332)(Actual:1.0 Prediction:0.0)
この分類子が機能しない理由を誰かが説明できますか?私が台無しにしたステップ、または私が欠けているステップはありますか?
ありがとう
あなたの評価方法が間違っているように私には思えます。次のようになります。
public double evaluate(double[] features, svm_model model)
{
svm_node[] nodes = new svm_node[features.length-1];
for (int i = 1; i < features.length; i++)
{
svm_node node = new svm_node();
node.index = i;
node.value = features[i];
nodes[i-1] = node;
}
int totalClasses = 2;
int[] labels = new int[totalClasses];
svm.svm_get_labels(model,labels);
double[] prob_estimates = new double[totalClasses];
double v = svm.svm_predict_probability(model, nodes, prob_estimates);
for (int i = 0; i < totalClasses; i++){
System.out.print("(" + labels[i] + ":" + prob_estimates[i] + ")");
}
System.out.println("(Actual:" + features[0] + " Prediction:" + v + ")");
return v;
}
これは、次のRコードのデータを使用してテストした上記の例のリワークです。 http://cbio.ensmp.fr/~jvert/svn/tutorials/practical/svmbasic/svmbasic_notes.pdf
import libsvm.*;
public class libsvmTest {
public static void main(String [] args) {
double[][] xtrain = ...
double[][] xtest = ...
double[][] ytrain = ...
double[][] ytest = ...
svm_model m = svmTrain(xtrain,ytrain);
double[] ypred = svmPredict(xtest, m);
for (int i = 0; i < xtest.length; i++){
System.out.println("(Actual:" + ytest[i][0] + " Prediction:" + ypred[i] + ")");
}
}
static svm_model svmTrain(double[][] xtrain, double[][] ytrain) {
svm_problem prob = new svm_problem();
int recordCount = xtrain.length;
int featureCount = xtrain[0].length;
prob.y = new double[recordCount];
prob.l = recordCount;
prob.x = new svm_node[recordCount][featureCount];
for (int i = 0; i < recordCount; i++){
double[] features = xtrain[i];
prob.x[i] = new svm_node[features.length];
for (int j = 0; j < features.length; j++){
svm_node node = new svm_node();
node.index = j;
node.value = features[j];
prob.x[i][j] = node;
}
prob.y[i] = ytrain[i][0];
}
svm_parameter param = new svm_parameter();
param.probability = 1;
param.gamma = 0.5;
param.nu = 0.5;
param.C = 100;
param.svm_type = svm_parameter.C_SVC;
param.kernel_type = svm_parameter.LINEAR;
param.cache_size = 20000;
param.eps = 0.001;
svm_model model = svm.svm_train(prob, param);
return model;
}
static double[] svmPredict(double[][] xtest, svm_model model)
{
double[] yPred = new double[xtest.length];
for(int k = 0; k < xtest.length; k++){
double[] fVector = xtest[k];
svm_node[] nodes = new svm_node[fVector.length];
for (int i = 0; i < fVector.length; i++)
{
svm_node node = new svm_node();
node.index = i;
node.value = fVector[i];
nodes[i] = node;
}
int totalClasses = 2;
int[] labels = new int[totalClasses];
svm.svm_get_labels(model,labels);
double[] prob_estimates = new double[totalClasses];
yPred[k] = svm.svm_predict_probability(model, nodes, prob_estimates);
}
return yPred;
}
}
出力は次のとおりです。
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:1.0 Prediction:1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
(Actual:-1.0 Prediction:-1.0)
LibSVMのJava実装の少しリファクタリングされたバージョンを作成しました。これは使いやすいかもしれません: https://github.com/syeedibnfaiz/libsvm-Java-kernel 。 Demo.Javaクラスを見て、その使用方法を確認してください。