我一直在跟随ML .NET教程中的这个例子: https://github.com/dotnet/samples/tree/master/machine-learning/tutorials/GitHubIssueClassification
我构建了这个例子的自己的版本,它从.xlsx文件(不同的数据集)中读取数据,并将其分为训练集和测试集。它运行良好并能做出正确的预测,但我无论如何也搞不明白为什么当我将_testSet输入时,评估指标(所有参数)总是显示0。当我输入_trainSet时,它评估为1,这是预期的结果。
即使我将TestFraction设为0.5,它仍然评估为0。
using System;using System.Data;using System.Data.OleDb;using System.Collections.Generic;using System.Linq;using System.IO;using Microsoft.ML;namespace Test.Repository{ public class SearchEntry { [LoadColumn(0)] public string Topic { get; set; } [LoadColumn(1)] public string Subject { get; set; } } public class SearchPrediction { [ColumnName("PredictedLabel")] public string Topic; } public class Googler { private static string _appPath => Path.GetDirectoryName(Environment.GetCommandLineArgs()[0]); public string SourceExcel { get; set; } = @"..\..\..\..\Test.Repository\model\in_data.xlsx"; public string ModelSavePath { get; set; } = @"..\..\..\..\Test.Repository\model\model"; public double TestFraction { get; set; } = 0.2d; private static IDataView _trainingDataView; private static MLContext _mlContext; private static ITransformer _trainedModel; private static IEstimator<ITransformer> pipeline; private static PredictionEngine<SearchEntry, SearchPrediction> _predEngine; private static List<SearchEntry> _trainSet; private static List<SearchEntry> _testSet; public void LoadModelData() { _mlContext = new MLContext(seed: 0); var dt = Heplers.Excel.Query(SourceExcel, "SELECT * FROM [data$]"); var searchEntries = dt.AsEnumerable() .Select(r => new SearchEntry { Topic = (string)r["Topic"], Subject = (string)r["Subject"] }); var dataview = _mlContext.Data.LoadFromEnumerable(searchEntries); var split = _mlContext.Data .TrainTestSplit(dataview, testFraction: TestFraction, samplingKeyColumnName: "Topic"); _trainSet = _mlContext.Data .CreateEnumerable<SearchEntry>(split.TrainSet, reuseRowObject: false).ToList(); _testSet = _mlContext.Data .CreateEnumerable<SearchEntry>(split.TestSet, reuseRowObject: false).ToList(); _trainingDataView = _mlContext.Data.LoadFromEnumerable<SearchEntry>(_trainSet); } public void ProcessData() { Console.WriteLine($"=============== Processing Data ==============="); pipeline = _mlContext.Transforms.Conversion.MapValueToKey(inputColumnName: "Topic", outputColumnName: "Label") .Append(_mlContext.Transforms.Text.FeaturizeText(inputColumnName: "Subject", outputColumnName: "SubjectFeaturized")) .Append(_mlContext.Transforms.Concatenate("Features", "SubjectFeaturized")) .AppendCacheCheckpoint(_mlContext); Console.WriteLine($"=============== Finished Processing Data ==============="); } public void BuildAndTrainModel() { var trainingPipeline = pipeline .Append(_mlContext.MulticlassClassification.Trainers.SdcaNonCalibrated("Label", "Features")) .Append(_mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel")); Console.WriteLine($"=============== Training the model ==============="); _trainedModel = trainingPipeline.Fit(_trainingDataView); Console.WriteLine($"=============== Finished Training the model Ending time: {DateTime.Now.ToString()} ==============="); } public void Evaluate() { Console.WriteLine($"=============== Evaluating to get model's accuracy metrics - Starting time: {DateTime.Now.ToString()} ==============="); var testDataView = _mlContext.Data.LoadFromEnumerable<SearchEntry>(_testSet); var testMetrics = _mlContext.MulticlassClassification.Evaluate(_trainedModel.Transform(testDataView)); Console.WriteLine($"=============== Evaluating to get model's accuracy metrics - Ending time: {DateTime.Now.ToString()} ==============="); Console.WriteLine($"*************************************************************************************************************"); Console.WriteLine($"* Metrics for Multi-class Classification model - Test Data "); Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); Console.WriteLine($"* MicroAccuracy: {testMetrics.MicroAccuracy:0.###}"); Console.WriteLine($"* MacroAccuracy: {testMetrics.MacroAccuracy:0.###}"); Console.WriteLine($"* LogLoss: {testMetrics.LogLoss:#.###}"); Console.WriteLine($"* LogLossReduction: {testMetrics.LogLossReduction:#.###}"); Console.WriteLine($"*************************************************************************************************************"); } }}
输出如下:
************************************************************************************************************** Metrics for Multi-class Classification model - Test Data *------------------------------------------------------------------------------------------------------------* MicroAccuracy: 0* MacroAccuracy: 0* LogLoss: * LogLossReduction: NaN*************************************************************************************************************
回答:
进行了以下更改:
var split = _mlContext.Data .TrainTestSplit(dataview, testFraction: TestFraction, samplingKeyColumnName: "Topic");
改为
var split = _mlContext.Data .TrainTestSplit(dataview, testFraction: TestFraction);
使用samplingKeyColumnName: “Topic”时,我的测试集只有2个唯一的主题,去掉它后有6个。因此,评估指标较差。
但我仍然不满意结果。我总共有10个唯一的主题,感觉测试集必须至少包含每个主题的一些条目。Microsoft.ML的TrainTestSplit似乎并不能保证这一点。
我编写了一个自定义的分割器:
private (List<SearchEntry> TrainSet, List<SearchEntry> TestSet) TrainTestSplit(List<SearchEntry> searchEntries, double testFraction) { var rand = new Random(); var testSet = searchEntries.AsEnumerable() .Select(r => new { Random = rand.Next(), Entry = r }) .OrderBy(r => r.Random) .Select(r => r.Entry) .GroupBy(r => r.Topic) .Select(r => r.Take((int)Math.Ceiling(searchEntries.Where(e => e.Topic == r.Key).Count() * testFraction))) .SelectMany(r => r) .ToList(); var trainSet = searchEntries.Except(testSet).ToList(); return (trainSet, testSet); }