1 
    2 
    3 
    4 
    5 
    6 
    7 
    8 
    9 
   10 
   11 
   12 
   13 
   14 
   15 
   16 
   17 
   18 
   19 
   20 
   21 
   22 
   23 
   24 
   25 
   26 
   27 
   28 
   29 
   30 
   31 
   32 
   33 
   34 
   35 
   36 module Densematrix(Matrix,Vector,
   37         mmult, madd, msub, vouter, vdot, norm ,mneg ,mxpose,mident, msize,
   38         mkmat, mkrmat, mkcmat, mkvec, mkrvec, mkcvec,
   39         vadd, vsub, vsize, vneg,
   40         swaprow, swapcol ,droprow, getrow, getcol,
   41         subscript, vsubscript, vecpart,
   42         update ,vupdate, update2,
   43         vhd, vtl,mergevectors,
   44         matvecmult, vmmult,svmult,
   45         showmatrix,displayvector,
   46         minverse,
   47         veclist, matlist ) where
   48 
   49 import List (transpose)
   50 import Utils
   51 
   52 type Matrix    =  [[Float]]
   53 type Vector    =  [Float]
   54 
   55 
   56 
   57 mmult   :: Matrix -> Matrix -> Matrix
   58 madd    :: Matrix -> Matrix -> Matrix
   59 msub    :: Matrix -> Matrix -> Matrix
   60 vadd    :: Vector -> Vector -> Vector
   61 vsub    :: Vector -> Vector -> Vector
   62 vsize   :: Vector -> Int
   63 vouter  :: Vector -> Vector -> Matrix
   64 vdot    :: Vector -> Vector -> Float
   65 norm    :: Vector -> Float
   66 vneg    :: Vector -> Vector
   67 mneg    :: Matrix -> Matrix
   68 mxpose  :: Matrix -> Matrix
   69 mident  :: Int -> Matrix
   70 msize   :: Matrix -> (Int,Int)
   71 mkmat   :: [[Float]] -> Matrix
   72 mkrmat  :: [[Float]] -> Matrix
   73 mkcmat  :: [[Float]] -> Matrix
   74 mkvec   :: [Float] -> Vector
   75 mkrvec  :: [Float] -> Vector
   76 mkcvec  :: [Float] -> Vector
   77 veclist :: Vector -> [Float]
   78 matlist :: Matrix -> [[Float]]
   79 
   80 swaprow    :: Int -> Int -> Matrix -> Matrix
   81 swapcol    :: Int -> Int -> Matrix -> Matrix
   82 droprow    :: Matrix -> Matrix
   83 getrow     :: Int -> Matrix -> Vector
   84 getcol     :: Int -> Matrix -> Vector
   85 
   86 subscript    :: Matrix -> (Int,Int) -> Float
   87 vsubscript   :: Vector -> Int -> Float
   88 vecpart      :: Int -> Int -> Vector -> Vector
   89 update       :: Matrix -> (Int,Int) -> Float -> Matrix
   90 update2      :: Matrix -> (Int,Int,Int) -> (Float,Float) -> Matrix
   91 vupdate      :: Vector -> Int -> Float -> Vector
   92 vhd          :: Vector -> Float
   93 vtl          :: Vector -> Vector
   94 mergevectors :: [Vector] -> Vector
   95 
   96 minverse      :: Matrix -> Matrix
   97 lu_decomp     :: Matrix -> Lu_factor
   98 apply_factor  :: Lu_factor -> Matrix -> Matrix
   99 forward       :: [[Float]] -> [Int] -> Matrix -> Matrix
  100 backward      :: [[Float]] -> [Int] -> Matrix -> Matrix
  101 
  102 showmatrix :: Matrix -> [Char]
  103 displayvector :: Vector -> [Char]
  104 
  105 matvecmult :: Matrix -> Vector -> Vector
  106 vmmult :: Vector -> Matrix -> Vector
  107 svmult :: Float -> Vector -> Vector
  108 
  109 
  110 
  111 
  112 
  113 
  114 
  115 
  116 
  117 
  118 
  119 
  120 
  121 mmult m n = if compatible then [row i | i <- [0..((length m)-1)]]
  122             else error errmsg
  123             where
  124               compatible = (snd (msize m)) == (fst (msize n))
  125               row i = [item i j | j <- [0..length(head n)-1]]
  126               item i j = vdot (m!!i) (map (!!j) n)
  127               errmsg = "mmult in densematrix: incompatible matrices " ++
  128                   (show (msize m)) ++ "*" ++ (show (msize n))
  129 
  130 
  131 
  132 
  133 
  134 matvecmult m v = map (vdot v) m
  135 
  136 
  137 
  138 
  139 vmmult v m = res
  140      where
  141         res' = mmult [v] m
  142         res = concat res'
  143 
  144 
  145 svmult s v = map (*s) v
  146 
  147 
  148 vdot u v = sum (map2 (*) u v)
  149 
  150 norm v = vdot v v
  151 
  152 vneg = map negate
  153 
  154 vouter u v = [ map (k*) u | k <- v]
  155 
  156 
  157 madd a b = map2 vadd a b
  158 
  159 vadd u v = map2 (+) u v
  160 
  161 msub a b = map2 vsub a b
  162 
  163 vsub u v = map2 (-) u v
  164 
  165 vsize = (length)
  166 
  167 mneg rs = map (map negate) rs
  168 
  169 mxpose = transpose
  170 
  171 mident n
  172    = [ row k | k <- [1..n] ]
  173      where
  174         row i = (rep (i-1) 0) ++ [1] ++ (rep (n-i) 0)
  175 
  176 
  177 msize m = (length m , length(head m))
  178 
  179 mkmat = id
  180 mkrmat = id
  181 mkcmat = transpose
  182 matlist = id
  183 
  184 
  185 mkvec =  id
  186 mkrvec = id
  187 mkcvec = id
  188 veclist = id
  189 
  190 
  191 
  192 swaprow a b m = swapitems a b m
  193 swapcol a b m = map (swapitems a b) m
  194 
  195 swapitems :: Int -> Int -> [a] -> [a]
  196 swapitems a b xs
  197    = if a< b then toa ++ [xs!!b] ++ atob ++ [xs!!a] ++ pastb
  198      else if a > b then swapitems b a xs else xs
  199      where
  200         toa = take a xs
  201         atob = take (b-a-1) (drop (a+1) xs)
  202         pastb = drop (b+1) xs
  203 
  204 
  205 droprow = tail
  206 
  207 
  208 getrow i m =  m!!i
  209 getcol i m = map (!!i) m
  210 
  211 
  212 subscript m (i,j) = m !! i !! j
  213 
  214 vsubscript v n = v!!n
  215 
  216 vecpart start intber
  217    = (take intber) . (drop start)
  218 
  219 update m (i,j) val
  220    = (take i m) ++ [(f (m!!i))] ++ (drop (i+1) m)
  221      where
  222         f xs = (take j xs) ++ [val] ++ (drop (j+1) xs)
  223 
  224 update2 m (i1,i2,j) (val1,val2)
  225    = (take i1 m) ++ [(f1 (m!!i1))] ++ [(f2 (m!!i2))] ++ (drop (i1+2) m)
  226      where
  227         f1 xs = (take j xs) ++ [val1] ++ (drop (j+1) xs)
  228         f2 xs = (take j xs) ++ [val2] ++ (drop (j+1) xs)
  229 
  230 
  231 vupdate v i val
  232    = (take i v) ++ [val] ++ (drop (i+1) v)
  233 
  234 
  235 vhd = head
  236 vtl = tail
  237 
  238 mergevectors = concat
  239 
  240 
  241 showmatrix m = "\n" ++ showmat m ++ "\n"
  242 
  243 showmat :: [Vector] -> [Char]
  244 showmat m
  245    = concat (map show_row m)
  246      where
  247         show_row v = (displayvector v) ++ "\n"
  248 
  249 
  250 displayvector
  251    =  concat                 .
  252       (map (rjustify 13)) .
  253       (map show)
  254 
  255 
  256 
  257 
  258 
  259 
  260 
  261 
  262 
  263 
  264 
  265 
  266 
  267 
  268 
  269 
  270 
  271 
  272 
  273 
  274 
  275 
  276 
  277 
  278 
  279 
  280 
  281 
  282 
  283 
  284 
  285 
  286 
  287 
  288 
  289 
  290 
  291 
  292 
  293 
  294 
  295 
  296 
  297 
  298 
  299 
  300 
  301 
  302 data Lu_factor = Lu_fact [[Float]] [[Float]] [Int]
  303 
  304 
  305 
  306 minverse m
  307    = apply_factor (lu_decomp m) (mident (fst(msize m)))
  308 
  309 
  310 apply_factor factor m
  311    = ((backward uvecs ps) . (forward lvecs ps)) m
  312      where
  313         (Lu_fact lvecs uvecs ps) = factor
  314 
  315 
  316   
  317 forward  [] p b = b
  318 forward (l:ls) (p:ps) b
  319    = y1 : (forward ls ps y')
  320      where
  321         y'  = msub y (vouter y1 l)
  322         (y1:y) = swaprow 0 p b
  323 
  324 
  325 backward [[u]] p y = map (map (/u)) y
  326 backward ((u1:u):us) (p:ps) (y:ys)
  327   = x
  328     where
  329        x = swaprow 0 p (xk:nextx)
  330        xk = map (/u1) (vsub y (vmmult u nextx))
  331        nextx = backward us ps ys
  332 
  333 
  334 
  335 
  336 lu_decomp a
  337    = if (fst(msize a) == 1) then Lu_fact [] [[only_one]] [1]
  338       else Lu_fact (l:ls)  (u:us)  (p:ps)
  339      where
  340         Lu_fact ls us ps  = lu_decomp a11'
  341         (a',p)  = pivot a
  342         u11  = head u
  343         u  = getrow 0 a'
  344         m  = getcol 0 a'
  345         l  = if not(u11 == 0) then tail (map (/ u11) m)
  346              else error ("div by 0 in lu: block is\n  "++(show a))
  347         a11  = map tail (tail a')
  348         a11' = msub a11 (vouter (tail u) l)
  349         only_one  = subscript a (0,0)
  350 
  351 
  352 
  353 
  354 pivot :: Matrix -> (Matrix,Int)
  355 pivot m
  356    = (swaprow 0 p (swapcol 0 p m), p)
  357      where
  358         p = findpivot m
  359 
  360 
  361 findpivot :: Matrix -> Int
  362 findpivot a
  363    = loc_of_max 0 diag
  364      where
  365        diag = [a `subscript` (i,i) |  i <- [0..(fst(msize a))-1] ]
  366        absmax = maxlist (map absfloat diag)
  367        loc_of_max n (x:xs)
  368                 = if (absfloat x) == absmax then n
  369                   else loc_of_max (n+1) xs
  370        
  371 
  372 absfloat :: Float -> Float
  373 absfloat n = if n < 0 then -n
  374              else n
  375 
  376 maxlist :: [Float] -> Float
  377 maxlist xs = foldl1 max xs
  378 
  379 
  380 
  381 
  382 
  383 
  384 
  385 
  386 
  387 
  388