NaturalDeductionContext: more permutation proofs
[coq-hetmet.git] / examples / Unify.hs
index afe3405..34761ea 100644 (file)
@@ -1,5 +1,11 @@
--- | A very simple unification engine; used by GArrowTikZ
-module Unify(UVar, Unifier, Unifiable(..), mergeU, emptyUnifier, getU, uvarSupply, unify, resolve, occurs)
+-- | A very simple finite-sized-term unification engine; used by GArrowTikZ
+module Unify(UVar, Unifier, Unifiable(..), mergeU, emptyUnifier, getU, uvarSupply, unify, resolve)
+-- 
+-- | Terminology: a value of type @t@ (for which an instance
+-- @Unifiable t@ exists) is "fully resolved" with respect to some
+-- value of type @Unifier t@ if no @UVar@ which occurs in the
+-- @t@-value is a key in the unifier map.
+--
 where
 import Prelude hiding (lookup)
 import Data.Map hiding (map)
@@ -7,8 +13,6 @@ import Data.Tree
 import Data.List (nub)
 import Control.Monad.Error
 
--- TO DO: propagate occurs-check errors through the Unifier instead of using Prelude.error
-
 -- | a unification variable
 newtype UVar = UVar Int
  deriving (Ord, Eq)
@@ -16,12 +20,16 @@ newtype UVar = UVar Int
 instance Show UVar where
  show (UVar v) = "u" ++ show v
 
--- | A unifier is a map from unification /variables/ to unification /values/ of type @t@.
+-- | A unifier is a map from unification /variables/ to unification
+-- /values/ of type @t@.  Invariant: values of the map are always
+-- fully resolved with respect to the map.
 data Unifier t = Unifier (Map UVar t)
+               | UnifierError String
 
--- | Resolves a unification variable according to a Unifier (not recursively).
+-- | Resolves a unification variable according to a Unifier.
 getU :: Unifier t -> UVar -> Maybe t
-getU (Unifier u) v = lookup v u
+getU (Unifier      u) v = lookup v u
+getU (UnifierError e) v = error e
 
 -- | An infinite supply of distinct unification variables.
 uvarSupply :: [UVar]
@@ -34,7 +42,7 @@ emptyUnifier :: Unifier t
 emptyUnifier =  Unifier empty
 
 -- | Types for which we know how to do unification.
-class Unifiable t where
+class Show t => Unifiable t where
 
   -- | Turns a @UVar@ into a @t@
   inject      :: UVar -> t
@@ -44,40 +52,47 @@ class Unifiable t where
 
   -- | Instances must implement this; it is called by @(unify x y)@
   --   only when both @(project x)@ and @(project y)@ are @Nothing@
-  unify'  :: t -> t -> Unifier t
+  unify'      :: t -> t -> Unifier t
 
-  -- | Returns a list of all @UVars@ occurring in @t@; duplicates are okay and resolution should not be performed.
+  -- | Returns a list of all @UVars@ occurring in @t@
   occurrences :: t -> [UVar]
 
--- | Returns a list of all UVars occurring anywhere in t and any UVars which
---   occur in values unified therewith.
-resolve :: Unifiable t => Unifier t -> UVar -> [UVar]
-resolve (Unifier u) v | member v u = v:(concatMap (resolve (Unifier u)) $ occurrences $ u ! v)
-                      | otherwise  = [v]
+  -- | @(replace vrep trep t)@ returns a copy of @t@ in which all occurrences of @vrep@ have been replaced by @trep@
+  replace     :: UVar -> t -> t -> t
 
--- | The occurs check.
-occurs :: Unifiable t => Unifier t -> UVar -> t -> Bool
-occurs u v x = elem v $ concatMap (resolve u) (occurrences x)
+-- | Returns a copy of the @t@ argument in which all @UVar@
+-- occurrences have been replaced by fully-resolved @t@ values.
+resolve :: Unifiable t => Unifier t -> t -> t
+resolve (UnifierError e) _ = error e
+resolve (Unifier m) t      = resolve' (toList m) t
+ where
+  resolve' []            t                         = t
+  resolve' ((uv,v):rest) t | Just uvt <- project t = if uvt == uv
+                                                     then v        -- we got this out of the unifier, so it must be fully resolved
+                                                     else resolve' rest t
+                           | otherwise             = resolve' rest (replace uv v t)
 
 -- | Given two unifiables, find their most general unifier.
 unify :: Unifiable t => t -> t -> Unifier t
-unify v1 v2 | (Just v1') <- project v1, (Just v2') <- project v2, v1'==v2'                   = emptyUnifier
-unify v1 v2 | (Just v1') <- project v1 = if  occurs emptyUnifier v1' v2
-                                         then error "occurs check failed"
-                                         else Unifier $ insert v1' v2 empty
-unify v1 v2 | (Just v2') <- project v2 = unify v2 v1
-unify v1 v2 |  _         <- project v1,  _         <- project v2                             = unify' v1 v2
+unify v1 v2 | (Just v1') <- project v1, (Just v2') <- project v2, v1'==v2'  = emptyUnifier
+unify v1 v2 | (Just v1') <- project v1                                      = if  elem v1' (occurrences v2)
+                                                                              then UnifierError "occurs check failed in Unify.unify"
+                                                                              else Unifier $ insert v1' v2 empty
+unify v1 v2 | (Just v2') <- project v2                                      = unify v2 v1
+unify v1 v2 |  _         <- project v1,  _         <- project v2            = unify' v1 v2
 
 -- | Merge two unifiers into a single unifier.
 mergeU :: Unifiable t => Unifier t -> Unifier t -> Unifier t
-mergeU (Unifier u) u' = foldr (\(k,v) -> \uacc -> mergeU' uacc k v) u' (toList u)
+mergeU ue@(UnifierError _) _  = ue
+mergeU    (Unifier      u) u' = foldr (\(k,v) -> \uacc -> mergeU' uacc k (resolve uacc v)) u' (toList u)
  where
-  mergeU' u@(Unifier m) v1 v2 | member v1 m    = mergeU u $ unify (m ! v1) v2
-                              | occurs u v1 v2 = error "occurs check failed"
-                              | otherwise      = Unifier $ insert v1 v2 m
+  mergeU' ue@(UnifierError _) _ _                                              = ue
+  mergeU'  u@(Unifier m) v1 v2 | member v1 m                                   = mergeU u $ unify (m ! v1) v2
+                               | Just v2' <- project (resolve u v2), v2' == v1 = u
+                               | elem v1 (occurrences v2)                      = UnifierError "occurs check failed in Unify.mergeU"
+                               | otherwise                                     = Unifier $ insert v1 v2 m
                                                            
 -- | Enumerates the unification variables, sorted by occurs-check.
 sortU :: (Unifiable t, Eq t) => Unifier t -> [UVar]
-sortU u@(Unifier um) = reverse $ nub $ concatMap (resolve u) (keys um)
-
-
+sortU u@(Unifier um)      = reverse $ nub $ concatMap (\k -> occurrences (um!k)) (keys um)
+sortU   (UnifierError ue) = error ue