diff --git a/include/haproxy/quic_tls-t.h b/include/haproxy/quic_tls-t.h index 95fefc486..693bc4f72 100644 --- a/include/haproxy/quic_tls-t.h +++ b/include/haproxy/quic_tls-t.h @@ -96,6 +96,7 @@ extern unsigned char initial_salt[20]; /* Key phase used for Key Update */ struct quic_tls_kp { + EVP_CIPHER_CTX *ctx; unsigned char *secret; size_t secretlen; unsigned char *iv; diff --git a/include/haproxy/quic_tls.h b/include/haproxy/quic_tls.h index d12d6e07b..8616aa31f 100644 --- a/include/haproxy/quic_tls.h +++ b/include/haproxy/quic_tls.h @@ -546,12 +546,15 @@ static inline int qc_new_isecs(struct quic_conn *qc, */ static inline void quic_tls_ku_free(struct quic_conn *qc) { + EVP_CIPHER_CTX_free(qc->ku.prv_rx.ctx); pool_free(pool_head_quic_tls_secret, qc->ku.prv_rx.secret); pool_free(pool_head_quic_tls_iv, qc->ku.prv_rx.iv); pool_free(pool_head_quic_tls_key, qc->ku.prv_rx.key); + EVP_CIPHER_CTX_free(qc->ku.nxt_rx.ctx); pool_free(pool_head_quic_tls_secret, qc->ku.nxt_rx.secret); pool_free(pool_head_quic_tls_iv, qc->ku.nxt_rx.iv); pool_free(pool_head_quic_tls_key, qc->ku.nxt_rx.key); + EVP_CIPHER_CTX_free(qc->ku.nxt_tx.ctx); pool_free(pool_head_quic_tls_secret, qc->ku.nxt_tx.secret); pool_free(pool_head_quic_tls_iv, qc->ku.nxt_tx.iv); pool_free(pool_head_quic_tls_key, qc->ku.nxt_tx.key); diff --git a/src/xprt_quic.c b/src/xprt_quic.c index e1cdc72e8..63cfc7fe7 100644 --- a/src/xprt_quic.c +++ b/src/xprt_quic.c @@ -733,6 +733,26 @@ static int quic_tls_key_update(struct quic_conn *qc) return 0; } + if (nxt_rx->ctx) { + EVP_CIPHER_CTX_free(nxt_rx->ctx); + nxt_rx->ctx = NULL; + } + + if (!quic_tls_rx_ctx_init(&nxt_rx->ctx, tls_ctx->rx.aead, nxt_rx->key)) { + TRACE_DEVEL("could not initial RX TLS cipher context", QUIC_EV_CONN_RWSEC, qc); + return 0; + } + + if (nxt_tx->ctx) { + EVP_CIPHER_CTX_free(nxt_tx->ctx); + nxt_tx->ctx = NULL; + } + + if (!quic_tls_rx_ctx_init(&nxt_tx->ctx, tls_ctx->tx.aead, nxt_tx->key)) { + TRACE_DEVEL("could not initial RX TLS cipher context", QUIC_EV_CONN_RWSEC, qc); + return 0; + } + return 1; } @@ -744,34 +764,42 @@ static void quic_tls_rotate_keys(struct quic_conn *qc) { struct quic_tls_ctx *tls_ctx = &qc->els[QUIC_TLS_ENC_LEVEL_APP].tls_ctx; unsigned char *curr_secret, *curr_iv, *curr_key; + EVP_CIPHER_CTX *curr_ctx; /* Rotate the RX secrets */ + curr_ctx = tls_ctx->rx.ctx; curr_secret = tls_ctx->rx.secret; curr_iv = tls_ctx->rx.iv; curr_key = tls_ctx->rx.key; + tls_ctx->rx.ctx = qc->ku.nxt_rx.ctx; tls_ctx->rx.secret = qc->ku.nxt_rx.secret; tls_ctx->rx.iv = qc->ku.nxt_rx.iv; tls_ctx->rx.key = qc->ku.nxt_rx.key; + qc->ku.nxt_rx.ctx = qc->ku.prv_rx.ctx; qc->ku.nxt_rx.secret = qc->ku.prv_rx.secret; qc->ku.nxt_rx.iv = qc->ku.prv_rx.iv; qc->ku.nxt_rx.key = qc->ku.prv_rx.key; + qc->ku.prv_rx.ctx = curr_ctx; qc->ku.prv_rx.secret = curr_secret; qc->ku.prv_rx.iv = curr_iv; qc->ku.prv_rx.key = curr_key; qc->ku.prv_rx.pn = tls_ctx->rx.pn; /* Update the TX secrets */ + curr_ctx = tls_ctx->tx.ctx; curr_secret = tls_ctx->tx.secret; curr_iv = tls_ctx->tx.iv; curr_key = tls_ctx->tx.key; + tls_ctx->tx.ctx = qc->ku.nxt_tx.ctx; tls_ctx->tx.secret = qc->ku.nxt_tx.secret; tls_ctx->tx.iv = qc->ku.nxt_tx.iv; tls_ctx->tx.key = qc->ku.nxt_tx.key; + qc->ku.nxt_tx.ctx = curr_ctx; qc->ku.nxt_tx.secret = curr_secret; qc->ku.nxt_tx.iv = curr_iv; qc->ku.nxt_tx.key = curr_key; @@ -1342,6 +1370,7 @@ static int qc_pkt_decrypt(struct quic_rx_packet *pkt, struct quic_enc_level *qel int ret, kp_changed; unsigned char iv[QUIC_TLS_IV_LEN]; struct quic_tls_ctx *tls_ctx = &qel->tls_ctx; + EVP_CIPHER_CTX *rx_ctx = tls_ctx->rx.ctx; unsigned char *rx_iv = tls_ctx->rx.iv; size_t rx_iv_sz = tls_ctx->rx.ivlen; unsigned char *rx_key = tls_ctx->rx.key; @@ -1360,13 +1389,15 @@ static int qc_pkt_decrypt(struct quic_rx_packet *pkt, struct quic_enc_level *qel if (!pkt->qc->ku.prv_rx.pn) return 0; - rx_iv = pkt->qc->ku.prv_rx.iv; + rx_ctx = pkt->qc->ku.prv_rx.ctx; + rx_iv = pkt->qc->ku.prv_rx.iv; rx_key = pkt->qc->ku.prv_rx.key; } else if (pkt->pn > qel->pktns->rx.largest_pn) { /* Next key phase */ kp_changed = 1; - rx_iv = pkt->qc->ku.nxt_rx.iv; + rx_ctx = pkt->qc->ku.nxt_rx.ctx; + rx_iv = pkt->qc->ku.nxt_rx.iv; rx_key = pkt->qc->ku.nxt_rx.key; } } @@ -1377,7 +1408,7 @@ static int qc_pkt_decrypt(struct quic_rx_packet *pkt, struct quic_enc_level *qel ret = quic_tls_decrypt(pkt->data + pkt->aad_len, pkt->len - pkt->aad_len, pkt->data, pkt->aad_len, - tls_ctx->rx.ctx, tls_ctx->rx.aead, rx_key, iv); + rx_ctx, tls_ctx->rx.aead, rx_key, iv); if (!ret) return 0;