1 /++
2 Internal - Low-level communications.
3 
4 Consider this module the main entry point for the low-level MySQL/MariaDB
5 protocol code. The other modules in `mysql.protocol` are mainly tools
6 to support this module.
7 
8 Previously, the code handling low-level protocol details was scattered all
9 across the library. Such functionality has been factored out into this module,
10 to be kept in one place for better encapsulation and to facilitate further
11 cleanup and refactoring.
12 
13 EXPECT MAJOR CHANGES to this entire `mysql.protocol` sub-package until it
14 eventually settles into what will eventually become a low-level library
15 containing the bulk of the MySQL/MariaDB-specific code. Hang on tight...
16 
17 Next tasks for this sub-package's cleanup:
18 - Reduce this module's reliance on Connection.
19 - Abstract out a PacketStream to clean up getPacket and related functionality.
20 +/
21 module mysql.protocol.comms;
22 
23 import std.algorithm;
24 import std.array;
25 import std.conv;
26 import std.digest.sha;
27 import std.exception;
28 import std.range;
29 import std.variant;
30 
31 import mysql.connection;
32 import mysql.exceptions;
33 import mysql.logger;
34 import mysql.prepared;
35 import mysql.result;
36 
37 import mysql.protocol.constants;
38 import mysql.protocol.extra_types;
39 import mysql.protocol.packet_helpers;
40 import mysql.protocol.packets;
41 import mysql.protocol.sockets;
42 
43 /// Low-level comms code relating to prepared statements.
44 package struct ProtocolPrepared
45 {
46 	import std.conv;
47 	import std.datetime;
48 	import std.variant;
49 	import mysql.types;
50 
51 	static ubyte[] makeBitmap(in Variant[] inParams)
52 	{
53 		size_t bml = (inParams.length+7)/8;
54 		ubyte[] bma;
55 		bma.length = bml;
56 		foreach (i; 0..inParams.length)
57 		{
58 			if(inParams[i].type != typeid(typeof(null)))
59 				continue;
60 			size_t bn = i/8;
61 			size_t bb = i%8;
62 			ubyte sr = 1;
63 			sr <<= bb;
64 			bma[bn] |= sr;
65 		}
66 		return bma;
67 	}
68 
69 	static ubyte[] makePSPrefix(uint hStmt, ubyte flags = 0) pure nothrow
70 	{
71 		ubyte[] prefix;
72 		prefix.length = 14;
73 
74 		prefix[4] = CommandType.STMT_EXECUTE;
75 		hStmt.packInto(prefix[5..9]);
76 		prefix[9] = flags;   // flags, no cursor
77 		prefix[10] = 1; // iteration count - currently always 1
78 		prefix[11] = 0;
79 		prefix[12] = 0;
80 		prefix[13] = 0;
81 
82 		return prefix;
83 	}
84 
85 	static ubyte[] analyseParams(Variant[] inParams, ParameterSpecialization[] psa,
86 		out ubyte[] vals, out bool longData)
87 	{
88 		size_t pc = inParams.length;
89 		ubyte[] types;
90 		types.length = pc*2;
91 		size_t alloc = pc*20;
92 		vals.length = alloc;
93 		uint vcl = 0, len;
94 		int ct = 0;
95 
96 		void reAlloc(size_t n)
97 		{
98 			if (vcl+n < alloc)
99 				return;
100 			size_t inc = (alloc*3)/2;
101 			if (inc <  n)
102 				inc = n;
103 			alloc += inc;
104 			vals.length = alloc;
105 		}
106 
107 		foreach (size_t i; 0..pc)
108 		{
109 			enum UNSIGNED  = 0x80;
110 			enum SIGNED    = 0;
111 			if (psa[i].chunkSize)
112 				longData= true;
113 			if (inParams[i].type == typeid(typeof(null)))
114 			{
115 				types[ct++] = SQLType.NULL;
116 				types[ct++] = SIGNED;
117 				continue;
118 			}
119 			Variant v = inParams[i];
120 			SQLType ext = psa[i].type;
121 			string ts = v.type.toString();
122 			bool isRef;
123 			if (ts[$-1] == '*')
124 			{
125 				ts.length = ts.length-1;
126 				isRef= true;
127 			}
128 
129 			switch (ts)
130 			{
131 				case "bool":
132 				case "const(bool)":
133 				case "immutable(bool)":
134 				case "shared(immutable(bool))":
135 					if (ext == SQLType.INFER_FROM_D_TYPE)
136 						types[ct++] = SQLType.BIT;
137 					else
138 						types[ct++] = cast(ubyte) ext;
139 					types[ct++] = SIGNED;
140 					reAlloc(2);
141 					bool bv = isRef? *(v.get!(const(bool*))): v.get!(const(bool));
142 					vals[vcl++] = 1;
143 					vals[vcl++] = bv? 0x31: 0x30;
144 					break;
145 				case "byte":
146 				case "const(byte)":
147 				case "immutable(byte)":
148 				case "shared(immutable(byte))":
149 					types[ct++] = SQLType.TINY;
150 					types[ct++] = SIGNED;
151 					reAlloc(1);
152 					vals[vcl++] = isRef? *(v.get!(const(byte*))): v.get!(const(byte));
153 					break;
154 				case "ubyte":
155 				case "const(ubyte)":
156 				case "immutable(ubyte)":
157 				case "shared(immutable(ubyte))":
158 					types[ct++] = SQLType.TINY;
159 					types[ct++] = UNSIGNED;
160 					reAlloc(1);
161 					vals[vcl++] = isRef? *(v.get!(const(ubyte*))): v.get!(const(ubyte));
162 					break;
163 				case "short":
164 				case "const(short)":
165 				case "immutable(short)":
166 				case "shared(immutable(short))":
167 					types[ct++] = SQLType.SHORT;
168 					types[ct++] = SIGNED;
169 					reAlloc(2);
170 					short si = isRef? *(v.get!(const(short*))): v.get!(const(short));
171 					vals[vcl++] = cast(ubyte) (si & 0xff);
172 					vals[vcl++] = cast(ubyte) ((si >> 8) & 0xff);
173 					break;
174 				case "ushort":
175 				case "const(ushort)":
176 				case "immutable(ushort)":
177 				case "shared(immutable(ushort))":
178 					types[ct++] = SQLType.SHORT;
179 					types[ct++] = UNSIGNED;
180 					reAlloc(2);
181 					ushort us = isRef? *(v.get!(const(ushort*))): v.get!(const(ushort));
182 					vals[vcl++] = cast(ubyte) (us & 0xff);
183 					vals[vcl++] = cast(ubyte) ((us >> 8) & 0xff);
184 					break;
185 				case "int":
186 				case "const(int)":
187 				case "immutable(int)":
188 				case "shared(immutable(int))":
189 					types[ct++] = SQLType.INT;
190 					types[ct++] = SIGNED;
191 					reAlloc(4);
192 					int ii = isRef? *(v.get!(const(int*))): v.get!(const(int));
193 					vals[vcl++] = cast(ubyte) (ii & 0xff);
194 					vals[vcl++] = cast(ubyte) ((ii >> 8) & 0xff);
195 					vals[vcl++] = cast(ubyte) ((ii >> 16) & 0xff);
196 					vals[vcl++] = cast(ubyte) ((ii >> 24) & 0xff);
197 					break;
198 				case "uint":
199 				case "const(uint)":
200 				case "immutable(uint)":
201 				case "shared(immutable(uint))":
202 					types[ct++] = SQLType.INT;
203 					types[ct++] = UNSIGNED;
204 					reAlloc(4);
205 					uint ui = isRef? *(v.get!(const(uint*))): v.get!(const(uint));
206 					vals[vcl++] = cast(ubyte) (ui & 0xff);
207 					vals[vcl++] = cast(ubyte) ((ui >> 8) & 0xff);
208 					vals[vcl++] = cast(ubyte) ((ui >> 16) & 0xff);
209 					vals[vcl++] = cast(ubyte) ((ui >> 24) & 0xff);
210 					break;
211 				case "long":
212 				case "const(long)":
213 				case "immutable(long)":
214 				case "shared(immutable(long))":
215 					types[ct++] = SQLType.LONGLONG;
216 					types[ct++] = SIGNED;
217 					reAlloc(8);
218 					long li = isRef? *(v.get!(const(long*))): v.get!(const(long));
219 					vals[vcl++] = cast(ubyte) (li & 0xff);
220 					vals[vcl++] = cast(ubyte) ((li >> 8) & 0xff);
221 					vals[vcl++] = cast(ubyte) ((li >> 16) & 0xff);
222 					vals[vcl++] = cast(ubyte) ((li >> 24) & 0xff);
223 					vals[vcl++] = cast(ubyte) ((li >> 32) & 0xff);
224 					vals[vcl++] = cast(ubyte) ((li >> 40) & 0xff);
225 					vals[vcl++] = cast(ubyte) ((li >> 48) & 0xff);
226 					vals[vcl++] = cast(ubyte) ((li >> 56) & 0xff);
227 					break;
228 				case "ulong":
229 				case "const(ulong)":
230 				case "immutable(ulong)":
231 				case "shared(immutable(ulong))":
232 					types[ct++] = SQLType.LONGLONG;
233 					types[ct++] = UNSIGNED;
234 					reAlloc(8);
235 					ulong ul = isRef? *(v.get!(const(ulong*))): v.get!(const(ulong));
236 					vals[vcl++] = cast(ubyte) (ul & 0xff);
237 					vals[vcl++] = cast(ubyte) ((ul >> 8) & 0xff);
238 					vals[vcl++] = cast(ubyte) ((ul >> 16) & 0xff);
239 					vals[vcl++] = cast(ubyte) ((ul >> 24) & 0xff);
240 					vals[vcl++] = cast(ubyte) ((ul >> 32) & 0xff);
241 					vals[vcl++] = cast(ubyte) ((ul >> 40) & 0xff);
242 					vals[vcl++] = cast(ubyte) ((ul >> 48) & 0xff);
243 					vals[vcl++] = cast(ubyte) ((ul >> 56) & 0xff);
244 					break;
245 				case "float":
246 				case "const(float)":
247 				case "immutable(float)":
248 				case "shared(immutable(float))":
249 					types[ct++] = SQLType.FLOAT;
250 					types[ct++] = SIGNED;
251 					reAlloc(4);
252 					float f = isRef? *(v.get!(const(float*))): v.get!(const(float));
253 					ubyte* ubp = cast(ubyte*) &f;
254 					vals[vcl++] = *ubp++;
255 					vals[vcl++] = *ubp++;
256 					vals[vcl++] = *ubp++;
257 					vals[vcl++] = *ubp;
258 					break;
259 				case "double":
260 				case "const(double)":
261 				case "immutable(double)":
262 				case "shared(immutable(double))":
263 					types[ct++] = SQLType.DOUBLE;
264 					types[ct++] = SIGNED;
265 					reAlloc(8);
266 					double d = isRef? *(v.get!(const(double*))): v.get!(const(double));
267 					ubyte* ubp = cast(ubyte*) &d;
268 					vals[vcl++] = *ubp++;
269 					vals[vcl++] = *ubp++;
270 					vals[vcl++] = *ubp++;
271 					vals[vcl++] = *ubp++;
272 					vals[vcl++] = *ubp++;
273 					vals[vcl++] = *ubp++;
274 					vals[vcl++] = *ubp++;
275 					vals[vcl++] = *ubp;
276 					break;
277 				case "std.datetime.date.Date":
278 				case "const(std.datetime.date.Date)":
279 				case "immutable(std.datetime.date.Date)":
280 				case "shared(immutable(std.datetime.date.Date))":
281 
282 				case "std.datetime.Date":
283 				case "const(std.datetime.Date)":
284 				case "immutable(std.datetime.Date)":
285 				case "shared(immutable(std.datetime.Date))":
286 					types[ct++] = SQLType.DATE;
287 					types[ct++] = SIGNED;
288 					Date date = isRef? *(v.get!(const(Date*))): v.get!(const(Date));
289 					ubyte[] da = pack(date);
290 					size_t l = da.length;
291 					reAlloc(l);
292 					vals[vcl..vcl+l] = da[];
293 					vcl += l;
294 					break;
295 				case "std.datetime.TimeOfDay":
296 				case "const(std.datetime.TimeOfDay)":
297 				case "immutable(std.datetime.TimeOfDay)":
298 				case "shared(immutable(std.datetime.TimeOfDay))":
299 
300 				case "std.datetime.date.TimeOfDay":
301 				case "const(std.datetime.date.TimeOfDay)":
302 				case "immutable(std.datetime.date.TimeOfDay)":
303 				case "shared(immutable(std.datetime.date.TimeOfDay))":
304 
305 				case "std.datetime.Time":
306 				case "const(std.datetime.Time)":
307 				case "immutable(std.datetime.Time)":
308 				case "shared(immutable(std.datetime.Time))":
309 					types[ct++] = SQLType.TIME;
310 					types[ct++] = SIGNED;
311 					TimeOfDay time = isRef? *(v.get!(const(TimeOfDay*))): v.get!(const(TimeOfDay));
312 					ubyte[] ta = pack(time);
313 					size_t l = ta.length;
314 					reAlloc(l);
315 					vals[vcl..vcl+l] = ta[];
316 					vcl += l;
317 					break;
318 				case "std.datetime.date.DateTime":
319 				case "const(std.datetime.date.DateTime)":
320 				case "immutable(std.datetime.date.DateTime)":
321 				case "shared(immutable(std.datetime.date.DateTime))":
322 
323 				case "std.datetime.DateTime":
324 				case "const(std.datetime.DateTime)":
325 				case "immutable(std.datetime.DateTime)":
326 				case "shared(immutable(std.datetime.DateTime))":
327 					types[ct++] = SQLType.DATETIME;
328 					types[ct++] = SIGNED;
329 					DateTime dt = isRef? *(v.get!(const(DateTime*))): v.get!(const(DateTime));
330 					ubyte[] da = pack(dt);
331 					size_t l = da.length;
332 					reAlloc(l);
333 					vals[vcl..vcl+l] = da[];
334 					vcl += l;
335 					break;
336 				case "mysql.types.Timestamp":
337 				case "const(mysql.types.Timestamp)":
338 				case "immutable(mysql.types.Timestamp)":
339 				case "shared(immutable(mysql.types.Timestamp))":
340 					types[ct++] = SQLType.TIMESTAMP;
341 					types[ct++] = SIGNED;
342 					Timestamp tms = isRef? *(v.get!(const(Timestamp*))): v.get!(const(Timestamp));
343 					DateTime dt = mysql.protocol.packet_helpers.toDateTime(tms.rep);
344 					ubyte[] da = pack(dt);
345 					size_t l = da.length;
346 					reAlloc(l);
347 					vals[vcl..vcl+l] = da[];
348 					vcl += l;
349 					break;
350 				case "char[]":
351 				case "const(char[])":
352 				case "immutable(char[])":
353 				case "const(char)[]":
354 				case "immutable(char)[]":
355 				case "shared(immutable(char)[])":
356 				case "shared(immutable(char))[]":
357 				case "shared(immutable(char[]))":
358 					if (ext == SQLType.INFER_FROM_D_TYPE)
359 						types[ct++] = SQLType.VARCHAR;
360 					else
361 						types[ct++] = cast(ubyte) ext;
362 					types[ct++] = SIGNED;
363 					const char[] ca = isRef? *(v.get!(const(char[]*))): v.get!(const(char[]));
364 					ubyte[] packed = packLCS(cast(void[]) ca);
365 					reAlloc(packed.length);
366 					vals[vcl..vcl+packed.length] = packed[];
367 					vcl += packed.length;
368 					break;
369 				case "byte[]":
370 				case "const(byte[])":
371 				case "immutable(byte[])":
372 				case "const(byte)[]":
373 				case "immutable(byte)[]":
374 				case "shared(immutable(byte)[])":
375 				case "shared(immutable(byte))[]":
376 				case "shared(immutable(byte[]))":
377 					if (ext == SQLType.INFER_FROM_D_TYPE)
378 						types[ct++] = SQLType.TINYBLOB;
379 					else
380 						types[ct++] = cast(ubyte) ext;
381 					types[ct++] = SIGNED;
382 					const byte[] ba = isRef? *(v.get!(const(byte[]*))): v.get!(const(byte[]));
383 					ubyte[] packed = packLCS(cast(void[]) ba);
384 					reAlloc(packed.length);
385 					vals[vcl..vcl+packed.length] = packed[];
386 					vcl += packed.length;
387 					break;
388 				case "ubyte[]":
389 				case "const(ubyte[])":
390 				case "immutable(ubyte[])":
391 				case "const(ubyte)[]":
392 				case "immutable(ubyte)[]":
393 				case "shared(immutable(ubyte)[])":
394 				case "shared(immutable(ubyte))[]":
395 				case "shared(immutable(ubyte[]))":
396 					if (ext == SQLType.INFER_FROM_D_TYPE)
397 						types[ct++] = SQLType.TINYBLOB;
398 					else
399 						types[ct++] = cast(ubyte) ext;
400 					types[ct++] = SIGNED;
401 					const ubyte[] uba = isRef? *(v.get!(const(ubyte[]*))): v.get!(const(ubyte[]));
402 					ubyte[] packed = packLCS(cast(void[]) uba);
403 					reAlloc(packed.length);
404 					vals[vcl..vcl+packed.length] = packed[];
405 					vcl += packed.length;
406 					break;
407 				case "void":
408 					throw new MYX("Unbound parameter " ~ to!string(i), __FILE__, __LINE__);
409 				default:
410 					throw new MYX("Unsupported parameter type " ~ ts, __FILE__, __LINE__);
411 			}
412 		}
413 		vals.length = vcl;
414 		return types;
415 	}
416 
417 	static void sendLongData(MySQLSocket socket, uint hStmt, ParameterSpecialization[] psa)
418 	{
419 		assert(psa.length <= ushort.max); // parameter number is sent as short
420 		foreach (size_t i, PSN psn; psa)
421 		{
422 			if (!psn.chunkSize) continue;
423 			uint cs = psn.chunkSize;
424 			uint delegate(ubyte[]) dg = psn.chunkDelegate;
425 
426 			ubyte[] chunk;
427 			chunk.length = cs+11;
428 			chunk.setPacketHeader(0 /*each chunk is separate cmd*/);
429 			chunk[4] = CommandType.STMT_SEND_LONG_DATA;
430 			hStmt.packInto(chunk[5..9]); // statement handle
431 			packInto(cast(ushort)i, chunk[9..11]); // parameter number
432 
433 			// byte 11 on is payload
434 			for (;;)
435 			{
436 				uint sent = dg(chunk[11..cs+11]);
437 				if (sent < cs)
438 				{
439 					if (sent == 0)    // data was exact multiple of chunk size - all sent
440 						break;
441 					chunk.length = chunk.length - (cs-sent);     // trim the chunk
442 					sent += 7;        // adjust for non-payload bytes
443 					packInto!(uint, true)(cast(uint)sent, chunk[0..3]);
444 					socket.send(chunk);
445 					break;
446 				}
447 				socket.send(chunk);
448 			}
449 		}
450 	}
451 
452 	static void sendCommand(Connection conn, uint hStmt, PreparedStmtHeaders psh,
453 		Variant[] inParams, ParameterSpecialization[] psa)
454 	{
455 		conn.autoPurge();
456 
457 		ubyte[] packet;
458 		conn.resetPacket();
459 
460 		ubyte[] prefix = makePSPrefix(hStmt, 0);
461 		size_t len = prefix.length;
462 		bool longData;
463 
464 		if (psh.paramCount)
465 		{
466 			ubyte[] one = [ 1 ];
467 			ubyte[] vals;
468 			ubyte[] types = analyseParams(inParams, psa, vals, longData);
469 			ubyte[] nbm = makeBitmap(inParams);
470 			packet = prefix ~ nbm ~ one ~ types ~ vals;
471 		}
472 		else
473 			packet = prefix;
474 
475 		if (longData)
476 			sendLongData(conn._socket, hStmt, psa);
477 
478 		assert(packet.length <= uint.max);
479 		packet.setPacketHeader(conn.pktNumber);
480 		conn.bumpPacket();
481 		conn._socket.send(packet);
482 	}
483 }
484 
485 package(mysql) struct ExecQueryImplInfo
486 {
487 	bool isPrepared;
488 
489 	// For non-prepared statements:
490 	const(char[]) sql;
491 
492 	// For prepared statements:
493 	uint hStmt;
494 	PreparedStmtHeaders psh;
495 	Variant[] inParams;
496 	ParameterSpecialization[] psa;
497 }
498 
499 /++
500 Internal implementation for the exec and query functions.
501 
502 Execute a one-off SQL command.
503 
504 Any result set can be accessed via Connection.getNextRow(), but you should really be
505 using the query function for such queries.
506 
507 Params: ra = An out parameter to receive the number of rows affected.
508 Returns: true if there was a (possibly empty) result set.
509 +/
510 package(mysql) bool execQueryImpl(Connection conn, ExecQueryImplInfo info, out ulong ra)
511 {
512 	scope(failure) conn.kill();
513 
514 	// Send data
515 	if(info.isPrepared)
516 	{
517 		logTrace("prepared SQL: %s", info.hStmt);
518 
519 		ProtocolPrepared.sendCommand(conn, info.hStmt, info.psh, info.inParams, info.psa);
520 	}
521 	else
522 	{
523 		logTrace("exec query: %s", info.sql);
524 
525 		conn.sendCmd(CommandType.QUERY, info.sql);
526 		conn._fieldCount = 0;
527 	}
528 
529 	// Handle response
530 	ubyte[] packet = conn.getPacket();
531 	bool rv;
532 	if (packet.front == ResultPacketMarker.ok || packet.front == ResultPacketMarker.error)
533 	{
534 		conn.resetPacket();
535 		auto okp = OKErrorPacket(packet);
536 
537 		if(okp.error) {
538 			logError("packet error: %s", cast(string) okp.message);
539 		}
540 
541 		enforcePacketOK(okp);
542 		ra = okp.affected;
543 		conn._serverStatus = okp.serverStatus;
544 		conn._insertID = okp.insertID;
545 		rv = false;
546 	}
547 	else
548 	{
549 		// There was presumably a result set
550 		assert(packet.front >= 1 && packet.front <= 250); // Result set packet header should have this value
551 		conn._headersPending = conn._rowsPending = true;
552 		conn._binaryPending = info.isPrepared;
553 		auto lcb = packet.consumeIfComplete!LCB();
554 		assert(!lcb.isNull);
555 		assert(!lcb.isIncomplete);
556 		conn._fieldCount = cast(ushort)lcb.value;
557 		assert(conn._fieldCount == lcb.value);
558 		rv = true;
559 		ra = 0;
560 	}
561 	return rv;
562 }
563 
564 ///ditto
565 package(mysql) bool execQueryImpl(Connection conn, ExecQueryImplInfo info)
566 {
567 	ulong rowsAffected;
568 	return execQueryImpl(conn, info, rowsAffected);
569 }
570 
571 package(mysql) void immediateReleasePrepared(Connection conn, uint statementId)
572 {
573 	scope(failure) conn.kill();
574 
575 	if(conn.closed())
576 		return;
577 
578 	ubyte[9] packet_buf;
579 	ubyte[] packet = packet_buf;
580 	packet.setPacketHeader(0/*packet number*/);
581 	conn.bumpPacket();
582 	packet[4] = CommandType.STMT_CLOSE;
583 	statementId.packInto(packet[5..9]);
584 	conn.purgeResult();
585 	conn._socket.send(packet);
586 	// It seems that the server does not find it necessary to send a response
587 	// for this command.
588 }
589 
590 // Moved here from `struct Row`
591 package(mysql) bool[] consumeNullBitmap(ref ubyte[] packet, uint fieldCount) pure
592 {
593 	uint bitmapLength = calcBitmapLength(fieldCount);
594 	enforce!MYXProtocol(packet.length >= bitmapLength, "Packet too small to hold null bitmap for all fields");
595 	auto bitmap = packet.consume(bitmapLength);
596 	return decodeNullBitmap(bitmap, fieldCount);
597 }
598 
599 // Moved here from `struct Row`
600 private static uint calcBitmapLength(uint fieldCount) pure nothrow
601 {
602 	return (fieldCount+7+2)/8;
603 }
604 
605 // Moved here from `struct Row`
606 // This is to decode the bitmap in a binary result row. First two bits are skipped
607 private bool[] decodeNullBitmap(ubyte[] bitmap, uint numFields) pure nothrow
608 in
609 {
610 	assert(bitmap.length >= calcBitmapLength(numFields),
611 		"bitmap not large enough to store all null fields");
612 }
613 out(result)
614 {
615 	assert(result.length == numFields);
616 }
617 do
618 {
619 	bool[] nulls;
620 	nulls.length = numFields;
621 
622 	// the current byte we are processing for nulls
623 	ubyte bits = bitmap.front();
624 	// strip away the first two bits as they are reserved
625 	bits >>= 2;
626 	// .. and then we only have 6 bits left to process for this byte
627 	ubyte bitsLeftInByte = 6;
628 	foreach(ref isNull; nulls)
629 	{
630 		assert(bitsLeftInByte <= 8);
631 		// processed all bits? fetch new byte
632 		if (bitsLeftInByte == 0)
633 		{
634 			assert(bits == 0, "not all bits are processed!");
635 			assert(!bitmap.empty, "bits array too short for number of columns");
636 			bitmap.popFront();
637 			bits = bitmap.front;
638 			bitsLeftInByte = 8;
639 		}
640 		assert(bitsLeftInByte > 0);
641 		isNull = (bits & 0b0000_0001) != 0;
642 
643 		// get ready to process next bit
644 		bits >>= 1;
645 		--bitsLeftInByte;
646 	}
647 	return nulls;
648 }
649 
650 // Moved here from `struct Row.this`
651 package(mysql) void ctorRow(Connection conn, ref ubyte[] packet, ResultSetHeaders rh, bool binary,
652 	out Variant[] _values, out bool[] _nulls, out string[] _names)
653 in
654 {
655 	assert(rh.fieldCount <= uint.max);
656 }
657 do
658 {
659 	scope(failure) conn.kill();
660 
661 	uint fieldCount = cast(uint)rh.fieldCount;
662 	_values.length = _nulls.length = _names.length = fieldCount;
663 
664 	if(binary)
665 	{
666 		// There's a null byte header on a binary result sequence, followed by some bytes of bitmap
667 		// indicating which columns are null
668 		enforce!MYXProtocol(packet.front == 0, "Expected null header byte for binary result row");
669 		packet.popFront();
670 		_nulls = consumeNullBitmap(packet, fieldCount);
671 	}
672 
673 	foreach(size_t i; 0..fieldCount)
674 	{
675 		if(binary && _nulls[i])
676 		{
677 			_values[i] = null;
678 			_names[i] = rh[i].name;
679 			continue;
680 		}
681 
682 		SQLValue sqlValue;
683 		do
684 		{
685 			FieldDescription fd = rh[i];
686 			_names[i] = fd.name;
687 			sqlValue = packet.consumeIfComplete(fd.type, binary, fd.unsigned, fd.charSet);
688 			// TODO: Support chunk delegate
689 			if(sqlValue.isIncomplete)
690 				packet ~= conn.getPacket();
691 		} while(sqlValue.isIncomplete);
692 		assert(!sqlValue.isIncomplete);
693 
694 		if(sqlValue.isNull)
695 		{
696 			assert(!binary);
697 			assert(!_nulls[i]);
698 			_nulls[i] = true;
699 			_values[i] = null;
700 		}
701 		else
702 		{
703 			_values[i] = sqlValue.value;
704 		}
705 	}
706 }
707 
708 ////// Moved here from Connection /////////////////////////////////
709 
710 package(mysql) ubyte[] getPacket(Connection conn)
711 {
712 	scope(failure) conn.kill();
713 
714 	ubyte[4] header;
715 	conn._socket.read(header);
716 	// number of bytes always set as 24-bit
717 	uint numDataBytes = (header[2] << 16) + (header[1] << 8) + header[0];
718 	enforce!MYXProtocol(header[3] == conn.pktNumber, "Server packet out of order");
719 	conn.bumpPacket();
720 
721 	ubyte[] packet = new ubyte[numDataBytes];
722 	conn._socket.read(packet);
723 	assert(packet.length == numDataBytes, "Wrong number of bytes read");
724 	return packet;
725 }
726 
727 package(mysql) void send(MySQLSocket _socket, const(ubyte)[] packet)
728 in
729 {
730 	assert(packet.length > 4); // at least 1 byte more than header
731 }
732 do
733 {
734 	_socket.write(packet);
735 }
736 
737 package(mysql) void send(MySQLSocket _socket, const(ubyte)[] header, const(ubyte)[] data)
738 in
739 {
740 	assert(header.length == 4 || header.length == 5/*command type included*/);
741 }
742 do
743 {
744 	_socket.write(header);
745 	if(data.length)
746 		_socket.write(data);
747 }
748 
749 package(mysql) void sendCmd(T)(Connection conn, CommandType cmd, const(T)[] data)
750 in
751 {
752 	// Internal thread states. Clients shouldn't use this
753 	assert(cmd != CommandType.SLEEP);
754 	assert(cmd != CommandType.CONNECT);
755 	assert(cmd != CommandType.TIME);
756 	assert(cmd != CommandType.DELAYED_INSERT);
757 	assert(cmd != CommandType.CONNECT_OUT);
758 
759 	// Deprecated
760 	assert(cmd != CommandType.CREATE_DB);
761 	assert(cmd != CommandType.DROP_DB);
762 	assert(cmd != CommandType.TABLE_DUMP);
763 
764 	// cannot send more than uint.max bytes. TODO: better error message if we try?
765 	assert(data.length <= uint.max);
766 }
767 out
768 {
769 	// at this point we should have sent a command
770 	assert(conn.pktNumber == 1);
771 }
772 do
773 {
774 	scope(failure) conn.kill();
775 
776 	conn._lastCommandID++;
777 
778 	if(!conn._socket.connected)
779 	{
780 		if(cmd == CommandType.QUIT)
781 			return; // Don't bother reopening connection just to quit
782 
783 		conn._open = Connection.OpenState.notConnected;
784 		conn.connect(conn._clientCapabilities);
785 	}
786 
787 	conn.autoPurge();
788 
789 	conn.resetPacket();
790 
791 	ubyte[] header;
792 	header.length = 4 /*header*/ + 1 /*cmd*/;
793 	header.setPacketHeader(conn.pktNumber, cast(uint)data.length +1/*cmd byte*/);
794 	header[4] = cmd;
795 	conn.bumpPacket();
796 
797 	conn._socket.send(header, cast(const(ubyte)[])data);
798 }
799 
800 package(mysql) OKErrorPacket getCmdResponse(Connection conn, bool asString = false)
801 {
802 	auto okp = OKErrorPacket(conn.getPacket());
803 	enforcePacketOK(okp);
804 	conn._serverStatus = okp.serverStatus;
805 	return okp;
806 }
807 
808 package(mysql) ubyte[] buildAuthPacket(Connection conn, ubyte[] token)
809 in
810 {
811 	assert(token.length == 20);
812 }
813 do
814 {
815 	ubyte[] packet;
816 	packet.reserve(4/*header*/ + 4 + 4 + 1 + 23 + conn._user.length+1 + token.length+1 + conn._db.length+1);
817 	packet.length = 4 + 4 + 4; // create room for the beginning headers that we set rather than append
818 
819 	// NOTE: we'll set the header last when we know the size
820 
821 	// Set the default capabilities required by the client
822 	conn._cCaps.packInto(packet[4..8]);
823 
824 	// Request a conventional maximum packet length.
825 	1.packInto(packet[8..12]);
826 
827 	packet ~= getDefaultCollation(conn._serverVersion);
828 
829 	// There's a statutory block of zero bytes here - fill them in.
830 	foreach(i; 0 .. 23)
831 		packet ~= 0;
832 
833 	// Add the user name as a null terminated string
834 	foreach(i; 0 .. conn._user.length)
835 		packet ~= conn._user[i];
836 	packet ~= 0; // \0
837 
838 	// Add our calculated authentication token as a length prefixed string.
839 	assert(token.length <= ubyte.max);
840 	if(conn._pwd.length == 0)  // Omit the token if the account has no password
841 		packet ~= 0;
842 	else
843 	{
844 		packet ~= cast(ubyte)token.length;
845 		foreach(i; 0 .. token.length)
846 			packet ~= token[i];
847 	}
848 
849 	// Add the default database as a null terminated string
850 	foreach(i; 0 .. conn._db.length)
851 		packet ~= conn._db[i];
852 	packet ~= 0; // \0
853 
854 	// The server sent us a greeting with packet number 0, so we send the auth packet
855 	// back with the next number.
856 	packet.setPacketHeader(conn.pktNumber);
857 	conn.bumpPacket();
858 	return packet;
859 }
860 
861 package(mysql) ubyte[] makeToken(string password, ubyte[] authBuf)
862 {
863 	auto pass1 = sha1Of(cast(const(ubyte)[])password);
864 	auto pass2 = sha1Of(pass1);
865 
866 	SHA1 sha1;
867 	sha1.start();
868 	sha1.put(authBuf);
869 	sha1.put(pass2);
870 	auto result = sha1.finish();
871 	foreach (size_t i; 0..20)
872 		result[i] = result[i] ^ pass1[i];
873 	return result.dup;
874 }
875 
876 /// Get the next `mysql.result.Row` of a pending result set.
877 package(mysql) Row getNextRow(Connection conn)
878 {
879 	scope(failure) conn.kill();
880 
881 	if (conn._headersPending)
882 	{
883 		conn._rsh = ResultSetHeaders(conn, conn._fieldCount);
884 		conn._headersPending = false;
885 	}
886 	ubyte[] packet;
887 	Row rr;
888 	packet = conn.getPacket();
889 	if(packet.front == ResultPacketMarker.error)
890 		throw new MYXReceived(OKErrorPacket(packet), __FILE__, __LINE__);
891 
892 	if (packet.isEOFPacket())
893 	{
894 		conn._rowsPending = conn._binaryPending = false;
895 		return rr;
896 	}
897 	if (conn._binaryPending)
898 		rr = Row(conn, packet, conn._rsh, true);
899 	else
900 		rr = Row(conn, packet, conn._rsh, false);
901 	//rr._valid = true;
902 	return rr;
903 }
904 
905 package(mysql) void consumeServerInfo(Connection conn, ref ubyte[] packet)
906 {
907 	scope(failure) conn.kill();
908 
909 	conn._sCaps = cast(SvrCapFlags)packet.consume!ushort(); // server_capabilities (lower bytes)
910 	conn._sCharSet = packet.consume!ubyte(); // server_language
911 	conn._serverStatus = packet.consume!ushort(); //server_status
912 	conn._sCaps += cast(SvrCapFlags)(packet.consume!ushort() << 16); // server_capabilities (upper bytes)
913 	conn._sCaps |= SvrCapFlags.OLD_LONG_PASSWORD; // Assumed to be set since v4.1.1, according to spec
914 
915 	enforce!MYX(conn._sCaps & SvrCapFlags.PROTOCOL41, "Server doesn't support protocol v4.1");
916 	enforce!MYX(conn._sCaps & SvrCapFlags.SECURE_CONNECTION, "Server doesn't support protocol v4.1 connection");
917 }
918 
919 package(mysql) ubyte[] parseGreeting(Connection conn)
920 {
921 	scope(failure) conn.kill();
922 
923 	ubyte[] packet = conn.getPacket();
924 
925 	if (packet.length > 0 && packet[0] == ResultPacketMarker.error)
926 	{
927 		auto okp = OKErrorPacket(packet);
928 		enforce!MYX(!okp.error, "Connection failure: " ~ cast(string) okp.message);
929 	}
930 
931 	conn._protocol = packet.consume!ubyte();
932 
933 	conn._serverVersion = packet.consume!string(packet.countUntil(0));
934 	packet.skip(1); // \0 terminated _serverVersion
935 
936 	conn._sThread = packet.consume!uint();
937 
938 	// read first part of scramble buf
939 	ubyte[] authBuf;
940 	authBuf.length = 255;
941 	authBuf[0..8] = packet.consume(8)[]; // scramble_buff
942 
943 	enforce!MYXProtocol(packet.consume!ubyte() == 0, "filler should always be 0");
944 
945 	conn.consumeServerInfo(packet);
946 
947 	packet.skip(1); // this byte supposed to be scramble length, but is actually zero
948 	packet.skip(10); // filler of \0
949 
950 	// rest of the scramble
951 	auto len = packet.countUntil(0);
952 	enforce!MYXProtocol(len >= 12, "second part of scramble buffer should be at least 12 bytes");
953 	enforce(authBuf.length > 8+len);
954 	authBuf[8..8+len] = packet.consume(len)[];
955 	authBuf.length = 8+len; // cut to correct size
956 	enforce!MYXProtocol(packet.consume!ubyte() == 0, "Excepted \\0 terminating scramble buf");
957 
958 	return authBuf;
959 }
960 
961 package(mysql) SvrCapFlags getCommonCapabilities(SvrCapFlags server, SvrCapFlags client) pure
962 {
963 	SvrCapFlags common;
964 	uint filter = 1;
965 	foreach (size_t i; 0..uint.sizeof)
966 	{
967 		bool serverSupport = (server & filter) != 0; // can the server do this capability?
968 		bool clientSupport = (client & filter) != 0; // can we support it?
969 		if(serverSupport && clientSupport)
970 			common |= filter;
971 		filter <<= 1; // check next flag
972 	}
973 	return common;
974 }
975 
976 package(mysql) SvrCapFlags setClientFlags(SvrCapFlags serverCaps, SvrCapFlags capFlags)
977 {
978 	auto cCaps = getCommonCapabilities(serverCaps, capFlags);
979 
980 	// We cannot operate in <4.1 protocol, so we'll force it even if the user
981 	// didn't supply it
982 	cCaps |= SvrCapFlags.PROTOCOL41;
983 	cCaps |= SvrCapFlags.SECURE_CONNECTION;
984 
985 	return cCaps;
986 }
987 
988 package(mysql) void authenticate(Connection conn, ubyte[] greeting)
989 in
990 {
991 	assert(conn._open == Connection.OpenState.connected);
992 }
993 out
994 {
995 	assert(conn._open == Connection.OpenState.authenticated);
996 }
997 do
998 {
999 	auto token = makeToken(conn._pwd, greeting);
1000 	auto authPacket = conn.buildAuthPacket(token);
1001 	conn._socket.send(authPacket);
1002 
1003 	auto packet = conn.getPacket();
1004 	auto okp = OKErrorPacket(packet);
1005 
1006 	if(okp.error) {
1007 		logError("Authentication failure: %s", cast(string) okp.message);
1008 	}
1009 
1010 	enforce!MYX(!okp.error, "Authentication failure: " ~ cast(string) okp.message);
1011 	conn._open = Connection.OpenState.authenticated;
1012 }
1013 
1014 // Register prepared statement
1015 package(mysql) PreparedServerInfo performRegister(Connection conn, const(char[]) sql)
1016 {
1017 	scope(failure) conn.kill();
1018 
1019 	PreparedServerInfo info;
1020 
1021 	conn.sendCmd(CommandType.STMT_PREPARE, sql);
1022 	conn._fieldCount = 0;
1023 
1024 	ubyte[] packet = conn.getPacket();
1025 	if(packet.front == ResultPacketMarker.ok)
1026 	{
1027 		packet.popFront();
1028 		info.statementId    = packet.consume!int();
1029 		conn._fieldCount    = packet.consume!short();
1030 		info.numParams      = packet.consume!short();
1031 
1032 		packet.popFront(); // one byte filler
1033 		info.psWarnings     = packet.consume!short();
1034 
1035 		// At this point the server also sends field specs for parameters
1036 		// and columns if there were any of each
1037 		info.headers = PreparedStmtHeaders(conn, conn._fieldCount, info.numParams);
1038 	}
1039 	else if(packet.front == ResultPacketMarker.error)
1040 	{
1041 		auto error = OKErrorPacket(packet);
1042 		enforcePacketOK(error);
1043 		logCritical("Unexpected failure: %s", cast(string) error.message);
1044 		assert(0); // FIXME: what now?
1045 	}
1046 	else
1047 		assert(0); // FIXME: what now?
1048 
1049 	return info;
1050 }
1051 
1052 /++
1053 Flush any outstanding result set elements.
1054 
1055 When the server responds to a command that produces a result set, it
1056 queues the whole set of corresponding packets over the current connection.
1057 Before that `Connection` can embark on any new command, it must receive
1058 all of those packets and junk them.
1059 
1060 As of v1.1.4, this is done automatically as needed. But you can still
1061 call this manually to force a purge to occur when you want.
1062 
1063 See_Also: $(LINK http://www.mysqlperformanceblog.com/2007/07/08/mysql-net_write_timeout-vs-wait_timeout-and-protocol-notes/)
1064 +/
1065 package(mysql) ulong purgeResult(Connection conn)
1066 {
1067 	scope(failure) conn.kill();
1068 
1069 	conn._lastCommandID++;
1070 
1071 	ulong rows = 0;
1072 	if (conn._headersPending)
1073 	{
1074 		for (size_t i = 0;; i++)
1075 		{
1076 			if (conn.getPacket().isEOFPacket())
1077 			{
1078 				conn._headersPending = false;
1079 				break;
1080 			}
1081 			enforce!MYXProtocol(i < conn._fieldCount,
1082 				text("Field header count (", conn._fieldCount, ") exceeded but no EOF packet found."));
1083 		}
1084 	}
1085 	if (conn._rowsPending)
1086 	{
1087 		for (;;  rows++)
1088 		{
1089 			if (conn.getPacket().isEOFPacket())
1090 			{
1091 				conn._rowsPending = conn._binaryPending = false;
1092 				break;
1093 			}
1094 		}
1095 	}
1096 	conn.resetPacket();
1097 	return rows;
1098 }
1099 
1100 /++
1101 Get a textual report on the server status.
1102 
1103 (COM_STATISTICS)
1104 +/
1105 package(mysql) string serverStats(Connection conn)
1106 {
1107 	conn.sendCmd(CommandType.STATISTICS, []);
1108 	return cast(string) conn.getPacket();
1109 }
1110 
1111 /++
1112 Enable multiple statement commands.
1113 
1114 This can be used later if this feature was not requested in the client capability flags.
1115 
1116 Warning: This functionality is currently untested.
1117 
1118 Params:
1119 	conn = The connection.
1120 	on = Boolean value to turn the capability on or off.
1121 +/
1122 //TODO: Need to test this
1123 package(mysql) void enableMultiStatements(Connection conn, bool on)
1124 {
1125 	scope(failure) conn.kill();
1126 
1127 	ubyte[] t;
1128 	t.length = 2;
1129 	t[0] = on ? 0 : 1;
1130 	t[1] = 0;
1131 	conn.sendCmd(CommandType.STMT_OPTION, t);
1132 
1133 	// For some reason this command gets an EOF packet as response
1134 	auto packet = conn.getPacket();
1135 	enforce!MYXProtocol(packet[0] == 254 && packet.length == 5, "Unexpected response to SET_OPTION command");
1136 }
1137 
1138 private ubyte getDefaultCollation(string serverVersion)
1139 {
1140 	// MySQL >= 5.5.3 supports utf8mb4
1141 	const v = serverVersion
1142 		.splitter('.')
1143 		.map!(a => a.parse!ushort)
1144 		.array;
1145 
1146 	if (v[0] < 5)
1147 		return 33; // Set utf8_general_ci as default
1148 	if (v[1] < 5)
1149 		return 33; // Set utf8_general_ci as default
1150 	if (v[2] < 3)
1151 		return 33; // Set utf8_general_ci as default
1152 
1153 	return 45; // Set utf8mb4_general_ci as default
1154 }
1155 
1156 unittest
1157 {
1158 	assert(getDefaultCollation("5.5.3") == 45);
1159 	assert(getDefaultCollation("5.5.2") == 33);
1160 
1161 	// MariaDB: https://mariadb.com/kb/en/connection/#initial-handshake-packet
1162 	assert(getDefaultCollation("5.5.5-10.0.7-MariaDB") == 45);
1163 }