r30: Winnbd code from Folkert van Heusden <folkert@vanheusden.com>
[nbd.git] / winnbd / nbdsrvr.cpp
1 #include <windows.h>
2 #include <stdio.h>
3
4 int portnr;
5 char *filename;
6
7 int READ(SOCKET sh, UCHAR *whereto, int howmuch)
8 {
9         int pnt = 0;
10
11 #ifdef _DEBUG
12         printf("read: %d bytes requested\n", howmuch);
13 #endif
14
15         while(howmuch > 0)
16         {
17                 int nread = recv(sh, (char *)&whereto[pnt], howmuch, 0);
18                 if (nread == 0)
19                         break;
20                 if (nread == SOCKET_ERROR)
21                 {
22                         fprintf(stderr, "Connection dropped. Error: %d\n", WSAGetLastError());
23                         break;
24                 }
25
26                 pnt += nread;
27                 howmuch -= nread;
28         }
29
30         return pnt;
31 }
32
33 int WRITE(SOCKET sh, UCHAR *wherefrom, int howmuch)
34 {
35         int pnt = 0;
36
37         while(howmuch > 0)
38         {
39                 int nwritten = send(sh, (char *)&wherefrom[pnt], howmuch, 0);
40                 if (nwritten == 0)
41                         break;
42                 if (nwritten == SOCKET_ERROR)
43                 {
44                         fprintf(stderr, "Connection dropped. Error: %d\n", WSAGetLastError());
45                         break;
46                 }
47
48                 pnt += nwritten;
49                 howmuch -= nwritten;
50         }
51
52         return pnt;
53 }
54
55 BOOL getu32(SOCKET sh, ULONG *val)
56 {
57         UCHAR buffer[4];
58
59         if (READ(sh, buffer, 4) != 4)
60                 return FALSE;
61
62         *val = (buffer[0] << 24) + (buffer[1] << 16) + (buffer[2] << 8) + (buffer[3]);
63
64         return TRUE;
65 }
66
67 BOOL putu32(SOCKET sh, ULONG value)
68 {
69         UCHAR buffer[4];
70
71         buffer[0] = (value >> 24) & 255;
72         buffer[1] = (value >> 16) & 255;
73         buffer[2] = (value >>  8) & 255;
74         buffer[3] = (value      ) & 255;
75
76         if (WRITE(sh, buffer, 4) != 4)
77                 return FALSE;
78         else
79                 return TRUE;
80 }
81
82 DWORD WINAPI draad(LPVOID data)
83 {
84         SOCKET sockh = (SOCKET)data;
85         HANDLE fh;
86         char neg = 1;
87
88         // open file 'filename'
89         fh = CreateFile(filename, GENERIC_READ | GENERIC_WRITE, FILE_SHARE_READ | FILE_SHARE_WRITE, NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
90         if (fh == INVALID_HANDLE_VALUE)
91         {
92                 fprintf(stderr, "Error opening file %s: %d\n", filename, GetLastError());
93         }
94
95         for(;fh != INVALID_HANDLE_VALUE;)
96         {
97                 UCHAR handle[9];
98                 ULONG magic, from, len, type, dummy;
99
100                 /* negotiating time? */
101                 if (neg)
102                 {
103                                 printf("Negotiating...\n");
104                                 if (WRITE(sockh, (unsigned char *)"NBDMAGIC", 8) != 8)
105                                 {
106                                         fprintf(stderr, "Failed to send magic string\n");
107                                         break;
108                                 }
109
110                                 // some other magic value
111                                 unsigned char magic[8];
112                                 magic[0] = 0x00;
113                                 magic[1] = 0x00;
114                                 magic[2] = 0x42;
115                                 magic[3] = 0x02;
116                                 magic[4] = 0x81;
117                                 magic[5] = 0x86;
118                                 magic[6] = 0x12;
119                                 magic[7] = 0x53;
120                                 if (WRITE(sockh, magic, 8) != 8)
121                                 {
122                                         fprintf(stderr, "Failed to send 2nd magic string\n");
123                                         break;
124                                 }
125
126                                 // send size of file
127                                 unsigned char exportsize[8];
128                                 DWORD fsize = GetFileSize(fh, NULL);
129                                 if (fsize == 0xFFFFFFFF)
130                                 {
131                                         fprintf(stderr, "Failed to get filesize. Error: %d\n", GetLastError());
132                                         break;
133                                 }
134                                 exportsize[7] = (fsize      ) & 255;
135                                 exportsize[6] = (fsize >>  8) & 255;
136                                 exportsize[5] = (fsize >> 16) & 255;
137                                 exportsize[4] = (fsize >> 24) & 255;
138                                 exportsize[3] = (fsize >> 32) & 255;
139                                 exportsize[2] = (fsize >> 40) & 255;
140                                 exportsize[1] = (fsize >> 48) & 255;
141                                 exportsize[0] = (fsize >> 56) & 255;
142 #ifdef _DEBUG
143                                 printf("File is %ld bytes\n", fsize);
144 #endif
145                                 if (WRITE(sockh, exportsize, 8) != 8)
146                                 {
147                                         fprintf(stderr, "Failed to send filesize\n");
148                                         break;
149                                 }
150                                 
151                                 // send a couple of zeros */
152                                 unsigned char buffer[128];
153                                 memset(buffer, 0x00, 128);
154                                 if (WRITE(sockh, buffer, 128) != 128)
155                                 {
156                                         fprintf(stderr, "Failed to send a couple of 0x00s\n");
157                                         break;
158                                 }
159
160                                 printf("Started!\n");
161                                 neg = 0;
162                 }
163
164                 if (getu32(sockh, &magic) == FALSE ||   // 0x12560953
165                         getu32(sockh, &type)  == FALSE ||       // 0=read,1=write
166                         READ(sockh, handle, 8) != 8    ||       // handle
167                         getu32(sockh, &dummy) == FALSE ||       // ... high word of offset
168                         getu32(sockh, &from)  == FALSE ||       // offset
169                         getu32(sockh, &len)   == FALSE)         // length
170                 {
171                         fprintf(stderr, "Failed to read from socket\n");
172                         break;
173                 }
174
175 #ifdef _DEBUG
176                 handle[8] = 0x00;
177                 printf("Magic:    %lx\n", magic);
178                 printf("Offset:   %ld\n", from);
179                 printf("Len:      %ld\n", len);
180                 printf("Handle:   %s\n", handle);
181                 printf("Req.type: %ld (%s)\n\n", type, type?"write":"read");
182 #endif
183
184                 // verify protocol
185                 if (magic != 0x25609513)
186                 {
187                         fprintf(stderr, "Unexpected protocol version! (got: %lx, expected: 0x25609513)\n", magic);
188                         break;
189                 }
190
191                 // seek to 'from'
192                 if (SetFilePointer(fh, from, NULL, FILE_BEGIN) == 0xFFFFFFFF)
193                 {
194                         fprintf(stderr, "Error seeking in file %s to position %d: %d\n", filename, from, GetLastError());
195                         break;
196                 }
197
198                 if (type == 1)  // write
199                 {
200                         while(len > 0)
201                         {
202                                 DWORD dummy;
203                                 UCHAR buffer[32768];
204                                 // read from socket
205                                 int nb = recv(sockh, (char *)buffer, min(len, 32768), 0);
206                                 if (nb == 0)
207                                         break;
208
209                                 // write to file;
210                                 if (WriteFile(fh, buffer, nb, &dummy, NULL) == 0)
211                                 {
212                                         fprintf(stderr, "Failed to write to %s: %d\n", filename, GetLastError());
213                                         break;
214                                 }
215                                 if (dummy != nb)
216                                 {
217                                         fprintf(stderr, "Failed to write to %s: %d (written: %d, requested to write: %d)\n", filename, GetLastError(), dummy, nb);
218                                         break;
219                                 }
220
221                                 len -= nb;
222                         }
223                         if (len)        // connection was closed
224                         {
225                                 fprintf(stderr, "Connection was dropped while receiving data\n");
226                                 break;
227                         }
228
229                         // send 'ack'
230                         if (putu32(sockh, 0x67446698) == FALSE ||
231                                 putu32(sockh, 0) == FALSE ||
232                                 WRITE(sockh, handle, 8) != 8)
233                         {
234                                 fprintf(stderr, "Failed to send through socket\n");
235                                 break;
236                         }
237                 }
238                 else if (type == 0)
239                 {
240                         // send 'ack'
241                         if (putu32(sockh, 0x67446698) == FALSE ||
242                                 putu32(sockh, 0) == FALSE ||
243                                 WRITE(sockh, handle, 8) != 8)
244                         {
245                                 fprintf(stderr, "Failed to send through socket\n");
246                                 break;
247                         }
248
249                         while(len > 0)
250                         {
251                                 DWORD dummy;
252                                 UCHAR buffer[32768];
253                                 int nb = min(len, 32768);
254                                 int pnt = 0;
255
256                                 // read nb to buffer;
257                                 if (ReadFile(fh, buffer, nb, &dummy, NULL) == 0)
258                                 {
259                                         fprintf(stderr, "Failed to read from %s: %d\n", filename, GetLastError());
260                                         break;
261                                 }
262                                 if (dummy != nb)
263                                 {
264                                         fprintf(stderr, "Failed to read from %s: %d\n", filename, GetLastError());
265                                         break;
266                                 }
267
268                                 // send through socket
269                                 if (WRITE(sockh, buffer, nb) != nb) // connection was closed
270                                 {
271                                         fprintf(stderr, "Connection dropped while sending block\n");
272                                         break;
273                                 }
274
275                                 len -= nb;
276                         }
277                         if (len)        // connection was closed
278                                 break;
279                 }
280                 else
281                 {
282                         printf("Unexpected commandtype: %d\n", type);
283                         break;
284                 }
285         }
286
287         // close file
288         if (CloseHandle(fh) == 0)
289         {
290                 fprintf(stderr, "Failed to close handle: %d\n", GetLastError());
291         }
292
293         closesocket(sockh);
294
295         ExitThread(0);
296
297         return 0;
298 }
299         
300 int main(int argc, char *argv[])
301 {
302         SOCKET newconnh;
303         WSADATA WSAData;
304
305         printf("nbdsrvr v0.1, (C) 2003 by folkert@vanheusden.com\n");
306
307         if (argc != 3)
308         {
309                 fprintf(stderr, "Usage: %s file portnr\n", argv[0]);
310                 return 1;
311         }
312         filename = argv[1];
313         portnr = atoi(argv[2]);
314
315         // initialize WinSock library
316         (void)WSAStartup(0x101, &WSAData); 
317         
318         // create listener socket
319         newconnh= socket(AF_INET, SOCK_STREAM, 0);
320         if (newconnh == INVALID_SOCKET)
321                 return -1;
322
323         // bind
324         struct sockaddr_in      ServerAddr;
325         int     ServerAddrLen;
326         ServerAddrLen = sizeof(ServerAddr);
327         memset((char *)&ServerAddr, '\0', ServerAddrLen);
328         ServerAddr.sin_family = AF_INET;
329         ServerAddr.sin_addr.s_addr = htonl(INADDR_ANY);
330         ServerAddr.sin_port = htons(portnr);
331         if (bind(newconnh, (struct sockaddr *)&ServerAddr, ServerAddrLen) == -1)
332                 return -1;
333
334         // listen
335         if (listen(newconnh, 5) == -1)
336                 return -1;
337
338         for(;;)
339         {
340                 SOCKET clienth;
341                 struct sockaddr_in      clientaddr;
342                 int     clientaddrlen;
343
344                 clientaddrlen = sizeof(clientaddr);
345
346                 /* accept a connection */
347                 clienth = accept(newconnh, (struct sockaddr *)&clientaddr, &clientaddrlen);
348
349                 if (clienth != INVALID_SOCKET)
350                 {
351                         printf("Connection made with %s\n", inet_ntoa(clientaddr.sin_addr));
352
353                         DWORD tid;
354                         HANDLE th = CreateThread(NULL, 0, draad, (void *)clienth, 0, &tid);
355                 }
356         }
357
358         return 0;
359 }