aboutsummaryrefslogtreecommitdiffstats
path: root/tools/testing/selftests/bpf/progs/test_tcp_check_syncookie_kern.c
blob: d8803dfa8d32f0016327c62dbffd35061e1c2598 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
// SPDX-License-Identifier: GPL-2.0
// Copyright (c) 2018 Facebook
// Copyright (c) 2019 Cloudflare

#include <string.h>

#include <linux/bpf.h>
#include <linux/pkt_cls.h>
#include <linux/if_ether.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <sys/socket.h>
#include <linux/tcp.h>

#include "bpf_helpers.h"
#include "bpf_endian.h"

struct bpf_map_def SEC("maps") results = {
	.type = BPF_MAP_TYPE_ARRAY,
	.key_size = sizeof(__u32),
	.value_size = sizeof(__u32),
	.max_entries = 3,
};

static __always_inline __s64 gen_syncookie(void *data_end, struct bpf_sock *sk,
					   void *iph, __u32 ip_size,
					   struct tcphdr *tcph)
{
	__u32 thlen = tcph->doff * 4;

	if (tcph->syn && !tcph->ack) {
		// packet should only have an MSS option
		if (thlen != 24)
			return 0;

		if ((void *)tcph + thlen > data_end)
			return 0;

		return bpf_tcp_gen_syncookie(sk, iph, ip_size, tcph, thlen);
	}
	return 0;
}

static __always_inline void check_syncookie(void *ctx, void *data,
					    void *data_end)
{
	struct bpf_sock_tuple tup;
	struct bpf_sock *sk;
	struct ethhdr *ethh;
	struct iphdr *ipv4h;
	struct ipv6hdr *ipv6h;
	struct tcphdr *tcph;
	int ret;
	__u32 key_mss = 2;
	__u32 key_gen = 1;
	__u32 key = 0;
	__s64 seq_mss;

	ethh = data;
	if (ethh + 1 > data_end)
		return;

	switch (bpf_ntohs(ethh->h_proto)) {
	case ETH_P_IP:
		ipv4h = data + sizeof(struct ethhdr);
		if (ipv4h + 1 > data_end)
			return;

		if (ipv4h->ihl != 5)
			return;

		tcph = data + sizeof(struct ethhdr) + sizeof(struct iphdr);
		if (tcph + 1 > data_end)
			return;

		tup.ipv4.saddr = ipv4h->saddr;
		tup.ipv4.daddr = ipv4h->daddr;
		tup.ipv4.sport = tcph->source;
		tup.ipv4.dport = tcph->dest;

		sk = bpf_skc_lookup_tcp(ctx, &tup, sizeof(tup.ipv4),
					BPF_F_CURRENT_NETNS, 0);
		if (!sk)
			return;

		if (sk->state != BPF_TCP_LISTEN)
			goto release;

		seq_mss = gen_syncookie(data_end, sk, ipv4h, sizeof(*ipv4h),
					tcph);

		ret = bpf_tcp_check_syncookie(sk, ipv4h, sizeof(*ipv4h),
					      tcph, sizeof(*tcph));
		break;

	case ETH_P_IPV6:
		ipv6h = data + sizeof(struct ethhdr);
		if (ipv6h + 1 > data_end)
			return;

		if (ipv6h->nexthdr != IPPROTO_TCP)
			return;

		tcph = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
		if (tcph + 1 > data_end)
			return;

		memcpy(tup.ipv6.saddr, &ipv6h->saddr, sizeof(tup.ipv6.saddr));
		memcpy(tup.ipv6.daddr, &ipv6h->daddr, sizeof(tup.ipv6.daddr));
		tup.ipv6.sport = tcph->source;
		tup.ipv6.dport = tcph->dest;

		sk = bpf_skc_lookup_tcp(ctx, &tup, sizeof(tup.ipv6),
					BPF_F_CURRENT_NETNS, 0);
		if (!sk)
			return;

		if (sk->state != BPF_TCP_LISTEN)
			goto release;

		seq_mss = gen_syncookie(data_end, sk, ipv6h, sizeof(*ipv6h),
					tcph);

		ret = bpf_tcp_check_syncookie(sk, ipv6h, sizeof(*ipv6h),
					      tcph, sizeof(*tcph));
		break;

	default:
		return;
	}

	if (seq_mss > 0) {
		__u32 cookie = (__u32)seq_mss;
		__u32 mss = seq_mss >> 32;

		bpf_map_update_elem(&results, &key_gen, &cookie, 0);
		bpf_map_update_elem(&results, &key_mss, &mss, 0);
	}

	if (ret == 0) {
		__u32 cookie = bpf_ntohl(tcph->ack_seq) - 1;

		bpf_map_update_elem(&results, &key, &cookie, 0);
	}

release:
	bpf_sk_release(sk);
}

SEC("clsact/check_syncookie")
int check_syncookie_clsact(struct __sk_buff *skb)
{
	check_syncookie(skb, (void *)(long)skb->data,
			(void *)(long)skb->data_end);
	return TC_ACT_OK;
}

SEC("xdp/check_syncookie")
int check_syncookie_xdp(struct xdp_md *ctx)
{
	check_syncookie(ctx, (void *)(long)ctx->data,
			(void *)(long)ctx->data_end);
	return XDP_PASS;
}

char _license[] SEC("license") = "GPL";