r231: Make client sockets non-blocking, too.
[nbd.git] / nbd-server.c
index a35aafb..281a14f 100644 (file)
@@ -82,6 +82,8 @@
 #include <dirent.h>
 #include <unistd.h>
 #include <getopt.h>
+#include <pwd.h>
+#include <grp.h>
 
 #include <glib.h>
 
 /** Where our config file actually is */
 gchar* config_file_pos;
 
+/** What user we're running as */
+gchar* runuser=NULL;
+/** What group we're running as */
+gchar* rungroup=NULL;
+
 /** Logging macros, now nothing goes to syslog unless you say ISSERVER */
 #ifdef ISSERVER
 #define msg2(a,b) syslog(a,b)
@@ -139,9 +146,22 @@ gchar* config_file_pos;
 #define F_COPYONWRITE 4          /**< flag to tell us a file is exported using
                            copyonwrite */
 #define F_AUTOREADONLY 8  /**< flag to tell us a file is set to autoreadonly */
+#define F_SPARSE 16
 GHashTable *children;
 char pidfname[256]; /**< name of our PID file */
-char default_authname[] = "/etc/nbd_server.allow"; /**< default name of allow file */
+char pidftemplate[256]; /**< template to be used for the filename of the PID file */
+char default_authname[] = SYSCONFDIR "/nbd-server/allow"; /**< default name of allow file */
+
+/**
+ * Types of virtuatlization
+ **/
+typedef enum {
+       VIRT_NONE=0,    /**< No virtualization */
+       VIRT_IPLIT,     /**< Literal IP address as part of the filename */
+       VIRT_IPHASH,    /**< Replacing all dots in an ip address by a / before
+                            doing the same as in IPLIT */
+       VIRT_CIDR,      /**< Every subnet in its own directory */
+} VIRT_STYLE;
 
 /**
  * Variables associated with a server.
@@ -156,6 +176,9 @@ typedef struct {
        unsigned int timeout;/**< how long a connection may be idle
                               (0=forever) */
        int socket;          /**< The socket of this server. */
+       VIRT_STYLE virtstyle;/**< The style of virtualization, if any */
+       uint8_t cidrlen;     /**< The length of the mask when we use
+                                 CIDR-style virtualization */
 } SERVER;
 
 /**
@@ -191,6 +214,7 @@ typedef enum {
        PARAM_STRING,           /**< This parameter is a string */
        PARAM_BOOL,             /**< This parameter is a boolean */
 } PARAM_TYPE;
+
 /**
  * Configuration file values
  **/
@@ -216,9 +240,14 @@ typedef struct {
  * @return 0 - authorization refused, 1 - OK
  **/
 int authorized_client(CLIENT *opts) {
+       const char *ERRMSG="Invalid entry '%s' in authfile '%s', so, refusing all connections.";
        FILE *f ;
-   
        char line[LINELEN]; 
+       char *tmp;
+       struct in_addr addr;
+       struct in_addr client;
+       struct in_addr cltemp;
+       int len;
 
        if ((f=fopen(opts->server->authname,"r"))==NULL) {
                msg4(LOG_INFO,"Can't open authorization file %s (%s).",
@@ -226,14 +255,35 @@ int authorized_client(CLIENT *opts) {
                return 1 ; 
        }
   
+       inet_aton(opts->clientname, &client);
        while (fgets(line,LINELEN,f)!=NULL) {
+               if((tmp=index(line, '/'))) {
+                       if(strlen(line)<=tmp-line) {
+                               msg4(LOG_CRIT, ERRMSG, line, opts->server->authname);
+                               return 0;
+                       }
+                       *(tmp++)=0;
+                       if(inet_aton(line,&addr)) {
+                               msg4(LOG_CRIT, ERRMSG, line, opts->server->authname);
+                               return 0;
+                       }
+                       len=strtol(tmp, NULL, 0);
+                       addr.s_addr>>=32-len;
+                       addr.s_addr<<=32-len;
+                       memcpy(&cltemp,&client,sizeof(client));
+                       cltemp.s_addr>>=32-len;
+                       cltemp.s_addr<<=32-len;
+                       if(addr.s_addr == cltemp.s_addr) {
+                               return 1;
+                       }
+               }
                if (strncmp(line,opts->clientname,strlen(opts->clientname))==0) {
                        fclose(f);
                        return 1;
                }
        }
-       fclose(f) ;
-       return 0 ;
+       fclose(f);
+       return 0;
 }
 
 /**
@@ -245,12 +295,33 @@ int authorized_client(CLIENT *opts) {
  **/
 inline void readit(int f, void *buf, size_t len) {
        ssize_t res;
+       gboolean tried = FALSE;
+
        while (len > 0) {
                DEBUG("*");
-               if ((res = read(f, buf, len)) <= 0)
-                       err("Read failed: %m");
-               len -= res;
-               buf += res;
+               if ((res = read(f, buf, len)) <= 0) {
+                       if(!tried && errno==EAGAIN) {
+                               /* Assume the connection will work some time in
+                                * the future, but don't run away with CPU time
+                                * in case it doesn't */
+                               fd_set set;
+                               struct timeval tv;
+
+                               DEBUG("Read failed, trying again");
+                               tried=TRUE;
+                               FD_ZERO(&set);
+                               FD_SET(f, &set);
+                               tv.tv_sec=30;
+                               tv.tv_usec=0;
+                               select(f+1, &set, NULL, NULL, &tv);
+                       } else {
+                               err("Read failed: %m");
+                       }
+               } else {
+                       len -= res;
+                       buf += res;
+                       tried=FALSE;
+               }
        }
 }
 
@@ -263,12 +334,33 @@ inline void readit(int f, void *buf, size_t len) {
  **/
 inline void writeit(int f, void *buf, size_t len) {
        ssize_t res;
+       gboolean tried=FALSE;
+
        while (len > 0) {
                DEBUG("+");
-               if ((res = write(f, buf, len)) <= 0)
-                       err("Send failed: %m");
-               len -= res;
-               buf += res;
+               if ((res = write(f, buf, len)) <= 0) {
+                       if(!tried && errno==EAGAIN) {
+                               /* Assume the connection will work some time in
+                                * the future, but don't run away with CPU time
+                                * in case it doesn't */
+                               fd_set set;
+                               struct timeval tv;
+
+                               DEBUG("Write failed, trying again");
+                               tried=TRUE;
+                               FD_ZERO(&set);
+                               FD_SET(f, &set);
+                               tv.tv_sec=30;
+                               tv.tv_usec=0;
+                               select(f+1, NULL, &set, NULL, &tv);
+                       } else {
+                               err("Send failed: %m");
+                       }
+               } else {
+                       len -= res;
+                       buf += res;
+                       tried=FALSE;
+               }
        }
 }
 
@@ -278,13 +370,14 @@ inline void writeit(int f, void *buf, size_t len) {
  */
 void usage() {
        printf("This is nbd-server version " VERSION "\n");
-       printf("Usage: port file_to_export [size][kKmM] [-l authorize_file] [-r] [-m] [-c] [-a timeout_sec] [-C configuration file]\n"
+       printf("Usage: port file_to_export [size][kKmM] [-l authorize_file] [-r] [-m] [-c] [-a timeout_sec] [-C configuration file] [-p PID file name]\n"
               "\t-r|--read-only\t\tread only\n"
               "\t-m|--multi-file\t\tmultiple file\n"
               "\t-c|--copy-on-write\tcopy on write\n"
-              "\t-C|--config-file\tspecify an alternat configuration file\n"
+              "\t-C|--config-file\tspecify an alternate configuration file\n"
               "\t-l|--authorize-file\tfile with list of hosts that are allowed to\n\t\t\t\tconnect.\n"
-              "\t-a|--idle-time\t\tmaximum idle seconds; server terminates when\n\t\t\t\tidle time exceeded\n\n"
+              "\t-a|--idle-time\t\tmaximum idle seconds; server terminates when\n\t\t\t\tidle time exceeded\n"
+              "\t-p|--pid-file\t\tspecify a filename to write our PID to\n\n"
               "\tif port is set to 0, stdin is used (for running from inetd)\n"
               "\tif file_to_export contains '%%s', it is substituted with the IP\n"
               "\t\taddress of the machine trying to connect\n" );
@@ -308,6 +401,7 @@ SERVER* cmdline(int argc, char *argv[]) {
                {"authorize-file", required_argument, NULL, 'l'},
                {"idle-time", required_argument, NULL, 'a'},
                {"config-file", required_argument, NULL, 'C'},
+               {"pid-file", required_argument, NULL, 'p'},
                {0,0,0,0}
        };
        SERVER *serve;
@@ -320,7 +414,7 @@ SERVER* cmdline(int argc, char *argv[]) {
        }
        serve=g_new0(SERVER, 1);
        serve->authname = g_strdup(default_authname);
-       while((c=getopt_long(argc, argv, "-a:C:cl:mr", long_options, &i))>=0) {
+       while((c=getopt_long(argc, argv, "-a:C:cl:mrp:", long_options, &i))>=0) {
                switch (c) {
                case 1:
                        /* non-option argument */
@@ -359,6 +453,9 @@ SERVER* cmdline(int argc, char *argv[]) {
                case 'm':
                        serve->flags |= F_MULTIFILE;
                        break;
+               case 'p':
+                       strncpy(pidftemplate, optarg, 256);
+                       break;
                case 'c': 
                        serve->flags |=F_COPYONWRITE;
                        break;
@@ -418,26 +515,35 @@ void remove_server(gpointer s) {
  * @param f the name of the config file
  * @param e a GError. @see CFILE_ERRORS for what error values this function can
  *     return.
- * @return a GHashTable of SERVER* pointers, with the port number as the hash
- *     key. If the config file is empty or does not exist, returns an empty
- *     GHashTable; if the config file contains an error, returns NULL, and
- *     e is set appropriately
+ * @return a Array of SERVER* pointers, If the config file is empty or does not
+ *     exist, returns an empty GHashTable; if the config file contains an
+ *     error, returns NULL, and e is set appropriately
  **/
 GArray* parse_cfile(gchar* f, GError** e) {
        const char* DEFAULT_ERROR = "Could not parse %s in group %s: %s";
        const char* MISSING_REQUIRED_ERROR = "Could not find required value %s in group %s: %s";
        SERVER s;
-       PARAM p[] = {
+       gchar *virtstyle=NULL;
+       PARAM lp[] = {
                { "exportname", TRUE,   PARAM_STRING,   NULL, 0 },
                { "port",       TRUE,   PARAM_INT,      NULL, 0 },
                { "authfile",   FALSE,  PARAM_STRING,   NULL, 0 },
                { "timeout",    FALSE,  PARAM_INT,      NULL, 0 },
                { "filesize",   FALSE,  PARAM_INT,      NULL, 0 },
+               { "virtstyle",  FALSE,  PARAM_STRING,   NULL, 0 },
                { "readonly",   FALSE,  PARAM_BOOL,     NULL, F_READONLY },
                { "multifile",  FALSE,  PARAM_BOOL,     NULL, F_MULTIFILE },
                { "copyonwrite", FALSE, PARAM_BOOL,     NULL, F_COPYONWRITE },
+               { "autoreadonly", FALSE, PARAM_BOOL,    NULL, F_AUTOREADONLY },
+               { "sparse_cow", FALSE,  PARAM_BOOL,     NULL, F_SPARSE },
        };
-       const int p_size=8;
+       const int lp_size=11;
+       PARAM gp[] = {
+               { "user",       FALSE, PARAM_STRING,    &runuser,       0 },
+               { "group",      FALSE, PARAM_STRING,    &rungroup,      0 },
+       };
+       PARAM* p=gp;
+       int p_size=2;
        GKeyFile *cfile;
        GError *err = NULL;
        const char *err_msg=NULL;
@@ -445,9 +551,9 @@ GArray* parse_cfile(gchar* f, GError** e) {
        GArray *retval=NULL;
        gchar **groups;
        gboolean value;
-       gint i,j;
+       gint i;
+       gint j;
 
-       memset(&s, '\0', sizeof(SERVER));
        errdomain = g_quark_from_string("parse_cfile");
        cfile = g_key_file_new();
        retval = g_array_new(FALSE, TRUE, sizeof(SERVER));
@@ -463,13 +569,21 @@ GArray* parse_cfile(gchar* f, GError** e) {
                return NULL;
        }
        groups = g_key_file_get_groups(cfile, NULL);
-       for(i=1;groups[i];i++) {
-               p[0].target=&(s.exportname);
-               p[1].target=&(s.port);
-               p[2].target=&(s.authname);
-               p[3].target=&(s.timeout);
-               p[4].target=&(s.expected_size);
-               p[5].target=p[6].target=p[7].target=&(s.flags);
+       for(i=0;groups[i];i++) {
+               memset(&s, '\0', sizeof(SERVER));
+               lp[0].target=&(s.exportname);
+               lp[1].target=&(s.port);
+               lp[2].target=&(s.authname);
+               lp[3].target=&(s.timeout);
+               lp[4].target=&(s.expected_size);
+               lp[5].target=&(virtstyle);
+               lp[6].target=lp[7].target=lp[8].target=
+                               lp[9].target=lp[10].target=&(s.flags);
+               /* After the [generic] group, start parsing exports */
+               if(i==1) {
+                       p=lp;
+                       p_size=lp_size;
+               } 
                for(j=0;j<p_size;j++) {
                        g_assert(p[j].target != NULL);
                        g_assert(p[j].ptype==PARAM_INT||p[j].ptype==PARAM_STRING||p[j].ptype==PARAM_BOOL);
@@ -492,8 +606,12 @@ GArray* parse_cfile(gchar* f, GError** e) {
                                        value = g_key_file_get_boolean(cfile,
                                                        groups[i],
                                                        p[j].paramname, &err);
-                                       if(!err && value) {
-                                               *((gint*)p[j].target) |= p[j].flagval;
+                                       if(!err) {
+                                               if(value) {
+                                                       *((gint*)p[j].target) |= p[j].flagval;
+                                               } else {
+                                                       *((gint*)p[j].target) &= ~(p[j].flagval);
+                                               }
                                        }
                                        break;
                        }
@@ -516,7 +634,37 @@ GArray* parse_cfile(gchar* f, GError** e) {
                                return NULL;
                        }
                }
-               g_array_append_val(retval, s);
+               if(virtstyle) {
+                       if(!strncmp(virtstyle, "none", 4)) {
+                               s.virtstyle=VIRT_NONE;
+                       } else if(!strncmp(virtstyle, "ipliteral", 9)) {
+                               s.virtstyle=VIRT_IPLIT;
+                       } else if(!strncmp(virtstyle, "iphash", 6)) {
+                               s.virtstyle=VIRT_IPHASH;
+                       } else if(!strncmp(virtstyle, "cidrhash", 8)) {
+                               s.virtstyle=VIRT_CIDR;
+                               if(strlen(virtstyle)<10) {
+                                       g_set_error(e, errdomain, CFILE_VALUE_INVALID, "Invalid value %s for parameter virtstyle in group %s: missing length", virtstyle, groups[i]);
+                                       g_array_free(retval, TRUE);
+                                       g_key_file_free(cfile);
+                                       return NULL;
+                               }
+                               s.cidrlen=strtol(virtstyle+8, NULL, 0);
+                       } else {
+                               g_set_error(e, errdomain, CFILE_VALUE_INVALID, "Invalid value %s for parameter virtstyle in group %s", virtstyle, groups[i]);
+                               g_array_free(retval, TRUE);
+                               g_key_file_free(cfile);
+                               return NULL;
+                       }
+               } else {
+                       s.virtstyle=VIRT_IPLIT;
+               }
+               /* Don't need to free this, it's not our string */
+               virtstyle=NULL;
+               /* Don't append values for the [generic] group */
+               if(i>0) {
+                       g_array_append_val(retval, s);
+               }
        }
        return retval;
 }
@@ -848,7 +996,7 @@ int expwrite(off_t a, char *buf, size_t len, CLIENT *client) {
                        if (write(client->difffile, buf, wrlen) != wrlen) return -1 ;
                } else { /* the block is not there */
                        myseek(client->difffile,client->difffilelen*DIFFPAGESIZE) ;
-                       client->difmap[mapcnt]=client->difffilelen++ ;
+                       client->difmap[mapcnt]=(client->server->flags&F_SPARSE)?mapcnt:client->difffilelen++;
                        DEBUG3("Page %Lu is not here, we put it at %lu\n",
                               (unsigned long long)mapcnt,
                               (unsigned long)(client->difmap[mapcnt]));
@@ -874,7 +1022,7 @@ void negotiate(CLIENT *client) {
        char zeros[300];
        u64 size_host;
 
-       memset(zeros, 0, 290);
+       memset(zeros, '\0', 290);
        if (write(client->net, INIT_PASSWD, 8) < 0)
                err("Negotiation failed: %m");
        cliserv_magic = htonll(cliserv_magic);
@@ -926,8 +1074,8 @@ int mainloop(CLIENT *client) {
 
                if (request.type==NBD_CMD_DISC) {
                        msg2(LOG_INFO, "Disconnect request received.");
-                       if (client->difmap) g_free(client->difmap) ;
-                       if (client->difffile>=0) { 
+                       if (client->server->flags & F_COPYONWRITE) { 
+                               if (client->difmap) g_free(client->difmap) ;
                                close(client->difffile);
                                unlink(client->difffilename);
                                free(client->difffilename);
@@ -1108,7 +1256,8 @@ void serveconnection(CLIENT *client) {
 /**
  * Find the name of the file we have to serve. This will use g_strdup_printf
  * to put the IP address of the client inside a filename containing
- * "%s". That name is then written to client->exportname.
+ * "%s" (in the form as specified by the "virtstyle" option). That name
+ * is then written to client->exportname.
  *
  * @param net A socket connected to an nbd client
  * @param client information about the client. The IP address in human-readable
@@ -1117,14 +1266,40 @@ void serveconnection(CLIENT *client) {
  **/
 void set_peername(int net, CLIENT *client) {
        struct sockaddr_in addrin;
-       int addrinlen = sizeof( addrin );
-       char *peername ;
+       struct sockaddr_in netaddr;
+       size_t addrinlen = sizeof( addrin );
+       char *peername;
+       char *netname;
+       char *tmp;
+       int i;
 
        if (getpeername(net, (struct sockaddr *) &addrin, (socklen_t *)&addrinlen) < 0)
                err("getsockname failed: %m");
-       peername = inet_ntoa(addrin.sin_addr);
-       client->exportname=g_strdup_printf(client->server->exportname, peername);
+       peername = g_strdup(inet_ntoa(addrin.sin_addr));
+       switch(client->server->virtstyle) {
+               case VIRT_NONE:
+                       client->exportname=g_strdup(client->server->exportname);
+                       break;
+               case VIRT_IPHASH:
+                       for(i=0;i<strlen(peername);i++) {
+                               if(peername[i]=='.') {
+                                       peername[i]='/';
+                               }
+                       }
+               case VIRT_IPLIT:
+                       client->exportname=g_strdup_printf(client->server->exportname, peername);
+                       break;
+               case VIRT_CIDR:
+                       memcpy(&netaddr, &addrin, addrinlen);
+                       netaddr.sin_addr.s_addr>>=32-(client->server->cidrlen);
+                       netaddr.sin_addr.s_addr<<=32-(client->server->cidrlen);
+                       netname = inet_ntoa(netaddr.sin_addr);
+                       tmp=g_strdup_printf("%s/%s", netname, peername);
+                       client->exportname=g_strdup_printf(client->server->exportname, tmp);
+                       break;
+       }
 
+       g_free(peername);
        msg4(LOG_INFO, "connect from %s, assigned file is %s", 
             peername, client->exportname);
        client->clientname=g_strdup(peername);
@@ -1152,11 +1327,14 @@ void daemonize(SERVER* serve) {
        if(daemon(0,0)<0) {
                err("daemon");
        }
-       if(serve) {
-               snprintf(pidfname, sizeof(char)*255, "/var/run/nbd-server.%d.pid", serve->port);
-       } else {
-               strncpy(pidfname, "/var/run/nbd-server.pid", sizeof(char)*255);
+       if(!*pidftemplate) {
+               if(serve) {
+                       strncpy(pidftemplate, "/var/run/server.%d.pid", 255);
+               } else {
+                       strncpy(pidftemplate, "/var/run/server.pid", 255);
+               }
        }
+       snprintf(pidfname, 255, pidftemplate, serve ? serve->port : 0);
        pidf=fopen(pidfname, "w");
        if(pidf) {
                fprintf(pidf,"%d\n", (int)getpid());
@@ -1201,7 +1379,7 @@ void setup_serve(SERVER *serve) {
                err("fcntl F_GETFL");
        }
        if (fcntl(serve->socket, F_SETFL, sock_flags | O_NONBLOCK) == -1) {
-               err("fcntl F_SETFL O_NONBLOCK");
+               err("fcntl F_SETFL O_NONBLOCK on server socket");
        }
 
        DEBUG("Waiting for connections... bind, ");
@@ -1223,7 +1401,6 @@ void setup_serve(SERVER *serve) {
        sa.sa_flags = SA_RESTART;
        if(sigaction(SIGTERM, &sa, NULL) == -1)
                err("sigaction: %m");
-       children=g_hash_table_new_full(g_int_hash, g_int_equal, NULL, destroy_pid_t);
 }
 
 /**
@@ -1235,6 +1412,7 @@ void setup_servers(GArray* servers) {
        for(i=0;i<servers->len;i++) {
                setup_serve(&(g_array_index(servers, SERVER, i)));
        }
+       children=g_hash_table_new_full(g_int_hash, g_int_equal, NULL, destroy_pid_t);
 }
 
 /**
@@ -1278,12 +1456,20 @@ int serveloop(GArray* servers) {
                        for(i=0;i<servers->len;i++) {
                                serve=&(g_array_index(servers, SERVER, i));
                                if(FD_ISSET(serve->socket, &rset)) {
+                                       int sock_flags;
+
                                        if ((net=accept(serve->socket, (struct sockaddr *) &addrin, &addrinlen)) < 0)
                                                err("accept: %m");
 
                                        client = g_malloc(sizeof(CLIENT));
                                        client->server=serve;
                                        client->exportsize=OFFT_MAX;
+                                       if ((sock_flags = fcntl(serve->socket, F_GETFL, 0)) == -1) {
+                                               err("fcntl F_GETFL");
+                                       }
+                                       if (fcntl(net, F_SETFL, sock_flags | O_NONBLOCK) == -1) {
+                                               err("fcntl F_SETFL O_NONBLOCK on client socket");
+                                       }
                                        client->net=net;
                                        set_peername(net, client);
                                        if (!authorized_client(client)) {
@@ -1326,6 +1512,24 @@ int serveloop(GArray* servers) {
 }
 
 /**
+ * Set up user-ID and/or group-ID
+ **/
+void dousers(void) {
+       struct passwd *pw;
+       struct group *gr;
+       if(runuser) {
+               pw=getpwnam(runuser);
+               if(setuid(pw->pw_uid)<0)
+                       msg3(LOG_DEBUG, "Could not set UID: %s", strerror(errno));
+       }
+       if(rungroup) {
+               gr=getgrnam(rungroup);
+               if(setgid(gr->gr_gid)<0)
+                       msg3(LOG_DEBUG, "Could not set GID: %s", strerror(errno));
+       }
+}
+
+/**
  * Main entry point...
  **/
 int main(int argc, char *argv[]) {
@@ -1338,6 +1542,8 @@ int main(int argc, char *argv[]) {
                exit(-1) ;
        }
 
+       memset(pidftemplate, '\0', 256);
+
        logging();
        config_file_pos = g_strdup(CFILE);
        serve=cmdline(argc, argv);
@@ -1378,6 +1584,7 @@ int main(int argc, char *argv[]) {
        }
        daemonize(serve);
        setup_servers(servers);
+       dousers();
        serveloop(servers);
        return 0 ;
 }