1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 module Densematrix(Matrix,Vector, 37 mmult, madd, msub, vouter, vdot, norm ,mneg ,mxpose,mident, msize, 38 mkmat, mkrmat, mkcmat, mkvec, mkrvec, mkcvec, 39 vadd, vsub, vsize, vneg, 40 swaprow, swapcol ,droprow, getrow, getcol, 41 subscript, vsubscript, vecpart, 42 update ,vupdate, update2, 43 vhd, vtl,mergevectors, 44 matvecmult, vmmult,svmult, 45 showmatrix,displayvector, 46 minverse, 47 veclist, matlist ) where 48 49 import List (transpose) 50 import Utils 51 52 type Matrix = [[Float]] 53 type Vector = [Float] 54 55 56 57 mmult :: Matrix -> Matrix -> Matrix 58 madd :: Matrix -> Matrix -> Matrix 59 msub :: Matrix -> Matrix -> Matrix 60 vadd :: Vector -> Vector -> Vector 61 vsub :: Vector -> Vector -> Vector 62 vsize :: Vector -> Int 63 vouter :: Vector -> Vector -> Matrix 64 vdot :: Vector -> Vector -> Float 65 norm :: Vector -> Float 66 vneg :: Vector -> Vector 67 mneg :: Matrix -> Matrix 68 mxpose :: Matrix -> Matrix 69 mident :: Int -> Matrix 70 msize :: Matrix -> (Int,Int) 71 mkmat :: [[Float]] -> Matrix 72 mkrmat :: [[Float]] -> Matrix 73 mkcmat :: [[Float]] -> Matrix 74 mkvec :: [Float] -> Vector 75 mkrvec :: [Float] -> Vector 76 mkcvec :: [Float] -> Vector 77 veclist :: Vector -> [Float] 78 matlist :: Matrix -> [[Float]] 79 80 swaprow :: Int -> Int -> Matrix -> Matrix 81 swapcol :: Int -> Int -> Matrix -> Matrix 82 droprow :: Matrix -> Matrix 83 getrow :: Int -> Matrix -> Vector 84 getcol :: Int -> Matrix -> Vector 85 86 subscript :: Matrix -> (Int,Int) -> Float 87 vsubscript :: Vector -> Int -> Float 88 vecpart :: Int -> Int -> Vector -> Vector 89 update :: Matrix -> (Int,Int) -> Float -> Matrix 90 update2 :: Matrix -> (Int,Int,Int) -> (Float,Float) -> Matrix 91 vupdate :: Vector -> Int -> Float -> Vector 92 vhd :: Vector -> Float 93 vtl :: Vector -> Vector 94 mergevectors :: [Vector] -> Vector 95 96 minverse :: Matrix -> Matrix 97 lu_decomp :: Matrix -> Lu_factor 98 apply_factor :: Lu_factor -> Matrix -> Matrix 99 forward :: [[Float]] -> [Int] -> Matrix -> Matrix 100 backward :: [[Float]] -> [Int] -> Matrix -> Matrix 101 102 showmatrix :: Matrix -> [Char] 103 displayvector :: Vector -> [Char] 104 105 matvecmult :: Matrix -> Vector -> Vector 106 vmmult :: Vector -> Matrix -> Vector 107 svmult :: Float -> Vector -> Vector 108 109 110 111 112 113 114 115 116 117 118 119 120 121 mmult m n = if compatible then [row i | i <- [0..((length m)-1)]] 122 else error errmsg 123 where 124 compatible = (snd (msize m)) == (fst (msize n)) 125 row i = [item i j | j <- [0..length(head n)-1]] 126 item i j = vdot (m!!i) (map (!!j) n) 127 errmsg = "mmult in densematrix: incompatible matrices " ++ 128 (show (msize m)) ++ "*" ++ (show (msize n)) 129 130 131 132 133 134 matvecmult m v = map (vdot v) m 135 136 137 138 139 vmmult v m = res 140 where 141 res' = mmult [v] m 142 res = concat res' 143 144 145 svmult s v = map (*s) v 146 147 148 vdot u v = sum (map2 (*) u v) 149 150 norm v = vdot v v 151 152 vneg = map negate 153 154 vouter u v = [ map (k*) u | k <- v] 155 156 157 madd a b = map2 vadd a b 158 159 vadd u v = map2 (+) u v 160 161 msub a b = map2 vsub a b 162 163 vsub u v = map2 (-) u v 164 165 vsize = (length) 166 167 mneg rs = map (map negate) rs 168 169 mxpose = transpose 170 171 mident n 172 = [ row k | k <- [1..n] ] 173 where 174 row i = (rep (i-1) 0) ++ [1] ++ (rep (n-i) 0) 175 176 177 msize m = (length m , length(head m)) 178 179 mkmat = id 180 mkrmat = id 181 mkcmat = transpose 182 matlist = id 183 184 185 mkvec = id 186 mkrvec = id 187 mkcvec = id 188 veclist = id 189 190 191 192 swaprow a b m = swapitems a b m 193 swapcol a b m = map (swapitems a b) m 194 195 swapitems :: Int -> Int -> [a] -> [a] 196 swapitems a b xs 197 = if a< b then toa ++ [xs!!b] ++ atob ++ [xs!!a] ++ pastb 198 else if a > b then swapitems b a xs else xs 199 where 200 toa = take a xs 201 atob = take (b-a-1) (drop (a+1) xs) 202 pastb = drop (b+1) xs 203 204 205 droprow = tail 206 207 208 getrow i m = m!!i 209 getcol i m = map (!!i) m 210 211 212 subscript m (i,j) = m !! i !! j 213 214 vsubscript v n = v!!n 215 216 vecpart start intber 217 = (take intber) . (drop start) 218 219 update m (i,j) val 220 = (take i m) ++ [(f (m!!i))] ++ (drop (i+1) m) 221 where 222 f xs = (take j xs) ++ [val] ++ (drop (j+1) xs) 223 224 update2 m (i1,i2,j) (val1,val2) 225 = (take i1 m) ++ [(f1 (m!!i1))] ++ [(f2 (m!!i2))] ++ (drop (i1+2) m) 226 where 227 f1 xs = (take j xs) ++ [val1] ++ (drop (j+1) xs) 228 f2 xs = (take j xs) ++ [val2] ++ (drop (j+1) xs) 229 230 231 vupdate v i val 232 = (take i v) ++ [val] ++ (drop (i+1) v) 233 234 235 vhd = head 236 vtl = tail 237 238 mergevectors = concat 239 240 241 showmatrix m = "\n" ++ showmat m ++ "\n" 242 243 showmat :: [Vector] -> [Char] 244 showmat m 245 = concat (map show_row m) 246 where 247 show_row v = (displayvector v) ++ "\n" 248 249 250 displayvector 251 = concat . 252 (map (rjustify 13)) . 253 (map show) 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 data Lu_factor = Lu_fact [[Float]] [[Float]] [Int] 303 304 305 306 minverse m 307 = apply_factor (lu_decomp m) (mident (fst(msize m))) 308 309 310 apply_factor factor m 311 = ((backward uvecs ps) . (forward lvecs ps)) m 312 where 313 (Lu_fact lvecs uvecs ps) = factor 314 315 316 317 forward [] p b = b 318 forward (l:ls) (p:ps) b 319 = y1 : (forward ls ps y') 320 where 321 y' = msub y (vouter y1 l) 322 (y1:y) = swaprow 0 p b 323 324 325 backward [[u]] p y = map (map (/u)) y 326 backward ((u1:u):us) (p:ps) (y:ys) 327 = x 328 where 329 x = swaprow 0 p (xk:nextx) 330 xk = map (/u1) (vsub y (vmmult u nextx)) 331 nextx = backward us ps ys 332 333 334 335 336 lu_decomp a 337 = if (fst(msize a) == 1) then Lu_fact [] [[only_one]] [1] 338 else Lu_fact (l:ls) (u:us) (p:ps) 339 where 340 Lu_fact ls us ps = lu_decomp a11' 341 (a',p) = pivot a 342 u11 = head u 343 u = getrow 0 a' 344 m = getcol 0 a' 345 l = if not(u11 == 0) then tail (map (/ u11) m) 346 else error ("div by 0 in lu: block is\n "++(show a)) 347 a11 = map tail (tail a') 348 a11' = msub a11 (vouter (tail u) l) 349 only_one = subscript a (0,0) 350 351 352 353 354 pivot :: Matrix -> (Matrix,Int) 355 pivot m 356 = (swaprow 0 p (swapcol 0 p m), p) 357 where 358 p = findpivot m 359 360 361 findpivot :: Matrix -> Int 362 findpivot a 363 = loc_of_max 0 diag 364 where 365 diag = [a `subscript` (i,i) | i <- [0..(fst(msize a))-1] ] 366 absmax = maxlist (map absfloat diag) 367 loc_of_max n (x:xs) 368 = if (absfloat x) == absmax then n 369 else loc_of_max (n+1) xs 370 371 372 absfloat :: Float -> Float 373 absfloat n = if n < 0 then -n 374 else n 375 376 maxlist :: [Float] -> Float 377 maxlist xs = foldl1 max xs 378 379 380 381 382 383 384 385 386 387 388