aboutsummaryrefslogtreecommitdiffstats
path: root/broken/propagate/src/msg.c
blob: 568ff44c6dfc9be17934b1d188221b1e80cfa752 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdint.h>

#include <ctype.h>
#include <unistd.h>
#include <sys/select.h>
#include <sys/types.h>

#include "pg.h"

static int encode(uint8_t *, uint8_t **, int);
static int decode(uint8_t *, uint8_t **, int);

int
msg_send(int fd, int type, char dest, uint8_t arg, int datalen, uint8_t *data) {
    uint8_t *header = NULL, *data_encoded = NULL;
    int len, data_encoded_len;

    log_debug("msg_send");

    len = writebuf((uint8_t *)MSG_MAGIC, fd, sizeof(MSG_MAGIC));
    if (len < 0)
        goto err;
    header = msg_pack_header(type, conf.me, dest, arg, datalen);
    if (!header)
        goto err;
    len = writebuf(header, fd, MSG_HEADER_SIZE_ENCODED);
    if (len < 0)
        goto err;

    if (datalen > 0 && data) {
        data_encoded_len = encode(data, &data_encoded, datalen);
        if (!data_encoded || data_encoded_len <= 0)
            goto err;
        len = writebuf(data, fd, data_encoded_len);
        if (len < 0)
            goto err;
        free(data_encoded);
    }
    free(header);

    return len;

err:
    if (header)
        free(header);
    if (data_encoded)
        free(data_encoded);
    log_warn("msg_send: err");
    return -1;
}

int
msg_send_from_fd(int fd, int type, char dest, uint8_t arg, int datafd) {
    uint8_t buf[BUFMAX];
    uint8_t *encoded;
    int buflen, encodedlen;

    log_debug("msg_send_from_fd");

    buflen = read(datafd, buf, sizeof(buf));
    if (buflen < 0)
        return -1;
    encodedlen = encode(buf, &encoded, buflen);
    if (encodedlen < 0)
        return -1;
    msg_send(fd, type, dest, arg, encodedlen, encoded);
    free(encoded);

    return 0;
}

uint8_t *
msg_pack_header(int type, char orig, char dest, uint8_t arg, int datalen) {
    struct msg_header h;
    uint8_t *data;
    int len;

    if (datalen > MSG_DATALEN_MAX) {
        log_warn("msg_pack_header: datalen too big %d", datalen);
        return NULL;
    }

    h.version = MSG_VERSION;
    h.type = type;
    h.orig = orig;
    h.dest = dest;
    h.arg = arg;
    h.datalen = datalen;

    len = encode((uint8_t *)&h, &data, MSG_HEADER_SIZE);
    if (len != MSG_HEADER_SIZE_ENCODED) {
        log_warn("msg_pack_header: encoded header has invalid size %d !"
                 "This should NOT happend, as MSG_HEADER_SIZE_ENCODED should be fixed !",
                len);
        return NULL;
    }

    return data;
}

#define ERR(msg...) \
{                   \
    log_warn(msg);  \
    return NULL;    \
}
struct msg_header *
msg_unpack_header(uint8_t *data) {
    uint8_t *data_decoded;
    struct msg_header *h;
    int len;
    /* keep in sync with pg.h MSG enum */
    int msg_client[] = {MSG_DATA, MSG_OK, MSG_ERR};
    int msg_server[] = {MSG_INIT, MSG_INIT_ASYNC, MSG_KILL, MSG_EXEC, MSG_READ, MSG_WRITE, MSG_DATA};
#define MSG_CLIENT_COUNT 3
#define MSG_SERVER_COUNT 7

    len = decode(data, &data_decoded, MSG_HEADER_SIZE_ENCODED);
    if (!data_decoded)
        ERR("failed to decode message header")

    h = (struct msg_header *)data_decoded;

    if (h->version <= 0 || h->version > 255)
        ERR("msg_unpack_header: invalid version %d", h->version)
    if (h->version != MSG_VERSION)
        ERR("msg_unpack_header: incompatible version %d", h->version)
    if (conf.server && (intab(msg_server, h->type, MSG_SERVER_COUNT) < 0))
        ERR("msg_unpack_header: type %d incorrect for server", h->type)
    if (!conf.server && (intab(msg_client, h->type, MSG_CLIENT_COUNT) < 0))
        ERR("msg_unpack_header: type %d incorrect for client", h->type)
    if (!isalnum(h->orig))
        ERR("msg_unpack_header: non alphanumeric originator %c", h->orig)
    if (h->orig == conf.me)
        ERR("msg_unpack_header: message pretends to come from me, ignoring")
    if (!isalnum(h->dest))
        ERR("msg_unpack_header: non alphanumeric destination %c", h->dest)
    // XXX check dest on routes ?
    if (h->datalen < 0)
        ERR("msg_unpack_header: datalen < 0")
    if (h->datalen > MSG_DATALEN_MAX)
        ERR("msg_unpack_header: datalen too big %d", h->datalen)

    return h;
}

int
msg_read_data(int fd, uint8_t **out, int len) {
    uint8_t *buf;
    int buflen, decodedlen;

    buflen = readbuf(fd, &buf, len);
    if (buflen < 0)
        return -1;
    decodedlen = decode(buf, out, buflen);
    if (decodedlen < 0)
        log_warn("msg_read_data: decoding failed");

    return decodedlen;
}

/* XXX rename to msg_fw_decode ? */
int
msg_read_data_to_fd(int ifd, int fd, int len) {
    uint8_t *buf;
    int buflen, writelen;

    buflen = msg_read_data(ifd, &buf, len);
    if (buflen < 0)
        goto err;
    writelen = writebuf(buf, fd, buflen);
    if (writelen < 0)
        goto err;
    free(buf);

    return writelen;

err:
    if (buf)
        free(buf);
    return -1;
}

static int
encode(uint8_t *arg, uint8_t **out, int len) {
    uint8_t *res;
    int outlen;

    outlen = len;
    *out = NULL;
    res = xmalloc(sizeof(uint8_t) * outlen);
    if (!res)
        goto err;
    memcpy(res, arg, len);

    *out = res;
    return outlen;

err:
    log_warn("encode failed");
    return -1;
}

static int
decode(uint8_t *arg, uint8_t **out, int len) {
    uint8_t *res;
    int outlen;

    outlen = len;
    *out = NULL;
    res = xmalloc(sizeof(uint8_t) * outlen);
    if (!res)
        goto err;
    memcpy(res, arg, len);

    *out = res;
    return outlen;

err:
    log_warn("decode failed");
    return -1;
}