aboutsummaryrefslogtreecommitdiffstats
path: root/net/bpfilter/bpfilter_kern.c
blob: c0fcde910a7ad75277765aaf7d7476cdc7dc861e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// SPDX-License-Identifier: GPL-2.0
#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
#include <linux/init.h>
#include <linux/module.h>
#include <linux/umh.h>
#include <linux/bpfilter.h>
#include <linux/sched.h>
#include <linux/sched/signal.h>
#include <linux/fs.h>
#include <linux/file.h>
#include "msgfmt.h"

extern char bpfilter_umh_start;
extern char bpfilter_umh_end;

/* since ip_getsockopt() can run in parallel, serialize access to umh */
static DEFINE_MUTEX(bpfilter_lock);

static void shutdown_umh(void)
{
	struct task_struct *tsk;

	if (bpfilter_ops.stop)
		return;

	tsk = get_pid_task(find_vpid(bpfilter_ops.info.pid), PIDTYPE_PID);
	if (tsk) {
		force_sig(SIGKILL, tsk);
		put_task_struct(tsk);
	}
}

static void __stop_umh(void)
{
	if (IS_ENABLED(CONFIG_INET))
		shutdown_umh();
}

static void stop_umh(void)
{
	mutex_lock(&bpfilter_lock);
	__stop_umh();
	mutex_unlock(&bpfilter_lock);
}

static int __bpfilter_process_sockopt(struct sock *sk, int optname,
				      char __user *optval,
				      unsigned int optlen, bool is_set)
{
	struct mbox_request req;
	struct mbox_reply reply;
	loff_t pos;
	ssize_t n;
	int ret = -EFAULT;

	req.is_set = is_set;
	req.pid = current->pid;
	req.cmd = optname;
	req.addr = (long __force __user)optval;
	req.len = optlen;
	mutex_lock(&bpfilter_lock);
	if (!bpfilter_ops.info.pid)
		goto out;
	n = __kernel_write(bpfilter_ops.info.pipe_to_umh, &req, sizeof(req),
			   &pos);
	if (n != sizeof(req)) {
		pr_err("write fail %zd\n", n);
		__stop_umh();
		ret = -EFAULT;
		goto out;
	}
	pos = 0;
	n = kernel_read(bpfilter_ops.info.pipe_from_umh, &reply, sizeof(reply),
			&pos);
	if (n != sizeof(reply)) {
		pr_err("read fail %zd\n", n);
		__stop_umh();
		ret = -EFAULT;
		goto out;
	}
	ret = reply.status;
out:
	mutex_unlock(&bpfilter_lock);
	return ret;
}

static int start_umh(void)
{
	int err;

	/* fork usermode process */
	err = fork_usermode_blob(&bpfilter_umh_start,
				 &bpfilter_umh_end - &bpfilter_umh_start,
				 &bpfilter_ops.info);
	if (err)
		return err;
	bpfilter_ops.stop = false;
	pr_info("Loaded bpfilter_umh pid %d\n", bpfilter_ops.info.pid);

	/* health check that usermode process started correctly */
	if (__bpfilter_process_sockopt(NULL, 0, NULL, 0, 0) != 0) {
		stop_umh();
		return -EFAULT;
	}

	return 0;
}

static int __init load_umh(void)
{
	int err;

	if (!bpfilter_ops.stop)
		return -EFAULT;
	err = start_umh();
	if (!err && IS_ENABLED(CONFIG_INET)) {
		bpfilter_ops.sockopt = &__bpfilter_process_sockopt;
		bpfilter_ops.start = &start_umh;
	}

	return err;
}

static void __exit fini_umh(void)
{
	if (IS_ENABLED(CONFIG_INET)) {
		bpfilter_ops.start = NULL;
		bpfilter_ops.sockopt = NULL;
	}
	stop_umh();
}
module_init(load_umh);
module_exit(fini_umh);
MODULE_LICENSE("GPL");