/* XXX * we should inspire from netcat * http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/nc/ * see readwrite(), atomicio() */ #include #include #include #include #include #include #include #include #include #include "pg.h" extern char *__progname; static int usage(); static void process(void); static int pfds_create(struct pollfd *); static int cmd_init(struct conn *, int, struct msg_header *); static int cmd_exec(struct conn *, int, struct msg_header *); static int cmd_kill(struct conn *, int, struct msg_header *); static int cmd_read(struct conn *, int, struct msg_header *); static int cmd_write(struct conn *, int, struct msg_header *); static int cmd_data_conn(struct conn *, int, struct msg_header *); static int cmd_data_route(struct route *, int, struct msg_header *); static int cmd_ok(struct route *, int, struct msg_header *); static int cmd_err(struct route *, int, struct msg_header *); int main(int argc, char *argv[]) { int opt; int loglevel = 1, logstdout = 0; int deamonize = 0; int server = 0; char dest = 0; char *cmd = NULL; while ((opt = getopt(argc, argv, "dhltv")) != -1) { switch (opt) { case 'd': logstdout = 1; deamonize = 0; break; case 'h': usage(); break; case 'l': server = 1; break; case 't': dest = optarg[0]; break; case 'v': loglevel++; break; default: usage(); break; } } argc -= optind; argv += optind; if ((server && argc != 0) || (!server && argc != 1) || (server && conf.client_dest)) usage(); if (!server) { cmd = argv[0]; if (dest == 0) dest = 'B'; conf.client_dest = dest; } log_init(loglevel, logstdout); log_warn("** Starting propagate v%.1f", VERSION); log_warn("** using message version %d", MSG_VERSION); log_warn("** loglevel %d", loglevel); conf.server = server; LIST_INIT(&routes); routes_count = 0; LIST_INIT(&listeners); listeners_count = 0; /* XXX load conf */ if (server) { conf.me = 'B'; listener_add(LISTENER_UNIX, "/tmp/propagate_sock"); } else { conf.me = 'A'; route_add('B', ROUTE_PROC, "nc 127.0.0.1 3333", 0, NULL, 0); } if (server) log_info("starting server %c", conf.me); else log_info("running client %c, dest %c, command %s", conf.me, conf.client_dest, cmd); if (deamonize) log_info("XXX deamonize not implemented yet"); if (!server) send_cmd(conf.client_dest, MSG_EXEC, 0, (uint8_t *)cmd, strlen(cmd)); process(); return 0; } static int usage() { printf("Usage: %s [-dhv] [-t destination] command\n", __progname); printf(" %s -l [-dhv]\n", __progname); exit(1); } static void process(void) { struct listener *l; struct conn *c; struct route *r; uint8_t *buf; struct msg_header *hdr; struct sockaddr_storage cliaddr; socklen_t slen = sizeof(cliaddr); int len, n, fd, cfd; struct pollfd pfds[POLLER_MAX]; int pfds_count; /* keep in sync with pg.h MSG enum */ int (*cmd_conn[MSG_MAX+1]) (struct conn *, int, struct msg_header *) = { &cmd_init, &cmd_init, &cmd_kill, &cmd_exec, &cmd_read, &cmd_write, &cmd_data_conn, NULL, NULL }; int (*cmd_route[MSG_MAX+1]) (struct route *, int, struct msg_header *) = { NULL, NULL, NULL, NULL, NULL, NULL, &cmd_data_route, &cmd_ok, &cmd_err }; /* OLD TODO * fork frontend and open pipe_fe * poll() * - pipe_fe input -> write stdout * - stdin -> append send_buf * sigalarm() every second * - MSG_READ and write on stdout * - if size(send_buf) > 0: MSG_WRITE send_buf * handle network errors / retransmission in fe * MSG_[READ|WRITE] do not expect MSG_OK, but warns on MSG_ERROR * MSG_OK only for MSG_INIT and MSG_EXEC */ for (;;) { pfds_count = pfds_create(pfds); n = poll(pfds, pfds_count, POLL_TIMEOUT); log_tmp("end of poll, %d fds on %d", n, pfds_count); if (n < 0) { log_warn("polling error"); // XXX fatal ? continue; } for (n=0; nasync) listener_conn_exec_bufferize(c, fd); else msg_send_from_fd(c->fd, MSG_DATA, c->orig, 0, fd); } else if (fd == fileno(stdin)) { log_debug("got stdin data"); r = route_find(conf.client_dest); if (r->proc.async) route_bufferize(r, fd); else msg_send_from_fd(r->proc.fd[1], MSG_DATA, r->dest, 0, fd); } else { log_debug("got command"); len = readbuf(fd, &buf, sizeof(MSG_MAGIC)); if (len < sizeof(MSG_MAGIC)) { log_info("Magic number too short"); continue; } if (memcmp(buf, MSG_MAGIC, sizeof(MSG_MAGIC))) { log_info("Invalid magic number %4x", buf); continue; } len = readbuf(fd, &buf, MSG_HEADER_SIZE_ENCODED); if (len < MSG_HEADER_SIZE_ENCODED) { log_warn("Message header too short"); continue; } hdr = msg_unpack_header(buf); if (!hdr) { log_warn("Invalid message header"); continue; } if (hdr->dest != conf.me) { route_fw(fd, hdr, buf); free(hdr); continue; } if ((c = listener_conn_find(fd))) { cmd_conn[hdr->type](c, fd, hdr); break; } else if ((r = route_find(hdr->orig))) { cmd_route[hdr->type](r, fd, hdr); break; } else log_warn("host %c has no reference ! ignoring command...", hdr->orig); free(hdr); } } } } static int pfds_create(struct pollfd *pfds) { struct route *r; struct listener *l; struct conn *c; int count = 0; #define ADDFD(myfd) { \ pfds[count].fd = myfd; \ pfds[count].events = POLLIN; \ count++; \ } if (!conf.server) ADDFD(fileno(stdin)) LIST_FOREACH(r, &routes, entry) { switch (r->type) { case ROUTE_PROC: ADDFD(r->proc.fd[0]) } } LIST_FOREACH(l, &listeners, entry) { ADDFD(l->sock) LIST_FOREACH(c, &l->conns, entry) { ADDFD(c->fd) if (c->exec.cmd) ADDFD(c->exec.fd[0]) } } return count; } static int cmd_init(struct conn *c, int fd, struct msg_header *hdr) { struct conn *oldc; log_debug("received INIT from %c :)", hdr->orig); if (c->state == CONN_READY) { log_warn("received INIT on an already open connection !"); send_cmd(hdr->orig, MSG_ERR, 0, NULL, 0); return -1; } /* checking if this client already has a connection opened on * another fd */ oldc = listener_conn_find_orig(hdr->orig); if (oldc) listener_conn_move(oldc, c); c->orig = hdr->orig; c->state = CONN_READY; switch (hdr->type) { case MSG_INIT: c->async = 0; break; case MSG_INIT_ASYNC: c->async = 1; break; } send_cmd(hdr->orig, MSG_OK, 0, NULL, 0); return 0; } static int cmd_exec(struct conn *c, int fd, struct msg_header *hdr) { char *cmd = NULL; int cmdlen; char **argv = NULL; int argc; cmdlen = readbuf(fd, (uint8_t **)&cmd, hdr->datalen); if (cmdlen <= 0 || !cmd) goto err; argv = explode(cmd, cmdlen, " ", &argc); if (!argv) goto err; if (listener_conn_exec(c, argv[0], argv) < 0) goto err; send_cmd(hdr->orig, MSG_OK, 0, NULL, 0); return 0; err: log_debug("exec failed"); if (cmd) free(cmd); if (argv) free(argv); send_cmd(hdr->orig, MSG_ERR, 0, NULL, 0); return -1; } static int cmd_kill(struct conn *c, int fd, struct msg_header *hdr) { listener_conn_exec_kill(c); send_cmd(hdr->orig, MSG_OK, 0, NULL, 0); return 0; } static int cmd_read(struct conn *c, int fd, struct msg_header *hdr) { int len; if (!c->exec.cmd) { log_warn("cmd_read: no exec in progress !"); goto err; } if (c->async == 0) { log_warn("cmd_read: not in async mode !"); goto err; } len = send_cmd(c->orig, MSG_DATA, 0, c->exec.async_writebuf, c->exec.async_writebuf_size); if (len < 0) goto err; c->exec.async_writebuf_size -= len; c->exec.async_writebuf = realloc(c->exec.async_writebuf, c->exec.async_writebuf_size); return 0; err: send_cmd(c->orig, MSG_ERR, 0, NULL, 0); return -1; } static int cmd_write(struct conn *c, int fd, struct msg_header *hdr) { if (!c->exec.cmd) { log_warn("cmd_write: no exec in progress !"); goto err; } if (c->async == 0) { log_warn("cmd_write: not in async mode !"); goto err; } if (msg_read_data_to_fd(fd, c->exec.fd[1], hdr->datalen) < 0) { goto err; } return 0; err: send_cmd(c->orig, MSG_ERR, 0, NULL, 0); return -1; } static int cmd_data_conn(struct conn *c, int fd, struct msg_header *hdr) { log_debug("received DATA from conn %c !", hdr->orig); if (msg_read_data_to_fd(fd, c->exec.fd[1], hdr->datalen) < 0) { log_warn("cmd_data: recvwrite failed"); return -1; } return 0; } static int cmd_data_route(struct route *r, int fd, struct msg_header *hdr) { log_debug("received DATA from route %c !", hdr->orig); if (msg_read_data_to_fd(fd, fileno(stdout), hdr->datalen) < 0) { log_warn("cmd_data: recvwrite failed"); return -1; } return 0; } static int cmd_ok(struct route *r, int fd, struct msg_header *hdr) { log_info("received OK from %c !", hdr->orig); return 0; } static int cmd_err(struct route *r, int fd, struct msg_header *hdr) { log_info("received ERROR from %c !", hdr->orig); return 0; }