Backport of pre-connect Hyper-V code
[freerdp-ubuntu-pcb-backport.git] / libfreerdp-core / nego.c
index 7eb810b..64bde26 100644 (file)
@@ -27,6 +27,8 @@
 
 #include "nego.h"
 
+#include "transport.h"
+
 static const char* const NEGO_STATE_STRINGS[] =
 {
        "NEGO_STATE_INITIAL",
@@ -44,6 +46,8 @@ static const char PROTOCOL_SECURITY_STRINGS[3][4] =
        "NLA"
 };
 
+boolean nego_security_connect(rdpNego* nego);
+
 /**
  * Negotiate protocol security and connect.
  * @param nego
@@ -61,7 +65,31 @@ boolean nego_connect(rdpNego* nego)
                else if (nego->enabled_protocols[PROTOCOL_RDP] > 0)
                        nego->state = NEGO_STATE_RDP;
                else
+               {
+                       DEBUG_NEGO("No security protocol is enabled");
                        nego->state = NEGO_STATE_FAIL;
+               }
+
+               if (!nego->security_layer_negotiation_enabled)
+               {
+                       DEBUG_NEGO("Security Layer Negotiation is disabled");
+                       nego->enabled_protocols[PROTOCOL_NLA] = 0;
+                       nego->enabled_protocols[PROTOCOL_TLS] = 0;
+                       nego->enabled_protocols[PROTOCOL_RDP] = 0;
+                       if(nego->state == NEGO_STATE_NLA)
+                               nego->enabled_protocols[PROTOCOL_NLA] = 1;
+                       else if (nego->state == NEGO_STATE_TLS)
+                               nego->enabled_protocols[PROTOCOL_TLS] = 1;
+                       else if (nego->state == NEGO_STATE_RDP)
+                               nego->enabled_protocols[PROTOCOL_RDP] = 1;
+               }
+
+               if(!nego_send_preconnection_pdu(nego))
+               {
+                       DEBUG_NEGO("Failed to send preconnection information");
+                       nego->state = NEGO_STATE_FINAL;
+                       return false;
+               }
        }
 
        do
@@ -93,9 +121,35 @@ boolean nego_connect(rdpNego* nego)
                nego->transport->settings->encryption_level = ENCRYPTION_LEVEL_CLIENT_COMPATIBLE;
        }
 
+       /* finally connect security layer (if not already done) */
+       if(!nego_security_connect(nego))
+       {
+               DEBUG_NEGO("Failed to connect with %s security", PROTOCOL_SECURITY_STRINGS[nego->selected_protocol]);
+               return false;
+       }
+
        return true;
 }
 
+/* connect to selected security layer */
+boolean nego_security_connect(rdpNego* nego)
+{
+       if(!nego->tcp_connected)
+       {
+               nego->security_connected = false;
+       }
+       else if (!nego->security_connected)
+       {
+               if (nego->enabled_protocols[PROTOCOL_NLA] > 0)
+                       nego->security_connected = transport_connect_nla(nego->transport);
+               else if (nego->enabled_protocols[PROTOCOL_TLS] > 0)
+                       nego->security_connected = transport_connect_tls(nego->transport);
+               else if (nego->enabled_protocols[PROTOCOL_RDP] > 0)
+                       nego->security_connected = transport_connect_rdp(nego->transport);
+       }
+       return nego->security_connected;
+}
+
 /**
  * Connect TCP layer.
  * @param nego
@@ -104,21 +158,25 @@ boolean nego_connect(rdpNego* nego)
 
 boolean nego_tcp_connect(rdpNego* nego)
 {
-       if (nego->tcp_connected == 0)
-       {
-               if (transport_connect(nego->transport, nego->hostname, nego->port) == false)
-               {
-                       nego->tcp_connected = 0;
-                       return false;
-               }
-               else
-               {
-                       nego->tcp_connected = 1;
-                       return true;
-               }
-       }
+       if (!nego->tcp_connected)
+               nego->tcp_connected = transport_connect(nego->transport, nego->hostname, nego->port);
+       return nego->tcp_connected;
+}
 
-       return true;
+/**
+ * Connect TCP layer. For direct approach, connect security layer as well.
+ * @param nego
+ * @return
+ */
+
+boolean nego_transport_connect(rdpNego* nego)
+{
+       nego_tcp_connect(nego);
+
+       if (nego->tcp_connected && !nego->security_layer_negotiation_enabled)
+               return nego_security_connect(nego);
+
+       return nego->tcp_connected;
 }
 
 /**
@@ -127,16 +185,66 @@ boolean nego_tcp_connect(rdpNego* nego)
  * @return
  */
 
-int nego_tcp_disconnect(rdpNego* nego)
+int nego_transport_disconnect(rdpNego* nego)
 {
        if (nego->tcp_connected)
                transport_disconnect(nego->transport);
 
        nego->tcp_connected = 0;
+       nego->security_connected = 0;
        return 1;
 }
 
 /**
+ * Send preconnection information if enabled.
+ * @param nego
+ * @return
+ */
+
+boolean nego_send_preconnection_pdu(rdpNego* nego)
+{
+       STREAM* s;
+       uint32 cbSize;
+       UNICONV* uniconv;
+       uint16 cchPCB_times2 = 0;
+       char* wszPCB = NULL;
+
+       if(!nego->send_preconnection_pdu)
+               return true;
+
+       DEBUG_NEGO("Sending preconnection PDU");
+       if(!nego_tcp_connect(nego))
+               return false;
+
+       /* it's easier to always send the version 2 PDU, and it's just 2 bytes overhead */
+       cbSize = PRECONNECTION_PDU_V2_MIN_SIZE;
+       if(nego->preconnection_blob) {
+               uniconv = freerdp_uniconv_new();
+               wszPCB = freerdp_uniconv_out(uniconv, nego->preconnection_blob, &cchPCB_times2);
+               freerdp_uniconv_free(uniconv);
+               cchPCB_times2 += 2; /* zero-termination */
+               cbSize += cchPCB_times2;
+       }
+
+       s = transport_send_stream_init(nego->transport, cbSize);
+       stream_write_uint32(s, cbSize); /* cbSize */
+       stream_write_uint32(s, 0); /* Flags */
+       stream_write_uint32(s, PRECONNECTION_PDU_V2); /* Version */
+       stream_write_uint32(s, nego->preconnection_id); /* Id */
+       stream_write_uint16(s, cchPCB_times2 / 2); /* cchPCB */
+       if(wszPCB)
+       {
+               stream_write(s, wszPCB, cchPCB_times2); /* wszPCB */
+               xfree(wszPCB);
+       }
+
+       if (transport_write(nego->transport, s) < 0)
+               return false;
+
+       return true;
+}
+
+/**
  * Attempt negotiating NLA + TLS security.
  * @param nego
  */
@@ -147,7 +255,7 @@ void nego_attempt_nla(rdpNego* nego)
 
        DEBUG_NEGO("Attempting NLA security");
 
-       if (!nego_tcp_connect(nego))
+       if (!nego_transport_connect(nego))
        {
                nego->state = NEGO_STATE_FAIL;
                return;
@@ -167,7 +275,7 @@ void nego_attempt_nla(rdpNego* nego)
 
        if (nego->state != NEGO_STATE_FINAL)
        {
-               nego_tcp_disconnect(nego);
+               nego_transport_disconnect(nego);
 
                if (nego->enabled_protocols[PROTOCOL_TLS] > 0)
                        nego->state = NEGO_STATE_TLS;
@@ -189,7 +297,7 @@ void nego_attempt_tls(rdpNego* nego)
 
        DEBUG_NEGO("Attempting TLS security");
 
-       if (!nego_tcp_connect(nego))
+       if (!nego_transport_connect(nego))
        {
                nego->state = NEGO_STATE_FAIL;
                return;
@@ -209,7 +317,7 @@ void nego_attempt_tls(rdpNego* nego)
 
        if (nego->state != NEGO_STATE_FINAL)
        {
-               nego_tcp_disconnect(nego);
+               nego_transport_disconnect(nego);
 
                if (nego->enabled_protocols[PROTOCOL_RDP] > 0)
                        nego->state = NEGO_STATE_RDP;
@@ -229,7 +337,7 @@ void nego_attempt_rdp(rdpNego* nego)
 
        DEBUG_NEGO("Attempting RDP security");
 
-       if (!nego_tcp_connect(nego))
+       if (!nego_transport_connect(nego))
        {
                nego->state = NEGO_STATE_FAIL;
                return;
@@ -258,7 +366,7 @@ boolean nego_recv_response(rdpNego* nego)
        STREAM* s = transport_recv_stream_init(nego->transport, 1024);
        if (transport_read(nego->transport, s) < 0)
                return false;
-       return nego_recv(nego->transport, s, nego->transport->recv_extra);
+       return nego_recv(nego->transport, s, nego);
 }
 
 /**
@@ -662,6 +770,18 @@ void nego_set_target(rdpNego* nego, char* hostname, int port)
 }
 
 /**
+ * Enable security layer negotiation.
+ * @param nego pointer to the negotiation structure
+ * @param enable_rdp whether to enable security layer negotiation (true for enabled, false for disabled)
+ */
+
+void nego_set_negotiation_enabled(rdpNego* nego, boolean security_layer_negotiation_enabled)
+{
+       DEBUG_NEGO("Enabling security layer negotiation: %s", security_layer_negotiation_enabled ? "true" : "false");
+       nego->security_layer_negotiation_enabled = security_layer_negotiation_enabled;
+}
+
+/**
  * Enable RDP security protocol.
  * @param nego pointer to the negotiation structure
  * @param enable_rdp whether to enable normal RDP protocol (true for enabled, false for disabled)
@@ -718,3 +838,36 @@ void nego_set_cookie(rdpNego* nego, char* cookie)
 {
        nego->cookie = cookie;
 }
+
+/**
+ * Enable / disable preconnection PDU.
+ * @param nego
+ * @param send_pcpdu
+ */
+
+void nego_set_send_preconnection_pdu(rdpNego* nego, boolean send_pcpdu)
+{
+       nego->send_preconnection_pdu = send_pcpdu;
+}
+
+/**
+ * Set preconnection id.
+ * @param nego
+ * @param id
+ */
+
+void nego_set_preconnection_id(rdpNego* nego, uint32 id)
+{
+       nego->preconnection_id = id;
+}
+
+/**
+ * Set preconnection blob.
+ * @param nego
+ * @param blob
+ */
+
+void nego_set_preconnection_blob(rdpNego* nego, char* blob)
+{
+       nego->preconnection_blob = blob;
+}