diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index 2cff2de4a..a43a31ed7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.annotations.Experimental; import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ImmutableGraphIndex.NodeAtLevel; @@ -325,6 +326,7 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.parallelExecutor = parallelExecutor; this.graph = new OnHeapGraphIndex(maxDegrees, neighborOverflow, new VamanaDiversityProvider(scoreProvider, alpha)); + this.searchers = ExplicitThreadLocal.withInitial(() -> { var gs = new GraphSearcher(graph); gs.usePruning(false); @@ -338,6 +340,58 @@ public GraphIndexBuilder(BuildScoreProvider scoreProvider, this.rng = new Random(0); } + /** + * Create this builder from an existing {@link io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex}, this is useful when we just loaded a graph from disk + * copy it into {@link OnHeapGraphIndex} and then start mutating it with minimal overhead of recreating the mutable {@link OnHeapGraphIndex} used in the new GraphIndexBuilder object + * + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param mutableGraphIndex a mutable graph index. + * @param beamWidth the width of the beam used during the graph building process. + * @param neighborOverflow the factor determining how many additional neighbors are allowed beyond the configured limit. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * @param refineFinalGraph whether to perform a refinement step on the final graph structure. + * @param simdExecutor the ForkJoinPool executor used for SIMD tasks during graph building. + * @param parallelExecutor the ForkJoinPool executor used for general parallelization during graph building. + * + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + private GraphIndexBuilder(BuildScoreProvider buildScoreProvider, int dimension, MutableGraphIndex mutableGraphIndex, int beamWidth, float neighborOverflow, float alpha, boolean addHierarchy, boolean refineFinalGraph, ForkJoinPool simdExecutor, ForkJoinPool parallelExecutor) { + if (beamWidth <= 0) { + throw new IllegalArgumentException("beamWidth must be positive"); + } + if (neighborOverflow < 1.0f) { + throw new IllegalArgumentException("neighborOverflow must be >= 1.0"); + } + if (alpha <= 0) { + throw new IllegalArgumentException("alpha must be positive"); + } + + this.scoreProvider = buildScoreProvider; + this.neighborOverflow = neighborOverflow; + this.dimension = dimension; + this.alpha = alpha; + this.addHierarchy = addHierarchy; + this.refineFinalGraph = refineFinalGraph; + this.beamWidth = beamWidth; + this.simdExecutor = simdExecutor; + this.parallelExecutor = parallelExecutor; + + this.graph = mutableGraphIndex; + + this.searchers = ExplicitThreadLocal.withInitial(() -> { + var gs = new GraphSearcher(graph); + gs.usePruning(false); + return gs; + }); + + // in scratch, we store candidates in reverse order: worse candidates are first + this.naturalScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + this.concurrentScratch = ExplicitThreadLocal.withInitial(() -> new NodeArray(max(beamWidth, graph.maxDegree() + 1))); + + this.rng = new Random(0); + } + // used by Cassandra when it fine-tunes the PQ codebook public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvider newProvider) { var newBuilder = new GraphIndexBuilder(newProvider, @@ -450,13 +504,13 @@ public void cleanup() { // clean up overflowed neighbor lists parallelExecutor.submit(() -> { IntStream.range(0, graph.getIdUpperBound()).parallel().forEach(id -> { - for (int layer = 0; layer <= graph.getMaxLevel(); layer++) { + for (int level = 0; level <= graph.getMaxLevel(); level++) { graph.enforceDegree(id); } }); }).join(); - graph.allMutationsCompleted(); + graph.setAllMutationsCompleted(); } private void improveConnections(int node) { @@ -825,6 +879,9 @@ public void load(RandomAccessReader in) throws IOException { loadV3(in, size); } else { version = in.readInt(); + if (version != 4) { + throw new IOException("Unsupported version: " + version); + } loadV4(in); } } @@ -836,15 +893,18 @@ private void loadV4(RandomAccessReader in) throws IOException { } int layerCount = in.readInt(); - int entryNode = in.readInt(); var layerDegrees = new ArrayList(layerCount); + for (int level = 0; level < layerCount; level++) { + layerDegrees.add(in.readInt()); + } + + int entryNode = in.readInt(); Map nodeLevelMap = new HashMap<>(); // Read layer info for (int level = 0; level < layerCount; level++) { int layerSize = in.readInt(); - layerDegrees.add(in.readInt()); for (int i = 0; i < layerSize; i++) { int nodeId = in.readInt(); int nNeighbors = in.readInt(); @@ -860,6 +920,7 @@ private void loadV4(RandomAccessReader in) throws IOException { var ca = new NodeArray(nNeighbors); for (int j = 0; j < nNeighbors; j++) { int neighbor = in.readInt(); + float score = in.readFloat(); ca.addInOrder(neighbor, sf.similarityTo(neighbor)); } graph.connectNode(level, nodeId, ca); @@ -909,4 +970,61 @@ private void loadV3(RandomAccessReader in, int size) throws IOException { graph.updateEntryNode(new NodeAtLevel(0, entryNode)); graph.setDegrees(List.of(maxDegree)); } + + /** + * Convenience method to build a new graph from an existing one, with the addition of new nodes. + * This is useful when we want to merge a new set of vectors into an existing graph that is already on disk. + * + * @param in a reader from which to read the on-heap graph. + * @param newVectors a super set RAVV containing the new vectors to be added to the graph as well as the old ones that are already in the graph + * @param buildScoreProvider the provider responsible for calculating build scores. + * @param startingNodeOffset the offset in the newVectors RAVV where the new vectors start + * @param graphToRavvOrdMap a mapping from the old graph's node ids to the newVectors RAVV node ids + * @param beamWidth the width of the beam used during the graph building process. + * @param overflowRatio the ratio of extra neighbors to allow temporarily when inserting a node. + * @param alpha the weight factor for balancing score computations. + * @param addHierarchy whether to add hierarchical structures while building the graph. + * + * @return the in-memory representation of the graph index. + * @throws IOException if an I/O error occurs during the graph loading or conversion process. + */ + @Experimental + public static ImmutableGraphIndex buildAndMergeNewNodes(RandomAccessReader in, + RandomAccessVectorValues newVectors, + BuildScoreProvider buildScoreProvider, + int startingNodeOffset, + int[] graphToRavvOrdMap, + int beamWidth, + float overflowRatio, + float alpha, + boolean addHierarchy) throws IOException { + + var diversityProvider = new VamanaDiversityProvider(buildScoreProvider, alpha); + + try (MutableGraphIndex graph = OnHeapGraphIndex.load(in, overflowRatio, diversityProvider);) { + + GraphIndexBuilder builder = new GraphIndexBuilder( + buildScoreProvider, + newVectors.dimension(), + graph, + beamWidth, + overflowRatio, + alpha, + addHierarchy, + true, + PhysicalCoreExecutor.pool(), + ForkJoinPool.commonPool() + ); + + var vv = newVectors.threadLocalSupplier(); + + // parallel graph construction from the merge documents Ids + PhysicalCoreExecutor.pool().submit(() -> IntStream.range(startingNodeOffset, newVectors.size()).parallel().forEach(ord -> { + builder.addGraphNode(ord, vv.get().getVector(graphToRavvOrdMap[ord])); + })).join(); + + builder.cleanup(); + return builder.getGraph(); + } + } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java index 2e88e6dd4..36ec49a16 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/MutableGraphIndex.java @@ -166,5 +166,10 @@ interface MutableGraphIndex extends ImmutableGraphIndex { * Signals that all mutations have been completed and the graph will not be mutated any further. * Should be called by the builder after all mutations are completed (during cleanup). */ - void allMutationsCompleted(); + void setAllMutationsCompleted(); + + /** + * Returns true if all mutations have been completed. This is signaled by calling setAllMutationsCompleted. + */ + boolean allMutationsCompleted(); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java index 7ddbf7897..29999bfde 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/OnHeapGraphIndex.java @@ -24,6 +24,8 @@ package io.github.jbellis.jvector.graph; +import io.github.jbellis.jvector.annotations.Experimental; +import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ConcurrentNeighborMap.Neighbors; import io.github.jbellis.jvector.graph.diversity.DiversityProvider; import io.github.jbellis.jvector.util.Accountable; @@ -37,9 +39,10 @@ import java.io.DataOutput; import java.io.IOException; -import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.NoSuchElementException; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicIntegerArray; @@ -367,10 +370,14 @@ public void setDegrees(List layerDegrees) { } @Override - public void allMutationsCompleted() { + public void setAllMutationsCompleted() { allMutationsCompleted = true; } + @Override + public boolean allMutationsCompleted() { + return allMutationsCompleted; + } /** * A concurrent View of the graph that is safe to search concurrently with updates and with other @@ -490,44 +497,101 @@ public String toString() { /** * Saves the graph to the given DataOutput for reloading into memory later */ + @Experimental @Deprecated - public void save(DataOutput out) { - if (deletedNodes.cardinality() > 0) { - throw new IllegalStateException("Cannot save a graph that has deleted nodes. Call cleanup() first"); - } - - try (var view = getView()) { - out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number - out.writeInt(4); // The version - - // Write graph-level properties. - out.writeInt(layers.size()); - assert view.entryNode().level == getMaxLevel(); - out.writeInt(view.entryNode().node); - - for (int level = 0; level < layers.size(); level++) { - out.writeInt(size(level)); - out.writeInt(getDegree(level)); - - // Save neighbors from the layer. - var baseLayer = layers.get(level); - baseLayer.forEach((nodeId, neighbors) -> { - try { - NodesIterator iterator = neighbors.iterator(); - out.writeInt(nodeId); - out.writeInt(iterator.size()); - for (int n = 0; n < iterator.size(); n++) { - out.writeInt(iterator.nextInt()); - } - assert !iterator.hasNext(); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - }); + public void save(DataOutput out) throws IOException { + if (!allMutationsCompleted()) { + throw new IllegalStateException("Cannot save a graph with pending mutations. Call cleanup() first"); + } + + out.writeInt(OnHeapGraphIndex.MAGIC); // the magic number + out.writeInt(4); // The version + + // Write graph-level properties. + out.writeInt(layers.size()); + for (int level = 0; level < layers.size(); level++) { + out.writeInt(getDegree(level)); + } + + var entryNode = entryPoint.get(); + assert entryNode.level == getMaxLevel(); + out.writeInt(entryNode.node); + + for (int level = 0; level < layers.size(); level++) { + out.writeInt(size(level)); + + // Save neighbors from the layer. + var it = nodeStream(level).iterator(); + while (it.hasNext()) { + int nodeId = it.nextInt(); + var neighbors = layers.get(level).get(nodeId); + out.writeInt(nodeId); + out.writeInt(neighbors.size()); + + for (int n = 0; n < neighbors.size(); n++) { + out.writeInt(neighbors.getNode(n)); + out.writeFloat(neighbors.getScore(n)); + } + } + } + } + + /** + * Saves the graph to the given DataOutput for reloading into memory later + */ + @Experimental + @Deprecated + public static OnHeapGraphIndex load(RandomAccessReader in, double overflowRatio, DiversityProvider diversityProvider) throws IOException { + int magic = in.readInt(); // the magic number + if (magic != OnHeapGraphIndex.MAGIC) { + throw new IOException("Unsupported magic number: " + magic); + } + + int version = in.readInt(); // The version + if (version != 4) { + throw new IOException("Unsupported version: " + version); + } + + // Write graph-level properties. + int layerCount = in.readInt(); + var layerDegrees = new ArrayList(layerCount); + for (int level = 0; level < layerCount; level++) { + layerDegrees.add(in.readInt()); + } + + int entryNode = in.readInt(); + + var graph = new OnHeapGraphIndex(layerDegrees, overflowRatio, diversityProvider); + + Map nodeLevelMap = new HashMap<>(); + + for (int level = 0; level < layerCount; level++) { + int layerSize = in.readInt(); + + for (int i = 0; i < layerSize; i++) { + int nodeId = in.readInt(); + int nNeighbors = in.readInt(); + + var ca = new NodeArray(nNeighbors); + for (int j = 0; j < nNeighbors; j++) { + int neighbor = in.readInt(); + float score = in.readFloat(); + ca.addInOrder(neighbor, score); + } + graph.connectNode(level, nodeId, ca); + nodeLevelMap.put(nodeId, level); } - } catch (IOException e) { - throw new UncheckedIOException(e); } + + for (var k : nodeLevelMap.keySet()) { + NodeAtLevel nal = new NodeAtLevel(nodeLevelMap.get(k), k); + graph.markComplete(nal); + } + + graph.setDegrees(layerDegrees); + graph.updateEntryNode(new NodeAtLevel(graph.getMaxLevel(), entryNode)); + + return graph; } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index a597aa78f..161fb0f07 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -225,10 +225,6 @@ public Set getFeatureSet() { return features.keySet(); } - public int getDimension() { - return dimension; - } - @Override public int size(int level) { return layerInfo.get(level).size; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index b8ec5fa5f..0ffdf72eb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -25,6 +25,8 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import java.util.stream.IntStream; + /** * Encapsulates comparing node distances for GraphIndexBuilder. */ @@ -83,8 +85,18 @@ public interface BuildScoreProvider { /** * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + * + * Helper method for the special case that mapping between graph node IDs and ravv ordinals is the identity function. */ static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, VectorSimilarityFunction similarityFunction) { + return randomAccessScoreProvider(ravv, IntStream.range(0, ravv.size()).toArray(), similarityFunction); + } + + /** + * Returns a BSP that performs exact score comparisons using the given RandomAccessVectorValues and VectorSimilarityFunction. + * graphToRavvOrdMap maps graph node IDs to ravv ordinals. + */ + static BuildScoreProvider randomAccessScoreProvider(RandomAccessVectorValues ravv, int[] graphToRavvOrdMap, VectorSimilarityFunction similarityFunction) { // We need two sources of vectors in order to perform diversity check comparisons without // colliding. ThreadLocalSupplier makes this a no-op if the RAVV is actually un-shared. var vectors = ravv.threadLocalSupplier(); @@ -113,22 +125,22 @@ public VectorFloat approximateCentroid() { @Override public SearchScoreProvider searchProviderFor(VectorFloat vector) { var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(vector, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(vector, graphToRavvOrdMap, similarityFunction, vc); } @Override public SearchScoreProvider searchProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(node1); + var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); return searchProviderFor(v); } @Override public SearchScoreProvider diversityProviderFor(int node1) { RandomAccessVectorValues randomAccessVectorValues = vectors.get(); - var v = randomAccessVectorValues.getVector(node1); + var v = randomAccessVectorValues.getVector(graphToRavvOrdMap[node1]); var vc = vectorsCopy.get(); - return DefaultSearchScoreProvider.exact(v, similarityFunction, vc); + return DefaultSearchScoreProvider.exact(v, graphToRavvOrdMap, similarityFunction, vc); } }; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java index 0754b39d7..de46762b2 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/DefaultSearchScoreProvider.java @@ -78,4 +78,20 @@ public float similarityTo(int node2) { }; return new DefaultSearchScoreProvider(sf); } + + /** + * A SearchScoreProvider for a single-pass search based on exact similarity. + * Generally only suitable when your RandomAccessVectorValues is entirely in-memory, + * e.g. during construction. + */ + public static DefaultSearchScoreProvider exact(VectorFloat v, int[] graphToRavvOrdMap ,VectorSimilarityFunction vsf, RandomAccessVectorValues ravv) { + // don't use ESF.reranker, we need thread safety here + var sf = new ScoreFunction.ExactScoreFunction() { + @Override + public float similarityTo(int node2) { + return vsf.compare(v, ravv.getVector(graphToRavvOrdMap[node2])); + } + }; + return new DefaultSearchScoreProvider(sf); + } } \ No newline at end of file diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index f66a2c6e4..73e59b20f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -77,6 +77,8 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept * Build a PQVectors instance from the given RandomAccessVectorValues. The vectors are encoded in parallel * and split into chunks to avoid exceeding the maximum array size. * + * This is a helper method for the special case where the ordinals mapping in the graph and the RAVV/PQVectors are the same. + * * @param pq the ProductQuantization to use * @param vectorCount the number of vectors to encode * @param ravv the RandomAccessVectorValues to encode @@ -84,6 +86,21 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept * @return the PQVectors instance */ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { + return encodeAndBuild(pq, vectorCount, IntStream.range(0, vectorCount).toArray(), ravv, simdExecutor); + } + + /** + * Build a PQVectors instance from the given RandomAccessVectorValues. The vectors are encoded in parallel + * and split into chunks to avoid exceeding the maximum array size. + * + * @param pq the ProductQuantization to use + * @param vectorCount the number of vectors to encode + * @param ravv the RandomAccessVectorValues to encode + * @param simdExecutor the ForkJoinPool to use for SIMD operations + * @param ordinalsMapping the graph ordinals to RAVV mapping + * @return the PQVectors instance + */ + public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vectorCount, int[] ordinalsMapping, RandomAccessVectorValues ravv, ForkJoinPool simdExecutor) { int compressedDimension = pq.compressedVectorSize(); PQLayout layout = new PQLayout(vectorCount,compressedDimension); final ByteSequence[] chunks = new ByteSequence[layout.totalChunks]; @@ -98,13 +115,13 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect // The changes are concurrent, but because they are coordinated and do not overlap, we can use parallel streams // and then we are guaranteed safe publication because we join the thread after completion. var ravvCopy = ravv.threadLocalSupplier(); - simdExecutor.submit(() -> IntStream.range(0, ravv.size()) + simdExecutor.submit(() -> IntStream.range(0, ordinalsMapping.length) .parallel() .forEach(ordinal -> { // Retrieve the slice and mutate it. var localRavv = ravvCopy.get(); var slice = PQVectors.get(chunks, ordinal, layout.fullChunkVectors, pq.getSubspaceCount()); - var vector = localRavv.getVector(ordinal); + var vector = localRavv.getVector(ordinalsMapping[ordinal]); if (vector != null) pq.encodeTo(vector, slice); else diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java new file mode 100644 index 000000000..c2cf9fc90 --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/OnHeapGraphIndexTest.java @@ -0,0 +1,262 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph; + +import com.carrotsearch.randomizedtesting.RandomizedTest; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import io.github.jbellis.jvector.TestUtil; +import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.disk.SimpleWriter; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.graph.diversity.VamanaDiversityProvider; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.util.Bits; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.apache.logging.log4j.Logger; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertEquals; + +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class OnHeapGraphIndexTest extends RandomizedTest { + private final static Logger log = org.apache.logging.log4j.LogManager.getLogger(OnHeapGraphIndexTest.class); + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int NUM_BASE_VECTORS = 100; + private static final int NUM_NEW_VECTORS = 100; + private static final int NUM_ALL_VECTORS = NUM_BASE_VECTORS + NUM_NEW_VECTORS; + private static final int DIMENSION = 16; + private static final int M = 8; + private static final int BEAM_WIDTH = 100; + private static final float ALPHA = 1.2f; + private static final float NEIGHBOR_OVERFLOW = 1.2f; + private static final boolean ADD_HIERARCHY = false; + private static final int TOP_K = 10; + + private Path testDirectory; + + private ArrayList> baseVectors; + private ArrayList> newVectors; + private ArrayList> allVectors; + private RandomAccessVectorValues baseVectorsRavv; + private RandomAccessVectorValues newVectorsRavv; + private RandomAccessVectorValues allVectorsRavv; + private VectorFloat queryVector; + private int[] groundTruthAllVectors; + private BuildScoreProvider baseBuildScoreProvider; + private BuildScoreProvider allBuildScoreProvider; + private ImmutableGraphIndex baseGraphIndex; + private ImmutableGraphIndex allGraphIndex; + + @Before + public void setup() throws IOException { + testDirectory = Files.createTempDirectory(this.getClass().getSimpleName()); + baseVectors = new ArrayList<>(NUM_BASE_VECTORS); + newVectors = new ArrayList<>(NUM_NEW_VECTORS); + allVectors = new ArrayList<>(NUM_ALL_VECTORS); + for (int i = 0; i < NUM_BASE_VECTORS; i++) { + VectorFloat vector = createRandomVector(DIMENSION); + baseVectors.add(vector); + allVectors.add(vector); + } + for (int i = 0; i < NUM_NEW_VECTORS; i++) { + VectorFloat vector = createRandomVector(DIMENSION); + newVectors.add(vector); + allVectors.add(vector); + } + + // wrap the raw vectors in a RandomAccessVectorValues + baseVectorsRavv = new ListRandomAccessVectorValues(baseVectors, DIMENSION); + newVectorsRavv = new ListRandomAccessVectorValues(newVectors, DIMENSION); + allVectorsRavv = new ListRandomAccessVectorValues(allVectors, DIMENSION); + + queryVector = createRandomVector(DIMENSION); + groundTruthAllVectors = getGroundTruth(allVectorsRavv, queryVector, TOP_K, VectorSimilarityFunction.EUCLIDEAN); + + // score provider using the raw, in-memory vectors + baseBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(baseVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + allBuildScoreProvider = BuildScoreProvider.randomAccessScoreProvider(allVectorsRavv, VectorSimilarityFunction.EUCLIDEAN); + var baseGraphIndexBuilder = new GraphIndexBuilder(baseBuildScoreProvider, + baseVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy + var allGraphIndexBuilder = new GraphIndexBuilder(allBuildScoreProvider, + allVectorsRavv.dimension(), + M, // graph degree + BEAM_WIDTH, // construction search depth + NEIGHBOR_OVERFLOW, // allow degree overflow during construction by this factor + ALPHA, // relax neighbor diversity requirement by this factor + ADD_HIERARCHY); // add the hierarchy + + baseGraphIndex = baseGraphIndexBuilder.build(baseVectorsRavv); + allGraphIndex = allGraphIndexBuilder.build(allVectorsRavv); + } + + @After + public void tearDown() { + TestUtil.deleteQuietly(testDirectory); + } + + + /** + * Create an {@link OnHeapGraphIndex} persist it as a {@link OnDiskGraphIndex} and reconstruct back to a mutable {@link OnHeapGraphIndex} + * Make sure that both graphs are equivalent + * @throws IOException + */ + @Test + public void testReconstructionOfOnHeapGraphIndex() throws IOException { + var graphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); + + log.info("Writing graph to {}", graphOutputPath); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, graphOutputPath); + + log.info("Writing on-heap graph to {}", heapGraphOutputPath); + try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) { + ((OnHeapGraphIndex) baseGraphIndex).save(writer); + } + + log.info("Reading on-heap graph from {}", heapGraphOutputPath); + MutableGraphIndex reconstructedOnHeapGraphIndex; + try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { + reconstructedOnHeapGraphIndex = OnHeapGraphIndex.load(readerSupplier.get(), NEIGHBOR_OVERFLOW, new VamanaDiversityProvider(baseBuildScoreProvider, ALPHA)); + } + + try (var readerSupplier = new SimpleMappedReader.Supplier(graphOutputPath.toAbsolutePath()); + var onDiskGraph = OnDiskGraphIndex.load(readerSupplier)) { + TestUtil.assertGraphEquals(baseGraphIndex, onDiskGraph); + try (var onDiskView = onDiskGraph.getView()) { + validateVectors(onDiskView, baseVectorsRavv); + } + + TestUtil.assertGraphEquals(baseGraphIndex, reconstructedOnHeapGraphIndex); + TestUtil.assertGraphEquals(onDiskGraph, reconstructedOnHeapGraphIndex); + } + } + + /** + * Create {@link OnDiskGraphIndex} then append to it via {@link GraphIndexBuilder#buildAndMergeNewNodes} + * Verify that the resulting OnHeapGraphIndex is equivalent to the graph that would have been alternatively generated by bulk index into a new {@link OnDiskGraphIndex} + */ + @Test + public void testIncrementalInsertionFromOnDiskIndex() throws IOException { + var outputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName()); + var heapGraphOutputPath = testDirectory.resolve("testReconstructionOfOnHeapGraphIndex_" + baseGraphIndex.getClass().getSimpleName() + "_onHeap"); + + log.info("Writing graph to {}", outputPath); + TestUtil.writeGraph(baseGraphIndex, baseVectorsRavv, outputPath); + + log.info("Writing on-heap graph to {}", heapGraphOutputPath); + try (SimpleWriter writer = new SimpleWriter(heapGraphOutputPath.toAbsolutePath())) { + ((OnHeapGraphIndex) baseGraphIndex).save(writer); + } + + log.info("Reading on-heap graph from {}", heapGraphOutputPath); + try (var readerSupplier = new SimpleMappedReader.Supplier(heapGraphOutputPath.toAbsolutePath())) { + // We will create a trivial 1:1 mapping between the new graph and the ravv + final int[] graphToRavvOrdMap = IntStream.range(0, allVectorsRavv.size()).toArray(); + ImmutableGraphIndex reconstructedAllNodeOnHeapGraphIndex = GraphIndexBuilder.buildAndMergeNewNodes(readerSupplier.get(), allVectorsRavv, allBuildScoreProvider, NUM_BASE_VECTORS, graphToRavvOrdMap, BEAM_WIDTH, NEIGHBOR_OVERFLOW, ALPHA, ADD_HIERARCHY); + + // Verify that the recall is similar + float recallFromReconstructedAllNodeOnHeapGraphIndex = calculateRecall(reconstructedAllNodeOnHeapGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); + float recallFromAllGraphIndex = calculateRecall(allGraphIndex, allBuildScoreProvider, queryVector, groundTruthAllVectors, TOP_K); + Assert.assertEquals(recallFromReconstructedAllNodeOnHeapGraphIndex, recallFromAllGraphIndex, 0.01f); + } + } + + public static void validateVectors(OnDiskGraphIndex.View view, RandomAccessVectorValues ravv) { + for (int i = 0; i < view.size(); i++) { + assertEquals("Incorrect vector at " + i, ravv.getVector(i), view.getVector(i)); + } + } + + private VectorFloat createRandomVector(int dimension) { + VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); + for (int i = 0; i < dimension; i++) { + vector.set(i, (float) Math.random()); + } + return vector; + } + + /** + * Get the ground truth for a query vector + * @param ravv the vectors to search + * @param queryVector the query vector + * @param topK the number of results to return + * @param similarityFunction the similarity function to use + + * @return the ground truth + */ + private static int[] getGroundTruth(RandomAccessVectorValues ravv, VectorFloat queryVector, int topK, VectorSimilarityFunction similarityFunction) { + var exactResults = new ArrayList(); + for (int i = 0; i < ravv.size(); i++) { + float similarityScore = similarityFunction.compare(queryVector, ravv.getVector(i)); + exactResults.add(new SearchResult.NodeScore(i, similarityScore)); + } + exactResults.sort((a, b) -> Float.compare(b.score, a.score)); + return exactResults.stream().limit(topK).mapToInt(nodeScore -> nodeScore.node).toArray(); + } + + private static float calculateRecall(ImmutableGraphIndex graphIndex, BuildScoreProvider buildScoreProvider, VectorFloat queryVector, int[] groundTruth, int k) throws IOException { + try (GraphSearcher graphSearcher = new GraphSearcher(graphIndex)){ + SearchScoreProvider ssp = buildScoreProvider.searchProviderFor(queryVector); + var searchResults = graphSearcher.search(ssp, k, Bits.ALL); + var predicted = Arrays.stream(searchResults.getNodes()).mapToInt(nodeScore -> nodeScore.node).boxed().collect(Collectors.toSet()); + return calculateRecall(predicted, groundTruth, k); + } + } + /** + * Calculate the recall for a set of predicted results + * @param predicted the predicted results + * @param groundTruth the ground truth + * @param k the number of results to consider + * @return the recall + */ + private static float calculateRecall(Set predicted, int[] groundTruth, int k) { + int hits = 0; + int actualK = Math.min(k, Math.min(predicted.size(), groundTruth.length)); + + for (int i = 0; i < actualK; i++) { + for (int j = 0; j < actualK; j++) { + if (predicted.contains(groundTruth[j])) { + hits++; + break; + } + } + } + + return ((float) hits) / (float) actualK; + } +} diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java new file mode 100644 index 000000000..4942b8efb --- /dev/null +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProviderTest.java @@ -0,0 +1,72 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.similarity; + +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class BuildScoreProviderTest { + private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport(); + + /** + * Test that the ordinal mapping is correctly applied when creating search and diversity score providers. + */ + @Test + public void testOrdinalMapping() { + final VectorSimilarityFunction vsf = VectorSimilarityFunction.DOT_PRODUCT; + + // Create test vectors + final List> vectors = new ArrayList<>(); + vectors.add(vts.createFloatVector(new float[]{1.0f, 0.0f})); + vectors.add(vts.createFloatVector(new float[]{0.0f, 1.0f})); + vectors.add(vts.createFloatVector(new float[]{-1.0f, 0.0f})); + var ravv = new ListRandomAccessVectorValues(vectors, 2); + + // Create non-identity mapping: graph node 0 -> ravv ordinal 2, graph node 1 -> ravv ordinal 0, graph node 2 -> ravv ordinal 1 + int[] graphToRavvOrdMap = {2, 0, 1}; + + var bsp = BuildScoreProvider.randomAccessScoreProvider(ravv, graphToRavvOrdMap, vsf); + + // Test that searchProviderFor(graphNode) uses the correct RAVV ordinal + var ssp0 = bsp.searchProviderFor(0); // should use ravv ordinal 2 (vector [-1, 0]) + var ssp1 = bsp.searchProviderFor(1); // should use ravv ordinal 0 (vector [1, 0]) + var ssp2 = bsp.searchProviderFor(2); // should use ravv ordinal 1 (vector [0, 1]) + + // Verify by computing similarity between graph nodes + // Graph node 0 (vector 2:[-1, 0]) vs graph node 1 (vector 0:[1, 0]) + assertEquals(vsf.compare(vectors.get(2), vectors.get(0)), ssp0.exactScoreFunction().similarityTo(1), 1e-6f); + + // Graph node 1 (vector 0:[1, 0]) vs graph node 0 (vector 2:[-1, 0]) + assertEquals(vsf.compare(vectors.get(0), vectors.get(2)), ssp1.exactScoreFunction().similarityTo(0), 1e-6f); + + // Graph node 2 (vector 1:[0, 1]) vs graph node 1 (vector 0:[1, 0]) + assertEquals(vsf.compare(vectors.get(1), vectors.get(0)), ssp2.exactScoreFunction().similarityTo(1), 1e-6f); + + // Test diversityProviderFor uses same mapping, Graph node 0 (vector 2:[-1, 0]) vs graph node 1 (vector 0:[1, 0]) + var dsp0 = bsp.diversityProviderFor(0); + assertEquals(vsf.compare(vectors.get(2), vectors.get(0)), dsp0.exactScoreFunction().similarityTo(1), 1e-6f); + } +} \ No newline at end of file