]> git.kernelconcepts.de Git - karo-tx-linux.git/blobdiff - kernel/audit.c
audit: fix memleak in auditd_send_unicast_skb.
[karo-tx-linux.git] / kernel / audit.c
index a2f7803a68d019e51ff423a2cc8cde8970c6c573..07def5e49cc9e29abd2dc4d5980564726a74f858 100644 (file)
@@ -59,6 +59,7 @@
 #include <linux/mutex.h>
 #include <linux/gfp.h>
 #include <linux/pid.h>
+#include <linux/slab.h>
 
 #include <linux/audit.h>
 
@@ -111,18 +112,19 @@ struct audit_net {
  * @pid: auditd PID
  * @portid: netlink portid
  * @net: the associated network namespace
- * @lock: spinlock to protect write access
+ * @rcu: RCU head
  *
  * Description:
  * This struct is RCU protected; you must either hold the RCU lock for reading
- * or the included spinlock for writing.
+ * or the associated spinlock for writing.
  */
 static struct auditd_connection {
        struct pid *pid;
        u32 portid;
        struct net *net;
-       spinlock_t lock;
-} auditd_conn;
+       struct rcu_head rcu;
+} *auditd_conn = NULL;
+static DEFINE_SPINLOCK(auditd_conn_lock);
 
 /* If audit_rate_limit is non-zero, limit the rate of sending audit records
  * to that number per second.  This prevents DoS attacks, but results in
@@ -152,12 +154,7 @@ static atomic_t    audit_lost = ATOMIC_INIT(0);
 /* Hash for inode-based rules */
 struct list_head audit_inode_hash[AUDIT_INODE_BUCKETS];
 
-/* The audit_freelist is a list of pre-allocated audit buffers (if more
- * than AUDIT_MAXFREE are in use, the audit buffer is freed instead of
- * being placed on the freelist). */
-static DEFINE_SPINLOCK(audit_freelist_lock);
-static int        audit_freelist_count;
-static LIST_HEAD(audit_freelist);
+static struct kmem_cache *audit_buffer_cache;
 
 /* queue msgs to send via kauditd_task */
 static struct sk_buff_head audit_queue;
@@ -192,17 +189,12 @@ DEFINE_MUTEX(audit_cmd_mutex);
  * should be at least that large. */
 #define AUDIT_BUFSIZ 1024
 
-/* AUDIT_MAXFREE is the number of empty audit_buffers we keep on the
- * audit_freelist.  Doing so eliminates many kmalloc/kfree calls. */
-#define AUDIT_MAXFREE  (2*NR_CPUS)
-
 /* The audit_buffer is used when formatting an audit record.  The caller
  * locks briefly to get the record off the freelist or to allocate the
  * buffer, and locks briefly to send the buffer to the netlink layer or
  * to place it on a transmit queue.  Multiple audit_buffers can be in
  * use simultaneously. */
 struct audit_buffer {
-       struct list_head     list;
        struct sk_buff       *skb;      /* formatted skb ready to send */
        struct audit_context *ctx;      /* NULL or associated context */
        gfp_t                gfp_mask;
@@ -224,9 +216,11 @@ struct audit_reply {
 int auditd_test_task(struct task_struct *task)
 {
        int rc;
+       struct auditd_connection *ac;
 
        rcu_read_lock();
-       rc = (auditd_conn.pid && auditd_conn.pid == task_tgid(task) ? 1 : 0);
+       ac = rcu_dereference(auditd_conn);
+       rc = (ac && ac->pid == task_tgid(task) ? 1 : 0);
        rcu_read_unlock();
 
        return rc;
@@ -234,22 +228,21 @@ int auditd_test_task(struct task_struct *task)
 
 /**
  * auditd_pid_vnr - Return the auditd PID relative to the namespace
- * @auditd: the auditd connection
  *
  * Description:
- * Returns the PID in relation to the namespace, 0 on failure.  This function
- * takes the RCU read lock internally, but if the caller needs to protect the
- * auditd_connection pointer it should take the RCU read lock as well.
+ * Returns the PID in relation to the namespace, 0 on failure.
  */
-static pid_t auditd_pid_vnr(const struct auditd_connection *auditd)
+static pid_t auditd_pid_vnr(void)
 {
        pid_t pid;
+       const struct auditd_connection *ac;
 
        rcu_read_lock();
-       if (!auditd || !auditd->pid)
+       ac = rcu_dereference(auditd_conn);
+       if (!ac || !ac->pid)
                pid = 0;
        else
-               pid = pid_vnr(auditd->pid);
+               pid = pid_vnr(ac->pid);
        rcu_read_unlock();
 
        return pid;
@@ -442,6 +435,24 @@ static int audit_set_failure(u32 state)
        return audit_do_config_change("audit_failure", &audit_failure, state);
 }
 
+/**
+ * auditd_conn_free - RCU helper to release an auditd connection struct
+ * @rcu: RCU head
+ *
+ * Description:
+ * Drop any references inside the auditd connection tracking struct and free
+ * the memory.
+ */
+ static void auditd_conn_free(struct rcu_head *rcu)
+ {
+       struct auditd_connection *ac;
+
+       ac = container_of(rcu, struct auditd_connection, rcu);
+       put_pid(ac->pid);
+       put_net(ac->net);
+       kfree(ac);
+ }
+
 /**
  * auditd_set - Set/Reset the auditd connection state
  * @pid: auditd PID
@@ -450,27 +461,33 @@ static int audit_set_failure(u32 state)
  *
  * Description:
  * This function will obtain and drop network namespace references as
- * necessary.
+ * necessary.  Returns zero on success, negative values on failure.
  */
-static void auditd_set(struct pid *pid, u32 portid, struct net *net)
+static int auditd_set(struct pid *pid, u32 portid, struct net *net)
 {
        unsigned long flags;
+       struct auditd_connection *ac_old, *ac_new;
 
-       spin_lock_irqsave(&auditd_conn.lock, flags);
-       if (auditd_conn.pid)
-               put_pid(auditd_conn.pid);
-       if (pid)
-               auditd_conn.pid = get_pid(pid);
-       else
-               auditd_conn.pid = NULL;
-       auditd_conn.portid = portid;
-       if (auditd_conn.net)
-               put_net(auditd_conn.net);
-       if (net)
-               auditd_conn.net = get_net(net);
-       else
-               auditd_conn.net = NULL;
-       spin_unlock_irqrestore(&auditd_conn.lock, flags);
+       if (!pid || !net)
+               return -EINVAL;
+
+       ac_new = kzalloc(sizeof(*ac_new), GFP_KERNEL);
+       if (!ac_new)
+               return -ENOMEM;
+       ac_new->pid = get_pid(pid);
+       ac_new->portid = portid;
+       ac_new->net = get_net(net);
+
+       spin_lock_irqsave(&auditd_conn_lock, flags);
+       ac_old = rcu_dereference_protected(auditd_conn,
+                                          lockdep_is_held(&auditd_conn_lock));
+       rcu_assign_pointer(auditd_conn, ac_new);
+       spin_unlock_irqrestore(&auditd_conn_lock, flags);
+
+       if (ac_old)
+               call_rcu(&ac_old->rcu, auditd_conn_free);
+
+       return 0;
 }
 
 /**
@@ -558,26 +575,40 @@ static void kauditd_retry_skb(struct sk_buff *skb)
 
 /**
  * auditd_reset - Disconnect the auditd connection
+ * @ac: auditd connection state
  *
  * Description:
  * Break the auditd/kauditd connection and move all the queued records into the
- * hold queue in case auditd reconnects.
+ * hold queue in case auditd reconnects.  It is important to note that the @ac
+ * pointer should never be dereferenced inside this function as it may be NULL
+ * or invalid, you can only compare the memory address!  If @ac is NULL then
+ * the connection will always be reset.
  */
-static void auditd_reset(void)
+static void auditd_reset(const struct auditd_connection *ac)
 {
+       unsigned long flags;
        struct sk_buff *skb;
+       struct auditd_connection *ac_old;
 
        /* if it isn't already broken, break the connection */
-       rcu_read_lock();
-       if (auditd_conn.pid)
-               auditd_set(0, 0, NULL);
-       rcu_read_unlock();
+       spin_lock_irqsave(&auditd_conn_lock, flags);
+       ac_old = rcu_dereference_protected(auditd_conn,
+                                          lockdep_is_held(&auditd_conn_lock));
+       if (ac && ac != ac_old) {
+               /* someone already registered a new auditd connection */
+               spin_unlock_irqrestore(&auditd_conn_lock, flags);
+               return;
+       }
+       rcu_assign_pointer(auditd_conn, NULL);
+       spin_unlock_irqrestore(&auditd_conn_lock, flags);
 
-       /* flush all of the main and retry queues to the hold queue */
+       if (ac_old)
+               call_rcu(&ac_old->rcu, auditd_conn_free);
+
+       /* flush the retry queue to the hold queue, but don't touch the main
+        * queue since we need to process that normally for multicast */
        while ((skb = skb_dequeue(&audit_retry_queue)))
                kauditd_hold_skb(skb);
-       while ((skb = skb_dequeue(&audit_queue)))
-               kauditd_hold_skb(skb);
 }
 
 /**
@@ -597,6 +628,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
        u32 portid;
        struct net *net;
        struct sock *sk;
+       struct auditd_connection *ac;
 
        /* NOTE: we can't call netlink_unicast while in the RCU section so
         *       take a reference to the network namespace and grab local
@@ -606,15 +638,16 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
         *       section netlink_unicast() should safely return an error */
 
        rcu_read_lock();
-       if (!auditd_conn.pid) {
+       ac = rcu_dereference(auditd_conn);
+       if (!ac) {
                rcu_read_unlock();
+               kfree_skb(skb);
                rc = -ECONNREFUSED;
                goto err;
        }
-       net = auditd_conn.net;
-       get_net(net);
+       net = get_net(ac->net);
        sk = audit_get_sk(net);
-       portid = auditd_conn.portid;
+       portid = ac->portid;
        rcu_read_unlock();
 
        rc = netlink_unicast(sk, skb, portid, 0);
@@ -625,8 +658,8 @@ static int auditd_send_unicast_skb(struct sk_buff *skb)
        return rc;
 
 err:
-       if (rc == -ECONNREFUSED)
-               auditd_reset();
+       if (ac && rc == -ECONNREFUSED)
+               auditd_reset(ac);
        return rc;
 }
 
@@ -749,6 +782,7 @@ static int kauditd_thread(void *dummy)
        u32 portid = 0;
        struct net *net = NULL;
        struct sock *sk = NULL;
+       struct auditd_connection *ac;
 
 #define UNICAST_RETRIES 5
 
@@ -756,23 +790,23 @@ static int kauditd_thread(void *dummy)
        while (!kthread_should_stop()) {
                /* NOTE: see the lock comments in auditd_send_unicast_skb() */
                rcu_read_lock();
-               if (!auditd_conn.pid) {
+               ac = rcu_dereference(auditd_conn);
+               if (!ac) {
                        rcu_read_unlock();
                        goto main_queue;
                }
-               net = auditd_conn.net;
-               get_net(net);
+               net = get_net(ac->net);
                sk = audit_get_sk(net);
-               portid = auditd_conn.portid;
+               portid = ac->portid;
                rcu_read_unlock();
 
                /* attempt to flush the hold queue */
                rc = kauditd_send_queue(sk, portid,
                                        &audit_hold_queue, UNICAST_RETRIES,
                                        NULL, kauditd_rehold_skb);
-               if (rc < 0) {
+               if (ac && rc < 0) {
                        sk = NULL;
-                       auditd_reset();
+                       auditd_reset(ac);
                        goto main_queue;
                }
 
@@ -780,9 +814,9 @@ static int kauditd_thread(void *dummy)
                rc = kauditd_send_queue(sk, portid,
                                        &audit_retry_queue, UNICAST_RETRIES,
                                        NULL, kauditd_hold_skb);
-               if (rc < 0) {
+               if (ac && rc < 0) {
                        sk = NULL;
-                       auditd_reset();
+                       auditd_reset(ac);
                        goto main_queue;
                }
 
@@ -790,12 +824,13 @@ main_queue:
                /* process the main queue - do the multicast send and attempt
                 * unicast, dump failed record sends to the retry queue; if
                 * sk == NULL due to previous failures we will just do the
-                * multicast send and move the record to the retry queue */
+                * multicast send and move the record to the hold queue */
                rc = kauditd_send_queue(sk, portid, &audit_queue, 1,
                                        kauditd_send_multicast_skb,
-                                       kauditd_retry_skb);
-               if (sk == NULL || rc < 0)
-                       auditd_reset();
+                                       (sk ?
+                                        kauditd_retry_skb : kauditd_hold_skb));
+               if (ac && rc < 0)
+                       auditd_reset(ac);
                sk = NULL;
 
                /* drop our netns reference, no auditd sends past this line */
@@ -1126,7 +1161,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                s.failure               = audit_failure;
                /* NOTE: use pid_vnr() so the PID is relative to the current
                 *       namespace */
-               s.pid                   = auditd_pid_vnr(&auditd_conn);
+               s.pid                   = auditd_pid_vnr();
                s.rate_limit            = audit_rate_limit;
                s.backlog_limit         = audit_backlog_limit;
                s.lost                  = atomic_read(&audit_lost);
@@ -1169,7 +1204,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                        /* test the auditd connection */
                        audit_replace(req_pid);
 
-                       auditd_pid = auditd_pid_vnr(&auditd_conn);
+                       auditd_pid = auditd_pid_vnr();
                        /* only the current auditd can unregister itself */
                        if ((!new_pid) && (new_pid != auditd_pid)) {
                                audit_log_config_change("audit_pid", new_pid,
@@ -1183,19 +1218,30 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                                return -EEXIST;
                        }
 
-                       if (audit_enabled != AUDIT_OFF)
-                               audit_log_config_change("audit_pid", new_pid,
-                                                       auditd_pid, 1);
-
                        if (new_pid) {
                                /* register a new auditd connection */
-                               auditd_set(req_pid, NETLINK_CB(skb).portid,
-                                          sock_net(NETLINK_CB(skb).sk));
+                               err = auditd_set(req_pid,
+                                                NETLINK_CB(skb).portid,
+                                                sock_net(NETLINK_CB(skb).sk));
+                               if (audit_enabled != AUDIT_OFF)
+                                       audit_log_config_change("audit_pid",
+                                                               new_pid,
+                                                               auditd_pid,
+                                                               err ? 0 : 1);
+                               if (err)
+                                       return err;
+
                                /* try to process any backlog */
                                wake_up_interruptible(&kauditd_wait);
-                       } else
+                       } else {
+                               if (audit_enabled != AUDIT_OFF)
+                                       audit_log_config_change("audit_pid",
+                                                               new_pid,
+                                                               auditd_pid, 1);
+
                                /* unregister the auditd connection */
-                               auditd_reset();
+                               auditd_reset(NULL);
+                       }
                }
                if (s.mask & AUDIT_STATUS_RATE_LIMIT) {
                        err = audit_set_rate_limit(s.rate_limit);
@@ -1463,10 +1509,11 @@ static void __net_exit audit_net_exit(struct net *net)
 {
        struct audit_net *aunet = net_generic(net, audit_net_id);
 
-       rcu_read_lock();
-       if (net == auditd_conn.net)
-               auditd_reset();
-       rcu_read_unlock();
+       /* NOTE: you would think that we would want to check the auditd
+        * connection and potentially reset it here if it lives in this
+        * namespace, but since the auditd connection tracking struct holds a
+        * reference to this namespace (see auditd_set()) we are only ever
+        * going to get here after that connection has been released */
 
        netlink_kernel_release(aunet->sk);
 }
@@ -1486,8 +1533,9 @@ static int __init audit_init(void)
        if (audit_initialized == AUDIT_DISABLED)
                return 0;
 
-       memset(&auditd_conn, 0, sizeof(auditd_conn));
-       spin_lock_init(&auditd_conn.lock);
+       audit_buffer_cache = kmem_cache_create("audit_buffer",
+                                              sizeof(struct audit_buffer),
+                                              0, SLAB_PANIC, NULL);
 
        skb_queue_head_init(&audit_queue);
        skb_queue_head_init(&audit_retry_queue);
@@ -1554,60 +1602,33 @@ __setup("audit_backlog_limit=", audit_backlog_limit_set);
 
 static void audit_buffer_free(struct audit_buffer *ab)
 {
-       unsigned long flags;
-
        if (!ab)
                return;
 
        kfree_skb(ab->skb);
-       spin_lock_irqsave(&audit_freelist_lock, flags);
-       if (audit_freelist_count > AUDIT_MAXFREE)
-               kfree(ab);
-       else {
-               audit_freelist_count++;
-               list_add(&ab->list, &audit_freelist);
-       }
-       spin_unlock_irqrestore(&audit_freelist_lock, flags);
+       kmem_cache_free(audit_buffer_cache, ab);
 }
 
-static struct audit_buffer * audit_buffer_alloc(struct audit_context *ctx,
-                                               gfp_t gfp_mask, int type)
+static struct audit_buffer *audit_buffer_alloc(struct audit_context *ctx,
+                                              gfp_t gfp_mask, int type)
 {
-       unsigned long flags;
-       struct audit_buffer *ab = NULL;
-       struct nlmsghdr *nlh;
-
-       spin_lock_irqsave(&audit_freelist_lock, flags);
-       if (!list_empty(&audit_freelist)) {
-               ab = list_entry(audit_freelist.next,
-                               struct audit_buffer, list);
-               list_del(&ab->list);
-               --audit_freelist_count;
-       }
-       spin_unlock_irqrestore(&audit_freelist_lock, flags);
-
-       if (!ab) {
-               ab = kmalloc(sizeof(*ab), gfp_mask);
-               if (!ab)
-                       goto err;
-       }
+       struct audit_buffer *ab;
 
-       ab->ctx = ctx;
-       ab->gfp_mask = gfp_mask;
+       ab = kmem_cache_alloc(audit_buffer_cache, gfp_mask);
+       if (!ab)
+               return NULL;
 
        ab->skb = nlmsg_new(AUDIT_BUFSIZ, gfp_mask);
        if (!ab->skb)
                goto err;
+       if (!nlmsg_put(ab->skb, 0, 0, type, 0, 0))
+               goto err;
 
-       nlh = nlmsg_put(ab->skb, 0, 0, type, 0, 0);
-       if (!nlh)
-               goto out_kfree_skb;
+       ab->ctx = ctx;
+       ab->gfp_mask = gfp_mask;
 
        return ab;
 
-out_kfree_skb:
-       kfree_skb(ab->skb);
-       ab->skb = NULL;
 err:
        audit_buffer_free(ab);
        return NULL;
@@ -1638,10 +1659,10 @@ unsigned int audit_serial(void)
 }
 
 static inline void audit_get_stamp(struct audit_context *ctx,
-                                  struct timespec *t, unsigned int *serial)
+                                  struct timespec64 *t, unsigned int *serial)
 {
        if (!ctx || !auditsc_get_stamp(ctx, t, serial)) {
-               *t = CURRENT_TIME;
+               ktime_get_real_ts64(t);
                *serial = audit_serial();
        }
 }
@@ -1665,7 +1686,7 @@ struct audit_buffer *audit_log_start(struct audit_context *ctx, gfp_t gfp_mask,
                                     int type)
 {
        struct audit_buffer *ab;
-       struct timespec t;
+       struct timespec64 t;
        unsigned int uninitialized_var(serial);
 
        if (audit_initialized != AUDIT_INITIALIZED)
@@ -1718,8 +1739,8 @@ struct audit_buffer *audit_log_start(struct audit_context *ctx, gfp_t gfp_mask,
        }
 
        audit_get_stamp(ab->ctx, &t, &serial);
-       audit_log_format(ab, "audit(%lu.%03lu:%u): ",
-                        t.tv_sec, t.tv_nsec/1000000, serial);
+       audit_log_format(ab, "audit(%llu.%03lu:%u): ",
+                        (unsigned long long)t.tv_sec, t.tv_nsec/1000000, serial);
 
        return ab;
 }
@@ -1988,22 +2009,10 @@ void audit_log_cap(struct audit_buffer *ab, char *prefix, kernel_cap_t *cap)
 
 static void audit_log_fcaps(struct audit_buffer *ab, struct audit_names *name)
 {
-       kernel_cap_t *perm = &name->fcap.permitted;
-       kernel_cap_t *inh = &name->fcap.inheritable;
-       int log = 0;
-
-       if (!cap_isclear(*perm)) {
-               audit_log_cap(ab, "cap_fp", perm);
-               log = 1;
-       }
-       if (!cap_isclear(*inh)) {
-               audit_log_cap(ab, "cap_fi", inh);
-               log = 1;
-       }
-
-       if (log)
-               audit_log_format(ab, " cap_fe=%d cap_fver=%x",
-                                name->fcap.fE, name->fcap_ver);
+       audit_log_cap(ab, "cap_fp", &name->fcap.permitted);
+       audit_log_cap(ab, "cap_fi", &name->fcap.inheritable);
+       audit_log_format(ab, " cap_fe=%d cap_fver=%x",
+                        name->fcap.fE, name->fcap_ver);
 }
 
 static inline int audit_copy_fcaps(struct audit_names *name,