/*
 * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.  Oracle designates this
 * particular file as subject to the "Classpath" exception as provided
 * by Oracle in the LICENSE file that accompanied this code.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package jdk.internal.net.http.quic;

import javax.crypto.KeyGenerator;
import javax.crypto.Mac;
import java.nio.ByteBuffer;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.atomic.AtomicLong;

import static jdk.internal.net.http.quic.QuicConnectionId.MAX_CONNECTION_ID_LENGTH;

/**
 * A class to generate connection ids bytes.
 * This algorithm is specific to our implementation - it's not defined
 * in any RFC (connection id bytes are free form).
 * For the purpose of validation we encode the length of
 * the connection id into the connection id bytes.
 * For the purpose of uniqueness we encode a unique id.
 * The rest of the connection id are random bytes.
 */
public class QuicConnectionIdFactory {
    private static final Random RANDOM = new SecureRandom();
    private static final String CLIENT_DESC = "QuicClientConnectionId";
    private static final String SERVER_DESC = "QuicServerConnectionId";

    private static final int MIN_CONNECTION_ID_LENGTH = 9;

    private final AtomicLong tokens = new AtomicLong();
    private volatile boolean wrapped;
    private final byte[] scrambler;
    private final Key statelessTokenKey;
    private final String simpleDesc;
    private final int connectionIdLength = RANDOM.nextInt(MIN_CONNECTION_ID_LENGTH, MAX_CONNECTION_ID_LENGTH+1);

    public static QuicConnectionIdFactory getClient() {
        return new QuicConnectionIdFactory(CLIENT_DESC);
    }

    public static QuicConnectionIdFactory getServer() {
        return new QuicConnectionIdFactory(SERVER_DESC);
    }

    private QuicConnectionIdFactory(String simpleDesc) {
        this.simpleDesc = simpleDesc;
        byte[] temp = new byte[MAX_CONNECTION_ID_LENGTH];
        RANDOM.nextBytes(temp);
        scrambler = temp;
        try {
            KeyGenerator kg = KeyGenerator.getInstance("HmacSHA256");
            statelessTokenKey = kg.generateKey();
        } catch (NoSuchAlgorithmException e) {
            throw new RuntimeException("HmacSHA256 key generator not available", e);
        }
    }

    /**
     * The connection ID length used by this Quic instance.
     * This is the source connection id length for outgoing packets,
     * and the destination connection id length for incoming packets.
     * @return the connection ID length used by this instance
     */
    public int connectionIdLength() {
        return connectionIdLength;
    }

    /**
     * Creates a new connection ID for a connection.
     * @return a new connection ID
     */
    public QuicConnectionId newConnectionId() {
        long token = newToken();
        return new QuicLocalConnectionId(token, simpleDesc,
                newConnectionId(connectionIdLength, token));
    }

    /**
     * Quick validation to see if the buffer can contain a connection
     * id generated by this instance. The byte buffer is expected to have
     * its {@linkplain ByteBuffer#position() position} set at the start
     * of the connection id, and its {@linkplain ByteBuffer#limit() limit}
     * at the end. In other words, {@code Buffer.remaining()} should
     * indicate the connection id length.
     * <p> This method does not advance the buffer position, and
     * returns a connection id that wraps the given buffer.
     * The returned connection id is only safe to use as long as
     * the buffer is not modified.
     * <p> It is usually only used temporarily as a lookup key
     * to locate an existing {@code QuicConnection}.
     *
     * @param buffer A buffer that delimits a connection id.
     * @return a new QuicConnectionId if the buffer can contain
     *         a connection id generated by this instance, {@code null}
     *         otherwise.
     */
    public QuicConnectionId unsafeConnectionIdFor(ByteBuffer buffer) {
        int expectedLength = connectionIdLength;

        int remaining = buffer.remaining();
        if (remaining < MIN_CONNECTION_ID_LENGTH) return null;
        if (remaining != expectedLength) return null;

        byte first = buffer.get(0);
        int len = extractConnectionIdLength(first);
        if (len < MIN_CONNECTION_ID_LENGTH) return null;
        if (len > MAX_CONNECTION_ID_LENGTH) return null;
        if (len != expectedLength) return null;

        long token = peekConnectionIdToken(buffer);
        if (!isValidToken(token)) return null;
        var cid = new QuicLocalConnectionId(buffer, token, simpleDesc);
        assert cid.length() == expectedLength;
        return cid;
    }

    /**
     * Returns a stateless reset token for the given connection ID
     * @param connectionId connection ID
     * @return stateless reset token for the given connection ID
     * @throws IllegalArgumentException if the connection ID was not generated by this factory
     */
    public byte[] statelessTokenFor(QuicConnectionId connectionId) {
        if (!(connectionId instanceof QuicLocalConnectionId)) {
            throw new IllegalArgumentException("Not a locally-generated connection ID");
        }
        Mac mac;
        try {
            mac = Mac.getInstance("HmacSHA256");
            mac.init(statelessTokenKey);
        } catch (NoSuchAlgorithmException | InvalidKeyException e) {
            throw new RuntimeException("HmacSHA256 is not available", e);
        }
        byte[] result = mac.doFinal(connectionId.getBytes());
        return Arrays.copyOf(result, 16);
    }

    // visible for testing
    public long newToken() {
        var token = tokens.incrementAndGet();
        if (token < 0) {
            token = -token - 1;
            wrapped = true;
        }
        return token;
    }

    // visible for testing
    public byte[] newConnectionId(int length, long token) {
        length = Math.clamp(length, MIN_CONNECTION_ID_LENGTH, MAX_CONNECTION_ID_LENGTH);
        assert length <= MAX_CONNECTION_ID_LENGTH;
        assert length >= MIN_CONNECTION_ID_LENGTH;
        byte[] bytes = new byte[length];
        RANDOM.nextBytes(bytes);

        if (token < 0) token = -token - 1;
        assert token >= 0;
        int len = variableLengthLength(token);
        assert len < 8;

        bytes[0] = (byte) ((length << 3) & 0xF8);
        bytes[0] = (byte) (bytes[0] | len);
        assert (bytes[0] & 0x07) == len;
        assert ((bytes[0] & 0xFF) >> 3) == length :
                "%s != %s".formatted(bytes[0] & 0xFF, length);
        int shift = 8 * len;
        for (int i = 0; i <= len; i++) {
            assert shift <= 56;
            bytes[i + 1] = (byte) ((token >> shift) & 0xFF);
            shift -= 8;
        }
        for (int i = 0; i < length; i++) {
            bytes[i] = (byte) ((bytes[i] & 0xFF) ^ (scrambler[i] & 0xFF));
        }

        assert length == getConnectionIdLength(bytes);
        assert token == getConnectionIdToken(bytes);
        return bytes;
    }

    // visible for testing
    public int getConnectionIdLength(byte[] bytes) {
        assert bytes.length >= MIN_CONNECTION_ID_LENGTH;
        var length = extractConnectionIdLength(bytes[0]);
        assert length <= MAX_CONNECTION_ID_LENGTH;
        return length;
    }

    // visible for testing
    public long getConnectionIdToken(byte[] bytes) {
        assert bytes.length >= MIN_CONNECTION_ID_LENGTH;
        int len = extractTokenLength(bytes[0]);
        long token = 0;
        int shift = len * 8;
        for (int i = 0; i <= len; i++) {
            assert shift >= 0;
            assert shift <= 56;
            int j = i + 1;
            long l = ((bytes[j] & 0xFF) ^ (scrambler[j] & 0xFF)) & 0xFF;
            l = l << shift;
            token += l;
            shift -= 8;
        }
        assert token >= 0;
        return token;
    }

    private long peekConnectionIdToken(ByteBuffer bytes) {
        assert bytes.remaining() >= MIN_CONNECTION_ID_LENGTH;
        int len = extractTokenLength(bytes.get(0));
        long token = 0;
        int shift = len * 8;
        for (int i = 0; i <= len; i++) {
            assert shift >= 0;
            assert shift <= 56;
            int j = i + 1;
            long l = ((bytes.get(j) & 0xFF) ^ (scrambler[j] & 0xFF)) & 0xFF;
            l = l << shift;
            token += l;
            shift -= 8;
        }
        return token;
    }

    private boolean isValidToken(long token) {
        if (token < 0) return false;
        long prevToken = tokens.get();
        boolean wrapped = prevToken < 0 || this.wrapped;
        // if `tokens` has wrapped, we can say nothing...
        // otherwise, we can say it should not be coded on more bytes than
        // the previous token that was distributed
        if (!wrapped) {
            return token <= prevToken;
        }
        return true;
    }

    private int extractConnectionIdLength(byte b) {
        var bits = ((b & 0xFF) ^ (scrambler[0] & 0xFF)) & 0xFF;
        bits = bits >> 3;
        return bits;
    }

    private int extractTokenLength(byte b) {
        var bits = ((b & 0xFF) ^ (scrambler[0] & 0xFF)) & 0xFF;
        return bits & 0x07;
    }

    private static int variableLengthLength(long token) {
        assert token >= 0;
        int len = 0;
        int shift = 0;
        for (int i = 1; i < 8; i++) {
            shift += 8;
            if ((token >> shift) == 0) break;
            len++;
        }
        assert len < 8;
        return len;
    }

    /**
     * Checks if {@code connId} looks like a connection ID we could possibly generate.
     * If it does, returns a stateless reset datagram.
     * @param connId the destination connection id that was received on the packet
     * @param length maximum length of the stateless reset packet
     * @return stateless reset datagram payload, or null
     */
    public ByteBuffer statelessReset(ByteBuffer connId, int length) {
        // 43 bytes max:
        //  first byte bits 01xx xxxx
        //  followed by random bytes
        //  terminated by 16 bytes reset token
        length = Math.min(length, 43);
        if (length < 21) { // minimum QUIC short datagram length
            return null;
        }

        var cid = (QuicLocalConnectionId)unsafeConnectionIdFor(connId);
        if (cid != null) {
            var localToken = statelessTokenFor(cid);
            assert localToken != null;
            ByteBuffer buf = ByteBuffer.allocate(length);
            buf.put((byte)(0x40 + RANDOM.nextInt(0x40)));
            byte[] random = new byte[length - 17];
            RANDOM.nextBytes(random);
            buf.put(random);
            buf.put(localToken);
            assert !buf.hasRemaining() : buf.remaining();
            buf.flip();
            return buf;
        }
        return null;
    }

    // A connection id generated by this instance.
    private static final class QuicLocalConnectionId extends QuicConnectionId {
        private final long token;
        private final String simpleDesc;

        // Connection Ids created with this constructor are safer
        // to use in maps as the buffer wraps a safe byte array in
        // this constructor.
        private QuicLocalConnectionId(long token, String simpleDesc, byte[] bytes) {
            super(ByteBuffer.wrap(bytes));
            this.token = token;
            this.simpleDesc = simpleDesc;
        }

        // Connection Ids created with this constructor are only
        // safe to use as long as the caller abstain from mutating
        // the provided byte buffer.
        // Typically, they will be transiently used to look up some
        // connection in a map indexed by a connection id.
        private QuicLocalConnectionId(ByteBuffer buffer, long token, String simpleDesc) {
            super(buffer);
            assert token >= 0;
            this.token = token;
            this.simpleDesc = simpleDesc;
        }

        @Override
        public String toString() {
            return "%s(length=%s, token=%s, hash=%s)"
                    .formatted(simpleDesc, length(), token, hashCode);
        }
    }
}
