diff options
Diffstat (limited to 'tools/testing/selftests/net/lib/py')
-rw-r--r-- | tools/testing/selftests/net/lib/py/__init__.py | 9 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/consts.py | 9 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/ksft.py | 280 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/netns.py | 49 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/nsim.py | 135 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/utils.py | 212 | ||||
-rw-r--r-- | tools/testing/selftests/net/lib/py/ynl.py | 58 |
7 files changed, 752 insertions, 0 deletions
diff --git a/tools/testing/selftests/net/lib/py/__init__.py b/tools/testing/selftests/net/lib/py/__init__.py new file mode 100644 index 000000000000..8697bd27dc30 --- /dev/null +++ b/tools/testing/selftests/net/lib/py/__init__.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0 + +from .consts import KSRC +from .ksft import * +from .netns import NetNS, NetNSEnter +from .nsim import * +from .utils import * +from .ynl import NlError, YnlFamily, EthtoolFamily, NetdevFamily, RtnlFamily, RtnlAddrFamily +from .ynl import NetshaperFamily diff --git a/tools/testing/selftests/net/lib/py/consts.py b/tools/testing/selftests/net/lib/py/consts.py new file mode 100644 index 000000000000..f518ce79d82c --- /dev/null +++ b/tools/testing/selftests/net/lib/py/consts.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: GPL-2.0 + +import sys +from pathlib import Path + +KSFT_DIR = (Path(__file__).parent / "../../..").resolve() +KSRC = (Path(__file__).parent / "../../../../../..").resolve() + +KSFT_MAIN_NAME = Path(sys.argv[0]).with_suffix("").name diff --git a/tools/testing/selftests/net/lib/py/ksft.py b/tools/testing/selftests/net/lib/py/ksft.py new file mode 100644 index 000000000000..61287c203b6e --- /dev/null +++ b/tools/testing/selftests/net/lib/py/ksft.py @@ -0,0 +1,280 @@ +# SPDX-License-Identifier: GPL-2.0 + +import builtins +import functools +import inspect +import signal +import sys +import time +import traceback +from .consts import KSFT_MAIN_NAME +from .utils import global_defer_queue + +KSFT_RESULT = None +KSFT_RESULT_ALL = True +KSFT_DISRUPTIVE = True + + +class KsftFailEx(Exception): + pass + + +class KsftSkipEx(Exception): + pass + + +class KsftXfailEx(Exception): + pass + + +class KsftTerminate(KeyboardInterrupt): + pass + + +def ksft_pr(*objs, **kwargs): + print("#", *objs, **kwargs) + + +def _fail(*args): + global KSFT_RESULT + KSFT_RESULT = False + + stack = inspect.stack() + started = False + for frame in reversed(stack[2:]): + # Start printing from the test case function + if not started: + if frame.function == 'ksft_run': + started = True + continue + + ksft_pr("Check| At " + frame.filename + ", line " + str(frame.lineno) + + ", in " + frame.function + ":") + ksft_pr("Check| " + frame.code_context[0].strip()) + ksft_pr(*args) + + +def ksft_eq(a, b, comment=""): + global KSFT_RESULT + if a != b: + _fail("Check failed", a, "!=", b, comment) + + +def ksft_ne(a, b, comment=""): + global KSFT_RESULT + if a == b: + _fail("Check failed", a, "==", b, comment) + + +def ksft_true(a, comment=""): + if not a: + _fail("Check failed", a, "does not eval to True", comment) + + +def ksft_in(a, b, comment=""): + if a not in b: + _fail("Check failed", a, "not in", b, comment) + + +def ksft_not_in(a, b, comment=""): + if a in b: + _fail("Check failed", a, "in", b, comment) + + +def ksft_is(a, b, comment=""): + if a is not b: + _fail("Check failed", a, "is not", b, comment) + + +def ksft_ge(a, b, comment=""): + if a < b: + _fail("Check failed", a, "<", b, comment) + + +def ksft_lt(a, b, comment=""): + if a >= b: + _fail("Check failed", a, ">=", b, comment) + + +class ksft_raises: + def __init__(self, expected_type): + self.exception = None + self.expected_type = expected_type + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + _fail(f"Expected exception {str(self.expected_type.__name__)}, none raised") + elif self.expected_type != exc_type: + _fail(f"Expected exception {str(self.expected_type.__name__)}, raised {str(exc_type.__name__)}") + self.exception = exc_val + # Suppress the exception if its the expected one + return self.expected_type == exc_type + + +def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""): + end = time.monotonic() + deadline + while True: + if cond(): + return + if time.monotonic() > end: + _fail("Waiting for condition timed out", comment) + return + time.sleep(sleep) + + +def ktap_result(ok, cnt=1, case="", comment=""): + global KSFT_RESULT_ALL + KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok + + res = "" + if not ok: + res += "not " + res += "ok " + res += str(cnt) + " " + res += KSFT_MAIN_NAME + if case: + res += "." + str(case.__name__) + if comment: + res += " # " + comment + print(res) + + +def ksft_flush_defer(): + global KSFT_RESULT + + i = 0 + qlen_start = len(global_defer_queue) + while global_defer_queue: + i += 1 + entry = global_defer_queue.pop() + try: + entry.exec_only() + except: + ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!") + tb = traceback.format_exc() + for line in tb.strip().split('\n'): + ksft_pr("Defer Exception|", line) + KSFT_RESULT = False + + +def ksft_disruptive(func): + """ + Decorator that marks the test as disruptive (e.g. the test + that can down the interface). Disruptive tests can be skipped + by passing DISRUPTIVE=False environment variable. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not KSFT_DISRUPTIVE: + raise KsftSkipEx(f"marked as disruptive") + return func(*args, **kwargs) + return wrapper + + +def ksft_setup(env): + """ + Setup test framework global state from the environment. + """ + + def get_bool(env, name): + value = env.get(name, "").lower() + if value in ["yes", "true"]: + return True + if value in ["no", "false"]: + return False + try: + return bool(int(value)) + except: + raise Exception(f"failed to parse {name}") + + if "DISRUPTIVE" in env: + global KSFT_DISRUPTIVE + KSFT_DISRUPTIVE = get_bool(env, "DISRUPTIVE") + + return env + + +def _ksft_intr(signum, frame): + # ksft runner.sh sends 2 SIGTERMs in a row on a timeout + # if we don't ignore the second one it will stop us from handling cleanup + global term_cnt + term_cnt += 1 + if term_cnt == 1: + raise KsftTerminate() + else: + ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...") + + +def ksft_run(cases=None, globs=None, case_pfx=None, args=()): + cases = cases or [] + + if globs and case_pfx: + for key, value in globs.items(): + if not callable(value): + continue + for prefix in case_pfx: + if key.startswith(prefix): + cases.append(value) + break + + global term_cnt + term_cnt = 0 + prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr) + + totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0} + + print("TAP version 13") + print("1.." + str(len(cases))) + + global KSFT_RESULT + cnt = 0 + stop = False + for case in cases: + KSFT_RESULT = True + cnt += 1 + comment = "" + cnt_key = "" + + try: + case(*args) + except KsftSkipEx as e: + comment = "SKIP " + str(e) + cnt_key = 'skip' + except KsftXfailEx as e: + comment = "XFAIL " + str(e) + cnt_key = 'xfail' + except BaseException as e: + stop |= isinstance(e, KeyboardInterrupt) + tb = traceback.format_exc() + for line in tb.strip().split('\n'): + ksft_pr("Exception|", line) + if stop: + ksft_pr(f"Stopping tests due to {type(e).__name__}.") + KSFT_RESULT = False + cnt_key = 'fail' + + ksft_flush_defer() + + if not cnt_key: + cnt_key = 'pass' if KSFT_RESULT else 'fail' + + ktap_result(KSFT_RESULT, cnt, case, comment=comment) + totals[cnt_key] += 1 + + if stop: + break + + signal.signal(signal.SIGTERM, prev_sigterm) + + print( + f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0" + ) + + +def ksft_exit(): + global KSFT_RESULT_ALL + sys.exit(0 if KSFT_RESULT_ALL else 1) diff --git a/tools/testing/selftests/net/lib/py/netns.py b/tools/testing/selftests/net/lib/py/netns.py new file mode 100644 index 000000000000..8e9317044eef --- /dev/null +++ b/tools/testing/selftests/net/lib/py/netns.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: GPL-2.0 + +from .utils import ip +import ctypes +import random +import string + +libc = ctypes.cdll.LoadLibrary('libc.so.6') + + +class NetNS: + def __init__(self, name=None): + if name: + self.name = name + else: + self.name = ''.join(random.choice(string.ascii_lowercase) for _ in range(8)) + ip('netns add ' + self.name) + + def __del__(self): + if self.name: + ip('netns del ' + self.name) + self.name = None + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_tb): + self.__del__() + + def __str__(self): + return self.name + + def __repr__(self): + return f"NetNS({self.name})" + + +class NetNSEnter: + def __init__(self, ns_name): + self.ns_path = f"/run/netns/{ns_name}" + + def __enter__(self): + self.saved = open("/proc/thread-self/ns/net") + with open(self.ns_path) as ns_file: + libc.setns(ns_file.fileno(), 0) + return self + + def __exit__(self, exc_type, exc_value, traceback): + libc.setns(self.saved.fileno(), 0) + self.saved.close() diff --git a/tools/testing/selftests/net/lib/py/nsim.py b/tools/testing/selftests/net/lib/py/nsim.py new file mode 100644 index 000000000000..1a8cbe9acc48 --- /dev/null +++ b/tools/testing/selftests/net/lib/py/nsim.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: GPL-2.0 + +import errno +import json +import os +import random +import re +import time +from .utils import cmd, ip + + +class NetdevSim: + """ + Class for netdevsim netdevice and its attributes. + """ + + def __init__(self, nsimdev, port_index, ifname, ns=None): + # In case udev renamed the netdev to according to new schema, + # check if the name matches the port_index. + nsimnamere = re.compile(r"eni\d+np(\d+)") + match = nsimnamere.match(ifname) + if match and int(match.groups()[0]) != port_index + 1: + raise Exception("netdevice name mismatches the expected one") + + self.ifname = ifname + self.nsimdev = nsimdev + self.port_index = port_index + self.ns = ns + self.dfs_dir = "%s/ports/%u/" % (nsimdev.dfs_dir, port_index) + ret = ip("-j link show dev %s" % ifname, ns=ns) + self.dev = json.loads(ret.stdout)[0] + self.ifindex = self.dev["ifindex"] + + def dfs_write(self, path, val): + self.nsimdev.dfs_write(f'ports/{self.port_index}/' + path, val) + + +class NetdevSimDev: + """ + Class for netdevsim bus device and its attributes. + """ + @staticmethod + def ctrl_write(path, val): + fullpath = os.path.join("/sys/bus/netdevsim/", path) + with open(fullpath, "w") as f: + f.write(val) + + def dfs_write(self, path, val): + fullpath = os.path.join(f"/sys/kernel/debug/netdevsim/netdevsim{self.addr}/", path) + with open(fullpath, "w") as f: + f.write(val) + + def __init__(self, port_count=1, queue_count=1, ns=None): + # nsim will spawn in init_net, we'll set to actual ns once we switch it there + self.ns = None + + if not os.path.exists("/sys/bus/netdevsim"): + cmd("modprobe netdevsim") + + addr = random.randrange(1 << 15) + while True: + try: + self.ctrl_write("new_device", "%u %u %u" % (addr, port_count, queue_count)) + except OSError as e: + if e.errno == errno.ENOSPC: + addr = random.randrange(1 << 15) + continue + raise e + break + self.addr = addr + + # As probe of netdevsim device might happen from a workqueue, + # so wait here until all netdevs appear. + self.wait_for_netdevs(port_count) + + if ns: + cmd(f"devlink dev reload netdevsim/netdevsim{addr} netns {ns.name}") + self.ns = ns + + cmd("udevadm settle", ns=self.ns) + ifnames = self.get_ifnames() + + self.dfs_dir = "/sys/kernel/debug/netdevsim/netdevsim%u/" % addr + + self.nsims = [] + for port_index in range(port_count): + self.nsims.append(self._make_port(port_index, ifnames[port_index])) + + self.removed = False + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_tb): + """ + __exit__ gets called at the end of a "with" block. + """ + self.remove() + + def _make_port(self, port_index, ifname): + return NetdevSim(self, port_index, ifname, self.ns) + + def get_ifnames(self): + ifnames = [] + listdir = cmd(f"ls /sys/bus/netdevsim/devices/netdevsim{self.addr}/net/", + ns=self.ns).stdout.split() + for ifname in listdir: + ifnames.append(ifname) + ifnames.sort() + return ifnames + + def wait_for_netdevs(self, port_count): + timeout = 5 + timeout_start = time.time() + + while True: + try: + ifnames = self.get_ifnames() + except FileNotFoundError as e: + ifnames = [] + if len(ifnames) == port_count: + break + if time.time() < timeout_start + timeout: + continue + raise Exception("netdevices did not appear within timeout") + + def remove(self): + if not self.removed: + self.ctrl_write("del_device", "%u" % (self.addr, )) + self.removed = True + + def remove_nsim(self, nsim): + self.nsims.remove(nsim) + self.ctrl_write("devices/netdevsim%u/del_port" % (self.addr, ), + "%u" % (nsim.port_index, )) diff --git a/tools/testing/selftests/net/lib/py/utils.py b/tools/testing/selftests/net/lib/py/utils.py new file mode 100644 index 000000000000..34470d65d871 --- /dev/null +++ b/tools/testing/selftests/net/lib/py/utils.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: GPL-2.0 + +import errno +import json as _json +import os +import random +import re +import select +import socket +import subprocess +import time + + +class CmdExitFailure(Exception): + def __init__(self, msg, cmd_obj): + super().__init__(msg) + self.cmd = cmd_obj + + +def fd_read_timeout(fd, timeout): + rlist, _, _ = select.select([fd], [], [], timeout) + if rlist: + return os.read(fd, 1024) + else: + raise TimeoutError("Timeout waiting for fd read") + + +class cmd: + """ + Execute a command on local or remote host. + + Use bkg() instead to run a command in the background. + """ + def __init__(self, comm, shell=True, fail=True, ns=None, background=False, + host=None, timeout=5, ksft_wait=None): + if ns: + comm = f'ip netns exec {ns} ' + comm + + self.stdout = None + self.stderr = None + self.ret = None + self.ksft_term_fd = None + + self.comm = comm + if host: + self.proc = host.cmd(comm) + else: + # ksft_wait lets us wait for the background process to fully start, + # we pass an FD to the child process, and wait for it to write back. + # Similarly term_fd tells child it's time to exit. + pass_fds = () + env = os.environ.copy() + if ksft_wait is not None: + rfd, ready_fd = os.pipe() + wait_fd, self.ksft_term_fd = os.pipe() + pass_fds = (ready_fd, wait_fd, ) + env["KSFT_READY_FD"] = str(ready_fd) + env["KSFT_WAIT_FD"] = str(wait_fd) + + self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, pass_fds=pass_fds, + env=env) + if ksft_wait is not None: + os.close(ready_fd) + os.close(wait_fd) + msg = fd_read_timeout(rfd, ksft_wait) + os.close(rfd) + if not msg: + raise Exception("Did not receive ready message") + if not background: + self.process(terminate=False, fail=fail, timeout=timeout) + + def process(self, terminate=True, fail=None, timeout=5): + if fail is None: + fail = not terminate + + if self.ksft_term_fd: + os.write(self.ksft_term_fd, b"1") + if terminate: + self.proc.terminate() + stdout, stderr = self.proc.communicate(timeout) + self.stdout = stdout.decode("utf-8") + self.stderr = stderr.decode("utf-8") + self.proc.stdout.close() + self.proc.stderr.close() + self.ret = self.proc.returncode + + if self.proc.returncode != 0 and fail: + if len(stderr) > 0 and stderr[-1] == "\n": + stderr = stderr[:-1] + raise CmdExitFailure("Command failed: %s\nSTDOUT: %s\nSTDERR: %s" % + (self.proc.args, stdout, stderr), self) + + +class bkg(cmd): + """ + Run a command in the background. + + Examples usage: + + Run a command on remote host, and wait for it to finish. + This is usually paired with wait_port_listen() to make sure + the command has initialized: + + with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc: + ... + + Run a command and expect it to let us know that it's ready + by writing to a special file descriptor passed via KSFT_READY_FD. + Command will be terminated when we exit the context manager: + + with bkg("my_binary", ksft_wait=5): + """ + def __init__(self, comm, shell=True, fail=None, ns=None, host=None, + exit_wait=False, ksft_wait=None): + super().__init__(comm, background=True, + shell=shell, fail=fail, ns=ns, host=host, + ksft_wait=ksft_wait) + self.terminate = not exit_wait and not ksft_wait + self.check_fail = fail + + if shell and self.terminate: + print("# Warning: combining shell and terminate is risky!") + print("# SIGTERM may not reach the child on zsh/ksh!") + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_tb): + return self.process(terminate=self.terminate, fail=self.check_fail) + + +global_defer_queue = [] + + +class defer: + def __init__(self, func, *args, **kwargs): + global global_defer_queue + + if not callable(func): + raise Exception("defer created with un-callable object, did you call the function instead of passing its name?") + + self.func = func + self.args = args + self.kwargs = kwargs + + self._queue = global_defer_queue + self._queue.append(self) + + def __enter__(self): + return self + + def __exit__(self, ex_type, ex_value, ex_tb): + return self.exec() + + def exec_only(self): + self.func(*self.args, **self.kwargs) + + def cancel(self): + self._queue.remove(self) + + def exec(self): + self.cancel() + self.exec_only() + + +def tool(name, args, json=None, ns=None, host=None): + cmd_str = name + ' ' + if json: + cmd_str += '--json ' + cmd_str += args + cmd_obj = cmd(cmd_str, ns=ns, host=host) + if json: + return _json.loads(cmd_obj.stdout) + return cmd_obj + + +def ip(args, json=None, ns=None, host=None): + if ns: + args = f'-netns {ns} ' + args + return tool('ip', args, json=json, host=host) + + +def ethtool(args, json=None, ns=None, host=None): + return tool('ethtool', args, json=json, ns=ns, host=host) + + +def rand_port(type=socket.SOCK_STREAM): + """ + Get a random unprivileged port. + """ + with socket.socket(socket.AF_INET6, type) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5): + end = time.monotonic() + deadline + + pattern = f":{port:04X} .* " + if proto == "tcp": # for tcp protocol additionally check the socket state + pattern += "0A" + pattern = re.compile(pattern) + + while True: + data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout + for row in data.split("\n"): + if pattern.search(row): + return + if time.monotonic() > end: + raise Exception("Waiting for port listen timed out") + time.sleep(sleep) diff --git a/tools/testing/selftests/net/lib/py/ynl.py b/tools/testing/selftests/net/lib/py/ynl.py new file mode 100644 index 000000000000..6329ae805abf --- /dev/null +++ b/tools/testing/selftests/net/lib/py/ynl.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: GPL-2.0 + +import sys +from pathlib import Path +from .consts import KSRC, KSFT_DIR +from .ksft import ksft_pr, ktap_result + +# Resolve paths +try: + if (KSFT_DIR / "kselftest-list.txt").exists(): + # Running in "installed" selftests + tools_full_path = KSFT_DIR + SPEC_PATH = KSFT_DIR / "net/lib/specs" + + sys.path.append(tools_full_path.as_posix()) + from net.lib.ynl.pyynl.lib import YnlFamily, NlError + else: + # Running in tree + tools_full_path = KSRC / "tools" + SPEC_PATH = KSRC / "Documentation/netlink/specs" + + sys.path.append(tools_full_path.as_posix()) + from net.ynl.pyynl.lib import YnlFamily, NlError +except ModuleNotFoundError as e: + ksft_pr("Failed importing `ynl` library from kernel sources") + ksft_pr(str(e)) + ktap_result(True, comment="SKIP") + sys.exit(4) + +# +# Wrapper classes, loading the right specs +# Set schema='' to avoid jsonschema validation, it's slow +# +class EthtoolFamily(YnlFamily): + def __init__(self, recv_size=0): + super().__init__((SPEC_PATH / Path('ethtool.yaml')).as_posix(), + schema='', recv_size=recv_size) + + +class RtnlFamily(YnlFamily): + def __init__(self, recv_size=0): + super().__init__((SPEC_PATH / Path('rt-link.yaml')).as_posix(), + schema='', recv_size=recv_size) + +class RtnlAddrFamily(YnlFamily): + def __init__(self, recv_size=0): + super().__init__((SPEC_PATH / Path('rt-addr.yaml')).as_posix(), + schema='', recv_size=recv_size) + +class NetdevFamily(YnlFamily): + def __init__(self, recv_size=0): + super().__init__((SPEC_PATH / Path('netdev.yaml')).as_posix(), + schema='', recv_size=recv_size) + +class NetshaperFamily(YnlFamily): + def __init__(self, recv_size=0): + super().__init__((SPEC_PATH / Path('net_shaper.yaml')).as_posix(), + schema='', recv_size=recv_size) |