Fix Windows memory freeing: add a check for fb == NULL; fixes trac #4506
[ghc-hetmet.git] / rts / win32 / OSMem.c
1 /* -----------------------------------------------------------------------------
2  *
3  * (c) The University of Glasgow 2006-2007
4  *
5  * OS-specific memory management
6  *
7  * ---------------------------------------------------------------------------*/
8
9 #include "Rts.h"
10 #include "sm/OSMem.h"
11 #include "RtsUtils.h"
12
13 #if HAVE_WINDOWS_H
14 #include <windows.h>
15 #endif
16
17 typedef struct alloc_rec_ {
18     char* base;     /* non-aligned base address, directly from VirtualAlloc */
19     int size;       /* Size in bytes */
20     struct alloc_rec_* next;
21 } alloc_rec;
22
23 typedef struct block_rec_ {
24     char* base;         /* base address, non-MBLOCK-aligned */
25     int size;           /* size in bytes */
26     struct block_rec_* next;
27 } block_rec;
28
29 /* allocs are kept in ascending order, and are the memory regions as
30    returned by the OS as we need to have matching VirtualAlloc and
31    VirtualFree calls. */
32 static alloc_rec* allocs = NULL;
33
34 /* free_blocks are kept in ascending order, and adjacent blocks are merged */
35 static block_rec* free_blocks = NULL;
36
37 void
38 osMemInit(void)
39 {
40     allocs = NULL;
41     free_blocks = NULL;
42 }
43
44 static
45 alloc_rec*
46 allocNew(nat n) {
47     alloc_rec* rec;
48     rec = (alloc_rec*)stgMallocBytes(sizeof(alloc_rec),"getMBlocks: allocNew");
49     rec->size = (n+1)*MBLOCK_SIZE;
50     rec->base =
51         VirtualAlloc(NULL, rec->size, MEM_RESERVE, PAGE_READWRITE);
52     if(rec->base==0) {
53         stgFree((void*)rec);
54         rec=0;
55         if (GetLastError() == ERROR_NOT_ENOUGH_MEMORY) {
56
57             errorBelch("out of memory");
58         } else {
59             sysErrorBelch(
60                 "getMBlocks: VirtualAlloc MEM_RESERVE %d blocks failed", n);
61         }
62     } else {
63         alloc_rec temp;
64         temp.base=0; temp.size=0; temp.next=allocs;
65
66         alloc_rec* it;
67         it=&temp;
68         for(; it->next!=0 && it->next->base<rec->base; it=it->next) ;
69         rec->next=it->next;
70         it->next=rec;
71
72         allocs=temp.next;
73     }
74     return rec;
75 }
76
77 static
78 void
79 insertFree(char* alloc_base, int alloc_size) {
80     block_rec temp;
81     block_rec* it;
82     block_rec* prev;
83
84     temp.base=0; temp.size=0; temp.next=free_blocks;
85     it = free_blocks;
86     prev = &temp;
87     for( ; it!=0 && it->base<alloc_base; prev=it, it=it->next) {}
88
89     if(it!=0 && alloc_base+alloc_size == it->base) {
90         if(prev->base + prev->size == alloc_base) {        /* Merge it, alloc, prev */
91             prev->size += alloc_size + it->size;
92             prev->next = it->next;
93             stgFree(it);
94         } else {                                            /* Merge it, alloc */
95             it->base = alloc_base;
96             it->size += alloc_size;
97         }
98     } else if(prev->base + prev->size == alloc_base) {     /* Merge alloc, prev */
99         prev->size += alloc_size;
100     } else {                                                /* Merge none */
101         block_rec* rec;
102         rec = (block_rec*)stgMallocBytes(sizeof(block_rec),"getMBlocks: insertFree");
103         rec->base=alloc_base;
104         rec->size=alloc_size;
105         rec->next = it;
106         prev->next=rec;
107     }
108     free_blocks=temp.next;
109 }
110
111 static
112 void*
113 findFreeBlocks(nat n) {
114     void* ret=0;
115     block_rec* it;
116     block_rec temp;
117     block_rec* prev;
118
119     int required_size;
120     it=free_blocks;
121     required_size = n*MBLOCK_SIZE;
122     temp.next=free_blocks; temp.base=0; temp.size=0;
123     prev=&temp;
124     /* TODO: Don't just take first block, find smallest sufficient block */
125     for( ; it!=0 && it->size<required_size; prev=it, it=it->next ) {}
126     if(it!=0) {
127         if( (((unsigned long)it->base) & MBLOCK_MASK) == 0) { /* MBlock aligned */
128             ret = (void*)it->base;
129             if(it->size==required_size) {
130                 prev->next=it->next;
131                 stgFree(it);
132             } else {
133                 it->base += required_size;
134                 it->size -=required_size;
135             }
136         } else {
137             char* need_base;
138             block_rec* next;
139             int new_size;
140             need_base = (char*)(((unsigned long)it->base) & ((unsigned long)~MBLOCK_MASK)) + MBLOCK_SIZE;
141             next = (block_rec*)stgMallocBytes(
142                     sizeof(block_rec)
143                     , "getMBlocks: findFreeBlocks: splitting");
144             new_size = need_base - it->base;
145             next->base = need_base +required_size;
146             next->size = it->size - (new_size+required_size);
147             it->size = new_size;
148             next->next = it->next;
149             it->next = next;
150             ret=(void*)need_base;
151         }
152     }
153     free_blocks=temp.next;
154     return ret;
155 }
156
157 /* VirtualAlloc MEM_COMMIT can't cross boundaries of VirtualAlloc MEM_RESERVE,
158    so we might need to do many VirtualAlloc MEM_COMMITs.  We simply walk the
159    (ordered) allocated blocks. */
160 static void
161 commitBlocks(char* base, int size) {
162     alloc_rec* it;
163     it=allocs;
164     for( ; it!=0 && (it->base+it->size)<=base; it=it->next ) {}
165     for( ; it!=0 && size>0; it=it->next ) {
166         int size_delta;
167         void* temp;
168         size_delta = it->size - (base-it->base);
169         if(size_delta>size) size_delta=size;
170         temp = VirtualAlloc(base, size_delta, MEM_COMMIT, PAGE_READWRITE);
171         if(temp==0) {
172             sysErrorBelch("getMBlocks: VirtualAlloc MEM_COMMIT failed");
173             stg_exit(EXIT_FAILURE);
174         }
175         size-=size_delta;
176         base+=size_delta;
177     }
178 }
179
180 void *
181 osGetMBlocks(nat n) {
182     void* ret;
183     ret = findFreeBlocks(n);
184     if(ret==0) {
185         alloc_rec* alloc;
186         alloc = allocNew(n);
187         /* We already belch in allocNew if it fails */
188         if (alloc == 0) {
189             stg_exit(EXIT_FAILURE);
190         } else {
191             insertFree(alloc->base, alloc->size);
192             ret = findFreeBlocks(n);
193         }
194     }
195
196     if(ret!=0) {
197         /* (In)sanity tests */
198         if (((W_)ret & MBLOCK_MASK) != 0) {
199             barf("getMBlocks: misaligned block returned");
200         }
201
202         commitBlocks(ret, MBLOCK_SIZE*n);
203     }
204
205     return ret;
206 }
207
208 void osFreeMBlocks(char *addr, nat n)
209 {
210     alloc_rec *p;
211     lnat nBytes = (lnat)n * MBLOCK_SIZE;
212
213     insertFree(addr, nBytes);
214
215     p = allocs;
216     while ((p != NULL) && (addr >= (p->base + p->size))) {
217         p = p->next;
218     }
219     while (nBytes > 0) {
220         if ((p == NULL) || (p->base > addr)) {
221             errorBelch("Memory to be freed isn't allocated\n");
222             stg_exit(EXIT_FAILURE);
223         }
224         if (p->base + p->size >= addr + nBytes) {
225             if (!VirtualFree(addr, nBytes, MEM_DECOMMIT)) {
226                 sysErrorBelch("osFreeMBlocks: VirtualFree MEM_DECOMMIT failed");
227                 stg_exit(EXIT_FAILURE);
228             }
229             nBytes = 0;
230         }
231         else {
232             lnat bytesToFree = p->base + p->size - addr;
233             if (!VirtualFree(addr, bytesToFree, MEM_DECOMMIT)) {
234                 sysErrorBelch("osFreeMBlocks: VirtualFree MEM_DECOMMIT failed");
235                 stg_exit(EXIT_FAILURE);
236             }
237             addr += bytesToFree;
238             nBytes -= bytesToFree;
239             p = p->next;
240         }
241     }
242 }
243
244 void osReleaseFreeMemory(void)
245 {
246     alloc_rec *prev_a, *a;
247     alloc_rec head_a;
248     block_rec *prev_fb, *fb;
249     block_rec head_fb;
250     char *a_end, *fb_end;
251
252     /* go through allocs and free_blocks in lockstep, looking for allocs
253        that are completely free, and uncommit them */
254
255     head_a.base = 0;
256     head_a.size = 0;
257     head_a.next = allocs;
258     head_fb.base = 0;
259     head_fb.size = 0;
260     head_fb.next = free_blocks;
261     prev_a = &head_a;
262     a = allocs;
263     prev_fb = &head_fb;
264     fb = free_blocks;
265
266     while (a != NULL) {
267         a_end = a->base + a->size;
268         /* If a is freeable then there is a single freeblock in fb that
269            covers it. The end of this free block must be >= the end of
270            a, so skip anything in fb that ends before a. */
271         while (fb != NULL && fb->base + fb->size < a_end) {
272             prev_fb = fb;
273             fb = fb->next;
274         }
275
276         if (fb == NULL) {
277             /* If we have nothing left in fb, then neither a nor
278                anything later in the list is freeable, so we are done. */
279             break;
280         }
281         else {
282             fb_end = fb->base + fb->size;
283             /* We have a candidate fb. But does it really cover a? */
284             if (fb->base <= a->base) {
285                 /* Yes, the alloc is within the free block. Now we need
286                    to know if it sticks out at either end. */
287                 if (fb_end == a_end) {
288                     if (fb->base == a->base) {
289                         /* fb and a are identical, so just free fb */
290                         prev_fb->next = fb->next;
291                         stgFree(fb);
292                         fb = prev_fb->next;
293                     }
294                     else {
295                         /* fb begins earlier, so truncate it to not include a */
296                         fb->size = a->base - fb->base;
297                     }
298                 }
299                 else {
300                     /* fb ends later, so we'll make fb just be the part
301                        after a. First though, if it also starts earlier,
302                        we make a new free block record for the before bit. */
303                     if (fb->base != a->base) {
304                         block_rec *new_fb;
305
306                         new_fb = (block_rec *)stgMallocBytes(sizeof(block_rec),"osReleaseFreeMemory");
307                         new_fb->base = fb->base;
308                         new_fb->size = a->base - fb->base;
309                         new_fb->next = fb;
310                         prev_fb->next = new_fb;
311                     }
312                     fb->size = fb_end - a_end;
313                     fb->base = a_end;
314                 }
315                 /* Now we can free the alloc */
316                 prev_a->next = a->next;
317                 if(!VirtualFree((void *)a->base, 0, MEM_RELEASE)) {
318                     sysErrorBelch("freeAllMBlocks: VirtualFree MEM_RELEASE failed");
319                     stg_exit(EXIT_FAILURE);
320                 }
321                 stgFree(a);
322                 a = prev_a->next;
323             }
324             else {
325                 /* Otherwise this alloc is not freeable, so go on to the
326                    next one */
327                 prev_a = a;
328                 a = a->next;
329             }
330         }
331     }
332
333     allocs = head_a.next;
334     free_blocks = head_fb.next;
335 }
336
337 void
338 osFreeAllMBlocks(void)
339 {
340     {
341         block_rec* next;
342         block_rec* it;
343         next=0;
344         it = free_blocks;
345         for(; it!=0; ) {
346             next = it->next;
347             stgFree(it);
348             it=next;
349         }
350     }
351     {
352         alloc_rec* next;
353         alloc_rec* it;
354         next=0;
355         it=allocs;
356         for(; it!=0; ) {
357             if(!VirtualFree((void*)it->base, 0, MEM_RELEASE)) {
358                 sysErrorBelch("freeAllMBlocks: VirtualFree MEM_RELEASE failed");
359                 stg_exit(EXIT_FAILURE);
360             }
361             next = it->next;
362             stgFree(it);
363             it=next;
364         }
365     }
366 }
367
368 lnat getPageSize (void)
369 {
370     static lnat pagesize = 0;
371     if (pagesize) {
372         return pagesize;
373     } else {
374         SYSTEM_INFO sSysInfo;
375         GetSystemInfo(&sSysInfo);
376         pagesize = sSysInfo.dwPageSize;
377         return pagesize;
378     }
379 }
380
381 void setExecutable (void *p, lnat len, rtsBool exec)
382 {
383     DWORD dwOldProtect = 0;
384     if (VirtualProtect (p, len,
385                         exec ? PAGE_EXECUTE_READWRITE : PAGE_READWRITE,
386                         &dwOldProtect) == 0)
387     {
388         sysErrorBelch("setExecutable: failed to protect 0x%p; old protection: %lu\n",
389                       p, (unsigned long)dwOldProtect);
390         stg_exit(EXIT_FAILURE);
391     }
392 }