#!/usr/bin/env python3
"""This script is designed to generate self-signed TLS (Transport Layer
Security) certificates for encrypting sensitive information during
transmission.
TLS certificates are essential for ensuring secure communication and
preventing unauthorized access to data transmitted over the network.
"""
__version__ = "2.0.0"

import argparse
import datetime
import os
import socket
import sys
from os import environ
from sys import stderr, stdout
from typing import Optional
from uuid import uuid4

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

CERT_LIFETIME_DAYS = 365 * 10
KEY_LENGTH = 4096
MY_HOSTNAME = socket.gethostname().lower()
MY_AVAHINAME = MY_HOSTNAME.split(".")[0] + ".local".lower()
PATH_CERT = environ.get("REVPI_CERT_PATH", "/etc/ssl/certs/revpi-self-signed.pem")
PATH_KEY = environ.get("REVPI_KEY_PATH", "/etc/ssl/private/revpi-self-signed.key")

parser = argparse.ArgumentParser(
    description=__doc__,
    epilog="Default behaviour: no arguments. Check if certificate and key exist and if not, create the certificate and key.",
)

parser.add_argument(
    "-n",
    "--name",
    dest="name",
    action="store_true",
    default=False,
    help="create TLS-Certificate only when hostname not correct",
)
parser.add_argument(
    "-t",
    "--time",
    dest="time",
    action="store_true",
    default=False,
    help="create TLS-Certificate only when date is expired",
)
parser.add_argument(
    "-f",
    "--force",
    dest="force",
    action="store_true",
    default=False,
    help="create new TLS-Certificate overwriting any existing certificate",
)
parser.add_argument(
    "-v",
    dest="verbose",
    action="store_true",
    default=False,
    help="increase output verbosity",
)
parser.add_argument(
    "--version",
    action="version",
    version=f"%(prog)s {__version__}",
)

args = parser.parse_args()
verbose = args.verbose


def check_cert_key() -> bool:
    """
    Checks whether the key file fits the certificate file.
    """
    cert = read_certificate(PATH_CERT)
    pkey = read_privatekey(PATH_KEY)

    if cert is None or pkey is None:
        return False

    # Compare the public key bytes of the certificate and the private key
    cert_public_bytes = cert.public_key().public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    )
    pkey_public_bytes = pkey.public_key().public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo
    )
    key_matches = cert_public_bytes == pkey_public_bytes

    if verbose:
        if key_matches:
            stdout.write("Key and certificate match.\n")
        else:
            stdout.write("Key and certificate do not match.\n")

    return key_matches


def create_certificate() -> bool:
    """
    Create and write a self-signed certificate and key into pem files.
    """
    stdout.write("Generate certificate... ")

    # RSA Key Generierung (gehört technisch zu hazmat, ist aber der Standardweg)
    private_key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=KEY_LENGTH,
    )

    subject = issuer = x509.Name([
        x509.NameAttribute(NameOID.COMMON_NAME, MY_HOSTNAME),
        x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, "Revolution Pi"),
    ])

    now = datetime.datetime.now(datetime.timezone.utc)
    cert = (
        x509.CertificateBuilder()
        .subject_name(subject)
        .issuer_name(issuer)
        .public_key(private_key.public_key())
        .serial_number(uuid4().int)
        .not_valid_before(now)
        .not_valid_after(now + datetime.timedelta(days=CERT_LIFETIME_DAYS))
        .add_extension(
            x509.SubjectAlternativeName([
                x509.DNSName(MY_HOSTNAME),
                x509.DNSName(MY_AVAHINAME),
            ]),
            critical=False,
        )
        .add_extension(
            x509.BasicConstraints(ca=False, path_length=None),
            critical=False,
        )
        .sign(private_key, hashes.SHA256())
    )

    try:
        # Serialisierung über die High-Level Methoden der Objekte
        key_pem = private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption(),
        )
        cert_pem = cert.public_bytes(serialization.Encoding.PEM)
    except Exception as e:
        stderr.write(f"\nCould not create certificate-key-pair: {e}\n")
        return False

    # Save certificate and private key to files
    try:
        # First, write the key file, the folder usually requires higher permissions.
        with open(PATH_KEY, "wb") as f:
            f.write(key_pem)
    except OSError as e:
        stderr.write(f"\nCould not write private key file '{PATH_KEY}': {e}\n")
        return False
    try:
        with open(PATH_CERT, "wb") as f:
            f.write(cert_pem)
    except OSError as e:
        stderr.write(f"\nCould not write certificate file '{PATH_CERT}': {e}\n")

        # Delete the key file, which we created in the step before
        os.remove(PATH_KEY)

        return False

    stdout.write(" done.\n")
    return True


def read_certificate(path: str) -> Optional[x509.Certificate]:
    try:
        with open(path, "rb") as f:
            return x509.load_pem_x509_certificate(f.read())
    except Exception as e:
        stderr.write(f"Error reading certificate '{path}': {e}\n")
        return None


def read_privatekey(path: str) -> Optional[rsa.RSAPrivateKey]:
    try:
        with open(path, "rb") as f:
            return serialization.load_pem_private_key(f.read(), password=None)
    except Exception as e:
        stderr.write(f"Error reading private key '{path}': {e}\n")
        return None


def main() -> int:
    if not os.path.exists(PATH_CERT):
        stdout.write("Certificate did not exist...\n")
        return int(not create_certificate())

    if not os.path.exists(PATH_KEY):
        stdout.write("Key did not exist...\n")
        return int(not create_certificate())

    if not check_cert_key():
        stdout.write("Keyfile does not fit certificate file.\n")
        return int(not create_certificate())

    if args.force:
        return int(not create_certificate())

    if args.time:
        cert = read_certificate(PATH_CERT)
        if not cert:
            # Could not read the certificate file
            return 1

        # Cryptography uses timezone-aware datetimes
        now = datetime.datetime.now()
        not_after = cert.not_valid_after
        not_before = cert.not_valid_before

        if now > not_after - datetime.timedelta(days=8):
            stdout.write(f"The certificate provided expired on {not_after}.\n")
            return int(not create_certificate())
        elif verbose:
            stdout.write(f"Information:\nValid from: {not_before}\nExpires on: {not_after}\n\n")

    if args.name:
        cert = read_certificate(PATH_CERT)
        if not cert:
            return 1

        expected_sans = {MY_HOSTNAME, MY_AVAHINAME}

        try:
            ext = cert.extensions.get_extension_for_class(x509.SubjectAlternativeName)
            found_sans = {name.value.lower() for name in ext.value if isinstance(name, x509.DNSName)}
        except x509.ExtensionNotFound:
            found_sans = set()

        missing_san = expected_sans - found_sans
        if missing_san:
            stdout.write(f"Missing SAN entries '{', '.join(missing_san)}' - Generating new certificate\n")
            return int(not create_certificate())

    stdout.write(f"Certificate not modified. To get more info {parser.prog} -h\n")
    return 0


if __name__ == "__main__":
    sys.exit(main())
