MINOR: quic: Atomically get/set the connection state

As ->state quic_conn struct member field is shared between threads
we must atomically get and set its value.
This commit is contained in:
Frédéric Lécaille 2021-08-18 09:16:01 +02:00 committed by Amaury Denoyelle
parent ee57444382
commit eed7a7d73b

View File

@ -594,7 +594,7 @@ static inline int quic_peer_validated_addr(struct ssl_sock_ctx *ctx)
if ((qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) || if ((qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) ||
(qc->els[QUIC_TLS_ENC_LEVEL_APP].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) || (qc->els[QUIC_TLS_ENC_LEVEL_APP].pktns->flags & QUIC_FL_PKTNS_ACK_RECEIVED) ||
(qc->state & QUIC_HS_ST_COMPLETE)) HA_ATOMIC_LOAD(&qc->state) >= QUIC_HS_ST_COMPLETE)
return 1; return 1;
return 0; return 0;
@ -608,6 +608,7 @@ static inline void qc_set_timer(struct ssl_sock_ctx *ctx)
struct quic_conn *qc; struct quic_conn *qc;
struct quic_pktns *pktns; struct quic_pktns *pktns;
unsigned int pto; unsigned int pto;
int handshake_complete;
TRACE_ENTER(QUIC_EV_CONN_STIMER, ctx->conn, TRACE_ENTER(QUIC_EV_CONN_STIMER, ctx->conn,
NULL, NULL, &ctx->conn->qc->path->ifae_pkts); NULL, NULL, &ctx->conn->qc->path->ifae_pkts);
@ -629,7 +630,8 @@ static inline void qc_set_timer(struct ssl_sock_ctx *ctx)
goto out; goto out;
} }
pktns = quic_pto_pktns(qc, qc->state & QUIC_HS_ST_COMPLETE, &pto); handshake_complete = HA_ATOMIC_LOAD(&qc->state) >= QUIC_HS_ST_COMPLETE;
pktns = quic_pto_pktns(qc, handshake_complete, &pto);
if (tick_isset(pto)) if (tick_isset(pto))
qc->timer = pto; qc->timer = pto;
out: out:
@ -1495,7 +1497,7 @@ static inline int qc_provide_cdata(struct quic_enc_level *el,
struct quic_rx_packet *pkt, struct quic_rx_packet *pkt,
struct quic_rx_crypto_frm *cf) struct quic_rx_crypto_frm *cf)
{ {
int ssl_err; int ssl_err, state;
struct quic_conn *qc; struct quic_conn *qc;
TRACE_ENTER(QUIC_EV_CONN_SSLDATA, ctx->conn); TRACE_ENTER(QUIC_EV_CONN_SSLDATA, ctx->conn);
@ -1511,43 +1513,44 @@ static inline int qc_provide_cdata(struct quic_enc_level *el,
TRACE_PROTO("in order CRYPTO data", TRACE_PROTO("in order CRYPTO data",
QUIC_EV_CONN_SSLDATA, ctx->conn,, cf, ctx->ssl); QUIC_EV_CONN_SSLDATA, ctx->conn,, cf, ctx->ssl);
if (qc->state < QUIC_HS_ST_COMPLETE) { state = HA_ATOMIC_LOAD(&qc->state);
if (state < QUIC_HS_ST_COMPLETE) {
ssl_err = SSL_do_handshake(ctx->ssl); ssl_err = SSL_do_handshake(ctx->ssl);
if (ssl_err != 1) { if (ssl_err != 1) {
ssl_err = SSL_get_error(ctx->ssl, ssl_err); ssl_err = SSL_get_error(ctx->ssl, ssl_err);
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
TRACE_PROTO("SSL handshake", TRACE_PROTO("SSL handshake",
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
goto out; goto out;
} }
TRACE_DEVEL("SSL handshake error", TRACE_DEVEL("SSL handshake error",
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
goto err; goto err;
} }
TRACE_PROTO("SSL handshake OK", QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); TRACE_PROTO("SSL handshake OK", QUIC_EV_CONN_HDSHK, ctx->conn, &state);
if (objt_listener(ctx->conn->target)) if (objt_listener(ctx->conn->target))
qc->state = QUIC_HS_ST_CONFIRMED; HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CONFIRMED);
else else
qc->state = QUIC_HS_ST_COMPLETE; HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_COMPLETE);
} else { } else {
ssl_err = SSL_process_quic_post_handshake(ctx->ssl); ssl_err = SSL_process_quic_post_handshake(ctx->ssl);
if (ssl_err != 1) { if (ssl_err != 1) {
ssl_err = SSL_get_error(ctx->ssl, ssl_err); ssl_err = SSL_get_error(ctx->ssl, ssl_err);
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
TRACE_DEVEL("SSL post handshake", TRACE_DEVEL("SSL post handshake",
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
goto out; goto out;
} }
TRACE_DEVEL("SSL post handshake error", TRACE_DEVEL("SSL post handshake error",
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); QUIC_EV_CONN_HDSHK, ctx->conn, &state, &ssl_err);
goto err; goto err;
} }
TRACE_PROTO("SSL post handshake succeeded", TRACE_PROTO("SSL post handshake succeeded",
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); QUIC_EV_CONN_HDSHK, ctx->conn, &state);
} }
out: out:
@ -1938,7 +1941,7 @@ static int qc_parse_pkt_frms(struct quic_rx_packet *pkt, struct ssl_sock_ctx *ct
if (objt_listener(ctx->conn->target)) if (objt_listener(ctx->conn->target))
goto err; goto err;
conn->state = QUIC_HS_ST_CONFIRMED; HA_ATOMIC_STORE(&conn->state, QUIC_HS_ST_CONFIRMED);
break; break;
default: default:
goto err; goto err;
@ -1949,12 +1952,12 @@ static int qc_parse_pkt_frms(struct quic_rx_packet *pkt, struct ssl_sock_ctx *ct
* has successfully parse a Handshake packet. The Initial encryption must also * has successfully parse a Handshake packet. The Initial encryption must also
* be discarded. * be discarded.
*/ */
if (conn->state == QUIC_HS_ST_SERVER_INITIAL && if (HA_ATOMIC_LOAD(&conn->state) == QUIC_HS_ST_SERVER_INITIAL &&
pkt->type == QUIC_PACKET_TYPE_HANDSHAKE) { pkt->type == QUIC_PACKET_TYPE_HANDSHAKE) {
quic_tls_discard_keys(&conn->els[QUIC_TLS_ENC_LEVEL_INITIAL]); quic_tls_discard_keys(&conn->els[QUIC_TLS_ENC_LEVEL_INITIAL]);
quic_pktns_discard(conn->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, conn); quic_pktns_discard(conn->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, conn);
qc_set_timer(ctx); qc_set_timer(ctx);
conn->state = QUIC_HS_ST_SERVER_HANDSHAKE; HA_ATOMIC_STORE(&conn->state, QUIC_HS_ST_SERVER_HANDSHAKE);
} }
TRACE_LEAVE(QUIC_EV_CONN_PRSHPKT, ctx->conn); TRACE_LEAVE(QUIC_EV_CONN_PRSHPKT, ctx->conn);
@ -2004,7 +2007,7 @@ static int qc_prep_hdshk_pkts(struct qring *qr, struct ssl_sock_ctx *ctx)
TRACE_ENTER(QUIC_EV_CONN_PHPKTS, ctx->conn); TRACE_ENTER(QUIC_EV_CONN_PHPKTS, ctx->conn);
qc = ctx->conn->qc; qc = ctx->conn->qc;
if (!quic_get_tls_enc_levels(&tel, &next_tel, qc->state)) { if (!quic_get_tls_enc_levels(&tel, &next_tel, HA_ATOMIC_LOAD(&qc->state))) {
TRACE_DEVEL("unknown enc. levels", QUIC_EV_CONN_PHPKTS, ctx->conn); TRACE_DEVEL("unknown enc. levels", QUIC_EV_CONN_PHPKTS, ctx->conn);
goto err; goto err;
} }
@ -2088,12 +2091,12 @@ static int qc_prep_hdshk_pkts(struct qring *qr, struct ssl_sock_ctx *ctx)
/* Discard the Initial encryption keys as soon as /* Discard the Initial encryption keys as soon as
* a handshake packet could be built. * a handshake packet could be built.
*/ */
if (qc->state == QUIC_HS_ST_CLIENT_INITIAL && if (HA_ATOMIC_LOAD(&qc->state) == QUIC_HS_ST_CLIENT_INITIAL &&
pkt_type == QUIC_PACKET_TYPE_HANDSHAKE) { pkt_type == QUIC_PACKET_TYPE_HANDSHAKE) {
quic_tls_discard_keys(&qc->els[QUIC_TLS_ENC_LEVEL_INITIAL]); quic_tls_discard_keys(&qc->els[QUIC_TLS_ENC_LEVEL_INITIAL]);
quic_pktns_discard(qc->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, qc); quic_pktns_discard(qc->els[QUIC_TLS_ENC_LEVEL_INITIAL].pktns, qc);
qc_set_timer(ctx); qc_set_timer(ctx);
qc->state = QUIC_HS_ST_CLIENT_HANDSHAKE; HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CLIENT_HANDSHAKE);
} }
/* Special case for Initial packets: when they have all /* Special case for Initial packets: when they have all
* been sent, select the next level. * been sent, select the next level.
@ -2478,7 +2481,7 @@ static inline void qc_rm_hp_pkts(struct quic_enc_level *el, struct ssl_sock_ctx
app_qel = &ctx->conn->qc->els[QUIC_TLS_ENC_LEVEL_APP]; app_qel = &ctx->conn->qc->els[QUIC_TLS_ENC_LEVEL_APP];
/* A server must not process incoming 1-RTT packets before the handshake is complete. */ /* A server must not process incoming 1-RTT packets before the handshake is complete. */
if (el == app_qel && objt_listener(ctx->conn->target) && if (el == app_qel && objt_listener(ctx->conn->target) &&
ctx->conn->qc->state < QUIC_HS_ST_COMPLETE) { HA_ATOMIC_LOAD(&ctx->conn->qc->state) < QUIC_HS_ST_COMPLETE) {
TRACE_PROTO("hp not removed (handshake not completed)", TRACE_PROTO("hp not removed (handshake not completed)",
QUIC_EV_CONN_ELRMHP, ctx->conn); QUIC_EV_CONN_ELRMHP, ctx->conn);
goto out; goto out;
@ -2622,9 +2625,10 @@ struct task *quic_conn_io_cb(struct task *t, void *context, unsigned int state)
ctx = context; ctx = context;
qc = ctx->conn->qc; qc = ctx->conn->qc;
qr = NULL; qr = NULL;
TRACE_ENTER(QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); st = HA_ATOMIC_LOAD(&qc->state);
TRACE_ENTER(QUIC_EV_CONN_HDSHK, ctx->conn, &st);
ssl_err = SSL_ERROR_NONE; ssl_err = SSL_ERROR_NONE;
if (!quic_get_tls_enc_levels(&tel, &next_tel, qc->state)) if (!quic_get_tls_enc_levels(&tel, &next_tel, st))
goto err; goto err;
qel = &qc->els[tel]; qel = &qc->els[tel];
@ -2674,15 +2678,14 @@ struct task *quic_conn_io_cb(struct task *t, void *context, unsigned int state)
goto next_level; goto next_level;
} }
out:
MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list); MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list);
TRACE_LEAVE(QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state); TRACE_LEAVE(QUIC_EV_CONN_HDSHK, ctx->conn, &st);
return t; return t;
err: err:
if (qr) if (qr)
MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list); MT_LIST_APPEND(qc->tx.qring_list, &qr->mt_list);
TRACE_DEVEL("leaving in error", QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err); TRACE_DEVEL("leaving in error", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err);
return t; return t;
} }
@ -2768,7 +2771,7 @@ static struct task *process_timer(struct task *task, void *ctx, unsigned int sta
struct ssl_sock_ctx *conn_ctx; struct ssl_sock_ctx *conn_ctx;
struct quic_conn *qc; struct quic_conn *qc;
struct quic_pktns *pktns; struct quic_pktns *pktns;
int st;
conn_ctx = task->context; conn_ctx = task->context;
qc = conn_ctx->conn->qc; qc = conn_ctx->conn->qc;
@ -2786,11 +2789,12 @@ static struct task *process_timer(struct task *task, void *ctx, unsigned int sta
goto out; goto out;
} }
st = HA_ATOMIC_LOAD(&qc->state);
if (qc->path->in_flight) { if (qc->path->in_flight) {
pktns = quic_pto_pktns(qc, qc->state >= QUIC_HS_ST_COMPLETE, NULL); pktns = quic_pto_pktns(qc, st >= QUIC_HS_ST_COMPLETE, NULL);
pktns->tx.pto_probe = 1; pktns->tx.pto_probe = 1;
} }
else if (objt_server(qc->conn->target) && qc->state <= QUIC_HS_ST_COMPLETE) { else if (objt_server(qc->conn->target) && st <= QUIC_HS_ST_COMPLETE) {
struct quic_enc_level *iel = &qc->els[QUIC_TLS_ENC_LEVEL_INITIAL]; struct quic_enc_level *iel = &qc->els[QUIC_TLS_ENC_LEVEL_INITIAL];
struct quic_enc_level *hel = &qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE]; struct quic_enc_level *hel = &qc->els[QUIC_TLS_ENC_LEVEL_HANDSHAKE];
@ -2837,7 +2841,7 @@ static struct quic_conn *qc_new_conn(unsigned int version, int ipv4,
if (server) { if (server) {
struct listener *l = owner; struct listener *l = owner;
qc->state = QUIC_HS_ST_SERVER_INITIAL; HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_SERVER_INITIAL);
/* Copy the initial DCID. */ /* Copy the initial DCID. */
qc->odcid.len = dcid_len; qc->odcid.len = dcid_len;
if (qc->odcid.len) if (qc->odcid.len)
@ -2851,7 +2855,7 @@ static struct quic_conn *qc_new_conn(unsigned int version, int ipv4,
} }
/* QUIC Client (outgoing connection to servers) */ /* QUIC Client (outgoing connection to servers) */
else { else {
qc->state = QUIC_HS_ST_CLIENT_INITIAL; HA_ATOMIC_STORE(&qc->state, QUIC_HS_ST_CLIENT_INITIAL);
if (dcid_len) if (dcid_len)
memcpy(qc->dcid.data, dcid, dcid_len); memcpy(qc->dcid.data, dcid, dcid_len);
qc->dcid.len = dcid_len; qc->dcid.len = dcid_len;
@ -3000,7 +3004,8 @@ static int qc_pkt_may_rm_hp(struct quic_rx_packet *pkt,
} }
if (((*qel)->tls_ctx.rx.flags & QUIC_FL_TLS_SECRETS_SET) && if (((*qel)->tls_ctx.rx.flags & QUIC_FL_TLS_SECRETS_SET) &&
(tel != QUIC_TLS_ENC_LEVEL_APP || ctx->conn->qc->state >= QUIC_HS_ST_COMPLETE)) (tel != QUIC_TLS_ENC_LEVEL_APP ||
HA_ATOMIC_LOAD(&ctx->conn->qc->state) >= QUIC_HS_ST_COMPLETE))
return 1; return 1;
return 0; return 0;
@ -4276,14 +4281,15 @@ static int qc_conn_init(struct connection *conn, void **xprt_ctx)
SSL_set_connect_state(ctx->ssl); SSL_set_connect_state(ctx->ssl);
ssl_err = SSL_do_handshake(ctx->ssl); ssl_err = SSL_do_handshake(ctx->ssl);
if (ssl_err != 1) { if (ssl_err != 1) {
int st;
st = HA_ATOMIC_LOAD(&qc->state);
ssl_err = SSL_get_error(ctx->ssl, ssl_err); ssl_err = SSL_get_error(ctx->ssl, ssl_err);
if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) { if (ssl_err == SSL_ERROR_WANT_READ || ssl_err == SSL_ERROR_WANT_WRITE) {
TRACE_PROTO("SSL handshake", TRACE_PROTO("SSL handshake", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err);
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
} }
else { else {
TRACE_DEVEL("SSL handshake error", TRACE_DEVEL("SSL handshake error", QUIC_EV_CONN_HDSHK, ctx->conn, &st, &ssl_err);
QUIC_EV_CONN_HDSHK, ctx->conn, &qc->state, &ssl_err);
goto err; goto err;
} }
} }