有人能告诉我最后一个循环在做什么吗?

时间:2018-05-26 04:32:30

标签: python machine-learning data-science

import os
import tarfile
from six.moves import urllib
import pandas as pd
import hashlib
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit 

DOWNLOAD_ROOT = "https://raw.githubusercontent.com/ageron/handson-ml/master/"
HOUSING_PATH = os.path.join("datasets", "housing")
HOUSING_URL = DOWNLOAD_ROOT + "datasets/housing/housing.tgz"

def fetch_housing_data(housing_url=HOUSING_URL, housing_path= HOUSING_PATH):
    if not os.path.isdir(housing_path):
        os.makedirs(housing_path)
    tgz_path = os.path.join(housing_path, "housing.tgz")
    urllib.request.urlretrieve(housing_url, tgz_path)
    housing_tgz = tarfile.open(tgz_path)
    housing_tgz.extractall(path=housing_path)
    housing_tgz.close()
#getting the housing data 


def load_housing_data(housing_path=HOUSING_PATH):
    csv_path = os.path.join(housing_path, "housing.csv")
    return pd.read_csv(csv_path)
#that function loaded the data in a panda datafrome object 


#need to call the function to get the housing data 
fetch_housing_data()
housing = load_housing_data()
housing.head()

#total bedrooms doesnt match entries deal with later 
#ocean proximity holds an object, since its in csv file still can contain text
housing.describe()
#describes the output of the housing information 


%matplotlib inline 
import matplotlib.pyplot as plt 
housing.hist(bins=50,figsize=(20,15))
plt.show()
#creates a histogram of the data set, x axis is the range of hosuing prices, y axis number of instances of housing prices at that 
#given range 
#income data has been scaled by max 15 and .5 for lower 

#since the data of housing prices has been capped at 500k posssible delete that data set 
#thus so our model wont learn those bad values because it may not be 500k thus labels could be off 
#tail heavy because its 200K plus for example so just barel a dollar more would make it (left)

import numpy as np 

def split_train_test(data,test_ratio):
    shuffled_indices = np.random.permutation(len(data))
    #a randomized array with the same length as the input data so all data 
    test_set_size = int(len(data)*test_ratio)
    #mutliplying by a ratio to see the difference of the data 
    test_indices = shuffled_indices[:test_set_size]
    train_indices = shuffled_indices[test_set_size:]
    #taking the test of the beggining because of the entry 
    #taking rest for training 
    return data.iloc[train_indices],data.iloc[test_indices]
#redo the variable since outside the cells 
housing = load_housing_data()

#creating a category of income prices that is stratified 
housing["income_cat"] = np.ceil(housing["median_income"]/1.5)
housing["income_cat"].where(housing["income_cat"]<5,5.0,inplace = True)
#since now the income has been set into categories 
#stratified because not even split reprisentative of the population 
split = StratifiedShuffleSplit(n_splits=1,test_size = 0.2,random_state=42)

这是代码末尾的循环

for train_index,test_index in split.split(housing,housing["income_cat"]):
    strat_train_set = housing.loc[train_index]
    strat_test_set = housing.loc[test_index]

有人可以向我解释最后一个for循环是做什么的吗?基本上它应该将数据集分层为训练和测试,但我特别对循环标题感到困惑,因为为什么整个数据框对象在第一个参数中然后是其后面的收入类别部分。是否对所创建的每个收入类别进行分层,从而操纵整个数据框对象中的所有后续类别?

1 个答案:

答案 0 :(得分:2)

我相信你已经读过了:http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html#sklearn.model_selection.StratifiedShuffleSplit.split

所以split需要两个变量:

住房:培训数据,其中n_samples是样本数,n_​​features是特征数。

住房[&#34; income_cat&#34;] :监督学习问题的目标变量。分层基于y标签完成。

它将返回一个包含2个条目的元组数组(其中每个条目都是ndarray):

第一个条目:该分组的训练集索引。

第二个条目:该分割的测试集索引。