#include #include #include #include #include #include #include #include #include "pg.h" static int encode(uint8_t *, uint8_t **, int); static int decode(uint8_t *, uint8_t **, int); int msg_send(int fd, int type, char dest, uint8_t arg, int datalen, uint8_t *data) { uint8_t *header = NULL, *data_encoded = NULL; int len, data_encoded_len; log_debug("msg_send"); len = writebuf((uint8_t *)MSG_MAGIC, fd, sizeof(MSG_MAGIC)); if (len < 0) goto err; header = msg_pack_header(type, conf.me, dest, arg, datalen); if (!header) goto err; len = writebuf(header, fd, MSG_HEADER_SIZE_ENCODED); if (len < 0) goto err; if (datalen > 0 && data) { data_encoded_len = encode(data, &data_encoded, datalen); if (!data_encoded || data_encoded_len <= 0) goto err; len = writebuf(data, fd, data_encoded_len); if (len < 0) goto err; free(data_encoded); } free(header); return len; err: if (header) free(header); if (data_encoded) free(data_encoded); log_warn("msg_send: err"); return -1; } int msg_send_from_fd(int fd, int type, char dest, uint8_t arg, int datafd) { uint8_t buf[BUFMAX]; uint8_t *encoded; int buflen, encodedlen; log_debug("msg_send_from_fd"); buflen = read(datafd, buf, sizeof(buf)); if (buflen < 0) return -1; encodedlen = encode(buf, &encoded, buflen); if (encodedlen < 0) return -1; msg_send(fd, type, dest, arg, encodedlen, encoded); free(encoded); return 0; } uint8_t * msg_pack_header(int type, char orig, char dest, uint8_t arg, int datalen) { struct msg_header h; uint8_t *data; int len; if (datalen > MSG_DATALEN_MAX) { log_warn("msg_pack_header: datalen too big %d", datalen); return NULL; } h.version = MSG_VERSION; h.type = type; h.orig = orig; h.dest = dest; h.arg = arg; h.datalen = datalen; len = encode((uint8_t *)&h, &data, MSG_HEADER_SIZE); if (len != MSG_HEADER_SIZE_ENCODED) { log_warn("msg_pack_header: encoded header has invalid size %d !" "This should NOT happend, as MSG_HEADER_SIZE_ENCODED should be fixed !", len); return NULL; } return data; } #define ERR(msg...) \ { \ log_warn(msg); \ return NULL; \ } struct msg_header * msg_unpack_header(uint8_t *data) { uint8_t *data_decoded; struct msg_header *h; int len; /* keep in sync with pg.h MSG enum */ int msg_client[] = {MSG_DATA, MSG_OK, MSG_ERR}; int msg_server[] = {MSG_INIT, MSG_INIT_ASYNC, MSG_KILL, MSG_EXEC, MSG_READ, MSG_WRITE, MSG_DATA}; #define MSG_CLIENT_COUNT 3 #define MSG_SERVER_COUNT 7 len = decode(data, &data_decoded, MSG_HEADER_SIZE_ENCODED); if (!data_decoded) ERR("failed to decode message header") h = (struct msg_header *)data_decoded; if (h->version <= 0 || h->version > 255) ERR("msg_unpack_header: invalid version %d", h->version) if (h->version != MSG_VERSION) ERR("msg_unpack_header: incompatible version %d", h->version) if (conf.server && (intab(msg_server, h->type, MSG_SERVER_COUNT) < 0)) ERR("msg_unpack_header: type %d incorrect for server", h->type) if (!conf.server && (intab(msg_client, h->type, MSG_CLIENT_COUNT) < 0)) ERR("msg_unpack_header: type %d incorrect for client", h->type) if (!isalnum(h->orig)) ERR("msg_unpack_header: non alphanumeric originator %c", h->orig) if (h->orig == conf.me) ERR("msg_unpack_header: message pretends to come from me, ignoring") if (!isalnum(h->dest)) ERR("msg_unpack_header: non alphanumeric destination %c", h->dest) // XXX check dest on routes ? if (h->datalen < 0) ERR("msg_unpack_header: datalen < 0") if (h->datalen > MSG_DATALEN_MAX) ERR("msg_unpack_header: datalen too big %d", h->datalen) return h; } int msg_read_data(int fd, uint8_t **out, int len) { uint8_t *buf; int buflen, decodedlen; buflen = readbuf(fd, &buf, len); if (buflen < 0) return -1; decodedlen = decode(buf, out, buflen); if (decodedlen < 0) log_warn("msg_read_data: decoding failed"); return decodedlen; } /* XXX rename to msg_fw_decode ? */ int msg_read_data_to_fd(int ifd, int fd, int len) { uint8_t *buf; int buflen, writelen; buflen = msg_read_data(ifd, &buf, len); if (buflen < 0) goto err; writelen = writebuf(buf, fd, buflen); if (writelen < 0) goto err; free(buf); return writelen; err: if (buf) free(buf); return -1; } static int encode(uint8_t *arg, uint8_t **out, int len) { uint8_t *res; int outlen; outlen = len; *out = NULL; res = xmalloc(sizeof(uint8_t) * outlen); if (!res) goto err; memcpy(res, arg, len); *out = res; return outlen; err: log_warn("encode failed"); return -1; } static int decode(uint8_t *arg, uint8_t **out, int len) { uint8_t *res; int outlen; outlen = len; *out = NULL; res = xmalloc(sizeof(uint8_t) * outlen); if (!res) goto err; memcpy(res, arg, len); *out = res; return outlen; err: log_warn("decode failed"); return -1; }