SUNRPC: Deal with the lack of a SYN_SENT sk->sk_state_change callback...
[linux-flexiantxendom0-natty.git] / net / sunrpc / xprtsock.c
index b7cd8cc..3e0b5f1 100644 (file)
@@ -210,7 +210,8 @@ struct sock_xprt {
         * State of TCP reply receive
         */
        __be32                  tcp_fraghdr,
-                               tcp_xid;
+                               tcp_xid,
+                               tcp_calldir;
 
        u32                     tcp_offset,
                                tcp_reclen;
@@ -709,6 +710,8 @@ static void xs_reset_transport(struct sock_xprt *transport)
        if (sk == NULL)
                return;
 
+       transport->srcport = 0;
+
        write_lock_bh(&sk->sk_callback_lock);
        transport->inet = NULL;
        transport->sock = NULL;
@@ -769,12 +772,11 @@ static void xs_destroy(struct rpc_xprt *xprt)
 
        dprintk("RPC:       xs_destroy xprt %p\n", xprt);
 
-       cancel_rearming_delayed_work(&transport->connect_worker);
+       cancel_delayed_work_sync(&transport->connect_worker);
 
        xs_close(xprt);
        xs_free_peer_addresses(xprt);
-       kfree(xprt->slot);
-       kfree(xprt);
+       xprt_free(xprt);
        module_put(THIS_MODULE);
 }
 
@@ -799,7 +801,7 @@ static void xs_udp_data_ready(struct sock *sk, int len)
        u32 _xid;
        __be32 *xp;
 
-       read_lock(&sk->sk_callback_lock);
+       read_lock_bh(&sk->sk_callback_lock);
        dprintk("RPC:       xs_udp_data_ready...\n");
        if (!(xprt = xprt_from_sock(sk)))
                goto out;
@@ -851,7 +853,7 @@ static void xs_udp_data_ready(struct sock *sk, int len)
  dropit:
        skb_free_datagram(sk, skb);
  out:
-       read_unlock(&sk->sk_callback_lock);
+       read_unlock_bh(&sk->sk_callback_lock);
 }
 
 static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc)
@@ -927,7 +929,7 @@ static inline void xs_tcp_read_calldir(struct sock_xprt *transport,
 {
        size_t len, used;
        u32 offset;
-       __be32  calldir;
+       char *p;
 
        /*
         * We want transport->tcp_offset to be 8 at the end of this routine
@@ -936,26 +938,33 @@ static inline void xs_tcp_read_calldir(struct sock_xprt *transport,
         * transport->tcp_offset is 4 (after having already read the xid).
         */
        offset = transport->tcp_offset - sizeof(transport->tcp_xid);
-       len = sizeof(calldir) - offset;
+       len = sizeof(transport->tcp_calldir) - offset;
        dprintk("RPC:       reading CALL/REPLY flag (%Zu bytes)\n", len);
-       used = xdr_skb_read_bits(desc, &calldir, len);
+       p = ((char *) &transport->tcp_calldir) + offset;
+       used = xdr_skb_read_bits(desc, p, len);
        transport->tcp_offset += used;
        if (used != len)
                return;
        transport->tcp_flags &= ~TCP_RCV_READ_CALLDIR;
-       transport->tcp_flags |= TCP_RCV_COPY_CALLDIR;
-       transport->tcp_flags |= TCP_RCV_COPY_DATA;
        /*
         * We don't yet have the XDR buffer, so we will write the calldir
         * out after we get the buffer from the 'struct rpc_rqst'
         */
-       if (ntohl(calldir) == RPC_REPLY)
+       switch (ntohl(transport->tcp_calldir)) {
+       case RPC_REPLY:
+               transport->tcp_flags |= TCP_RCV_COPY_CALLDIR;
+               transport->tcp_flags |= TCP_RCV_COPY_DATA;
                transport->tcp_flags |= TCP_RPC_REPLY;
-       else
+               break;
+       case RPC_CALL:
+               transport->tcp_flags |= TCP_RCV_COPY_CALLDIR;
+               transport->tcp_flags |= TCP_RCV_COPY_DATA;
                transport->tcp_flags &= ~TCP_RPC_REPLY;
-       dprintk("RPC:       reading %s CALL/REPLY flag %08x\n",
-                       (transport->tcp_flags & TCP_RPC_REPLY) ?
-                               "reply for" : "request with", calldir);
+               break;
+       default:
+               dprintk("RPC:       invalid request message type\n");
+               xprt_force_disconnect(&transport->xprt);
+       }
        xs_tcp_check_fraghdr(transport);
 }
 
@@ -975,12 +984,10 @@ static inline void xs_tcp_read_common(struct rpc_xprt *xprt,
                /*
                 * Save the RPC direction in the XDR buffer
                 */
-               __be32  calldir = transport->tcp_flags & TCP_RPC_REPLY ?
-                                       htonl(RPC_REPLY) : 0;
-
                memcpy(rcvbuf->head[0].iov_base + transport->tcp_copied,
-                       &calldir, sizeof(calldir));
-               transport->tcp_copied += sizeof(calldir);
+                       &transport->tcp_calldir,
+                       sizeof(transport->tcp_calldir));
+               transport->tcp_copied += sizeof(transport->tcp_calldir);
                transport->tcp_flags &= ~TCP_RCV_COPY_CALLDIR;
        }
 
@@ -1223,7 +1230,7 @@ static void xs_tcp_data_ready(struct sock *sk, int bytes)
 
        dprintk("RPC:       xs_tcp_data_ready...\n");
 
-       read_lock(&sk->sk_callback_lock);
+       read_lock_bh(&sk->sk_callback_lock);
        if (!(xprt = xprt_from_sock(sk)))
                goto out;
        if (xprt->shutdown)
@@ -1242,7 +1249,7 @@ static void xs_tcp_data_ready(struct sock *sk, int bytes)
                read = tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv);
        } while (read > 0);
 out:
-       read_unlock(&sk->sk_callback_lock);
+       read_unlock_bh(&sk->sk_callback_lock);
 }
 
 /*
@@ -1295,18 +1302,19 @@ static void xs_tcp_state_change(struct sock *sk)
 {
        struct rpc_xprt *xprt;
 
-       read_lock(&sk->sk_callback_lock);
+       read_lock_bh(&sk->sk_callback_lock);
        if (!(xprt = xprt_from_sock(sk)))
                goto out;
        dprintk("RPC:       xs_tcp_state_change client %p...\n", xprt);
-       dprintk("RPC:       state %x conn %d dead %d zapped %d\n",
+       dprintk("RPC:       state %x conn %d dead %d zapped %d sk_shutdown %d\n",
                        sk->sk_state, xprt_connected(xprt),
                        sock_flag(sk, SOCK_DEAD),
-                       sock_flag(sk, SOCK_ZAPPED));
+                       sock_flag(sk, SOCK_ZAPPED),
+                       sk->sk_shutdown);
 
        switch (sk->sk_state) {
        case TCP_ESTABLISHED:
-               spin_lock_bh(&xprt->transport_lock);
+               spin_lock(&xprt->transport_lock);
                if (!xprt_test_and_set_connected(xprt)) {
                        struct sock_xprt *transport = container_of(xprt,
                                        struct sock_xprt, xprt);
@@ -1320,7 +1328,7 @@ static void xs_tcp_state_change(struct sock *sk)
 
                        xprt_wake_pending_tasks(xprt, -EAGAIN);
                }
-               spin_unlock_bh(&xprt->transport_lock);
+               spin_unlock(&xprt->transport_lock);
                break;
        case TCP_FIN_WAIT1:
                /* The client initiated a shutdown of the socket */
@@ -1336,7 +1344,6 @@ static void xs_tcp_state_change(struct sock *sk)
        case TCP_CLOSE_WAIT:
                /* The server initiated a shutdown of the socket */
                xprt_force_disconnect(xprt);
-       case TCP_SYN_SENT:
                xprt->connect_cookie++;
        case TCP_CLOSING:
                /*
@@ -1358,7 +1365,7 @@ static void xs_tcp_state_change(struct sock *sk)
                xs_sock_mark_closed(xprt);
        }
  out:
-       read_unlock(&sk->sk_callback_lock);
+       read_unlock_bh(&sk->sk_callback_lock);
 }
 
 /**
@@ -1369,7 +1376,7 @@ static void xs_error_report(struct sock *sk)
 {
        struct rpc_xprt *xprt;
 
-       read_lock(&sk->sk_callback_lock);
+       read_lock_bh(&sk->sk_callback_lock);
        if (!(xprt = xprt_from_sock(sk)))
                goto out;
        dprintk("RPC:       %s client %p...\n"
@@ -1377,7 +1384,7 @@ static void xs_error_report(struct sock *sk)
                        __func__, xprt, sk->sk_err);
        xprt_wake_pending_tasks(xprt, -EAGAIN);
 out:
-       read_unlock(&sk->sk_callback_lock);
+       read_unlock_bh(&sk->sk_callback_lock);
 }
 
 static void xs_write_space(struct sock *sk)
@@ -1409,13 +1416,13 @@ static void xs_write_space(struct sock *sk)
  */
 static void xs_udp_write_space(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
+       read_lock_bh(&sk->sk_callback_lock);
 
        /* from net/core/sock.c:sock_def_write_space */
        if (sock_writeable(sk))
                xs_write_space(sk);
 
-       read_unlock(&sk->sk_callback_lock);
+       read_unlock_bh(&sk->sk_callback_lock);
 }
 
 /**
@@ -1430,13 +1437,13 @@ static void xs_udp_write_space(struct sock *sk)
  */
 static void xs_tcp_write_space(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
+       read_lock_bh(&sk->sk_callback_lock);
 
        /* from net/core/stream.c:sk_stream_write_space */
        if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk))
                xs_write_space(sk);
 
-       read_unlock(&sk->sk_callback_lock);
+       read_unlock_bh(&sk->sk_callback_lock);
 }
 
 static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt)
@@ -1509,7 +1516,7 @@ static void xs_set_port(struct rpc_xprt *xprt, unsigned short port)
        xs_update_peer_port(xprt);
 }
 
-static unsigned short xs_get_srcport(struct sock_xprt *transport, struct socket *sock)
+static unsigned short xs_get_srcport(struct sock_xprt *transport)
 {
        unsigned short port = transport->srcport;
 
@@ -1518,7 +1525,7 @@ static unsigned short xs_get_srcport(struct sock_xprt *transport, struct socket
        return port;
 }
 
-static unsigned short xs_next_srcport(struct sock_xprt *transport, struct socket *sock, unsigned short port)
+static unsigned short xs_next_srcport(struct sock_xprt *transport, unsigned short port)
 {
        if (transport->srcport != 0)
                transport->srcport = 0;
@@ -1528,23 +1535,18 @@ static unsigned short xs_next_srcport(struct sock_xprt *transport, struct socket
                return xprt_max_resvport;
        return --port;
 }
-
-static int xs_bind4(struct sock_xprt *transport, struct socket *sock)
+static int xs_bind(struct sock_xprt *transport, struct socket *sock)
 {
-       struct sockaddr_in myaddr = {
-               .sin_family = AF_INET,
-       };
-       struct sockaddr_in *sa;
+       struct sockaddr_storage myaddr;
        int err, nloop = 0;
-       unsigned short port = xs_get_srcport(transport, sock);
+       unsigned short port = xs_get_srcport(transport);
        unsigned short last;
 
-       sa = (struct sockaddr_in *)&transport->srcaddr;
-       myaddr.sin_addr = sa->sin_addr;
+       memcpy(&myaddr, &transport->srcaddr, transport->xprt.addrlen);
        do {
-               myaddr.sin_port = htons(port);
-               err = kernel_bind(sock, (struct sockaddr *) &myaddr,
-                                               sizeof(myaddr));
+               rpc_set_port((struct sockaddr *)&myaddr, port);
+               err = kernel_bind(sock, (struct sockaddr *)&myaddr,
+                               transport->xprt.addrlen);
                if (port == 0)
                        break;
                if (err == 0) {
@@ -1552,48 +1554,23 @@ static int xs_bind4(struct sock_xprt *transport, struct socket *sock)
                        break;
                }
                last = port;
-               port = xs_next_srcport(transport, sock, port);
+               port = xs_next_srcport(transport, port);
                if (port > last)
                        nloop++;
        } while (err == -EADDRINUSE && nloop != 2);
-       dprintk("RPC:       %s %pI4:%u: %s (%d)\n",
-                       __func__, &myaddr.sin_addr,
-                       port, err ? "failed" : "ok", err);
-       return err;
-}
 
-static int xs_bind6(struct sock_xprt *transport, struct socket *sock)
-{
-       struct sockaddr_in6 myaddr = {
-               .sin6_family = AF_INET6,
-       };
-       struct sockaddr_in6 *sa;
-       int err, nloop = 0;
-       unsigned short port = xs_get_srcport(transport, sock);
-       unsigned short last;
-
-       sa = (struct sockaddr_in6 *)&transport->srcaddr;
-       myaddr.sin6_addr = sa->sin6_addr;
-       do {
-               myaddr.sin6_port = htons(port);
-               err = kernel_bind(sock, (struct sockaddr *) &myaddr,
-                                               sizeof(myaddr));
-               if (port == 0)
-                       break;
-               if (err == 0) {
-                       transport->srcport = port;
-                       break;
-               }
-               last = port;
-               port = xs_next_srcport(transport, sock, port);
-               if (port > last)
-                       nloop++;
-       } while (err == -EADDRINUSE && nloop != 2);
-       dprintk("RPC:       xs_bind6 %pI6:%u: %s (%d)\n",
-               &myaddr.sin6_addr, port, err ? "failed" : "ok", err);
+       if (myaddr.ss_family == AF_INET)
+               dprintk("RPC:       %s %pI4:%u: %s (%d)\n", __func__,
+                               &((struct sockaddr_in *)&myaddr)->sin_addr,
+                               port, err ? "failed" : "ok", err);
+       else
+               dprintk("RPC:       %s %pI6:%u: %s (%d)\n", __func__,
+                               &((struct sockaddr_in6 *)&myaddr)->sin6_addr,
+                               port, err ? "failed" : "ok", err);
        return err;
 }
 
+
 #ifdef CONFIG_DEBUG_LOCK_ALLOC
 static struct lock_class_key xs_key[2];
 static struct lock_class_key xs_slock_key[2];
@@ -1615,6 +1592,18 @@ static inline void xs_reclassify_socket6(struct socket *sock)
        sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC",
                &xs_slock_key[1], "sk_lock-AF_INET6-RPC", &xs_key[1]);
 }
+
+static inline void xs_reclassify_socket(int family, struct socket *sock)
+{
+       switch (family) {
+       case AF_INET:
+               xs_reclassify_socket4(sock);
+               break;
+       case AF_INET6:
+               xs_reclassify_socket6(sock);
+               break;
+       }
+}
 #else
 static inline void xs_reclassify_socket4(struct socket *sock)
 {
@@ -1623,8 +1612,37 @@ static inline void xs_reclassify_socket4(struct socket *sock)
 static inline void xs_reclassify_socket6(struct socket *sock)
 {
 }
+
+static inline void xs_reclassify_socket(int family, struct socket *sock)
+{
+}
 #endif
 
+static struct socket *xs_create_sock(struct rpc_xprt *xprt,
+               struct sock_xprt *transport, int family, int type, int protocol)
+{
+       struct socket *sock;
+       int err;
+
+       err = __sock_create(xprt->xprt_net, family, type, protocol, &sock, 1);
+       if (err < 0) {
+               dprintk("RPC:       can't create %d transport socket (%d).\n",
+                               protocol, -err);
+               goto out;
+       }
+       xs_reclassify_socket(family, sock);
+
+       err = xs_bind(transport, sock);
+       if (err) {
+               sock_release(sock);
+               goto out;
+       }
+
+       return sock;
+out:
+       return ERR_PTR(err);
+}
+
 static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
 {
        struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
@@ -1654,82 +1672,23 @@ static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
        xs_udp_do_set_buffer_size(xprt);
 }
 
-/**
- * xs_udp_connect_worker4 - set up a UDP socket
- * @work: RPC transport to connect
- *
- * Invoked by a work queue tasklet.
- */
-static void xs_udp_connect_worker4(struct work_struct *work)
-{
-       struct sock_xprt *transport =
-               container_of(work, struct sock_xprt, connect_worker.work);
-       struct rpc_xprt *xprt = &transport->xprt;
-       struct socket *sock = transport->sock;
-       int err, status = -EIO;
-
-       if (xprt->shutdown)
-               goto out;
-
-       /* Start by resetting any existing state */
-       xs_reset_transport(transport);
-
-       err = sock_create_kern(PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock);
-       if (err < 0) {
-               dprintk("RPC:       can't create UDP transport socket (%d).\n", -err);
-               goto out;
-       }
-       xs_reclassify_socket4(sock);
-
-       if (xs_bind4(transport, sock)) {
-               sock_release(sock);
-               goto out;
-       }
-
-       dprintk("RPC:       worker connecting xprt %p via %s to "
-                               "%s (port %s)\n", xprt,
-                       xprt->address_strings[RPC_DISPLAY_PROTO],
-                       xprt->address_strings[RPC_DISPLAY_ADDR],
-                       xprt->address_strings[RPC_DISPLAY_PORT]);
-
-       xs_udp_finish_connecting(xprt, sock);
-       status = 0;
-out:
-       xprt_clear_connecting(xprt);
-       xprt_wake_pending_tasks(xprt, status);
-}
-
-/**
- * xs_udp_connect_worker6 - set up a UDP socket
- * @work: RPC transport to connect
- *
- * Invoked by a work queue tasklet.
- */
-static void xs_udp_connect_worker6(struct work_struct *work)
+static void xs_udp_setup_socket(struct work_struct *work)
 {
        struct sock_xprt *transport =
                container_of(work, struct sock_xprt, connect_worker.work);
        struct rpc_xprt *xprt = &transport->xprt;
        struct socket *sock = transport->sock;
-       int err, status = -EIO;
+       int status = -EIO;
 
        if (xprt->shutdown)
                goto out;
 
        /* Start by resetting any existing state */
        xs_reset_transport(transport);
-
-       err = sock_create_kern(PF_INET6, SOCK_DGRAM, IPPROTO_UDP, &sock);
-       if (err < 0) {
-               dprintk("RPC:       can't create UDP transport socket (%d).\n", -err);
-               goto out;
-       }
-       xs_reclassify_socket6(sock);
-
-       if (xs_bind6(transport, sock) < 0) {
-               sock_release(sock);
+       sock = xs_create_sock(xprt, transport,
+                       xs_addr(xprt)->sa_family, SOCK_DGRAM, IPPROTO_UDP);
+       if (IS_ERR(sock))
                goto out;
-       }
 
        dprintk("RPC:       worker connecting xprt %p via %s to "
                                "%s (port %s)\n", xprt,
@@ -1748,12 +1707,12 @@ out:
  * We need to preserve the port number so the reply cache on the server can
  * find our cached RPC replies when we get around to reconnecting.
  */
-static void xs_abort_connection(struct rpc_xprt *xprt, struct sock_xprt *transport)
+static void xs_abort_connection(struct sock_xprt *transport)
 {
        int result;
        struct sockaddr any;
 
-       dprintk("RPC:       disconnecting xprt %p to reuse port\n", xprt);
+       dprintk("RPC:       disconnecting xprt %p to reuse port\n", transport);
 
        /*
         * Disconnect the transport socket by doing a connect operation
@@ -1763,26 +1722,42 @@ static void xs_abort_connection(struct rpc_xprt *xprt, struct sock_xprt *transpo
        any.sa_family = AF_UNSPEC;
        result = kernel_connect(transport->sock, &any, sizeof(any), 0);
        if (!result)
-               xs_sock_mark_closed(xprt);
+               xs_sock_mark_closed(&transport->xprt);
        else
                dprintk("RPC:       AF_UNSPEC connect return code %d\n",
                                result);
 }
 
-static void xs_tcp_reuse_connection(struct rpc_xprt *xprt, struct sock_xprt *transport)
+static void xs_tcp_reuse_connection(struct sock_xprt *transport)
 {
        unsigned int state = transport->inet->sk_state;
 
-       if (state == TCP_CLOSE && transport->sock->state == SS_UNCONNECTED)
-               return;
-       if ((1 << state) & (TCPF_ESTABLISHED|TCPF_SYN_SENT))
-               return;
-       xs_abort_connection(xprt, transport);
+       if (state == TCP_CLOSE && transport->sock->state == SS_UNCONNECTED) {
+               /* we don't need to abort the connection if the socket
+                * hasn't undergone a shutdown
+                */
+               if (transport->inet->sk_shutdown == 0)
+                       return;
+               dprintk("RPC:       %s: TCP_CLOSEd and sk_shutdown set to %d\n",
+                               __func__, transport->inet->sk_shutdown);
+       }
+       if ((1 << state) & (TCPF_ESTABLISHED|TCPF_SYN_SENT)) {
+               /* we don't need to abort the connection if the socket
+                * hasn't undergone a shutdown
+                */
+               if (transport->inet->sk_shutdown == 0)
+                       return;
+               dprintk("RPC:       %s: ESTABLISHED/SYN_SENT "
+                               "sk_shutdown set to %d\n",
+                               __func__, transport->inet->sk_shutdown);
+       }
+       xs_abort_connection(transport);
 }
 
 static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
 {
        struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt);
+       int ret = -ENOTCONN;
 
        if (!transport->inet) {
                struct sock *sk = sock->sk;
@@ -1814,12 +1789,22 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
        }
 
        if (!xprt_bound(xprt))
-               return -ENOTCONN;
+               goto out;
 
        /* Tell the socket layer to start connecting... */
        xprt->stat.connect_count++;
        xprt->stat.connect_start = jiffies;
-       return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK);
+       ret = kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK);
+       switch (ret) {
+       case 0:
+       case -EINPROGRESS:
+               /* SYN_SENT! */
+               xprt->connect_cookie++;
+               if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO)
+                       xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO;
+       }
+out:
+       return ret;
 }
 
 /**
@@ -1830,12 +1815,12 @@ static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock)
  *
  * Invoked by a work queue tasklet.
  */
-static void xs_tcp_setup_socket(struct rpc_xprt *xprt,
-               struct sock_xprt *transport,
-               struct socket *(*create_sock)(struct rpc_xprt *,
-                       struct sock_xprt *))
+static void xs_tcp_setup_socket(struct work_struct *work)
 {
+       struct sock_xprt *transport =
+               container_of(work, struct sock_xprt, connect_worker.work);
        struct socket *sock = transport->sock;
+       struct rpc_xprt *xprt = &transport->xprt;
        int status = -EIO;
 
        if (xprt->shutdown)
@@ -1843,7 +1828,8 @@ static void xs_tcp_setup_socket(struct rpc_xprt *xprt,
 
        if (!sock) {
                clear_bit(XPRT_CONNECTION_ABORT, &xprt->state);
-               sock = create_sock(xprt, transport);
+               sock = xs_create_sock(xprt, transport,
+                               xs_addr(xprt)->sa_family, SOCK_STREAM, IPPROTO_TCP);
                if (IS_ERR(sock)) {
                        status = PTR_ERR(sock);
                        goto out;
@@ -1854,7 +1840,7 @@ static void xs_tcp_setup_socket(struct rpc_xprt *xprt,
                abort_and_exit = test_and_clear_bit(XPRT_CONNECTION_ABORT,
                                &xprt->state);
                /* "close" the socket, preserving the local port */
-               xs_tcp_reuse_connection(xprt, transport);
+               xs_tcp_reuse_connection(transport);
 
                if (abort_and_exit)
                        goto out_eagain;
@@ -1903,84 +1889,6 @@ out:
        xprt_wake_pending_tasks(xprt, status);
 }
 
-static struct socket *xs_create_tcp_sock4(struct rpc_xprt *xprt,
-               struct sock_xprt *transport)
-{
-       struct socket *sock;
-       int err;
-
-       /* start from scratch */
-       err = sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
-       if (err < 0) {
-               dprintk("RPC:       can't create TCP transport socket (%d).\n",
-                               -err);
-               goto out_err;
-       }
-       xs_reclassify_socket4(sock);
-
-       if (xs_bind4(transport, sock) < 0) {
-               sock_release(sock);
-               goto out_err;
-       }
-       return sock;
-out_err:
-       return ERR_PTR(-EIO);
-}
-
-/**
- * xs_tcp_connect_worker4 - connect a TCP socket to a remote endpoint
- * @work: RPC transport to connect
- *
- * Invoked by a work queue tasklet.
- */
-static void xs_tcp_connect_worker4(struct work_struct *work)
-{
-       struct sock_xprt *transport =
-               container_of(work, struct sock_xprt, connect_worker.work);
-       struct rpc_xprt *xprt = &transport->xprt;
-
-       xs_tcp_setup_socket(xprt, transport, xs_create_tcp_sock4);
-}
-
-static struct socket *xs_create_tcp_sock6(struct rpc_xprt *xprt,
-               struct sock_xprt *transport)
-{
-       struct socket *sock;
-       int err;
-
-       /* start from scratch */
-       err = sock_create_kern(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock);
-       if (err < 0) {
-               dprintk("RPC:       can't create TCP transport socket (%d).\n",
-                               -err);
-               goto out_err;
-       }
-       xs_reclassify_socket6(sock);
-
-       if (xs_bind6(transport, sock) < 0) {
-               sock_release(sock);
-               goto out_err;
-       }
-       return sock;
-out_err:
-       return ERR_PTR(-EIO);
-}
-
-/**
- * xs_tcp_connect_worker6 - connect a TCP socket to a remote endpoint
- * @work: RPC transport to connect
- *
- * Invoked by a work queue tasklet.
- */
-static void xs_tcp_connect_worker6(struct work_struct *work)
-{
-       struct sock_xprt *transport =
-               container_of(work, struct sock_xprt, connect_worker.work);
-       struct rpc_xprt *xprt = &transport->xprt;
-
-       xs_tcp_setup_socket(xprt, transport, xs_create_tcp_sock6);
-}
-
 /**
  * xs_connect - connect a socket to a remote endpoint
  * @task: address of RPC task that manages state of connect request
@@ -2240,6 +2148,31 @@ static struct rpc_xprt_ops bc_tcp_ops = {
        .print_stats            = xs_tcp_print_stats,
 };
 
+static int xs_init_anyaddr(const int family, struct sockaddr *sap)
+{
+       static const struct sockaddr_in sin = {
+               .sin_family             = AF_INET,
+               .sin_addr.s_addr        = htonl(INADDR_ANY),
+       };
+       static const struct sockaddr_in6 sin6 = {
+               .sin6_family            = AF_INET6,
+               .sin6_addr              = IN6ADDR_ANY_INIT,
+       };
+
+       switch (family) {
+       case AF_INET:
+               memcpy(sap, &sin, sizeof(sin));
+               break;
+       case AF_INET6:
+               memcpy(sap, &sin6, sizeof(sin6));
+               break;
+       default:
+               dprintk("RPC:       %s: Bad address family\n", __func__);
+               return -EAFNOSUPPORT;
+       }
+       return 0;
+}
+
 static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args,
                                      unsigned int slot_table_size)
 {
@@ -2251,27 +2184,25 @@ static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args,
                return ERR_PTR(-EBADF);
        }
 
-       new = kzalloc(sizeof(*new), GFP_KERNEL);
-       if (new == NULL) {
+       xprt = xprt_alloc(args->net, sizeof(*new), slot_table_size);
+       if (xprt == NULL) {
                dprintk("RPC:       xs_setup_xprt: couldn't allocate "
                                "rpc_xprt\n");
                return ERR_PTR(-ENOMEM);
        }
-       xprt = &new->xprt;
-
-       xprt->max_reqs = slot_table_size;
-       xprt->slot = kcalloc(xprt->max_reqs, sizeof(struct rpc_rqst), GFP_KERNEL);
-       if (xprt->slot == NULL) {
-               kfree(xprt);
-               dprintk("RPC:       xs_setup_xprt: couldn't allocate slot "
-                               "table\n");
-               return ERR_PTR(-ENOMEM);
-       }
 
+       new = container_of(xprt, struct sock_xprt, xprt);
        memcpy(&xprt->addr, args->dstaddr, args->addrlen);
        xprt->addrlen = args->addrlen;
        if (args->srcaddr)
                memcpy(&new->srcaddr, args->srcaddr, args->addrlen);
+       else {
+               int err;
+               err = xs_init_anyaddr(args->dstaddr->sa_family,
+                                       (struct sockaddr *)&new->srcaddr);
+               if (err != 0)
+                       return ERR_PTR(err);
+       }
 
        return xprt;
 }
@@ -2293,6 +2224,7 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args)
        struct sockaddr *addr = args->dstaddr;
        struct rpc_xprt *xprt;
        struct sock_xprt *transport;
+       struct rpc_xprt *ret;
 
        xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries);
        if (IS_ERR(xprt))
@@ -2318,7 +2250,7 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args)
                        xprt_set_bound(xprt);
 
                INIT_DELAYED_WORK(&transport->connect_worker,
-                                       xs_udp_connect_worker4);
+                                       xs_udp_setup_socket);
                xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP);
                break;
        case AF_INET6:
@@ -2326,12 +2258,12 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args)
                        xprt_set_bound(xprt);
 
                INIT_DELAYED_WORK(&transport->connect_worker,
-                                       xs_udp_connect_worker6);
+                                       xs_udp_setup_socket);
                xs_format_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6);
                break;
        default:
-               kfree(xprt);
-               return ERR_PTR(-EAFNOSUPPORT);
+               ret = ERR_PTR(-EAFNOSUPPORT);
+               goto out_err;
        }
 
        if (xprt_bound(xprt))
@@ -2346,10 +2278,10 @@ static struct rpc_xprt *xs_setup_udp(struct xprt_create *args)
 
        if (try_module_get(THIS_MODULE))
                return xprt;
-
-       kfree(xprt->slot);
-       kfree(xprt);
-       return ERR_PTR(-EINVAL);
+       ret = ERR_PTR(-EINVAL);
+out_err:
+       xprt_free(xprt);
+       return ret;
 }
 
 static const struct rpc_timeout xs_tcp_default_timeout = {
@@ -2368,6 +2300,7 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
        struct sockaddr *addr = args->dstaddr;
        struct rpc_xprt *xprt;
        struct sock_xprt *transport;
+       struct rpc_xprt *ret;
 
        xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries);
        if (IS_ERR(xprt))
@@ -2391,7 +2324,7 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
                        xprt_set_bound(xprt);
 
                INIT_DELAYED_WORK(&transport->connect_worker,
-                                       xs_tcp_connect_worker4);
+                                       xs_tcp_setup_socket);
                xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP);
                break;
        case AF_INET6:
@@ -2399,12 +2332,12 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
                        xprt_set_bound(xprt);
 
                INIT_DELAYED_WORK(&transport->connect_worker,
-                                       xs_tcp_connect_worker6);
+                                       xs_tcp_setup_socket);
                xs_format_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6);
                break;
        default:
-               kfree(xprt);
-               return ERR_PTR(-EAFNOSUPPORT);
+               ret = ERR_PTR(-EAFNOSUPPORT);
+               goto out_err;
        }
 
        if (xprt_bound(xprt))
@@ -2420,10 +2353,10 @@ static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args)
 
        if (try_module_get(THIS_MODULE))
                return xprt;
-
-       kfree(xprt->slot);
-       kfree(xprt);
-       return ERR_PTR(-EINVAL);
+       ret = ERR_PTR(-EINVAL);
+out_err:
+       xprt_free(xprt);
+       return ret;
 }
 
 /**
@@ -2437,7 +2370,17 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args)
        struct rpc_xprt *xprt;
        struct sock_xprt *transport;
        struct svc_sock *bc_sock;
+       struct rpc_xprt *ret;
 
+       if (args->bc_xprt->xpt_bc_xprt) {
+               /*
+                * This server connection already has a backchannel
+                * export; we can't create a new one, as we wouldn't be
+                * able to match replies based on xid any more.  So,
+                * reuse the already-existing one:
+                */
+                return args->bc_xprt->xpt_bc_xprt;
+       }
        xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries);
        if (IS_ERR(xprt))
                return xprt;
@@ -2454,16 +2397,6 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args)
        xprt->reestablish_timeout = 0;
        xprt->idle_timeout = 0;
 
-       /*
-        * The backchannel uses the same socket connection as the
-        * forechannel
-        */
-       xprt->bc_xprt = args->bc_xprt;
-       bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt);
-       bc_sock->sk_bc_xprt = xprt;
-       transport->sock = bc_sock->sk_sock;
-       transport->inet = bc_sock->sk_sk;
-
        xprt->ops = &bc_tcp_ops;
 
        switch (addr->sa_family) {
@@ -2476,19 +2409,28 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args)
                                   RPCBIND_NETID_TCP6);
                break;
        default:
-               kfree(xprt);
-               return ERR_PTR(-EAFNOSUPPORT);
+               ret = ERR_PTR(-EAFNOSUPPORT);
+               goto out_err;
        }
 
-       if (xprt_bound(xprt))
-               dprintk("RPC:       set up xprt to %s (port %s) via %s\n",
-                               xprt->address_strings[RPC_DISPLAY_ADDR],
-                               xprt->address_strings[RPC_DISPLAY_PORT],
-                               xprt->address_strings[RPC_DISPLAY_PROTO]);
-       else
-               dprintk("RPC:       set up xprt to %s (autobind) via %s\n",
-                               xprt->address_strings[RPC_DISPLAY_ADDR],
-                               xprt->address_strings[RPC_DISPLAY_PROTO]);
+       dprintk("RPC:       set up xprt to %s (port %s) via %s\n",
+                       xprt->address_strings[RPC_DISPLAY_ADDR],
+                       xprt->address_strings[RPC_DISPLAY_PORT],
+                       xprt->address_strings[RPC_DISPLAY_PROTO]);
+
+       /*
+        * Once we've associated a backchannel xprt with a connection,
+        * we want to keep it around as long as long as the connection
+        * lasts, in case we need to start using it for a backchannel
+        * again; this reference won't be dropped until bc_xprt is
+        * destroyed.
+        */
+       xprt_get(xprt);
+       args->bc_xprt->xpt_bc_xprt = xprt;
+       xprt->bc_xprt = args->bc_xprt;
+       bc_sock = container_of(args->bc_xprt, struct svc_sock, sk_xprt);
+       transport->sock = bc_sock->sk_sock;
+       transport->inet = bc_sock->sk_sk;
 
        /*
         * Since we don't want connections for the backchannel, we set
@@ -2499,9 +2441,11 @@ static struct rpc_xprt *xs_setup_bc_tcp(struct xprt_create *args)
 
        if (try_module_get(THIS_MODULE))
                return xprt;
-       kfree(xprt->slot);
-       kfree(xprt);
-       return ERR_PTR(-EINVAL);
+       xprt_put(xprt);
+       ret = ERR_PTR(-EINVAL);
+out_err:
+       xprt_free(xprt);
+       return ret;
 }
 
 static struct xprt_class       xs_udp_transport = {
@@ -2564,7 +2508,8 @@ void cleanup_socket_xprt(void)
        xprt_unregister_transport(&xs_bc_tcp_transport);
 }
 
-static int param_set_uint_minmax(const char *val, struct kernel_param *kp,
+static int param_set_uint_minmax(const char *val,
+               const struct kernel_param *kp,
                unsigned int min, unsigned int max)
 {
        unsigned long num;
@@ -2579,34 +2524,37 @@ static int param_set_uint_minmax(const char *val, struct kernel_param *kp,
        return 0;
 }
 
-static int param_set_portnr(const char *val, struct kernel_param *kp)
+static int param_set_portnr(const char *val, const struct kernel_param *kp)
 {
        return param_set_uint_minmax(val, kp,
                        RPC_MIN_RESVPORT,
                        RPC_MAX_RESVPORT);
 }
 
-static int param_get_portnr(char *buffer, struct kernel_param *kp)
-{
-       return param_get_uint(buffer, kp);
-}
+static struct kernel_param_ops param_ops_portnr = {
+       .set = param_set_portnr,
+       .get = param_get_uint,
+};
+
 #define param_check_portnr(name, p) \
        __param_check(name, p, unsigned int);
 
 module_param_named(min_resvport, xprt_min_resvport, portnr, 0644);
 module_param_named(max_resvport, xprt_max_resvport, portnr, 0644);
 
-static int param_set_slot_table_size(const char *val, struct kernel_param *kp)
+static int param_set_slot_table_size(const char *val,
+                                    const struct kernel_param *kp)
 {
        return param_set_uint_minmax(val, kp,
                        RPC_MIN_SLOT_TABLE,
                        RPC_MAX_SLOT_TABLE);
 }
 
-static int param_get_slot_table_size(char *buffer, struct kernel_param *kp)
-{
-       return param_get_uint(buffer, kp);
-}
+static struct kernel_param_ops param_ops_slot_table_size = {
+       .set = param_set_slot_table_size,
+       .get = param_get_uint,
+};
+
 #define param_check_slot_table_size(name, p) \
        __param_check(name, p, unsigned int);