PermutationFeatureImportance返回的功能比模型提供的功能更多

时间:2019-06-21 06:53:18

标签: c# .net machine-learning ml.net

我有一个具有35个功能的模型。将字符串值转换为向量后,我最终得到了具有54个特征的转换模型。

训练后,我想评估特征权重,但是最终我得到了比可用特征多得多的特征(104次,所以两次)。找到相应要素列的结果是“索引超出范围”。

PFI

var featureImportanceMetrics =
                permutationFeatureImportance
                    .Select((metric, index) => new { index, metric.RSquared })
                    .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

任何人都可以告诉我我做错了什么,还是我刚刚发现ML.Net库中的缺陷?

(请注意,我注释了一些算法,因为我想使用BinaryClassification进行播放,但无法使其正常运行且没有错误)

代码

 public class Employee
        {
            [LoadColumn(0)]
            public float Age { get; set; }
            [LoadColumn(1)]
            [ColumnName("Label")]
            public float Attrition { get; set; }
            [LoadColumn(2)]
            public string BusinessTravel { get; set; }
            [LoadColumn(3)]
            public float DailyRate { get; set; }
            [LoadColumn(4)]
            public string Department { get; set; }
            [LoadColumn(5)]
            public float DistanceFromHome { get; set; }
            [LoadColumn(6)]
            public float Education { get; set; }
            [LoadColumn(7)]
            public string EducationField { get; set; }
            [LoadColumn(8)]
            public float EmployeeCount { get; set; }
            [LoadColumn(9)]
            public float EmployeeNumber { get; set; }
            [LoadColumn(10)]
            public float EnvironmentSatisfaction { get; set; }
            [LoadColumn(11)]
            public string Gender { get; set; }
            [LoadColumn(12)]
            public float HourlyRate { get; set; }
            [LoadColumn(13)]
            public float JobInvolvement { get; set; }
            [LoadColumn(14)]
            public float JobLevel { get; set; }
            [LoadColumn(15)]
            public string JobRole { get; set; }
            [LoadColumn(16)]
            public float JobSatisfaction { get; set; }
            [LoadColumn(17)]
            public string MaritalStatus { get; set; }
            [LoadColumn(18)]
            public float MonthlyIncome { get; set; }
            [LoadColumn(19)]
            public float MonthlyRate { get; set; }
            [LoadColumn(20)]
            public float NumCompaniesWorked { get; set; }
            [LoadColumn(21)]
            public string Over18 { get; set; }
            [LoadColumn(22)]
            public string OverTime { get; set; }
            [LoadColumn(23)]
            public float PercentSalaryHike { get; set; }
            [LoadColumn(24)]
            public float PerformanceRating { get; set; }
            [LoadColumn(25)]
            public float RelationshipSatisfaction { get; set; }
            [LoadColumn(26)]
            public float StandardHours { get; set; }
            [LoadColumn(27)]
            public float StockOptionLevel { get; set; }
            [LoadColumn(28)]
            public float TotalWorkingYears { get; set; }
            [LoadColumn(29)]
            public float TrainingTimesLastYear { get; set; }
            [LoadColumn(30)]
            public float WorkLifeBalance { get; set; }
            [LoadColumn(31)]
            public float YearsAtCompany { get; set; }
            [LoadColumn(32)]
            public float YearsInCurrentRole { get; set; }
            [LoadColumn(33)]
            public float YearsSinceLastPromotion { get; set; }
            [LoadColumn(34)]
            public float YearsWithCurrManager { get; set; }
        }
        public class EmployeeTransformed
        {
            public float Age { get; set; }
            [ColumnName("Label")]
            public float Attrition { get; set; }
            public float[] BusinessTravel { get; set; }
            public float DailyRate { get; set; }
            public float[] Department { get; set; }
            public float DistanceFromHome { get; set; }
            public float Education { get; set; }
            public float[] EducationField { get; set; }
            public float EmployeeCount { get; set; }
            public float EmployeeNumber { get; set; }
            public float EnvironmentSatisfaction { get; set; }
            public float[] Gender { get; set; }
            public float HourlyRate { get; set; }
            public float JobInvolvement { get; set; }
            public float JobLevel { get; set; }
            public float[] JobRole { get; set; }
            public float JobSatisfaction { get; set; }
            public float[] MaritalStatus { get; set; }
            public float MonthlyIncome { get; set; }
            public float MonthlyRate { get; set; }
            public float NumCompaniesWorked { get; set; }
            public float[] Over18 { get; set; }
            public float[] OverTime { get; set; }
            public float PercentSalaryHike { get; set; }
            public float PerformanceRating { get; set; }
            public float RelationshipSatisfaction { get; set; }
            public float StandardHours { get; set; }
            public float StockOptionLevel { get; set; }
            public float TotalWorkingYears { get; set; }
            public float TrainingTimesLastYear { get; set; }
            public float WorkLifeBalance { get; set; }
            public float YearsAtCompany { get; set; }
            public float YearsInCurrentRole { get; set; }
            public float YearsSinceLastPromotion { get; set; }
            public float YearsWithCurrManager { get; set; }

        }

        public ActionResult Turnover()
        {
            MLContext mlContext = new MLContext();

            var _appPath = AppDomain.CurrentDomain.BaseDirectory;
            //var _dataPath = Path.Combine(_appPath, "Datasets", "WA_Fn-UseC_-HR-Employee-Attrition.csv");
            var _dataPath = Path.Combine(_appPath, "Datasets", "attrition_small_dataset.csv");

            // Load data from file
            IDataView dataView = mlContext.Data.LoadFromTextFile<Employee>(_dataPath, separatorChar: ',', hasHeader: true);
            var a = mlContext.Data.CreateEnumerable<Employee>(dataView, true).ToList();

            // Define categorical transform estimator
            var categoricalEstimator = mlContext.Transforms.Categorical.OneHotEncoding("BusinessTravel")
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("Department"))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("EducationField"))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("Gender"))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("JobRole"))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("MaritalStatus"))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("Over18"))
            .Append(mlContext.Transforms.Categorical.OneHotEncoding("OverTime"));
            IDataView transformedData = categoricalEstimator.Fit(dataView).Transform(dataView);

            // Split into train and test dataset
            DataOperationsCatalog.TrainTestData dataSplit = mlContext.Data.TrainTestSplit(transformedData, testFraction: 0.2);
            IDataView trainData = dataSplit.TrainSet;
            IDataView testData = dataSplit.TestSet;

            // Get the column names of input features.
            string[] featureColumnNames =
                trainData.Schema
                    .Select(column => column.Name)
                    .Where(columnName => columnName != "Label").ToArray();

            // Define estimator with data pre-processing steps
            IEstimator<ITransformer> dataPrepEstimator =
                mlContext.Transforms.Concatenate("Features", featureColumnNames)
                    .Append(mlContext.Transforms.NormalizeMinMax("Features"));

            IDataView preprocessedTrainData = dataPrepEstimator.Fit(trainData).Transform(trainData);
            var e = mlContext.Data.CreateEnumerable<EmployeeTransformed>(preprocessedTrainData, true).ToList();

            /*
            //  Define Stochastic Dual Coordinate Ascent machine learning estimator
            //var sdcaEstimator = mlContext.Regression.Trainers.Sdca(labelColumnName: "Age", featureColumnName: "Features");
            //var sdcaEstimator = mlContext.Regression.Trainers.Sdca(labelColumnName: "Attrition", maximumNumberOfIterations: 100);
            //var sdcaEstimator = mlContext.BinaryClassification.Trainers.FastTree(labelColumnName : "Attrition", featureColumnName : "Features", numberOfLeaves: 50, numberOfTrees: 50, minimumExampleCountPerLeaf: 20);
            */
            var sdcaEstimator = mlContext.Regression.Trainers.Sdca();

            // Train machine learning model
            var sdcaModel = sdcaEstimator.Fit(preprocessedTrainData);

            // Explain the model with Permutation Feature Importance (PFI)
            ImmutableArray<RegressionMetricsStatistics> permutationFeatureImportance =
            mlContext
                .Regression
                .PermutationFeatureImportance(sdcaModel, preprocessedTrainData, permutationCount: 3);

            // Order features by importance
            var featureImportanceMetrics =
                permutationFeatureImportance
                    .Select((metric, index) => new { index, metric.RSquared })
                    .OrderByDescending(myFeatures => Math.Abs(myFeatures.RSquared.Mean));

            var line = "Feature\tPFI <br>";

            var z = featureColumnNames;
            foreach (var feature in featureImportanceMetrics)
            {
                line += $"{featureColumnNames[feature.index],-20}|\t{feature.RSquared.Mean:F6} <br>";
            }

            return Content(line);
        }

1 个答案:

答案 0 :(得分:0)

我解决了我的问题,希望这对其他人也有帮助。 ML.net文档主要侧重于快乐的流程,对框架的各种更改使各种代码示例不再更新。

转换内容后,我使用转换后的架构而不是原始架构。我必须过滤掉标签,功能列。对于每个字符串属性,ML.net都会将其转换为Int32和一个Vector(float []),但具有属性“ IsHidden”,因此您也可以过滤掉它们,并留下所需的属性。

var categoricalEstimator = mlContext.Transforms.Categorical.OneHotEncoding("BusinessTravel");
dataView = categoricalEstimator.Fit(dataView).Transform(dataView);

DataOperationsCatalog.TrainTestData dataSplit = mlContext.Data.TrainTestSplit(dataView, testFraction: 0.2);
IDataView trainData = dataSplit.TrainSet;
IDataView testData = dataSplit.TestSet;

var pipeline = mlContext.Transforms.Concatenate("Features", featureColumns)
.Append(mlContext.Transforms.NormalizeMinMax("Features"))                    .Append(mlContext.BinaryClassification.Trainers.SdcaLogisticRegression());

var model = pipeline.Fit(trainData);
var transformedData = model.Transform(trainData);
var linearPredictor = model.LastTransformer;

var permutationMetrics = mlContext.BinaryClassification.PermutationFeatureImportance(linearPredictor, transformedData, permutationCount: 30);

var sortedIndices = permutationMetrics.Select((metrics, index) => new { index, metrics.AreaUnderRocCurve })
                .OrderByDescending(feature => Math.Abs(feature.AreaUnderRocCurve.Mean))
                .Select(feature => feature.index);


            var sb = new System.Text.StringBuilder();

            // Calculate metrics of the model on the test data.
            var trainedModelMetrics = mlContext.BinaryClassification.Evaluate(model.Transform(testData), labelColumnName: "Label");

            sb.Append("<h1>Binary Classification Model, Predicting Employee Turnover</h1>");
            sb.Append(String.Format("<h3>Accuracy:{0}</h3>",trainedModelMetrics.Accuracy));
            sb.Append(String.Format("<h3>F1Score:{0}</h3>", trainedModelMetrics.F1Score));

            sb.Append("<table border=1><thead><tr><th>Feature</th><th>Model Weight</th><th>Change in AUC</th><th>95% Confidence in the Mean Change in AUC</th></tr></thead><tbody>");
            var auc = permutationMetrics.Select(x => x.AreaUnderRocCurve).ToArray();

            foreach (int i in sortedIndices)
            {
                if (transformedData.Schema[i].IsHidden || transformedData.Schema[i].Name == "Label" || transformedData.Schema[i].Name == "Features")
                {
                    continue;
                }

                var s = String.Format("<tr><td>{0}</td><td>{1:0.00}</td><td>{2:G4}</td><td>{3:G4}</td></tr>",
                    transformedData.Schema[i].Name,
                    linearPredictor.Model.SubModel.Weights[i],
                    auc[i].Mean,
                    1.96 * auc[i].StandardError);
                sb.Append(s);
            }
            sb.Append("</tbody></table>");



            return Content(sb.ToString());