1 module LinearAlgebra where
    2 
    3 import Array
    4 import List
    5 
    6 import Types
    7         
    8 apply :: Num a => Matrix a -> Vector a -> Vector a
    9 apply m v             | cm == cv   = accumArray (+) 0 (1, rm) [(r, v ! c * m ! (r, c)) | (r, c) <- indices m]
   10                         | otherwise  = error "apply: matrix and vector dimensions are not compatible"
   11   where 
   12         ((1, 1), (rm, cm))       = bounds m
   13         (    1 ,      cv )         = bounds v
   14 
   15 m_mul :: Num a => Matrix a -> Matrix a -> Matrix a
   16 m_mul a b             | ca == rb   = accumArray (+) 0 ((1, 1), (ra, cb)) [((r, c), a ! (r, t) * b ! (t, c)) | r <- [1..ra], t <- [1..ca], c <- [1..cb]]
   17                         | otherwise  = error "m_mul: matrix dimensions are not compatible"
   18   where
   19         ((1, 1), (ra, ca))     = bounds a
   20         ((1, 1), (rb, cb))         = bounds b
   21 
   22 m_add :: Num a => Matrix a -> Matrix a -> Matrix a
   23 m_add                         = m_zipWith (+)
   24 
   25 m_sub :: Num a => Matrix a -> Matrix a -> Matrix a
   26 m_sub                           = m_zipWith (-)
   27 
   28 m_map :: (a -> b) -> Matrix a -> Matrix b
   29 m_map f a                      = listArray (bounds a) (map f (elems a))
   30 
   31 m_zipWith :: (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c
   32 m_zipWith f a b         | compatible   = listArray (bounds a) (zipWith f (elems a) (elems b))
   33                         | otherwise  = error "m_zipWith: matrix dimensions are not compatible"
   34   where
   35         compatible      = bounds a == bounds b
   36 
   37 v_add :: Num a => Vector a -> Vector a -> Vector a
   38 v_add                     = v_zipWith (+)
   39 
   40 v_sub :: Num a => Vector a -> Vector a -> Vector a
   41 v_sub                     = v_zipWith (-)
   42 
   43 v_map :: (a -> b) -> Vector a -> Vector b
   44 v_map f a                      = listArray (bounds a) (map f (elems a))
   45 
   46 v_zipWith f a b         | compatible   = listArray (bounds a) (zipWith f (elems a) (elems b))
   47                         | otherwise  = error "v_zipWith: vector dimensions are not compatible"
   48   where
   49         compatible      = bounds a == bounds b
   50 
   51 m_transpose :: Matrix a -> Matrix a
   52 m_transpose m                   = let ((1, 1), (r, c)) = bounds m in array ((1, 1), (c, r)) [((c, r), v) | ((r, c), v) <- assocs m]
   53 
   54 m_is_square :: Matrix a -> Bool
   55 m_is_square m                   = let ((1, 1), (r, c)) = bounds m in r == c
   56 
   57 m_zero :: Num a => Index -> Matrix a
   58 m_zero s                          = accumArray (+) 0 ((1, 1), (s, s)) []
   59 
   60 m_unit :: Num a => Index -> Matrix a
   61 m_unit s                       = accumArray (+) 0 ((1, 1), (s, s)) [((i, i), 1) | i <- [1 .. s]]
   62 
   63 nullspace :: Matrix Exact -> Matrix Exact
   64 nullspace                   = m_transpose . left_nullspace . m_transpose
   65 
   66 left_nullspace :: Matrix Exact -> Matrix Exact
   67 left_nullspace m                     = let (rows, _, i) = gauss_jordan m in m_select_rows rows i
   68 
   69 m_select_rows :: [Index] -> Matrix a -> Matrix a
   70 m_select_rows rows matrix             = listArray ((1, 1), (length rows, size)) [matrix ! (r, c) | r <- rows, c <- [1 .. size]]
   71   where
   72         size              = (snd . snd . bounds) matrix
   73 
   74 m_inv :: Matrix Exact -> Matrix Exact
   75 m_inv m                | l == []     = i
   76                         | otherwise  = error "m_inv: matrix isn't invertible"
   77   where
   78         (l, u, i)         = gauss_jordan m
   79         ((1,1), (r, c))               = bounds m
   80 
   81 gauss_jordan :: Matrix Exact -> ([Index], Matrix Exact, Matrix Exact)
   82 gauss_jordan m   | m_is_square m = (foldr1 (.) [step c | c <- reverse [1 .. size]]) ([1 .. size], m, m_unit size)
   83                         | otherwise  = error "gauss_jordan: not a square matrix"
   84   where
   85         step c (rs, m0, i0)   = if v /= 0 then (delete c rs, m2, i2) else (rs, m0, i0)
   86           where
   87                 (m2, i2)           = (sweep     m1, sweep     i1)
   88                 (m1, i1)           = (swap_norm m0, swap_norm i0)
   89                 swap_norm         = (multiply c (1 / v)) . (if r /= c then swap r c else id)
   90                 sweep                  = eliminate c m1
   91                 (r, v)               = pivot c rs m0
   92         
   93         pivot c rs m0           = foldl1 max' [(r, m0 ! (r, c)) | r <- rs]
   94         max' (r1, v1) (r2, v2)         = if (abs(v1) >= abs(v2)) then (r1, v1) else (r2, v2)
   95         
   96         swap      r s  m             = m // concat ([[((r, c), m ! (s, c)),      ((s, c), m ! (r, c))] | c <- [1 .. size]])
   97         multiply  r f  m             = m //         [ ((r, c),              f           * m ! (r, c))  | c <- [1 .. size]]
   98         eliminate w m1 m             = m //         [ ((r, c), m ! (r, c) - m1 ! (r, w) * m ! (w, c))  | r <- [1 .. size], r /= w, c <- [1 .. size]]
   99 
  100         size              = (snd . snd . bounds) m