OpenSSL Sockets in C++ (part 4)

We've come to it at last! The introduction of OpenSSL to our sockets. The general structure we're going to go for, is adding a make_secure() function to the ssl_socket class that initiations the SSL handshake. After that has occurred, we will want to have all of our read and write functions call their OpenSSL couterparts instead.

First in ssl_socket.h we will need to add a new import, some new fields, and two new functions.

#include <openssl/ssl.h>
SSL* ssl_handle;
SSL_CTX* ssl_context;
/**
 * Check to see if this socket is an encrypted socket
 */
bool is_secure() const { return ssl_handle != nullptr; }

/**
 * Perform the SSL handshake to switch all communications over
 * this socket from unencrypted to encrypted
 */
ssl_socket& make_secure();

We also should make sure to initialize the two new variables to nullptr in the constructor

ssl_socket::ssl_socket(const std::string & _host, const std::string & _port):
    address_info(nullptr),
    connection(-1),
    host(_host),
    port(_port),
    ssl_handle(nullptr),
    ssl_context(nullptr)
{}

Now we need to establish the function that does the handshake. First we need to create an OpenSSL context that configures what version of SSL/TLS we wish to use. Next we need to create a handle, which is much like the file descriptor opened from the socket call earlier.

#include <openssl/err.h>
///...
namespace
{
    std::string get_ssl_error()
    {
	return std::string(ERR_error_string(0, nullptr));
    }
}
///...
ssl_socket& ssl_socket::make_secure()
{
    ssl_context = SSL_CTX_new(TLSv1_client_method());
    if (ssl_context == nullptr)
    {
	throw ssl_socket_exception("Unable to create SSL context " + get_ssl_error());
    }

    // Create an SSL handle that we will use for reading and writing
    ssl_handle = SSL_new(ssl_context);
    if (ssl_handle == nullptr)
    {
	SSL_CTX_free(ssl_context);
	ssl_context = nullptr;
	throw ssl_socket_exception("Unable to create SSL handle " + get_ssl_error());
    }

    ///... more code goes here ...///

    return *this;
}

Now that we have a handle, we need to tie it together with the plain open socket we already have and initiate the handshake.

// Pair the SSL handle with the plain socket
if (!SSL_set_fd(ssl_handle, connection))
{
    SSL_free(ssl_handle);
    SSL_CTX_free(ssl_context);
    ssl_handle = nullptr;
    ssl_context = nullptr;
    throw ssl_socket_exception("Unable to associate SSL and plain socket " + get_ssl_error());
}

// Finally do the SSL handshake
for (int error = SSL_connect(ssl_handle); error != 1; error = SSL_connect(ssl_handle))
{
    switch(SSL_get_error(ssl_handle, error))
    {
      case SSL_ERROR_WANT_READ:
      case SSL_ERROR_WANT_WRITE:
	std::this_thread::sleep_for(std::chrono::milliseconds(200));
	break;
      default:
	SSL_free(ssl_handle);
	SSL_CTX_free(ssl_context);
	ssl_handle = nullptr;
	ssl_context = nullptr;
	throw ssl_socket_exception("Error in SSL handshake: " + get_ssl_error());
	break;
    }
}

We also need to add more to the disconnect() function to free OpenSSL resources.

void ssl_socket::disconnect()
{
    if (ssl_handle != nullptr)
    {
	SSL_shutdown(ssl_handle);
	SSL_free(ssl_handle);
	ssl_handle = nullptr;
    }

    if (ssl_context != nullptr)
    {
	SSL_CTX_free(ssl_context);
	ssl_context = nullptr;
    }

    if (connection >= 0)
    {
	close(connection);
	connection = -1;
    }

    if (address_info != nullptr)
    {
	freeaddrinfo(address_info);
	address_info = nullptr;
    }
}

Now we have the initiation and the cleanup handled, so we need to write the SSL calls for read / write.

ssl_socket& ssl_socket::write(const uint8_t* data, size_t length)
{
    for (const uint8_t* current_position = data, * end = data + length; current_position < end; )
    {
	if (!is_secure())
	{
	    ///... unencrypted handler here ...///
	} else {
	    ssize_t sent = SSL_write(ssl_handle, current_position, end - current_position);
	    if (sent > 0)
	    {
		current_position += sent;
	    } else {
		switch(SSL_get_error(ssl_handle, sent))
		{
		  case SSL_ERROR_ZERO_RETURN: // The socket has been closed on the other end
		    disconnect();
		    throw ssl_socket_exception("The socket disconnected");
		    break;
		  case SSL_ERROR_WANT_READ:
		  case SSL_ERROR_WANT_WRITE:
		    std::this_thread::sleep_for(std::chrono::milliseconds(200));
		    break;
		  default:
		    throw ssl_socket_exception("Error sending socket: " + get_ssl_error());
		    break;
		}
	    }
	}
    }
    return *this;
}
size_t ssl_socket::read(void* buffer, size_t length)
{
    if (!is_secure())
    {
	///... unencrypted handler here ...///
    } else {
	ssize_t read_size = SSL_read(ssl_handle, buffer, length);
	if (read_size > 0)
	{
	    return read_size;
	} else {
	    switch(SSL_get_error(ssl_handle, read_size))
	    {
	      case SSL_ERROR_ZERO_RETURN: // The socket has been closed on the other end
		disconnect();
		return 0;
		break;
	      case SSL_ERROR_WANT_READ:
	      case SSL_ERROR_WANT_WRITE:
		return 0; // Read nothing
		break;
	      default:
		throw ssl_socket_exception("Error reading socket: " + get_ssl_error());
		break;
	    }
	}
    }
}

Now our sockets themselves should be ready for use, though we're missing one key component: initializing and de-initializing the OpenSSL library. For this we're going to introduce a new class that initialzes the OpenSSL library in its constructor and de-initializes it in its destructor.

/**
 * Initializes and de-initializes the OpenSSL library. This should
 * only be instantiated and destroyed once
 */
class openssl_init_handler
{
  public:
    openssl_init_handler()
    {
	SSL_load_error_strings();
	SSL_library_init();
    }
    ~openssl_init_handler()
    {
	ERR_remove_state(0);
	ERR_free_strings();
	EVP_cleanup();
	CRYPTO_cleanup_all_ex_data();
	sk_SSL_COMP_free(SSL_COMP_get_compression_methods());
    }
};

Now we can make it a static variable inside the make_secure function so that it will be constructed the first time any socket because encrypted and destroyed at the end of the program.

ssl_socket& ssl_socket::connect()
{
    static openssl_init_handler _ssl_init_life;
    ///... the rest of the connect code here ...///
}

Great! Now everything should be in place, lets make two minor changes to main.cpp to point it at the https site. In the ssl_socket constructor change "http" to "https" and add a call to make_secure() before writing to the socket

int main(int argc, char** argv)
{
    try
    {
	ssl_socket s(HOST, "https");
	char buffer[BUFFER_SIZE];
	std::string http_query = "GET / HTTP/1.1\r\n"    \
	    "Host: " + std::string(HOST) + "\r\n\r\n";

	s.connect().make_secure().write(http_query);

Now all that is left is compiling and testing:

$ premake4 gmake
Building configurations...
Running action 'gmake'...
Generating Makefile...
Generating sockets_part_4.make...
Done.
$ scan-build make
...
scan-build: No bugs found.
$ valgrind --leak-check=full ./sockets_part_4
...
definitely lost: 24 bytes in 1 blocks

Uh-oh! Seems theres some small part of SSL_library_init I'm not freeing. If anyone out there knows what I'm missing please drop me a line.

See how easy it is to use OpenSSL now that we've established a common interface between plain sockets and encrypted sockets? Now theres plenty of room to build on top of this to create interesting things (for example, auto-reconnecting sockets with their own auto-firing handshakes for protocols like IRC or XMPP). Go forth and create. The source for this post can be found here under the ISC license.

Comments

Comments powered by Disqus