]> git.kernelconcepts.de Git - karo-tx-linux.git/blobdiff - fs/userfaultfd.c
userfaultfd: non-cooperative: wake userfaults after UFFDIO_UNREGISTER
[karo-tx-linux.git] / fs / userfaultfd.c
index d96e2f30084bcfab552ffe3005af090abfe319c9..26e1ef00b63c1f2c58be46440f892033127bcf4b 100644 (file)
@@ -12,6 +12,7 @@
  *  mm/ksm.c (mm hashing).
  */
 
+#include <linux/list.h>
 #include <linux/hashtable.h>
 #include <linux/sched.h>
 #include <linux/mm.h>
@@ -45,12 +46,16 @@ struct userfaultfd_ctx {
        wait_queue_head_t fault_wqh;
        /* waitqueue head for the pseudo fd to wakeup poll/read */
        wait_queue_head_t fd_wqh;
+       /* waitqueue head for events */
+       wait_queue_head_t event_wqh;
        /* a refile sequence protected by fault_pending_wqh lock */
        struct seqcount refile_seq;
        /* pseudo fd refcounting */
        atomic_t refcount;
        /* userfaultfd syscall flags */
        unsigned int flags;
+       /* features requested from the userspace */
+       unsigned int features;
        /* state machine */
        enum userfaultfd_state state;
        /* released */
@@ -59,10 +64,17 @@ struct userfaultfd_ctx {
        struct mm_struct *mm;
 };
 
+struct userfaultfd_fork_ctx {
+       struct userfaultfd_ctx *orig;
+       struct userfaultfd_ctx *new;
+       struct list_head list;
+};
+
 struct userfaultfd_wait_queue {
        struct uffd_msg msg;
        wait_queue_t wq;
        struct userfaultfd_ctx *ctx;
+       bool waken;
 };
 
 struct userfaultfd_wake_range {
@@ -86,6 +98,12 @@ static int userfaultfd_wake_function(wait_queue_t *wq, unsigned mode,
        if (len && (start > uwq->msg.arg.pagefault.address ||
                    start + len <= uwq->msg.arg.pagefault.address))
                goto out;
+       WRITE_ONCE(uwq->waken, true);
+       /*
+        * The implicit smp_mb__before_spinlock in try_to_wake_up()
+        * renders uwq->waken visible to other CPUs before the task is
+        * waken.
+        */
        ret = wake_up_state(wq->private, mode);
        if (ret)
                /*
@@ -135,6 +153,8 @@ static void userfaultfd_ctx_put(struct userfaultfd_ctx *ctx)
                VM_BUG_ON(waitqueue_active(&ctx->fault_pending_wqh));
                VM_BUG_ON(spin_is_locked(&ctx->fault_wqh.lock));
                VM_BUG_ON(waitqueue_active(&ctx->fault_wqh));
+               VM_BUG_ON(spin_is_locked(&ctx->event_wqh.lock));
+               VM_BUG_ON(waitqueue_active(&ctx->event_wqh));
                VM_BUG_ON(spin_is_locked(&ctx->fd_wqh.lock));
                VM_BUG_ON(waitqueue_active(&ctx->fd_wqh));
                mmdrop(ctx->mm);
@@ -162,7 +182,7 @@ static inline struct uffd_msg userfault_msg(unsigned long address,
        msg.arg.pagefault.address = address;
        if (flags & FAULT_FLAG_WRITE)
                /*
-                * If UFFD_FEATURE_PAGEFAULT_FLAG_WRITE was set in the
+                * If UFFD_FEATURE_PAGEFAULT_FLAG_WP was set in the
                 * uffdio_api.features and UFFD_PAGEFAULT_FLAG_WRITE
                 * was not set in a UFFD_EVENT_PAGEFAULT, it means it
                 * was a read fault, otherwise if set it means it's
@@ -264,6 +284,7 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
        struct userfaultfd_wait_queue uwq;
        int ret;
        bool must_wait, return_to_userland;
+       long blocking_state;
 
        BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
 
@@ -334,10 +355,13 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
        uwq.wq.private = current;
        uwq.msg = userfault_msg(vmf->address, vmf->flags, reason);
        uwq.ctx = ctx;
+       uwq.waken = false;
 
        return_to_userland =
                (vmf->flags & (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE)) ==
                (FAULT_FLAG_USER|FAULT_FLAG_KILLABLE);
+       blocking_state = return_to_userland ? TASK_INTERRUPTIBLE :
+                        TASK_KILLABLE;
 
        spin_lock(&ctx->fault_pending_wqh.lock);
        /*
@@ -350,8 +374,7 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
         * following the spin_unlock to happen before the list_add in
         * __add_wait_queue.
         */
-       set_current_state(return_to_userland ? TASK_INTERRUPTIBLE :
-                         TASK_KILLABLE);
+       set_current_state(blocking_state);
        spin_unlock(&ctx->fault_pending_wqh.lock);
 
        must_wait = userfaultfd_must_wait(ctx, vmf->address, vmf->flags,
@@ -364,6 +387,29 @@ int handle_userfault(struct vm_fault *vmf, unsigned long reason)
                wake_up_poll(&ctx->fd_wqh, POLLIN);
                schedule();
                ret |= VM_FAULT_MAJOR;
+
+               /*
+                * False wakeups can orginate even from rwsem before
+                * up_read() however userfaults will wait either for a
+                * targeted wakeup on the specific uwq waitqueue from
+                * wake_userfault() or for signals or for uffd
+                * release.
+                */
+               while (!READ_ONCE(uwq.waken)) {
+                       /*
+                        * This needs the full smp_store_mb()
+                        * guarantee as the state write must be
+                        * visible to other CPUs before reading
+                        * uwq.waken from other CPUs.
+                        */
+                       set_current_state(blocking_state);
+                       if (READ_ONCE(uwq.waken) ||
+                           READ_ONCE(ctx->released) ||
+                           (return_to_userland ? signal_pending(current) :
+                            fatal_signal_pending(current)))
+                               break;
+                       schedule();
+               }
        }
 
        __set_current_state(TASK_RUNNING);
@@ -425,6 +471,196 @@ out:
        return ret;
 }
 
+static int userfaultfd_event_wait_completion(struct userfaultfd_ctx *ctx,
+                                            struct userfaultfd_wait_queue *ewq)
+{
+       int ret = 0;
+
+       ewq->ctx = ctx;
+       init_waitqueue_entry(&ewq->wq, current);
+
+       spin_lock(&ctx->event_wqh.lock);
+       /*
+        * After the __add_wait_queue the uwq is visible to userland
+        * through poll/read().
+        */
+       __add_wait_queue(&ctx->event_wqh, &ewq->wq);
+       for (;;) {
+               set_current_state(TASK_KILLABLE);
+               if (ewq->msg.event == 0)
+                       break;
+               if (ACCESS_ONCE(ctx->released) ||
+                   fatal_signal_pending(current)) {
+                       ret = -1;
+                       __remove_wait_queue(&ctx->event_wqh, &ewq->wq);
+                       break;
+               }
+
+               spin_unlock(&ctx->event_wqh.lock);
+
+               wake_up_poll(&ctx->fd_wqh, POLLIN);
+               schedule();
+
+               spin_lock(&ctx->event_wqh.lock);
+       }
+       __set_current_state(TASK_RUNNING);
+       spin_unlock(&ctx->event_wqh.lock);
+
+       /*
+        * ctx may go away after this if the userfault pseudo fd is
+        * already released.
+        */
+
+       userfaultfd_ctx_put(ctx);
+       return ret;
+}
+
+static void userfaultfd_event_complete(struct userfaultfd_ctx *ctx,
+                                      struct userfaultfd_wait_queue *ewq)
+{
+       ewq->msg.event = 0;
+       wake_up_locked(&ctx->event_wqh);
+       __remove_wait_queue(&ctx->event_wqh, &ewq->wq);
+}
+
+int dup_userfaultfd(struct vm_area_struct *vma, struct list_head *fcs)
+{
+       struct userfaultfd_ctx *ctx = NULL, *octx;
+       struct userfaultfd_fork_ctx *fctx;
+
+       octx = vma->vm_userfaultfd_ctx.ctx;
+       if (!octx || !(octx->features & UFFD_FEATURE_EVENT_FORK)) {
+               vma->vm_userfaultfd_ctx = NULL_VM_UFFD_CTX;
+               vma->vm_flags &= ~(VM_UFFD_WP | VM_UFFD_MISSING);
+               return 0;
+       }
+
+       list_for_each_entry(fctx, fcs, list)
+               if (fctx->orig == octx) {
+                       ctx = fctx->new;
+                       break;
+               }
+
+       if (!ctx) {
+               fctx = kmalloc(sizeof(*fctx), GFP_KERNEL);
+               if (!fctx)
+                       return -ENOMEM;
+
+               ctx = kmem_cache_alloc(userfaultfd_ctx_cachep, GFP_KERNEL);
+               if (!ctx) {
+                       kfree(fctx);
+                       return -ENOMEM;
+               }
+
+               atomic_set(&ctx->refcount, 1);
+               ctx->flags = octx->flags;
+               ctx->state = UFFD_STATE_RUNNING;
+               ctx->features = octx->features;
+               ctx->released = false;
+               ctx->mm = vma->vm_mm;
+               atomic_inc(&ctx->mm->mm_count);
+
+               userfaultfd_ctx_get(octx);
+               fctx->orig = octx;
+               fctx->new = ctx;
+               list_add_tail(&fctx->list, fcs);
+       }
+
+       vma->vm_userfaultfd_ctx.ctx = ctx;
+       return 0;
+}
+
+static int dup_fctx(struct userfaultfd_fork_ctx *fctx)
+{
+       struct userfaultfd_ctx *ctx = fctx->orig;
+       struct userfaultfd_wait_queue ewq;
+
+       msg_init(&ewq.msg);
+
+       ewq.msg.event = UFFD_EVENT_FORK;
+       ewq.msg.arg.reserved.reserved1 = (unsigned long)fctx->new;
+
+       return userfaultfd_event_wait_completion(ctx, &ewq);
+}
+
+void dup_userfaultfd_complete(struct list_head *fcs)
+{
+       int ret = 0;
+       struct userfaultfd_fork_ctx *fctx, *n;
+
+       list_for_each_entry_safe(fctx, n, fcs, list) {
+               if (!ret)
+                       ret = dup_fctx(fctx);
+               list_del(&fctx->list);
+               kfree(fctx);
+       }
+}
+
+void mremap_userfaultfd_prep(struct vm_area_struct *vma,
+                            struct vm_userfaultfd_ctx *vm_ctx)
+{
+       struct userfaultfd_ctx *ctx;
+
+       ctx = vma->vm_userfaultfd_ctx.ctx;
+       if (ctx && (ctx->features & UFFD_FEATURE_EVENT_REMAP)) {
+               vm_ctx->ctx = ctx;
+               userfaultfd_ctx_get(ctx);
+       }
+}
+
+void mremap_userfaultfd_complete(struct vm_userfaultfd_ctx *vm_ctx,
+                                unsigned long from, unsigned long to,
+                                unsigned long len)
+{
+       struct userfaultfd_ctx *ctx = vm_ctx->ctx;
+       struct userfaultfd_wait_queue ewq;
+
+       if (!ctx)
+               return;
+
+       if (to & ~PAGE_MASK) {
+               userfaultfd_ctx_put(ctx);
+               return;
+       }
+
+       msg_init(&ewq.msg);
+
+       ewq.msg.event = UFFD_EVENT_REMAP;
+       ewq.msg.arg.remap.from = from;
+       ewq.msg.arg.remap.to = to;
+       ewq.msg.arg.remap.len = len;
+
+       userfaultfd_event_wait_completion(ctx, &ewq);
+}
+
+void madvise_userfault_dontneed(struct vm_area_struct *vma,
+                               struct vm_area_struct **prev,
+                               unsigned long start, unsigned long end)
+{
+       struct mm_struct *mm = vma->vm_mm;
+       struct userfaultfd_ctx *ctx;
+       struct userfaultfd_wait_queue ewq;
+
+       ctx = vma->vm_userfaultfd_ctx.ctx;
+       if (!ctx || !(ctx->features & UFFD_FEATURE_EVENT_MADVDONTNEED))
+               return;
+
+       userfaultfd_ctx_get(ctx);
+       up_read(&mm->mmap_sem);
+
+       *prev = NULL; /* We wait for ACK w/o the mmap semaphore */
+
+       msg_init(&ewq.msg);
+
+       ewq.msg.event = UFFD_EVENT_MADVDONTNEED;
+       ewq.msg.arg.madv_dn.start = start;
+       ewq.msg.arg.madv_dn.end = end;
+
+       userfaultfd_event_wait_completion(ctx, &ewq);
+
+       down_read(&mm->mmap_sem);
+}
+
 static int userfaultfd_release(struct inode *inode, struct file *file)
 {
        struct userfaultfd_ctx *ctx = file->private_data;
@@ -489,25 +725,36 @@ wakeup:
 }
 
 /* fault_pending_wqh.lock must be hold by the caller */
-static inline struct userfaultfd_wait_queue *find_userfault(
-       struct userfaultfd_ctx *ctx)
+static inline struct userfaultfd_wait_queue *find_userfault_in(
+               wait_queue_head_t *wqh)
 {
        wait_queue_t *wq;
        struct userfaultfd_wait_queue *uwq;
 
-       VM_BUG_ON(!spin_is_locked(&ctx->fault_pending_wqh.lock));
+       VM_BUG_ON(!spin_is_locked(&wqh->lock));
 
        uwq = NULL;
-       if (!waitqueue_active(&ctx->fault_pending_wqh))
+       if (!waitqueue_active(wqh))
                goto out;
        /* walk in reverse to provide FIFO behavior to read userfaults */
-       wq = list_last_entry(&ctx->fault_pending_wqh.task_list,
-                            typeof(*wq), task_list);
+       wq = list_last_entry(&wqh->task_list, typeof(*wq), task_list);
        uwq = container_of(wq, struct userfaultfd_wait_queue, wq);
 out:
        return uwq;
 }
 
+static inline struct userfaultfd_wait_queue *find_userfault(
+               struct userfaultfd_ctx *ctx)
+{
+       return find_userfault_in(&ctx->fault_pending_wqh);
+}
+
+static inline struct userfaultfd_wait_queue *find_userfault_evt(
+               struct userfaultfd_ctx *ctx)
+{
+       return find_userfault_in(&ctx->event_wqh);
+}
+
 static unsigned int userfaultfd_poll(struct file *file, poll_table *wait)
 {
        struct userfaultfd_ctx *ctx = file->private_data;
@@ -539,18 +786,59 @@ static unsigned int userfaultfd_poll(struct file *file, poll_table *wait)
                smp_mb();
                if (waitqueue_active(&ctx->fault_pending_wqh))
                        ret = POLLIN;
+               else if (waitqueue_active(&ctx->event_wqh))
+                       ret = POLLIN;
+
                return ret;
        default:
-               BUG();
+               WARN_ON_ONCE(1);
+               return POLLERR;
        }
 }
 
+static const struct file_operations userfaultfd_fops;
+
+static int resolve_userfault_fork(struct userfaultfd_ctx *ctx,
+                                 struct userfaultfd_ctx *new,
+                                 struct uffd_msg *msg)
+{
+       int fd;
+       struct file *file;
+       unsigned int flags = new->flags & UFFD_SHARED_FCNTL_FLAGS;
+
+       fd = get_unused_fd_flags(flags);
+       if (fd < 0)
+               return fd;
+
+       file = anon_inode_getfile("[userfaultfd]", &userfaultfd_fops, new,
+                                 O_RDWR | flags);
+       if (IS_ERR(file)) {
+               put_unused_fd(fd);
+               return PTR_ERR(file);
+       }
+
+       fd_install(fd, file);
+       msg->arg.reserved.reserved1 = 0;
+       msg->arg.fork.ufd = fd;
+
+       return 0;
+}
+
 static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                                    struct uffd_msg *msg)
 {
        ssize_t ret;
        DECLARE_WAITQUEUE(wait, current);
        struct userfaultfd_wait_queue *uwq;
+       /*
+        * Handling fork event requires sleeping operations, so
+        * we drop the event_wqh lock, then do these ops, then
+        * lock it back and wake up the waiter. While the lock is
+        * dropped the ewq may go away so we keep track of it
+        * carefully.
+        */
+       LIST_HEAD(fork_event);
+       struct userfaultfd_ctx *fork_nctx = NULL;
 
        /* always take the fd_wqh lock before the fault_pending_wqh lock */
        spin_lock(&ctx->fd_wqh.lock);
@@ -602,6 +890,29 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
                        break;
                }
                spin_unlock(&ctx->fault_pending_wqh.lock);
+
+               spin_lock(&ctx->event_wqh.lock);
+               uwq = find_userfault_evt(ctx);
+               if (uwq) {
+                       *msg = uwq->msg;
+
+                       if (uwq->msg.event == UFFD_EVENT_FORK) {
+                               fork_nctx = (struct userfaultfd_ctx *)
+                                       (unsigned long)
+                                       uwq->msg.arg.reserved.reserved1;
+                               list_move(&uwq->wq.task_list, &fork_event);
+                               spin_unlock(&ctx->event_wqh.lock);
+                               ret = 0;
+                               break;
+                       }
+
+                       userfaultfd_event_complete(ctx, uwq);
+                       spin_unlock(&ctx->event_wqh.lock);
+                       ret = 0;
+                       break;
+               }
+               spin_unlock(&ctx->event_wqh.lock);
+
                if (signal_pending(current)) {
                        ret = -ERESTARTSYS;
                        break;
@@ -618,6 +929,23 @@ static ssize_t userfaultfd_ctx_read(struct userfaultfd_ctx *ctx, int no_wait,
        __set_current_state(TASK_RUNNING);
        spin_unlock(&ctx->fd_wqh.lock);
 
+       if (!ret && msg->event == UFFD_EVENT_FORK) {
+               ret = resolve_userfault_fork(ctx, fork_nctx, msg);
+
+               if (!ret) {
+                       spin_lock(&ctx->event_wqh.lock);
+                       if (!list_empty(&fork_event)) {
+                               uwq = list_first_entry(&fork_event,
+                                                      typeof(*uwq),
+                                                      wq.task_list);
+                               list_del(&uwq->wq.task_list);
+                               __add_wait_queue(&ctx->event_wqh, &uwq->wq);
+                               userfaultfd_event_complete(ctx, uwq);
+                       }
+                       spin_unlock(&ctx->event_wqh.lock);
+               }
+       }
+
        return ret;
 }
 
@@ -796,7 +1124,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
 
                /* check not compatible vmas */
                ret = -EINVAL;
-               if (cur->vm_ops)
+               if (!vma_is_anonymous(cur))
                        goto out_unlock;
 
                /*
@@ -821,7 +1149,7 @@ static int userfaultfd_register(struct userfaultfd_ctx *ctx,
        do {
                cond_resched();
 
-               BUG_ON(vma->vm_ops);
+               BUG_ON(!vma_is_anonymous(vma));
                BUG_ON(vma->vm_userfaultfd_ctx.ctx &&
                       vma->vm_userfaultfd_ctx.ctx != ctx);
 
@@ -947,7 +1275,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                 * provides for more strict behavior to notice
                 * unregistration errors.
                 */
-               if (cur->vm_ops)
+               if (!vma_is_anonymous(cur))
                        goto out_unlock;
 
                found = true;
@@ -961,7 +1289,7 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
        do {
                cond_resched();
 
-               BUG_ON(vma->vm_ops);
+               BUG_ON(!vma_is_anonymous(vma));
 
                /*
                 * Nothing to do: this vma is already registered into this
@@ -974,6 +1302,19 @@ static int userfaultfd_unregister(struct userfaultfd_ctx *ctx,
                        start = vma->vm_start;
                vma_end = min(end, vma->vm_end);
 
+               if (userfaultfd_missing(vma)) {
+                       /*
+                        * Wake any concurrent pending userfault while
+                        * we unregister, so they will not hang
+                        * permanently and it avoids userland to call
+                        * UFFDIO_WAKE explicitly.
+                        */
+                       struct userfaultfd_wake_range range;
+                       range.start = start;
+                       range.len = vma_end - start;
+                       wake_userfault(vma->vm_userfaultfd_ctx.ctx, &range);
+               }
+
                new_flags = vma->vm_flags & ~(VM_UFFD_MISSING | VM_UFFD_WP);
                prev = vma_merge(mm, prev, start, vma_end, new_flags,
                                 vma->anon_vma, vma->vm_file, vma->vm_pgoff,
@@ -1145,6 +1486,14 @@ out:
        return ret;
 }
 
+static inline unsigned int uffd_ctx_features(__u64 user_features)
+{
+       /*
+        * For the current set of features the bits just coincide
+        */
+       return (unsigned int)user_features;
+}
+
 /*
  * userland asks for a certain API version and we return which bits
  * and ioctl commands are implemented in this kernel for such API
@@ -1156,6 +1505,7 @@ static int userfaultfd_api(struct userfaultfd_ctx *ctx,
        struct uffdio_api uffdio_api;
        void __user *buf = (void __user *)arg;
        int ret;
+       __u64 features;
 
        ret = -EINVAL;
        if (ctx->state != UFFD_STATE_WAIT_API)
@@ -1163,19 +1513,23 @@ static int userfaultfd_api(struct userfaultfd_ctx *ctx,
        ret = -EFAULT;
        if (copy_from_user(&uffdio_api, buf, sizeof(uffdio_api)))
                goto out;
-       if (uffdio_api.api != UFFD_API || uffdio_api.features) {
+       features = uffdio_api.features;
+       if (uffdio_api.api != UFFD_API || (features & ~UFFD_API_FEATURES)) {
                memset(&uffdio_api, 0, sizeof(uffdio_api));
                if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api)))
                        goto out;
                ret = -EINVAL;
                goto out;
        }
+       /* report all available features and ioctls to userland */
        uffdio_api.features = UFFD_API_FEATURES;
        uffdio_api.ioctls = UFFD_API_IOCTLS;
        ret = -EFAULT;
        if (copy_to_user(buf, &uffdio_api, sizeof(uffdio_api)))
                goto out;
        ctx->state = UFFD_STATE_RUNNING;
+       /* only enable the requested features for this uffd context */
+       ctx->features = uffd_ctx_features(features);
        ret = 0;
 out:
        return ret;
@@ -1262,6 +1616,7 @@ static void init_once_userfaultfd_ctx(void *mem)
 
        init_waitqueue_head(&ctx->fault_pending_wqh);
        init_waitqueue_head(&ctx->fault_wqh);
+       init_waitqueue_head(&ctx->event_wqh);
        init_waitqueue_head(&ctx->fd_wqh);
        seqcount_init(&ctx->refile_seq);
 }
@@ -1302,6 +1657,7 @@ static struct file *userfaultfd_file_create(int flags)
 
        atomic_set(&ctx->refcount, 1);
        ctx->flags = flags;
+       ctx->features = 0;
        ctx->state = UFFD_STATE_WAIT_API;
        ctx->released = false;
        ctx->mm = current->mm;