2d散射的sklearn线性回归

时间:2014-10-31 20:54:33

标签: python matplotlib scikit-learn regression linear-regression

我在使用元组形式的二维散点图上执行sklearn线性回归时遇到问题。我从csv文件中的文本生成数据,即使用np.genfromtxt

以下是我的代码的完全可操作部分:

导入模块:

import numpy as np
import scipy
import matplotlib as mpl
import matplotlib.pyplot as plt
from pylab import *

from matplotlib import pyplot as plt
from matplotlib import rc

from sklearn import linear_model

元组形式的数据集:

bpt_class_sey, logLOIII_blue_sey, logLOIII_red_sey, logOII_OIII_sey = np.genfromtxt('/Users/iMacHome/Downloads/seyfert_LOIII_LOGOIIOIII.csv', delimiter=',', unpack=True)

回归:

regr = linear_model.LinearRegression()

regr.fit(logLOIII_blue_sey, logOII_OIII_sey)

回归图:

axScatter.plot(logLOIII_blue_sey, regr.predict(logLOIII_blue_sey), color='blue',linewidth=3)

现在,我收到以下错误:

IndexError: tuple index out of range

我想知道是否有人会知道如何快速解决这个小问题。我在网上看起来尽可能多,但没有解释为什么会发生这种情况/可能的解决方案,甚至是以我所拥有的格式为数据集执行统计上可靠的线性回归的全新方法。完整代码如下:

import numpy as np
import scipy
import matplotlib as mpl
import matplotlib.pyplot as plt
from pylab import *

from matplotlib import pyplot as plt
from matplotlib import rc

from sklearn import datasets, linear_model

bpt_class_comp, logLOIII_blue_comp, logLOIII_red_comp, logOII_OIII_comp = np.genfromtxt('/Users/iMacHome/Downloads/composite_LOIII_LOGOIIOIII.csv', delimiter=',', unpack=True)
bpt_class_sf, logLOIII_blue_sf, logLOIII_red_sf, logOII_OIII_sf         = np.genfromtxt('/Users/iMacHome/Downloads/starforming_LOIII_LOGOIIOIII.csv', delimiter=',', unpack=True)
bpt_class_sey, logLOIII_blue_sey, logLOIII_red_sey, logOII_OIII_sey     = np.genfromtxt('/Users/iMacHome/Downloads/seyfert_LOIII_LOGOIIOIII.csv', delimiter=',', unpack=True)

regr = linear_model.LinearRegression()

regr.fit(logLOIII_blue_sey, logOII_OIII_sey)

fig = plt.figure(132)

axScatter = fig.add_subplot(131)
axScatter.set_ylabel(r'$\mathrm{log([OII]/[OIII])}$', fontsize='medium')
axScatter.set_xlabel(r'$\mathrm{log[}L\mathrm{_{[OIII]}\ (erg\ {s^{-1}})]}$', fontsize='medium')
axScatter.set_ylim(-1.5, 1.0)
axScatter.set_xlim(39, 44)
axScatter.tick_params(axis='both', which='major', labelsize=10)
axScatter.tick_params(axis='both', which='minor', labelsize=10)
axScatter.xaxis.labelpad = 10
axScatter.scatter(logLOIII_blue_sey, logOII_OIII_sey, marker="o", c='0.1',s=15,lw=1)
axScatter.scatter(logLOIII_blue_comp, logOII_OIII_comp, marker="o", c='#6633ff',s=15,lw=1)
axScatter.scatter(logLOIII_blue_sf, logOII_OIII_sf, marker="o", c='0.7',s=15,lw=1)

axScatter.plot([logLOIII_blue_sey, regr.predict(logLOIII_blue_sey)], color='blue',linewidth=3)

axScatter = fig.add_subplot(132)
#axScatter.set_ylabel(r'$\mathrm{log([OII]/[OIII])}$', fontsize='medium')
axScatter.set_xlabel(r'$\mathrm{log[}L\mathrm{_{[OIII]} (erg {s^{-1}})]}$', fontsize='medium')
axScatter.set_ylim(-1.5, 1.0)
axScatter.set_xlim(39, 44)
axScatter.tick_params(axis='both', which='major', labelsize=10)
axScatter.tick_params(axis='both', which='minor', labelsize=10)
axScatter.xaxis.labelpad = 10
axScatter.scatter(logLOIII_red_sey, logOII_OIII_sey, marker="o", c='0.1',s=15,lw=1)
axScatter.scatter(logLOIII_red_comp, logOII_OIII_comp, marker="o", c='#6633ff',s=15,lw=1)
axScatter.scatter(logLOIII_red_sf, logOII_OIII_sf, marker="o", c='0.7',s=15,lw=1)

axHistogram = fig.add_subplot(133)

plt.show()

例如,使用:

print logLOIII_blue_sey

收率:

[ 42.30730188  42.67215043  42.15924954  41.61370469  41.94149606
  41.22327958  42.15549254  42.07837228  41.43995205  41.44106463]

print logLOIII_blue_sey

[ -0.05883232 -0.10934038 -0.10249362  0.71041126  0.12462513  0.69641850
   0.11334571         nan -0.07256197  0.72781828  0.02585652  0.70823414]

0 个答案:

没有答案