uaccess clean up

上级 ae8f1bcf
...@@ -205,9 +205,9 @@ void release(struct spinlock*); ...@@ -205,9 +205,9 @@ void release(struct spinlock*);
int argcheckptr(void *argval, int); int argcheckptr(void *argval, int);
int argcheckstr(const char*); int argcheckstr(const char*);
int fetchint64(uptr, u64*); int fetchint64(uptr, u64*);
int fetchstr(char*, const char*, unsigned); int fetchstr(char*, const char*, u64);
int umemcpy(void*, void*, u64); int fetchmem(void*, const void*, u64);
int kmemcpy(void*, void*, u64); int putmem(void*, const void*, u64);
u64 syscall(u64 a0, u64 a1, u64 a2, u64 a3, u64 a4, u64 num); u64 syscall(u64 a0, u64 a1, u64 a2, u64 a3, u64 a4, u64 num);
// string.c // string.c
......
...@@ -282,7 +282,7 @@ netbind(int sock, void *xaddr, int xaddrlen) ...@@ -282,7 +282,7 @@ netbind(int sock, void *xaddr, int xaddrlen)
if (addr == nullptr) if (addr == nullptr)
return -1; return -1;
if (umemcpy(addr, xaddr, xaddrlen)) if (fetchmem(addr, xaddr, xaddrlen))
return -1; return -1;
lwip_core_lock(); lwip_core_lock();
...@@ -311,7 +311,7 @@ netaccept(int sock, void *xaddr, void *xaddrlen) ...@@ -311,7 +311,7 @@ netaccept(int sock, void *xaddr, void *xaddrlen)
void *addr; void *addr;
int ss; int ss;
if (umemcpy(&len, lenptr, sizeof(*lenptr))) if (fetchmem(&len, lenptr, sizeof(*lenptr)))
return -1; return -1;
addr = kmalloc(len, "sockaddr"); addr = kmalloc(len, "sockaddr");
...@@ -326,7 +326,7 @@ netaccept(int sock, void *xaddr, void *xaddrlen) ...@@ -326,7 +326,7 @@ netaccept(int sock, void *xaddr, void *xaddrlen)
return ss; return ss;
} }
if (kmemcpy(xaddrlen, &len, sizeof(len)) || kmemcpy(xaddr, addr, len)) { if (putmem(xaddrlen, &len, sizeof(len)) || putmem(xaddr, addr, len)) {
lwip_core_lock(); lwip_core_lock();
lwip_close(ss); lwip_close(ss);
lwip_core_unlock(); lwip_core_unlock();
...@@ -357,7 +357,7 @@ netwrite(int sock, char *ubuf, int len) ...@@ -357,7 +357,7 @@ netwrite(int sock, char *ubuf, int len)
return -1; return -1;
cc = MIN(len, PGSIZE); cc = MIN(len, PGSIZE);
if (umemcpy(kbuf, ubuf, cc)) { if (fetchmem(kbuf, ubuf, cc)) {
kfree(kbuf); kfree(kbuf);
return -1; return -1;
} }
...@@ -388,7 +388,7 @@ netread(int sock, char *ubuf, int len) ...@@ -388,7 +388,7 @@ netread(int sock, char *ubuf, int len)
return r; return r;
} }
kmemcpy(ubuf, kbuf, r); putmem(ubuf, kbuf, r);
kfree(kbuf); kfree(kbuf);
return r; return r;
} }
......
...@@ -10,15 +10,32 @@ ...@@ -10,15 +10,32 @@
#include "cpu.hh" #include "cpu.hh"
#include "kmtrace.hh" #include "kmtrace.hh"
extern "C" int __fetchstr(char* dst, const char* usrc, unsigned size); extern "C" int __uaccess_mem(void* dst, const void* src, u64 size);
extern "C" int __fetchint64(uptr addr, u64* ip); extern "C" int __uaccess_str(char* dst, const char* src, u64 size);
extern "C" int __uaccess_int64(uptr addr, u64* ip);
int int
fetchstr(char* dst, const char* usrc, unsigned size) fetchmem(void* dst, const void* usrc, u64 size)
{ {
if(mycpu()->ncli != 0) if(mycpu()->ncli != 0)
panic("fetchstr: cli'd"); panic("fetchstr: cli'd");
return __fetchstr(dst, usrc, size); return __uaccess_mem(dst, usrc, size);
}
int
putmem(void *udst, const void *src, u64 size)
{
if(mycpu()->ncli != 0)
panic("fetchstr: cli'd");
return __uaccess_mem(udst, src, size);
}
int
fetchstr(char* dst, const char* usrc, u64 size)
{
if(mycpu()->ncli != 0)
panic("fetchstr: cli'd");
return __uaccess_str(dst, usrc, size);
} }
int int
...@@ -26,7 +43,7 @@ fetchint64(uptr addr, u64 *ip) ...@@ -26,7 +43,7 @@ fetchint64(uptr addr, u64 *ip)
{ {
if(mycpu()->ncli != 0) if(mycpu()->ncli != 0)
panic("fetchstr: cli'd"); panic("fetchstr: cli'd");
return __fetchint64(addr, ip); return __uaccess_int64(addr, ip);
} }
// Fetch the nul-terminated string at addr from process p. // Fetch the nul-terminated string at addr from process p.
...@@ -60,43 +77,6 @@ argcheckptr(void *p, int size) ...@@ -60,43 +77,6 @@ argcheckptr(void *p, int size)
return 0; return 0;
} }
static int
umemptr(void *umem, void **ret, u64 size)
{
uptr ptr = (uptr) umem;
for(uptr va = PGROUNDDOWN(ptr); va < ptr+size; va = va + PGSIZE)
if(pagefault(myproc()->vmap, va, 0) < 0)
return -1;
*ret = umem;
return 0;
}
int
umemcpy(void *dst, void *umem, u64 size)
{
void *ptr;
if (umemptr(umem, &ptr, size))
return -1;
memmove(dst, ptr, size);
return 0;
}
int
kmemcpy(void *umem, void *src, u64 size)
{
void *ptr;
if (umemptr(umem, &ptr, size))
return -1;
memmove(ptr, src, size);
return 0;
}
u64 u64
syscall(u64 a0, u64 a1, u64 a2, u64 a3, u64 a4, u64 num) syscall(u64 a0, u64 a1, u64 a2, u64 a3, u64 a4, u64 num)
{ {
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "bits.hh" #include "bits.hh"
#include "kalloc.hh" #include "kalloc.hh"
extern "C" void __fetch_end(void); extern "C" void __uaccess_end(void);
struct intdesc idt[256] __attribute__((aligned(16))); struct intdesc idt[256] __attribute__((aligned(16)));
...@@ -60,7 +60,7 @@ do_pagefault(struct trapframe *tf) ...@@ -60,7 +60,7 @@ do_pagefault(struct trapframe *tf)
} }
cprintf("pagefault: failed in kernel\n"); cprintf("pagefault: failed in kernel\n");
tf->rax = -1; tf->rax = -1;
tf->rip = (u64)__fetch_end; tf->rip = (u64)__uaccess_end;
return 0; return 0;
} else if (tf->err & FEC_U) { } else if (tf->err & FEC_U) {
sti(); sti();
......
#include "mmu.h" #include "mmu.h"
#include "asmdefines.h" #include "asmdefines.h"
#define ENTRY(name) .globl name ; .align 8; name :
// We aren't allowed to touch rbx,rsp,rbp,r12-r15
.code64 .code64
.globl __fetchint64
.align 8
// rdi user src // rdi user src
// rsi kernel dst // rsi kernel dst
// We aren't allowed to touch rbx,rsp,rbp,r12-r15 ENTRY(__uaccess_int64)
__fetchint64:
mov %gs:0x8, %r11 mov %gs:0x8, %r11
movl $1, PROC_UACCESS(%r11) movl $1, PROC_UACCESS(%r11)
mov (%rdi), %r10 mov (%rdi), %r10
mov %r10, (%rsi) mov %r10, (%rsi)
mov $0, %rax mov $0, %rax
jmp __fetch_end jmp __uaccess_end
.globl __fetchstr // rdi dst
.align 8 // rsi src
// rdi kernel dst // rdx dst len
// rsi user src ENTRY(__uaccess_str)
// rdx kernel len
// We aren't allowed to touch rbx,rsp,rbp,r12-r15
__fetchstr:
mov %gs:0x8, %r11 mov %gs:0x8, %r11
movl $1, PROC_UACCESS(%r11) movl $1, PROC_UACCESS(%r11)
...@@ -40,11 +38,30 @@ __fetchstr: ...@@ -40,11 +38,30 @@ __fetchstr:
// Error // Error
movq $-1, %rax movq $-1, %rax
2: // Done 2: // Done
jmp __fetch_end jmp __uaccess_end
// rdi dst
// rsi src
// rdx len
ENTRY(__uaccess_mem)
mov %gs:0x8, %r11
movl $1, PROC_UACCESS(%r11)
.globl __fetch_end // %rcx is loop instruction counter
mov %rdx, %rcx
xor %rax, %rax
1:
movb (%rsi), %r10b
movb %r10b, (%rdi)
inc %rdi
inc %rsi
loop 1b
// Done
jmp __uaccess_end
.globl __uaccess_end
.align 8 .align 8
__fetch_end: __uaccess_end:
movl $0, PROC_UACCESS(%r11) movl $0, PROC_UACCESS(%r11)
ret ret
......
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论