/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.indices;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.LRUQueryCache;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryCache;
import org.apache.lucene.search.QueryCachingPolicy;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.lucene.ShardCoreKeyMap;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Predicates;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.index.cache.query.QueryCacheStats;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.ShardId;

import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Predicate;
import java.util.function.Supplier;

public class IndicesQueryCache implements QueryCache, Closeable {

    private static final Logger logger = LogManager.getLogger(IndicesQueryCache.class);

    public static final Setting<ByteSizeValue> INDICES_CACHE_QUERY_SIZE_SETTING = Setting.memorySizeSetting(
        "indices.queries.cache.size",
        "10%",
        Property.NodeScope
    );
    // mostly a way to prevent queries from being the main source of memory usage
    // of the cache
    public static final Setting<Integer> INDICES_CACHE_QUERY_COUNT_SETTING = Setting.intSetting(
        "indices.queries.cache.count",
        10_000,
        1,
        Property.NodeScope
    );
    // enables caching on all segments instead of only the larger ones, for testing only
    public static final Setting<Boolean> INDICES_QUERIES_CACHE_ALL_SEGMENTS_SETTING = Setting.boolSetting(
        "indices.queries.cache.all_segments",
        false,
        Property.NodeScope
    );

    private final LRUQueryCache cache;
    private final ShardCoreKeyMap shardKeyMap = new ShardCoreKeyMap();
    private final Map<ShardId, Stats> shardStats = new ConcurrentHashMap<>();
    private volatile long sharedRamBytesUsed;

    /**
     * Calculates a map of {@link ShardId} to {@link Long} which contains the calculated share of the {@link IndicesQueryCache} shared ram
     * size for a given shard (that is, the sum of all the longs is the size of the indices query cache). Since many shards will not
     * participate in the cache, shards whose calculated share is zero will not be contained in the map at all. As a consequence, the
     * correct pattern for using the returned map will be via {@link Map#getOrDefault(Object, Object)} with a {@code defaultValue} of
     * {@code 0L}.
     * @return an unmodifiable map from {@link ShardId} to the calculated share of the query cache's shared RAM size for each shard,
     *         omitting shards with a zero share
     */
    public static Map<ShardId, Long> getSharedRamSizeForAllShards(IndicesService indicesService) {
        Map<ShardId, Long> shardIdToSharedRam = new HashMap<>();
        IndicesQueryCache.CacheTotals cacheTotals = IndicesQueryCache.getCacheTotalsForAllShards(indicesService);
        for (IndexService indexService : indicesService) {
            for (IndexShard indexShard : indexService) {
                final var queryCache = indicesService.getIndicesQueryCache();
                long sharedRam = (queryCache == null) ? 0L : queryCache.getSharedRamSizeForShard(indexShard.shardId(), cacheTotals);
                // as a size optimization, only store non-zero values in the map
                if (sharedRam > 0L) {
                    shardIdToSharedRam.put(indexShard.shardId(), sharedRam);
                }
            }
        }
        return Collections.unmodifiableMap(shardIdToSharedRam);
    }

    public long getCacheSizeForShard(ShardId shardId) {
        Stats stats = shardStats.get(shardId);
        return stats != null ? stats.cacheSize : 0L;
    }

    public long getSharedRamBytesUsed() {
        return sharedRamBytesUsed;
    }

    // This is a hack for the fact that the close listener for the
    // ShardCoreKeyMap will be called before onDocIdSetEviction
    // See onDocIdSetEviction for more info
    private final Map<Object, StatsAndCount> stats2 = Collections.synchronizedMap(new IdentityHashMap<>());

    public IndicesQueryCache(Settings settings) {
        final ByteSizeValue size = INDICES_CACHE_QUERY_SIZE_SETTING.get(settings);
        final int count = INDICES_CACHE_QUERY_COUNT_SETTING.get(settings);
        logger.debug("using [node] query cache with size [{}] max filter count [{}]", size, count);
        if (INDICES_QUERIES_CACHE_ALL_SEGMENTS_SETTING.get(settings)) {
            // Use the default skip_caching_factor (i.e., 10f) in Lucene
            cache = new ElasticsearchLRUQueryCache(count, size.getBytes(), Predicates.always(), 10f);
        } else {
            cache = new ElasticsearchLRUQueryCache(count, size.getBytes());
        }
        sharedRamBytesUsed = 0;
    }

    private static QueryCacheStats toQueryCacheStatsSafe(@Nullable Stats stats) {
        return stats == null ? new QueryCacheStats() : stats.toQueryCacheStats();
    }

    /**
     * Computes the total cache size in bytes, and the total shard count in the cache for all shards.
     * @param indicesService the IndicesService instance to retrieve cache information from
     * @return A CacheTotals object containing the computed total number of items in the cache and the number of shards seen in the cache
     */
    public static CacheTotals getCacheTotalsForAllShards(IndicesService indicesService) {
        IndicesQueryCache queryCache = indicesService.getIndicesQueryCache();
        boolean hasQueryCache = queryCache != null;
        long totalItemsInCache = 0L;
        int shardCount = 0;
        for (final IndexService indexService : indicesService) {
            for (final IndexShard indexShard : indexService) {
                final var shardId = indexShard.shardId();
                long cacheSize = hasQueryCache ? queryCache.getCacheSizeForShard(shardId) : 0L;
                shardCount++;
                assert cacheSize >= 0 : "Unexpected cache size of " + cacheSize + " for shard " + shardId;
                totalItemsInCache += cacheSize;
            }
        }
        return new CacheTotals(totalItemsInCache, shardCount);
    }

    /**
     * This method computes the shared RAM size in bytes for the given indexShard.
     * @param shardId The shard to compute the shared RAM size for.
     * @param cacheTotals Shard totals computed in {@link #getCacheTotalsForAllShards(IndicesService)}.
     * @return the shared RAM size in bytes allocated to the given shard, or 0 if unavailable
     */
    public long getSharedRamSizeForShard(ShardId shardId, CacheTotals cacheTotals) {
        long sharedRamBytesUsed = getSharedRamBytesUsed();
        if (sharedRamBytesUsed == 0L) {
            return 0L;
        }

        int shardCount = cacheTotals.shardCount();
        if (shardCount == 0) {
            // Sometimes it's not possible to do this when there are no shard entries at all, which can happen as the shared ram usage can
            // extend beyond the closing of all shards.
            return 0L;
        }
        /*
         * We have some shared ram usage that we try to distribute proportionally to the number of segment-requests in the cache for each
         * shard.
         */
        long totalItemsInCache = cacheTotals.totalItemsInCache();
        long itemsInCacheForShard = getCacheSizeForShard(shardId);
        final long additionalRamBytesUsed;
        if (totalItemsInCache == 0) {
            // all shards have zero cache footprint, so we apportion the size of the shared bytes equally across all shards
            additionalRamBytesUsed = Math.round((double) sharedRamBytesUsed / shardCount);
        } else {
            /*
             * Some shards have nonzero cache footprint, so we apportion the size of the shared bytes proportionally to the number of
             * segment-requests in the cache for this shard (the number and size of documents associated with those requests is irrelevant
             * for this calculation).
             * Note that this was a somewhat arbitrary decision. Calculating it by number of documents might have been better. Calculating
             * it by number of documents weighted by size would also be good, but possibly more expensive. But the decision to attribute
             * memory proportionally to the number of segment-requests was made a long time ago, and we're sticking with that here for the
             * sake of consistency and backwards compatibility.
             */
            additionalRamBytesUsed = Math.round((double) sharedRamBytesUsed * itemsInCacheForShard / totalItemsInCache);
        }
        assert additionalRamBytesUsed >= 0L : additionalRamBytesUsed;
        return additionalRamBytesUsed;
    }

    public record CacheTotals(long totalItemsInCache, int shardCount) {}

    /** Get usage statistics for the given shard. */
    public QueryCacheStats getStats(ShardId shard, Supplier<Long> precomputedSharedRamBytesUsed) {
        final QueryCacheStats queryCacheStats = toQueryCacheStatsSafe(shardStats.get(shard));
        queryCacheStats.addRamBytesUsed(precomputedSharedRamBytesUsed.get());
        return queryCacheStats;
    }

    @Override
    public Weight doCache(Weight weight, QueryCachingPolicy policy) {
        while (weight instanceof CachingWeightWrapper) {
            weight = ((CachingWeightWrapper) weight).in;
        }
        final Weight in = cache.doCache(weight, policy);
        // We wrap the weight to track the readers it sees and map them with
        // the shards they belong to
        return new CachingWeightWrapper(in);
    }

    private class CachingWeightWrapper extends Weight {

        private final Weight in;

        protected CachingWeightWrapper(Weight in) {
            super(in.getQuery());
            this.in = in;
        }

        @Override
        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            shardKeyMap.add(context.reader());
            return in.explain(context, doc);
        }

        @Override
        public int count(LeafReaderContext context) throws IOException {
            shardKeyMap.add(context.reader());
            return in.count(context);
        }

        @Override
        public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
            shardKeyMap.add(context.reader());
            return in.scorerSupplier(context);
        }

        @Override
        public boolean isCacheable(LeafReaderContext ctx) {
            return in.isCacheable(ctx);
        }
    }

    /** Clear all entries that belong to the given index. */
    public void clearIndex(String index) {
        final Set<Object> coreCacheKeys = shardKeyMap.getCoreKeysForIndex(index);
        for (Object coreKey : coreCacheKeys) {
            cache.clearCoreCacheKey(coreKey);
        }

        // This cache stores two things: filters, and doc id sets. Calling
        // clear only removes the doc id sets, but if we reach the situation
        // that the cache does not contain any DocIdSet anymore, then it
        // probably means that the user wanted to remove everything.
        if (cache.getCacheSize() == 0) {
            cache.clear();
        }
    }

    @Override
    public void close() {
        assert shardKeyMap.size() == 0 : shardKeyMap.size();
        assert shardStats.isEmpty() : shardStats.keySet();
        assert stats2.isEmpty() : stats2;

        // This cache stores two things: filters, and doc id sets. At this time
        // we only know that there are no more doc id sets, but we still track
        // recently used queries, which we want to reclaim.
        cache.clear();
    }

    private static class Stats implements Cloneable {

        final ShardId shardId;
        volatile long ramBytesUsed;
        volatile long hitCount;
        volatile long missCount;
        volatile long cacheCount;
        volatile long cacheSize;

        Stats(ShardId shardId) {
            this.shardId = shardId;
        }

        QueryCacheStats toQueryCacheStats() {
            return new QueryCacheStats(ramBytesUsed, hitCount, missCount, cacheCount, cacheSize);
        }

        @Override
        public String toString() {
            return "{shardId="
                + shardId
                + ", ramBytesUsed="
                + ramBytesUsed
                + ", hitCount="
                + hitCount
                + ", missCount="
                + missCount
                + ", cacheCount="
                + cacheCount
                + ", cacheSize="
                + cacheSize
                + "}";
        }
    }

    private static class StatsAndCount {
        volatile int count;
        final Stats stats;

        StatsAndCount(Stats stats) {
            this.stats = stats;
            this.count = 0;
        }

        @Override
        public String toString() {
            return "{stats=" + stats + " ,count=" + count + "}";
        }
    }

    private static boolean empty(Stats stats) {
        if (stats == null) {
            return true;
        }
        return stats.cacheSize == 0 && stats.ramBytesUsed == 0;
    }

    public void onClose(ShardId shardId) {
        assert empty(shardStats.get(shardId));
        shardStats.remove(shardId);
    }

    private class ElasticsearchLRUQueryCache extends LRUQueryCache {

        ElasticsearchLRUQueryCache(int maxSize, long maxRamBytesUsed, Predicate<LeafReaderContext> leavesToCache, float skipFactor) {
            super(maxSize, maxRamBytesUsed, leavesToCache, skipFactor);
        }

        ElasticsearchLRUQueryCache(int maxSize, long maxRamBytesUsed) {
            super(maxSize, maxRamBytesUsed);
        }

        private Stats getStats(Object coreKey) {
            final ShardId shardId = shardKeyMap.getShardId(coreKey);
            if (shardId == null) {
                return null;
            }
            return shardStats.get(shardId);
        }

        private Stats getOrCreateStats(Object coreKey) {
            return shardStats.computeIfAbsent(shardKeyMap.getShardId(coreKey), Stats::new);
        }

        // It's ok to not protect these callbacks by a lock since it is
        // done in LRUQueryCache
        @Override
        protected void onClear() {
            super.onClear();
            for (Stats stats : shardStats.values()) {
                // don't throw away hit/miss
                stats.cacheSize = 0;
                stats.ramBytesUsed = 0;
            }
            stats2.clear();
            sharedRamBytesUsed = 0;
        }

        @Override
        protected void onQueryCache(Query filter, long ramBytesUsed) {
            super.onQueryCache(filter, ramBytesUsed);
            sharedRamBytesUsed += ramBytesUsed;
        }

        @Override
        protected void onQueryEviction(Query filter, long ramBytesUsed) {
            super.onQueryEviction(filter, ramBytesUsed);
            sharedRamBytesUsed -= ramBytesUsed;
        }

        @Override
        protected void onDocIdSetCache(Object readerCoreKey, long ramBytesUsed) {
            super.onDocIdSetCache(readerCoreKey, ramBytesUsed);
            final Stats shardStats = getOrCreateStats(readerCoreKey);
            shardStats.cacheSize += 1;
            shardStats.cacheCount += 1;
            shardStats.ramBytesUsed += ramBytesUsed;

            StatsAndCount statsAndCount = stats2.computeIfAbsent(readerCoreKey, ignored -> new StatsAndCount(shardStats));
            statsAndCount.count += 1;
        }

        @Override
        protected void onDocIdSetEviction(Object readerCoreKey, int numEntries, long sumRamBytesUsed) {
            super.onDocIdSetEviction(readerCoreKey, numEntries, sumRamBytesUsed);
            // onDocIdSetEviction might sometimes be called with a number
            // of entries equal to zero if the cache for the given segment
            // was already empty when the close listener was called
            if (numEntries > 0) {
                // We can't use ShardCoreKeyMap here because its core closed
                // listener is called before the listener of the cache which
                // triggers this eviction. So instead we use stats2 that
                // we only evict when nothing is cached anymore on the segment
                // instead of relying on close listeners
                final StatsAndCount statsAndCount = stats2.get(readerCoreKey);
                final Stats shardStats = statsAndCount.stats;
                shardStats.cacheSize -= numEntries;
                shardStats.ramBytesUsed -= sumRamBytesUsed;
                statsAndCount.count -= numEntries;
                if (statsAndCount.count == 0) {
                    stats2.remove(readerCoreKey);
                }
            }
        }

        @Override
        protected void onHit(Object readerCoreKey, Query filter) {
            super.onHit(readerCoreKey, filter);
            final Stats shardStats = getStats(readerCoreKey);
            shardStats.hitCount += 1;
        }

        @Override
        protected void onMiss(Object readerCoreKey, Query filter) {
            super.onMiss(readerCoreKey, filter);
            final Stats shardStats = getOrCreateStats(readerCoreKey);
            shardStats.missCount += 1;
        }
    }
}
