From a6f505ccdc6ae1f21001daf23cfb4c81ce392f1c Mon Sep 17 00:00:00 2001 From: e-strauss <92718421+e-strauss@users.noreply.github.com> Date: Tue, 3 Dec 2024 00:58:31 +0100 Subject: [PATCH] missing method implementations in ColGroupSDCZeros --- .../compress/colgroup/ColGroupSDC.java | 4 +- .../compress/colgroup/ColGroupSDCZeros.java | 39 ++++++++++++++++--- .../colgroup/dictionary/ADictionary.java | 12 +++++- .../colgroup/dictionary/IDictionary.java | 15 ++++++- .../colgroup/dictionary/PlaceHolderDict.java | 8 +++- .../dictionary/PlaceHolderDictTest.java | 9 ++++- 6 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index ea4f2fb5811..e78bea93a2e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -683,7 +683,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re } else { while(c < points.length && points[c].o == of) { - _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + _dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); c++; } of = it.next(); @@ -696,7 +696,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re } while(of == last && c < points.length && points[c].o == of) { - _dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); + _dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes); c++; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index c1e081f2533..d0d8d160a5a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -836,7 +836,7 @@ public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, while(of < last && c < points.length) { if(points[c].o == of) { - c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex())); + c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex())); of = it.next(); } else if(points[c].o < of) @@ -848,23 +848,52 @@ else if(points[c].o < of) while(c < points.length && points[c].o < last) c++; - c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex())); + c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex())); } @Override protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { - throw new NotImplementedException(); + final DenseBlock dr = ret.getDenseBlock(); + final int nCol = _colIndexes.size(); + final AIterator it = _indexes.getIterator(); + final int last = _indexes.getOffsetToLast(); + int c = 0; + int of = it.value(); + + while(of < last && c < points.length) { + if(points[c].o == of) { + c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); + of = it.next(); + } + else if(points[c].o < of) + c++; + else + of = it.next(); + } + // increment the c pointer until it is pointing at least to last point or is done. + while(c < points.length && points[c].o < last) + c++; + c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); + } + + private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { + while(c < points.length && points[c].o == of) { + _dict.putSparse(sr, did, points[c].r, nCol, _colIndexes); + c++; + } + return c; } - private int processRow(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { + private int processRowDense(P[] points, final DenseBlock dr, final int nCol, int c, int of, final int did) { while(c < points.length && points[c].o == of) { - _dict.put(sr, did, points[c].r, nCol, _colIndexes); + _dict.putDense(dr, did, points[c].r, nCol, _colIndexes); c++; } return c; } + public String toString() { StringBuilder sb = new StringBuilder(); sb.append(super.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java index d41e2675f57..7d88573e3a4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/ADictionary.java @@ -22,6 +22,7 @@ import java.io.Serializable; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; @@ -87,8 +88,17 @@ public static void correctNan(double[] res, IColIndex colIndexes) { } @Override - public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { for(int i = 0; i < nCol; i++) sb.append(rowOut, columns.get(i), getValue(idx, i, nCol)); } + + @Override + public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex columns) { + double[] dv = dr.values(rowOut); + int off = dr.pos(rowOut); + for(int i = 0; i < nCol; i++) + dv[off + columns.get(i)] += getValue(idx, i, nCol); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index a7a74775be7..bfe4ef23c33 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.ValueFunction; @@ -989,6 +990,18 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef * @param nCol The number of columns in the dictionary * @param columns The columns to output into. */ - public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns); + public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns); + + /** + * Put the row specified into the sparse block, via append calls. + * + * @param db The dense block to put into + * @param idx The dictionary index to put in. + * @param rowOut The row in the sparse block to put it into + * @param nCol The number of columns in the dictionary + * @param columns The columns to output into. + */ + public void putDense(DenseBlock db, int idx, int rowOut, int nCol, IColIndex columns); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index 88a7be26194..68a3fb3fac2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -25,6 +25,7 @@ import java.io.Serializable; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.ValueFunction; @@ -526,7 +527,12 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex } @Override - public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { + throw new RuntimeException(errMessage); + } + + @Override + public void putDense(DenseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) { throw new RuntimeException(errMessage); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java index 88e5d8adcc3..5a112a800c3 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/dictionary/PlaceHolderDictTest.java @@ -490,8 +490,13 @@ public void MMDictScalingSparse() { } @Test(expected = Exception.class) - public void put() { - d.put(null, 1, 1, 1, null); + public void putDense() { + d.putDense(null, 1, 1, 1, null); + } + + @Test(expected = Exception.class) + public void putSparse() { + d.putSparse(null, 1, 1, 1, null); } @Test