hoist orig_ty (ForAllTy tv ty) = case hoist ty ty of
(tvs,theta,tau) -> (tv:tvs,theta,tau)
hoist orig_ty (FunTy arg res)
- | isPredTy arg = case hoist res res of
- (tvs,theta,tau) -> (tvs,arg:theta,tau)
+ | isPredTy arg' = case hoist res res of
+ (tvs,theta,tau) -> (tvs,arg':theta,tau)
| otherwise = case hoist res res of
- (tvs,theta,tau) -> (tvs,theta,mkFunTy arg tau)
+ (tvs,theta,tau) -> (tvs,theta,mkFunTy arg' tau)
+ where
+ arg' = hoistForAllTys arg -- Don't forget to apply hoist recursively
+ -- to the argument type
hoist orig_ty (NoteTy _ ty) = hoist orig_ty ty
hoist orig_ty ty = ([], [], orig_ty)