我使用Visual图形助手创建了模型。工作室为我生成了一个项目和代码。一切都适合我,除了输入时,我必须以字符串的形式给出文件的路径。我能用位图格式将图像传送到神经网络吗?
我发现了许多例子,但它们都与我的代码不同,就像我有一个不同的版本。我正在尝试修改我找到的代码,但我遇到了各种错误。
请解释一下现在如何用Microsoft.ML 1.5来完成这个任务?如何调整下面生成的代码以使用位图图像(而不是路径输入)?
我的ModelInput.cs
// This file was auto-generated by ML.NET Model Builder.
using Microsoft.ML.Data;
namespace MLTestAppML.Model
{
public class ModelInput
{
[ColumnName("Label"), LoadColumn(0)]
public string Label { get; set; }
[ColumnName("ImageSource"), LoadColumn(1)]
public string ImageSource { get; set; }
}
}我的ModelBuilder.cs
// This file was auto-generated by ML.NET Model Builder.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using MLTestAppML.Model;
using Microsoft.ML.Vision;
namespace MLTestAppML.ConsoleApp
{
public static class ModelBuilder
{
private static string TRAIN_DATA_FILEPATH = @"C:\Users\aaa\AppData\Local\Temp\e43005d1-d83d-4f35-ab8d-7dbc3e693583.tsv";
private static string MODEL_FILEPATH = @"C:\Users\aaa\AppData\Local\Temp\MLVSTools\MLTestAppML\MLTestAppML.Model\MLModel.zip";
// Create MLContext to be shared across the model creation workflow objects
// Set a random seed for repeatable/deterministic results across multiple trainings.
private static MLContext mlContext = new MLContext(seed: 1);
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: '\t',
allowQuoting: true,
allowSparse: false);
// Build training pipeline
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
// Train Model
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
// Evaluate quality of Model
Evaluate(mlContext, trainingDataView, trainingPipeline);
// Save model
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
}
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
{
// Data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Label")
.Append(mlContext.Transforms.LoadRawImageBytes("ImageSource_featurized", null, "ImageSource"))
.Append(mlContext.Transforms.CopyColumns("Features", "ImageSource_featurized"));
// Set the training algorithm
var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(new ImageClassificationTrainer.Options() { LabelColumnName = "Label", FeatureColumnName = "Features" })
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
var trainingPipeline = dataProcessPipeline.Append(trainer);
return trainingPipeline;
}
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
Console.WriteLine("=============== Training model ===============");
ITransformer model = trainingPipeline.Fit(trainingDataView);
Console.WriteLine("=============== End of training process ===============");
return model;
}
private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
// in order to evaluate and get the model's accuracy metrics
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label");
PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model ===============");
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}
public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
return fullPath;
}
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"* Metrics for multi-class classification model ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
{
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
}
Console.WriteLine($"************************************************************");
}
public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
var microAccuracyAverage = microAccuracyValues.Average();
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
var macroAccuracyAverage = macroAccuracyValues.Average();
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
var logLossAverage = logLossValues.Average();
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
var logLossReductionAverage = logLossReductionValues.Average();
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"* Metrics for Multi-class Classification model ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");
}
public static double CalculateStandardDeviation(IEnumerable<double> values)
{
double average = values.Average();
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
return standardDeviation;
}
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
{
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
return confidenceInterval95;
}
}
}发布于 2022-08-26 19:27:28
更新
在较新版本的ModelBuilder上,它允许从byte[]中进行预测,确保您有16.13.10.2241902版本。
你应该重新训练你的模型,然后你就可以沿着这条路跑了。
//Load sample data
var imageBytes = File.ReadAllBytes(@"C:\Users\ivmendoza\Documents\MIL CoPilot Workspaces\AlignersInBag\2 Condtions DS Tiny\NonTypical\MX2-RND40K_20220725020628_PIDV_12443583U22N_12443583L22N.JPG");
AlignersCounterModel.ModelInput sampleData = new AlignersCounterModel.ModelInput()
{
ImageSource = imageBytes,
};
//Load model and predict output
var result = AlignersCounterModel.Predict(sampleData);
输入/输出类
public class ModelInput
{
[ColumnName(@"Label")]
public string Label { get; set; }
[ColumnName(@"ImageSource")]
public string ImageSource { get; set; }
}
public class ModelInputBytes
{
[ColumnName(@"Label")]
public string Label { get; set; }
[ColumnName(@"Features")]
public byte[] ImageBytes { get; set; }
}
public class ModelOutput
{
[ColumnName("PredictedLabel")]
public string Prediction { get; set; }
public float[] Score { get; set; }
}
消费代码
public static ModelOutput Predict(ModelInput input)
{
MLContext mlContext = new MLContext();
// Load model & create prediction engine
ITransformer mlModel = mlContext.Model.Load(MLNetModelPath, out var modelInputSchema);
ITransformer dataPreProcessTransform = LoadImageFromFileTransformer(input, mlContext);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(dataPreProcessTransform.Append(mlModel));
ModelOutput result = predEngine.Predict(input);
return result;
}
public static ITransformer LoadImageFromFileTransformer(ModelInput input, MLContext mlContext)
{
var dataPreProcess = mlContext.Transforms.Conversion.MapValueToKey(@"Label", @"Label")
.Append(mlContext.Transforms.LoadRawImageBytes(@"ImageSource_featurized", @"ImageSource"))
.Append(mlContext.Transforms.CopyColumns(@"Features", @"ImageSource_featurized"));
var dataView = mlContext.Data.LoadFromEnumerable(new[] { input });
var dataPreProcessTransform = dataPreProcess.Fit(dataView);
return dataPreProcessTransform;
}
public static ModelOutput PredictFromBytes(ModelInputBytes input)
{
MLContext mlContext = new MLContext();
// Load model & create prediction engine
ITransformer mlModel = mlContext.Model.Load(MLNetModelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInputBytes, ModelOutput>(mlModel);
ModelOutput result = predEngine.Predict(input);
return result;
}
https://stackoverflow.com/questions/69159459
复制相似问题