Allow varying the port for new-style handshakes
[nbd.git] / nbd-server.c
index a99c27b..f8028cb 100644 (file)
@@ -171,6 +171,8 @@ int modernsock=0;     /**< Socket for the modern handler. Not used
                               oldstyle is set to false (and then the
                               command-line client isn't used, gna gna) */
 char* modern_listen;     /**< listenaddr value for modernsock */
+char* modernport=NBD_DEFAULT_PORT; /**< Port number on which to listen for
+                                     new-style nbd-client connections */
 
 /**
  * Types of virtuatlization
@@ -261,6 +263,12 @@ typedef struct {
                                  is PARAM_BOOL. */
 } PARAM;
 
+/**
+ * Translate a command name into human readable form
+ *
+ * @param command The command number (after applying NBD_CMD_MASK_COMMAND)
+ * @return pointer to the command name
+ **/
 static inline const char * getcommandname(uint64_t command) {
        switch (command) {
        case NBD_CMD_READ:
@@ -354,6 +362,24 @@ static inline void readit(int f, void *buf, size_t len) {
 }
 
 /**
+ * Consume data from an FD that we don't want
+ *
+ * @param f a file descriptor
+ * @param buf a buffer
+ * @param len the number of bytes to consume
+ * @param bufsiz the size of the buffer
+ **/
+static inline void consume(int f, void * buf, size_t len, size_t bufsiz) {
+       size_t curlen;
+       while (len>0) {
+               curlen = (len>bufsiz)?bufsiz:len;
+               readit(f, buf, curlen);
+               len -= curlen;
+       }
+}
+
+
+/**
  * Write data from a buffer into a filedescriptor
  *
  * @param f a file descriptor
@@ -386,7 +412,7 @@ void usage() {
               "\t-p|--pid-file\t\tspecify a filename to write our PID to\n"
               "\t-o|--output-config\toutput a config file section for what you\n\t\t\t\tspecified on the command line, with the\n\t\t\t\tspecified section name\n"
               "\t-M|--max-connections\tspecify the maximum number of opened connections\n\n"
-              "\tif port is set to 0, stdin is used (for running from inetd)\n"
+              "\tif port is set to 0, stdin is used (for running from inetd).\n"
               "\tif file_to_export contains '%%s', it is substituted with the IP\n"
               "\t\taddress of the machine trying to connect\n" 
               "\tif ip is set, it contains the local IP address on which we're listening.\n\tif not, the server will listen on all local IP addresses\n");
@@ -753,6 +779,7 @@ GArray* parse_cfile(gchar* f, GError** e) {
                { "group",      FALSE, PARAM_STRING,    &rungroup,      0 },
                { "oldstyle",   FALSE, PARAM_BOOL,      &do_oldstyle,   1 },
                { "listenaddr", FALSE, PARAM_STRING,    &modern_listen, 0 },
+               { "port",       FALSE, PARAM_STRING,    &modernport,    0 },
        };
        PARAM* p=gp;
        int p_size=sizeof(gp)/sizeof(PARAM);
@@ -822,8 +849,8 @@ GArray* parse_cfile(gchar* f, GError** e) {
                                        }
                                        break;
                        }
-                       if(!strcmp(p[j].paramname, "port") && !strcmp(p[j].target, NBD_DEFAULT_PORT)) {
-                               g_set_error(e, errdomain, CFILE_INCORRECT_PORT, "Config file specifies default port for oldstyle export");
+                       if(!strcmp(p[j].paramname, "port") && !strcmp(p[j].target, modernport)) {
+                               g_set_error(e, errdomain, CFILE_INCORRECT_PORT, "Config file specifies new-style port for oldstyle export");
                                g_key_file_free(cfile);
                                return NULL;
                        }
@@ -1077,6 +1104,7 @@ void myseek(int handle,off_t a) {
  * @param buf The buffer to write from
  * @param len The length of buf
  * @param client The client we're serving for
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return The number of bytes actually written, or -1 in case of an error
  **/
 ssize_t rawexpwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1141,6 +1169,12 @@ ssize_t rawexpwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
 
 /**
  * Call rawexpwrite repeatedly until all data has been written.
+ *
+ * @param a The offset where the write should start
+ * @param buf The buffer to write from
+ * @param len The length of buf
+ * @param client The client we're serving for
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return 0 on success, nonzero on failure
  **/
 int rawexpwrite_fully(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1245,6 +1279,7 @@ int expread(off_t a, char *buf, size_t len, CLIENT *client) {
  * @param buf The buffer to write from
  * @param len The length of buf
  * @param client The client we're going to write for.
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return 0 on success, nonzero on failure
  **/
 int expwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1299,6 +1334,12 @@ int expwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
        return 0;
 }
 
+/**
+ * Flush data to a client
+ *
+ * @param client The client we're going to write for.
+ * @return 0 on success, nonzero on failure
+ **/
 int expflush(CLIENT *client) {
        gint i;
 
@@ -1527,18 +1568,21 @@ int mainloop(CLIENT *client) {
                                    (client->server->flags & F_AUTOREADONLY)) {
                                        DEBUG("[WRITE to READONLY!]");
                                        ERROR(client, reply, EPERM);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               if (expwrite(request.from, buf, len, client,
+                               if (expwrite(request.from, buf, currlen, client,
                                             request.type & NBD_CMD_FLAG_FUA)) {
                                        DEBUG("Write failed: %m" );
                                        ERROR(client, reply, errno);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               SEND(client->net, reply);
                                len -= currlen;
+                               request.from += currlen;
                                currlen = (len < BUFSIZE) ? len : BUFSIZE;
                        }
+                       SEND(client->net, reply);
                        DEBUG("OK!\n");
                        continue;
 
@@ -2071,7 +2115,7 @@ void open_modern(void) {
        hints.ai_socktype = SOCK_STREAM;
        hints.ai_family = AF_UNSPEC;
        hints.ai_protocol = IPPROTO_TCP;
-       e = getaddrinfo(modern_listen, NBD_DEFAULT_PORT, &hints, &ai);
+       e = getaddrinfo(modern_listen, modernport, &hints, &ai);
        if(e != 0) {
                fprintf(stderr, "getaddrinfo failed: %s\n", gai_strerror(e));
                exit(EXIT_FAILURE);