Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re
}
else {
while(c < points.length && points[c].o == of) {
_dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
_dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
c++;
}
of = it.next();
Expand All @@ -696,7 +696,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re
}

while(of == last && c < points.length && points[c].o == of) {
_dict.put(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
_dict.putSparse(sr, _data.getIndex(it.getDataIndex()), points[c].r, nCol, _colIndexes);
c++;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,7 @@ public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret,

while(of < last && c < points.length) {
if(points[c].o == of) {
c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
of = it.next();
}
else if(points[c].o < of)
Expand All @@ -848,23 +848,52 @@ else if(points[c].o < of)
while(c < points.length && points[c].o < last)
c++;

c = processRow(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));
c = processRowSparse(points, sr, nCol, c, of, _data.getIndex(it.getDataIndex()));

}

@Override
protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) {
throw new NotImplementedException();
final DenseBlock dr = ret.getDenseBlock();
final int nCol = _colIndexes.size();
final AIterator it = _indexes.getIterator();
final int last = _indexes.getOffsetToLast();
int c = 0;
int of = it.value();

while(of < last && c < points.length) {
if(points[c].o == of) {
c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex()));
of = it.next();
}
else if(points[c].o < of)
c++;
else
of = it.next();
}
// increment the c pointer until it is pointing at least to last point or is done.
while(c < points.length && points[c].o < last)
c++;
c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex()));
}

private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) {
while(c < points.length && points[c].o == of) {
_dict.putSparse(sr, did, points[c].r, nCol, _colIndexes);
c++;
}
return c;
}

private int processRow(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) {
private int processRowDense(P[] points, final DenseBlock dr, final int nCol, int c, int of, final int did) {
while(c < points.length && points[c].o == of) {
_dict.put(sr, did, points[c].r, nCol, _colIndexes);
_dict.putDense(dr, did, points[c].r, nCol, _colIndexes);
c++;
}
return c;
}


public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(super.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.io.Serializable;

import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
Expand Down Expand Up @@ -87,8 +88,17 @@ public static void correctNan(double[] res, IColIndex colIndexes) {
}

@Override
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
for(int i = 0; i < nCol; i++)
sb.append(rowOut, columns.get(i), getValue(idx, i, nCol));
}

@Override
public void putDense(DenseBlock dr, int idx, int rowOut, int nCol, IColIndex columns) {
double[] dv = dr.values(rowOut);
int off = dr.pos(rowOut);
for(int i = 0; i < nCol; i++)
dv[off + columns.get(i)] += getValue(idx, i, nCol);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
Expand Down Expand Up @@ -989,6 +990,18 @@ public void TSMMToUpperTriangleSparseScaling(SparseBlock left, IColIndex rowsLef
* @param nCol The number of columns in the dictionary
* @param columns The columns to output into.
*/
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns);
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns);

/**
* Put the row specified into the sparse block, via append calls.
*
* @param db The dense block to put into
* @param idx The dictionary index to put in.
* @param rowOut The row in the sparse block to put it into
* @param nCol The number of columns in the dictionary
* @param columns The columns to output into.
*/
public void putDense(DenseBlock db, int idx, int rowOut, int nCol, IColIndex columns);


}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.io.Serializable;

import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
Expand Down Expand Up @@ -526,7 +527,12 @@ public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, IColIndex
}

@Override
public void put(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
public void putSparse(SparseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
throw new RuntimeException(errMessage);
}

@Override
public void putDense(DenseBlock sb, int idx, int rowOut, int nCol, IColIndex columns) {
throw new RuntimeException(errMessage);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,13 @@ public void MMDictScalingSparse() {
}

@Test(expected = Exception.class)
public void put() {
d.put(null, 1, 1, 1, null);
public void putDense() {
d.putDense(null, 1, 1, 1, null);
}

@Test(expected = Exception.class)
public void putSparse() {
d.putSparse(null, 1, 1, 1, null);
}

@Test
Expand Down
Loading