#!/usr/bin/python3
"""
    Copyright (C) 2021  Michael Ablassmeier <abi@grinser.de>

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import sys
import json
import signal
import logging
import argparse
from argparse import Namespace
from time import time
from typing import List, Union, BinaryIO, Any
from datetime import datetime
from functools import partial
from threading import current_thread
from concurrent.futures import ThreadPoolExecutor, as_completed
from libvirt import virDomain
from libvirtnbdbackup import argopt
from libvirtnbdbackup import __version__
from libvirtnbdbackup import nbdcli
from libvirtnbdbackup import extenthandler
from libvirtnbdbackup import qemu
from libvirtnbdbackup import virt
from libvirtnbdbackup.virt.client import DomainDisk
from libvirtnbdbackup import output
from libvirtnbdbackup.common import common as lib
from libvirtnbdbackup.common.processinfo import processInfo
from libvirtnbdbackup.logcount import logCount
from libvirtnbdbackup.sparsestream import streamer
from libvirtnbdbackup.sparsestream import types
from libvirtnbdbackup import exceptions
from libvirtnbdbackup.ssh.exceptions import sshError
from libvirtnbdbackup.nbdcli.exceptions import NbdClientException
from libvirtnbdbackup.qemu.exceptions import ProcessError
from libvirtnbdbackup.virt.exceptions import (
    startBackupFailed,
    domainNotFound,
    connectionFailed,
)
from libvirtnbdbackup.output.exceptions import OutputException


def handleSignal(
    args: Namespace,
    domObj: virDomain,
    virtClient: virt.client,
    log: Any,
    signum: int,
    _,
) -> None:
    """Catch signal, attempt to stop running backup job."""
    log.error("Caught signal: %s", signum)
    log.error("Cleanup: Stopping backup job")
    if args.offline is not True:
        virtClient.stopBackup(domObj)
    sys.exit(1)


def checkForeign(
    args: Namespace,
    virtClient: virt.client,
    domObj: virDomain,
) -> bool:
    """Check and warn user if virtual machine has checkpoints
    not originating from this utility"""
    foreign = None
    if args.level in ("full", "inc", "diff"):
        foreign = virtClient.hasforeignCheckpoint(domObj, lib.checkpointName)

    if not foreign:
        return True

    logging.fatal("Foreign checkpoint found: [%s]", foreign)
    logging.fatal("This checkpoint has not been created by this utility.")
    logging.fatal(
        "To ensure backup chain consistency, "
        "remove existing checkpoints "
        "and start a new backup chain by creating a full backup."
    )

    raise exceptions.ForeignCeckpointError


def setOfflineArguments(args: Namespace, domObj: virDomain) -> None:
    """Check if to be saved VM is offline and set
    propper options/overwrite backup mode"""
    args.offline = False
    if domObj.isActive() == 0:
        if args.level == "full":
            logging.warning("Domain is offline, resetting backup options.")
            args.level = "copy"
            logging.info("Backup level: [%s].", args.level)
        args.offline = True


def hasPartial(args: Namespace) -> bool:
    """Check if target directory has an partial backup,
    makes backup utility exit errnous in case backup
    type is full or inc"""
    if (
        args.level in ("inc", "diff")
        and args.stdout is False
        and lib.partialBackup(args) is True
    ):
        logging.error("Partial backup found in target directory: [%s]", args.output)
        logging.error("One of the last backups seems to have failed.")
        logging.error("Consider re-executing full backup.")
        return True

    return False


def readCheckpointFile(cFile: str) -> List[str]:
    """Open checkpoint file and read checkpoint
    information"""
    checkpoints: List[str] = []
    if not lib.os.path.exists(cFile):
        return checkpoints

    try:
        with output.openfile(cFile, "rb") as fh:
            checkpoints = json.loads(fh.read().decode())
        return checkpoints
    except OutputException as e:
        raise exceptions.ReadCheckpointsError(f"Failed to read checkpoint file: [{e}]")
    except json.decoder.JSONDecodeError as e:
        raise exceptions.ReadCheckpointsError(f"Invalid checkpoint file: [{e}]")


def handleCheckpoints(
    args: Namespace,
    virtClient: virt.client,
    domObj: virDomain,
) -> None:
    """Checkpoint handling for different backup modes
    to be executed.

    Creates a new namespace in the argparse object,
    for easy pass around in further functions.
    """
    checkpointName: str = f"{lib.checkpointName}.0"
    parentCheckpoint: str = ""
    cptFile: str = f"{args.output}/{args.domain}.cpt"
    checkpoints: List[str] = readCheckpointFile(cptFile)

    if args.offline is False:
        if virtClient.redefineCheckpoints(domObj, args) is False:
            raise exceptions.RedefineCheckpointError("Failed to redefine checkpoints.")

    logging.info("Checkpoint handling.")
    if args.level == "full" and checkpoints:
        logging.info("Removing all existent checkpoints before full backup.")
        if not virtClient.removeAllCheckpoints(
            domObj, checkpoints, args, lib.checkpointName
        ):
            raise exceptions.RemoveCheckpointError("Failed to remove checkpoint.")
        os.remove(cptFile)
        checkpoints = []
    elif args.level == "full" and len(checkpoints) < 1:
        if not virtClient.removeAllCheckpoints(domObj, None, args, lib.checkpointName):
            raise exceptions.RemoveCheckpointError("Failed to remove checkpoint.")
        checkpoints = []

    if checkpoints and args.level in ("inc", "diff"):
        nextCpt = len(checkpoints)
        checkpointName = f"{lib.checkpointName}.{nextCpt}"
        parentCheckpoint = checkpoints[-1]
        logging.info("Next checkpoint id: [%s].", nextCpt)
        logging.info("Parent checkpoint name [%s].", parentCheckpoint)

        if args.offline is True:
            logging.info("Offline backup, using latest checkpoint, saving only delta.")
            checkpointName = parentCheckpoint

    if args.level == "diff":
        logging.info(
            "Diff backup: saving delta since checkpoint: [%s].", parentCheckpoint
        )

    if args.level in ("inc", "diff") and len(checkpoints) < 1:
        raise exceptions.NoCheckpointsFound(
            "No existing checkpoints found, execute full backup first."
        )

    if args.level in ("full", "inc"):
        logging.info("Using checkpoint name: [%s].", checkpointName)

    args.cpt = Namespace()
    args.cpt.name = checkpointName
    args.cpt.parent = parentCheckpoint
    args.cpt.file = cptFile

    logging.debug("Checkpoint info: [%s].", vars(args.cpt))


def saveCheckpointFile(args: Namespace) -> None:
    """Append created checkpoint to checkpoint
    file"""
    try:
        checkpoints = readCheckpointFile(args.cpt.file)
        checkpoints.append(args.cpt.name)
        with output.openfile(args.cpt.file, "wb") as cFw:
            cFw.write(json.dumps(checkpoints).encode())
    except exceptions.CheckpointException as e:
        raise exceptions.CheckpointException from e
    except OutputException as e:
        raise exceptions.SaveCheckpointError from e


def getFileStream(
    args: Namespace, repository: output.target
) -> Union[output.target.Directory, output.target.Zip]:
    """Get filehandle for output files based on output
    mode"""
    fileStream: Union[output.target.Directory, output.target.Zip]
    if args.stdout is False:
        fileStream = repository.Directory(args.output)
    else:
        fileStream = repository.Zip()
        args.output = "./"
        args.worker = 1

    return fileStream


def startBackupJob(
    args: Namespace,
    virtClient: virt.client,
    domObj: virDomain,
    disks: List[DomainDisk],
) -> bool:
    """Start backup job via libvirt API"""
    try:
        logging.info("Starting backup job.")
        virtClient.startBackup(
            args,
            domObj,
            disks,
        )
        logging.debug("Backup job started.")
        return True
    except startBackupFailed as e:
        logging.error(e)

    return False


def main() -> None:
    """Handle backup operation"""
    parser = argparse.ArgumentParser(
        description="Backup libvirt/qemu virtual machines",
        epilog=(
            "Examples:\n"
            "   # full backup of domain 'webvm' with all attached disks:\n"
            "\t%(prog)s -d webvm -l full -o /backup/\n"
            "   # incremental backup:\n"
            "\t%(prog)s -d webvm -l inc -o /backup/\n"
            "   # differential backup:\n"
            "\t%(prog)s -d webvm -l diff -o /backup/\n"
            "   # full backup, exclude disk 'vda':\n"
            "\t%(prog)s -d webvm -l full -x vda -o /backup/\n"
            "   # full backup, backup only disk 'vdb':\n"
            "\t%(prog)s -d webvm -l full -i vdb -o /backup/\n"
            "   # full backup, compression enabled:\n"
            "\t%(prog)s -d webvm -l full -z -o /backup/\n"
            "   # full backup, create archive:\n"
            "\t%(prog)s -d webvm -l full -o - > backup.zip\n"
            "   # full backup of vm operating on remote libvirtd:\n"
            "\t%(prog)s -U qemu+ssh://root@remotehost/system "
            "--ssh-user root -d webvm -l full -o /backup/\n"
        ),
        formatter_class=argparse.RawTextHelpFormatter,
    )

    opt = parser.add_argument_group("General options")
    opt.add_argument("-d", "--domain", required=True, type=str, help="Domain to backup")
    opt.add_argument(
        "-l",
        "--level",
        default="copy",
        choices=["copy", "full", "inc", "diff", "auto"],
        type=str,
        help="Backup level. (default: %(default)s)",
    )
    opt.add_argument(
        "-t",
        "--type",
        default="stream",
        type=str,
        choices=["stream", "raw"],
        help="Output type: stream or raw. (default: %(default)s)",
    )
    opt.add_argument(
        "-r",
        "--raw",
        default=False,
        action="store_true",
        help="Include full provisioned disk images in backup. (default: %(default)s)",
    )
    opt.add_argument(
        "-o", "--output", required=True, type=str, help="Output target directory"
    )
    opt.add_argument(
        "-C",
        "--checkpointdir",
        required=False,
        default=None,
        type=str,
        help="Persistent libvirt checkpoint storage directory",
    )
    opt.add_argument(
        "-S",
        "--scratchdir",
        default="/var/tmp",
        required=False,
        type=str,
        help="Target dir for temporary scratch file. (default: %(default)s)",
    )
    opt.add_argument(
        "-i",
        "--include",
        default=None,
        type=str,
        help="Backup only disk with target dev name (-i vda)",
    )
    opt.add_argument(
        "-x",
        "--exclude",
        default=None,
        type=str,
        help="Exclude disk(s) with target dev name (-x vda,vdb)",
    )
    opt.add_argument(
        "-f",
        "--socketfile",
        default=f"/var/tmp/virtnbdbackup.{os.getpid()}",
        type=str,
        help="Use specified file for NBD Server socket (default: %(default)s)",
    )
    opt.add_argument(
        "-n",
        "--noprogress",
        default=False,
        help="Disable progress bar",
        action="store_true",
    )
    opt.add_argument(
        "-z",
        "--compress",
        default=False,
        type=int,
        const=2,
        nargs="?",
        help="Compress with lz4 compression level. (default: %(default)s)",
        action="store",
    )
    opt.add_argument(
        "-w",
        "--worker",
        type=int,
        default=None,
        help=(
            "Amount of concurrent workers used "
            "to backup multiple disks. (default: amount of disks)"
        ),
    )
    opt.add_argument(
        "-F",
        "--freeze-mountpoint",
        type=str,
        default=None,
        help=(
            "If qemu agent available, freeze only filesystems on specified mountpoints within"
            " virtual machine (default: all)"
        ),
    )
    opt.add_argument(
        "-e",
        "--strict",
        default=False,
        help=(
            "Change exit code if warnings occur during backup operation. "
            "(default: %(default)s)"
        ),
        action="store_true",
    )
    opt.add_argument(
        "-T",
        "--threshold",
        type=int,
        default=None,
        help=("Execute backup only if threshold is reached."),
    )
    remopt = parser.add_argument_group("Remote Backup options")
    argopt.addRemoteArgs(remopt)
    logopt = parser.add_argument_group("Logging options")
    logopt.add_argument(
        "-L",
        "--syslog",
        default=False,
        action="store_true",
        help="Additionally send log messages to syslog (default: %(default)s)",
    )
    debopt = parser.add_argument_group("Debug options")
    debopt.add_argument(
        "-q",
        "--qemu",
        default=False,
        action="store_true",
        help="Use Qemu tools to query extents.",
    )
    debopt.add_argument(
        "-s",
        "--startonly",
        default=False,
        help="Only initialize backup job via libvirt, do not backup any data",
        action="store_true",
    )
    debopt.add_argument(
        "-k",
        "--killonly",
        default=False,
        help="Kill any running block job",
        action="store_true",
    )
    debopt.add_argument(
        "-p",
        "--printonly",
        default=False,
        help="Quit after printing estimated checkpoint size.",
        action="store_true",
    )
    argopt.addDebugArgs(debopt)

    repository = output.target()
    args = lib.argparse(parser)

    args.stdout = args.output == "-"
    args.sshClient = None
    args.diskInfo = []

    try:
        fileStream = getFileStream(args, repository)
    except OutputException as e:
        logging.error("Cant open output file: [%s]", e)
        sys.exit(1)

    if args.worker is not None and args.worker < 1:
        args.worker = 1

    now = datetime.now().strftime("%m%d%Y%H%M%S")
    logFile = f"{args.output}/backup.{args.level}.{now}.log"
    fileLog = lib.getLogFile(logFile) or sys.exit(1)

    counter = logCount()
    lib.configLogger(args, fileLog, counter)
    lib.printVersion(__version__)

    logging.info("Backup level: [%s]", args.level)
    if args.compress:
        logging.info("Compression enabled, level [%s]", args.compress)

    if args.stdout is True and args.type == "raw":
        logging.error("Output type raw not supported to stdout.")
        sys.exit(1)

    if args.stdout is True and args.raw is True:
        logging.error("Saving raw images to stdout is not supported.")
        sys.exit(1)

    if lib.targetIsEmpty(args) and args.level == "auto":
        logging.info("Backup mode auto, target folder is empty: executing full backup.")
        args.level = "full"
    elif not lib.targetIsEmpty(args) and args.level == "auto":
        if not lib.hasFullBackup(args):
            logging.error(
                "Cant execute switch to auto incremental backup: "
                "target folder seems not to include full backup."
            )
            sys.exit(1)
        logging.info("Backup mode auto: executing incremental backup.")
        args.level = "inc"
    elif not args.stdout and not args.startonly and not args.killonly:
        if not lib.targetIsEmpty(args):
            logging.error("Target directory already contains full or copy backup.")
            sys.exit(1)

    if args.raw is True and args.level in ("inc", "diff"):
        logging.warning(
            "Raw disks cant be included during incremental or differential backup."
        )
        logging.warning("Excluding raw disks.")
        args.raw = False

    if args.type == "raw" and args.level in ("inc", "diff"):
        logging.error(
            "Stream format raw does not support incremental or differential backup."
        )
        sys.exit(1)

    if hasPartial(args):
        sys.exit(1)

    if not args.checkpointdir:
        args.checkpointdir = f"{args.output}/checkpoints"
    else:
        logging.info("Store checkpoints in: [%s]", args.checkpointdir)

    repository.Directory(args.checkpointdir)

    try:
        virtClient = virt.client(args)
        domObj = virtClient.getDomain(args.domain)
    except domainNotFound as e:
        logging.error("%s", e)
        sys.exit(1)
    except connectionFailed as e:
        logging.error("Cant connect libvirt daemon: [%s]", e)
        sys.exit(1)

    logging.info("Libvirt library version: [%s]", virtClient.libvirtVersion)

    if virtClient.hasIncrementalEnabled(domObj) is False:
        logging.error("Domain is missing required incremental-backup capability.")
        sys.exit(1)

    try:
        checkForeign(args, virtClient, domObj)
    except exceptions.CheckpointException:
        sys.exit(1)

    setOfflineArguments(args, domObj)
    if args.offline is True and args.startonly is True:
        logging.error("Virtual machine is currently offline")
        logging.error("Virtual machine must be active for this function.")
        sys.exit(1)

    signal.signal(
        signal.SIGINT, partial(handleSignal, args, domObj, virtClient, logging)
    )

    vmConfig = virtClient.getDomainConfig(domObj)
    disks: List[DomainDisk] = virtClient.getDomainDisks(args, vmConfig)
    args.info = virtClient.getDomainInfo(vmConfig)

    if args.level != "copy" and lib.hasQcowDisks(disks) is False:
        args.level = "copy"
        logging.info("Only raw disks attached, switching to backup mode copy.")

    if not disks:
        logging.error("Unable to detect disks suitable for backup.")
        saveFiles(args, vmConfig, disks, fileStream, logFile)
        sys.exit(1)
    if (
        not args.killonly
        and not args.offline
        and virtClient.blockJobActive(domObj, disks)
    ):
        logging.error("Detected an active backup operation for running domain.")
        logging.error("Check with [virsh domjobinfo %s]", args.domain)
        sys.exit(1)

    logging.info(
        "Backup will save [%s] attached disks.",
        len(disks),
    )
    if args.worker is None or args.worker > int(len(disks)):
        args.worker = int(len(disks))
    logging.info("Concurrent backup processes: [%s]", args.worker)

    if args.killonly is True:
        logging.info("Stopping backup job")
        if not virtClient.stopBackup(domObj):
            sys.exit(1)
        sys.exit(0)

    try:
        handleCheckpoints(args, virtClient, domObj)
    except exceptions.CheckpointException as errmsg:
        logging.error(errmsg)
        sys.exit(1)

    if args.printonly and args.cpt.parent and not args.offline:
        size = virtClient.getCheckpointSize(domObj, args.cpt.parent)
        logging.info("Estimated checkpoint backup size: [%s] Bytes", size)
        sys.exit(0)

    if args.threshold and args.cpt.parent and not args.offline:
        size = virtClient.getCheckpointSize(domObj, args.cpt.parent)
        if size < args.threshold:
            logging.info(
                "Backup size [%s] does not meet required threshold [%s], skipping backup.",
                size,
                args.threshold,
            )
            sys.exit(0)

    if virtClient.remoteHost != "":
        args.sshClient = lib.sshSession(args, virtClient.remoteHost)
        if not args.sshClient:
            logging.error("Remote backup detected but ssh session setup failed")
            sys.exit(1)
        logging.info(
            "Remote: NDB Endpoint socket: [%s:%s]",
            virtClient.remoteHost,
            args.socketfile,
        )
        if args.offline is True:
            logging.info(
                "Remote ports used for backup: [%s-%s]",
                args.nbd_port,
                args.nbd_port + args.worker,
            )
    else:
        logging.info("Local NDB Endpoint socket: [%s]", args.socketfile)

    if args.offline is not True:
        logging.info("Temporary scratch file target directory: [%s]", args.scratchdir)
        repository.Directory(args.scratchdir)
        if not startBackupJob(args, virtClient, domObj, disks):
            sys.exit(1)

    if args.level not in ("copy", "diff") and args.offline is False:
        logging.info("Started backup job with checkpoint, saving information.")
        try:
            saveCheckpointFile(args)
        except exceptions.CheckpointException as e:
            logging.error("Extending checkpoint file failed: [%s]", e)
            sys.exit(1)
        if not virtClient.backupCheckpoint(args, domObj):
            virtClient.stopBackup(domObj)
            sys.exit(1)

    if args.startonly is True:
        logging.info("Started backup job for debugging, exiting.")
        sys.exit(0)

    try:
        with ThreadPoolExecutor(max_workers=args.worker) as executor:
            futures = {
                executor.submit(
                    backupDisk, args, disk, count, fileStream, virtClient
                ): disk
                for count, disk in enumerate(disks)
            }
            for future in as_completed(futures):
                if future.result() is not True:
                    raise exceptions.DiskBackupFailed("Backup of one disk failed")
    except exceptions.BackupException as e:
        logging.error("Disk backup failed: [%s]", e)
    except sshError as e:
        logging.error("Remote Disk backup failed: [%s]", e)
    except Exception as e:  # pylint: disable=broad-except
        logging.fatal("Unknown Exception during backup: %s", e)
        logging.exception(e)

    if args.offline is False:
        logging.info("Backup jobs finished, stopping backup task.")
        virtClient.stopBackup(domObj)

    saveFiles(args, vmConfig, disks, fileStream, logFile)

    if domObj.autostart() == 1:
        backupAutoStart(args, domObj)

    if counter.count.errors > 0:
        logging.error("Error during backup")
        sys.exit(1)

    if args.sshClient:
        args.sshClient.disconnect()

    if counter.count.warnings > 0 and args.strict is True:
        logging.info(
            "[%s] Warnings detected during backup operation, forcing exit code 2",
            counter.count.warnings,
        )
        sys.exit(2)

    logging.info("Finished successfully")


def saveFiles(
    args: Namespace,
    vmConfig: str,
    disks: List[DomainDisk],
    fileStream: Union[output.target.Directory, output.target.Zip],
    logFile: str,
):
    """Save additional files such as virtual machine configuration
    and UEFI / kernel images"""
    configFile = backupConfig(args, vmConfig)

    backupBootConfig(args)
    for disk in disks:
        if disk.format.startswith("qcow"):
            backupDiskInfo(args, disk)
    if args.stdout is True:
        addFiles(args, configFile, fileStream, logFile)


def addFiles(args: Namespace, configFile: Union[str, None], zipStream, logFile: str):
    """Add backup log and other files to zip archive"""
    if configFile is not None:
        logging.info("Adding vm config to zipfile")
        zipStream.zipStream.write(configFile, configFile)
    if args.level in ("full", "inc"):
        logging.info("Adding checkpoint info to zipfile")
        zipStream.zipStream.write(args.cpt.file, args.cpt.file)
        for dirname, _, files in os.walk(args.checkpointdir):
            zipStream.zipStream.write(dirname)
            for filename in files:
                zipStream.zipStream.write(os.path.join(dirname, filename))

    for setting, val in args.info.items():
        logging.info(
            "Adding additional [%s] setting file [%s] to zipfile", setting, val
        )
        zipStream.zipStream.write(val, os.path.basename(val))

    for diskInfo in args.diskInfo:
        logging.info("Adding QCOW image format file [%s] to zipfile", diskInfo)
        zipStream.zipStream.write(diskInfo, os.path.basename(diskInfo))

    logging.info("Adding backup log [%s] to zipfile", logFile)
    zipStream.zipStream.write(logFile, logFile)


def getIdent(args: Namespace) -> Union[str, int]:
    """Used to get an uniqe identifier for target files,
    usually checkpoint name is used, but if no checkpoint
    is created, we use timestamp"""
    try:
        ident = args.cpt.name
    except AttributeError:
        ident = int(time())
    if args.level == "diff":
        ident = int(time())
    if args.level == "copy":
        ident = "copy"

    return ident


def backupConfig(args: Namespace, vmConfig: str) -> Union[str, None]:
    """Save domain XML config file"""
    configFile = f"{args.output}/vmconfig.{getIdent(args)}.xml"
    logging.info("Saving VM config to: [%s]", configFile)
    try:
        with output.openfile(configFile, "wb") as fh:
            fh.write(vmConfig.encode())
        return configFile
    except OutputException as e:
        logging.error("Failed to save VM config: [%s]", e)
        return None


def backupAutoStart(args: Namespace, domObj: virDomain) -> None:
    """Save information wether if virtual machine was marked
    for autostart during system boot"""
    logging.info("Autostart setting configured for virtual machine.")
    autoStartFile = f"{args.output}/autostart.{getIdent(args)}"
    try:
        with output.openfile(autoStartFile, "wb") as fh:
            fh.write(b"True")
    except OutputException as e:
        logging.warning("Failed to save autostart information: [%s]", e)


def backupBootConfig(args: Namespace) -> None:
    """Save domain uefi/nvram/kernel and loader if configured."""
    for setting, val in args.info.items():
        if args.level != "copy":
            tFile = f"{args.output}/{os.path.basename(val)}.{getIdent(args)}"
        else:
            tFile = f"{args.output}/{os.path.basename(val)}"
        logging.info("Save additional boot config [%s] to: [%s]", setting, tFile)
        lib.copy(args, val, tFile)
        args.info[setting] = tFile


def backupDiskInfo(args: Namespace, disk: DomainDisk):
    """Save information about qcow image"""
    try:
        info = qemu.util("").info(disk.path, args.sshClient)
    except (
        ProcessError,
        sshError,
    ) as errmsg:
        logging.warning("Failed to read qcow image info: [%s]", errmsg)
        return

    configFile = f"{args.output}/{disk.target}.{getIdent(args)}.qcow.json"
    try:
        with output.openfile(configFile, "wb") as fh:
            fh.write(info.out.encode())
        logging.info("Saved qcow image config to: [%s]", configFile)
        if args.stdout is True:
            args.diskInfo.append(configFile)
    except OutputException as e:
        logging.warning("Failed to save qcow image config: [%s]", e)


def setMetaContext(args: Namespace, disk: DomainDisk) -> str:
    """Set meta context passed to nbd server based on
    backup type"""
    metaContext = ""
    if args.level in ("inc", "diff"):
        if args.offline is True:
            metaContext = f"qemu:dirty-bitmap:{args.cpt.name}"
        else:
            metaContext = f"qemu:dirty-bitmap:backup-{disk.target}"

        logging.info("INC/DIFF backup: set context to [%s]", metaContext)

    return metaContext


def setStreamType(args: Namespace, disk: DomainDisk) -> str:
    """Set target stream type based on disk format"""
    streamType = "raw"
    if disk.format != streamType:
        streamType = args.type

    return streamType


def setTargetFile(args: Namespace, disk: DomainDisk, ext: str = "data"):
    """Set Target file name to write data to, used for both data files
    and qemu disk info"""
    targetFile: str = ""
    if args.level in ("full", "copy"):
        level = args.level
        if disk.format == "raw":
            level = "copy"
        targetFile = f"{args.output}/{disk.target}.{level}.{ext}"
    elif args.level in ("inc", "diff"):
        cptName = getIdent(args)
        targetFile = f"{args.output}/{disk.target}.{args.level}.{cptName}.{ext}"

    targetFilePartial = f"{targetFile}.partial"

    return targetFile, targetFilePartial


def getWriter(
    args: Namespace, fileStream, targetFile: str, targetFilePartial: str
) -> BinaryIO:
    """Open target file based on output writer"""
    if args.stdout is True:
        targetFile = os.path.basename(targetFile)
        logging.info("Writing data to zip archive.")
        writer = fileStream.open(targetFile)
    else:
        logging.info("Write data to target file: [%s].", targetFilePartial)
        writer = fileStream.open(targetFilePartial)

    return writer


def getExtentHandler(args: Namespace, nbdClient):
    """Query dirty blocks either via qemu client or self
    implemented extend handler"""
    if args.qemu:
        logging.info("Using qemu tools to query extents")
        extentHandler = extenthandler.ExtentHandler(
            qemu.util(nbdClient.cType.exportName), nbdClient.cType
        )
    else:
        extentHandler = extenthandler.ExtentHandler(nbdClient, nbdClient.cType)

    return extentHandler


def renamePartial(targetFilePartial: str, targetFile: str) -> None:
    """After backup, move .partial file to real
    target file"""
    try:
        os.rename(targetFilePartial, targetFile)
    except OSError as e:
        raise exceptions.DiskBackupFailed(f"Failed to rename file: [{e}]") from e


def startOfflineNBD(
    args: Namespace, disk: DomainDisk, remoteHost: str, port: int
) -> processInfo:
    """Start background qemu-nbd process used during backup
    if domain is offline, in case of remote backup, initiate
    ssh session and start process on remote system."""
    bitMap: str = ""
    if args.level in ("inc", "diff"):
        bitMap = args.cpt.name
    socket = f"{args.socketfile}.{disk.target}"
    if remoteHost != "":
        logging.info(
            "Offline backup, starting remote NDB server, socket: [%s:%s], port: [%s]",
            remoteHost,
            socket,
            port,
        )
        nbdProc = qemu.util(disk.target).startRemoteBackupNbdServer(
            args, disk.format, disk.path, bitMap, port
        )
        logging.info("Remote NDB server started, PID: [%s].", nbdProc.pid)
        return nbdProc

    logging.info("Offline backup, starting local NDB server, socket: [%s]", socket)
    nbdProc = qemu.util(disk.target).startBackupNbdServer(
        disk.format, disk.path, socket, bitMap
    )
    logging.info("Local NDB Service started, PID: [%s]", nbdProc.pid)
    return nbdProc


def truncate(writer, size: int) -> None:
    """Truncate raw image"""
    try:
        writer.truncate(size)
        writer.seek(0)
    except OSError as e:
        raise exceptions.DiskBackupWriterException(
            f"Failed to truncate target file: [{e}]"
        ) from e


def connectNbd(
    args: Namespace,
    disk: DomainDisk,
    metaContext: str,
    remoteIP: str,
    port: int,
    virtClient: virt.client,
):
    """Connect to started nbd endpoint"""
    socket = args.socketfile
    if args.offline is True:
        socket = f"{args.socketfile}.{disk.target}"

    cType: Union[nbdcli.TCP, nbdcli.Unix]
    if virtClient.remoteHost != "":
        cType = nbdcli.TCP(disk.target, metaContext, remoteIP, args.tls, port)
    else:
        cType = nbdcli.Unix(disk.target, metaContext, socket)

    nbdClient = nbdcli.client(cType)

    try:
        return nbdClient.connect()
    except NbdClientException as e:
        raise exceptions.DiskBackupFailed(
            f"NBD endpoint: [{cType}]: connection failed: [{e}]"
        )


def backupDisk(
    args: Namespace,
    disk: DomainDisk,
    count: int,
    fileStream,
    virtClient: virt.client,
):
    """Backup domain disk data."""
    stream = streamer.SparseStream(types)
    sTypes = types.SparseStreamTypes()
    current_thread().name = disk.target
    streamType = setStreamType(args, disk)
    metaContext = setMetaContext(args, disk)
    nbdProc: processInfo
    remoteIP: str = virtClient.remoteHost
    port: int = args.nbd_port
    if args.nbd_ip != "":
        remoteIP = args.nbd_ip

    if args.offline is True:
        port = args.nbd_port + count
        try:
            nbdProc = startOfflineNBD(args, disk, virtClient.remoteHost, port)
        except ProcessError as errmsg:
            logging.error(errmsg)
            raise exceptions.DiskBackupFailed("Failed to start NBD server.")

    connection = connectNbd(args, disk, metaContext, remoteIP, port, virtClient)

    extentHandler = getExtentHandler(args, connection)
    extents = extentHandler.queryBlockStatus()
    diskSize = connection.nbd.get_size()

    if extents is None:
        logging.error("No extents found.")
        return True

    thinBackupSize = sum(extent.length for extent in extents if extent.data is True)
    logging.info("Got %s extents to backup.", len(extents))
    logging.debug("%s", lib.dumpExtentJson(extents))
    logging.info("%s bytes disk size", diskSize)
    logging.info("%s bytes of data extents to backup", thinBackupSize)

    if args.level in ("inc", "diff") and thinBackupSize == 0:
        logging.info("No dirty blocks found")
        args.noprogress = True

    targetFile, targetFilePartial = setTargetFile(args, disk)
    writer = getWriter(args, fileStream, targetFile, targetFilePartial)

    if streamType == "raw":
        logging.info("Creating full provisioned raw backup image")
        truncate(writer, diskSize)
    else:
        logging.info("Creating thin provisioned stream backup image")
        metadata = stream.dumpMetadata(
            args,
            diskSize,
            thinBackupSize,
            disk,
        )
        stream.writeFrame(writer, sTypes.META, 0, len(metadata))
        writer.write(metadata)
        writer.write(sTypes.TERM)

    progressBar = lib.progressBar(
        thinBackupSize, f"saving disk {disk.target}", args, count=count
    )
    compressedSizes: List[Any] = []
    for save in extents:
        if save.data is True:
            if streamType == "stream":
                stream.writeFrame(writer, sTypes.DATA, save.offset, save.length)
                logging.debug(
                    "Read data from: start %s, length: %s", save.offset, save.length
                )

            cSizes = None

            if save.length >= connection.maxRequestSize:
                logging.debug(
                    "Chunked data read from: start %s, length: %s",
                    save.offset,
                    save.length,
                )
                size, cSizes = lib.writeChunk(
                    writer,
                    save,
                    connection,
                    streamType,
                    args.compress,
                )
            else:
                size = lib.writeBlock(
                    writer,
                    save,
                    connection,
                    streamType,
                    args.compress,
                )
                if streamType == "raw":
                    size = writer.seek(save.offset)

            if streamType == "stream":
                writer.write(sTypes.TERM)
                if args.compress:
                    logging.debug("Compressed size: %s", size)
                    if cSizes:
                        blockList = {}
                        blockList[size] = cSizes
                        compressedSizes.append(blockList)
                    else:
                        compressedSizes.append(size)
                else:
                    assert size == save.length

            progressBar.update(save.length)
        else:
            if streamType == "raw":
                writer.seek(save.offset)
            elif streamType == "stream" and args.level not in ("inc", "diff"):
                stream.writeFrame(writer, sTypes.ZERO, save.offset, save.length)
    if streamType == "stream":
        stream.writeFrame(writer, sTypes.STOP, 0, 0)
        if args.compress:
            stream.writeCompressionTrailer(writer, compressedSizes)

    progressBar.close()
    logging.debug("Closing write handle.")
    writer.close()
    connection.disconnect()
    if args.offline is True and virtClient.remoteHost == "":
        logging.info("Stopping NBD Service.")
        lib.killProc(nbdProc.pid)

    if args.stdout is False:
        if args.noprogress is True:
            logging.info(
                "Backup of disk [%s] finished, file: [%s]", disk.target, targetFile
            )
        renamePartial(targetFilePartial, targetFile)

    return True


if __name__ == "__main__":
    main()
