#!/usr/bin/env python
# **********************************************************************
#
# Copyright (c) 2003-2013 ZeroC, Inc. All rights reserved.
#
# This copy of Ice is licensed to you under the terms described in the
# ICE_LICENSE file included in this distribution.
#
# **********************************************************************

import os
import sys
import getopt
import tempfile
import getpass
import shutil
import socket

home = os.getenv("ICE_CA_HOME")
if home is None:
    # The substring on sys.platform is required because some cygwin
    # versions return variations like "cygwin_nt-4.01".
    if sys.platform == "win32" or sys.platform[:6] == "cygwin":
        home = os.path.dirname(sys.argv[0])
        home = os.path.join(home, "..", "config")
    else:
        home = os.getenv('HOME')
        if home is None:
            print("Environment variable HOME is not set.")
            sys.exit(1)
        home = os.path.join(home, ".iceca")
home = os.path.normpath(home)
os.putenv("ICE_CA_HOME", home)

opensslCmd = "openssl"
if sys.platform == "win32" or sys.platform[:6] == "cygwin":

    # Add the directory containing this script to the PATH, this is
    # also where the openssl.exe is located.
    if os.path.exists(os.path.join(os.path.dirname(sys.argv[0]), "openssl.exe")):
        os.environ["PATH"] = os.path.dirname(sys.argv[0]) + os.pathsep + os.getenv("PATH", "")

    # Setup RANDFILE to avoid errors on Vista about openssl not being
    # able to write 'Random State'
    os.environ["RANDFILE"] = os.path.join(os.getenv("ICE_CA_HOME"), "randfile")

caroot = os.path.join(home, "ca")
cadb = os.path.join(caroot, "db")

def usage():
    print("usage: " + sys.argv[0] + " [--verbose --keep] import sign request init")
    sys.exit(1)

if len(sys.argv) == 1:
    usage()

# Work out the position of the script.
script = 1
while script < len(sys.argv) and sys.argv[script].startswith("--"):
    script = script + 1

if script > len(sys.argv):
    usage()

#
# Parse the global options.
#
try:
    opts, args = getopt.getopt(sys.argv[1:script], "", [ "verbose", "keep"])
except getopt.GetoptError:
    usage()

verbose = False
keep = False
for o, a in opts:
    if o == "--verbose":
        verbose = True
    if o == "--keep":
        keep = True

if sys.argv[script] == "import":
    def usage():
        print("usage: " + sys.argv[script] + " [--overwrite] [--key-pass password] [--store-pass password] [--java alias cert key keystore ] [--cs cert key out-file]")
        sys.exit(1)

    try:
        opts, args = getopt.getopt(sys.argv[script+1:], "", [ "overwrite", "key-pass=", "store-pass=", "java", "cs"])
    except getopt.GetoptError:
        usage()

    java = False
    cs = False
    keyPass = None
    storePass = None
    overwrite = False
    for o, a in opts:
        if o == "--overwrite":
            overwrite = True
        if o == "--key-pass":
            keyPass = a
        if o == "--store-pass":
            storePass = a
        if o == "--java":
            java = True
        if o == "--cs":
            cs = True

    if not java and not cs:
        print(sys.argv[script] + ": one of --java or --cs must be provided")
        usage()

    if java:
        #
        # For an RPM install we use /usr/share/Ice-<full-version>. A
        # non-RPM install uses ${prefix}/lib (that is the location of the
        # script ../lib). For development purposes we also check ".".
        #
        checkLocations = [".", 
                          os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "..", "lib")),
                          os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), "..", "..", "lib")),
                          "/usr/share/Ice-3.5.1"]
        for bindir in checkLocations:
            bindir = os.path.normpath(bindir)
            if os.path.exists(os.path.join(bindir, "ImportKey.class")):
                break
        else:
            print("Error: cannot locate ImportKey.class any of the following locations:\n%s" % str(checkLocations))
            sys.exit(1)

        if len(args) != 4:
            usage()

        alias = args[0]
        cert = args[1]
        key = args[2]
        store = args[3]

        if keyPass is None:
            keyPass = getpass.getpass("Enter private key passphrase:")

        while storePass is None or len(storePass) == 0:
            storePass = getpass.getpass("Enter keystore password:")

        #
        # We use secure temporary files to transfer the password to openssl
        # and the ImportKey java tool. Its necessary to create 2 files for the
        # key password because openssl craps out if you use the same password
        # file twice.
        #
        keypassfile1 = None
        keypassfile2 = None

        if len(keyPass) > 0:
            temp, keypassfile1 = tempfile.mkstemp("keypass1")
            os.write(temp, keyPass.encode("utf-8"))
            os.close(temp)

            temp, keypassfile2 = tempfile.mkstemp("keypass2")
            os.write(temp, keyPass.encode("utf-8"))
            os.close(temp)
        else:
            # Java can't deal with unencrypted keystores, so we use a
            # temporary one.
            temp, keypassfile1 = tempfile.mkstemp("keypass1")
            os.write(temp, "password".encode("utf-8"))
            os.close(temp)

            # We create a file with an empty password to store the results
            # in the keystore.
            temp, keypassfile2 = tempfile.mkstemp("keypass2")
            os.close(temp)

        temp, storepassfile = tempfile.mkstemp("storepass3")
        os.write(temp, storePass.encode("utf-8"))
        os.close(temp)

        temp, pkcs12cert = tempfile.mkstemp(".p12", "pkcs12")
        os.close(temp)

        if len(keyPass) > 0:
            cmd = opensslCmd + " pkcs12 -in " + cert + " -inkey " + key + " -export -out " + pkcs12cert + " -name " + \
                alias + " -passin file:" + keypassfile1 + " -passout file:" + keypassfile2 + ' -certfile "' + \
                os.path.join(home, "ca_cert.pem") + '"'
        else:
            cmd = opensslCmd + " pkcs12 -in " + cert + " -inkey " + key + " -export -out " + pkcs12cert + " -name " + \
                alias + " -passout file:" + keypassfile1 + ' -certfile "' + os.path.join(home, "ca_cert.pem") + '"'
        sys.stdout.write("converting to pkcs12 format... ")
        sys.stdout.flush()
        if verbose: print(cmd)
        status = os.system(cmd)
        if status != 0:
            print("openssl command failed")
            if not keep: os.remove(keypassfile1)
            if not keep: os.remove(keypassfile2)
            if not keep: os.remove(storepassfile)
            sys.exit(1)
        print("ok")

        # Use java to import the cert into the keystore.
        cmd = "java -classpath \"" + bindir + "\" ImportKey " + pkcs12cert + " " + alias + " " \
                  + os.path.join(home, "ca_cert.pem") + " " + store + " " + storepassfile + " " + keypassfile1
        if len(keyPass) == 0:
            cmd = cmd + " " + keypassfile2

        #print(cmd)
        sys.stdout.write("importing into the keystore...")
        sys.stdout.flush()
        if verbose: print(cmd)
        status = os.system(cmd)
        if status != 0:
            print("java command failed")
        else:
            print("ok")

        # Cleanup.
        if not keep: os.remove(pkcs12cert)
        if not keep: os.remove(keypassfile1)
        if not keep: os.remove(keypassfile2)
        if not keep: os.remove(storepassfile)

    if cs:
        if len(args) != 3:
            usage()
        if storePass is not None:
            print('Usage: %s:  "--store-pass" is only supported with --java' % sys.argv[script])
            sys.exit(1)

        cert = args[0]
        key = args[1]
        pkcs12cert = args[2]
        if not overwrite and os.path.exists(pkcs12cert):
            print(pkcs12cert + ": file exists")
            sys.exit(1)

        if keyPass is None:
            keyPass = getpass.getpass("Enter private key passphrase:")

        #
        # We use secure temporary files to transfer the password to
        # openssl Its necessary to create 2 files for the key password
        # because openssl craps out if you use the same password file
        # twice.
        #
        keypassfile1 = None
        keypassfile2 = None

        if len(keyPass) > 0:
            temp, keypassfile1 = tempfile.mkstemp("keypass1")
            os.write(temp, keyPass.encode("utf-8"))
            os.close(temp)

            temp, keypassfile2 = tempfile.mkstemp("keypass2")
            os.write(temp, keyPass.encode("utf-8"))
            os.close(temp)

            cmd = opensslCmd + " pkcs12 -in " + cert + " -inkey " + key + " -export -out " + pkcs12cert + \
                " -passin file:" + keypassfile1 + " -passout file:" + keypassfile2
        else:
            cmd = opensslCmd + " pkcs12 -in " + cert + " -inkey " + key + " -export -out " + pkcs12cert + \
                " -passout pass:"
        sys.stdout.write("converting to pkcs12 format...")
        sys.stdout.flush()
        if verbose: print(cmd)
        status = os.system(cmd)
        if keypassfile1 != None:
            if not keep: os.remove(keypassfile1)
        if keypassfile2 != None:
            if not keep: os.remove(keypassfile2)
        if status != 0:
            print("openssl command failed")
            sys.exit(1)
        print("ok")

    sys.exit(0)

if sys.argv[script] == "init":
    def usage():
        print("usage: " + sys.argv[script] + " [--no-password] [--overwrite]")
        sys.exit(1)

    try:
        opts, args = getopt.getopt(sys.argv[script+1:], "", [ "no-password", "overwrite"])
    except getopt.GetoptError:
        usage()

    if args:
        usage()

    print("This script will initialize your organization's Certificate Authority (CA).")
    print('The CA database will be created in "%s"' % caroot)

    nopassphrase = False
    for o, a in opts:
        if o == "--no-password":
            nopassphrase = True
        if o == "--overwrite":
            # If the CA exists then destroy it.
            if os.path.exists(caroot):
                print("Warning: running this command will destroy your existing CA setup!")
                sys.stdout.write("Do you want to continue? (y/n)")
                sys.stdout.flush()
                choice = sys.stdin.readline().strip()
                if choice != 'y' and choice != 'Y':
                    sys.exit(1)
                shutil.rmtree(caroot)

    #
    # Check that the caroot isn't already populated.
    #
    if os.path.exists(os.path.join(cadb, "ca_key.pem")):
        print(cadb + ": CA has already been initialized.")
        print("Use the --overwrite option to re-initialize the database.")
        sys.exit(1)

    try:
        os.makedirs(cadb)
    except OSError:
        pass

    #
    # Initialize the CA serial and index.
    #
    serial = open(os.path.join(cadb, "serial"), "w" )
    serial.write("01\n")
    serial.close()

    index = open(os.path.join(cadb, "index.txt"), "w")
    index.close()
        
    # Construct the DN for the CA certificate.
    DNelements = { \
        'C':['countryName', "Country name", "", 'match'], \
        'ST':['stateOrProvinceName', "State or province name", "", 'match'], \
        'L':['localityName', "Locality", "", 'match'], \
        'O':['organizationName', "Organization name", "GridCA-" + socket.gethostname(), 'match'], \
        'OU':['organizationalUnitName', "Organizational unit name", "", 'optional'], \
        'CN':['commonName', "Common name", "Grid CA", 'supplied'] \
    }

    while True:
        print("The subject name for your CA will be ")

        first = True
        for k,v in DNelements.items():
            if len(v[2]) > 0:
                if not first:
                    sys.stdout.write(", ")
                sys.stdout.write(k + "=" + v[2])
                first = False
        print("")

        sys.stdout.write("Do you want to keep this as the CA subject name? (y/n) [y]")
        sys.stdout.flush()
        input = sys.stdin.readline().strip()
        if input == 'n':
            for v in DNelements.values():
                sys.stdout.write(v[1] + ":")
                sys.stdout.flush()
                v[2] = sys.stdin.readline().strip()
        else:
            break

    while True:
        sys.stdout.write("Enter the email address of the CA: ")
        sys.stdout.flush()
        DNelements['emailAddress'] = ['emailAddress', '', sys.stdin.readline().strip(), 'optional']
        if DNelements['emailAddress'][2] and len(DNelements['emailAddress'][2]) > 0:
            break

    #
    # Static configuration file data. This avoids locating template files
    # and such.
    #
    config = {\
"ca.cnf":"\
# **********************************************************************\n\
#\n\
# Copyright (c) 2003-2013 ZeroC, Inc. All rights reserved.\n\
#\n\
# This copy of Ice is licensed to you under the terms described in the\n\
# ICE_LICENSE file included in this distribution.\n\
#\n\
# **********************************************************************\n\
\n\
# Configuration file for the CA. This file is generated by iceca init.\n\
# DO NOT EDIT!\n\
\n\
###############################################################################\n\
###  Self Signed Root Certificate\n\
###############################################################################\n\
\n\
[ ca ]\n\
default_ca = ice\n\
\n\
[ ice ]\n\
default_days     = 1825    # How long certs are valid.\n\
default_md       = md5     # The Message Digest type.\n\
preserve         = no      # Keep passed DN ordering?\n\
\n\
[ req ]\n\
default_bits        = 2048\n\
default_keyfile     = $ENV::ICE_CA_HOME/ca/db/ca_key.pem\n\
default_md          = md5\n\
prompt              = no\n\
distinguished_name  = dn\n\
x509_extensions     = extensions\n\
\n\
[ extensions ]\n\
basicConstraints = CA:true\n\
\n\
# PKIX recommendation.\n\
subjectKeyIdentifier = hash\n\
authorityKeyIdentifier = keyid:always,issuer:always\n\
\n\
[dn]\n\
",\
"sign.cnf":"\
# **********************************************************************\n\
#\n\
# Copyright (c) 2003-2013 ZeroC, Inc. All rights reserved.\n\
#\n\
# This copy of Ice is licensed to you under the terms described in the\n\
# ICE_LICENSE file included in this distribution.\n\
#\n\
# **********************************************************************\n\
\n\
# Configuration file to sign a certificate. This file is generated by iceca init.\n\
# DO NOT EDIT!!\n\
\n\
[ ca ]\n\
default_ca = ice\n\
\n\
[ ice ]\n\
dir              = $ENV::ICE_CA_HOME/ca/db  # Where everything is kept.\n\
private_key      = $dir/ca_key.pem   # The CA Private Key.\n\
certificate      = $dir/ca_cert.pem  # The CA Certificate.\n\
database         = $dir/index.txt           # Database index file.\n\
new_certs_dir    = $dir                     # Default loc for new certs.\n\
serial           = $dir/serial              # The current serial number.\n\
certs            = $dir                     # Where issued certs are kept.\n\
RANDFILE         = $dir/.rand               # Private random number file.\n\
default_days     = 1825                     # How long certs are valid.\n\
default_md       = md5                      # The Message Digest type.\n\
preserve         = yes                      # Keep passed DN ordering?\n\
\n\
policy           = ca_policy\n\
x509_extensions  = certificate_extensions\n\
\n\
[ certificate_extensions ]\n\
basicConstraints = CA:false\n\
\n\
# PKIX recommendation.\n\
subjectKeyIdentifier = hash\n\
authorityKeyIdentifier = keyid:always,issuer:always\n\
\n\
# ca_policy is generated by the initca script.\n\
",\
"req.cnf":"\
# **********************************************************************\n\
#\n\
# Copyright (c) 2003-2013 ZeroC, Inc. All rights reserved.\n\
#\n\
# This copy of Ice is licensed to you under the terms described in the\n\
# ICE_LICENSE file included in this distribution.\n\
#\n\
# **********************************************************************\n\
\n\
# Configuration file to request a node, registry or service\n\
# certificate. This file is generated by iceca init.\n\
# DO NOT EDIT!\n\
\n\
[ req ]\n\
default_bits        = 1024\n\
default_md          = md5\n\
prompt              = no\n\
distinguished_name  = dn\n\
x509_extensions     = extensions\n\
\n\
[ extensions ]\n\
basicConstraints = CA:false\n\
\n\
# PKIX recommendation.\n\
subjectKeyIdentifier = hash\n\
authorityKeyIdentifier = keyid:always,issuer:always\n\
keyUsage = nonRepudiation, digitalSignature, keyEncipherment\n\
\n\
# The dn section is added by the initca script.\n\
\n\
[dn]\n\
"\
    }

    #
    # It is necessary to generate configuration files because the
    # openssl configuration files do not permit empty values.
    #
    sys.stdout.write("Generating configuration files... ")

    sys.stdout.write("ca.cnf")
    sys.stdout.flush()
    temp, cacnfname = tempfile.mkstemp(".cnf", "ca")
    os.write(temp, config["ca.cnf"].encode("utf-8"))
    for k,v in DNelements.items():
        if len(v[2]) > 0:
            os.write(temp, (v[0] + "=" + v[2] + "\n").encode("utf-8"))
    os.close(temp)

    file = 'sign.cnf'
    sys.stdout.write(" " + file)
    sys.stdout.flush()
    cnf = open(os.path.join(caroot, file), "w")
    cnf.write(config[file])
    cnf.write("[ ca_policy ]\n");
    for k,v in DNelements.items():
        if len(v[2]) > 0:
            cnf.write(v[0] + "=" + v[3] + "\n")
    cnf.close()

    # Don't want these RDNs in req.cnf
    del DNelements['emailAddress']
    del DNelements['CN']

    file = "req.cnf"
    sys.stdout.write(file)
    sys.stdout.flush()
    cnf = open(os.path.join(home, file), "w")
    cnf.write(config[file])
    for k,v in DNelements.items():
        if len(v[2]) > 0:
            cnf.write(v[0] + "=" + v[2] + "\n")
    cnf.close()

    print("ok")

    cmd = opensslCmd + ' req -config "' + cacnfname + '" -x509 -days 1825 -newkey rsa:2048 -out "' + \
            os.path.join(cadb, "ca_cert.pem") + '" -outform PEM'
    if nopassphrase:
        cmd += " -nodes"

    #print(cmd)
    if verbose: print(cmd)
    status = os.system(cmd)
    if not keep: os.remove(cacnfname)
    if status != 0:
        print("openssl command failed")
        sys.exit(1)

    # Copy in the new ca certificate and private key.
    shutil.copy(os.path.join(cadb, "ca_cert.pem"), os.path.join(home))

    print("")
    print("The CA is initialized.")
    print("")
    print("You need to distribute the following files to all machines that can")
    print("request certificates:")
    print("")
    print(os.path.join(home, "req.cnf"))
    print(os.path.join(home, "ca_cert.pem"))
    print("")
    print("These files should be placed in the user's home directory in")
    print("~/.iceca. On Windows, place these files in <ice-install>/config.")

    sys.exit(0)

if sys.argv[script] == "request":
    def usage():
        print("usage: " + sys.argv[script] + " [--overwrite] [--no-password] file common-name [email]")
        sys.exit(1)

    try:
        opts, args = getopt.getopt(sys.argv[script+1:], "", [ "overwrite", "no-password" ])
    except getopt.GetoptError:
        usage()

    nopassphrase = False
    overwrite = False
    for o, a in opts:
        if o == "--overwrite":
            overwrite = True
        elif o == "--no-password":
            nopassphrase = True

    if len(args) < 2 or len(args) > 3:
        usage()

    keyfile = args[0] + "_key.pem"
    reqfile = args[0] + "_req.pem"
    if not overwrite:
        if os.path.exists(keyfile):
            print(keyfile + ": exists")
            sys.exit(1)
        if os.path.exists(reqfile):
            print(reqfile + ": exists")
            sys.exit(1)

    commonName = args[1]

    if len(args) == 3:
        email = args[2]
    else:
        email = None

    #
    # Create a temporary configuration file.
    #
    template = open(os.path.join(home, "req.cnf"), "r")
    if not template:
        print("cannot open " + os.path.join(home, "req.cnf"))
        sys.exit(1)

    data = template.read()
    template.close()
    temp, tempname = tempfile.mkstemp(".cnf", "req")
    os.write(temp, data.encode("utf-8"))
    os.write(temp, ("commonName=" + commonName + "\n").encode("utf-8"))
    if email:
        os.write(temp, ("emailAddress=" + email + "\n").encode("utf-8"))
    os.close(temp)

    cmd = opensslCmd + ' req -config "' + tempname + '" -new -keyout "' + keyfile + '" -out "' + reqfile + '"'
    if nopassphrase:
        cmd += " -nodes"

    if verbose: print(cmd)
    status = os.system(cmd)
    if not keep: os.remove(tempname)
    if status != 0:
        print("openssl command failed")
        sys.exit(1)

    print("")
    print("Created key: " + keyfile)
    print("Created certificate request: " + reqfile)
    print("")
    print("The certificate request must be signed by the CA. Send the certificate")
    print("request file to the CA at the following email address:")
    cmd = opensslCmd + ' x509 -in "' + os.path.join(home, "ca_cert.pem") + '" -email -noout'
    if verbose: print(cmd)
    os.system(cmd)

    sys.exit(0)

if sys.argv[script] == "sign":
    def usage():
        print("usage: " + sys.argv[script] + " --in <req> --out <cert> [--ip <ip> --dns <dns>]")
        sys.exit(1)

    try:
        opts, args = getopt.getopt(sys.argv[script+1:], "", [ "in=", "out=", "ip=", "dns=" ])
    except getopt.GetoptError:
        usage()

    if args:
        usage()

    infile = None
    outfile = None
    subjectAltName = ""
    for o, a in opts:
        if o == "--in":
            infile = a
        elif o == "--out":
            outfile = a
        elif o == "--ip":
            if len(subjectAltName) > 0:
                subjectAltName += ","
            subjectAltName += "IP:" + a
        elif o == "--dns":
            if len(subjectAltName) > 0:
                subjectAltName += ","
            subjectAltName += "DNS:" + a

    if infile == None or outfile == None:
        usage()

    #
    # Create a temporary configuration file.
    #
    template = open(os.path.join(caroot, "sign.cnf"), "r")
    if not template:
        print("cannot open " + os.path.join(caroot, "sign.cnf"))
        sys.exit(1)

    data = template.read()
    template.close()
    temp, tempname = tempfile.mkstemp(".cnf", "sign")
    os.write(temp, data.encode("utf-8"))
    if len(subjectAltName) > 0:
        os.write(temp, ("\n[certificate_extensions]\nsubjectAltName=" + subjectAltName + "\n").encode("utf-8"))
    os.close(temp)

    cmd = opensslCmd + ' ca -config "' + tempname + '" -in "' + infile + '" -out "' + outfile + '"'
    if verbose: print(cmd)
    status = os.system(cmd)
    if not keep: os.remove(tempname)
    if status != 0:
        print("openssl command failed")
        sys.exit(1)

    sys.exit(0)

usage()
