@@ -63,6 +63,9 @@ void ramdiskrw(struct buf*); void* kalloc(void); void kfree(void *); void kinit(void); +void inc_kmemref(void*); +void dec_kmemref(void*); +char get_kmemref(void*); // log.c void initlog(int, struct superblock*); @@ -147,6 +150,7 @@ void trapinit(void); void trapinithart(void); extern struct spinlock tickslock; void usertrapret(void); +int cowhandler(pagetable_t, uint64); // uart.c void uartinit(void);
@@ -21,6 +21,7 @@ struct run { struct { struct spinlock lock; struct run *freelist; + char ref_count[PHYSTOP/PGSIZE]; } kmem; void @@ -34,9 +35,15 @@ void freerange(void *pa_start, void *pa_end) { char *p; + int i; p = (char*)PGROUNDUP((uint64)pa_start); - for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE) + for(i=0; i<(uint64)p/PGSIZE; i++) { + kmem.ref_count[i] = 0; + } + for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE) { + kmem.ref_count[(uint64)p/PGSIZE] = 1; kfree(p); + } } // Free the page of physical memory pointed at by pa, @@ -51,6 +58,11 @@ kfree(void *pa) if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP) panic("kfree"); + if (get_kmemref(pa) > 1) { + dec_kmemref(pa); + return; + } + // Fill with junk to catch dangling refs. memset(pa, 1, PGSIZE); @@ -72,11 +84,40 @@ kalloc(void) acquire(&kmem.lock); r = kmem.freelist; - if(r) + if(r) { kmem.freelist = r->next; + kmem.ref_count[(uint64)r / PGSIZE] = 1; + } release(&kmem.lock); if(r) memset((char*)r, 5, PGSIZE); // fill with junk return (void*)r; } + +char +get_kmemref(void *pa){ + acquire(&kmem.lock); + char refcnt = kmem.ref_count[(uint64)pa/PGSIZE]; + release(&kmem.lock); + return refcnt; +} + +void +inc_kmemref(void *pa){ + acquire(&kmem.lock); + kmem.ref_count[(uint64)pa/PGSIZE]++; + release(&kmem.lock); +} + +void +dec_kmemref(void *pa){ + acquire(&kmem.lock); + kmem.ref_count[(uint64)pa/PGSIZE]--; + release(&kmem.lock); + +} + + + +
@@ -343,6 +343,7 @@ typedef uint64 *pagetable_t; // 512 PTEs #define PTE_W (1L << 2) #define PTE_X (1L << 3) #define PTE_U (1L << 4) // user can access +#define PTE_COW (1L << 8) // cow page // shift a physical address to the right place for a PTE. #define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)
@@ -65,6 +65,15 @@ usertrap(void) intr_on(); syscall(); + } else if(r_scause() == 15) { // 写页面错 + uint64 va0 = r_stval(); + if(va0 > p->sz) { + p->killed = 1; + } else if(cowhandler(p->pagetable,va0) !=0 ) { + p->killed = 1; + } else if(va0 < PGSIZE) { + p->killed = 1; + } } else if((which_dev = devintr()) != 0){ // ok } else { @@ -219,3 +228,37 @@ devintr() } } +int +cowhandler(pagetable_t pagetable, uint64 va) +{ + char *mem; + if (va >= MAXVA) + return -1; + pte_t *pte = walk(pagetable, va, 0); + if (pte == 0) + return -1; + // check the PTE + if ((*pte & PTE_COW) == 0 || (*pte & PTE_U) == 0 || (*pte & PTE_V) == 0) { + return -1; + } + uint64 pa = PTE2PA(*pte); + char refcnt = get_kmemref((void *)pa); + if(refcnt == 1) { + *pte = (*pte & (~PTE_COW)) | PTE_W; + return 0; + } + if(refcnt > 1) { + if ((mem = kalloc()) == 0) { + return -1; + } + // copy old data to new mem + memmove((char*)mem, (char*)pa, PGSIZE); + kfree((void*)pa); + uint flags = PTE_FLAGS(*pte); + *pte = (PA2PTE(mem) | flags | PTE_W); + *pte &= ~PTE_COW; + return 0; + } + return -1; +} +
@@ -5,7 +5,8 @@ #include "riscv.h" #include "defs.h" #include "fs.h" - +#include "spinlock.h" +#include "proc.h" /* * the kernel's page table. */ @@ -160,8 +161,8 @@ mappages(pagetable_t pagetable, uint64 va, uint64 size, uint64 pa, int perm) for(;;){ if((pte = walk(pagetable, a, 1)) == 0) return -1; - if(*pte & PTE_V) - panic("mappages: remap"); +// if(*pte & PTE_V) +// panic("mappages: remap"); *pte = PA2PTE(pa) | perm | PTE_V; if(a == last) break; @@ -315,22 +316,22 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz) pte_t *pte; uint64 pa, i; uint flags; - char *mem; for(i = 0; i < sz; i += PGSIZE){ if((pte = walk(old, i, 0)) == 0) panic("uvmcopy: pte should exist"); if((*pte & PTE_V) == 0) panic("uvmcopy: page not present"); + + if(*pte & PTE_W) { + *pte = (*pte & (~PTE_W)) | PTE_COW; + } pa = PTE2PA(*pte); flags = PTE_FLAGS(*pte); - if((mem = kalloc()) == 0) - goto err; - memmove(mem, (char*)pa, PGSIZE); - if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){ - kfree(mem); + if(mappages(new, i, PGSIZE, (uint64)pa, flags) != 0){ goto err; } + inc_kmemref((void *)pa); } return 0; @@ -360,23 +361,48 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len) { uint64 n, va0, pa0; pte_t *pte; - + while(len > 0){ va0 = PGROUNDDOWN(dstva); + pa0 = walkaddr(pagetable, va0); + if(pa0 == 0) + return -1; + struct proc *p = myproc(); if(va0 >= MAXVA) return -1; - pte = walk(pagetable, va0, 0); - if(pte == 0 || (*pte & PTE_V) == 0 || (*pte & PTE_U) == 0 || - (*pte & PTE_W) == 0) + if(va0 < PGSIZE) + return -1; + if((pte = walk(pagetable, va0, 0))==0) { + p->killed = 1; return -1; - pa0 = PTE2PA(*pte); + } + if ((va0 < p->sz) && (*pte & PTE_V) && + (*pte & PTE_COW)&&(*pte & PTE_U)) { + char refcnt = get_kmemref((void *)pa0); + if(refcnt == 1) { + *pte = (*pte &(~PTE_COW)) | PTE_W; + }else if(refcnt > 1){ + char *mem; + dec_kmemref((void *)pa0); + if ((mem = kalloc()) == 0) { + p->killed = 1; + return -1; + } + memmove(mem, (char*)pa0, PGSIZE); + uint flags = PTE_FLAGS(*pte); + *pte = (PA2PTE(mem) | flags | PTE_W); + *pte &= ~PTE_COW; + pa0 = (uint64)mem; + + } + } n = PGSIZE - (dstva - va0); if(n > len) n = len; memmove((void *)(pa0 + (dstva - va0)), src, n); len -= n; - src += n; + src += n; dstva = va0 + PGSIZE; } return 0;
|