Skip to content

Commit

Permalink
move atomic instrumentation to application level.
Browse files Browse the repository at this point in the history
  • Loading branch information
aoli-al committed Jul 21, 2024
1 parent d114c0f commit 00d6612
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class ApplicationCodeTransformer : ClassFileTransformer {
cv = SynchronizedMethodInstrumenter(cv, false)
cv = ClassConstructorInstrumenter(cv)
cv = SleepInstrumenter(cv)
cv = AtomicOperationInstrumenter(cv)
cv = TimeInstrumenter(cv)
cv = ThreadHashCodeInstrumenter(cv)
val classVersionInstrumenter = ClassVersionInstrumenter(cv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fun instrumentClass(path: String, inputStream: InputStream): ByteArray {
cv = LockInstrumenter(cv)
cv = SystemModulesMapInstrumenter(cv)
cv = ConditionInstrumenter(cv)
cv = AtomicOperationInstrumenter(cv)
// cv = AtomicOperationInstrumenter(cv)
cv = ObjectNotifyInstrumenter(cv)
cv = UnsafeInstrumenter(cv)
cv = SkipMethodInstrumenter(cv)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import java.util.concurrent.atomic.*
import org.objectweb.asm.ClassVisitor
import org.objectweb.asm.MethodVisitor
import org.objectweb.asm.Opcodes.*
import org.objectweb.asm.Type
import org.objectweb.asm.commons.AdviceAdapter

class AtomicOperationInstrumenter(cv: ClassVisitor) : ClassVisitor(ASM9, cv) {
var className = ""
Expand All @@ -26,36 +28,59 @@ class AtomicOperationInstrumenter(cv: ClassVisitor) : ClassVisitor(ASM9, cv) {
override fun visitMethod(
access: Int,
name: String,
descriptor: String?,
descriptor: String,
signature: String?,
exceptions: Array<out String>?
): MethodVisitor {
val mv = super.visitMethod(access, name, descriptor, signature, exceptions)
val memoryType = memoryTypeFromMethodName(name)
if (atomicClasses.contains(className) &&
!atomicNonVolatileMethodNames.contains(name) &&
access and ACC_PUBLIC != 0) {
return object : MethodVisitor(ASM9, mv) {
override fun visitCode() {
super.visitCode()
val type = MemoryOpType::class.java.name.replace(".", "/")
visitVarInsn(ALOAD, 0)
visitFieldInsn(GETSTATIC, type, memoryType.name, "L$type;")
visitMethodInsn(
INVOKESTATIC,
cmu.pasta.fray.runtime.Runtime::class.java.name.replace(".", "/"),
Runtime::onAtomicOperation.name,
Utils.kFunctionToJvmMethodDescriptor(Runtime::onAtomicOperation),
false)
return object : AdviceAdapter(ASM9, mv, access, name, descriptor) {
override fun visitMethodInsn(
opcodeAndSource: Int,
owner: String,
name: String,
descriptor: String?,
isInterface: Boolean
) {
if (atomicClasses.contains(owner) && !atomicNonVolatileMethodNames.contains(name)) {

val argumentTypes = Type.getArgumentTypes(descriptor)
val paramArrayIndex = newLocal(Type.getType("[Ljava/lang/Object;"))
push(argumentTypes.size)
newArray(Type.getObjectType("java/lang/Object"))
storeLocal(paramArrayIndex)
for (i in argumentTypes.indices) { // store call parameters to an array
val type = argumentTypes[argumentTypes.size - 1 - i]
box(type)
loadLocal(paramArrayIndex)
swap()
push(i)
swap()
arrayStore(Type.getObjectType("java/lang/Object"))
}
dup()
val memoryType = memoryTypeFromMethodName(name)
getStatic(
Type.getType(MemoryOpType::class.java),
memoryType.name,
Type.getType(MemoryOpType::class.java))
invokeStatic(
Type.getType(Runtime::class.java),
Utils.kFunctionToASMMethod(Runtime::onAtomicOperation))
for (i in argumentTypes.indices.reversed()) { // load call parameters from an array
loadLocal(paramArrayIndex)
push(i)
arrayLoad(Type.getObjectType("java/lang/Object"))
unbox(argumentTypes[argumentTypes.size - 1 - i])
}
}
super.visitMethodInsn(opcodeAndSource, owner, name, descriptor, isInterface)
}
}
return mv
}

fun memoryTypeFromMethodName(name: String): MemoryOpType {
val lname = name.lowercase()
return if (lname.contains("set") || lname.contains("exchange")) {
return if (lname.contains("set") || lname.contains("exchange") || lname.contains("update")) {
MemoryOpType.MEMORY_WRITE
} else {
MemoryOpType.MEMORY_READ
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class VolatileFieldsInstrumenter(cv: ClassVisitor, private val instrumentingJdk:
}
return object : AdviceAdapter(ASM9, mv, access, name, descriptor) {
override fun visitFieldInsn(opcode: Int, owner: String, name: String, descriptor: String) {
if (owner.startsWith("cmu/pasta/fray/runtime")) {
return super.visitFieldInsn(opcode, owner, name, descriptor)
}
if (recursiveVisitClass(owner) || volatileManager.isVolatile(owner, name)) {
if (opcode == Opcodes.GETFIELD) {
dup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@ public String runTest(Function0<Unit> exec, Scheduler scheduler, int iter) {
return logger.sb.toString();
}

// public void runTest(String methodName, String testCase) {
// String testName = this.getClass().getSimpleName();
// String expectedFile = "expected/" + testName + "_" + testCase + ".txt";
// String scheduleFile = "schedules/" + testName + "_" + testCase + ".json";
// String expected = getResourceAsString(expectedFile);
// ReplayScheduler scheduler = new ReplayScheduler(Schedule.Companion.fromString(getResourceAsString(scheduleFile), true));
// assertEquals(expected, runTest(methodName, scheduler));
// }

public String getResourceAsString(String path) {
try(InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(path)) {
BufferedReader reader = new BufferedReader(new InputStreamReader(is));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package cmu.pasta.fray.it.scheduler;

import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;


//public class AtomicTest extends IntegrationTestRunner {
// @Test
// public void testInterleaving1() {
// runTest("T1T2");
// }
//
// @Test
// public void testInterleaving2() {
// runTest("T2T1");
// }
//
// @Test
// public void testInterleaving3() {
// runTest("alternateT2T1");
// void testAtomic() {
// AtomicInteger i = new AtomicInteger(0);
//
// Thread t = new Thread(() -> {
// intei.incrementAndGet();
// });
// }
//
//}

0 comments on commit 00d6612

Please sign in to comment.