/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.common.util.concurrent;

import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.ReferenceDocs;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.monitor.jvm.HotThreads;

import java.util.concurrent.BlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.LongSupplier;
import java.util.stream.Stream;

import static org.elasticsearch.core.Strings.format;

/**
 * An extension to thread pool executor, allowing (in the future) to add specific additional stats to it.
 */
public class EsThreadPoolExecutor extends ThreadPoolExecutor {

    private static final Logger logger = LogManager.getLogger(EsThreadPoolExecutor.class);
    private static final long NOT_TRACKED_TIME = -1L;

    // noop probe to prevent starvation of work in the work queue due to ForceQueuePolicy
    // https://github.com/elastic/elasticsearch/issues/124667
    // note, this is intentionally not a lambda to avoid this ever be turned into a compile time constant
    // matching similar lambdas coming from other places
    static final Runnable WORKER_PROBE = new Runnable() {
        @Override
        public void run() {}
    };

    private final ThreadContext contextHolder;

    /**
     * Name used in error reporting.
     */
    private final String name;

    private final EsExecutors.HotThreadsOnLargeQueueConfig hotThreadsOnLargeQueueConfig;
    private final LongSupplier currentTimeMillisSupplier;

    // There may be racing on updating this field. It's OK since hot threads logging is very coarse grained time wise
    // and can tolerate some inaccuracies.
    private volatile long startTimeMillisOfLargeQueue = NOT_TRACKED_TIME;

    private final AtomicLong lastLoggingTimeMillisForHotThreads;

    EsThreadPoolExecutor(
        String name,
        int corePoolSize,
        int maximumPoolSize,
        long keepAliveTime,
        TimeUnit unit,
        BlockingQueue<Runnable> workQueue,
        ThreadFactory threadFactory,
        ThreadContext contextHolder
    ) {
        this(
            name,
            corePoolSize,
            maximumPoolSize,
            keepAliveTime,
            unit,
            workQueue,
            threadFactory,
            new EsAbortPolicy(),
            contextHolder,
            EsExecutors.HotThreadsOnLargeQueueConfig.DISABLED
        );
    }

    EsThreadPoolExecutor(
        String name,
        int corePoolSize,
        int maximumPoolSize,
        long keepAliveTime,
        TimeUnit unit,
        BlockingQueue<Runnable> workQueue,
        ThreadFactory threadFactory,
        RejectedExecutionHandler handler,
        ThreadContext contextHolder,
        EsExecutors.HotThreadsOnLargeQueueConfig hotThreadsOnLargeQueueConfig
    ) {
        this(
            name,
            corePoolSize,
            maximumPoolSize,
            keepAliveTime,
            unit,
            workQueue,
            threadFactory,
            handler,
            contextHolder,
            hotThreadsOnLargeQueueConfig,
            System::currentTimeMillis
        );
    }

    @SuppressForbidden(reason = "properly rethrowing errors, see EsExecutors.rethrowErrors")
    EsThreadPoolExecutor(
        String name,
        int corePoolSize,
        int maximumPoolSize,
        long keepAliveTime,
        TimeUnit unit,
        BlockingQueue<Runnable> workQueue,
        ThreadFactory threadFactory,
        RejectedExecutionHandler handler,
        ThreadContext contextHolder,
        EsExecutors.HotThreadsOnLargeQueueConfig hotThreadsOnLargeQueueConfig,
        LongSupplier currentTimeMillisSupplier // For test to configure a custom time supplier
    ) {
        super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
        this.name = name;
        this.contextHolder = contextHolder;
        this.hotThreadsOnLargeQueueConfig = hotThreadsOnLargeQueueConfig;
        this.currentTimeMillisSupplier = currentTimeMillisSupplier;
        this.lastLoggingTimeMillisForHotThreads = hotThreadsOnLargeQueueConfig.isEnabled()
            ? new AtomicLong(currentTimeMillisSupplier.getAsLong() - hotThreadsOnLargeQueueConfig.intervalInMillis())
            : null;
    }

    @Override
    public void setCorePoolSize(int corePoolSize) {
        throw new UnsupportedOperationException("reconfiguration at runtime is not supported");
    }

    @Override
    public void setMaximumPoolSize(int maximumPoolSize) {
        throw new UnsupportedOperationException("reconfiguration at runtime is not supported");
    }

    @Override
    public void execute(Runnable command) {
        final Runnable wrappedRunnable = command != WORKER_PROBE ? wrapRunnable(command) : WORKER_PROBE;

        maybeLogForLargeQueueSize();

        try {
            super.execute(wrappedRunnable);
        } catch (Exception e) {
            if (wrappedRunnable instanceof AbstractRunnable abstractRunnable) {
                try {
                    // If we are an abstract runnable we can handle the exception
                    // directly and don't need to rethrow it, but we log and assert
                    // any unexpected exception first.
                    if (e instanceof EsRejectedExecutionException == false) {
                        logException(abstractRunnable, e);
                    }
                    abstractRunnable.onRejection(e);
                } finally {
                    abstractRunnable.onAfter();
                }
            } else {
                throw e;
            }
        }
    }

    private void maybeLogForLargeQueueSize() {
        if (hotThreadsOnLargeQueueConfig.isEnabled() == false) {
            return;
        }

        final int queueSize = getQueue().size();
        // Use queueSize + 1 so that we start to track when queueSize is 499 and this task is most likely to be queued as well,
        // thus reaching the threshold of 500. It won't log right away due to the duration threshold.
        if (queueSize + 1 >= hotThreadsOnLargeQueueConfig.sizeThreshold()) {
            final long startTime = startTimeMillisOfLargeQueue;
            final long now = currentTimeMillisSupplier.getAsLong();
            if (startTime == NOT_TRACKED_TIME) {
                startTimeMillisOfLargeQueue = now;
                return;
            }
            final long duration = now - startTime;
            if (duration >= hotThreadsOnLargeQueueConfig.durationThresholdInMillis()) {
                final var lastLoggingTime = lastLoggingTimeMillisForHotThreads.get();
                if (now - lastLoggingTime >= hotThreadsOnLargeQueueConfig.intervalInMillis()
                    && lastLoggingTimeMillisForHotThreads.compareAndSet(lastLoggingTime, now)) {
                    logger.info("start logging hot-threads for large queue size [{}] on [{}] executor", queueSize, name);
                    HotThreads.logLocalHotThreads(
                        logger,
                        Level.INFO,
                        "ThreadPoolExecutor ["
                            + name
                            + "] queue size ["
                            + queueSize
                            + "] has been over threshold for ["
                            + TimeValue.timeValueMillis(duration)
                            + "]",
                        ReferenceDocs.LOGGING
                    );
                }
            }
        } else {
            startTimeMillisOfLargeQueue = NOT_TRACKED_TIME;
        }
    }

    // package private for testing
    EsExecutors.HotThreadsOnLargeQueueConfig getHotThreadsOnLargeQueueConfig() {
        return hotThreadsOnLargeQueueConfig;
    }

    // package private for testing
    long getStartTimeMillisOfLargeQueue() {
        return startTimeMillisOfLargeQueue;
    }

    // package-visible for testing
    void logException(AbstractRunnable r, Exception e) {
        logger.error(() -> format("[%s] unexpected exception when submitting task [%s] for execution", name, r), e);
        assert false : "executor throws an exception (not a rejected execution exception) before the task has been submitted " + e;
    }

    @Override
    protected void afterExecute(Runnable r, Throwable t) {
        super.afterExecute(r, t);
        EsExecutors.rethrowErrors(unwrap(r));
        assert assertDefaultContext(r);
    }

    private boolean assertDefaultContext(Runnable r) {
        assert contextHolder.isDefaultContext()
            : "the thread context is not the default context and the thread ["
                + Thread.currentThread().getName()
                + "] is being returned to the pool after executing ["
                + r
                + "]";
        return true;
    }

    /**
     * Returns a stream of all pending tasks. This is similar to {@link #getQueue()} but will expose the originally submitted
     * {@link Runnable} instances rather than potentially wrapped ones.
     */
    public Stream<Runnable> getTasks() {
        return this.getQueue().stream().map(this::unwrap);
    }

    @Override
    public final String toString() {
        StringBuilder b = new StringBuilder();
        b.append(getClass().getSimpleName()).append('[');
        b.append("name = ").append(name).append(", ");
        if (getQueue() instanceof SizeBlockingQueue<?> queue) {
            b.append("queue capacity = ").append(queue.capacity()).append(", ");
        }
        appendThreadPoolExecutorDetails(b);
        /*
         * ThreadPoolExecutor has some nice information in its toString but we
         * can't get at it easily without just getting the toString.
         */
        b.append(super.toString()).append(']');
        return b.toString();
    }

    @Override
    public boolean remove(Runnable task) {
        logger.trace(() -> "task is removed " + task);
        return super.remove(task);
    }

    /**
     * Append details about this thread pool to the specified {@link StringBuilder}. All details should be appended as key/value pairs in
     * the form "%s = %s, "
     *
     * @param sb the {@link StringBuilder} to append to
     */
    protected void appendThreadPoolExecutorDetails(final StringBuilder sb) {}

    protected Runnable wrapRunnable(Runnable command) {
        return contextHolder.preserveContext(command);
    }

    protected Runnable unwrap(Runnable runnable) {
        return ThreadContext.unwrap(runnable);
    }
}
