Handle failed negotiation on modern socket
[nbd.git] / nbd-server.c
index b47da86..5c7cbc0 100644 (file)
@@ -320,7 +320,7 @@ int authorized_client(CLIENT *opts) {
  * @param buf a buffer
  * @param len the number of bytes to be read
  **/
-inline void readit(int f, void *buf, size_t len) {
+static inline void readit(int f, void *buf, size_t len) {
        ssize_t res;
        while (len > 0) {
                DEBUG("*");
@@ -342,7 +342,7 @@ inline void readit(int f, void *buf, size_t len) {
  * @param buf a buffer containing data
  * @param len the number of bytes to be written
  **/
-inline void writeit(int f, void *buf, size_t len) {
+static inline void writeit(int f, void *buf, size_t len) {
        ssize_t res;
        while (len > 0) {
                DEBUG("+");
@@ -599,8 +599,9 @@ SERVER* dup_serve(SERVER *s) {
                serve->authname = strdup(s->authname);
 
        serve->flags = s->flags;
-       serve->socket = serve->socket;
-       serve->socket_family = serve->socket_family;
+       serve->socket = s->socket;
+       serve->socket_family = s->socket_family;
+       serve->virtstyle = s->virtstyle;
        serve->cidrlen = s->cidrlen;
 
        if(s->prerun)
@@ -699,21 +700,21 @@ GArray* parse_cfile(gchar* f, GError** e) {
        SERVER s;
        gchar *virtstyle=NULL;
        PARAM lp[] = {
-               { "exportname", TRUE,   PARAM_STRING,   NULL, 0 },
-               { "port",       TRUE,   PARAM_INT,      NULL, 0 },
-               { "authfile",   FALSE,  PARAM_STRING,   NULL, 0 },
-               { "filesize",   FALSE,  PARAM_INT,      NULL, 0 },
-               { "virtstyle",  FALSE,  PARAM_STRING,   NULL, 0 },
-               { "prerun",     FALSE,  PARAM_STRING,   NULL, 0 },
-               { "postrun",    FALSE,  PARAM_STRING,   NULL, 0 },
-               { "readonly",   FALSE,  PARAM_BOOL,     NULL, F_READONLY },
-               { "multifile",  FALSE,  PARAM_BOOL,     NULL, F_MULTIFILE },
-               { "copyonwrite", FALSE, PARAM_BOOL,     NULL, F_COPYONWRITE },
-               { "sparse_cow", FALSE,  PARAM_BOOL,     NULL, F_SPARSE },
-               { "sdp",        FALSE,  PARAM_BOOL,     NULL, F_SDP },
-               { "sync",       FALSE,  PARAM_BOOL,     NULL, F_SYNC },
-               { "listenaddr", FALSE,  PARAM_STRING,   NULL, 0 },
-               { "maxconnections", FALSE, PARAM_INT,   NULL, 0 },
+               { "exportname", TRUE,   PARAM_STRING,   &(s.exportname),        0 },
+               { "port",       TRUE,   PARAM_INT,      &(s.port),              0 },
+               { "authfile",   FALSE,  PARAM_STRING,   &(s.authname),          0 },
+               { "filesize",   FALSE,  PARAM_INT,      &(s.expected_size),     0 },
+               { "virtstyle",  FALSE,  PARAM_STRING,   &(virtstyle),           0 },
+               { "prerun",     FALSE,  PARAM_STRING,   &(s.prerun),            0 },
+               { "postrun",    FALSE,  PARAM_STRING,   &(s.postrun),           0 },
+               { "readonly",   FALSE,  PARAM_BOOL,     &(s.flags),             F_READONLY },
+               { "multifile",  FALSE,  PARAM_BOOL,     &(s.flags),             F_MULTIFILE },
+               { "copyonwrite", FALSE, PARAM_BOOL,     &(s.flags),             F_COPYONWRITE },
+               { "sparse_cow", FALSE,  PARAM_BOOL,     &(s.flags),             F_SPARSE },
+               { "sdp",        FALSE,  PARAM_BOOL,     &(s.flags),             F_SDP },
+               { "sync",       FALSE,  PARAM_BOOL,     &(s.flags),             F_SYNC },
+               { "listenaddr", FALSE,  PARAM_STRING,   &(s.listenaddr),        0 },
+               { "maxconnections", FALSE, PARAM_INT,   &(s.max_connections),   0 },
        };
        const int lp_size=sizeof(lp)/sizeof(PARAM);
        PARAM gp[] = {
@@ -753,18 +754,6 @@ GArray* parse_cfile(gchar* f, GError** e) {
        groups = g_key_file_get_groups(cfile, NULL);
        for(i=0;groups[i];i++) {
                memset(&s, '\0', sizeof(SERVER));
-               lp[0].target=&(s.exportname);
-               lp[1].target=&(s.port);
-               lp[2].target=&(s.authname);
-               lp[3].target=&(s.expected_size);
-               lp[4].target=&(virtstyle);
-               lp[5].target=&(s.prerun);
-               lp[6].target=&(s.postrun);
-               lp[7].target=lp[8].target=lp[9].target=
-                               lp[10].target=lp[11].target=
-                               lp[12].target=&(s.flags);
-               lp[13].target=&(s.listenaddr);
-               lp[14].target=&(s.max_connections);
 
                /* After the [generic] group, start parsing exports */
                if(i==1) {
@@ -1308,9 +1297,11 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
                                client->exportsize = OFFT_MAX;
                                client->net = net;
                                client->modern = TRUE;
+                               free(name);
                                return client;
                        }
                }
+               free(name);
                return NULL;
        }
        /* common */
@@ -1364,7 +1355,10 @@ int mainloop(CLIENT *client) {
        reply.error = 0;
        while (go_on) {
                char buf[BUFSIZE];
+               char* p;
                size_t len;
+               size_t currlen;
+               size_t writelen;
 #ifdef DODBG
                i++;
                printf("%d: ", i);
@@ -1389,8 +1383,12 @@ int mainloop(CLIENT *client) {
 
                if (request.magic != htonl(NBD_REQUEST_MAGIC))
                        err("Not enough magic.");
-               if (len > BUFSIZE - sizeof(struct nbd_reply))
-                       err("Request too big!");
+               if (len > BUFSIZE - sizeof(struct nbd_reply)) {
+                       currlen = BUFSIZE - sizeof(struct nbd_reply);
+                       msg2(LOG_INFO, "oversized request (this is not a problem)");
+               } else {
+                       currlen = len;
+               }
 #ifdef DODBG
                printf("%s from %llu (%llu) len %d, ", request.type ? "WRITE" :
                                "READ", (unsigned long long)request.from,
@@ -1411,35 +1409,47 @@ int mainloop(CLIENT *client) {
 
                if (request.type==NBD_CMD_WRITE) {
                        DEBUG("wr: net->buf, ");
-                       readit(client->net, buf, len);
-                       DEBUG("buf->exp, ");
-                       if ((client->server->flags & F_READONLY) ||
-                           (client->server->flags & F_AUTOREADONLY)) {
-                               DEBUG("[WRITE to READONLY!]");
-                               ERROR(client, reply, EPERM);
-                               continue;
-                       }
-                       if (expwrite(request.from, buf, len, client)) {
-                               DEBUG("Write failed: %m" );
-                               ERROR(client, reply, errno);
-                               continue;
+                       while(len > 0) {
+                               readit(client->net, buf, currlen);
+                               DEBUG("buf->exp, ");
+                               if ((client->server->flags & F_READONLY) ||
+                                   (client->server->flags & F_AUTOREADONLY)) {
+                                       DEBUG("[WRITE to READONLY!]");
+                                       ERROR(client, reply, EPERM);
+                                       continue;
+                               }
+                               if (expwrite(request.from, buf, len, client)) {
+                                       DEBUG("Write failed: %m" );
+                                       ERROR(client, reply, errno);
+                                       continue;
+                               }
+                               SEND(client->net, reply);
+                               DEBUG("OK!\n");
+                               len -= currlen;
+                               currlen = (len < BUFSIZE) ? len : BUFSIZE;
                        }
-                       SEND(client->net, reply);
-                       DEBUG("OK!\n");
                        continue;
                }
                /* READ */
 
                DEBUG("exp->buf, ");
-               if (expread(request.from, buf + sizeof(struct nbd_reply), len, client)) {
-                       DEBUG("Read failed: %m");
-                       ERROR(client, reply, errno);
-                       continue;
-               }
-
-               DEBUG("buf->net, ");
                memcpy(buf, &reply, sizeof(struct nbd_reply));
-               writeit(client->net, buf, len + sizeof(struct nbd_reply));
+               p = buf + sizeof(struct nbd_reply);
+               writelen = currlen + sizeof(struct nbd_reply);
+               while(len > 0) {
+                       if (expread(request.from, p, currlen, client)) {
+                               DEBUG("Read failed: %m");
+                               ERROR(client, reply, errno);
+                               continue;
+                       }
+
+                       DEBUG("buf->net, ");
+                       writeit(client->net, buf, writelen);
+                       len -= currlen;
+                       currlen = (len < BUFSIZE) ? len : BUFSIZE;
+                       p = buf;
+                       writelen = currlen;
+               }
                DEBUG("OK!\n");
        }
        return 0;
@@ -1726,7 +1736,7 @@ int serveloop(GArray* servers) {
                memcpy(&rset, &mset, sizeof(fd_set));
                if(select(max+1, &rset, NULL, NULL, NULL)>0) {
                        int net = 0;
-                       SERVER* serve;
+                       SERVER* serve=NULL;
 
                        DEBUG("accept, ");
                        if(FD_ISSET(modernsock, &rset)) {
@@ -1737,7 +1747,9 @@ int serveloop(GArray* servers) {
                                        err_nonfatal("negotiation failed");
                                        close(net);
                                        net=0;
+                                       continue;
                                }
+                               serve = client->server;
                        }
                        for(i=0;i<servers->len && !net;i++) {
                                serve=&(g_array_index(servers, SERVER, i));