diff --git a/include/haproxy/quic_tls.h b/include/haproxy/quic_tls.h index 02e155d00..df836d510 100644 --- a/include/haproxy/quic_tls.h +++ b/include/haproxy/quic_tls.h @@ -420,14 +420,24 @@ static inline void quic_tls_ctx_secs_free(struct quic_tls_ctx *ctx) */ static inline int quic_tls_ctx_keys_alloc(struct quic_tls_ctx *ctx) { + if (ctx->rx.key) + goto write; + if (!(ctx->rx.iv = pool_alloc(pool_head_quic_tls_iv)) || - !(ctx->rx.key = pool_alloc(pool_head_quic_tls_key)) || - !(ctx->tx.iv = pool_alloc(pool_head_quic_tls_iv)) || + !(ctx->rx.key = pool_alloc(pool_head_quic_tls_key))) + goto err; + + write: + if (ctx->tx.key) + goto out; + + if (!(ctx->tx.iv = pool_alloc(pool_head_quic_tls_iv)) || !(ctx->tx.key = pool_alloc(pool_head_quic_tls_key))) goto err; ctx->rx.ivlen = ctx->tx.ivlen = QUIC_TLS_IV_LEN; ctx->rx.keylen = ctx->tx.keylen = QUIC_TLS_KEY_LEN; +out: return 1; err: diff --git a/src/xprt_quic.c b/src/xprt_quic.c index 4cfd60f87..042bd17ae 100644 --- a/src/xprt_quic.c +++ b/src/xprt_quic.c @@ -919,6 +919,9 @@ int ha_quic_set_encryption_secrets(SSL *ssl, enum ssl_encryption_level_t level, rx->md = tx->md = tls_md(cipher); rx->hp = tx->hp = tls_hp(cipher); + if (!read_secret) + goto write; + if (!quic_tls_derive_keys(rx->aead, rx->hp, rx->md, ver, rx->key, rx->keylen, rx->iv, rx->ivlen, rx->hp_key, sizeof rx->hp_key, read_secret, secret_len)) { @@ -945,6 +948,8 @@ int ha_quic_set_encryption_secrets(SSL *ssl, enum ssl_encryption_level_t level, quic_accept_push_qc(qc); } +write: + if (!write_secret) goto out; @@ -977,17 +982,17 @@ int ha_quic_set_encryption_secrets(SSL *ssl, enum ssl_encryption_level_t level, goto leave; } - memcpy(rx->secret, read_secret, secret_len); - rx->secretlen = secret_len; - memcpy(tx->secret, write_secret, secret_len); - tx->secretlen = secret_len; + if (read_secret) { + memcpy(rx->secret, read_secret, secret_len); + rx->secretlen = secret_len; + } + if (write_secret) { + memcpy(tx->secret, write_secret, secret_len); + tx->secretlen = secret_len; + } /* Initialize all the secret keys lengths */ prv_rx->secretlen = nxt_rx->secretlen = nxt_tx->secretlen = secret_len; /* Prepare the next key update */ - if (!quic_tls_key_update(qc)) { - // trace already emitted by function above - goto leave; - } } out: @@ -2258,6 +2263,11 @@ static inline int qc_provide_cdata(struct quic_enc_level *el, else { qc->state = QUIC_HS_ST_COMPLETE; } + + if (!quic_tls_key_update(qc)) { + TRACE_ERROR("quic_tls_key_update() failed", QUIC_EV_CONN_IO_CB, qc); + goto leave; + } } else { ssl_err = SSL_process_quic_post_handshake(ctx->ssl); if (ssl_err != 1) {