diff --git a/.gitignore b/.gitignore index 25335ca03..83d6e4182 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,12 @@ replay_pid* .idea *.iml +### Eclipse/JDTLS ### +.settings/ +.classpath +.project +.factorypath + ### VS Code ### .vscode/ diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java index 1ec2730ab..dd85f49f0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/IndexWriter.java @@ -19,6 +19,10 @@ import java.io.Closeable; import java.io.DataOutput; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; + /** * Interface for writing index data. @@ -30,4 +34,13 @@ public interface IndexWriter extends DataOutput, Closeable { * @throws IOException if an I/O error occurs */ long position() throws IOException; + + default void writeFloats(float[] floats, int offset, int count) throws IOException { + FloatBuffer fb = FloatBuffer.wrap(floats, offset, count); + ByteBuffer bb = ByteBuffer.allocate(fb.capacity() * Float.BYTES); + // DataOutput specifies BIG_ENDIAN for float + bb.order(ByteOrder.BIG_ENDIAN).asFloatBuffer().put(fb); + bb.rewind(); + write(bb.array()); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java index 251592d64..46d91f8e3 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/disk/SimpleMappedReader.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.RandomAccessFile; import java.lang.reflect.Field; +import java.nio.ByteOrder; import java.nio.MappedByteBuffer; import java.nio.channels.FileChannel; import java.nio.file.Path; @@ -73,6 +74,7 @@ public Supplier(Path path) throws IOException { throw new RuntimeException("SimpleMappedReader doesn't support files above 2GB"); } this.buffer = raf.getChannel().map(FileChannel.MapMode.READ_ONLY, 0, raf.length()); + this.buffer.order(ByteOrder.BIG_ENDIAN); this.buffer.load(); } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java index b8c86581f..03dd2ec98 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/Feature.java @@ -16,7 +16,8 @@ package io.github.jbellis.jvector.graph.disk.feature; -import java.io.DataOutput; +import io.github.jbellis.jvector.disk.IndexWriter; + import java.io.IOException; import java.util.EnumMap; import java.util.function.IntFunction; @@ -35,9 +36,9 @@ default boolean isFused() { int featureSize(); - void writeHeader(DataOutput out) throws IOException; + void writeHeader(IndexWriter out) throws IOException; - default void writeInline(DataOutput out, State state) throws IOException { + default void writeInline(IndexWriter out, State state) throws IOException { // default no-op } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java index deddb7d4f..d54630999 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedFeature.java @@ -16,10 +16,10 @@ package io.github.jbellis.jvector.graph.disk.feature; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.util.Accountable; -import java.io.DataOutput; import java.io.IOException; /** @@ -33,7 +33,7 @@ default boolean isFused() { return true; } - void writeSourceFeature(DataOutput out, State state) throws IOException; + void writeSourceFeature(IndexWriter out, State state) throws IOException; interface InlineSource extends Accountable {} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java index 5fcf59293..840650ba5 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/FusedPQ.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.graph.disk.feature; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.ImmutableGraphIndex; import io.github.jbellis.jvector.graph.disk.CommonHeader; @@ -31,7 +32,6 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; import java.util.function.IntFunction; @@ -97,14 +97,14 @@ public ScoreFunction.ApproximateScoreFunction approximateScoreFunctionFor(Vector } @Override - public void writeHeader(DataOutput out) throws IOException { + public void writeHeader(IndexWriter out) throws IOException { pq.write(out, OnDiskGraphIndex.CURRENT_VERSION); } // this is an awkward fit for the Feature.State design since we need to // generate the fused set based on the neighbors of the node, not just the node itself @Override - public void writeInline(DataOutput out, Feature.State state_) throws IOException { + public void writeInline(IndexWriter out, Feature.State state_) throws IOException { var state = (FusedPQ.State) state_; var neighbors = state.view.getNeighborsIterator(0, state.nodeId); @@ -138,7 +138,7 @@ public State(ImmutableGraphIndex.View view, IntFunction> compres } @Override - public void writeSourceFeature(DataOutput out, Feature.State state_) throws IOException { + public void writeSourceFeature(IndexWriter out, Feature.State state_) throws IOException { var state = (FusedPQ.State) state_; var compressed = state.compressedVectorFunction.apply(state.nodeId); var temp = pqCodeScratch.get(); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java index 0e80bd467..005103ddb 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/InlineVectors.java @@ -16,13 +16,13 @@ package io.github.jbellis.jvector.graph.disk.feature; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; /** @@ -59,12 +59,12 @@ static InlineVectors load(CommonHeader header, RandomAccessReader reader) { } @Override - public void writeHeader(DataOutput out) { + public void writeHeader(IndexWriter out) { // common header contains dimension, which is sufficient } @Override - public void writeInline(DataOutput out, Feature.State state) throws IOException { + public void writeInline(IndexWriter out, Feature.State state) throws IOException { vectorTypeSupport.writeFloatVector(out, ((InlineVectors.State) state).vector); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java index 866bd171c..e0ccbd264 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/NVQ.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.graph.disk.feature; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; @@ -26,7 +27,6 @@ import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; @@ -70,12 +70,12 @@ static NVQ load(CommonHeader header, RandomAccessReader reader) { } @Override - public void writeHeader(DataOutput out) throws IOException { + public void writeHeader(IndexWriter out) throws IOException { nvq.write(out, OnDiskGraphIndex.CURRENT_VERSION); } @Override - public void writeInline(DataOutput out, Feature.State state_) throws IOException { + public void writeInline(IndexWriter out, Feature.State state_) throws IOException { var state = (NVQ.State) state_; state.vector.write(out); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java index d90aee603..d3dcffb2e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedFeature.java @@ -16,12 +16,13 @@ package io.github.jbellis.jvector.graph.disk.feature; -import java.io.DataOutput; +import io.github.jbellis.jvector.disk.IndexWriter; + import java.io.IOException; public interface SeparatedFeature extends Feature { void setOffset(long offset); long getOffset(); - void writeSeparately(DataOutput out, State state) throws IOException; + void writeSeparately(IndexWriter out, State state) throws IOException; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java index d7cb8080b..b65370632 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedNVQ.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.graph.disk.feature; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; @@ -25,7 +26,6 @@ import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; @@ -68,13 +68,13 @@ public int featureSize() { } @Override - public void writeHeader(DataOutput out) throws IOException { + public void writeHeader(IndexWriter out) throws IOException { nvq.write(out, OnDiskGraphIndex.CURRENT_VERSION); out.writeLong(offset); } @Override - public void writeSeparately(DataOutput out, State state_) throws IOException { + public void writeSeparately(IndexWriter out, State state_) throws IOException { var state = (NVQ.State) state_; if (state.vector != null) { state.vector.write(out); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java index 50bcef545..017aef53f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/feature/SeparatedVectors.java @@ -16,12 +16,12 @@ package io.github.jbellis.jvector.graph.disk.feature; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.disk.CommonHeader; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.io.UncheckedIOException; @@ -61,12 +61,12 @@ public int featureSize() { } @Override - public void writeHeader(DataOutput out) throws IOException { + public void writeHeader(IndexWriter out) throws IOException { out.writeLong(offset); } @Override - public void writeSeparately(DataOutput out, State state_) throws IOException { + public void writeSeparately(IndexWriter out, State state_) throws IOException { var state = (InlineVectors.State) state_; if (state.vector != null) { vectorTypeSupport.writeFloatVector(out, state.vector); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java index b89bd9c4c..f2d2a0cf1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.quantization; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.RamUsageEstimator; @@ -23,7 +24,6 @@ import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; import java.util.Objects; @@ -37,7 +37,7 @@ protected BQVectors(BinaryQuantization bq) { } @Override - public void write(DataOutput out, int version) throws IOException { + public void write(IndexWriter out, int version) throws IOException { // BQ centering data bq.write(out, version); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java index 3cfe950c1..f0d660301 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java @@ -16,13 +16,13 @@ package io.github.jbellis.jvector.quantization; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.util.Objects; import java.util.concurrent.ForkJoinPool; @@ -121,7 +121,7 @@ public int compressedVectorSize() { } @Override - public void write(DataOutput out, int version) throws IOException { + public void write(IndexWriter out, int version) throws IOException { out.writeInt(dimension); // We used to record the center of the dataset but this actually degrades performance. // Write a zero vector to maintain compatibility. diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java index ee60859b7..767659148 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java @@ -16,27 +16,27 @@ package io.github.jbellis.jvector.quantization; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.Accountable; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.DataOutput; import java.io.IOException; public interface CompressedVectors extends Accountable { /** - * Write the compressed vectors to the given DataOutput - * @param out the DataOutput to write to + * Write the compressed vectors to the given IndexWriter + * @param out the IndexWriter to write to * @param version the serialization version. versions 2 and 3 are supported */ - void write(DataOutput out, int version) throws IOException; + void write(IndexWriter out, int version) throws IOException; /** - * Write the compressed vectors to the given DataOutput at the current serialization version + * Write the compressed vectors to the given IndexWriter at the current serialization version */ - default void write(DataOutput out) throws IOException { + default void write(IndexWriter out) throws IOException { write(out, OnDiskGraphIndex.CURRENT_VERSION); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java index bf8019d9d..a21b74060 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java @@ -16,13 +16,13 @@ package io.github.jbellis.jvector.quantization; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.util.RamUsageEstimator; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; import java.util.Objects; @@ -48,7 +48,7 @@ public int count() { } @Override - public void write(DataOutput out, int version) throws IOException + public void write(IndexWriter out, int version) throws IOException { // serializing NVQ at the given version nvq.write(out, version); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java index 093f5ad3c..aef0325b9 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.annotations.VisibleForTesting; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; @@ -27,7 +28,6 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -70,11 +70,11 @@ public ByteSequence createByteSequence(int nDimensions) { }; /** - * Writes the BitsPerDimension to DataOutput. - * @param out the DataOutput into which to write the object + * Writes the BitsPerDimension to IndexWriter. + * @param out the IndexWriter into which to write the object * @throws IOException if there is a problem writing to out. */ - public void write(DataOutput out) throws IOException { + public void write(IndexWriter out) throws IOException { out.writeInt(getInt()); } @@ -251,12 +251,12 @@ static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) { } /** - * Writes the instance to a DataOutput. - * @param out DataOutput to write to + * Writes the instance to a IndexWriter. + * @param out IndexWriter to write to * @param version serialization version. - * @throws IOException fails if we cannot write to the DataOutput + * @throws IOException fails if we cannot write to the IndexWriter */ - public void write(DataOutput out, int version) throws IOException + public void write(IndexWriter out, int version) throws IOException { if (version > OnDiskGraphIndex.CURRENT_VERSION) { throw new IllegalArgumentException("Unsupported serialization version " + version); @@ -432,11 +432,11 @@ public static QuantizedVector createEmpty(int[][] subvectorSizesAndOffsets, Bits /** - * Write the instance to a DataOutput - * @param out the DataOutput - * @throws IOException fails if we cannot write to the DataOutput + * Write the instance to a IndexWriter + * @param out the IndexWriter + * @throws IOException fails if we cannot write to the IndexWriter */ - public void write(DataOutput out) throws IOException { + public void write(IndexWriter out) throws IOException { out.writeInt(subVectors.length); for (var sv : subVectors) { @@ -593,11 +593,11 @@ private QuantizedSubVector(ByteSequence bytes, int originalDimensions, BitsPe } /** - * Write the instance to a DataOutput - * @param out the DataOutput - * @throws IOException fails if we cannot write to the DataOutput + * Write the instance to a IndexWriter + * @param out the IndexWriter + * @throws IOException fails if we cannot write to the IndexWriter */ - public void write(DataOutput out) throws IOException { + public void write(IndexWriter out) throws IOException { bitsPerDimension.write(out); out.writeFloat(minValue); out.writeFloat(maxValue); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 4c28f7096..3c00134a4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.quantization; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; @@ -27,7 +28,6 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.util.Objects; import java.util.concurrent.ForkJoinPool; @@ -134,7 +134,7 @@ public static ImmutablePQVectors encodeAndBuild(ProductQuantization pq, int vect } @Override - public void write(DataOutput out, int version) throws IOException + public void write(IndexWriter out, int version) throws IOException { // pq codebooks pq.write(out, version); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index 79042666b..963cef0c8 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -17,6 +17,7 @@ package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.annotations.VisibleForTesting; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; @@ -29,11 +30,11 @@ import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; import java.util.List; import java.util.Objects; +import java.util.concurrent.Callable; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicReference; @@ -183,12 +184,13 @@ public ProductQuantization refine(RandomAccessVectorValues ravv, } var vectors = vectorsMutable; // "effectively final" to make the closure happy - var refinedCodebooks = simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> { + Callable[]> callable = () -> IntStream.range(0, M).parallel().mapToObj(m -> { VectorFloat[] subvectors = extractSubvectors(vectors, m, subvectorSizesAndOffsets); var clusterer = new KMeansPlusPlusClusterer(subvectors, codebooks[m], anisotropicThreshold); return clusterer.cluster(anisotropicThreshold == UNWEIGHTED ? lloydsRounds : 0, anisotropicThreshold == UNWEIGHTED ? 0 : lloydsRounds); - }).toArray(VectorFloat[]::new)).join(); + }).toArray(VectorFloat[]::new); + var refinedCodebooks = simdExecutor.submit(callable).join(); return new ProductQuantization(refinedCodebooks, clusterCount, subvectorSizesAndOffsets, globalCentroid, anisotropicThreshold); } @@ -459,11 +461,12 @@ public int getClusterCount() { static VectorFloat[] createCodebooks(List> vectors, int[][] subvectorSizeAndOffset, int clusters, float anisotropicThreshold, ForkJoinPool simdExecutor) { int M = subvectorSizeAndOffset.length; - return simdExecutor.submit(() -> IntStream.range(0, M).parallel().mapToObj(m -> { + Callable[]> callable = () -> IntStream.range(0, M).parallel().mapToObj(m -> { VectorFloat[] subvectors = extractSubvectors(vectors, m, subvectorSizeAndOffset); var clusterer = new KMeansPlusPlusClusterer(subvectors, clusters, anisotropicThreshold); return clusterer.cluster(K_MEANS_ITERATIONS, anisotropicThreshold == UNWEIGHTED ? 0 : K_MEANS_ITERATIONS); - }).toArray(VectorFloat[]::new)).join(); + }).toArray(VectorFloat[]::new); + return simdExecutor.submit(callable).join(); } /** @@ -529,7 +532,7 @@ AtomicReference> partialSquaredMagnitudes() { return partialSquaredMagnitudes; } - public void write(DataOutput out, int version) throws IOException + public void write(IndexWriter out, int version) throws IOException { if (version > OnDiskGraphIndex.CURRENT_VERSION) { throw new IllegalArgumentException("Unsupported serialization version " + version); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java index d7529498b..c5708716c 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java @@ -16,13 +16,13 @@ package io.github.jbellis.jvector.quantization; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.DataOutput; import java.io.IOException; import java.util.List; import java.util.concurrent.ForkJoinPool; @@ -55,13 +55,13 @@ default CompressedVectors encodeAll(RandomAccessVectorValues ravv) { void encodeTo(VectorFloat v, T dest); /** - * @param out DataOutput to write to + * @param out IndexWriter to write to * @param version serialization version. Versions 2 and 3 are supported */ - void write(DataOutput out, int version) throws IOException; + void write(IndexWriter out, int version) throws IOException; /** Write with the current serialization version */ - default void write(DataOutput out) throws IOException { + default void write(IndexWriter out) throws IOException { write(out, OnDiskGraphIndex.CURRENT_VERSION); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java index 8f3e05db5..9985e946f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayVectorProvider.java @@ -17,12 +17,12 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.FloatBuffer; @@ -59,15 +59,10 @@ public void readFloatVector(RandomAccessReader r, int size, VectorFloat vecto } @Override - public void writeFloatVector(DataOutput out, VectorFloat vector) throws IOException + public void writeFloatVector(IndexWriter out, VectorFloat vector) throws IOException { ArrayVectorFloat v = (ArrayVectorFloat)vector; - // this seems to be the only way to avoid writing float-at-a-time which is far too slow - var fb = FloatBuffer.wrap(v.get()); - var bb = ByteBuffer.allocate(fb.capacity() * Float.BYTES); - bb.asFloatBuffer().put(fb); - bb.rewind(); - out.write(bb.array()); + out.writeFloats(v.get(), 0, v.length()); } @Override @@ -97,7 +92,7 @@ public void readByteSequence(RandomAccessReader r, ByteSequence sequence) thr } @Override - public void writeByteSequence(DataOutput out, ByteSequence sequence) throws IOException + public void writeByteSequence(IndexWriter out, ByteSequence sequence) throws IOException { ArrayByteSequence v = (ArrayByteSequence) sequence; out.write(v.get()); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java index 409389370..a4cb4f8af 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/VectorTypeSupport.java @@ -16,9 +16,9 @@ package io.github.jbellis.jvector.vector.types; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; -import java.io.DataOutput; import java.io.IOException; public interface VectorTypeSupport { @@ -57,12 +57,12 @@ public interface VectorTypeSupport { void readFloatVector(RandomAccessReader r, int size, VectorFloat vector, int offset) throws IOException; /** - * Write the given vector to the given DataOutput. + * Write the given vector to the given IndexWriter. * @param out the output to write the vector to. * @param vector the vector to write. * @throws IOException */ - void writeFloatVector(DataOutput out, VectorFloat vector) throws IOException; + void writeFloatVector(IndexWriter out, VectorFloat vector) throws IOException; /** * Create a sequence from the given data. @@ -83,5 +83,5 @@ public interface VectorTypeSupport { void readByteSequence(RandomAccessReader r, ByteSequence sequence) throws IOException; - void writeByteSequence(DataOutput out, ByteSequence sequence) throws IOException; + void writeByteSequence(IndexWriter out, ByteSequence sequence) throws IOException; } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index 7e368ed56..2ba49f62f 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.example; +import io.github.jbellis.jvector.disk.BufferedRandomAccessWriter; import io.github.jbellis.jvector.disk.ReaderSupplierFactory; import io.github.jbellis.jvector.example.benchmarks.AccuracyBenchmark; import io.github.jbellis.jvector.example.benchmarks.BenchmarkTablePrinter; @@ -54,8 +55,6 @@ import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.types.VectorFloat; -import java.io.BufferedOutputStream; -import java.io.DataOutputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.io.UncheckedIOException; @@ -651,7 +650,7 @@ private static VectorCompressor getCompressor(Function> baseVectors, List> baseVectors, List> baseVectors, List> baseVectors, List> baseVectors, List< .withMapper(new OrdinalMapper.IdentityMapper(baseVectors.size() - 1)) .build(); // output for the compressed vectors - DataOutputStream pqOut = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(pqPath)))) + IndexWriter pqWriter = new BufferedRandomAccessWriter(pqPath)) { // build the index vector-at-a-time (on disk) for (int ordinal = 0; ordinal < baseVectors.size(); ordinal++) { @@ -314,7 +314,7 @@ public static void siftDiskAnnLTMWithNVQ(List> baseVectors, List< // finish writing the index (by filling in the edge lists) and write our completed PQVectors writer.write(Map.of()); - pqv.write(pqOut); + pqv.write(pqWriter); } // searching the index does not change diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java index a098ced7a..1ce0d81b2 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentVectorProvider.java @@ -16,12 +16,12 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.disk.IndexWriter; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import io.github.jbellis.jvector.vector.types.VectorTypeSupport; -import java.io.DataOutput; import java.io.IOException; import java.nio.Buffer; @@ -60,7 +60,7 @@ public void readFloatVector(RandomAccessReader r, int count, VectorFloat vect } @Override - public void writeFloatVector(DataOutput out, VectorFloat vector) throws IOException + public void writeFloatVector(IndexWriter out, VectorFloat vector) throws IOException { for (int i = 0; i < vector.length(); i++) out.writeFloat(vector.get(i)); @@ -96,7 +96,7 @@ public void readByteSequence(RandomAccessReader r, ByteSequence sequence) thr @Override - public void writeByteSequence(DataOutput out, ByteSequence sequence) throws IOException + public void writeByteSequence(IndexWriter out, ByteSequence sequence) throws IOException { for (int i = 0; i < sequence.length(); i++) out.writeByte(sequence.get(i)); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java index 3c0c0660e..456ae2d18 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java @@ -20,15 +20,14 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.disk.SimpleWriter; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.junit.Test; -import java.io.DataOutputStream; import java.io.File; -import java.io.FileOutputStream; import java.util.List; import static io.github.jbellis.jvector.TestUtil.createRandomVectors; @@ -54,8 +53,8 @@ public void testSaveLoadPQ() throws Exception { // Write compressed vectors File cvFile = File.createTempFile("pqtest", ".cv"); - try (var out = new DataOutputStream(new FileOutputStream(cvFile))) { - cv.write(out); + try (var writer = new SimpleWriter(cvFile.toPath())) { + cv.write(writer); } // Read compressed vectors try (var readerSupplier = new SimpleMappedReader.Supplier(cvFile.toPath())) { @@ -79,8 +78,8 @@ public void testSaveLoadBQ() throws Exception { // Write compressed vectors File cvFile = File.createTempFile("bqtest", ".cv"); - try (var out = new DataOutputStream(new FileOutputStream(cvFile))) { - cv.write(out); + try (var writer = new SimpleWriter(cvFile.toPath())) { + cv.write(writer); } // Read compressed vectors try (var readerSupplier = new SimpleMappedReader.Supplier(cvFile.toPath())) { @@ -117,8 +116,8 @@ public void testSaveLoadNVQ() throws Exception { // Write compressed vectors File cvFile = File.createTempFile("bqtest", ".cv"); - try (var out = new DataOutputStream(new FileOutputStream(cvFile))) { - cv.write(out); + try (var writer = new SimpleWriter(cvFile.toPath())) { + cv.write(writer); } // Read compressed vectors try (var readerSupplier = new SimpleMappedReader.Supplier(cvFile.toPath())) { diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java index db35b52d1..992604b7e 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java @@ -19,6 +19,7 @@ import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.disk.SimpleMappedReader; +import io.github.jbellis.jvector.disk.SimpleWriter; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; @@ -28,9 +29,7 @@ import org.junit.Test; import org.junit.jupiter.api.Assertions; -import java.io.DataOutputStream; import java.io.File; -import java.io.FileOutputStream; import java.nio.file.Files; import java.util.Arrays; import java.util.List; @@ -203,8 +202,8 @@ public void testSaveLoad() throws Exception { // Write var file = File.createTempFile("pqtest", ".pq"); - try (var out = new DataOutputStream(new FileOutputStream(file))) { - pq.write(out); + try (var writer = new SimpleWriter(file.toPath())) { + pq.write(writer); } // Read try (var readerSupplier = new SimpleMappedReader.Supplier(file.toPath())) { @@ -237,8 +236,8 @@ public void testSaveVersion0() throws Exception { var pq = ProductQuantization.load(readerSupplier.get()); // re-save, emulating version 0 - try (var out = new DataOutputStream(new FileOutputStream(fileOut))) { - pq.write(out, 0); + try (var writer = new SimpleWriter(fileOut.toPath())) { + pq.write(writer, 0); } } @@ -352,7 +351,7 @@ public void testPQLayoutEdgeCases() { int[][] testCases = { // Minimal cases {1, 1}, {1, 2}, - + // Power-of-2 boundaries for compressedDimension (layoutBytesPerVector changes) {10, 1}, {10, 2}, {10, 3}, {10, 4}, {10, 5}, {10, 7}, {10, 8}, {10, 9}, @@ -360,20 +359,20 @@ public void testPQLayoutEdgeCases() { {10, 31}, {10, 32}, {10, 33}, {10, 63}, {10, 64}, {10, 65}, {10, 127}, {10, 128}, {10, 129}, - + // Cases where addressableVectorsPerChunk becomes interesting {1073741823, 1}, // layoutBytesPerVector=2, addressableVectorsPerChunk=1073741823 - {1073741823, 2}, // layoutBytesPerVector=4, addressableVectorsPerChunk=536870911 + {1073741823, 2}, // layoutBytesPerVector=4, addressableVectorsPerChunk=536870911 {1073741824, 2}, // vectorCount > addressableVectorsPerChunk, creates chunks - + // Large dimension cases (small addressableVectorsPerChunk) {1000, 1024}, // layoutBytesPerVector=2048, addressableVectorsPerChunk=1048575 {2000000, 1024}, // vectorCount > addressableVectorsPerChunk - + // Integer overflow boundary cases {536870911, 4}, // layoutBytesPerVector=8, exactly fits in one chunk {536870912, 4}, // one more than above, creates multiple chunks - + // Edge case where lastChunkVectors becomes non-zero {100, 1073741824} // layoutBytesPerVector huge, addressableVectorsPerChunk=1, creates 100 chunks };