nbd-tester-client: ignore SIGPIPE so we pick up and print the error
[nbd.git] / nbd-tester-client.c
1 /*
2  * Test client to test the NBD server. Doesn't do anything useful, except
3  * checking that the server does, actually, work.
4  *
5  * Note that the only 'real' test is to check the client against a kernel. If
6  * it works here but does not work in the kernel, then that's most likely a bug
7  * in this program and/or in nbd-server.
8  *
9  * Copyright(c) 2006  Wouter Verhelst
10  *
11  * This program is Free Software; you can redistribute it and/or modify it
12  * under the terms of the GNU General Public License as published by the Free
13  * Software Foundation, in version 2.
14  *
15  * This program is distributed in the hope that it will be useful, but WITHOUT
16  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
17  * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
18  * more details.
19  *
20  * You should have received a copy of the GNU General Public License along with
21  * this program; if not, write to the Free Software Foundation, Inc., 51
22  * Franklin St, Fifth Floor, Boston, MA  02110-1301 USA
23  */
24 #include <stdlib.h>
25 #include <stdio.h>
26 #include <stdbool.h>
27 #include <string.h>
28 #include <sys/time.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <sys/stat.h>
32 #include <sys/mman.h>
33 #include <fcntl.h>
34 #include <syslog.h>
35 #include <unistd.h>
36 #include "config.h"
37 #include "lfs.h"
38 #include <netinet/in.h>
39 #include <glib.h>
40
41 #define MY_NAME "nbd-tester-client"
42 #include "cliserv.h"
43
44 static gchar errstr[1024];
45 const static int errstr_len=1024;
46
47 static uint64_t size;
48
49 static gchar * transactionlog = "nbd-tester-client.tr";
50
51 typedef enum {
52         CONNECTION_TYPE_NONE,
53         CONNECTION_TYPE_CONNECT,
54         CONNECTION_TYPE_INIT_PASSWD,
55         CONNECTION_TYPE_CLISERV,
56         CONNECTION_TYPE_FULL,
57 } CONNECTION_TYPE;
58
59 typedef enum {
60         CONNECTION_CLOSE_PROPERLY,
61         CONNECTION_CLOSE_FAST,
62 } CLOSE_TYPE;
63
64 struct reqcontext {
65         uint64_t seq;
66         char orighandle[8];
67         struct nbd_request req;
68         struct reqcontext * next;
69         struct reqcontext * prev;
70 };
71
72 struct rclist {
73         struct reqcontext * head;
74         struct reqcontext * tail;
75         int numitems;
76 };
77
78 struct chunk {
79         char * buffer;
80         char * readptr;
81         char * writeptr;
82         uint64_t space;
83         uint64_t length;
84         struct chunk * next;
85         struct chunk * prev;
86 };
87
88 struct chunklist {
89         struct chunk * head;
90         struct chunk * tail;
91         int numitems;
92 };
93
94 void rclist_unlink(struct rclist * l, struct reqcontext * p) {
95         if (p && l) {
96                 struct reqcontext * prev = p->prev;
97                 struct reqcontext * next = p->next;
98                 
99                 /* Fix link to previous */
100                 if (prev)
101                         prev->next = next;
102                 else
103                         l->head = next;
104                 
105                 if (next)
106                         next->prev = prev;
107                 else
108                         l->tail = prev;
109
110                 p->prev = NULL;
111                 p->next = NULL;
112                 l->numitems--;
113         }                                                       
114 }                                                                       
115
116 /* Add a new list item to the tail */
117 void rclist_addtail(struct rclist * l, struct reqcontext * p) {
118         if (!p || !l)
119                 return;
120         if (l->tail) {
121                 if (l->tail->next)
122                         g_warning("addtail found list tail has a next pointer");
123                 l->tail->next = p;
124                 p->next = NULL;
125                 p->prev = l->tail;
126                 l->tail = p;
127         } else {
128                 if (l->head)
129                         g_warning("addtail found no list tail but a list head");
130                 l->head = p;
131                 l->tail = p;
132                 p->prev = NULL;
133                 p->next = NULL;
134         }
135         l->numitems++;
136 }
137
138 void chunklist_unlink(struct chunklist * l, struct chunk * p) {
139         if (p && l) {
140                 struct chunk * prev = p->prev;
141                 struct chunk * next = p->next;
142                 
143                 /* Fix link to previous */
144                 if (prev)
145                         prev->next = next;
146                 else
147                         l->head = next;
148                 
149                 if (next)
150                         next->prev = prev;
151                 else
152                         l->tail = prev;
153
154                 p->prev = NULL;
155                 p->next = NULL;
156                 l->numitems--;
157         }                                                       
158 }                                                                       
159
160 /* Add a new list item to the tail */
161 void chunklist_addtail(struct chunklist * l, struct chunk * p) {
162         if (!p || !l)
163                 return;
164         if (l->tail) {
165                 if (l->tail->next)
166                         g_warning("addtail found list tail has a next pointer");
167                 l->tail->next = p;
168                 p->next = NULL;
169                 p->prev = l->tail;
170                 l->tail = p;
171         } else {
172                 if (l->head)
173                         g_warning("addtail found no list tail but a list head");
174                 l->head = p;
175                 l->tail = p;
176                 p->prev = NULL;
177                 p->next = NULL;
178         }
179         l->numitems++;
180 }
181
182 /* Add some new bytes to a chunklist */
183 void addbuffer(struct chunklist * l, void * data, uint64_t len) {
184         void * buf;
185         uint64_t size = 64*1024;
186         struct chunk * pchunk;
187
188         while (len>0)
189         {
190                 /* First see if there is a current chunk, and if it has space */
191                 if (l->tail && l->tail->space) {
192                         uint64_t towrite = len;
193                         if (towrite > l->tail->space)
194                                 towrite = l->tail->space;
195                         memcpy(l->tail->writeptr, data, towrite);
196                         l->tail->length += towrite;
197                         l->tail->space -= towrite;
198                         l->tail->writeptr += towrite;
199                         len -= towrite;
200                         data += towrite;
201                 }
202
203                 if (len>0) {
204                         /* We still need to write more, so prepare a new chunk */
205                         if ((NULL == (buf = malloc(size))) || (NULL == (pchunk = calloc(1, sizeof(struct chunk))))) {
206                                 g_critical("Out of memory");
207                                 exit (1);
208                         }
209
210                         pchunk->buffer = buf;
211                         pchunk->readptr = buf;
212                         pchunk->writeptr = buf;
213                         pchunk->space = size;
214                         chunklist_addtail(l, pchunk);
215                 }
216         }
217
218 }
219
220 /* returns 0 on success, -1 on failure */
221 int writebuffer(int fd, struct chunklist * l) {
222
223         struct chunk * pchunk = NULL;
224         int res;
225         if (!l)
226                 return 0;
227
228         while (!pchunk)
229         {
230                 pchunk = l->head;
231                 if (!pchunk)
232                         return 0;
233                 if (!(pchunk->length) || !(pchunk->readptr)) {
234                         chunklist_unlink(l, pchunk);
235                         free(pchunk->buffer);
236                         free(pchunk);
237                         pchunk = NULL;
238                 }
239         }
240         
241         /* OK we have a chunk with some data in */
242         res = write(fd, pchunk->readptr, pchunk->length);
243         if (res==0)
244                 errno = EAGAIN;
245         if (res<=0)
246                 return -1;
247         pchunk->length -= res;
248         pchunk->readptr += res;
249         if (!pchunk->length) {
250                 chunklist_unlink(l, pchunk);
251                 free(pchunk->buffer);
252                 free(pchunk);
253         }
254         return 0;
255 }
256
257
258
259 #define TEST_WRITE (1<<0)
260 #define TEST_FLUSH (1<<1)
261
262 int timeval_subtract (struct timeval *result, struct timeval *x,
263                       struct timeval *y) {
264         if (x->tv_usec < y->tv_usec) {
265                 int nsec = (y->tv_usec - x->tv_usec) / 1000000 + 1;
266                 y->tv_usec -= 1000000 * nsec;
267                 y->tv_sec += nsec;
268         }
269         
270         if (x->tv_usec - y->tv_usec > 1000000) {
271                 int nsec = (x->tv_usec - y->tv_usec) / 1000000;
272                 y->tv_usec += 1000000 * nsec;
273                 y->tv_sec -= nsec;
274         }
275         
276         result->tv_sec = x->tv_sec - y->tv_sec;
277         result->tv_usec = x->tv_usec - y->tv_usec;
278         
279         return x->tv_sec < y->tv_sec;
280 }
281
282 double timeval_diff_to_double (struct timeval * x, struct timeval * y) {
283         struct timeval r;
284         timeval_subtract(&r, x, y);
285         return r.tv_sec * 1.0 + r.tv_usec/1000000.0;
286 }
287
288 static inline int read_all(int f, void *buf, size_t len) {
289         ssize_t res;
290         size_t retval=0;
291
292         while(len>0) {
293                 if((res=read(f, buf, len)) <=0) {
294                         if (!res)
295                                 errno=EAGAIN;
296                         snprintf(errstr, errstr_len, "Read failed: %s", strerror(errno));
297                         return -1;
298                 }
299                 len-=res;
300                 buf+=res;
301                 retval+=res;
302         }
303         return retval;
304 }
305
306 static inline int write_all(int f, void *buf, size_t len) {
307         ssize_t res;
308         size_t retval=0;
309
310         while(len>0) {
311                 if((res=write(f, buf, len)) <=0) {
312                         if (!res)
313                                 errno=EAGAIN;
314                         snprintf(errstr, errstr_len, "Write failed: %s", strerror(errno));
315                         return -1;
316                 }
317                 len-=res;
318                 buf+=res;
319                 retval+=res;
320         }
321         return retval;
322 }
323
324 #define READ_ALL_ERRCHK(f, buf, len, whereto, errmsg...) if((read_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); goto whereto; }
325 #define READ_ALL_ERR_RT(f, buf, len, whereto, rval, errmsg...) if((read_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); retval = rval; goto whereto; }
326
327 #define WRITE_ALL_ERRCHK(f, buf, len, whereto, errmsg...) if((write_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); goto whereto; }
328 #define WRITE_ALL_ERR_RT(f, buf, len, whereto, rval, errmsg...) if((write_all(f, buf, len))<=0) { snprintf(errstr, errstr_len, ##errmsg); retval = rval; goto whereto; }
329
330 int setup_connection(gchar *hostname, int port, gchar* name, CONNECTION_TYPE ctype, int* serverflags) {
331         int sock;
332         struct hostent *host;
333         struct sockaddr_in addr;
334         char buf[256];
335         uint64_t mymagic = (name ? opts_magic : cliserv_magic);
336         u64 tmp64;
337         uint32_t tmp32 = 0;
338
339         sock=0;
340         if(ctype<CONNECTION_TYPE_CONNECT)
341                 goto end;
342         if((sock=socket(PF_INET, SOCK_STREAM, IPPROTO_TCP))<0) {
343                 strncpy(errstr, strerror(errno), errstr_len);
344                 goto err;
345         }
346         setmysockopt(sock);
347         if(!(host=gethostbyname(hostname))) {
348                 strncpy(errstr, strerror(errno), errstr_len);
349                 goto err_open;
350         }
351         addr.sin_family=AF_INET;
352         addr.sin_port=htons(port);
353         addr.sin_addr.s_addr=*((int *) host->h_addr);
354         if((connect(sock, (struct sockaddr *)&addr, sizeof(addr))<0)) {
355                 strncpy(errstr, strerror(errno), errstr_len);
356                 goto err_open;
357         }
358         if(ctype<CONNECTION_TYPE_INIT_PASSWD)
359                 goto end;
360         READ_ALL_ERRCHK(sock, buf, strlen(INIT_PASSWD), err_open, "Could not read INIT_PASSWD: %s", strerror(errno));
361         if(strlen(buf)==0) {
362                 snprintf(errstr, errstr_len, "Server closed connection");
363                 goto err_open;
364         }
365         if(strncmp(buf, INIT_PASSWD, strlen(INIT_PASSWD))) {
366                 snprintf(errstr, errstr_len, "INIT_PASSWD does not match");
367                 goto err_open;
368         }
369         if(ctype<CONNECTION_TYPE_CLISERV)
370                 goto end;
371         READ_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err_open, "Could not read cliserv_magic: %s", strerror(errno));
372         tmp64=ntohll(tmp64);
373         if(tmp64 != mymagic) {
374                 strncpy(errstr, "mymagic does not match", errstr_len);
375                 goto err_open;
376         }
377         if(ctype<CONNECTION_TYPE_FULL)
378                 goto end;
379         if(!name) {
380                 READ_ALL_ERRCHK(sock, &size, sizeof(size), err_open, "Could not read size: %s", strerror(errno));
381                 size=ntohll(size);
382                 READ_ALL_ERRCHK(sock, buf, 128, err_open, "Could not read data: %s", strerror(errno));
383                 goto end;
384         }
385         /* flags */
386         READ_ALL_ERRCHK(sock, buf, sizeof(uint16_t), err_open, "Could not read reserved field: %s", strerror(errno));
387         /* reserved field */
388         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err_open, "Could not write reserved field: %s", strerror(errno));
389         /* magic */
390         tmp64 = htonll(opts_magic);
391         WRITE_ALL_ERRCHK(sock, &tmp64, sizeof(tmp64), err_open, "Could not write magic: %s", strerror(errno));
392         /* name */
393         tmp32 = htonl(NBD_OPT_EXPORT_NAME);
394         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err_open, "Could not write option: %s", strerror(errno));
395         tmp32 = htonl((uint32_t)strlen(name));
396         WRITE_ALL_ERRCHK(sock, &tmp32, sizeof(tmp32), err_open, "Could not write name length: %s", strerror(errno));
397         WRITE_ALL_ERRCHK(sock, name, strlen(name), err_open, "Could not write name:: %s", strerror(errno));
398         READ_ALL_ERRCHK(sock, &size, sizeof(size), err_open, "Could not read size: %s", strerror(errno));
399         size = ntohll(size);
400         uint16_t flags;
401         READ_ALL_ERRCHK(sock, &flags, sizeof(uint16_t), err_open, "Could not read flags: %s", strerror(errno));
402         flags = ntohs(flags);
403         *serverflags = flags;
404         READ_ALL_ERRCHK(sock, buf, 124, err_open, "Could not read reserved zeroes: %s", strerror(errno));
405         goto end;
406 err_open:
407         close(sock);
408 err:
409         sock=-1;
410 end:
411         return sock;
412 }
413
414 int close_connection(int sock, CLOSE_TYPE type) {
415         struct nbd_request req;
416         u64 counter=0;
417
418         switch(type) {
419                 case CONNECTION_CLOSE_PROPERLY:
420                         req.magic=htonl(NBD_REQUEST_MAGIC);
421                         req.type=htonl(NBD_CMD_DISC);
422                         memcpy(&(req.handle), &(counter), sizeof(counter));
423                         counter++;
424                         req.from=0;
425                         req.len=0;
426                         if(write(sock, &req, sizeof(req))<0) {
427                                 snprintf(errstr, errstr_len, "Could not write to socket: %s", strerror(errno));
428                                 return -1;
429                         }
430                 case CONNECTION_CLOSE_FAST:
431                         if(close(sock)<0) {
432                                 snprintf(errstr, errstr_len, "Could not close socket: %s", strerror(errno));
433                                 return -1;
434                         }
435                         break;
436                 default:
437                         g_critical("Your compiler is on crack!"); /* or I am buggy */
438                         return -1;
439         }
440         return 0;
441 }
442
443 int read_packet_check_header(int sock, size_t datasize, long long int curhandle) {
444         struct nbd_reply rep;
445         int retval=0;
446         char buf[datasize];
447
448         READ_ALL_ERR_RT(sock, &rep, sizeof(rep), end, -1, "Could not read reply header: %s", strerror(errno));
449         rep.magic=ntohl(rep.magic);
450         rep.error=ntohl(rep.error);
451         if(rep.magic!=NBD_REPLY_MAGIC) {
452                 snprintf(errstr, errstr_len, "Received package with incorrect reply_magic. Index of sent packages is %lld (0x%llX), received handle is %lld (0x%llX). Received magic 0x%lX, expected 0x%lX", (long long int)curhandle, (long long unsigned int)curhandle, (long long int)*((u64*)rep.handle), (long long unsigned int)*((u64*)rep.handle), (long unsigned int)rep.magic, (long unsigned int)NBD_REPLY_MAGIC);
453                 retval=-1;
454                 goto end;
455         }
456         if(rep.error) {
457                 snprintf(errstr, errstr_len, "Received error from server: %ld (0x%lX). Handle is %lld (0x%llX).", (long int)rep.error, (long unsigned int)rep.error, (long long int)(*((u64*)rep.handle)), (long long unsigned int)*((u64*)rep.handle));
458                 retval=-1;
459                 goto end;
460         }
461         if (datasize)
462                 READ_ALL_ERR_RT(sock, &buf, datasize, end, -1, "Could not read data: %s", strerror(errno));
463
464 end:
465         return retval;
466 }
467
468 int oversize_test(gchar* hostname, int port, char* name, int sock,
469                   char sock_is_open, char close_sock, int testflags) {
470         int retval=0;
471         struct nbd_request req;
472         struct nbd_reply rep;
473         int i=0;
474         int serverflags = 0;
475         pid_t G_GNUC_UNUSED mypid = getpid();
476         char buf[((1024*1024)+sizeof(struct nbd_request)/2)<<1];
477         bool got_err;
478
479         /* This should work */
480         if(!sock_is_open) {
481                 if((sock=setup_connection(hostname, port, name, CONNECTION_TYPE_FULL, &serverflags))<0) {
482                         g_warning("Could not open socket: %s", errstr);
483                         retval=-1;
484                         goto err;
485                 }
486         }
487         req.magic=htonl(NBD_REQUEST_MAGIC);
488         req.type=htonl(NBD_CMD_READ);
489         req.len=htonl(1024*1024);
490         memcpy(&(req.handle),&i,sizeof(i));
491         req.from=htonll(i);
492         WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1, "Could not write request: %s", strerror(errno));
493         printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
494         READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1, "Could not read reply header: %s", strerror(errno));
495         READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1, "Could not read data: %s", strerror(errno));
496         if(rep.error) {
497                 snprintf(errstr, errstr_len, "Received unexpected error: %d", rep.error);
498                 retval=-1;
499                 goto err;
500         } else {
501                 printf("OK\n");
502         }
503         /* This probably should not work */
504         i++; req.from=htonll(i);
505         req.len = htonl(ntohl(req.len) + sizeof(struct nbd_request) / 2);
506         WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1, "Could not write request: %s", strerror(errno));
507         printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
508         READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1, "Could not read reply header: %s", strerror(errno));
509         READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1, "Could not read data: %s", strerror(errno));
510         if(rep.error) {
511                 printf("Received expected error\n");
512                 got_err=true;
513         } else {
514                 printf("OK\n");
515                 got_err=false;
516         }
517         /* ... unless this works, too */
518         i++; req.from=htonll(i);
519         req.len = htonl(ntohl(req.len) << 1);
520         WRITE_ALL_ERR_RT(sock, &req, sizeof(req), err, -1, "Could not write request: %s", strerror(errno));
521         printf("%d: testing oversized request: %d: ", getpid(), ntohl(req.len));
522         READ_ALL_ERR_RT(sock, &rep, sizeof(struct nbd_reply), err, -1, "Could not read reply header: %s", strerror(errno));
523         READ_ALL_ERR_RT(sock, &buf, ntohl(req.len), err, -1, "Could not read data: %s", strerror(errno));
524         if(rep.error) {
525                 printf("error\n");
526         } else {
527                 printf("OK\n");
528         }
529         if((rep.error && !got_err) || (!rep.error && got_err)) {
530                 printf("Received unexpected error\n");
531                 retval=-1;
532         }
533   err:
534         return retval;
535 }
536
537 int throughput_test(gchar* hostname, int port, char* name, int sock,
538                     char sock_is_open, char close_sock, int testflags) {
539         long long int i;
540         char writebuf[1024];
541         struct nbd_request req;
542         int requests=0;
543         fd_set set;
544         struct timeval tv;
545         struct timeval start;
546         struct timeval stop;
547         double timespan;
548         double speed;
549         char speedchar[2] = { '\0', '\0' };
550         int retval=0;
551         int serverflags = 0;
552         signed int do_write=TRUE;
553         pid_t mypid = getpid();
554
555
556         if (!(testflags & TEST_WRITE))
557                 testflags &= ~TEST_FLUSH;
558
559         memset (writebuf, 'X', 1024);
560         size=0;
561         if(!sock_is_open) {
562                 if((sock=setup_connection(hostname, port, name, CONNECTION_TYPE_FULL, &serverflags))<0) {
563                         g_warning("Could not open socket: %s", errstr);
564                         retval=-1;
565                         goto err;
566                 }
567         }
568         if ((testflags & TEST_FLUSH) && ((serverflags & (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
569                                          != (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))) {
570                 snprintf(errstr, errstr_len, "Server did not supply flush capability flags");
571                 retval = -1;
572                 goto err_open;
573         }
574         req.magic=htonl(NBD_REQUEST_MAGIC);
575         req.len=htonl(1024);
576         if(gettimeofday(&start, NULL)<0) {
577                 retval=-1;
578                 snprintf(errstr, errstr_len, "Could not measure start time: %s", strerror(errno));
579                 goto err_open;
580         }
581         for(i=0;i+1024<=size;i+=1024) {
582                 if(do_write) {
583                         int sendfua = (testflags & TEST_FLUSH) && (((i>>10) & 15) == 3);
584                         int sendflush = (testflags & TEST_FLUSH) && (((i>>10) & 15) == 11);
585                         req.type=htonl((testflags & TEST_WRITE)?NBD_CMD_WRITE:NBD_CMD_READ);
586                         if (sendfua)
587                                 req.type = htonl(NBD_CMD_WRITE | NBD_CMD_FLAG_FUA);
588                         memcpy(&(req.handle),&i,sizeof(i));
589                         req.from=htonll(i);
590                         if (write_all(sock, &req, sizeof(req)) <0) {
591                                 retval=-1;
592                                 goto err_open;
593                         }
594                         if (testflags & TEST_WRITE) {
595                                 if (write_all(sock, writebuf, 1024) <0) {
596                                         retval=-1;
597                                         goto err_open;
598                                 }
599                         }
600                         printf("%d: Requests(+): %d\n", (int)mypid, ++requests);
601                         if (sendflush) {
602                                 long long int j = i ^ (1LL<<63);
603                                 req.type = htonl(NBD_CMD_FLUSH);
604                                 memcpy(&(req.handle),&j,sizeof(j));
605                                 req.from=0;
606                                 if (write_all(sock, &req, sizeof(req)) <0) {
607                                         retval=-1;
608                                         goto err_open;
609                                 }
610                                 printf("%d: Requests(+): %d\n", (int)mypid, ++requests);
611                         }
612                 }
613                 do {
614                         FD_ZERO(&set);
615                         FD_SET(sock, &set);
616                         tv.tv_sec=0;
617                         tv.tv_usec=0;
618                         select(sock+1, &set, NULL, NULL, &tv);
619                         if(FD_ISSET(sock, &set)) {
620                                 /* Okay, there's something ready for
621                                  * reading here */
622                                 if(read_packet_check_header(sock, (testflags & TEST_WRITE)?0:1024, i)<0) {
623                                         retval=-1;
624                                         goto err_open;
625                                 }
626                                 printf("%d: Requests(-): %d\n", (int)mypid, --requests);
627                         }
628                 } while FD_ISSET(sock, &set);
629                 /* Now wait until we can write again or until a second have
630                  * passed, whichever comes first*/
631                 FD_ZERO(&set);
632                 FD_SET(sock, &set);
633                 tv.tv_sec=1;
634                 tv.tv_usec=0;
635                 do_write=select(sock+1,NULL,&set,NULL,&tv);
636                 if(!do_write) printf("Select finished\n");
637                 if(do_write<0) {
638                         snprintf(errstr, errstr_len, "select: %s", strerror(errno));
639                         retval=-1;
640                         goto err_open;
641                 }
642         }
643         /* Now empty the read buffer */
644         do {
645                 FD_ZERO(&set);
646                 FD_SET(sock, &set);
647                 tv.tv_sec=0;
648                 tv.tv_usec=0;
649                 select(sock+1, &set, NULL, NULL, &tv);
650                 if(FD_ISSET(sock, &set)) {
651                         /* Okay, there's something ready for
652                          * reading here */
653                         read_packet_check_header(sock, (testflags & TEST_WRITE)?0:1024, i);
654                         printf("%d: Requests(-): %d\n", (int)mypid, --requests);
655                 }
656         } while (requests);
657         if(gettimeofday(&stop, NULL)<0) {
658                 retval=-1;
659                 snprintf(errstr, errstr_len, "Could not measure end time: %s", strerror(errno));
660                 goto err_open;
661         }
662         timespan=timeval_diff_to_double(&stop, &start);
663         speed=size/timespan;
664         if(speed>1024) {
665                 speed=speed/1024.0;
666                 speedchar[0]='K';
667         }
668         if(speed>1024) {
669                 speed=speed/1024.0;
670                 speedchar[0]='M';
671         }
672         if(speed>1024) {
673                 speed=speed/1024.0;
674                 speedchar[0]='G';
675         }
676         g_message("%d: Throughput %s test (%s flushes) complete. Took %.3f seconds to complete, %.3f%sib/s", (int)getpid(), (testflags & TEST_WRITE)?"write":"read", (testflags & TEST_FLUSH)?"with":"without", timespan, speed, speedchar);
677
678 err_open:
679         if(close_sock) {
680                 close_connection(sock, CONNECTION_CLOSE_PROPERLY);
681         }
682 err:
683         return retval;
684 }
685
686 /*
687  * fill 512 byte buffer 'buf' with a hashed selection of interesting data based
688  * only on handle and blknum. The first word is blknum, and the second handle, for ease
689  * of understanding. Things with handle 0 are blank.
690  */
691 static inline void makebuf(char *buf, uint64_t seq, uint64_t blknum) {
692         uint64_t x = ((uint64_t)blknum) ^ (seq << 32) ^ (seq >> 32);
693         uint64_t* p = (uint64_t*)buf;
694         int i;
695         if (!seq) {
696                 bzero(buf, 512);
697                 return;
698         }
699         for (i = 0; i<512/sizeof(uint64_t); i++) {
700                 int s;
701                 *(p++) = x;
702                 x+=0xFEEDA1ECDEADBEEFULL+i+(((uint64_t)i)<<56);
703                 s = x & 63;
704                 x = x ^ (x<<s) ^ (x>>(64-s)) ^ 0xAA55AA55AA55AA55ULL ^ seq;
705         }
706 }
707                 
708 static inline int checkbuf(char *buf, uint64_t seq, uint64_t blknum) {
709         char cmp[512];
710         makebuf(cmp, seq, blknum);
711         return memcmp(cmp, buf, 512)?-1:0;
712 }
713
714 static inline void dumpcommand(char * text, uint32_t command)
715 {
716 #ifdef DEBUG_COMMANDS
717         command=ntohl(command);
718         char * ctext;
719         switch (command & NBD_CMD_MASK_COMMAND) {
720         case NBD_CMD_READ:
721                 ctext="NBD_CMD_READ";
722                 break;
723         case NBD_CMD_WRITE:
724                 ctext="NBD_CMD_WRITE";
725                 break;
726         case NBD_CMD_DISC:
727                 ctext="NBD_CMD_DISC";
728                 break;
729         case NBD_CMD_FLUSH:
730                 ctext="NBD_CMD_FLUSH";
731                 break;
732         default:
733                 ctext="UNKNOWN";
734                 break;
735         }
736         printf("%s: %s [%s] (0x%08x)\n",
737                text,
738                ctext,
739                (command & NBD_CMD_FLAG_FUA)?"FUA":"NONE",
740                command);
741 #endif
742 }
743
744 /* return an unused handle */
745 uint64_t getrandomhandle(GHashTable *phash) {
746         uint64_t handle = 0;
747         int i;
748         do {
749                 /* RAND_MAX may be as low as 2^15 */
750                 for (i= 1 ; i<=5; i++)
751                         handle ^= random() ^ (handle << 15); 
752         } while (g_hash_table_lookup(phash, &handle));
753         return handle;
754 }
755
756 int integrity_test(gchar* hostname, int port, char* name, int sock,
757                    char sock_is_open, char close_sock, int testflags) {
758         struct nbd_reply rep;
759         fd_set rset;
760         fd_set wset;
761         struct timeval tv;
762         struct timeval start;
763         struct timeval stop;
764         double timespan;
765         double speed;
766         char speedchar[2] = { '\0', '\0' };
767         int retval=0;
768         int serverflags = 0;
769         pid_t G_GNUC_UNUSED mypid = getpid();
770         int blkhashfd = -1;
771         char *blkhashname=NULL;
772         uint32_t *blkhash = NULL;
773         int logfd=-1;
774         uint64_t seq=1;
775         uint64_t processed=0;
776         uint64_t printer=0;
777         uint64_t xfer=0;
778         int readtransactionfile = 1;
779         struct rclist txqueue={NULL, NULL, 0};
780         struct rclist inflight={NULL, NULL, 0};
781         struct chunklist txbuf={NULL, NULL, 0};
782
783         GHashTable *handlehash = g_hash_table_new(g_int64_hash, g_int64_equal);
784
785         size=0;
786         if(!sock_is_open) {
787                 if((sock=setup_connection(hostname, port, name, CONNECTION_TYPE_FULL, &serverflags))<0) {
788                         g_warning("Could not open socket: %s", errstr);
789                         retval=-1;
790                         goto err;
791                 }
792         }
793
794         if ((serverflags & (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
795             != (NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA))
796                 g_warning("Server flags do not support FLUSH and FUA - these may error");
797
798 #ifdef HAVE_MKSTEMP
799         blkhashname=strdup("/tmp/blkarray-XXXXXX");
800         if (!blkhashname || (-1 == (blkhashfd = mkstemp(blkhashname)))) {
801                 g_warning("Could not open temp file: %s", strerror(errno));
802                 retval=-1;
803                 goto err;
804         }
805 #else
806         /* use tmpnam here to avoid further feature test nightmare */
807         if (-1 == (blkhashfd = open(blkhashname=strdup(tmpnam(NULL)),
808                                     O_CREAT | O_RDWR,
809                                     S_IRUSR|S_IWUSR|S_IRGRP|S_IROTH))) {
810                 g_warning("Could not open temp file: %s", strerror(errno));
811                 retval=-1;
812                 goto err;
813         }
814 #endif
815         /* Ensure space freed if we die */
816         if (-1 == unlink(blkhashname)) {
817                 g_warning("Could not unlink temp file: %s", strerror(errno));
818                 retval=-1;
819                 goto err;
820         }
821
822         if (-1 == lseek(blkhashfd, (off_t)((size>>9)<<2), SEEK_SET)) {
823                 g_warning("Could not llseek temp file: %s", strerror(errno));
824                 retval=-1;
825                 goto err;
826         }
827
828         if (-1 == write(blkhashfd, "\0", 1)) {
829                 g_warning("Could not write temp file: %s", strerror(errno));
830                 retval=-1;
831                 goto err;
832         }
833
834         if (NULL == (blkhash = mmap(NULL,
835                                     (size>>9)<<2,
836                                     PROT_READ | PROT_WRITE,
837                                     MAP_SHARED,
838                                     blkhashfd,
839                                     0))) {
840                 g_warning("Could not mmap temp file: %s", strerror(errno));
841                 retval=-1;
842                 goto err;
843         }
844
845         if (-1 == (logfd = open(transactionlog, O_RDONLY)))
846         {
847                 g_warning("Could open log file: %s", strerror(errno));
848                 retval=-1;
849                 goto err;
850         }
851                 
852         if(gettimeofday(&start, NULL)<0) {
853                 retval=-1;
854                 snprintf(errstr, errstr_len, "Could not measure start time: %s", strerror(errno));
855                 goto err_open;
856         }
857
858         while (readtransactionfile || txqueue.numitems || txbuf.numitems || inflight.numitems) {
859                 int ret;
860
861                 uint32_t magic;
862                 uint32_t command;
863                 uint64_t from;
864                 uint32_t len;
865                 struct reqcontext * prc;
866
867                 *errstr=0;
868
869                 FD_ZERO(&wset);
870                 FD_ZERO(&rset);
871                 if (readtransactionfile)
872                         FD_SET(logfd, &rset);
873                 if (txqueue.numitems || txbuf.numitems)
874                         FD_SET(sock, &wset);
875                 if (inflight.numitems)
876                         FD_SET(sock, &rset);
877                 tv.tv_sec=5;
878                 tv.tv_usec=0;
879                 ret = select(1+((sock>logfd)?sock:logfd), &rset, &wset, NULL, &tv);
880                 if (ret == 0) {
881                         retval=-1;
882                         snprintf(errstr, errstr_len, "Timeout reading from socket");
883                         goto err_open;
884                 } else if (ret<0) {
885                         g_warning("Could not mmap temp file: %s", errstr);
886                         retval=-1;
887                         goto err;
888                 }
889                 /* We know we've got at least one thing to do here then */
890
891                 /* Get a command from the transaction log */
892                 if (FD_ISSET(logfd, &rset)) {
893                         
894                         /* Read a request or reply from the transaction file */
895                         READ_ALL_ERRCHK(logfd,
896                                         &magic,
897                                         sizeof(magic),
898                                         err_open,
899                                         "Could not read transaction log: %s",
900                                         strerror(errno));
901                         magic = ntohl(magic);
902                         switch (magic) {
903                         case NBD_REQUEST_MAGIC:
904                                 if (NULL == (prc = calloc(1, sizeof(struct reqcontext)))) {
905                                         retval=-1;
906                                         snprintf(errstr, errstr_len, "Could not allocate request");
907                                         goto err_open;
908                                 }
909                                 READ_ALL_ERRCHK(logfd,
910                                                 sizeof(magic)+(char *)&(prc->req),
911                                                 sizeof(struct nbd_request)-sizeof(magic),
912                                                 err_open,
913                                                 "Could not read transaction log: %s",
914                                                 strerror(errno));
915                                 prc->req.magic = htonl(NBD_REQUEST_MAGIC);
916                                 memcpy(prc->orighandle, prc->req.handle, 8);
917                                 prc->seq=seq++;
918                                 if ((ntohl(prc->req.type) & NBD_CMD_MASK_COMMAND) == NBD_CMD_DISC) {
919                                         /* no more to read; don't enqueue as no reply
920                                          * we will disconnect manually at the end
921                                          */
922                                         readtransactionfile = 0;
923                                         free (prc);
924                                 } else {
925                                         dumpcommand("Enqueuing command", prc->req.type);
926                                         rclist_addtail(&txqueue, prc);
927                                 }
928                                 prc = NULL;
929                                 break;
930                         case NBD_REPLY_MAGIC:
931                                 READ_ALL_ERRCHK(logfd,
932                                                 sizeof(magic)+(char *)(&rep),
933                                                 sizeof(struct nbd_reply)-sizeof(magic),
934                                                 err_open,
935                                                 "Could not read transaction log: %s",
936                                                 strerror(errno));
937
938                                 if (rep.error) {
939                                         retval=-1;
940                                         snprintf(errstr, errstr_len, "Transaction log file contained errored transaction");
941                                         goto err_open;
942                                 }
943                                         
944                                 /* We do not need to consume data on a read reply as there is
945                                  * none in the log */
946                                 break;
947                         default:
948                                 retval=-1;
949                                 snprintf(errstr, errstr_len, "Could not measure start time: %08x", magic);
950                                 goto err_open;
951                         }
952                 }
953
954                 /* See if we have a write we can do */
955                 if (FD_ISSET(sock, &wset))
956                 {
957                         if (!(txqueue.head) && !(txbuf.head))
958                                 g_warning("Socket write FD set but we shouldn't have been interested");
959
960                         /* If there is no buffered data, generate some */
961                         if (!(txbuf.head) && (NULL != (prc = txqueue.head)))
962                         {
963                                 rclist_unlink(&txqueue, prc);
964                                 rclist_addtail(&inflight, prc);
965                                 
966                                 if (ntohl(prc->req.magic) != NBD_REQUEST_MAGIC) {
967                                         retval=-1;
968                                         g_warning("Asked to write a reply without a magic number");
969                                         goto err_open;
970                                 }
971                                         
972                                 dumpcommand("Sending command", prc->req.type);
973                                 command = ntohl(prc->req.type);
974                                 from = ntohll(prc->req.from);
975                                 len = ntohl(prc->req.len);
976                                 /* we rewrite the handle as they otherwise may not be unique */
977                                 *((uint64_t*)(prc->req.handle))=getrandomhandle(handlehash);
978                                 g_hash_table_insert(handlehash, prc->req.handle, prc);
979                                 addbuffer(&txbuf, &(prc->req), sizeof(struct nbd_request));
980                                 switch (command & NBD_CMD_MASK_COMMAND) {
981                                 case NBD_CMD_WRITE:
982                                         xfer+=len;
983                                         while (len > 0) {
984                                                 uint64_t blknum = from>>9;
985                                                 char dbuf[512];
986                                                 if (from>=size) {
987                                                         snprintf(errstr, errstr_len, "offset %llx beyond size %llx",
988                                                                  (long long int) from, (long long int)size);
989                                                         goto err_open;
990                                                 }
991                                                 /* work out what we should be writing */
992                                                 makebuf(dbuf, prc->seq, blknum);
993                                                 addbuffer(&txbuf, dbuf, 512);
994                                                 from += 512;
995                                                 len -= 512;
996                                         }
997                                         break;
998                                 case NBD_CMD_READ:
999                                         xfer+=len;
1000                                         break;
1001                                 case NBD_CMD_DISC:
1002                                 case NBD_CMD_FLUSH:
1003                                         break;
1004                                 default:
1005                                         retval=-1;
1006                                         snprintf(errstr, errstr_len, "Incomprehensible command: %08x", command);
1007                                         goto err_open;
1008                                         break;
1009                                 }
1010                                 
1011                                 prc = NULL;
1012                         }
1013
1014                         /* there should be some now */
1015                         if (writebuffer(sock, &txbuf)<0) {
1016                                 retval=-1;
1017                                 snprintf(errstr, errstr_len, "Failed to write to socket buffer: %s", strerror(errno));
1018                                 goto err_open;
1019                         }
1020                         
1021                 }
1022
1023                 /* See if there is a reply to be processed from the socket */
1024                 if(FD_ISSET(sock, &rset)) {
1025                         /* Okay, there's something ready for
1026                          * reading here */
1027                         
1028                         READ_ALL_ERRCHK(sock,
1029                                         &rep,
1030                                         sizeof(struct nbd_reply),
1031                                         err_open,
1032                                         "Could not read from server socket: %s",
1033                                         strerror(errno));
1034                         
1035                         if (rep.magic != htonl(NBD_REPLY_MAGIC)) {
1036                                 retval=-1;
1037                                 snprintf(errstr, errstr_len, "Bad magic from server");
1038                                 goto err_open;
1039                         }
1040                         
1041                         if (rep.error) {
1042                                 retval=-1;
1043                                 snprintf(errstr, errstr_len, "Server errored a transaction");
1044                                 goto err_open;
1045                         }
1046                                 
1047                         prc = g_hash_table_lookup(handlehash, rep.handle);
1048                         if (!prc) {
1049                                 retval=-1;
1050                                 snprintf(errstr, errstr_len, "Unrecognised handle in reply: 0x%llX", *(long long unsigned int*)(rep.handle));
1051                                 goto err_open;
1052                         }
1053                         if (!g_hash_table_remove(handlehash, rep.handle)) {
1054                                 retval=-1;
1055                                 snprintf(errstr, errstr_len, "Could not remove handle from hash: 0x%llX", *(long long unsigned int*)(rep.handle));
1056                                 goto err_open;
1057                         }
1058
1059                         if (prc->req.magic != htonl(NBD_REQUEST_MAGIC)) {
1060                                 retval=-1;
1061                                 snprintf(errstr, errstr_len, "Bad magic in inflight data: %08x", prc->req.magic);
1062                                 goto err_open;
1063                         }
1064                         
1065                         dumpcommand("Processing reply to command", prc->req.type);
1066                         command = ntohl(prc->req.type);
1067                         from = ntohll(prc->req.from);
1068                         len = ntohl(prc->req.len);
1069                         
1070                         switch (command & NBD_CMD_MASK_COMMAND) {
1071                         case NBD_CMD_READ:
1072                                 while (len > 0) {
1073                                         uint64_t blknum = from>>9;
1074                                         char dbuf[512];
1075                                         if (from>=size) {
1076                                                 snprintf(errstr, errstr_len, "offset %llx beyond size %llx",
1077                                                          (long long int) from, (long long int)size);
1078                                                 goto err_open;
1079                                         }
1080                                         READ_ALL_ERRCHK(sock,
1081                                                         dbuf,
1082                                                         512,
1083                                                         err_open,
1084                                                         "Could not read data: %s",
1085                                                         strerror(errno));
1086                                         /* work out what we was written */
1087                                         if (checkbuf(dbuf, blkhash[blknum], blknum))
1088                                         {
1089                                                 retval=-1;
1090                                                 snprintf(errstr, errstr_len, "Bad reply data: seq %08x", blkhash[blknum]);
1091                                                 goto err_open;
1092                                                 
1093                                         }
1094                                         from += 512;
1095                                         len -= 512;
1096                                 }
1097                                 break;
1098                         case NBD_CMD_WRITE:
1099                                 /* subsequent reads should get data with this seq*/
1100                                 while (len > 0) {
1101                                         uint64_t blknum = from>>9;
1102                                         blkhash[blknum]=(uint32_t)(prc->seq);
1103                                         from += 512;
1104                                         len -= 512;
1105                                 }
1106                                 break;
1107                         default:
1108                                 break;
1109                         }
1110                         
1111                         processed++;
1112                         rclist_unlink(&inflight, prc);
1113                         prc->req.magic=0; /* so a duplicate reply is detected */
1114                         free(prc);
1115                 }
1116
1117                 if (!(printer++ % 10000) || !(readtransactionfile || txqueue.numitems || inflight.numitems) )
1118                         printf("%d: Seq %08lld Queued: %08d Inflight: %08d Done: %08lld\n",
1119                                (int)mypid,
1120                                (long long int) seq,
1121                                txqueue.numitems,
1122                                inflight.numitems,
1123                                (long long int) processed);
1124
1125         }
1126
1127         if (gettimeofday(&stop, NULL)<0) {
1128                 retval=-1;
1129                 snprintf(errstr, errstr_len, "Could not measure end time: %s", strerror(errno));
1130                 goto err_open;
1131         }
1132         timespan=timeval_diff_to_double(&stop, &start);
1133         speed=xfer/timespan;
1134         if(speed>1024) {
1135                 speed=speed/1024.0;
1136                 speedchar[0]='K';
1137         }
1138         if(speed>1024) {
1139                 speed=speed/1024.0;
1140                 speedchar[0]='M';
1141         }
1142         if(speed>1024) {
1143                 speed=speed/1024.0;
1144                 speedchar[0]='G';
1145         }
1146         g_message("%d: Integrity %s test complete. Took %.3f seconds to complete, %.3f%sib/s", (int)getpid(), (testflags & TEST_WRITE)?"write":"read", timespan, speed, speedchar);
1147
1148 err_open:
1149         if(close_sock) {
1150                 close_connection(sock, CONNECTION_CLOSE_PROPERLY);
1151         }
1152 err:
1153         if (size && blkhash)
1154                 munmap(blkhash, (size>>9)<<2);
1155
1156         if (blkhashfd != -1)
1157                 close (blkhashfd);
1158
1159         if (logfd != -1)
1160                 close (logfd);
1161
1162         if (blkhashname)
1163                 free(blkhashname);
1164
1165         if (*errstr)
1166                 g_warning("%s",errstr);
1167
1168         g_hash_table_destroy(handlehash);
1169
1170         return retval;
1171 }
1172
1173 typedef int (*testfunc)(gchar*, int, char*, int, char, char, int);
1174
1175 int main(int argc, char**argv) {
1176         gchar *hostname;
1177         long int p = 0;
1178         char* name = NULL;
1179         int sock=0;
1180         int c;
1181         bool want_port = TRUE;
1182         int nonopt=0;
1183         int testflags=0;
1184         testfunc test = throughput_test;
1185
1186         /* Ignore SIGPIPE as we want to pick up the error from write() */
1187         signal (SIGPIPE, SIG_IGN);
1188
1189         if(argc<3) {
1190                 g_message("%d: Not enough arguments", (int)getpid());
1191                 g_message("%d: Usage: %s <hostname> <port>", (int)getpid(), argv[0]);
1192                 g_message("%d: Or: %s <hostname> -N <exportname> [<port>]", (int)getpid(), argv[0]);
1193                 exit(EXIT_FAILURE);
1194         }
1195         logging();
1196         while((c=getopt(argc, argv, "-N:t:owfi"))>=0) {
1197                 switch(c) {
1198                         case 1:
1199                                 switch(nonopt) {
1200                                         case 0:
1201                                                 hostname=g_strdup(optarg);
1202                                                 nonopt++;
1203                                                 break;
1204                                         case 1:
1205                                                 p=(strtol(argv[2], NULL, 0));
1206                                                 if(p==LONG_MIN||p==LONG_MAX) {
1207                                                         g_critical("Could not parse port number: %s", strerror(errno));
1208                                                         exit(EXIT_FAILURE);
1209                                                 }
1210                                                 break;
1211                                 }
1212                                 break;
1213                         case 'N':
1214                                 name=g_strdup(optarg);
1215                                 if(!p) {
1216                                         p = 10809;
1217                                 }
1218                                 want_port = false;
1219                                 break;
1220                         case 't':
1221                                 transactionlog=g_strdup(optarg);
1222                                 break;
1223                         case 'o':
1224                                 test=oversize_test;
1225                                 break;
1226                         case 'w':
1227                                 testflags|=TEST_WRITE;
1228                                 break;
1229                         case 'f':
1230                                 testflags|=TEST_FLUSH;
1231                                 break;
1232                         case 'i':
1233                                 test=integrity_test;
1234                                 break;
1235                 }
1236         }
1237
1238         if(test(hostname, (int)p, name, sock, FALSE, TRUE, testflags)<0) {
1239                 g_warning("Could not run test: %s", errstr);
1240                 exit(EXIT_FAILURE);
1241         }
1242
1243         return 0;
1244 }