Fix #3441: detect errors in partial sequences
[ghc-base.git] / GHC / IO / Encoding / UTF8.hs
1 {-# OPTIONS_GHC  -XNoImplicitPrelude -funbox-strict-fields #-}
2 {-# LANGUAGE BangPatterns #-}
3 -----------------------------------------------------------------------------
4 -- |
5 -- Module      :  GHC.IO.Encoding.UTF8
6 -- Copyright   :  (c) The University of Glasgow, 2009
7 -- License     :  see libraries/base/LICENSE
8 -- 
9 -- Maintainer  :  libraries@haskell.org
10 -- Stability   :  internal
11 -- Portability :  non-portable
12 --
13 -- UTF-8 Codec for the IO library
14 --
15 -- Portions Copyright   : (c) Tom Harper 2008-2009,
16 --                        (c) Bryan O'Sullivan 2009,
17 --                        (c) Duncan Coutts 2009
18 --
19 -----------------------------------------------------------------------------
20
21 module GHC.IO.Encoding.UTF8 (
22   utf8,
23   utf8_bom,
24   ) where
25
26 import GHC.Base
27 import GHC.Real
28 import GHC.Num
29 import GHC.IORef
30 -- import GHC.IO
31 import GHC.IO.Exception
32 import GHC.IO.Buffer
33 import GHC.IO.Encoding.Types
34 import GHC.Word
35 import Data.Bits
36 import Data.Maybe
37
38 utf8 :: TextEncoding
39 utf8 = TextEncoding { mkTextDecoder = utf8_DF,
40                       mkTextEncoder = utf8_EF }
41
42 utf8_DF :: IO (TextDecoder ())
43 utf8_DF =
44   return (BufferCodec {
45              encode   = utf8_decode,
46              close    = return (),
47              getState = return (),
48              setState = const $ return ()
49           })
50
51 utf8_EF :: IO (TextEncoder ())
52 utf8_EF =
53   return (BufferCodec {
54              encode   = utf8_encode,
55              close    = return (),
56              getState = return (),
57              setState = const $ return ()
58           })
59
60 utf8_bom :: TextEncoding
61 utf8_bom = TextEncoding { mkTextDecoder = utf8_bom_DF,
62                           mkTextEncoder = utf8_bom_EF }
63
64 utf8_bom_DF :: IO (TextDecoder Bool)
65 utf8_bom_DF = do
66    ref <- newIORef True
67    return (BufferCodec {
68              encode   = utf8_bom_decode ref,
69              close    = return (),
70              getState = readIORef ref,
71              setState = writeIORef ref
72           })
73
74 utf8_bom_EF :: IO (TextEncoder Bool)
75 utf8_bom_EF = do
76    ref <- newIORef True
77    return (BufferCodec {
78              encode   = utf8_bom_encode ref,
79              close    = return (),
80              getState = readIORef ref,
81              setState = writeIORef ref
82           })
83
84 utf8_bom_decode :: IORef Bool -> DecodeBuffer
85 utf8_bom_decode ref
86   input@Buffer{  bufRaw=iraw, bufL=ir, bufR=iw,  bufSize=_  }
87   output
88  = do
89    first <- readIORef ref
90    if not first
91       then utf8_decode input output
92       else do
93        let no_bom = do writeIORef ref False; utf8_decode input output
94        if iw - ir < 1 then return (input,output) else do
95        c0 <- readWord8Buf iraw ir
96        if (c0 /= bom0) then no_bom else do
97        if iw - ir < 2 then return (input,output) else do
98        c1 <- readWord8Buf iraw (ir+1)
99        if (c1 /= bom1) then no_bom else do
100        if iw - ir < 3 then return (input,output) else do
101        c2 <- readWord8Buf iraw (ir+2)
102        if (c2 /= bom2) then no_bom else do
103        -- found a BOM, ignore it and carry on
104        writeIORef ref False
105        utf8_decode input{ bufL = ir + 3 } output
106
107 utf8_bom_encode :: IORef Bool -> EncodeBuffer
108 utf8_bom_encode ref input
109   output@Buffer{ bufRaw=oraw, bufL=_, bufR=ow, bufSize=os }
110  = do
111   b <- readIORef ref
112   if not b then utf8_encode input output
113            else if os - ow < 3
114                   then return (input,output)
115                   else do
116                     writeIORef ref False
117                     writeWord8Buf oraw ow     bom0
118                     writeWord8Buf oraw (ow+1) bom1
119                     writeWord8Buf oraw (ow+2) bom2
120                     utf8_encode input output{ bufR = ow+3 }
121
122 bom0, bom1, bom2 :: Word8
123 bom0 = 0xef
124 bom1 = 0xbb
125 bom2 = 0xbf
126
127 utf8_decode :: DecodeBuffer
128 utf8_decode 
129   input@Buffer{  bufRaw=iraw, bufL=ir0, bufR=iw,  bufSize=_  }
130   output@Buffer{ bufRaw=oraw, bufL=_,   bufR=ow0, bufSize=os }
131  = let 
132        loop !ir !ow
133          | ow >= os || ir >= iw = done ir ow
134          | otherwise = do
135               c0 <- readWord8Buf iraw ir
136               case c0 of
137                 _ | c0 <= 0x7f -> do 
138                            ow' <- writeCharBuf oraw ow (unsafeChr (fromIntegral c0))
139                            loop (ir+1) ow'
140                   | c0 >= 0xc0 && c0 <= 0xdf ->
141                            if iw - ir < 2 then done ir ow else do
142                            c1 <- readWord8Buf iraw (ir+1)
143                            if (c1 < 0x80 || c1 >= 0xc0) then invalid else do
144                            ow' <- writeCharBuf oraw ow (chr2 c0 c1)
145                            loop (ir+2) ow'
146                   | c0 >= 0xe0 && c0 <= 0xef ->
147                       case iw - ir of
148                         1 -> done ir ow
149                         2 -> do -- check for an error even when we don't have
150                                 -- the full sequence yet (#3341)
151                            c1 <- readWord8Buf iraw (ir+1)
152                            if not (validate3 c0 c1 0x80) 
153                               then invalid else done ir ow
154                         _ -> do
155                            c1 <- readWord8Buf iraw (ir+1)
156                            c2 <- readWord8Buf iraw (ir+2)
157                            if not (validate3 c0 c1 c2) then invalid else do
158                            ow' <- writeCharBuf oraw ow (chr3 c0 c1 c2)
159                            loop (ir+3) ow'
160                   | c0 >= 0xf0 ->
161                       case iw - ir of
162                         1 -> done ir ow
163                         2 -> do -- check for an error even when we don't have
164                                 -- the full sequence yet (#3341)
165                            c1 <- readWord8Buf iraw (ir+1)
166                            if not (validate4 c0 c1 0x80 0x80)
167                               then invalid else done ir ow
168                         3 -> do
169                            c1 <- readWord8Buf iraw (ir+1)
170                            c2 <- readWord8Buf iraw (ir+2)
171                            if not (validate4 c0 c1 c2 0x80)
172                               then invalid else done ir ow
173                         _ -> do
174                            c1 <- readWord8Buf iraw (ir+1)
175                            c2 <- readWord8Buf iraw (ir+2)
176                            c3 <- readWord8Buf iraw (ir+3)
177                            if not (validate4 c0 c1 c2 c3) then invalid else do
178                            ow' <- writeCharBuf oraw ow (chr4 c0 c1 c2 c3)
179                            loop (ir+4) ow'
180                   | otherwise ->
181                            invalid
182          where
183            invalid = if ir > ir0 then done ir ow else ioe_decodingError
184
185        -- lambda-lifted, to avoid thunks being built in the inner-loop:
186        done !ir !ow = return (if ir == iw then input{ bufL=0, bufR=0 }
187                                           else input{ bufL=ir },
188                          output{ bufR=ow })
189    in
190    loop ir0 ow0
191
192 ioe_decodingError :: IO a
193 ioe_decodingError = ioException
194      (IOError Nothing InvalidArgument "utf8_decode"
195           "invalid UTF-8 byte sequence" Nothing Nothing)
196
197 utf8_encode :: EncodeBuffer
198 utf8_encode
199   input@Buffer{  bufRaw=iraw, bufL=ir0, bufR=iw,  bufSize=_  }
200   output@Buffer{ bufRaw=oraw, bufL=_,   bufR=ow0, bufSize=os }
201  = let 
202       done !ir !ow = return (if ir == iw then input{ bufL=0, bufR=0 }
203                                          else input{ bufL=ir },
204                              output{ bufR=ow })
205       loop !ir !ow
206         | ow >= os || ir >= iw = done ir ow
207         | otherwise = do
208            (c,ir') <- readCharBuf iraw ir
209            case ord c of
210              x | x <= 0x7F   -> do
211                     writeWord8Buf oraw ow (fromIntegral x)
212                     loop ir' (ow+1)
213                | x <= 0x07FF ->
214                     if os - ow < 2 then done ir ow else do
215                     let (c1,c2) = ord2 c
216                     writeWord8Buf oraw ow     c1
217                     writeWord8Buf oraw (ow+1) c2
218                     loop ir' (ow+2)
219                | x <= 0xFFFF -> do
220                     if os - ow < 3 then done ir ow else do
221                     let (c1,c2,c3) = ord3 c
222                     writeWord8Buf oraw ow     c1
223                     writeWord8Buf oraw (ow+1) c2
224                     writeWord8Buf oraw (ow+2) c3
225                     loop ir' (ow+3)
226                | otherwise -> do
227                     if os - ow < 4 then done ir ow else do
228                     let (c1,c2,c3,c4) = ord4 c
229                     writeWord8Buf oraw ow     c1
230                     writeWord8Buf oraw (ow+1) c2
231                     writeWord8Buf oraw (ow+2) c3
232                     writeWord8Buf oraw (ow+3) c4
233                     loop ir' (ow+4)
234    in
235    loop ir0 ow0
236
237 -- -----------------------------------------------------------------------------
238 -- UTF-8 primitives, lifted from Data.Text.Fusion.Utf8
239   
240 ord2   :: Char -> (Word8,Word8)
241 ord2 c = assert (n >= 0x80 && n <= 0x07ff) (x1,x2)
242     where
243       n  = ord c
244       x1 = fromIntegral $ (n `shiftR` 6) + 0xC0
245       x2 = fromIntegral $ (n .&. 0x3F)   + 0x80
246
247 ord3   :: Char -> (Word8,Word8,Word8)
248 ord3 c = assert (n >= 0x0800 && n <= 0xffff) (x1,x2,x3)
249     where
250       n  = ord c
251       x1 = fromIntegral $ (n `shiftR` 12) + 0xE0
252       x2 = fromIntegral $ ((n `shiftR` 6) .&. 0x3F) + 0x80
253       x3 = fromIntegral $ (n .&. 0x3F) + 0x80
254
255 ord4   :: Char -> (Word8,Word8,Word8,Word8)
256 ord4 c = assert (n >= 0x10000) (x1,x2,x3,x4)
257     where
258       n  = ord c
259       x1 = fromIntegral $ (n `shiftR` 18) + 0xF0
260       x2 = fromIntegral $ ((n `shiftR` 12) .&. 0x3F) + 0x80
261       x3 = fromIntegral $ ((n `shiftR` 6) .&. 0x3F) + 0x80
262       x4 = fromIntegral $ (n .&. 0x3F) + 0x80
263
264 chr2       :: Word8 -> Word8 -> Char
265 chr2 (W8# x1#) (W8# x2#) = C# (chr# (z1# +# z2#))
266     where
267       !y1# = word2Int# x1#
268       !y2# = word2Int# x2#
269       !z1# = uncheckedIShiftL# (y1# -# 0xC0#) 6#
270       !z2# = y2# -# 0x80#
271 {-# INLINE chr2 #-}
272
273 chr3          :: Word8 -> Word8 -> Word8 -> Char
274 chr3 (W8# x1#) (W8# x2#) (W8# x3#) = C# (chr# (z1# +# z2# +# z3#))
275     where
276       !y1# = word2Int# x1#
277       !y2# = word2Int# x2#
278       !y3# = word2Int# x3#
279       !z1# = uncheckedIShiftL# (y1# -# 0xE0#) 12#
280       !z2# = uncheckedIShiftL# (y2# -# 0x80#) 6#
281       !z3# = y3# -# 0x80#
282 {-# INLINE chr3 #-}
283
284 chr4             :: Word8 -> Word8 -> Word8 -> Word8 -> Char
285 chr4 (W8# x1#) (W8# x2#) (W8# x3#) (W8# x4#) =
286     C# (chr# (z1# +# z2# +# z3# +# z4#))
287     where
288       !y1# = word2Int# x1#
289       !y2# = word2Int# x2#
290       !y3# = word2Int# x3#
291       !y4# = word2Int# x4#
292       !z1# = uncheckedIShiftL# (y1# -# 0xF0#) 18#
293       !z2# = uncheckedIShiftL# (y2# -# 0x80#) 12#
294       !z3# = uncheckedIShiftL# (y3# -# 0x80#) 6#
295       !z4# = y4# -# 0x80#
296 {-# INLINE chr4 #-}
297
298 between :: Word8                -- ^ byte to check
299         -> Word8                -- ^ lower bound
300         -> Word8                -- ^ upper bound
301         -> Bool
302 between x y z = x >= y && x <= z
303 {-# INLINE between #-}
304
305 validate3          :: Word8 -> Word8 -> Word8 -> Bool
306 {-# INLINE validate3 #-}
307 validate3 x1 x2 x3 = validate3_1 ||
308                      validate3_2 ||
309                      validate3_3 ||
310                      validate3_4
311   where
312     validate3_1 = (x1 == 0xE0) &&
313                   between x2 0xA0 0xBF &&
314                   between x3 0x80 0xBF
315     validate3_2 = between x1 0xE1 0xEC &&
316                   between x2 0x80 0xBF &&
317                   between x3 0x80 0xBF
318     validate3_3 = x1 == 0xED &&
319                   between x2 0x80 0x9F &&
320                   between x3 0x80 0xBF
321     validate3_4 = between x1 0xEE 0xEF &&
322                   between x2 0x80 0xBF &&
323                   between x3 0x80 0xBF
324
325 validate4             :: Word8 -> Word8 -> Word8 -> Word8 -> Bool
326 {-# INLINE validate4 #-}
327 validate4 x1 x2 x3 x4 = validate4_1 ||
328                         validate4_2 ||
329                         validate4_3
330   where 
331     validate4_1 = x1 == 0xF0 &&
332                   between x2 0x90 0xBF &&
333                   between x3 0x80 0xBF &&
334                   between x4 0x80 0xBF
335     validate4_2 = between x1 0xF1 0xF3 &&
336                   between x2 0x80 0xBF &&
337                   between x3 0x80 0xBF &&
338                   between x4 0x80 0xBF
339     validate4_3 = x1 == 0xF4 &&
340                   between x2 0x80 0x8F &&
341                   between x3 0x80 0xBF &&
342                   between x4 0x80 0xBF