import unittest
from importlib import reload
from os import environ, remove
from os.path import exists
from shutil import copyfile
from time import sleep

from cryptography.hazmat.primitives import serialization

import test.cert_wizard.revpi_cert_wizard as rcw

PATH_CERT = environ.get("REVPI_CERT_PATH", "/tmp/revpi-cert-wizard-testing.pem")
PATH_KEY = environ.get("REVPI_KEY_PATH", "/tmp/revpi-cert-wizard-testing.key")


class TestRevpiCertWizard(unittest.TestCase):

    def setUp(self):
        self.assertIsNotNone(PATH_CERT)
        self.assertIsNotNone(PATH_KEY)

        # Set up the environment variables for testing and reload the module
        environ["REVPI_CERT_PATH"] = PATH_CERT
        environ["REVPI_KEY_PATH"] = PATH_KEY
        reload(rcw)

        # Clean up any existing cert files
        remove(PATH_CERT) if exists(PATH_CERT) else None
        remove(PATH_KEY) if exists(PATH_KEY) else None

    def test_functions_without_certificate(self):
        """Test functions when no certificate is present."""
        self.assertIsNone(rcw.read_certificate(PATH_CERT))
        self.assertIsNone(rcw.read_privatekey(PATH_KEY))
        self.assertFalse(rcw.check_cert_key())

    def test_certificate_creation(self):
        self.assertTrue(rcw.create_certificate())
        self.assertTrue(exists(PATH_CERT))
        self.assertTrue(exists(PATH_KEY))

        cert = rcw.read_certificate(PATH_CERT)
        pkey = rcw.read_privatekey(PATH_KEY)
        self.assertIsNotNone(cert)
        self.assertIsNotNone(pkey)

        # Compare the public key bytes of the certificate and the private key
        cert_public_bytes = cert.public_key().public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        pkey_public_bytes = pkey.public_key().public_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PublicFormat.SubjectPublicKeyInfo
        )
        self.assertTrue(cert_public_bytes == pkey_public_bytes)
        self.assertTrue(rcw.check_cert_key())

        # todo: Test exception handling of create_certificate

    def test_main_no_args(self):
        # Generate a new cert by missing the PATH_CERT file
        self.assertEqual(rcw.main(), 0)
        self.assertTrue(exists(PATH_CERT))
        self.assertTrue(exists(PATH_KEY))
        cert_serial = get_cert_serial()

        # Generate a new cert by missing the PATH_KEY file
        remove(PATH_KEY)
        self.assertEqual(rcw.main(), 0)
        self.assertTrue(exists(PATH_CERT))
        self.assertTrue(exists(PATH_KEY))
        self.assertNotEqual(get_cert_serial(), cert_serial)
        cert_serial = get_cert_serial()

        # No new cert if exists
        self.assertEqual(rcw.main(), 0)
        self.assertEqual(get_cert_serial(), cert_serial)

    def test_main_cert_and_key_mismatch(self):
        self.assertEqual(rcw.main(), 0)
        self.assertTrue(exists(PATH_CERT))
        self.assertTrue(exists(PATH_KEY))
        self.assertTrue(rcw.check_cert_key())

        # Create a second set of cert and key files
        rcw.PATH_CERT = PATH_CERT + ".2"
        rcw.PATH_KEY = PATH_KEY + ".2"
        self.assertEqual(rcw.main(), 0)
        cert_serial = get_cert_serial()

        # Now change the first cert file with the second.
        copyfile(PATH_CERT, PATH_CERT + ".2")
        self.assertEqual(rcw.main(), 0)
        self.assertNotEqual(get_cert_serial(), cert_serial)

    def test_main_force_flag(self):
        self.assertTrue(rcw.create_certificate())
        cert_serial = get_cert_serial()

        # No new cert without the -f flag
        self.assertEqual(rcw.main(), 0)
        self.assertEqual(get_cert_serial(), cert_serial)

        # New cert with the -t flag
        rcw.args.force = True
        self.assertEqual(rcw.main(), 0)
        self.assertNotEqual(get_cert_serial(), cert_serial)

    def test_main_cert_expired(self):
        rcw.CERT_LIFETIME_DAYS = 0
        self.assertTrue(rcw.create_certificate())
        cert_serial = get_cert_serial()

        # Let the cert expire
        sleep(1.0)

        # No new cert without the -t flag
        self.assertEqual(rcw.main(), 0)
        self.assertEqual(get_cert_serial(), cert_serial)

        # New cert with the -t flag
        rcw.args.time = True
        self.assertEqual(rcw.main(), 0)
        self.assertNotEqual(get_cert_serial(), cert_serial)

    def test_main_cert_san_mismatch(self):
        self.assertTrue(rcw.create_certificate())
        cert_serial = get_cert_serial()

        # Change hostname to trigger mismatch
        rcw.MY_HOSTNAME = "test-cert-wizard-py"

        # No new cert without the -n flag
        self.assertEqual(rcw.main(), 0)
        self.assertEqual(get_cert_serial(), cert_serial)

        # New cert with the -n flag
        rcw.args.name = True
        self.assertEqual(rcw.main(), 0)
        self.assertNotEqual(get_cert_serial(), cert_serial)


def get_cert_serial() -> int:
    cert = rcw.read_certificate(rcw.PATH_CERT)
    return cert.serial_number
