IntelliJKdbxDatabase.java

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.identity.implementation.intellij;

import com.azure.identity.CredentialUnavailableException;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
import org.xml.sax.SAXException;

import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
import javax.xml.xpath.XPath;
import javax.xml.xpath.XPathConstants;
import javax.xml.xpath.XPathExpressionException;
import javax.xml.xpath.XPathFactory;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Locale;
import java.util.zip.GZIPInputStream;

public class IntelliJKdbxDatabase {

    private static final XPath XPATH = XPathFactory.newInstance().newXPath();

    private static final String STANDARD_PROPERTY_NAME_USER_NAME = "UserName";
    private static final String STANDARD_PROPERTY_NAME_PASSWORD = "Password";
    private static final String STANDARD_PROPERTY_NAME_URL = "URL";
    private static final String STANDARD_PROPERTY_NAME_TITLE = "Title";
    private static final String STANDARD_PROPERTY_NAME_NOTES = "Notes";

    private Element rootElement;

    IntelliJKdbxDatabase(Document document, Element rootElement) {
        this.rootElement = rootElement;
    }

    public static IntelliJKdbxDatabase parse(InputStream encryptedDatabaseStream, String databasePassword) throws IOException {

        byte[] key = getDatabaseKey(databasePassword);

        IntelliJKdbxMetadata kdbxMetadata = new IntelliJKdbxMetadata();
        InputStream decryptedInputStream = decryptInputStream(key, kdbxMetadata, encryptedDatabaseStream);

        Salsa20 salsa20 = IntelliJCryptoUtil.createSalsa20CryptoEngine(kdbxMetadata.getEncryptionKey());
        Document document = loadDatabase(decryptedInputStream, salsa20);

        Element rootElement;
        try {
            rootElement = (Element) XPATH.evaluate("/KeePassFile/Root/Group", document, XPathConstants.NODE);
        } catch (XPathExpressionException e) {
            throw new CredentialUnavailableException("Error loading the database", e);
        }
        return new IntelliJKdbxDatabase(document, rootElement);
    }

    private static byte[] getDatabaseKey(String databasePassword) {
        MessageDigest md = IntelliJCryptoUtil.getMessageDigestSHA256();
        byte[] digest = md.digest(databasePassword.getBytes(StandardCharsets.UTF_8));
        return md.digest(digest);
    }


    private static InputStream decryptInputStream(byte[] key, IntelliJKdbxMetadata kdbxMetadata, InputStream inputStream)
        throws IOException {
        parseDatabaseMetadata(kdbxMetadata, inputStream);
        InputStream decryptedInputStream = IntelliJCryptoUtil.createDecryptedStream(key, inputStream, kdbxMetadata);
        validateInitialBytes(kdbxMetadata, decryptedInputStream);
        HashedBlockInputStream blockInputStream = new HashedBlockInputStream(decryptedInputStream, true);
        return kdbxMetadata.getDatabaseCompressionFlags().equals(IntelliJKdbxMetadata.DatabaseCompressionFlags.NONE)
            ? blockInputStream : new GZIPInputStream(blockInputStream);
    }

    private static IntelliJKdbxMetadata parseDatabaseMetadata(IntelliJKdbxMetadata kdbxMetadata, InputStream inputStream) throws IOException {
        MessageDigest digest = IntelliJCryptoUtil.getMessageDigestSHA256();
        DigestInputStream digestInputStream = new DigestInputStream(inputStream, digest);
        LittleEndianDataInputStream littleEndianDataInputStream = new LittleEndianDataInputStream(digestInputStream);

        int sig1 = littleEndianDataInputStream.readInt();
        int sig2 = littleEndianDataInputStream.readInt();
        if (sig1 != -1700603645 || sig2 != -1253311641) {
            throw new IllegalStateException("Magic number did not match");
        } else if ((littleEndianDataInputStream.readInt() & -65536) > 196608) {
            throw new IllegalStateException("File version did not match");
        } else {
            byte headerType;
            while ((headerType = littleEndianDataInputStream.readByte()) != 0) {
                switch (headerType) {
                    case 1:
                        readByteArray(littleEndianDataInputStream);
                        break;
                    case 2:
                        kdbxMetadata.setCipherUuid(readByteArray(littleEndianDataInputStream));
                        break;
                    case 3:
                        kdbxMetadata.setDatabaseCompressionFlags(readInt(littleEndianDataInputStream));
                        break;
                    case 4:
                        kdbxMetadata.setBaseSeed(readByteArray(littleEndianDataInputStream));
                        break;
                    case 5:
                        kdbxMetadata.setTransformSeed(readByteArray(littleEndianDataInputStream));
                        break;
                    case 6:
                        kdbxMetadata.setTransformRounds(readLong(littleEndianDataInputStream));
                        break;
                    case 7:
                        kdbxMetadata.setEncryptionIv(readByteArray(littleEndianDataInputStream));
                        break;
                    case 8:
                        kdbxMetadata.setEncryptionKey(readByteArray(littleEndianDataInputStream));
                        break;
                    case 9:
                        kdbxMetadata.setInitBytes(readByteArray(littleEndianDataInputStream));
                        break;
                    case 10:
                        kdbxMetadata.setEncryptionAlgorithm(readInt(littleEndianDataInputStream));
                        break;
                    default:
                        throw new IllegalStateException("Unknown File Header");
                }
            }

            readByteArray(littleEndianDataInputStream);
            kdbxMetadata.setHeaderHash(digest.digest());
            return kdbxMetadata;
        }
    }

    private static byte[] readByteArray(LittleEndianDataInputStream ledis) throws IOException {
        short fieldLength = ledis.readShort();
        byte[] value = new byte[fieldLength];
        ledis.readFully(value);
        return value;
    }

    private static void validateInitialBytes(IntelliJKdbxMetadata kdbxMetadata, InputStream decryptedInputStream) throws IOException {
        LittleEndianDataInputStream ledis = new LittleEndianDataInputStream(decryptedInputStream);
        byte[] initBytes = new byte[32];
        ledis.readFully(initBytes);
        if (!Arrays.equals(initBytes, kdbxMetadata.getInitBytes())) {
            throw new IllegalStateException("Inconsistent stream start bytes. This usually means the credentials were wrong.");
        }
    }

    private static int readInt(LittleEndianDataInputStream ledis) throws IOException {
        short length = ledis.readShort();
        if (length != 4) {
            throw new IllegalStateException("Int required but length was " + length);
        } else {
            return ledis.readInt();
        }
    }

    private static long readLong(LittleEndianDataInputStream ledis) throws IOException {
        short length = ledis.readShort();
        if (length != 8) {
            throw new IllegalStateException("Long required but length was " + length);
        } else {
            return ledis.readLong();
        }
    }

    public static Document loadDatabase(InputStream inputStream, Salsa20 salsa20Engine) throws IOException {
        DocumentBuilderFactory dbFactory = DocumentBuilderFactory.newInstance();

        try {
            DocumentBuilder dBuilder = dbFactory.newDocumentBuilder();
            Document doc = dBuilder.parse(inputStream);
            NodeList protectedContent = (NodeList) XPATH.evaluate("//*[@Protected='True']", doc, XPathConstants.NODESET);

            for (int i = 0; i < protectedContent.getLength(); ++i) {
                Element element = (Element) protectedContent.item(i);
                Element res = getElement(".", element, false);
                String base64 = res == null ? null : res.getTextContent();
                byte[] encrypted = Base64.getDecoder().decode(base64.getBytes(StandardCharsets.UTF_8));
                String decrypted = new String(IntelliJCryptoUtil.decrypt(encrypted, salsa20Engine), "UTF-8");
                setElementContent(".", element, decrypted);
                element.removeAttribute("Protected");
            }

            return doc;
        } catch (ParserConfigurationException var10) {
            throw new IllegalStateException("Instantiating Document Builder", var10);
        } catch (SAXException var11) {
            throw new IllegalStateException("Parsing exception", var11);
        } catch (XPathExpressionException var12) {
            throw new IllegalStateException("XPath Exception", var12);
        }
    }

    private static Element getElement(String elementPath, Element parentElement, boolean create) {
        try {
            Element output = (Element) XPATH.evaluate(elementPath, parentElement, XPathConstants.NODE);
            if (output == null && create) {
                output = buildHierarchialPath(elementPath, parentElement);
            }

            return output;
        } catch (XPathExpressionException e) {
            throw new RuntimeException(e);
        }
    }

    private static Element buildHierarchialPath(String elementPath, Element startElement) {
        Element currentElement = startElement;
        String[] pathTokens = elementPath.split("/");

        for (int i = 0; i < pathTokens.length; ++i) {
            String elementName = pathTokens[i];

            try {
                Element nextElement = (Element) XPATH.evaluate(elementName, currentElement, XPathConstants.NODE);
                if (nextElement == null) {
                    nextElement = (Element) currentElement.appendChild(currentElement.getOwnerDocument()
                        .createElement(elementName));
                }

                currentElement = nextElement;
            } catch (XPathExpressionException e) {
                throw new RuntimeException(e);
            }
        }

        return currentElement;
    }

    static Element setElementContent(String elementPath, Element parentElement, String value) {
        Element result = getElement(elementPath, parentElement, true);
        result.setTextContent(value);
        return result;
    }

    public static boolean match(Element baseEntry, String text) {
        String title = getProperty(baseEntry, STANDARD_PROPERTY_NAME_TITLE);
        String notes = getProperty(baseEntry, STANDARD_PROPERTY_NAME_NOTES);
        String url = getProperty(baseEntry, STANDARD_PROPERTY_NAME_URL);
        String username = getProperty(baseEntry, STANDARD_PROPERTY_NAME_USER_NAME);

        return matchString(title, text) || matchString(notes, text) || matchString(url, text) || matchString(username, text);
    }

    public static boolean matchString(String property, String toMatch) {
        return property != null && property.toLowerCase(Locale.getDefault())
            .contains(toMatch.toLowerCase(Locale.getDefault()));
    }

    private String getDatabaseEntryValue(Element dbRootGroup, String toMatch) {

        for (Element entry: getElements("Entry", dbRootGroup)) {
            if (match(entry, toMatch)) {
                return getProperty(entry, STANDARD_PROPERTY_NAME_PASSWORD);
            }
        }
        for (Element group : getGroups(dbRootGroup)) {
            getDatabaseEntryValue(group, toMatch);
        }

        return null;
    }

    public String getDatabaseEntryValue(String toMatch) {
        return getDatabaseEntryValue(rootElement, toMatch);
    }


    static String getElementContent(String elementPath, Element parentElement) {
        Element result = getElement(elementPath, parentElement, false);
        return result == null ? null : result.getTextContent();
    }

    public static String getProperty(Element element, String name) {
        Element property = getElement(String.format("String[Key/text()='%s']", name), element, false);
        return property == null ? null : getElementContent("Value", property);
    }

    static List<Element> getElements(String elementPath, Element parentElement) {
        try {
            NodeList nodes = (NodeList) XPATH.evaluate(elementPath, parentElement, XPathConstants.NODESET);
            ArrayList<Element> result = new ArrayList<>(nodes.getLength());

            for (int i = 0; i < nodes.getLength(); ++i) {
                result.add((Element) nodes.item(i));
            }

            return result;
        } catch (XPathExpressionException var5) {
            throw new IllegalStateException(var5);
        }
    }

    public static List<Element> getGroups(Element rootGroup) {
        List<Element> elements = getElements("Group", rootGroup);
        return elements;
    }
}