使用python删除曲线下方的数据点

时间:2011-10-31 19:52:27

标签: python matplotlib scipy scientific-computing

我需要将一些理论数据与python中的实际数据进行比较。 理论数据来自解决方程式。 为了改进比较,我想删除远离理论曲线的数据点。我的意思是,我想删除图中红色虚线下方和上方的点(用matplotlib制作)。 Data points and theoretical curves

理论曲线和数据点都是不同长度的数组。

我可以尝试以粗略的方式移除这些点,例如:可以使用以下方法检测第一个上点:

data2[(data2.redshift<0.4)&data2.dmodulus>1]
rec.array([('1997o', 0.374, 1.0203223485103787, 0.44354759972859786)], dtype=[('SN_name', '|S10'), ('redshift', '<f8'), ('dmodulus', '<f8'), ('dmodulus_error', '<f8')])    

但我想用一种不太粗略的方式。

那么,任何人都可以帮我找到一个简单的方法来消除问题点吗?

谢谢!

4 个答案:

答案 0 :(得分:4)

这可能是过度的,并且基于您的评论

  

理论曲线和数据点都是数组   不同的长度。

我会做以下事情:

  1. 截断数据集,使其x值位于理论集的最大值和最小值之内。
  2. 使用scipy.interpolate.interp1d和上述截断数据x值插值理论曲线。步骤(1)的原因是满足interp1d的约束。
  3. 使用numpy.where查找超出可接受理论值范围的数据y值。
  4. 不要像评论和其他答案中所建议的那样丢弃这些值。如果你想要清晰,可以通过绘制'内衬'一种颜色和'异常值'作为另一种颜色来指出它们。
  5. 我认为这是一个接近你所寻找的脚本。它有望帮助您实现您想要的目标:

    import numpy as np
    import scipy.interpolate as interpolate
    import matplotlib.pyplot as plt
    
    # make up data
    def makeUpData():
        '''Make many more data points (x,y,yerr) than theory (x,y),
        with theory yerr corresponding to a constant "sigma" in y, 
        about x,y value'''
        NX= 150
        dataX = (np.random.rand(NX)*1.1)**2
        dataY = (1.5*dataX+np.random.rand(NX)**2)*dataX
        dataErr = np.random.rand(NX)*dataX*1.3
        theoryX = np.arange(0,1,0.1)
        theoryY = theoryX*theoryX*1.5
        theoryErr = 0.5
        return dataX,dataY,dataErr,theoryX,theoryY,theoryErr
    
    def makeSameXrange(theoryX,dataX,dataY):
        '''
        Truncate the dataX and dataY ranges so that dataX min and max are with in
        the max and min of theoryX.
        '''
        minT,maxT = theoryX.min(),theoryX.max()
        goodIdxMax = np.where(dataX<maxT)
        goodIdxMin = np.where(dataX[goodIdxMax]>minT)
        return (dataX[goodIdxMax])[goodIdxMin],(dataY[goodIdxMax])[goodIdxMin]
    
    # take 'theory' and get values at every 'data' x point
    def theoryYatDataX(theoryX,theoryY,dataX):
        '''For every dataX point, find interpolated thoeryY value. theoryx needed
        for interpolation.'''
        f = interpolate.interp1d(theoryX,theoryY)
        return f(dataX[np.where(dataX<np.max(theoryX))])
    
    # collect valid points
    def findInlierSet(dataX,dataY,interpTheoryY,thoeryErr):
        '''Find where theoryY-theoryErr < dataY theoryY+theoryErr and return
        valid indicies.'''
        withinUpper = np.where(dataY<(interpTheoryY+theoryErr))
        withinLower = np.where(dataY[withinUpper]
                        >(interpTheoryY[withinUpper]-theoryErr))
        return (dataX[withinUpper])[withinLower],(dataY[withinUpper])[withinLower]
    
    def findOutlierSet(dataX,dataY,interpTheoryY,thoeryErr):
        '''Find where theoryY-theoryErr < dataY theoryY+theoryErr and return
        valid indicies.'''
        withinUpper = np.where(dataY>(interpTheoryY+theoryErr))
        withinLower = np.where(dataY<(interpTheoryY-theoryErr))
        return (dataX[withinUpper],dataY[withinUpper],
                dataX[withinLower],dataY[withinLower])
    if __name__ == "__main__":
    
        dataX,dataY,dataErr,theoryX,theoryY,theoryErr = makeUpData()
    
        TruncDataX,TruncDataY = makeSameXrange(theoryX,dataX,dataY)
    
        interpTheoryY = theoryYatDataX(theoryX,theoryY,TruncDataX)
    
        inDataX,inDataY = findInlierSet(TruncDataX,TruncDataY,interpTheoryY,
                                        theoryErr)
    
        outUpX,outUpY,outDownX,outDownY = findOutlierSet(TruncDataX,
                                                        TruncDataY,
                                                        interpTheoryY,
                                                        theoryErr)
        #print inlierIndex
        fig = plt.figure()
        ax = fig.add_subplot(211)
    
        ax.errorbar(dataX,dataY,dataErr,fmt='.',color='k')
        ax.plot(theoryX,theoryY,'r-')
        ax.plot(theoryX,theoryY+theoryErr,'r--')
        ax.plot(theoryX,theoryY-theoryErr,'r--')
        ax.set_xlim(0,1.4)
        ax.set_ylim(-.5,3)
        ax = fig.add_subplot(212)
    
        ax.plot(inDataX,inDataY,'ko')
        ax.plot(outUpX,outUpY,'bo')
        ax.plot(outDownX,outDownY,'ro')
        ax.plot(theoryX,theoryY,'r-')
        ax.plot(theoryX,theoryY+theoryErr,'r--')
        ax.plot(theoryX,theoryY-theoryErr,'r--')
        ax.set_xlim(0,1.4)
        ax.set_ylim(-.5,3)
        fig.savefig('findInliers.png')
    

    这个数字是结果: enter image description here

答案 1 :(得分:4)

最后我使用了一些Yann代码:

def theoryYatDataX(theoryX,theoryY,dataX):
'''For every dataX point, find interpolated theoryY value. theoryx needed
for interpolation.'''
f = interpolate.interp1d(theoryX,theoryY)
return f(dataX[np.where(dataX<np.max(theoryX))])

def findOutlierSet(data,interpTheoryY,theoryErr):
    '''Find where theoryY-theoryErr < dataY theoryY+theoryErr and return
    valid indicies.'''

    up = np.where(data.dmodulus > (interpTheoryY+theoryErr))
    low = np.where(data.dmodulus < (interpTheoryY-theoryErr))
    # join all the index together in a flat array
    out = np.hstack([up,low]).ravel()

    index = np.array(np.ones(len(data),dtype=bool))
    index[out]=False

    datain = data[index]
    dataout = data[out]

    return datain, dataout

def selectdata(data,theoryX,theoryY):
    """
    Data selection: z<1 and +-0.5 LFLRW separation
    """
    # Select data with redshift z<1
    data1 = data[data.redshift < 1]

    # From modulus to light distance:
    data1.dmodulus, data1.dmodulus_error = modulus2distance(data1.dmodulus,data1.dmodulus_error)

    # redshift data order
    data1.sort(order='redshift')

    # Outliers: distance to LFLRW curve bigger than +-0.5
    theoryErr = 0.5
    # Theory curve Interpolation to get the same points as data
    interpy = theoryYatDataX(theoryX,theoryY,data1.redshift)

    datain, dataout = findOutlierSet(data1,interpy,theoryErr)
    return datain, dataout

使用这些功能我终于可以获得:

Data selection

谢谢大家的帮助。

答案 2 :(得分:1)

看看红色曲线和点之间的差异,如果它大于红色曲线和虚线红色曲线之间的差异,则将其移除。

diff=np.abs(points-red_curve)
index= (diff>(dashed_curve-redcurve))
filtered=points[index]

但请认真听取NickLH的评论。你的数据看起来相当不错,没有任何过滤,你的“外表”都有一个非常大的错误,并且不会影响适合度。

答案 3 :(得分:0)

您可以使用numpy.where()来确定哪些xy对符合您的绘图标准,或者可能枚举几乎完全相同的事情。例如:

x_list = [ 1,  2,  3,  4,  5,  6 ]
y_list = ['f','o','o','b','a','r']

result = [y_list[i] for i, x in enumerate(x_list) if 2 <= x < 5]

print result

我确信您可以更改条件,以便上例中的'2'和'5'是曲线的函数