diff --git a/java/CrossingCRNodeCounter.java b/java/CrossingCRNodeCounter.java new file mode 100644 index 0000000..1743340 --- /dev/null +++ b/java/CrossingCRNodeCounter.java @@ -0,0 +1,241 @@ +package com.xilinx.rapidwright.examples; + +import com.xilinx.rapidwright.design.blocks.PBlock; +import com.xilinx.rapidwright.design.Design; +import com.xilinx.rapidwright.device.Device; +import com.xilinx.rapidwright.device.IntentCode; +import com.xilinx.rapidwright.device.Node; +import com.xilinx.rapidwright.device.PIP; +import com.xilinx.rapidwright.device.Tile; + +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Set; +import java.util.HashSet; + +public class CrossingCRNodeCounter { + private static Set intentCodesConsidered = new HashSet<>( + Arrays.asList( // color_index + IntentCode.INTENT_DEFAULT, // 1 + IntentCode.NODE_HLONG, // 2 + IntentCode.NODE_VLONG, // 3 + IntentCode.NODE_SINGLE, // 4 + IntentCode.NODE_DOUBLE, // 5 + IntentCode.NODE_HQUAD, // 6 + IntentCode.NODE_VQUAD, // 7 + IntentCode.NODE_LOCAL, // 8 + IntentCode.NODE_PINBOUNCE, // 9 + IntentCode.NODE_LAGUNA_DATA, // 10 + + IntentCode.NODE_GLOBAL_VDISTR, // ignored + IntentCode.NODE_GLOBAL_VROUTE, // ignored + IntentCode.NODE_GLOBAL_HDISTR, // ignored + IntentCode.NODE_GLOBAL_HROUTE // ignored + ) + ); + private static HashMap>> allPBlockDirectionToNodeCount = new HashMap<>(); + // Check whether a node has a IntentCode of our interest + // Exclude some IntentCodes including clock and control signals + private static boolean validNode(Node n) { + assert intentCodesConsidered.contains(n.getIntentCode()) : "Unexpected IntentCode: " + n.getIntentCode(); + // Exclude IntentCode.NODE_GLOBAL_VROUTE and IntentCode.NODE_GLOBAL_VDISTR + if (n.getIntentCode().toString().startsWith("NODE_GLOBAL")) { + assert n.getWireName().startsWith("CLK") : "This GLOBAL node is not a clock node: " + n; + return false; + } + // Filter out IntentCode.INTENT_DEFAULT Nodes + // ? referring to https://docs.amd.com/v/u/en-US/ug573-ultrascale-memory-resources + if (n.getIntentCode() == IntentCode.INTENT_DEFAULT) { + // Exclude CLKOUT_{NORTH\d+|SOUTH\d+} and cascade i/o of BRAMs + assert n.getWireName().matches("^(CLKOUT|BRAM|HPIO|GND|VCC).*$") : "Unexpected INTENT_DEFAULT node: " + n; + return false; + } + // DSP cascade not extracted by getIntersectingNodes + // ? referring to https://docs.amd.com/v/u/en-US/ug579-ultrascale-dsp + assert !n.getWireName().contains("DSP_") : "Unexpected DSP-related node: " + n; + return true; + } + + // Check whether a node in Set a also exists in Set b + // then, check whether the node has a valid IntentCode of our interest + // if both satisfied, add the node to the result set + public static Set getIntersectingNodes(Set a, Set b) { + Set nodesInA = getNodesInTiles(a); + Set nodesInB = getNodesInTiles(b); + + Set both = new HashSet<>(); + for (Node n : nodesInA) { + if (nodesInB.contains(n) && validNode(n)) { + both.add(n); + + // For debugging + // System.out.println(n + " " + n.getIntentCode()); // + " " + a + " " + b); + + // For highlight_objects in Vivado + // try { + // PrintWriter writer = new PrintWriter(new FileWriter("node_stats.txt", true)); + // writer.println(n + " " + n.getIntentCode());// + " " + a + " " + b); + // writer.close(); + // } catch (IOException e) { + // System.out.println("Error: " + e.getMessage()); + // e.printStackTrace(); + // } + } + } + return both; + } + + // Get the number of nodes crossing the boundary between the current pblock and the adjacent pblock(s) + // in at most 4 directions (N, S, E, W) + public static HashMap getSinglePBlockDirectionToNodeCount(Device device, int col, int row) { + // The current pblock (pCenter). + PBlock pCenter = new PBlock(device, "CLOCKREGION_X" + col + "Y" + row + ":CLOCKREGION_X" + col + "Y" + row); + // The direction (N, S, E, W) to the number of nodes crossing the boundary between the current pblock and the adjacent pblock(s). + HashMap directionToNodeCount = new HashMap(); + + // North + if (row < device.getNumOfClockRegionRows() - 1) { + PBlock pNorth = new PBlock(device, "CLOCKREGION_X" + col + "Y" + (row + 1) + ":CLOCKREGION_X" + col + "Y" + (row + 1)); + Set intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pNorth.getAllTiles()); + int intersectingNodesCount = intersectingNodes.size(); + directionToNodeCount.put("N", intersectingNodesCount); + } + // South + if (row > 0) { + if (allPBlockDirectionToNodeCount.containsKey(col) && allPBlockDirectionToNodeCount.get(col).containsKey(row - 1)) { + directionToNodeCount.put("S", allPBlockDirectionToNodeCount.get(col).get(row - 1).get("N")); + } else { + assert false : "col " + col + "; row " + row; + PBlock pSouth = new PBlock(device, "CLOCKREGION_X" + col + "Y" + (row - 1) + ":CLOCKREGION_X" + col + "Y" + (row - 1)); + Set intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pSouth.getAllTiles()); + int intersectingNodesCount = intersectingNodes.size(); + directionToNodeCount.put("S", intersectingNodesCount); + } + } + // East + if (col < device.getNumOfClockRegionsColumns() - 1) { + PBlock pEast = new PBlock(device, "CLOCKREGION_X" + (col + 1) + "Y" + row + ":CLOCKREGION_X" + (col + 1) + "Y" + row); + Set intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pEast.getAllTiles()); + int intersectingNodesCount = intersectingNodes.size(); + directionToNodeCount.put("E", intersectingNodesCount); + } + // West + if (col > 0) { + if (allPBlockDirectionToNodeCount.containsKey(col - 1) && allPBlockDirectionToNodeCount.get(col - 1).containsKey(row)) { + directionToNodeCount.put("W", allPBlockDirectionToNodeCount.get(col - 1).get(row).get("E")); + } else { + assert false : "col " + col + "; row " + row; + PBlock pWest = new PBlock(device, "CLOCKREGION_X" + (col - 1) + "Y" + row + ":CLOCKREGION_X" + (col - 1) + "Y" + row); + Set intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pWest.getAllTiles()); + int intersectingNodesCount = intersectingNodes.size(); + directionToNodeCount.put("W", intersectingNodesCount); + } + } + return directionToNodeCount; + } + + // Get a set of nodes in a set of tiles + public static Set getNodesInTiles(Set tiles) { + Set nodes = new HashSet<>(); + for (Tile t : tiles) { + for (PIP p : t.getPIPs()) { + Node start = p.getStartNode(); + Node end = p.getEndNode(); + nodes.add(start); + nodes.add(end); + } + } + nodes.remove(null); + return nodes; + } + + public static HashMap getSinglePBlockCrossingCRNodeCount(String dcp, int col, int row) { + Design design = null; + Device device = null; + try { + design = Design.readCheckpoint(dcp); + device = design.getDevice(); // TODO: directly get device bypassing design? + if (row < 0 || row >= device.getNumOfClockRegionRows() || col < 0 || col >= device.getNumOfClockRegionsColumns()) { + System.out.println("Invalid row or column number"); + return null; + } + else { + System.out.println("Device: " + device + "; col(X): " + col + "/" + device.getNumOfClockRegionsColumns() + "; row(Y): " + row + "/" + device.getNumOfClockRegionRows()); + // e.g., xcu250, col(X): 4/8, row(Y): 3/16 + } + } catch (Exception e) { + System.out.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + assert design != null : "Design is null"; + assert device != null : "Device is null"; + + // Create the result map: + HashMap directionToNodeCount = getSinglePBlockDirectionToNodeCount(device, col, row); + return directionToNodeCount; + } + + public static HashMap>> getAllPBlockCrossingCRNodeCount(String dcp) { + Design design = null; + Device device = null; + try { + design = Design.readCheckpoint(dcp); + device = design.getDevice(); // TODO: directly get device bypassing design? + System.out.println("Device: " + device + "; col(X): " + device.getNumOfClockRegionsColumns() + "; row(Y): " + device.getNumOfClockRegionRows()); + } catch (Exception e) { + System.out.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + assert design != null : "Design is null"; + assert device != null : "Device is null"; + + // Create the result map: + for (int col = 0; col < device.getNumOfClockRegionsColumns(); col++) { + for (int row = 0; row < device.getNumOfClockRegionRows(); row++) { + HashMap directionToNodeCount = getSinglePBlockDirectionToNodeCount(device, col, row); + HashMap> rowToDirectionToNodeCount = allPBlockDirectionToNodeCount.getOrDefault(col, new HashMap<>()); + rowToDirectionToNodeCount.put(row, directionToNodeCount); + allPBlockDirectionToNodeCount.put(col, rowToDirectionToNodeCount); + System.out.println("CLOCKREGION: X " + col + "/" + device.getNumOfClockRegionsColumns() + "; Y " + row + "/" + device.getNumOfClockRegionRows() + " done"); + } + } + return allPBlockDirectionToNodeCount; + } + + public static void main(String[] args) { + if(args.length != 3) { + System.out.println("USAGE: rapidwright MetricsExtractor <.dcp> "); + return; + } + Design design = null; + Device device = null; + int col = -1, row = -1; + try { + design = Design.readCheckpoint(args[0]); + col = Integer.parseInt(args[1]); + row = Integer.parseInt(args[2]); + device = design.getDevice(); // TODO: directly get device bypassing design? + if (row < 0 || row >= device.getNumOfClockRegionRows() || col < 0 || col >= device.getNumOfClockRegionsColumns()) { + System.out.println("Invalid row or column number"); + return; + } + else { + System.out.println("Device: " + device + "; col(X): " + col + "/" + device.getNumOfClockRegionsColumns() + "; row(Y): " + row + "/" + device.getNumOfClockRegionRows()); + // e.g., xcu250, col(X): 4/8, row(Y): 3/16 + } + } catch (Exception e) { + System.out.println("Error: " + e.getMessage()); + e.printStackTrace(); + } + assert design != null : "Design is null"; + assert device != null : "Device is null"; + + // Create the result map: + HashMap directionToNodeCount = getSinglePBlockDirectionToNodeCount(device, col, row); + // System.out.println("CLOCKREGION_X" + col + "Y" + row + ":CLOCKREGION_X" + col + "Y" + row); + System.out.println(directionToNodeCount); + } +} \ No newline at end of file diff --git a/java/README.md b/java/README.md new file mode 100644 index 0000000..8fad9d5 --- /dev/null +++ b/java/README.md @@ -0,0 +1,13 @@ +# RapidStream DSE Engine Java Source Code + +## How to Compile +1. Copy the source code in `/java` to `/src/com/xilinx/rapidwright/examples`. +2. Register the corresponding pass in `/src/com/xilinx/rapidwright/MainEntrypoint.java` + ```=java + static { + ... + addFunction("CrossingCRNodeCounter", CrossingCRNodeCounter::main); + ... + } + ``` +3. Recompile RapidWright with `./gradlew compileJava` \ No newline at end of file diff --git a/python/src/CrossingCRNodeCounter.py b/python/src/CrossingCRNodeCounter.py new file mode 100644 index 0000000..ccd5628 --- /dev/null +++ b/python/src/CrossingCRNodeCounter.py @@ -0,0 +1,88 @@ +from pathlib import Path +from typing import Literal +import jpype +import jpype.imports +import sys + +def get_node_counts_between_CRs( + dcp_path: Path, + num_col: int, + num_row: int, +) -> dict[Literal["N", "S", "E", "W"], int]: + """Extract number of nodes between a **single** clock region and its adjacent + ones (at clock region level). + + Args: + dcp_path: Path to a checkpoint file which contains a grid of CRs. + num_col: Number of columns (X) of the pblock grid. + num_row: Number of rows (Y) of the pblock grid. + + Returns: + A dictionrary direction to the number of nodes. For example, for Alveo + U280, num_col: 0, num_row: 0, the output is {'E': 12480, 'N': 9200}. + + CR_X0Y0 should not have "SOUTH" or "WEST" keys because those slots + do not exist. + """ + jpype.startJVM(jpype.getDefaultJVMPath()) + from com.xilinx.rapidwright.examples import CrossingCRNodeCounter + node_count_java = CrossingCRNodeCounter.getSinglePBlockCrossingCRNodeCount(jpype.JString(str(dcp_path)), num_col, num_row) # + node_count = dict(node_count_java) # dict + ''' + # enabling assertions by "-ea" + process = jpype.java.lang.Runtime.getRuntime().exec_( + "java -ea com.xilinx.rapidwright.examples.CrossingCRNodeCounter {} {} {}".format(dcp_path, num_col, num_row)) + process.waitFor() + + # get stdout + input_stream = process.getInputStream() + reader = jpype.java.io.BufferedReader(jpype.java.io.InputStreamReader(input_stream)) + + # read commandline output + line = reader.readLine() + while line is not None: + print(line) + line = reader.readLine() + ''' + return node_count + +def get_node_counts_between_CRs( + dcp_path: Path +) -> dict[int, dict[int, dict[Literal["N", "S", "E", "W"], int]]]: + """Extract number of nodes between **all pairs of** pblocks (at clock region + level). + + Args: + dcp_path: Path to a checkpoint file which contains a grid of pblocks. It is + guaranteed that the pblock will be named in the format of f"x{x}y{y}". + + Returns: + A mapping from (col, row) to direction to the number of nodes. For example, + node_count[0][0]["N"] is the number of nodes between pblock (0, 0) and + (0, 1). + + node_count[0][0] should not have "SOUTH" or "WEST" keys because those slots + do not exist. + """ + jpype.startJVM(jpype.getDefaultJVMPath(), "-ea") # enabling assertions by "-ea" + from com.xilinx.rapidwright.examples import CrossingCRNodeCounter + node_count_java = CrossingCRNodeCounter.getAllPBlockCrossingCRNodeCount(jpype.JString(str(dcp_path))) # + node_count = {_:{__:dict(dict(node_count_java[_])[__]) for __ in dict(node_count_java[_]).keys()} for _ in node_count_java.keys()} # dict + return node_count + +if __name__ == "__main__": + if len(sys.argv) not in {2, 4}: + print("Usage: python CrossingCRNodeExtractor.py \n or python CrossingCRNodeExtractor.py ") + sys.exit(1) + + try: + dcp_path = Path(sys.argv[1]) + num_col = int(sys.argv[2]) if len(sys.argv) == 4 else -1 + num_row = int(sys.argv[3]) if len(sys.argv) == 4 else -1 + except Exception as e: + print("Error: Invalid arguments.") + sys.exit(1) + + # node_counts = get_node_counts_between_CRs(dcp_path, num_col, num_row) + node_counts = get_node_counts_between_CRs(dcp_path) + print(node_counts) \ No newline at end of file diff --git a/python/src/README.md b/python/src/README.md new file mode 100644 index 0000000..caa36e0 --- /dev/null +++ b/python/src/README.md @@ -0,0 +1,3 @@ +# RapidStream DSE Engine Python Source Code + +Please finish the "How to Compile" section in `/java/README.md` before running the script in `/python`. \ No newline at end of file