GArrowTikZ: render more of the structural stuff in gray!50
[coq-hetmet.git] / examples / GArrowPortShape.hs
1 {-# LANGUAGE MultiParamTypeClasses, GADTs, FlexibleContexts, FlexibleInstances, TypeFamilies #-}
2 -----------------------------------------------------------------------------
3 -- |
4 -- Module      :  GArrowPortShape
5 -- Copyright   :  none
6 -- License     :  public domain
7 --
8 -- Maintainer  :  Adam Megacz <megacz@acm.org>
9 -- Stability   :  experimental
10 --
11 -- | We cannot, at run time, query to find out the input and output
12 -- port types of a GArrowSkeleton since Haskell erases types during
13 -- compilation.  Using Data.Typeable is problematic here because
14 -- GAS_comp and GAS_loop{l,r} have an existential type.
15 --
16 -- In spite of this, we can determine the "shape" of the ports --
17 -- which ports are of unit type, and which ports must be tensors.  A
18 -- GArrowPortShape is a GArrowSkeleton along with this
19 -- information for certain nodes (the inference mechanism below adds
20 -- it on every node).
21 --
22 module GArrowPortShape (GArrowPortShape(..), PortShape(..), detectShape)
23 where
24 import Prelude hiding ( id, (.), lookup )
25 import Control.Category
26 import GHC.HetMet.GArrow
27 import Unify
28 import GArrowSkeleton
29 import Control.Monad.State
30
31 --
32 -- | Please keep in mind that the "shapes" computed below are simply the
33 -- least-complicated shapes that could possibly work.  Just because a
34 -- GArrowPortShape has an input port of shape (x,y)
35 -- doesn't mean it couldn't later be used in a context where its input
36 -- port had shape ((a,b),y)!  However, you can be assured that it
37 -- won't be used in a context where the input port has shape ().
38 --
39 data PortShape a = PortUnit
40                  | PortTensor (PortShape a) (PortShape a)
41                  | PortFree a
42
43 instance Show a => Show (PortShape a) where
44  show PortUnit           = "U"
45  show (PortTensor p1 p2) = "("++show p1++"*"++show p2++")"
46  show (PortFree x)       = show x
47
48 data GArrowPortShape m s a b =
49     GASPortPassthrough
50       (PortShape s)
51       (PortShape s)
52       (m a b)
53   | GASPortShapeWrapper
54       (PortShape s)
55       (PortShape s)
56       (GArrowSkeleton (GArrowPortShape m s) a b)
57
58 --
59 -- implementation below; none of this is exported
60 --
61
62 type UPort = PortShape UVar
63
64 instance Unifiable UPort where
65   unify' (PortTensor x1 y1) (PortTensor x2 y2) = mergeU (unify x1 x2) (unify y1 y2)
66   unify' PortUnit PortUnit                     = emptyUnifier
67   unify' s1 s2                                 = error $ "Unifiable UPort got impossible unification case: "
68
69   replace uv prep PortUnit                    = PortUnit
70   replace uv prep (PortTensor p1 p2)          = PortTensor (replace uv prep p1) (replace uv prep p2)
71   replace uv prep (PortFree x)                = if x==uv then prep else PortFree x
72
73   inject                                       = PortFree
74   project (PortFree v)                         = Just v
75   project _                                    = Nothing
76   occurrences (PortFree v)                     = [v]
77   occurrences (PortTensor x y)                 = occurrences x ++ occurrences y
78   occurrences PortUnit                         = []
79
80 -- detection monad
81 type DetectM a = State ((Unifier UPort),[UVar]) a
82
83 shapes :: GArrowPortShape m UVar a b -> (UPort,UPort)
84 shapes (GASPortPassthrough  x y _) = (x,y)
85 shapes (GASPortShapeWrapper x y _) = (x,y)
86
87 unifyM :: UPort -> UPort -> DetectM ()
88 unifyM p1 p2 = do { (u,vars) <- get
89                   ; put (mergeU u $ unify p1 p2 , vars)
90                   }
91
92 freshM :: DetectM UVar
93 freshM = do { (u,(v:vars)) <- get
94             ; put (u,vars)
95             ; return v
96             }
97
98 -- recursive version of getU
99 getU' :: Unifier UPort -> UPort -> PortShape ()
100 getU' u (PortTensor x y)  = PortTensor (getU' u x) (getU' u y)
101 getU' _ PortUnit          = PortUnit
102 getU' u x@(PortFree v)    = case Unify.getU u v  of
103                                      Nothing -> PortFree () -- or x
104                                      Just x' -> getU' u x'
105
106 resolveG :: Unifier UPort -> (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
107 resolveG u (GASPortPassthrough  x y m) = GASPortPassthrough  (getU' u x) (getU' u y) m
108 resolveG u (GASPortShapeWrapper x y g) = GASPortShapeWrapper (getU' u x) (getU' u y) (resolveG' g)
109  where
110   resolveG' :: GArrowSkeleton (GArrowPortShape m UVar)             a b -> 
111                GArrowSkeleton (GArrowPortShape m ())   a b
112   resolveG' (GAS_id           ) = GAS_id
113   resolveG' (GAS_comp      f g) = GAS_comp (resolveG' f) (resolveG' g)
114   resolveG' (GAS_first       f) = GAS_first (resolveG' f)
115   resolveG' (GAS_second      f) = GAS_second (resolveG' f)
116   resolveG' GAS_cancell         = GAS_cancell
117   resolveG' GAS_cancelr         = GAS_cancelr
118   resolveG' GAS_uncancell       = GAS_uncancell
119   resolveG' GAS_uncancelr       = GAS_uncancelr
120   resolveG' GAS_drop            = GAS_drop
121   resolveG' (GAS_const i)       = GAS_const i
122   resolveG' GAS_copy            = GAS_copy
123   resolveG' GAS_merge           = GAS_merge
124   resolveG' GAS_swap            = GAS_swap
125   resolveG' GAS_assoc           = GAS_assoc
126   resolveG' GAS_unassoc         = GAS_unassoc
127   resolveG' (GAS_loopl f)       = GAS_loopl (resolveG' f)
128   resolveG' (GAS_loopr f)       = GAS_loopr (resolveG' f)
129   resolveG' (GAS_misc g )       = GAS_misc $ resolveG u g
130
131 detectShape :: GArrowSkeleton m a b -> GArrowPortShape m () a b
132 detectShape g = runM (detect g)
133
134 runM :: DetectM (GArrowPortShape m UVar a b) -> GArrowPortShape m () a b
135 runM f = let s     = (emptyUnifier,uvarSupply)
136              g     = evalState f s
137              (u,_) = execState f s
138           in resolveG u g
139
140 detect :: GArrowSkeleton m a b -> DetectM (GArrowPortShape m UVar a b)
141 detect (GAS_id      ) = do { x <- freshM ; return $ GASPortShapeWrapper (PortFree x) (PortFree x) GAS_id }
142 detect (GAS_comp f g) = do { f' <- detect f
143                            ; g' <- detect g
144                            ; unifyM (snd $ shapes f') (fst $ shapes g')
145                            ; return $ GASPortShapeWrapper (fst $ shapes f') (snd $ shapes g') (GAS_comp (GAS_misc f') (GAS_misc g'))
146                            }
147 detect (GAS_first  f) = do { x <- freshM
148                            ; f' <- detect f
149                            ; return $ GASPortShapeWrapper
150                                         (PortTensor (fst $ shapes f') (PortFree x))
151                                         (PortTensor (snd $ shapes f') (PortFree x))
152                                         (GAS_first (GAS_misc f'))
153                            }
154 detect (GAS_second f) = do { x <- freshM
155                            ; f' <- detect f
156                            ; return $ GASPortShapeWrapper
157                                         (PortTensor (PortFree x) (fst $ shapes f'))
158                                         (PortTensor (PortFree x) (snd $ shapes f'))
159                                         (GAS_second (GAS_misc f'))
160                            }
161 detect GAS_cancell    = do { x <- freshM; return$GASPortShapeWrapper (PortTensor PortUnit (PortFree x)) (PortFree x) GAS_cancell }
162 detect GAS_cancelr    = do { x <- freshM; return$GASPortShapeWrapper (PortTensor (PortFree x) PortUnit) (PortFree x) GAS_cancelr }
163 detect GAS_uncancell  = do { x <- freshM; return$GASPortShapeWrapper (PortFree x) (PortTensor PortUnit (PortFree x)) GAS_uncancell }
164 detect GAS_uncancelr  = do { x <- freshM; return$GASPortShapeWrapper (PortFree x) (PortTensor (PortFree x) PortUnit) GAS_uncancelr }
165 detect GAS_drop       = do { x <- freshM; return$GASPortShapeWrapper (PortFree x) PortUnit GAS_drop }
166 detect GAS_copy       = do { x <- freshM
167                            ; return $ GASPortShapeWrapper (PortFree x) (PortTensor (PortFree x) (PortFree x)) GAS_copy }
168 detect GAS_swap       = do { x <- freshM
169                            ; y <- freshM
170                            ; let x' = PortFree x
171                            ; let y' = PortFree y
172                            ; return $ GASPortShapeWrapper (PortTensor x' y') (PortTensor y' x') GAS_swap
173                            }
174 detect GAS_assoc      = do { x <- freshM; y <- freshM; z <- freshM
175                            ; let x' = PortFree x
176                            ; let y' = PortFree y
177                            ; let z' = PortFree z
178                            ; return $ GASPortShapeWrapper
179                                         (PortTensor (PortTensor x' y') z')
180                                         (PortTensor x' (PortTensor y' z'))
181                                         GAS_assoc
182                            }
183 detect GAS_unassoc    = do { x <- freshM; y <- freshM; z <- freshM
184                            ; let x' = PortFree x
185                            ; let y' = PortFree y
186                            ; let z' = PortFree z
187                            ; return $ GASPortShapeWrapper
188                                         (PortTensor x' (PortTensor y' z'))
189                                         (PortTensor (PortTensor x' y') z')
190                                         GAS_unassoc
191                            }
192 detect (GAS_const i)  = do { x <- freshM; return $ GASPortShapeWrapper PortUnit (PortFree x) (GAS_const i) }
193
194 detect GAS_merge      = do { x <- freshM
195                            ; return $ GASPortShapeWrapper (PortTensor (PortFree x) (PortFree x)) (PortFree x) GAS_merge }
196
197 detect (GAS_loopl f)  = do { x <- freshM
198                            ; y <- freshM
199                            ; z <- freshM
200                            ; f' <- detect f
201                            ; unifyM (fst $ shapes f') (PortTensor (PortFree z) (PortFree x))
202                            ; unifyM (snd $ shapes f') (PortTensor (PortFree z) (PortFree y))
203                            ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopl (GAS_misc f'))
204                            }
205 detect (GAS_loopr f)  = do { x <- freshM
206                            ; y <- freshM
207                            ; z <- freshM
208                            ; f' <- detect f
209                            ; unifyM (fst $ shapes f') (PortTensor (PortFree x) (PortFree z))
210                            ; unifyM (snd $ shapes f') (PortTensor (PortFree y) (PortFree z))
211                            ; return $ GASPortShapeWrapper (PortFree x) (PortFree y) (GAS_loopr (GAS_misc f'))
212                            }
213
214 detect (GAS_misc f)   = error "GAS_misc: not implemented"
215