flaw-visual/Flaw/Visual/Geometry/Simplification.hs
{-|
Module: Flaw.Visual.Geometry.Simplification
Description: Geometry simplification algorithm.
License: MIT
-}
{-# LANGUAGE FlexibleContexts, ViewPatterns #-}
module Flaw.Visual.Geometry.Simplification
( simplifyGeometry
) where
import Control.Monad
import Control.Monad.ST
import qualified Data.Map.Strict as M
import Data.STRef
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as VSM
import qualified Data.Vector.Unboxed.Mutable as VUM
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import Flaw.Math
data PairKey = PairKey {-# UNPACK #-} !Word32 {-# UNPACK #-} !Word32 deriving (Eq, Ord)
data Pair = Pair
{ pair_vertex0 :: {-# UNPACK #-} !Word32
, pair_vertex1 :: {-# UNPACK #-} !Word32
, pair_optimalVertex :: {-# UNPACK #-} !Float3
, pair_cost :: {-# UNPACK #-} !Float
}
instance Storable Pair where
sizeOf _ = 24
alignment _ = 4
peek ptr = Pair
<$> peek (castPtr ptr)
<*> peek (castPtr ptr `plusPtr` 4)
<*> peek (castPtr ptr `plusPtr` 8)
<*> peek (castPtr ptr `plusPtr` 20)
poke ptr (Pair v0 v1 ov c) = do
poke (castPtr ptr) v0
poke (castPtr ptr `plusPtr` 4) v1
poke (castPtr ptr `plusPtr` 8) ov
poke (castPtr ptr `plusPtr` 20) c
-- | Simplify geometry.
-- Algorithm: http://cseweb.ucsd.edu/~ravir/190/2016/garland97.pdf
simplifyGeometry :: Int -> VS.Vector Float3 -> VS.Vector Word32 -> (VS.Vector Float3, VS.Vector Word32)
simplifyGeometry iterationsCount vertices indices = runST $ do
let
verticesCount = VG.length vertices
indicesCount = VG.length indices
trianglesCount = indicesCount `quot` 3
-- allocate memory
vertexPositions <- VS.thaw vertices
vertexParents <- VUM.new verticesCount
let
f i = when (i < verticesCount) $ do
VGM.unsafeWrite vertexParents i i
f $ i + 1
in f 0
vertexRanks <- VUM.replicate verticesCount (1 :: Word32)
vertexQuadrics <- VSM.replicate verticesCount 0
pairHeap <- VSM.new indicesCount
pairHeapSizeRef <- newSTRef 0
pairHeapIndexByKeyRef <- newSTRef M.empty
let
calculatePair (PairKey v0 v1) = do
quadric0 <- VGM.unsafeRead vertexQuadrics $ fromIntegral v0
quadric1 <- VGM.unsafeRead vertexQuadrics $ fromIntegral v1
let
quadricSum@(Float4x4
q11 q12 q13 q14
q21 q22 q23 q24
q31 q32 q33 q34
_q41 _q42 _q43 _q44
) = quadric0 + quadric1
q_12_12 = q11 * q22 - q21 * q12
q_12_13 = q11 * q23 - q21 * q13
q_12_14 = q11 * q24 - q21 * q14
q_12_23 = q12 * q23 - q22 * q13
q_12_24 = q12 * q24 - q22 * q14
q_12_34 = q13 * q24 - q23 * q14
q_123_123 = q_12_12 * q33 - q_12_13 * q32 + q_12_23 * q31
q_123_124 = q_12_12 * q34 - q_12_14 * q32 + q_12_24 * q31
q_123_134 = q_12_13 * q34 - q_12_14 * q33 + q_12_34 * q31
q_123_234 = q_12_23 * q34 - q_12_24 * q33 + q_12_34 * q32
-- q_1234_1234 = q_123_123 * q44 - q_123_124 * q43 + q_123_134 * q42 - q_123_234 * q41
aff :: Float3 -> Float4
aff (Float3 x y z) = Float4 x y z 1
cost :: Float4x4 -> Float3 -> Float
cost q (aff -> p) = p `dot` (q `mul` p)
pv0 <- VGM.unsafeRead vertexPositions $ fromIntegral v0
pv1 <- VGM.unsafeRead vertexPositions $ fromIntegral v1
let
(oc, ov) =
-- try global minimum
if abs q_123_123 > 1e-4 then let
v = Float3
(negate $ q_123_234 / q_123_123)
(q_123_134 / q_123_123)
(negate $ q_123_124 / q_123_123)
in (cost quadricSum v, v)
-- try minimum along the edge
else let
v10 = aff pv1 - aff pv0
u = quadricSum `mul` v10
h = u `dot` v10
in if abs h > 1e-4 then let
hinv = 1 / h
v = pv0 * vecFromScalar ((u `dot` aff pv1) * hinv) - pv1 * vecFromScalar ((u `dot` aff pv0) * hinv)
in (cost quadricSum v, v)
-- try minimum from ends and the middle
else let
pm = (pv0 + pv1) * 0.5
in min (cost quadricSum pm, pm) $ min (cost quadricSum pv0, pv0) (cost quadricSum pv1, pv1)
return Pair
{ pair_vertex0 = v0
, pair_vertex1 = v1
, pair_optimalVertex = ov
, pair_cost = oc - cost quadric0 pv0 - cost quadric1 pv1
}
-- calculate initial vertex quadrics
let
f i = when (i < trianglesCount) $ do
let
i1 = fromIntegral $ indices VG.! (i * 3)
i2 = fromIntegral $ indices VG.! (i * 3 + 1)
i3 = fromIntegral $ indices VG.! (i * 3 + 2)
p1 = vertices VG.! i1
p2 = vertices VG.! i2
p3 = vertices VG.! i3
normal@(Float3 a b c) = normalize $ cross (p2 - p1) (p3 - p1)
d = negate $ dot normal p1
aa = a * a
ab = a * b
ac = a * c
ad = a * d
bb = b * b
bc = b * c
bd = b * d
cc = c * c
cd = c * d
dd = d * d
q = Float4x4
aa ab ac ad
ab bb bc bd
ac bc cc cd
ad bd cd dd
VGM.unsafeModify vertexQuadrics (+ q) i1
VGM.unsafeModify vertexQuadrics (+ q) i2
VGM.unsafeModify vertexQuadrics (+ q) i3
f $ i + 1
in f 0
-- add initial pairs
let
f i = when (i < trianglesCount) $ do
let
i1 = indices VG.! (i * 3)
i2 = indices VG.! (i * 3 + 1)
i3 = indices VG.! (i * 3 + 2)
addPair v0 v1 = unless (v0 == v1) $ do
let key = PairKey (min v0 v1) (max v0 v1)
pairHeapIndexByKey <- readSTRef pairHeapIndexByKeyRef
unless (M.member key pairHeapIndexByKey) $ do
pairHeapSize <- readSTRef pairHeapSizeRef
VGM.unsafeWrite pairHeap pairHeapSize =<< calculatePair key
writeSTRef pairHeapSizeRef $! pairHeapSize + 1
writeSTRef pairHeapIndexByKeyRef $! M.insert key pairHeapSize pairHeapIndexByKey
addPair i1 i2
addPair i2 i3
addPair i1 i3
f $ i + 1
in f 0
-- add reverse pairs in index too
do
pairs <- readSTRef pairHeapIndexByKeyRef
writeSTRef pairHeapIndexByKeyRef $! foldr (\(PairKey v0 v1, p) -> M.insert (PairKey v1 v0) p) pairs (M.toList pairs)
-- find open edges and add additional costs for vertices
do
pairHeapIndexByKey <- readSTRef pairHeapIndexByKeyRef
pairHeapSize <- readSTRef pairHeapSizeRef
-- count triangles for every edge
pairTriangleCounts <- VUM.replicate pairHeapSize (0 :: Word32)
let
f i = when (i < trianglesCount) $ do
let
i1 = indices VG.! (i * 3)
i2 = indices VG.! (i * 3 + 1)
i3 = indices VG.! (i * 3 + 2)
accountForPair a b = case M.lookup (PairKey a b) pairHeapIndexByKey of
Just h -> VGM.unsafeModify pairTriangleCounts (+ 1) h
Nothing -> return ()
accountForPair i1 i2
accountForPair i2 i3
accountForPair i1 i3
f $ i + 1
in f 0
-- for every open edge add penalty to its vertices for moving from the edge
let
f i = when (i < pairHeapSize) $ do
triangleCount <- VGM.unsafeRead pairTriangleCounts i
when (triangleCount == 1) $ do
Pair
{ pair_vertex0 = v0
, pair_vertex1 = v1
} <- VGM.unsafeRead pairHeap i
pv0 <- VGM.unsafeRead vertexPositions $ fromIntegral v0
pv1 <- VGM.unsafeRead vertexPositions $ fromIntegral v1
when (norm2 (pv1 - pv0) > 1e-8) $ do
let
r@(Float3 rx ry rz) = normalize $ pv1 - pv0
rxx = rx * rx
rxy = rx * ry
rxz = rx * rz
ryy = ry * ry
ryz = ry * rz
rzz = rz * rz
a = Float3 (rxx - 1) rxy rxz
b = Float3 rxy (ryy - 1) ryz
c = Float3 rxz ryz (rzz - 1)
d = pv0 - r * vecFromScalar (dot r pv0)
aa = dot a a
ab = dot a b
ac = dot a c
ad = dot a d
bb = dot b b
bc = dot b c
bd = dot b d
cc = dot c c
cd = dot c d
dd = dot d d
q = Float4x4
aa ab ac ad
ab bb bc bd
ac bc cc cd
ad bd cd dd
VGM.unsafeModify vertexQuadrics (+ q) $ fromIntegral v0
VGM.unsafeModify vertexQuadrics (+ q) $ fromIntegral v1
f $ i + 1
in f 0
-- vertex disjoint-set functions
let
-- get vertex parent
vertexParent a = do
p <- VGM.unsafeRead vertexParents a
if p == a then return p else do
pp <- vertexParent p
VGM.unsafeWrite vertexParents a pp
return pp
-- union vertices
unionVertices a b = do
pa <- vertexParent a
pb <- vertexParent b
if pa == pb then return pa else do
ra <- VGM.unsafeRead vertexRanks pa
rb <- VGM.unsafeRead vertexRanks pb
if ra > rb then do
VGM.unsafeWrite vertexParents pb pa
VGM.unsafeWrite vertexRanks pa $ ra + rb
return pa
else do
VGM.unsafeWrite vertexParents pa pb
VGM.unsafeWrite vertexRanks pb $ ra + rb
return pb
-- heap functions
-- swap elements
heapSwap a pa@Pair
{ pair_vertex0 = a0
, pair_vertex1 = a1
} b pb@Pair
{ pair_vertex0 = b0
, pair_vertex1 = b1
} = do
VGM.unsafeWrite pairHeap b pa
VGM.unsafeWrite pairHeap a pb
modifySTRef' pairHeapIndexByKeyRef
$ M.insert (PairKey a0 a1) b
. M.insert (PairKey a1 a0) b
. M.insert (PairKey b0 b1) a
. M.insert (PairKey b1 b0) a
-- sift down
heapSiftDown i = do
pairHeapSize <- readSTRef pairHeapSizeRef
p <- VGM.unsafeRead pairHeap i
let
l = i * 2 + 1
r = i * 2 + 2
swap m mp = do
heapSwap i p m mp
heapSiftDown m
lp <- if l < pairHeapSize then VGM.unsafeRead pairHeap l else return p
rp <- if r < pairHeapSize then VGM.unsafeRead pairHeap r else return p
if pair_cost lp < pair_cost p && pair_cost lp <= pair_cost rp then swap l lp
else if pair_cost rp < pair_cost p && pair_cost rp <= pair_cost lp then swap r rp
else return ()
-- sift up
heapSiftUp i = when (i > 0) $ do
p <- VGM.unsafeRead pairHeap i
let m = (i - 1) `quot` 2
mp <- VGM.unsafeRead pairHeap m
when (pair_cost p < pair_cost mp) $ do
heapSwap i p m mp
heapSiftUp m
-- update
heapUpdate i =
if i > 0 then do
p <- VGM.unsafeRead pairHeap i
let m = (i - 1) `quot` 2
mp <- VGM.unsafeRead pairHeap m
case compare (pair_cost p) (pair_cost mp) of
LT -> heapSiftUp i
GT -> heapSiftDown i
EQ -> return ()
else heapSiftDown i
-- delete
heapDelete i = do
pairHeapSize <- readSTRef pairHeapSizeRef
let m = pairHeapSize - 1
writeSTRef pairHeapSizeRef $! m
Pair
{ pair_vertex0 = v0
, pair_vertex1 = v1
} <- VGM.unsafeRead pairHeap i
if i < m then do
mp@Pair
{ pair_vertex0 = mv0
, pair_vertex1 = mv1
} <- VGM.unsafeRead pairHeap m
VGM.unsafeWrite pairHeap i mp
modifySTRef' pairHeapIndexByKeyRef
$ M.insert (PairKey mv0 mv1) i
. M.insert (PairKey mv1 mv0) i
. M.delete (PairKey v0 v1)
. M.delete (PairKey v1 v0)
heapUpdate i
else modifySTRef' pairHeapIndexByKeyRef
$ M.delete (PairKey v0 v1)
. M.delete (PairKey v1 v0)
-- make actual heap from pairs
do
pairHeapSize <- readSTRef pairHeapSizeRef
let
f i = when (i >= 0) $ do
heapSiftDown i
f $ i - 1
in f $ pairHeapSize - 1
-- contraction step
let
contraction = do
-- get minimal cost pair
Pair
{ pair_vertex0 = pv0
, pair_vertex1 = pv1
, pair_optimalVertex = vo
} <- VGM.unsafeRead pairHeap 0
-- union vertices
v <- unionVertices (fromIntegral pv0) (fromIntegral pv1)
-- update position
VGM.unsafeWrite vertexPositions v vo
-- delete minimal cost pair
heapDelete 0
-- figure out what vertex got replaced
let
(v0, v1) = case (fromIntegral v == pv0, fromIntegral v == pv1) of
(True, False) -> (pv0, pv1)
(False, True) -> (pv1, pv0)
_ -> error $ show ("impossible pair", v, pv0, pv1)
-- sum up quadrics
v1q <- VGM.unsafeRead vertexQuadrics (fromIntegral v1)
VGM.unsafeModify vertexQuadrics (+ v1q) (fromIntegral v0)
-- update all pairs with v1 to use v0
pairsWithV1
<- M.takeWhileAntitone (\(PairKey a _) -> a == v1)
. M.dropWhileAntitone (\(PairKey a _) -> a < v1)
<$> readSTRef pairHeapIndexByKeyRef
forM_ (M.keys pairsWithV1) $ \oldKey@(PairKey _v1 b) -> unless (v0 == b) $ do
let newKey = PairKey v0 b
pairHeapIndex <- readSTRef pairHeapIndexByKeyRef
Just h <- return $ M.lookup oldKey pairHeapIndex
-- if contraction made double edge, remove it
if M.member newKey pairHeapIndex then heapDelete h
-- else fix edge to point to v0
else do
writeSTRef pairHeapIndexByKeyRef $
( M.insert newKey h
. M.insert (PairKey b v0) h
. M.delete oldKey
. M.delete (PairKey b v1)
) pairHeapIndex
VGM.unsafeModify pairHeap (\p -> p
{ pair_vertex0 = min b v0
, pair_vertex1 = max b v0
}) h
heapUpdate h
-- update all pairs using v0
pairsWithV0
<- M.takeWhileAntitone (\(PairKey a _) -> a == v0)
. M.dropWhileAntitone (\(PairKey a _) -> a < v0)
<$> readSTRef pairHeapIndexByKeyRef
forM_ (M.keys pairsWithV0) $ \(PairKey _v0 b) -> do
Just h <- M.lookup (PairKey v0 b) <$> readSTRef pairHeapIndexByKeyRef
VGM.unsafeWrite pairHeap h =<< calculatePair (PairKey (min v0 b) (max v0 b))
heapUpdate h
-- perform contractions
replicateM_ iterationsCount contraction
-- filter out indices with degenerate triangles
newIndices <- VSM.new indicesCount
end <- let
f p i =
if i < trianglesCount then do
i1 <- vertexParent $ fromIntegral $ indices VG.! (i * 3)
i2 <- vertexParent $ fromIntegral $ indices VG.! (i * 3 + 1)
i3 <- vertexParent $ fromIntegral $ indices VG.! (i * 3 + 2)
if i1 == i2 || i1 == i3 || i2 == i3 then f p (i + 1) else do
VSM.unsafeWrite newIndices p $ fromIntegral i1
VSM.unsafeWrite newIndices (p + 1) $ fromIntegral i2
VSM.unsafeWrite newIndices (p + 2) $ fromIntegral i3
f (p + 3) (i + 1)
else return p
in f 0 0
resultVertices <- VG.unsafeFreeze vertexPositions
resultIndices <- VG.unsafeFreeze $ VGM.slice 0 end newIndices
return (resultVertices, resultIndices)