如何在BigQuery中执行线性回归?

时间:2016-09-13 04:46:12

标签: google-bigquery

BigQuery有一些统计聚合函数,如STDDEV(X)和CORR(X,Y),但它不提供直接执行线性回归的函数。

如何使用现有的函数计算线性回归?

3 个答案:

答案 0 :(得分:11)

编辑的编辑:请参阅next answer,BigQuery现在支持线性回归。 --Fh

以下查询使用数值稳定且易于修改以在任何输入表上工作的计算执行线性回归。它使用内置函数CORR产生最适合模型的斜率和截距Y = SLOPE * X + INTERCEPT和Pearson相关系数。

作为一个例子,我们使用公共天性数据集来计算出生体重作为怀孕持续时间的线性函数,按州分解。你可以更紧凑地编写这个,但是我们使用几层子查询来突出这些部分是如何组合在一起的。要将其应用于其他数据集,您只需要替换最里面的查询。

SELECT Bucket,
       SLOPE,
       (SUM_OF_Y - SLOPE * SUM_OF_X) / N AS INTERCEPT,
       CORRELATION
FROM (
    SELECT Bucket,
           N,
           SUM_OF_X,
           SUM_OF_Y,
           CORRELATION * STDDEV_OF_Y / STDDEV_OF_X AS SLOPE,
           CORRELATION
    FROM (
        SELECT Bucket,
               COUNT(*) AS N,
               SUM(X) AS SUM_OF_X,
               SUM(Y) AS SUM_OF_Y,
               STDDEV_POP(X) AS STDDEV_OF_X,
               STDDEV_POP(Y) AS STDDEV_OF_Y,
               CORR(X,Y) AS CORRELATION
        FROM (SELECT state AS Bucket,
                     gestation_weeks AS X,
                     weight_pounds AS Y
              FROM [publicdata.samples.natality])
        WHERE Bucket IS NOT NULL AND
              X IS NOT NULL AND
              Y IS NOT NULL
        GROUP BY Bucket));

使用STDDEV_POP和CORR函数可以提高此查询的数值稳定性,而不是总结X和Y的乘积,然后进行差异和除法,但如果在行为良好的数据集上使用这两种方法,则可以验证它们产生相同的结果,以达到高精度。

答案 1 :(得分:2)

好消息! BigQuery现在具有对ML的本地支持。

要产生线性回归CREATE MODEL,然后用SELECT FROM ML.PREDICT进行预测。

文档:

有趣的例子:When will Stack Overflow reply

答案 2 :(得分:0)

这里的代码使用有关出生率(活产儿)的公共数据集创建线性回归模型,并将其生成为名为demo_ml_bq的数据集。必须在运行以下语句之前创建此文件。

%%bq query
CREATE or REPLACE MODEL demo_bq_ml.babyweight_model_asis
OPTIONS
  (model_type='linear_reg', labels=['weight_pounds']) AS

WITH natality_data AS (
  SELECT
     weight_pounds, -- this is the label; because it is continuous, we need to use regression
    CAST(is_male AS STRING) AS is_male,
    mother_age,
    CAST(plurality AS STRING) AS plurality,
    gestation_weeks,
    CAST(alcohol_use AS STRING) AS alcohol_use,
    CAST(year AS STRING) AS year,
    ABS(FARM_FINGERPRINT(CONCAT(CAST(YEAR AS STRING), CAST(month AS STRING)))) AS hashmonth
  FROM
    publicdata.samples.natality
  WHERE
    year > 2000
    AND gestation_weeks > 0
    AND mother_age > 0
    AND plurality > 0
    AND weight_pounds > 0
)

SELECT
    weight_pounds,
    is_male,
    mother_age,
    plurality,
    gestation_weeks,
    alcohol_use,
    year
FROM
    natality_data
WHERE
  MOD(hashmonth, 4) < 3  -- select 75% of the data as training