wibble in setExecutable
[ghc-hetmet.git] / rts / win32 / OSMem.c
1 /* -----------------------------------------------------------------------------
2  *
3  * (c) The University of Glasgow 2006-2007
4  *
5  * OS-specific memory management
6  *
7  * ---------------------------------------------------------------------------*/
8
9 #include "Rts.h"
10 #include "OSMem.h"
11 #include "RtsUtils.h"
12 #include "RtsMessages.h"
13
14 #if HAVE_WINDOWS_H
15 #include <windows.h>
16 #endif
17
18 /* alloc_rec keeps the info we need to have matching VirtualAlloc and
19    VirtualFree calls.
20 */
21 typedef struct alloc_rec_ {
22     char* base;     /* non-aligned base address, directly from VirtualAlloc */
23     int size;       /* Size in bytes */
24     struct alloc_rec_* next;
25 } alloc_rec;
26
27 typedef struct block_rec_ {
28     char* base;         /* base address, non-MBLOCK-aligned */
29     int size;           /* size in bytes */
30     struct block_rec_* next;
31 } block_rec;
32
33 static alloc_rec* allocs = NULL;
34 static block_rec* free_blocks = NULL;
35
36 void
37 osMemInit(void)
38 {
39     allocs = NULL;
40     free_blocks = NULL;
41 }
42
43 static
44 alloc_rec*
45 allocNew(nat n) {
46     alloc_rec* rec;
47     rec = (alloc_rec*)stgMallocBytes(sizeof(alloc_rec),"getMBlocks: allocNew");
48     rec->size = (n+1)*MBLOCK_SIZE;
49     rec->base = 
50         VirtualAlloc(NULL, rec->size, MEM_RESERVE, PAGE_READWRITE);
51     if(rec->base==0) {
52         stgFree((void*)rec);
53         rec=0;
54         if (GetLastError() == ERROR_NOT_ENOUGH_MEMORY) {
55
56             errorBelch("out of memory");
57         } else {
58             sysErrorBelch(
59                 "getMBlocks: VirtualAlloc MEM_RESERVE %d blocks failed", n);
60         }
61     } else {
62                 alloc_rec temp;
63                 temp.base=0; temp.size=0; temp.next=allocs;
64
65         alloc_rec* it;
66         it=&temp;
67         for(; it->next!=0 && it->next->base<rec->base; it=it->next) ;
68         rec->next=it->next;
69         it->next=rec;
70
71                 allocs=temp.next;
72     }
73     return rec;
74 }
75
76 static
77 void
78 insertFree(char* alloc_base, int alloc_size) {
79     block_rec temp;
80     block_rec* it;
81     block_rec* prev;
82
83     temp.base=0; temp.size=0; temp.next=free_blocks;
84     it = free_blocks;
85     prev = &temp;
86     for( ; it!=0 && it->base<alloc_base; prev=it, it=it->next) {}
87
88     if(it!=0 && alloc_base+alloc_size == it->base) {
89         if(prev->base + prev->size == alloc_base) {        /* Merge it, alloc, prev */
90             prev->size += alloc_size + it->size;
91             prev->next = it->next;
92             stgFree(it);
93         } else {                                            /* Merge it, alloc */
94             it->base = alloc_base;
95             it->size += alloc_size;
96         }
97     } else if(prev->base + prev->size == alloc_base) {     /* Merge alloc, prev */
98         prev->size += alloc_size;
99     } else {                                                /* Merge none */
100         block_rec* rec;
101         rec = (block_rec*)stgMallocBytes(sizeof(block_rec),"getMBlocks: insertFree");
102         rec->base=alloc_base;
103         rec->size=alloc_size;
104         rec->next = it;
105         prev->next=rec;
106     }
107     free_blocks=temp.next;
108 }
109
110 static
111 void*
112 findFreeBlocks(nat n) {
113     void* ret=0;
114     block_rec* it;
115     block_rec temp;
116     block_rec* prev;
117
118     int required_size;
119     it=free_blocks;
120     required_size = n*MBLOCK_SIZE;
121     temp.next=free_blocks; temp.base=0; temp.size=0;
122     prev=&temp;
123     /* TODO: Don't just take first block, find smallest sufficient block */
124     for( ; it!=0 && it->size<required_size; prev=it, it=it->next ) {}
125     if(it!=0) {
126         if( (((unsigned long)it->base) & MBLOCK_MASK) == 0) { /* MBlock aligned */
127             ret = (void*)it->base;
128             if(it->size==required_size) {
129                 prev->next=it->next;
130                 stgFree(it);
131             } else {
132                 it->base += required_size;
133                 it->size -=required_size;
134             }
135         } else {
136             char* need_base;
137             block_rec* next;
138             int new_size;
139             need_base = (char*)(((unsigned long)it->base) & ((unsigned long)~MBLOCK_MASK)) + MBLOCK_SIZE;
140             next = (block_rec*)stgMallocBytes(
141                     sizeof(block_rec)
142                     , "getMBlocks: findFreeBlocks: splitting");
143             new_size = need_base - it->base;
144             next->base = need_base +required_size;
145             next->size = it->size - (new_size+required_size);
146             it->size = new_size;
147             next->next = it->next;
148             it->next = next;
149             ret=(void*)need_base;
150         }
151     }
152     free_blocks=temp.next;
153     return ret;
154 }
155
156 /* VirtualAlloc MEM_COMMIT can't cross boundaries of VirtualAlloc MEM_RESERVE,
157    so we might need to do many VirtualAlloc MEM_COMMITs.  We simply walk the
158    (ordered) allocated blocks. */
159 static void
160 commitBlocks(char* base, int size) {
161     alloc_rec* it;
162     it=allocs;
163     for( ; it!=0 && (it->base+it->size)<=base; it=it->next ) {}
164     for( ; it!=0 && size>0; it=it->next ) {
165         int size_delta;
166         void* temp;
167         size_delta = it->size - (base-it->base);
168         if(size_delta>size) size_delta=size;
169         temp = VirtualAlloc(base, size_delta, MEM_COMMIT, PAGE_READWRITE);
170         if(temp==0) {
171             sysErrorBelch("getMBlocks: VirtualAlloc MEM_COMMIT failed");
172             stg_exit(EXIT_FAILURE);
173         }
174         size-=size_delta;
175         base+=size_delta;
176     }
177 }
178
179 void *
180 osGetMBlocks(nat n) {
181     void* ret;
182     ret = findFreeBlocks(n);
183     if(ret==0) {
184         alloc_rec* alloc;
185         alloc = allocNew(n);
186         /* We already belch in allocNew if it fails */
187         if (alloc == 0) {
188             stg_exit(EXIT_FAILURE);
189         } else {
190             insertFree(alloc->base, alloc->size);
191             ret = findFreeBlocks(n);
192         }
193     }
194
195     if(ret!=0) {
196         /* (In)sanity tests */
197         if (((W_)ret & MBLOCK_MASK) != 0) {
198             barf("getMBlocks: misaligned block returned");
199         }
200
201         commitBlocks(ret, MBLOCK_SIZE*n);
202     }
203
204     return ret;
205 }
206
207 void
208 osFreeAllMBlocks(void)
209 {
210     {
211         block_rec* next;
212         block_rec* it;
213         next=0;
214         it = free_blocks;
215         for(; it!=0; ) {
216             next = it->next;
217             stgFree(it);
218             it=next;
219         }
220     }
221     {
222         alloc_rec* next;
223         alloc_rec* it;
224         next=0;
225         it=allocs;
226         for(; it!=0; ) {
227             if(!VirtualFree((void*)it->base, 0, MEM_RELEASE)) {
228                 sysErrorBelch("freeAllMBlocks: VirtualFree MEM_RELEASE failed");
229                 stg_exit(EXIT_FAILURE);
230             }
231             next = it->next;
232             stgFree(it);
233             it=next;
234         }
235     }
236 }
237
238 lnat getPageSize (void)
239 {
240     static lnat pagesize = 0;
241     if (pagesize) {
242         return pagesize;
243     } else {
244         SYSTEM_INFO sSysInfo;
245         GetSystemInfo(&sSysInfo);
246         pagesize = sSysInfo.dwPageSize;
247         return pagesize;
248     }
249 }
250
251 void setExecutable (void *p, lnat len, rtsBool exec)
252 {
253     DWORD dwOldProtect = 0;
254     if (VirtualProtect (p, len, 
255                         exec ? PAGE_EXECUTE_READWRITE : PAGE_READWRITE, 
256                         &dwOldProtect) == 0)
257     {
258         sysErrorBelch("setExecutable: failed to protect 0x%p; old protection: %lu\n",
259                       p, (unsigned long)dwOldProtect);
260         stg_exit(EXIT_FAILURE);
261     }
262 }