From 3f4e95901b43bd5f734e7c2b1594d3cd702aa9b0 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 23 Oct 2025 14:24:23 +0200 Subject: [PATCH 1/4] [SYSTEMDS-3927] Out-of-core centralMoment operations --- .../org/apache/sysds/lops/CentralMoment.java | 2 +- .../instructions/OOCInstructionParser.java | 3 + .../ooc/AggregateUnaryOOCInstruction.java | 7 + .../ooc/CentralMomentOOCInstruction.java | 150 ++++++++++++++++++ .../ooc/ComputationOOCInstruction.java | 8 + .../instructions/ooc/OOCInstruction.java | 2 +- .../test/functions/ooc/CentralMomentTest.java | 142 +++++++++++++++++ .../ooc/CentralMomentWeightsTest.java | 148 +++++++++++++++++ .../scripts/functions/ooc/CentralMoment.dml | 25 +++ .../functions/ooc/CentralMomentWeights.dml | 26 +++ 10 files changed, 511 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java create mode 100644 src/test/scripts/functions/ooc/CentralMoment.dml create mode 100644 src/test/scripts/functions/ooc/CentralMomentWeights.dml diff --git a/src/main/java/org/apache/sysds/lops/CentralMoment.java b/src/main/java/org/apache/sysds/lops/CentralMoment.java index f2048f7e5cc..b8907fd79f3 100644 --- a/src/main/java/org/apache/sysds/lops/CentralMoment.java +++ b/src/main/java/org/apache/sysds/lops/CentralMoment.java @@ -97,7 +97,7 @@ public String getInstructions(String input1, String input2, String input3, Strin getInputs().get(2).prepScalarInputOperand(getExecType()), prepOutputOperand(output)); } - if( getExecType() == ExecType.CP || getExecType() == ExecType.FED ) { + if(getExecType() == ExecType.CP || getExecType() == ExecType.FED || getExecType() == ExecType.OOC) { sb.append(OPERAND_DELIMITOR); sb.append(_numThreads); if ( getExecType() == ExecType.FED ){ diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index c98301dcc60..9c0f0f2e0f4 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.AggregateUnaryOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.BinaryOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.CentralMomentOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; import org.apache.sysds.runtime.instructions.ooc.ReblockOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TSMMOOCInstruction; @@ -69,6 +70,8 @@ public static OOCInstruction parseSingleInstruction(InstructionType ooctype, Str return TransposeOOCInstruction.parseInstruction(str); case Tee: return TeeOOCInstruction.parseInstruction(str); + case CentralMoment: + return CentralMomentOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index b71cdaaeb51..c01fb3fa376 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.Operator; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -49,6 +50,12 @@ protected AggregateUnaryOOCInstruction(OOCType type, AggregateUnaryOperator auop _aop = aop; } + protected AggregateUnaryOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, + CPOperand out, String opcode, String istr) { + super(type, op, in1, in2, in3, out, opcode, istr); + _aop = null; + } + public static AggregateUnaryOOCInstruction parseInstruction(String str) { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); InstructionUtils.checkNumFields(parts, 2); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java new file mode 100644 index 00000000000..49fc2b71f43 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.cp.*; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.matrix.operators.CMOperator; + +import java.util.*; + +public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction { + + private CentralMomentOOCInstruction(CMOperator cm, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, + String opcode, String str) { + super(OOCType.CM, cm, in1, in2, in3, out, opcode, str); + } + + public static CentralMomentOOCInstruction parseInstruction(String str) { + CentralMomentCPInstruction cpInst = CentralMomentCPInstruction.parseInstruction(str); + return parseInstruction(cpInst); + } + + public static CentralMomentOOCInstruction parseInstruction(CentralMomentCPInstruction inst) { + return new CentralMomentOOCInstruction((CMOperator) inst.getOperator(), inst.input1, inst.input2, inst.input3, + inst.output, inst.getOpcode(), inst.getInstructionString()); + } + + @Override + public void processInstruction(ExecutionContext ec) { + String output_name = output.getName(); + + /* + * The "order" of the central moment in the instruction can + * be set to INVALID when the exact value is unknown at + * compilation time. We first need to determine the exact + * order and update the CMOperator, if needed. + */ + + MatrixObject matObj = ec.getMatrixObject(input1.getName()); + LocalTaskQueue qIn = matObj.getStreamHandle(); + + CPOperand scalarInput = (input3 == null ? input2 : input3); + ScalarObject order = ec.getScalarInput(scalarInput); + + CMOperator cm_op = ((CMOperator) _optr); + if(cm_op.getAggOpType() == CMOperator.AggregateOperationTypes.INVALID) + cm_op = cm_op.setCMAggOp((int) order.getLongValue()); + + CMOperator finalCm_op = cm_op; + + List cmObjs = new ArrayList<>(); + + if(input3 == null) { + try { + IndexedMatrixValue tmp = null; + + while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { + // We only handle MatrixBlock, other types of MatrixValue will fail here + cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op)); + } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + } + else { + // Here we use a hash join approach + // Note that this may keep blocks in the cache for a while, depending on when a matching block arrives in the stream + MatrixObject wtObj = ec.getMatrixObject(input2.getName()); + LocalTaskQueue wIn = wtObj.getStreamHandle(); + + try { + IndexedMatrixValue tmp = qIn.dequeueTask(); + IndexedMatrixValue tmpW = wIn.dequeueTask(); + Map left = new HashMap<>(); + Map right = new HashMap<>(); + + boolean cont = tmp != LocalTaskQueue.NO_MORE_TASKS || tmpW != LocalTaskQueue.NO_MORE_TASKS; + + while(cont) { + cont = false; + + if(tmp != LocalTaskQueue.NO_MORE_TASKS) { + MatrixValue weights = right.remove(tmp.getIndexes()); + + if(weights != null) + cmObjs.add(((MatrixBlock) tmp.getValue()).cmOperations(cm_op, (MatrixBlock) weights)); + else + left.put(tmp.getIndexes(), tmp.getValue()); + + tmp = qIn.dequeueTask(); + cont = tmp != LocalTaskQueue.NO_MORE_TASKS; + } + + if(tmpW != LocalTaskQueue.NO_MORE_TASKS) { + MatrixValue q = left.remove(tmpW.getIndexes()); + + if(q != null) + cmObjs.add(((MatrixBlock) q).cmOperations(cm_op, (MatrixBlock) tmpW.getValue())); + else + right.put(tmpW.getIndexes(), tmpW.getValue()); + + tmpW = wIn.dequeueTask(); + cont |= tmpW != LocalTaskQueue.NO_MORE_TASKS; + } + } + + if (!left.isEmpty() || !right.isEmpty()) + throw new DMLRuntimeException("Unmatched blocks: values=" + left.size() + ", weights=" + right.size()); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + } + + Optional res = cmObjs.stream() + .reduce((arg0, arg1) -> (CM_COV_Object) finalCm_op.fn.execute(arg0, arg1)); + + try { + ec.setScalarOutput(output_name, new DoubleObject(res.get().getRequiredResult(finalCm_op))); + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java index 5552017493e..4dcdffcb0dc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java @@ -42,6 +42,14 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP output = out; } + protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) { + super(type, op, opcode, istr); + input1 = in1; + input2 = in2; + input3 = in3; + output = out; + } + public String getOutputVariableName() { return output.getName(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index ff9046d490b..5b1c7666612 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -33,7 +33,7 @@ public abstract class OOCInstruction extends Instruction { protected static final Log LOG = LogFactory.getLog(OOCInstruction.class.getName()); public enum OOCType { - Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, + Reblock, Tee, Binary, Unary, AggregateUnary, AggregateBinary, MAPMM, MMTSJ, Reorg, CM } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java new file mode 100644 index 00000000000..00b880e749b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; + +public class CentralMomentTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "CentralMoment"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CentralMomentTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1871; + private final static int maxVal = 7; + private final static double sparsity1 = 0.65; + private final static double sparsity2 = 0.05; + private final int k = 1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testCentralMoment2Dense() { + runCentralMomentTest(2, false); + } + + @Test + public void testCentralMoment3Dense() { + runCentralMomentTest(3, false); + } + + @Test + public void testCentralMoment4Dense() { + runCentralMomentTest(4, false); + } + + @Test + public void testCentralMoment2Sparse() { + runCentralMomentTest(2, true); + } + + @Test + public void testCentralMoment3Sparse() { + runCentralMomentTest(3, true); + } + + @Test + public void testCentralMoment4Sparse() { + runCentralMomentTest(4, true); + } + + private void runCentralMomentTest(int order, boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), + Integer.toString(order), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, 1, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 1, 1000, A_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, 1, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check Central Moment OOC + Assert.assertTrue("OOC wasn't used for CentralMoment", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CM)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME), Integer.toString(order), + output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + HashMap ret1 = readDMLMatrixFromOutputDir(OUTPUT_NAME); + HashMap ret2 = readDMLMatrixFromOutputDir(OUTPUT_NAME + "_target"); + TestUtils.compareMatrices(ret1, ret2, eps, "Ret-1", "Ret-2"); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java new file mode 100644 index 00000000000..2e8293b2691 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; + +public class CentralMomentWeightsTest extends AutomatedTestBase { + private final static String TEST_NAME1 = "CentralMomentWeights"; + private final static String TEST_DIR = "functions/ooc/"; + private final static String TEST_CLASS_DIR = TEST_DIR + CentralMomentWeightsTest.class.getSimpleName() + "/"; + private final static double eps = 1e-8; + private static final String INPUT_NAME = "X"; + private static final String INPUT_NAME_W = "W"; + private static final String OUTPUT_NAME = "res"; + + private final static int rows = 1871; + private final static int maxVal = 7; + private final static double sparsity1 = 0.65; + private final static double sparsity2 = 0.05; + private final int k = 1; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + TestConfiguration config = new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1); + addTestConfiguration(TEST_NAME1, config); + } + + @Test + public void testCentralMoment2Dense() { + runCentralMomentTest(2, false); + } + + @Test + public void testCentralMoment3Dense() { + runCentralMomentTest(3, false); + } + + @Test + public void testCentralMoment4Dense() { + runCentralMomentTest(4, false); + } + + @Test + public void testCentralMoment2Sparse() { + runCentralMomentTest(2, true); + } + + @Test + public void testCentralMoment3Sparse() { + runCentralMomentTest(3, true); + } + + @Test + public void testCentralMoment4Sparse() { + runCentralMomentTest(4, true); + } + + private void runCentralMomentTest(int order, boolean sparse) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME1); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", input(INPUT_NAME), input(INPUT_NAME_W), + Integer.toString(order), output(OUTPUT_NAME)}; + + // 1. Generate the data in-memory as MatrixBlock objects + double[][] A_data = getRandomMatrix(rows, 1, 1, maxVal, sparse ? sparsity2 : sparsity1, 7); + double[][] W_data = getRandomMatrix(rows, 1, 0, 1, 1.0, 7); + + // 2. Convert the double arrays to MatrixBlock objects + MatrixBlock A_mb = DataConverter.convertToMatrixBlock(A_data); + MatrixBlock W_mb = DataConverter.convertToMatrixBlock(W_data); + + // 3. Create a binary matrix writer + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + + // 4. Write matrix A to a binary SequenceFile + writer.writeMatrixToHDFS(A_mb, input(INPUT_NAME), rows, 1, 1000, A_mb.getNonZeros()); + writer.writeMatrixToHDFS(W_mb, input(INPUT_NAME_W), rows, 1, 1000, W_mb.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, 1, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_W + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, 1, 1000, A_mb.getNonZeros()), Types.FileFormat.BINARY); + + runTest(true, false, null, -1); + + //check tsmm OOC + Assert.assertTrue("OOC wasn't used for CentralMoment", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.CM)); + + //compare results + + // rerun without ooc flag + programArgs = new String[] {"-explain", "-stats", "-args", input(INPUT_NAME), input(INPUT_NAME_W), + Integer.toString(order), output(OUTPUT_NAME + "_target")}; + runTest(true, false, null, -1); + + // compare matrices + HashMap ret1 = readDMLMatrixFromOutputDir(OUTPUT_NAME); + HashMap ret2 = readDMLMatrixFromOutputDir(OUTPUT_NAME + "_target"); + TestUtils.compareMatrices(ret1, ret2, eps, "Ret-1", "Ret-2"); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/CentralMoment.dml b/src/test/scripts/functions/ooc/CentralMoment.dml new file mode 100644 index 00000000000..d5fda5a6a9d --- /dev/null +++ b/src/test/scripts/functions/ooc/CentralMoment.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +s = moment(A, $2); +m = as.matrix(s); + +write(m, $3, format="text"); \ No newline at end of file diff --git a/src/test/scripts/functions/ooc/CentralMomentWeights.dml b/src/test/scripts/functions/ooc/CentralMomentWeights.dml new file mode 100644 index 00000000000..a8f24d45686 --- /dev/null +++ b/src/test/scripts/functions/ooc/CentralMomentWeights.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +A = read($1); +W = read($2); +s = moment(A, W, $3); +m = as.matrix(s); + +write(m, $4, format="text"); \ No newline at end of file From f63d25fa22d6c9136f9f3c5da806ef6aab6ff05f Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 23 Oct 2025 14:26:48 +0200 Subject: [PATCH 2/4] Removing redundancies --- .../runtime/instructions/ooc/CentralMomentOOCInstruction.java | 2 +- .../org/apache/sysds/test/functions/ooc/CentralMomentTest.java | 1 - .../sysds/test/functions/ooc/CentralMomentWeightsTest.java | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java index 49fc2b71f43..8c2039bd928 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java @@ -76,7 +76,7 @@ public void processInstruction(ExecutionContext ec) { if(input3 == null) { try { - IndexedMatrixValue tmp = null; + IndexedMatrixValue tmp; while((tmp = qIn.dequeueTask()) != LocalTaskQueue.NO_MORE_TASKS) { // We only handle MatrixBlock, other types of MatrixValue will fail here diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java index 00b880e749b..79f05421ad5 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentTest.java @@ -50,7 +50,6 @@ public class CentralMomentTest extends AutomatedTestBase { private final static int maxVal = 7; private final static double sparsity1 = 0.65; private final static double sparsity2 = 0.05; - private final int k = 1; @Override public void setUp() { diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java index 2e8293b2691..994f84526f9 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CentralMomentWeightsTest.java @@ -51,7 +51,6 @@ public class CentralMomentWeightsTest extends AutomatedTestBase { private final static int maxVal = 7; private final static double sparsity1 = 0.65; private final static double sparsity2 = 0.05; - private final int k = 1; @Override public void setUp() { From 009aacb8c1694e763aa7b890ee901bf196fbd6e2 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 23 Oct 2025 16:04:57 +0200 Subject: [PATCH 3/4] Ensure matching block-sizes --- .../instructions/ooc/CentralMomentOOCInstruction.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java index 8c2039bd928..c1bf7af2e9a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java @@ -29,6 +29,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.matrix.data.MatrixValue; import org.apache.sysds.runtime.matrix.operators.CMOperator; +import org.apache.sysds.runtime.meta.DataCharacteristics; import java.util.*; @@ -91,6 +92,13 @@ public void processInstruction(ExecutionContext ec) { // Here we use a hash join approach // Note that this may keep blocks in the cache for a while, depending on when a matching block arrives in the stream MatrixObject wtObj = ec.getMatrixObject(input2.getName()); + + DataCharacteristics dc = ec.getDataCharacteristics(input1.getName()); + DataCharacteristics dcW = ec.getDataCharacteristics(input2.getName()); + + if (dc.getBlocksize() != dcW.getBlocksize()) + throw new DMLRuntimeException("Different block sizes are not yet supported"); + LocalTaskQueue wIn = wtObj.getStreamHandle(); try { From cd166afca4c039f62db37c65fe93039646574935 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 23 Oct 2025 16:20:27 +0200 Subject: [PATCH 4/4] Fix wildcard imports --- .../ooc/CentralMomentOOCInstruction.java | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java index c1bf7af2e9a..9c122662c2c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CentralMomentOOCInstruction.java @@ -23,7 +23,11 @@ import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; -import org.apache.sysds.runtime.instructions.cp.*; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction; +import org.apache.sysds.runtime.instructions.cp.DoubleObject; +import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; @@ -31,7 +35,11 @@ import org.apache.sysds.runtime.matrix.operators.CMOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; -import java.util.*; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; public class CentralMomentOOCInstruction extends AggregateUnaryOOCInstruction {