update to use Control.GArrow instead of GHC.HetMet.GArrow
[coq-hetmet.git] / examples / GArrowPortShape.hs
index 8e5078b..45cbdb2 100644 (file)
 -- information for certain nodes (the inference mechanism below adds
 -- it on every node).
 --
-module GArrowPortShape (GArrowPortShape(..), PortShape(..))
+module GArrowPortShape (GArrowPortShape(..), PortShape(..), detectShape, Detect(..), DetectM, freshM)
 where
 import Prelude hiding ( id, (.), lookup )
 import Control.Category
-import GHC.HetMet.GArrow
+import Control.GArrow
 import Unify
 import GArrowSkeleton
 import Control.Monad.State
@@ -40,6 +40,11 @@ data PortShape a = PortUnit
                  | PortTensor (PortShape a) (PortShape a)
                  | PortFree a
 
+instance Show a => Show (PortShape a) where
+ show PortUnit           = "U"
+ show (PortTensor p1 p2) = "("++show p1++"*"++show p2++")"
+ show (PortFree x)       = show x
+
 data GArrowPortShape m s a b =
     GASPortPassthrough
       (PortShape s)
@@ -58,7 +63,13 @@ type UPort = PortShape UVar
 
 instance Unifiable UPort where
   unify' (PortTensor x1 y1) (PortTensor x2 y2) = mergeU (unify x1 x2) (unify y1 y2)
-  unify' _ _                                   = error "impossible"
+  unify' PortUnit PortUnit                     = emptyUnifier
+  unify' s1 s2                                 = error $ "Unifiable UPort got impossible unification case: "
+
+  replace uv prep PortUnit                    = PortUnit
+  replace uv prep (PortTensor p1 p2)          = PortTensor (replace uv prep p1) (replace uv prep p2)
+  replace uv prep (PortFree x)                = if x==uv then prep else PortFree x
+
   inject                                       = PortFree
   project (PortFree v)                         = Just v
   project _                                    = Nothing
@@ -117,18 +128,24 @@ resolveG u (GASPortShapeWrapper x y g) = GASPortShapeWrapper (getU' u x) (getU'
   resolveG' (GAS_loopr f)       = GAS_loopr (resolveG' f)
   resolveG' (GAS_misc g )       = GAS_misc $ resolveG u g
 
-runM :: DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
+detectShape :: Detect m => GArrowSkeleton m a b -> GArrowPortShape m () a b
+detectShape g = runM (detect g)
+
+runM :: Detect m => DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
 runM f = let s     = (emptyUnifier,uvarSupply)
              g     = evalState f s
              (u,_) = execState f s
           in resolveG u g
 
-detect :: GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b)
+class Detect m where
+  detect' :: m x y -> DetectM (GArrowPortShape m UVar x y)
+
+detect :: Detect m => GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b)
 detect (GAS_id      ) = do { x <- freshM ; return $ GASPortShapeWrapper (PortFree x) (PortFree x) GAS_id }
-detect (GAS_comp g f) = do { f' <- detect f
+detect (GAS_comp f g) = do { f' <- detect f
                            ; g' <- detect g
                            ; unifyM (snd $ shapes f') (fst $ shapes g')
-                           ; return $ GASPortShapeWrapper (fst $ shapes f') (snd $ shapes g') (GAS_comp (GAS_misc g') (GAS_misc f'))
+                           ; return $ GASPortShapeWrapper (fst $ shapes f') (snd $ shapes g') (GAS_comp (GAS_misc f') (GAS_misc g'))
                            }
 detect (GAS_first  f) = do { x <- freshM
                            ; f' <- detect f
@@ -176,10 +193,26 @@ detect GAS_unassoc    = do { x <- freshM; y <- freshM; z <- freshM
                                         GAS_unassoc
                            }
 detect (GAS_const i)  = do { x <- freshM; return $ GASPortShapeWrapper PortUnit (PortFree x) (GAS_const i) }
+
 detect GAS_merge      = do { x <- freshM
                            ; return $ GASPortShapeWrapper (PortTensor (PortFree x) (PortFree x)) (PortFree x) GAS_merge }
-detect (GAS_loopl f)  = error "not implemented"
-detect (GAS_loopr f)  = error "not implemented"
 
-detect (GAS_misc f)   = error "not implemented"
+detect (GAS_loopl f)  = do { x <- freshM
+                           ; y <- freshM
+                           ; z <- freshM
+                           ; f' <- detect f
+                           ; unifyM (fst $ shapes f') (PortTensor (PortFree z) (PortFree x))
+                           ; unifyM (snd $ shapes f') (PortTensor (PortFree z) (PortFree y))
+                           ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopl (GAS_misc f'))
+                           }
+detect (GAS_loopr f)  = do { x <- freshM
+                           ; y <- freshM
+                           ; z <- freshM
+                           ; f' <- detect f
+                           ; unifyM (fst $ shapes f') (PortTensor (PortFree x) (PortFree z))
+                           ; unifyM (snd $ shapes f') (PortTensor (PortFree y) (PortFree z))
+                           ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopr (GAS_misc f'))
+                           }
+
+detect (GAS_misc f)   = detect' f