Skip to content
11 changes: 11 additions & 0 deletions examples/client/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ static THREAD_RET readInput(void* in)
ret = wolfSSH_stream_send(args->ssh, buf, sz);
wc_UnLockMutex(&args->lock);
if (ret <= 0) {
if (ret == WS_REKEYING) {
continue;
}
fprintf(stderr, "Couldn't send data\n");
return THREAD_RET_SUCCESS;
}
Expand Down Expand Up @@ -472,8 +475,16 @@ static THREAD_RET readPeer(void* in)
continue;
}
#endif /* WOLFSSH_AGENT */
else if (ret == WS_REKEYING) {
wolfSSH_worker(args->ssh, NULL);
ret = 0;
}
}
else if (ret != WS_EOF) {
if (ret == 0) {
bytes = 0;
continue;
}
err_sys("Stream read failed.");
}
}
Expand Down
5 changes: 4 additions & 1 deletion examples/echoserver/echoserver.c
Original file line number Diff line number Diff line change
Expand Up @@ -1416,8 +1416,11 @@ static int sftp_worker(thread_ctx_t* threadCtx)
}
else if (ret < 0) {
error = wolfSSH_get_error(ssh);
if (error == WS_EOF)
if (error == WS_EOF) {
/* shutdown is happening, clear peek error */
ret = 0;
break;
}
}

if (ret == WS_FATAL_ERROR && error == 0) {
Expand Down
87 changes: 80 additions & 7 deletions examples/sftpclient/sftpclient.c
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,13 @@ static int doCmds(func_args* args)

/* check directory is valid */
do {
while (ret == WS_REKEYING || ssh->error == WS_REKEYING) {
ret = wolfSSH_worker(ssh, NULL);
if (ret != WS_SUCCESS && ret == WS_FATAL_ERROR) {
ret = wolfSSH_get_error(ssh);
}
}

ret = wolfSSH_SFTP_STAT(ssh, pt, &atrb);
err = wolfSSH_get_error(ssh);
} while ((err == WS_WANT_READ || err == WS_WANT_WRITE)
Expand Down Expand Up @@ -828,6 +835,13 @@ static int doCmds(func_args* args)

/* update permissions */
do {
while (ret == WS_REKEYING || ssh->error == WS_REKEYING) {
ret = wolfSSH_worker(ssh, NULL);
if (ret != WS_SUCCESS && ret == WS_FATAL_ERROR) {
ret = wolfSSH_get_error(ssh);
}
}

ret = wolfSSH_SFTP_CHMOD(ssh, pt, mode);
err = wolfSSH_get_error(ssh);
} while ((err == WS_WANT_READ || err == WS_WANT_WRITE)
Expand Down Expand Up @@ -878,6 +892,13 @@ static int doCmds(func_args* args)
}

do {
while (ret == WS_REKEYING || ssh->error == WS_REKEYING) {
ret = wolfSSH_worker(ssh, NULL);
if (ret != WS_SUCCESS && ret == WS_FATAL_ERROR) {
ret = wolfSSH_get_error(ssh);
}
}

ret = wolfSSH_SFTP_RMDIR(ssh, pt);
err = wolfSSH_get_error(ssh);
} while ((err == WS_WANT_READ || err == WS_WANT_WRITE)
Expand Down Expand Up @@ -924,6 +945,13 @@ static int doCmds(func_args* args)
}

do {
while (ret == WS_REKEYING || ssh->error == WS_REKEYING) {
ret = wolfSSH_worker(ssh, NULL);
if (ret != WS_SUCCESS && ret == WS_FATAL_ERROR) {
ret = wolfSSH_get_error(ssh);
}
}

ret = wolfSSH_SFTP_Remove(ssh, pt);
err = wolfSSH_get_error(ssh);
} while ((err == WS_WANT_READ || err == WS_WANT_WRITE)
Expand Down Expand Up @@ -1119,7 +1147,7 @@ static int doCmds(func_args* args)
/* alternate main loop for the autopilot get/receive */
static int doAutopilot(int cmd, char* local, char* remote)
{
int err;
int err = 0;
int ret = WS_SUCCESS;
char fullpath[128] = ".";
WS_SFTPNAME* name = NULL;
Expand Down Expand Up @@ -1156,6 +1184,12 @@ static int doAutopilot(int cmd, char* local, char* remote)
}

do {
if (err == WS_REKEYING || err == WS_WINDOW_FULL) { /* handle rekeying state */
do {
ret = wolfSSH_worker(ssh, NULL);
} while (ret == WS_REKEYING);
}

if (cmd == AUTOPILOT_PUT) {
ret = wolfSSH_SFTP_Put(ssh, local, fullpath, 0, NULL);
}
Expand All @@ -1164,7 +1198,8 @@ static int doAutopilot(int cmd, char* local, char* remote)
}
err = wolfSSH_get_error(ssh);
} while ((err == WS_WANT_READ || err == WS_WANT_WRITE ||
err == WS_CHAN_RXD || err == WS_REKEYING) &&
err == WS_CHAN_RXD || err == WS_REKEYING ||
err == WS_WINDOW_FULL) &&
ret == WS_FATAL_ERROR);

if (ret != WS_SUCCESS) {
Expand Down Expand Up @@ -1452,14 +1487,52 @@ THREAD_RETURN WOLFSSH_THREAD sftpclient_test(void* args)

WFREE(workingDir, NULL, DYNAMIC_TYPE_TMP_BUFFER);
if (ret == WS_SUCCESS) {
if (wolfSSH_shutdown(ssh) != WS_SUCCESS) {
int rc;
rc = wolfSSH_get_error(ssh);
int err;
ret = wolfSSH_shutdown(ssh);

/* peer hung up, stop trying to shutdown */
if (ret == WS_SOCKET_ERROR_E) {
ret = 0;
}

err = wolfSSH_get_error(ssh);
if (err != WS_SOCKET_ERROR_E &&
(err == WS_WANT_READ || err == WS_WANT_WRITE)) {
int maxAttempt = 10; /* make 10 attempts max before giving up */
int attempt;

for (attempt = 0; attempt < maxAttempt; attempt++) {
ret = wolfSSH_worker(ssh, NULL);
err = wolfSSH_get_error(ssh);

/* peer successfully closed down gracefully */
if (ret == WS_CHANNEL_CLOSED) {
ret = 0;
break;
}

/* peer hung up, stop shutdown */
if (ret == WS_SOCKET_ERROR_E) {
ret = 0;
break;
}

if (err == WS_WANT_READ || err == WS_WANT_WRITE) {
/* Wanting read or wanting write. Clear ret. */
ret = 0;
}
else {
break;
}
}

if (rc != WS_SOCKET_ERROR_E && rc != WS_EOF)
printf("error with wolfSSH_shutdown()\n");
if (attempt == maxAttempt) {
printf("SFTP client gave up on gracefull shutdown,"
"closing the socket\n");
}
}
}

WCLOSESOCKET(sockFd);
wolfSSH_free(ssh);
wolfSSH_CTX_free(ctx);
Expand Down
57 changes: 54 additions & 3 deletions src/internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,40 @@ static void HandshakeInfoFree(HandshakeInfo* hs, void* heap)
}


/* RFC 4253 section 7.1, Once having sent SSH_MSG_KEXINIT the only messages
* that can be sent are 1-19 (except SSH_MSG_SERVICE_REQUEST and
* SSH_MSG_SERVICE_ACCEPT), 20-29 (except SSH_MSG_KEXINIT again), and 30-49
*/
INLINE static int IsMessageAllowedKeying(WOLFSSH *ssh, byte msg)
{
if (ssh->isKeying == 0) {
return 1;
}

/* case of service request or accept in 1-19 */
if (msg == MSGID_SERVICE_REQUEST || msg == MSGID_SERVICE_ACCEPT) {
WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by during rekeying", msg);
ssh->error = WS_REKEYING;
return 0;
}

/* case of resending SSH_MSG_KEXINIT */
if (msg == MSGID_KEXINIT) {
WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by during rekeying", msg);
ssh->error = WS_REKEYING;
return 0;
}

/* case where message id greater than 49 */
if (msg >= MSGID_USERAUTH_REQUEST) {
WLOG(WS_LOG_DEBUG, "Message ID %u not allowed by during rekeying", msg);
ssh->error = WS_REKEYING;
return 0;
}
return 1;
}


#ifndef NO_WOLFSSH_SERVER
INLINE static int IsMessageAllowedServer(WOLFSSH *ssh, byte msg)
{
Expand Down Expand Up @@ -673,8 +707,14 @@ INLINE static int IsMessageAllowedClient(WOLFSSH *ssh, byte msg)
#endif /* NO_WOLFSSH_CLIENT */


INLINE static int IsMessageAllowed(WOLFSSH *ssh, byte msg)
/* 'state' argument is for if trying to send a message or receive one.
* Returns 1 if allowed 0 if not allowed. */
INLINE static int IsMessageAllowed(WOLFSSH *ssh, byte msg, byte state)
{
if (state == WS_MSG_SEND && !IsMessageAllowedKeying(ssh, msg)) {
return 0;
}

#ifndef NO_WOLFSSH_SERVER
if (ssh->ctx->side == WOLFSSH_ENDPOINT_SERVER) {
return IsMessageAllowedServer(ssh, msg);
Expand Down Expand Up @@ -5905,7 +5945,6 @@ static int DoNewKeys(WOLFSSH* ssh, byte* buf, word32 len, word32* idx)
HandshakeInfoFree(ssh->handshake, ssh->ctx->heap);
ssh->handshake = NULL;
WLOG(WS_LOG_DEBUG, "Keying completed");

if (ssh->ctx->keyingCompletionCb)
ssh->ctx->keyingCompletionCb(ssh->keyingCompletionCtx);
}
Expand Down Expand Up @@ -9309,7 +9348,7 @@ static int DoPacket(WOLFSSH* ssh, byte* bufferConsumed)
return WS_OVERFLOW_E;
}

if (!IsMessageAllowed(ssh, msg)) {
if (!IsMessageAllowed(ssh, msg, WS_MSG_RECV)) {
return WS_MSGID_NOT_ALLOWED_E;
}

Expand Down Expand Up @@ -15649,6 +15688,12 @@ int SendChannelEof(WOLFSSH* ssh, word32 peerChannelId)
if (ssh == NULL)
ret = WS_BAD_ARGUMENT;

if (ret == WS_SUCCESS) {
if (!IsMessageAllowed(ssh, MSGID_CHANNEL_EOF, WS_MSG_SEND)) {
ret = WS_MSGID_NOT_ALLOWED_E;
}
}

if (ret == WS_SUCCESS) {
channel = ChannelFind(ssh, peerChannelId, WS_CHANNEL_ID_PEER);
if (channel == NULL)
Expand Down Expand Up @@ -16077,6 +16122,12 @@ int SendChannelWindowAdjust(WOLFSSH* ssh, word32 channelId,
if (ssh == NULL)
ret = WS_BAD_ARGUMENT;

if (ret == WS_SUCCESS) {
if (!IsMessageAllowed(ssh, MSGID_CHANNEL_WINDOW_ADJUST, WS_MSG_SEND)) {
ret = WS_MSGID_NOT_ALLOWED_E;
}
}

channel = ChannelFind(ssh, channelId, WS_CHANNEL_ID_SELF);
if (channel == NULL) {
WLOG(WS_LOG_DEBUG, "Invalid channel");
Expand Down
14 changes: 12 additions & 2 deletions src/ssh.c
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,11 @@ int wolfSSH_stream_read(WOLFSSH* ssh, byte* buf, word32 bufSz)
return WS_ERROR;
}

if (ssh->isKeying) {
ssh->error = WS_REKEYING;
return WS_FATAL_ERROR;
}

inputBuffer = &ssh->channelList->inputBuffer;
ssh->error = WS_SUCCESS;

Expand Down Expand Up @@ -1164,7 +1169,7 @@ int wolfSSH_stream_read(WOLFSSH* ssh, byte* buf, word32 bufSz)
}

/* update internal input buffer based on data read */
if (ret == WS_SUCCESS) {
if (ret == WS_SUCCESS && !ssh->isKeying) {
int n;

n = min(bufSz, inputBuffer->length - inputBuffer->idx);
Expand Down Expand Up @@ -1196,7 +1201,7 @@ int wolfSSH_stream_send(WOLFSSH* ssh, byte* buf, word32 bufSz)

if (ssh->isKeying) {
ssh->error = WS_REKEYING;
return WS_REKEYING;
return WS_FATAL_ERROR;
}

bytesTxd = SendChannelData(ssh, ssh->channelList->channel, buf, bufSz);
Expand Down Expand Up @@ -2901,6 +2906,11 @@ int wolfSSH_ChannelRead(WOLFSSH_CHANNEL* channel, byte* buf, word32 bufSz)
if (channel == NULL || buf == NULL || bufSz == 0)
return WS_BAD_ARGUMENT;

if (channel->ssh->isKeying) {
channel->ssh->error = WS_REKEYING;
return WS_REKEYING;
}

bufSz = _ChannelRead(channel, buf, bufSz);

WLOG(WS_LOG_DEBUG, "Leaving wolfSSH_ChannelRead(), bytesRxd = %d",
Expand Down
Loading