inet_diag: fix inet_diag_bc_audit(), CVE-2011-2213
[linux-flexiantxendom0-natty.git] / net / ipv4 / inet_diag.c
index cb73fde..65c23d9 100644 (file)
@@ -14,6 +14,7 @@
 #include <linux/types.h>
 #include <linux/fcntl.h>
 #include <linux/random.h>
+#include <linux/slab.h>
 #include <linux/cache.h>
 #include <linux/init.h>
 #include <linux/time.h>
@@ -116,10 +117,10 @@ static int inet_csk_diag_fill(struct sock *sk,
        r->id.idiag_cookie[0] = (u32)(unsigned long)sk;
        r->id.idiag_cookie[1] = (u32)(((unsigned long)sk >> 31) >> 1);
 
-       r->id.idiag_sport = inet->sport;
-       r->id.idiag_dport = inet->dport;
-       r->id.idiag_src[0] = inet->rcv_saddr;
-       r->id.idiag_dst[0] = inet->daddr;
+       r->id.idiag_sport = inet->inet_sport;
+       r->id.idiag_dport = inet->inet_dport;
+       r->id.idiag_src[0] = inet->inet_rcv_saddr;
+       r->id.idiag_dst[0] = inet->inet_daddr;
 
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
        if (r->idiag_family == AF_INET6) {
@@ -368,7 +369,7 @@ static int inet_diag_bc_run(const void *bc, int len,
                        yes = entry->sport >= op[1].no;
                        break;
                case INET_DIAG_BC_S_LE:
-                       yes = entry->dport <= op[1].no;
+                       yes = entry->sport <= op[1].no;
                        break;
                case INET_DIAG_BC_D_GE:
                        yes = entry->dport >= op[1].no;
@@ -424,7 +425,7 @@ static int inet_diag_bc_run(const void *bc, int len,
                        bc += op->no;
                }
        }
-       return (len == 0);
+       return len == 0;
 }
 
 static int valid_cc(const void *bc, int len, int cc)
@@ -436,7 +437,7 @@ static int valid_cc(const void *bc, int len, int cc)
                        return 0;
                if (cc == len)
                        return 1;
-               if (op->yes < 4)
+               if (op->yes < 4 || op->yes & 3)
                        return 0;
                len -= op->yes;
                bc  += op->yes;
@@ -446,11 +447,11 @@ static int valid_cc(const void *bc, int len, int cc)
 
 static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
 {
-       const unsigned char *bc = bytecode;
+       const void *bc = bytecode;
        int  len = bytecode_len;
 
        while (len > 0) {
-               struct inet_diag_bc_op *op = (struct inet_diag_bc_op *)bc;
+               const struct inet_diag_bc_op *op = bc;
 
 //printk("BC: %d %d %d {%d} / %d\n", op->code, op->yes, op->no, op[1].no, len);
                switch (op->code) {
@@ -461,22 +462,20 @@ static int inet_diag_bc_audit(const void *bytecode, int bytecode_len)
                case INET_DIAG_BC_S_LE:
                case INET_DIAG_BC_D_GE:
                case INET_DIAG_BC_D_LE:
-                       if (op->yes < 4 || op->yes > len + 4)
-                               return -EINVAL;
                case INET_DIAG_BC_JMP:
-                       if (op->no < 4 || op->no > len + 4)
+                       if (op->no < 4 || op->no > len + 4 || op->no & 3)
                                return -EINVAL;
                        if (op->no < len &&
                            !valid_cc(bytecode, bytecode_len, len - op->no))
                                return -EINVAL;
                        break;
                case INET_DIAG_BC_NOP:
-                       if (op->yes < 4 || op->yes > len + 4)
-                               return -EINVAL;
                        break;
                default:
                        return -EINVAL;
                }
+               if (op->yes < 4 || op->yes > len + 4 || op->yes & 3)
+                       return -EINVAL;
                bc  += op->yes;
                len -= op->yes;
        }
@@ -489,9 +488,11 @@ static int inet_csk_diag_dump(struct sock *sk,
 {
        struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
 
-       if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) {
+       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
                struct inet_diag_entry entry;
-               struct rtattr *bc = (struct rtattr *)(r + 1);
+               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
+                                                         sizeof(*r),
+                                                         INET_DIAG_REQ_BYTECODE);
                struct inet_sock *inet = inet_sk(sk);
 
                entry.family = sk->sk_family;
@@ -504,14 +505,14 @@ static int inet_csk_diag_dump(struct sock *sk,
                } else
 #endif
                {
-                       entry.saddr = &inet->rcv_saddr;
-                       entry.daddr = &inet->daddr;
+                       entry.saddr = &inet->inet_rcv_saddr;
+                       entry.daddr = &inet->inet_daddr;
                }
-               entry.sport = inet->num;
-               entry.dport = ntohs(inet->dport);
+               entry.sport = inet->inet_num;
+               entry.dport = ntohs(inet->inet_dport);
                entry.userlocks = sk->sk_userlocks;
 
-               if (!inet_diag_bc_run(RTA_DATA(bc), RTA_PAYLOAD(bc), &entry))
+               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
                        return 0;
        }
 
@@ -526,9 +527,11 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
 {
        struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
 
-       if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) {
+       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
                struct inet_diag_entry entry;
-               struct rtattr *bc = (struct rtattr *)(r + 1);
+               const struct nlattr *bc = nlmsg_find_attr(cb->nlh,
+                                                         sizeof(*r),
+                                                         INET_DIAG_REQ_BYTECODE);
 
                entry.family = tw->tw_family;
 #if defined(CONFIG_IPV6) || defined (CONFIG_IPV6_MODULE)
@@ -547,7 +550,7 @@ static int inet_twsk_diag_dump(struct inet_timewait_sock *tw,
                entry.dport = ntohs(tw->tw_dport);
                entry.userlocks = 0;
 
-               if (!inet_diag_bc_run(RTA_DATA(bc), RTA_PAYLOAD(bc), &entry))
+               if (!inet_diag_bc_run(nla_data(bc), nla_len(bc), &entry))
                        return 0;
        }
 
@@ -584,7 +587,7 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,
        if (tmo < 0)
                tmo = 0;
 
-       r->id.idiag_sport = inet->sport;
+       r->id.idiag_sport = inet->inet_sport;
        r->id.idiag_dport = ireq->rmt_port;
        r->id.idiag_src[0] = ireq->loc_addr;
        r->id.idiag_dst[0] = ireq->rmt_addr;
@@ -617,7 +620,7 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        struct inet_diag_req *r = NLMSG_DATA(cb->nlh);
        struct inet_connection_sock *icsk = inet_csk(sk);
        struct listen_sock *lopt;
-       struct rtattr *bc = NULL;
+       const struct nlattr *bc = NULL;
        struct inet_sock *inet = inet_sk(sk);
        int j, s_j;
        int reqnum, s_reqnum;
@@ -637,9 +640,10 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
        if (!lopt || !lopt->qlen)
                goto out;
 
-       if (cb->nlh->nlmsg_len > 4 + NLMSG_SPACE(sizeof(*r))) {
-               bc = (struct rtattr *)(r + 1);
-               entry.sport = inet->num;
+       if (nlmsg_attrlen(cb->nlh, sizeof(*r))) {
+               bc = nlmsg_find_attr(cb->nlh, sizeof(*r),
+                                    INET_DIAG_REQ_BYTECODE);
+               entry.sport = inet->inet_num;
                entry.userlocks = sk->sk_userlocks;
        }
 
@@ -671,8 +675,8 @@ static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk,
                                        &ireq->rmt_addr;
                                entry.dport = ntohs(ireq->rmt_port);
 
-                               if (!inet_diag_bc_run(RTA_DATA(bc),
-                                                   RTA_PAYLOAD(bc), &entry))
+                               if (!inet_diag_bc_run(nla_data(bc),
+                                                     nla_len(bc), &entry))
                                        continue;
                        }
 
@@ -732,7 +736,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
                                        continue;
                                }
 
-                               if (r->id.idiag_sport != inet->sport &&
+                               if (r->id.idiag_sport != inet->inet_sport &&
                                    r->id.idiag_sport)
                                        goto next_listen;
 
@@ -797,10 +801,10 @@ skip_listen_ht:
                                goto next_normal;
                        if (!(r->idiag_states & (1 << sk->sk_state)))
                                goto next_normal;
-                       if (r->id.idiag_sport != inet->sport &&
+                       if (r->id.idiag_sport != inet->inet_sport &&
                            r->id.idiag_sport)
                                goto next_normal;
-                       if (r->id.idiag_dport != inet->dport &&
+                       if (r->id.idiag_dport != inet->inet_dport &&
                            r->id.idiag_dport)
                                goto next_normal;
                        if (inet_csk_diag_dump(sk, skb, cb) < 0) {