Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/lops/CentralMoment.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public String getInstructions(String input1, String input2, String input3, Strin
getInputs().get(2).prepScalarInputOperand(getExecType()),
prepOutputOperand(output));
}
if( getExecType() == ExecType.CP || getExecType() == ExecType.FED ) {
if(getExecType() == ExecType.CP || getExecType() == ExecType.FED || getExecType() == ExecType.OOC) {
sb.append(OPERAND_DELIMITOR);
sb.append(_numThreads);
if ( getExecType() == ExecType.FED ){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.OOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction;
Expand Down Expand Up @@ -69,6 +70,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str
return TransposeOOCInstruction.parseInstruction(str);
case Tee:
return TeeOOCInstruction.parseInstruction(str);
case CentralMoment:
return CentralMomentOOCInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.CommonThreadPool;

Expand All @@ -49,6 +50,12 @@ protected AggregateUnaryOOCInstruction(OOCType type, AggregateUnaryOperator auop
_aop = aop;
}

protected AggregateUnaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3,
CPOperand out, String opcode, String istr) {
super(type, op, in1, in2, in3, out, opcode, istr);
_aop = null;
}

public static AggregateUnaryOOCInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
InstructionUtils.checkNumFields(parts, 2);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction {

private CentralMomentOOCInstruction(CMOperator cm, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out,
String opcode, String str) {
super(OOCType.CM, cm, in1, in2, in3, out, opcode, str);
}

public static CentralMomentOOCInstruction parseInstruction(String str) {
CentralMomentCPInstruction cpInst = CentralMomentCPInstruction.parseInstruction(str);
return parseInstruction(cpInst);
}

public static CentralMomentOOCInstruction parseInstruction(CentralMomentCPInstruction inst) {
return new CentralMomentOOCInstruction((CMOperator) inst.getOperator(), inst.input1, inst.input2, inst.input3,
inst.output, inst.getOpcode(), inst.getInstructionString());
}

@Override
public void processInstruction(ExecutionContext ec) {
String output_name = output.getName();

/*
* The "order" of the central moment in the instruction can
* be set to INVALID when the exact value is unknown at
* compilation time. We first need to determine the exact
* order and update the CMOperator, if needed.
*/

MatrixObject matObj = ec.getMatrixObject(input1.getName());
LocalTaskQueue<IndexedMatrixValue> qIn = matObj.getStreamHandle();

CPOperand scalarInput = (input3 == null ? input2 : input3);
ScalarObject order = ec.getScalarInput(scalarInput);

CMOperator cm_op = ((CMOperator) _optr);
if(cm_op.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID)
cm_op = cm_op.setCMAggOp((int) order.getLongValue());

CMOperator finalCm_op = cm_op;

List<CM_COV_Object> cmObjs = new ArrayList<>();

if(input3 == null) {
try {
IndexedMatrixValue tmp;

while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) {
// We only handle MatrixBlock, other types of MatrixValue will fail here
cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op));
}
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}
else {
// Here we use a hash join approach
// Note that this may keep blocks in the cache for a while, depending on when a matching block arrives in the stream
MatrixObject wtObj = ec.getMatrixObject(input2.getName());

DataCharacteristics dc = ec.getDataCharacteristics(input1.getName());
DataCharacteristics dcW = ec.getDataCharacteristics(input2.getName());

if (dc.getBlocksize() != dcW.getBlocksize())
throw new DMLRuntimeException("Different block sizes are not yet supported");

LocalTaskQueue<IndexedMatrixValue> wIn = wtObj.getStreamHandle();

try {
IndexedMatrixValue tmp = qIn.dequeueTask();
IndexedMatrixValue tmpW = wIn.dequeueTask();
Map<MatrixIndexes, MatrixValue> left = new HashMap<>();
Map<MatrixIndexes, MatrixValue> right = new HashMap<>();

boolean cont = tmp != LocalTaskQueue.NO_MORE_TASKS || tmpW != LocalTaskQueue.NO_MORE_TASKS;

while(cont) {
cont = false;

if(tmp != LocalTaskQueue.NO_MORE_TASKS) {
MatrixValue weights = right.remove(tmp.getIndexes());

if(weights != null)
cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op, (MatrixBlock) weights));
else
left.put(tmp.getIndexes(), tmp.getValue());

tmp = qIn.dequeueTask();
cont = tmp != LocalTaskQueue.NO_MORE_TASKS;
}

if(tmpW != LocalTaskQueue.NO_MORE_TASKS) {
MatrixValue q = left.remove(tmpW.getIndexes());

if(q != null)
cmObjs.add(((MatrixBlock) q).cmOperations(cm_op, (MatrixBlock) tmpW.getValue()));
else
right.put(tmpW.getIndexes(), tmpW.getValue());

tmpW = wIn.dequeueTask();
cont |= tmpW != LocalTaskQueue.NO_MORE_TASKS;
}
}

if (!left.isEmpty() || !right.isEmpty())
throw new DMLRuntimeException("Unmatched blocks: values=" + left.size() + ", weights=" + right.size());
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}

Optional<CM_COV_Object> res = cmObjs.stream()
.reduce((arg0, arg1) -> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1));

try {
ec.setScalarOutput(output_name, new DoubleObject(res.get().getRequiredResult(finalCm_op)));
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP
output = out;
}

protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
input1 = in1;
input2 = in2;
input3 = in3;
output = out;
}

public String getOutputVariableName() {
return output.getName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction {
protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName());

public enum OOCType {
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg,
Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM
}

protected final OOCInstruction.OOCType _ooctype;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.test.functions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.io.MatrixWriter;
import org.apache.sysds.runtime.io.MatrixWriterFactory;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
import org.junit.Test;

import java.io.IOException;
import java.util.HashMap;

public class CentralMomentTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "CentralMoment";
private final static String TEST_DIR = "functions/ooc/";
private final static String TEST_CLASS_DIR = TEST_DIR + CentralMomentTest.class.getSimpleName() + "/";
private final static double eps = 1e-8;
private static final String INPUT_NAME = "X";
private static final String OUTPUT_NAME = "res";

private final static int rows = 1871;
private final static int maxVal = 7;
private final static double sparsity1 = 0.65;
private final static double sparsity2 = 0.05;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1);
addTestConfiguration(TEST_NAME1, config);
}

@Test
public void testCentralMoment2Dense() {
runCentralMomentTest(2, false);
}

@Test
public void testCentralMoment3Dense() {
runCentralMomentTest(3, false);
}

@Test
public void testCentralMoment4Dense() {
runCentralMomentTest(4, false);
}

@Test
public void testCentralMoment2Sparse() {
runCentralMomentTest(2, true);
}

@Test
public void testCentralMoment3Sparse() {
runCentralMomentTest(3, true);
}

@Test
public void testCentralMoment4Sparse() {
runCentralMomentTest(4, true);
}

private void runCentralMomentTest(int order, boolean sparse) {
Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE);

try {
getAndLoadTestConfiguration(TEST_NAME1);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME),
Integer.toString(order), output(OUTPUT_NAME)};

// 1. Generate the data in-memory as MatrixBlock objects
double[][] A_data = getRandomMatrix(rows, 1, 1, maxVal, sparse ? sparsity2 : sparsity1, 7);

// 2. Convert the double arrays to MatrixBlock objects
MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data);

// 3. Create a binary matrix writer
MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY);

// 4. Write matrix A to a binary SequenceFile
writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 1, 1000, A_mb.getNonZeros());
HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64,
new MatrixCharacteristics(rows, 1, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY);

runTest(true, false, null, -1);

//check Central Moment OOC
Assert.assertTrue("OOC wasn't used for CentralMoment",
heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CM));

//compare results

// rerun without ooc flag
programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME), Integer.toString(order),
output(OUTPUT_NAME + "_target")};
runTest(true, false, null, -1);

// compare matrices
HashMap<MatrixValue.CellIndex, Double> ret1 = readDMLMatrixFromOutputDir(OUTPUT_NAME);
HashMap<MatrixValue.CellIndex, Double> ret2 = readDMLMatrixFromOutputDir(OUTPUT_NAME + "_target");
TestUtils.compareMatrices(ret1, ret2, eps, "Ret-1", "Ret-2");
}
catch(IOException e) {
throw new RuntimeException(e);
}
finally {
resetExecMode(platformOld);
}
}
}
Loading
Loading