/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.transport;

import org.apache.logging.log4j.Level;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.node.VersionInformation;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.EnumSerializationTestUtils;
import org.elasticsearch.test.MockLog;
import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.test.transport.MockTransportService;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;

import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;

import static org.elasticsearch.test.MockLog.assertThatLogger;
import static org.elasticsearch.test.MockLog.awaitLogger;
import static org.elasticsearch.transport.RemoteClusterSettings.ProxyConnectionStrategySettings.PROXY_ADDRESS;
import static org.elasticsearch.transport.RemoteClusterSettings.REMOTE_CONNECTION_MODE;
import static org.elasticsearch.transport.RemoteClusterSettings.SniffConnectionStrategySettings.REMOTE_CLUSTER_SEEDS;
import static org.elasticsearch.transport.RemoteClusterSettings.toConfig;
import static org.elasticsearch.transport.RemoteConnectionStrategy.buildConnectionProfile;
import static org.mockito.Mockito.mock;

public class RemoteConnectionStrategyTests extends ESTestCase {
    private static final String clusterAlias = "cluster-alias";
    private static final Map<RemoteConnectionStrategy.ConnectionStrategy, LinkedProjectConfig> cfgMap = Map.of(
        RemoteConnectionStrategy.ConnectionStrategy.PROXY,
        new LinkedProjectConfig.ProxyLinkedProjectConfigBuilder(ProjectId.DEFAULT, ProjectId.DEFAULT, clusterAlias).proxyAddress(
            "localhost:8080"
        ).build(),
        RemoteConnectionStrategy.ConnectionStrategy.SNIFF,
        new LinkedProjectConfig.SniffLinkedProjectConfigBuilder(ProjectId.DEFAULT, ProjectId.DEFAULT, clusterAlias).seedNodes(
            List.of("localhost:8080")
        ).build()
    );
    private static final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);

    public void testStrategyChangeMeansThatStrategyMustBeRebuilt() {
        final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
            Settings.EMPTY,
            mock(Transport.class),
            threadContext
        );
        RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(
            "cluster-alias",
            RemoteClusterCredentialsManager.EMPTY,
            connectionManager
        );
        FakeConnectionStrategy first = new FakeConnectionStrategy(
            "cluster-alias",
            mock(TransportService.class),
            remoteConnectionManager,
            RemoteConnectionStrategy.ConnectionStrategy.PROXY
        );
        Settings newSettings = Settings.builder()
            .put(REMOTE_CONNECTION_MODE.getConcreteSettingForNamespace("cluster-alias").getKey(), "sniff")
            .put(REMOTE_CLUSTER_SEEDS.getConcreteSettingForNamespace("cluster-alias").getKey(), "127.0.0.1:9300")
            .build();
        assertTrue(first.shouldRebuildConnection(toConfig("cluster-alias", newSettings)));
    }

    public void testSameStrategyChangeMeansThatStrategyDoesNotNeedToBeRebuilt() {
        final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
            Settings.EMPTY,
            mock(Transport.class),
            threadContext
        );
        RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(
            "cluster-alias",
            RemoteClusterCredentialsManager.EMPTY,
            connectionManager
        );
        FakeConnectionStrategy first = new FakeConnectionStrategy(
            "cluster-alias",
            mock(TransportService.class),
            remoteConnectionManager,
            RemoteConnectionStrategy.ConnectionStrategy.PROXY
        );
        Settings newSettings = Settings.builder()
            .put(REMOTE_CONNECTION_MODE.getConcreteSettingForNamespace("cluster-alias").getKey(), "proxy")
            .put(PROXY_ADDRESS.getConcreteSettingForNamespace("cluster-alias").getKey(), "127.0.0.1:9300")
            .build();
        assertFalse(first.shouldRebuildConnection(toConfig("cluster-alias", newSettings)));
    }

    public void testChangeInConnectionProfileMeansTheStrategyMustBeRebuilt() {
        final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
            TestProfiles.LIGHT_PROFILE,
            mock(Transport.class),
            threadContext
        );
        assertEquals(TimeValue.MINUS_ONE, connectionManager.getConnectionProfile().getPingInterval());
        assertEquals(Compression.Enabled.INDEXING_DATA, connectionManager.getConnectionProfile().getCompressionEnabled());
        assertEquals(Compression.Scheme.LZ4, connectionManager.getConnectionProfile().getCompressionScheme());
        RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(
            "cluster-alias",
            RemoteClusterCredentialsManager.EMPTY,
            connectionManager
        );
        FakeConnectionStrategy first = new FakeConnectionStrategy(
            "cluster-alias",
            mock(TransportService.class),
            remoteConnectionManager,
            RemoteConnectionStrategy.ConnectionStrategy.PROXY
        );

        Settings.Builder newBuilder = Settings.builder();
        newBuilder.put(REMOTE_CONNECTION_MODE.getConcreteSettingForNamespace("cluster-alias").getKey(), "proxy");
        newBuilder.put(PROXY_ADDRESS.getConcreteSettingForNamespace("cluster-alias").getKey(), "127.0.0.1:9300");
        String ping = "ping";
        String compress = "compress";
        String compressionScheme = "compression_scheme";
        String change = randomFrom(ping, compress, compressionScheme);
        if (change.equals(ping)) {
            newBuilder.put(
                RemoteClusterSettings.REMOTE_CLUSTER_PING_SCHEDULE.getConcreteSettingForNamespace("cluster-alias").getKey(),
                TimeValue.timeValueSeconds(5)
            );
        } else if (change.equals(compress)) {
            newBuilder.put(
                RemoteClusterSettings.REMOTE_CLUSTER_COMPRESS.getConcreteSettingForNamespace("cluster-alias").getKey(),
                randomFrom(Compression.Enabled.FALSE, Compression.Enabled.TRUE)
            );
        } else if (change.equals(compressionScheme)) {
            newBuilder.put(
                RemoteClusterSettings.REMOTE_CLUSTER_COMPRESSION_SCHEME.getConcreteSettingForNamespace("cluster-alias").getKey(),
                Compression.Scheme.DEFLATE
            );
        } else {
            throw new AssertionError("Unexpected option: " + change);
        }
        assertTrue(first.shouldRebuildConnection(toConfig("cluster-alias", newBuilder.build())));
    }

    public void testCorrectChannelNumber() {
        for (RemoteConnectionStrategy.ConnectionStrategy strategy : RemoteConnectionStrategy.ConnectionStrategy.values()) {
            ConnectionProfile proxyProfile = buildConnectionProfile(
                cfgMap.get(strategy),
                randomBoolean() ? RemoteClusterPortSettings.REMOTE_CLUSTER_PROFILE : TransportSettings.DEFAULT_PROFILE
            );
            assertEquals(
                "Incorrect number of channels for " + strategy.name(),
                strategy.getNumberOfChannels(),
                proxyProfile.getNumConnections()
            );
        }
    }

    public void testTransportProfile() {

        // New rcs connection with credentials
        for (RemoteConnectionStrategy.ConnectionStrategy strategy : RemoteConnectionStrategy.ConnectionStrategy.values()) {
            ConnectionProfile profile = buildConnectionProfile(cfgMap.get(strategy), RemoteClusterPortSettings.REMOTE_CLUSTER_PROFILE);
            assertEquals(
                "Incorrect transport profile for " + strategy.name(),
                RemoteClusterPortSettings.REMOTE_CLUSTER_PROFILE,
                profile.getTransportProfile()
            );
        }

        // Legacy ones without credentials
        for (RemoteConnectionStrategy.ConnectionStrategy strategy : RemoteConnectionStrategy.ConnectionStrategy.values()) {
            ConnectionProfile profile = buildConnectionProfile(cfgMap.get(strategy), TransportSettings.DEFAULT_PROFILE);
            assertEquals(
                "Incorrect transport profile for " + strategy.name(),
                TransportSettings.DEFAULT_PROFILE,
                profile.getTransportProfile()
            );
        }
    }

    public void testConnectionStrategySerialization() {
        EnumSerializationTestUtils.assertEnumSerialization(
            RemoteConnectionStrategy.ConnectionStrategy.class,
            RemoteConnectionStrategy.ConnectionStrategy.SNIFF,
            RemoteConnectionStrategy.ConnectionStrategy.PROXY
        );
    }

    @TestLogging(
        value = "org.elasticsearch.transport.RemoteConnectionStrategyTests.FakeConnectionStrategy:DEBUG",
        reason = "logging verification"
    )
    public void testConnectionAttemptMetricsAndLogging() {
        final var originProjectId = randomUniqueProjectId();
        final var linkedProjectId = randomUniqueProjectId();
        final var alias = randomAlphanumericOfLength(10);

        try (
            var threadPool = new TestThreadPool(getClass().getName());
            var transportService = startTransport(threadPool);
            var connectionManager = new RemoteConnectionManager(
                alias,
                RemoteClusterCredentialsManager.EMPTY,
                new ClusterConnectionManager(TestProfiles.LIGHT_PROFILE, mock(Transport.class), threadContext)
            )
        ) {
            for (boolean shouldConnectFail : new boolean[] { true, false }) {
                for (boolean isInitialConnectAttempt : new boolean[] { true, false }) {
                    final var strategy = new FakeConnectionStrategy(
                        originProjectId,
                        linkedProjectId,
                        alias,
                        transportService,
                        connectionManager
                    );
                    if (isInitialConnectAttempt == false) {
                        waitForConnect(strategy);
                    }
                    strategy.setShouldConnectFail(shouldConnectFail);
                    final var expectedLogLevel = shouldConnectFail ? Level.WARN : Level.DEBUG;
                    final var expectedLogMessage = Strings.format(
                        "Origin project [%s] %s to linked project [%s] with alias [%s] on %s attempt",
                        originProjectId,
                        shouldConnectFail ? "failed to connect" : "successfully connected",
                        linkedProjectId,
                        alias,
                        isInitialConnectAttempt ? "the initial connection" : "a reconnection"
                    );
                    assertThatLogger(() -> {
                        if (shouldConnectFail) {
                            assertThrows(RuntimeException.class, () -> waitForConnect(strategy));
                        } else {
                            waitForConnect(strategy);
                        }
                    },
                        strategy.getClass(),
                        new MockLog.SeenEventExpectation(
                            "connection strategy should log at "
                                + expectedLogLevel
                                + " after a "
                                + (shouldConnectFail ? "failed" : "successful")
                                + (isInitialConnectAttempt ? " initial connection attempt" : " reconnection attempt"),
                            strategy.getClass().getCanonicalName(),
                            expectedLogLevel,
                            expectedLogMessage
                        )
                    );
                }
            }

            // Now verify connection errors when closing (node shutting down) are logged at debug and not warn.
            final var strategy = new FakeConnectionStrategy(originProjectId, linkedProjectId, alias, transportService, connectionManager);
            waitForConnect(strategy);
            strategy.setShouldConnectFail(true);
            strategy.setWaitInConnect(true);
            final var expectedLogLevel = Level.DEBUG;
            final var expectedLogMessage = Strings.format(
                "Origin project [%s] failed to connect to linked project [%s] with alias [%s] on a reconnection attempt",
                originProjectId,
                linkedProjectId,
                alias
            );
            awaitLogger(() -> assertThrows(RuntimeException.class, () -> {
                PlainActionFuture<Void> connectFuture = new PlainActionFuture<>();
                // Initiate the connection, this will block the connecting thread.
                strategy.connect(connectFuture);
                // Close the strategy and manager (similar to RemoteClusterConnection.close()), then let the connection attempt complete.
                strategy.close();
                connectionManager.close();
                strategy.releaseWaitingConnect();
                connectFuture.actionGet();
            }),
                strategy.getClass(),
                new MockLog.SeenEventExpectation(
                    "connection strategy should log at " + expectedLogLevel + " after a failed reconnection attempt",
                    strategy.getClass().getCanonicalName(),
                    expectedLogLevel,
                    expectedLogMessage
                )
            );
        }
    }

    private MockTransportService startTransport(ThreadPool threadPool) {
        boolean success = false;
        final Settings s = Settings.builder().put(ClusterName.CLUSTER_NAME_SETTING.getKey(), "cluster1").put("node.name", "node1").build();
        MockTransportService newService = MockTransportService.createNewService(
            s,
            VersionInformation.CURRENT,
            TransportVersion.current(),
            threadPool
        );
        try {
            newService.start();
            newService.acceptIncomingRequests();
            success = true;
            return newService;
        } finally {
            if (success == false) {
                newService.close();
            }
        }
    }

    private static void waitForConnect(RemoteConnectionStrategy strategy) {
        PlainActionFuture<Void> connectFuture = new PlainActionFuture<>();
        strategy.connect(connectFuture);
        connectFuture.actionGet();
    }

    private static class FakeConnectionStrategy extends RemoteConnectionStrategy {

        private final ConnectionStrategy strategy;
        private boolean shouldConnectFail;
        private boolean waitInConnect;
        private final CountDownLatch waitLatch;

        FakeConnectionStrategy(
            ProjectId originProjectId,
            ProjectId linkedProjectId,
            String clusterAlias,
            TransportService transportService,
            RemoteConnectionManager connectionManager
        ) {
            this(
                originProjectId,
                linkedProjectId,
                clusterAlias,
                transportService,
                connectionManager,
                randomFrom(RemoteConnectionStrategy.ConnectionStrategy.values())
            );
        }

        FakeConnectionStrategy(
            String clusterAlias,
            TransportService transportService,
            RemoteConnectionManager connectionManager,
            RemoteConnectionStrategy.ConnectionStrategy strategy
        ) {
            this(ProjectId.DEFAULT, ProjectId.DEFAULT, clusterAlias, transportService, connectionManager, strategy);
        }

        FakeConnectionStrategy(
            ProjectId originProjectId,
            ProjectId linkedProjectId,
            String clusterAlias,
            TransportService transportService,
            RemoteConnectionManager connectionManager,
            RemoteConnectionStrategy.ConnectionStrategy strategy
        ) {
            super(switch (strategy) {
                case PROXY -> new LinkedProjectConfig.ProxyLinkedProjectConfigBuilder(originProjectId, linkedProjectId, clusterAlias)
                    .proxyAddress("localhost:8080")
                    .build();
                case SNIFF -> new LinkedProjectConfig.SniffLinkedProjectConfigBuilder(originProjectId, linkedProjectId, clusterAlias)
                    .seedNodes(List.of("localhost:8080"))
                    .build();
            }, transportService, connectionManager);
            this.strategy = strategy;
            this.shouldConnectFail = false;
            this.waitInConnect = false;
            this.waitLatch = new CountDownLatch(1);
        }

        void setShouldConnectFail(boolean shouldConnectFail) {
            this.shouldConnectFail = shouldConnectFail;
        }

        void setWaitInConnect(boolean waitInConnect) {
            this.waitInConnect = waitInConnect;
        }

        void releaseWaitingConnect() {
            this.waitLatch.countDown();
        }

        @Override
        protected boolean strategyMustBeRebuilt(LinkedProjectConfig config) {
            return false;
        }

        @Override
        protected ConnectionStrategy strategyType() {
            return this.strategy;
        }

        @Override
        protected boolean shouldOpenMoreConnections() {
            return false;
        }

        @Override
        protected void connectImpl(ActionListener<Void> listener) {
            if (waitInConnect) {
                safeAwait(waitLatch);
            }
            if (shouldConnectFail) {
                listener.onFailure(new RuntimeException("simulated failure"));
            } else {
                listener.onResponse(null);
            }
        }

        @Override
        protected RemoteConnectionInfo.ModeInfo getModeInfo() {
            return null;
        }
    }
}
