From b5f20a473110edc9f51f7384b02f1484eb18aea0 Mon Sep 17 00:00:00 2001 From: min-guk Date: Mon, 18 Nov 2024 21:31:47 +0900 Subject: [PATCH 1/5] Federated Planner MemoTable --- .../sysds/hops/fedplanner/MemoTable.java | 155 +++++++++++++++ .../federated/privacy/MemoTableTest.java | 183 ++++++++++++++++++ 2 files changed, 338 insertions(+) create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java create mode 100644 src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java new file mode 100644 index 00000000000..c7cbb1dd9a8 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.sysds.hops.Hop; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; + +/** + * A Memoization Table for managing federated plans (`FedPlan`) based on + * combinations of Hops and FTypes. Each combination is mapped to a list + * of possible execution plans, allowing for pruning and optimization. + */ +public class MemoTable { + /** + * Represents a federated execution plan with its cost and associated references. + */ + public static class FedPlan { + private final Hop hopRef; // The associated Hop object + private final double cost; // Cost of this federated plan + private final List> planRefs; // References to dependent plans + + public FedPlan(Hop hopRef, double cost, List> planRefs) { + this.hopRef = hopRef; + this.cost = cost; + this.planRefs = planRefs; + } + + public double getCost() {return cost;} + } + + // Maps combinations of Hop ID and FType to lists of FedPlans + private final Map, List> hopMemoTable = new HashMap<>(); + + /** + * Adds a single FedPlan to the memo table for a given Hop and FType. + * If the entry already exists, the new FedPlan is appended to the list. + * + * @param hop The Hop object. + * @param fType The associated FType. + * @param fedPlan The FedPlan to add. + */ + public void addFedPlan(Hop hop, FTypes.FType fType, FedPlan fedPlan) { + if (contains(hop, fType)) { + List fedPlanList = get(hop, fType); + fedPlanList.add(fedPlan); + } else { + List fedPlanList = new ArrayList<>(); + fedPlanList.add(fedPlan); + hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList); + } + } + + /** + * Adds multiple FedPlans to the memo table for a given Hop and FType. + * If the entry already exists, the new FedPlans are appended to the list. + * + * @param hop The Hop object. + * @param fType The associated FType. + * @param newFedPlanList The list of FedPlans to add. + */ + public void addFedPlanList(Hop hop, FTypes.FType fType, List fedPlanList) { + if (contains(hop, fType)) { + List prevFedPlanList = get(hop, fType); + prevFedPlanList.addAll(fedPlanList); + } else { + assert !fedPlanList.isEmpty() : "FedPlan list should not be empty"; + hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList); + } + } + + /** + * Retrieves the list of FedPlans associated with a given Hop and FType. + * + * @param hop The Hop object. + * @param fType The associated FType. + * @return The list of FedPlans, or null if no entry exists. + */ + public List get(Hop hop, FTypes.FType fType) { + return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType)); + } + + /** + * Checks if the memo table contains an entry for a given Hop and FType. + * + * @param hop The Hop object. + * @param fType The associated FType. + * @return True if the entry exists, false otherwise. + */ + public boolean contains(Hop hop, FTypes.FType fType) { + return hopMemoTable.containsKey(new ImmutablePair<>(hop.getHopID(), fType)); + } + + /** + * Prunes the FedPlans associated with a specific Hop and FType, + * keeping only the plan with the minimum cost. + * + * @param hop The Hop object. + * @param fType The associated FType. + */ + public void prunePlan(Hop hop, FTypes.FType fType) { + prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType))); + } + + /** + * Prunes all entries in the memo table, retaining only the minimum-cost + * FedPlan for each entry. + */ + public void pruneAll() { + for (Map.Entry, List> entry : hopMemoTable.entrySet()) { + prunePlan(entry.getValue()); + } + } + + /** + * Prunes the given list of FedPlans to retain only the plan with the minimum cost. + * + * @param fedPlanList The list of FedPlans to prune. + */ + private void prunePlan(List fedPlanList) { + if (fedPlanList.size() > 1) { + // Find the FedPlan with the minimum cost + FedPlan minCostPlan = fedPlanList.stream() + .min(Comparator.comparingDouble(plan -> plan.cost)) + .orElse(null); + + // Retain only the minimum cost plan + fedPlanList.clear(); + fedPlanList.add(minCostPlan); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java b/src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java new file mode 100644 index 00000000000..5e501c1257d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.federated.privacy; + +import static org.junit.Assert.*; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.fedplanner.FTypes; +import org.apache.sysds.hops.fedplanner.MemoTable; +import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.Mockito.when; + +public class MemoTableTest { + + private MemoTable memoTable; + + @Mock + private Hop mockHop1; + + @Mock + private Hop mockHop2; + + private java.util.Random rand; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + memoTable = new MemoTable(); + + // Set up unique IDs for mock Hops + when(mockHop1.getHopID()).thenReturn(1L); + when(mockHop2.getHopID()).thenReturn(2L); + + // Initialize random generator with fixed seed for reproducible tests + rand = new java.util.Random(42); + } + + @Test + public void testAddAndGetSingleFedPlan() { + // Initialize test data + List> planRefs = new ArrayList<>(); + FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); + + // Verify initial state + List result = memoTable.get(mockHop1, FTypes.FType.FULL); + assertNull("Initial FedPlan list should be null before adding any plans", result); + + // Add single FedPlan + memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); + + // Verify after addition + result = memoTable.get(mockHop1, FTypes.FType.FULL); + assertNotNull("FedPlan list should exist after adding a plan", result); + assertEquals("FedPlan list should contain exactly one plan", 1, result.size()); + assertEquals("FedPlan cost should be exactly 10.0", 10.0, result.get(0).getCost(), 0.001); + } + + @Test + public void testAddMultipleDuplicatedFedPlans() { + // Initialize test data with duplicate costs + List> planRefs = new ArrayList<>(); + List fedPlans = new ArrayList<>(); + fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique cost + fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First duplicate + fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second duplicate + + // Add multiple plans including duplicates + memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans); + + // Verify handling of duplicate plans + List result = memoTable.get(mockHop1, FTypes.FType.FULL); + assertNotNull("FedPlan list should exist after adding multiple plans", result); + assertEquals("FedPlan list should maintain all plans including duplicates", 3, result.size()); + } + + @Test + public void testContains() { + // Initialize test data + List> planRefs = new ArrayList<>(); + FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); + + // Verify initial state + assertFalse("MemoTable should not contain any entries initially", + memoTable.contains(mockHop1, FTypes.FType.FULL)); + + // Add plan and verify presence + memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); + + assertTrue("MemoTable should contain entry after adding FedPlan", + memoTable.contains(mockHop1, FTypes.FType.FULL)); + assertFalse("MemoTable should not contain entries for different Hop", + memoTable.contains(mockHop2, FTypes.FType.FULL)); + } + + @Test + public void testPrunePlanPruneAll() { + // Initialize base test data + List> planRefs = new ArrayList<>(); + // Create separate FedPlan lists for independent testing of each Hop + List fedPlans1 = new ArrayList<>(); // Plans for mockHop1 + List fedPlans2 = new ArrayList<>(); // Plans for mockHop2 + + // Generate random cost FedPlans for both Hops + double minCost = Double.MAX_VALUE; + int size = 100; + for(int i = 0; i < size; i++) { + double cost = rand.nextDouble() * 1000; // Random cost between 0 and 1000 + fedPlans1.add(new FedPlan(mockHop1, cost, planRefs)); + fedPlans2.add(new FedPlan(mockHop2, cost, planRefs)); + minCost = Math.min(minCost, cost); + } + + // Add FedPlan lists to MemoTable + memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans1); + memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, fedPlans2); + + // Test selective pruning on mockHop1 + memoTable.prunePlan(mockHop1, FTypes.FType.FULL); + + // Get results for verification + List result1 = memoTable.get(mockHop1, FTypes.FType.FULL); + List result2 = memoTable.get(mockHop2, FTypes.FType.FULL); + + // Verify selective pruning results + assertNotNull("Pruned mockHop1 should maintain a FedPlan list", result1); + assertEquals("Pruned mockHop1 should contain exactly one minimum cost plan", 1, result1.size()); + assertEquals("Pruned mockHop1's plan should have the minimum cost", minCost, result1.get(0).getCost(), 0.001); + + // Verify unpruned Hop state + assertNotNull("Unpruned mockHop2 should maintain a FedPlan list", result2); + assertEquals("Unpruned mockHop2 should maintain all original plans", size, result2.size()); + + // Add additional plans to both Hops + for(int i = 0; i < size; i++) { + double cost = rand.nextDouble() * 1000; + memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new FedPlan(mockHop1, cost, planRefs)); + memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new FedPlan(mockHop2, cost, planRefs)); + minCost = Math.min(minCost, cost); + } + + // Test global pruning + memoTable.pruneAll(); + + // Verify global pruning results + assertNotNull("mockHop1 should maintain a FedPlan list after global pruning", result1); + assertEquals("mockHop1 should contain exactly one minimum cost plan after global pruning", + 1, result1.size()); + assertEquals("mockHop1's plan should have the global minimum cost", + minCost, result1.get(0).getCost(), 0.001); + + assertNotNull("mockHop2 should maintain a FedPlan list after global pruning", result2); + assertEquals("mockHop2 should contain exactly one minimum cost plan after global pruning", + 1, result2.size()); + assertEquals("mockHop2's plan should have the global minimum cost", + minCost, result2.get(0).getCost(), 0.001); + } +} From fc8e32b16755338d9c79099261e362dd1f4ffb51 Mon Sep 17 00:00:00 2001 From: min-guk Date: Thu, 28 Nov 2024 21:22:46 +0900 Subject: [PATCH 2/5] Implement FederatedMemoTable, FederatedPlanCostEstimator, FederatedPlanCostEnumerator --- .../hops/fedplanner/FederatedMemoTable.java | 175 ++++++++++++++++++ .../FederatedPlanCostEnumerator.java | 104 +++++++++++ .../FederatedPlanCostEstimator.java | 103 +++++++++++ .../sysds/hops/fedplanner/MemoTable.java | 160 ---------------- 4 files changed, 382 insertions(+), 160 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java delete mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java new file mode 100644 index 00000000000..0d7e4876e69 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.sysds.hops.Hop; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.ArrayList; +import java.util.Map; + +/** + * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. + * This table stores and manages different execution plan variants for each Hop and fedOutType combination, + * facilitating the optimization of federated execution plans. + */ +public class FederatedMemoTable { + /** + * Represents a collection of federated execution plan variants for a specific Hop. + * Contains cost information and references to the associated plans. + */ + public static class FedPlanVariants { + protected final Hop hopRef; // Reference to the associated Hop + protected double currentCost; // Current execution cost (compute + memory access) + protected double netTransferCost; // Network transfer cost + protected List _fedPlanVariants; // List of plan variants + + public FedPlanVariants(Hop hopRef) { + this.hopRef = hopRef; + this.currentCost = 0; + this.netTransferCost = 0; + this._fedPlanVariants = new ArrayList<>(); + } + + public void add(FedPlan fedPlan) { + _fedPlanVariants.add(fedPlan); + } + + public int size() {return _fedPlanVariants.size();} + + public FedPlan get(int index) {return _fedPlanVariants.get(index);} + + public List getFedPlanVariants() {return _fedPlanVariants;} + } + + /** + * Represents a single federated execution plan with its associated costs and dependencies. + * Contains: + * 1. currentCost: Cost of current hop (compute + input/output memory access) + * 2. cumulativeCost: Total cost including this plan and all child plans + * 3. netTransferCost: Network transfer cost for this plan + */ + public static class FedPlan { + private double cumulativeCost; // Total cost including child plans + private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) + private final FedPlanVariants fedPlanVariantList; // Reference to variant list + private List> metaChildFedPlans; // Child plan references + private List selectedFedPlans; // Selected child plans + + public FedPlan(FederatedOutput fedOutType, List> planChilds, FedPlanVariants fedPlanVariants) { + this.fedOutType = fedOutType; + this.cumulativeCost = 0; + this.metaChildFedPlans = planChilds; + this.selectedFedPlans = new ArrayList<>(); + this.fedPlanVariantList = fedPlanVariants; + } + + public Hop getHopRef() {return fedPlanVariantList.hopRef;} + + public FederatedOutput getFedOutType() {return fedOutType;} + + public double getCurrentCost() {return fedPlanVariantList.currentCost;} + + public double getNetTransferCost() {return fedPlanVariantList.netTransferCost;} + + public double getCumulativeCost() {return cumulativeCost;} + + /** + * Calculates the cost from parent's perspective based on output type compatibility. + * Returns cumulative cost if output types match, otherwise adds network transfer cost. + */ + public double getParentViewCost(FederatedOutput parentFedOutType) { + if (parentFedOutType == fedOutType){ + return cumulativeCost; + } + return cumulativeCost + fedPlanVariantList.netTransferCost; + } + + public List> getMetaChildFedPlans() {return metaChildFedPlans;} + + public void setCurrentCost(double currentCost) {fedPlanVariantList.currentCost = currentCost;} + + public void setNetTransferCost(double netTransferCost) {fedPlanVariantList.netTransferCost = netTransferCost;} + + public void setCumulativeCost(double cumulativeCost) {this.cumulativeCost = cumulativeCost;} + + public void putChildFedPlan(FedPlan childFedPlan) {selectedFedPlans.add(childFedPlan);} + } + + // Maps Hop ID and fedOutType pairs to their plan variants + private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); + + /** + * Adds a new federated plan to the memo table. + * Creates a new variant list if none exists for the given Hop and fedOutType. + * + * @param hop The Hop node + * @param fedOutType The federated output type + * @param planChilds List of child plan references + * @return The newly created FedPlan + */ + public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List> planChilds) { + long hopID = hop.getHopID(); + FedPlanVariants fedPlanVariantList; + + if (contains(hopID, fedOutType)) { + fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } else { + fedPlanVariantList = new FedPlanVariants(hop); + hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); + } + + FedPlan newPlan = new FedPlan(fedOutType, planChilds, fedPlanVariantList); + fedPlanVariantList.add(newPlan); + + return newPlan; + } + + /** + * Retrieves the minimum cost child plan considering the parent's output type. + * The cost is calculated using getParentViewCost to account for potential type mismatches. + */ + public FedPlan getMinCostChildFedPlan(long childHopID, FederatedOutput childFedOutType, FederatedOutput currentFedOutType) { + FedPlanVariants fedPlanVariantList = hopMemoTable.get(new ImmutablePair<>(childHopID, childFedOutType)); + return fedPlanVariantList._fedPlanVariants.stream() + .min(Comparator.comparingDouble(plan -> plan.getParentViewCost(currentFedOutType))) + .orElse(null); + } + + public FedPlanVariants getFedPlanVariantList(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); + } + + /** + * Checks if the memo table contains an entry for a given Hop and fedOutType. + * + * @param hopID The Hop ID. + * @param fedOutType The associated fedOutType. + * @return True if the entry exists, false otherwise. + */ + public boolean contains(long hopID, FederatedOutput fedOutType) { + return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java new file mode 100644 index 00000000000..da9b7891417 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -0,0 +1,104 @@ +package org.apache.sysds.hops.fedplanner; +import java.util.ArrayList; +import java.util.List; +import java.util.Comparator; +import java.util.Objects; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlanVariants; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +/** + * Enumerates and evaluates all possible federated execution plans for a given Hop DAG. + * Works with FederatedMemoTable to store plan variants and FederatedPlanCostEstimator + * to compute their costs. + */ +public class FederatedPlanCostEnumerator { + /** + * Entry point for federated plan enumeration. Creates a memo table and returns + * the minimum cost plan for the entire DAG. + */ + public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { + // Create new memo table to store all plan variants + FederatedMemoTable memoTable = new FederatedMemoTable(); + + // Recursively enumerate all possible plans + enumerateFederatedPlanCost(rootHop, memoTable); + + // Return the minimum cost plan for the root node + + return getMinCostRootFedPlan(rootHop.getHopID(), memoTable); + } + + /** + * Recursively enumerates all possible federated execution plans for a Hop DAG. + * For each node: + * 1. First processes all input nodes recursively if not already processed + * 2. Generates all possible combinations of federation types (FOUT/LOUT) for inputs + * 3. Creates and evaluates both FOUT and LOUT variants for current node with each input combination + * + * The enumeration uses a bottom-up approach where: + * - Each input combination is represented by a binary number (i) + * - Bit j in i determines whether input j is FOUT (1) or LOUT (0) + * - Total number of combinations is 2^numInputs + */ + private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) { + int numInputs = hop.getInput().size(); + + // Process all input nodes first if not already in memo table + for (Hop inputHop : hop.getInput()) { + if (!memoTable.contains(inputHop.getHopID(), FederatedOutput.FOUT) + && !memoTable.contains(inputHop.getHopID(), FederatedOutput.LOUT)) { + enumerateFederatedPlanCost(inputHop, memoTable); + } + } + + // Generate all possible input combinations using binary representation + // i represents a specific combination of FOUT/LOUT for inputs + for (int i = 0; i < (1 << numInputs); i++) { + List> planChilds = new ArrayList<>(); + + // For each input, determine if it should be FOUT or LOUT based on bit j in i + for (int j = 0; j < numInputs; j++) { + Hop inputHop = hop.getInput().get(j); + // If bit j is set (1), use FOUT; otherwise use LOUT + FederatedOutput childType = ((i & (1 << j)) != 0) ? + FederatedOutput.FOUT : FederatedOutput.LOUT; + planChilds.add(Pair.of(inputHop.getHopID(), childType)); + } + + // Create and evaluate FOUT variant for current input combination + FedPlan fOutPlan = memoTable.addFedPlan(hop, FederatedOutput.FOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable); + + // Create and evaluate LOUT variant for current input combination + FedPlan lOutPlan = memoTable.addFedPlan(hop, FederatedOutput.LOUT, planChilds); + FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable); + } + } + + /** + * Returns the minimum cost plan for the root Hop, comparing both FOUT and LOUT variants. + * Used to select the final execution plan after enumeration. + */ + private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { + FedPlanVariants fOutFedPlanVariantList = memoTable.getFedPlanVariantList(HopID, FederatedOutput.FOUT); + FedPlanVariants lOutFedPlanVariantList = memoTable.getFedPlanVariantList(HopID, FederatedOutput.LOUT); + + FedPlan minFOutFedPlan = fOutFedPlanVariantList._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) + .orElse(null); + FedPlan minlOutFedPlan = lOutFedPlanVariantList._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) + .orElse(null); + + if (Objects.requireNonNull(minFOutFedPlan).getCumulativeCost() + < Objects.requireNonNull(minlOutFedPlan).getCumulativeCost()) { + return minFOutFedPlan; + } + return minlOutFedPlan; + } + +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java new file mode 100644 index 00000000000..fbf745bfbc0 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -0,0 +1,103 @@ +package org.apache.sysds.hops.fedplanner; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.cost.ComputeCost; +import org.apache.sysds.hops.fedplanner.FederatedMemoTable.FedPlan; +import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; + +/** + * Cost estimator for federated execution plans. + * Calculates computation, memory access, and network transfer costs for federated operations. + * Works in conjunction with FederatedMemoTable to evaluate different execution plan variants. + */ +public class FederatedPlanCostEstimator { + // Default value is used as a reasonable estimate since we only need + // to compare relative costs between different federated plans + // Memory bandwidth for local computations (25 GB/s) + private static final double DEFAULT_MBS_MEMORY_BANDWIDTH = 25000.0; + // Network bandwidth for data transfers between federated sites (1 Gbps) + private static final double DEFAULT_MBS_NETWORK_BANDWIDTH = 125.0; + + /** + * Computes total cost of federated plan by: + * 1. Computing current node cost (if not cached) + * 2. Adding minimum-cost child plans + * 3. Including network transfer costs when needed + * + * @param currentPlan Plan to compute cost for + * @param memoTable Table containing all plan variants + */ + public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { + double cumulativeCost = 0; + Hop currentHop = currentPlan.getHopRef(); + + // Step 1: Calculate current node costs if not already computed + if (currentPlan.getCurrentCost() == 0) { + // Compute cost for current node (computation + memory access) + cumulativeCost = computeCurrentCost(currentHop); + currentPlan.setCurrentCost(cumulativeCost); + // Calculate potential network transfer cost if federation type changes + currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); + } else { + cumulativeCost = currentPlan.getCurrentCost(); + } + + // Step 2: Process each child plan and add their costs + for (Pair planRefMeta : currentPlan.getMetaChildFedPlans()) { + // Find minimum cost child plan considering federation type compatibility + // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents + // because we're selecting child plans independently for each parent + FedPlan planRef = memoTable.getMinCostChildFedPlan( + planRefMeta.getLeft(), planRefMeta.getRight(), currentPlan.getFedOutType()); + + // Add child plan cost (includes network transfer cost if federation types differ) + cumulativeCost += planRef.getParentViewCost(currentPlan.getFedOutType()); + + // Store selected child plan + // Note: Selected plan has minimum parent view cost, not minimum cumulative cost, + // which means it highly unlikely to be found through simple pruning after enumeration + currentPlan.putChildFedPlan(planRef); + } + + // Step 3: Set final cumulative cost including current node + currentPlan.setCumulativeCost(cumulativeCost); + } + + /** + * Computes the cost for the current Hop node. + * + * @param currentHop The Hop node whose cost needs to be computed + * @return The total cost for the current node's operation + */ + private static double computeCurrentCost(Hop currentHop){ + double computeCost = ComputeCost.getHOPComputeCost(currentHop); + double inputAccessCost = computeHopMemoryAccessCost(currentHop.getInputMemEstimate()); + double ouputAccessCost = computeHopMemoryAccessCost(currentHop.getOutputMemEstimate()); + + // Compute total cost assuming: + // 1. Computation and input access can be overlapped (hence taking max) + // 2. Output access must wait for both to complete (hence adding) + return Math.max(computeCost, inputAccessCost) + ouputAccessCost; + } + + /** + * Calculates the memory access cost based on data size and memory bandwidth. + * + * @param memSize Size of data to be accessed (in bytes) + * @return Time cost for memory access (in seconds) + */ + private static double computeHopMemoryAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_MEMORY_BANDWIDTH; + } + + /** + * Calculates the network transfer cost based on data size and network bandwidth. + * Used when federation status changes between parent and child plans. + * + * @param memSize Size of data to be transferred (in bytes) + * @return Time cost for network transfer (in seconds) + */ + private static double computeHopNetworkAccessCost(double memSize) { + return memSize / (1024*1024) / DEFAULT_MBS_NETWORK_BANDWIDTH; + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java deleted file mode 100644 index f11b17b9849..00000000000 --- a/src/main/java/org/apache/sysds/hops/fedplanner/MemoTable.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.hops.fedplanner; - -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FTypes.FType; -import org.apache.commons.lang3.tuple.Pair; -import org.apache.commons.lang3.tuple.ImmutablePair; - -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.ArrayList; -import java.util.Map; - -/** - * A Memoization Table for managing federated plans (`FedPlan`) based on - * combinations of Hops and FTypes. Each combination is mapped to a list - * of possible execution plans, allowing for pruning and optimization. - */ -public class MemoTable { - - // Maps combinations of Hop ID and FType to lists of FedPlans - private final Map, List> hopMemoTable = new HashMap<>(); - - /** - * Represents a federated execution plan with its cost and associated references. - */ - public static class FedPlan { - @SuppressWarnings("unused") - private final Hop hopRef; // The associated Hop object - private final double cost; // Cost of this federated plan - @SuppressWarnings("unused") - private final List> planRefs; // References to dependent plans - - public FedPlan(Hop hopRef, double cost, List> planRefs) { - this.hopRef = hopRef; - this.cost = cost; - this.planRefs = planRefs; - } - - public double getCost() { - return cost; - } - } - - /** - * Adds a single FedPlan to the memo table for a given Hop and FType. - * If the entry already exists, the new FedPlan is appended to the list. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @param fedPlan The FedPlan to add. - */ - public void addFedPlan(Hop hop, FType fType, FedPlan fedPlan) { - if (contains(hop, fType)) { - List fedPlanList = get(hop, fType); - fedPlanList.add(fedPlan); - } else { - List fedPlanList = new ArrayList<>(); - fedPlanList.add(fedPlan); - hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList); - } - } - - /** - * Adds multiple FedPlans to the memo table for a given Hop and FType. - * If the entry already exists, the new FedPlans are appended to the list. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @param fedPlanList The list of FedPlans to add. - */ - public void addFedPlanList(Hop hop, FType fType, List fedPlanList) { - if (contains(hop, fType)) { - List prevFedPlanList = get(hop, fType); - prevFedPlanList.addAll(fedPlanList); - } else { - hopMemoTable.put(new ImmutablePair<>(hop.getHopID(), fType), fedPlanList); - } - } - - /** - * Retrieves the list of FedPlans associated with a given Hop and FType. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @return The list of FedPlans, or null if no entry exists. - */ - public List get(Hop hop, FType fType) { - return hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType)); - } - - /** - * Checks if the memo table contains an entry for a given Hop and FType. - * - * @param hop The Hop object. - * @param fType The associated FType. - * @return True if the entry exists, false otherwise. - */ - public boolean contains(Hop hop, FType fType) { - return hopMemoTable.containsKey(new ImmutablePair<>(hop.getHopID(), fType)); - } - - /** - * Prunes the FedPlans associated with a specific Hop and FType, - * keeping only the plan with the minimum cost. - * - * @param hop The Hop object. - * @param fType The associated FType. - */ - public void prunePlan(Hop hop, FType fType) { - prunePlan(hopMemoTable.get(new ImmutablePair<>(hop.getHopID(), fType))); - } - - /** - * Prunes all entries in the memo table, retaining only the minimum-cost - * FedPlan for each entry. - */ - public void pruneAll() { - for (Map.Entry, List> entry : hopMemoTable.entrySet()) { - prunePlan(entry.getValue()); - } - } - - /** - * Prunes the given list of FedPlans to retain only the plan with the minimum cost. - * - * @param fedPlanList The list of FedPlans to prune. - */ - private void prunePlan(List fedPlanList) { - if (fedPlanList.size() > 1) { - // Find the FedPlan with the minimum cost - FedPlan minCostPlan = fedPlanList.stream() - .min(Comparator.comparingDouble(plan -> plan.cost)) - .orElse(null); - - // Retain only the minimum cost plan - fedPlanList.clear(); - fedPlanList.add(minCostPlan); - } - } -} From ab8d36063e9477503d5a08d321380e235b2ab7c7 Mon Sep 17 00:00:00 2001 From: min-guk Date: Thu, 28 Nov 2024 21:24:25 +0900 Subject: [PATCH 3/5] Delete MemoTableTest.java --- .../federated/privacy/MemoTableTest.java | 183 ------------------ 1 file changed, 183 deletions(-) delete mode 100644 src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java diff --git a/src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java b/src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java deleted file mode 100644 index 5e501c1257d..00000000000 --- a/src/test/java/org/apache/sysds/test/component/federated/privacy/MemoTableTest.java +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.test.component.federated.privacy; - -import static org.junit.Assert.*; - -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FTypes; -import org.apache.sysds.hops.fedplanner.MemoTable; -import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import java.util.ArrayList; -import java.util.List; - -import static org.mockito.Mockito.when; - -public class MemoTableTest { - - private MemoTable memoTable; - - @Mock - private Hop mockHop1; - - @Mock - private Hop mockHop2; - - private java.util.Random rand; - - @Before - public void setUp() { - MockitoAnnotations.openMocks(this); - memoTable = new MemoTable(); - - // Set up unique IDs for mock Hops - when(mockHop1.getHopID()).thenReturn(1L); - when(mockHop2.getHopID()).thenReturn(2L); - - // Initialize random generator with fixed seed for reproducible tests - rand = new java.util.Random(42); - } - - @Test - public void testAddAndGetSingleFedPlan() { - // Initialize test data - List> planRefs = new ArrayList<>(); - FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); - - // Verify initial state - List result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNull("Initial FedPlan list should be null before adding any plans", result); - - // Add single FedPlan - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); - - // Verify after addition - result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNotNull("FedPlan list should exist after adding a plan", result); - assertEquals("FedPlan list should contain exactly one plan", 1, result.size()); - assertEquals("FedPlan cost should be exactly 10.0", 10.0, result.get(0).getCost(), 0.001); - } - - @Test - public void testAddMultipleDuplicatedFedPlans() { - // Initialize test data with duplicate costs - List> planRefs = new ArrayList<>(); - List fedPlans = new ArrayList<>(); - fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique cost - fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First duplicate - fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second duplicate - - // Add multiple plans including duplicates - memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans); - - // Verify handling of duplicate plans - List result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNotNull("FedPlan list should exist after adding multiple plans", result); - assertEquals("FedPlan list should maintain all plans including duplicates", 3, result.size()); - } - - @Test - public void testContains() { - // Initialize test data - List> planRefs = new ArrayList<>(); - FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); - - // Verify initial state - assertFalse("MemoTable should not contain any entries initially", - memoTable.contains(mockHop1, FTypes.FType.FULL)); - - // Add plan and verify presence - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); - - assertTrue("MemoTable should contain entry after adding FedPlan", - memoTable.contains(mockHop1, FTypes.FType.FULL)); - assertFalse("MemoTable should not contain entries for different Hop", - memoTable.contains(mockHop2, FTypes.FType.FULL)); - } - - @Test - public void testPrunePlanPruneAll() { - // Initialize base test data - List> planRefs = new ArrayList<>(); - // Create separate FedPlan lists for independent testing of each Hop - List fedPlans1 = new ArrayList<>(); // Plans for mockHop1 - List fedPlans2 = new ArrayList<>(); // Plans for mockHop2 - - // Generate random cost FedPlans for both Hops - double minCost = Double.MAX_VALUE; - int size = 100; - for(int i = 0; i < size; i++) { - double cost = rand.nextDouble() * 1000; // Random cost between 0 and 1000 - fedPlans1.add(new FedPlan(mockHop1, cost, planRefs)); - fedPlans2.add(new FedPlan(mockHop2, cost, planRefs)); - minCost = Math.min(minCost, cost); - } - - // Add FedPlan lists to MemoTable - memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans1); - memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, fedPlans2); - - // Test selective pruning on mockHop1 - memoTable.prunePlan(mockHop1, FTypes.FType.FULL); - - // Get results for verification - List result1 = memoTable.get(mockHop1, FTypes.FType.FULL); - List result2 = memoTable.get(mockHop2, FTypes.FType.FULL); - - // Verify selective pruning results - assertNotNull("Pruned mockHop1 should maintain a FedPlan list", result1); - assertEquals("Pruned mockHop1 should contain exactly one minimum cost plan", 1, result1.size()); - assertEquals("Pruned mockHop1's plan should have the minimum cost", minCost, result1.get(0).getCost(), 0.001); - - // Verify unpruned Hop state - assertNotNull("Unpruned mockHop2 should maintain a FedPlan list", result2); - assertEquals("Unpruned mockHop2 should maintain all original plans", size, result2.size()); - - // Add additional plans to both Hops - for(int i = 0; i < size; i++) { - double cost = rand.nextDouble() * 1000; - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new FedPlan(mockHop1, cost, planRefs)); - memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new FedPlan(mockHop2, cost, planRefs)); - minCost = Math.min(minCost, cost); - } - - // Test global pruning - memoTable.pruneAll(); - - // Verify global pruning results - assertNotNull("mockHop1 should maintain a FedPlan list after global pruning", result1); - assertEquals("mockHop1 should contain exactly one minimum cost plan after global pruning", - 1, result1.size()); - assertEquals("mockHop1's plan should have the global minimum cost", - minCost, result1.get(0).getCost(), 0.001); - - assertNotNull("mockHop2 should maintain a FedPlan list after global pruning", result2); - assertEquals("mockHop2 should contain exactly one minimum cost plan after global pruning", - 1, result2.size()); - assertEquals("mockHop2's plan should have the global minimum cost", - minCost, result2.get(0).getCost(), 0.001); - } -} From a5d4020eebf553b83b5ddb8b4c511a77fc8fbc68 Mon Sep 17 00:00:00 2001 From: min-guk Date: Fri, 20 Dec 2024 21:27:01 +0900 Subject: [PATCH 4/5] Update MemoTable, Cost Estimator, Cost Enumerator --- .../hops/fedplanner/FederatedMemoTable.java | 234 +++++++++++------- .../FederatedPlanCostEnumerator.java | 23 +- .../FederatedPlanCostEstimator.java | 24 +- .../component/federated/MemoTableTest.java | 186 -------------- .../FederatedPlanCostEnumeratorTest.java | 99 ++++++++ 5 files changed, 267 insertions(+), 299 deletions(-) delete mode 100644 src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index 0d7e4876e69..049e91a638e 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -29,6 +29,8 @@ import java.util.List; import java.util.ArrayList; import java.util.Map; +import java.util.HashSet; +import java.util.Set; /** * A Memoization Table for managing federated plans (FedPlan) based on combinations of Hops and fedOutTypes. @@ -36,88 +38,6 @@ * facilitating the optimization of federated execution plans. */ public class FederatedMemoTable { - /** - * Represents a collection of federated execution plan variants for a specific Hop. - * Contains cost information and references to the associated plans. - */ - public static class FedPlanVariants { - protected final Hop hopRef; // Reference to the associated Hop - protected double currentCost; // Current execution cost (compute + memory access) - protected double netTransferCost; // Network transfer cost - protected List _fedPlanVariants; // List of plan variants - - public FedPlanVariants(Hop hopRef) { - this.hopRef = hopRef; - this.currentCost = 0; - this.netTransferCost = 0; - this._fedPlanVariants = new ArrayList<>(); - } - - public void add(FedPlan fedPlan) { - _fedPlanVariants.add(fedPlan); - } - - public int size() {return _fedPlanVariants.size();} - - public FedPlan get(int index) {return _fedPlanVariants.get(index);} - - public List getFedPlanVariants() {return _fedPlanVariants;} - } - - /** - * Represents a single federated execution plan with its associated costs and dependencies. - * Contains: - * 1. currentCost: Cost of current hop (compute + input/output memory access) - * 2. cumulativeCost: Total cost including this plan and all child plans - * 3. netTransferCost: Network transfer cost for this plan - */ - public static class FedPlan { - private double cumulativeCost; // Total cost including child plans - private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) - private final FedPlanVariants fedPlanVariantList; // Reference to variant list - private List> metaChildFedPlans; // Child plan references - private List selectedFedPlans; // Selected child plans - - public FedPlan(FederatedOutput fedOutType, List> planChilds, FedPlanVariants fedPlanVariants) { - this.fedOutType = fedOutType; - this.cumulativeCost = 0; - this.metaChildFedPlans = planChilds; - this.selectedFedPlans = new ArrayList<>(); - this.fedPlanVariantList = fedPlanVariants; - } - - public Hop getHopRef() {return fedPlanVariantList.hopRef;} - - public FederatedOutput getFedOutType() {return fedOutType;} - - public double getCurrentCost() {return fedPlanVariantList.currentCost;} - - public double getNetTransferCost() {return fedPlanVariantList.netTransferCost;} - - public double getCumulativeCost() {return cumulativeCost;} - - /** - * Calculates the cost from parent's perspective based on output type compatibility. - * Returns cumulative cost if output types match, otherwise adds network transfer cost. - */ - public double getParentViewCost(FederatedOutput parentFedOutType) { - if (parentFedOutType == fedOutType){ - return cumulativeCost; - } - return cumulativeCost + fedPlanVariantList.netTransferCost; - } - - public List> getMetaChildFedPlans() {return metaChildFedPlans;} - - public void setCurrentCost(double currentCost) {fedPlanVariantList.currentCost = currentCost;} - - public void setNetTransferCost(double netTransferCost) {fedPlanVariantList.netTransferCost = netTransferCost;} - - public void setCumulativeCost(double cumulativeCost) {this.cumulativeCost = cumulativeCost;} - - public void putChildFedPlan(FedPlan childFedPlan) {selectedFedPlans.add(childFedPlan);} - } - // Maps Hop ID and fedOutType pairs to their plan variants private final Map, FedPlanVariants> hopMemoTable = new HashMap<>(); @@ -137,12 +57,12 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List(hopID, fedOutType)); } else { - fedPlanVariantList = new FedPlanVariants(hop); + fedPlanVariantList = new FedPlanVariants(hop, fedOutType); hopMemoTable.put(new ImmutablePair<>(hopID, fedOutType), fedPlanVariantList); } - FedPlan newPlan = new FedPlan(fedOutType, planChilds, fedPlanVariantList); - fedPlanVariantList.add(newPlan); + FedPlan newPlan = new FedPlan(planChilds, fedPlanVariantList); + fedPlanVariantList.addFedPlan(newPlan); return newPlan; } @@ -151,14 +71,14 @@ public FedPlan addFedPlan(Hop hop, FederatedOutput fedOutType, List(childHopID, childFedOutType)); return fedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(plan -> plan.getParentViewCost(currentFedOutType))) + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) .orElse(null); } - public FedPlanVariants getFedPlanVariantList(long hopID, FederatedOutput fedOutType) { + public FedPlanVariants getFedPlanVariants(long hopID, FederatedOutput fedOutType) { return hopMemoTable.get(new ImmutablePair<>(hopID, fedOutType)); } @@ -172,4 +92,142 @@ public FedPlanVariants getFedPlanVariantList(long hopID, FederatedOutput fedOutT public boolean contains(long hopID, FederatedOutput fedOutType) { return hopMemoTable.containsKey(new ImmutablePair<>(hopID, fedOutType)); } + + /** + * Prunes all entries in the memo table, retaining only the minimum-cost + * FedPlan for each entry. + */ + public void pruneMemoTable() { + for (Map.Entry, FedPlanVariants> entry : hopMemoTable.entrySet()) { + List fedPlanList = entry.getValue().getFedPlanVariants(); + if (fedPlanList.size() > 1) { + // Find the FedPlan with the minimum cost + FedPlan minCostPlan = fedPlanList.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) + .orElse(null); + + // Retain only the minimum cost plan + fedPlanList.clear(); + fedPlanList.add(minCostPlan); + } + } + } + + /** + * Recursively prints a tree representation of the DAG starting from the given root FedPlan. + * Includes information about hopID, fedOutType, TotalCost, SelfCost, and NetCost for each node. + * + * @param rootFedPlan The starting point FedPlan to print + */ + public void printFedPlanTree(FedPlan rootFedPlan) { + Set visited = new HashSet<>(); + printFedPlanTreeRecursive(rootFedPlan, visited, 0, true); + } + + /** + * Helper method to recursively print the FedPlan tree. + * + * @param plan The current FedPlan to print + * @param visited Set to keep track of visited FedPlans (prevents cycles) + * @param depth The current depth level for indentation + * @param isLast Whether this node is the last child of its parent + */ + private void printFedPlanTreeRecursive(FedPlan plan, Set visited, int depth, boolean isLast) { + if (plan == null || visited.contains(plan)) { + return; + } + + visited.add(plan); + + // Create indentation and connectors for tree visualization + String indent = " ".repeat(depth); + String prefix = depth == 0 ? "└──" : + isLast ? "└─" : "├─"; + + // Print plan information + System.out.printf("%s%sHop %d [%s] (Total: %.3f, Self: %.3f, Net: %.3f)%n", + indent, prefix, + plan.getHopRef().getHopID(), + plan.getFedOutType(), + plan.getTotalCost(), + plan.getSelfCost(), + plan.getNetTransferCost()); + + // Process child nodes + List> childRefs = plan.getChildFedPlans(); + for (int i = 0; i < childRefs.size(); i++) { + Pair childRef = childRefs.get(i); + FedPlanVariants childVariants = getFedPlanVariants(childRef.getLeft(), childRef.getRight()); + if (childVariants == null || childVariants.getFedPlanVariants().isEmpty()) + continue; + + boolean isLastChild = (i == childRefs.size() - 1); + for (FedPlan childPlan : childVariants.getFedPlanVariants()) { + printFedPlanTreeRecursive(childPlan, visited, depth + 1, isLastChild); + } + } + } + + /** + * Represents a collection of federated execution plan variants for a specific Hop. + * Contains cost information and references to the associated plans. + */ + public static class FedPlanVariants { + protected final Hop hopRef; // Reference to the associated Hop + protected double selfCost; // Current execution cost (compute + memory access) + protected double netTransferCost; // Network transfer cost + private final FederatedOutput fedOutType; // Output type (FOUT/LOUT) + protected List _fedPlanVariants; // List of plan variants + + public FedPlanVariants(Hop hopRef, FederatedOutput fedOutType) { + this.hopRef = hopRef; + this.fedOutType = fedOutType; + this.selfCost = 0; + this.netTransferCost = 0; + this._fedPlanVariants = new ArrayList<>(); + } + + public int size() {return _fedPlanVariants.size();} + public void addFedPlan(FedPlan fedPlan) {_fedPlanVariants.add(fedPlan);} + public List getFedPlanVariants() {return _fedPlanVariants;} + } + + /** + * Represents a single federated execution plan with its associated costs and dependencies. + * Contains: + * 1. selfCost: Cost of current hop (compute + input/output memory access) + * 2. totalCost: Cumulative cost including this plan and all child plans + * 3. netTransferCost: Network transfer cost for this plan to parent plan. + */ + public static class FedPlan { + private double totalCost; // Total cost including child plans + private final FedPlanVariants fedPlanVariants; // Reference to variant list + private final List> childFedPlans; // Child plan references + + public FedPlan(List> childFedPlans, FedPlanVariants fedPlanVariants) { + this.totalCost = 0; + this.childFedPlans = childFedPlans; + this.fedPlanVariants = fedPlanVariants; + } + + public void setTotalCost(double totalCost) {this.totalCost = totalCost;} + public void setSelfCost(double selfCost) {fedPlanVariants.selfCost = selfCost;} + public void setNetTransferCost(double netTransferCost) {fedPlanVariants.netTransferCost = netTransferCost;} + + public Hop getHopRef() {return fedPlanVariants.hopRef;} + public FederatedOutput getFedOutType() {return fedPlanVariants.fedOutType;} + public double getTotalCost() {return totalCost;} + public double getSelfCost() {return fedPlanVariants.selfCost;} + private double getNetTransferCost() {return fedPlanVariants.netTransferCost;} + public List> getChildFedPlans() {return childFedPlans;} + + /** + * Calculates the conditional network transfer cost based on output type compatibility. + * Returns 0 if output types match, otherwise returns the network transfer cost. + */ + public double getCondNetTransferCost(FederatedOutput parentFedOutType) { + if (parentFedOutType == getFedOutType()) return 0; + return fedPlanVariants.netTransferCost; + } + } } diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java index da9b7891417..befeec15782 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEnumerator.java @@ -20,7 +20,7 @@ public class FederatedPlanCostEnumerator { * Entry point for federated plan enumeration. Creates a memo table and returns * the minimum cost plan for the entire DAG. */ - public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { + public static FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) { // Create new memo table to store all plan variants FederatedMemoTable memoTable = new FederatedMemoTable(); @@ -28,8 +28,11 @@ public static FedPlan enumerateFederatedPlanCost(Hop rootHop) { enumerateFederatedPlanCost(rootHop, memoTable); // Return the minimum cost plan for the root node + FedPlan optimalPlan = getMinCostRootFedPlan(rootHop.getHopID(), memoTable); + memoTable.pruneMemoTable(); + if (printTree) memoTable.printFedPlanTree(optimalPlan); - return getMinCostRootFedPlan(rootHop.getHopID(), memoTable); + return optimalPlan; } /** @@ -84,18 +87,18 @@ private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoT * Used to select the final execution plan after enumeration. */ private static FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) { - FedPlanVariants fOutFedPlanVariantList = memoTable.getFedPlanVariantList(HopID, FederatedOutput.FOUT); - FedPlanVariants lOutFedPlanVariantList = memoTable.getFedPlanVariantList(HopID, FederatedOutput.LOUT); + FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.FOUT); + FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FederatedOutput.LOUT); - FedPlan minFOutFedPlan = fOutFedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) + FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) .orElse(null); - FedPlan minlOutFedPlan = lOutFedPlanVariantList._fedPlanVariants.stream() - .min(Comparator.comparingDouble(FedPlan::getCumulativeCost)) + FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream() + .min(Comparator.comparingDouble(FedPlan::getTotalCost)) .orElse(null); - if (Objects.requireNonNull(minFOutFedPlan).getCumulativeCost() - < Objects.requireNonNull(minlOutFedPlan).getCumulativeCost()) { + if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() + < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) { return minFOutFedPlan; } return minlOutFedPlan; diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java index fbf745bfbc0..8a3590cdf15 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlanCostEstimator.java @@ -28,39 +28,33 @@ public class FederatedPlanCostEstimator { * @param memoTable Table containing all plan variants */ public static void computeFederatedPlanCost(FedPlan currentPlan, FederatedMemoTable memoTable) { - double cumulativeCost = 0; + double totalCost = 0; Hop currentHop = currentPlan.getHopRef(); // Step 1: Calculate current node costs if not already computed - if (currentPlan.getCurrentCost() == 0) { + if (currentPlan.getSelfCost() == 0) { // Compute cost for current node (computation + memory access) - cumulativeCost = computeCurrentCost(currentHop); - currentPlan.setCurrentCost(cumulativeCost); + totalCost = computeCurrentCost(currentHop); + currentPlan.setSelfCost(totalCost); // Calculate potential network transfer cost if federation type changes currentPlan.setNetTransferCost(computeHopNetworkAccessCost(currentHop.getOutputMemEstimate())); } else { - cumulativeCost = currentPlan.getCurrentCost(); + totalCost = currentPlan.getSelfCost(); } // Step 2: Process each child plan and add their costs - for (Pair planRefMeta : currentPlan.getMetaChildFedPlans()) { + for (Pair planRefMeta : currentPlan.getChildFedPlans()) { // Find minimum cost child plan considering federation type compatibility // Note: This approach might lead to suboptimal or wrong solutions when a child has multiple parents // because we're selecting child plans independently for each parent - FedPlan planRef = memoTable.getMinCostChildFedPlan( - planRefMeta.getLeft(), planRefMeta.getRight(), currentPlan.getFedOutType()); + FedPlan planRef = memoTable.getMinCostChildFedPlan(planRefMeta.getLeft(), planRefMeta.getRight()); // Add child plan cost (includes network transfer cost if federation types differ) - cumulativeCost += planRef.getParentViewCost(currentPlan.getFedOutType()); - - // Store selected child plan - // Note: Selected plan has minimum parent view cost, not minimum cumulative cost, - // which means it highly unlikely to be found through simple pruning after enumeration - currentPlan.putChildFedPlan(planRef); + totalCost += planRef.getTotalCost() + planRef.getCondNetTransferCost(currentPlan.getFedOutType()); } // Step 3: Set final cumulative cost including current node - currentPlan.setCumulativeCost(cumulativeCost); + currentPlan.setTotalCost(totalCost); } /** diff --git a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java b/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java deleted file mode 100644 index e3928c12630..00000000000 --- a/src/test/java/org/apache/sysds/test/component/federated/MemoTableTest.java +++ /dev/null @@ -1,186 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.test.component.federated; - -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.fedplanner.FTypes; -import org.apache.sysds.hops.fedplanner.MemoTable; -import org.apache.sysds.hops.fedplanner.MemoTable.FedPlan; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.Before; -import org.junit.Test; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -import java.util.ArrayList; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.when; - -public class MemoTableTest { - - private MemoTable memoTable; - - @Mock - private Hop mockHop1; - - @Mock - private Hop mockHop2; - - private java.util.Random rand; - - @Before - public void setUp() { - MockitoAnnotations.openMocks(this); - memoTable = new MemoTable(); - - // Set up unique IDs for mock Hops - when(mockHop1.getHopID()).thenReturn(1L); - when(mockHop2.getHopID()).thenReturn(2L); - - // Initialize random generator with fixed seed for reproducible tests - rand = new java.util.Random(42); - } - - @Test - public void testAddAndGetSingleFedPlan() { - // Initialize test data - List> planRefs = new ArrayList<>(); - FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); - - // Verify initial state - List result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNull("Initial FedPlan list should be null before adding any plans", result); - - // Add single FedPlan - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); - - // Verify after addition - result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNotNull("FedPlan list should exist after adding a plan", result); - assertEquals("FedPlan list should contain exactly one plan", 1, result.size()); - assertEquals("FedPlan cost should be exactly 10.0", 10.0, result.get(0).getCost(), 0.001); - } - - @Test - public void testAddMultipleDuplicatedFedPlans() { - // Initialize test data with duplicate costs - List> planRefs = new ArrayList<>(); - List fedPlans = new ArrayList<>(); - fedPlans.add(new FedPlan(mockHop1, 10.0, planRefs)); // Unique cost - fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // First duplicate - fedPlans.add(new FedPlan(mockHop1, 20.0, planRefs)); // Second duplicate - - // Add multiple plans including duplicates - memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans); - - // Verify handling of duplicate plans - List result = memoTable.get(mockHop1, FTypes.FType.FULL); - assertNotNull("FedPlan list should exist after adding multiple plans", result); - assertEquals("FedPlan list should maintain all plans including duplicates", 3, result.size()); - } - - @Test - public void testContains() { - // Initialize test data - List> planRefs = new ArrayList<>(); - FedPlan fedPlan = new FedPlan(mockHop1, 10.0, planRefs); - - // Verify initial state - assertFalse("MemoTable should not contain any entries initially", - memoTable.contains(mockHop1, FTypes.FType.FULL)); - - // Add plan and verify presence - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, fedPlan); - - assertTrue("MemoTable should contain entry after adding FedPlan", - memoTable.contains(mockHop1, FTypes.FType.FULL)); - assertFalse("MemoTable should not contain entries for different Hop", - memoTable.contains(mockHop2, FTypes.FType.FULL)); - } - - @Test - public void testPrunePlanPruneAll() { - // Initialize base test data - List> planRefs = new ArrayList<>(); - // Create separate FedPlan lists for independent testing of each Hop - List fedPlans1 = new ArrayList<>(); // Plans for mockHop1 - List fedPlans2 = new ArrayList<>(); // Plans for mockHop2 - - // Generate random cost FedPlans for both Hops - double minCost = Double.MAX_VALUE; - int size = 100; - for(int i = 0; i < size; i++) { - double cost = rand.nextDouble() * 1000; // Random cost between 0 and 1000 - fedPlans1.add(new FedPlan(mockHop1, cost, planRefs)); - fedPlans2.add(new FedPlan(mockHop2, cost, planRefs)); - minCost = Math.min(minCost, cost); - } - - // Add FedPlan lists to MemoTable - memoTable.addFedPlanList(mockHop1, FTypes.FType.FULL, fedPlans1); - memoTable.addFedPlanList(mockHop2, FTypes.FType.FULL, fedPlans2); - - // Test selective pruning on mockHop1 - memoTable.prunePlan(mockHop1, FTypes.FType.FULL); - - // Get results for verification - List result1 = memoTable.get(mockHop1, FTypes.FType.FULL); - List result2 = memoTable.get(mockHop2, FTypes.FType.FULL); - - // Verify selective pruning results - assertNotNull("Pruned mockHop1 should maintain a FedPlan list", result1); - assertEquals("Pruned mockHop1 should contain exactly one minimum cost plan", 1, result1.size()); - assertEquals("Pruned mockHop1's plan should have the minimum cost", minCost, result1.get(0).getCost(), 0.001); - - // Verify unpruned Hop state - assertNotNull("Unpruned mockHop2 should maintain a FedPlan list", result2); - assertEquals("Unpruned mockHop2 should maintain all original plans", size, result2.size()); - - // Add additional plans to both Hops - for(int i = 0; i < size; i++) { - double cost = rand.nextDouble() * 1000; - memoTable.addFedPlan(mockHop1, FTypes.FType.FULL, new FedPlan(mockHop1, cost, planRefs)); - memoTable.addFedPlan(mockHop2, FTypes.FType.FULL, new FedPlan(mockHop2, cost, planRefs)); - minCost = Math.min(minCost, cost); - } - - // Test global pruning - memoTable.pruneAll(); - - // Verify global pruning results - assertNotNull("mockHop1 should maintain a FedPlan list after global pruning", result1); - assertEquals("mockHop1 should contain exactly one minimum cost plan after global pruning", - 1, result1.size()); - assertEquals("mockHop1's plan should have the global minimum cost", - minCost, result1.get(0).getCost(), 0.001); - - assertNotNull("mockHop2 should maintain a FedPlan list after global pruning", result2); - assertEquals("mockHop2 should contain exactly one minimum cost plan after global pruning", - 1, result2.size()); - assertEquals("mockHop2's plan should have the global minimum cost", - minCost, result2.get(0).getCost(), 0.001); - } -} diff --git a/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java new file mode 100644 index 00000000000..80c8d47f435 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.federated.privacy; + +import java.io.IOException; +import java.util.HashMap; + +import org.apache.sysds.hops.Hop; +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DMLTranslator; +import org.apache.sysds.parser.ParserFactory; +import org.apache.sysds.parser.ParserWrapper; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.hops.fedplanner.FederatedPlanCostEnumerator; + + +public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase +{ + private static final String TEST_DIR = "component/parfor/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testDependencyAnalysis1() { runTest("parfor1.dml"); } + + @Test + public void testDependencyAnalysis3() { runTest("parfor3.dml"); } + + @Test + public void testDependencyAnalysis4() { runTest("parfor4.dml"); } + + @Test + public void testDependencyAnalysis6() { runTest("parfor6.dml"); } + + @Test + public void testDependencyAnalysis7() { runTest("parfor7.dml"); } + + + private void runTest( String scriptFilename ) { + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + try { + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + //read script + String dmlScriptString = DMLScript.readDMLScript(true, HOME + scriptFilename); + + //parsing and dependency analysis + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, new HashMap<>()); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + + /* TODO) In the current DAG, Hop's _outputMemEstimate is not initialized + // This leads to incorrect fedplan generation, so test code needs to be modified + // If needed, modify costEstimator to handle cases where _outputMemEstimate is not initialized + */ + Hop hops = prog.getStatementBlocks().get(0).getHops().get(0); + FederatedPlanCostEnumerator.enumerateFederatedPlanCost(hops, true); + } + catch (IOException e) { + e.printStackTrace(); + Assert.fail(); + } + } +} From 0fcfc6d432733ea53cb93f9153041af32c7a55e3 Mon Sep 17 00:00:00 2001 From: min-guk Date: Sat, 21 Dec 2024 03:49:16 +0900 Subject: [PATCH 5/5] Update CostEnumeratorTest, printFedPlanTreeRecursive --- .../hops/fedplanner/FederatedMemoTable.java | 77 +++++++++++++++---- .../FederatedPlanCostEnumeratorTest.java | 20 +---- .../functions/federated/privacy/cost.dml | 25 ++++++ 3 files changed, 92 insertions(+), 30 deletions(-) create mode 100644 src/test/scripts/functions/federated/privacy/cost.dml diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java index 049e91a638e..8e67f283d07 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedMemoTable.java @@ -20,10 +20,10 @@ package org.apache.sysds.hops.fedplanner; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; - import java.util.Comparator; import java.util.HashMap; import java.util.List; @@ -139,19 +139,68 @@ private void printFedPlanTreeRecursive(FedPlan plan, Set visited, int d visited.add(plan); - // Create indentation and connectors for tree visualization - String indent = " ".repeat(depth); - String prefix = depth == 0 ? "└──" : - isLast ? "└─" : "├─"; - - // Print plan information - System.out.printf("%s%sHop %d [%s] (Total: %.3f, Self: %.3f, Net: %.3f)%n", - indent, prefix, - plan.getHopRef().getHopID(), - plan.getFedOutType(), - plan.getTotalCost(), - plan.getSelfCost(), - plan.getNetTransferCost()); + Hop hop = plan.getHopRef(); + StringBuilder sb = new StringBuilder(); + + // Add FedPlan information + sb.append(String.format("(%d) ", plan.getHopRef().getHopID())) + .append(plan.getHopRef().getOpString()) + .append(" [") + .append(plan.getFedOutType()) + .append("]"); + + StringBuilder childs = new StringBuilder(); + childs.append(" ("); + boolean childAdded = false; + for( Hop input : hop.getInput()){ + childs.append(childAdded?",":""); + childs.append(input.getHopID()); + childAdded = true; + } + childs.append(")"); + if( childAdded ) + sb.append(childs.toString()); + + + sb.append(String.format(" {Total: %.1f, Self: %.1f, Net: %.1f}", + plan.getTotalCost(), + plan.getSelfCost(), + plan.getNetTransferCost())); + + // Add matrix characteristics + sb.append(" [") + .append(hop.getDim1()).append(", ") + .append(hop.getDim2()).append(", ") + .append(hop.getBlocksize()).append(", ") + .append(hop.getNnz()); + + if (hop.getUpdateType().isInPlace()) { + sb.append(", ").append(hop.getUpdateType().toString().toLowerCase()); + } + sb.append("]"); + + // Add memory estimates + sb.append(" [") + .append(OptimizerUtils.toMB(hop.getInputMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getIntermediateMemEstimate())).append(", ") + .append(OptimizerUtils.toMB(hop.getOutputMemEstimate())).append(" -> ") + .append(OptimizerUtils.toMB(hop.getMemEstimate())).append("MB]"); + + // Add reblock and checkpoint requirements + if (hop.requiresReblock() && hop.requiresCheckpoint()) { + sb.append(" [rblk, chkpt]"); + } else if (hop.requiresReblock()) { + sb.append(" [rblk]"); + } else if (hop.requiresCheckpoint()) { + sb.append(" [chkpt]"); + } + + // Add execution type + if (hop.getExecType() != null) { + sb.append(", ").append(hop.getExecType()); + } + + System.out.println(sb); // Process child nodes List> childRefs = plan.getChildFedPlans(); diff --git a/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java b/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java index 80c8d47f435..57ecac158a1 100644 --- a/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java +++ b/src/test/java/org/apache/sysds/test/component/federated/privacy/FederatedPlanCostEnumeratorTest.java @@ -39,7 +39,7 @@ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase { - private static final String TEST_DIR = "component/parfor/"; + private static final String TEST_DIR = "functions/federated/privacy/"; private static final String HOME = SCRIPT_DIR + TEST_DIR; private static final String TEST_CLASS_DIR = TEST_DIR + FederatedPlanCostEnumeratorTest.class.getSimpleName() + "/"; @@ -47,20 +47,7 @@ public class FederatedPlanCostEnumeratorTest extends AutomatedTestBase public void setUp() {} @Test - public void testDependencyAnalysis1() { runTest("parfor1.dml"); } - - @Test - public void testDependencyAnalysis3() { runTest("parfor3.dml"); } - - @Test - public void testDependencyAnalysis4() { runTest("parfor4.dml"); } - - @Test - public void testDependencyAnalysis6() { runTest("parfor6.dml"); } - - @Test - public void testDependencyAnalysis7() { runTest("parfor7.dml"); } - + public void testDependencyAnalysis1() { runTest("cost.dml"); } private void runTest( String scriptFilename ) { int index = scriptFilename.lastIndexOf(".dml"); @@ -83,7 +70,8 @@ private void runTest( String scriptFilename ) { dmlt.liveVariableAnalysis(prog); dmlt.validateParseTree(prog); dmlt.constructHops(prog); - + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); /* TODO) In the current DAG, Hop's _outputMemEstimate is not initialized // This leads to incorrect fedplan generation, so test code needs to be modified // If needed, modify costEstimator to handle cases where _outputMemEstimate is not initialized diff --git a/src/test/scripts/functions/federated/privacy/cost.dml b/src/test/scripts/functions/federated/privacy/cost.dml new file mode 100644 index 00000000000..ec34d45bb65 --- /dev/null +++ b/src/test/scripts/functions/federated/privacy/cost.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +a = matrix(7,10,10); +b = a + a^2; +c = sqrt(b); +print(sum(c)); \ No newline at end of file