Fix oversize writes to write to correct area of disk
[nbd.git] / nbd-server.c
index 3343d9d..41b847a 100644 (file)
@@ -262,6 +262,28 @@ typedef struct {
 } PARAM;
 
 /**
+ * Translate a command name into human readable form
+ *
+ * @param command The command number (after applying NBD_CMD_MASK_COMMAND)
+ * @return pointer to the command name
+ **/
+static inline const char * getcommandname(uint64_t command) {
+       switch (command) {
+       case NBD_CMD_READ:
+               return "NBD_CMD_READ";
+       case NBD_CMD_WRITE:
+               return "NBD_CMD_WRITE";
+       case NBD_CMD_DISC:
+               return "NBD_CMD_DISC";
+       case NBD_CMD_FLUSH:
+               return "NBD_CMD_FLUSH";
+       default:
+               break;
+       }
+       return "UNKNOWN";
+}
+
+/**
  * Check whether a client is allowed to connect. Works with an authorization
  * file which contains one line per machine, no wildcards.
  *
@@ -338,6 +360,24 @@ static inline void readit(int f, void *buf, size_t len) {
 }
 
 /**
+ * Consume data from an FD that we don't want
+ *
+ * @param f a file descriptor
+ * @param buf a buffer
+ * @param len the number of bytes to consume
+ * @param bufsiz the size of the buffer
+ **/
+static inline void consume(int f, void * buf, size_t len, size_t bufsiz) {
+       size_t curlen;
+       while (len>0) {
+               curlen = (len>bufsiz)?bufsiz:len;
+               readit(f, buf, curlen);
+               len -= curlen;
+       }
+}
+
+
+/**
  * Write data from a buffer into a filedescriptor
  *
  * @param f a file descriptor
@@ -1061,6 +1101,7 @@ void myseek(int handle,off_t a) {
  * @param buf The buffer to write from
  * @param len The length of buf
  * @param client The client we're serving for
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return The number of bytes actually written, or -1 in case of an error
  **/
 ssize_t rawexpwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1125,6 +1166,12 @@ ssize_t rawexpwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
 
 /**
  * Call rawexpwrite repeatedly until all data has been written.
+ *
+ * @param a The offset where the write should start
+ * @param buf The buffer to write from
+ * @param len The length of buf
+ * @param client The client we're serving for
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return 0 on success, nonzero on failure
  **/
 int rawexpwrite_fully(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1229,6 +1276,7 @@ int expread(off_t a, char *buf, size_t len, CLIENT *client) {
  * @param buf The buffer to write from
  * @param len The length of buf
  * @param client The client we're going to write for.
+ * @param fua Flag to indicate 'Force Unit Access'
  * @return 0 on success, nonzero on failure
  **/
 int expwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
@@ -1283,6 +1331,12 @@ int expwrite(off_t a, char *buf, size_t len, CLIENT *client, int fua) {
        return 0;
 }
 
+/**
+ * Flush data to a client
+ *
+ * @param client The client we're going to write for.
+ * @return 0 on success, nonzero on failure
+ **/
 int expflush(CLIENT *client) {
        gint i;
 
@@ -1458,32 +1512,15 @@ int mainloop(CLIENT *client) {
                request.from = ntohll(request.from);
                request.type = ntohl(request.type);
                command = request.type & NBD_CMD_MASK_COMMAND;
-
-               if (command==NBD_CMD_DISC) {
-                       msg2(LOG_INFO, "Disconnect request received.");
-                       if (client->server->flags & F_COPYONWRITE) { 
-                               if (client->difmap) g_free(client->difmap) ;
-                               close(client->difffile);
-                               unlink(client->difffilename);
-                               free(client->difffilename);
-                       }
-                       go_on=FALSE;
-                       continue;
-               }
-
                len = ntohl(request.len);
 
+               DEBUG("%s from %llu (%llu) len %d, ", getcommandname(command),
+                               (unsigned long long)request.from,
+                               (unsigned long long)request.from / 512, (unsigned int)len);
+
                if (request.magic != htonl(NBD_REQUEST_MAGIC))
                        err("Not enough magic.");
-               if (len > BUFSIZE - sizeof(struct nbd_reply)) {
-                       currlen = BUFSIZE - sizeof(struct nbd_reply);
-                       msg2(LOG_INFO, "oversized request (this is not a problem)");
-               } else {
-                       currlen = len;
-               }
-               DEBUG("%s from %llu (%llu) len %d, ", command ? "WRITE" :
-                               "READ", (unsigned long long)request.from,
-                               (unsigned long long)request.from / 512, (unsigned int)len);
+
                memcpy(reply.handle, request.handle, sizeof(reply.handle));
 
                if ((command==NBD_CMD_WRITE) || (command==NBD_CMD_READ)) {
@@ -1498,9 +1535,28 @@ int mainloop(CLIENT *client) {
                                ERROR(client, reply, EINVAL);
                                continue;
                        }
+
+                       currlen = len;
+                       if (currlen > BUFSIZE - sizeof(struct nbd_reply)) {
+                               currlen = BUFSIZE - sizeof(struct nbd_reply);
+                               msg2(LOG_INFO, "oversized request (this is not a problem)");
+                       }
                }
 
-               if (command==NBD_CMD_WRITE) {
+               switch (command) {
+
+               case NBD_CMD_DISC:
+                       msg2(LOG_INFO, "Disconnect request received.");
+                       if (client->server->flags & F_COPYONWRITE) { 
+                               if (client->difmap) g_free(client->difmap) ;
+                               close(client->difffile);
+                               unlink(client->difffilename);
+                               free(client->difffilename);
+                       }
+                       go_on=FALSE;
+                       continue;
+
+               case NBD_CMD_WRITE:
                        DEBUG("wr: net->buf, ");
                        while(len > 0) {
                                readit(client->net, buf, currlen);
@@ -1509,23 +1565,25 @@ int mainloop(CLIENT *client) {
                                    (client->server->flags & F_AUTOREADONLY)) {
                                        DEBUG("[WRITE to READONLY!]");
                                        ERROR(client, reply, EPERM);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               if (expwrite(request.from, buf, len, client,
+                               if (expwrite(request.from, buf, currlen, client,
                                             request.type & NBD_CMD_FLAG_FUA)) {
                                        DEBUG("Write failed: %m" );
                                        ERROR(client, reply, errno);
+                                       consume(client->net, buf, len-currlen, BUFSIZE);
                                        continue;
                                }
-                               SEND(client->net, reply);
-                               DEBUG("OK!\n");
                                len -= currlen;
+                               request.from += currlen;
                                currlen = (len < BUFSIZE) ? len : BUFSIZE;
                        }
+                       SEND(client->net, reply);
+                       DEBUG("OK!\n");
                        continue;
-               }
 
-               if (command==NBD_CMD_FLUSH) {
+               case NBD_CMD_FLUSH:
                        DEBUG("fl: ");
                        if (expflush(client)) {
                                DEBUG("Flush failed: %m");
@@ -1535,9 +1593,8 @@ int mainloop(CLIENT *client) {
                        SEND(client->net, reply);
                        DEBUG("OK!\n");
                        continue;
-               }
 
-               if (command==NBD_CMD_READ) {
+               case NBD_CMD_READ:
                        DEBUG("exp->buf, ");
                        memcpy(buf, &reply, sizeof(struct nbd_reply));
                        if (client->transactionlogfd != -1)
@@ -1561,9 +1618,11 @@ int mainloop(CLIENT *client) {
                        }
                        DEBUG("OK!\n");
                        continue;
-               }
 
-               DEBUG ("Ignoring unknown command\n");
+               default:
+                       DEBUG ("Ignoring unknown command\n");
+                       continue;
+               }
        }
        return 0;
 }