inet_diag: fix inet_diag_bc_audit(), CVE-2011-2213
[linux-flexiantxendom0-natty.git] / net / ipv4 / inet_diag.c
index 8aa7d51..65c23d9 100644 (file)
@@ -1,8 +1,6 @@
 /*
  * inet_diag.c Module for monitoring INET transport protocols sockets.
  *
- * Version:    $Id: inet_diag.c,v 1.3 2002/02/01 22:01:04 davem Exp $
- *
  * Authors:    Alexey Kuznetsov, <kuznet@ms2.inr.ac.ru>
  *
  *     This program is free software; you can redistribute it and/or
  *      2 of the License, or (at your option) any later version.
  */
 
+#include <linux/kernel.h>
 #include <linux/module.h>
 #include <linux/types.h>
 #include <linux/fcntl.h>
 #include <linux/random.h>
+#include <linux/slab.h>
 #include <linux/cache.h>
 #include <linux/init.h>
 #include <linux/time.h>
@@ -27,6 +27,7 @@
 #include <net/inet_hashtables.h>
 #include <net/inet_timewait_sock.h>
 #include <net/inet6_hashtables.h>
+#include <net/netlink.h>
 
 #include <linux/inet.h>
 #include <linux/stddef.h>
@@ -49,6 +50,27 @@ static struct sock *idiagnl;
 #define INET_DIAG_PUT(skb, attrtype, attrlen) \
        RTA_DATA(__RTA_PUT(skb, attrtype, attrlen))
 
+static DEFINE_MUTEX(inet_diag_table_mutex);
+
+static const struct inet_diag_handler *inet_diag_lock_handler(int type)
+{
+       if (!inet_diag_table[type])
+               request_module("net-pf-%d-proto-%d-type-%d", PF_NETLINK,
+                              NETLINK_INET_DIAG, type);
+
+       mutex_lock(&inet_diag_table_mutex);
+       if (!inet_diag_table[type])
+               return ERR_PTR(-ENOENT);
+
+       return inet_diag_table[type];
+}
+
+static inline void inet_diag_unlock_handler(
+       const struct inet_diag_handler *handler)
+{
+       mutex_unlock(&inet_diag_table_mutex);
+}
+
 static int inet_csk_diag_fill(struct sock *sk,
                              struct sk_buff *skb,
                              int ext, u32 pid, u32 seq, u16 nlmsg_flags,
@@ -60,7 +82,7 @@ static int inet_csk_diag_fill(struct sock *sk,
        struct nlmsghdr  *nlh;
        void *info = NULL;
        struct inet_diag_meminfo  *minfo = NULL;
-       unsigned char    *b = skb->tail;
+       unsigned char    *b = skb_tail_pointer(skb);
        const struct inet_diag_handler *handler;
 
        handler = inet_diag_table[unlh->nlmsg_type];
@@ -95,10 +117,10 @@ static int inet_csk_diag_fill(struct sock *sk,
        r->id.idiag_cookie[0] = (u32)(unsigned long)sk;
        r->id.idiag_cookie[1] = (u32)(((unsigned long)sk >> 31) >> 1);
 
-       r->id.idiag_sport = inet->sport;
-       r->id.idiag_dport = inet->dport;
-       r->id.idiag_src[0] = inet->rcv_saddr;
-       r->id.idiag_dst[0] = inet->daddr;
+       r->id.idiag_sport = inet->inet_sport;
+       r->id.idiag_dport = inet->inet_dport;
+       r->id.idiag_src[0] = inet->inet_rcv_saddr;
+       r->id.idiag_dst[0] = inet->inet_daddr;
 
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        if (r->idiag_family == AF_INET6) {
@@ -111,7 +133,7 @@ static int inet_csk_diag_fill(struct sock *sk,
        }
 #endif
 
-#define EXPIRES_IN_MS(tmo)  ((tmo - jiffies) * 1000 + HZ - 1) / HZ
+#define EXPIRES_IN_MS(tmo)  DIV_ROUND_UP((tmo - jiffies) * 1000, HZ)
 
        if (icsk->icsk_pending == ICSK_TIME_RETRANS) {
                r->idiag_timer = 1;
@@ -135,10 +157,10 @@ static int inet_csk_diag_fill(struct sock *sk,
        r->idiag_inode = sock_i_ino(sk);
 
        if (minfo) {
-               minfo->idiag_rmem = atomic_read(&sk->sk_rmem_alloc);
+               minfo->idiag_rmem = sk_rmem_alloc_get(sk);
                minfo->idiag_wmem = sk->sk_wmem_queued;
                minfo->idiag_fmem = sk->sk_forward_alloc;
-               minfo->idiag_tmem = atomic_read(&sk->sk_wmem_alloc);
+               minfo->idiag_tmem = sk_wmem_alloc_get(sk);
        }
 
        handler->idiag_get_info(sk, r, info);
@@ -147,12 +169,12 @@ static int inet_csk_diag_fill(struct sock *sk,
            icsk->icsk_ca_ops && icsk->icsk_ca_ops->get_info)
                icsk->icsk_ca_ops->get_info(sk, ext, skb);
 
-       nlh->nlmsg_len = skb->tail - b;
+       nlh->nlmsg_len = skb_tail_pointer(skb) - b;
        return skb->len;
 
 rtattr_failure:
 nlmsg_failure:
-       skb_trim(skb, b - skb->data);
+       nlmsg_trim(skb, b);
        return -EMSGSIZE;
 }
 
@@ -163,7 +185,7 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
 {
        long tmo;
        struct inet_diag_msg *r;
-       const unsigned char *previous_tail = skb->tail;
+       const unsigned char *previous_tail = skb_tail_pointer(skb);
        struct nlmsghdr *nlh = NLMSG_PUT(skb, pid, seq,
                                         unlh->nlmsg_type, sizeof(*r));
 
@@ -177,8 +199,6 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
                tmo = 0;
 
        r->idiag_family       = tw->tw_family;
-       r->idiag_state        = tw->tw_state;
-       r->idiag_timer        = 0;
        r->idiag_retrans      = 0;
        r->id.idiag_if        = tw->tw_bound_dev_if;
        r->id.idiag_cookie[0] = (u32)(unsigned long)tw;
@@ -189,7 +209,7 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
        r->id.idiag_dst[0]    = tw->tw_daddr;
        r->idiag_state        = tw->tw_substate;
        r->idiag_timer        = 3;
-       r->idiag_expires      = (tmo * 1000 + HZ - 1) / HZ;
+       r->idiag_expires      = DIV_ROUND_UP(tmo * 1000, HZ);
        r->idiag_rqueue       = 0;
        r->idiag_wqueue       = 0;
        r->idiag_uid          = 0;
@@ -205,10 +225,10 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,
                               &tw6->tw_v6_daddr);
        }
 #endif
-       nlh->nlmsg_len = skb->tail - previous_tail;
+       nlh->nlmsg_len = skb_tail_pointer(skb) - previous_tail;
        return skb->len;
 nlmsg_failure:
-       skb_trim(skb, previous_tail - skb->data);
+       nlmsg_trim(skb, previous_tail);
        return -EMSGSIZE;
 }
 
@@ -233,18 +253,23 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
        struct inet_hashinfo *hashinfo;
        const struct inet_diag_handler *handler;
 
-       handler = inet_diag_table[nlh->nlmsg_type];
-       BUG_ON(handler == NULL);
+       handler = inet_diag_lock_handler(nlh->nlmsg_type);
+       if (IS_ERR(handler)) {
+               err = PTR_ERR(handler);
+               goto unlock;
+       }
+
        hashinfo = handler->idiag_hashinfo;
+       err = -EINVAL;
 
        if (req->idiag_family == AF_INET) {
-               sk = inet_lookup(hashinfo, req->id.idiag_dst[0],
+               sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0],
                                 req->id.idiag_dport, req->id.idiag_src[0],
                                 req->id.idiag_sport, req->id.idiag_if);
        }
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        else if (req->idiag_family == AF_INET6) {
-               sk = inet6_lookup(hashinfo,
+               sk = inet6_lookup(&init_net, hashinfo,
                                  (struct in6_addr *)req->id.idiag_dst,
                                  req->id.idiag_dport,
                                  (struct in6_addr *)req->id.idiag_src,
@@ -253,11 +278,12 @@ static int inet_diag_get_exact(struct sk_buff *in_skb,
        }
 #endif
        else {
-               return -EINVAL;
+               goto unlock;
        }
 
+       err = -ENOENT;
        if (sk == NULL)
-               return -ENOENT;
+               goto unlock;
 
        err = -ESTALE;
        if ((req->id.idiag_cookie[0] != INET_DIAG_NOCOOKIE ||
@@ -294,6 +320,8 @@ out:
                else
                        sock_put(sk);
        }
+unlock:
+       inet_diag_unlock_handler(handler);
        return err;
 }
 
@@ -341,7 +369,7 @@ static int inet_diag_bc_run(const void *bc, int len,
                        yes = entry->sport >= op[1].no;
                        break;
                case INET_DIAG_BC_S_LE:
-                       yes = entry->dport <= op[1].no;
+                       yes = entry->sport <= op[1].no;
                        break;
                case INET_DIAG_BC_D_GE:
                        yes = entry->dport >= op[1].no;
@@ -381,7 +409,7 @@ static int inet_diag_bc_run(const void *bc, int len,
                                if (addr[0] == 0 && addr[1] == 0 &&
                                    addr[2] == htonl(0xffff) &&
                                    bitstring_match(addr + 3, cond->addr,
-                                                   cond->prefix_len))
+                                                   cond->prefix_len))
                                        break;
                        }
                        yes = 0;
@@ -397,7 +425,7 @@ static int inet_diag_bc_run(const void *bc, int len,
                        bc += op->no;
                }
        }
-       return (len == 0);
+       return len == 0;
 }
 
 static int valid_cc(const void *bc, int len, int cc)
@@ -409,7 +437,7 @@ static int valid_cc(const void *bc, int len, int cc)
                        return 0;
                if (cc == len)
                        return 1;
-               if (op->yes < 4)
+               if (op->yes < 4 || op->yes & 3)
                        return 0;
                len -= op->yes;
                bc  += op->yes;
@@ -419,11 +447,11 @@ static int valid_cc(const void *bc, int len, int cc)
 
 static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
 {
-       const unsigned char *bc = bytecode;
+       const void *bc = bytecode;
        int  len = bytecode_len;
 
        while (len > 0) {
-               struct inet_diag_bc_op *op = (struct inet_diag_bc_op *)bc;
+               const struct inet_diag_bc_op *op = bc;
 
 //printk("BC: %d %d %d {%d} / %d\n", op->code, op->yes, op->no, op[1].no, len);
                switch (op->code) {
@@ -434,22 +462,20 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
                case INET_DIAG_BC_S_LE:
                case INET_DIAG_BC_D_GE:
                case INET_DIAG_BC_D_LE:
-                       if (op->yes < 4 || op->yes > len + 4)
-                               return -EINVAL;
                case INET_DIAG_BC_JMP:
-                       if (op->no < 4 || op->no > len + 4)
+                       if (op->no < 4 || op->no > len + 4 || op->no & 3)
                                return -EINVAL;
                        if (op->no < len &&
                            !valid_cc(bytecode, bytecode_len, len - op->no))
                                return -EINVAL;
                        break;
                case INET_DIAG_BC_NOP:
-                       if (op->yes < 4 || op->yes > len + 4)
-                               return -EINVAL;
                        break;
                default:
                        return -EINVAL;
                }
+               if (op->yes < 4 || op->yes > len + 4 || op->yes & 3)
+                       return -EINVAL;
                bc  += op->yes;
                len -= op->yes;
        }
@@ -462,9 +488,11 @@ static int inet_csk_diag_dump(struct sock *sk,
 {
        struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
 
-       if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) {
+       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
                struct inet_diag_entry entry;
-               struct rtattr *bc = (struct rtattr *)(r + 1);
+               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
+                                                         sizeof(*r),
+                                                         INET_DIAG_REQ_BYTECODE);
                struct inet_sock *inet = inet_sk(sk);
 
                entry.family = sk->sk_family;
@@ -477,14 +505,14 @@ static int inet_csk_diag_dump(struct sock *sk,
                } else
 #endif
                {
-                       entry.saddr = &inet->rcv_saddr;
-                       entry.daddr = &inet->daddr;
+                       entry.saddr = &inet->inet_rcv_saddr;
+                       entry.daddr = &inet->inet_daddr;
                }
-               entry.sport = inet->num;
-               entry.dport = ntohs(inet->dport);
+               entry.sport = inet->inet_num;
+               entry.dport = ntohs(inet->inet_dport);
                entry.userlocks = sk->sk_userlocks;
 
-               if (!inet_diag_bc_run(RTA_DATA(bc), RTA_PAYLOAD(bc), &entry))
+               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
                        return 0;
        }
 
@@ -499,9 +527,11 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
 {
        struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
 
-       if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) {
+       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
                struct inet_diag_entry entry;
-               struct rtattr *bc = (struct rtattr *)(r + 1);
+               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
+                                                         sizeof(*r),
+                                                         INET_DIAG_REQ_BYTECODE);
 
                entry.family = tw->tw_family;
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
@@ -518,9 +548,9 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                }
                entry.sport = tw->tw_num;
                entry.dport = ntohs(tw->tw_dport);
-               entry.userlocks = 0; 
+               entry.userlocks = 0;
 
-               if (!inet_diag_bc_run(RTA_DATA(bc), RTA_PAYLOAD(bc), &entry))
+               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
                        return 0;
        }
 
@@ -535,7 +565,7 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
 {
        const struct inet_request_sock *ireq = inet_rsk(req);
        struct inet_sock *inet = inet_sk(sk);
-       unsigned char *b = skb->tail;
+       unsigned char *b = skb_tail_pointer(skb);
        struct inet_diag_msg *r;
        struct nlmsghdr *nlh;
        long tmo;
@@ -557,7 +587,7 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
        if (tmo < 0)
                tmo = 0;
 
-       r->id.idiag_sport = inet->sport;
+       r->id.idiag_sport = inet->inet_sport;
        r->id.idiag_dport = ireq->rmt_port;
        r->id.idiag_src[0] = ireq->loc_addr;
        r->id.idiag_dst[0] = ireq->rmt_addr;
@@ -574,12 +604,12 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
                               &inet6_rsk(req)->rmt_addr);
        }
 #endif
-       nlh->nlmsg_len = skb->tail - b;
+       nlh->nlmsg_len = skb_tail_pointer(skb) - b;
 
        return skb->len;
 
 nlmsg_failure:
-       skb_trim(skb, b - skb->data);
+       nlmsg_trim(skb, b);
        return -1;
 }
 
@@ -590,7 +620,7 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct listen_sock *lopt;
-       struct rtattr *bc = NULL;
+       const struct nlattr *bc = NULL;
        struct inet_sock *inet = inet_sk(sk);
        int j, s_j;
        int reqnum, s_reqnum;
@@ -610,9 +640,10 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        if (!lopt || !lopt->qlen)
                goto out;
 
-       if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) {
-               bc = (struct rtattr *)(r + 1);
-               entry.sport = inet->num;
+       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
+               bc = nlmsg_find_attr(cb->nlh, sizeof(*r),
+                                    INET_DIAG_REQ_BYTECODE);
+               entry.sport = inet->inet_num;
                entry.userlocks = sk->sk_userlocks;
        }
 
@@ -644,8 +675,8 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
                                        &ireq->rmt_addr;
                                entry.dport = ntohs(ireq->rmt_port);
 
-                               if (!inet_diag_bc_run(RTA_DATA(bc),
-                                                   RTA_PAYLOAD(bc), &entry))
+                               if (!inet_diag_bc_run(nla_data(bc),
+                                                     nla_len(bc), &entry))
                                        continue;
                        }
 
@@ -676,8 +707,10 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
        const struct inet_diag_handler *handler;
        struct inet_hashinfo *hashinfo;
 
-       handler = inet_diag_table[cb->nlh->nlmsg_type];
-       BUG_ON(handler == NULL);
+       handler = inet_diag_lock_handler(cb->nlh->nlmsg_type);
+       if (IS_ERR(handler))
+               goto unlock;
+
        hashinfo = handler->idiag_hashinfo;
 
        s_i = cb->args[1];
@@ -687,13 +720,15 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                if (!(r->idiag_states & (TCPF_LISTEN | TCPF_SYN_RECV)))
                        goto skip_listen_ht;
 
-               inet_listen_lock(hashinfo);
                for (i = s_i; i < INET_LHTABLE_SIZE; i++) {
                        struct sock *sk;
-                       struct hlist_node *node;
+                       struct hlist_nulls_node *node;
+                       struct inet_listen_hashbucket *ilb;
 
                        num = 0;
-                       sk_for_each(sk, node, &hashinfo->listening_hash[i]) {
+                       ilb = &hashinfo->listening_hash[i];
+                       spin_lock_bh(&ilb->lock);
+                       sk_nulls_for_each(sk, node, &ilb->head) {
                                struct inet_sock *inet = inet_sk(sk);
 
                                if (num < s_num) {
@@ -701,7 +736,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                        continue;
                                }
 
-                               if (r->id.idiag_sport != inet->sport &&
+                               if (r->id.idiag_sport != inet->inet_sport &&
                                    r->id.idiag_sport)
                                        goto next_listen;
 
@@ -711,7 +746,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                        goto syn_recv;
 
                                if (inet_csk_diag_dump(sk, skb, cb) < 0) {
-                                       inet_listen_unlock(hashinfo);
+                                       spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
 
@@ -720,7 +755,7 @@ syn_recv:
                                        goto next_listen;
 
                                if (inet_diag_dump_reqs(skb, sk, cb) < 0) {
-                                       inet_listen_unlock(hashinfo);
+                                       spin_unlock_bh(&ilb->lock);
                                        goto done;
                                }
 
@@ -729,45 +764,51 @@ next_listen:
                                cb->args[4] = 0;
                                ++num;
                        }
+                       spin_unlock_bh(&ilb->lock);
 
                        s_num = 0;
                        cb->args[3] = 0;
                        cb->args[4] = 0;
                }
-               inet_listen_unlock(hashinfo);
 skip_listen_ht:
                cb->args[0] = 1;
                s_i = num = s_num = 0;
        }
 
        if (!(r->idiag_states & ~(TCPF_LISTEN | TCPF_SYN_RECV)))
-               return skb->len;
+               goto unlock;
 
-       for (i = s_i; i < hashinfo->ehash_size; i++) {
+       for (i = s_i; i <= hashinfo->ehash_mask; i++) {
                struct inet_ehash_bucket *head = &hashinfo->ehash[i];
+               spinlock_t *lock = inet_ehash_lockp(hashinfo, i);
                struct sock *sk;
-               struct hlist_node *node;
+               struct hlist_nulls_node *node;
+
+               num = 0;
+
+               if (hlist_nulls_empty(&head->chain) &&
+                       hlist_nulls_empty(&head->twchain))
+                       continue;
 
                if (i > s_i)
                        s_num = 0;
 
-               read_lock_bh(&head->lock);
-               num = 0;
-               sk_for_each(sk, node, &head->chain) {
+               spin_lock_bh(lock);
+               sk_nulls_for_each(sk, node, &head->chain) {
                        struct inet_sock *inet = inet_sk(sk);
 
                        if (num < s_num)
                                goto next_normal;
                        if (!(r->idiag_states & (1 << sk->sk_state)))
                                goto next_normal;
-                       if (r->id.idiag_sport != inet->sport &&
+                       if (r->id.idiag_sport != inet->inet_sport &&
                            r->id.idiag_sport)
                                goto next_normal;
-                       if (r->id.idiag_dport != inet->dport &&
+                       if (r->id.idiag_dport != inet->inet_dport &&
                            r->id.idiag_dport)
                                goto next_normal;
                        if (inet_csk_diag_dump(sk, skb, cb) < 0) {
-                               read_unlock_bh(&head->lock);
+                               spin_unlock_bh(lock);
                                goto done;
                        }
 next_normal:
@@ -789,88 +830,60 @@ next_normal:
                                    r->id.idiag_dport)
                                        goto next_dying;
                                if (inet_twsk_diag_dump(tw, skb, cb) < 0) {
-                                       read_unlock_bh(&head->lock);
+                                       spin_unlock_bh(lock);
                                        goto done;
                                }
 next_dying:
                                ++num;
                        }
                }
-               read_unlock_bh(&head->lock);
+               spin_unlock_bh(lock);
        }
 
 done:
        cb->args[1] = i;
        cb->args[2] = num;
+unlock:
+       inet_diag_unlock_handler(handler);
        return skb->len;
 }
 
-static inline int inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
+static int inet_diag_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
 {
-       if (!(nlh->nlmsg_flags&NLM_F_REQUEST))
-               return 0;
-
-       if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX)
-               goto err_inval;
-
-       if (inet_diag_table[nlh->nlmsg_type] == NULL)
-               return -ENOENT;
-
-       if (NLMSG_LENGTH(sizeof(struct inet_diag_req)) > skb->len)
-               goto err_inval;
-
-       if (nlh->nlmsg_flags&NLM_F_DUMP) {
-               if (nlh->nlmsg_len >
-                   (4 + NLMSG_SPACE(sizeof(struct inet_diag_req)))) {
-                       struct rtattr *rta = (void *)(NLMSG_DATA(nlh) +
-                                                sizeof(struct inet_diag_req));
-                       if (rta->rta_type != INET_DIAG_REQ_BYTECODE ||
-                           rta->rta_len < 8 ||
-                           rta->rta_len >
-                           (nlh->nlmsg_len -
-                            NLMSG_SPACE(sizeof(struct inet_diag_req))))
-                               goto err_inval;
-                       if (inet_diag_bc_audit(RTA_DATA(rta), RTA_PAYLOAD(rta)))
-                               goto err_inval;
+       int hdrlen = sizeof(struct inet_diag_req);
+
+       if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||
+           nlmsg_len(nlh) < hdrlen)
+               return -EINVAL;
+
+       if (nlh->nlmsg_flags & NLM_F_DUMP) {
+               if (nlmsg_attrlen(nlh, hdrlen)) {
+                       struct nlattr *attr;
+
+                       attr = nlmsg_find_attr(nlh, hdrlen,
+                                              INET_DIAG_REQ_BYTECODE);
+                       if (attr == NULL ||
+                           nla_len(attr) < sizeof(struct inet_diag_bc_op) ||
+                           inet_diag_bc_audit(nla_data(attr), nla_len(attr)))
+                               return -EINVAL;
                }
+
                return netlink_dump_start(idiagnl, skb, nlh,
                                          inet_diag_dump, NULL);
-       } else
-               return inet_diag_get_exact(skb, nlh);
+       }
 
-err_inval:
-       return -EINVAL;
+       return inet_diag_get_exact(skb, nlh);
 }
 
+static DEFINE_MUTEX(inet_diag_mutex);
 
-static inline void inet_diag_rcv_skb(struct sk_buff *skb)
-{
-       if (skb->len >= NLMSG_SPACE(0)) {
-               int err;
-               struct nlmsghdr *nlh = (struct nlmsghdr *)skb->data;
-
-               if (nlh->nlmsg_len < sizeof(*nlh) ||
-                   skb->len < nlh->nlmsg_len)
-                       return;
-               err = inet_diag_rcv_msg(skb, nlh);
-               if (err || nlh->nlmsg_flags & NLM_F_ACK)
-                       netlink_ack(skb, nlh, err);
-       }
-}
-
-static void inet_diag_rcv(struct sock *sk, int len)
+static void inet_diag_rcv(struct sk_buff *skb)
 {
-       struct sk_buff *skb;
-       unsigned int qlen = skb_queue_len(&sk->sk_receive_queue);
-
-       while (qlen-- && (skb = skb_dequeue(&sk->sk_receive_queue))) {
-               inet_diag_rcv_skb(skb);
-               kfree_skb(skb);
-       }
+       mutex_lock(&inet_diag_mutex);
+       netlink_rcv_skb(skb, &inet_diag_rcv_msg);
+       mutex_unlock(&inet_diag_mutex);
 }
 
-static DEFINE_SPINLOCK(inet_diag_register_lock);
-
 int inet_diag_register(const struct inet_diag_handler *h)
 {
        const __u16 type = h->idiag_type;
@@ -879,13 +892,13 @@ int inet_diag_register(const struct inet_diag_handler *h)
        if (type >= INET_DIAG_GETSOCK_MAX)
                goto out;
 
-       spin_lock(&inet_diag_register_lock);
+       mutex_lock(&inet_diag_table_mutex);
        err = -EEXIST;
        if (inet_diag_table[type] == NULL) {
                inet_diag_table[type] = h;
                err = 0;
        }
-       spin_unlock(&inet_diag_register_lock);
+       mutex_unlock(&inet_diag_table_mutex);
 out:
        return err;
 }
@@ -898,11 +911,9 @@ void inet_diag_unregister(const struct inet_diag_handler *h)
        if (type >= INET_DIAG_GETSOCK_MAX)
                return;
 
-       spin_lock(&inet_diag_register_lock);
+       mutex_lock(&inet_diag_table_mutex);
        inet_diag_table[type] = NULL;
-       spin_unlock(&inet_diag_register_lock);
-
-       synchronize_rcu();
+       mutex_unlock(&inet_diag_table_mutex);
 }
 EXPORT_SYMBOL_GPL(inet_diag_unregister);
 
@@ -916,8 +927,8 @@ static int __init inet_diag_init(void)
        if (!inet_diag_table)
                goto out;
 
-       idiagnl = netlink_kernel_create(NETLINK_INET_DIAG, 0, inet_diag_rcv,
-                                       THIS_MODULE);
+       idiagnl = netlink_kernel_create(&init_net, NETLINK_INET_DIAG, 0,
+                                       inet_diag_rcv, NULL, THIS_MODULE);
        if (idiagnl == NULL)
                goto out_free_table;
        err = 0;
@@ -930,10 +941,11 @@ out_free_table:
 
 static void __exit inet_diag_exit(void)
 {
-       sock_release(idiagnl->sk_socket);
+       netlink_kernel_release(idiagnl);
        kfree(inet_diag_table);
 }
 
 module_init(inet_diag_init);
 module_exit(inet_diag_exit);
 MODULE_LICENSE("GPL");
+MODULE_ALIAS_NET_PF_PROTO(PF_NETLINK, NETLINK_INET_DIAG);