Add missing break
[nbd.git] / nbd-server.c
index 3343d9d..aea5d2f 100644 (file)
@@ -165,12 +165,18 @@ char pidfname[256]; /**< name of our PID file */
 char pidftemplate[256]; /**< template to be used for the filename of the PID file */
 char default_authname[] = SYSCONFDIR "/nbd-server/allow"; /**< default name of allow file */
 
+#define NEG_INIT       (1 << 0)
+#define NEG_OLD                (1 << 1)
+#define NEG_MODERN     (1 << 2)
+
 int modernsock=0;        /**< Socket for the modern handler. Not used
                               if a client was only specified on the
                               command line; only port used if
                               oldstyle is set to false (and then the
                               command-line client isn't used, gna gna) */
 char* modern_listen;     /**< listenaddr value for modernsock */
+char* modernport=NBD_DEFAULT_PORT; /**< Port number on which to listen for
+                                     new-style nbd-client connections */
 
 /**
  * Types of virtuatlization
@@ -262,6 +268,28 @@ typedef struct {
 } PARAM;
 
 /**
+ * Translate a command name into human readable form
+ *
+ * @param command The command number (after applying NBD_CMD_MASK_COMMAND)
+ * @return pointer to the command name
+ **/
+static inline const char * getcommandname(uint64_t command) {
+       switch (command) {
+       case NBD_CMD_READ:
+               return "NBD_CMD_READ";
+       case NBD_CMD_WRITE:
+               return "NBD_CMD_WRITE";
+       case NBD_CMD_DISC:
+               return "NBD_CMD_DISC";
+       case NBD_CMD_FLUSH:
+               return "NBD_CMD_FLUSH";
+       default:
+               break;
+       }
+       return "UNKNOWN";
+}
+
+/**
  * Check whether a client is allowed to connect. Works with an authorization
  * file which contains one line per machine, no wildcards.
  *
@@ -338,6 +366,24 @@ static inline void readit(int f, void *buf, size_t len) {
 }
 
 /**
+ * Consume data from an FD that we don't want
+ *
+ * @param f a file descriptor
+ * @param buf a buffer
+ * @param len the number of bytes to consume
+ * @param bufsiz the size of the buffer
+ **/
+static inline void consume(int f, void * buf, size_t len, size_t bufsiz) {
+       size_t curlen;
+       while (len>0) {
+               curlen = (len>bufsiz)?bufsiz:len;
+               readit(f, buf, curlen);
+               len -= curlen;
+       }
+}
+
+
+/**
  * Write data from a buffer into a filedescriptor
  *
  * @param f a file descriptor
@@ -370,7 +416,7 @@ void usage() {
               "\t-p|--pid-file\t\tspecify a filename to write our PID to\n"
               "\t-o|--output-config\toutput a config file section for what you\n\t\t\t\tspecified on the command line, with the\n\t\t\t\tspecified section name\n"
               "\t-M|--max-connections\tspecify the maximum number of opened connections\n\n"
-              "\tif port is set to 0, stdin is used (for running from inetd)\n"
+              "\tif port is set to 0, stdin is used (for running from inetd).\n"
               "\tif file_to_export contains '%%s', it is substituted with the IP\n"
               "\t\taddress of the machine trying to connect\n" 
               "\tif ip is set, it contains the local IP address on which we're listening.\n\tif not, the server will listen on all local IP addresses\n");
@@ -737,6 +783,7 @@ GArray* parse_cfile(gchar* f, GError** e) {
                { "group",      FALSE, PARAM_STRING,    &rungroup,      0 },
                { "oldstyle",   FALSE, PARAM_BOOL,      &do_oldstyle,   1 },
                { "listenaddr", FALSE, PARAM_STRING,    &modern_listen, 0 },
+               { "port",       FALSE, PARAM_STRING,    &modernport,    0 },
        };
        PARAM* p=gp;
        int p_size=sizeof(gp)/sizeof(PARAM);
@@ -746,7 +793,9 @@ GArray* parse_cfile(gchar* f, GError** e) {
        GQuark errdomain;
        GArray *retval=NULL;
        gchar **groups;
-       gboolean value;
+       gboolean bval;
+       gint ival;
+       gchar* sval;
        gchar* startgroup;
        gint i;
        gint j;
@@ -780,25 +829,29 @@ GArray* parse_cfile(gchar* f, GError** e) {
                        g_assert(p[j].ptype==PARAM_INT||p[j].ptype==PARAM_STRING||p[j].ptype==PARAM_BOOL);
                        switch(p[j].ptype) {
                                case PARAM_INT:
-                                       *((gint*)p[j].target) =
-                                               g_key_file_get_integer(cfile,
+                                       ival = g_key_file_get_integer(cfile,
                                                                groups[i],
                                                                p[j].paramname,
                                                                &err);
+                                       if(!err) {
+                                               *((gint*)p[j].target) = ival;
+                                       }
                                        break;
                                case PARAM_STRING:
-                                       *((gchar**)p[j].target) =
-                                               g_key_file_get_string(cfile,
+                                       sval = g_key_file_get_string(cfile,
                                                                groups[i],
                                                                p[j].paramname,
                                                                &err);
+                                       if(!err) {
+                                               *((gchar**)p[j].target) = sval;
+                                       }
                                        break;
                                case PARAM_BOOL:
-                                       value = g_key_file_get_boolean(cfile,
+                                       bval = g_key_file_get_boolean(cfile,
                                                        groups[i],
                                                        p[j].paramname, &err);
                                        if(!err) {
-                                               if(value) {
+                                               if(bval) {
                                                        *((gint*)p[j].target) |= p[j].flagval;
                                                } else {
                                                        *((gint*)p[j].target) &= ~(p[j].flagval);
@@ -806,11 +859,6 @@ GArray* parse_cfile(gchar* f, GError** e) {
                                        }
                                        break;
                        }
-                       if(!strcmp(p[j].paramname, "port") && !strcmp(p[j].target, NBD_DEFAULT_PORT)) {
-                               g_set_error(e, errdomain, CFILE_INCORRECT_PORT, "Config file specifies default port for oldstyle export");
-                               g_key_file_free(cfile);
-                               return NULL;
-                       }
                        if(err) {
                                if(err->code == G_KEY_FILE_ERROR_KEY_NOT_FOUND) {
                                        if(!p[j].required) {
@@ -1061,6 +1109,7 @@ void myseek(int handle,off_t a) {
  * @param buf The buffer to write from
  * @param len The length of buf
  * @param client The client we're serving for
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return The number of bytes actually written, or -1 in case of an error
  **/
 ssize_t rawexpwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1125,6 +1174,12 @@ ssize_t rawexpwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
 
 /**
  * Call rawexpwrite repeatedly until all data has been written.
+ *
+ * @param a The offset where the write should start
+ * @param buf The buffer to write from
+ * @param len The length of buf
+ * @param client The client we're serving for
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return 0 on success, nonzero on failure
  **/
 int rawexpwrite_fully(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1229,6 +1284,7 @@ int expread(off_t a, char *buf, size_t len, CLIENT *client) {
  * @param buf The buffer to write from
  * @param len The length of buf
  * @param client The client we're going to write for.
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return 0 on success, nonzero on failure
  **/
 int expwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1283,6 +1339,12 @@ int expwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
        return 0;
 }
 
+/**
+ * Flush data to a client
+ *
+ * @param client The client we're going to write for.
+ * @return 0 on success, nonzero on failure
+ **/
 int expflush(CLIENT *client) {
        gint i;
 
@@ -1304,7 +1366,7 @@ int expflush(CLIENT *client) {
  *
  * @param client The client we're negotiating with.
  **/
-CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
+CLIENT* negotiate(int net, CLIENT *client, GArray* servers, int phase) {
        char zeros[128];
        uint64_t size_host;
        uint32_t flags = NBD_FLAG_HAS_FLAGS;
@@ -1312,14 +1374,14 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
        uint64_t magic;
 
        memset(zeros, '\0', sizeof(zeros));
-       if(!client || !client->modern) {
+       if(phase & NEG_INIT) {
                /* common */
                if (write(net, INIT_PASSWD, 8) < 0) {
                        err_nonfatal("Negotiation failed: %m");
                        if(client)
                                exit(EXIT_FAILURE);
                }
-               if(!client || client->modern) {
+               if(phase & NEG_MODERN) {
                        /* modern */
                        magic = htonll(opts_magic);
                } else {
@@ -1328,11 +1390,11 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
                }
                if (write(net, &magic, sizeof(magic)) < 0) {
                        err_nonfatal("Negotiation failed: %m");
-                       if(client)
+                       if(phase & NEG_OLD)
                                exit(EXIT_FAILURE);
                }
        }
-       if(!client) {
+       if ((phase & NEG_MODERN) && (phase & NEG_INIT)) {
                /* modern */
                uint32_t reserved;
                uint32_t opt;
@@ -1395,7 +1457,7 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
                flags |= NBD_FLAG_SEND_FUA;
        if (client->server->flags & F_ROTATIONAL)
                flags |= NBD_FLAG_ROTATIONAL;
-       if (!client->modern) {
+       if (phase & NEG_OLD) {
                /* oldstyle */
                flags = htonl(flags);
                if (write(client->net, &flags, 4) < 0)
@@ -1436,7 +1498,7 @@ int mainloop(CLIENT *client) {
 #ifdef DODBG
        int i = 0;
 #endif
-       negotiate(client->net, client, NULL);
+       negotiate(client->net, client, NULL, client->modern ? NEG_MODERN : (NEG_OLD | NEG_INIT));
        DEBUG("Entering request loop!\n");
        reply.magic = htonl(NBD_REPLY_MAGIC);
        reply.error = 0;
@@ -1458,32 +1520,15 @@ int mainloop(CLIENT *client) {
                request.from = ntohll(request.from);
                request.type = ntohl(request.type);
                command = request.type & NBD_CMD_MASK_COMMAND;
-
-               if (command==NBD_CMD_DISC) {
-                       msg2(LOG_INFO, "Disconnect request received.");
-                       if (client->server->flags & F_COPYONWRITE) { 
-                               if (client->difmap) g_free(client->difmap) ;
-                               close(client->difffile);
-                               unlink(client->difffilename);
-                               free(client->difffilename);
-                       }
-                       go_on=FALSE;
-                       continue;
-               }
-
                len = ntohl(request.len);
 
+               DEBUG("%s from %llu (%llu) len %d, ", getcommandname(command),
+                               (unsigned long long)request.from,
+                               (unsigned long long)request.from / 512, (unsigned int)len);
+
                if (request.magic != htonl(NBD_REQUEST_MAGIC))
                        err("Not enough magic.");
-               if (len > BUFSIZE - sizeof(struct nbd_reply)) {
-                       currlen = BUFSIZE - sizeof(struct nbd_reply);
-                       msg2(LOG_INFO, "oversized request (this is not a problem)");
-               } else {
-                       currlen = len;
-               }
-               DEBUG("%s from %llu (%llu) len %d, ", command ? "WRITE" :
-                               "READ", (unsigned long long)request.from,
-                               (unsigned long long)request.from / 512, (unsigned int)len);
+
                memcpy(reply.handle, request.handle, sizeof(reply.handle));
 
                if ((command==NBD_CMD_WRITE) || (command==NBD_CMD_READ)) {
@@ -1498,9 +1543,28 @@ int mainloop(CLIENT *client) {
                                ERROR(client, reply, EINVAL);
                                continue;
                        }
+
+                       currlen = len;
+                       if (currlen > BUFSIZE - sizeof(struct nbd_reply)) {
+                               currlen = BUFSIZE - sizeof(struct nbd_reply);
+                               msg2(LOG_INFO, "oversized request (this is not a problem)");
+                       }
                }
 
-               if (command==NBD_CMD_WRITE) {
+               switch (command) {
+
+               case NBD_CMD_DISC:
+                       msg2(LOG_INFO, "Disconnect request received.");
+                       if (client->server->flags & F_COPYONWRITE) { 
+                               if (client->difmap) g_free(client->difmap) ;
+                               close(client->difffile);
+                               unlink(client->difffilename);
+                               free(client->difffilename);
+                       }
+                       go_on=FALSE;
+                       continue;
+
+               case NBD_CMD_WRITE:
                        DEBUG("wr: net->buf, ");
                        while(len > 0) {
                                readit(client->net, buf, currlen);
@@ -1509,23 +1573,25 @@ int mainloop(CLIENT *client) {
                                    (client->server->flags & F_AUTOREADONLY)) {
                                        DEBUG("[WRITE to READONLY!]");
                                        ERROR(client, reply, EPERM);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               if (expwrite(request.from, buf, len, client,
+                               if (expwrite(request.from, buf, currlen, client,
                                             request.type & NBD_CMD_FLAG_FUA)) {
                                        DEBUG("Write failed: %m" );
                                        ERROR(client, reply, errno);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               SEND(client->net, reply);
-                               DEBUG("OK!\n");
                                len -= currlen;
+                               request.from += currlen;
                                currlen = (len < BUFSIZE) ? len : BUFSIZE;
                        }
+                       SEND(client->net, reply);
+                       DEBUG("OK!\n");
                        continue;
-               }
 
-               if (command==NBD_CMD_FLUSH) {
+               case NBD_CMD_FLUSH:
                        DEBUG("fl: ");
                        if (expflush(client)) {
                                DEBUG("Flush failed: %m");
@@ -1535,9 +1601,8 @@ int mainloop(CLIENT *client) {
                        SEND(client->net, reply);
                        DEBUG("OK!\n");
                        continue;
-               }
 
-               if (command==NBD_CMD_READ) {
+               case NBD_CMD_READ:
                        DEBUG("exp->buf, ");
                        memcpy(buf, &reply, sizeof(struct nbd_reply));
                        if (client->transactionlogfd != -1)
@@ -1561,9 +1626,11 @@ int mainloop(CLIENT *client) {
                        }
                        DEBUG("OK!\n");
                        continue;
-               }
 
-               DEBUG ("Ignoring unknown command\n");
+               default:
+                       DEBUG ("Ignoring unknown command\n");
+                       continue;
+               }
        }
        return 0;
 }
@@ -1870,7 +1937,7 @@ int serveloop(GArray* servers) {
                        if(FD_ISSET(modernsock, &rset)) {
                                if((net=accept(modernsock, (struct sockaddr *) &addrin, &addrinlen)) < 0)
                                        err("accept: %m");
-                               client = negotiate(net, NULL, servers);
+                               client = negotiate(net, NULL, servers, NEG_INIT | NEG_MODERN);
                                if(!client) {
                                        err_nonfatal("negotiation failed");
                                        close(net);
@@ -2053,7 +2120,7 @@ void open_modern(void) {
        hints.ai_socktype = SOCK_STREAM;
        hints.ai_family = AF_UNSPEC;
        hints.ai_protocol = IPPROTO_TCP;
-       e = getaddrinfo(modern_listen, NBD_DEFAULT_PORT, &hints, &ai);
+       e = getaddrinfo(modern_listen, modernport, &hints, &ai);
        if(e != 0) {
                fprintf(stderr, "getaddrinfo failed: %s\n", gai_strerror(e));
                exit(EXIT_FAILURE);
@@ -2206,6 +2273,7 @@ void glib_message_syslog_redirect(const gchar *log_domain,
         break;
       case G_LOG_LEVEL_DEBUG:
         level=LOG_DEBUG;
+       break;
       default:
         level=LOG_ERR;
     }