changeset 128:21a5b37c6f35

Improved handshake: session resumed and multithreaded
author Sebastien Decugis <sdecugis@nict.go.jp>
date Mon, 18 Aug 2008 14:54:40 +0900
parents 6e655f5ae656
children f185f65e213f
files extensions/sec_tls_gnutls/gnutls_sctp_wrapper.c
diffstat 1 files changed, 215 insertions(+), 44 deletions(-) [+]
line wrap: on
line diff
--- a/extensions/sec_tls_gnutls/gnutls_sctp_wrapper.c	Fri Aug 15 17:53:04 2008 +0900
+++ b/extensions/sec_tls_gnutls/gnutls_sctp_wrapper.c	Mon Aug 18 14:54:40 2008 +0900
@@ -113,6 +113,13 @@
 	int		count;		/* The number of messages in the queue */
 } bq_t;
 
+/* To hold session state for session resuming */
+typedef struct {
+	char 		key[GNUTLS_MAX_SESSION_ID];
+	size_t 		keylen;
+	gnutls_datum_t 	data;
+} sr_t;
+
 /* forward declaration */
 struct _tls_sctpw_ctx;
 
@@ -122,8 +129,9 @@
 	struct _tls_sctpw_ctx	*ctx;	 /* Pointer to the head of the context object */
 	bq_t		 	 queue;	 /* buffer queue of the encrypted chunks of data received and demux'd */
 	gnutls_session_t 	 session;/* The gnutls session corresponding to this pair of streams. */
+	sr_t			*resume; /* pointer to data to resume a session */
 	pthread_t		 th;	 /* After handshake, the thread that receives and decrypts the data */
-	int			 th_run; /* Is the thread created? */
+	int			 th_st;  /* Is the thread created? 0: no; 1: yes, running; 2: yes, terminated */
 } pairinfo_t;
 
 /* State of the wrapper context */
@@ -496,19 +504,13 @@
 		if (ret < 0) {
 			/* The socket is no more valid */
 			TRACE_DEBUG(INFO, "tls_sctp_recv failed");
-			
-			/* we should signal a control thread here */
-			
-			return NULL;
+			goto error;
 		}
 		
 		if ((ret == 0) && ((event == RCV_ERROR) || (event == RCV_SHUTDOWN))) {
 			/* The socket is no more valid */
 			TRACE_DEBUG(INFO, "tls_sctp_recv received a termination notification");
-			
-			/* we should signal a control thread here */
-			
-			return NULL;
+			goto error;
 		}
 		
 		if (ret == 0)
@@ -524,16 +526,125 @@
 		reti = bq_post( &(ctx->pairs[new->buf.streamid].queue), new);
 		if (reti != 0) {
 			TRACE_DEBUG(INFO, "bq_post failed");
-			
-			/* we should signal a control thread here */
-			
-			return NULL;
+			goto error;
 		}
 	}
+error:
+				
+	/* we should signal a control thread here */
+	
+	ctx->hdr.state = SCTPW_CTX_BROKEN;
+
 	return NULL;
+
+}
+
+/* For session resuming in the server */
+static int sr_store (void *dbf, gnutls_datum_t key, gnutls_datum_t data)
+{
+	sr_t *sr = (sr_t *)dbf;
+	int ret = 0;
+	
+	/* We don't really do anything here, the session we are interested in is already stored. */
+	/* we just return 0 if it is the same data, and -1 otherwise */
+	
+	if ((key.size != sr->keylen) ||	memcmp (key.data, sr->key, key.size))
+		ret = -1;
+		
+	if ((data.size != sr->data.size) || memcmp (data.data, sr->data.data, data.size))
+		ret = -1;
+	
+	return ret;
+}
+static int sr_remove (void *dbf, gnutls_datum_t key)
+{
+	sr_t *sr = (sr_t *)dbf;
+	
+	/* Again, we don't do anything */
+	return -1;
+}
+static gnutls_datum_t sr_fetch (void *dbf, gnutls_datum_t key)
+{
+	gnutls_datum_t res = { NULL, 0 };
+	sr_t *sr = (sr_t *)dbf;
+	
+	/* This is the only db function we use, to return the session of pair[0] if the id matches */
+	
+	/* Check if the session matches the session of pair 0; we don't resume other sessions */
+	if ((key.size == sr->keylen) && memcmp (key.data, sr->key, key.size) == 0)
+	{
+		/* It's the same id, so we can actually resume */
+		res.size = sr->data.size;
+
+		res.data = gnutls_malloc (res.size);
+		if (res.data == NULL)
+			return res;
+
+		memcpy (res.data, sr->data.data, res.size);
+	}
+	return res;
 }
 
-/* Receive data from a pair of streams, once TLS handshake is done */
+/* Perform session resuming */
+static void * sctpw_sess_resume (void * arg)
+{
+	int ret = 0;
+	pairinfo_t * pair = (pairinfo_t *)arg;
+	ctx_hdr_t * hdr;
+
+	
+	hdr = &(pair->ctx->hdr);
+	
+	if (hdr->side) {
+		/* We are client side */
+		gnutls_transport_ptr_t ptr_bkp;
+		
+		/* Backup the transport pointer of the session */
+		ptr_bkp = gnutls_transport_get_ptr(pair->session);
+
+		/* Use the data from the session on stream 0 to perform resuming */
+		ret = gnutls_session_set_data(pair->session, pair->resume->data.data, pair->resume->data.size);
+		if (ret != GNUTLS_E_SUCCESS) {
+			TRACE_DEBUG(INFO, "gnutls_session_set_data failed on stream %d: %s", pair->stream, gnutls_strerror(ret));
+			goto error;
+		} else {
+			TRACE_DEBUG(FULL, "data set: %llx...(%d bytes)", *(long long *)pair->resume->data.data, pair->resume->data.size);
+		}
+		
+		/* Now restore the transport pointer */
+		gnutls_transport_set_ptr(pair->session, ptr_bkp);
+
+	} else {
+		/* We are server side */
+		
+		/* The resume will be trigged during handshake. We provide the "fetch" function for this */
+		gnutls_db_set_retrieve_function(pair->session, sr_fetch);
+		
+		/* The following functions are not useful but gnutls seems to need them... */
+		gnutls_db_set_remove_function (pair->session, sr_remove);
+		gnutls_db_set_store_function (pair->session, sr_store);
+		
+		/* The functions need this argument: */
+		gnutls_db_set_ptr (pair->session, (void *)pair->resume);
+	}
+	
+	/* And finaly perform the handshake */
+	TRACE_DEBUG(FULL, "Starting handshake on stream %d", pair->stream);
+	ret = gnutls_handshake( pair->session );
+	if (ret != GNUTLS_E_SUCCESS) {
+		TRACE_DEBUG(INFO, "Resumed handshake failed on stream %d: %s", pair->stream, gnutls_strerror(ret));
+		goto error;
+	}
+	
+	TRACE_DEBUG(FULL, "Handshake complete on stream %d, %s.", pair->stream, gnutls_session_is_resumed(pair->session) ? "session resumed" : "session NOT resumed");
+	
+	return (void *)0;
+error:
+	return (void *)-1;
+}
+	
+
+/* Receive data from a pair of streams, once TLS handshake is done.  One thread per pair of streams */
 #define BUF_SIZE	2048
 static void * sctpw_recv_data (void * arg)
 {
@@ -564,7 +675,7 @@
 		
 		ret = gnutls_record_recv ( pair->session, new->buf.data, BUF_SIZE );
 		if (ret <= 0) {
-			TRACE_DEBUG(INFO, "gnutls_record_recv failed (%d)", ret);
+			TRACE_DEBUG(INFO, "gnutls_record_recv (stream %d) failed (%d)", pair->stream, ret);
 			goto term;
 		}
 		
@@ -579,9 +690,11 @@
 		
 	}
 term:
-	/* We should signal a control thread here */
-	
-	hdr->state = SCTPW_CTX_BROKEN;
+	if (ret != 0) {
+		/* We should signal a control thread here... */
+		TRACE_DEBUG(FULL, "Receiving thread of pair %d terminating...", pair->stream);
+	}
+	pair->th_st = 2;
 	return NULL;
 }
 
@@ -602,8 +715,8 @@
 	_sctx = (tls_sctpw_ctx_t *) *sctx;
 	
 	if (graceful && (_sctx->hdr.state != SCTPW_CTX_OPEN)) {
-		TRACE_DEBUG(INFO, "Invalid parameter");
-		return EINVAL;
+		/* force destroy here */
+		graceful = 0;
 	}
 	
 	/* First, stop the demux thread */
@@ -633,11 +746,14 @@
 		}
 		
 		/* Terminate (cancel) the thread */
-		if (_sctx->pairs[i].th_run) {
-			ret = pthread_cancel(_sctx->pairs[i].th);
-			if (ret != 0) {
-				TRACE_DEBUG(INFO, "pthread_cancel failed: %s", strerror(ret));
-				return ret;
+		if (_sctx->pairs[i].th_st) {
+			
+			if (_sctx->pairs[i].th_st == 1) {
+				ret = pthread_cancel(_sctx->pairs[i].th);
+				if (ret != 0) {
+					TRACE_DEBUG(INFO, "pthread_cancel failed: %s", strerror(ret));
+					return ret;
+				}
 			}
 			
 			ret = pthread_join(_sctx->pairs[i].th, NULL);
@@ -646,7 +762,7 @@
 				return ret;
 			}
 			
-			_sctx->pairs[i].th_run = 0;
+			_sctx->pairs[i].th_st = 0;
 			
 			/* Destroy the gnutls session */
 			gnutls_deinit (_sctx->pairs[i].session);
@@ -780,7 +896,8 @@
 int tls_sctpw_handshake( tls_sctpw_ctx * sctx, int (*checkcb)(gnutls_session_t session, void * parm), void * parm )
 {
 	int ret = 0;
-	int i;
+	int i, errs=0;
+	sr_t resume_data;
 	
 	tls_sctpw_ctx_t * _sctx = (tls_sctpw_ctx_t *) sctx;
 	
@@ -791,25 +908,79 @@
 		return EINVAL;
 	}
 	
-	/* For each pair of stream  -- we could create separate threads to process all handshakes in parallel */
-	for ( i= 0; i < _sctx->hdr.nbstrm; i++) {
-		/* handshake */
-		ret = gnutls_handshake( _sctx->pairs[i].session );
-		if (ret < 0) {
-			TRACE_DEBUG(INFO, "Handshake failed on stream %d: %s", i, gnutls_strerror(ret));
+	/* First, perform a full handshake on the first pair of streams */
+	ret = gnutls_handshake( _sctx->pairs[0].session );
+	if (ret != GNUTLS_E_SUCCESS) {
+		TRACE_DEBUG(INFO, "Full handshake failed on stream 0: %s", gnutls_strerror(ret));
+		return ret;
+	}
+
+	TRACE_DEBUG(FULL, "Handskake complete on stream pair 0");
+
+	if (checkcb != NULL) {
+		/* Now verifying the credentials */
+		ret = (*checkcb)(_sctx->pairs[0].session, parm);
+		if (ret != 0) {
+			TRACE_DEBUG(INFO, "The callback to verify the credentials returned an error: %d", ret);
 			return ret;
 		}
-		
-		TRACE_DEBUG(FULL, "Handskake complete on stream pair %d", i);
+	}
+	
+	/* Ok, stream 0 is ready for use. Now retrieve the session data to perform session resuming */
+	ret = gnutls_session_get_data2(_sctx->pairs[0].session, &resume_data.data);
+	if (ret != GNUTLS_E_SUCCESS) {
+		TRACE_DEBUG(INFO, "gnutls_session_get_data2 failed: %s", gnutls_strerror(ret));
+		return ret;
+	}
+	/* We also need the session id (for the server) */
+	resume_data.keylen = GNUTLS_MAX_SESSION_ID;
+	ret = gnutls_session_get_id(_sctx->pairs[0].session, resume_data.key, &resume_data.keylen);
+	if (ret != GNUTLS_E_SUCCESS) {
+		TRACE_DEBUG(INFO, "gnutls_session_get_id failed: %s", gnutls_strerror(ret));
+		return ret;
+	}
+	
+	/* Resume the session on all other stream pairs, in parallel */
+	for ( i= 1; i < _sctx->hdr.nbstrm; i++) {
+		_sctx->pairs[i].resume = &resume_data;
 		
-		if (checkcb != NULL) {
-			/* Now verifying the credentials */
-			ret = (*checkcb)(_sctx->pairs[i].session, parm);
-			if (ret != 0) {
-				TRACE_DEBUG(INFO, "The callback to verify the credentials returned an error: %d", ret);
-				return ret;
-			}
+		/* create the thread that resumes the session */
+		ret = pthread_create( &(_sctx->pairs[i].th), NULL, sctpw_sess_resume, (void *)&(_sctx->pairs[i]));
+		if (ret != 0) {
+			TRACE_DEBUG(INFO, "pthread_create failed: %s", strerror(ret));
+			return ret;
+		}
+		_sctx->pairs[i].th_st = 1;
+	}
+	
+	/* Wait for the process to complete */
+	for ( i= 1; i < _sctx->hdr.nbstrm; i++) {
+		void * th_ret;
+		
+		/* create the thread that resumes the session */
+		ret = pthread_join(_sctx->pairs[i].th, &th_ret);
+		if (ret != 0) {
+			TRACE_DEBUG(INFO, "pthread_join failed: %s", strerror(ret));
+			return ret;
 		}
+		_sctx->pairs[i].th_st = 0;
+		
+		/* Check return value of the exchange */
+		if ( th_ret != (void *)0 )
+			errs ++ ;
+	}
+	
+	/* Deallocate the resources of session resuming */
+	gnutls_free(resume_data.data.data);
+	
+	/* If we had some errors: return an error */
+	if (errs != 0) {
+		TRACE_DEBUG(INFO, "Session could not be resumed on %d pairs of streams", errs);
+		return ECONNABORTED;
+	}
+	
+	/* For each pair of stream, now create the receiver thread */
+	for ( i= 0; i < _sctx->hdr.nbstrm; i++) {
 	
 		/* create the thread that receives the data */
 		ret = pthread_create( &(_sctx->pairs[i].th), NULL, sctpw_recv_data, (void *)&(_sctx->pairs[i]));
@@ -817,7 +988,7 @@
 			TRACE_DEBUG(INFO, "pthread_create failed: %s", strerror(ret));
 			return ret;
 		}
-		_sctx->pairs[i].th_run = 1;
+		_sctx->pairs[i].th_st = 1;
 	}
 	
 	/* Mark the state as open now */
@@ -1219,7 +1390,7 @@
 	TRACE_ENTRY("%d %p", argc, argv);
 	
 	/* Initialize gnutls */
- 	// gcry_control (GCRYCTL_SET_THREAD_CBS, &gcry_threads_pthread);
+ 	gcry_control (GCRYCTL_SET_THREAD_CBS, &gcry_threads_pthread);
 	gnutls_global_init ();
 	gnutls_global_set_log_function(my_gnutls_log);
 	gnutls_global_set_log_level(0);
"Welcome to our mercurial repository"