/*
 * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
 * or more contributor license agreements. Licensed under the "Elastic License
 * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
 * Public License v 1"; you may not use this file except in compliance with, at
 * your election, the "Elastic License 2.0", the "GNU Affero General Public
 * License v3.0 only", or the "Server Side Public License, v 1".
 */

package org.elasticsearch.search.vectors;

import org.apache.lucene.search.Query;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.index.query.AbstractQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

/**
 * Exact knn query builder. Will iterate and score all documents that have the provided knn field in the index.
 * Useful in inner hits scoring scenarios.
 */
public class ExactKnnQueryBuilder extends AbstractQueryBuilder<ExactKnnQueryBuilder> {
    public static final String NAME = "exact_knn";
    private final String field;
    private final VectorData query;
    private final Float vectorSimilarity;

    /**
     * Creates a query builder.
     *
     * @param query the query vector
     * @param field the field that was used for the kNN query
     */
    public ExactKnnQueryBuilder(VectorData query, String field, Float vectorSimilarity) {
        this.query = query;
        this.field = field;
        this.vectorSimilarity = vectorSimilarity;
    }

    public ExactKnnQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.query = in.readOptionalWriteable(VectorData::new);
        this.field = in.readString();
        this.vectorSimilarity = in.readOptionalFloat();
    }

    String getField() {
        return field;
    }

    VectorData getQuery() {
        return query;
    }

    Float vectorSimilarity() {
        return vectorSimilarity;
    }

    @Override
    public String getWriteableName() {
        return NAME;
    }

    @Override
    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeOptionalWriteable(query);
        out.writeString(field);
        out.writeOptionalFloat(vectorSimilarity);
    }

    @Override
    protected void doXContent(XContentBuilder builder, Params params) throws IOException {
        builder.startObject(NAME);
        builder.field("query", query);
        builder.field("field", field);
        if (vectorSimilarity != null) {
            builder.field("similarity", vectorSimilarity);
        }
        boostAndQueryNameToXContent(builder);
        builder.endObject();
    }

    @Override
    protected Query doToQuery(SearchExecutionContext context) throws IOException {
        final MappedFieldType fieldType = context.getFieldType(field);
        if (fieldType == null) {
            throw new IllegalArgumentException("field [" + field + "] does not exist in the mapping");
        }
        if (fieldType instanceof DenseVectorFieldMapper.DenseVectorFieldType == false) {
            throw new IllegalArgumentException(
                "[" + NAME + "] queries are only supported on [" + DenseVectorFieldMapper.CONTENT_TYPE + "] fields"
            );
        }
        final DenseVectorFieldMapper.DenseVectorFieldType vectorFieldType = (DenseVectorFieldMapper.DenseVectorFieldType) fieldType;
        return vectorFieldType.createExactKnnQuery(query, vectorSimilarity);
    }

    @Override
    protected boolean doEquals(ExactKnnQueryBuilder other) {
        return field.equals(other.field) && Objects.equals(query, other.query) && Objects.equals(vectorSimilarity, other.vectorSimilarity);
    }

    @Override
    protected int doHashCode() {
        return Objects.hash(field, Objects.hashCode(query), vectorSimilarity);
    }

    @Override
    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
        return this;
    }

    @Override
    public TransportVersion getMinimalSupportedVersion() {
        return TransportVersion.minimumCompatible();
    }
}
