图片(20*25)数据特征为每个像素点用0或者1表示,形成500位的数据,类别为数字2-9和英文a-z共33类。
装载数据后采用:mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Label", featureColumn: "Features");进行模型训练。
每一类500个样本时,测试集正确率75%。
每一类1000个样本时,测试集正确率70%.
每一类3000个样本时,测试集正确率26%
为什么数据越多正确率急剧下降呢?
代码如下:
//训练模型:
MLContext mlContext = new MLContext(seed: 1);
TextLoader textLoader = mlContext.Data.CreateTextReader(Host.columns,
hasHeader: true,
separatorChar: ',');
IDataView fullData = textLoader.Read(AppDomain.CurrentDomain.BaseDirectory + "\\trainData.txt");
(IDataView trainingDataView, IDataView testingDataView) = mlContext.Clustering.TrainTestSplit(fullData, testFraction: 0.2);
var dataProcessPipeline = mlContext.Transforms.Concatenate("Features", Host.featuresColumnNames);
var transformer = dataProcessPipeline.Fit(trainingDataView);
var transformedData = transformer.Transform(trainingDataView);
var transformer1 = dataProcessPipeline.Fit(trainingDataView);
var transformedData1 = transformer.Transform(trainingDataView);
SdcaMultiClassTrainer sdcaMultiClassTrainer = mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Label", featureColumn: "Features");
var trainingPipeline = dataProcessPipeline.Append(sdcaMultiClassTrainer);
ITransformer trainedModel = trainingPipeline.Fit(trainingDataView);
var predictions = trainedModel.Transform(testingDataView);
var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score");
using (var fs = new FileStream(AppDomain.CurrentDomain.BaseDirectory + "\\imageModel.zip", FileMode.Create, FileAccess.Write, FileShare.Write))
{
mlContext.Model.Save(trainedModel, fs);
}
//测试集校验
MLContext mlContext = new MLContext(seed: 1);
ITransformer trainedModel;
using (var stream = new FileStream(AppDomain.CurrentDomain.BaseDirectory + "\\imageModel.zip", FileMode.Open, FileAccess.Read, FileShare.Read))
{
trainedModel = mlContext.Model.Load(stream);
}
var predEngine = trainedModel.CreatePredictionEngine<ImageData, ImageResult>(mlContext);
float minScore = 0.5f;
float totalCount = 0;
float rightCount = 0;
foreach (ImageData id in imageDatas)
{
ImageResult ir = predEngine.Predict(id);
List<float> results = ir.Score.ToList();
float max = results.Max();
int resultIndex = results.IndexOf(max);
if (resultIndex == id.Label)
{
if (max < minScore)
{
minScore = max;
}
rightCount++;
}
totalCount++;
}
Console.WriteLine("正确率:" + Convert.ToSingle(rightCount / totalCount));
Console.ReadKey();