1 {-# OPTIONS_GHC -XModalTypes -ddump-types -XNoMonoPatBinds -XFlexibleContexts #-}
2 module DotProduct
3 where
4 import GHC.HetMet.CodeTypes hiding ((-))
5 import Prelude hiding ( id, (.) )
7 --------------------------------------------------------------------------------
8 -- Dot Product
9 --
10 --  This shows how to build a two-level program one step at a time by
11 --  slowly rearranging it until the brackets can be inserted.
12 --
14 -- a one-level function to compute the dot product of two vectors
15 dotproduct :: [Int] -> [Int] -> Int
16 dotproduct v1 v2 =
17   case v1 of
18     []     -> 0
19     (a:ax) -> case v2 of
20                    []     -> 0
21                    (b:bx) ->
22                        (a*b)+(dotproduct ax bx)
24 -- A slightly modified version of the dot product: note that we
25 -- check for zeroes and ones to avoid multiplying.  In a one-level
26 -- program this yields no advantage, however!
27 dotproduct' :: [Int] -> [Int] -> Int
28 dotproduct' v1 v2 =
29   case v1 of
30     []     -> 0
31     (0:ax) -> case v2 of
32                    []     -> 0
33                    (b:bx) -> (dotproduct' ax bx)
34     (1:ax) -> case v2 of
35                    []     -> 0
36                    (b:bx) -> b+(dotproduct' ax bx)
37     (a:ax) -> case v2 of
38                    []     -> 0
39                    (b:bx) ->
40                        (a*b)+(dotproduct' ax bx)
42 -- A two-level version of the dot product.  Note how we ask for the first
43 -- vector, then produce a program which is optimized for multiplying
44 -- by that particular vector.  If there are zeroes or ones in the
45 -- original vector, we will emit code which is faster than a one-level
46 -- dot product.
48 dotproduct'' :: forall g.
50                 GuestLanguageMult        g Integer =>
51                 GuestIntegerLiteral      g         =>
52                 [Integer] -> <[ [Integer] -> Integer ]>@g
53 dotproduct'' v1 =
54   case v1 of
55     []     -> <[ \v2 -> 0 ]>
56     (0:ax) -> <[ \v2 -> case v2 of
57                           []     -> 0
58                           (b:bx) -> ~~(dotproduct'' ax) bx ]>
59     (1:ax) -> <[ \v2 -> case v2 of
60                           []     -> 0
61                           (b:bx) -> b + ~~(dotproduct'' ax) bx ]>
63     (a:ax) -> <[ \v2 -> case v2 of
64                           []     -> 0
65                           (b:bx) -> ~~(guestIntegerLiteral a) * b + ~~(dotproduct'' ax) bx ]>