#!/usr/bin/python3
import argparse
import os
from subprocess import run, check_call, check_output
import tempfile
import pathlib
import platform
import glob
import shutil
import re
import sys
import hashlib
from collections import namedtuple
from enum import Enum, auto
import json


class ModTable:

    def __init__(self):
        self._table = {}

    def get(self, key, default=None):
        if not isinstance(key, str):
            raise TypeError
        return self._table.get(key.replace('_', '-'), default)

    def __contains__(self, key):
        return key.replace('_', '-') in self._table

    def __getitem__(self, key):
        if not isinstance(key, str):
            raise TypeError
        return self._table[key.replace('_', '-')]

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError
        self._table[key.replace('_', '-')] = value


class ModuleDb:

    ModuleInfo = namedtuple('ModuleInfo', ('builtin', 'aliases', 'symbols'))
    alias_line = re.compile(r"^alias +(?P<alias>[^ ]+) +(?P<target>[^ ]+)$")
    modinfo_line = re.compile("^(?P<name>[^.]+)[.](?P<arg>[^.]+)=(?P<value>.+)$")
    class Installed(Enum):
        EXPLICIT = auto()
        IMPLICIT = auto()

    @staticmethod
    def __each_alias(path):
        with open(path, "r") as f:
            data = f.read()
        for line in data.splitlines():
            m = ModuleDb.alias_line.fullmatch(line)
            if m:
                yield m['alias'], m['target']

    def __init__(self, module_path):
        self._modules = ModTable()
        self._alias = ModTable()
        self._deps = ModTable()
        self._installed = ModTable()
        not_found = set()

        with open(os.path.join(module_path, "modules.builtin"), "r") as f:
            for line in f.read().splitlines():
                name = os.path.splitext(os.path.basename(line))[0]
                self._modules[name] = ModuleDb.ModuleInfo(True, [], [])

        with open(os.path.join(module_path, "modules.builtin.modinfo"), "r") as f:
            for line in f.read().split('\0'):
                m = ModuleDb.modinfo_line.fullmatch(line)
                if m:
                    if m['arg'] == 'alias':
                        original = self._modules.get(m['name'], ModuleDb.ModuleInfo(True, [], []))
                        self._modules[m['name']] = ModuleDb.ModuleInfo(original.builtin,
                                                                    original.aliases + [m['value']],
                                                                    original.symbols)
                    self._alias[m['value']] = m['name']

        for root, dirs, files in os.walk(module_path):
            for f in files:
                if f.endswith('.ko'):
                    name, _ = os.path.splitext(f)
                elif '.ko.' in f:
                    name, _ = f.split('.ko.', maxsplit=1)
                else:
                    continue
                self._modules[name] = ModuleDb.ModuleInfo(False, [], [])

        for alias, target in ModuleDb.__each_alias(os.path.join(module_path, "modules.alias")):
            try:
                original = self._modules[target]
            except KeyError:
                if target not in not_found:
                    print(f"module.alias: cannot find module {target}", file=sys.stderr)
                    not_found.add(target)
            else:
                self._modules[target] = ModuleDb.ModuleInfo(original.builtin,
                                                            original.aliases + [alias],
                                                            original.symbols)
                self._alias[alias] = target

        for alias, target in ModuleDb.__each_alias(os.path.join(module_path, "modules.symbols")):
            try:
                original = self._modules[target]
            except KeyError:
                if target not in not_found:
                    print(f"modules.symbols: cannot find module {target}", file=sys.stderr)
                    not_found.add(target)
            else:
                self._modules[target] = ModuleDb.ModuleInfo(original.builtin,
                                                            original.aliases,
                                                            original.symbols + [alias])

        with open(os.path.join(module_path, "modules.dep"), "r") as f:
            for line in f.read().splitlines():
                mod, deps = line.split(':', 1)
                deps = deps.strip()
                mod = os.path.basename(mod)
                mod = os.path.splitext(mod)[0]
                if len(deps) != 0:
                    for dep in deps.split():
                        dep = os.path.basename(dep)
                        dep = os.path.splitext(dep)[0]
                        if mod not in self._deps:
                            self._deps[mod] = []
                        self._deps[mod].append(dep)

    def mark_installed(self, mod, source):
        real_mod = self.solve_alias(mod)
        if real_mod is None:
            print(f"WARNING: Module {mod} is not known in {source}", file=sys.stderr)
            return
        if real_mod in self._installed:
            old_mode, old_origin, old_source = self._installed[real_mod]
            if old_mode == ModuleDb.Installed.IMPLICIT:
                print(f"WARNING: Module {mod} installed by {source}, but is dependency of {old_origin} installed by {old_source}", file=sys.stderr)
            elif old_mode == ModuleDb.Installed.EXPLICIT:
                print(f"WARNING: Module {mod} installed by {source}, but it was already installed by {old_source} - you could remove it from one of the sources", file=sys.stderr)
                return
            else:
                print(f"INTERNAL ERROR: unexpected value in ModuleDb: {ModuleDb}",
                      file=sys.stderr)
                exit(1)
        self._installed[mod] = (ModuleDb.Installed.EXPLICIT, mod, source)
        for dep in self.all_deps(real_mod):
            if dep in self._installed:
                dep_mode, dep_origin, dep_source = self._installed[dep]
                if dep_mode == ModuleDb.Installed.EXPLICIT:
                    print(f"WARNING: Module {dep} installed by {dep_source}, but is dependency of {mod} installed by {source}", file=sys.stderr)
            else:
                self._installed[dep] = (ModuleDb.Installed.IMPLICIT, mod, source)

    def all_deps(self, mod):
        real_mod = self.solve_alias(mod)
        if real_mod is None:
            return
        queue = self._deps.get(real_mod, [])
        done = set()
        while queue:
            m = queue.pop().replace('_', '-')
            if m in done:
                continue
            yield m
            done.add(m)
            if m in self._deps:
                queue.extend(self._deps[m])

    def solve_alias(self, name):
        if name in self._modules:
            return name
        if name in self._alias:
            return self.solve_alias(self._alias[name])
        return None

    def __contains__(self, name):
        if name in self._modules:
            return True
        if name not in self._alias:
            return False
        return self._alias[name] in self

    def __getitem__(self, name):
        if not isinstance(name, str):
            raise TypeError
        if name in self._modules:
            return self._modules[name]
        if name in self._alias:
            return self[self._alias[name]]
        raise KeyError


# os.path.join() throws away previous components when it finds one
# that is an absolute path - we do not want that when using --root as
# many defaults are absolute, so we use this function to make second
# and later paths relative.
def path_join_make_rel_paths(path, *paths):
    relative_paths = []
    for p in paths:
        relative_paths.append(os.path.relpath(os.path.join('/', p), '/'))
    return os.path.join(path, *relative_paths)


known_automatic_modules = list(map(re.compile, [
    r'crypto-.*',
    r'algif-.*',
    r'fs-.*',
    r'nls-.*',
    r'dm-.*',
    r'md-.*',
]))

# Include kernel modules specified in conf_file
def add_modules_from_file(dest_d, kernel_root, modules_d, fw_d, conf_file, db,
                          warn_discoverable=False):
    """Include kernel modules specified in `conf_file`

    Args:
        dest_d (str): Path to sysroot
        kernel_root (str): Path to kernel directory
        modules_d (str): Path to module directory
        fw_d (str): Path to firmware directory
        conf_file (str): Path to configuration file
        db (ModuleDb): A ModuleDb instance which contains information about modules from `modules_d`
        warn_discoverable (bool): Whether to warn of modules that can be discovered through aliases

    """

    with open(conf_file) as f:
        for module in f.readlines():
            module = module.strip()
            if module:
                if module[0] in ["#", ";"]:
                    continue

                try:
                    mod = db[module]
                except KeyError:
                    print(f"NOTE: {conf_file}: Module {module} not found", file=sys.stderr)
                    if warn_discoverable and any(r.fullmatch(module.replace('_', '-')) for r in known_automatic_modules):
                        print(f"WARNING: {conf_file}: Module {module} is probably automatically loaded", file=sys.stderr)
                else:
                    if mod.builtin:
                        print(f"NOTE: {conf_file}: Module {module} is builtin", file=sys.stderr)
                    if warn_discoverable:
                        if len(mod.aliases) > 0:
                              print(f"WARNING: {conf_file}: Module {module} has aliases:", file=sys.stderr)
                              for alias in mod.aliases:
                                  print(f" * {alias}", file=sys.stderr)
                        elif warn_discoverable and any(r.fullmatch(module.replace('_', '-')) for r in known_automatic_modules):
                              print(f"WARNING: {conf_file}: Module {module} is probably automatically loaded", file=sys.stderr)

                    # A module that has symbols but no alias is likely to be loaded through aliases
                    if len(mod.aliases) == 0 and len(mod.symbols) > 0:
                        print(f"WARNING: {conf_file}: Module {module} exports symbols:", file=sys.stderr)
                        for symbol in mod.symbols:
                            print(f" * {symbol}", file=sys.stderr)
                    if not mod.builtin:
                        db.mark_installed(module, conf_file)

                check_call(
                    [
                        "/usr/lib/dracut/dracut-install",
                        "-D",
                        dest_d,
                        "-r", kernel_root,
                        "--kerneldir",
                        modules_d,
                        "--firmwaredirs",
                        fw_d,
                        "--module",
                        "--optional",
                        module,
                    ]
                )


def install_files(files, dest_dir, sysroot):
    proc_env = os.environ.copy()
    proc_env["LD_PRELOAD"] = ""
    cmd = [
        "/usr/lib/dracut/dracut-install",
        "-D", dest_dir,
        "--ldd", "--all",
    ]
    # dracut-install has some bugs if you set "/" as sysroot, avoid that
    if sysroot != "/":
        cmd += ["-r", sysroot]
    check_call(cmd + files, env=proc_env)


# This is useful if the destination path inside dest_dir is different
# from the current file path.
def install_file_to_path(file, dest_dir, dest_path, sysroot):
    proc_env = os.environ.copy()
    proc_env["LD_PRELOAD"] = ""
    cmd = [
        "/usr/lib/dracut/dracut-install",
        "-D", dest_dir,
        "--ldd",
    ]
    # dracut-install has some bugs if you set "/" as sysroot, avoid that
    if sysroot != "/":
        cmd += ["-r", sysroot]
    cmd += [file, dest_path]
    check_call(cmd, env=proc_env)


# Returns as a list the files contained in a list of deb packages.
def package_files(pkgs, sysroot):
    out = check_output(["dpkg", "--admindir=" + sysroot + "/var/lib/dpkg/", "-L"] + pkgs)
    return out.decode("utf-8").splitlines()


def install_systemd_files(dest_dir, sysroot, ubuntu_release):
    # Build list of files and directories

    syd_pkgs = ["systemd", "systemd-sysv"]
    if release_higher_or_equal(ubuntu_release, (24, 10)):
        syd_pkgs.append("systemd-cryptsetup")
    lines = package_files(syd_pkgs, sysroot)
    # From systemd, we pull
    # * units configuration
    # * Executables
    # * Module load options
    # * Configuration of kernel parameters
    # * udev rules
    # * Configuration for systemd-tmpfiles
    # TODO: some of this can be cleaned up
    to_include = re.compile(r"^(/usr)?/lib/systemd/system|"
                            r"/usr/share/dbus-1/system-services/org.freedesktop.systemd1.service|"
                            r"/usr/share/dbus-1/system.d/org.freedesktop.systemd1.conf|"
                            r".*/modprobe\.d/|"
                            r".*/sysctl\.d/|"
                            r".*/rules\.d/|"
                            r".*/tmpfiles\.d/|"
                            r".*bin/|"
                            r".*sbin/"
                            )
    files = [i for i in lines if to_include.match(i)]

    lines = package_files(["udev"], sysroot)
    # From udev, we pull
    # * Executables
    # * systemd configuration (units, tmpfiles)
    # * udev rules
    # * hwdb
    # TODO: some of this can be cleaned up
    to_include = re.compile(r".*bin/|"
                            r".*lib/|"
                            r".*rules\.d/"
                            )
    files += [i for i in lines if to_include.match(i)]

    files.append("/var/lib/systemd/")

    # Filter out some units we don't want
    to_remove = re.compile(".*systemd-gpt-auto-generator|"
                           ".*proc-sys-fs-binfmt_misc.automount|"
                           ".*systemd-pcrphase.*|"
                           ".*systemd-tpm2-setup.*")
    filtered = [i for i in files if not to_remove.match(i)]
    # Install
    install_files(filtered, dest_dir, sysroot)

    # Local modifications
    # This hack should be removed with PR#113
    check_call([r"sed -i '/^After=/"
                r"{;s, *plymouth-start[.]service *, ,;/"
                r"^After= *$/d;}' "
                + os.path.join(dest_dir,
                               "usr/lib/systemd/system/systemd-ask-password-*")],
               shell=True)
    # Generate hw database (/usr/lib/udev/hwdb.bin) for udev and
    # remove redundant definitions after that.
    check_call(["systemd-hwdb", "--root", dest_dir,
                "update", "--usr", "--strict"])
    shutil.rmtree(os.path.join(dest_dir, "usr/lib/udev/hwdb.d"))


def install_busybox(dest_dir, sysroot):
    install_file_to_path("/usr/lib/initramfs-tools/bin/busybox", dest_dir,
                         "usr/bin/busybox", sysroot)
    # Commands we want for busybox (static for more control and to help
    # with cross-building).
    cmds = ["[", "[[", "acpid", "arch", "ascii", "ash", "awk", "base32", "basename",
            "blockdev", "cat", "chmod", "chroot", "chvt", "clear", "cmp", "cp",
            "crc32", "cut", "date", "deallocvt", "deluser", "devmem", "df", "dirname",
            "du", "dumpkmap", "echo", "egrep", "env", "expr", "false", "fbset", "fgrep",
            "find", "fold", "fstrim", "grep", "gunzip", "gzip", "hostname", "hwclock",
            "i2ctransfer", "ifconfig", "ip", "kill", "ln", "loadfont", "loadkmap", "ls",
            "lzop", "mkdir", "mkfifo", "mknod", "mkswap", "mktemp", "more",
            "mv", "nuke", "openvt", "pidof", "printf", "ps", "pwd", "readlink",
            "reset", "rm", "rmdir", "run-init", "sed", "seq", "setkeycodes",
            "sh", "sleep", "sort", "stat", "static-sh", "stty", "switch_root", "sync",
            "tail", "tee", "test", "touch", "tr", "true", "ts", "tty", "uname",
            "uniq", "wc", "wget", "which", "yes"]
    for c in cmds:
        os.symlink("busybox", os.path.join(dest_dir, "usr/bin", c))


def install_misc(dest_dir, sysroot, deb_arch):
    # dmsetup rules
    rules = package_files(["dmsetup"], sysroot)
    to_include = re.compile(r".*rules.d/")
    rules = [i for i in rules if to_include.match(i)]
    install_files(rules, dest_dir, sysroot)

    files = [
        "/usr/bin/kmod",
        "/usr/bin/mount",
        "/usr/sbin/sulogin",
        "/usr/bin/tar",
        "/usr/lib/" + deb_arch + "/libgcc_s.so.1",
        "/usr/lib/" + deb_arch + "/libkmod.so.2",
        "/usr/lib/systemd/system/dbus.service",
        "/usr/lib/systemd/system/dbus.socket",
        "/usr/lib/systemd/system/plymouth-start.service",
        "/usr/lib/systemd/system/plymouth-switch-root.service",
        "/usr/lib/systemd/systemd-bootchart",
        "/usr/sbin/cryptsetup",
        "/usr/sbin/dmsetup",
        "/usr/sbin/e2fsck",
        "/usr/sbin/fsck",
        "/usr/bin/umount",
        "/usr/sbin/fsck.vfat",
        "/usr/sbin/fsck.vfat",
        "/usr/sbin/mkfs.ext4",
        "/usr/sbin/mkfs.vfat",
        "/usr/sbin/sfdisk",
        "/usr/bin/dbus-daemon",
        "/usr/bin/mountpoint",
        "/usr/bin/partx",
        "/usr/bin/plymouth",
        "/usr/bin/unsquashfs",
        "/usr/lib/" + deb_arch + "/plymouth/label-freetype.so",
        "/usr/lib/" + deb_arch + "/plymouth/script.so",
        "/usr/lib/" + deb_arch + "/plymouth/two-step.so",
        "/usr/sbin/depmod",
        "/usr/sbin/insmod",
        "/usr/sbin/lsmod",
        "/usr/sbin/modinfo",
        "/usr/sbin/modprobe",
        "/usr/sbin/plymouthd",
        "/usr/sbin/rmmod",
        "/usr/share/dbus-1/system.conf",
        "/usr/share/libdrm/amdgpu.ids",
    ]
    files += glob.glob("/lib/" + deb_arch + "/libnss_compat.so.*")
    files += glob.glob("/lib/" + deb_arch + "/libnss_files.so.*")
    files += glob.glob("/usr/lib/" + deb_arch + "/plymouth/renderers/*.so")
    files += glob.glob("/usr/share/plymouth/themes/bgrt/*")
    files += glob.glob("/usr/share/plymouth/themes/spinner/*")
    install_files(files, dest_dir, sysroot)
    # Links for fsck
    os.symlink("e2fsck", os.path.join(dest_dir, "usr/sbin", "fsck.ext4"))
    # Get deps for shared objects that have not the exec bit set and that
    # are loaded with dlopen (which means that they are not pull by --ldd
    # option, instead use --resolvelazy).
    proc_env = os.environ.copy()
    proc_env["LD_PRELOAD"] = ""
    to_resolve = [
        "/usr/lib/" + deb_arch + "/plymouth/label-freetype.so",
        "/usr/lib/" + deb_arch + "/plymouth/script.so",
        "/usr/lib/" + deb_arch + "/plymouth/two-step.so",
    ]
    to_resolve += glob.glob("/usr/lib/" + deb_arch + "/plymouth/renderers/*.so")
    cmd = [
        "/usr/lib/dracut/dracut-install",
        "-D", dest_dir,
        "--resolvelazy",
    ]
    # dracut-install has some bugs if you set "/" as sysroot, avoid that
    if sysroot != "/":
        cmd += ["-r", sysroot]
    check_call(cmd + to_resolve, env=proc_env)
    # Build ld cache (make sure /etc is there before)
    etc_d = os.path.join(dest_dir, "etc")
    os.makedirs(etc_d, exist_ok=True)
    check_call(["ldconfig", "-r", dest_dir])

    # /sbin/modprobe is a static path in the kernel
    os.makedirs(os.path.join(dest_dir, "sbin"))
    os.symlink("../usr/sbin/modprobe", os.path.join(dest_dir, "sbin", "modprobe"))
    # FIXME: systemd is configured with the wrong path to dmsetup
    os.symlink("../usr/sbin/dmsetup", os.path.join(dest_dir, "sbin", "dmsetup"))
    # FIXME: snap-bootstrap uses /sbin/reboot
    os.symlink("../usr/bin/systemctl", os.path.join(dest_dir, "sbin", "reboot"))


class AptRepo:

    def __init__(self, url):
        self.url = url
        self.suites = []
        self.components = []


def find_apt_repos(sysroot):
    repos = {}
    sources = os.path.join(sysroot, "etc/apt/sources.list")
    sources_d = os.path.join(sysroot, "etc/apt/sources.list.d")
    state_d = os.path.join(sysroot, "var/lib/apt/lists")
    repo_lines = check_output(["apt-cache", "policy",
                               "-o", "Dir::Etc::SourceList="+sources,
                               "-o", "Dir::Etc::SourceParts="+sources_d,
                               "-o", "Dir::State::Lists="+state_d]).decode("utf-8").splitlines()
    uri_line = re.compile(r"^ *-?[0-9]+ (http|mirror|file|cdrom|ftp|copy|rsh|ssh)")
    repo_lines = [r for r in repo_lines if uri_line.match(r)]
    uri_start = re.compile(r"^ *-?[0-9]+ ")
    for r in repo_lines:
        r = uri_start.sub("", r)
        fields = r.split(" ")
        if len(fields) < 2:
            continue
        # suite/comp in the line
        suite_comp = fields[1].split("/")
        if len(suite_comp) != 2:
            continue
        # Key will be url + suite (we only have one suite per line)
        key = (fields[0], suite_comp[0])
        if key in repos:
            repos[key].components.append(suite_comp[1])
        else:
            repo = AptRepo(fields[0])
            repo.suites = [suite_comp[0]]
            repo.components = [suite_comp[1]]
            repos[key] = repo
    # Now merge suites that share components
    repo_ls = list(repos.values())
    merged = set()
    final_repos = []
    for i, repo in enumerate(repo_ls):
        if i in merged:
            continue
        for j in range(i+1, len(repo_ls)):
            if sorted(repo.components) == sorted(repo_ls[j].components):
                repo.suites += repo_ls[j].suites
                merged.add(j)
        final_repos.append(repo)

    return final_repos


def pkgs_used_in_initrd(dest_dir, sysroot):
    MD5_HEX_CHARS = 2*hashlib.md5().digest_size

    # Get list of files we have in the initramfs and calculate their md5sum
    initrd_digests = []
    for dirpath, dirs, files in os.walk(dest_dir):
        for f in files:
            path = os.path.join(dirpath, f)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            f_stats = os.stat(path)
            if f_stats.st_size == 0:
                continue
            dig = hashlib.file_digest(open(path, "rb"), "md5").hexdigest()
            initrd_digests.append(dig)

    # Read digests from dpkg database
    root_digests = {}
    dpkg_d = os.path.join(sysroot, "var/lib/dpkg/info")
    for dirpath, dirs, files in os.walk(dpkg_d):
        for f in files:
            if not f.endswith(".md5sums"):
                continue
            pkg = f.removesuffix(".md5sums")
            # Files from libraries usually are of form pkg_name:arch, consider
            # that case
            pkg = pkg.split(":")[0]
            with open(os.path.join(dirpath, f), "r") as digests_f:
                for ln in digests_f.readlines():
                    if len(ln) >= MD5_HEX_CHARS:
                        root_digests[ln[:MD5_HEX_CHARS]] = pkg

    # Now find matches
    pkgs = set()
    for d in initrd_digests:
        pkg = root_digests.get(d)
        if pkg is not None:
            pkgs.add(pkg)

    # Files we take from systemd-sysv are symlinks and have been ignored
    pkgs.add("systemd-sysv")
    # Return sorted list
    return sorted(pkgs)


def create_initrd_pkg_list(dest_dir, sysroot):
    # Find used repos
    repos = find_apt_repos(sysroot)
    # and packages used to build the initramfs
    pkgs = pkgs_used_in_initrd(dest_dir, sysroot)

    # Write package list in yaml format
    docs_dir = os.path.join(dest_dir, "usr", "share", "doc")
    os.makedirs(docs_dir)
    with open(os.path.join(docs_dir, "dpkg.yaml"), "w") as pkg_list:
        # Write active repositories first so we can have an idea of where
        # things come from
        out = "package-repositories:\n"
        for r in repos:
            out += "- type: apt\n" \
                "  url: " + r.url + "\n" \
                "  suites: [ " + ', '.join(r.suites) + " ]\n" \
                "  components: [ " + ', '.join(r.components) + " ]\n"

        # Now write package information
        out += "packages:\n"
        out += check_output(["dpkg-query", "--admindir=" + sysroot + "/var/lib/dpkg/",
                             "--show", "--showformat=- ${binary:Package}=${Version}\\n"] +
                            pkgs).decode("utf-8")
        pkg_list.write(out)


# verify_missing_dlopen looks at the .notes.dlopen section of ELF
# binaries to find libraries that are not in the dynamic section, and
# that will be loaded with dynamically dlopen when needed.
# See https://systemd.io/ELF_DLOPEN_METADATA/
def verify_missing_dlopen(destdir, libdir):
    missing = {}
    for dirpath, dirs, files in os.walk(destdir):
        for f in files:
            path = os.path.join(dirpath, f)
            if os.path.islink(path) or not os.path.isfile(path):
                continue
            with open(path, 'rb') as b:
                if b.read(4) != b'\x7fELF':
                    continue
            out = check_output(["dlopen-notes", path])
            split = out.splitlines()
            json_doc = b'\n'.join([s for s in split if not s[:1] == b'#'])
            doc = json.loads(json_doc)
            for dep in doc:
                sonames = dep["soname"]
                priority = dep["priority"]
                found_sonames = []
                for soname in sonames:
                    dest = os.path.join(destdir, os.path.relpath(libdir, "/"), soname)
                    if os.path.exists(os.path.join(destdir, dest)):
                        found_sonames.append(soname)
                if not found_sonames:
                    # We did not find any library.
                    # In this case we need to mark all sonames as
                    # missing. This is required because some features
                    # may have common subset of sonames and those
                    # features might have different priorities.
                    for soname in sonames:
                        current_priority = missing.get(soname)
                        if current_priority == "required":
                            continue
                        elif current_priority == "recommended" and priority not in ["required"]:
                            continue
                        elif current_priority == "suggested" and priority not in ["required", "recommended"]:
                            continue
                        else:
                            missing[soname] = priority

    fatal = False
    if missing:
        print(f"WARNING: These sonames are missing:", file=sys.stderr)
        for m, priority in missing.items():
            print(f" * {m} ({priority})", file=sys.stderr)
            if priority in ["required", "recommended"]:
                fatal = True
        if fatal:
            print(f"WARNING: Some missing sonames are required or recommended. Failing.", file=sys.stderr)

    return not fatal


def get_deb_arch(sysroot):
    proc_env = os.environ.copy()
    proc_env["DPKG_DATADIR"] = sysroot + "/usr/share/dpkg"
    out = check_output(["dpkg-architecture", "-q",
                        "DEB_HOST_MULTIARCH"], env=proc_env).decode("utf-8")
    return out.splitlines()[0]


def get_ubuntu_release(sysroot):
    with open(os.path.join(sysroot, "etc/os-release"), "r") as f:
        regex = re.compile('^VERSION_ID="([0-9]+)[.]([0-9]+)"$', re.MULTILINE)
        res = regex.search(f.read())
        if res:
            return [int(res.group(1)), int(res.group(2))]
    print("ERROR: cannot find Ubuntu release", file=sys.stderr)
    sys.exit(1)


def release_higher_or_equal(release1, release2):
    if release1[0] > release2[0] or \
       (release1[0] == release2[0] and release1[1] >= release2[1]):
        return True
    return False


def create_initrd(parser, args):
    # TODO generate microcode instead of shipping in debian package
    rootfs = "/"
    if not args.kerneldir:
        args.kerneldir = "/lib/modules/%s" % args.kernelver
    if args.root:
        args.skeleton = path_join_make_rel_paths(args.root, args.skeleton)
        args.kerneldir = path_join_make_rel_paths(args.root, args.kerneldir)
        args.firmwaredir = path_join_make_rel_paths(args.root, args.firmwaredir)
        args.output = path_join_make_rel_paths(args.root, args.output)
        rootfs = args.root
    if args.kernelver:
        args.output = "-".join([args.output, args.kernelver])
    with tempfile.TemporaryDirectory(suffix=".ubuntu-core-initramfs") as d:
        deb_arch = get_deb_arch(rootfs)
        ubuntu_release = get_ubuntu_release(rootfs)

        kernel_root = os.path.join(d, "kernel")
        modules = os.path.join(kernel_root, "usr", "lib", "modules")
        os.makedirs(modules, exist_ok=True)
        modules = os.path.join(modules, args.kernelver)
        check_call(["cp", "-ar", args.kerneldir, modules])

        firmware = os.path.join(kernel_root, "usr", "lib", "firmware")
        check_call(["cp", "-ar", args.firmwaredir, firmware])

        db = ModuleDb(modules)
        main = os.path.join(d, "main")
        os.makedirs(main, exist_ok=True)
        # copy busybox first so we get already the shell interpreter we
        # want (busybox) instead of dracut-install pulling the systemd
        # default (dash) when it pulls a shell script later.
        install_busybox(main, rootfs)
        # Copy systemd bits
        install_systemd_files(main, rootfs, ubuntu_release)
        # Other miscelanea stuff
        install_misc(main, rootfs, deb_arch)
        # Copy features
        for feature in args.features:
            # Add feature files
            feature_path = os.path.join(args.skeleton, feature)
            if os.path.isdir(feature_path):
                check_call(["cp", "-aT", feature_path, main])
            # Add feature kernel modules
            extra_modules = os.path.join(args.skeleton, "modules", feature,
                                         "extra-modules.conf")
            if os.path.exists(extra_modules):
                add_modules_from_file(main, kernel_root, modules, firmware, extra_modules, db)

        # TODO xnox: fips actually needs additional runtime dependencies than
        # this, and currently forked in fips PPA. we need to figure out how to
        # make ubuntu-core-initramfs-fips a reality.
        if "fips" in args.features:
            install_files([
                    "/usr/bin/kcapi-hasher",
                    "/usr/bin/.kcapi-hasher.hmac",
            ] + glob.glob("/usr/lib/*/.libkcapi.so.*.hmac"), main, rootfs)

        # Update epoch
        pathlib.Path("%s/main/usr/lib/clock-epoch" % d).touch()
        # Should iterate all the .conf drop ins
        for module_load in glob.iglob("%s/usr/lib/modules-load.d/*.conf" % main):
            add_modules_from_file(main, kernel_root, modules, firmware, module_load, db,
                                  warn_discoverable=True)

        for modulesf in ["modules.order", "modules.builtin", "modules.builtin.bin", "modules.builtin.modinfo"]:
            check_call(
                [
                    "/usr/lib/dracut/dracut-install",
                    "-D",
                    main,
                    "-r", kernel_root,
                    os.path.join(modules, modulesf),
                    os.path.join("usr/lib/modules", args.kernelver, modulesf),
                ]
            )
        check_call(["depmod", "-a", "-b", main, args.kernelver])

        if release_higher_or_equal(ubuntu_release, (24, 10)):
            if not verify_missing_dlopen(main, os.path.join("/usr/lib", deb_arch)):
                sys.exit(1)

        # Create manifest with packages with files included in the initramfs
        create_initrd_pkg_list(main, rootfs)

        # Finally, write it
        with open(args.output, "wb") as output:
            for early in glob.iglob("%s/early/*.cpio" % args.skeleton):
                with open(early, "rb") as f:
                    shutil.copyfileobj(f, output)
            output.write(
                run(
                    "find . | cpio --create --quiet --format='newc' --owner=0:0 | zstd -1 -T0",
                    cwd=main,
                    capture_output=True,
                    shell=True,
                    check=True,
                ).stdout)


def create_efi(parser, args):
    if args.root:
        if not args.stub:
            parser.error("--stub is mandatory when --root is given")
        args.os_release = path_join_make_rel_paths(args.root, args.os_release)
        args.stub = path_join_make_rel_paths(args.root, args.stub)
        args.kernel = path_join_make_rel_paths(args.root, args.kernel)
        args.initrd = path_join_make_rel_paths(args.root, args.initrd)
        args.key = path_join_make_rel_paths(args.root, args.key)
        args.cert = path_join_make_rel_paths(args.root, args.cert)
        args.output = path_join_make_rel_paths(args.root, args.output)
    if args.kernelver:
        args.kernel = "-".join([args.kernel, args.kernelver])
        args.initrd = "-".join([args.initrd, args.kernelver])
        args.output = "-".join([args.output, args.kernelver])

        if platform.machine() == "aarch64":
            import gzip
            raw_kernel = tempfile.NamedTemporaryFile(mode='wb')
            try:
                with gzip.open(args.kernel, 'rb') as comp_kernel:
                    shutil.copyfileobj(comp_kernel, raw_kernel)
                raw_kernel.flush()
                args.kernel = raw_kernel.name
            except gzip.BadGzipFile:
                pass

    # TODO: add --splash
    ukify_cmd = [
        '/usr/lib/systemd/ukify', 'build',
        '--linux', args.kernel,
        '--initrd', args.initrd,
        '--output', args.output,
        '--os-release', '@{}'.format(args.os_release),
    ]

    if args.kernelver:
        ukify_cmd += [
            '--uname', args.kernelver,
        ]
    if args.stub:
        ukify_cmd += [
            '--stub', args.stub,
        ]
    if args.efi_arch:
        ukify_cmd += [
            '--efi-arch', args.efi_arch,
        ]
    if args.cmdline:
        ukify_cmd += [
            '--cmdline', args.cmdline,
        ]
    check_call(ukify_cmd)
    if not args.unsigned:
        check_call(
            [
                "sbsign",
                "--key",
                args.key,
                "--cert",
                args.cert,
                "--output",
                args.output,
                args.output,
            ]
        )
    os.chmod(args.output, 0o644)


def main():
    kernelver = check_output(
        ["uname", "-r"], universal_newlines=True
    ).strip()
    features = {"x86_64": "main server"}.get(platform.machine(), "main")
    parser = argparse.ArgumentParser()
    subparser = parser.add_subparsers(dest="subcmd", required=True)
    efi_parser = subparser.add_parser(
        "create-efi", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    efi_parser.add_argument("--root", help="path to root")
    efi_parser.add_argument("--stub", help="path to stub")
    efi_parser.add_argument("--efi-arch", help="efi architecture name")
    efi_parser.add_argument("--os-release", help="path to os-release", default="/etc/os-release")
    efi_parser.add_argument("--kernel", help="path to kernel", default="/boot/vmlinuz")
    efi_parser.add_argument(
        "--kernelver", help="kernel version suffix", default=kernelver
    )
    efi_parser.add_argument(
        "--initrd", help="path to initramfs", default="/boot/ubuntu-core-initramfs.img"
    )
    efi_parser.add_argument(
        "--cmdline", help="commandline to embed (can be overridden with LoadOptions)"
    )
    efi_parser.add_argument(
        "--unsigned", help="do not sign efi app", default=False, action="store_true"
    )
    efi_parser.add_argument(
        "--key",
        help="signing key",
        default="/usr/lib/ubuntu-core-initramfs/snakeoil/PkKek-1-snakeoil.key",
    )
    efi_parser.add_argument(
        "--cert",
        help="certificate",
        default="/usr/lib/ubuntu-core-initramfs/snakeoil/PkKek-1-snakeoil.pem",
    )
    efi_parser.add_argument(
        "--output", help="path to output", default="/boot/kernel.efi"
    )
    initrd_parser = subparser.add_parser(
        "create-initrd", formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    initrd_parser.add_argument("--root", help="path to root", default="")
    initrd_parser.add_argument(
        "--skeleton", help="path to skeleton", default="/usr/lib/ubuntu-core-initramfs"
    )
    initrd_parser.add_argument(
        "--feature",
        dest="features",
        help="list of features to enable, if unspecified defaults to %s" % features,
        nargs="+"
    )
    initrd_parser.add_argument(
        "--kerneldir", help="path to kernel modules dir default /lib/modules/KERNELVER"
    )
    initrd_parser.add_argument("--kernelver", help="kernel version", default=kernelver)
    initrd_parser.add_argument(
        "--firmwaredir", help="path to kernel firmware", default="/lib/firmware"
    )
    initrd_parser.add_argument(
        "--output", help="path to output", default="/boot/ubuntu-core-initramfs.img"
    )
    initrd_parser.set_defaults(
        kernelver=check_output(
            ["uname", "-r"], universal_newlines=True
        ).strip()
    )

    args = parser.parse_args()
    if args.subcmd == "create-initrd" and not args.features:
        # Set arch specific default set of features
        args.features = features.split()
        # For generic kernels, enable more server drivers too
        if args.kernelver.endswith("-generic") and "server" not in args.features:
            args.features.append("server")
        if args.kernelver.endswith("-fips"):
            args.features.append("fips")
    globals()[args.subcmd.replace("-", "_")](parser, args)


if __name__ == "__main__":
    main()
