diff --git a/src/internal.c b/src/internal.c index edab14eb4..db70d0f6f 100644 --- a/src/internal.c +++ b/src/internal.c @@ -1096,7 +1096,7 @@ WOLFSSH* SshInit(WOLFSSH* ssh, WOLFSSH_CTX* ctx) ssh->fs = NULL; ssh->acceptState = ACCEPT_BEGIN; ssh->clientState = CLIENT_BEGIN; - ssh->isKeying = 1; + ssh->isKeying = 0; /* initial state of not keying yet */ ssh->authId = ID_USERAUTH_PUBLICKEY; ssh->supportedAuth[0] = ID_USERAUTH_PUBLICKEY; ssh->supportedAuth[1] = ID_USERAUTH_PASSWORD; @@ -4058,6 +4058,15 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) ret = WS_BAD_ARGUMENT; } + if (ret == WS_SUCCESS) { + /* Check if already in process of keying and error out if so. */ + if (ssh->isKeying & WOLFSSH_PEER_IS_KEYING) { + WLOG(WS_LOG_ERROR, + "Already in keying process and got KEX init"); + ret = WS_INVALID_STATE_E; + } + } + /* * I don't need to save what the client sends here. I should decode * each list into a local array of IDs, and pick the one the peer is @@ -4067,6 +4076,8 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) */ if (ret == WS_SUCCESS) { + /* Set peer is keying flag after receiving SSH_MSG_KEX_INIT */ + ssh->isKeying |= WOLFSSH_PEER_IS_KEYING; if (ssh->handshake == NULL) { ssh->handshake = HandshakeInfoNew(ssh->ctx->heap); if (ssh->handshake == NULL) { @@ -4327,7 +4338,8 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) byte scratchLen[LENGTH_SZ]; word32 strSz = 0; - if (!ssh->isKeying) { + /* respond with KEX Init message if not having initiated the keying */ + if ((ssh->isKeying & WOLFSSH_SELF_IS_KEYING) == 0) { WLOG(WS_LOG_DEBUG, "Keying initiated"); ret = SendKexInit(ssh); } @@ -5881,6 +5893,13 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) if (ssh == NULL || ssh->handshake == NULL) ret = WS_BAD_ARGUMENT; + if (ret == WS_SUCCESS) { + if (ssh->isKeying & WOLFSSH_SELF_IS_KEYING) { + WLOG(WS_LOG_ERROR, "Keying failed"); + ret = WS_INVALID_STATE_E; + } + } + if (ret == WS_SUCCESS) { ssh->peerEncryptId = ssh->handshake->encryptId; ssh->peerMacId = ssh->handshake->macId; @@ -5941,7 +5960,9 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) if (ret == WS_SUCCESS) { ssh->rxCount = 0; ssh->highwaterFlag = 0; - ssh->isKeying = 0; + + /* Clear peer is keying flag */ + ssh->isKeying &= ~WOLFSSH_PEER_IS_KEYING; HandshakeInfoFree(ssh->handshake, ssh->ctx->heap); ssh->handshake = NULL; WLOG(WS_LOG_DEBUG, "Keying completed"); @@ -9405,7 +9426,7 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed) case MSGID_KEXINIT: WLOG(WS_LOG_DEBUG, "Decoding MSGID_KEXINIT"); ret = DoKexInit(ssh, buf + idx, payloadSz, &payloadIdx); - if (ssh->isKeying == 1 && + if (ssh->isKeying && ssh->connectState == CONNECT_SERVER_CHANNEL_REQUEST_DONE) { if (ssh->handshake->kexId == ID_DH_GEX_SHA256) { #if !defined(WOLFSSH_NO_DH) && !defined(WOLFSSH_NO_DH_GEX_SHA256) @@ -10501,7 +10522,8 @@ int SendKexInit(WOLFSSH* ssh) } if (ret == WS_SUCCESS) { - ssh->isKeying = 1; + /* Set self is keying flag since we started sending the KEX init msg */ + ssh->isKeying |= WOLFSSH_SELF_IS_KEYING; if (ssh->handshake == NULL) { ssh->handshake = HandshakeInfoNew(ssh->ctx->heap); if (ssh->handshake == NULL) { @@ -12534,9 +12556,13 @@ int SendNewKeys(WOLFSSH* ssh) ssh->txCount = 0; } - if (ret == WS_SUCCESS) + if (ret == WS_SUCCESS) { ret = wolfSSH_SendPacket(ssh); + /* Clear self is keying flag */ + ssh->isKeying &= ~WOLFSSH_SELF_IS_KEYING; + } + WLOG(WS_LOG_DEBUG, "Leaving SendNewKeys(), ret = %d", ret); return ret; } diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 1b7dada16..6df5f1147 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -473,6 +473,11 @@ enum NameIdType { #define WOLFSSH_KEY_QUANTITY_REQ 1 #endif +/* Keep track of keying state for both sides of the connection. + * WOLFSSH_SELF_IS_KEYING gets set on sending KEX init and + * WOLFSSH_PEER_IS_KEYING gets set on receiving KEX init */ +#define WOLFSSH_PEER_IS_KEYING 0x01 +#define WOLFSSH_SELF_IS_KEYING 0x02 WOLFSSH_LOCAL byte NameToId(const char* name, word32 nameSz); WOLFSSH_LOCAL const char* IdToName(byte id);