summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbeck <beck@openbsd.org>2019-11-13 04:10:38 +0000
committerbeck <beck@openbsd.org>2019-11-13 04:10:38 +0000
commit16e245ae7f37a55f059145ad7a4864aac65be61e (patch)
tree87d7128c2940cf9485093f8ce1898382b611eff3
parentunbreak ramdisks (diff)
downloadwireguard-openbsd-16e245ae7f37a55f059145ad7a4864aac65be61e.tar.xz
wireguard-openbsd-16e245ae7f37a55f059145ad7a4864aac65be61e.zip
refactor the nc pool loop to not shut down the socket early, and
to handle tls_shutdown correctly if using TLS, doing tls_shutdown correctly if we are using the -N flag ok sthen@
-rw-r--r--usr.bin/nc/netcat.c100
1 files changed, 64 insertions, 36 deletions
diff --git a/usr.bin/nc/netcat.c b/usr.bin/nc/netcat.c
index a53fe1c4cd6..1dc95e9f360 100644
--- a/usr.bin/nc/netcat.c
+++ b/usr.bin/nc/netcat.c
@@ -1,4 +1,4 @@
-/* $OpenBSD: netcat.c,v 1.210 2019/11/04 17:33:28 millert Exp $ */
+/* $OpenBSD: netcat.c,v 1.211 2019/11/13 04:10:38 beck Exp $ */
/*
* Copyright (c) 2001 Eric Jackson <ericj@monkey.org>
* Copyright (c) 2015 Bob Beck. All rights reserved.
@@ -1103,13 +1103,14 @@ void
readwrite(int net_fd, struct tls *tls_ctx)
{
struct pollfd pfd[4];
+ int gone[4] = { 0 };
int stdin_fd = STDIN_FILENO;
int stdout_fd = STDOUT_FILENO;
unsigned char netinbuf[BUFSIZE];
size_t netinbufpos = 0;
unsigned char stdinbuf[BUFSIZE];
size_t stdinbufpos = 0;
- int n, num_fds;
+ int n, num_fds, shutdown_netin, shutdown_netout;
ssize_t ret;
/* don't read from stdin if requested */
@@ -1132,17 +1133,20 @@ readwrite(int net_fd, struct tls *tls_ctx)
pfd[POLL_STDOUT].fd = stdout_fd;
pfd[POLL_STDOUT].events = 0;
+ /* used to indicate we wish to shut down the network socket */
+ shutdown_netin = shutdown_netout = 0;
+
while (1) {
/* both inputs are gone, buffers are empty, we are done */
- if (pfd[POLL_STDIN].fd == -1 && pfd[POLL_NETIN].fd == -1 &&
+ if (gone[POLL_STDIN] && gone[POLL_NETIN] &&
stdinbufpos == 0 && netinbufpos == 0)
return;
/* both outputs are gone, we can't continue */
- if (pfd[POLL_NETOUT].fd == -1 && pfd[POLL_STDOUT].fd == -1)
+ if (gone[POLL_NETOUT] && gone[POLL_STDOUT])
return;
/* listen and net in gone, queues empty, done */
- if (lflag && pfd[POLL_NETIN].fd == -1 &&
- stdinbufpos == 0 && netinbufpos == 0)
+ if (lflag && gone[POLL_NETIN] && stdinbufpos == 0
+ && netinbufpos == 0)
return;
/* help says -i is for "wait between lines sent". We read and
@@ -1151,6 +1155,12 @@ readwrite(int net_fd, struct tls *tls_ctx)
if (iflag)
sleep(iflag);
+ /* If it's gone, take it away from poll */
+ for (n = 0; n < 4; n++) {
+ if (gone[n])
+ pfd[n].events = pfd[n].revents = 0;
+ }
+
/* poll */
num_fds = poll(pfd, 4, timeout);
@@ -1165,36 +1175,36 @@ readwrite(int net_fd, struct tls *tls_ctx)
/* treat socket error conditions */
for (n = 0; n < 4; n++) {
if (pfd[n].revents & (POLLERR|POLLNVAL)) {
- pfd[n].fd = -1;
+ gone[n] = 1;
}
}
/* reading is possible after HUP */
if (pfd[POLL_STDIN].events & POLLIN &&
pfd[POLL_STDIN].revents & POLLHUP &&
!(pfd[POLL_STDIN].revents & POLLIN))
- pfd[POLL_STDIN].fd = -1;
+ gone[POLL_STDIN] = 1;
if (pfd[POLL_NETIN].events & POLLIN &&
pfd[POLL_NETIN].revents & POLLHUP &&
!(pfd[POLL_NETIN].revents & POLLIN))
- pfd[POLL_NETIN].fd = -1;
+ gone[POLL_NETIN] = 1;
if (pfd[POLL_NETOUT].revents & POLLHUP) {
if (Nflag)
- shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
- pfd[POLL_NETOUT].fd = -1;
+ shutdown_netout = 1;
+ gone[POLL_NETOUT] = 1;
}
- /* if HUP, stop watching stdout */
- if (pfd[POLL_STDOUT].revents & POLLHUP)
- pfd[POLL_STDOUT].fd = -1;
/* if no net out, stop watching stdin */
- if (pfd[POLL_NETOUT].fd == -1)
- pfd[POLL_STDIN].fd = -1;
+ if (gone[POLL_NETOUT])
+ gone[POLL_STDIN] = 1;
+
+ /* if stdout HUP's, stop watching stdout */
+ if (pfd[POLL_STDOUT].revents & POLLHUP)
+ gone[POLL_STDOUT] = 1;
/* if no stdout, stop watching net in */
- if (pfd[POLL_STDOUT].fd == -1) {
- if (pfd[POLL_NETIN].fd != -1)
- shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
- pfd[POLL_NETIN].fd = -1;
+ if (gone[POLL_STDOUT]) {
+ shutdown_netin = 1;
+ gone[POLL_NETIN] = 1;
}
/* try to read from stdin */
@@ -1206,7 +1216,7 @@ readwrite(int net_fd, struct tls *tls_ctx)
else if (ret == TLS_WANT_POLLOUT)
pfd[POLL_STDIN].events = POLLOUT;
else if (ret == 0 || ret == -1)
- pfd[POLL_STDIN].fd = -1;
+ gone[POLL_STDIN] = 1;
/* read something - poll net out */
if (stdinbufpos > 0)
pfd[POLL_NETOUT].events = POLLOUT;
@@ -1223,7 +1233,7 @@ readwrite(int net_fd, struct tls *tls_ctx)
else if (ret == TLS_WANT_POLLOUT)
pfd[POLL_NETOUT].events = POLLOUT;
else if (ret == -1)
- pfd[POLL_NETOUT].fd = -1;
+ gone[POLL_NETOUT] = 1;
/* buffer empty - remove self from polling */
if (stdinbufpos == 0)
pfd[POLL_NETOUT].events = 0;
@@ -1240,17 +1250,15 @@ readwrite(int net_fd, struct tls *tls_ctx)
else if (ret == TLS_WANT_POLLOUT)
pfd[POLL_NETIN].events = POLLOUT;
else if (ret == -1)
- pfd[POLL_NETIN].fd = -1;
+ gone[POLL_NETIN] = 1;
/* eof on net in - remove from pfd */
if (ret == 0) {
- shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
- pfd[POLL_NETIN].fd = -1;
+ gone[POLL_NETIN] = 1;
}
if (recvlimit > 0 && ++recvcount >= recvlimit) {
- if (pfd[POLL_NETIN].fd != -1)
- shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
- pfd[POLL_NETIN].fd = -1;
- pfd[POLL_STDIN].fd = -1;
+ shutdown_netin = 1;
+ gone[POLL_NETIN] = 1;
+ gone[POLL_STDIN] = 1;
}
/* read something - poll stdout */
if (netinbufpos > 0)
@@ -1272,7 +1280,7 @@ readwrite(int net_fd, struct tls *tls_ctx)
else if (ret == TLS_WANT_POLLOUT)
pfd[POLL_STDOUT].events = POLLOUT;
else if (ret == -1)
- pfd[POLL_STDOUT].fd = -1;
+ gone[POLL_STDOUT] = 1;
/* buffer empty - remove self from polling */
if (netinbufpos == 0)
pfd[POLL_STDOUT].events = 0;
@@ -1282,14 +1290,34 @@ readwrite(int net_fd, struct tls *tls_ctx)
}
/* stdin gone and queue empty? */
- if (pfd[POLL_STDIN].fd == -1 && stdinbufpos == 0) {
- if (pfd[POLL_NETOUT].fd != -1 && Nflag)
- shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
- pfd[POLL_NETOUT].fd = -1;
+ if (gone[POLL_STDIN] && stdinbufpos == 0) {
+ if (Nflag) {
+ shutdown_netin = 1;
+ shutdown_netout = 1;
+ }
+ gone[POLL_NETOUT] = 1;
}
/* net in gone and queue empty? */
- if (pfd[POLL_NETIN].fd == -1 && netinbufpos == 0) {
- pfd[POLL_STDOUT].fd = -1;
+ if (gone[POLL_NETIN] && netinbufpos == 0) {
+ if (Nflag) {
+ shutdown_netin = 1;
+ shutdown_netout = 1;
+ }
+ gone[POLL_STDOUT] = 1;
+ }
+
+ /* call tls_close if any part of the network socket is closing */
+ if ((shutdown_netin || shutdown_netout) && usetls) {
+ timeout_tls(pfd[POLL_NETIN].fd, tls_ctx, tls_close);
+ shutdown_netout = shutdown_netin = 1;
+ }
+ if (shutdown_netin) {
+ shutdown(pfd[POLL_NETIN].fd, SHUT_RD);
+ gone[POLL_NETIN] = 1;
+ }
+ if (shutdown_netout) {
+ shutdown(pfd[POLL_NETOUT].fd, SHUT_WR);
+ gone[POLL_NETOUT] = 1;
}
}
}