aboutsummaryrefslogtreecommitdiffstats
path: root/mm/mempolicy.c
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--mm/mempolicy.c23
1 files changed, 20 insertions, 3 deletions
diff --git a/mm/mempolicy.c b/mm/mempolicy.c
index 6e867a8dcca9..65df28d7cc89 100644
--- a/mm/mempolicy.c
+++ b/mm/mempolicy.c
@@ -1263,6 +1263,7 @@ static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
unsigned long maxnode)
{
unsigned long k;
+ unsigned long t;
unsigned long nlongs;
unsigned long endmask;
@@ -1279,11 +1280,17 @@ static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
else
endmask = (1UL << (maxnode % BITS_PER_LONG)) - 1;
- /* When the user specified more nodes than supported just check
- if the non supported part is all zero. */
+ /*
+ * When the user specified more nodes than supported just check
+ * if the non supported part is all zero.
+ *
+ * If maxnode have more longs than MAX_NUMNODES, check
+ * the bits in that area first. And then go through to
+ * check the rest bits which equal or bigger than MAX_NUMNODES.
+ * Otherwise, just check bits [MAX_NUMNODES, maxnode).
+ */
if (nlongs > BITS_TO_LONGS(MAX_NUMNODES)) {
for (k = BITS_TO_LONGS(MAX_NUMNODES); k < nlongs; k++) {
- unsigned long t;
if (get_user(t, nmask + k))
return -EFAULT;
if (k == nlongs - 1) {
@@ -1296,6 +1303,16 @@ static int get_nodes(nodemask_t *nodes, const unsigned long __user *nmask,
endmask = ~0UL;
}
+ if (maxnode > MAX_NUMNODES && MAX_NUMNODES % BITS_PER_LONG != 0) {
+ unsigned long valid_mask = endmask;
+
+ valid_mask &= ~((1UL << (MAX_NUMNODES % BITS_PER_LONG)) - 1);
+ if (get_user(t, nmask + nlongs - 1))
+ return -EFAULT;
+ if (t & valid_mask)
+ return -EINVAL;
+ }
+
if (copy_from_user(nodes_addr(*nodes), nmask, nlongs*sizeof(unsigned long)))
return -EFAULT;
nodes_addr(*nodes)[nlongs-1] &= endmask;