1 module LinearAlgebra where 2 3 import Array 4 import List 5 6 import Types 7 8 apply :: Num a => Matrix a -> Vector a -> Vector a 9 apply m v | cm == cv = accumArray (+) 0 (1, rm) [(r, v ! c * m ! (r, c)) | (r, c) <- indices m] 10 | otherwise = error "apply: matrix and vector dimensions are not compatible" 11 where 12 ((1, 1), (rm, cm)) = bounds m 13 ( 1 , cv ) = bounds v 14 15 m_mul :: Num a => Matrix a -> Matrix a -> Matrix a 16 m_mul a b | ca == rb = accumArray (+) 0 ((1, 1), (ra, cb)) [((r, c), a ! (r, t) * b ! (t, c)) | r <- [1..ra], t <- [1..ca], c <- [1..cb]] 17 | otherwise = error "m_mul: matrix dimensions are not compatible" 18 where 19 ((1, 1), (ra, ca)) = bounds a 20 ((1, 1), (rb, cb)) = bounds b 21 22 m_add :: Num a => Matrix a -> Matrix a -> Matrix a 23 m_add = m_zipWith (+) 24 25 m_sub :: Num a => Matrix a -> Matrix a -> Matrix a 26 m_sub = m_zipWith (-) 27 28 m_map :: (a -> b) -> Matrix a -> Matrix b 29 m_map f a = listArray (bounds a) (map f (elems a)) 30 31 m_zipWith :: (a -> b -> c) -> Matrix a -> Matrix b -> Matrix c 32 m_zipWith f a b | compatible = listArray (bounds a) (zipWith f (elems a) (elems b)) 33 | otherwise = error "m_zipWith: matrix dimensions are not compatible" 34 where 35 compatible = bounds a == bounds b 36 37 v_add :: Num a => Vector a -> Vector a -> Vector a 38 v_add = v_zipWith (+) 39 40 v_sub :: Num a => Vector a -> Vector a -> Vector a 41 v_sub = v_zipWith (-) 42 43 v_map :: (a -> b) -> Vector a -> Vector b 44 v_map f a = listArray (bounds a) (map f (elems a)) 45 46 v_zipWith f a b | compatible = listArray (bounds a) (zipWith f (elems a) (elems b)) 47 | otherwise = error "v_zipWith: vector dimensions are not compatible" 48 where 49 compatible = bounds a == bounds b 50 51 m_transpose :: Matrix a -> Matrix a 52 m_transpose m = let ((1, 1), (r, c)) = bounds m in array ((1, 1), (c, r)) [((c, r), v) | ((r, c), v) <- assocs m] 53 54 m_is_square :: Matrix a -> Bool 55 m_is_square m = let ((1, 1), (r, c)) = bounds m in r == c 56 57 m_zero :: Num a => Index -> Matrix a 58 m_zero s = accumArray (+) 0 ((1, 1), (s, s)) [] 59 60 m_unit :: Num a => Index -> Matrix a 61 m_unit s = accumArray (+) 0 ((1, 1), (s, s)) [((i, i), 1) | i <- [1 .. s]] 62 63 nullspace :: Matrix Exact -> Matrix Exact 64 nullspace = m_transpose . left_nullspace . m_transpose 65 66 left_nullspace :: Matrix Exact -> Matrix Exact 67 left_nullspace m = let (rows, _, i) = gauss_jordan m in m_select_rows rows i 68 69 m_select_rows :: [Index] -> Matrix a -> Matrix a 70 m_select_rows rows matrix = listArray ((1, 1), (length rows, size)) [matrix ! (r, c) | r <- rows, c <- [1 .. size]] 71 where 72 size = (snd . snd . bounds) matrix 73 74 m_inv :: Matrix Exact -> Matrix Exact 75 m_inv m | l == [] = i 76 | otherwise = error "m_inv: matrix isn't invertible" 77 where 78 (l, u, i) = gauss_jordan m 79 ((1,1), (r, c)) = bounds m 80 81 gauss_jordan :: Matrix Exact -> ([Index], Matrix Exact, Matrix Exact) 82 gauss_jordan m | m_is_square m = (foldr1 (.) [step c | c <- reverse [1 .. size]]) ([1 .. size], m, m_unit size) 83 | otherwise = error "gauss_jordan: not a square matrix" 84 where 85 step c (rs, m0, i0) = if v /= 0 then (delete c rs, m2, i2) else (rs, m0, i0) 86 where 87 (m2, i2) = (sweep m1, sweep i1) 88 (m1, i1) = (swap_norm m0, swap_norm i0) 89 swap_norm = (multiply c (1 / v)) . (if r /= c then swap r c else id) 90 sweep = eliminate c m1 91 (r, v) = pivot c rs m0 92 93 pivot c rs m0 = foldl1 max' [(r, m0 ! (r, c)) | r <- rs] 94 max' (r1, v1) (r2, v2) = if (abs(v1) >= abs(v2)) then (r1, v1) else (r2, v2) 95 96 swap r s m = m // concat ([[((r, c), m ! (s, c)), ((s, c), m ! (r, c))] | c <- [1 .. size]]) 97 multiply r f m = m // [ ((r, c), f * m ! (r, c)) | c <- [1 .. size]] 98 eliminate w m1 m = m // [ ((r, c), m ! (r, c) - m1 ! (r, w) * m ! (w, c)) | r <- [1 .. size], r /= w, c <- [1 .. size]] 99 100 size = (snd . snd . bounds) m