在SML中将两个矩阵相乘

时间:2018-11-16 18:41:19

标签: sml smlnj

我想用SML / NJ编写一个函数,该函数将2个矩阵作为参数并将其相乘。

我只能使用:

  • 函数点,它包含2个向量并计算标量积:

    fun dot (xs: int list, ys: int list): int =
        List.foldl (fn (x,y) => x+y)
           0
           (ListPair.map (fn (x,y) => x*y) (xs, ys))
    
  • 函数转置,它采用1个矩阵并计算该矩阵的转置:

    fun transpose (m: 'a list list): 'a list list =
        List.tabulate (List.length (List.nth (m, 0)),
               fn x => List.map (fn y => (List.nth (y, x))) m)
    
    • 匿名函数

    • 结构列表,ListPair和数学

我要编写的函数应该是这样的:

fun multiply (a: int list list, b: int list list): int list list

到目前为止,我已经做到了:

fun multiply (a: int list list, b: int list list): int list list =
case a of
[] => []
  | g::rep => [(List.map (fn y => dot(g, y)) transpose(b))] @ (multiply(rep, b))

但是我得到了这个错误:

test.sml:66.21-66.62 Error: operator and operand do not agree [tycon mismatch]
  operator domain: int list list
  operand:         'Z list list -> 'Z list list
  in expression:
(List.map (fn y => dot <exp>)) transpose

如果在函数的最后一行乘以b而不是tranpose(b),我不会出错,但是,当然,如果这样做,我不会得到我想要的结果:

fun multiply (a: int list list, b: int list list): int list list =
case a of
[] => []
  | g::rep => [(List.map (fn y => dot(g, y)) b)] @ (multiply(rep, b))

我不知道该怎么办。有人能帮我吗?

2 个答案:

答案 0 :(得分:1)

存在a solution for OCaml on RosettaCode您可以翻译。

给出插图

         | [ a,   [ c,
         |   b ]    d ]
---------+-------------
[ 1, 2 ] |   w      x
[ 3, 4 ] |   y      z

然后针对第一个矩阵中的每一行,计算与第二个矩阵中相同数量dot乘积。即w = dot ([1, 2], [a, b])。提取第一个矩阵的行很容易,因为您可以使用列表递归。

提取第二个矩阵的列并不容易,因为它们与列表表示正交(即a是第一行的第一元素,b是第二个矩阵的第一元素行,c是第一行的第二个元素,d是第二行的第二个元素。

您可以通过执行transpose来简化从第二个矩阵中提取列的操作,在这种情况下,提取列等同于提取行。此时,您可以采用“行”(即第一个矩阵中的行和第二个矩阵中的转置列(“行”))的成对dot乘积。

我鼓励对这种类型的操作使用Array2,因为当“矩阵”(列表)参差不齐(行长不同)时,您还可以避免错误处理。

答案 1 :(得分:-1)

至少我成功使用了您的方案

exception dim_Error;

fun listPair (l1:int list) (l2:int list): (int * int) List.list =  
`if not (List.length l1 = List.length l2) then raise dim_Error
    else List.tabulate (List.length l1, fn i => (List.nth(l1, i),List.nth(l2, i)));`    

fun dotprod (l1:int list) (l2:int list): int = 
    List.foldl (fn (x,y) => x+y) 0 (map (fn (x,y) => x*y) (listPair l1 l2));

fun transpose (matrix: int list list):int list list =
    List.tabulate (List.length (List.nth (matrix, 0)), fn i => List.map (fn j => List.nth (j, i)) matrix);

fun multiply (matrix1: int list list) (matrix2: int list list): int list list =
case matrix1 of 
    [] => []
    |    x::xr => [List.map (fn y => dotprod x y) (transpose matrix2)] @ (multiply xr matrix2);

val A = [[ 21,  ~2,  3, ~1], [  2, 115,  6,  7], [ ~6,  9, 210, 11], [ 13, 14, ~19, 15], [ ~2, 21, ~99,  3], [  5, ~4,  ~1,  9]];
val B = [[2, ~6, 0, 3, 9], [~4, ~1, 2, 1, ~2], [~6, 5, 3, ~5, 1], [10, 9, ~4, 9, ~7]];
val C = [[~4, ~3, 7], [1, ~1, ~2], [~5, 3, 7], [8, ~7, 2], [~8, 3, 6]];

multiply (multiply  A B) C;
相关问题