MEDIUM: ssl: Add ktls support for AWS-LC.

Add ktls support for AWS-LC. As it does not know anything
about ktls, it means extracting keys from the ssl lib, and provide them
to the kernel. At which point we can use regular recvmsg()/sendmsg()
calls.
This patch only provides support for TLS 1.2, AWS-LC provides a
different way to extract keys for TLS 1.3.
Note that this may work with BoringSSL too, but it has not been tested.
This commit is contained in:
Olivier Houchard 2025-06-19 18:44:22 +02:00 committed by Olivier Houchard
parent a903004a1a
commit 5c8fa50966
2 changed files with 278 additions and 4 deletions

View File

@ -578,6 +578,11 @@ static inline unsigned long ERR_peek_error_func(const char **func)
#endif /* HAVE_VANILLA_OPENSSL && OPENSSL_VERSION_NUMBER >= 0x3000000fL */
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
#include <openssl/hkdf.h>
#define HA_USE_KTLS
#endif /* OPENSSL_IS_BORINGSSL || OPENSSL_IS_AWSLC */
#endif /* USE_KTLS */
#endif /* _HAPROXY_OPENSSL_COMPAT_H */

View File

@ -6031,6 +6031,196 @@ static int ssl_remove_xprt(struct connection *conn, void *xprt_ctx, void *toremo
return (ctx->xprt->remove_xprt(conn, ctx->xprt_ctx, toremove_ctx, newops, newctx));
}
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_AWSLC) || defined(OPENSSL_IS_BORINGSSL)
static void ssl_sock_setup_ktls(struct ssl_sock_ctx *ctx)
{
struct kinfo {
struct tls_crypto_info info;
/*
* Should be enough for key + iv + salt + seq for
* every cipher.
*/
unsigned char buf[68];
} info;
struct {
int nid;
int tls_cipher;
int key_size;
int salt_size;
int iv_size;
int seq_size;
} known_ciphers[] = {
#ifdef TLS_CIPHER_AES_GCM_128
{ NID_aes_128_gcm, TLS_CIPHER_AES_GCM_128, TLS_CIPHER_AES_GCM_128_KEY_SIZE, TLS_CIPHER_AES_GCM_128_SALT_SIZE, TLS_CIPHER_AES_GCM_128_IV_SIZE, TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE },
#endif
#ifdef TLS_CIPHER_AES_GCM_256
{ NID_aes_256_gcm, TLS_CIPHER_AES_GCM_256, TLS_CIPHER_AES_GCM_256_KEY_SIZE, TLS_CIPHER_AES_GCM_256_SALT_SIZE, TLS_CIPHER_AES_GCM_256_IV_SIZE, TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE },
#endif
#ifdef TLS_CIPHER_AES_CCM_128
{ NID_aes_128_ccm, TLS_CIPHER_AES_CCM_128, TLS_CIPHER_AES_CCM_128_KEY_SIZE, TLS_CIPHER_AES_CCM_128_SALT_SIZE, TLS_CIPHER_AES_CCM_128_IV_SIZE, TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE },
#endif
#ifdef TLS_CIPHER_CHACHA20_POLY1305
{ NID_chacha20_poly1305, TLS_CIPHER_CHACHA20_POLY1305, TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE, TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE, TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE },
#endif
#if defined(TLS_CIPHER_SM4_GCM) && defined(NID_sm4_gcm)
{ NID_sm4_gcm, TLS_CIPHER_SM4_GCM, TLS_CIPHER_SM4_GCM_KEY_SIZE,cTLS_CIPHER_SM4_GCM_SALT_SIZE, TLS_CIPHER_SM4_GCM_IV_SIZE, TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE },
#endif
#if defined(TLS_CIPHER_SM4_CCM) && defined(NID_sm4_ccm)
{ NID_sm4_ccm, TLS_CIPHER_SM4_CCM, TLS_CIPHER_SM4_CCM_KEY_SIZE,cTLS_CIPHER_SM4_CCM_SALT_SIZE, TLS_CIPHER_SM4_CCM_IV_SIZE, TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE },
#endif
#if defined(TLS_CIPHER_ARIA_GCM_128) && defined(NID_aria_128_gcm)
{ NID_aria_128_gcm, TLS_CIPHER_ARIA_GCM_128, TLS_CIPHER_ARIA_GCM_128_KEY_SIZE, TLS_CIPHER_ARIA_GCM_128_SALT_SIZE, TLS_CIPHER_ARIA_GCM_128_IV_SIZE, TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE },
#endif
#if defined(TLS_CIPHER_ARIA_GCM_256) && defined(NID_aria_256_gcm)
{ NID_aria_256_gcm, TLS_CIPHER_ARIA_GCM_256, TLS_CIPHER_ARIA_GCM_256_KEY_SIZE, TLS_CIPHER_ARIA_GCM_256_SALT_SIZE, TLS_CIPHER_ARIA_GCM_256_IV_SIZE, TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE },
#endif
};
SSL *ssl = ctx->ssl;
unsigned char buf[128];
uint64_t seq;
int info_size;
int key_size, salt_size, iv_size, seq_size;
int is_tls_12;
int nid, i;
if (!(ctx->flags & SSL_SOCK_F_KTLS_ENABLED))
return;
switch (SSL_version(ctx->ssl)) {
case TLS_1_2_VERSION:
is_tls_12 = 1;
break;
default:
ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
return;
}
nid = SSL_CIPHER_get_cipher_nid(SSL_get_current_cipher(ssl));
for (i = 0; i < sizeof(known_ciphers) / sizeof(known_ciphers[0]); i++) {
if (known_ciphers[i].nid == nid)
break;
}
if (i == sizeof(known_ciphers) / sizeof(known_ciphers[0])) {
ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
return;
}
key_size = known_ciphers[i].key_size;
salt_size = known_ciphers[i].salt_size;
iv_size = known_ciphers[i].iv_size;
seq_size = known_ciphers[i].seq_size;
info_size = sizeof(struct tls_crypto_info) + key_size + salt_size + iv_size + seq_size;
/*
* If new ciphers are added, wy may have to increase the buffer size
*/
BUG_ON(key_size + salt_size + iv_size + seq_size > sizeof(info.buf));
info.info.version = is_tls_12 ? TLS_1_2_VERSION : TLS_1_3_VERSION;
info.info.cipher_type = known_ciphers[i].tls_cipher;
if (is_tls_12) {
unsigned char iv[iv_size];
int block_key_size = 2 * key_size + 2 * salt_size;
int i;
/*
* We may have to increase buf size if new ciphers are
* added with bigger key/salt.
*/
BUG_ON(block_key_size > sizeof(buf));
if (SSL_get_key_block_len(ssl) != block_key_size) {
ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
goto out;
}
if (SSL_generate_key_block(ssl, buf, block_key_size) != 1) {
ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
goto out;
}
/*
* The key block contains :
* - client key
* server key
* client salt
* server salt
*/
/*
* First, prepare the RX side
* The oldest linux versions do not support RTX, that way
* we will fail before setting the TX side.
*/
seq = SSL_get_read_sequence(ssl);
seq = my_htonll(seq);
for (i = 0; i < iv_size; i++)
iv[i] = (unsigned char)statistical_prng_range(256);
/* IV */
memcpy(&info.buf[0], &iv, iv_size);
if (!conn_is_back(ctx->conn)) {
/* Key */
memcpy(&info.buf[iv_size], &buf[0], key_size);
/* Salt */
memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size], salt_size);
} else {
/* Key */
memcpy(&info.buf[iv_size], &buf[key_size], key_size);
/* Salt */
memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size + salt_size], salt_size);
}
/* Record seq number */
memcpy(&info.buf[iv_size + key_size + salt_size], &seq, seq_size);
if (ktls_set_key(ctx, &info, info_size, 0) != 0) {
ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
goto out;
}
/*
* Now do the TX side
*/
seq = SSL_get_write_sequence(ssl);
seq = my_htonll(seq);
for (i = 0; i < iv_size; i++)
iv[i] = (unsigned char)statistical_prng_range(256);
memcpy(&info.buf[0], &iv, iv_size);
if (!conn_is_back(ctx->conn)) {
/* Key */
memcpy(&info.buf[iv_size], &buf[key_size], key_size);
/* Salt */
memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size + salt_size], salt_size);
} else {
/* Key */
memcpy(&info.buf[iv_size], &buf[0], key_size);
/* Salt */
memcpy(&info.buf[iv_size + key_size], &buf[2 * key_size], salt_size);
}
memcpy(&info.buf[iv_size + key_size + salt_size], &seq, seq_size);
if (ktls_set_key(ctx, &info, info_size, 1) != 0) {
/*
* Not much we can do at this point. TLS has been
* enabled for RX, we can't disable it, we won't
* try to support only one side, so give up with
* that connection.
*/
ctx->conn->flags |= CO_FL_ERROR;
ctx->flags &= ~SSL_SOCK_F_KTLS_ENABLED;
goto out;
}
ctx->flags |= SSL_SOCK_F_KTLS_SEND | SSL_SOCK_F_KTLS_RECV;
}
out:
return;
}
#endif
#endif
struct task *ssl_sock_io_cb(struct task *t, void *context, unsigned int state)
{
struct tasklet *tl = (struct tasklet *)t;
@ -6069,6 +6259,12 @@ struct task *ssl_sock_io_cb(struct task *t, void *context, unsigned int state)
if (!(ctx->conn->flags & CO_FL_SSL_WAIT_HS)) {
/* handshake completed, leave the bulk queue */
_HA_ATOMIC_AND(&tl->state, ~TASK_HEAVY);
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_AWSLC) || defined(OPENSSL_IS_BORINGSSL)
if (ctx->flags & SSL_SOCK_F_KTLS_ENABLED)
ssl_sock_setup_ktls(ctx);
#endif
#endif
}
}
/* If we had an error, or the handshake is done and I/O is available,
@ -6184,7 +6380,6 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu
* EINTR too.
*/
while (count > 0) {
try = b_contig_space(buf);
if (!try)
break;
@ -6192,6 +6387,14 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu
if (try > count)
try = count;
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
if (ctx->flags & SSL_SOCK_F_KTLS_RECV) {
ret = ctx->xprt->rcv_buf(ctx->conn, ctx->xprt_ctx, buf, try, NULL, NULL, 0);
} else
#endif
#endif
ret = SSL_read(ctx->ssl, b_tail(buf), try);
if (conn->flags & CO_FL_ERROR) {
@ -6199,12 +6402,33 @@ static size_t ssl_sock_to_buf(struct connection *conn, void *xprt_ctx, struct bu
goto out_error;
}
if (ret > 0) {
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
/*
* The next xprt already adjusted the buffer,
* so we should not do it.
*/
if (!(ctx->flags & SSL_SOCK_F_KTLS_RECV))
#endif
#endif
b_add(buf, ret);
done += ret;
count -= ret;
TRACE_DEVEL("Post SSL_read success", SSL_EV_CONN_RECV, conn, &ret);
}
else {
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
if (ctx->flags & SSL_SOCK_F_KTLS_RECV)
/*
* At this point the underlying xprt already
* set any connection error, and we can't
* ask the SSL lib, so we can stop now.
*/
break;
else
#endif
#endif
ret = SSL_get_error(ctx->ssl, ret);
if (ret == SSL_ERROR_WANT_WRITE) {
/* handshake is running, and it needs to enable write */
@ -6317,6 +6541,12 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s
* in which case we accept to do it once again.
*/
while (count) {
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
int ktls_error = 0;
#endif
#endif
#ifdef SSL_READ_EARLY_DATA_SUCCESS
size_t written_data;
#endif
@ -6379,10 +6609,36 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s
TRACE_PROTO("Write early data", SSL_EV_CONN_SEND|SSL_EV_CONN_SEND_EARLY, conn, &ret);
}
} else {
#endif
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
if (ctx->flags & SSL_SOCK_F_KTLS_SEND) {
struct buffer tmpbuf;
tmpbuf.size = b_data(buf) - done;
tmpbuf.data = tmpbuf.size;
tmpbuf.area = b_peek(buf, done);
tmpbuf.head = 0;
ret = ctx->xprt->snd_buf(ctx->conn, ctx->xprt_ctx, &tmpbuf, try, NULL, 0, (ctx->xprt_st & SSL_SOCK_SEND_MORE) ? CO_SFL_MSG_MORE : 0);
if (ret < try) {
if (errno == EINTR)
continue;
else if (!(conn->flags & CO_FL_ERROR))
ktls_error = SSL_ERROR_WANT_WRITE;
else {
ktls_error = SSL_ERROR_SSL;
}
}
} else
#endif
#endif
ret = SSL_write(ctx->ssl, b_peek(buf, done), try);
#ifdef SSL_READ_EARLY_DATA_SUCCESS
}
#endif
if (conn->flags & CO_FL_ERROR) {
/* CO_FL_ERROR may be set by ssl_sock_infocbk */
goto out_error;
@ -6396,6 +6652,13 @@ static size_t ssl_sock_from_buf(struct connection *conn, void *xprt_ctx, const s
TRACE_DEVEL("Post SSL_write success", SSL_EV_CONN_SEND, conn, &ret);
}
else {
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_BORINGSSL) || defined(OPENSSL_IS_AWSLC)
if (ctx->flags & SSL_SOCK_F_KTLS_SEND)
ret = ktls_error;
else
#endif
#endif
ret = SSL_get_error(ctx->ssl, ret);
if (ret == SSL_ERROR_WANT_WRITE) {
@ -6542,6 +6805,12 @@ static void ssl_sock_shutw(struct connection *conn, void *xprt_ctx, int clean)
TRACE_ENTER(SSL_EV_CONN_END, conn);
#ifdef HA_USE_KTLS
#if defined(OPENSSL_IS_AWSLC) || defined(OPENSSL_IS_BORINGSSL)
if (ctx->flags & (SSL_SOCK_F_KTLS_RECV | SSL_SOCK_F_KTLS_SEND))
return;
#endif
#endif
if (conn->flags & (CO_FL_WAIT_XPRT | CO_FL_SSL_WAIT_HS))
return;
conn_report_term_evt(conn, tevt_loc_xprt, xprt_tevt_type_shutw);