xv6 Lab5: COW
Lapin Gris Lv3

Lab5 的任务是在 xv6 内核中实现 COW fork。叫做 lazy fork 也行。

Implement copy-on-write fork

关键实现

struct kmem

在 struct kmem 结构体中引入引用计数成员。

struct {
struct spinlock lock;
struct run *freelist;
int ref_count[PHYSTOP / PGSIZE];
} kmem;

并引入几个快捷操作函数,

int get_kmemref(void *pa) {
acquire(&kmem.lock);
int refcnt = kmem.ref_count[(uint64)pa / PGSIZE];
release(&kmem.lock);
return refcnt;
}

void inc_kmemref(void *pa) {
acquire(&kmem.lock);
kmem.ref_count[(uint64)pa / PGSIZE]++;
release(&kmem.lock);
}

void dec_kmemref(void *pa) {
acquire(&kmem.lock);
kmem.ref_count[(uint64)pa / PGSIZE]--;
release(&kmem.lock);
}

uvmcopy

uvmcopy 是用来拷贝父进程的地址空间的,将其修改为 COW 形式,将原有的 kalloc 行为延迟的运行时。实现方式为,

  • 移除父子进程地址空间的写权限
  • 增加 COW 标记
  • 增加引用计数(kfree 需要)
// Given a parent process's page table, copy
// its memory into a child's page table.
int uvmcopy(pagetable_t old, pagetable_t new, uint64 sz) {
pte_t *pte;
uint64 pa, i;
uint flags;

for (i = 0; i < sz; i += PGSIZE) {
if ((pte = walk(old, i, 0)) == 0)
panic("uvmcopy: pte should exist");
if ((*pte & PTE_V) == 0)
panic("uvmcopy: page not present");

*pte = (*pte & (~PTE_W)) | PTE_COW;
pa = PTE2PA(*pte);
flags = PTE_FLAGS(*pte);
if (mappages(new, i, PGSIZE, (uint64)pa, flags) != 0) {
goto err;
}
inc_kmemref((void *)pa);
}
return 0;

err:
uvmunmap(new, 0, i / PGSIZE, 1);
return -1;
}

usertrap

sscause 寄存器值位 15 代表 page write fault,此时触发 cowhandler 复制地址空间,处理完成移除页面 COW 位,并将页面标记为可写。

void usertrap(void) {	
// ...
} else if (r_scause() == 15) {
uint64 va0 = r_stval();
if (cowhandler(p->pagetable, va0) != 0) {
p->killed = 1;
}
// ...
}

int cowhandler(pagetable_t pagetable, uint64 va) {
pte_t *pte;
uint64 *mem;

if ((pte = walk(pagetable, va, 0)) == 0) {
return -1;
}

if (!(*pte & PTE_COW) || !(*pte & PTE_U) || !(*pte & PTE_V)) {
return -1;
}

uint64 pa = PTE2PA(*pte);

int refcnt = get_kmemref((void *)pa);
if (refcnt == 1) {
*pte = (*pte & (~PTE_COW)) | PTE_W; // mark it writeable & remove COW bit
return 0;
} else if (refcnt > 1) {
if ((mem = kalloc()) == 0) {
return -1;
}
memmove((void *)mem, (void *)pa, PGSIZE);
kfree((void *)pa); // remove old ref

uint flags = PTE_FLAGS(*pte);
*pte = (PA2PTE(mem) | (flags & ~PTE_COW)) | PTE_W; // make it writeable & remove COW bit
return 0;
}

return -1;
}

copyout

copyout 过程中也属于一个 case。需要进行 cow 处理。

int copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len) {
uint64 n, va0, pa0;
pte_t *pte;
uint64 *mem;

while (len > 0) {
va0 = PGROUNDDOWN(dstva);
pa0 = walkaddr(pagetable, va0);
if (pa0 == 0)
return -1;

if (va0 >= MAXVA || va0 < PGSIZE)
return -1;

if ((pte = walk(pagetable, va0, 0)) == 0) {
return -1;
}

if ((*pte & PTE_V) && (*pte & PTE_COW) && (*pte & PTE_U)) {
int refcnt = get_kmemref((void *)pa0);
if (refcnt == 1) {
*pte = (*pte & (~PTE_COW)) | PTE_W;
} else if (refcnt > 1) {
if ((mem = kalloc()) == 0) {
return -1;
}

memmove((void *)mem, (void *)pa0, PGSIZE);
uint flags = PTE_FLAGS(*pte);
*pte = (PA2PTE(mem) | (flags & ~PTE_COW)) | PTE_W; // writeable & remove COW bit

pa0 = (uint64)mem;
}
}
n = PGSIZE - (dstva - va0);
if (n > len)
n = len;
memmove((void *)(pa0 + (dstva - va0)), src, n);

len -= n;
src += n;
dstva = va0 + PGSIZE;
}
return 0;
}

Show me the code

diff --git a/kernel/defs.h b/kernel/defs.h
index a3c962b..8c4ebb8 100644
--- a/kernel/defs.h
+++ b/kernel/defs.h
@@ -63,6 +63,9 @@ void ramdiskrw(struct buf*);
void* kalloc(void);
void kfree(void *);
void kinit(void);
+void inc_kmemref(void*);
+void dec_kmemref(void*);
+char get_kmemref(void*);

// log.c
void initlog(int, struct superblock*);
@@ -147,6 +150,7 @@ void trapinit(void);
void trapinithart(void);
extern struct spinlock tickslock;
void usertrapret(void);
+int cowhandler(pagetable_t, uint64);

// uart.c
void uartinit(void);
diff --git a/kernel/kalloc.c b/kernel/kalloc.c
index 0699e7e..4baa07c 100644
--- a/kernel/kalloc.c
+++ b/kernel/kalloc.c
@@ -21,6 +21,7 @@ struct run {
struct {
struct spinlock lock;
struct run *freelist;
+ char ref_count[PHYSTOP/PGSIZE];
} kmem;

void
@@ -34,9 +35,15 @@ void
freerange(void *pa_start, void *pa_end)
{
char *p;
+ int i;
p = (char*)PGROUNDUP((uint64)pa_start);
- for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE)
+ for(i=0; i<(uint64)p/PGSIZE; i++) {
+ kmem.ref_count[i] = 0;
+ }
+ for(; p + PGSIZE <= (char*)pa_end; p += PGSIZE) {
+ kmem.ref_count[(uint64)p/PGSIZE] = 1;
kfree(p);
+ }
}

// Free the page of physical memory pointed at by pa,
@@ -51,6 +58,11 @@ kfree(void *pa)
if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
panic("kfree");

+ if (get_kmemref(pa) > 1) {
+ dec_kmemref(pa);
+ return;
+ }
+
// Fill with junk to catch dangling refs.
memset(pa, 1, PGSIZE);

@@ -72,11 +84,40 @@ kalloc(void)

acquire(&kmem.lock);
r = kmem.freelist;
- if(r)
+ if(r) {
kmem.freelist = r->next;
+ kmem.ref_count[(uint64)r / PGSIZE] = 1;
+ }
release(&kmem.lock);

if(r)
memset((char*)r, 5, PGSIZE); // fill with junk
return (void*)r;
}
+
+char
+get_kmemref(void *pa){
+ acquire(&kmem.lock);
+ char refcnt = kmem.ref_count[(uint64)pa/PGSIZE];
+ release(&kmem.lock);
+ return refcnt;
+}
+
+void
+inc_kmemref(void *pa){
+ acquire(&kmem.lock);
+ kmem.ref_count[(uint64)pa/PGSIZE]++;
+ release(&kmem.lock);
+}
+
+void
+dec_kmemref(void *pa){
+ acquire(&kmem.lock);
+ kmem.ref_count[(uint64)pa/PGSIZE]--;
+ release(&kmem.lock);
+
+}
+
+
+
+
diff --git a/kernel/riscv.h b/kernel/riscv.h
index 20a01db..7eb2bd4 100644
--- a/kernel/riscv.h
+++ b/kernel/riscv.h
@@ -343,6 +343,7 @@ typedef uint64 *pagetable_t; // 512 PTEs
#define PTE_W (1L << 2)
#define PTE_X (1L << 3)
#define PTE_U (1L << 4) // user can access
+#define PTE_COW (1L << 8) // cow page

// shift a physical address to the right place for a PTE.
#define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)
diff --git a/kernel/trap.c b/kernel/trap.c
index 512c850..e6be7fd 100644
--- a/kernel/trap.c
+++ b/kernel/trap.c
@@ -65,6 +65,15 @@ usertrap(void)
intr_on();

syscall();
+ } else if(r_scause() == 15) { // 写页面错
+ uint64 va0 = r_stval();
+ if(va0 > p->sz) {
+ p->killed = 1;
+ } else if(cowhandler(p->pagetable,va0) !=0 ) {
+ p->killed = 1;
+ } else if(va0 < PGSIZE) {
+ p->killed = 1;
+ }
} else if((which_dev = devintr()) != 0){
// ok
} else {
@@ -219,3 +228,37 @@ devintr()
}
}

+int
+cowhandler(pagetable_t pagetable, uint64 va)
+{
+ char *mem;
+ if (va >= MAXVA)
+ return -1;
+ pte_t *pte = walk(pagetable, va, 0);
+ if (pte == 0)
+ return -1;
+ // check the PTE
+ if ((*pte & PTE_COW) == 0 || (*pte & PTE_U) == 0 || (*pte & PTE_V) == 0) {
+ return -1;
+ }
+ uint64 pa = PTE2PA(*pte);
+ char refcnt = get_kmemref((void *)pa);
+ if(refcnt == 1) {
+ *pte = (*pte & (~PTE_COW)) | PTE_W;
+ return 0;
+ }
+ if(refcnt > 1) {
+ if ((mem = kalloc()) == 0) {
+ return -1;
+ }
+ // copy old data to new mem
+ memmove((char*)mem, (char*)pa, PGSIZE);
+ kfree((void*)pa);
+ uint flags = PTE_FLAGS(*pte);
+ *pte = (PA2PTE(mem) | flags | PTE_W);
+ *pte &= ~PTE_COW;
+ return 0;
+ }
+ return -1;
+}
+
diff --git a/kernel/vm.c b/kernel/vm.c
index 5c31e87..78551b0 100644
--- a/kernel/vm.c
+++ b/kernel/vm.c
@@ -5,7 +5,8 @@
#include "riscv.h"
#include "defs.h"
#include "fs.h"
-
+#include "spinlock.h"
+#include "proc.h"
/*
* the kernel's page table.
*/
@@ -160,8 +161,8 @@ mappages(pagetable_t pagetable, uint64 va, uint64 size, uint64 pa, int perm)
for(;;){
if((pte = walk(pagetable, a, 1)) == 0)
return -1;
- if(*pte & PTE_V)
- panic("mappages: remap");
+// if(*pte & PTE_V)
+// panic("mappages: remap");
*pte = PA2PTE(pa) | perm | PTE_V;
if(a == last)
break;
@@ -315,22 +316,22 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
pte_t *pte;
uint64 pa, i;
uint flags;
- char *mem;

for(i = 0; i < sz; i += PGSIZE){
if((pte = walk(old, i, 0)) == 0)
panic("uvmcopy: pte should exist");
if((*pte & PTE_V) == 0)
panic("uvmcopy: page not present");
+
+ if(*pte & PTE_W) {
+ *pte = (*pte & (~PTE_W)) | PTE_COW;
+ }
pa = PTE2PA(*pte);
flags = PTE_FLAGS(*pte);
- if((mem = kalloc()) == 0)
- goto err;
- memmove(mem, (char*)pa, PGSIZE);
- if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
- kfree(mem);
+ if(mappages(new, i, PGSIZE, (uint64)pa, flags) != 0){
goto err;
}
+ inc_kmemref((void *)pa);
}
return 0;

@@ -360,23 +361,48 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)
{
uint64 n, va0, pa0;
pte_t *pte;
-
+
while(len > 0){
va0 = PGROUNDDOWN(dstva);
+ pa0 = walkaddr(pagetable, va0);
+ if(pa0 == 0)
+ return -1;
+ struct proc *p = myproc();
if(va0 >= MAXVA)
return -1;
- pte = walk(pagetable, va0, 0);
- if(pte == 0 || (*pte & PTE_V) == 0 || (*pte & PTE_U) == 0 ||
- (*pte & PTE_W) == 0)
+ if(va0 < PGSIZE)
+ return -1;
+ if((pte = walk(pagetable, va0, 0))==0) {
+ p->killed = 1;
return -1;
- pa0 = PTE2PA(*pte);
+ }
+ if ((va0 < p->sz) && (*pte & PTE_V) &&
+ (*pte & PTE_COW)&&(*pte & PTE_U)) {
+ char refcnt = get_kmemref((void *)pa0);
+ if(refcnt == 1) {
+ *pte = (*pte &(~PTE_COW)) | PTE_W;
+ }else if(refcnt > 1){
+ char *mem;
+ dec_kmemref((void *)pa0);
+ if ((mem = kalloc()) == 0) {
+ p->killed = 1;
+ return -1;
+ }
+ memmove(mem, (char*)pa0, PGSIZE);
+ uint flags = PTE_FLAGS(*pte);
+ *pte = (PA2PTE(mem) | flags | PTE_W);
+ *pte &= ~PTE_COW;
+ pa0 = (uint64)mem;
+
+ }
+ }
n = PGSIZE - (dstva - va0);
if(n > len)
n = len;
memmove((void *)(pa0 + (dstva - va0)), src, n);

len -= n;
- src += n;
+ src += n;
dstva = va0 + PGSIZE;
}
return 0;

测试结果

$ make qemu-gdb
(12.3s)
== Test simple ==
simple: OK
== Test three ==
three: OK
== Test file ==
file: OK
== Test usertests ==
$ make qemu-gdb
(118.3s)
== Test usertests: copyin ==
usertests: copyin: OK
== Test usertests: copyout ==
usertests: copyout: OK
== Test usertests: all tests ==
usertests: all tests: OK
== Test time ==
time: OK
Score: 110/110