1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 module Matlib (scale, precond, uncondition) where 26 27 import List (genericLength) 28 import Matrix 29 import AbsDensematrix 30 31 32 type Scale_function = Matrix -> Vector -> (Matrix,Vector) 33 type Precond_function = Matrix -> Vector -> Vector 34 35 type Block_tuple = (Int,Int,Block) 36 37 38 scale :: Scale_function 39 precond :: Int -> Precond_function 40 uncondition :: Precond_function 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 scale = bdscl 56 57 58 uncondition a v = mvmult (diag_matrix a) v 59 60 61 bdscl :: Matrix -> Vector -> (Matrix,Vector) 62 bdscl m vs 63 = (m', vs') 64 where 65 vs' = mkvector [ bvecmult (row_factor r) (vsubscript r vs) 66 | r <- [0..(vsize vs)-1] ] 67 m' = mkmatrix [map scale1 (getrow i m) | i <- [0..(numrows m)-1]] 68 69 scale1 (r,c,b) = if r==c then (r,c, bident (fst (bsize b))) 70 else (r,c, bmult (row_factor r) b) 71 row_factor n = if (okindex n diag_inverses) then diag_inverses !! n 72 else error "bdscl:dinverse in matrix" 73 74 diag_inverses = map inverse (diag_blocks m) 75 inverse (r,c,b) = binverse b 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 precond nwells a v 190 = backsubs (pseudofactor firstwell a) v 191 where 192 firstwell = (numrows a) - nwells 193 194 195 pseudofactor :: Int -> Matrix -> Matrix 196 pseudofactor n mat 197 = mat' 198 where 199 mat' = mkmatrix [ newrow k | k <- blocknums mat ] 200 newrow k = (bilu k) ++ (bclu k) 201 maxblock = last (blocknums mat) 202 bilu k = if (k>=n) then [] 203 else map bilu' (colsbefore n (getrow k mat)) 204 bilu' (r,c,oldb) 205 = if r/=c then (r,c,oldb) 206 else (r,c,binverse newb) 207 where 208 newb = if sumpart == [] then oldb 209 else oldb `bsub` (bsum sumpart) 210 k=r 211 sumpart = [ b'uk `bmult` (b' k k) `bmult` (b' k u) 212 | (u,k,b'uk) <- colsbefore k (getrow k mat') ] 213 b i j = msubscript i j mat 214 b' i j = msubscript i j mat' 215 216 first_col_in_row r 217 = if getrow r mat == [] then maxblock 218 else colno (head (getrow r mat)) 219 colno (r,c,b) = c 220 firstcols = [first_col_in_row r | r<-[n..maxblock]] 221 firstcol r = firstcols !! (r-n) 222 exist r c 223 = if c>r then exist c r 224 else if c>=(firstcol r) then True 225 else False 226 227 bclu r 228 = if (r<n) then [ bclu' r c | c <- [n..maxblock], exist r c ] 229 else [ bclu' r c | c <- blocknums mat, exist r c ] 230 bclu' i j 231 = if result == [] then error "pseudofactor" 232 else if i==j then (i,j,binverse (head result)) 233 else (i,j,head result) 234 where 235 result = if i<j then [b' i i] `mult` ([b i j] `sub` sumpartU) 236 else [b i j] `sub` sumpartL 237 sumpartU = sumblocks [ [bik] `mult` [b k j] 238 | (x,k,bik) <- (getrow i mat), k<i ] 239 sumpartL = sumblocks [ [bik] `mult` [b' k k] `mult` [b' k j] 240 | (x,k,bik) <- (getrow i mat), k<j ] 241 242 243 244 colsbefore n row = [(r,c,b) | (r,c,b) <- row, c < n ] 245 246 247 248 backsubs m v 249 = v' 250 where 251 v' = ((backward m) . (forward m)) v 252 253 254 255 forward m rs 256 = mkvector [(y k) | k <- (blocknums m) ] 257 where 258 y k = bvecmult (dinv k) (terms k) 259 dinv k = b k k 260 r k = vsubscript k rs 261 terms k = if (sumpart k)==[] then (r k) 262 else (r k) `vecsub` (vecsum (sumpart k)) 263 sumpart k = sumparts!!k 264 sumparts = [ [ b_ij `bvecmult` (y j) | (i,j, b_ij) <- (getrow k m) , j<i ] 265 | k <- (blocknums m) ] 266 b i j = msubscript i j m 267 268 269 270 271 backward :: Matrix -> Vector -> [Vec] 272 backward m ys 273 = mkvector ps 274 where 275 ps = [ (p' k) | k <- (blocknums m) ] 276 p' k = if (terms k) == [] then (y k) 277 else (y k) `vecsub` ((b k k) `bvecmult` (vecsum (terms k))) 278 p k = ps !! k 279 b i j = msubscript i j m 280 y k = vsubscript k ys 281 termss = [ [b_i_j `bvecmult` (p j) | (i,j, b_i_j) <- (getrow k m), j>i ] 282 | k <- (blocknums m) ] 283 terms k = termss!!k 284 285 286 287 288 289 290 291 292 293 294 mult [ ] [ ] = [] 295 mult [x] [ ] = [] 296 mult [ ] [y] = [] 297 mult [x] [y] = [bmult x y] 298 299 sub [ ] [ ] = [] 300 sub [x] [ ] = [x] 301 sub [ ] [y] = [bneg y] 302 sub [x] [y] = [bsub x y] 303 304 sumblocks xs = if blocks /= [] then [bsum blocks] 305 else [] 306 where blocks = concat xs 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 bsum :: [Block] -> Block 332 bsum [] = error "bsum: no blocks to sum" 333 bsum ms = foldl1 badd ms 334 335 diag_matrix :: Matrix -> Matrix 336 diag_matrix m = mkmatrix [filter is_diag (getrow i m) | i <- blocknums m] 337 338 diag_blocks :: Matrix -> [Block_tuple] 339 diag_blocks m 340 = if (length(diags) == numrows m) then concat diags 341 else error "matlib:diag_blocks: given matrix with missing diagonal block" 342 where 343 diags = [filter is_diag (getrow i m) | i <- blocknums m] 344 345 is_diag :: Block_tuple -> Bool 346 is_diag (r,c,b) = r==c 347 348 349 blocknums :: Matrix -> [Int] 350 blocknums m = [0..(numrows m)-1] 351 352 353 vecsum :: [Vec] -> Vec 354 vecsum [] = error "vecsum: no vecs to sum" 355 vecsum ms = foldl1 vecadd ms 356 357 okindex :: Int -> [a] -> Bool 358 okindex n m = (0<=n)&&(n<=((length m)-1)) 359 360 testmat :: Matrix -> [Char] 361 testmat m 362 = if result == [] then "Matrix is probably ok" 363 else result 364 where 365 result = (rowsnumbered m) ++ (symmetric m) ++ (sorted m) 366 367 rowsnumbered :: Matrix -> [Char] 368 rowsnumbered m 369 = concat [ goodrow i (getrow i m) | i <- blocknums m ] 370 where 371 goodrow i row = concat (map (isrow i) row) 372 isrow i (r,c,b) 373 =if i==r then [] 374 else ("Row " ++ (show i) ++ " is misnumbered\n") 375 376 symmetric :: Matrix -> [Char] 377 symmetric m 378 = concat [ symmetric_row (getrow i m) | i <- blocknums m ] 379 where 380 symmetric_row row = concat [ has_corresponding elem | elem <- row ] 381 has_corresponding (r,c,b) 382 = if exists c r then [] 383 else "Cannot find corresponding block for " ++ (show (r,c))++"\n" 384 exists r c = (filter (iscol c) (getrow r m)) /= [] 385 iscol c (i,j,b) = c==j 386 387 sorted :: Matrix -> [Char] 388 sorted m 389 = concat [ sortedrow i (getrow i m) | i <- blocknums m ] 390 where 391 sortedrow i row 392 = if row == (sort row) then [] 393 else "Row number " ++ (show i) ++ " is not properly sorted.\n" 394 395 sort :: (Ord a) => [a] -> [a] 396 sort xs = if (n == 1) then xs 397 else merge (sort us) (sort vs) 398 where 399 n = genericLength xs 400 us = take (n `div` 2) xs 401 vs = drop (n `div` 2) xs 402 403 404 merge :: (Ord a) => [a] -> [a] -> [a] 405 merge [] ys = ys 406 merge (x:xs) [] = (x:xs) 407 merge (x:xs) (y:ys) = if (x <= y) then (x:merge xs (y:ys)) 408 else (y:(merge (x:xs) ys)) 409