Skip to content

Commit

Permalink
Passes counting node number crossing clock regions
Browse files Browse the repository at this point in the history
  • Loading branch information
RipperJ committed Apr 17, 2024
1 parent b6fb41b commit f404e96
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 0 deletions.
241 changes: 241 additions & 0 deletions java/CrossingCRNodeCounter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
package com.xilinx.rapidwright.examples;

import com.xilinx.rapidwright.design.blocks.PBlock;
import com.xilinx.rapidwright.design.Design;
import com.xilinx.rapidwright.device.Device;
import com.xilinx.rapidwright.device.IntentCode;
import com.xilinx.rapidwright.device.Node;
import com.xilinx.rapidwright.device.PIP;
import com.xilinx.rapidwright.device.Tile;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Set;
import java.util.HashSet;

public class CrossingCRNodeCounter {
private static Set<IntentCode> intentCodesConsidered = new HashSet<>(
Arrays.asList( // color_index
IntentCode.INTENT_DEFAULT, // 1
IntentCode.NODE_HLONG, // 2
IntentCode.NODE_VLONG, // 3
IntentCode.NODE_SINGLE, // 4
IntentCode.NODE_DOUBLE, // 5
IntentCode.NODE_HQUAD, // 6
IntentCode.NODE_VQUAD, // 7
IntentCode.NODE_LOCAL, // 8
IntentCode.NODE_PINBOUNCE, // 9
IntentCode.NODE_LAGUNA_DATA, // 10

IntentCode.NODE_GLOBAL_VDISTR, // ignored
IntentCode.NODE_GLOBAL_VROUTE, // ignored
IntentCode.NODE_GLOBAL_HDISTR, // ignored
IntentCode.NODE_GLOBAL_HROUTE // ignored
)
);
private static HashMap<Integer, HashMap<Integer, HashMap<String, Integer>>> allPBlockDirectionToNodeCount = new HashMap<>();
// Check whether a node has a IntentCode of our interest
// Exclude some IntentCodes including clock and control signals
private static boolean validNode(Node n) {
assert intentCodesConsidered.contains(n.getIntentCode()) : "Unexpected IntentCode: " + n.getIntentCode();
// Exclude IntentCode.NODE_GLOBAL_VROUTE and IntentCode.NODE_GLOBAL_VDISTR
if (n.getIntentCode().toString().startsWith("NODE_GLOBAL")) {
assert n.getWireName().startsWith("CLK") : "This GLOBAL node is not a clock node: " + n;
return false;
}
// Filter out IntentCode.INTENT_DEFAULT Nodes
// ? referring to https://docs.amd.com/v/u/en-US/ug573-ultrascale-memory-resources
if (n.getIntentCode() == IntentCode.INTENT_DEFAULT) {
// Exclude CLKOUT_{NORTH\d+|SOUTH\d+} and cascade i/o of BRAMs
assert n.getWireName().matches("^(CLKOUT|BRAM|HPIO|GND|VCC).*$") : "Unexpected INTENT_DEFAULT node: " + n;
return false;
}
// DSP cascade not extracted by getIntersectingNodes
// ? referring to https://docs.amd.com/v/u/en-US/ug579-ultrascale-dsp
assert !n.getWireName().contains("DSP_") : "Unexpected DSP-related node: " + n;
return true;
}

// Check whether a node in Set<Tile> a also exists in Set<Tile> b
// then, check whether the node has a valid IntentCode of our interest
// if both satisfied, add the node to the result set
public static Set<Node> getIntersectingNodes(Set<Tile> a, Set<Tile> b) {
Set<Node> nodesInA = getNodesInTiles(a);
Set<Node> nodesInB = getNodesInTiles(b);

Set<Node> both = new HashSet<>();
for (Node n : nodesInA) {
if (nodesInB.contains(n) && validNode(n)) {
both.add(n);

// For debugging
// System.out.println(n + " " + n.getIntentCode()); // + " " + a + " " + b);

// For highlight_objects in Vivado
// try {
// PrintWriter writer = new PrintWriter(new FileWriter("node_stats.txt", true));
// writer.println(n + " " + n.getIntentCode());// + " " + a + " " + b);
// writer.close();
// } catch (IOException e) {
// System.out.println("Error: " + e.getMessage());
// e.printStackTrace();
// }
}
}
return both;
}

// Get the number of nodes crossing the boundary between the current pblock and the adjacent pblock(s)
// in at most 4 directions (N, S, E, W)
public static HashMap<String, Integer> getSinglePBlockDirectionToNodeCount(Device device, int col, int row) {
// The current pblock (pCenter).
PBlock pCenter = new PBlock(device, "CLOCKREGION_X" + col + "Y" + row + ":CLOCKREGION_X" + col + "Y" + row);
// The direction (N, S, E, W) to the number of nodes crossing the boundary between the current pblock and the adjacent pblock(s).
HashMap<String, Integer> directionToNodeCount = new HashMap<String, Integer>();

// North
if (row < device.getNumOfClockRegionRows() - 1) {
PBlock pNorth = new PBlock(device, "CLOCKREGION_X" + col + "Y" + (row + 1) + ":CLOCKREGION_X" + col + "Y" + (row + 1));
Set<Node> intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pNorth.getAllTiles());
int intersectingNodesCount = intersectingNodes.size();
directionToNodeCount.put("N", intersectingNodesCount);
}
// South
if (row > 0) {
if (allPBlockDirectionToNodeCount.containsKey(col) && allPBlockDirectionToNodeCount.get(col).containsKey(row - 1)) {
directionToNodeCount.put("S", allPBlockDirectionToNodeCount.get(col).get(row - 1).get("N"));
} else {
assert false : "col " + col + "; row " + row;
PBlock pSouth = new PBlock(device, "CLOCKREGION_X" + col + "Y" + (row - 1) + ":CLOCKREGION_X" + col + "Y" + (row - 1));
Set<Node> intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pSouth.getAllTiles());
int intersectingNodesCount = intersectingNodes.size();
directionToNodeCount.put("S", intersectingNodesCount);
}
}
// East
if (col < device.getNumOfClockRegionsColumns() - 1) {
PBlock pEast = new PBlock(device, "CLOCKREGION_X" + (col + 1) + "Y" + row + ":CLOCKREGION_X" + (col + 1) + "Y" + row);
Set<Node> intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pEast.getAllTiles());
int intersectingNodesCount = intersectingNodes.size();
directionToNodeCount.put("E", intersectingNodesCount);
}
// West
if (col > 0) {
if (allPBlockDirectionToNodeCount.containsKey(col - 1) && allPBlockDirectionToNodeCount.get(col - 1).containsKey(row)) {
directionToNodeCount.put("W", allPBlockDirectionToNodeCount.get(col - 1).get(row).get("E"));
} else {
assert false : "col " + col + "; row " + row;
PBlock pWest = new PBlock(device, "CLOCKREGION_X" + (col - 1) + "Y" + row + ":CLOCKREGION_X" + (col - 1) + "Y" + row);
Set<Node> intersectingNodes = getIntersectingNodes(pCenter.getAllTiles(), pWest.getAllTiles());
int intersectingNodesCount = intersectingNodes.size();
directionToNodeCount.put("W", intersectingNodesCount);
}
}
return directionToNodeCount;
}

// Get a set of nodes in a set of tiles
public static Set<Node> getNodesInTiles(Set<Tile> tiles) {
Set<Node> nodes = new HashSet<>();
for (Tile t : tiles) {
for (PIP p : t.getPIPs()) {
Node start = p.getStartNode();
Node end = p.getEndNode();
nodes.add(start);
nodes.add(end);
}
}
nodes.remove(null);
return nodes;
}

public static HashMap<String, Integer> getSinglePBlockCrossingCRNodeCount(String dcp, int col, int row) {
Design design = null;
Device device = null;
try {
design = Design.readCheckpoint(dcp);
device = design.getDevice(); // TODO: directly get device bypassing design?
if (row < 0 || row >= device.getNumOfClockRegionRows() || col < 0 || col >= device.getNumOfClockRegionsColumns()) {
System.out.println("Invalid row or column number");
return null;
}
else {
System.out.println("Device: " + device + "; col(X): " + col + "/" + device.getNumOfClockRegionsColumns() + "; row(Y): " + row + "/" + device.getNumOfClockRegionRows());
// e.g., xcu250, col(X): 4/8, row(Y): 3/16
}
} catch (Exception e) {
System.out.println("Error: " + e.getMessage());
e.printStackTrace();
}
assert design != null : "Design is null";
assert device != null : "Device is null";

// Create the result map:
HashMap<String, Integer> directionToNodeCount = getSinglePBlockDirectionToNodeCount(device, col, row);
return directionToNodeCount;
}

public static HashMap<Integer, HashMap<Integer, HashMap<String, Integer>>> getAllPBlockCrossingCRNodeCount(String dcp) {
Design design = null;
Device device = null;
try {
design = Design.readCheckpoint(dcp);
device = design.getDevice(); // TODO: directly get device bypassing design?
System.out.println("Device: " + device + "; col(X): " + device.getNumOfClockRegionsColumns() + "; row(Y): " + device.getNumOfClockRegionRows());
} catch (Exception e) {
System.out.println("Error: " + e.getMessage());
e.printStackTrace();
}
assert design != null : "Design is null";
assert device != null : "Device is null";

// Create the result map:
for (int col = 0; col < device.getNumOfClockRegionsColumns(); col++) {
for (int row = 0; row < device.getNumOfClockRegionRows(); row++) {
HashMap<String, Integer> directionToNodeCount = getSinglePBlockDirectionToNodeCount(device, col, row);
HashMap<Integer, HashMap<String, Integer>> rowToDirectionToNodeCount = allPBlockDirectionToNodeCount.getOrDefault(col, new HashMap<>());
rowToDirectionToNodeCount.put(row, directionToNodeCount);
allPBlockDirectionToNodeCount.put(col, rowToDirectionToNodeCount);
System.out.println("CLOCKREGION: X " + col + "/" + device.getNumOfClockRegionsColumns() + "; Y " + row + "/" + device.getNumOfClockRegionRows() + " done");
}
}
return allPBlockDirectionToNodeCount;
}

public static void main(String[] args) {
if(args.length != 3) {
System.out.println("USAGE: rapidwright MetricsExtractor <.dcp> <col (X)> <row (Y)>");
return;
}
Design design = null;
Device device = null;
int col = -1, row = -1;
try {
design = Design.readCheckpoint(args[0]);
col = Integer.parseInt(args[1]);
row = Integer.parseInt(args[2]);
device = design.getDevice(); // TODO: directly get device bypassing design?
if (row < 0 || row >= device.getNumOfClockRegionRows() || col < 0 || col >= device.getNumOfClockRegionsColumns()) {
System.out.println("Invalid row or column number");
return;
}
else {
System.out.println("Device: " + device + "; col(X): " + col + "/" + device.getNumOfClockRegionsColumns() + "; row(Y): " + row + "/" + device.getNumOfClockRegionRows());
// e.g., xcu250, col(X): 4/8, row(Y): 3/16
}
} catch (Exception e) {
System.out.println("Error: " + e.getMessage());
e.printStackTrace();
}
assert design != null : "Design is null";
assert device != null : "Device is null";

// Create the result map:
HashMap<String, Integer> directionToNodeCount = getSinglePBlockDirectionToNodeCount(device, col, row);
// System.out.println("CLOCKREGION_X" + col + "Y" + row + ":CLOCKREGION_X" + col + "Y" + row);
System.out.println(directionToNodeCount);
}
}
13 changes: 13 additions & 0 deletions java/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# RapidStream DSE Engine Java Source Code

## How to Compile
1. Copy the source code in `<rapidstream-dse-base>/java` to `<RapidWright-base>/src/com/xilinx/rapidwright/examples`.
2. Register the corresponding pass in `<RapidWright-base>/src/com/xilinx/rapidwright/MainEntrypoint.java`
```=java
static {
...
addFunction("CrossingCRNodeCounter", CrossingCRNodeCounter::main);
...
}
```
3. Recompile RapidWright with `./gradlew compileJava`
88 changes: 88 additions & 0 deletions python/src/CrossingCRNodeCounter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from pathlib import Path
from typing import Literal
import jpype
import jpype.imports
import sys

def get_node_counts_between_CRs(
dcp_path: Path,
num_col: int,
num_row: int,
) -> dict[Literal["N", "S", "E", "W"], int]:
"""Extract number of nodes between a **single** clock region and its adjacent
ones (at clock region level).
Args:
dcp_path: Path to a checkpoint file which contains a grid of CRs.
num_col: Number of columns (X) of the pblock grid.
num_row: Number of rows (Y) of the pblock grid.
Returns:
A dictionrary direction to the number of nodes. For example, for Alveo
U280, num_col: 0, num_row: 0, the output is {'E': 12480, 'N': 9200}.
CR_X0Y0 should not have "SOUTH" or "WEST" keys because those slots
do not exist.
"""
jpype.startJVM(jpype.getDefaultJVMPath())
from com.xilinx.rapidwright.examples import CrossingCRNodeCounter
node_count_java = CrossingCRNodeCounter.getSinglePBlockCrossingCRNodeCount(jpype.JString(str(dcp_path)), num_col, num_row) # <java object 'java.util.HashMap'>
node_count = dict(node_count_java) # dict
'''
# enabling assertions by "-ea"
process = jpype.java.lang.Runtime.getRuntime().exec_(
"java -ea com.xilinx.rapidwright.examples.CrossingCRNodeCounter {} {} {}".format(dcp_path, num_col, num_row))
process.waitFor()
# get stdout
input_stream = process.getInputStream()
reader = jpype.java.io.BufferedReader(jpype.java.io.InputStreamReader(input_stream))
# read commandline output
line = reader.readLine()
while line is not None:
print(line)
line = reader.readLine()
'''
return node_count

def get_node_counts_between_CRs(
dcp_path: Path
) -> dict[int, dict[int, dict[Literal["N", "S", "E", "W"], int]]]:
"""Extract number of nodes between **all pairs of** pblocks (at clock region
level).
Args:
dcp_path: Path to a checkpoint file which contains a grid of pblocks. It is
guaranteed that the pblock will be named in the format of f"x{x}y{y}".
Returns:
A mapping from (col, row) to direction to the number of nodes. For example,
node_count[0][0]["N"] is the number of nodes between pblock (0, 0) and
(0, 1).
node_count[0][0] should not have "SOUTH" or "WEST" keys because those slots
do not exist.
"""
jpype.startJVM(jpype.getDefaultJVMPath(), "-ea") # enabling assertions by "-ea"
from com.xilinx.rapidwright.examples import CrossingCRNodeCounter
node_count_java = CrossingCRNodeCounter.getAllPBlockCrossingCRNodeCount(jpype.JString(str(dcp_path))) # <java object 'java.util.HashMap'>
node_count = {_:{__:dict(dict(node_count_java[_])[__]) for __ in dict(node_count_java[_]).keys()} for _ in node_count_java.keys()} # dict
return node_count

if __name__ == "__main__":
if len(sys.argv) not in {2, 4}:
print("Usage: python CrossingCRNodeExtractor.py <dcp_path> <num_col> <num_row>\n or python CrossingCRNodeExtractor.py <dcp_path>")
sys.exit(1)

try:
dcp_path = Path(sys.argv[1])
num_col = int(sys.argv[2]) if len(sys.argv) == 4 else -1
num_row = int(sys.argv[3]) if len(sys.argv) == 4 else -1
except Exception as e:
print("Error: Invalid arguments.")
sys.exit(1)

# node_counts = get_node_counts_between_CRs(dcp_path, num_col, num_row)
node_counts = get_node_counts_between_CRs(dcp_path)
print(node_counts)
3 changes: 3 additions & 0 deletions python/src/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# RapidStream DSE Engine Python Source Code

Please finish the "How to Compile" section in `<rapidstream-dse-base>/java/README.md` before running the script in `<rapidstream-dse-base>/python`.

0 comments on commit f404e96

Please sign in to comment.