Add missing break
[nbd.git] / nbd-server.c
index a9098a5..aea5d2f 100644 (file)
@@ -165,12 +165,18 @@ char pidfname[256]; /**< name of our PID file */
 char pidftemplate[256]; /**< template to be used for the filename of the PID file */
 char default_authname[] = SYSCONFDIR "/nbd-server/allow"; /**< default name of allow file */
 
+#define NEG_INIT       (1 << 0)
+#define NEG_OLD                (1 << 1)
+#define NEG_MODERN     (1 << 2)
+
 int modernsock=0;        /**< Socket for the modern handler. Not used
                               if a client was only specified on the
                               command line; only port used if
                               oldstyle is set to false (and then the
                               command-line client isn't used, gna gna) */
 char* modern_listen;     /**< listenaddr value for modernsock */
+char* modernport=NBD_DEFAULT_PORT; /**< Port number on which to listen for
+                                     new-style nbd-client connections */
 
 /**
  * Types of virtuatlization
@@ -360,6 +366,24 @@ static inline void readit(int f, void *buf, size_t len) {
 }
 
 /**
+ * Consume data from an FD that we don't want
+ *
+ * @param f a file descriptor
+ * @param buf a buffer
+ * @param len the number of bytes to consume
+ * @param bufsiz the size of the buffer
+ **/
+static inline void consume(int f, void * buf, size_t len, size_t bufsiz) {
+       size_t curlen;
+       while (len>0) {
+               curlen = (len>bufsiz)?bufsiz:len;
+               readit(f, buf, curlen);
+               len -= curlen;
+       }
+}
+
+
+/**
  * Write data from a buffer into a filedescriptor
  *
  * @param f a file descriptor
@@ -392,7 +416,7 @@ void usage() {
               "\t-p|--pid-file\t\tspecify a filename to write our PID to\n"
               "\t-o|--output-config\toutput a config file section for what you\n\t\t\t\tspecified on the command line, with the\n\t\t\t\tspecified section name\n"
               "\t-M|--max-connections\tspecify the maximum number of opened connections\n\n"
-              "\tif port is set to 0, stdin is used (for running from inetd)\n"
+              "\tif port is set to 0, stdin is used (for running from inetd).\n"
               "\tif file_to_export contains '%%s', it is substituted with the IP\n"
               "\t\taddress of the machine trying to connect\n" 
               "\tif ip is set, it contains the local IP address on which we're listening.\n\tif not, the server will listen on all local IP addresses\n");
@@ -759,6 +783,7 @@ GArray* parse_cfile(gchar* f, GError** e) {
                { "group",      FALSE, PARAM_STRING,    &rungroup,      0 },
                { "oldstyle",   FALSE, PARAM_BOOL,      &do_oldstyle,   1 },
                { "listenaddr", FALSE, PARAM_STRING,    &modern_listen, 0 },
+               { "port",       FALSE, PARAM_STRING,    &modernport,    0 },
        };
        PARAM* p=gp;
        int p_size=sizeof(gp)/sizeof(PARAM);
@@ -768,7 +793,9 @@ GArray* parse_cfile(gchar* f, GError** e) {
        GQuark errdomain;
        GArray *retval=NULL;
        gchar **groups;
-       gboolean value;
+       gboolean bval;
+       gint ival;
+       gchar* sval;
        gchar* startgroup;
        gint i;
        gint j;
@@ -802,25 +829,29 @@ GArray* parse_cfile(gchar* f, GError** e) {
                        g_assert(p[j].ptype==PARAM_INT||p[j].ptype==PARAM_STRING||p[j].ptype==PARAM_BOOL);
                        switch(p[j].ptype) {
                                case PARAM_INT:
-                                       *((gint*)p[j].target) =
-                                               g_key_file_get_integer(cfile,
+                                       ival = g_key_file_get_integer(cfile,
                                                                groups[i],
                                                                p[j].paramname,
                                                                &err);
+                                       if(!err) {
+                                               *((gint*)p[j].target) = ival;
+                                       }
                                        break;
                                case PARAM_STRING:
-                                       *((gchar**)p[j].target) =
-                                               g_key_file_get_string(cfile,
+                                       sval = g_key_file_get_string(cfile,
                                                                groups[i],
                                                                p[j].paramname,
                                                                &err);
+                                       if(!err) {
+                                               *((gchar**)p[j].target) = sval;
+                                       }
                                        break;
                                case PARAM_BOOL:
-                                       value = g_key_file_get_boolean(cfile,
+                                       bval = g_key_file_get_boolean(cfile,
                                                        groups[i],
                                                        p[j].paramname, &err);
                                        if(!err) {
-                                               if(value) {
+                                               if(bval) {
                                                        *((gint*)p[j].target) |= p[j].flagval;
                                                } else {
                                                        *((gint*)p[j].target) &= ~(p[j].flagval);
@@ -828,11 +859,6 @@ GArray* parse_cfile(gchar* f, GError** e) {
                                        }
                                        break;
                        }
-                       if(!strcmp(p[j].paramname, "port") && !strcmp(p[j].target, NBD_DEFAULT_PORT)) {
-                               g_set_error(e, errdomain, CFILE_INCORRECT_PORT, "Config file specifies default port for oldstyle export");
-                               g_key_file_free(cfile);
-                               return NULL;
-                       }
                        if(err) {
                                if(err->code == G_KEY_FILE_ERROR_KEY_NOT_FOUND) {
                                        if(!p[j].required) {
@@ -1340,7 +1366,7 @@ int expflush(CLIENT *client) {
  *
  * @param client The client we're negotiating with.
  **/
-CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
+CLIENT* negotiate(int net, CLIENT *client, GArray* servers, int phase) {
        char zeros[128];
        uint64_t size_host;
        uint32_t flags = NBD_FLAG_HAS_FLAGS;
@@ -1348,14 +1374,14 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
        uint64_t magic;
 
        memset(zeros, '\0', sizeof(zeros));
-       if(!client || !client->modern) {
+       if(phase & NEG_INIT) {
                /* common */
                if (write(net, INIT_PASSWD, 8) < 0) {
                        err_nonfatal("Negotiation failed: %m");
                        if(client)
                                exit(EXIT_FAILURE);
                }
-               if(!client || client->modern) {
+               if(phase & NEG_MODERN) {
                        /* modern */
                        magic = htonll(opts_magic);
                } else {
@@ -1364,11 +1390,11 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
                }
                if (write(net, &magic, sizeof(magic)) < 0) {
                        err_nonfatal("Negotiation failed: %m");
-                       if(client)
+                       if(phase & NEG_OLD)
                                exit(EXIT_FAILURE);
                }
        }
-       if(!client) {
+       if ((phase & NEG_MODERN) && (phase & NEG_INIT)) {
                /* modern */
                uint32_t reserved;
                uint32_t opt;
@@ -1431,7 +1457,7 @@ CLIENT* negotiate(int net, CLIENT *client, GArray* servers) {
                flags |= NBD_FLAG_SEND_FUA;
        if (client->server->flags & F_ROTATIONAL)
                flags |= NBD_FLAG_ROTATIONAL;
-       if (!client->modern) {
+       if (phase & NEG_OLD) {
                /* oldstyle */
                flags = htonl(flags);
                if (write(client->net, &flags, 4) < 0)
@@ -1472,7 +1498,7 @@ int mainloop(CLIENT *client) {
 #ifdef DODBG
        int i = 0;
 #endif
-       negotiate(client->net, client, NULL);
+       negotiate(client->net, client, NULL, client->modern ? NEG_MODERN : (NEG_OLD | NEG_INIT));
        DEBUG("Entering request loop!\n");
        reply.magic = htonl(NBD_REPLY_MAGIC);
        reply.error = 0;
@@ -1547,15 +1573,18 @@ int mainloop(CLIENT *client) {
                                    (client->server->flags & F_AUTOREADONLY)) {
                                        DEBUG("[WRITE to READONLY!]");
                                        ERROR(client, reply, EPERM);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               if (expwrite(request.from, buf, len, client,
+                               if (expwrite(request.from, buf, currlen, client,
                                             request.type & NBD_CMD_FLAG_FUA)) {
                                        DEBUG("Write failed: %m" );
                                        ERROR(client, reply, errno);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
                                len -= currlen;
+                               request.from += currlen;
                                currlen = (len < BUFSIZE) ? len : BUFSIZE;
                        }
                        SEND(client->net, reply);
@@ -1908,7 +1937,7 @@ int serveloop(GArray* servers) {
                        if(FD_ISSET(modernsock, &rset)) {
                                if((net=accept(modernsock, (struct sockaddr *) &addrin, &addrinlen)) < 0)
                                        err("accept: %m");
-                               client = negotiate(net, NULL, servers);
+                               client = negotiate(net, NULL, servers, NEG_INIT | NEG_MODERN);
                                if(!client) {
                                        err_nonfatal("negotiation failed");
                                        close(net);
@@ -2091,7 +2120,7 @@ void open_modern(void) {
        hints.ai_socktype = SOCK_STREAM;
        hints.ai_family = AF_UNSPEC;
        hints.ai_protocol = IPPROTO_TCP;
-       e = getaddrinfo(modern_listen, NBD_DEFAULT_PORT, &hints, &ai);
+       e = getaddrinfo(modern_listen, modernport, &hints, &ai);
        if(e != 0) {
                fprintf(stderr, "getaddrinfo failed: %s\n", gai_strerror(e));
                exit(EXIT_FAILURE);
@@ -2244,6 +2273,7 @@ void glib_message_syslog_redirect(const gchar *log_domain,
         break;
       case G_LOG_LEVEL_DEBUG:
         level=LOG_DEBUG;
+       break;
       default:
         level=LOG_ERR;
     }