r22: Reap zombies; make errormsg bit more helpful when connection is
[nbd.git] / nbd-server.c
1 /*
2  * Network Block Device - server
3  *
4  * Copyright 1996-1998 Pavel Machek, distribute under GPL
5  *  <pavel@atrey.karlin.mff.cuni.cz>
6  * Copyright 2002 Anton Altaparmakov <aia21@cam.ac.uk>
7  *
8  * Version 1.0 - hopefully 64-bit-clean
9  * Version 1.1 - merging enhancements from Josh Parsons, <josh@coombs.anu.edu.au>
10  * Version 1.2 - autodetect size of block devices, thanx to Peter T. Breuer" <ptb@it.uc3m.es>
11  * Version 1.5 - can compile on Unix systems that don't have 64 bit integer
12  *      type, or don't have 64 bit file offsets by defining FS_32BIT
13  *      in compile options for nbd-server *only*. This can be done
14  *      with make FSCHOICE=-DFS_32BIT nbd-server. (I don't have the
15  *      original autoconf input file, or I would make it a configure
16  *      option.) Ken Yap <ken@nlc.net.au>.
17  * Version 1.6 - fix autodetection of block device size and really make 64 bit
18  *      clean on 32 bit machines. Anton Altaparmakov <aia21@cam.ac.uk>
19  * Version 2.0 - Version synchronised with client
20  * Version 2.1 - Reap zombie client processes when they exit. Removed
21  *      (uncommented) the _IO magic, it's no longer necessary.
22  */
23
24 #define VERSION "2.1"
25 #define GIGA (1*1024*1024*1024)
26
27 #include <sys/types.h>
28 #include <sys/socket.h>
29 #include <sys/stat.h>
30 #include <sys/wait.h>           /* wait */
31 #include <signal.h>             /* sigaction */
32 #include <netinet/tcp.h>
33 #include <netinet/in.h>         /* sockaddr_in, htons, in_addr */
34 #include <netdb.h>              /* hostent, gethostby*, getservby* */
35 #include <syslog.h>
36 #include <unistd.h>
37 #include <stdio.h>
38 #include <stdlib.h>
39 #include <string.h>
40 #include <fcntl.h>
41 #include <arpa/inet.h>
42 #include <strings.h>
43
44 //#define _IO(a,b)
45 // #define ISSERVER
46 #define MY_NAME "nbd_server"
47
48 /* Authorization file should contain lines with IP addresses of 
49    clients authorized to use the server. If it does not exist,
50    access is permitted. */
51 #define AUTH_FILE "nbd_server.allow"
52
53 #include "cliserv.h"
54 //#undef _IO
55 /* Deep magic: ioctl.h defines _IO macro (at least on linux) */
56
57
58 /* Debugging macros, now nothing goes to syslog unless you say ISSERVER */
59 #ifdef ISSERVER
60 #define msg2(a,b) syslog(a,b)
61 #define msg3(a,b,c) syslog(a,b,c)
62 #define msg4(a,b,c,d) syslog(a,b,c,d)
63 #else
64 #define msg2(a,b) do { fprintf(stderr,b) ; fputs("\n",stderr) ; } while(0) 
65 #define msg3(a,b,c) do { fprintf(stderr,b,c); fputs("\n",stderr) ; } while(0) 
66 #define msg4(a,b,c,d) do { fprintf(stderr,b,c,d); fputs("\n",stderr) ; } while(0)
67 #endif
68
69
70 #include <sys/ioctl.h>
71 #include <sys/mount.h>          /* For BLKGETSIZE */
72
73 #ifdef  FS_32BIT
74 typedef u32             fsoffset_t;
75 #define htonll          htonl
76 #define ntohll          ntohl
77 #else
78 typedef u64             fsoffset_t;
79 #endif
80
81
82 //#define DODBG
83 #ifdef DODBG
84 #define DEBUG( a ) printf( a )
85 #define DEBUG2( a,b ) printf( a,b )
86 #define DEBUG3( a,b,c ) printf( a,b,c )
87 #else
88 #define DEBUG( a )
89 #define DEBUG2( a,b ) 
90 #define DEBUG3( a,b,c ) 
91 #endif
92
93 #if     defined(HAVE_LLSEEK) && !defined(sun)
94 /* Solaris already has llseek defined in unistd.h */
95 extern long long llseek(unsigned int, long long, unsigned int);
96 #endif
97
98 void serveconnection(int net);
99 void set_peername(int net,char *clientname);
100
101 #define LINELEN 256 
102 char difffilename[256];
103 unsigned int timeout = 0;
104
105 int authorized_client(char *name)
106 /* 0 - authorization refused, 1 - OK 
107   authorization file contains one line per machine, no wildcards
108 */
109 { FILE *f ;
110    
111   char line[LINELEN] ; 
112
113   if ((f=fopen(AUTH_FILE,"r"))==NULL)
114     { msg4(LOG_INFO,"Can't open authorization file %s (%s).",
115            AUTH_FILE,strerror(errno)) ;
116       return 1 ; 
117     }
118   
119   while (fgets(line,LINELEN,f)!=NULL) {
120     if (strncmp(line,name,strlen(name))==0) { fclose(f)  ; return 1 ; }
121   }
122   fclose(f) ;
123   return 0 ;
124 }
125
126
127 inline void readit(int f, void *buf, int len)
128 {
129         int res;
130         while (len > 0) {
131                 DEBUG("*");
132                 if ((res = read(f, buf, len)) <= 0)
133                         err("Read failed: %m");
134                 len -= res;
135                 buf += res;
136         }
137 }
138
139 inline void writeit(int f, void *buf, int len)
140 {
141         int res;
142         while (len > 0) {
143                 DEBUG("+");
144                 if ((res = write(f, buf, len)) <= 0)
145                         err("Write failed: %m");
146                 len -= res;
147                 buf += res;
148         }
149 }
150
151 int port;                       /* Port I'm listening at */
152 char *exportname;               /* File I'm exporting */
153 fsoffset_t exportsize = (fsoffset_t)-1; /* ...and its length */
154 fsoffset_t hunksize = (fsoffset_t)-1;
155 int flags = 0;
156 int export[1024];
157 int difffile=-1 ;
158 u32 difffilelen=0 ; /* number of pages in difffile */
159 u32 *difmap=NULL ;
160 char clientname[256] ;
161
162
163 #define DIFFPAGESIZE 4096 /* diff file uses those chunks */
164
165 #define F_READONLY 1
166 #define F_MULTIFILE 2 
167 #define F_COPYONWRITE 4
168
169 void cmdline(int argc, char *argv[])
170 {
171         int i;
172
173         if (argc < 3) {
174                 printf("This is nbd-server version " VERSION "\n");     
175                 printf("Usage: port file_to_export [size][kKmM] [-r] [-m] [-c] [-a timeout_sec]\n"
176                        "        -r read only\n"
177                        "        -m multiple file\n"
178                        "        -c copy on write\n"
179                        "        -a maximum idle seconds, terminates when idle time exceeded\n"
180                        "        if port is set to 0, stdin is used (for running from inetd)\n"
181                        "        if file_to_export contains '%%s', it is substituted with IP\n"
182                        "                address of machine trying to connect\n" );
183                 exit(0);
184         }
185         port = atoi(argv[1]);
186         for (i = 3; i < argc; i++) {
187                 if (*argv[i] == '-') {
188                         switch (argv[i][1]) {
189                         case 'r':
190                                 flags |= F_READONLY;
191                                 break;
192                         case 'm':
193                                 flags |= F_MULTIFILE;
194                                 hunksize = 1*GIGA;
195                                 break;
196                         case 'c': flags |=F_COPYONWRITE;
197                                 break;
198                         case 'a': 
199                                 if (i+1<argc) {
200                                         timeout = atoi(argv[i+1]);
201                                         i++;
202                                 } else {
203                                         fprintf(stderr, "timeout requires argument\n");
204                                         exit(1);
205                                 }
206                         }
207                 } else {
208                         fsoffset_t es;
209                         int last = strlen(argv[i])-1;
210                         char suffix = argv[i][last];
211                         if (suffix == 'k' || suffix == 'K' ||
212                             suffix == 'm' || suffix == 'M')
213                                 argv[i][last] = '\0';
214                         es = (fsoffset_t)atol(argv[i]);
215                         switch (suffix) {
216                                 case 'm':
217                                 case 'M':  es <<= 10;
218                                 case 'k':
219                                 case 'K':  es <<= 10;
220                                 default :  break;
221                         }
222                         exportsize = es;
223                 }
224         }
225
226         exportname = argv[2];
227 }
228
229 void sigchld_handler(int s)
230 {
231         while(wait(NULL) > 0);
232 }
233
234 void connectme(int port)
235 {
236         struct sockaddr_in addrin;
237         struct sigaction sa;
238         int addrinlen = sizeof(addrin);
239         int net, sock, newpid;
240 #ifndef sun
241         int yes=1;
242 #else
243         char yes='1';
244 #endif
245
246         if ((sock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) < 0)
247                 err("socket: %m");
248
249         /* lose the pesky "Address already in use" error message */
250         if (setsockopt(sock,SOL_SOCKET,SO_REUSEADDR,&yes,sizeof(int)) == -1) {
251                 err("setsockopt");
252         }
253
254         DEBUG("Waiting for connections... bind, ");
255         addrin.sin_family = AF_INET;
256         addrin.sin_port = htons(port);
257         addrin.sin_addr.s_addr = 0;
258         if (bind(sock, (struct sockaddr *) &addrin, addrinlen) < 0)
259                 err("bind: %m");
260         DEBUG("listen, ");
261         if (listen(sock, 1) < 0)
262                 err("listen: %m");
263         DEBUG("accept, ");
264         sa.sa_handler = sigchld_handler;
265         sigemptyset(&sa.sa_mask);
266         sa.sa_flags = SA_RESTART;
267         if(sigaction(SIGCHLD, &sa, NULL) == -1)
268                 err("sigaction: %m");
269         for(;;) { /* infinite loop */
270           if ((net = accept(sock, (struct sockaddr *) &addrin, &addrinlen)) < 0)
271             err("accept: %m");
272
273           set_peername(net,clientname) ;
274           if (!authorized_client(clientname)) {
275             msg2(LOG_INFO,"Unauthorized client") ;
276             close(net) ;
277             continue ;
278           }
279           msg2(LOG_INFO,"Authorized client") ;
280           if ((newpid=fork())<0) {
281             msg3(LOG_INFO,"Could not fork (%s)",strerror(errno)) ;
282             close(net) ;
283             continue ;
284           }
285           if (newpid>0) { /* parent */
286             close(net) ; continue ; }
287           /* child */
288           close(sock) ;
289           msg2(LOG_INFO,"Starting to serve") ;
290           serveconnection(net) ;        
291         }
292 }
293
294 #define SEND writeit( net, &reply, sizeof( reply ));
295 #define ERROR { reply.error = htonl(-1); SEND; reply.error = 0; lastpoint = -1; }
296
297 fsoffset_t lastpoint = (fsoffset_t)-1;
298
299 void maybeseek(int handle, fsoffset_t a)
300 {
301         if (a > exportsize)
302                 err("Can not happen\n");
303         if (lastpoint != a) {
304 #if     defined(HAVE_LLSEEK) && !defined(FS_32BIT)
305                 if (llseek(handle, a, SEEK_SET) < 0)
306 #else
307                 if (lseek(handle, (long)a, SEEK_SET) < 0)
308 #endif
309                         err("Can not seek locally!\n");
310                 lastpoint = a;
311         } else {
312                 DEBUG("@");
313         }
314 }
315
316 void myseek(int handle,fsoffset_t a)
317 {
318 #if HAVE_LLSEEK && !defined(FS_32BIT)
319   if (llseek(handle, a, SEEK_SET) < 0)
320 #else
321   if (lseek(handle, (long)a, SEEK_SET) < 0)
322 #endif 
323     err("Can not seek locally!\n");
324 }
325
326 char pagebuf[DIFFPAGESIZE] ;
327
328
329 int rawexpread(fsoffset_t a, char *buf, int len)
330 {
331   maybeseek(export[a/hunksize], a%hunksize);
332   return (read(export[a/hunksize], buf, len) != len);
333 }
334
335 int expread(fsoffset_t a, char *buf, int len)
336 {
337         int rdlen, offset;
338         fsoffset_t mapcnt, mapl, maph, pagestart;
339  
340   if (flags & F_COPYONWRITE) {
341     DEBUG3("Asked to read %d bytes at %Lu.\n", len, (unsigned long long)a);
342
343     mapl=a/DIFFPAGESIZE ; maph=(a+len-1)/DIFFPAGESIZE ;
344
345     for (mapcnt=mapl;mapcnt<=maph;mapcnt++) {
346       pagestart=mapcnt*DIFFPAGESIZE ;
347       offset=a-pagestart ;
348       rdlen=(len<DIFFPAGESIZE-offset) ? len : DIFFPAGESIZE-offset ;
349       if (difmap[mapcnt]!=(u32)(-1)) { /* the block is already there */
350         DEBUG3("Page %Lu is at %lu\n", (unsigned long long)mapcnt,
351                         (unsigned long)difmap[mapcnt]);
352         myseek(difffile,difmap[mapcnt]*DIFFPAGESIZE+offset) ;
353         if (read(difffile, buf, rdlen) != rdlen) return -1 ;
354       } else { /* the block is not there */
355         DEBUG2("Page %Lu is not here, we read the original one\n",
356                         (unsigned long long)mapcnt) ;
357         if (rawexpread(a,buf,rdlen)) return -1 ;
358       }
359       len-=rdlen ; a+=rdlen ; buf+=rdlen ;
360     }
361   } else return rawexpread(a,buf,len) ;
362   return 0 ;
363 }
364
365 int rawexpwrite(fsoffset_t a, char *buf, int len)
366 {
367         maybeseek(export[a/hunksize], a%hunksize);
368         return (write(export[a/hunksize], buf, len) != len);
369 }
370
371
372 int expwrite(fsoffset_t a, char *buf, int len)
373 {  u32 mapcnt,mapl,maph ; int wrlen,rdlen ; 
374    fsoffset_t pagestart ; int offset ;
375
376   if (flags & F_COPYONWRITE) {
377     DEBUG3("Asked to write %d bytes at %Lu.\n", len, (unsigned long long)a);
378
379     mapl=a/DIFFPAGESIZE ; maph=(a+len-1)/DIFFPAGESIZE ;
380
381     for (mapcnt=mapl;mapcnt<=maph;mapcnt++) {
382       pagestart=mapcnt*DIFFPAGESIZE ;
383       offset=a-pagestart ;
384       wrlen=(len<DIFFPAGESIZE-offset) ? len : DIFFPAGESIZE-offset ;
385
386       if (difmap[mapcnt]!=(u32)(-1)) { /* the block is already there */
387         DEBUG3("Page %Lu is at %lu\n", (unsigned long long)mapcnt,
388                         (unsigned long)difmap[mapcnt]) ;
389         myseek(difffile,difmap[mapcnt]*DIFFPAGESIZE+offset) ;
390         if (write(difffile, buf, wrlen) != wrlen) return -1 ;
391       } else { /* the block is not there */
392         myseek(difffile,difffilelen*DIFFPAGESIZE) ;
393         difmap[mapcnt]=difffilelen++ ;
394         DEBUG3("Page %Lu is not here, we put it at %lu\n",
395                         (unsigned long long)mapcnt,
396                         (unsigned long)difmap[mapcnt]);
397         rdlen=DIFFPAGESIZE ;
398         if (rdlen+pagestart%hunksize>hunksize) 
399           rdlen=hunksize-(pagestart%hunksize) ;
400         if (rawexpread(pagestart,pagebuf,rdlen)) return -1 ;
401         memcpy(pagebuf+offset,buf,wrlen) ;
402         if (write(difffile,pagebuf,DIFFPAGESIZE)!=DIFFPAGESIZE) return -1 ;
403       }                                             
404       len-=wrlen ; a+=wrlen ; buf+=wrlen ;
405     }
406   } else return(rawexpwrite(a,buf,len)); 
407   return 0 ;
408 }
409
410 int mainloop(int net)
411 {
412         struct nbd_request request;
413         struct nbd_reply reply;
414         char zeros[300];
415         int i = 0;
416         fsoffset_t size_host;
417
418         memset(zeros, 0, 290);
419         if (write(net, INIT_PASSWD, 8) < 0)
420                 err("Negotiation failed: %m");
421 #ifndef FS_32BIT
422         cliserv_magic = htonll(cliserv_magic);
423 #endif
424         if (write(net, &cliserv_magic, sizeof(cliserv_magic)) < 0)
425                 err("Negotiation failed: %m");
426         size_host = htonll(exportsize);
427 #ifdef  FS_32BIT
428         if (write(net, zeros, 4) < 0 || write(net, &size_host, 4) < 0)
429 #else
430         if (write(net, &size_host, 8) < 0)
431 #endif
432                 err("Negotiation failed: %m");
433         if (write(net, zeros, 128) < 0)
434                 err("Negotiation failed: %m");
435
436         DEBUG("Entering request loop!\n");
437         reply.magic = htonl(NBD_REPLY_MAGIC);
438         reply.error = 0;
439         while (1) {
440 #define BUFSIZE (1024*1024)
441                 char buf[BUFSIZE];
442                 int len;
443
444 #ifdef DODBG
445                 i++;
446                 printf("%d: ", i);
447 #endif
448
449                 if (timeout) 
450                         alarm(timeout);
451                 readit(net, &request, sizeof(request));
452                 request.from = ntohll(request.from);
453                 request.type = ntohl(request.type);
454
455                 if (request.type==2) { /* Disconnect request */
456                   if (difmap) free(difmap) ;
457                   if (difffile>=0) { 
458                      close(difffile) ; unlink(difffilename) ; }
459                   err("Disconnect request received.") ;
460                 }
461
462                 len = ntohl(request.len);
463
464                 if (request.magic != htonl(NBD_REQUEST_MAGIC))
465                         err("Not enough magic.");
466                 if (len > BUFSIZE)
467                         err("Request too big!");
468 #ifdef DODBG
469                 printf("%s from %Lu (%Lu) len %d, ", request.type ? "WRITE" :
470                                 "READ", (unsigned long long)request.from,
471                                 (unsigned long long)request.from / 512, len);
472 #endif
473                 memcpy(reply.handle, request.handle, sizeof(reply.handle));
474                 if (((request.from + len) > exportsize) ||
475                     ((flags & F_READONLY) && request.type)) {
476                         DEBUG("[RANGE!]");
477                         ERROR;
478                         continue;
479                 }
480                 if (request.type==1) {  /* WRITE */
481                         DEBUG("wr: net->buf, ");
482                         readit(net, buf, len);
483                         DEBUG("buf->exp, ");
484                         if (expwrite(request.from, buf, len)) {
485                                 DEBUG("Write failed: %m" );
486                                 ERROR;
487                                 continue;
488                         }
489                         lastpoint += len;
490                         SEND;
491                         continue;
492                 }
493                 /* READ */
494
495                 DEBUG("exp->buf, ");
496                 if (expread(request.from, buf + sizeof(struct nbd_reply), len)) {
497                         lastpoint = -1;
498                         DEBUG("Read failed: %m");
499                         ERROR;
500                         continue;
501                 }
502                 lastpoint += len;
503
504                 DEBUG("buf->net, ");
505                 memcpy(buf, &reply, sizeof(struct nbd_reply));
506                 writeit(net, buf, len + sizeof(struct nbd_reply));
507                 DEBUG("OK!\n");
508         }
509 }
510
511 char exportname2[1024];
512
513 void set_peername(int net,char *clientname)
514 {
515         struct sockaddr_in addrin;
516         int addrinlen = sizeof( addrin );
517         char *peername ;
518
519         if (getpeername( net, (struct sockaddr *) &addrin, &addrinlen ) < 0)
520                 err("getsockname failed: %m");
521         peername = inet_ntoa(addrin.sin_addr);
522         sprintf(exportname2, exportname, peername);
523
524         msg4(LOG_INFO, "connect from %s, assigned file is %s", peername, exportname2);
525         strncpy(clientname,peername,255) ;
526 }
527
528 fsoffset_t size_autodetect(int export)
529 {
530         fsoffset_t es;
531         u32 es32;
532         struct stat stat_buf;
533         int error;
534
535         DEBUG("looking for export size with lseek SEEK_END\n");
536         es = (fsoffset_t)lseek(export, 0, SEEK_END);
537         if ((signed long long)es > 0LL)
538                 return es;
539
540         DEBUG("looking for export size with fstat\n");
541         stat_buf.st_size = 0;
542         error = fstat(export, &stat_buf);
543         if (!error && stat_buf.st_size > 0)
544                 return (fsoffset_t)stat_buf.st_size;
545
546 #ifdef BLKGETSIZE
547         DEBUG("looking for export size with ioctl BLKGETSIZE\n");
548         if (!ioctl(export, BLKGETSIZE, &es32) && es32) {
549                 es = (fsoffset_t)es32 * (fsoffset_t)512;
550                 return es;
551         }
552 #endif
553         err("Could not find size of exported block device: %m");
554         return (fsoffset_t)-1;
555 }
556
557 int main(int argc, char *argv[])
558 {
559         int net;
560         fsoffset_t i;
561
562         if (sizeof( struct nbd_request )!=28) {
563                 fprintf(stderr,"Bad size of structure. Alignment problems?\n");
564                 exit(-1) ;
565         }
566         logging();
567         cmdline(argc, argv);
568         
569         if (!port) return 1 ;
570         connectme(port); /* serve infinitely */
571         return 0 ;
572 }
573
574
575 void serveconnection(int net) 
576 {   
577   u64 i ;
578
579   for (i=0; i<exportsize; i+=hunksize) {
580     char exportname3[1024];
581     
582     sprintf(exportname3, exportname2, i/hunksize);
583     printf( "Opening %s\n", exportname3 );
584     if ((export[i/hunksize] = open(exportname3, (flags & F_READONLY) ? O_RDONLY : O_RDWR)) == -1)
585       err("Could not open exported file: %m");
586     }
587         
588     if (exportsize == (fsoffset_t)-1) {
589         exportsize = size_autodetect(export[0]);
590     }
591     if (exportsize > ((fsoffset_t)-1 >> 1)) {
592 #ifdef HAVE_LLSEEK
593         if ((exportsize >> 10) > ((fsoffset_t)-1 >> 1))
594                 msg3(LOG_INFO, "size of exported file/device is %LuMB",
595                                 (unsigned long long)(exportsize >> 20));
596         else
597                 msg3(LOG_INFO, "size of exported file/device is %LuKB",
598                                 (unsigned long long)(exportsize >> 10));
599     }
600 #else
601         err("Size of exported file is too big\n");
602     }
603 #endif
604     else
605         msg3(LOG_INFO, "size of exported file/device is %Lu",
606                         (unsigned long long)exportsize);
607
608     if (flags & F_COPYONWRITE) {
609       sprintf(difffilename,"%s-%s-%d.diff",exportname2,clientname,
610               (int)getpid()) ;
611       msg3(LOG_INFO,"About to create map and diff file %s",difffilename) ;
612       difffile=open(difffilename,O_RDWR | O_CREAT | O_TRUNC,0600) ;
613       if (difffile<0) err("Could not create diff file (%m)") ;
614       if ((difmap=calloc(exportsize/DIFFPAGESIZE,sizeof(u32)))==NULL)
615           err("Could not allocate memory") ;
616       for (i=0;i<exportsize/DIFFPAGESIZE;i++) difmap[i]=(u32)-1 ;         
617     }
618     
619     setmysockopt(net);
620       
621     mainloop(net);
622 }