如何计算两个重叠线性数据集之间的点?

时间:2014-08-06 21:03:30

标签: arrays algorithm sorting machine-learning classification

我有两组重叠的数据(见下图)。我需要找到这些集合之间的点,人们会猜测未知数据点属于特定类别。

如果我有一个新的数据点(比方说5000),并且不得不打赌 $$$ 是否属于A组或B组,我该如何计算让我最肯定的赌注?

参见下面的样本数据集和附图,其中这些组之间的近似点(通过眼睛计算)。

GROUP A
[385,515,975,1136,2394,2436,4051,4399,4484,4768,4768,4849,4856,4954,5020,5020,5020,5020,5020,5020,5020,5020,5020,5052,5163,5200,5271,5421,5421,5442,5746,5765,5903,5992,5992,6046,6122,6205,6208,6239,6310,6360,6416,6512,6536,6543,6581,6609,6696,6699,6752,6796,6806,6855,6859,6886,6906,6911,6923,6953,7016,7072,7086,7089,7110,7232,7278,7293,7304,7309,7348,7367,7378,7380,7419,7453,7454,7492,7506,7549,7563,7721,7723,7731,7745,7750,7751,7783,7791,7813,7813,7814,7818,7833,7863,7875,7886,7887,7902,7907,7935,7942,7942,7948,7973,7995,8002,8013,8013,8015,8024,8025,8030,8038,8041,8050,8056,8060,8064,8071,8081,8082,8085,8093,8124,8139,8142,8167,8179,8204,8214,8223,8225,8247,8248,8253,8258,8264,8265,8265,8269,8277,8278,8289,8300,8312,8314,8323,8328,8334,8363,8369,8390,8397,8399,8399,8401,8436,8442,8456,8457,8471,8474,8483,8503,8511,8516,8533,8560,8571,8575,8583,8592,8593,8626,8635,8635,8644,8659,8685,8695,8695,8702,8714,8715,8717,8729,8732,8740,8743,8750,8756,8772,8772,8778,8797,8828,8840,8840,8843,8856,8865,8874,8876,8878,8885,8887,8893,8896,8905,8910,8955,8970,8971,8991,8995,9014,9016,9042,9043,9063,9069,9104,9106,9107,9116,9131,9157,9227,9359,9471]

GROUP B
[12,16,29,32,33,35,39,42,44,44,44,45,45,45,45,45,45,45,45,45,47,51,51,51,57,57,60,61,61,62,71,75,75,75,75,75,75,76,76,76,76,76,76,79,84,84,85,89,93,93,95,96,97,98,100,100,100,100,100,102,102,103,105,108,109,109,109,109,109,109,109,109,109,109,109,109,110,110,112,113,114,114,116,116,118,119,120,121,122,124,125,128,129,130,131,132,133,133,137,138,144,144,146,146,146,148,149,149,150,150,150,151,153,155,157,159,164,164,164,167,169,170,171,171,171,171,173,174,175,176,176,177,178,179,180,181,181,183,184,185,187,191,193,199,203,203,205,205,206,212,213,214,214,219,224,224,224,225,225,226,227,227,228,231,234,234,235,237,240,244,245,245,246,246,246,248,249,250,250,251,255,255,257,264,264,267,270,271,271,281,282,286,286,291,291,292,292,294,295,299,301,302,304,304,304,304,304,306,308,314,318,329,340,344,345,356,359,363,368,368,371,375,379,386,389,390,392,394,408,418,438,440,456,456,458,460,461,467,491,503,505,508,524,557,558,568,591,609,622,656,665,668,687,705,728,817,839,965,1013,1093,1126,1512,1935,2159,2384,2424,2426,2484,2738,2746,2751,3006,3184,3184,3184,3184,3184,4023,5842,5842,6502,7443,7781,8132,8237,8501]

Data Plot

Array Stats:

                      Group A    Group B
Total Numbers             231        286
Mean                  7534.71     575.56
Standard Deviation    1595.04    1316.03

5 个答案:

答案 0 :(得分:4)

这可以被视为具有单个连续预测器的二元分类问题。您可以将此视为拟合一个简单决策树,找到阈值t,以便在值>> t时预测A组。

为此,您选择最小化结果分割的熵的t。假设您对某些t有以下计数:

| | <t | >= t | | Group A | X | Y | | Group B | Z | W |

&lt;的熵。分裂是 - (X /(X + Z))* log(X /(X + Z)) - (Z /(X + Z))* log(Z /(X + Z))。 &gt; = split的熵是 - (Y /(Y + W))* log(Y /(Y + W)) - (W /(Y + W))* log(W /(Y + W) )。这看起来比它更脏;它只是分组中每个组的比例p的-p * log(p)之和。

您取两者的加权平均值,按分割的整体大小加权。所以第一项由(X + Z)/(X + Y + Z + W)加权,另一项由(Y + W)/(X + Y + Z + W)加权。

答案 1 :(得分:4)

我只想指出使用密度估算的另一种方法。

根据您的数据,使用pdf很容易适应平滑的kernel density estimation。下面的python代码显示了如何在scipy中使用kde模块。

from scipy.stats.kde import gaussian_kde
from numpy import linspace
import matplotlib.pyplot as plt
data1 = [385,515,975,1136,2394,2436,4051,4399,4484,4768,4768,4849,4856,4954,5020,5020,5020,5020,5020,5020,5020,5020,5020,5052,5163,5200,5271,5421,5421,5442,5746,5765,5903,5992,5992,6046,6122,6205,6208,6239,6310,6360,6416,6512,6536,6543,6581,6609,6696,6699,6752,6796,6806,6855,6859,6886,6906,6911,6923,6953,7016,7072,7086,7089,7110,7232,7278,7293,7304,7309,7348,7367,7378,7380,7419,7453,7454,7492,7506,7549,7563,7721,7723,7731,7745,7750,7751,7783,7791,7813,7813,7814,7818,7833,7863,7875,7886,7887,7902,7907,7935,7942,7942,7948,7973,7995,8002,8013,8013,8015,8024,8025,8030,8038,8041,8050,8056,8060,8064,8071,8081,8082,8085,8093,8124,8139,8142,8167,8179,8204,8214,8223,8225,8247,8248,8253,8258,8264,8265,8265,8269,8277,8278,8289,8300,8312,8314,8323,8328,8334,8363,8369,8390,8397,8399,8399,8401,8436,8442,8456,8457,8471,8474,8483,8503,8511,8516,8533,8560,8571,8575,8583,8592,8593,8626,8635,8635,8644,8659,8685,8695,8695,8702,8714,8715,8717,8729,8732,8740,8743,8750,8756,8772,8772,8778,8797,8828,8840,8840,8843,8856,8865,8874,8876,8878,8885,8887,8893,8896,8905,8910,8955,8970,8971,8991,8995,9014,9016,9042,9043,9063,9069,9104,9106,9107,9116,9131,9157,9227,9359,9471]
data2 = [12,16,29,32,33,35,39,42,44,44,44,45,45,45,45,45,45,45,45,45,47,51,51,51,57,57,60,61,61,62,71,75,75,75,75,75,75,76,76,76,76,76,76,79,84,84,85,89,93,93,95,96,97,98,100,100,100,100,100,102,102,103,105,108,109,109,109,109,109,109,109,109,109,109,109,109,110,110,112,113,114,114,116,116,118,119,120,121,122,124,125,128,129,130,131,132,133,133,137,138,144,144,146,146,146,148,149,149,150,150,150,151,153,155,157,159,164,164,164,167,169,170,171,171,171,171,173,174,175,176,176,177,178,179,180,181,181,183,184,185,187,191,193,199,203,203,205,205,206,212,213,214,214,219,224,224,224,225,225,226,227,227,228,231,234,234,235,237,240,244,245,245,246,246,246,248,249,250,250,251,255,255,257,264,264,267,270,271,271,281,282,286,286,291,291,292,292,294,295,299,301,302,304,304,304,304,304,306,308,314,318,329,340,344,345,356,359,363,368,368,371,375,379,386,389,390,392,394,408,418,438,440,456,456,458,460,461,467,491,503,505,508,524,557,558,568,591,609,622,656,665,668,687,705,728,817,839,965,1013,1093,1126,1512,1935,2159,2384,2424,2426,2484,2738,2746,2751,3006,3184,3184,3184,3184,3184,4023,5842,5842,6502,7443,7781,8132,8237,8501]

pdf1 = gaussian_kde(data1)
pdf2 = gaussian_kde(data2)

x = linspace(0, 9500, 1000)
plt.plot(x, pdf1(x),'r')
plt.plot(x, pdf2(x),'g')
plt.legend(['data1 pdf', 'data2 pdf'])

plt.show()

enter image description here

在图中,绿色是第二个数据集的pdf;红色是第一个数据集的pdf。显然,决策边界是穿过绿色与红色相交的点的垂直线。

要以数字方式找到边界,我们可以执行类似下面的操作(假设只有一个交点,否则没有意义):

min_diff = 10000
min_diff_x = -1
for x in linspace(3600, 4000, 400):
    diff = abs(pdf1(x) - pdf2(x))
    if diff < min_diff:
        min_diff = diff
        min_diff_x = x
print min_diff, min_diff_x

我们发现边界大约位于3762处。

如果两个pdf有多个交叉点,为了预测数据点x属于哪个类,我们计算pdf1(x)pdf2(x),最大值为类最大限度地减少贝叶斯风险。有关贝叶斯风险主题和预测误差概率评估的更多详细信息,请参阅here

下面举例说明了一个实际包含三个pdf的示例,在任何查询点x,我们应该分别询问三个pdf并选择最大值为pdf(x)的那个作为预测类。

enter image description here

答案 2 :(得分:4)

通过合理的假设,良好的判别式是唯一的数据值,它导致分裂点左边的B概率密度区域等于A的右侧区域(反之亦然,这给出了相同的点)

找到这个的一个简单方法是将两个empiricle累积分布函数(CDF)计算为shown here并搜索它们以提供分割点。这是两个CDF总和为1的点。

简要说明,构建empiricle CDF只是对每个数据集进行排序并将数据用作x轴值。从左到右绘制曲线,从y = 0开始,并在每个x值处向上迈出1 / n步。对于x&lt; = data 1,这样的曲线从0渐近地上升到y = CDF(x)= 1,对于x> = data [n]。有一种稍微复杂的方法可以提供连续的逐步线性曲线而不是阶梯,这在某些假设下是真正CDF的更好的估计。在

注意上面的讨论只是为了提供直觉。 CDF由排序的数据阵列完美表示。不需要新的数据结构;即,x [i],i = 1,2,...,n是曲线达到y = i / n的x值。

使用两个CDF,R(x)和B(x),根据你的图表,你想要找到唯一的点x,使得| 1 - R(x) - B(x)|最小化(使用分段线性CDF,您将始终能够使其为零)。这可以通过二分查找很容易地完成。

这种方法的一个好处是你可以通过在排序集(平衡二进制搜索树)中维护两个CDF来使其动态化。随着点数的增加,很容易找到新的分界点。

有序集需要“订单统计”。 Here is a reference。我的意思是你需要能够查询有序集合以检索CDF中任何存储的x值的序数。这可以通过跳过列表和树来完成。

我编码了这个算法的一个变体。它使用分段CDF近似,但也允许在重复数据点处的“垂直步骤”。这在某种程度上使算法复杂化,但它并不太糟糕。然后我使用二分(而不是组合二分搜索)来找到分裂点。正常的二分算法需要修改以适应CDF中的垂直“步骤”。我认为我有这一切,但它经过了轻微的测试。

处理的一个边案例是数据集是否具有不相交的范围。这将在较高的下部和底部的顶部之间找到 a 点,这是一个完全有效的鉴别器。但是你可能想要做一些更好的事情,比如返回某种加权平均值。

请注意,如果您对数据可以实现的实际最小值和最大值有一个很好的概念,并且它们不会出现在数据中,您应该考虑添加它们以便CDF不会不经意间有偏见。

在您的示例数据上,代码生成4184.76,它看起来非常接近您在图表中选择的值(略低于最小和最大数据的中间值)。

注意我没有对数据进行排序,因为它已经是。排序肯定是必要的。

public class SplitData {

    // Return: i such that a[i] <= x < a[i+1] if i,i+1 in range
    // else -1 if x < a[0]
    // else a.length if x >= a[a.length - 1]
    static int hi_bracket(double[] a, double x) {
        if (x < a[0]) return -1;
        if (x >= a[a.length - 1]) return a.length;
        int lo = 0, hi = a.length - 1;
        while (lo + 1 < hi) {
            int mid = (lo + hi) / 2;
            if (x < a[mid])
                hi = mid;
            else 
                lo = mid;
        }
        return lo;
    }

    // Return: i such that a[i-1] < x <= a[i] if i-1,i in range
    // else -1 if x <= a[0]
    // else a.length if x > a[a.length - 1]
    static int lo_bracket(double[] a, double x) {
        if (x <= a[0]) return -1;
        if (x > a[a.length - 1]) return a.length;
        int lo = 0, hi = a.length - 1;
        while (lo + 1 < hi) {
            int mid = (lo + hi) / 2;
            if (x <= a[mid])
                hi = mid;
            else
                lo = mid;
        }
        return hi;
    }

    // Interpolate the CDF value for the data a at value x.  Returns a range.
    static void interpolate_cdf(double[] a, double x, double[] rtn) {
        int lo_i1 = lo_bracket(a, x);
        if (lo_i1 == -1) {
            rtn[0] = rtn[1] = 0;
            return;
        }
        int hi_i0 = hi_bracket(a, x);
        if (hi_i0 == a.length) {
            rtn[0] = rtn[1] = 1;
            return;
        }
        if (hi_i0 + 1 == lo_i1) {  // normal interpolation
            rtn[0] = rtn[1] 
                = (hi_i0 + (x - a[hi_i0]) / (a[lo_i1] - a[hi_i0])) 
                    / (a.length - 1);
            return;
        }
        // we're on a joint or step; return range answer
        rtn[0] = (double)lo_i1 / (a.length - 1);  
        rtn[1] = (double)hi_i0 / (a.length - 1);
        assert rtn[0] <= rtn[1];
    }

    // Find the data value where the two given data set's empirical CDFs
    // sum to 1. This is a good discrimination value for new data.
    // This deals with the case where there's a step in either or both CDFs.
    static double find_bisector(double[] a, double[] b) {
        assert a.length > 0;
        assert b.length > 0;
        double lo = Math.min(a[0], b[0]);
        double hi = Math.max(a[a.length - 1], b[b.length - 1]);
        double eps = (hi - lo) * 1e-7;
        double[] a_rtn = new double[2], b_rtn = new double[2];
        while (hi - lo > eps) {
            double mid = 0.5 * (lo + hi);
            interpolate_cdf(a, mid, a_rtn);
            interpolate_cdf(b, mid, b_rtn);
            if (1 < a_rtn[0] + b_rtn[0])
                hi = mid;
            else if (a_rtn[1] + b_rtn[1] < 1)
                lo = mid;
            else
                return mid;  // 1 is included in the interpolated range
        }
        return 0.5 * (lo + hi);
    }

    public static void main(String[] args) {
        double split = find_bisector(a, b);
        System.err.println("Split at x = " + split);
    }

    static final double[] a = {
        385, 515, 975, 1136, 2394, 2436, 4051, 4399, 4484, 4768, 4768, 4849,
        4856, 4954, 5020, 5020, 5020, 5020, 5020, 5020, 5020, 5020, 5020, 5052,
        5163, 5200, 5271, 5421, 5421, 5442, 5746, 5765, 5903, 5992, 5992, 6046,
        6122, 6205, 6208, 6239, 6310, 6360, 6416, 6512, 6536, 6543, 6581, 6609,
        6696, 6699, 6752, 6796, 6806, 6855, 6859, 6886, 6906, 6911, 6923, 6953,
        7016, 7072, 7086, 7089, 7110, 7232, 7278, 7293, 7304, 7309, 7348, 7367,
        7378, 7380, 7419, 7453, 7454, 7492, 7506, 7549, 7563, 7721, 7723, 7731,
        7745, 7750, 7751, 7783, 7791, 7813, 7813, 7814, 7818, 7833, 7863, 7875,
        7886, 7887, 7902, 7907, 7935, 7942, 7942, 7948, 7973, 7995, 8002, 8013,
        8013, 8015, 8024, 8025, 8030, 8038, 8041, 8050, 8056, 8060, 8064, 8071,
        8081, 8082, 8085, 8093, 8124, 8139, 8142, 8167, 8179, 8204, 8214, 8223,
        8225, 8247, 8248, 8253, 8258, 8264, 8265, 8265, 8269, 8277, 8278, 8289,
        8300, 8312, 8314, 8323, 8328, 8334, 8363, 8369, 8390, 8397, 8399, 8399,
        8401, 8436, 8442, 8456, 8457, 8471, 8474, 8483, 8503, 8511, 8516, 8533,
        8560, 8571, 8575, 8583, 8592, 8593, 8626, 8635, 8635, 8644, 8659, 8685,
        8695, 8695, 8702, 8714, 8715, 8717, 8729, 8732, 8740, 8743, 8750, 8756,
        8772, 8772, 8778, 8797, 8828, 8840, 8840, 8843, 8856, 8865, 8874, 8876,
        8878, 8885, 8887, 8893, 8896, 8905, 8910, 8955, 8970, 8971, 8991, 8995,
        9014, 9016, 9042, 9043, 9063, 9069, 9104, 9106, 9107, 9116, 9131, 9157,
        9227, 9359, 9471
    };
    static final double[] b = {
        12, 16, 29, 32, 33, 35, 39, 42, 44, 44, 44, 45, 45, 45, 45, 45, 45, 45,
        45, 45, 47, 51, 51, 51, 57, 57, 60, 61, 61, 62, 71, 75, 75, 75, 75, 75,
        75, 76, 76, 76, 76, 76, 76, 79, 84, 84, 85, 89, 93, 93, 95, 96, 97, 98,
        100, 100, 100, 100, 100, 102, 102, 103, 105, 108, 109, 109, 109, 109,
        109, 109, 109, 109, 109, 109, 109, 109, 110, 110, 112, 113, 114, 114,
        116, 116, 118, 119, 120, 121, 122, 124, 125, 128, 129, 130, 131, 132,
        133, 133, 137, 138, 144, 144, 146, 146, 146, 148, 149, 149, 150, 150,
        150, 151, 153, 155, 157, 159, 164, 164, 164, 167, 169, 170, 171, 171,
        171, 171, 173, 174, 175, 176, 176, 177, 178, 179, 180, 181, 181, 183,
        184, 185, 187, 191, 193, 199, 203, 203, 205, 205, 206, 212, 213, 214,
        214, 219, 224, 224, 224, 225, 225, 226, 227, 227, 228, 231, 234, 234,
        235, 237, 240, 244, 245, 245, 246, 246, 246, 248, 249, 250, 250, 251,
        255, 255, 257, 264, 264, 267, 270, 271, 271, 281, 282, 286, 286, 291,
        291, 292, 292, 294, 295, 299, 301, 302, 304, 304, 304, 304, 304, 306,
        308, 314, 318, 329, 340, 344, 345, 356, 359, 363, 368, 368, 371, 375,
        379, 386, 389, 390, 392, 394, 408, 418, 438, 440, 456, 456, 458, 460,
        461, 467, 491, 503, 505, 508, 524, 557, 558, 568, 591, 609, 622, 656,
        665, 668, 687, 705, 728, 817, 839, 965, 1013, 1093, 1126, 1512, 1935,
        2159, 2384, 2424, 2426, 2484, 2738, 2746, 2751, 3006, 3184, 3184, 3184,
        3184, 3184, 4023, 5842, 5842, 6502, 7443, 7781, 8132, 8237, 8501
    };
}

答案 3 :(得分:3)

您可以针对每组计算新点的Mahalanobis distance。新点具有最低距离的集合最有可能匹配。

  

Mahalanobis距离是PC Mahalanobis在1936年引入的点P与分布D之间距离的度量。1这是测量多少标准偏差的概念的多维概括P来自D的平均值。如果P处于D的平均值,则该距离为零,并且随着P偏离平均值而增长

由于您的空间是一维的,因此计算应简化为:

  1. 计算每个分布的标准差
  2. 计算每个分布的平均值
  3. 对于每个分布,计算该点远离分布均值的标准差数。

答案 4 :(得分:3)

您正在描述您正在寻找“决定边界”的一维statistical classification问题。您有很多选择可供选择:

  • 逻辑回归
  • 最近邻分类器
  • 支持向量机
  • 多层感知器
  • ...

但由于问题很简单(一维,两个分离良好的类),决策边界是一个相当空的区域,我怀疑没有繁重的统计方法会明显优于简单的基于眼睛的猜测。

相关问题