多元LSTM的keras输入形状

时间:2018-08-07 07:41:48

标签: r keras lstm

我试图在有两个输入的喀拉拉邦中拟合LSTM模型

y是形状为(100,10)的输出 x是形状为(100,20)的输入

library(keras)

x_train_vec <- matrix(rnorm(2000), ncol = 20, nrow = 100)
x_train_arr <- array(data = x_train_vec, dim = c(nrow(x_train_vec), 1, 20))


y_train_vec <- matrix(rnorm(1000), ncol = 10, nrow = 100)
y_train_arr <- array(data = y_train_vec, dim = c(nrow(x_train_vec), 1, 10))


> dim(x_train_arr)
[1] 100   1  20
> dim(y_train_arr)
[1] 100   1  10

现在我要适合LSTM模型

model <- keras_model_sequential()

model %>%
  layer_lstm(units            = 50, 
             input_shape      = c(1,10), 
             batch_size       = 1) %>% 
  layer_dense(units = 1)

model %>% 
  compile(loss = 'mae', optimizer = 'adam')

model %>% fit(x          = x_train_arr, 
              y          = y_train_arr, 
              batch_size = 1,
              epochs     = 10, 
              verbose    = 1, 
              shuffle    = FALSE)

但是我得到这个错误:

  

py_call_impl(可调用,dots $ args,dots $ keywords)错误:
  ValueError:检查输入时出错:预期lstm_21_input具有   形状(1,10),但数组的形状为(1,20)

如果将输入大小更改为c(1,20),则会得到:

  

py_call_impl(可调用,dots $ args,dots $ keywords)错误:
  ValueError:检查目标时出错:预期density_13具有2   尺寸,但数组的形状为(100,1,10)

我也使用了不同的设置,但从未奏效。

2 个答案:

答案 0 :(得分:0)

如果您的Keras版本是<2.0,则需要使用model.add(TimeDistributed(Dense(1)))。

请注意,该语法适用于python,您需要找到R等值。

答案 1 :(得分:0)

我弄清楚了如何使它工作:

        protected void Page_Load(object sender, EventArgs e)
    {

        if (!IsPostBack)
        {
            //string query = @"Select * from Studentsinfor";
            var data = db.Database.SqlQuery<StudentsInfo>(query);
            ReportViewer1.SizeToReportContent = true;
            ReportViewer1.LocalReport.ReportPath = Server.MapPath("IDCards.rdlc");
            ReportViewer1.LocalReport.DataSources.Clear();
            ReportDataSource ds = new ReportDataSource("DataSet1", data);
            ReportViewer1.LocalReport.DataSources.Add(ds);
            this.ReportViewer1.LocalReport.EnableExternalImages = true;

            /* begin added part */

            // get absolute path to Project folder
            string path = new Uri(Server.MapPath("~/Photos")).AbsoluteUri; // adjust path to Project folder here

            // set above path to report parameter
            var parameter = new ReportParameter[1];
            parameter[0] = new ReportParameter("ImagePath", path); // adjust parameter name here
            ReportViewer1.LocalReport.SetParameters(parameter);
            /* end of added part */

            ReportViewer1.LocalReport.Refresh();