794c987c0fae2682cd09f1ba0c40f8f043a86954
[ghc-hetmet.git] / rts / win32 / OSMem.c
1 /* -----------------------------------------------------------------------------
2  *
3  * (c) The University of Glasgow 2006-2007
4  *
5  * OS-specific memory management
6  *
7  * ---------------------------------------------------------------------------*/
8
9 #include "Rts.h"
10 #include "sm/OSMem.h"
11 #include "RtsUtils.h"
12
13 #if HAVE_WINDOWS_H
14 #include <windows.h>
15 #endif
16
17 typedef struct alloc_rec_ {
18     char* base;     /* non-aligned base address, directly from VirtualAlloc */
19     int size;       /* Size in bytes */
20     struct alloc_rec_* next;
21 } alloc_rec;
22
23 typedef struct block_rec_ {
24     char* base;         /* base address, non-MBLOCK-aligned */
25     int size;           /* size in bytes */
26     struct block_rec_* next;
27 } block_rec;
28
29 /* allocs are kept in ascending order, and are the memory regions as
30    returned by the OS as we need to have matching VirtualAlloc and
31    VirtualFree calls. */
32 static alloc_rec* allocs = NULL;
33
34 /* free_blocks are kept in ascending order, and adjacent blocks are merged */
35 static block_rec* free_blocks = NULL;
36
37 void
38 osMemInit(void)
39 {
40     allocs = NULL;
41     free_blocks = NULL;
42 }
43
44 static
45 alloc_rec*
46 allocNew(nat n) {
47     alloc_rec* rec;
48     rec = (alloc_rec*)stgMallocBytes(sizeof(alloc_rec),"getMBlocks: allocNew");
49     rec->size = (n+1)*MBLOCK_SIZE;
50     rec->base =
51         VirtualAlloc(NULL, rec->size, MEM_RESERVE, PAGE_READWRITE);
52     if(rec->base==0) {
53         stgFree((void*)rec);
54         rec=0;
55         if (GetLastError() == ERROR_NOT_ENOUGH_MEMORY) {
56
57             errorBelch("out of memory");
58         } else {
59             sysErrorBelch(
60                 "getMBlocks: VirtualAlloc MEM_RESERVE %d blocks failed", n);
61         }
62     } else {
63         alloc_rec temp;
64         temp.base=0; temp.size=0; temp.next=allocs;
65
66         alloc_rec* it;
67         it=&temp;
68         for(; it->next!=0 && it->next->base<rec->base; it=it->next) ;
69         rec->next=it->next;
70         it->next=rec;
71
72         allocs=temp.next;
73     }
74     return rec;
75 }
76
77 static
78 void
79 insertFree(char* alloc_base, int alloc_size) {
80     block_rec temp;
81     block_rec* it;
82     block_rec* prev;
83
84     temp.base=0; temp.size=0; temp.next=free_blocks;
85     it = free_blocks;
86     prev = &temp;
87     for( ; it!=0 && it->base<alloc_base; prev=it, it=it->next) {}
88
89     if(it!=0 && alloc_base+alloc_size == it->base) {
90         if(prev->base + prev->size == alloc_base) {        /* Merge it, alloc, prev */
91             prev->size += alloc_size + it->size;
92             prev->next = it->next;
93             stgFree(it);
94         } else {                                            /* Merge it, alloc */
95             it->base = alloc_base;
96             it->size += alloc_size;
97         }
98     } else if(prev->base + prev->size == alloc_base) {     /* Merge alloc, prev */
99         prev->size += alloc_size;
100     } else {                                                /* Merge none */
101         block_rec* rec;
102         rec = (block_rec*)stgMallocBytes(sizeof(block_rec),"getMBlocks: insertFree");
103         rec->base=alloc_base;
104         rec->size=alloc_size;
105         rec->next = it;
106         prev->next=rec;
107     }
108     free_blocks=temp.next;
109 }
110
111 static
112 void*
113 findFreeBlocks(nat n) {
114     void* ret=0;
115     block_rec* it;
116     block_rec temp;
117     block_rec* prev;
118
119     int required_size;
120     it=free_blocks;
121     required_size = n*MBLOCK_SIZE;
122     temp.next=free_blocks; temp.base=0; temp.size=0;
123     prev=&temp;
124     /* TODO: Don't just take first block, find smallest sufficient block */
125     for( ; it!=0 && it->size<required_size; prev=it, it=it->next ) {}
126     if(it!=0) {
127         if( (((unsigned long)it->base) & MBLOCK_MASK) == 0) { /* MBlock aligned */
128             ret = (void*)it->base;
129             if(it->size==required_size) {
130                 prev->next=it->next;
131                 stgFree(it);
132             } else {
133                 it->base += required_size;
134                 it->size -=required_size;
135             }
136         } else {
137             char* need_base;
138             block_rec* next;
139             int new_size;
140             need_base = (char*)(((unsigned long)it->base) & ((unsigned long)~MBLOCK_MASK)) + MBLOCK_SIZE;
141             next = (block_rec*)stgMallocBytes(
142                     sizeof(block_rec)
143                     , "getMBlocks: findFreeBlocks: splitting");
144             new_size = need_base - it->base;
145             next->base = need_base +required_size;
146             next->size = it->size - (new_size+required_size);
147             it->size = new_size;
148             next->next = it->next;
149             it->next = next;
150             ret=(void*)need_base;
151         }
152     }
153     free_blocks=temp.next;
154     return ret;
155 }
156
157 /* VirtualAlloc MEM_COMMIT can't cross boundaries of VirtualAlloc MEM_RESERVE,
158    so we might need to do many VirtualAlloc MEM_COMMITs.  We simply walk the
159    (ordered) allocated blocks. */
160 static void
161 commitBlocks(char* base, int size) {
162     alloc_rec* it;
163     it=allocs;
164     for( ; it!=0 && (it->base+it->size)<=base; it=it->next ) {}
165     for( ; it!=0 && size>0; it=it->next ) {
166         int size_delta;
167         void* temp;
168         size_delta = it->size - (base-it->base);
169         if(size_delta>size) size_delta=size;
170         temp = VirtualAlloc(base, size_delta, MEM_COMMIT, PAGE_READWRITE);
171         if(temp==0) {
172             sysErrorBelch("getMBlocks: VirtualAlloc MEM_COMMIT failed");
173             stg_exit(EXIT_FAILURE);
174         }
175         size-=size_delta;
176         base+=size_delta;
177     }
178 }
179
180 void *
181 osGetMBlocks(nat n) {
182     void* ret;
183     ret = findFreeBlocks(n);
184     if(ret==0) {
185         alloc_rec* alloc;
186         alloc = allocNew(n);
187         /* We already belch in allocNew if it fails */
188         if (alloc == 0) {
189             stg_exit(EXIT_FAILURE);
190         } else {
191             insertFree(alloc->base, alloc->size);
192             ret = findFreeBlocks(n);
193         }
194     }
195
196     if(ret!=0) {
197         /* (In)sanity tests */
198         if (((W_)ret & MBLOCK_MASK) != 0) {
199             barf("getMBlocks: misaligned block returned");
200         }
201
202         commitBlocks(ret, MBLOCK_SIZE*n);
203     }
204
205     return ret;
206 }
207
208 void osFreeMBlocks(char *addr, nat n)
209 {
210     alloc_rec *p;
211     lnat nBytes = (lnat)n * MBLOCK_SIZE;
212
213     insertFree(addr, nBytes);
214
215     p = allocs;
216     while ((p != NULL) && (addr >= (p->base + p->size))) {
217         p = p->next;
218     }
219     while (nBytes > 0) {
220         if ((p == NULL) || (p->base > addr)) {
221             errorBelch("Memory to be freed isn't allocated\n");
222             stg_exit(EXIT_FAILURE);
223         }
224         if (p->base + p->size >= addr + nBytes) {
225             if (!VirtualFree(addr, nBytes, MEM_DECOMMIT)) {
226                 sysErrorBelch("osFreeMBlocks: VirtualFree MEM_DECOMMIT failed");
227                 stg_exit(EXIT_FAILURE);
228             }
229             nBytes = 0;
230         }
231         else {
232             lnat bytesToFree = p->base + p->size - addr;
233             if (!VirtualFree(addr, bytesToFree, MEM_DECOMMIT)) {
234                 sysErrorBelch("osFreeMBlocks: VirtualFree MEM_DECOMMIT failed");
235                 stg_exit(EXIT_FAILURE);
236             }
237             addr += bytesToFree;
238             nBytes -= bytesToFree;
239             p = p->next;
240         }
241     }
242 }
243
244 void osReleaseFreeMemory(void)
245 {
246     alloc_rec *prev_a, *a;
247     alloc_rec head_a;
248     block_rec *prev_fb, *fb;
249     block_rec head_fb;
250     char *a_end, *fb_end;
251
252     /* go through allocs and free_blocks in lockstep, looking for allocs
253        that are completely free, and uncommit them */
254
255     head_a.base = 0;
256     head_a.size = 0;
257     head_a.next = allocs;
258     head_fb.base = 0;
259     head_fb.size = 0;
260     head_fb.next = free_blocks;
261     prev_a = &head_a;
262     a = allocs;
263     prev_fb = &head_fb;
264     fb = free_blocks;
265
266     while (a != NULL) {
267         a_end = a->base + a->size;
268         while (fb != NULL && fb->base + fb->size < a_end) {
269             prev_fb = fb;
270             fb = fb->next;
271         }
272
273         fb_end = fb->base + fb->size;
274         if (fb->base <= a->base) {
275             /* The alloc is within the free block. Now we need to know
276                if it sticks out at either end. */
277             if (fb_end == a_end) {
278                 if (fb->base == a->base) {
279                     /* fb and a are identical, so just free fb */
280                     prev_fb->next = fb->next;
281                     stgFree(fb);
282                     fb = prev_fb->next;
283                 }
284                 else {
285                     /* fb begins earlier, so truncate it to not include a */
286                     fb->size = a->base - fb->base;
287                 }
288             }
289             else {
290                 /* fb ends later, so we'll make fb just be the part
291                    after a. First though, if it also starts earlier,
292                    we make a new free block record for the before bit. */
293                 if (fb->base != a->base) {
294                     block_rec *new_fb;
295
296                     new_fb = (block_rec *)stgMallocBytes(sizeof(block_rec),"osReleaseFreeMemory");
297                     new_fb->base = fb->base;
298                     new_fb->size = a->base - fb->base;
299                     new_fb->next = fb;
300                     prev_fb->next = new_fb;
301                 }
302                 fb->size = fb_end - a_end;
303                 fb->base = a_end;
304             }
305             /* Now we can free the alloc */
306             prev_a->next = a->next;
307             if(!VirtualFree((void *)a->base, 0, MEM_RELEASE)) {
308                 sysErrorBelch("freeAllMBlocks: VirtualFree MEM_RELEASE failed");
309                 stg_exit(EXIT_FAILURE);
310             }
311             stgFree(a);
312             a = prev_a->next;
313         }
314         else {
315             /* Otherwise this alloc is not freeable, so go on to the
316                next one */
317             prev_a = a;
318             a = a->next;
319         }
320     }
321
322     allocs = head_a.next;
323     free_blocks = head_fb.next;
324 }
325
326 void
327 osFreeAllMBlocks(void)
328 {
329     {
330         block_rec* next;
331         block_rec* it;
332         next=0;
333         it = free_blocks;
334         for(; it!=0; ) {
335             next = it->next;
336             stgFree(it);
337             it=next;
338         }
339     }
340     {
341         alloc_rec* next;
342         alloc_rec* it;
343         next=0;
344         it=allocs;
345         for(; it!=0; ) {
346             if(!VirtualFree((void*)it->base, 0, MEM_RELEASE)) {
347                 sysErrorBelch("freeAllMBlocks: VirtualFree MEM_RELEASE failed");
348                 stg_exit(EXIT_FAILURE);
349             }
350             next = it->next;
351             stgFree(it);
352             it=next;
353         }
354     }
355 }
356
357 lnat getPageSize (void)
358 {
359     static lnat pagesize = 0;
360     if (pagesize) {
361         return pagesize;
362     } else {
363         SYSTEM_INFO sSysInfo;
364         GetSystemInfo(&sSysInfo);
365         pagesize = sSysInfo.dwPageSize;
366         return pagesize;
367     }
368 }
369
370 void setExecutable (void *p, lnat len, rtsBool exec)
371 {
372     DWORD dwOldProtect = 0;
373     if (VirtualProtect (p, len,
374                         exec ? PAGE_EXECUTE_READWRITE : PAGE_READWRITE,
375                         &dwOldProtect) == 0)
376     {
377         sysErrorBelch("setExecutable: failed to protect 0x%p; old protection: %lu\n",
378                       p, (unsigned long)dwOldProtect);
379         stg_exit(EXIT_FAILURE);
380     }
381 }