From 5654e0f65c216662e84f85fee3d212ff823de198 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 14 Aug 2024 11:57:33 +0200 Subject: [PATCH] JAVA 17 BWARE COMMIT debug no skip decodeRecode decoder passthough?? except hash columns do not subtract one to String inconsistency debug error remove hash columns in decode fix decode ? parallel convert decoder dummy sparse faster number parsing handle new exception number format exception reduce error logging fix ? check fix handle hash for init metadata decode dummy parallel decode +k for parallel cleanup GetCategorical Map k added stack quicksort dimensions log the failing operation support solve mote info Compressed Remove Empty decompress selection matrix fall back to uncompressed Compressed remove empty columns disable decompressing for not supported ops disable decompressing RMM for SDC disable the decompressing RMM --- .../org/apache/sysds/common/Builtins.java | 1 + .../java/org/apache/sysds/common/Opcodes.java | 2 + .../java/org/apache/sysds/common/Types.java | 1 + .../java/org/apache/sysds/hops/BinaryOp.java | 41 +- src/main/java/org/apache/sysds/hops/Hop.java | 11 + .../java/org/apache/sysds/hops/UnaryOp.java | 34 +- .../parser/BuiltinFunctionExpression.java | 16 + .../apache/sysds/parser/DMLTranslator.java | 3 + .../compress/CompressedMatrixBlock.java | 43 +- .../CompressedMatrixBlockFactory.java | 24 +- .../runtime/compress/colgroup/AColGroup.java | 85 +++- .../compress/colgroup/AColGroupValue.java | 3 +- .../compress/colgroup/ADictBasedColGroup.java | 1 + .../runtime/compress/colgroup/ASDCZero.java | 19 + .../compress/colgroup/ColGroupConst.java | 18 +- .../compress/colgroup/ColGroupDDC.java | 49 ++- .../compress/colgroup/ColGroupDDCFOR.java | 35 ++ .../compress/colgroup/ColGroupEmpty.java | 17 + .../runtime/compress/colgroup/ColGroupIO.java | 4 +- .../colgroup/ColGroupLinearFunctional.java | 15 + .../compress/colgroup/ColGroupOLE.java | 16 +- .../compress/colgroup/ColGroupRLE.java | 15 + .../compress/colgroup/ColGroupSDC.java | 68 +++- .../compress/colgroup/ColGroupSDCFOR.java | 63 +++ .../compress/colgroup/ColGroupSDCSingle.java | 65 ++- .../colgroup/ColGroupSDCSingleZeros.java | 75 +++- .../compress/colgroup/ColGroupSDCZeros.java | 98 ++++- .../colgroup/ColGroupUncompressed.java | 74 ++-- .../colgroup/ColGroupUncompressedArray.java | 16 + .../dictionary/AIdentityDictionary.java | 6 + .../colgroup/dictionary/DeltaDictionary.java | 11 + .../colgroup/dictionary/Dictionary.java | 72 ++++ .../colgroup/dictionary/IDictionary.java | 20 + .../dictionary/IdentityDictionary.java | 7 +- .../dictionary/IdentityDictionarySlice.java | 6 + .../dictionary/MatrixBlockDictionary.java | 43 +- .../colgroup/dictionary/PlaceHolderDict.java | 11 + .../colgroup/dictionary/QDictionary.java | 11 + .../compress/colgroup/mapping/AMapToData.java | 38 ++ .../compress/colgroup/offset/AIterator.java | 4 +- .../compress/colgroup/offset/AOffset.java | 50 ++- .../compress/colgroup/offset/OffsetEmpty.java | 4 + .../compress/lib/CLALibBinaryCellOp.java | 340 ++++++++++++---- .../runtime/compress/lib/CLALibCompAgg.java | 2 +- .../runtime/compress/lib/CLALibMMChain.java | 6 + .../compress/lib/CLALibRemoveEmpty.java | 107 +++++ .../compress/lib/CLALibRightMultBy.java | 14 +- .../runtime/compress/lib/CLALibSort.java | 37 ++ .../runtime/compress/lib/CLALibTSMM.java | 43 +- .../compress/utils/HashMapIntToInt.java | 380 ++++++++++++++++++ .../federated/FederatedWorkloadAnalyzer.java | 53 ++- .../frame/data/columns/DoubleArray.java | 2 +- .../frame/data/columns/HashMapToInt.java | 23 ++ .../frame/data/columns/StringArray.java | 4 +- .../frame/data/lib/MatrixBlockFromFrame.java | 64 ++- .../runtime/functionobjects/Builtin.java | 3 +- .../instructions/cp/BinaryCPInstruction.java | 2 + .../cp/BinaryFrameScalarCPInstruction.java | 130 ++++++ .../cp/BinaryMatrixMatrixCPInstruction.java | 9 +- .../cp/ParameterizedBuiltinCPInstruction.java | 2 +- .../cp/VariableCPInstruction.java | 3 +- .../data/LibAggregateUnarySpecialization.java | 148 +++++++ .../runtime/matrix/data/LibMatrixMult.java | 10 + .../runtime/matrix/data/LibMatrixReorg.java | 63 ++- .../runtime/matrix/data/MatrixBlock.java | 87 +++- .../runtime/transform/decode/Decoder.java | 32 +- .../runtime/transform/decode/DecoderBin.java | 68 +++- .../transform/decode/DecoderComposite.java | 32 +- .../transform/decode/DecoderDummycode.java | 117 ++++-- .../transform/decode/DecoderFactory.java | 28 +- .../transform/decode/DecoderPassThrough.java | 23 +- .../transform/decode/DecoderRecode.java | 49 ++- .../encode/ColumnEncoderFeatureHash.java | 6 +- .../org/apache/sysds/utils/DoubleParser.java | 9 +- .../java/org/apache/sysds/test/TestUtils.java | 20 + .../compress/CompressedMatrixTest.java | 50 +++ .../colgroup/ColGroupNegativeTests.java | 37 ++ .../encoding/EncodeSampleMultiColTest.java | 3 + .../encoding/EncodeSampleUnbalancedTest.java | 4 + .../compress/offset/CustomOffsetTest.java | 94 ++++- .../compress/configuration/CompressForce.java | 2 +- .../test/functions/misc/ToStringTest.java | 92 +++++ .../transform/GetCategoricalMaskTest.java | 167 ++++++++ .../scripts/functions/misc/ToString12.dml | 24 ++ .../transform/GetCategoricalMaskTest.dml | 37 ++ 85 files changed, 3207 insertions(+), 414 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java create mode 100644 src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java create mode 100644 src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java create mode 100644 src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java create mode 100644 src/test/scripts/functions/misc/ToString12.dml create mode 100644 src/test/scripts/functions/transform/GetCategoricalMaskTest.dml diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index 4feab311c76..f53e6c91d0a 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -154,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 251f773a18c..0419a0e056b 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -197,6 +197,8 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), + GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary), + //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..c2832aeb8cd 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -639,6 +639,7 @@ public enum OpOp2 { MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y + GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false), //quantization-fused compression UNION_DISTINCT(false); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index a3ddb45ea6d..73e3c5fac86 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -763,8 +763,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { setExecType(_etypeForced); @@ -812,18 +812,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && @@ -853,7 +863,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + + if( op == OpOp2.GET_CATEGORICAL_MASK) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 86749d44c1c..675fbb380a1 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1045,6 +1045,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1629,6 +1635,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 34da36dd13c..e16896b869b 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -463,6 +467,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 28f6949f722..ab0c7993b4e 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("The compress or decompress instruction is not allowed in dml scripts"); break; + case GET_CATEGORICAL_MASK: + checkNumParameters(2); + checkFrameParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + output.setDataType(DataType.MATRIX); + output.setDimensions(1, -1); + output.setBlocksize( id.getBlocksize()); + output.setValueType(ValueType.FP64); + break; case QUANTIZE_COMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { checkNumParameters(2); @@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS); } } + + protected void checkFrameParam(Expression e) { + if(e.getOutput().getDataType() != DataType.FRAME) { + raiseValidateError("Expecting frame parameter for function " + getOpCode(), false, + LanguageErrorCodes.UNSUPPORTED_PARAMETERS); + } + } protected void checkMatrixScalarParam(Expression e) { //always unconditional if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 092fbffe36d..949e67a62cc 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2821,6 +2821,9 @@ else if ( in.length == 2 ) DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; + case GET_CATEGORICAL_MASK: + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2); + break; default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java index 48637595741..ef5d2630390 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlock.java @@ -58,12 +58,14 @@ import org.apache.sysds.runtime.compress.lib.CLALibMMChain; import org.apache.sysds.runtime.compress.lib.CLALibMatrixMult; import org.apache.sysds.runtime.compress.lib.CLALibMerge; -import org.apache.sysds.runtime.compress.lib.CLALibReplace; +import org.apache.sysds.runtime.compress.lib.CLALibRemoveEmpty; import org.apache.sysds.runtime.compress.lib.CLALibReorg; +import org.apache.sysds.runtime.compress.lib.CLALibReplace; import org.apache.sysds.runtime.compress.lib.CLALibReshape; import org.apache.sysds.runtime.compress.lib.CLALibRexpand; import org.apache.sysds.runtime.compress.lib.CLALibScalar; import org.apache.sysds.runtime.compress.lib.CLALibSlice; +import org.apache.sysds.runtime.compress.lib.CLALibSort; import org.apache.sysds.runtime.compress.lib.CLALibSquash; import org.apache.sysds.runtime.compress.lib.CLALibTSMM; import org.apache.sysds.runtime.compress.lib.CLALibTernaryOp; @@ -101,6 +103,7 @@ import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.utils.DMLCompressionStatistics; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; +import org.apache.sysds.utils.stats.Timing; public class CompressedMatrixBlock extends MatrixBlock { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlock.class.getName()); @@ -475,16 +478,20 @@ public void readFields(DataInput in) throws IOException { } public static CompressedMatrixBlock read(DataInput in) throws IOException { + Timing t = new Timing(); int rlen = in.readInt(); int clen = in.readInt(); long nonZeros = in.readLong(); boolean overlappingColGroups = in.readBoolean(); List groups = ColGroupIO.readGroups(in, rlen); - return new CompressedMatrixBlock(rlen, clen, nonZeros, overlappingColGroups, groups); + CompressedMatrixBlock ret = new CompressedMatrixBlock(rlen, clen, nonZeros, overlappingColGroups, groups); + LOG.debug("Compressed read serialization time: " + t.stop()); + return ret; } @Override public void write(DataOutput out) throws IOException { + Timing t = new Timing(); final long estimateUncompressed = nonZeros > 0 ? MatrixBlock.estimateSizeOnDisk(rlen, clen, nonZeros) : Long.MAX_VALUE; final long estDisk = nonZeros > 0 ? getExactSizeOnDisk() : Long.MAX_VALUE; @@ -512,6 +519,7 @@ public void write(DataOutput out) throws IOException { out.writeLong(nonZeros); out.writeBoolean(overlappingColGroups); ColGroupIO.writeGroups(out, _colGroups); + LOG.debug("Compressed write serialization time: " + t.stop()); } /** @@ -611,14 +619,6 @@ public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixVal public MatrixBlock transposeSelfMatrixMultOperations(MatrixBlock out, MMTSJType tstype, int k) { // check for transpose type if(tstype == MMTSJType.LEFT) { - if(isEmpty()) - return new MatrixBlock(clen, clen, true); - // create output matrix block - if(out == null) - out = new MatrixBlock(clen, clen, false); - else - out.reset(clen, clen, false); - out.allocateDenseBlock(); CLALibTSMM.leftMultByTransposeSelf(this, out, k); return out; } @@ -846,9 +846,8 @@ public CM_COV_Object covOperations(COVOperator op, MatrixBlock that, MatrixBlock } @Override - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { - MatrixBlock right = getUncompressed(weights); - return getUncompressed("sortOperations").sortOperations(right, result); + public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result, int k) { + return CLALibSort.sort(this, weights, result, k); } @Override @@ -871,9 +870,7 @@ public MatrixBlock groupedAggOperations(MatrixValue tgt, MatrixValue wghts, Matr @Override public MatrixBlock removeEmptyOperations(MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select) { - printDecompressWarning("removeEmptyOperations"); - MatrixBlock tmp = getUncompressed(); - return tmp.removeEmptyOperations(ret, rows, emptyReturn, select); + return CLALibRemoveEmpty.rmempty(this, ret, rows, emptyReturn, select); } @Override @@ -1202,8 +1199,8 @@ public void examSparsity(boolean allowCSR, int k) { } @Override - public void sparseToDense(int k) { - // do nothing + public MatrixBlock sparseToDense(int k) { + return this; // do nothing } @Override @@ -1236,16 +1233,6 @@ public double interQuartileMean() { return getUncompressed("interQuartileMean").interQuartileMean(); } - @Override - public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { - return getUncompressed("pickValues").pickValues(quantiles, ret); - } - - @Override - public double pickValue(double quantile, boolean average) { - return getUncompressed("pickValue").pickValue(quantile, average); - } - @Override public double sumWeightForQuantile() { return getUncompressed("sumWeightForQuantile").sumWeightForQuantile(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 4c48effb4df..f082d1ffc3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -64,6 +64,8 @@ public class CompressedMatrixBlockFactory { private static final Log LOG = LogFactory.getLog(CompressedMatrixBlockFactory.class.getName()); + private static final Object asyncCompressLock = new Object(); + /** Timing object to measure the time of each phase in the compression */ private final Timing time = new Timing(true); /** Compression statistics gathered throughout the compression */ @@ -181,21 +183,23 @@ public static Future compressAsync(ExecutionContext ec, String varName) { } public static Future compressAsync(ExecutionContext ec, String varName, InstructionTypeCounter ins) { - LOG.debug("Compressing Async"); final ExecutorService pool = CommonThreadPool.get(); // We have to guarantee that a thread pool is allocated. return CompletableFuture.runAsync(() -> { // method call or code to be async try { CacheableData data = ec.getCacheableData(varName); - if(data instanceof MatrixObject) { - MatrixObject mo = (MatrixObject) data; - MatrixBlock mb = mo.acquireReadAndRelease(); - MatrixBlock mbc = CompressedMatrixBlockFactory.compress(mo.acquireReadAndRelease(), ins).getLeft(); - if(mbc instanceof CompressedMatrixBlock) { - ExecutionContext.createCacheableData(mb); - mo.acquireModify(mbc); - mo.release(); - mbc.sum(); // calculate sum to forcefully materialize counts + synchronized(asyncCompressLock){ // synchronize on the data object to not allow multiple compressions of the same matrix. + if(data instanceof MatrixObject) { + LOG.debug("Compressing Async"); + MatrixObject mo = (MatrixObject) data; + MatrixBlock mb = mo.acquireReadAndRelease(); + MatrixBlock mbc = CompressedMatrixBlockFactory.compress(mb, ins).getLeft(); + if(mbc instanceof CompressedMatrixBlock) { + ExecutionContext.createCacheableData(mb); + mo.acquireModify(mbc); + mo.release(); + mbc.sum(); // calculate sum to forcefully materialize counts + } } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index ec502d6d122..fd59447b3d9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -29,9 +29,9 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex.SliceResult; @@ -41,6 +41,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -401,8 +402,9 @@ public final AColGroup rightMultByMatrix(MatrixBlock right) { * @param cru The right hand side column upper * @param nRows The number of rows in this column group */ - public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru){ - throw new NotImplementedException("not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); + public void rightDecompressingMult(MatrixBlock right, MatrixBlock ret, int rl, int ru, int nRows, int crl, int cru) { + throw new NotImplementedException( + "not supporting right Decompressing Multiply on class: " + this.getClass().getSimpleName()); } /** @@ -806,7 +808,7 @@ public final void selectionMultiply(MatrixBlock selection, P[] points, MatrixBlo else denseSelection(selection, points, ret, rl, ru); } - + /** * Get an approximate sparsity of this column group * @@ -972,6 +974,15 @@ public AColGroup[] splitReshapePushDown(final int multiplier, final int nRow, fi return splitReshape(multiplier, nRow, nColOrg); } + /** + * Sort the values of the column group according to double < > operations and return as another compressed group. + * + * This sorting assumes that the column group is sorted independently of everything else. + * + * @return The sorted group + */ + public abstract AColGroup sort(); + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -981,4 +992,70 @@ public String toString() { sb.append(_colIndexes); return sb.toString(); } + + /** + * Return a new column group containing only the selected rows in the given boolean vector. + * + * Whenever possible only modify the index structure, not the dictionary of the column groups. + * + * @param selectV The selection vector + * @param rOut The number of rows in the output + * @return The new column group + */ + public abstract AColGroup removeEmptyRows(boolean[] selectV, int rOut); + + /** + * Return a new column group containing only the selected columns in the given boolean vector. + * + * Whenever possible only modify the column index, and reduce the dictionaries of the column groups. + * + * @param selectV The selection vector + * @return The new column group + */ + public AColGroup removeEmptyCols(boolean[] selectV) { + if(!inSelection(selectV)) + return null; + + final IntArrayList selectedColumns = new IntArrayList(); + final IntArrayList newIDs = new IntArrayList(); + int idx = 0; + int idxOwn = 0; + final int end = Math.min(selectV.length, _colIndexes.get(_colIndexes.size() - 1) + 1); + for(int i = 0; i < end; i++) { + + if(i == _colIndexes.get(idxOwn)) { + if(selectV[i]) { + selectedColumns.appendValue(idxOwn); + newIDs.appendValue(idx); + } + idxOwn++; + } + if(selectV[i]) + idx++; + } + + final IColIndex newColumnIDs = ColIndexFactory.create(newIDs); + if(newColumnIDs.size() == _colIndexes.size()) + return copyAndSet(newColumnIDs); + else + return removeEmptyColsSubset(newColumnIDs, selectedColumns); + } + + /** + * Using the selection of columns, slice out those and return in a new column group with the given column indexes. + * Ideally this method should only modify the dictionaries. + * + * @param newColumnIDs the new column indexes + * @param selectedColumns The selected columns of this column group (guaranteed < current number of columns) + * @return A new Column group + */ + protected abstract AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns); + + private boolean inSelection(boolean[] selection) { + for(int i = 0; i < _colIndexes.size(); i++) { + if(selection[_colIndexes.get(i)]) + return true; + } + return false; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java index 0cde289b30f..4f53d8b912b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroupValue.java @@ -59,8 +59,6 @@ public int getNumValues() { * produce an overhead in cases where the count is calculated, but the overhead will be limited to number of distinct * tuples in the dictionary. * - * The returned counts always contains the number of zero tuples as well if there are some contained, even if they - * are not materialized. * * @return The count of each value in the MatrixBlock. */ @@ -212,6 +210,7 @@ public void clear() { counts = null; } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java index 8f2f0b46055..d114f029df8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.java @@ -402,4 +402,5 @@ protected IDictionary combineDictionaries(int nCol, List right) { public double getSparsity() { return _dict.getSparsity(); } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 3de98a1c23f..30de5e120c5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -203,6 +203,22 @@ private final void leftMultByMatrixNoPreAggRowsDense(MatrixBlock mb, double[] re */ protected abstract void multiplyScalar(double v, double[] resV, int offRet, AIterator it); + public void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC, AIterator it) { + if(_dict instanceof MatrixBlockDictionary) { + final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; + final MatrixBlock mb = md.getMatrixBlock(); + // The dictionary is never empty. + if(mb.isInSparseFormat()) + // TODO make sparse decompression where the iterator is known in argument + decompressToSparseBlockSparseDictionary(sb, rl, ru, offR, offC, mb.getSparseBlock()); + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, mb.getDenseBlockValues(), + it); + } + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(sb, rl, ru, offR, offC, _dict.getValues(), it); + } + public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC, AIterator it) { if(_dict instanceof MatrixBlockDictionary) { final MatrixBlockDictionary md = (MatrixBlockDictionary) _dict; @@ -223,6 +239,9 @@ public void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, _dict.getValues(), it); } + public abstract void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock db, int rl, int ru, + int offR, int offC, double[] values, AIterator it); + public abstract void decompressToDenseBlockDenseDictionaryWithProvidedIterator(DenseBlock db, int rl, int ru, int offR, int offC, double[] values, AIterator it); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 21c6a0e1d80..cd0f52ebb8d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -46,6 +46,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -527,7 +528,7 @@ public CM_COV_Object centralMoment(CMOperator op, int nRows) { @Override public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); - if(d == null){ + if(d == null) { if(max <= 0) return null; return ColGroupEmpty.create(max); @@ -758,4 +759,19 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) protected boolean allowShallowIdentityRightMult() { return true; } + + @Override + public AColGroup sort() { + return this; + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return this; + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupConst.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols())); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index fc82c58e16b..40b5d41e3ea 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -26,8 +26,6 @@ import java.util.List; import java.util.concurrent.ExecutorService; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; @@ -52,6 +50,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -67,6 +66,9 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; import org.jboss.netty.handler.codec.compression.CompressionException; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + /** * Class to encapsulate information about a column group that is encoded with dense dictionary encoding (DDC). */ @@ -668,7 +670,8 @@ private void defaultRightDecompressingMult(MatrixBlock right, MatrixBlock ret, i } } - final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, DoubleVector vVec) { + final void vectMM(double aa, double[] b, double[] c, int endT, int jd, int crl, int cru, int offOut, int k, int vLen, + DoubleVector vVec) { vVec = vVec.broadcast(aa); final int offj = k * jd; final int end = endT + offj; @@ -1091,6 +1094,41 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E return res; } + @Override + public AColGroup sort() { + // TODO restore support for run length encoding to exploit the runs + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDC.create(_colIndexes, _dict, m, counts); + + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return ColGroupDDC.create(_colIndexes, _dict, _data.removeEmpty(selectV, rOut), null); + } + + @Override + protected boolean allowShallowIdentityRightMult() { + return true; + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupDDC.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols()), _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -1100,9 +1138,4 @@ public String toString() { return sb.toString(); } - @Override - protected boolean allowShallowIdentityRightMult() { - return true; - } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 70191a27936..3f30dc8f0aa 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -546,6 +547,40 @@ protected boolean allowShallowIdentityRightMult() { return false; } + @Override + public AColGroup sort() { + // TODO restore support for run length encoding. + + int[] counts = getCounts(); + // get the sort index + int[] r = _dict.sort(); + + AMapToData m = MapToFactory.create(_data.size(), counts.length); + int off = 0; + for(int i = 0; i < counts.length; i++) { + for(int j = 0; j < counts[r[i]]; j++) { + m.set(off++, r[i]); + } + } + + return ColGroupDDCFOR.create(_colIndexes, _dict, m, counts, _reference); + + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + return ColGroupDDCFOR.create(_colIndexes, _dict, _data.removeEmpty(selectV, rOut), null, _reference); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _reference[selectedColumns.get(i)]; + } + return ColGroupDDCFOR.create(newColumnIDs, _dict.sliceColumns(selectedColumns, getNumCols()), _data, null, ref); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index ba547a8d7aa..aa4d8428dd1 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -476,4 +477,20 @@ public AColGroup combineWithSameIndex(int nRow, int nCol, List right) return new ColGroupEmpty(combinedIndex); } + + @Override + public AColGroup sort(){ + return this; + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut){ + return this; + } + + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + return new ColGroupEmpty(newColumnIDs); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java index 91442281317..1091ae36890 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java @@ -94,9 +94,7 @@ public static long getExactSizeOnDisk(List colGroups) { } ret += grp.getExactSizeOnDisk(); } - if(LOG.isWarnEnabled()) - LOG.warn(" duplicate dicts on exact Size on Disk : " + (colGroups.size() - dicts.size()) ); - + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java index 45b4fbeb026..ed4a8d03030 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java @@ -32,6 +32,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -740,4 +741,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index ea6d0f34c2a..dded0e9f520 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -26,15 +26,16 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.bitmap.ABitmap; -import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; +import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -731,5 +732,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 2b4b23792e3..560af40bcf4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.RLEScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -1190,4 +1191,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + throw new NotImplementedException(); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index 1270823bfdc..8e4d23baaa8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -42,6 +42,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -508,10 +509,10 @@ protected static AColGroup rexpandCols(int max, boolean ignore, boolean cast, in AOffset indexes, AMapToData data, int[] counts, int def, int nVal) { if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -873,6 +874,69 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDC.create(_colIndexes, _numRows, _dict, _defaultTuple, o, m, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDC.create(_colIndexes, rOut, _dict, _defaultTuple, offsetTmp.retOffset, nm, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _defaultTuple[selectedColumns.get(i)]; + } + return ColGroupSDC.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), ref, + _indexes, _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 41fb7ac5709..15661e86ad0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -39,6 +39,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -620,6 +621,68 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCFOR.create(_colIndexes, _numRows, _dict, o, m, counts, _reference); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDCFOR.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, nm, null, _reference); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _reference[selectedColumns.get(i)]; + } + return ColGroupSDCFOR.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), _indexes, _data, null, + ref); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index fa5772c0c3e..ed0cafd07b4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; @@ -469,10 +470,10 @@ public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { IDictionary d = _dict.rexpandCols(max, ignore, cast, _colIndexes.size()); final int def = (int) _defaultTuple[0]; if(d == null) { - if(def <= 0){ + if(def <= 0) { if(max > 0) return ColGroupEmpty.create(max); - else + else return null; } else if(def > max && max > 0) @@ -718,6 +719,66 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + final double def = _defaultTuple[0]; + int defIdx = counts.length; + int nondefault = 0; + for(int i = 0; i < r.length; i++) { + if(defIdx == counts.length && _dict.getValue(r[i], 0, 1) >= def) { + defIdx = i; + } + nondefault += counts[i]; + } + + int defaultLength = _numRows - nondefault; + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingle.create(_colIndexes, _numRows, _dict, _defaultTuple, o, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO optimize by not constructing boolean array. + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + return ColGroupSDCSingle.create(_colIndexes, rOut, _dict, _defaultTuple, offsetTmp.retOffset, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] ref = new double[selectedColumns.size()]; + for(int i = 0; i < selectedColumns.size(); i++) { + ref[i] = _defaultTuple[selectedColumns.get(i)]; + } + return ColGroupSDCSingle.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), ref, + _indexes, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index 9efd0c41098..79db48492bf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -40,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; @@ -109,10 +110,8 @@ protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int return; else if(it.value() >= ru) return; - // _indexes.cacheIterator(it, ru); else { decompressToDenseBlockDenseDictionaryWithProvidedIterator(db, rl, ru, offR, offC, values, it); - // _indexes.cacheIterator(it, ru); } } @@ -238,7 +237,8 @@ protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); + return; + // _indexes.cacheIterator(it, ru); else if(ru > last) { final int apos = sb.pos(0); final int alen = sb.size(0) + apos; @@ -277,8 +277,15 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int nCol = _colIndexes.size(); final int lastOff = _indexes.getOffsetToLast(); int row = offR + it.value(); @@ -963,7 +970,7 @@ protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock re protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { throw new NotImplementedException(); } - + protected void decompressToDenseBlockTransposedSparseDictionary(DenseBlock db, int rl, int ru, SparseBlock sb) { throw new NotImplementedException(); } @@ -1043,6 +1050,62 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + int nondefault = 0; + for(int i = 0; i < r.length; i++) { + if(defIdx == counts.length && _dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + } + nondefault += counts[i]; + } + + int defaultLength = _numRows - nondefault; + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCSingleZeros.create(_colIndexes, _numRows, _dict, o, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO optimize by not constructing boolean array. + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + return ColGroupSDCSingleZeros.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + + return ColGroupSDCSingleZeros.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), + _indexes, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 69e0f776383..f3c82684f99 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -45,6 +45,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -184,8 +185,7 @@ private final void decompressToDenseBlockDenseDictionaryPostAllCols(DenseBlock d final double[] c = db.values(idx); final int off = db.pos(idx); final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); if(it.value() == lastOff) return; it.next(); @@ -301,13 +301,19 @@ private void decompressToDenseBlockDenseDictionaryPreAllCols(DenseBlock db, int final double[] c = db.values(idx); final int off = db.pos(idx) + offC; final int offDict = _data.getIndex(it.getDataIndex()) * nCol; - for(int j = 0; j < nCol; j++) - c[off + j] += values[offDict + j]; + decompressSingleRow(values, nCol, c, off, offDict); it.next(); } } + private static void decompressSingleRow(double[] values, final int nCol, final double[] c, final int off, + final int offDict) { + final int end = nCol + off; + for(int j = off, k = offDict; j < end; j++, k++) + c[j] += values[k]; + } + @Override protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) { @@ -438,8 +444,16 @@ protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, i if(it == null) return; else if(it.value() >= ru) - _indexes.cacheIterator(it, ru); - else if(ru > _indexes.getOffsetToLast()) { + return; + else + decompressToSparseBlockDenseDictionaryWithProvidedIterator(ret, rl, ru, offR, offC, values, it); + + } + + @Override + public void decompressToSparseBlockDenseDictionaryWithProvidedIterator(SparseBlock ret, int rl, int ru, int offR, + int offC, double[] values, final AIterator it) { + if(ru > _indexes.getOffsetToLast()) { final int lastOff = _indexes.getOffsetToLast(); final int nCol = _colIndexes.size(); while(true) { @@ -467,7 +481,6 @@ else if(ru > _indexes.getOffsetToLast()) { } _indexes.cacheIterator(it, ru); } - } @Override @@ -899,7 +912,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return super.morph(ct, nRow); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { final SparseBlock sr = ret.getSparseBlock(); @@ -942,14 +954,14 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret of = it.next(); } else if(points[c].o < of) - c++; + c++; else of = it.next(); - } - // increment the c pointer until it is pointing at least to last point or is done. - while(c < points.length && points[c].o < last) - c++; - c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); + } + // increment the c pointer until it is pointing at least to last point or is done. + while(c < points.length && points[c].o < last) + c++; + c = processRowDense(points, dr, nCol, c, of, _data.getIndex(it.getDataIndex())); } private int processRowSparse(P[] points, final SparseBlock sr, final int nCol, int c, int of, final int did) { @@ -1078,6 +1090,64 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { return res; } + @Override + public AColGroup sort() { + if(getNumCols() > 1) + throw new NotImplementedException(); + // TODO restore support for run length encoding. + + final int[] counts = getCounts(); + // get the sort index + final int[] r = _dict.sort(); + + // find default value position. + // todo use binary search for minor improvements. + int defIdx = counts.length; + for(int i = 0; i < r.length; i++) { + if(_dict.getValue(r[i], 0, 1) >= 0) { + defIdx = i; + break; + } + } + + int nondefault = _data.size(); + int defaultLength = _numRows - nondefault; + AMapToData m = MapToFactory.create(nondefault, counts.length); + int[] offsets = new int[nondefault]; + + int off = 0; + for(int i = 0; i < counts.length; i++) { + if(i < defIdx) { + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off; + m.set(off++, r[i]); + } + } + else {// if( i >= defIdx){ + for(int j = 0; j < counts[r[i]]; j++) { + offsets[off] = off + defaultLength; + m.set(off++, r[i]); + } + } + } + + AOffset o = OffsetFactory.createOffset(offsets); + return ColGroupSDCZeros.create(_colIndexes, _numRows, _dict, o, m, counts); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + final RemoveEmptyOffsetsTmp offsetTmp = _indexes.removeEmptyRows(selectV, rOut); + final AMapToData nm = _data.removeEmpty(offsetTmp.select); + return ColGroupSDCZeros.create(_colIndexes, rOut, _dict, offsetTmp.retOffset, nm, null); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + return ColGroupSDCZeros.create(newColumnIDs, _numRows, _dict.sliceColumns(selectedColumns, getNumCols()), + _indexes, _data, null); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 1c3bce2e16c..8efffc4878c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.EstimationFactors; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; @@ -82,7 +83,8 @@ public class ColGroupUncompressed extends AColGroup { /** * Do not use this constructor of column group uncompressed, instead use the create constructor. - * @param mb The contained data. + * + * @param mb The contained data. * @param colIndexes Column indexes for this Columngroup */ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { @@ -92,14 +94,15 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { /** * Do not use this constructor of column group quantization-fused uncompressed, instead use the create constructor. - * @param mb The contained data. - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix - * @param colIndexes Column indexes for this Columngroup + * + * @param mb The contained data. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Column indexes for this Columngroup */ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { super(colIndexes); - // Apply scaling and flooring - // TODO: Use internal matrix prod + // Apply scaling and flooring + // TODO: Use internal matrix prod for(int r = 0; r < mb.getNumRows(); r++) { double scaleFactor = scaleFactors.length == 1 ? scaleFactors[0] : scaleFactors[r]; for(int c = 0; c < mb.getNumColumns(); c++) { @@ -108,7 +111,8 @@ protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes, double[] sc } } _data = mb; - } + } + /** * Create an Uncompressed Matrix Block, where the columns are offset by col indexes. * @@ -130,9 +134,9 @@ public static AColGroup create(MatrixBlock mb, IColIndex colIndexes) { * * It is assumed that the size of the colIndexes and number of columns in mb is matching. * - * @param mb The MB / data to contain in the uncompressed column - * @param colIndexes The column indexes for the group - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param mb The MB / data to contain in the uncompressed column + * @param colIndexes The column indexes for the group + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix * @return An Uncompressed Column group */ public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, double[] scaleFactors) { @@ -147,14 +151,15 @@ public static AColGroup createQuantized(MatrixBlock mb, IColIndex colIndexes, do /** * Main constructor for a quantization-fused uncompressed ColGroup. * - * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. - * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is - * called - * @param transposed Says if the input matrix raw block have been transposed. - * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix + * @param colIndexes Indices (relative to the current block) of the columns that this column group represents. + * @param rawBlock The uncompressed block; uncompressed data must be present at the time that the constructor is + * called + * @param transposed Says if the input matrix raw block have been transposed. + * @param scaleFactors For quantization-fused compression, scale factors per row, or a single value for entire matrix * @return AColGroup. */ - public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, double[] scaleFactors) { + public static AColGroup createQuantized(IColIndex colIndexes, MatrixBlock rawBlock, boolean transposed, + double[] scaleFactors) { // special cases if(rawBlock.isEmptyBlock(false)) // empty input @@ -187,22 +192,24 @@ else if(!transposed && colIndexes.size() == rawBlock.getNumColumns()) final int n = colIndexes.size(); if(transposed) { - if (scaleFactors.length == 1) { + if(scaleFactors.length == 1) { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); - } else { + } + else { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[j])); } } else { - if (scaleFactors.length == 1) { + if(scaleFactors.length == 1) { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[0])); - } else { + } + else { for(int i = 0; i < m; i++) for(int j = 0; j < n; j++) mb.appendValue(i, j, Math.floor(rawBlock.get(i, colIndexes.get(j)) * scaleFactors[i])); @@ -1075,7 +1082,6 @@ public AColGroup morph(CompressionType ct, int nRow) { return comp.get(0).copyAndSet(_colIndexes); } - @Override public void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { if(_data.isInSparseFormat()) @@ -1092,7 +1098,6 @@ protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret denseSelectionDenseColumnGroup(selection, ret, rl, ru); } - private void sparseSelectionSparseColumnGroup(MatrixBlock selection, MatrixBlock ret, int rl, int ru) { final SparseBlock sb = selection.getSparseBlock(); @@ -1192,7 +1197,7 @@ public AColGroup reduceCols() { else return new ColGroupUncompressed(mb, ColIndexFactory.createI(0)); } - + @Override public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { if(_data.isInSparseFormat()) @@ -1289,11 +1294,30 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { for(int i = 0; i < multiplier; i++) for(int j = 0; j < s; j++) newColumns[i * s + j] = _colIndexes.get(j) + nColOrg * i; - MatrixBlock newData = _data.reshape(nRow/ multiplier, s * multiplier, true); - return new AColGroup[]{create(newData,ColIndexFactory.create(newColumns))}; + MatrixBlock newData = _data.reshape(nRow / multiplier, s * multiplier, true); + return new AColGroup[] {create(newData, ColIndexFactory.create(newColumns))}; // throw new NotImplementedException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort() { + return new ColGroupUncompressed(_data.sortOperations(), _colIndexes); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + MatrixBlock tmp = new MatrixBlock(); + tmp = LibMatrixReorg.removeEmptyRows(_data, tmp, false, false, selectV, rOut); + return ColGroupUncompressed.create(_colIndexes, tmp, false); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + double[] vals = MatrixBlockDictionary.sliceColumns(_data, selectedColumns); + MatrixBlock ret = new MatrixBlock(_data.getNumRows(), selectedColumns.size(), vals); + return ColGroupUncompressed.create(newColumnIDs, ret, false); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 31e29341645..a0a9dd46306 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -19,11 +19,13 @@ package org.apache.sysds.runtime.compress.colgroup; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -282,4 +284,18 @@ public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { throw new UnsupportedOperationException("Unimplemented method 'splitReshape'"); } + @Override + public AColGroup sort(){ + throw new NotImplementedException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + throw new NotImplementedException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns){ + throw new NotImplementedException("Unimplemented method 'removeEmptyColumns'"); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java index 17b382f06ad..a7e715b59b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/AIdentityDictionary.java @@ -19,6 +19,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; public abstract class AIdentityDictionary extends ACachingMBDictionary { @@ -74,4 +75,9 @@ public double[] productAllRowsToDoubleWithDefault(double[] defaultTuple) { ret[ret.length - 1] *= defaultTuple[i]; return ret; } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java index d67ab95f824..7bb6132deb4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DeltaDictionary.java @@ -23,6 +23,7 @@ import java.io.IOException; import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -121,4 +122,14 @@ public boolean equals(IDictionary o) { public IDictionary clone() { throw new NotImplementedException(); } + + @Override + public int[] sort(){ + throw new NotImplementedException(); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java index 939b48bf424..bd7dc98fe55 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java @@ -28,10 +28,13 @@ import java.util.HashSet; import java.util.Set; +import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -1341,4 +1344,73 @@ public IDictionary append(double[] row) { return new Dictionary(retV); } + @Override + public int[] sort() { + return sort(_values); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + // TODO: make specialized version for this. + return getMBDict(nCol).sliceColumns(selectedColumns, nCol); + } + + protected static int[] sort(double[] values) { + int[] indices = new int[values.length]; + for(int i = 0; i < indices.length; i++) { + indices[i] = i; + } + + // quicksort with stack + int[] stack = new int[values.length]; + + int top = -1; + stack[++top] = 0; + stack[++top] = values.length - 1; + + while(top >= 0) { + int high = stack[top--]; + int low = stack[top--]; + + if(low < high) { + + int pivotIndex = partition(indices, values, low, high); + // Left side + if(pivotIndex - 1 > low) { + stack[++top] = low; + stack[++top] = pivotIndex - 1; + } + + // Right side + if(pivotIndex + 1 < high) { + stack[++top] = pivotIndex + 1; + stack[++top] = high; + } + } + } + + return indices; + } + + private static int partition(int[] indices, double[] values, int low, int high) { + double pivotValue = values[indices[high]]; + int i = low - 1; + + for(int j = low; j < high; j++) { + if(values[indices[j]] <= pivotValue) { + i++; + swap(indices, i, j); + } + } + + swap(indices, i + 1, high); + return i + 1; + } + + private static void swap(int[] arr, int i, int j) { + int tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java index dddea0eec7a..28ec8ebf207 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java @@ -25,6 +25,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.functionobjects.Builtin; @@ -1051,4 +1052,23 @@ public IDictionary rightMMPreAggSparse(int numVals, SparseBlock b, IColIndex thi * @return The nonzero count of each column in the dictionary. */ public int[] countNNZZeroColumns(int[] counts); + + /** + * Sort the values of this dictionary via an index of how the values mapped previously. + * + * In practice this design means we can reuse the previous dictionary for the resulting column group + * + * @return The sorted index. + */ + public int[] sort(); + + /** + * Slice out the selected columns given of this encoded group. + * + * @param selectedColumns The columns to slice out and return as a new matrix. + * @param nCol The number of columns in this dictionary. + * @return The returned matrix + */ + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol); + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java index 40e1b065653..c2540de959a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java @@ -27,6 +27,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockFactory; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -540,9 +541,13 @@ public String getString(int colIndexes) { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + return getMBDict().sliceColumns(selectedColumns, nCol); + } + @Override public String toString() { return "IdentityMatrix of size: " + nRowCol + " with empty: " + withEmpty; } - } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java index df702524d55..c7f642edfd0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java @@ -27,6 +27,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -310,6 +311,11 @@ public String getString(int colIndexes) { return toString(); } + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol){ + return getMBDict().sliceColumns(selectedColumns, nCol); + } + @Override public String toString() { return "IdentityMatrixSlice of size: " + nRowCol + " l " + l + " u " + u; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java index 24776f3adc4..ef0fc4aacd5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java @@ -27,8 +27,6 @@ import java.util.Arrays; import java.util.Set; -import jdk.incubator.vector.DoubleVector; -import jdk.incubator.vector.VectorSpecies; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.indexes.ArrayIndex; @@ -36,6 +34,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.RangeIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.SingleIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.TwoIndex; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; @@ -61,6 +60,9 @@ import org.apache.sysds.runtime.matrix.operators.ScalarOperator; import org.apache.sysds.runtime.matrix.operators.UnaryOperator; +import jdk.incubator.vector.DoubleVector; +import jdk.incubator.vector.VectorSpecies; + public class MatrixBlockDictionary extends ADictionary { private static final long serialVersionUID = 2535887782150955098L; @@ -2801,4 +2803,41 @@ private void SparseAdd(int sPos, int sEnd, double[] ret, int offOut, int[] sIdx, } } + @Override + public int[] sort() { + if(_data.getNumColumns() > 1) + throw new RuntimeException("Not supported sort on multicolumn dictionaries"); + _data.sparseToDense(); + + return Dictionary.sort(_data.getDenseBlockValues()); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + + final double[] ret = sliceColumns(_data, selectedColumns); + + return new Dictionary(ret); + } + + public static double[] sliceColumns(MatrixBlock mb, IntArrayList selectedColumns) { + //TODO: Optimize to allow sparse outputs. and change output type to MatrixBlock. + final int outC = selectedColumns.size(); + if((long) mb.getNumRows() * outC > (long) Integer.MAX_VALUE) + throw new NotImplementedException("Not supported large output blocks for slicing dictionary columns"); + mb.sparseToDense(); + final DenseBlock db = mb.getDenseBlock(); + final double[] ret = new double[mb.getNumRows() * outC]; + + for(int i = 0; i < mb.getNumRows(); i++) { + double[] vals = db.values(i); + int offIn = db.pos(i); + int offOut = i * outC; + for(int j = 0; j < outC; j++) { + ret[offOut + j] = vals[offIn + selectedColumns.get(j)]; + } + } + return ret; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java index f5746647a37..ec3ead4a68b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java @@ -23,6 +23,7 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.io.IOUtilFunctions; public class PlaceHolderDict extends ADictionary { @@ -101,4 +102,14 @@ public DictType getDictType() { throw new RuntimeException("invalid to get dictionary type for PlaceHolderDict"); } + @Override + public int[] sort() { + throw new RuntimeException("Invalid call"); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + throw new RuntimeException("Invalid call"); + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java index 6802d920b49..9e2fa4bf1d7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java @@ -23,6 +23,8 @@ import java.io.DataOutput; import java.io.IOException; +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.functionobjects.Builtin; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.MemoryEstimates; @@ -277,4 +279,13 @@ public MatrixBlockDictionary createMBDict(int nCol) { return new MatrixBlockDictionary(mb); } + @Override + public int[] sort() { + throw new NotImplementedException(); + } + + @Override + public IDictionary sliceColumns(IntArrayList selectedColumns, int nCol) { + return getMBDict().sliceColumns(selectedColumns, nCol); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index 5fc2acaea7a..79cc219f2e2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -30,6 +30,7 @@ import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup; @@ -39,6 +40,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffsetIterator; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -1041,4 +1043,40 @@ public String toString() { sb.append("]"); return sb.toString(); } + + public AMapToData removeEmpty(final boolean[] selectV, final int rOut) { + try{ + + final AMapToData ret = MapToFactory.create(rOut, getUnique()); + final int s = size(); + int t = 0; + for(int i = 0; i < s; i++) + if(selectV[i] == true) + ret.set(t++, getIndex(i)); + + return ret; + } + catch(ArrayIndexOutOfBoundsException e){ + + int trueCount = 0; + for(boolean a : selectV){ + if(a) trueCount ++; + } + throw new DMLRuntimeException("actual number of true values " + trueCount + " vs argument " + rOut,e); + } + } + + /** + * Use the offsets of the select vector to choose which values to keep. + * + * @param select The row indexes to keep + * @return A New MapToData + */ + public AMapToData removeEmpty(IntArrayList select) { + final int s = select.size(); + final AMapToData ret = MapToFactory.create(s, getUnique()); + for(int i = 0; i < s; i++) + ret.set(i, getIndex(select.get(i))); + return ret; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java index 45c78dd3abd..a809afccd3d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AIterator.java @@ -71,8 +71,8 @@ public boolean isNotOver(int ub) { /** * Get the current data index associated with the index returned from value. * - * This index points to a position int the mapToData object, that then inturn can be used to lookup the dictionary - * entry in ADictionary. + * This index points to a position in the AMapToData object, that can be used to lookup the dictionary entry in + * ADictionary. * * @return The Data Index. */ diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java index 8930074eb0e..bae6ae57cad 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/AOffset.java @@ -569,7 +569,7 @@ public OffsetSliceInfo slice(int l, int u) { else return new OffsetSliceInfo(0, s, moveIndex(l)); } - else if (u < first) + else if(u < first) return EMPTY_SLICE; final AIterator it = getIteratorSkipCache(l); @@ -764,6 +764,43 @@ public AOffset reverse(int numRows) { return OffsetFactory.createOffset(newOff); } + public RemoveEmptyOffsetsTmp removeEmptyRows(boolean[] selectV, int rOut) { + IntArrayList newOff = new IntArrayList(); + IntArrayList selectMTmp = new IntArrayList(); + + final AIterator it = getIterator(); + final int last = getOffsetToLast(); + int t = 0; + int o = 0; + while(it.value() < last) { + while(t < it.value()) { + if(selectV[t]) + o++; + t++; + } + if(selectV[it.value()]) { + newOff.appendValue(o); + selectMTmp.appendValue(it.getDataIndex()); + o++; + t++; + } + it.next(); + } + while(t < last) { + if(selectV[t]) + o++; + t++; + } + if(selectV[last]) { + newOff.appendValue(o); + selectMTmp.appendValue(it.getDataIndex()); + } + + // throw new RuntimeException("\n\n\n" + Arrays.toString(selectV) + " \n\n " + this + "\n\n " + newOff + " \n " + + // selectMTmp + "\n\n " + "\n\n "); + return new RemoveEmptyOffsetsTmp(OffsetFactory.createOffset(newOff), selectMTmp); + } + /** * Offset slice info containing the start and end index an offset that contains the slice, and an new AOffset * containing only the sliced elements @@ -793,6 +830,16 @@ public String toString() { } + public static final class RemoveEmptyOffsetsTmp { + public final AOffset retOffset; + public final IntArrayList select; + + protected RemoveEmptyOffsetsTmp(AOffset retOffset, IntArrayList select) { + this.retOffset = retOffset; + this.select = select; + } + } + private static class OffsetCache { private final AIterator it; private final int row; @@ -824,4 +871,5 @@ public String toString() { return "r" + row + " d " + dataIndex + " o " + offIndex + "\n"; } } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java index 73264c84767..5c410026587 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetEmpty.java @@ -76,6 +76,10 @@ public int getOffsetToLast() { public long getInMemorySize() { return estimateInMemorySize(); } + @Override + public boolean equals(AOffset b) { + return b instanceof OffsetEmpty; + } public static long estimateInMemorySize() { return 16; // object header diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java index ce52bcd23fd..26779215306 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinaryCellOp.java @@ -48,6 +48,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; +import org.apache.sysds.runtime.compress.utils.HashMapIntToInt; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.DenseBlockFP64; import org.apache.sysds.runtime.data.SparseBlock; @@ -55,7 +56,6 @@ import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.data.SparseRowScalar; import org.apache.sysds.runtime.data.SparseRowVector; -import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; import org.apache.sysds.runtime.functionobjects.Divide; import org.apache.sysds.runtime.functionobjects.Minus; import org.apache.sysds.runtime.functionobjects.Multiply; @@ -77,7 +77,7 @@ public final class CLALibBinaryCellOp { private static final Log LOG = LogFactory.getLog(CLALibBinaryCellOp.class.getName()); - public static final int DECOMPRESSION_BLEN = 16384; + public static final int DECOMPRESSION_BLEN = 16384 / 2; private CLALibBinaryCellOp() { // empty private constructor. @@ -86,7 +86,7 @@ private CLALibBinaryCellOp() { public static MatrixBlock binaryOperationsRight(BinaryOperator op, CompressedMatrixBlock m1, MatrixBlock that) { try { - op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); + op = LibMatrixBincell.replaceOpWithSparseSafeIfApplicable(m1, that, op); if((that.getNumRows() == 1 && that.getNumColumns() == 1) || that.isEmpty()) { ScalarOperator sop = new RightScalarOperator(op.fn, that.get(0, 0), op.getNumThreads()); @@ -113,7 +113,7 @@ public static MatrixBlock binaryOperationsLeft(BinaryOperator op, CompressedMatr return selectProcessingBasedOnAccessType(op, m1, that, atype, true); } catch(Exception e) { - throw new DMLRuntimeException("Failed Left Binary Compressed Operation", e); + throw new DMLRuntimeException("Failed Left Binary Compressed Operation: " + op, e); } } @@ -122,8 +122,8 @@ private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, Comp BinaryAccessType atype = LibMatrixBincell.getBinaryAccessTypeExtended(m1, that); if(isDoubleCompressedOpApplicable(m1, that)) return doubleCompressedBinaryOp(op, m1, (CompressedMatrixBlock) that); - if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() - && that.getInMemorySize() < m1.getInMemorySize() ) { + if(that instanceof CompressedMatrixBlock && that.getNumColumns() == m1.getNumColumns() && + that.getInMemorySize() < m1.getInMemorySize()) { MatrixBlock m1uc = CompressedMatrixBlock.getUncompressed(m1, "Decompressing left side in BinaryOps"); return selectProcessingBasedOnAccessType(op, (CompressedMatrixBlock) that, m1uc, atype, true); } @@ -135,16 +135,15 @@ private static MatrixBlock binaryOperationsRightFiltered(BinaryOperator op, Comp } private static boolean isDoubleCompressedOpApplicable(CompressedMatrixBlock m1, MatrixBlock that) { - return that instanceof CompressedMatrixBlock - && !m1.isOverlapping() - && m1.getColGroups().get(0) instanceof ColGroupDDC - && !((CompressedMatrixBlock) that).isOverlapping() - && ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC - && ((IMapToDataGroup) m1.getColGroups().get(0)).getMapToData() == - ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); + return that instanceof CompressedMatrixBlock && !m1.isOverlapping() && + m1.getColGroups().get(0) instanceof ColGroupDDC && !((CompressedMatrixBlock) that).isOverlapping() && + ((CompressedMatrixBlock) that).getColGroups().get(0) instanceof ColGroupDDC && + ((IMapToDataGroup) m1.getColGroups().get(0)) + .getMapToData() == ((IMapToDataGroup) ((CompressedMatrixBlock) that).getColGroups().get(0)).getMapToData(); } - private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, CompressedMatrixBlock m2) { + private static CompressedMatrixBlock doubleCompressedBinaryOp(BinaryOperator op, CompressedMatrixBlock m1, + CompressedMatrixBlock m2) { LOG.debug("Double Compressed BinaryOp"); AColGroup left = m1.getColGroups().get(0); AColGroup right = m2.getColGroups().get(0); @@ -201,6 +200,7 @@ private static MatrixBlock mvCol(BinaryOperator op, CompressedMatrixBlock m1, Ma // Column vector access MatrixBlock d_compressed = m1.getCachedDecompressed(); if(d_compressed != null) { + LOG.debug("Using cached decompressed for Matrix column vector compressed operation"); if(left) throw new NotImplementedException("Binary row op left is not supported for Uncompressed Matrix, " + "Implement support for VMr in MatrixBlock Binary Cell operations"); @@ -416,17 +416,24 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock Pair tuple = evaluateSparsityMVCol(m1, m2, op, left); double estSparsity = tuple.getKey(); double estNnzPerRow = tuple.getValue(); - boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, (long) (estSparsity * nRows * nCols)); + boolean shouldBeSparseOut = MatrixBlock.evalSparseFormatInMemory(nRows, nCols, + (long) (estSparsity * nRows * nCols)); // currently also jump into that case if estNnzPerRow == 0 - if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction){ - return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, left) : - binaryMVComparisonColMultiCompressed(m1, m2, op, left); + if(estNnzPerRow <= 2 && nCols <= 31 && op.fn instanceof ValueComparisonFunction) { + return k <= 1 ? binaryMVComparisonColSingleThreadCompressed(m1, m2, op, + left) : binaryMVComparisonColMultiCompressed(m1, m2, op, left); } MatrixBlock ret = new MatrixBlock(nRows, nCols, shouldBeSparseOut, -1).allocateBlock(); if(shouldBeSparseOut) { - if(k <= 1) + if(!m1.isOverlapping() && MatrixBlock.evalSparseFormatInMemory(nRows, nCols, m1.getNonZeros())) { + if(k <= 1) + nnz = binaryMVColSingleThreadSparseSparse(m1, m2, op, left, ret); + else + nnz = binaryMVColMultiThreadSparseSparse(m1, m2, op, left, ret); + } + else if(k <= 1) nnz = binaryMVColSingleThreadSparse(m1, m2, op, left, ret); else nnz = binaryMVColMultiThreadSparse(m1, m2, op, left, ret); @@ -438,7 +445,7 @@ private static MatrixBlock mvColCompressed(CompressedMatrixBlock m1, MatrixBlock nnz = binaryMVColMultiThreadDense(m1, m2, op, left, ret); } - if(op.fn instanceof ValueComparisonFunction) { + if(op.fn instanceof ValueComparisonFunction) { // potentially empty or filled. if(nnz == (long) nRows * nCols)// all was 1 return CompressedMatrixBlockFactory.createConstant(nRows, nCols, 1.0); else if(nnz == 0) // all was 0 -> return empty. @@ -452,19 +459,19 @@ else if(nnz == 0) // all was 0 -> return empty. } private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(CompressedMatrixBlock m1, MatrixBlock m2, - BinaryOperator op, boolean left) { + BinaryOperator op, boolean left) { final int nRows = m1.getNumRows(); final int nCols = m1.getNumColumns(); // get indicators (one-hot-encoded comparison results) - BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); + BinaryMVColTaskCompressed task = new BinaryMVColTaskCompressed(m1, m2, 0, nRows, op, left); long nnz = task.call(); int[] indicators = task._ret; // map each unique indicator to an index - HashMapToInt hm = new HashMapToInt<>(nCols*3); + HashMapIntToInt hm = new HashMapIntToInt(nCols * 3); int[] colMap = new int[nRows]; - for(int i = 0; i < m1.getNumRows(); i++){ + for(int i = 0; i < m1.getNumRows(); i++) { int nextId = hm.size(); int id = hm.putIfAbsentI(indicators[i], nextId); colMap[i] = id == -1 ? nextId : id; @@ -477,37 +484,39 @@ private static MatrixBlock binaryMVComparisonColSingleThreadCompressed(Compresse return getCompressedMatrixBlock(m1, colMap, hm.size(), outMb, nRows, nCols, nnz); } - private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, SparseBlockMCSR out) { + private static void fillSparseBlockFromIndicatorFromIndicatorInt(int numCol, Integer indicator, Integer rix, + SparseBlockMCSR out) { ArrayList colIndices = new ArrayList<>(8); - for (int c = numCol - 1; c >= 0; c--) { + for(int c = numCol - 1; c >= 0; c--) { if(indicator <= 0) break; - if(indicator % 2 == 1){ + if(indicator % 2 == 1) { colIndices.add(c); } indicator = indicator >> 1; } SparseRow row = null; - if(colIndices.size() > 1){ + if(colIndices.size() > 1) { double[] vals = new double[colIndices.size()]; Arrays.fill(vals, 1); int[] indices = new int[colIndices.size()]; - for (int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) + for(int i = 0, j = colIndices.size() - 1; i < colIndices.size(); i++, j--) indices[i] = colIndices.get(j); row = new SparseRowVector(vals, indices); - } else if(colIndices.size() == 1){ + } + else if(colIndices.size() == 1) { row = new SparseRowScalar(colIndices.get(0), 1.0); } out.set(rix, row, false); } private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrixBlock m1, MatrixBlock m2, - BinaryOperator op, boolean left) throws Exception { + BinaryOperator op, boolean left) throws Exception { final int nRows = m1.getNumRows(); final int nCols = m1.getNumColumns(); final int k = op.getNumThreads(); - final int blkz = nRows / k; + final int blkz = Math.max((nRows + k) / k, 1000); // get indicators (one-hot-encoded comparison results) long nnz = 0; @@ -518,14 +527,11 @@ private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrix tasks.add(new BinaryMVColTaskCompressed(m1, m2, i, Math.min(nRows, i + blkz), op, left)); } List> futures = pool.invokeAll(tasks); - HashMapToInt hm = new HashMapToInt<>(nCols*2); + HashMapIntToInt hm = new HashMapIntToInt(nCols * 2); int[] colMap = new int[nRows]; - for(Future f : futures) - nnz += f.get(); - // map each unique indicator to an index - mergeMVColTaskResults(tasks, blkz, hm, colMap); + nnz = mergeMVColTaskResults(futures, tasks, blkz, hm, colMap); // decode the unique indicator ints to SparseVectors MatrixBlock outMb = getMCSRMatrixBlock(hm, nCols); @@ -539,48 +545,53 @@ private static MatrixBlock binaryMVComparisonColMultiCompressed(CompressedMatrix } - private static void mergeMVColTaskResults(ArrayList tasks, int blkz, HashMapToInt hm, int[] colMap) { - + private static long mergeMVColTaskResults(List> futures, ArrayList tasks, + int blkz, HashMapIntToInt hm, int[] colMap) throws InterruptedException, ExecutionException { + long nnz = 0; for(int j = 0; j < tasks.size(); j++) { + nnz += futures.get(j).get(); // ensure task was finished. int[] indicators = tasks.get(j)._ret; - int offset = j* blkz; - - final int remainders = indicators.length % 8; - final int endVecLen = indicators.length - remainders; - for (int i = 0; i < endVecLen; i+= 8) { - colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); - colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); - colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); - colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); - colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); - colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); - colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); - colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); + int offset = j * blkz; - } - for (int i = 0; i < remainders; i++) { - colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); - } + mergeMVColUnrolled(hm, colMap, indicators, offset); } + return nnz; } + private static void mergeMVColUnrolled(HashMapIntToInt hm, int[] colMap, int[] indicators, int offset) { + final int remainders = indicators.length % 8; + final int endVecLen = indicators.length - remainders; + for(int i = 0; i < endVecLen; i += 8) { + colMap[offset + i] = hm.putIfAbsentReturnVal(indicators[i], hm.size()); + colMap[offset + i + 1] = hm.putIfAbsentReturnVal(indicators[i + 1], hm.size()); + colMap[offset + i + 2] = hm.putIfAbsentReturnVal(indicators[i + 2], hm.size()); + colMap[offset + i + 3] = hm.putIfAbsentReturnVal(indicators[i + 3], hm.size()); + colMap[offset + i + 4] = hm.putIfAbsentReturnVal(indicators[i + 4], hm.size()); + colMap[offset + i + 5] = hm.putIfAbsentReturnVal(indicators[i + 5], hm.size()); + colMap[offset + i + 6] = hm.putIfAbsentReturnVal(indicators[i + 6], hm.size()); + colMap[offset + i + 7] = hm.putIfAbsentReturnVal(indicators[i + 7], hm.size()); - private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, - int mapSize, MatrixBlock outMb, int nRows, int nCols, long nnz) { + } + for(int i = 0; i < remainders; i++) { + colMap[offset + endVecLen + i] = hm.putIfAbsentReturnVal(indicators[endVecLen + i], hm.size()); + } + } + + private static CompressedMatrixBlock getCompressedMatrixBlock(CompressedMatrixBlock m1, int[] colMap, int mapSize, + MatrixBlock outMb, int nRows, int nCols, long nnz) { final IColIndex i = ColIndexFactory.create(0, m1.getNumColumns()); final AMapToData map = MapToFactory.create(m1.getNumRows(), colMap, mapSize); final AColGroup rgroup = ColGroupDDC.create(i, MatrixBlockDictionary.create(outMb), map, null); final ArrayList groups = new ArrayList<>(1); groups.add(rgroup); - return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); + return new CompressedMatrixBlock(nRows, nCols, nnz, false, groups); } - private static MatrixBlock getMCSRMatrixBlock(HashMapToInt hm, int nCols) { + private static MatrixBlock getMCSRMatrixBlock(HashMapIntToInt hm, int nCols) { // decode the unique indicator ints to SparseVectors SparseBlockMCSR out = new SparseBlockMCSR(hm.size()); - hm.forEach((indicator, rix) -> - fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); - return new MatrixBlock(hm.size(), nCols, -1, out); + hm.forEach((indicator, rix) -> fillSparseBlockFromIndicatorFromIndicatorInt(nCols, indicator, rix, out)); + return new MatrixBlock(hm.size(), nCols, -1, out); } private static long binaryMVColSingleThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, @@ -599,6 +610,14 @@ private static long binaryMVColSingleThreadSparse(CompressedMatrixBlock m1, Matr return nnz; } + private static long binaryMVColSingleThreadSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) { + final int nRows = m1.getNumRows(); + long nnz = 0; + nnz += new BinaryMVColTaskSparseSparse(m1, m2, ret, 0, nRows, op, left).call(); + return nnz; + } + private static long binaryMVColMultiThreadDense(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left, MatrixBlock ret) throws Exception { final int nRows = m1.getNumRows(); @@ -641,6 +660,27 @@ private static long binaryMVColMultiThreadSparse(CompressedMatrixBlock m1, Matri return nnz; } + private static long binaryMVColMultiThreadSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, + boolean left, MatrixBlock ret) throws Exception { + final int nRows = m1.getNumRows(); + final int k = op.getNumThreads(); + final int blkz = Math.max(nRows / k, 64); + long nnz = 0; + final ExecutorService pool = CommonThreadPool.get(op.getNumThreads()); + try { + final ArrayList> tasks = new ArrayList<>(); + for(int i = 0; i < nRows; i += blkz) { + tasks.add(new BinaryMVColTaskSparseSparse(m1, m2, ret, i, Math.min(nRows, i + blkz), op, left)); + } + for(Future f : pool.invokeAll(tasks)) + nnz += f.get(); + } + finally { + pool.shutdown(); + } + return nnz; + } + private static MatrixBlock mmCompressed(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, boolean left) throws Exception { final int nCols = m1.getNumColumns(); @@ -724,8 +764,8 @@ private static class BinaryMVColTaskCompressed implements Callable { private MatrixBlock tmp; - protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, - BinaryOperator op, boolean left) { + protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, int rl, int ru, BinaryOperator op, + boolean left) { _m1 = m1; _m2 = m2; _op = op; @@ -738,21 +778,21 @@ protected BinaryMVColTaskCompressed(CompressedMatrixBlock m1, MatrixBlock m2, in @Override public Long call() { - tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); - final int _blklen = tmp.getNumRows(); + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_blklen, _m1.getNumColumns()); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); long nnz = 0; if(!_left) - for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + for(int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen) { int ru = Math.min(rl + _blklen, _ru); decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); nnz += processDense(rl, ru, retIxOff); tmp.reset(); } else - for (int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen){ + for(int rl = _rl, retIxOff = 0; rl < _ru; rl += _blklen, retIxOff += _blklen) { int ru = Math.min(rl + _blklen, _ru); decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); nnz += processDenseLeft(rl, ru, retIxOff); @@ -770,18 +810,24 @@ private final long processDense(final int rl, final int ru, final int retIxOffse for(int row = rl, retIx = retIxOffset; row < ru; row++, retIx++) { final double vr = _m2Dense[row]; final int tmpOff = (row - rl) * nCol; - int indicatorVector = 0; - for(int col = 0; col < nCol; col++) { - indicatorVector = indicatorVector << 1; - int indicator = _compFn.compare(_tmpDense[tmpOff + col], vr) ? 1 : 0; - indicatorVector += indicator; - nnz += indicator; - } - _ret[retIx] = indicatorVector; + nnz = processRow(nCol, _tmpDense, nnz, retIx, vr, tmpOff); } return nnz; } + private final long processRow(final int nCol, final double[] _tmpDense, long nnz, int retIx, final double vr, + final int tmpOff) { + int indicatorVector = 0; + for(int col = tmpOff; col < nCol + tmpOff; col++) { + indicatorVector = indicatorVector << 1; + int indicator = _compFn.compare(_tmpDense[col], vr) ? 1 : 0; + indicatorVector += indicator; + nnz += indicator; + } + _ret[retIx] = indicatorVector; + return nnz; + } + private final long processDenseLeft(final int rl, final int ru, final int retIxOffset) { final int nCol = _m1.getNumColumns(); final double[] _tmpDense = tmp.getDenseBlockValues(); @@ -847,7 +893,8 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { // unsafe decompress, since we count nonzeros afterwards. final DenseBlock db = _ret.getDenseBlock(); decompressToSubBlock(rl, ru, db, groups, its); @@ -887,7 +934,7 @@ private void processRow(final int ncol, final double[] ret, final int posR, fina private void processRowLeft(final int ncol, final double[] ret, final int posR, final double vr) { for(int col = 0; col < ncol; col++) - ret[posR + col] = _op.fn.execute(vr,ret[posR + col]); + ret[posR + col] = _op.fn.execute(vr, ret[posR + col]); } } @@ -917,8 +964,8 @@ protected BinaryMVColTaskSparse(CompressedMatrixBlock m1, MatrixBlock m2, Matrix @Override public Long call() { - tmp = allocateTempUncompressedBlock(_m1.getNumColumns()); - final int _blklen = tmp.getNumRows(); + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlock(_blklen, _m1.getNumColumns()); final List groups = _m1.getColGroups(); final AIterator[] its = getIterators(groups, _rl); if(!_left) @@ -936,7 +983,8 @@ private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { decompressToTmpBlock(rl, ru, tmp.getDenseBlock(), groups, its); processDenseLeft(rl, ru); tmp.reset(); @@ -971,8 +1019,107 @@ private final void processDenseLeft(final int rl, final int ru) { } } - private static MatrixBlock allocateTempUncompressedBlock(int cols) { - MatrixBlock out = new MatrixBlock(Math.max(DECOMPRESSION_BLEN / cols, 64), cols, false); + private static class BinaryMVColTaskSparseSparse implements Callable { + private final int _rl; + private final int _ru; + private final CompressedMatrixBlock _m1; + private final MatrixBlock _m2; + private final MatrixBlock _ret; + private final BinaryOperator _op; + + private MatrixBlock tmp; + + private boolean _left; + + protected BinaryMVColTaskSparseSparse(CompressedMatrixBlock m1, MatrixBlock m2, MatrixBlock ret, int rl, int ru, + BinaryOperator op, boolean left) { + _m1 = m1; + _m2 = m2; + _ret = ret; + _op = op; + _rl = rl; + _ru = ru; + _left = left; + } + + @Override + public Long call() { + final int _blklen = Math.max(DECOMPRESSION_BLEN / _m1.getNumColumns(), 64); + tmp = allocateTempUncompressedBlockSparse(_blklen, _m1.getNumColumns()); + final List groups = _m1.getColGroups(); + final AIterator[] its = getIterators(groups, _rl); + if(!_left) + for(int r = _rl; r < _ru; r += _blklen) + processBlock(r, Math.min(r + _blklen, _ru), groups, its); + else + for(int r = _rl; r < _ru; r += _blklen) + processBlockLeft(r, Math.min(r + _blklen, _ru), groups, its); + return _ret.recomputeNonZeros(_rl, _ru - 1); + } + + private final void processBlock(final int rl, final int ru, final List groups, final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getSparseBlock(), groups, its); + processDense(rl, ru); + tmp.reset(); + } + + private final void processBlockLeft(final int rl, final int ru, final List groups, + final AIterator[] its) { + decompressToTmpBlock(rl, ru, tmp.getSparseBlock(), groups, its); + processDenseLeft(rl, ru); + tmp.reset(); + } + + private final void processDense(final int rl, final int ru) { + final SparseBlock sb = _ret.getSparseBlock(); + final SparseBlock _tmpSparse = tmp.getSparseBlock(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl); + if(!_tmpSparse.isEmpty(tmpOff)){ + int[] aoff = _tmpSparse.indexes(tmpOff); + double[] aval = _tmpSparse.values(tmpOff); + int apos = _tmpSparse.pos(tmpOff); + int alen = apos + _tmpSparse.size(tmpOff); + + for(int j = apos; j < alen; j++){ + sb.append(row, aoff[j], _op.fn.execute(aval[j], vr)); + } + } + + } + } + + private final void processDenseLeft(final int rl, final int ru) { + final int nCol = _m1.getNumColumns(); + final SparseBlock sb = _ret.getSparseBlock(); + final SparseBlock _tmpSparse = tmp.getSparseBlock(); + final double[] _m2Dense = _m2.getDenseBlockValues(); + for(int row = rl; row < ru; row++) { + final double vr = _m2Dense[row]; + final int tmpOff = (row - rl) * nCol; + if(!_tmpSparse.isEmpty(tmpOff)){ + int[] aoff = _tmpSparse.indexes(tmpOff); + double[] aval = _tmpSparse.values(tmpOff); + int apos = _tmpSparse.pos(tmpOff); + int alen = apos + _tmpSparse.size(tmpOff); + for(int j = apos; j < alen; j++){ + sb.append(row, aoff[j], _op.fn.execute(vr,aval[j])); + } + } + } + } + } + + private static MatrixBlock allocateTempUncompressedBlock(int blklen, int cols) { + MatrixBlock out = new MatrixBlock(blklen, cols, false); + out.allocateBlock(); + return out; + } + + private static MatrixBlock allocateTempUncompressedBlockSparse(int blklen, int cols) { + MatrixBlock out = new MatrixBlock(blklen, cols, true); out.allocateBlock(); return out; } @@ -1199,6 +1346,25 @@ protected static void decompressToTmpBlock(final int rl, final int ru, final Den } } + protected static void decompressToTmpBlock(final int rl, final int ru, final SparseBlock db, + final List groups, final AIterator[] its) { + Timing time = new Timing(true); + for(int i = 0; i < groups.size(); i++) { + final AColGroup g = groups.get(i); + if(g.getCompType() == CompressionType.SDC) + ((ASDCZero) g).decompressToSparseBlock(db, rl, ru, -rl, 0, its[i]); + else + g.decompressToSparseBlock(db, rl, ru, -rl, 0); + } + + if(DMLScript.STATISTICS) { + final double t = time.stop(); + DMLCompressionStatistics.addDecompressToBlockTime(t, 1); + if(LOG.isTraceEnabled()) + LOG.trace("decompressed block w/ k=" + 1 + " in " + t + "ms."); + } + } + protected static AIterator[] getIterators(final List groups, final int rl) { final AIterator[] its = new AIterator[groups.size()]; for(int i = 0; i < groups.size(); i++) { @@ -1210,8 +1376,8 @@ protected static AIterator[] getIterators(final List groups, final in return its; } - private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, BinaryOperator op, - boolean left) { + private static Pair evaluateSparsityMVCol(CompressedMatrixBlock m1, MatrixBlock m2, + BinaryOperator op, boolean left) { final List groups = m1.getColGroups(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); @@ -1247,7 +1413,7 @@ private static Pair evaluateSparsityMVCol(CompressedMatrixBlock for(int r = 0; r < sampleRow; r++) { final double m = m2v[r]; final int off = r * sampleCol; - for(int c = 0; c < sampleCol; c++){ + for(int c = 0; c < sampleCol; c++) { int outVal = op.fn.execute(dv[off + c], m) != 0 ? 1 : 0; nnz += outVal; nnzPerRow[r] += outVal; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 99693635a9b..948a78f96af 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -486,7 +486,7 @@ private static List> generateUnaryAggregateOverlappingFuture final ArrayList tasks = new ArrayList<>(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); - final int blklen = Math.max(64, nRow / k); + final int blklen = Math.max(64, (nRow + k) / k); final List groups = m1.getColGroups(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); if(shouldFilter) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index d82d58e323e..cc7953f8c5d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -95,6 +96,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix if(x.isEmpty()) return returnEmpty(x, out); + if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){ + MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k); + return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k)); + } + // Morph the columns to efficient types for the operation. x = filterColGroups(x); double preFilterTime = t.stop(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java new file mode 100644 index 00000000000..89f639e4ab8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRemoveEmpty.java @@ -0,0 +1,107 @@ +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; + +public class CLALibRemoveEmpty { + protected static final Log LOG = LogFactory.getLog(CLALibRemoveEmpty.class.getName()); + + /** + * CP rmempty operation (single input, single output matrix) + * + * @param in The input matrix + * @param ret The output matrix + * @param rows If we are removing based on rows, or columns. + * @param emptyReturn Return row/column of zeros for empty input. + * @param select An optional selection vector, to remove based on rather than empty rows or columns + * @return The result MatrixBlock, can be a different object that the caller used. + */ + public static MatrixBlock rmempty(CompressedMatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, + MatrixBlock select) { + if(ret == null) + ret = new MatrixBlock(); + MatrixBlock ret2 = LibMatrixReorg.rmemptyEarlyAbort(in, ret, rows, emptyReturn, select); + if(ret2 != null) + return ret2; + + if(rows) + return rmEmptyRows(in, ret, emptyReturn, select); + else + return rmEmptyCols(in, ret, emptyReturn, select); + } + + private static MatrixBlock rmEmptyCols(CompressedMatrixBlock in, MatrixBlock ret, boolean emptyReturn, + MatrixBlock select) { + if(select == null) + return fallback(in, false, emptyReturn, select, ret); + + int cOut = (int) select.getNonZeros(); + if(cOut == -1) + cOut = (int) select.recomputeNonZeros(); + if(cOut == 0){ + ret.reset(in.getNumRows(), !emptyReturn ? 0 : 1); + return ret; + } + + final boolean[] selectV = DataConverter + .convertToBooleanVector(CompressedMatrixBlock.getUncompressed(select, "decompressing selection in rmempty")); + + final List inG = in.getColGroups(); + final List retG = new ArrayList<>(inG.size()); + for(int i = 0; i < inG.size(); i++) { + AColGroup tmp = inG.get(i).removeEmptyCols(selectV); + if(tmp != null) + retG.add(tmp); + } + return new CompressedMatrixBlock(in.getNumRows(), cOut, -1, in.isOverlapping(), retG); + + } + + private static MatrixBlock rmEmptyRows(CompressedMatrixBlock in, MatrixBlock ret, boolean emptyReturn, + MatrixBlock select) { + if(select == null) + return fallback(in, true, emptyReturn, select, ret); + + select = CompressedMatrixBlock.getUncompressed(select, "decompressing selection in rmempty"); + + int rOut = (int) select.getNonZeros(); + if(rOut == -1) + rOut = (int) select.recomputeNonZeros(); + if(rOut == 0){ + ret.reset(!emptyReturn ? 0 : 1, in.getNumColumns()); + return ret; + } + + // TODO: add optimization to avoid linear scan and make selectV indexes, if selection is small relative to number + // of rows + // TODO: add decompress to boolean vector. + final boolean[] selectV = DataConverter.convertToBooleanVector(select); + + + + final List inG = in.getColGroups(); + final List retG = new ArrayList<>(inG.size()); + for(int i = 0; i < inG.size(); i++) { + retG.add(inG.get(i).removeEmptyRows(selectV, rOut)); + } + + return new CompressedMatrixBlock(rOut, in.getNumColumns(), -1, in.isOverlapping(), retG); + } + + private static MatrixBlock fallback(CompressedMatrixBlock in, boolean rows, boolean emptyReturn, MatrixBlock select, + MatrixBlock ret) { + LOG.warn("Decompressing because: removeEmptyOperations with select: " + (select != null) + " rows: " + rows); + MatrixBlock tmp = CompressedMatrixBlock.getUncompressed(in); + MatrixBlock select2 = CompressedMatrixBlock.getUncompressed(select); + return LibMatrixReorg.rmemptyUnsafe(tmp, ret, rows, emptyReturn, select2); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index f14d6833d95..ce06262b9a5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -31,6 +31,8 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -71,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc if(m2 instanceof CompressedMatrixBlock) m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k); - if(betterIfDecompressed(m1)) { - // perform uncompressed multiplication. - return decompressingMatrixMult(m1, m2, k); - } + // if(betterIfDecompressed(m1)) { + // // perform uncompressed multiplication. + // return decompressingMatrixMult(m1, m2, k); + // } if(!allowOverlap) { LOG.trace("Overlapping output not allowed in call to Right MM"); @@ -143,7 +145,9 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat private static boolean betterIfDecompressed(CompressedMatrixBlock m) { for(AColGroup g : m.getColGroups()) { - if(!(g instanceof ColGroupUncompressed) && g.getNumValues() * 2 >= m.getNumRows()) { + // TODO add subpport for decompressing RMM to ASDC and ASDCZero + if(!(g instanceof ColGroupUncompressed || g instanceof ASDC || g instanceof ASDCZero) && + g.getNumValues() * 2 >= m.getNumRows()) { return true; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java new file mode 100644 index 00000000000..c793e84ebef --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibSort.java @@ -0,0 +1,37 @@ +package org.apache.sysds.runtime.compress.lib; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; + +public class CLALibSort { + + public static MatrixBlock sort(CompressedMatrixBlock mb, MatrixValue weights, MatrixBlock result, int k) { + // force uncompressed weights + weights = CompressedMatrixBlock.getUncompressed(weights); + + if(mb.getNumColumns() == 1 && mb.getColGroups().size() == 1 && weights == null) { + return sortSingleCol(mb, k); + } + + // fallback to uncompressed. + return CompressedMatrixBlock// + .getUncompressed(mb, "sortOperations")// + .sortOperations(weights, result); + } + + private static MatrixBlock sortSingleCol(CompressedMatrixBlock mb, int k) { + + AColGroup g = mb.getColGroups().get(0); + + AColGroup r = g.sort(); + + List rg = new ArrayList<>(); + rg.add(r); + return new CompressedMatrixBlock(mb.getNumRows(), mb.getNumColumns(), mb.getNonZeros(), false, rg); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java index a1d47a9b150..d0396b63810 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -42,6 +43,10 @@ private CLALibTSMM() { // private constructor } + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) { + return leftMultByTransposeSelf(cmb, new MatrixBlock(), k); + } + /** * Self left Matrix multiplication (tsmm) * @@ -51,24 +56,32 @@ private CLALibTSMM() { * @param ret The output matrix to put the result into * @param k The parallelization degree allowed */ - public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + final int numColumns = cmb.getNumColumns(); + final int numRows = cmb.getNumRows(); + if(cmb.isEmpty()) + return new MatrixBlock(numColumns, numColumns, true); + // create output matrix block + if(ret == null) + ret = new MatrixBlock(numColumns, numColumns, false); + else + ret.reset(numColumns, numColumns, false); + ret.allocateDenseBlock(); final List groups = cmb.getColGroups(); - final int numColumns = cmb.getNumColumns(); - if(groups.size() >= numColumns) { + if(groups.size() >= numColumns || containsUncompressedColGroup(groups)) { MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k); LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k); - return; + return ret; } - final int numRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); final boolean overlapping = cmb.isOverlapping(); if(shouldFilter) { final double[] constV = new double[numColumns]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); tsmmColGroups(filteredGroups, ret, numRows, overlapping, k); - addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV); + addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV, k); } else { @@ -77,17 +90,23 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret)); ret.examSparsity(); + return ret; + } + + private static boolean containsUncompressedColGroup(List groups) { + for(AColGroup g : groups) + if(g instanceof ColGroupUncompressed) + return true; + return false; } private static void addCorrectionLayer(List filteredGroups, MatrixBlock result, int nRows, int nCols, - double[] constV) { + double[] constV, int k) { final double[] retV = result.getDenseBlockValues(); final double[] filteredColSum = CLALibUtils.getColSum(filteredGroups, nCols, nRows); addCorrectionLayer(constV, filteredColSum, nRows, retV); } - - private static void tsmmColGroups(List groups, MatrixBlock ret, int nRows, boolean overlapping, int k) { if(k <= 1) tsmmColGroupsSingleThread(groups, ret, nRows); @@ -136,12 +155,12 @@ private static void tsmmColGroupsMultiThread(List groups, MatrixBlock public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) { final int nColRow = constV.length; - for(int row = 0; row < nColRow; row++){ + for(int row = 0; row < nColRow; row++) { int offOut = nColRow * row; final double v1l = constV[row]; final double v2l = filteredColSum[row] + constV[row] * nRow; - for(int col = row; col < nColRow; col++){ - ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; + for(int col = row; col < nColRow; col++) { + ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java new file mode 100644 index 00000000000..29650048509 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapIntToInt.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.utils; + +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +public class HashMapIntToInt implements Map { + + static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; + static final float DEFAULT_LOAD_FACTOR = 0.75f; + + protected Node[] buckets; + + protected int size; + + public HashMapIntToInt(int capacity) { + alloc(Math.max(capacity, DEFAULT_INITIAL_CAPACITY)); + } + + protected void alloc(int size) { + Node[] tmp = (Node[]) new Node[size]; + buckets = tmp; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public boolean containsKey(Object key) { + return getI((Integer) key) != -1; + } + + @Override + public boolean containsValue(Object value) { + if(value instanceof Integer) { + for(Entry v : this.entrySet()) { + if(v.getValue().equals(value)) + return true; + } + } + return false; + + } + + @Override + public Integer get(Object key) { + final int i = getI((Integer) key); + if(i != -1) + return i; + else + return null; + } + + public int getI(int key) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b != null) { + do { + if(key == b.key) + return b.value; + } + while((b = b.next) != null); + } + return -1; + + } + + public int hash(int key) { + return Math.abs(Integer.hashCode(key) % buckets.length); + } + + @Override + public Integer put(Integer key, Integer value) { + int i = putI(key, value); + if(i != -1) + return i; + else + return null; + } + + @Override + public Integer putIfAbsent(Integer key, Integer value) { + int i = putIfAbsentI(key, value); + if(i != -1) + return i; + else + return null; + } + + public int putIfAbsentI(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return putIfAbsentBucket(ix, key, value); + + } + + public int putIfAbsentReturnVal(int key, int value) { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + public int putIfAbsentReturnValHash(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + + } + + private int putIfAbsentBucket(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key == key) + return b.value; + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return -1; + } + b = b.next; + } + } + + private int putIfAbsentBucketReturnval(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(b.key == key) + return b.value; + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return value; + } + b = b.next; + } + } + + public int putI(int key, int value) { + + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return addToBucket(ix, key, value); + + } + + private int createBucket(int ix, int key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return -1; + } + + private int createBucketReturnVal(int ix, int key, int value) { + buckets[ix] = new Node(key, value, null); + size++; + return value; + } + + private int addToBucket(int ix, int key, int value) { + Node b = buckets[ix]; + while(true) { + if(key == b.key) { + int tmp = b.getValue(); + b.setValue(value); + return tmp; + } + if(b.next == null) { + b.setNext(new Node(key, value, null)); + size++; + resize(); + return -1; + } + b = b.next; + } + } + + private void resize() { + if(size > buckets.length * DEFAULT_LOAD_FACTOR) { + + Node[] tmp = (Node[]) new Node[buckets.length * 2]; + Node[] oldBuckets = buckets; + buckets = tmp; + size = 0; + for(Node n : oldBuckets) { + if(n != null) + do { + put(n.key, n.value); + } + while((n = n.next) != null); + } + + } + } + + @Override + public Integer remove(Object key) { + throw new UnsupportedOperationException("Unimplemented method 'remove'"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("Unimplemented method 'putAll'"); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("Unimplemented method 'clear'"); + } + + @Override + public Set keySet() { + throw new UnsupportedOperationException("Unimplemented method 'keySet'"); + } + + @Override + public Collection values() { + throw new UnsupportedOperationException("Unimplemented method 'values'"); + } + + @Override + public Set> entrySet() { + return new EntrySet(); + } + + @Override + public void forEach(BiConsumer action) { + + for(Node n : buckets) { + if(n != null) { + do { + action.accept(n.key, n.value); + } + while((n = n.next) != null); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(size() * 3); + this.forEach((k, v) -> { + sb.append("(" + k + "→" + v + ")"); + }); + return sb.toString(); + } + + private static class Node implements Entry { + final int key; + int value; + Node next; + + Node(int key, int value, Node next) { + this.key = key; + this.value = value; + this.next = next; + } + + public final void setNext(Node n) { + next = n; + } + + @Override + public Integer getKey() { + return key; + } + + @Override + public Integer getValue() { + return value; + } + + @Override + public Integer setValue(Integer value) { + return this.value = value; + } + } + + private final class EntrySet extends AbstractSet> { + + @Override + public int size() { + return size; + } + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + } + + private final class EntryIterator implements Iterator> { + Node next; + int bucketId = 0; + + protected EntryIterator() { + + for(; bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public Entry next() { + + Node e = next; + + if(e.next != null) + next = e.next; + else { + for(; ++bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + if(bucketId >= buckets.length) + next = null; + } + + return e; + } + + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java index fc0aa3b1a29..4940dd801b3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java @@ -27,9 +27,18 @@ import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.IndexFunction; +import org.apache.sysds.runtime.functionobjects.KahanPlus; +import org.apache.sysds.runtime.functionobjects.Mean; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction; +import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; public class FederatedWorkloadAnalyzer { protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName()); @@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) { } public void compressRun(ExecutionContext ec, long tid) { - if(counter >= compressRunFrequency ){ + if(counter >= compressRunFrequency) { counter = 0; get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V)); } @@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap mm, ComputationCPInstruction cpIns) { // TODO: Count transitive closure via lineage + // TODO: add more operations if(cpIns instanceof AggregateBinaryCPInstruction) { final String n1 = cpIns.input1.getName(); MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1); @@ -81,15 +91,48 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap mm, long id) { @@ -117,8 +160,8 @@ private static boolean validSize(int nRow, int nCol) { return nRow > 90 && nRow >= nCol; } - @Override - public String toString(){ + @Override + public String toString() { StringBuilder sb = new StringBuilder(); sb.append(this.getClass().getSimpleName()); sb.append(" Counter: "); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 99cce9f9e97..972a2893fd8 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -377,7 +377,7 @@ public static double parseDouble(String value) { return Double.POSITIVE_INFINITY; else if(len == 4 && value.compareToIgnoreCase("-Inf") == 0) return Double.NEGATIVE_INFINITY; - throw new DMLRuntimeException(e); + throw e; } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index b26695e5797..84e4e89a420 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -174,6 +174,29 @@ public int putIfAbsentReturnVal(K key, int value) { } + + public int putIfAbsentReturnValHash(K key, int value) { + + if(key == null) { + if(nullV == -1) { + size++; + nullV = value; + return -1; + } + else + return nullV; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucketReturnVal(ix, key, value); + else + return putIfAbsentBucketReturnval(ix, key, value); + } + + } + private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; while(true) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..292fcb52bf5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,7 +607,6 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { @@ -617,7 +616,8 @@ private static double getAsDouble(String s) { else if(ls.equals("false") || ls.equals("f")) return 0; else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw e; // for efficiency + // throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java index 032afe2cd7c..987d14106ac 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java @@ -32,11 +32,17 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; -public interface MatrixBlockFromFrame { +public class MatrixBlockFromFrame { public static final Log LOG = LogFactory.getLog(MatrixBlockFromFrame.class.getName()); public static final int blocksizeIJ = 32; + public static Boolean WARNED_FOR_FAILED_CAST = false; + + private MatrixBlockFromFrame(){ + // private constructor for code coverage. + } + /** * Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. @@ -94,10 +100,25 @@ else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseForma } private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { - if(mb.getDenseBlock().isContiguous()) - return convertContiguous(frame, mb, n, rl, ru); - else - return convertGeneric(frame, mb, n, rl, ru); + try { + + if(mb.getDenseBlock().isContiguous()) + return convertContiguous(frame, mb, n, rl, ru); + else + return convertGeneric(frame, mb, n, rl, ru); + } + catch(NumberFormatException | DMLRuntimeException e) { + synchronized(WARNED_FOR_FAILED_CAST){ + if(!WARNED_FOR_FAILED_CAST) { + LOG.error( + "Failed to convert to Matrix because of number format errors, falling back to NaN on incompatible cells", + e); + WARNED_FOR_FAILED_CAST = true; + } + } + return convertSafeCast(frame, mb, n, rl, ru); + + } } private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception { @@ -169,4 +190,37 @@ private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final } return lnnz; } + + private static long convertSafeCast(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, + final int ru) { + final DenseBlock c = mb.getDenseBlock(); + long lnnz = 0; + for(int bi = rl; bi < ru; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, ru); + int bjmin = Math.min(bj + blocksizeIJ, n); + lnnz = convertBlockSafeCast(frame, lnnz, c, bi, bj, bimin, bjmin); + } + } + return lnnz; + } + + private static long convertBlockSafeCast(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl, + final int cl, final int ru, final int cu) { + for(int i = rl; i < ru; i++) { + final double[] cvals = c.values(i); + final int cpos = c.pos(i); + for(int j = cl; j < cu; j++) { + try { + lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + } + catch(NumberFormatException | DMLRuntimeException e) { + lnnz += 1; + cvals[cpos + j] = Double.NaN; + } + } + } + return lnnz; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 39735be62e0..eed2c58f78c 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, - DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, + DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; @@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, String2BuiltinCode.put( "_map", BuiltinCode.MAP); String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA); + String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK); } protected Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 28b8775ebd5..86184f47be6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME) return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str); + else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR) + return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX) return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str); else diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java new file mode 100644 index 00000000000..99b3c1a3b13 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysds.common.Builtins; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; +import org.apache.sysds.runtime.transform.TfUtils.TfMethod; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONObject; + +public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { + // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, + String opcode, String istr) { + super(CPType.Binary, op, in1, in2, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + // get input frames + FrameBlock inBlock1 = ec.getFrameInput(input1.getName()); + ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true); + if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) { + processGetCategorical(ec, inBlock1, spec); + } + else { + throw new DMLRuntimeException("Unsupported operation"); + } + + // Release the memory occupied by input frames + ec.releaseFrameInput(input1.getName()); + } + + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + + // MatrixBlock ret = new MatrixBlock(); + int nCol = f.getNumColumns(); + + JSONObject jSpec = new JSONObject(spec.getStringValue()); + + if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) { + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); + } + + String recode = TfMethod.RECODE.toString(); + String dummycode = TfMethod.DUMMYCODE.toString(); + + int[] lengths = new int[nCol]; + // assume all columns encode to at least one column. + Arrays.fill(lengths, 1); + boolean[] categorical = new boolean[nCol]; + + if(jSpec.containsKey(recode)) { + JSONArray a = jSpec.getJSONArray(recode); + for(Object aa : a) { + int av = (Integer) aa - 1; + categorical[av] = true; + } + } + + if(jSpec.containsKey(dummycode)) { + JSONArray a = jSpec.getJSONArray(dummycode); + for(Object aa : a) { + int av = (Integer) aa - 1; + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + lengths[av] = ndist; + categorical[av] = true; + } + } + + // get total size after mapping + + int sumLengths = 0; + for(int i : lengths) { + sumLengths += i; + } + + MatrixBlock ret = new MatrixBlock(1, sumLengths, false); + ret.allocateDenseBlock(); + int off = 0; + for(int i = 0; i < lengths.length; i++) { + for(int j = 0; j < lengths[i]; j++) { + ret.set(0, off++, categorical[i] ? 1 : 0); + } + } + + ec.setMatrixOutput(output.getName(), ret); + + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java index 2ec23037385..d76dbe0d45e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryMatrixMatrixCPInstruction.java @@ -80,8 +80,15 @@ public void processInstruction(ExecutionContext ec) { retBlock = inBlock1; } else { - if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) && !compressedLeft && !compressedRight) + if(LibCommonsMath.isSupportedMatrixMatrixOperation(getOpcode()) ){ + if(compressedLeft) + inBlock1 = CompressedMatrixBlock.getUncompressed(inBlock1, getOpcode()); + + if(compressedRight) + inBlock2 = CompressedMatrixBlock.getUncompressed(inBlock2, getOpcode()); + retBlock = LibCommonsMath.matrixMatrixOperations(inBlock1, inBlock2, getOpcode()); + } else { // Perform computation using input matrices, and produce the result matrix BinaryOperator bop = (BinaryOperator) _optr; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..e53958ac4b8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -352,7 +352,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) { // compute transformdecode Decoder decoder = DecoderFactory .createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns()); - FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema())); + FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()), InfrastructureAnalyzer.getLocalParallelism()); fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns())); // release locks diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 5dd8e55e821..93390cc686d 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.io.FileFormatProperties; @@ -918,7 +919,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { switch( getInput1().getDataType() ) { case FRAME: { FrameBlock fin = ec.getFrameInput(getInput1().getName()); - MatrixBlock out = DataConverter.convertToMatrixBlock(fin); + MatrixBlock out = MatrixBlockFromFrame.convertToMatrixBlock(fin, k); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out); break; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java new file mode 100644 index 00000000000..79f08cb353a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibAggregateUnarySpecialization.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.CorrectionLocationType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.cp.KahanObject; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; + +public class LibAggregateUnarySpecialization { + protected static final Log LOG = LogFactory.getLog(LibAggregateUnarySpecialization.class.getName()); + + public static void aggregateUnary(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.sparseSafe) + sparseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + else + denseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + } + + private static void sparseAggregateUnaryHelp(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, + int blen, MatrixIndexes indexesIn) { + // initialize result + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + + if(mb.sparse && mb.sparseBlock != null) { + SparseBlock a = mb.sparseBlock; + for(int r = 0; r < Math.min(mb.rlen, a.numRows()); r++) { + if(a.isEmpty(r)) + continue; + int apos = a.pos(r); + int alen = a.size(r); + int[] aix = a.indexes(r); + double[] aval = a.values(r); + for(int i = apos; i < apos + alen; i++) { + tempCellIndex.set(r, aix[i]); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, aval[i], + buffer); + } + } + } + else if(!mb.sparse && mb.denseBlock != null) { + DenseBlock a = mb.getDenseBlock(); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, a.get(i, j), + buffer); + } + } + } + + private static void denseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, + mb.get(i, j), buffer); + } + } + + private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, + double newvalue, KahanObject buffer) { + if(aggOp.existsCorrection()) { + if(aggOp.correction == CorrectionLocationType.LASTROW || + aggOp.correction == CorrectionLocationType.LASTCOLUMN) { + int corRow = row, corCol = column; + if(aggOp.correction == CorrectionLocationType.LASTROW)// extra row + corRow++; + else if(aggOp.correction == CorrectionLocationType.LASTCOLUMN) + corCol++; + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + + buffer._sum = result.get(row, column); + buffer._correction = result.get(corRow, corCol); + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue); + result.set(row, column, buffer._sum); + result.set(corRow, corCol, buffer._correction); + } + else if(aggOp.correction == CorrectionLocationType.NONE) { + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + } + else// for mean + { + int corRow = row, corCol = column; + int countRow = row, countCol = column; + if(aggOp.correction == CorrectionLocationType.LASTTWOROWS) { + countRow++; + corRow += 2; + } + else if(aggOp.correction == CorrectionLocationType.LASTTWOCOLUMNS) { + countCol++; + corCol += 2; + } + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + buffer._sum = result.get(row, column); + buffer._correction = result.get(corRow, corCol); + double count = result.get(countRow, countCol) + 1.0; + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue, count); + result.set(row, column, buffer._sum); + result.set(corRow, corCol, buffer._correction); + result.set(countRow, countCol, count); + } + + } + else { + newvalue = aggOp.increOp.fn.execute(result.get(row, column), newvalue); + result.set(row, column, newvalue); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java index af702cb7fad..3113850ec80 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java @@ -3234,6 +3234,11 @@ private static void matrixMultWDivMMDense(MatrixBlock mW, MatrixBlock mU, Matrix DenseBlock x = (mX==null) ? null : mX.getDenseBlock(); DenseBlock c = ret.getDenseBlock(); + if(c == null){ + ret.allocateDenseBlock(); + c = ret.getDenseBlock(); + } + //approach: iterate over non-zeros of w, selective mm computation //cache-conscious blocking: due to blocksize constraint (default 1000), //a blocksize of 16 allows to fit blocks of UV into L2 cache (256KB) @@ -3380,6 +3385,11 @@ private static void matrixMultWDivMMGeneric(MatrixBlock mW, MatrixBlock mU, Matr //output always in dense representation DenseBlock c = ret.getDenseBlock(); + + if(c == null){ + ret.allocateDenseBlock(); + c = ret.getDenseBlock(); + } //approach: iterate over non-zeros of w, selective mm computation if( mW.sparse ) //SPARSE diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 90ea445be8d..1c0535f3b36 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -844,16 +844,34 @@ public static List reshape(IndexedMatrixValue in, DataCharac } /** - * CP rmempty operation (single input, single output matrix) + * CP rmempty operation (single input, single output matrix) * - * @param in input matrix - * @param ret output matrix - * @param rows ? - * @param emptyReturn return row/column of zeros for empty input - * @param select ? - * @return matrix block + * @param in The input matrix + * @param ret The output matrix + * @param rows If we are removing based on rows, or columns. + * @param emptyReturn Return row/column of zeros for empty input + * @param select An optional selection vector, to remove based on rather than empty rows or columns + * @return The result MatrixBlock */ public static MatrixBlock rmempty(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select) { + if(ret == null) + ret = new MatrixBlock(); + MatrixBlock ret2 = rmemptyEarlyAbort(in, ret, rows, emptyReturn, select); + if(ret2 != null ) + return ret2; + // core removeEmpty + return rmemptyUnsafe(in, ret, rows, emptyReturn, select); + } + + public static MatrixBlock rmemptyUnsafe(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, + MatrixBlock select) { + if( rows ) + return removeEmptyRows(in, ret, select, emptyReturn); + else // cols + return removeEmptyColumns(in, ret, select, emptyReturn); + } + + public static MatrixBlock rmemptyEarlyAbort(MatrixBlock in, MatrixBlock ret, boolean rows, boolean emptyReturn, MatrixBlock select){ //check for empty inputs //(the semantics of removeEmpty are that for an empty m-by-n matrix, the output //is an empty 1-by-n or m-by-1 matrix because we don't allow matrices with dims 0) @@ -870,12 +888,8 @@ public static MatrixBlock rmempty(MatrixBlock in, MatrixBlock ret, boolean rows, if( select != null && (select.nonZeros == (rows?in.rlen:in.clen)) ) { return in; } - - // core removeEmpty - if( rows ) - return removeEmptyRows(in, ret, select, emptyReturn); - else //cols - return removeEmptyColumns(in, ret, select, emptyReturn); + + return null; } /** @@ -3500,6 +3514,25 @@ private static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, Matr rlen2 = (int)select.getNonZeros(); } + return removeEmptyRows(in, ret, emptyReturn, select == null, flags, rlen2); + } + + /** + * Remove selected rows, based on the boolean array given. Note this function is internal use only, and require a + * boolean vector to be constructed first. + * + * @param in Input to remove rows from + * @param ret Output to assign the result into + * @param emptyReturn If the output is allowed to be empty. + * @param selectNull If the original caller did not have a selection matrix. + * @param flags The boolean selection vector to specify which rows to keep. + * @param rlen2 The number of true values in the flags argument. + * @return Another reference to the ret matrix input argument. + */ + public static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, boolean emptyReturn, boolean selectNull, + boolean[] flags, int rlen2) { + final int m = in.rlen; + final int n = in.clen; //Step 2: reset result and copy rows //dense stays dense if correct input representation (but robust for any input), //sparse might be dense/sparse @@ -3509,7 +3542,7 @@ private static MatrixBlock removeEmptyRows(MatrixBlock in, MatrixBlock ret, Matr if( in.isEmptyBlock(false) ) return ret; - if( SHALLOW_COPY_REORG && m == rlen2 && select == null ) { + if( SHALLOW_COPY_REORG && m == rlen2 && selectNull ) { // the condition m==rlen2 is not enough with non-empty 1-row input but empty // 1-row select vector because if emptyReturn should output a single empty row ret.sparse = in.sparse; @@ -3552,7 +3585,7 @@ else if( !in.sparse && !ret.sparse ) //DENSE <- DENSE } //check sparsity - ret.nonZeros = (select==null) ? + ret.nonZeros = (selectNull) ? in.nonZeros : ret.recomputeNonZeros(); ret.examSparsity(); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 3dd8b2ad3b4..56095ab3d05 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -1315,7 +1315,7 @@ public void examSparsity(boolean allowCSR, int k) { else if( !sparse && sparseDst ) denseToSparse(allowCSR, k); } - + public static boolean evalSparseFormatInMemory(DataCharacteristics dc) { return evalSparseFormatInMemory(dc.getRows(), dc.getCols(), dc.getNonZeros()); } @@ -1387,12 +1387,13 @@ public void denseToSparse(boolean allowCSR, int k){ LibMatrixDenseToSparse.denseToSparse(this, allowCSR, k); } - public final void sparseToDense() { - sparseToDense(1); + public final MatrixBlock sparseToDense() { + return sparseToDense(1); } - public void sparseToDense(int k) { + public MatrixBlock sparseToDense(int k) { LibMatrixSparseToDense.sparseToDense(this, k); + return this; } /** @@ -2954,13 +2955,14 @@ public boolean isShallowSerialize(boolean inclConvert) { boolean sparseDst = evalSparseFormatOnDisk(); return !sparse || !sparseDst || (sparse && sparseBlock instanceof SparseBlockCSR) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD - <= getExactSerializedSize()) - || (sparse && sparseBlock instanceof SparseBlockMCSR - && nonZeros < Integer.MAX_VALUE //CSR constraint - && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE - && !isUltraSparseSerialize(sparseDst)); + || (sparse && sparseBlock instanceof SparseBlockMCSR); + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && getInMemorySize() / MAX_SHALLOW_SERIALIZE_OVERHEAD + // <= getExactSerializedSize()) + // || (sparse && sparseBlock instanceof SparseBlockMCSR + // && nonZeros < Integer.MAX_VALUE //CSR constraint + // && inclConvert && CONVERT_MCSR_TO_CSR_ON_DEEP_SERIALIZE + // && !isUltraSparseSerialize(sparseDst)); } @Override @@ -4650,7 +4652,7 @@ public final MatrixBlock sortOperations(MatrixValue weights){ return sortOperations(weights, null); } - public MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { + public final MatrixBlock sortOperations(MatrixValue weights, MatrixBlock result) { return sortOperations(weights, result, 1); } @@ -4754,7 +4756,17 @@ public static double computeIQMCorrection(double sum, double sum_wt, return (sum + q25Part*q25Val - q75Part*q75Val) / (sum_wt*0.5); } - public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { + /** + * Pick the quantiles out of this matrix. If this matrix contains two columns it is weighted quantile picking. + * If a single column it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantiles The quantiles to pick + * @param ret The result matrix + * @return The result matrix + */ + public final MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret) { return pickValues(quantiles, ret, false); } @@ -4778,17 +4790,56 @@ public MatrixBlock pickValues(MatrixValue quantiles, MatrixValue ret, boolean av return output; } - + + /** + * Pick the median quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @return The quantile + */ public double median() { double sum_wt = sumWeightForQuantile(); return pickValue(0.5, sum_wt%2==0); } - + + /** + * Pick a specific quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @return The quantile + */ public final double pickValue(double quantile){ return pickValue(quantile, false); } - public double pickValue(double quantile, boolean average) { + /** + * Pick a specific quantile from this matrix. if this matrix is two columns, it is weighted picking else it is unweighted. + * + * Note the values are assumed to be sorted + * + * @param quantile The quantile to pick + * @param average If the quantile is averaged. + * @return The quantile + */ + public final double pickValue(double quantile, boolean average) { + if(this.getNumColumns() == 1) + return pickUnweightedValue(quantile, average); + return pickWeightedValue(quantile, average); + } + + private double pickUnweightedValue(double quantile, boolean average) { + double pos = quantile * rlen; + if(average && (int) pos != pos) + return (get((int) Math.floor(pos), 0) + get(Math.min(rlen - 1, (int) Math.ceil(pos)), 0)) / 2; + else + return get(Math.min(rlen - 1, (int) Math.round(pos)), 0); + } + + private double pickWeightedValue(double quantile, boolean average) { double sum_wt = sumWeightForQuantile(); // do averaging only if it is asked for; and sum_wt is even @@ -5342,8 +5393,8 @@ public MatrixBlock ctableSeqOperations(MatrixValue thatMatrix, double thatScalar * (i1,j1,v2) from input2 (that) * (w) from scalar_input3 (scalarThat2) * - * @param thatMatrix matrix value - * @param thatScalar scalar double + * @param thatMatrix matrix value, the vector to encode via table + * @param thatScalar scalar double, w, that is the weight to multiply on the encoded values * @param resultBlock result matrix block * @return resultBlock */ diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..70834675ded 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,6 +23,10 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +34,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Base class for all transform decoders providing both a row and block @@ -77,8 +82,31 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..c9fcc23990a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -43,15 +44,18 @@ public class DecoderBin extends Decoder { // a) column bin boundaries private int[] _numBins; + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; - public DecoderBin() { - super(null, null); - } + // public DecoderBin() { + // super(null, null); + // } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); + _dcCols = dcCols; } @Override @@ -66,14 +70,28 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ - final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + try{ + + final int key = (int) Math.round(val); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + } + catch(Exception e){ + LOG.error(a); + LOG.error(in.slice(0, in.getNumRows()-1, _colList[j]-1,_colList[j]-1)); + LOG.error( val); + throw e; + } } else a.set(i, val); // NaN @@ -111,6 +129,34 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + + if( _dcCols.length > 0 ) { + //prepare source column id mapping w/ dummy coding + _srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while( ix1<_colList.length ) { + if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { + _srcCols[ix1] = _colList[ix1] + off; + ix1 ++; + } + else { //_colList[ix1] > _dcCols[ix2] + ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; + String v = meta.getString(0, _dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } + ix2 ++; + } + } + } + else { + //prepare direct source column mapping + _srcCols = _colList; + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..dff85e72dc6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,13 +25,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -50,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - public DecoderComposite() { super(null, null); } + // public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -59,33 +56,6 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } - - @Override - public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { - final ExecutorService pool = CommonThreadPool.get(k); - out.ensureAllocatedColumns(in.getNumRows()); - try { - final List> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..debce027680 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,30 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +58,98 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - _clPos[j] + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +157,25 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + String v = meta.getString(0, colID - 1); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..12ba2968877 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,34 +64,52 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // remove hash recoded. // todo potentially wrong and remove? + rcIDs = except(rcIDs, hcIDs); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // hashed columns + ptIDs = except(ptIDs, hcIDs); // remove hashed columns + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { + // todo figure out if we need to handle rc columns with regards to dictionary offsets. ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..c2de3ec1df3 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - public DecoderPassThrough() { super(null, null); } + // public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -61,13 +61,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i _dcCols[ix2] ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + String v = meta.getString( 0,_dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 33459a1c4f9..1cf0b7c4b3f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,12 +47,11 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; private boolean _onOut = false; - public DecoderRecode() { - super(null, null); - } + // public DecoderRecode() { + // super(null, null); + // } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); @@ -59,8 +59,7 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null && key > 0) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + return _rcMaps[i].get(key); } @Override @@ -129,27 +128,33 @@ public void initMetaData(FrameBlock meta) { for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } + // if( Arrays.stream(max).allMatch(v -> v < Integer.MAX_VALUE) ) { + // _rcMapsDirect = new Object[_rcMaps.length][]; + // for( int i=0; i<_rcMaps.length; i++ ) { + // Object[] arr = new Object[(int)max[i]]; + // for(Entry e1 : _rcMaps[i].entrySet()) + // arr[e1.getKey().intValue()-1] = e1.getValue(); + // _rcMapsDirect[i] = arr; + // } + // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..361c9c52135 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - meta.set(0, _colID - 1, String.valueOf(_K)); + // set metadata of hash columns to magical hash value + k + meta.set(0, _colID - 1, String.format("¿%d" , _K)); + return meta; } @@ -154,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); + _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); } @Override diff --git a/src/main/java/org/apache/sysds/utils/DoubleParser.java b/src/main/java/org/apache/sysds/utils/DoubleParser.java index 9c77a3e95c8..c0122f8061f 100644 --- a/src/main/java/org/apache/sysds/utils/DoubleParser.java +++ b/src/main/java/org/apache/sysds/utils/DoubleParser.java @@ -184,7 +184,7 @@ public interface DoubleParser { 0x8e679c2f5e44ff8fL}; public static double parseFloatingPointLiteral(String str, int offset, int endIndex) { - if(endIndex > 100) + if(endIndex > 100)// long string return Double.parseDouble(str); // Skip leading whitespace int index = skipWhitespace(str, offset, endIndex); @@ -197,9 +197,10 @@ public static double parseFloatingPointLiteral(String str, int offset, int endIn } // Parse NaN or Infinity (this occurs rarely) - if(ch >= 'I') - return Double.parseDouble(str); - else if(str.charAt(endIndex - 1) >= 'a') + // : is the first character after numbers. + // 0 is the first number. + // we use the last position, since this is not allowed to be other values than a number. + if(str.charAt(endIndex - 1) > '9' || str.charAt(endIndex - 1) < '0') return Double.parseDouble(str); final double val = parseDecFloatLiteral(str, index, offset, endIndex); diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index e470dd82539..07abdbaff26 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -32,6 +32,7 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; +import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; @@ -2941,6 +2942,25 @@ public static void writeTestScalar(String file, double value) { } } + + /** + * Write scalar to file + * + * @param file File to write to + * @param value Value to write + */ + public static void writeTestScalar(String file, String value) { + try { + DataOutputStream out = new DataOutputStream(new FileOutputStream(file)); + try(PrintWriter pw = new PrintWriter(out)) { + pw.println(value); + } + } + catch(IOException e) { + fail("unable to write test scalar (" + file + "): " + e.getMessage()); + } + } + /** * Write scalar to file * diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java index d36c6167cf7..9d5976a8903 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedMatrixTest.java @@ -687,4 +687,54 @@ public void toRDDAndBack(int blen) { fail(e.getMessage()); } } + + @Test + public void removeEmptyOperationsBase1() { + removeEmptyOperations(false, false, null); + } + + @Test + public void removeEmptyOperationsBase2() { + removeEmptyOperations(true, false, null); + } + + @Test + public void removeEmptyOperationsBase3() { + removeEmptyOperations(false, true, null); + } + + @Test + public void removeEmptyOperationsBase4() { + removeEmptyOperations(true, true, null); + } + + @Test + public void removeEmptyOperationsSelect1() { + if(rows < 5000) { + MatrixBlock s = TestUtils.generateTestMatrixBlock(rows, 1, 1, 1, 0.05, 321); + removeEmptyOperations(true, false, s); + } + } + + @Test + public void removeEmptyOperationsSelect2() { + if(rows < 5000) { + MatrixBlock s = TestUtils.generateTestMatrixBlock(1, cols, 1, 1, 0.5, 321); + removeEmptyOperations(false, false, s); + } + } + + public void removeEmptyOperations(boolean rows, boolean emptyReturn, MatrixBlock select) { + try { + MatrixBlock a = cmb.removeEmptyOperations(null, rows, emptyReturn, select); + MatrixBlock b = mb.removeEmptyOperations(null, rows, emptyReturn, select); + compareResultMatrices(b, a, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + + } + } diff --git a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java index c3efeea4014..36dab4191ee 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupNegativeTests.java @@ -49,6 +49,7 @@ import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockMCSR; @@ -468,6 +469,24 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); + } } private class FakeDictBasedColGroup extends ADictBasedColGroup { @@ -777,5 +796,23 @@ public AColGroup[] splitReshapePushDown(int multiplier, int nRow, int nColOrg, E // TODO Auto-generated method stub throw new UnsupportedOperationException("Unimplemented method 'splitReshapePushDown'"); } + + @Override + public AColGroup sort() { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'sort'"); + } + + @Override + public AColGroup removeEmptyRows(boolean[] selectV, int rOut) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyRows'"); + } + + @Override + protected AColGroup removeEmptyColsSubset(IColIndex newColumnIDs, IntArrayList selectedColumns) { + // TODO Auto-generated method stub + throw new UnsupportedOperationException("Unimplemented method 'removeEmptyColsSubset'"); + } } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java index 194f581121a..a5bd3cebfb0 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleMultiColTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.fail; +import org.apache.sysds.runtime.compress.CompressionSettingsBuilder; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.junit.Test; @@ -115,6 +116,8 @@ public void testJoinWithSecondSubpartLeft() { private void partJoinVerification(IEncode er) { boolean incorrectUnique = e.getUnique() != er.getUnique(); + er.extractFacts(10000, 1.0, 1.0, new CompressionSettingsBuilder().create()); + if(incorrectUnique) { StringBuilder sb = new StringBuilder(); sb.append("\nFailed joining sub parts to recreate whole."); diff --git a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java index 182bd7fa37e..5a298f145ec 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java @@ -81,6 +81,10 @@ public static Collection data() { // Both Sparse and end dense joined tests.add(createT(1, 0.2, 10, 10, 0.1, 2, 1000, 1231521)); + + tests.add(createT(1, 1.0, 100, 1, 1.0, 10, 10000, 132)); + tests.add(createT(1, 1.0, 1000, 1, 1.0, 10, 10000, 132)); + return tests; } diff --git a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java index 2e901eeb14d..3755365c018 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/offset/CustomOffsetTest.java @@ -28,13 +28,14 @@ import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset; import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.OffsetSliceInfo; +import org.apache.sysds.runtime.compress.colgroup.offset.AOffset.RemoveEmptyOffsetsTmp; import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.junit.Test; public class CustomOffsetTest { protected static final Log LOG = LogFactory.getLog(CustomOffsetTest.class.getName()); - static{ + static { CompressedMatrixBlock.debug = true; } @@ -96,4 +97,95 @@ public void printCache() { String s = off.toString(); assertTrue(s.contains("CacheRow")); } + + @Test + public void removeEmptyRows1() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, false, false, false, false}, 0); + assertEquals(1, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows2() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, true, false, false, false}, 0); + assertEquals(1, t.select.size()); + assertEquals(1, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows3() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, true, false, false, false}, 0); + assertEquals(2, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.select.get(1)); + assertEquals(2, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0, 1}), t.retOffset); + } + + @Test + public void removeEmptyRows4() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, true, true, false, false, true}, 0); + assertEquals(3, t.select.size()); + assertEquals(0, t.select.get(0)); + assertEquals(1, t.select.get(1)); + assertEquals(4, t.select.get(2)); + assertEquals(3, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0, 1, 2}), t.retOffset); + } + + @Test + public void removeEmptyRows5() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, false, false, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(4, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {0}), t.retOffset); + } + + @Test + public void removeEmptyRows6() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, true, true, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(2, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {2}), t.retOffset); + } + + @Test + public void removeEmptyRows7() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {true, false, false, true, true, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(2, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {3}), t.retOffset); + } + + @Test + public void removeEmptyRows8() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {true, false, false, false, false, true}, 0); + assertEquals(1, t.select.size()); + assertEquals(4, t.select.get(0)); + assertEquals(1, t.retOffset.getSize()); + assertEquals(OffsetFactory.createOffset(new int[] {1}), t.retOffset); + } + + @Test + public void removeEmptyRowsEmpty() { + AOffset of = OffsetFactory.createOffset(new int[] {1, 2, 3, 4, 5}); + RemoveEmptyOffsetsTmp t = of.removeEmptyRows(new boolean[] {false, false, false, false, false, false}, 0); + assertEquals(0, t.select.size()); + assertEquals(OffsetFactory.createOffset(new int[] {}), t.retOffset); + } } diff --git a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java index c6d52a70a51..872ec79c1f1 100644 --- a/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java +++ b/src/test/java/org/apache/sysds/test/functions/compress/configuration/CompressForce.java @@ -49,7 +49,7 @@ protected String getTestDir() { @Test public void testTranspose_CP() { - runTest(1500, 20, 1, 1, ExecType.CP, "transpose"); + runTest(1500, 20, 2, 1, ExecType.CP, "transpose"); } @Test diff --git a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java index ee6a2953980..18ca2fbc454 100644 --- a/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java +++ b/src/test/java/org/apache/sysds/test/functions/misc/ToStringTest.java @@ -270,4 +270,96 @@ protected void toStringTestHelper(ExecMode platform, String testName, String exp DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; } } + + @Test + public void testPrintWithDecimal(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "22"; + String expectedOutput = "22.00\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal2(){ + String testName = "ToString12"; + + String decimalPoints = "2"; + String value = "5.244058388023880"; + String expectedOutput = "5.24\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal3(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "5.244058388023880"; + String expectedOutput = "5.2440583880\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal4(){ + String testName = "ToString12"; + + String decimalPoints = "4"; + String value = "5.244058388023880"; + String expectedOutput = "5.2441\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + + @Test + public void testPrintWithDecimal5(){ + String testName = "ToString12"; + + String decimalPoints = "10"; + String value = "0.000000008023880"; + String expectedOutput = "0.0000000080\n"; + + addTestConfiguration(testName, new TestConfiguration(TEST_CLASS_DIR, testName)); + toStringTestHelper2(ExecMode.SINGLE_NODE, testName, expectedOutput, decimalPoints, value); + } + + protected void toStringTestHelper2(ExecMode platform, String testName, String expectedOutput, String decimalPoints, String value) { + ExecMode platformOld = rtplatform; + + rtplatform = platform; + boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; + if (rtplatform == ExecMode.SPARK) + DMLScript.USE_LOCAL_SPARK_CONFIG = true; + try { + // Create and load test configuration + getAndLoadTestConfiguration(testName); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testName + ".dml"; + programArgs = new String[]{"-args", output(OUTPUT_NAME), value, decimalPoints}; + + // Run DML and R scripts + runTest(true, false, null, -1); + + // Compare output strings + String output = TestUtils.readDMLString(output(OUTPUT_NAME)); + TestUtils.compareScalars(expectedOutput, output); + } + finally { + // Reset settings + rtplatform = platformOld; + DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld; + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java new file mode 100644 index 00000000000..30681f373e4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class GetCategoricalMaskTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskTest.class.getName()); + + private final static String TEST_NAME1 = "GetCategoricalMaskTest"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeApplyTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testRecode() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 1, 1.0); + String spec = "{\"ids\": true, \"recode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testRecode2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8, ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 2, new double[] {0, 1}); + + String spec = "{\"ids\": true, \"recode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {0, 1, 1, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {1, 1, 1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash3() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 7, new double[] {1, 1, 1, 0, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + + @Test + public void testHybrid1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 9, new double[] {1, 1, 1, 0, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHybrid2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.BOOLEAN,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 10, new double[] {1, 1, 1, 1,1, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,2,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + private void runTransformTest(FrameBlock fb, String spec, MatrixBlock expected) throws Exception { + try { + + getAndLoadTestConfiguration(TEST_NAME1); + + String inF = input("F-In"); + String inS = input("spec"); + + TestUtils.writeTestFrame(inF, fb, fb.getSchema(), FileFormat.CSV); + TestUtils.writeTestScalar(input("spec"), spec); + + String out = output("ret"); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-args", inF, inS, out, expected.getNumColumns() + ""}; + + runTest(true, false, null, -1); + + MatrixBlock result = TestUtils.readBinary(out); + + TestUtils.compareMatrices(expected, result, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + +} diff --git a/src/test/scripts/functions/misc/ToString12.dml b/src/test/scripts/functions/misc/ToString12.dml new file mode 100644 index 00000000000..4f120630b75 --- /dev/null +++ b/src/test/scripts/functions/misc/ToString12.dml @@ -0,0 +1,24 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = matrix($2, rows=1, cols=1) +str = toString(X, rows=3, cols=3, decimal=$3) +write(str, $1) diff --git a/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml new file mode 100644 index 00000000000..5d7bb35a250 --- /dev/null +++ b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($1, data_type="frame", format="csv"); + +jspec = read($2, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +Cm = getCategoricalMask(M, jspec) +expectedColumns = $4 +if(ncol(Cm) != expectedColumns){ + stop("Wrong number of metadata columns in categorical mask") +} +# print mean to verify that Cm is a matrix, not a Frame according to compiler +print(mean(Cm)) + +write(Cm, $3, format="csv"); +