projects
/
coq-hetmet.git
/ blobdiff
commit
grep
author
committer
pickaxe
?
search:
re
summary
|
shortlog
|
log
|
commit
|
commitdiff
|
tree
raw
|
inline
| side by side
update to use Control.GArrow instead of GHC.HetMet.GArrow
[coq-hetmet.git]
/
examples
/
GArrowPortShape.hs
diff --git
a/examples/GArrowPortShape.hs
b/examples/GArrowPortShape.hs
index
e746b5f
..
45cbdb2
100644
(file)
--- a/
examples/GArrowPortShape.hs
+++ b/
examples/GArrowPortShape.hs
@@
-19,11
+19,11
@@
-- information for certain nodes (the inference mechanism below adds
-- it on every node).
--
-- information for certain nodes (the inference mechanism below adds
-- it on every node).
--
-module GArrowPortShape (GArrowPortShape(..), PortShape(..), detectShape)
+module GArrowPortShape (GArrowPortShape(..), PortShape(..), detectShape, Detect(..), DetectM, freshM)
where
import Prelude hiding ( id, (.), lookup )
import Control.Category
where
import Prelude hiding ( id, (.), lookup )
import Control.Category
-import GHC.HetMet.GArrow
+import Control.GArrow
import Unify
import GArrowSkeleton
import Control.Monad.State
import Unify
import GArrowSkeleton
import Control.Monad.State
@@
-40,6
+40,11
@@
data PortShape a = PortUnit
| PortTensor (PortShape a) (PortShape a)
| PortFree a
| PortTensor (PortShape a) (PortShape a)
| PortFree a
+instance Show a => Show (PortShape a) where
+ show PortUnit = "U"
+ show (PortTensor p1 p2) = "("++show p1++"*"++show p2++")"
+ show (PortFree x) = show x
+
data GArrowPortShape m s a b =
GASPortPassthrough
(PortShape s)
data GArrowPortShape m s a b =
GASPortPassthrough
(PortShape s)
@@
-60,7
+65,11
@@
instance Unifiable UPort where
unify' (PortTensor x1 y1) (PortTensor x2 y2) = mergeU (unify x1 x2) (unify y1 y2)
unify' PortUnit PortUnit = emptyUnifier
unify' s1 s2 = error $ "Unifiable UPort got impossible unification case: "
unify' (PortTensor x1 y1) (PortTensor x2 y2) = mergeU (unify x1 x2) (unify y1 y2)
unify' PortUnit PortUnit = emptyUnifier
unify' s1 s2 = error $ "Unifiable UPort got impossible unification case: "
--- ++ show s1 ++ " and " ++ show s2
+
+ replace uv prep PortUnit = PortUnit
+ replace uv prep (PortTensor p1 p2) = PortTensor (replace uv prep p1) (replace uv prep p2)
+ replace uv prep (PortFree x) = if x==uv then prep else PortFree x
+
inject = PortFree
project (PortFree v) = Just v
project _ = Nothing
inject = PortFree
project (PortFree v) = Just v
project _ = Nothing
@@
-119,16
+128,19
@@
resolveG u (GASPortShapeWrapper x y g) = GASPortShapeWrapper (getU' u x) (getU'
resolveG' (GAS_loopr f) = GAS_loopr (resolveG' f)
resolveG' (GAS_misc g ) = GAS_misc $ resolveG u g
resolveG' (GAS_loopr f) = GAS_loopr (resolveG' f)
resolveG' (GAS_misc g ) = GAS_misc $ resolveG u g
-detectShape :: GArrowSkeleton m a b -> GArrowPortShape m () a b
+detectShape :: Detect m => GArrowSkeleton m a b -> GArrowPortShape m () a b
detectShape g = runM (detect g)
detectShape g = runM (detect g)
-runM :: DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
+runM :: Detect m => DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
runM f = let s = (emptyUnifier,uvarSupply)
g = evalState f s
(u,_) = execState f s
in resolveG u g
runM f = let s = (emptyUnifier,uvarSupply)
g = evalState f s
(u,_) = execState f s
in resolveG u g
-detect :: GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b)
+class Detect m where
+ detect' :: m x y -> DetectM (GArrowPortShape m UVar x y)
+
+detect :: Detect m => GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b)
detect (GAS_id ) = do { x <- freshM ; return $ GASPortShapeWrapper (PortFree x) (PortFree x) GAS_id }
detect (GAS_comp f g) = do { f' <- detect f
; g' <- detect g
detect (GAS_id ) = do { x <- freshM ; return $ GASPortShapeWrapper (PortFree x) (PortFree x) GAS_id }
detect (GAS_comp f g) = do { f' <- detect f
; g' <- detect g
@@
-182,28
+194,25
@@
detect GAS_unassoc = do { x <- freshM; y <- freshM; z <- freshM
}
detect (GAS_const i) = do { x <- freshM; return $ GASPortShapeWrapper PortUnit (PortFree x) (GAS_const i) }
}
detect (GAS_const i) = do { x <- freshM; return $ GASPortShapeWrapper PortUnit (PortFree x) (GAS_const i) }
--- FIXME: I need to fix the occurs check before I can make these different again
detect GAS_merge = do { x <- freshM
detect GAS_merge = do { x <- freshM
- ; y <- freshM
- ; return $ GASPortShapeWrapper (PortTensor (PortFree x) (PortFree y)) (PortFree x) GAS_merge }
+ ; return $ GASPortShapeWrapper (PortTensor (PortFree x) (PortFree x)) (PortFree x) GAS_merge }
+
detect (GAS_loopl f) = do { x <- freshM
; y <- freshM
; z <- freshM
detect (GAS_loopl f) = do { x <- freshM
; y <- freshM
; z <- freshM
- ; z' <- freshM -- remove once I fix the occurs check
; f' <- detect f
; unifyM (fst $ shapes f') (PortTensor (PortFree z) (PortFree x))
; f' <- detect f
; unifyM (fst $ shapes f') (PortTensor (PortFree z) (PortFree x))
- ; unifyM (snd $ shapes f') (PortTensor (PortFree z') (PortFree y))
+ ; unifyM (snd $ shapes f') (PortTensor (PortFree z) (PortFree y))
; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopl (GAS_misc f'))
}
detect (GAS_loopr f) = do { x <- freshM
; y <- freshM
; z <- freshM
; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopl (GAS_misc f'))
}
detect (GAS_loopr f) = do { x <- freshM
; y <- freshM
; z <- freshM
- ; z' <- freshM -- remove once I fix the occurs check
; f' <- detect f
; unifyM (fst $ shapes f') (PortTensor (PortFree x) (PortFree z))
; f' <- detect f
; unifyM (fst $ shapes f') (PortTensor (PortFree x) (PortFree z))
- ; unifyM (snd $ shapes f') (PortTensor (PortFree y) (PortFree z'))
+ ; unifyM (snd $ shapes f') (PortTensor (PortFree y) (PortFree z))
; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopr (GAS_misc f'))
}
; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopr (GAS_misc f'))
}
-detect (GAS_misc f) = error "GAS_misc: not implemented"
+detect (GAS_misc f) = detect' f