/*
 * Decompiled with CFR 0.152.
 */
package org.apache.cassandra.index.sai;

import io.github.jbellis.jvector.util.Bits;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.NavigableSet;
import java.util.Set;
import java.util.TreeSet;
import org.apache.cassandra.db.ReadCommand;
import org.apache.cassandra.index.sai.disk.PrimaryKeyMap;
import org.apache.cassandra.index.sai.disk.v1.segment.SegmentMetadata;
import org.apache.cassandra.index.sai.disk.v1.vector.DiskAnn;
import org.apache.cassandra.index.sai.disk.v1.vector.OnDiskOrdinalsMap;
import org.apache.cassandra.index.sai.disk.v1.vector.OnHeapGraph;
import org.apache.cassandra.index.sai.utils.PrimaryKey;

public class VectorQueryContext {
    private final int limit;
    private TreeSet<PrimaryKey> shadowedPrimaryKeys;

    public VectorQueryContext(ReadCommand readCommand) {
        this.limit = readCommand.limits().count();
    }

    public int limit() {
        return this.limit;
    }

    public void recordShadowedPrimaryKeys(Set<PrimaryKey> keys) {
        if (this.shadowedPrimaryKeys == null) {
            this.shadowedPrimaryKeys = new TreeSet();
        }
        this.shadowedPrimaryKeys.addAll(keys);
    }

    public boolean shouldInclude(long sstableRowId, PrimaryKeyMap primaryKeyMap) {
        return this.shadowedPrimaryKeys == null || !this.shadowedPrimaryKeys.contains(primaryKeyMap.primaryKeyFromRowId(sstableRowId));
    }

    public boolean shouldInclude(PrimaryKey pk) {
        return this.shadowedPrimaryKeys == null || !this.shadowedPrimaryKeys.contains(pk);
    }

    public boolean containsShadowedPrimaryKey(PrimaryKey primaryKey) {
        return this.shadowedPrimaryKeys != null && this.shadowedPrimaryKeys.contains(primaryKey);
    }

    public NavigableSet<PrimaryKey> getShadowedPrimaryKeys() {
        if (this.shadowedPrimaryKeys == null) {
            return Collections.emptyNavigableSet();
        }
        return this.shadowedPrimaryKeys;
    }

    public Bits bitsetForShadowedPrimaryKeys(OnHeapGraph<PrimaryKey> graph) {
        if (this.shadowedPrimaryKeys == null) {
            return null;
        }
        return new IgnoredKeysBits(graph, this.shadowedPrimaryKeys);
    }

    public Bits bitsetForShadowedPrimaryKeys(SegmentMetadata metadata, PrimaryKeyMap primaryKeyMap, DiskAnn graph) throws IOException {
        HashSet<Integer> ignoredOrdinals = null;
        try (OnDiskOrdinalsMap.OrdinalsView ordinalsView = graph.getOrdinalsView();){
            for (PrimaryKey primaryKey : this.getShadowedPrimaryKeys()) {
                int segmentRowId;
                long sstableRowId;
                if (primaryKey.compareTo(metadata.minKey) < 0 || primaryKey.compareTo(metadata.maxKey) > 0 || (sstableRowId = primaryKeyMap.rowIdFromPrimaryKey(primaryKey)) == Long.MAX_VALUE || (segmentRowId = Math.toIntExact(sstableRowId - metadata.rowIdOffset)) < 0) continue;
                if ((long)segmentRowId > metadata.maxSSTableRowId) {
                    break;
                }
                int ordinal = ordinalsView.getOrdinalForRowId(segmentRowId);
                if (ordinal < 0) continue;
                if (ignoredOrdinals == null) {
                    ignoredOrdinals = new HashSet<Integer>();
                }
                ignoredOrdinals.add(ordinal);
            }
        }
        if (ignoredOrdinals == null) {
            return null;
        }
        return new IgnoringBits(ignoredOrdinals, metadata);
    }

    private static class IgnoredKeysBits
    implements Bits {
        private final OnHeapGraph<PrimaryKey> graph;
        private final NavigableSet<PrimaryKey> ignored;

        public IgnoredKeysBits(OnHeapGraph<PrimaryKey> graph, NavigableSet<PrimaryKey> ignored) {
            this.graph = graph;
            this.ignored = ignored;
        }

        public boolean get(int ordinal) {
            Collection<PrimaryKey> keys = this.graph.keysFromOrdinal(ordinal);
            return keys.stream().anyMatch(k -> !this.ignored.contains(k));
        }

        public int length() {
            return this.graph.size();
        }
    }

    private static class IgnoringBits
    implements Bits {
        private final Set<Integer> ignoredOrdinals;
        private final int length;

        public IgnoringBits(Set<Integer> ignoredOrdinals, SegmentMetadata metadata) {
            this.ignoredOrdinals = ignoredOrdinals;
            this.length = 1 + Math.toIntExact(metadata.maxSSTableRowId - metadata.rowIdOffset);
        }

        public boolean get(int index) {
            return !this.ignoredOrdinals.contains(index);
        }

        public int length() {
            return this.length;
        }
    }
}

