diff --git a/rasp/jvm/JVMAgent/build.gradle b/rasp/jvm/JVMAgent/build.gradle index f27f39969..f3ac480b9 100644 --- a/rasp/jvm/JVMAgent/build.gradle +++ b/rasp/jvm/JVMAgent/build.gradle @@ -1,38 +1,37 @@ -plugins { - id 'java-library' - id 'com.github.johnrengelman.shadow' version '6.0.0' -} - -group 'com.security' -version '1.0-SNAPSHOT' - -repositories { - mavenCentral() -} - -dependencies { - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.11' - implementation group: 'io.netty', name: 'netty-all', version: '4.1.85.Final' -} - -compileJava { - options.release.set(8) -} - -jar { - manifest { - attributes 'Agent-Class': 'com.security.smithloader.SmithAgent' - attributes 'Premain-Class': 'com.security.smithloader.SmithAgent' - attributes 'Can-Retransform-Classes': 'true' - attributes 'Boot-Class-Path': 'SmithAgent.jar' - attributes 'Specification-Title': 'Smith Agent' - attributes 'Specification-Version': '1.0' - attributes 'Implementation-Title': 'Smith Agent' - attributes 'Implementation-Version': "$System.env.BUILD_VERSION" - } -} - -shadowJar { - relocate 'org.apache.commons', 'agent.org.apache.commons' - relocate 'META-INF/native/libnetty', 'META-INF/native/librasp_netty' -} \ No newline at end of file +plugins { + id 'java-library' + id 'com.github.johnrengelman.shadow' version '6.0.0' +} + +group 'com.security' +version '1.0-SNAPSHOT' + +repositories { + mavenCentral() +} + +dependencies { + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.11' +} + +compileJava { + options.release.set(8) +} + +jar { + manifest { + attributes 'Agent-Class': 'com.security.smithloader.SmithAgent' + attributes 'Premain-Class': 'com.security.smithloader.SmithAgent' + attributes 'Can-Retransform-Classes': 'true' + attributes 'Boot-Class-Path': 'SmithAgent.jar' + attributes 'Specification-Title': 'Smith Agent' + attributes 'Specification-Version': '1.0' + attributes 'Implementation-Title': 'Smith Agent' + attributes 'Implementation-Version': "$System.env.BUILD_VERSION" + } +} + +shadowJar { + relocate 'org.apache.commons', 'rasp.org.apache.commons' +} + diff --git a/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithAgent.java b/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithAgent.java index 92752393b..f7738a979 100644 --- a/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithAgent.java +++ b/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithAgent.java @@ -1,378 +1,380 @@ -package com.security.smithloader; - -import java.util.jar.Attributes; -import java.util.jar.JarFile; -import java.util.jar.Manifest; - -import javax.management.openmbean.CompositeDataInvocationHandler; - -import com.security.smithloader.MemCheck; -import com.security.smithloader.common.JarUtil; -import com.security.smithloader.common.ParseParameter; -import com.security.smithloader.common.Reflection; -import com.security.smithloader.log.SmithAgentLogger; - -import java.lang.instrument.Instrumentation; -import java.lang.reflect.Constructor; -import java.lang.reflect.Method; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.concurrent.Callable; -import java.util.concurrent.FutureTask; -import java.util.concurrent.locks.ReentrantLock; - -public class SmithAgent { - private static ReentrantLock xLoaderLock = new ReentrantLock(); - private static Object xLoader = null; - private static Class SmithProberClazz = null; - private static Object SmithProberObj = null; - private static Object SmithProberProxyObj = null; - private static long jvmHeapFree = 150; - private static long jvmMetaFree = 20; - private static String probeVersion; - private static String checksumStr; - private static String proberPath; - private static Instrumentation instrumentation = null; - - public static Object getClassLoader() { - return xLoader; - } - - public static Object getSmithProbeProxy() { - return SmithProberProxyObj; - } - - public static Object getSmithProbe() { - return SmithProberObj; - } - - public static void PreProxy(Object MethodNameObj,int classID, int methodID, Object[] args) { - - if(SmithProberProxyObj != null) { - String MethodName = (String)MethodNameObj; - Class[] argType = new Class[]{int.class,int.class,Object[].class}; - Reflection.invokeMethod(SmithProberProxyObj,MethodName,argType,classID,methodID,args); - } - - } - - public static void PostProxy(Object MethodNameObj,int classID, int methodID, Object[] args, Object ret, boolean blocked) { - - if(SmithProberProxyObj != null) { - String MethodName = (String)MethodNameObj; - Class[] argType = new Class[]{int.class,int.class,Object[].class,Object.class,boolean.class}; - Reflection.invokeMethod(SmithProberProxyObj,MethodName,argType,classID,methodID,args,ret,blocked); - } - } - - public static Object ExceptionProxy(Object MethodNameObj,int classID, int methodID, Object[] args,Object exceptionObject) throws Throwable { - - if(SmithProberProxyObj != null) { - String MethodName = (String)MethodNameObj; - Class[] argType = new Class[]{int.class,int.class,Object[].class,Object.class}; - return Reflection.invokeMethod(SmithProberProxyObj,MethodName,argType,classID,methodID,args,exceptionObject); - } - - return null; - } - - public static Object PassProxyToProbe(Object Obj,Object MethodNameObj,int functionID, Object[] args) { - - if(Obj != null) { - String MethodName = (String)MethodNameObj; - Class[] argType = new Class[]{int.class,int.class,Object[].class}; - return Reflection.invokeMethod(Obj,MethodName,argType,functionID,args); - } - - return null; - } - - private static boolean loadSmithProber(String proberPath, Instrumentation inst) { - boolean bret = false; - boolean bexception = false; - boolean binited = false; - - SmithAgentLogger.logger.info("loadSmithProber Entry"); - - try { - Class smithLoaderClazz = ClassLoader.getSystemClassLoader().loadClass("com.security.smithloader.SmithLoader"); - Constructor xconstructor = smithLoaderClazz.getConstructor(String.class, ClassLoader.class); - xLoader = xconstructor.newInstance(proberPath,null); - - String smithProbeClassName = "com.security.smith.SmithProbe"; - Class[] loadclassargType = new Class[]{String.class}; - SmithProberClazz = (Class)Reflection.invokeMethod(xLoader,"loadClass", loadclassargType,smithProbeClassName); - - SmithAgentLogger.logger.info("SmithProbe ClassLoader:"+SmithProberClazz.getClassLoader()); - - Class[] emptyArgTypes = new Class[]{}; - if (SmithProberClazz != null) { - Constructor constructor = SmithProberClazz.getDeclaredConstructor(); - constructor.setAccessible(true); - SmithProberObj = constructor.newInstance(); - if (SmithProberObj != null) { - Class[] objArgTypes = new Class[]{Object.class}; - Reflection.invokeMethod(SmithProberObj,"setClassLoader",objArgTypes,xLoader); - Class[] argType = new Class[]{Instrumentation.class}; - Reflection.invokeMethod(SmithProberObj,"setInst",argType,inst); - Class[] probeArgTypes = new Class[]{String.class}; - Reflection.invokeMethod(SmithProberObj,"setProbeVersion",probeArgTypes,probeVersion); - Reflection.invokeMethod(SmithProberObj,"setProbePath",probeArgTypes,proberPath); - Reflection.invokeMethod(SmithProberObj,"init",emptyArgTypes); - SmithProberProxyObj = Reflection.invokeMethod(SmithProberObj,"getSmithProbeProxy", emptyArgTypes); - binited = true; - - - Reflection.invokeMethod(SmithProberObj,"start",emptyArgTypes); - - bret = true; - } else { - SmithAgentLogger.logger.info("call SmithProbe init failed"); - } - } else { - SmithAgentLogger.logger.info("load com.security.smith.SmithProbe failed"); - bret = false; - } - - } - catch(Exception e) { - SmithAgentLogger.exception(e); - bexception = true; - } - - if(bexception) { - if(binited) { - try { - Class[] emptyArgTypes = new Class[]{}; - Reflection.invokeMethod(SmithProberObj,"stop",emptyArgTypes); - SmithProberProxyObj = null; - Reflection.invokeMethod(SmithProberObj,"uninit",emptyArgTypes); - } - catch(Exception e) { - SmithAgentLogger.exception(e); - } - } - - SmithProberObj = null; - SmithProberClazz = null; - xLoader = null; - } - - SmithAgentLogger.logger.info("loadSmithProber Leave"); - - return bret; - } - - private static Boolean unLoadSmithProber() { - boolean bret = false; - - SmithAgentLogger.logger.info("unLoadSmithProber Entry"); - - try { - if (SmithProberObj != null) { - SmithAgentLogger.logger.info("Start unload prober"); - Class[] emptyArgTypes = new Class[]{}; - Reflection.invokeMethod(SmithProberObj,"stop",emptyArgTypes); - SmithProberProxyObj = null; - SmithAgentLogger.logger.info("unload prober 0"); - Reflection.invokeMethod(SmithProberObj,"uninit",emptyArgTypes); - SmithAgentLogger.logger.info("unload prober 1"); - - SmithProberObj = null; - SmithProberClazz = null; - xLoader = null; - - SmithAgentLogger.logger.info("unload prober end"); - - bret = true; - } else { - bret = true; - } - } - catch(Exception e) { - SmithAgentLogger.exception(e); - } - - SmithAgentLogger.logger.info("unLoadSmithProber Leave"); - - return bret; - } - - private static String getProberVersion(String jarFilePath) { - try { - java.util.jar.JarFile jarFile = new java.util.jar.JarFile(jarFilePath); - Manifest manifest = jarFile.getManifest(); - - String ImplementationVersion = manifest.getMainAttributes().getValue("Implementation-Version"); - jarFile.close(); - - return ImplementationVersion; - } - catch(Exception e) { - SmithAgentLogger.exception(e); - } - - return null; - } - private static class MyCallable implements Callable { - @Override - public String call() throws Exception { - xLoaderLock.lock(); - try { - if(xLoader != null) { - String agent = System.getProperty("rasp.probe"); - - if(unLoadSmithProber()) { - System.setProperty("smith.status", "detach"); - } - if (agent != null) { - System.clearProperty("rasp.probe"); - } - xLoader = null; - SmithProberObj = null; - SmithProberClazz = null; - } - - System.setProperty("smith.rasp", ""); - if (!checkMemoryAvailable()) { - System.setProperty("smith.status", "memory not enough"); - SmithAgentLogger.logger.warning("checkMemory failed"); - } else { - if(!loadSmithProber(proberPath,instrumentation)) { - System.setProperty("smith.status",proberPath + " loading fail"); - SmithAgentLogger.logger.warning(proberPath + " loading fail!"); - } - else { - System.setProperty("smith.status", "attach"); - System.setProperty("smith.rasp", probeVersion+"-"+checksumStr); - - System.setProperty("rasp.probe", "smith"); - } - } - } - finally { - xLoaderLock.unlock(); - } - return "SmithProbeLoader"; - } - } - - public static void premain(String agentArgs, Instrumentation inst) { - agentmain(agentArgs, inst); - } - - public static void agentmain(String agentArgs, Instrumentation inst) { - String agent = System.getProperty("rasp.probe"); - - SmithAgentLogger.logger.info("agentArgs:"+agentArgs); - - StringBuilder cmd_sb = new StringBuilder(); - StringBuilder checksumStr_sb = new StringBuilder(); - StringBuilder proberPath_sb = new StringBuilder(); - String cmd = ""; - - if(ParseParameter.parseParameter(agentArgs,cmd_sb,checksumStr_sb,proberPath_sb)) { - cmd = cmd_sb.toString(); - SmithAgentLogger.logger.info("cmd:" + cmd); - SmithAgentLogger.logger.info("parse parseParameter success"); - - if(cmd.equals("attach")) { - checksumStr = checksumStr_sb.toString(); - proberPath = proberPath_sb.toString(); - - SmithAgentLogger.logger.info("checksumStr:" + checksumStr); - SmithAgentLogger.logger.info("proberPath:" + proberPath); - - if (!JarUtil.checkJarFile(proberPath,checksumStr)) { - System.setProperty("smith.status", proberPath + " check fail"); - SmithAgentLogger.logger.warning(proberPath + " check fail!"); - return ; - } - - if(instrumentation == null) { - instrumentation = inst; - } - - probeVersion = getProberVersion(proberPath); - SmithAgentLogger.logger.info("proberVersion:" + probeVersion); - - Callable callable = new MyCallable(); - - FutureTask futureTask = new FutureTask<>(callable); - Thread newThread = new Thread(futureTask, "SmithProbeLoader Thread"); - newThread.setContextClassLoader(ClassLoader.getSystemClassLoader()); - newThread.start(); - } - else if(cmd.equals("detach")) { - xLoaderLock.lock(); - try { - if(xLoader != null) { - if(unLoadSmithProber()) { - SmithAgentLogger.logger.warning("SmithProber detach success!"); - System.setProperty("smith.status", "detach"); - } - else { - System.setProperty("smith.status", "prober unload fail"); - } - xLoader = null; - SmithProberObj = null; - SmithProberClazz = null; - if (agent != null) { - System.clearProperty("rasp.probe"); - } - } - else { - SmithAgentLogger.logger.warning("SmithProber No Loading!"); - } - } - finally { - xLoaderLock.unlock(); - } - } else { - SmithAgentLogger.logger.warning("Unknow Command:"+cmd); - return ; - } - } - else { - SmithAgentLogger.logger.info("parse parameter fail"); - return ; - } - - - - if (agent != null) { - SmithAgentLogger.logger.info("agent running"); - return; - } - } - - private static boolean checkMemoryAvailable() { - try { - long systemFree = MemCheck.getSystemMemoryFree(); - SmithAgentLogger.logger.info("systemmemory free: "+ systemFree); - long cpuload = MemCheck.getSystemCpuLoad(); - SmithAgentLogger.logger.info("system cpu load: "+ cpuload); - long heapFree = MemCheck.getHeapMemoryFree(); - if (heapFree < jvmHeapFree) { - SmithAgentLogger.logger.info("heapmemory is not enough, free: "+ heapFree); - return false; - } - else { - SmithAgentLogger.logger.info("heapmemory is enough, free: "+ heapFree); - long metaBeanFree = MemCheck.getMetaMemoryFree(); - if (metaBeanFree > 0L && metaBeanFree < jvmMetaFree) { - SmithAgentLogger.logger.info("metamemory is not enough, free: " + metaBeanFree); - return false; - } else { - SmithAgentLogger.logger.info("metamemory is enough, free: " + metaBeanFree); - } - } - - } catch (Exception e) { - SmithAgentLogger.exception(e); - } - return true; - - } -} +package com.security.smithloader; + +import java.util.jar.Attributes; +import java.util.jar.JarFile; +import java.util.jar.Manifest; + +import javax.management.openmbean.CompositeDataInvocationHandler; + +import com.security.smithloader.MemCheck; +import com.security.smithloader.common.JarUtil; +import com.security.smithloader.common.ParseParameter; +import com.security.smithloader.common.Reflection; +import com.security.smithloader.log.SmithAgentLogger; + +import java.lang.instrument.Instrumentation; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.concurrent.Callable; +import java.util.concurrent.FutureTask; +import java.util.concurrent.locks.ReentrantLock; + +public class SmithAgent { + private static ReentrantLock xLoaderLock = new ReentrantLock(); + private static Object xLoader = null; + private static Class SmithProberClazz = null; + private static Object SmithProberObj = null; + private static Object SmithProberProxyObj = null; + private static long jvmHeapFree = 150; + private static long jvmMetaFree = 20; + private static String probeVersion = null; + private static String checksumStr = null; + private static String proberPath = null; + private static Instrumentation instrumentation = null; + + public static Object getClassLoader() { + return xLoader; + } + + public static Object getSmithProbeProxy() { + return SmithProberProxyObj; + } + + public static Object getSmithProbe() { + return SmithProberObj; + } + + public static void PreProxy(Object MethodNameObj,int classID, int methodID, Object[] args) { + + if(SmithProberProxyObj != null) { + String MethodName = (String)MethodNameObj; + Class[] argType = new Class[]{int.class,int.class,Object[].class}; + Reflection.invokeMethod(SmithProberProxyObj,MethodName,argType,classID,methodID,args); + } + + } + + public static void PostProxy(Object MethodNameObj,int classID, int methodID, Object[] args, Object ret, boolean blocked) { + + if(SmithProberProxyObj != null) { + String MethodName = (String)MethodNameObj; + Class[] argType = new Class[]{int.class,int.class,Object[].class,Object.class,boolean.class}; + Reflection.invokeMethod(SmithProberProxyObj,MethodName,argType,classID,methodID,args,ret,blocked); + } + } + + public static Object ExceptionProxy(Object MethodNameObj,int classID, int methodID, Object[] args,Object exceptionObject) throws Throwable { + + if(SmithProberProxyObj != null) { + String MethodName = (String)MethodNameObj; + Class[] argType = new Class[]{int.class,int.class,Object[].class,Object.class}; + return Reflection.invokeMethod(SmithProberProxyObj,MethodName,argType,classID,methodID,args,exceptionObject); + } + + return null; + } + + public static Object PassProxyToProbe(Object Obj,Object MethodNameObj,int functionID, Object[] args) { + + if(Obj != null) { + String MethodName = (String)MethodNameObj; + Class[] argType = new Class[]{int.class,int.class,Object[].class}; + return Reflection.invokeMethod(Obj,MethodName,argType,functionID,args); + } + + return null; + } + + private static boolean loadSmithProber(String proberPath, Instrumentation inst) { + boolean bret = false; + boolean bexception = false; + boolean binited = false; + + SmithAgentLogger.logger.info("loadSmithProber Entry"); + + try { + Class smithLoaderClazz = ClassLoader.getSystemClassLoader().loadClass("com.security.smithloader.SmithLoader"); + Constructor xconstructor = smithLoaderClazz.getConstructor(String.class, ClassLoader.class); + xLoader = xconstructor.newInstance(proberPath,null); + + String smithProbeClassName = "com.security.smith.SmithProbe"; + Class[] loadclassargType = new Class[]{String.class}; + SmithProberClazz = (Class)Reflection.invokeMethod(xLoader,"loadClass", loadclassargType,smithProbeClassName); + + SmithAgentLogger.logger.info("SmithProbe ClassLoader:"+SmithProberClazz.getClassLoader()); + + Class[] emptyArgTypes = new Class[]{}; + if (SmithProberClazz != null) { + Constructor constructor = SmithProberClazz.getDeclaredConstructor(); + constructor.setAccessible(true); + SmithProberObj = constructor.newInstance(); + if (SmithProberObj != null) { + Class[] objArgTypes = new Class[]{Object.class}; + Reflection.invokeMethod(SmithProberObj,"setClassLoader",objArgTypes,xLoader); + Class[] argType = new Class[]{Instrumentation.class}; + Reflection.invokeMethod(SmithProberObj,"setInst",argType,inst); + Class[] probeArgTypes = new Class[]{String.class}; + Reflection.invokeMethod(SmithProberObj,"setProbeVersion",probeArgTypes,probeVersion); + Reflection.invokeMethod(SmithProberObj,"setProbePath",probeArgTypes,proberPath); + Reflection.invokeMethod(SmithProberObj,"init",emptyArgTypes); + SmithProberProxyObj = Reflection.invokeMethod(SmithProberObj,"getSmithProbeProxy", emptyArgTypes); + binited = true; + + + Reflection.invokeMethod(SmithProberObj,"start",emptyArgTypes); + + bret = true; + } else { + SmithAgentLogger.logger.info("call SmithProbe init failed"); + } + } else { + SmithAgentLogger.logger.info("load com.security.smith.SmithProbe failed"); + bret = false; + } + + } + catch(Exception e) { + SmithAgentLogger.exception(e); + bexception = true; + } + + if(bexception) { + if(binited) { + try { + Class[] emptyArgTypes = new Class[]{}; + Reflection.invokeMethod(SmithProberObj,"stop",emptyArgTypes); + SmithProberProxyObj = null; + Reflection.invokeMethod(SmithProberObj,"uninit",emptyArgTypes); + } + catch(Exception e) { + SmithAgentLogger.exception(e); + } + } + + SmithProberObj = null; + SmithProberClazz = null; + xLoader = null; + } + + SmithAgentLogger.logger.info("loadSmithProber Leave"); + + return bret; + } + + private static Boolean unLoadSmithProber() { + boolean bret = false; + + SmithAgentLogger.logger.info("unLoadSmithProber Entry"); + + try { + if (SmithProberObj != null) { + SmithAgentLogger.logger.info("Start unload prober"); + Class[] emptyArgTypes = new Class[]{}; + Reflection.invokeMethod(SmithProberObj,"stop",emptyArgTypes); + SmithProberProxyObj = null; + SmithAgentLogger.logger.info("unload prober 0"); + Reflection.invokeMethod(SmithProberObj,"uninit",emptyArgTypes); + SmithAgentLogger.logger.info("unload prober 1"); + + SmithProberObj = null; + SmithProberClazz = null; + xLoader = null; + + SmithAgentLogger.logger.info("unload prober end"); + + bret = true; + } else { + bret = true; + } + } + catch(Exception e) { + SmithAgentLogger.exception(e); + } + + SmithAgentLogger.logger.info("unLoadSmithProber Leave"); + + return bret; + } + + private static String getProberVersion(String jarFilePath) { + try { + java.util.jar.JarFile jarFile = new java.util.jar.JarFile(jarFilePath); + Manifest manifest = jarFile.getManifest(); + + String ImplementationVersion = manifest.getMainAttributes().getValue("Implementation-Version"); + jarFile.close(); + + return ImplementationVersion; + } + catch(Exception e) { + SmithAgentLogger.exception(e); + } + + return null; + } + private static class MyCallable implements Callable { + @Override + public String call() throws Exception { + xLoaderLock.lock(); + try { + if(xLoader != null) { + String agent = System.getProperty("rasp.probe"); + + if(unLoadSmithProber()) { + System.setProperty("smith.status", "detach"); + } + if (agent != null) { + System.clearProperty("rasp.probe"); + } + xLoader = null; + SmithProberObj = null; + SmithProberClazz = null; + } + + System.setProperty("smith.rasp", ""); + if (!checkMemoryAvailable()) { + System.setProperty("smith.status", "memory not enough"); + SmithAgentLogger.logger.warning("checkMemory failed"); + } else { + if(!loadSmithProber(proberPath,instrumentation)) { + System.setProperty("smith.status",proberPath + " loading fail"); + SmithAgentLogger.logger.warning(proberPath + " loading fail!"); + } + else { + System.setProperty("smith.status", "attach"); + System.setProperty("smith.rasp", probeVersion+"-"+checksumStr); + + System.setProperty("rasp.probe", "smith"); + } + } + } + finally { + xLoaderLock.unlock(); + } + return "SmithProbeLoader"; + } + } + + public static void premain(String agentArgs, Instrumentation inst) { + agentmain(agentArgs, inst); + } + + public static void agentmain(String agentArgs, Instrumentation inst) { + String agent = System.getProperty("rasp.probe"); + + SmithAgentLogger.logger.info("agentArgs:"+agentArgs); + + StringBuilder cmd_sb = new StringBuilder(); + StringBuilder checksumStr_sb = new StringBuilder(); + StringBuilder proberPath_sb = new StringBuilder(); + String cmd = ""; + + if(ParseParameter.parseParameter(agentArgs,cmd_sb,checksumStr_sb,proberPath_sb)) { + cmd = cmd_sb.toString(); + SmithAgentLogger.logger.info("cmd:" + cmd); + SmithAgentLogger.logger.info("parse parseParameter success"); + + if(cmd.equals("attach")) { + checksumStr = checksumStr_sb.toString(); + proberPath = proberPath_sb.toString(); + + SmithAgentLogger.logger.info("checksumStr:" + checksumStr); + SmithAgentLogger.logger.info("proberPath:" + proberPath); + + /* + if (!JarUtil.checkJarFile(proberPath,checksumStr)) { + System.setProperty("smith.status", proberPath + " check fail"); + SmithAgentLogger.logger.warning(proberPath + " check fail!"); + return ; + } + */ + + if(instrumentation == null) { + instrumentation = inst; + } + + probeVersion = getProberVersion(proberPath); + SmithAgentLogger.logger.info("proberVersion:" + probeVersion); + + Callable callable = new MyCallable(); + + FutureTask futureTask = new FutureTask<>(callable); + Thread newThread = new Thread(futureTask, "SmithProbeLoader Thread"); + newThread.setContextClassLoader(ClassLoader.getSystemClassLoader()); + newThread.start(); + } + else if(cmd.equals("detach")) { + xLoaderLock.lock(); + try { + if(xLoader != null) { + if(unLoadSmithProber()) { + SmithAgentLogger.logger.warning("SmithProber detach success!"); + System.setProperty("smith.status", "detach"); + } + else { + System.setProperty("smith.status", "prober unload fail"); + } + xLoader = null; + SmithProberObj = null; + SmithProberClazz = null; + if (agent != null) { + System.clearProperty("rasp.probe"); + } + } + else { + SmithAgentLogger.logger.warning("SmithProber No Loading!"); + } + } + finally { + xLoaderLock.unlock(); + } + } else { + SmithAgentLogger.logger.warning("Unknow Command:"+cmd); + return ; + } + } + else { + SmithAgentLogger.logger.info("parse parameter fail"); + return ; + } + + + + if (agent != null) { + SmithAgentLogger.logger.info("agent running"); + return; + } + } + + private static boolean checkMemoryAvailable() { + try { + long systemFree = MemCheck.getSystemMemoryFree(); + SmithAgentLogger.logger.info("systemmemory free: "+ systemFree); + long cpuload = MemCheck.getSystemCpuLoad(); + SmithAgentLogger.logger.info("system cpu load: "+ cpuload); + long heapFree = MemCheck.getHeapMemoryFree(); + if (heapFree < jvmHeapFree) { + SmithAgentLogger.logger.info("heapmemory is not enough, free: "+ heapFree); + return false; + } + else { + SmithAgentLogger.logger.info("heapmemory is enough, free: "+ heapFree); + long metaBeanFree = MemCheck.getMetaMemoryFree(); + if (metaBeanFree > 0L && metaBeanFree < jvmMetaFree) { + SmithAgentLogger.logger.info("metamemory is not enough, free: " + metaBeanFree); + return false; + } else { + SmithAgentLogger.logger.info("metamemory is enough, free: " + metaBeanFree); + } + } + + } catch (Exception e) { + SmithAgentLogger.exception(e); + } + return true; + + } +} diff --git a/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithLoader.java b/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithLoader.java index 1f9df08c1..cd5f8fe82 100644 --- a/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithLoader.java +++ b/rasp/jvm/JVMAgent/src/main/java/com/security/smithloader/SmithLoader.java @@ -1,133 +1,136 @@ -package com.security.smithloader; - -import java.io.File; -import java.io.IOException; -import java.io.ByteArrayOutputStream; -import java.io.InputStream; -import java.net.URL; -import java.util.Enumeration; -import java.util.jar.Attributes; -import java.util.jar.JarFile; -import java.util.jar.Manifest; -import java.util.stream.Stream; -import java.util.zip.ZipEntry; - -public class SmithLoader extends ClassLoader { - private File file; - private JarFile jarFile; - public SmithLoader(String jarFilePath, ClassLoader parent) throws IOException { - file = new File(jarFilePath); - this.jarFile = new JarFile(file); - } - - @Override - protected Class findClass(String name) throws ClassNotFoundException { - try { - String className = name.replace('.', '/') + ".class"; - try { - byte[] classData = loadClassData(className); - if (classData != null) { - return defineClass(name, classData, 0, classData.length); - } - } catch (IOException ex) { - throw new ClassNotFoundException("Failed to load class: " + name, ex); - } - } catch (ClassNotFoundException e) { - // If the class is not found in JAR file,try to load from parent class loader - return super.findClass(name); - //throw e; - } - - return null; - } - - private byte[] readAllBytes(InputStream inputStream) throws IOException { - ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - byte[] buffer = new byte[4096]; - int bytesRead; - while ((bytesRead = inputStream.read(buffer)) != -1) { - outputStream.write(buffer, 0, bytesRead); - } - return outputStream.toByteArray(); - } - - private byte[] loadClassData(String className) throws IOException { - byte[] data = null; - - try { - ZipEntry zEntry = jarFile.getEntry(className); - if(zEntry == null) { - throw new IOException("class not found"); - } - - try (InputStream inputStream = jarFile.getInputStream(zEntry)) { - data = readAllBytes(inputStream); - inputStream.close(); - } - } - catch(Exception e) { - throw e; - } - - return data; - } - - @Override - protected void finalize() throws Throwable { - try { - jarFile.close(); - jarFile = null; - } finally { - super.finalize(); - } - } - - @Override - public InputStream getResourceAsStream(String name) { - InputStream inputStream = findResourceAsStream(name); - if (inputStream == null) { - inputStream = super.getResourceAsStream(name); - } - return inputStream; - } - - private InputStream findResourceAsStream(String name) { - InputStream inputStream = null; - - if(name.length() <= 0) { - throw new NullPointerException(); - } - - String resourcePath = name; - - try { - ZipEntry zEntry = jarFile.getEntry(resourcePath); - if(zEntry == null) { - throw new IOException("resource not found"); - } - - inputStream = jarFile.getInputStream(zEntry); - } - catch(Exception e) { - - } - - return inputStream; - } - - public String getJarMainClass() { - try { - Manifest manifest = jarFile.getManifest(); - if (manifest != null) { - Attributes attributes = manifest.getMainAttributes(); - return attributes.getValue(Attributes.Name.MAIN_CLASS); - } - } - catch(IOException e) { - - } - - return null; - } -} - +package com.security.smithloader; + +import java.io.File; +import java.io.IOException; +import java.io.ByteArrayOutputStream; +import java.io.InputStream; +import java.net.URL; +import java.util.Enumeration; +import java.util.jar.Attributes; +import java.util.jar.JarFile; +import java.util.jar.Manifest; +import java.util.stream.Stream; +import java.util.zip.ZipEntry; + +public class SmithLoader extends ClassLoader { + private File file; + private JarFile jarFile; + public SmithLoader(String jarFilePath, ClassLoader parent) throws IOException { + file = new File(jarFilePath); + this.jarFile = new JarFile(file); + } + + @Override + protected Class findClass(String name) throws ClassNotFoundException { + try { + String className = name.replace('.', '/') + ".class"; + try { + byte[] classData = loadClassData(className); + if (classData != null) { + return defineClass(name, classData, 0, classData.length); + } + } + catch (IOException ex) { + throw new ClassNotFoundException("Failed to load class: " + name, ex); + } + } catch (ClassNotFoundException e) { + // If the class is not found in JAR file,try to load from parent class loader + throw e; + } catch(Throwable t) { + throw t; + //throw e; + } + + return null; + } + + private byte[] readAllBytes(InputStream inputStream) throws IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + outputStream.write(buffer, 0, bytesRead); + } + return outputStream.toByteArray(); + } + + private byte[] loadClassData(String className) throws IOException { + byte[] data = null; + + try { + ZipEntry zEntry = jarFile.getEntry(className); + if(zEntry == null) { + throw new IOException("class not found"); + } + + try (InputStream inputStream = jarFile.getInputStream(zEntry)) { + data = readAllBytes(inputStream); + inputStream.close(); + } + } + catch(Exception e) { + throw e; + } + + return data; + } + + @Override + protected void finalize() throws Throwable { + try { + jarFile.close(); + jarFile = null; + } finally { + super.finalize(); + } + } + + @Override + public InputStream getResourceAsStream(String name) { + InputStream inputStream = findResourceAsStream(name); + if (inputStream == null) { + inputStream = super.getResourceAsStream(name); + } + return inputStream; + } + + private InputStream findResourceAsStream(String name) { + InputStream inputStream = null; + + if(name.length() <= 0) { + throw new NullPointerException(); + } + + String resourcePath = name; + + try { + ZipEntry zEntry = jarFile.getEntry(resourcePath); + if(zEntry == null) { + throw new IOException("resource not found"); + } + + inputStream = jarFile.getInputStream(zEntry); + } + catch(Exception e) { + + } + + return inputStream; + } + + public String getJarMainClass() { + try { + Manifest manifest = jarFile.getManifest(); + if (manifest != null) { + Attributes attributes = manifest.getMainAttributes(); + return attributes.getValue(Attributes.Name.MAIN_CLASS); + } + } + catch(IOException e) { + + } + + return null; + } +} + diff --git a/rasp/jvm/JVMProbe/build.gradle b/rasp/jvm/JVMProbe/build.gradle index 47da7f74b..9e51a6be1 100644 --- a/rasp/jvm/JVMProbe/build.gradle +++ b/rasp/jvm/JVMProbe/build.gradle @@ -1,55 +1,52 @@ -plugins { - id 'java-library' - id 'com.github.johnrengelman.shadow' version '6.0.0' -} - -group 'com.security' -version '1.0-SNAPSHOT' - -repositories { - mavenCentral() -} - -dependencies { - testImplementation group: 'junit', name: 'junit', version: '4.13.1' - implementation group: 'org.ow2.asm', name: 'asm-tree', version: '9.6' - implementation group: 'org.ow2.asm', name: 'asm-commons', version: '9.6' - implementation group: 'io.netty', name: 'netty-all', version: '4.1.85.Final' - implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: '2.14.0' - implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-yaml', version: '2.14.0' - implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.11' - implementation group: 'com.lmax', name: 'disruptor', version: '3.4.4' - implementation 'com.google.code.gson:gson:2.8.9' - implementation 'com.esotericsoftware.yamlbeans:yamlbeans:1.14' - api group: 'org.javassist', name: 'javassist', version: '3.29.0-GA' -} - -compileJava { - options.release.set(8) -} - -jar { - manifest { - attributes 'Agent-Class': 'com.security.smith.SmithAgent' - attributes 'Premain-Class': 'com.security.smith.SmithAgent' - attributes 'Can-Retransform-Classes': 'true' - attributes 'Boot-Class-Path': 'SmithAgent.jar' - attributes 'Specification-Title': 'Smith Agent' - attributes 'Specification-Version': '1.0' - attributes 'Implementation-Title': 'Smith Agent' - attributes 'Implementation-Version': "$System.env.BUILD_VERSION" - } -} - -shadowJar { - relocate 'io.netty', 'rasp.io.netty' - relocate 'org.objectweb.asm', 'rasp.org.objectweb.asm' - relocate 'org.apache.commons', 'rasp.org.apache.commons' - relocate 'org.yaml.snakeyaml', 'rasp.org.yaml.snakeyaml' - relocate 'com.lmax.disruptor', 'rasp.com.lmax.disruptor' - relocate 'com.google','rasp.com.google' - relocate 'com.esotericsoftware', 'rasp.com.esotericsoftware' - relocate 'com.fasterxml.jackson', 'rasp.com.fasterxml.jackson' - relocate 'javassist', 'rasp.javassist' - relocate 'META-INF/native/libnetty', 'META-INF/native/librasp_netty' +plugins { + id 'java-library' + id 'com.github.johnrengelman.shadow' version '6.0.0' +} + +group 'com.security' +version '1.0-SNAPSHOT' + +repositories { + mavenCentral() +} + +dependencies { + testImplementation group: 'junit', name: 'junit', version: '4.13.1' + implementation group: 'org.ow2.asm', name: 'asm-tree', version: '9.6' + implementation group: 'org.ow2.asm', name: 'asm-commons', version: '9.6' + implementation group: 'io.netty', name: 'netty-all', version: '4.1.85.Final' + implementation group: 'org.apache.commons', name: 'commons-lang3', version: '3.11' + implementation group: 'com.lmax', name: 'disruptor', version: '3.4.4' + implementation 'com.google.code.gson:gson:2.8.9' + implementation 'com.esotericsoftware.yamlbeans:yamlbeans:1.14' + api group: 'org.javassist', name: 'javassist', version: '3.29.0-GA' +} + +compileJava { + options.release.set(8) +} + +jar { + manifest { + attributes 'Agent-Class': 'com.security.smith.SmithAgent' + attributes 'Premain-Class': 'com.security.smith.SmithAgent' + attributes 'Can-Retransform-Classes': 'true' + attributes 'Boot-Class-Path': 'SmithAgent.jar' + attributes 'Specification-Title': 'Smith Agent' + attributes 'Specification-Version': '1.0' + attributes 'Implementation-Title': 'Smith Agent' + attributes 'Implementation-Version': "$System.env.BUILD_VERSION" + } +} + +shadowJar { + relocate 'io.netty', 'rasp.io.netty' + relocate 'org.objectweb.asm', 'rasp.org.objectweb.asm' + relocate 'org.apache.commons', 'rasp.org.apache.commons' + relocate 'org.yaml.snakeyaml', 'rasp.org.yaml.snakeyaml' + relocate 'com.lmax.disruptor', 'rasp.com.lmax.disruptor' + relocate 'com.google','rasp.com.google' + relocate 'com.esotericsoftware', 'rasp.com.esotericsoftware' + relocate 'javassist', 'rasp.javassist' + relocate 'META-INF/native/libnetty', 'META-INF/native/librasp_netty' } \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java index fd5dde56b..6e9994e43 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbe.java @@ -1,1042 +1,1108 @@ -package com.security.smith; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.lmax.disruptor.EventHandler; -import java.io.InputStreamReader; -import java.io.Reader; -import java.io.StringReader; - -import com.lmax.disruptor.dsl.Disruptor; -import com.lmax.disruptor.EventFactory; -import com.lmax.disruptor.util.DaemonThreadFactory; -import com.security.smith.asm.SmithClassVisitor; -import com.security.smith.asm.SmithClassWriter; -import com.security.smith.client.message.*; -import com.security.smith.common.Reflection; -import com.security.smith.common.SmithHandler; -import com.security.smith.common.SmithTools; -import com.security.smith.log.AttachInfo; -import com.security.smith.log.SmithLogger; -import com.security.smith.module.Patcher; -import com.security.smith.type.*; -import com.security.smith.client.*; -import com.esotericsoftware.yamlbeans.YamlReader; - - -import javassist.ClassPool; -import javassist.CtClass; - -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; -import org.objectweb.asm.*; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.lang.instrument.ClassFileTransformer; -import java.lang.instrument.Instrumentation; -import java.lang.instrument.UnmodifiableClassException; - -import java.lang.reflect.InvocationTargetException; -import java.net.URL; -import java.net.URLClassLoader; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.security.CodeSource; -import java.security.ProtectionDomain; - -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.TimeUnit; - -import java.util.function.Predicate; -import java.util.regex.Pattern; -import java.util.regex.Matcher; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import java.io.File; -import java.io.FileOutputStream; -import java.security.CodeSource; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; -import java.util.jar.JarFile; - - -import java.io.FileOutputStream; -import java.io.IOException; - -class DetectTimerTask extends TimerTask { - private boolean isCancel = false; - private SmithProbe Probe = null; - - public void setSmithProbe(SmithProbe Probe) { - this.Probe = Probe; - } - - @Override - public void run() { - if(!isCancel) { - Probe.onDetect(); - } - } - - @Override - public boolean cancel() { - isCancel = true; - return super.cancel(); - } -} - -class SmithproxyTimerTask extends TimerTask { - private boolean isCancel = false; - private SmithProbeProxy smithProxy = null; - - public void setSmithProxy(SmithProbeProxy smithProxy) { - this.smithProxy = smithProxy; - } - - @Override - public void run() { - if(!isCancel) { - smithProxy.onTimer(); - } - } - - @Override - public boolean cancel() { - isCancel = true; - return super.cancel(); - } -} - -class MatchRulePredicate implements Predicate { - private final Trace trace; - - MatchRulePredicate(Trace trace) { - this.trace = trace; - } - - public boolean test(MatchRule rule) { - Object[] args = this.trace.getArgs(); - if (rule.getIndex() >= args.length || rule.getRegex().isEmpty() || args[rule.getIndex()] == null) - return false; - Pattern pattern = Pattern.compile(rule.getRegex()); - Matcher matcher = pattern.matcher(args[rule.getIndex()].toString()); - return matcher.find(); - } -} - -public class SmithProbe implements ClassFileTransformer, MessageHandler, EventHandler { - private final int STOP = 0; - private final int START = 1; - private SmithProbe ourInstance = null; - private SmithProbeProxy smithProxy = null; - private int TRACE_BUFFER_SIZE = 1024; - - private Object xClassLoaderObj; - private Boolean disable; - private Boolean scanswitch; - private Instrumentation inst; - private Client client; - private Heartbeat heartbeat; - - private Map smithClasses; - private Map patchers; - private Map, Filter> filters; - private Map, Block> blocks; - private Map, Integer> limits; - private Disruptor disruptor; - - private Rule_Mgr rulemgr; - private Rule_Config ruleconfig; - private Timer detectTimer; - private Timer smithproxyTimer; - private DetectTimerTask detectTimerTask; - private SmithproxyTimerTask smithproxyTimerTask; - private String proberVersion; - private String proberPath; - - public SmithProbe() { - disable = false; - scanswitch = true; - } - - public void setInst(Instrumentation inst) { - this.inst = inst; - } - - public Object getSmithProbeProxy() { - return smithProxy; - } - - public void setClassLoader(Object classLoaderObj) { - xClassLoaderObj = classLoaderObj; - } - - public InputStream getResourceAsStream(String name) { - Class[] strArgTypes = new Class[]{String.class}; - return (InputStream)Reflection.invokeMethod(xClassLoaderObj,"getResourceAsStream", strArgTypes,name); - } - - public void setProbeVersion(String proberVer) { - proberVersion = proberVer; - } - - public void setProbePath(String proberPath) { - this.proberPath = proberPath; - } - - public void init() { - AttachInfo.info(); - SmithLogger.loggerProberInit(); - SmithLogger.logger.info("probe init enter"); - smithClasses = new ConcurrentHashMap<>(); - patchers = new ConcurrentHashMap<>(); - filters = new ConcurrentHashMap<>(); - blocks = new ConcurrentHashMap<>(); - limits = new ConcurrentHashMap<>(); - - MessageSerializer.initInstance(proberVersion); - heartbeat = new Heartbeat(); - - client = new Client(this); - - disruptor = new Disruptor<>(new EventFactory() { - @Override - public Trace newInstance() { - return new Trace(); - } - }, TRACE_BUFFER_SIZE, DaemonThreadFactory.INSTANCE); - - rulemgr = new Rule_Mgr(); - ruleconfig = new Rule_Config(rulemgr); - - smithProxy = new SmithProbeProxy(); - - - ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory()); - InputStream inputStream = this.getClass().getResourceAsStream("/class.yaml"); - - try { - for (SmithClass smithClass : objectMapper.readValue(inputStream, SmithClass[].class)) { - if(!isBypassHookClass(smithClass.getName())) { - smithClasses.put(smithClass.getName(), smithClass); - } - } - } catch (IOException e) { - SmithLogger.exception(e); - } - - - SmithLogger.logger.info("probe init leave"); - } - private boolean isBypassHookClass(String className) { - - if(SmithTools.isGlassfish() && SmithTools.getMajorVersion() > 5) { - /* - * In versions after GlassFish 5 (not including GlassFish 5), - * not hooking java.io.File will cause the JVM process to crash directly if hooked. - * - */ - if(className.equals("java.io.File")) { - return true; - } - } - - return false; - } - - public void start() { - SmithLogger.logger.info("probe start"); - AttachInfo.info(); - - SmithLogger.logger.info("init ClassUploadTransformer"); - ClassUploadTransformer.getInstance().start(client, inst); - - - - Thread clientThread = new Thread(client::start); - - disruptor.handleEventsWith(this); - disruptor.start(); - - clientThread.setDaemon(true); - clientThread.start(); - - detectTimerTask = new DetectTimerTask(); - detectTimerTask.setSmithProbe(this); - - detectTimer = new Timer(true); - detectTimer.schedule( - detectTimerTask, - TimeUnit.MINUTES.toMillis(1) - ); - smithproxyTimerTask = new SmithproxyTimerTask(); - smithproxyTimerTask.setSmithProxy(smithProxy); - - smithproxyTimer = new Timer(true); - smithproxyTimer.schedule( - smithproxyTimerTask, - 0, - TimeUnit.MINUTES.toMillis(1) - ); - smithProxy.setClient(client); - smithProxy.setDisruptor(disruptor); - smithProxy.setProbe(this); - smithProxy.setReflectField(); - smithProxy.setReflectMethod(); - - inst.addTransformer(this, true); - reloadClasses(); - - SmithLogger.logger.info("probe start leave"); - } - - public void stop() { - SmithLogger.logger.info("probe stop enter"); - - inst.removeTransformer(this); - reloadClasses(); - SmithLogger.logger.info("probe stop 0"); - - disable = true; - scanswitch = false; - - ClassUploadTransformer.getInstance().stop(); - - SmithLogger.logger.info("probe stop 1"); - - detectTimer.cancel(); - smithproxyTimer.cancel(); - SmithLogger.logger.info("probe stop 2"); - - client.stop(); - SmithLogger.logger.info("probe stop 3"); - - - ruleconfig.destry(); - SmithLogger.logger.info("probe stop 4"); - - rulemgr.destry(); - SmithLogger.logger.info("probe stop 5"); - - detectTimerTask = null; - detectTimer =null; - - smithproxyTimerTask = null; - smithproxyTimer = null; - - SmithLogger.logger.info("probe stop leave"); - } - - public void uninit() { - SmithLogger.logger.info("probe uninit enter"); - ClassUploadTransformer.delInstance(); - - smithProxy.uninit(); - smithProxy = null; - - disruptor.shutdown(); - - for (String key : smithClasses.keySet()) { - SmithClass smithClass = smithClasses.get(key); - smithClass.clear(); - smithClasses.remove(key); - } - smithClasses.clear(); - smithClasses = null; - for (String key : patchers.keySet()) { - patchers.remove(key); - } - patchers.clear(); - patchers = null; - filters.clear(); - filters = null; - for (Pair key : blocks.keySet()) { - Block value = blocks.get(key); - value.removeAll(); - blocks.remove(key); - } - blocks.clear(); - blocks = null; - limits.clear(); - limits = null; - SmithLogger.logger.info("probe uninit 0"); - - disruptor = null; - ruleconfig = null; - rulemgr = null; - client = null; - - heartbeat = null; - inst = null; - ourInstance = null; - proberVersion = null; - proberPath = null; - MessageSerializer.delInstance(); - - SmithLogger.logger.info("probe uninit leave"); - SmithLogger.loggerProberUnInit(); - - } - - private void reloadClasses() { - reloadClasses(smithClasses.keySet()); - } - - - private String getJarPath(Class clazz) { - CodeSource codeSource = clazz.getProtectionDomain().getCodeSource(); - if (codeSource != null) { - URL location = codeSource.getLocation(); - try { - File file = new File(location.toURI()); - return file.getAbsolutePath(); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - return null; - } - - private String[] addJarclassns = { - "org.apache.felix.framework.BundleWiringImpl$BundleClassLoader" - }; - - private Set addedJarset = Collections.synchronizedSet(new HashSet<>()); - - public void checkNeedAddJarPath(Class clazz,Instrumentation inst) { - try { - String cn = clazz.getName(); - for (String name : addJarclassns) { - if(cn.equals(name)) { - try { - String jarFile = getJarPath(clazz); - if(jarFile != null && !addedJarset.contains(jarFile)) { - SmithLogger.logger.info("add "+ name + " jarpath:"+jarFile); - inst.appendToSystemClassLoaderSearch(new JarFile(jarFile)); - addedJarset.add(jarFile); - } - }catch(Exception e) { - SmithLogger.exception(e); - } - } - } - } - catch(Exception e) { - SmithLogger.exception(e); - } - } - - public void checkNeedAddJarPaths(Class[] cls,Instrumentation inst) { - for (Class cx : cls) { - checkNeedAddJarPath(cx,inst); - } - } - - private void reloadClasses(Collection classes) { - Class[] loadedClasses = inst.getAllLoadedClasses(); - - List> resultList = new ArrayList<>(); - for (Class loadedClass : loadedClasses) { - if (classes.contains(loadedClass.getName())) { - resultList.add(loadedClass); - } - } - Class[] cls = resultList.toArray(new Class[0]); - - SmithLogger.logger.info("reload: " + Arrays.toString(cls)); - - checkNeedAddJarPaths(cls,inst); - - try { - inst.retransformClasses(cls); - } catch (UnmodifiableClassException e) { - SmithLogger.exception(e); - } - } - - @Override - public void onEvent(Trace trace, long sequence, boolean endOfBatch) { - Filter filter = filters.get(new ImmutablePair<>(trace.getClassID(), trace.getMethodID())); - - if (filter == null) { - client.write(Operate.TRACE, trace); - return; - } - - MatchRulePredicate pred = new MatchRulePredicate(trace); - - MatchRule[] include = filter.getInclude(); - MatchRule[] exclude = filter.getExclude(); - - if (include.length > 0 && Arrays.stream(include).noneMatch(pred)) - return; - - if (exclude.length > 0 && Arrays.stream(exclude).anyMatch(pred)) - return; - - client.write(Operate.TRACE, trace); - } - - public void printClassfilter(ClassFilter data) { - SmithLogger.logger.info("------------------------------------------------------------------------"); - SmithLogger.logger.info("className:" + data.getClassName()); - SmithLogger.logger.info("classPath:" + data.getClassPath()); - SmithLogger.logger.info("interfaceName:" + data.getInterfacesName()); - SmithLogger.logger.info("classLoaderName:" + data.getClassLoaderName()); - SmithLogger.logger.info("parentClassName:" + data.getParentClassName()); - SmithLogger.logger.info("parentClassLoaderName:" + data.getParentClassLoaderName()); - - } - - public InputStream byteArrayToInputStream(byte[] bytes) throws IOException { - if(bytes == null) { - return null; - } - - ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); - - return inputStream; - } - - private void checkClassFilter(ClassLoader loader, String className, byte[] classfileBuffer) { - CtClass ctClass = null; - - try { - if (className == null && classfileBuffer == null) { - SmithLogger.logger.info("nononono className == null && classfileBuffer == null"); - return ; - } - - String className_std = ""; - - ClassPool pool = ClassPool.getDefault(); - - if (className != null) { - className_std = className; - if (className.startsWith("com/security/smith") || className.startsWith("rasp/")) { - return ; - } - - ctClass = pool.makeClass(className); - } - else { - className_std = ""; - - InputStream inputS = byteArrayToInputStream(classfileBuffer); - - ctClass = pool.makeClass(inputS); - } - - if (ctClass == null) { - return ; - } else { - className_std = ctClass.getName(); - } - - if (className_std != "") { - className_std = className_std.replace("/", "."); - } - - ClassFilter classFilter = new ClassFilter(); - if (loader != null) { - classFilter.setClassLoaderName(loader.getClass().getName()); - } - classFilter.setClassName(className_std); - - - try { - if (!ctClass.isInterface()) { - classFilter.setInterfacesName(SmithHandler.getCtClassInterfaces(ctClass)); - } - classFilter.setClassPath(SmithHandler.getCtClassPath(ctClass)); - CtClass superClass = null; - try { - superClass = ctClass.getSuperclass(); - } catch(Exception e) { - // SmithLogger.exception(e); - } - - String superClassName = superClass != null ? superClass.getName() : ""; - classFilter.setParentClassName(superClassName); - - if (superClass != null) { - ClassLoader parentClassLoader = superClass.getClassPool().getClassLoader(); - String parentClassLoaderName = parentClassLoader != null ? parentClassLoader.getClass().getName() : ""; - classFilter.setParentClassLoaderName(parentClassLoaderName); - } - } catch (Exception e) { - SmithLogger.exception(e); - } - - long rule_id = rulemgr.matchRule(classFilter); - if(rule_id != -1) { - - classFilter.setRuleId(rule_id); - classFilter.setTransId(); - classFilter.setStackTrace(Thread.currentThread().getStackTrace()); - - client.write(Operate.SCANCLASS, classFilter); - SmithLogger.logger.info("send metadata: " + classFilter.toString()); - Thread.sleep(1000); - sendByte(classfileBuffer, classFilter.getTransId()); - } - } catch(Exception e) { - SmithLogger.exception(e); - } - finally { - if(ctClass != null) { - ctClass.detach(); - } - } - } - - boolean hasExceptionHook(Map methodMap) { - Iterator> iterator = methodMap.entrySet().iterator(); - while (iterator.hasNext()) { - Map.Entry entry = iterator.next(); - SmithMethod value = entry.getValue(); - String exceptionHookName = value.getExceptionHook(); - if(exceptionHookName != null && exceptionHookName.length() > 1 && exceptionHookName != "") { - return true; - } - } - - return false; - } - - - @Override - public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) { - if (disable) - return null; - - // if(scanswitch) { - // checkClassFilter(loader, className,classfileBuffer); - // } - - Type classType = null; - SmithClass smithClass = null; - try { - classType = Type.getObjectType(className); - smithClass = smithClasses.get(classType.getClassName()); - } catch (Exception e) { - //SmithLogger.exception(e); - } - - if (smithClass == null) { - - ClassReader cr = new ClassReader(classfileBuffer); - - if (className == null) { - className = cr.getClassName(); - classType = Type.getObjectType(className); - } - String[] interfaces = cr.getInterfaces(); - String superClass = cr.getSuperName(); - - try { - String[] combined; - if (superClass != null) { - combined = new String[interfaces.length + 1]; - System.arraycopy(interfaces, 0, combined, 0, interfaces.length); - combined[interfaces.length] = superClass; - } else { - combined = interfaces; - } - - for (String interName : combined) { - if (SmithHandler.checkInterfaceNeedTran(interName)) { - Type interfaceType = Type.getObjectType(interName); - smithClass = smithClasses.get(interfaceType.getClassName()); - break; - } - } - } catch (Throwable e) { - SmithLogger.exception(e); - } - - if (smithClass == null) { - return null; - } - } - - try { - Map methodMap = new HashMap<>(); - List methods = smithClass.getMethods(); - - for (SmithMethod method : methods) { - String key = method.getName() + method.getDesc(); - methodMap.put(key, method); - } - - SmithLogger.logger.info("transform: " + classType.getClassName()); - ClassReader classReader = new ClassReader(classfileBuffer); - - ClassWriter classWriter; - if(!hasExceptionHook(methodMap)) { - classWriter = new SmithClassWriter(ClassWriter.COMPUTE_MAXS); - } - else { - classWriter = new SmithClassWriter(ClassWriter.COMPUTE_FRAMES); - } - ClassVisitor classVisitor = new SmithClassVisitor( - Opcodes.ASM9, - classWriter, - smithClass.getId(), - classType, - methodMap - ); - - - classReader.accept(classVisitor, ClassReader.EXPAND_FRAMES); - - return classWriter.toByteArray(); - } catch (Throwable e) { - SmithLogger.exception(e); - } - - return null; - } - - @Override - public void onConfig(String config) { - SmithLogger.logger.info("on config: " + config); - - Set classes = new HashSet<>(smithClasses.keySet()); - - smithClasses.clear(); - - try { - YamlReader yamlReader = new YamlReader(new StringReader(config)); - for (SmithClass smithClass : yamlReader.read(SmithClass[].class)) { - smithClasses.put(smithClass.getName(), smithClass); - } - } catch (IOException e) { - SmithLogger.exception(e); - } - - classes.addAll(smithClasses.keySet()); - reloadClasses(classes); - } - - @Override - public void onControl(int action) { - SmithLogger.logger.info("on control: " + action); - disable = action == STOP; - reloadClasses(); - } - - @Override - public void onDetect() { - SmithLogger.logger.info("on detect"); - - Set jars = new HashSet<>(); - - for (Class cl : inst.getAllLoadedClasses()) { - CodeSource codeSource = cl.getProtectionDomain().getCodeSource(); - - if (codeSource == null) - continue; - - Jar jar = new Jar(); - - URL url = codeSource.getLocation(); - - if (url == null) - continue; - - jar.setPath(url.toString()); - - if (jars.contains(jar)) - continue; - - Package pkg = cl.getPackage(); - - if (pkg == null) - continue; - - jar.setSpecificationTitle(pkg.getSpecificationTitle()); - jar.setSpecificationVersion(pkg.getSpecificationVersion()); - jar.setImplementationTitle(pkg.getImplementationTitle()); - jar.setImplementationVersion(pkg.getImplementationVersion()); - - jars.add(jar); - } - - client.write(Operate.DETECT, Collections.singletonMap("jars", jars)); - } - - @Override - public void onFilter(FilterConfig config) { - filters.clear(); - - for (Filter filter : config.getFilters()) { - filters.put( - new ImmutablePair<>(filter.getClassID(), filter.getMethodID()), - filter - ); - } - - heartbeat.setFilter(config.getUUID()); - } - - @Override - public void onBlock(BlockConfig config) { - blocks.clear(); - - for (Block block : config.getBlocks()) { - blocks.put( - new ImmutablePair<>(block.getClassID(), block.getMethodID()), - block - ); - } - - heartbeat.setBlock(config.getUUID()); - } - - @Override - public void onLimit(LimitConfig config) { - limits.clear(); - - for (Limit limit : config.getLimits()) { - limits.put( - new ImmutablePair<>(limit.getClassID(), limit.getMethodID()), - limit.getQuota() - ); - } - - heartbeat.setLimit(config.getUUID()); - } - - @Override - public void onPatch(PatchConfig config) { - if (config == null || config.getPatches() == null || config.getPatches().length == 0) { - SmithLogger.logger.info("patch may not be download, so not update heartbeat"); - return ; - } - for (Patch patch : config.getPatches()) { - SmithLogger.logger.info("install patch: " + patch.getClassName()); - - if (patchers.containsKey(patch.getClassName())) { - SmithLogger.logger.info("ignore installed patch: " + patch.getClassName()); - continue; - } - - try (URLClassLoader loader = new URLClassLoader(new URL[]{patch.getUrl()})) { - Patcher patcher = loader.loadClass(patch.getClassName()) - .asSubclass(Patcher.class) - .getConstructor(Instrumentation.class) - .newInstance(inst); - - patcher.install(); - - patchers.put(patch.getClassName(), patcher); - } catch (IOException | ClassNotFoundException | NoSuchMethodException | InvocationTargetException | - InstantiationException | IllegalAccessException e) { - SmithLogger.exception(e); - } - } - - Set active = Arrays.stream(config.getPatches()).map(Patch::getClassName).collect(Collectors.toSet()); - - for (String name : patchers.keySet()) { - if (!active.contains(name)) { - SmithLogger.logger.info("uninstall patch: " + name); - Patcher patcher = patchers.remove(name); - if (patcher == null) - continue; - patcher.uninstall(); - } - } - - heartbeat.setPatch(config.getUUID()); - } - - - @Override - public boolean setRuleVersion(Rule_Version ruleVersion) { - boolean bresult = false; - - try { - bresult = ruleconfig.setVersion(ruleVersion.getRule_version()); - heartbeat.setClassFilterVersion(ruleVersion.getClass_filter_version()); - } - catch(Exception e) { - SmithLogger.exception(e); - } - - return bresult; - } - - @Override - public boolean OnAddRule(Rule_Data ruleData) { - boolean bresult = false; - - try { - bresult = ruleconfig.addRuleData(ruleData); - } - catch(Exception e) { - SmithLogger.exception(e); - } - - return bresult; - } - - @Override - public boolean OnAddRule(String rulejson) { - boolean bresult = false; - - try { - bresult = ruleconfig.setRuleConfig(rulejson); - } - catch(Exception e) { - SmithLogger.exception(e); - } - - return bresult; - } - - - /* scan all class */ - @Override - public void onScanAllClass() { - if (scanswitch == false) { - return; - } - scanswitch = false; - - try { - Class[] loadedClasses = inst.getAllLoadedClasses(); - - for (Class clazz : loadedClasses) { - try { - - String className = clazz.getName(); - if (className.startsWith("rasp.") || className.startsWith("com.security.smith") || className.startsWith("java.lang.invoke.LambdaForm")) { - continue; - } - - if(classIsSended(clazz)) { - continue; - } - - ClassFilter classFilter = new ClassFilter(); - - SmithHandler.queryClassFilter(clazz, classFilter); - long rule_id = -1; - if (!SmithHandler.checkClassMemshell(clazz)) { - rule_id = rulemgr.matchRule(classFilter); - if (rule_id == -1) - continue; - } - classFilter.setTransId(); - classFilter.setRuleId(rule_id); - classFilter.setStackTrace(Thread.currentThread().getStackTrace()); - - client.write(Operate.SCANCLASS, classFilter); - SmithLogger.logger.info("send metadata: " + classFilter.toString()); - sendClass(clazz, classFilter.getTransId()); - - } catch(Exception e) { - SmithLogger.exception(e); - } - } - } catch(Exception e) { - SmithLogger.exception(e); - } - finally { - scanswitch = true; - SmithLogger.logger.info("scan all class finished"); - } - } - - /* - * send class file - */ - public void sendClass(Class clazz, String transId) { - if (clazz == null || transId == null) { - return; - } - try { - ClassUploadTransformer.getInstance().sendClass(clazz, transId); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - public boolean classIsSended(Class clazz) { - try { - return ClassUploadTransformer.getInstance().classIsSended(clazz.hashCode()); - } catch (Exception e) { - SmithLogger.exception(e); - } - - return false; - } - - /* - * send CtClass file - */ - private void sendByte(byte[] data, String transId) { - if (data == null) { - return; - } - int length = data.length; - ClassUpload classUpload = new ClassUpload(); - classUpload.setTransId(transId); - - // client.write(Operate.CLASSDUMP, classUpload); - - // int packetSize = 1024; \ - // int totalPackets = (data.length + packetSize - 1) / packetSize; - //for (int i = 0; i < totalPackets; i++) { - //int offset = i * packetSize; - classUpload.setByteTotalLength(length); - //classUpload.setByteOffset(offset); - classUpload.setByteLength(length); - //int send_length = Math.min(packetSize, data.length - offset); - classUpload.setClassData(data); - - client.write(Operate.CLASSUPLOAD, classUpload); - SmithLogger.logger.info("send classdata: " + classUpload.toString()); - //} - } - - public Heartbeat getHeartbeat() { - return heartbeat; - } - - public void addDisacrdCount() { - int discrad_count = this.heartbeat.getDiscardCount(); - discrad_count++; - this.heartbeat.setDiscardCount(discrad_count); - } - - public Map, Integer> getLimits() { - return limits; - } - - public Map, Block> GetBlocks() { - return blocks; - } - - public Map, Filter> GetFiltes() { - return filters; - } - - public Client getClient() { - return client; - } - - public Disruptor getDisruptor() { - return disruptor; - } - -} +package com.security.smith; + +import com.lmax.disruptor.EventHandler; +import java.io.InputStreamReader; +import java.io.Reader; +import java.io.StringReader; + +import com.lmax.disruptor.dsl.Disruptor; +import com.lmax.disruptor.EventFactory; +import com.lmax.disruptor.util.DaemonThreadFactory; +import com.security.smith.asm.SmithClassVisitor; +import com.security.smith.asm.SmithClassWriter; +import com.security.smith.client.message.*; +import com.security.smith.common.Reflection; +import com.security.smith.common.SmithHandler; +import com.security.smith.common.SmithTools; +import com.security.smith.log.AttachInfo; +import com.security.smith.log.SmithLogger; +import com.security.smith.module.Patcher; +import com.security.smith.type.*; +import com.security.smith.client.*; +import com.esotericsoftware.yamlbeans.YamlReader; + + +import javassist.ClassPool; +import javassist.CtClass; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; +import org.objectweb.asm.*; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.lang.instrument.ClassFileTransformer; +import java.lang.instrument.Instrumentation; +import java.lang.instrument.UnmodifiableClassException; + +import java.lang.reflect.InvocationTargetException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.CodeSource; +import java.security.ProtectionDomain; + +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +import java.util.function.Predicate; +import java.util.regex.Pattern; +import java.util.regex.Matcher; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import java.io.File; +import java.io.FileOutputStream; +import java.security.CodeSource; +import java.util.jar.JarFile; + + +import java.io.FileOutputStream; +import java.io.IOException; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonArray; +import com.google.gson.GsonBuilder; +import com.security.smith.client.message.*; + +import java.io.File; +import java.io.FileOutputStream; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; + +class DetectTimerTask extends TimerTask { + private boolean isCancel = false; + private SmithProbe Probe = null; + + public void setSmithProbe(SmithProbe Probe) { + this.Probe = Probe; + } + + @Override + public void run() { + if(!isCancel) { + Probe.onDetect(); + } + } + + @Override + public boolean cancel() { + isCancel = true; + return super.cancel(); + } +} + +class SmithproxyTimerTask extends TimerTask { + private boolean isCancel = false; + private SmithProbeProxy smithProxy = null; + + public void setSmithProxy(SmithProbeProxy smithProxy) { + this.smithProxy = smithProxy; + } + + @Override + public void run() { + if(!isCancel) { + smithProxy.onTimer(); + } + } + + @Override + public boolean cancel() { + isCancel = true; + return super.cancel(); + } +} + +class MatchRulePredicate implements Predicate { + private final Trace trace; + + MatchRulePredicate(Trace trace) { + this.trace = trace; + } + + public boolean test(MatchRule rule) { + Object[] args = this.trace.getArgs(); + if (rule.getIndex() >= args.length || rule.getRegex().isEmpty() || args[rule.getIndex()] == null) + return false; + Pattern pattern = Pattern.compile(rule.getRegex()); + Matcher matcher = pattern.matcher(args[rule.getIndex()].toString()); + return matcher.find(); + } +} + +public class SmithProbe implements ClassFileTransformer, MessageHandler, EventHandler { + private final int STOP = 0; + private final int START = 1; + private SmithProbe ourInstance = null; + private SmithProbeProxy smithProxy = null; + private int TRACE_BUFFER_SIZE = 1024; + + private Object xClassLoaderObj; + private Boolean disable; + private Boolean scanswitch; + private Instrumentation inst; + private Client client; + private Heartbeat heartbeat; + + private Map smithClasses; + private Map patchers; + private Map, Filter> filters; + private Map, Block> blocks; + private Map, Integer> limits; + private Disruptor disruptor; + + private Rule_Mgr rulemgr; + private Rule_Config ruleconfig; + private Timer detectTimer; + private Timer smithproxyTimer; + private DetectTimerTask detectTimerTask; + private SmithproxyTimerTask smithproxyTimerTask; + private String proberVersion; + private String proberPath; + + public SmithProbe() { + disable = false; + scanswitch = true; + } + + public void setInst(Instrumentation inst) { + this.inst = inst; + } + + public Object getSmithProbeProxy() { + return smithProxy; + } + + public void setClassLoader(Object classLoaderObj) { + xClassLoaderObj = classLoaderObj; + } + + public InputStream getResourceAsStream(String name) { + Class[] strArgTypes = new Class[]{String.class}; + return (InputStream)Reflection.invokeMethod(xClassLoaderObj,"getResourceAsStream", strArgTypes,name); + } + + public void setProbeVersion(String proberVer) { + proberVersion = proberVer; + } + + public void setProbePath(String proberPath) { + this.proberPath = proberPath; + } + + public void init() { + AttachInfo.info(); + SmithLogger.loggerProberInit(); + SmithLogger.logger.info("probe init enter"); + smithClasses = new ConcurrentHashMap<>(); + patchers = new ConcurrentHashMap<>(); + filters = new ConcurrentHashMap<>(); + blocks = new ConcurrentHashMap<>(); + limits = new ConcurrentHashMap<>(); + + MessageEncoder.initInstance(); + MessageSerializer.initInstance(proberVersion); + MessageDecoder.initInstance(); + + heartbeat = new Heartbeat(); + + try { + client = new Client(this); + } + catch(Throwable e) { + SmithLogger.exception(e); + } + + disruptor = new Disruptor<>(new EventFactory() { + @Override + public Trace newInstance() { + return new Trace(); + } + }, TRACE_BUFFER_SIZE, DaemonThreadFactory.INSTANCE); + + rulemgr = new Rule_Mgr(); + ruleconfig = new Rule_Config(rulemgr); + + smithProxy = new SmithProbeProxy(); + + InputStream inputStream = getResourceAsStream("class.yaml"); + + if(inputStream != null) { + SmithLogger.logger.info("find class.yaml"); + try { + Reader xreader = new InputStreamReader(inputStream); + YamlReader yamlReader = new YamlReader(xreader); + for (SmithClass smithClass : yamlReader.read(SmithClass[].class)) { + smithClasses.put(smithClass.getName(), smithClass); + } + } catch (IOException e) { + SmithLogger.exception(e); + } + } + else { + SmithLogger.logger.info("not find class.yaml"); + } + + SmithLogger.logger.info("probe init leave"); + } + private boolean isBypassHookClass(String className) { + + if(SmithTools.isGlassfish() && SmithTools.getMajorVersion() > 5) { + /* + * In versions after GlassFish 5 (not including GlassFish 5), + * not hooking java.io.File will cause the JVM process to crash directly if hooked. + * + */ + if(className.equals("java.io.File")) { + return true; + } + } + + return false; + } + + public void start() { + SmithLogger.logger.info("probe start"); + AttachInfo.info(); + + SmithLogger.logger.info("init ClassUploadTransformer"); + ClassUploadTransformer.getInstance().start(client, inst); + + Thread clientThread = new Thread(client::start); + + disruptor.handleEventsWith(this); + disruptor.start(); + + clientThread.setDaemon(true); + clientThread.start(); + + detectTimerTask = new DetectTimerTask(); + detectTimerTask.setSmithProbe(this); + + detectTimer = new Timer(true); + detectTimer.schedule( + detectTimerTask, + TimeUnit.MINUTES.toMillis(1) + ); + smithproxyTimerTask = new SmithproxyTimerTask(); + smithproxyTimerTask.setSmithProxy(smithProxy); + + smithproxyTimer = new Timer(true); + smithproxyTimer.schedule( + smithproxyTimerTask, + 0, + TimeUnit.MINUTES.toMillis(1) + ); + smithProxy.setClient(client); + smithProxy.setDisruptor(disruptor); + smithProxy.setProbe(this); + smithProxy.setReflectField(); + smithProxy.setReflectMethod(); + + inst.addTransformer(this, true); + reloadClasses(); + + SmithLogger.logger.info("probe start leave"); + } + + public void stop() { + SmithLogger.logger.info("probe stop enter"); + + inst.removeTransformer(this); + reloadClasses(); + SmithLogger.logger.info("Transformer stop"); + + disable = true; + scanswitch = false; + + ClassUploadTransformer.getInstance().stop(); + + SmithLogger.logger.info("Upload Transformer stop"); + + detectTimer.cancel(); + smithproxyTimer.cancel(); + SmithLogger.logger.info("detect Timer stop"); + + client.stop(); + SmithLogger.logger.info("client stop"); + + ruleconfig.destry(); + SmithLogger.logger.info("ruleconfig stop"); + + rulemgr.destry(); + SmithLogger.logger.info("rulemgr stop"); + + detectTimerTask = null; + detectTimer =null; + + smithproxyTimerTask = null; + smithproxyTimer = null; + + SmithLogger.logger.info("probe stop leave"); + } + + public void uninit() { + SmithLogger.logger.info("probe uninit enter"); + ClassUploadTransformer.delInstance(); + + smithProxy.uninit(); + smithProxy = null; + + disruptor.shutdown(); + + for (String key : smithClasses.keySet()) { + SmithClass smithClass = smithClasses.get(key); + smithClass.clear(); + smithClasses.remove(key); + } + smithClasses.clear(); + smithClasses = null; + for (String key : patchers.keySet()) { + patchers.remove(key); + } + patchers.clear(); + patchers = null; + filters.clear(); + filters = null; + for (Pair key : blocks.keySet()) { + Block value = blocks.get(key); + value.removeAll(); + blocks.remove(key); + } + blocks.clear(); + blocks = null; + limits.clear(); + limits = null; + SmithLogger.logger.info("probe uninit 0"); + + disruptor = null; + ruleconfig = null; + rulemgr = null; + client = null; + + heartbeat = null; + inst = null; + ourInstance = null; + proberVersion = null; + proberPath = null; + MessageSerializer.delInstance(); + + MessageEncoder.delInstance(); + MessageSerializer.delInstance(); + MessageDecoder.delInstance(); + SmithLogger.logger.info("probe uninit leave"); + SmithLogger.loggerProberUnInit(); + + } + + private void reloadClasses() { + reloadClasses(smithClasses.keySet()); + } + + + private String getJarPath(Class clazz) { + CodeSource codeSource = clazz.getProtectionDomain().getCodeSource(); + if (codeSource != null) { + URL location = codeSource.getLocation(); + try { + File file = new File(location.toURI()); + return file.getAbsolutePath(); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + return null; + } + + private String[] addJarclassns = { + "org.apache.felix.framework.BundleWiringImpl$BundleClassLoader" + }; + + private Set addedJarset = Collections.synchronizedSet(new HashSet<>()); + + public void checkNeedAddJarPath(Class clazz,Instrumentation inst) { + try { + String cn = clazz.getName(); + for (String name : addJarclassns) { + if(cn.equals(name)) { + try { + String jarFile = getJarPath(clazz); + if(jarFile != null && !addedJarset.contains(jarFile)) { + SmithLogger.logger.info("add "+ name + " jarpath:"+jarFile); + inst.appendToSystemClassLoaderSearch(new JarFile(jarFile)); + addedJarset.add(jarFile); + } + }catch(Exception e) { + SmithLogger.exception(e); + } + } + } + } + catch(Exception e) { + SmithLogger.exception(e); + } + } + + public void checkNeedAddJarPaths(Class[] cls,Instrumentation inst) { + for (Class cx : cls) { + checkNeedAddJarPath(cx,inst); + } + } + + private void reloadClasses(Collection classes) { + Class[] loadedClasses = inst.getAllLoadedClasses(); + + List> resultList = new ArrayList<>(); + for (Class loadedClass : loadedClasses) { + if (classes.contains(loadedClass.getName())) { + resultList.add(loadedClass); + } + } + Class[] cls = resultList.toArray(new Class[0]); + + SmithLogger.logger.info("reload: " + Arrays.toString(cls)); + + checkNeedAddJarPaths(cls,inst); + + try { + inst.retransformClasses(cls); + } catch (UnmodifiableClassException e) { + SmithLogger.exception(e); + } + } + + @Override + public void onEvent(Trace trace, long sequence, boolean endOfBatch) { + Filter filter = filters.get(new ImmutablePair<>(trace.getClassID(), trace.getMethodID())); + + if (filter == null) { + Gson gson = new GsonBuilder() + .registerTypeAdapter(Trace.class, new TraceSerializer()) + .registerTypeAdapter(Trace.class, new TraceDeserializer()) + .create(); + JsonElement jsonElement = gson.toJsonTree(trace); + + client.write(Operate.TRACE, jsonElement); + return; + } + + MatchRulePredicate pred = new MatchRulePredicate(trace); + + MatchRule[] include = filter.getInclude(); + MatchRule[] exclude = filter.getExclude(); + + if (include.length > 0 && Arrays.stream(include).noneMatch(pred)) + return; + + if (exclude.length > 0 && Arrays.stream(exclude).anyMatch(pred)) + return; + + Gson gson = new GsonBuilder() + .registerTypeAdapter(Trace.class, new TraceSerializer()) + .registerTypeAdapter(Trace.class, new TraceDeserializer()) + .create(); + JsonElement jsonElement = gson.toJsonTree(trace); + + client.write(Operate.TRACE, jsonElement); + } + + public void printClassfilter(ClassFilter data) { + SmithLogger.logger.info("------------------------------------------------------------------------"); + SmithLogger.logger.info("className:" + data.getClassName()); + SmithLogger.logger.info("classPath:" + data.getClassPath()); + SmithLogger.logger.info("interfaceName:" + data.getInterfacesName()); + SmithLogger.logger.info("classLoaderName:" + data.getClassLoaderName()); + SmithLogger.logger.info("parentClassName:" + data.getParentClassName()); + SmithLogger.logger.info("parentClassLoaderName:" + data.getParentClassLoaderName()); + + } + + public InputStream byteArrayToInputStream(byte[] bytes) throws IOException { + if(bytes == null) { + return null; + } + + ByteArrayInputStream inputStream = new ByteArrayInputStream(bytes); + + return inputStream; + } + + private void checkClassFilter(ClassLoader loader, String className, byte[] classfileBuffer) { + CtClass ctClass = null; + + try { + if (className == null && classfileBuffer == null) { + SmithLogger.logger.info("nononono className == null && classfileBuffer == null"); + return ; + } + + String className_std = ""; + + ClassPool pool = ClassPool.getDefault(); + + if (className != null) { + className_std = className; + if (className.startsWith("com/security/smith") || className.startsWith("rasp/")) { + return ; + } + + ctClass = pool.makeClass(className); + } + else { + className_std = ""; + + InputStream inputS = byteArrayToInputStream(classfileBuffer); + + ctClass = pool.makeClass(inputS); + } + + if (ctClass == null) { + return ; + } else { + className_std = ctClass.getName(); + } + + if (className_std != "") { + className_std = className_std.replace("/", "."); + } + + ClassFilter classFilter = new ClassFilter(); + if (loader != null) { + classFilter.setClassLoaderName(loader.getClass().getName()); + } + classFilter.setClassName(className_std); + + + try { + if (!ctClass.isInterface()) { + classFilter.setInterfacesName(SmithHandler.getCtClassInterfaces(ctClass)); + } + classFilter.setClassPath(SmithHandler.getCtClassPath(ctClass)); + CtClass superClass = null; + try { + superClass = ctClass.getSuperclass(); + } catch(Exception e) { + // SmithLogger.exception(e); + } + + String superClassName = superClass != null ? superClass.getName() : ""; + classFilter.setParentClassName(superClassName); + + if (superClass != null) { + ClassLoader parentClassLoader = superClass.getClassPool().getClassLoader(); + String parentClassLoaderName = parentClassLoader != null ? parentClassLoader.getClass().getName() : ""; + classFilter.setParentClassLoaderName(parentClassLoaderName); + } + } catch (Exception e) { + SmithLogger.exception(e); + } + + long rule_id = rulemgr.matchRule(classFilter); + if(rule_id != -1) { + + classFilter.setRuleId(rule_id); + classFilter.setTransId(); + classFilter.setStackTrace(Thread.currentThread().getStackTrace()); + + Gson gson = new GsonBuilder() + .registerTypeAdapter(ClassFilter.class, new ClassFilterSerializer()) + .registerTypeAdapter(ClassFilter.class, new ClassFilterDeserializer()) + .create(); + JsonElement jsonElement = gson.toJsonTree(classFilter); + + client.write(Operate.SCANCLASS, jsonElement); + SmithLogger.logger.info("send metadata: " + classFilter.toString()); + Thread.sleep(1000); + sendByte(classfileBuffer, classFilter.getTransId()); + } + } catch(Exception e) { + SmithLogger.exception(e); + } + finally { + if(ctClass != null) { + ctClass.detach(); + } + } + } + + boolean hasExceptionHook(Map methodMap) { + Iterator> iterator = methodMap.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + SmithMethod value = entry.getValue(); + String exceptionHookName = value.getExceptionHook(); + if(exceptionHookName != null && exceptionHookName.length() > 1 && exceptionHookName != "") { + return true; + } + } + + return false; + } + + + @Override + public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, ProtectionDomain protectionDomain, byte[] classfileBuffer) { + if (disable) + return null; + + // if(scanswitch) { + // checkClassFilter(loader, className,classfileBuffer); + // } + + Type classType = null; + SmithClass smithClass = null; + try { + classType = Type.getObjectType(className); + smithClass = smithClasses.get(classType.getClassName()); + } catch (Exception e) { + //SmithLogger.exception(e); + } + + if (smithClass == null) { + + ClassReader cr = new ClassReader(classfileBuffer); + + if (className == null) { + className = cr.getClassName(); + classType = Type.getObjectType(className); + } + String[] interfaces = cr.getInterfaces(); + String superClass = cr.getSuperName(); + + try { + String[] combined; + if (superClass != null) { + combined = new String[interfaces.length + 1]; + System.arraycopy(interfaces, 0, combined, 0, interfaces.length); + combined[interfaces.length] = superClass; + } else { + combined = interfaces; + } + + for (String interName : combined) { + if (SmithHandler.checkInterfaceNeedTran(interName)) { + Type interfaceType = Type.getObjectType(interName); + smithClass = smithClasses.get(interfaceType.getClassName()); + break; + } + } + } catch (Throwable e) { + SmithLogger.exception(e); + } + + if (smithClass == null) { + return null; + } + } + + try { + Map methodMap = new HashMap<>(); + List methods = smithClass.getMethods(); + + for (SmithMethod method : methods) { + String key = method.getName() + method.getDesc(); + methodMap.put(key, method); + } + + SmithLogger.logger.info("transform: " + classType.getClassName()); + ClassReader classReader = new ClassReader(classfileBuffer); + + ClassWriter classWriter; + if(!hasExceptionHook(methodMap)) { + classWriter = new SmithClassWriter(ClassWriter.COMPUTE_MAXS); + } + else { + classWriter = new SmithClassWriter(ClassWriter.COMPUTE_FRAMES); + } + ClassVisitor classVisitor = new SmithClassVisitor( + Opcodes.ASM9, + classWriter, + smithClass.getId(), + classType, + methodMap + ); + + classReader.accept(classVisitor, ClassReader.EXPAND_FRAMES); + + return classWriter.toByteArray(); + } catch (Throwable e) { + SmithLogger.exception(e); + } + + return null; + } + + @Override + public void onConfig(String config) { + SmithLogger.logger.info("on config: " + config); + + Set classes = new HashSet<>(smithClasses.keySet()); + + smithClasses.clear(); + + try { + YamlReader yamlReader = new YamlReader(new StringReader(config)); + for (SmithClass smithClass : yamlReader.read(SmithClass[].class)) { + smithClasses.put(smithClass.getName(), smithClass); + } + } catch (IOException e) { + SmithLogger.exception(e); + } + + classes.addAll(smithClasses.keySet()); + reloadClasses(classes); + } + + @Override + public void onControl(int action) { + SmithLogger.logger.info("on control: " + action); + disable = action == STOP; + reloadClasses(); + } + + public static JsonElement convertJarsToJsonElement(Set jars) { + Gson gson = new Gson(); + + JsonArray jarsArray = new JsonArray(); + for (Jar jar : jars) { + JsonObject jarObj = new JsonObject(); + jarObj.addProperty("path", jar.getPath()); + jarObj.addProperty("implementationTitle", jar.getImplementationTitle()); + jarObj.addProperty("implementationVersion", jar.getImplementationVersion()); + jarObj.addProperty("specificationTitle", jar.getSpecificationTitle()); + jarObj.addProperty("specificationVersion", jar.getSpecificationVersion()); + jarsArray.add(jarObj); + } + + JsonObject jsonObj = new JsonObject(); + jsonObj.add("jars", jarsArray); + + return jsonObj; + } + + @Override + public void onDetect() { + SmithLogger.logger.info("on detect"); + + Set jars = new HashSet<>(); + + for (Class cl : inst.getAllLoadedClasses()) { + CodeSource codeSource = cl.getProtectionDomain().getCodeSource(); + + if (codeSource == null) + continue; + + Jar jar = new Jar(); + + URL url = codeSource.getLocation(); + + if (url == null) + continue; + + jar.setPath(url.toString()); + + if (jars.contains(jar)) + continue; + + Package pkg = cl.getPackage(); + + if (pkg == null) + continue; + + jar.setSpecificationTitle(pkg.getSpecificationTitle()); + jar.setSpecificationVersion(pkg.getSpecificationVersion()); + jar.setImplementationTitle(pkg.getImplementationTitle()); + jar.setImplementationVersion(pkg.getImplementationVersion()); + + jars.add(jar); + } + + JsonElement jsonElement = convertJarsToJsonElement(jars); + + client.write(Operate.DETECT, jsonElement); + } + + @Override + public void onFilter(FilterConfig config) { + filters.clear(); + + for (Filter filter : config.getFilters()) { + filters.put( + new ImmutablePair<>(filter.getClassID(), filter.getMethodID()), + filter + ); + } + + heartbeat.setFilter(config.getUUID()); + } + + @Override + public void onBlock(BlockConfig config) { + blocks.clear(); + + for (Block block : config.getBlocks()) { + blocks.put( + new ImmutablePair<>(block.getClassID(), block.getMethodID()), + block + ); + } + + heartbeat.setBlock(config.getUUID()); + } + + @Override + public void onLimit(LimitConfig config) { + limits.clear(); + + for (Limit limit : config.getLimits()) { + limits.put( + new ImmutablePair<>(limit.getClassID(), limit.getMethodID()), + limit.getQuota() + ); + } + + heartbeat.setLimit(config.getUUID()); + } + + @Override + public void onPatch(PatchConfig config) { + if (config == null || config.getPatches() == null || config.getPatches().length == 0) { + SmithLogger.logger.info("patch may not be download, so not update heartbeat"); + return ; + } + for (Patch patch : config.getPatches()) { + SmithLogger.logger.info("install patch: " + patch.getClassName()); + + if (patchers.containsKey(patch.getClassName())) { + SmithLogger.logger.info("ignore installed patch: " + patch.getClassName()); + continue; + } + + try (URLClassLoader loader = new URLClassLoader(new URL[]{patch.getUrl()})) { + Patcher patcher = loader.loadClass(patch.getClassName()) + .asSubclass(Patcher.class) + .getConstructor(Instrumentation.class) + .newInstance(inst); + + patcher.install(); + + patchers.put(patch.getClassName(), patcher); + } catch (IOException | ClassNotFoundException | NoSuchMethodException | InvocationTargetException | + InstantiationException | IllegalAccessException e) { + SmithLogger.exception(e); + } + } + + Set active = Arrays.stream(config.getPatches()).map(Patch::getClassName).collect(Collectors.toSet()); + + for (String name : patchers.keySet()) { + if (!active.contains(name)) { + SmithLogger.logger.info("uninstall patch: " + name); + Patcher patcher = patchers.remove(name); + if (patcher == null) + continue; + patcher.uninstall(); + } + } + + heartbeat.setPatch(config.getUUID()); + } + + + @Override + public boolean setRuleVersion(Rule_Version ruleVersion) { + boolean bresult = false; + + try { + bresult = ruleconfig.setVersion(ruleVersion.getRule_version()); + heartbeat.setClassFilterVersion(ruleVersion.getClass_filter_version()); + } + catch(Exception e) { + SmithLogger.exception(e); + } + + return bresult; + } + + @Override + public boolean OnAddRule(Rule_Data ruleData) { + boolean bresult = false; + + try { + bresult = ruleconfig.addRuleData(ruleData); + } + catch(Exception e) { + SmithLogger.exception(e); + } + + return bresult; + } + + @Override + public boolean OnAddRule(String rulejson) { + boolean bresult = false; + + try { + bresult = ruleconfig.setRuleConfig(rulejson); + } + catch(Exception e) { + SmithLogger.exception(e); + } + + return bresult; + } + + + /* scan all class */ + @Override + public void onScanAllClass() { + if (scanswitch == false) { + return; + } + scanswitch = false; + + try { + Class[] loadedClasses = inst.getAllLoadedClasses(); + + for (Class clazz : loadedClasses) { + try { + + String className = clazz.getName(); + if (className.startsWith("rasp.") || className.startsWith("com.security.smith") || className.startsWith("java.lang.invoke.LambdaForm")) { + continue; + } + + if(classIsSended(clazz)) { + continue; + } + + ClassFilter classFilter = new ClassFilter(); + + SmithHandler.queryClassFilter(clazz, classFilter); + long rule_id = -1; + if (!SmithHandler.checkClassMemshell(clazz)) { + rule_id = rulemgr.matchRule(classFilter); + if (rule_id == -1) + continue; + } + classFilter.setTransId(); + classFilter.setRuleId(rule_id); + classFilter.setStackTrace(Thread.currentThread().getStackTrace()); + + Gson gson = new GsonBuilder() + .registerTypeAdapter(ClassFilter.class, new ClassFilterSerializer()) + .registerTypeAdapter(ClassFilter.class, new ClassFilterDeserializer()) + .create(); + JsonElement jsonElement = gson.toJsonTree(classFilter); + + client.write(Operate.SCANCLASS, jsonElement); + SmithLogger.logger.info("send metadata: " + classFilter.toString()); + sendClass(clazz, classFilter.getTransId()); + + } catch(Exception e) { + SmithLogger.exception(e); + } + } + } catch(Exception e) { + SmithLogger.exception(e); + } + finally { + scanswitch = true; + SmithLogger.logger.info("scan all class finished"); + } + } + + /* + * send class file + */ + public void sendClass(Class clazz, String transId) { + if (clazz == null || transId == null) { + return; + } + try { + ClassUploadTransformer.getInstance().sendClass(clazz, transId); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + public boolean classIsSended(Class clazz) { + try { + return ClassUploadTransformer.getInstance().classIsSended(clazz.hashCode()); + } catch (Exception e) { + SmithLogger.exception(e); + } + + return false; + } + + /* + * send CtClass file + */ + private void sendByte(byte[] data, String transId) { + if (data == null) { + return; + } + int length = data.length; + ClassUpload classUpload = new ClassUpload(); + classUpload.setTransId(transId); + + // client.write(Operate.CLASSDUMP, classUpload); + + // int packetSize = 1024; \ + // int totalPackets = (data.length + packetSize - 1) / packetSize; + //for (int i = 0; i < totalPackets; i++) { + //int offset = i * packetSize; + classUpload.setByteTotalLength(length); + //classUpload.setByteOffset(offset); + classUpload.setByteLength(length); + //int send_length = Math.min(packetSize, data.length - offset); + classUpload.setClassData(data); + + Gson gson = new Gson(); + JsonElement jsonElement = gson.toJsonTree(classUpload); + + client.write(Operate.CLASSUPLOAD, jsonElement); + SmithLogger.logger.info("send classdata: " + classUpload.toString()); + //} + } + + public Heartbeat getHeartbeat() { + return heartbeat; + } + + public void addDisacrdCount() { + int discrad_count = this.heartbeat.getDiscardCount(); + discrad_count++; + this.heartbeat.setDiscardCount(discrad_count); + } + + public Map, Integer> getLimits() { + return limits; + } + + public Map, Block> GetBlocks() { + return blocks; + } + + public Map, Filter> GetFiltes() { + return filters; + } + + public Client getClient() { + return client; + } + + public Disruptor getDisruptor() { + return disruptor; + } + +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbeProxy.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbeProxy.java index 2655831b9..e6c489d4d 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbeProxy.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/SmithProbeProxy.java @@ -1,932 +1,946 @@ -package com.security.smith; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; -import java.util.Objects; -import java.util.concurrent.atomic.AtomicIntegerArray; -import java.util.regex.Pattern; -import java.util.regex.Matcher; -import java.util.stream.Stream; - -import org.apache.commons.lang3.tuple.ImmutablePair; -import org.apache.commons.lang3.tuple.Pair; - -import com.lmax.disruptor.InsufficientCapacityException; -import com.lmax.disruptor.RingBuffer; -import com.lmax.disruptor.dsl.Disruptor; -import com.security.smith.client.Client; -import com.security.smith.client.Operate; -import com.security.smith.client.message.ClassFilter; -import com.security.smith.client.message.Heartbeat; -import com.security.smith.client.message.MatchRule; -import com.security.smith.client.message.Trace; -import com.security.smith.client.message.Block; -import com.security.smith.common.Reflection; -import com.security.smith.common.SmithHandler; -import com.security.smith.log.SmithLogger; - -public class SmithProbeProxy { - private final int CLASS_MAX_ID = 30; - private final int METHOD_MAX_ID = 20; - private final int DEFAULT_QUOTA = 12000; - - private SmithProbe SmithProbeObj = null; - private AtomicIntegerArray[] quotas; - private Disruptor disruptor; - private Client client; - private boolean stopX; - private Map reflectField = new HashMap<>(); - private Map reflectMethod = new HashMap<>(); - - public InheritableThreadLocal localfilterConfig = new InheritableThreadLocal() { - @Override - protected Object initialValue() { - return null; - } - }; - - public InheritableThreadLocal localfilterDef = new InheritableThreadLocal() { - @Override - protected Object initialValue() { - return null; - } - }; - - public InheritableThreadLocal needFoundfilterDef = new InheritableThreadLocal() { - @Override - protected Object initialValue() { - return null; - } - }; - - public InheritableThreadLocal jettyDeploying = new InheritableThreadLocal() { - @Override - protected Boolean initialValue() { - return false; - } - }; - - private boolean removeThreadLocalFormThread(Object threadObj,Object threadLocalObj) { - boolean bret = false; - boolean usegetMap = false; - - if(threadObj == null || - threadLocalObj == null) { - return false; - } - - try { - String className = threadLocalObj.getClass().getSuperclass().getName(); - if(className.contains("java.lang.InheritableThreadLocal")) { - Class[] argType_remove = new Class[]{Thread.class}; - bret = Reflection.invokeSuperSuperMethodNoReturn(threadLocalObj,"remove",argType_remove,threadObj); - } - else if(className.contains("java.lang.ThreadLocal")) { - Class[] argType_remove = new Class[]{Thread.class}; - bret = Reflection.invokeSuperMethodNoReturn(threadLocalObj,"remove",argType_remove,threadObj); - } - } - catch(Throwable t) { - } - - if(!bret) { - try { - Class[] argType_getMap = new Class[]{Thread.class}; - Object threadlocalMap = Reflection.invokeSuperMethod(threadLocalObj,"getMap",argType_getMap,threadObj); - if(threadlocalMap != null) { - Class[] argType_remove = new Class[]{ThreadLocal.class}; - bret = Reflection.invokeMethodNoReturn(threadlocalMap,"remove",argType_remove,threadLocalObj); - - } - } - catch(Throwable t) { - SmithLogger.exception(t); - } - } - - return bret; - } - - private void RemoveThreadLocalVar() { - int activeCount = Thread.activeCount(); - Thread[] threads = new Thread[activeCount+100]; - int count = Thread.enumerate(threads); - for (int i = 0; i < count; i++) { - removeThreadLocalFormThread(threads[i], localfilterConfig); - removeThreadLocalFormThread(threads[i], localfilterDef); - removeThreadLocalFormThread(threads[i], needFoundfilterDef); - removeThreadLocalFormThread(threads[i], jettyDeploying); - } - } - - public SmithProbeProxy() { - stopX = false; - //quotas = Stream.generate(() -> new AtomicIntegerArray(METHOD_MAX_ID)).limit(CLASS_MAX_ID).toArray(AtomicIntegerArray[]::new); - - quotas = new AtomicIntegerArray[CLASS_MAX_ID]; - for (int i = 0; i < CLASS_MAX_ID; i++) { - quotas[i] = new AtomicIntegerArray(METHOD_MAX_ID); - } - } - - public void uninit() { - this.client = null; - this.disruptor = null; - for(int i = 0;i < this.quotas.length;i++) { - this.quotas[i] = null; - } - this.quotas = null; - this.SmithProbeObj = null; - RemoveThreadLocalVar(); - - localfilterConfig = null; - localfilterDef = null; - needFoundfilterDef = null; - jettyDeploying = null; - } - - public void setProbe(SmithProbe SmithProbeObj) { - this.SmithProbeObj = SmithProbeObj; - } - - public void setClient(Client client) { - this.client = client; - } - - public void setDisruptor(Disruptor disruptor) { - this.disruptor = disruptor; - } - - public void setReflectField() { - String[] values1 = {"theUnsafe", "unsafe", "fieldFilterMap", "methodFilterMap"}; - String[] values2 = {"launchMechanism"}; - String[] values3 = {"handlerMap", "adaptedInterceptors"}; - String[] values4 = {"context"}; - String[] values5 = {"delegate"}; - String[] values6 = {"handlerAdapters", "handlerMappings"}; - String[] values7 = {"chain"}; - String[] values8 = {"httpUpgradeProtocols"}; - String[] values9 = {"executor"}; - String[] values10 = {"connector"}; - - reflectField.put("*", values1); - reflectField.put("java.lang.UNIXProcess", values2); - reflectField.put("java.lang.ProcessImpl", values2); - reflectField.put("org.springframework.web.servlet.handler.AbstractUrlHandlerMapping", values3); - reflectField.put("org.apache.catalina.core.ApplicationContext", values4); - reflectField.put("org.springframework.context.ApplicationListener", values5); - reflectField.put("org.springframework.web.servlet.DispatcherServlet", values6); - reflectField.put("org.springframework.web.server.handler.FilteringWebHandler", values7); - reflectField.put("org.apache.coyote.http11.AbstractHttp11Protocol", values8); - reflectField.put("org.apache.tomcat.util.net.AbstractEndpoint", values9); - reflectField.put("org.apache.catalina.connector.CoyoteAdapter", values10); - } - - public void setReflectMethod() { - - String[] values1 = {"*"}; - String[] values2 = {"load"}; - String[] values3 = {"forkAndExec"}; - String[] values4 = {"create"}; - String[] values5 = {"defineClass"}; - reflectMethod.put("java.lang.Unsafe", values1); - reflectMethod.put("java.lang.ClassLoader$NativeLibrary", values2); - reflectMethod.put("java.lang.UNIXProcess", values3); - reflectMethod.put("java.lang.ProcessImpl", values4); - reflectMethod.put("java.lang.ClassLoader", values5); - } - - public Map getReflectMethod() { - return this.reflectMethod; - } - - public Map getReflectField() { - return this.reflectField; - } - - public boolean checkReflectFeildEvil(String classname, String fieldname) { - if (classname == null || fieldname == null) { - return false; - } - Map refieldMap = getReflectField(); - if (refieldMap == null) { - return false; - } - if (refieldMap.containsKey(classname)) { - String[] values = refieldMap.get(classname); - for (String value : values) { - if (value.equals(fieldname) || value.equals("*")) { - return true; - } - } - } else { - String[] values = refieldMap.get("*"); - if (values == null) { - return false; - } - for (String value : values) { - if (value.equals(fieldname) || value.equals("*")) { - return true; - } - } - } - return false; - } - - - public boolean checkReflectMethodEvil(String classname, String methodname) { - if (classname == null || methodname == null) { - return false; - } - Map refieldMap = getReflectMethod(); - if (refieldMap == null) { - return false; - } - if (refieldMap.containsKey(classname)) { - String[] values = refieldMap.get(classname); - for (String value : values) { - if (value.equals(methodname) || value.equals("*")) { - return true; - } - } - } else { - String[] values = refieldMap.get("*"); - if (values == null) { - return false; - } - for (String value : values) { - if (value.equals(methodname) || value.equals("*")) { - return true; - } - } - } - return false; - } - - public void detect(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - - Map, Block> blocks = SmithProbeObj.GetBlocks(); - if (blocks == null) - return; - Block block = blocks.get(new ImmutablePair<>(classID, methodID)); - - if (block == null) - return; - - MatchRule[] rules = block.getRules(); - boolean isBlocked = false; - - for (MatchRule rule : rules) { - if (rule != null) { - if (rule.getIndex() >= args.length || args[rule.getIndex()] == null || rule.getRegex() == null) { - continue; - } - - Pattern pattern = Pattern.compile(rule.getRegex()); - Matcher matcher = pattern.matcher(args[rule.getIndex()].toString()); - - if (matcher.find()) { - isBlocked = true; - break; - } - } - } - - if (isBlocked) { - throw new SecurityException("API blocked by RASP"); - } - } - - public void trace(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if (classID >= CLASS_MAX_ID || methodID >= METHOD_MAX_ID || stopX) - return; - - while (true) { - int quota = quotas[classID].get(methodID); - - if (quota <= 0) { - SmithProbeObj.addDisacrdCount(); - return; - } - - if (quotas[classID].compareAndSet(methodID, quota, quota - 1)) - break; - } - if (disruptor == null) { - SmithProbeObj.addDisacrdCount(); - return; - } - RingBuffer ringBuffer = disruptor.getRingBuffer(); - - try { - long sequence = ringBuffer.tryNext(); - - Trace trace = ringBuffer.get(sequence); - - trace.setClassID(classID); - trace.setMethodID(methodID); - trace.setBlocked(blocked); - trace.setRet(ret); - trace.setArgs(args); - trace.setStackTrace(Thread.currentThread().getStackTrace()); - - ringBuffer.publish(sequence); - } catch (InsufficientCapacityException ignored) { - SmithProbeObj.addDisacrdCount(); - } - } - - public void sendMetadataObject(Object obj) { - if(stopX) { - return; - } - - if (obj != null) { - sendMetadataClass(obj.getClass()); - } - } - - public void sendMetadataClass(Class cla) { - if (cla == null || stopX) { - return; - } - - if(SmithProbeObj.classIsSended(cla)) { - return ; - } - - ClassFilter classFilter = new ClassFilter(); - SmithHandler.queryClassFilter(cla, classFilter); - classFilter.setTransId(); - classFilter.setRuleId(-1); - classFilter.setStackTrace(Thread.currentThread().getStackTrace()); - if (client != null) { - client.write(Operate.SCANCLASS, classFilter); - SmithLogger.logger.info("send metadata: " + classFilter.toString()); - SmithProbeObj.sendClass(cla, classFilter.getTransId()); - } - } - - public void checkAddServletPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkAddServlet pre_hook call success"); - - if (args.length < 3) { - return; - } - try { - Object context = args[0]; - String name = (String)args[2]; - if (context != null) { - Class[] argTypes = new Class[]{String.class}; - Object wrapper = Reflection.invokeMethod(context, "findChild", argTypes, name); - if(wrapper != null) { - Class[] emptyArgTypes = new Class[]{}; - - Object servlet = Reflection.invokeMethod(wrapper, "getServlet", emptyArgTypes); - sendMetadataObject(servlet); - } - } - - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - private Object getFilterFromConfig(Object filterConfig) { - if (filterConfig == null) { - return null; - } - Object filter = null; - try { - filter = Reflection.getField(filterConfig, "filter"); - } catch (Exception e) { - SmithLogger.exception(e); - } - return filter; - } - - private Class getFilterFromLoader(Object context, String filterName) { - Class filter = null; - if (context == null || filterName == null) - return filter; - ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); - if (classLoader == null) - classLoader = context.getClass().getClassLoader(); - try { - filter = classLoader.loadClass(filterName); - } catch (Exception e) { - } - return filter; - } - - public void checkAddFilterPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkAddFilter pre_hook call success"); - if (args.length < 2) { - return; - } - try { - Object filterdef = args[1]; - Object filter = null; - Class filterClass = null; - if (filterdef != null) { - Class[] emptyArgTypes = new Class[]{}; - filter = Reflection.invokeMethod(filterdef, "getFilter", emptyArgTypes); - String filterName = ""; - if (filter == null) { - // Godzilla filter check - if (localfilterDef != null && localfilterConfig != null && filterdef == localfilterDef.get()) { - filter = getFilterFromConfig(localfilterConfig.get()); - } else { - filterName = (String)Reflection.invokeMethod(filterdef, "getFilterClass", emptyArgTypes); - filterClass = getFilterFromLoader(args[0], filterName); - } - } - if (filter != null || filterClass != null) { - Class clazz = null; - if (filterClass != null) { - clazz = filterClass; - } else { - clazz = filter.getClass(); - } - - sendMetadataObject(clazz); - } else { - needFoundfilterDef.set(filterdef); - } - } - } catch (Exception e) { - SmithLogger.exception(e); - } - } - public void checkFilterConfigPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkAddFilter post_hook call success"); - if (ret == null || args.length < 2) { - return; - } - try { - localfilterConfig.set(ret); - localfilterDef.set(args[1]); - - // shiro filter check - if (needFoundfilterDef != null && needFoundfilterDef.get() == args[1]) { - Object filter = getFilterFromConfig(ret); - sendMetadataObject(filter); - } - } catch(Exception e) { - SmithLogger.exception(e); - } - } - - public void checkAddValvePre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - if (args.length < 2) { - return; - } - try { - Object valve = args[1]; - sendMetadataObject(valve); - - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - public void checkAddListenerPre(int classID, int methodID, Object[] args) { - checkAddValvePre(classID, methodID, args); - } - - public void checkWebSocketPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("check WebSocketPre"); - if (args.length < 2) { - return; - } - try { - Object ws = args[1]; - Class[] emptyArgTypes = new Class[]{}; - Class endpointCla = (Class)Reflection.invokeMethod(ws, "getEndpointClass", emptyArgTypes); - sendMetadataClass(endpointCla); - - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - public void onTimer() { - Heartbeat heartbeat = SmithProbeObj.getHeartbeat(); - if (client != null) - client.write(Operate.HEARTBEAT, heartbeat); - - Map, Integer> limits = SmithProbeObj.getLimits(); - - for (int i = 0; i < CLASS_MAX_ID; i++) { - for (int j = 0; j < METHOD_MAX_ID; j++) { - Integer quota = limits.get(new ImmutablePair<>(i, j)); - - if (quota == null) { - quotas[i].set(j, DEFAULT_QUOTA); - continue; - } - quotas[i].set(j, quota); - } - } - } - - public void checkResinAddServletPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if(stopX) { - return; - } - if (args.length < 2) { - return; - } - try { - Object servletMapping = args[1]; - if (servletMapping != null) { - Class[] emptyArgTypes = new Class[]{}; - Class servletClass = (Class)Reflection.invokeMethod(servletMapping, "getServletClass", emptyArgTypes); - sendMetadataClass(servletClass); - } - } catch (Throwable e) { - SmithLogger.exception(e); - } - } - - /* - * check resin servlet - */ - public void checkResinAddServletPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - if (args.length < 2) { - return; - } - try { - Object servletMapping = args[1]; - if (servletMapping != null) { - Class[] emptyArgTypes = new Class[]{}; - Class servletClass = (Class)Reflection.invokeMethod(servletMapping, "getServletClass", emptyArgTypes); - sendMetadataClass(servletClass); - } - } catch (Throwable e) { - SmithLogger.exception(e); - } - } - - /* - * check resin add filter memshell - */ - public void checkResinAddFilterPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkResinAddFilter pre_hook call success"); - if (args.length < 2) { - return; - } - try { - Object filterdef = args[1]; - if (filterdef != null) { - Class[] emptyArgTypes = new Class[]{}; - Class filterCla = (Class)Reflection.invokeMethod(filterdef, "getFilterClass", emptyArgTypes); - sendMetadataClass(filterCla); - } - } catch (Throwable e) { - SmithLogger.exception(e); - } - - } - - public void checkResinWebSocketPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkResinWebSocket pre_hook call success"); - if (args.length < 3) { - return; - } - try { - Object weblistener = args[2]; - if (weblistener != null) { - sendMetadataObject(weblistener); - } - } catch (Exception e) { - SmithLogger.exception(e); - } - } - /* - * check jetty version 9 add filter/servlet memshell - * TODO: add url check - */ - public void checkJettyMemshellPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkJettyMemshellPre pre_hook call success"); - if (jettyDeploying != null && jettyDeploying.get() == true) { - return; - } - if (args.length < 2) { - return; - } - try { - Class newclass = (Class)args[1]; - sendMetadataClass(newclass); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - /* - * check Jetty 9.4 Listener memshell - */ - public void checkJettyListenerPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkJettyListenerPre pre_hook call success"); - if (args.length < 2) { - return; - } - try { - Object listener = args[1]; - sendMetadataObject(listener); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - /* - * used for listener check - */ - public void cehckJettyDeployPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - if (jettyDeploying != null) { - jettyDeploying.set(true); - } - } - - /* user for check ServerEndpointConfig init */ - public void checkWebSocketConfigPre(int classID, int metodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkWebSocketConfigPre called"); - try { - if (args.length < 2) { - return; - } - Class websocket = (Class)args[0]; - sendMetadataClass(websocket); - - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - /* - * used for listener check - */ - public void checkJettyDeployPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if(stopX) { - return; - } - if (jettyDeploying != null) { - jettyDeploying.set(false); - } - } - - /* - * check spring controller memshell - */ - public void checkSpringControllerPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - if (args.length < 3) { - return; - } - try { - Object controller = args[2]; - sendMetadataObject(controller); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - /* - * check spring Interceptor memshell - */ - public void checkSpringInterceptorPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - if (args.length < 1) { - return; - } - try { - Object interceptor = args[0]; - sendMetadataObject(interceptor); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - public void checkMemshellInitPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if(stopX) { - return; - } - //SmithLogger.logger.info("checkMemshellInitPost call success"); - if (ret != null) { - try { - sendMetadataObject(ret); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - } - - private boolean checkIsRaspClass(String classname) { - - if (((classname.startsWith("com.security.smith.") || - classname.startsWith("com.security.smithloader.") || - classname.startsWith("rasp.io")) || - classname.startsWith("rasp.org") || - classname.startsWith("rasp.com") || - classname.startsWith("rasp.javassist"))) { - return true; - } - - return false; - } - - /* - * used for wildfly ModuleClassLoader findClass hook - */ - - public Object processWildflyClassLoaderException(int classID, int methodID, Object[] args,Object exceptionObject) throws Throwable { - if(exceptionObject instanceof ClassNotFoundException) { - String classname = (String) args[1]; - - if(checkIsRaspClass(classname)) { - return (Object)Class.forName(classname); - } - } - - return null; - } - - /* - - public ServletHandler addServlet(ServletInfo servletInfo) - - - */ - - public void checkWildflyaddServletPre(int classID, int methodID, Object[] args) { - if(stopX) { - return; - } - SmithLogger.logger.info("checkWildflyaddServlet pre_hook call success"); - if(args.length < 2) { - return ; - } - - try { - Object servletInfo = args[1]; - if(servletInfo != null) { - Class servletClass = (Class)Reflection.getField(servletInfo,"servletClass"); - String servletName = (String)Reflection.getField(servletInfo,"name"); - - if(servletName != null) { - if (servletClass != null) { - sendMetadataObject(servletClass); - } else { - SmithLogger.logger.warning("can't find "+servletName); - } - } - } - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - /* - - public ManagedFilter addFilter(FilterInfo filterInfo) - - - */ - - public void checkWildflyaddFilterPre(int classID, int methodID, Object[] args) { - if (stopX) { - return; - } - SmithLogger.logger.info("checkWildflyaddFilter pre_hook call success"); - if(args.length < 2) { - return ; - } - - try { - Object filterInfo = args[1]; - if(filterInfo != null) { - Class filterClass = (Class)Reflection.getField(filterInfo,"filterClass"); - String filterName = (String)Reflection.getField(filterInfo,"name"); - - if(filterName != null) { - if (filterClass != null) { - sendMetadataObject(filterClass); - } else { - SmithLogger.logger.warning("can't find "+filterName); - } - } - } - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - public void handleReflectField(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if(stopX) { - return; - } - if (args.length < 2) { - return ; - } - try { - Class clas = (Class)args[0]; - String reflectClass = clas.getName(); - String feild = (String)args[1]; - if (reflectClass.startsWith("com.security.smith") || reflectClass.startsWith("rasp.")) { - return ; - } else { - if (checkReflectFeildEvil(reflectClass, feild)) { - trace(classID, methodID, args, ret, blocked); - } - } - } catch (Throwable e) { - SmithLogger.exception(e); - } - } - - public void handleReflectMethod(int classID, int methodID, Object[] args, Object ret, boolean blocked) { - if(stopX) { - return; - } - if (args.length < 2) { - return ; - } - try { - Class clas = (Class)args[0]; - String reflectClass = clas.getName(); - String feild = (String)args[1]; - if (reflectClass.startsWith("com.security.smith") || reflectClass.startsWith("rasp.")) { - return ; - } else { - if (checkReflectMethodEvil(reflectClass, feild)) { - trace(classID, methodID, args, ret, blocked); - } - } - } catch (Throwable e) { - SmithLogger.exception(e); - } - } - - /* - * used for glassfish org.apache.felix.framework.BundleWiringImpl$BundleClassLoader findClass loadClass hook - */ - - public Object processGlassfishClassLoaderfindClassException(int classID, int methodID, Object[] args,Object exceptionObject) throws Throwable { - //SmithLogger.logger.info("processGlassfishClassLoaderfindClass Exception_hook call success"); - if(exceptionObject instanceof ClassNotFoundException) { - String classname = (String) args[1]; - //SmithLogger.logger.info("processGlassfishClassLoaderfindClass find class:"+classname); - if(checkIsRaspClass(classname)) { - return (Object)Class.forName(classname); - } - } - - return null; - } -} - +package com.security.smith; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicIntegerArray; +import java.util.regex.Pattern; +import java.util.regex.Matcher; +import java.util.stream.Stream; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; + +import com.lmax.disruptor.InsufficientCapacityException; +import com.lmax.disruptor.RingBuffer; +import com.lmax.disruptor.dsl.Disruptor; +import com.security.smith.client.Client; +import com.security.smith.client.Operate; +import com.security.smith.client.message.*; +import com.security.smith.client.message.ClassFilter; +import com.security.smith.client.message.Heartbeat; +import com.security.smith.client.message.MatchRule; +import com.security.smith.client.message.Trace; +import com.security.smith.client.message.Block; +import com.security.smith.common.Reflection; +import com.security.smith.common.SmithHandler; +import com.security.smith.log.SmithLogger; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.GsonBuilder; +public class SmithProbeProxy { + private final int CLASS_MAX_ID = 30; + private final int METHOD_MAX_ID = 20; + private final int DEFAULT_QUOTA = 12000; + + private SmithProbe SmithProbeObj = null; + private AtomicIntegerArray[] quotas; + private Disruptor disruptor; + private Client client; + private boolean stopX; + private Map reflectField = new HashMap<>(); + private Map reflectMethod = new HashMap<>(); + + public InheritableThreadLocal localfilterConfig = new InheritableThreadLocal() { + @Override + protected Object initialValue() { + return null; + } + }; + + public InheritableThreadLocal localfilterDef = new InheritableThreadLocal() { + @Override + protected Object initialValue() { + return null; + } + }; + + public InheritableThreadLocal needFoundfilterDef = new InheritableThreadLocal() { + @Override + protected Object initialValue() { + return null; + } + }; + + public InheritableThreadLocal jettyDeploying = new InheritableThreadLocal() { + @Override + protected Boolean initialValue() { + return false; + } + }; + + private boolean removeThreadLocalFormThread(Object threadObj,Object threadLocalObj) { + boolean bret = false; + boolean usegetMap = false; + + if(threadObj == null || + threadLocalObj == null) { + return false; + } + + try { + String className = threadLocalObj.getClass().getSuperclass().getName(); + if(className.contains("java.lang.InheritableThreadLocal")) { + Class[] argType_remove = new Class[]{Thread.class}; + bret = Reflection.invokeSuperSuperMethodNoReturn(threadLocalObj,"remove",argType_remove,threadObj); + } + else if(className.contains("java.lang.ThreadLocal")) { + Class[] argType_remove = new Class[]{Thread.class}; + bret = Reflection.invokeSuperMethodNoReturn(threadLocalObj,"remove",argType_remove,threadObj); + } + } + catch(Throwable t) { + } + + if(!bret) { + try { + Class[] argType_getMap = new Class[]{Thread.class}; + Object threadlocalMap = Reflection.invokeSuperMethod(threadLocalObj,"getMap",argType_getMap,threadObj); + if(threadlocalMap != null) { + Class[] argType_remove = new Class[]{ThreadLocal.class}; + bret = Reflection.invokeMethodNoReturn(threadlocalMap,"remove",argType_remove,threadLocalObj); + + } + } + catch(Throwable t) { + SmithLogger.exception(t); + } + } + + return bret; + } + + private void RemoveThreadLocalVar() { + int activeCount = Thread.activeCount(); + Thread[] threads = new Thread[activeCount+100]; + int count = Thread.enumerate(threads); + for (int i = 0; i < count; i++) { + removeThreadLocalFormThread(threads[i], localfilterConfig); + removeThreadLocalFormThread(threads[i], localfilterDef); + removeThreadLocalFormThread(threads[i], needFoundfilterDef); + removeThreadLocalFormThread(threads[i], jettyDeploying); + } + } + + public SmithProbeProxy() { + stopX = false; + + quotas = new AtomicIntegerArray[CLASS_MAX_ID]; + for (int i = 0; i < CLASS_MAX_ID; i++) { + quotas[i] = new AtomicIntegerArray(METHOD_MAX_ID); + } + } + + public void uninit() { + this.client = null; + this.disruptor = null; + for(int i = 0;i < this.quotas.length;i++) { + this.quotas[i] = null; + } + this.quotas = null; + this.SmithProbeObj = null; + RemoveThreadLocalVar(); + + reflectField.clear(); + reflectField = null; + + reflectMethod.clear(); + reflectMethod = null; + + localfilterConfig = null; + localfilterDef = null; + needFoundfilterDef = null; + jettyDeploying = null; + } + + public void setProbe(SmithProbe SmithProbeObj) { + this.SmithProbeObj = SmithProbeObj; + } + + public void setClient(Client client) { + this.client = client; + } + + public void setDisruptor(Disruptor disruptor) { + this.disruptor = disruptor; + } + + public void setReflectField() { + String[] values1 = {"theUnsafe", "unsafe", "fieldFilterMap", "methodFilterMap"}; + String[] values2 = {"launchMechanism"}; + String[] values3 = {"handlerMap", "adaptedInterceptors"}; + String[] values4 = {"context"}; + String[] values5 = {"delegate"}; + String[] values6 = {"handlerAdapters", "handlerMappings"}; + String[] values7 = {"chain"}; + String[] values8 = {"httpUpgradeProtocols"}; + String[] values9 = {"executor"}; + String[] values10 = {"connector"}; + + reflectField.put("*", values1); + reflectField.put("java.lang.UNIXProcess", values2); + reflectField.put("java.lang.ProcessImpl", values2); + reflectField.put("org.springframework.web.servlet.handler.AbstractUrlHandlerMapping", values3); + reflectField.put("org.apache.catalina.core.ApplicationContext", values4); + reflectField.put("org.springframework.context.ApplicationListener", values5); + reflectField.put("org.springframework.web.servlet.DispatcherServlet", values6); + reflectField.put("org.springframework.web.server.handler.FilteringWebHandler", values7); + reflectField.put("org.apache.coyote.http11.AbstractHttp11Protocol", values8); + reflectField.put("org.apache.tomcat.util.net.AbstractEndpoint", values9); + reflectField.put("org.apache.catalina.connector.CoyoteAdapter", values10); + } + + public void setReflectMethod() { + + String[] values1 = {"*"}; + String[] values2 = {"load"}; + String[] values3 = {"forkAndExec"}; + String[] values4 = {"create"}; + String[] values5 = {"defineClass"}; + reflectMethod.put("java.lang.Unsafe", values1); + reflectMethod.put("java.lang.ClassLoader$NativeLibrary", values2); + reflectMethod.put("java.lang.UNIXProcess", values3); + reflectMethod.put("java.lang.ProcessImpl", values4); + reflectMethod.put("java.lang.ClassLoader", values5); + } + + public Map getReflectMethod() { + return this.reflectMethod; + } + + public Map getReflectField() { + return this.reflectField; + } + + public boolean checkReflectFeildEvil(String classname, String fieldname) { + if (classname == null || fieldname == null) { + return false; + } + Map refieldMap = getReflectField(); + if (refieldMap == null) { + return false; + } + if (refieldMap.containsKey(classname)) { + String[] values = refieldMap.get(classname); + for (String value : values) { + if (value.equals(fieldname) || value.equals("*")) { + return true; + } + } + } else { + String[] values = refieldMap.get("*"); + if (values == null) { + return false; + } + for (String value : values) { + if (value.equals(fieldname) || value.equals("*")) { + return true; + } + } + } + return false; + } + + + public boolean checkReflectMethodEvil(String classname, String methodname) { + if (classname == null || methodname == null) { + return false; + } + Map refieldMap = getReflectMethod(); + if (refieldMap == null) { + return false; + } + if (refieldMap.containsKey(classname)) { + String[] values = refieldMap.get(classname); + for (String value : values) { + if (value.equals(methodname) || value.equals("*")) { + return true; + } + } + } else { + String[] values = refieldMap.get("*"); + if (values == null) { + return false; + } + for (String value : values) { + if (value.equals(methodname) || value.equals("*")) { + return true; + } + } + } + return false; + } + + public void detect(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + + Map, Block> blocks = SmithProbeObj.GetBlocks(); + if (blocks == null) + return; + Block block = blocks.get(new ImmutablePair<>(classID, methodID)); + + if (block == null) + return; + + MatchRule[] rules = block.getRules(); + boolean isBlocked = false; + + for (MatchRule rule : rules) { + if (rule != null) { + if (rule.getIndex() >= args.length || args[rule.getIndex()] == null || rule.getRegex() == null) { + continue; + } + + Pattern pattern = Pattern.compile(rule.getRegex()); + Matcher matcher = pattern.matcher(args[rule.getIndex()].toString()); + + if (matcher.find()) { + isBlocked = true; + break; + } + } + } + + if (isBlocked) { + throw new SecurityException("API blocked by RASP"); + } + } + + public void trace(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if (classID >= CLASS_MAX_ID || methodID >= METHOD_MAX_ID || stopX) + return; + + while (true) { + int quota = quotas[classID].get(methodID); + + if (quota <= 0) { + SmithProbeObj.addDisacrdCount(); + return; + } + + if (quotas[classID].compareAndSet(methodID, quota, quota - 1)) + break; + } + if (disruptor == null) { + SmithProbeObj.addDisacrdCount(); + return; + } + RingBuffer ringBuffer = disruptor.getRingBuffer(); + + try { + long sequence = ringBuffer.tryNext(); + + Trace trace = ringBuffer.get(sequence); + + trace.setClassID(classID); + trace.setMethodID(methodID); + trace.setBlocked(blocked); + trace.setRet(ret); + trace.setArgs(args); + trace.setStackTrace(Thread.currentThread().getStackTrace()); + + ringBuffer.publish(sequence); + } catch (InsufficientCapacityException ignored) { + SmithProbeObj.addDisacrdCount(); + } + } + + public void sendMetadataObject(Object obj) { + if(stopX) { + return; + } + + if (obj != null) { + sendMetadataClass(obj.getClass()); + } + } + + public void sendMetadataClass(Class cla) { + if (cla == null || stopX) { + return; + } + + if(SmithProbeObj.classIsSended(cla)) { + return ; + } + + ClassFilter classFilter = new ClassFilter(); + SmithHandler.queryClassFilter(cla, classFilter); + classFilter.setTransId(); + classFilter.setRuleId(-1); + classFilter.setStackTrace(Thread.currentThread().getStackTrace()); + if (client != null) { + Gson gson = new GsonBuilder() + .registerTypeAdapter(ClassFilter.class, new ClassFilterSerializer()) + .registerTypeAdapter(ClassFilter.class, new ClassFilterDeserializer()) + .create(); + JsonElement jsonElement = gson.toJsonTree(classFilter); + client.write(Operate.SCANCLASS, jsonElement); + SmithLogger.logger.info("send metadata: " + classFilter.toString()); + SmithProbeObj.sendClass(cla, classFilter.getTransId()); + } + } + + public void checkAddServletPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkAddServlet pre_hook call success"); + + if (args.length < 3) { + return; + } + try { + Object context = args[0]; + String name = (String)args[2]; + if (context != null) { + Class[] argTypes = new Class[]{String.class}; + Object wrapper = Reflection.invokeMethod(context, "findChild", argTypes, name); + if(wrapper != null) { + Class[] emptyArgTypes = new Class[]{}; + + Object servlet = Reflection.invokeMethod(wrapper, "getServlet", emptyArgTypes); + sendMetadataObject(servlet); + } + } + + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + private Object getFilterFromConfig(Object filterConfig) { + if (filterConfig == null) { + return null; + } + Object filter = null; + try { + filter = Reflection.getField(filterConfig, "filter"); + } catch (Exception e) { + SmithLogger.exception(e); + } + return filter; + } + + private Class getFilterFromLoader(Object context, String filterName) { + Class filter = null; + if (context == null || filterName == null) + return filter; + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + if (classLoader == null) + classLoader = context.getClass().getClassLoader(); + try { + filter = classLoader.loadClass(filterName); + } catch (Exception e) { + } + return filter; + } + + public void checkAddFilterPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkAddFilter pre_hook call success"); + if (args.length < 2) { + return; + } + try { + Object filterdef = args[1]; + Object filter = null; + Class filterClass = null; + if (filterdef != null) { + Class[] emptyArgTypes = new Class[]{}; + filter = Reflection.invokeMethod(filterdef, "getFilter", emptyArgTypes); + String filterName = ""; + if (filter == null) { + // Godzilla filter check + if (localfilterDef != null && localfilterConfig != null && filterdef == localfilterDef.get()) { + filter = getFilterFromConfig(localfilterConfig.get()); + } else { + filterName = (String)Reflection.invokeMethod(filterdef, "getFilterClass", emptyArgTypes); + filterClass = getFilterFromLoader(args[0], filterName); + } + } + if (filter != null || filterClass != null) { + Class clazz = null; + if (filterClass != null) { + clazz = filterClass; + } else { + clazz = filter.getClass(); + } + + sendMetadataObject(clazz); + } else { + needFoundfilterDef.set(filterdef); + } + } + } catch (Exception e) { + SmithLogger.exception(e); + } + } + public void checkFilterConfigPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkAddFilter post_hook call success"); + if (ret == null || args.length < 2) { + return; + } + try { + localfilterConfig.set(ret); + localfilterDef.set(args[1]); + + // shiro filter check + if (needFoundfilterDef != null && needFoundfilterDef.get() == args[1]) { + Object filter = getFilterFromConfig(ret); + sendMetadataObject(filter); + } + } catch(Exception e) { + SmithLogger.exception(e); + } + } + + public void checkAddValvePre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + if (args.length < 2) { + return; + } + try { + Object valve = args[1]; + sendMetadataObject(valve); + + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + public void checkAddListenerPre(int classID, int methodID, Object[] args) { + checkAddValvePre(classID, methodID, args); + } + + public void checkWebSocketPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("check WebSocketPre"); + if (args.length < 2) { + return; + } + try { + Object ws = args[1]; + Class[] emptyArgTypes = new Class[]{}; + Class endpointCla = (Class)Reflection.invokeMethod(ws, "getEndpointClass", emptyArgTypes); + sendMetadataClass(endpointCla); + + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + public void onTimer() { + Heartbeat heartbeat = SmithProbeObj.getHeartbeat(); + if (client != null) + client.write(Operate.HEARTBEAT, heartbeat.toJsonElement()); + + Map, Integer> limits = SmithProbeObj.getLimits(); + + for (int i = 0; i < CLASS_MAX_ID; i++) { + for (int j = 0; j < METHOD_MAX_ID; j++) { + Integer quota = limits.get(new ImmutablePair<>(i, j)); + + if (quota == null) { + quotas[i].set(j, DEFAULT_QUOTA); + continue; + } + quotas[i].set(j, quota); + } + } + } + + public void checkResinAddServletPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if(stopX) { + return; + } + if (args.length < 2) { + return; + } + try { + Object servletMapping = args[1]; + if (servletMapping != null) { + Class[] emptyArgTypes = new Class[]{}; + Class servletClass = (Class)Reflection.invokeMethod(servletMapping, "getServletClass", emptyArgTypes); + sendMetadataClass(servletClass); + } + } catch (Throwable e) { + SmithLogger.exception(e); + } + } + + /* + * check resin servlet + */ + public void checkResinAddServletPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + if (args.length < 2) { + return; + } + try { + Object servletMapping = args[1]; + if (servletMapping != null) { + Class[] emptyArgTypes = new Class[]{}; + Class servletClass = (Class)Reflection.invokeMethod(servletMapping, "getServletClass", emptyArgTypes); + sendMetadataClass(servletClass); + } + } catch (Throwable e) { + SmithLogger.exception(e); + } + } + + /* + * check resin add filter memshell + */ + public void checkResinAddFilterPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkResinAddFilter pre_hook call success"); + if (args.length < 2) { + return; + } + try { + Object filterdef = args[1]; + if (filterdef != null) { + Class[] emptyArgTypes = new Class[]{}; + Class filterCla = (Class)Reflection.invokeMethod(filterdef, "getFilterClass", emptyArgTypes); + sendMetadataClass(filterCla); + } + } catch (Throwable e) { + SmithLogger.exception(e); + } + + } + + public void checkResinWebSocketPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkResinWebSocket pre_hook call success"); + if (args.length < 3) { + return; + } + try { + Object weblistener = args[2]; + if (weblistener != null) { + sendMetadataObject(weblistener); + } + } catch (Exception e) { + SmithLogger.exception(e); + } + } + /* + * check jetty version 9 add filter/servlet memshell + * TODO: add url check + */ + public void checkJettyMemshellPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkJettyMemshellPre pre_hook call success"); + if (jettyDeploying != null && jettyDeploying.get() == true) { + return; + } + if (args.length < 2) { + return; + } + try { + Class newclass = (Class)args[1]; + sendMetadataClass(newclass); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + /* + * check Jetty 9.4 Listener memshell + */ + public void checkJettyListenerPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkJettyListenerPre pre_hook call success"); + if (args.length < 2) { + return; + } + try { + Object listener = args[1]; + sendMetadataObject(listener); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + /* + * used for listener check + */ + public void cehckJettyDeployPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + if (jettyDeploying != null) { + jettyDeploying.set(true); + } + } + + /* user for check ServerEndpointConfig init */ + public void checkWebSocketConfigPre(int classID, int metodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkWebSocketConfigPre called"); + try { + if (args.length < 2) { + return; + } + Class websocket = (Class)args[0]; + sendMetadataClass(websocket); + + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + /* + * used for listener check + */ + public void checkJettyDeployPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if(stopX) { + return; + } + if (jettyDeploying != null) { + jettyDeploying.set(false); + } + } + + /* + * check spring controller memshell + */ + public void checkSpringControllerPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + if (args.length < 3) { + return; + } + try { + Object controller = args[2]; + sendMetadataObject(controller); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + /* + * check spring Interceptor memshell + */ + public void checkSpringInterceptorPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + if (args.length < 1) { + return; + } + try { + Object interceptor = args[0]; + sendMetadataObject(interceptor); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + public void checkMemshellInitPost(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if(stopX) { + return; + } + //SmithLogger.logger.info("checkMemshellInitPost call success"); + if (ret != null) { + try { + sendMetadataObject(ret); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + } + + private boolean checkIsRaspClass(String classname) { + + if (((classname.startsWith("com.security.smith.") || + classname.startsWith("com.security.smithloader.") || + classname.startsWith("rasp.io")) || + classname.startsWith("rasp.org") || + classname.startsWith("rasp.com") || + classname.startsWith("rasp.javassist"))) { + return true; + } + + return false; + } + + /* + * used for wildfly ModuleClassLoader findClass hook + */ + + public Object processWildflyClassLoaderException(int classID, int methodID, Object[] args,Object exceptionObject) throws Throwable { + if(exceptionObject instanceof ClassNotFoundException) { + String classname = (String) args[1]; + + if(checkIsRaspClass(classname)) { + return (Object)Class.forName(classname); + } + } + + return null; + } + + /* + + public ServletHandler addServlet(ServletInfo servletInfo) + + + */ + + public void checkWildflyaddServletPre(int classID, int methodID, Object[] args) { + if(stopX) { + return; + } + SmithLogger.logger.info("checkWildflyaddServlet pre_hook call success"); + if(args.length < 2) { + return ; + } + + try { + Object servletInfo = args[1]; + if(servletInfo != null) { + Class servletClass = (Class)Reflection.getField(servletInfo,"servletClass"); + String servletName = (String)Reflection.getField(servletInfo,"name"); + + if(servletName != null) { + if (servletClass != null) { + sendMetadataObject(servletClass); + } else { + SmithLogger.logger.warning("can't find "+servletName); + } + } + } + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + /* + + public ManagedFilter addFilter(FilterInfo filterInfo) + + + */ + + public void checkWildflyaddFilterPre(int classID, int methodID, Object[] args) { + if (stopX) { + return; + } + SmithLogger.logger.info("checkWildflyaddFilter pre_hook call success"); + if(args.length < 2) { + return ; + } + + try { + Object filterInfo = args[1]; + if(filterInfo != null) { + Class filterClass = (Class)Reflection.getField(filterInfo,"filterClass"); + String filterName = (String)Reflection.getField(filterInfo,"name"); + + if(filterName != null) { + if (filterClass != null) { + sendMetadataObject(filterClass); + } else { + SmithLogger.logger.warning("can't find "+filterName); + } + } + } + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + public void handleReflectField(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if(stopX) { + return; + } + if (args.length < 2) { + return ; + } + try { + Class clas = (Class)args[0]; + String reflectClass = clas.getName(); + String feild = (String)args[1]; + if (reflectClass.startsWith("com.security.smith") || reflectClass.startsWith("rasp.")) { + return ; + } else { + if (checkReflectFeildEvil(reflectClass, feild)) { + trace(classID, methodID, args, ret, blocked); + } + } + } catch (Throwable e) { + SmithLogger.exception(e); + } + } + + public void handleReflectMethod(int classID, int methodID, Object[] args, Object ret, boolean blocked) { + if(stopX) { + return; + } + if (args.length < 2) { + return ; + } + try { + Class clas = (Class)args[0]; + String reflectClass = clas.getName(); + String feild = (String)args[1]; + if (reflectClass.startsWith("com.security.smith") || reflectClass.startsWith("rasp.")) { + return ; + } else { + if (checkReflectMethodEvil(reflectClass, feild)) { + trace(classID, methodID, args, ret, blocked); + } + } + } catch (Throwable e) { + SmithLogger.exception(e); + } + } + + /* + * used for glassfish org.apache.felix.framework.BundleWiringImpl$BundleClassLoader findClass loadClass hook + */ + + public Object processGlassfishClassLoaderfindClassException(int classID, int methodID, Object[] args,Object exceptionObject) throws Throwable { + //SmithLogger.logger.info("processGlassfishClassLoaderfindClass Exception_hook call success"); + if(exceptionObject instanceof ClassNotFoundException) { + String classname = (String) args[1]; + //SmithLogger.logger.info("processGlassfishClassLoaderfindClass find class:"+classname); + if(checkIsRaspClass(classname)) { + return (Object)Class.forName(classname); + } + } + + return null; + } +} + diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/asm/SmithMethodVisitor.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/asm/SmithMethodVisitor.java index 02876ecdf..48de9c42d 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/asm/SmithMethodVisitor.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/asm/SmithMethodVisitor.java @@ -1,371 +1,348 @@ -package com.security.smith.asm; - -//import com.security.smith.SmithProbeProxy; -import com.security.smith.processor.*; -import org.apache.commons.lang3.ArrayUtils; -import org.objectweb.asm.Label; -import org.objectweb.asm.MethodVisitor; -import org.objectweb.asm.Opcodes; -import org.objectweb.asm.Type; -import org.objectweb.asm.commons.AdviceAdapter; -import org.objectweb.asm.commons.Method; -import java.util.function.Function; - -import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; - -public class SmithMethodVisitor extends AdviceAdapter { - private final int classID; - private final int methodID; - private final Type classType; - private final boolean canBlock; - private final boolean isStatic; - private final boolean isConstructor; - private final int returnVariable; - private final int argumentsVariable; - private final Label start; - private final Label end; - private final Label handler; - private String preHook; - private String postHook; - private String exceptionHook; - private String xHook; - - private final Map> smithProcessors = new HashMap>() {{ - put("byte[]", ByteArrayProcessor.class); - put("int[]", IntegerArrayProcessor.class); - put("java.net.ProtocolFamily", ProtocolFamilyProcessor.class); - put("java.io.FileDescriptor", FileDescriptorProcessor.class); - put("java.net.URL[]", ObjectArrayProcessor.class); - put("java.net.DatagramPacket", DatagramPacketProcessor.class); - put("java.net.DatagramSocket", DatagramSocketProcessor.class); - put("java.lang.String[]", ObjectArrayProcessor.class); - put("java.lang.Process", ProcessProcessor.class); - put("java.lang.UNIXProcess", ProcessProcessor.class); - put("java.lang.ProcessImpl", ProcessProcessor.class); - put("java.net.InetAddress[]", ObjectArrayProcessor.class); - }}; - - protected SmithMethodVisitor(int api, Type classType, int classID, int methodID, boolean canBlock, MethodVisitor methodVisitor, int access, String name, String descriptor, String pre_hook, String post_hook,String exception_hook) { - super(api, methodVisitor, access, name, descriptor); - - this.classType = classType; - this.classID = classID; - this.methodID = methodID; - this.canBlock = canBlock; - this.preHook = pre_hook; - this.postHook = post_hook; - this.exceptionHook = exception_hook; - - start = new Label(); - end = new Label(); - handler = new Label(); - - argumentsVariable = newLocal(Type.getType(Object[].class)); - returnVariable = newLocal(Type.getType(Object.class)); - - isConstructor = name.equals(""); - isStatic = (access & Opcodes.ACC_STATIC) != 0; - } - - @Override - public void loadArgArray() { - int reserved = isStatic || isConstructor ? 0 : 1; - Type[] argumentTypes = Type.getArgumentTypes(methodDesc); - - push(argumentTypes.length + reserved); - newArray(Type.getType(Object.class)); - - if (reserved > 0) { - dup(); - push(0); - loadThis(); - processObject(classType.getClassName()); - arrayStore(Type.getType(Object.class)); - } - - for (int i = 0; i < argumentTypes.length; i++) { - dup(); - push(i + reserved); - loadArg(i); - box(argumentTypes[i]); - processObject(argumentTypes[i].getClassName()); - arrayStore(Type.getType(Object.class)); - } - } - - @Override - protected void onMethodEnter() { - super.onMethodEnter(); - - visitTryCatchBlock(start, end, handler, Type.getInternalName(Exception.class)); - - loadArgArray(); - storeLocal(argumentsVariable); - - visitInsn(ACONST_NULL); - storeLocal(returnVariable); - - mark(start); - - if (preHook == null || preHook == "") { - if (!canBlock) { - return; - } else { - preHook = "detect"; - } - } - - push(preHook); - push(classID); - push(methodID); - loadLocal(argumentsVariable); - - invokeStatic( - Type.getType("Lcom/security/smithloader/SmithAgent;"), - new Method( - "PreProxy", - Type.VOID_TYPE, - new Type[]{ - Type.getType(Object.class), - Type.INT_TYPE, - Type.INT_TYPE, - Type.getType(Object[].class) - } - ) - ); - } - - @Override - protected void onMethodExit(int opcode) { - super.onMethodExit(opcode); - - if (opcode == ATHROW) - return; - - Type returnType = Type.getReturnType(methodDesc); - - if (opcode == RETURN) { - if (isConstructor) { - loadThis(); - processObject(classType.getClassName()); - } else { - visitInsn(ACONST_NULL); - } - } else if (opcode == ARETURN) { - dup(); - processObject(returnType.getClassName()); - } else { - if (opcode == LRETURN || opcode == DRETURN) { - dup2(); - } else { - dup(); - } - - box(returnType); - } - - storeLocal(returnVariable); - - if (postHook == null || postHook == "") { - postHook = "trace"; - } - - push(postHook); - push(classID); - push(methodID); - loadLocal(argumentsVariable); - loadLocal(returnVariable); - push(false); - - invokeStatic( - Type.getType("Lcom/security/smithloader/SmithAgent;"), - new Method( - "PostProxy", - Type.VOID_TYPE, - new Type[]{ - Type.getType(Object.class), - Type.INT_TYPE, - Type.INT_TYPE, - Type.getType(Object[].class), - Type.getType(Object.class), - Type.BOOLEAN_TYPE - } - ) - ); - } - - class TypeMapper implements Function { - @Override - public Object apply(Type t) { - switch (t.getSort()) { - case Type.BOOLEAN: - case Type.CHAR: - case Type.BYTE: - case Type.SHORT: - case Type.INT: - return Opcodes.INTEGER; - case Type.FLOAT: - return Opcodes.FLOAT; - case Type.ARRAY: - case Type.OBJECT: - return t.getInternalName(); - case Type.LONG: - return Opcodes.LONG; - case Type.DOUBLE: - return Opcodes.DOUBLE; - default: - throw new AssertionError(); - } - } -} - @Override - public void visitMaxs(final int maxStack, final int maxLocals) { - mark(end); - mark(handler); - - if (exceptionHook == null || exceptionHook == "") { - - if (xHook == null || xHook == "") { - xHook = "trace"; - } - - Type[] types = Type.getArgumentTypes(methodDesc); - - if (!isStatic) { - types = ArrayUtils.addFirst(types, classType); - } - - /* - Object[] local = Arrays.stream(types).map(t -> { - switch (t.getSort()) { - case Type.BOOLEAN: - case Type.CHAR: - case Type.BYTE: - case Type.SHORT: - case Type.INT: - return Opcodes.INTEGER; - case Type.FLOAT: - return Opcodes.FLOAT; - case Type.ARRAY: - case Type.OBJECT: - return t.getInternalName(); - case Type.LONG: - return Opcodes.LONG; - case Type.DOUBLE: - return Opcodes.DOUBLE; - default: - throw new AssertionError(); - } - }).toArray(); -*/ - - Function typeMapper = new TypeMapper(); - Object[] local = Arrays.stream(types).map(typeMapper).toArray(); - - visitFrame( - Opcodes.F_NEW, - local.length, - local, - 1, - new Object[]{Type.getInternalName(Exception.class)} - ); - - storeLocal(returnVariable + 1, Type.getType(Exception.class)); - - push(xHook); - push(classID); - push(methodID); - loadLocal(argumentsVariable); - visitInsn(Opcodes.ACONST_NULL); - - if (!canBlock) { - push(false); - } else { - loadLocal(returnVariable + 1); - instanceOf(Type.getType(SecurityException.class)); - } - - invokeStatic( - Type.getType("Lcom/security/smithloader/SmithAgent;"), - new Method( - "PostProxy", - Type.VOID_TYPE, - new Type[]{ - Type.getType(Object.class), - Type.INT_TYPE, - Type.INT_TYPE, - Type.getType(Object[].class), - Type.getType(Object.class), - Type.BOOLEAN_TYPE - } - ) - ); - - loadLocal(returnVariable + 1); - throwException(); - - super.visitMaxs(maxStack, maxLocals); - } - else { - int newLocal = newLocal(Type.getType(Exception.class)); - int retId = newLocal(Type.getType(Object.class)); - - storeLocal(newLocal,Type.getType(Exception.class)); - loadLocal(newLocal); - - push(exceptionHook); - push(classID); - push(methodID); - loadLocal(argumentsVariable); - loadLocal(newLocal); - - invokeStatic( - Type.getType("Lcom/security/smithloader/SmithAgent;"), - new Method( - "ExceptionProxy", - Type.getType(Object.class), - new Type[]{ - Type.getType(Object.class), - Type.INT_TYPE, - Type.INT_TYPE, - Type.getType(Object[].class), - Type.getType(Object.class) - } - ) - ); - - mv.visitVarInsn(ASTORE, retId); - mv.visitVarInsn(ALOAD, retId); - Label label_if = new Label(); - mv.visitJumpInsn(IFNULL, label_if); - mv.visitVarInsn(ALOAD, retId); - mv.visitTypeInsn(CHECKCAST, "java/lang/Class"); - mv.visitInsn(ARETURN); - mv.visitLabel(label_if); - loadLocal(newLocal); - throwException(); - - super.visitMaxs(maxStack, maxLocals); - } - } - - void processObject(String name) { - Class processor = smithProcessors.get(name); - - if (processor == null) - return; - - invokeStatic( - Type.getType(processor), - new Method( - "transform", - Type.getType(Object.class), - new Type[]{ - Type.getType(Object.class) - } - ) - ); - } -} +package com.security.smith.asm; + +//import com.security.smith.SmithProbeProxy; +import com.security.smith.processor.*; +import org.apache.commons.lang3.ArrayUtils; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; +import org.objectweb.asm.commons.AdviceAdapter; +import org.objectweb.asm.commons.Method; +import java.util.function.Function; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +public class SmithMethodVisitor extends AdviceAdapter { + private final int classID; + private final int methodID; + private final Type classType; + private final boolean canBlock; + private final boolean isStatic; + private final boolean isConstructor; + private final int returnVariable; + private final int argumentsVariable; + private final Label start; + private final Label end; + private final Label handler; + private String preHook; + private String postHook; + private String exceptionHook; + private String xHook; + + private final Map> smithProcessors = new HashMap>() {{ + put("byte[]", ByteArrayProcessor.class); + put("int[]", IntegerArrayProcessor.class); + put("java.net.ProtocolFamily", ProtocolFamilyProcessor.class); + put("java.io.FileDescriptor", FileDescriptorProcessor.class); + put("java.net.URL[]", ObjectArrayProcessor.class); + put("java.net.DatagramPacket", DatagramPacketProcessor.class); + put("java.net.DatagramSocket", DatagramSocketProcessor.class); + put("java.lang.String[]", ObjectArrayProcessor.class); + put("java.lang.Process", ProcessProcessor.class); + put("java.lang.UNIXProcess", ProcessProcessor.class); + put("java.lang.ProcessImpl", ProcessProcessor.class); + put("java.net.InetAddress[]", ObjectArrayProcessor.class); + }}; + + protected SmithMethodVisitor(int api, Type classType, int classID, int methodID, boolean canBlock, MethodVisitor methodVisitor, int access, String name, String descriptor, String pre_hook, String post_hook,String exception_hook) { + super(api, methodVisitor, access, name, descriptor); + + this.classType = classType; + this.classID = classID; + this.methodID = methodID; + this.canBlock = canBlock; + this.preHook = pre_hook; + this.postHook = post_hook; + this.exceptionHook = exception_hook; + + start = new Label(); + end = new Label(); + handler = new Label(); + + argumentsVariable = newLocal(Type.getType(Object[].class)); + returnVariable = newLocal(Type.getType(Object.class)); + + isConstructor = name.equals(""); + isStatic = (access & Opcodes.ACC_STATIC) != 0; + } + + @Override + public void loadArgArray() { + int reserved = isStatic || isConstructor ? 0 : 1; + Type[] argumentTypes = Type.getArgumentTypes(methodDesc); + + push(argumentTypes.length + reserved); + newArray(Type.getType(Object.class)); + + if (reserved > 0) { + dup(); + push(0); + loadThis(); + processObject(classType.getClassName()); + arrayStore(Type.getType(Object.class)); + } + + for (int i = 0; i < argumentTypes.length; i++) { + dup(); + push(i + reserved); + loadArg(i); + box(argumentTypes[i]); + processObject(argumentTypes[i].getClassName()); + arrayStore(Type.getType(Object.class)); + } + } + + @Override + protected void onMethodEnter() { + super.onMethodEnter(); + + visitTryCatchBlock(start, end, handler, Type.getInternalName(Exception.class)); + + loadArgArray(); + storeLocal(argumentsVariable); + + visitInsn(ACONST_NULL); + storeLocal(returnVariable); + + mark(start); + + if (preHook == null || preHook == "") { + if (!canBlock) { + return; + } else { + preHook = "detect"; + } + } + + push(preHook); + push(classID); + push(methodID); + loadLocal(argumentsVariable); + + invokeStatic( + Type.getType("Lcom/security/smithloader/SmithAgent;"), + new Method( + "PreProxy", + Type.VOID_TYPE, + new Type[]{ + Type.getType(Object.class), + Type.INT_TYPE, + Type.INT_TYPE, + Type.getType(Object[].class) + } + ) + ); + } + + @Override + protected void onMethodExit(int opcode) { + super.onMethodExit(opcode); + + if (opcode == ATHROW) { + return; + } + + Type returnType = Type.getReturnType(methodDesc); + + if (opcode == RETURN) { + if (isConstructor) { + loadThis(); + processObject(classType.getClassName()); + } else { + visitInsn(ACONST_NULL); + } + } else if (opcode == ARETURN) { + dup(); + processObject(returnType.getClassName()); + } else { + if (opcode == LRETURN || opcode == DRETURN) { + dup2(); + } else { + dup(); + } + + box(returnType); + } + + storeLocal(returnVariable); + + if (postHook == null || postHook == "") { + postHook = "trace"; + } + + push(postHook); + push(classID); + push(methodID); + loadLocal(argumentsVariable); + loadLocal(returnVariable); + push(false); + + invokeStatic( + Type.getType("Lcom/security/smithloader/SmithAgent;"), + new Method( + "PostProxy", + Type.VOID_TYPE, + new Type[]{ + Type.getType(Object.class), + Type.INT_TYPE, + Type.INT_TYPE, + Type.getType(Object[].class), + Type.getType(Object.class), + Type.BOOLEAN_TYPE + } + ) + ); + } + + class TypeMapper implements Function { + @Override + public Object apply(Type t) { + switch (t.getSort()) { + case Type.BOOLEAN: + case Type.CHAR: + case Type.BYTE: + case Type.SHORT: + case Type.INT: + return Opcodes.INTEGER; + case Type.FLOAT: + return Opcodes.FLOAT; + case Type.ARRAY: + case Type.OBJECT: + return t.getInternalName(); + case Type.LONG: + return Opcodes.LONG; + case Type.DOUBLE: + return Opcodes.DOUBLE; + default: + throw new AssertionError(); + } + } +} + @Override + public void visitMaxs(final int maxStack, final int maxLocals) { + mark(end); + mark(handler); + + if (exceptionHook == null || exceptionHook == "") { + + if (xHook == null || xHook == "") { + xHook = "trace"; + } + + Type[] types = Type.getArgumentTypes(methodDesc); + + if (!isStatic) { + types = ArrayUtils.addFirst(types, classType); + } + + Function typeMapper = new TypeMapper(); + Object[] local = Arrays.stream(types).map(typeMapper).toArray(); + + visitFrame( + Opcodes.F_NEW, + local.length, + local, + 1, + new Object[]{Type.getInternalName(Exception.class)} + ); + + storeLocal(returnVariable + 1, Type.getType(Exception.class)); + + push(xHook); + push(classID); + push(methodID); + loadLocal(argumentsVariable); + visitInsn(Opcodes.ACONST_NULL); + + if (!canBlock) { + push(false); + } else { + loadLocal(returnVariable + 1); + instanceOf(Type.getType(SecurityException.class)); + } + + invokeStatic( + Type.getType("Lcom/security/smithloader/SmithAgent;"), + new Method( + "PostProxy", + Type.VOID_TYPE, + new Type[]{ + Type.getType(Object.class), + Type.INT_TYPE, + Type.INT_TYPE, + Type.getType(Object[].class), + Type.getType(Object.class), + Type.BOOLEAN_TYPE + } + ) + ); + + loadLocal(returnVariable + 1); + throwException(); + + super.visitMaxs(maxStack, maxLocals); + } + else { + int newLocal = newLocal(Type.getType(Exception.class)); + int retId = newLocal(Type.getType(Object.class)); + + storeLocal(newLocal,Type.getType(Exception.class)); + loadLocal(newLocal); + + push(exceptionHook); + push(classID); + push(methodID); + loadLocal(argumentsVariable); + loadLocal(newLocal); + + invokeStatic( + Type.getType("Lcom/security/smithloader/SmithAgent;"), + new Method( + "ExceptionProxy", + Type.getType(Object.class), + new Type[]{ + Type.getType(Object.class), + Type.INT_TYPE, + Type.INT_TYPE, + Type.getType(Object[].class), + Type.getType(Object.class) + } + ) + ); + + mv.visitVarInsn(ASTORE, retId); + mv.visitVarInsn(ALOAD, retId); + Label label_if = new Label(); + mv.visitJumpInsn(IFNULL, label_if); + mv.visitVarInsn(ALOAD, retId); + mv.visitTypeInsn(CHECKCAST, "java/lang/Class"); + mv.visitInsn(ARETURN); + mv.visitLabel(label_if); + loadLocal(newLocal); + throwException(); + + super.visitMaxs(maxStack, maxLocals); + } + } + + void processObject(String name) { + Class processor = smithProcessors.get(name); + + if (processor == null) + return; + + invokeStatic( + Type.getType(processor), + new Method( + "transform", + Type.getType(Object.class), + new Type[]{ + Type.getType(Object.class) + } + ) + ); + } +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/ClassUploadTransformer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/ClassUploadTransformer.java index 7f71958b6..2d07aadd5 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/ClassUploadTransformer.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/ClassUploadTransformer.java @@ -1,379 +1,383 @@ -package com.security.smith.client; - -import java.lang.instrument.Instrumentation; -import java.lang.instrument.ClassFileTransformer; -import java.lang.Class; - -import java.security.ProtectionDomain; -import java.time.Instant; -import java.util.LinkedHashMap; -import java.util.LinkedList; -import java.util.Map; -import java.util.Queue; -import java.util.concurrent.Semaphore; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.ReentrantLock; - -import com.security.smith.client.message.ClassFilter; -import com.security.smith.client.message.ClassUpload; -import com.security.smith.common.ProcessHelper; -import com.security.smith.log.SmithLogger; - -import java.lang.instrument.IllegalClassFormatException; -import java.lang.management.ManagementFactory; - -public class ClassUploadTransformer implements ClassFileTransformer,Runnable { - private static ClassUploadTransformer ourInstance = new ClassUploadTransformer(); - - public final static int MAX_DUMP_CLASS_SIZE = 1024 * 1024 * 10; - - public final static int MAX_HASH_SIZE = 1024*2; - - /** - * class for dump; - */ - private Class clazzToUpload = null; - /** - * transId for dump; - */ - private String transId = null; - - /* - * client to send class - */ - private Client client = null; - private Instrumentation inst = null; - private Thread uploadClassThread = null; - - private boolean started = false; - private boolean notifyStop = false; - - /** - * send class info Queue lock - */ - - private ReentrantLock classToUploadLock = null; - - private Condition classToUploadcondition = null; - - /* - * class hash cache lock - */ - - private ReadWriteLock classHashCachelock = null; - - /* - * class hash cache - */ - - private LinkedHashMap classHashCache = null; - - /** - * send class info Queue - */ - - private Queue classToUploadList = null; - - public static ClassUploadTransformer getInstance() { - if(ourInstance == null) { - ourInstance = new ClassUploadTransformer(); - } - - return ourInstance; - } - - public static void delInstance() { - if(ourInstance != null) { - ourInstance.stop(); - ourInstance = null; - } - } - - public ClassUploadTransformer() { - - } - - public void run() { - SmithLogger.logger.info("class upload thread run enter"); - - try { - while(true) { - SendClassInfo info = getUploadClassInfo_Wait(); - if(notifyStop) { - SmithLogger.logger.info("class upload thread stop!"); - break; - } - - if(info == null) { - continue; - } - - try { - this.clazzToUpload = info.clazz; - this.transId = info.transId; - - if (inst.isModifiableClass(info.clazz) && !info.clazz.getName().startsWith("java.lang.invoke.LambdaForm")) { - try { - inst.retransformClasses(info.clazz); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - } - finally { - this.clazzToUpload = null; - this.transId = null; - info = null; - } - } - } - catch(Exception e) { - SmithLogger.exception(e); - } - - SmithLogger.logger.info("class upload thread run leave"); - } - - public synchronized boolean start(Client client,Instrumentation inst) { - SmithLogger.logger.info("start enter"); - - if(!started) { - try { - this.client = client; - this.inst = inst; - - inst.addTransformer(ourInstance, true); - SmithLogger.logger.info("addTransformer success"); - this.classToUploadLock = new ReentrantLock(); - this.classToUploadcondition = this.classToUploadLock.newCondition(); - this.classHashCachelock = new ReentrantReadWriteLock(); - this.classHashCache = new LinkedHashMap(MAX_HASH_SIZE, 0.75f, true) { - @Override - protected boolean removeEldestEntry(Map.Entry eldest) { - return size() > MAX_HASH_SIZE; - }; - }; - this.classToUploadList = new LinkedList<>(); - - SmithLogger.logger.info("init ClassUploadTransformer Var success"); - - uploadClassThread = new Thread(ourInstance); - - uploadClassThread.start(); - - SmithLogger.logger.info("Start uploadClassThread success"); - - started = true; - } - catch(Exception e) { - SmithLogger.exception(e); - inst.removeTransformer(ourInstance); - this.classToUploadcondition = null; - this.classToUploadLock = null; - this.classHashCachelock = null; - this.classHashCache = null; - this.classToUploadList = null; - this.uploadClassThread = null; - this.client = null; - this.inst = null; - } - } - - SmithLogger.logger.info("start leave"); - - return started; - } - - public synchronized boolean stop() { - SmithLogger.logger.info("stop enter"); - - if(started) { - try { - started = false; - inst.removeTransformer(ourInstance); - - SmithLogger.logger.info("removeTransformer success"); - - SmithLogger.logger.info("clear classHashCache"); - classHashCachelock.writeLock().lock(); - try { - classHashCache.clear(); - } finally { - classHashCachelock.writeLock().unlock(); - } - - SmithLogger.logger.info("notify thread stop"); - classToUploadLock.lock(); - try { - notifyStop = true; - classToUploadcondition.signal(); - } - finally { - classToUploadLock.unlock(); - } - - SmithLogger.logger.info("wait thread stop"); - uploadClassThread.join(); - SmithLogger.logger.info("upload thread stoped"); - - SmithLogger.logger.info("clear classToUploadList"); - classToUploadLock.lock(); - try { - classToUploadList.clear(); - } - finally { - classToUploadLock.unlock(); - } - - this.uploadClassThread = null; - this.client = null; - this.inst = null; - this.classToUploadcondition = null; - this.classToUploadLock = null; - this.classHashCachelock = null; - this.classHashCache = null; - this.classToUploadList = null; - } - catch(Exception e) { - SmithLogger.exception(e); - } - } - - SmithLogger.logger.info("stop leave"); - - return !started; - } - - public boolean classIsSended(int hashcode) { - boolean isSended = false; - classHashCachelock.readLock().lock(); - try { - isSended = classHashCache.containsKey(hashcode); - } finally { - classHashCachelock.readLock().unlock(); - } - - return isSended; - } - - private SendClassInfo getUploadClassInfo_Wait() { - SendClassInfo ret = null; - boolean exceptioned = false; - - classToUploadLock.lock(); - try { - if(classToUploadList.isEmpty()) { - try { - classToUploadcondition.await(); - } - catch(InterruptedException e) { - exceptioned = true; - SmithLogger.exception(e); - } - } - - if(!exceptioned && !classToUploadList.isEmpty()) { - ret = classToUploadList.poll(); - } - } - finally { - classToUploadLock.unlock(); - } - - return ret; - } - - private boolean addUploadClassInfo(Class classToUpload, String transId) { - boolean ret = false; - - try { - SendClassInfo info = new SendClassInfo(classToUpload,transId); - if(info != null) { - SmithLogger.logger.info("upload Class:" + classToUpload + " transId:"+transId); - classToUploadLock.lock(); - try { - classToUploadList.add(info); - classToUploadcondition.signal(); - } - finally { - classToUploadLock.unlock(); - } - - classHashCachelock.writeLock().lock(); - try { - classHashCache.put(classToUpload.hashCode(), 1); - } finally { - classHashCachelock.writeLock().unlock(); - } - - ret = true; - } - } - catch(Exception e) { - SmithLogger.exception(e); - } - - return ret; - } - - public boolean sendClass(Class clazz, String transId) { - if(!started) { - return false; - } - - try { - return addUploadClassInfo(clazz,transId); - } - catch(Exception e) { - SmithLogger.exception(e); - } - - return false; - } - - @Override - public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, - ProtectionDomain protectionDomain, byte[] classfileBuffer) - throws IllegalClassFormatException { - - if (clazzToUpload == classBeingRedefined) { - sendClass(classBeingRedefined, classfileBuffer); - } - - return null; - } - - /** - * dump Class; - * @param clazz - * @param data - */ - private void sendClass(Class clazz, byte[] data) { - - try { - if(clazz != null && - data != null && - data.length < MAX_DUMP_CLASS_SIZE) { - - int length = data.length; - ClassUpload classUpload = new ClassUpload(); - classUpload.setTransId(transId); - - classUpload.setByteTotalLength(length); - classUpload.setByteLength(length); - classUpload.setClassData(data); - - if (client != null) { - SmithLogger.logger.info("send classdata: " + classUpload.toString()); - client.write(Operate.CLASSUPLOAD, classUpload); - } - } - - } catch (Exception e) { - SmithLogger.exception(e); - } - } - -} +package com.security.smith.client; + +import java.lang.instrument.Instrumentation; +import java.lang.instrument.ClassFileTransformer; +import java.lang.Class; + +import java.security.ProtectionDomain; +import java.time.Instant; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.Semaphore; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import com.security.smith.client.message.ClassFilter; +import com.security.smith.client.message.ClassUpload; +import com.security.smith.common.ProcessHelper; +import com.security.smith.log.SmithLogger; + +import java.lang.instrument.IllegalClassFormatException; +import java.lang.management.ManagementFactory; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; + +public class ClassUploadTransformer implements ClassFileTransformer,Runnable { + private static ClassUploadTransformer ourInstance = new ClassUploadTransformer(); + + public final static int MAX_DUMP_CLASS_SIZE = 1024 * 1024 * 10; + + public final static int MAX_HASH_SIZE = 1024*2; + + /** + * class for dump; + */ + private Class clazzToUpload = null; + /** + * transId for dump; + */ + private String transId = null; + + /* + * client to send class + */ + private Client client = null; + private Instrumentation inst = null; + private Thread uploadClassThread = null; + + private boolean started = false; + private boolean notifyStop = false; + + /** + * send class info Queue lock + */ + + private ReentrantLock classToUploadLock = null; + + private Condition classToUploadcondition = null; + + /* + * class hash cache lock + */ + + private ReadWriteLock classHashCachelock = null; + + /* + * class hash cache + */ + + private LinkedHashMap classHashCache = null; + + /** + * send class info Queue + */ + + private Queue classToUploadList = null; + + public static ClassUploadTransformer getInstance() { + if(ourInstance == null) { + ourInstance = new ClassUploadTransformer(); + } + + return ourInstance; + } + + public static void delInstance() { + if(ourInstance != null) { + ourInstance.stop(); + ourInstance = null; + } + } + + public ClassUploadTransformer() { + + } + + public void run() { + SmithLogger.logger.info("class upload thread run enter"); + + try { + while(true) { + SendClassInfo info = getUploadClassInfo_Wait(); + if(notifyStop) { + SmithLogger.logger.info("class upload thread stop!"); + break; + } + + if(info == null) { + continue; + } + + try { + this.clazzToUpload = info.clazz; + this.transId = info.transId; + + if (inst.isModifiableClass(info.clazz) && !info.clazz.getName().startsWith("java.lang.invoke.LambdaForm")) { + try { + inst.retransformClasses(info.clazz); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + } + finally { + this.clazzToUpload = null; + this.transId = null; + info = null; + } + } + } + catch(Exception e) { + SmithLogger.exception(e); + } + + SmithLogger.logger.info("class upload thread run leave"); + } + + public synchronized boolean start(Client client,Instrumentation inst) { + SmithLogger.logger.info("start enter"); + + if(!started) { + try { + this.client = client; + this.inst = inst; + + inst.addTransformer(ourInstance, true); + SmithLogger.logger.info("addTransformer success"); + this.classToUploadLock = new ReentrantLock(); + this.classToUploadcondition = this.classToUploadLock.newCondition(); + this.classHashCachelock = new ReentrantReadWriteLock(); + this.classHashCache = new LinkedHashMap(MAX_HASH_SIZE, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > MAX_HASH_SIZE; + }; + }; + this.classToUploadList = new LinkedList<>(); + + SmithLogger.logger.info("init ClassUploadTransformer Var success"); + + uploadClassThread = new Thread(ourInstance); + + uploadClassThread.start(); + + SmithLogger.logger.info("Start uploadClassThread success"); + + started = true; + } + catch(Exception e) { + SmithLogger.exception(e); + inst.removeTransformer(ourInstance); + this.classToUploadcondition = null; + this.classToUploadLock = null; + this.classHashCachelock = null; + this.classHashCache = null; + this.classToUploadList = null; + this.uploadClassThread = null; + this.client = null; + this.inst = null; + } + } + + SmithLogger.logger.info("start leave"); + + return started; + } + + public synchronized boolean stop() { + SmithLogger.logger.info("stop enter"); + + if(started) { + try { + started = false; + inst.removeTransformer(ourInstance); + + SmithLogger.logger.info("removeTransformer success"); + + SmithLogger.logger.info("clear classHashCache"); + classHashCachelock.writeLock().lock(); + try { + classHashCache.clear(); + } finally { + classHashCachelock.writeLock().unlock(); + } + + SmithLogger.logger.info("notify thread stop"); + classToUploadLock.lock(); + try { + notifyStop = true; + classToUploadcondition.signal(); + } + finally { + classToUploadLock.unlock(); + } + + SmithLogger.logger.info("wait thread stop"); + uploadClassThread.join(); + SmithLogger.logger.info("upload thread stoped"); + + SmithLogger.logger.info("clear classToUploadList"); + classToUploadLock.lock(); + try { + classToUploadList.clear(); + } + finally { + classToUploadLock.unlock(); + } + + this.uploadClassThread = null; + this.client = null; + this.inst = null; + this.classToUploadcondition = null; + this.classToUploadLock = null; + this.classHashCachelock = null; + this.classHashCache = null; + this.classToUploadList = null; + } + catch(Exception e) { + SmithLogger.exception(e); + } + } + + SmithLogger.logger.info("stop leave"); + + return !started; + } + + public boolean classIsSended(int hashcode) { + boolean isSended = false; + classHashCachelock.readLock().lock(); + try { + isSended = classHashCache.containsKey(hashcode); + } finally { + classHashCachelock.readLock().unlock(); + } + + return isSended; + } + + private SendClassInfo getUploadClassInfo_Wait() { + SendClassInfo ret = null; + boolean exceptioned = false; + + classToUploadLock.lock(); + try { + if(classToUploadList.isEmpty()) { + try { + classToUploadcondition.await(); + } + catch(InterruptedException e) { + exceptioned = true; + SmithLogger.exception(e); + } + } + + if(!exceptioned && !classToUploadList.isEmpty()) { + ret = classToUploadList.poll(); + } + } + finally { + classToUploadLock.unlock(); + } + + return ret; + } + + private boolean addUploadClassInfo(Class classToUpload, String transId) { + boolean ret = false; + + try { + SendClassInfo info = new SendClassInfo(classToUpload,transId); + if(info != null) { + SmithLogger.logger.info("upload Class:" + classToUpload + " transId:"+transId); + classToUploadLock.lock(); + try { + classToUploadList.add(info); + classToUploadcondition.signal(); + } + finally { + classToUploadLock.unlock(); + } + + classHashCachelock.writeLock().lock(); + try { + classHashCache.put(classToUpload.hashCode(), 1); + } finally { + classHashCachelock.writeLock().unlock(); + } + + ret = true; + } + } + catch(Exception e) { + SmithLogger.exception(e); + } + + return ret; + } + + public boolean sendClass(Class clazz, String transId) { + if(!started) { + return false; + } + + try { + return addUploadClassInfo(clazz,transId); + } + catch(Exception e) { + SmithLogger.exception(e); + } + + return false; + } + + @Override + public byte[] transform(ClassLoader loader, String className, Class classBeingRedefined, + ProtectionDomain protectionDomain, byte[] classfileBuffer) + throws IllegalClassFormatException { + + if (clazzToUpload == classBeingRedefined) { + sendClass(classBeingRedefined, classfileBuffer); + } + + return null; + } + + /** + * dump Class; + * @param clazz + * @param data + */ + private void sendClass(Class clazz, byte[] data) { + + try { + if(clazz != null && + data != null && + data.length < MAX_DUMP_CLASS_SIZE) { + + int length = data.length; + ClassUpload classUpload = new ClassUpload(); + classUpload.setTransId(transId); + + classUpload.setByteTotalLength(length); + classUpload.setByteLength(length); + classUpload.setClassData(data); + + if (client != null) { + Gson gson = new Gson(); + JsonElement jsonElement = gson.toJsonTree(classUpload); + client.write(Operate.CLASSUPLOAD, jsonElement); + } + } + + } catch (Exception e) { + SmithLogger.exception(e); + } + } + +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Client.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Client.java index d53a2b641..192951c37 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Client.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Client.java @@ -1,311 +1,325 @@ -package com.security.smith.client; - -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.PropertyNamingStrategies; -import com.security.smith.client.message.*; -import com.security.smith.common.ProcessHelper; -import com.security.smith.log.SmithLogger; -import io.netty.bootstrap.Bootstrap; -import io.netty.channel.*; -import io.netty.channel.epoll.EpollDomainSocketChannel; -import io.netty.channel.epoll.EpollEventLoopGroup; -import io.netty.channel.unix.DomainSocketAddress; -import io.netty.channel.unix.DomainSocketChannel; -import io.netty.util.concurrent.DefaultThreadFactory; -import io.netty.util.concurrent.GenericFutureListener; - -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; - -import com.google.gson.Gson; - -interface EventHandler { - void onReconnect(); - void onMessage(Message message); -} - -public class Client implements EventHandler { - private static final int EVENT_LOOP_THREADS = 1; - private static final int RECONNECT_SCHEDULE = 60; - private static final String SOCKET_PATH = "/var/run/smith_agent.sock"; - private static final String MESSAGE_DIRECTORY = "/var/run/elkeid_rasp"; - - private Channel channel; - private boolean stopX; - private MessageHandler messageHandler; - private EpollEventLoopGroup group; - private ChannelFuture cf; - private GenericFutureListener connectListener = (ChannelFuture f) -> { - if (!f.isSuccess()) { - if(!stopX) { - f.channel().eventLoop().schedule(this::onReconnect, RECONNECT_SCHEDULE, TimeUnit.SECONDS); - } - } - }; - - public Client(MessageHandler messageHandler) { - // note: linux use epoll, mac use kqueue - this.stopX = false; - this.messageHandler = messageHandler; - this.group = new EpollEventLoopGroup(EVENT_LOOP_THREADS, new DefaultThreadFactory(getClass(), true)); - } - - public void start() { - SmithLogger.logger.info("probe client start"); - - try { - Bootstrap b = new Bootstrap(); - b.group(group) - .channel(EpollDomainSocketChannel.class) - .handler(new ChannelInitializer() { - @Override - public void initChannel(DomainSocketChannel ch) { - ChannelPipeline p = ch.pipeline(); - - p.addLast(new MessageDecoder()); - p.addLast(new MessageEncoder()); - p.addLast(new ClientHandlerAdapter(Client.this)); - } - }); - - cf = b.connect(new DomainSocketAddress(SOCKET_PATH)).addListener(connectListener); - - channel = cf.sync().channel(); - - channel.closeFuture().sync(); - } catch (Exception e) { - SmithLogger.exception(e); - } - } - - public void stop() { - stopX = true; - group.shutdownGracefully(); - messageHandler = null; - group = null; - channel.close(); - channel = null; - cf.removeListener(connectListener); - cf = null; - connectListener = null; - } - - public void write(int operate, Object object) { - if (channel == null || !channel.isActive() || !channel.isWritable()) - return; - - try { - //TODO must send metadata first - if (operate == Operate.CLASSUPLOAD) { - Thread.sleep(1000); - } - } catch (Exception e) { - - } - ObjectMapper objectMapper = new ObjectMapper() - .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE); - - - Message message = new Message(); - - message.setOperate(operate); - message.setData(objectMapper.valueToTree(object)); - - channel.writeAndFlush(message); - } - - @Override - public void onReconnect() { - SmithLogger.logger.info("reconnect"); - - readMessage(); - new Thread(this::start).start(); - } - - @Override - public void onMessage(Message message) { - switch (message.getOperate()) { - case Operate.EXIT: - SmithLogger.logger.info("exit"); - break; - - case Operate.HEARTBEAT: - SmithLogger.logger.info("heartbeat"); - break; - - case Operate.CONFIG: - SmithLogger.logger.info("config"); - messageHandler.onConfig(message.getData().get("config").asText()); - break; - - case Operate.CONTROL: - SmithLogger.logger.info("control"); - messageHandler.onControl(message.getData().get("action").asInt()); - break; - - case Operate.DETECT: - SmithLogger.logger.info("detect"); - messageHandler.onDetect(); - break; - - case Operate.FILTER: { - SmithLogger.logger.info("filter: " + message.getData().toString()); - - try { - Gson gson = new Gson(); - FilterConfig config = gson.fromJson(message.getData().toString(), FilterConfig.class); - messageHandler.onFilter(config); - } catch (Exception e) { - SmithLogger.exception(e); - } - - break; - } - - case Operate.BLOCK: { - SmithLogger.logger.info("block: " + message.getData().toString()); - - try { - Gson gson = new Gson(); - BlockConfig config = gson.fromJson(message.getData().toString(), BlockConfig.class); - messageHandler.onBlock(config); - } catch (Exception e) { - SmithLogger.exception(e); - } - - break; - } - - case Operate.LIMIT: { - SmithLogger.logger.info("limit: " + message.getData().toString()); - - try { - Gson gson = new Gson(); - LimitConfig config = gson.fromJson(message.getData().toString(), LimitConfig.class); - messageHandler.onLimit(config); - } catch (Exception e) { - SmithLogger.exception(e); - } - - break; - } - - case Operate.PATCH: { - SmithLogger.logger.info("patch: " + message.getData().toString()); - - try { - Gson gson = new Gson(); - PatchConfig config = gson.fromJson(message.getData().toString(), PatchConfig.class); - messageHandler.onPatch(config); - } catch (Exception e) { - SmithLogger.exception(e); - } - - break; - } - case Operate.CLASSFILTERSTART: { - SmithLogger.logger.info("rule upload start: " + message.getData().toString()); - - try { - Gson gson = new Gson(); - Rule_Version ruleVersion = gson.fromJson(message.getData().toString(), Rule_Version.class); - messageHandler.setRuleVersion(ruleVersion); - } catch (Exception e) { - SmithLogger.exception(e); - } - - break; - } - case Operate.CLASSFILTER: { - SmithLogger.logger.info("rule upload: " + message.getData().toString()); - - try { - Gson gson = new Gson(); - Rule_Data ruleData = gson.fromJson(message.getData().toString(), Rule_Data.class); - messageHandler.OnAddRule(ruleData); - } catch (Exception e) { - SmithLogger.exception(e); - } - - break; - } - case Operate.CLASSFILTEREND: { - SmithLogger.logger.info("class filter config receive finish, start to scan all class"); - Thread scanAllClassThread = new Thread(messageHandler::onScanAllClass); - scanAllClassThread.setDaemon(true); - scanAllClassThread.start(); - } - } - } - - private void readMessage() { - Path path = Paths.get(MESSAGE_DIRECTORY, String.format("%d.json", ProcessHelper.getCurrentPID())); - - if (!Files.exists(path)) { - SmithLogger.logger.info("message file not exist: " + path); - return; - } - - SmithLogger.logger.info("read message file: " + path); - - ObjectMapper objectMapper = new ObjectMapper() - .setPropertyNamingStrategy(PropertyNamingStrategies.SNAKE_CASE) - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - - try { - for (Message message : objectMapper.readValue(path.toFile(), Message[].class)) - onMessage(message); - - Files.delete(path); - } catch (IOException e) { - SmithLogger.exception(e); - } - } - - static class ClientHandlerAdapter extends ChannelInboundHandlerAdapter { - private EventHandler eventHandler; - - ClientHandlerAdapter(EventHandler eventHandler) { - this.eventHandler = eventHandler; - } - - public void closeHandler() { - this.eventHandler = null; - } - - @Override - public void channelInactive(ChannelHandlerContext ctx) throws Exception { - super.channelInactive(ctx); - SmithLogger.logger.info("channel inactive"); - - ctx.channel().eventLoop().schedule( - eventHandler::onReconnect, - RECONNECT_SCHEDULE, - TimeUnit.SECONDS - ); - } - - @Override - public void channelActive(ChannelHandlerContext ctx) throws Exception { - super.channelActive(ctx); - SmithLogger.logger.info("channel active"); - } - - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - eventHandler.onMessage((Message) msg); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - SmithLogger.exception(cause); - ctx.close(); - } - } -} +package com.security.smith.client; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; +import com.security.smith.client.message.*; +import com.security.smith.common.ProcessHelper; +import com.security.smith.log.SmithLogger; +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.*; +import io.netty.channel.epoll.EpollDomainSocketChannel; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.unix.DomainSocketAddress; +import io.netty.channel.unix.DomainSocketChannel; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.GenericFutureListener; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.nio.charset.StandardCharsets; +import java.util.stream.Collectors; + +interface EventHandler { + void onReconnect(); + void onMessage(Message message); +} + +public class Client implements EventHandler { + private static final int EVENT_LOOP_THREADS = 1; + private static final int RECONNECT_SCHEDULE = 60; + private static final String SOCKET_PATH = "/var/run/smith_agent.sock"; + private static final String MESSAGE_DIRECTORY = "/var/run/elkeid_rasp"; + + private Channel channel; + private boolean stopX; + private MessageHandler messageHandler; + private EpollEventLoopGroup group; + private ChannelFuture cf; + private GenericFutureListener connectListener = (ChannelFuture f) -> { + if (!f.isSuccess()) { + if(!stopX) { + f.channel().eventLoop().schedule(this::onReconnect, RECONNECT_SCHEDULE, TimeUnit.SECONDS); + } + } + }; + + public Client(MessageHandler messageHandler) { + // note: linux use epoll, mac use kqueue + this.stopX = false; + this.messageHandler = messageHandler; + this.group = new EpollEventLoopGroup(EVENT_LOOP_THREADS, new DefaultThreadFactory(getClass(), true)); + } + + public void start() { + SmithLogger.logger.info("probe client start"); + + try { + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(EpollDomainSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + public void initChannel(DomainSocketChannel ch) { + ChannelPipeline p = ch.pipeline(); + + p.addLast(new MessageDecoder()); + p.addLast(new MessageEncoder()); + p.addLast(new ClientHandlerAdapter(Client.this)); + } + }); + + cf = b.connect(new DomainSocketAddress(SOCKET_PATH)).addListener(connectListener); + + channel = cf.sync().channel(); + + channel.closeFuture().sync(); + } catch (Exception e) { + SmithLogger.exception(e); + } + } + + public void stop() { + stopX = true; + group.shutdownGracefully(); + messageHandler = null; + group = null; + channel.close(); + channel = null; + cf.removeListener(connectListener); + cf = null; + connectListener = null; + } + + public void write(int operate, Object object) { + if (channel == null || !channel.isActive() || !channel.isWritable()) + return; + + try { + //TODO must send metadata first + if (operate == Operate.CLASSUPLOAD) { + Thread.sleep(1000); + } + } catch (Exception e) { + + } + + Message message = new Message(); + message.setOperate(operate); + + message.setData((JsonElement)object); + + channel.writeAndFlush(message); + } + + @Override + public void onReconnect() { + SmithLogger.logger.info("reconnect"); + + readMessage(); + new Thread(this::start).start(); + } + + @Override + public void onMessage(Message message) { + switch (message.getOperate()) { + case Operate.EXIT: + SmithLogger.logger.info("exit"); + break; + + case Operate.HEARTBEAT: + SmithLogger.logger.info("heartbeat"); + break; + + case Operate.CONFIG: + SmithLogger.logger.info("config"); + JsonElement configElement = message.getData().getAsJsonObject().get("config"); + if (configElement != null && configElement.isJsonPrimitive()) { + String config = configElement.getAsString(); + messageHandler.onConfig(config); + } + //messageHandler.onConfig(message.getData().get("config").asText()); + break; + + case Operate.CONTROL: + SmithLogger.logger.info("control"); + JsonElement actionElement = message.getData().getAsJsonObject().get("action"); + if (actionElement != null && actionElement.isJsonPrimitive() && actionElement.getAsJsonPrimitive().isNumber()) { + int action = actionElement.getAsInt(); + messageHandler.onControl(action); + } + //messageHandler.onControl(message.getData().get("action").asInt()); + break; + + case Operate.DETECT: + SmithLogger.logger.info("detect"); + messageHandler.onDetect(); + break; + + case Operate.FILTER: { + SmithLogger.logger.info("filter: " + message.getData().toString()); + + try { + Gson gson = new Gson(); + FilterConfig config = gson.fromJson(message.getData().toString(), FilterConfig.class); + messageHandler.onFilter(config); + } catch (Exception e) { + SmithLogger.exception(e); + } + + break; + } + + case Operate.BLOCK: { + SmithLogger.logger.info("block: " + message.getData().toString()); + + try { + Gson gson = new Gson(); + BlockConfig config = gson.fromJson(message.getData().toString(), BlockConfig.class); + messageHandler.onBlock(config); + } catch (Exception e) { + SmithLogger.exception(e); + } + + break; + } + + case Operate.LIMIT: { + SmithLogger.logger.info("limit: " + message.getData().toString()); + + try { + Gson gson = new Gson(); + LimitConfig config = gson.fromJson(message.getData().toString(), LimitConfig.class); + messageHandler.onLimit(config); + } catch (Exception e) { + SmithLogger.exception(e); + } + + break; + } + + case Operate.PATCH: { + SmithLogger.logger.info("patch: " + message.getData().toString()); + + try { + Gson gson = new Gson(); + PatchConfig config = gson.fromJson(message.getData().toString(), PatchConfig.class); + messageHandler.onPatch(config); + } catch (Exception e) { + SmithLogger.exception(e); + } + + break; + } + case Operate.CLASSFILTERSTART: { + SmithLogger.logger.info("rule upload start: " + message.getData().toString()); + + try { + Gson gson = new Gson(); + Rule_Version ruleVersion = gson.fromJson(message.getData().toString(), Rule_Version.class); + messageHandler.setRuleVersion(ruleVersion); + } catch (Exception e) { + SmithLogger.exception(e); + } + + break; + } + case Operate.CLASSFILTER: { + SmithLogger.logger.info("rule upload: " + message.getData().toString()); + + try { + Gson gson = new Gson(); + Rule_Data ruleData = gson.fromJson(message.getData().toString(), Rule_Data.class); + messageHandler.OnAddRule(ruleData); + } catch (Exception e) { + SmithLogger.exception(e); + } + + break; + } + case Operate.CLASSFILTEREND: { + SmithLogger.logger.info("class filter config receive finish, start to scan all class"); + Thread scanAllClassThread = new Thread(messageHandler::onScanAllClass); + scanAllClassThread.setDaemon(true); + scanAllClassThread.start(); + } + } + } + + private void readMessage() { + Path path = Paths.get(MESSAGE_DIRECTORY, String.format("%d.json", ProcessHelper.getCurrentPID())); + + if (!Files.exists(path)) { + SmithLogger.logger.info("message file not exist: " + path); + return; + } + + SmithLogger.logger.info("read message file: " + path); + + try { + String json = Files.lines(path, StandardCharsets.UTF_8) + .collect(Collectors.joining("\n")); + JsonElement jsonElement = JsonParser.parseString(json); + + if (jsonElement.isJsonArray()) { + JsonArray jsonArray = jsonElement.getAsJsonArray(); + for (JsonElement element : jsonArray) { + if (element.isJsonObject()) { + Message message = new Gson().fromJson(element, Message.class); + onMessage(message); + } + } + } + + Files.delete(path); + } catch (IOException e) { + SmithLogger.exception(e); + } + } + + static class ClientHandlerAdapter extends ChannelInboundHandlerAdapter { + private EventHandler eventHandler; + + ClientHandlerAdapter(EventHandler eventHandler) { + this.eventHandler = eventHandler; + } + + public void closeHandler() { + this.eventHandler = null; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + super.channelInactive(ctx); + SmithLogger.logger.info("channel inactive"); + + ctx.channel().eventLoop().schedule( + eventHandler::onReconnect, + RECONNECT_SCHEDULE, + TimeUnit.SECONDS + ); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + SmithLogger.logger.info("channel active"); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + eventHandler.onMessage((Message) msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + SmithLogger.exception(cause); + ctx.close(); + } + } +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Message.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Message.java index 82f18d579..b2d7f13ec 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Message.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Message.java @@ -1,108 +1,32 @@ -package com.security.smith.client; - -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.fasterxml.jackson.databind.deser.std.StdDeserializer; -import com.fasterxml.jackson.databind.ser.std.StdSerializer; -import io.netty.buffer.ByteBuf; -import io.netty.channel.ChannelHandlerContext; -import io.netty.handler.codec.MessageToByteEncoder; -import io.netty.handler.codec.ReplayingDecoder; - -import java.io.IOException; -import java.lang.management.ManagementFactory; -import java.nio.ByteBuffer; -import java.time.Instant; -import java.util.List; - -@JsonSerialize(using = MessageSerializer.class) -@JsonDeserialize(using = MessageDeserializer.class) -public class Message { - static final int PROTOCOL_HEADER_SIZE = 4; - static final int MAX_PAYLOAD_SIZE = 10240; - - private int operate; - private JsonNode data; - - int getOperate() { - return operate; - } - - public void setOperate(int operate) { - this.operate = operate; - } - - public JsonNode getData() { - return data; - } - - public void setData(JsonNode data) { - this.data = data; - } -} - -class MessageDeserializer extends StdDeserializer { - protected MessageDeserializer() { - super(Message.class); - } - - protected MessageDeserializer(Class vc) { - super(vc); - } - - @Override - public Message deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { - JsonNode node = p.getCodec().readTree(p); - - Message message = new Message(); - - message.setOperate(node.get("message_type").asInt()); - message.setData(node.get("data")); - - return message; - } -} - -class MessageEncoder extends MessageToByteEncoder { - @Override - protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) throws JsonProcessingException { - ObjectMapper objectMapper = new ObjectMapper(); - - byte[] payload = objectMapper.writeValueAsBytes(msg); - int payloadSize = payload.length; - - ByteBuffer buffer = ByteBuffer.allocate(payloadSize + Message.PROTOCOL_HEADER_SIZE); - - buffer.putInt(payloadSize); - buffer.put(payload); - - buffer.flip(); - - out.writeBytes(buffer); - } -} - -class MessageDecoder extends ReplayingDecoder { - @Override - protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws IOException { - long payloadSize = in.readUnsignedInt(); - - if (payloadSize > Message.MAX_PAYLOAD_SIZE) - return; - - byte[] buffer = new byte[(int) payloadSize]; - in.readBytes(buffer); - - Message message = new ObjectMapper().readValue(buffer, Message.class); - - if (message != null) - out.add(message); - } -} +package com.security.smith.client; + +import java.lang.reflect.Type; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; + + +public class Message { + static final int PROTOCOL_HEADER_SIZE = 4; + static final int MAX_PAYLOAD_SIZE = 10240; + + private int operate; + private JsonElement data; + + int getOperate() { + return operate; + } + + public void setOperate(int operate) { + this.operate = operate; + } + + public JsonElement getData() { + return data; + } + + public void setData(JsonElement data) { + this.data = data; + } +} + diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageDecoder.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageDecoder.java new file mode 100644 index 000000000..7184caeaf --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageDecoder.java @@ -0,0 +1,50 @@ + +package com.security.smith.client; + +import java.lang.reflect.Type; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ReplayingDecoder; +import io.netty.buffer.ByteBuf; +import com.security.smith.client.MessageSerializer; +import com.security.smith.client.MessageDeserializer; + +import java.io.IOException; +import java.util.List; + +public class MessageDecoder extends ReplayingDecoder { + private static Gson gson = null; + + public static void delInstance() { + gson = null; + } + + public static void initInstance() { + gson = new GsonBuilder() + .registerTypeAdapter(Message.class, new MessageSerializer()) + .registerTypeAdapter(Message.class, new MessageDeserializer()) + .create(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws IOException { + try { + long payloadSize = in.readUnsignedInt(); + if (payloadSize > Message.MAX_PAYLOAD_SIZE) + return; + + byte[] buffer = new byte[(int) payloadSize]; + in.readBytes(buffer); + + String msg = new String(buffer); + Message message = gson.fromJson(msg,Message.class); + if (message != null) + out.add(message); + } + catch(Throwable e) { + e.printStackTrace(); + } + + } +} \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageDeserializer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageDeserializer.java new file mode 100644 index 000000000..0c497d1c9 --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageDeserializer.java @@ -0,0 +1,22 @@ +package com.security.smith.client; + +import java.lang.reflect.Type; +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonParseException; + +public class MessageDeserializer implements JsonDeserializer { + @Override + public Message deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) + throws JsonParseException { + JsonObject obj = json.getAsJsonObject(); + Message message = new Message(); + message.setOperate(obj.get("message_type").getAsInt()); + message.setData(obj.get("data")); + return message; + } +} \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageEncoder.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageEncoder.java new file mode 100644 index 000000000..4ad8b8621 --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageEncoder.java @@ -0,0 +1,45 @@ +package com.security.smith.client; + +import java.lang.reflect.Type; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.buffer.ByteBuf; +import java.nio.ByteBuffer; +import com.security.smith.client.MessageSerializer; +import com.security.smith.client.MessageDeserializer; + +public class MessageEncoder extends MessageToByteEncoder { + private static Gson gson = null; + + public static void delInstance() { + gson = null; + } + + public static void initInstance() { + gson = new GsonBuilder() + .registerTypeAdapter(Message.class, new MessageSerializer()) + .registerTypeAdapter(Message.class, new MessageDeserializer()) + .create(); + } + + @Override + protected void encode(ChannelHandlerContext ctx, Object msg, ByteBuf out) { + try { + byte[] payload = gson.toJson(msg).getBytes(); + int payloadSize = payload.length; + + ByteBuffer buffer = ByteBuffer.allocate(payloadSize + Message.PROTOCOL_HEADER_SIZE); + buffer.putInt(payloadSize); + buffer.put(payload); + buffer.flip(); + + out.writeBytes(buffer); + } + catch(Throwable e) { + e.printStackTrace(); + } + + } +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageSerializer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageSerializer.java index 349ab15ba..53836fada 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageSerializer.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/MessageSerializer.java @@ -1,51 +1,52 @@ -package com.security.smith.client; - -import com.fasterxml.jackson.databind.ser.std.StdSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.core.JsonGenerator; -import java.lang.management.ManagementFactory; -import java.io.IOException; -import java.time.Instant; -import com.security.smith.common.ProcessHelper; - -public class MessageSerializer extends StdSerializer { - static private int pid; - static private String jvmVersion; - static private String probeVersion; - - public static void initInstance(String probeVer) { - pid = ProcessHelper.getCurrentPID(); - jvmVersion = ManagementFactory.getRuntimeMXBean().getSpecVersion(); - probeVersion = probeVer; - } - - public static void delInstance() { - probeVersion = null; - jvmVersion = null; - pid = 0; - } - - protected MessageSerializer() { - super(Message.class); - } - - protected MessageSerializer(Class t) { - super(t); - } - - @Override - public void serialize(Message value, JsonGenerator gen, SerializerProvider provider) throws IOException { - gen.writeStartObject(); - gen.writeNumberField("message_type", value.getOperate()); - - gen.writeNumberField("pid", pid); - gen.writeStringField("runtime", "JVM"); - gen.writeStringField("runtime_version", jvmVersion); - gen.writeStringField("probe_version", probeVersion); - gen.writeNumberField("time", Instant.now().getEpochSecond()); - - gen.writeObjectField("data", value.getData()); - - gen.writeEndObject(); - } -} +package com.security.smith.client; + +import java.lang.reflect.Type; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; + +import com.security.smith.common.ProcessHelper; +import java.lang.management.ManagementFactory; +import java.time.Instant; + + +public class MessageSerializer implements JsonSerializer { + static private int pid; + static private String jvmVersion; + static private String probeVersion; + + public static void initInstance(String probeVer) { + pid = ProcessHelper.getCurrentPID(); + jvmVersion = ManagementFactory.getRuntimeMXBean().getSpecVersion(); + probeVersion = probeVer; + } + + public static void delInstance() { + jvmVersion = null; + probeVersion = null; + jvmVersion = null; + pid = 0; + } + + public static void initInstance() { + pid = ProcessHelper.getCurrentPID(); + jvmVersion = ManagementFactory.getRuntimeMXBean().getSpecVersion(); + probeVersion = MessageSerializer.class.getPackage().getImplementationVersion(); + } + + @Override + public JsonElement serialize(Message message, Type typeOfSrc, JsonSerializationContext context) { + JsonObject obj = new JsonObject(); + obj.addProperty("message_type", message.getOperate()); + obj.add("data", context.serialize(message.getData())); + obj.addProperty("pid", pid); + obj.addProperty("runtime", "JVM"); + obj.addProperty("runtime_version", jvmVersion); + obj.addProperty("probe_version", probeVersion); + obj.addProperty("time", Instant.now().getEpochSecond()); + return obj; + } +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Rule_Config.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Rule_Config.java index d403c68bb..7878e0040 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Rule_Config.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/Rule_Config.java @@ -2,14 +2,12 @@ import com.security.smith.log.SmithLogger; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.gson.Gson; import java.util.ArrayList; import java.util.concurrent.ExecutionException; public class Rule_Config { - private ObjectMapper m_objectMapper = new ObjectMapper(); private Rule_Json m_Rule_Json; private Rule_Mgr m_Rule_Mgr; @@ -42,7 +40,6 @@ public boolean setVersion(int rule_version) { public void destry() { try { m_Rule_Json = null; - m_objectMapper = null; m_Rule_Mgr.delRule_all(); m_Rule_Mgr = null; } catch (Exception e) { @@ -97,15 +94,14 @@ public boolean addRuleData(Rule_Data ruleData) { public boolean setRuleConfig(String JsonRule) { boolean bresult = false; + Gson gson = new Gson(); if(m_Rule_Mgr == null) { return false; } try { - m_objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES,false); - - m_Rule_Json = m_objectMapper.readValue(JsonRule, Rule_Json.class); + m_Rule_Json = gson.fromJson(JsonRule, Rule_Json.class); if(m_Rule_Json != null && m_Rule_Json.getRule().length > 0) { int rule_version = m_Rule_Json.getRule_version(); Rule_Item[] rule = m_Rule_Json.getRule(); diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilter.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilter.java index bc1646522..c71278e2c 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilter.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilter.java @@ -1,110 +1,106 @@ -package com.security.smith.client.message; - - -import java.time.Instant; -import java.util.UUID; - -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - - -public class ClassFilter { - private String transId = null; - private String className = ""; - private String classPath = ""; - private String interfacesName = ""; - private String classLoaderName = ""; - private String parentClassName = ""; - private String parentClassLoaderName = ""; - private long ruleId; - - @JsonSerialize(converter = StackTraceConverter.class) - private StackTraceElement[] stackTrace = {}; - - public String getTransId() { - return transId; - } - - public void setTransId() { - UUID uniqueId = UUID.randomUUID(); - transId = uniqueId.toString().replace("-", ""); - } - - public String getClassName() { - return className; - } - - public void setClassName(String className) { - this.className = className; - } - - public String getClassPath() { - return classPath; - } - - public void setClassPath(String classPath) { - this.classPath = classPath; - } - - public String getInterfacesName() { - return interfacesName; - } - - public void setInterfacesName(String interfacesName) { - this.interfacesName = interfacesName; - } - public String getClassLoaderName() { - return classLoaderName; - } - - public void setClassLoaderName(String classLoaderName) { - this.classLoaderName = classLoaderName; - } - - - public String getParentClassName() { - return parentClassName; - } - - public void setParentClassName(String parentClassName) { - this.parentClassName = parentClassName; - } - - public String getParentClassLoaderName() { - return parentClassLoaderName; - } - - public void setParentClassLoaderName(String parentClassLoaderName) { - this.parentClassLoaderName = parentClassLoaderName; - } - - public long getRuleId() { - return ruleId; - } - - public void setRuleId(long ruleId) { - this.ruleId = ruleId; - } - - public StackTraceElement[] getStackTrace() { - return stackTrace; - } - - public void setStackTrace(StackTraceElement[] stackTrace) { - this.stackTrace = stackTrace; - } - @Override - public String toString() { - return "{" + - "transId: '" + transId + '\'' + - ", className: '" + className + '\'' + - ", classPath: '" + classPath + '\'' + - ", interfacesName: '" + interfacesName + '\'' + - ", classLoaderName: '" + classLoaderName + '\'' + - ", parentClassName: '" + parentClassName + '\'' + - ", parentClassLoaderName: '" + parentClassLoaderName + '\'' + - ", ruleId: " + ruleId + - ", timestamp: " + Instant.now().getEpochSecond() + - '}'; - } - +package com.security.smith.client.message; + +import com.google.gson.annotations.SerializedName; +import java.time.Instant; +import java.util.UUID; + +public class ClassFilter { + private String transId = null; + private String className = ""; + private String classPath = ""; + private String interfacesName = ""; + private String classLoaderName = ""; + private String parentClassName = ""; + private String parentClassLoaderName = ""; + private long ruleId; + @SerializedName("stackTrace") + private StackTraceElement[] stackTrace = {}; + + public String getTransId() { + return transId; + } + + public void setTransId() { + UUID uniqueId = UUID.randomUUID(); + transId = uniqueId.toString().replace("-", ""); + } + + public String getClassName() { + return className; + } + + public void setClassName(String className) { + this.className = className; + } + + public String getClassPath() { + return classPath; + } + + public void setClassPath(String classPath) { + this.classPath = classPath; + } + + public String getInterfacesName() { + return interfacesName; + } + + public void setInterfacesName(String interfacesName) { + this.interfacesName = interfacesName; + } + + public String getClassLoaderName() { + return classLoaderName; + } + + public void setClassLoaderName(String classLoaderName) { + this.classLoaderName = classLoaderName; + } + + public String getParentClassName() { + return parentClassName; + } + + public void setParentClassName(String parentClassName) { + this.parentClassName = parentClassName; + } + + public String getParentClassLoaderName() { + return parentClassLoaderName; + } + + public void setParentClassLoaderName(String parentClassLoaderName) { + this.parentClassLoaderName = parentClassLoaderName; + } + + public long getRuleId() { + return ruleId; + } + + public void setRuleId(long ruleId) { + this.ruleId = ruleId; + } + + public StackTraceElement[] getStackTrace() { + return stackTrace; + } + + public void setStackTrace(StackTraceElement[] stackTrace) { + this.stackTrace = stackTrace; + } + + @Override + public String toString() { + return "{" + + "transId: '" + transId + '\'' + + ", className: '" + className + '\'' + + ", classPath: '" + classPath + '\'' + + ", interfacesName: '" + interfacesName + '\'' + + ", classLoaderName: '" + classLoaderName + '\'' + + ", parentClassName: '" + parentClassName + '\'' + + ", parentClassLoaderName: '" + parentClassLoaderName + '\'' + + ", ruleId: " + ruleId + + ", timestamp: " + Instant.now().getEpochSecond() + + '}'; + } } \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilterDeserializer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilterDeserializer.java new file mode 100644 index 000000000..03348ece7 --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilterDeserializer.java @@ -0,0 +1,46 @@ +package com.security.smith.client.message; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.security.smith.client.message.ClassFilter; +import com.security.smith.client.message.ClassFilter; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import java.lang.reflect.Type; + +public class ClassFilterDeserializer implements com.google.gson.JsonDeserializer { + @Override + public ClassFilter deserialize(JsonElement json, Type typeOfT, com.google.gson.JsonDeserializationContext context) { + JsonObject jsonObject = json.getAsJsonObject(); + ClassFilter filter = new ClassFilter(); + filter.setTransId(); + filter.setClassName(jsonObject.getAsJsonPrimitive("class_name").getAsString()); + filter.setClassPath(jsonObject.getAsJsonPrimitive("class_path").getAsString()); + filter.setInterfacesName(jsonObject.getAsJsonPrimitive("interfaces_name").getAsString()); + filter.setClassLoaderName(jsonObject.getAsJsonPrimitive("class_Loader_name").getAsString()); + filter.setParentClassName(jsonObject.getAsJsonPrimitive("parent_Class_name").getAsString()); + filter.setParentClassLoaderName(jsonObject.getAsJsonPrimitive("parent_class_Loader_name").getAsString()); + filter.setRuleId(jsonObject.getAsJsonPrimitive("rule_id").getAsInt()); + filter.setStackTrace(convertStackTrace(context.deserialize(jsonObject.get("stackTrace"), String[].class))); + return filter; + } + + private StackTraceElement[] convertStackTrace(String[] stackTrace) { + if (stackTrace == null) + return new StackTraceElement[0]; + + StackTraceElement[] result = new StackTraceElement[stackTrace.length]; + for (int i = 0; i < stackTrace.length; i++) { + String[] parts = stackTrace[i].split(","); + if (parts.length != 4) { + continue; + } + String className = parts[0].trim(); + String methodName = parts[1].trim(); + String fileName = parts[2].trim(); + int lineNumber = Integer.parseInt(parts[3].trim()); + result[i] = new StackTraceElement(className, methodName, fileName, lineNumber); + } + return result; + } +} \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilterSerializer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilterSerializer.java new file mode 100644 index 000000000..067f377b5 --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/ClassFilterSerializer.java @@ -0,0 +1,42 @@ +package com.security.smith.client.message; + +import com.google.gson.JsonObject; +import com.google.gson.JsonElement; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; +import com.security.smith.client.message.ClassFilter; +import com.security.smith.client.message.Trace; + +import java.lang.reflect.Type; +import java.util.Arrays; + +public class ClassFilterSerializer implements JsonSerializer { + @Override + public JsonElement serialize(ClassFilter src, Type typeOfSrc, JsonSerializationContext context) { + JsonObject jsonObject = new JsonObject(); + jsonObject.addProperty("trans_id", src.getTransId()); + jsonObject.addProperty("class_name", src.getClassName()); + jsonObject.addProperty("class_path", src.getClassPath()); + jsonObject.addProperty("interfaces_name", src.getInterfacesName()); + jsonObject.addProperty("class_Loader_name", src.getClassLoaderName()); + jsonObject.addProperty("parent_Class_name", src.getParentClassName()); + jsonObject.addProperty("parent_class_Loader_name", src.getParentClassLoaderName()); + jsonObject.addProperty("rule_id", src.getRuleId()); + jsonObject.add("stackTrace", context.serialize(convertStackTrace(src.getStackTrace()))); + return jsonObject; + } + + private String[] convertStackTrace(StackTraceElement[] stackTrace) { + if (stackTrace.length <= 2) + return null; + + StackTraceElement[] elements = Arrays.copyOfRange(stackTrace, 2, stackTrace.length); + String[] result = new String[elements.length]; + + for (int i = 0; i < elements.length; i++) { + result[i] = elements[i].toString(); + } + + return result; + } +} \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Heartbeat.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Heartbeat.java index 59adaaeea..a2210fee8 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Heartbeat.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Heartbeat.java @@ -1,58 +1,90 @@ -package com.security.smith.client.message; - -public class Heartbeat { - private String filter; - private String block; - private String limit; - private String patch; - private String class_filter_version; - private int discard_count; - - public String getFilter() { - return filter; - } - - public void setFilter(String filter) { - this.filter = filter; - } - - public String getBlock() { - return block; - } - - public void setBlock(String block) { - this.block = block; - } - - public String getLimit() { - return limit; - } - - public void setLimit(String limit) { - this.limit = limit; - } - - public String getPatch() { - return patch; - } - - public void setPatch(String patch) { - this.patch = patch; - } - - public String getClassFilterVersion() { - return class_filter_version; - } - - public void setClassFilterVersion(String classFilterVersion) { - this.class_filter_version = classFilterVersion; - } - - public synchronized int getDiscardCount() { - return discard_count; - } - - public synchronized void setDiscardCount(int discard_count) { - this.discard_count = discard_count; - } -} +package com.security.smith.client.message; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; + +public class Heartbeat { + private String filter; + private String block; + private String limit; + private String patch; + private String class_filter_version; + private int discard_count; + + public String getFilter() { + return filter; + } + + public void setFilter(String filter) { + this.filter = filter; + } + + public String getBlock() { + return block; + } + + public void setBlock(String block) { + this.block = block; + } + + public String getLimit() { + return limit; + } + + public void setLimit(String limit) { + this.limit = limit; + } + + public String getPatch() { + return patch; + } + + public void setPatch(String patch) { + this.patch = patch; + } + + public String getClassFilterVersion() { + return class_filter_version; + } + + public void setClassFilterVersion(String classFilterVersion) { + this.class_filter_version = classFilterVersion; + } + + public synchronized int getDiscardCount() { + return discard_count; + } + + public synchronized void setDiscardCount(int discard_count) { + this.discard_count = discard_count; + } + + public static Heartbeat fromJson(String json) { + Gson gson = new Gson(); + return gson.fromJson(json, Heartbeat.class); + } + + public String toJson() { + Gson gson = new Gson(); + return gson.toJson(this); + } + + public static Heartbeat fromJsonObject(JsonObject jsonObject) { + Gson gson = new Gson(); + return gson.fromJson(jsonObject, Heartbeat.class); + } + + public JsonObject toJsonObject() { + Gson gson = new Gson(); + JsonElement jsonElement = gson.toJsonTree(this); + return jsonElement.getAsJsonObject(); + } + + public JsonElement toJsonElement() { + Gson gson = new Gson(); + JsonElement jsonElement = gson.toJsonTree(this); + return jsonElement; + } +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Trace.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Trace.java index c6e42a1c1..33cafb10b 100644 --- a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Trace.java +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/Trace.java @@ -1,123 +1,76 @@ -package com.security.smith.client.message; - -import com.fasterxml.jackson.databind.annotation.JsonSerialize; -import com.fasterxml.jackson.databind.util.StdConverter; - -import java.util.Arrays; - -public class Trace { - private int classID; - private int methodID; - private boolean blocked; - private String policyID; - - @JsonSerialize(converter = RetConverter.class) - private Object ret; - - @JsonSerialize(converter = ArgsConverter.class) - private Object[] args; - - @JsonSerialize(converter = StackTraceConverter.class) - private StackTraceElement[] stackTrace; - - public int getClassID() { - return classID; - } - - public void setClassID(int classID) { - this.classID = classID; - } - - public int getMethodID() { - return methodID; - } - - public void setMethodID(int methodID) { - this.methodID = methodID; - } - - public boolean isBlocked() { - return blocked; - } - - public void setBlocked(boolean blocked) { - this.blocked = blocked; - } - - public String getPolicyID() { - return policyID; - } - - public void setPolicyID(String policyID) { - this.policyID = policyID; - } - - public Object getRet() { - return ret; - } - - public void setRet(Object ret) { - this.ret = ret; - } - - public Object[] getArgs() { - return args; - } - - public void setArgs(Object[] args) { - this.args = args; - } - - public StackTraceElement[] getStackTrace() { - return stackTrace; - } - - public void setStackTrace(StackTraceElement[] stackTrace) { - this.stackTrace = stackTrace; - } -} - -class RetConverter extends StdConverter { - @Override - public String convert(Object value) { - return String.valueOf(value); - } -} - -class StackTraceConverter extends StdConverter { - @Override - public String[] convert(StackTraceElement[] value) { - if (value.length <= 2) - return null; - - /* - return Arrays.stream(Arrays.copyOfRange(value, 2, value.length)) - .map(StackTraceElement::toString) - .toArray(String[]::new); - */ - - StackTraceElement[] elements = Arrays.copyOfRange(value, 2, value.length); - String[] result = new String[elements.length]; - - for (int i = 0; i < elements.length; i++) { - result[i] = elements[i].toString(); - } - - return result; - } -} - -class ArgsConverter extends StdConverter { - @Override - public String[] convert(Object[] value) { - //return Arrays.stream(value).map(String::valueOf).toArray(String[]::new); - - String[] result = new String[value.length]; - - for (int i = 0; i < value.length; i++) { - result[i] = String.valueOf(value[i]); - } - - return result; - } -} +package com.security.smith.client.message; + +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; + + +import java.util.Arrays; + +public class Trace { + private int classID; + private int methodID; + private boolean blocked; + private String policyID; + + private Object ret; + private Object[] args; + private StackTraceElement[] stackTrace; + + public int getClassID() { + return classID; + } + + public void setClassID(int classID) { + this.classID = classID; + } + + public int getMethodID() { + return methodID; + } + + public void setMethodID(int methodID) { + this.methodID = methodID; + } + + public boolean isBlocked() { + return blocked; + } + + public void setBlocked(boolean blocked) { + this.blocked = blocked; + } + + public String getPolicyID() { + return policyID; + } + + public void setPolicyID(String policyID) { + this.policyID = policyID; + } + + public Object getRet() { + return ret; + } + + public void setRet(Object ret) { + this.ret = ret; + } + + public Object[] getArgs() { + return args; + } + + public void setArgs(Object[] args) { + this.args = args; + } + + public StackTraceElement[] getStackTrace() { + return stackTrace; + } + + public void setStackTrace(StackTraceElement[] stackTrace) { + this.stackTrace = stackTrace; + } +} diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/TraceDeserializer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/TraceDeserializer.java new file mode 100644 index 000000000..237465e2d --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/TraceDeserializer.java @@ -0,0 +1,42 @@ +package com.security.smith.client.message; + +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import java.lang.reflect.Type; + +public class TraceDeserializer implements com.google.gson.JsonDeserializer { + @Override + public Trace deserialize(JsonElement json, Type typeOfT, com.google.gson.JsonDeserializationContext context) { + JsonObject jsonObject = json.getAsJsonObject(); + Trace trace = new Trace(); + trace.setClassID(jsonObject.getAsJsonPrimitive("class_id").getAsInt()); + trace.setMethodID(jsonObject.getAsJsonPrimitive("method_id").getAsInt()); + trace.setBlocked(jsonObject.getAsJsonPrimitive("blocked").getAsBoolean()); + trace.setPolicyID(jsonObject.getAsJsonPrimitive("policyID").getAsString()); + trace.setRet(context.deserialize(jsonObject.get("ret"), Object.class)); + trace.setArgs(context.deserialize(jsonObject.get("args"), Object[].class)); + trace.setStackTrace(convertStackTrace(context.deserialize(jsonObject.get("stackTrace"), String[].class))); + return trace; + } + + private StackTraceElement[] convertStackTrace(String[] stackTrace) { + if (stackTrace == null) + return new StackTraceElement[0]; + + StackTraceElement[] result = new StackTraceElement[stackTrace.length]; + for (int i = 0; i < stackTrace.length; i++) { + String[] parts = stackTrace[i].split(","); + if (parts.length != 4) { + continue; + } + String className = parts[0].trim(); + String methodName = parts[1].trim(); + String fileName = parts[2].trim(); + int lineNumber = Integer.parseInt(parts[3].trim()); + result[i] = new StackTraceElement(className, methodName, fileName, lineNumber); + } + return result; + } +} \ No newline at end of file diff --git a/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/TraceSerializer.java b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/TraceSerializer.java new file mode 100644 index 000000000..0101def49 --- /dev/null +++ b/rasp/jvm/JVMProbe/src/main/java/com/security/smith/client/message/TraceSerializer.java @@ -0,0 +1,52 @@ +package com.security.smith.client.message; + +import com.google.gson.JsonObject; +import com.google.gson.JsonElement; +import com.google.gson.JsonSerializationContext; +import com.google.gson.JsonSerializer; +import com.security.smith.client.message.Trace; + +import java.lang.reflect.Type; +import java.util.Arrays; + +public class TraceSerializer implements JsonSerializer { + @Override + public JsonElement serialize(Trace src, Type typeOfSrc, JsonSerializationContext context) { + JsonObject jsonObject = new JsonObject(); + jsonObject.addProperty("class_id", src.getClassID()); + jsonObject.addProperty("method_id", src.getMethodID()); + jsonObject.addProperty("blocked", src.isBlocked()); + jsonObject.addProperty("policyID", src.getPolicyID()); + jsonObject.add("ret",context.serialize(convertRet(src.getRet()))); + jsonObject.add("args",context.serialize(convertArgs(src.getArgs()))); + jsonObject.add("stackTrace", context.serialize(convertStackTrace(src.getStackTrace()))); + return jsonObject; + } + + private String convertRet(Object value) { + return String.valueOf(value); + } + + private String[] convertArgs(Object[] value) { + String[] result = new String[value.length]; + for (int i = 0; i < value.length; i++) { + result[i] = String.valueOf(value[i]); + } + return result; +// return Arrays.stream(value).map(String::valueOf).toArray(String[]::new); + } + + private String[] convertStackTrace(StackTraceElement[] stackTrace) { + if (stackTrace.length <= 2) + return null; + + StackTraceElement[] elements = Arrays.copyOfRange(stackTrace, 2, stackTrace.length); + String[] result = new String[elements.length]; + + for (int i = 0; i < elements.length; i++) { + result[i] = elements[i].toString(); + } + + return result; + } +} \ No newline at end of file diff --git a/rasp/librasp/src/jvm.rs b/rasp/librasp/src/jvm.rs index e58ad047f..8aff3535e 100644 --- a/rasp/librasp/src/jvm.rs +++ b/rasp/librasp/src/jvm.rs @@ -1,240 +1,363 @@ -use anyhow::{anyhow, Result}; - -use log::*; -use regex::Regex; -use std::process::Command; -use std::fs; -use std::time::Duration; -use crate::async_command::run_async_process; -use crate::process::ProcessInfo; -use crate::runtime::{ProbeCopy, ProbeState, ProbeStateInspect}; -use crate::settings::{self, RASP_VERSION}; -use lazy_static::lazy_static; - -lazy_static! { - static ref RASP_JAVA_CHECKSUMSTR: String = { - match fs::read_to_string(settings::RASP_JAVA_CHECKSUM_PATH()) { - Ok(content) => content, - Err(e) => { - error!("failed to get Java check sum, err: {}, path: {},java may not be attach success", e, settings::RASP_JAVA_CHECKSUM_PATH()); - String::new() - } - } - }; -} -pub struct JVMProbeState {} - -impl ProbeStateInspect for JVMProbeState { - fn inspect_process(process_info: &ProcessInfo) -> Result { - match prop(process_info.pid) { - Ok(state) => { - Ok(state) - } - Err(_) => { - Ok(ProbeState::NotAttach) - } - } - } -} - -pub struct JVMProbe {} - -impl ProbeCopy for JVMProbe { - fn names() -> (Vec, Vec) { - ( - [ - settings::RASP_JAVA_JATTACH_BIN(), - settings::RASP_JAVA_PROBE_BIN(), - settings::RASP_JAVA_AGENT_BIN(), - ] - .to_vec(), - [].to_vec(), - ) - } -} - -pub fn java_attach(pid: i32) -> Result { - let java_attach = settings::RASP_JAVA_JATTACH_BIN(); - let agent = settings::RASP_JAVA_AGENT_BIN(); - let probe_param = format!("{}={};{};{};", agent, "attach", *RASP_JAVA_CHECKSUMSTR, settings::RASP_JAVA_PROBE_BIN()); - match run_async_process(Command::new(java_attach).args(&[ - pid.to_string().as_str(), - "load", - "instrument", - "false", - probe_param.as_str(), - ])) { - Ok((es, out, err)) => { - if out.len() != 0 { - info!("{}", &out); - } - if err.len() != 0 { - info!("{}", &err); - } - let es_code = match es.code() { - Some(ec) => ec, - None => { - return Err(anyhow!("get status code failed: {}", pid)); - } - }; - if es_code == 0 { - std::thread::sleep(Duration::from_millis(500)); - match check_result(pid, "attach") { - Ok(_) => { - return Ok(true); - } - Err(e) => { - return Err(anyhow!(e.to_string())); - } - } - } else { - let msg = format!( - "jvm attach exit code {} {} {} {}", - es_code, pid, &out, &err - ); - error!("{}", msg); - Err(anyhow!("{}", msg)) - } - } - Err(e) => { - Err(anyhow!(e.to_string())) - } - } -} - -pub fn jcmd(pid: i32, cmd: &'static str) -> Result> { - let java_attach = settings::RASP_JAVA_JATTACH_BIN(); - match run_async_process(Command::new(java_attach).args(&[ - pid.to_string().as_str(), - "jcmd", - cmd, - ])) { - Ok((_, out, err)) => { - // if out.len() != 0 { - // info!("{}", &out); - // } - if err.len() != 0 { - info!("pid: {}, {}", pid, &err); - } - return Ok(out.into()) - } - Err(e) => { - Err(anyhow!(e.to_string())) - } - } -} - -pub fn vm_version(pid: i32) -> Result { - return match jcmd(pid, "VM.version") { - Ok(stdout) => { - let response = String::from_utf8(stdout).unwrap_or(String::new()); - let re = Regex::new(r"JDK (\d+)\.").unwrap(); - let version = match re.captures(&response) { - Some(c) => c.get(1).map_or("", |m| m.as_str()), - None => return Err(anyhow!(String::from("can not find version"))), - }; - let version_number = match version.parse::() { - Ok(vm) => vm, - Err(e) => return Err(anyhow!(e.to_string())), - }; - Ok(version_number) - } - Err(e) => Err(anyhow!(e)), - }; -} - -pub fn prop(pid: i32) -> Result { - return match jcmd(pid, " VM.system_properties") { - Ok(stdout) => { - let response = String::from_utf8_lossy(&stdout); - let re = Regex::new(r"rasp\.probe").unwrap(); - if re.is_match(&response) { - info!("found rasp.probe"); - let re: Regex = Regex::new(r"smith\.rasp=(.*)").unwrap(); - - if let Some(captures) = re.captures(&response) { - if let Some(value_match) = captures.get(1) { - let check_str = value_match.as_str().trim().to_string(); - let unescaped_string = check_str.replace(r"\=", "="); - info!("found checkstr: {}", unescaped_string); - if unescaped_string != format!("{}-{}", RASP_VERSION, *RASP_JAVA_CHECKSUMSTR) { - return Ok(ProbeState::AttachedVersionNotMatch); - } - } - } - return Ok(ProbeState::Attached); - } - return Ok(ProbeState::NotAttach); - } - Err(e) => Err(anyhow!(e)), - }; -} - -pub fn check_result(pid: i32, need_status: &str) -> Result { - return match jcmd(pid, " VM.system_properties") { - Ok(stdout) => { - let response = String::from_utf8_lossy(&stdout); - let re: Regex = Regex::new(r"smith\.status=(.*)").unwrap(); - if let Some(captures) = re.captures(&response) { - if let Some(value_match) = captures.get(1) { - let check_status = value_match.as_str().trim().to_string(); - info!("found smith.status: {}", check_status); - if check_status != need_status { - return Err(anyhow!(check_status)); - } - } - } - Ok(true) - } - Err(e) => {Err(anyhow!(e))} - } -} - -pub fn java_detach(pid: i32) -> Result { - let java_detach = settings::RASP_JAVA_JATTACH_BIN(); - let agent = settings::RASP_JAVA_AGENT_BIN(); - let probe_param = format!("{}={};", agent, "detach"); - match run_async_process(Command::new(java_detach).args(&[ - pid.to_string().as_str(), - "load", - "instrument", - "false", - probe_param.as_str(), - ])) { - Ok((es, out, err)) => { - if out.len() != 0 { - info!("{}", &out); - } - if err.len() != 0 { - info!("{}", &err); - } - let es_code = match es.code() { - Some(ec) => ec, - None => { - return Err(anyhow!("get status code failed: {}", pid)); - } - }; - if es_code == 0 { - std::thread::sleep(Duration::from_millis(500)); - match check_result(pid, "detach") { - Ok(_) => { - return Ok(true); - } - Err(e) => { - return Err(anyhow!(e.to_string())); - } - } - } else { - let msg = format!( - "jvm detach exit code {} {} {} {}", - es_code, pid, &out, &err - ); - error!("{}", msg); - Err(anyhow!("{}", msg)) - } - } - Err(e) => { - return Err(anyhow!(e.to_string())); - } - } +use log::*; +use regex::Regex; +use std::process::Command; +use std::fs; +use std::path::Path; +use std::path::PathBuf; +use std::time::Duration; +use crate::async_command::run_async_process; +use crate::process::ProcessInfo; +use crate::runtime::{ProbeCopy, ProbeState, ProbeStateInspect}; +use crate::settings::{self, RASP_VERSION}; +use lazy_static::lazy_static; +use anyhow::{anyhow, Result, Result as AnyhowResult}; +use fs_extra::file::{copy as file_copy, remove as file_remove, CopyOptions as FileCopyOptions}; + +lazy_static! { + static ref RASP_JAVA_CHECKSUMSTR: String = { + match fs::read_to_string(settings::RASP_JAVA_CHECKSUM_PATH()) { + Ok(content) => content, + Err(e) => { + error!("failed to get Java check sum, err: {}, path: {},java may not be attach success", e, settings::RASP_JAVA_CHECKSUM_PATH()); + String::new() + } + } + }; +} +pub struct JVMProbeState {} + +impl ProbeStateInspect for JVMProbeState { + fn inspect_process(process_info: &ProcessInfo) -> Result { + match prop(process_info.pid) { + Ok(state) => { + Ok(state) + } + Err(_) => { + Ok(ProbeState::NotAttach) + } + } + } +} + +pub struct JVMProbe {} + +impl ProbeCopy for JVMProbe { + fn names() -> (Vec, Vec) { + ( + [ + settings::RASP_JAVA_JATTACH_BIN(), + settings::RASP_JAVA_PROBE_BIN(), + settings::RASP_JAVA_AGENT_BIN(), + ] + .to_vec(), + [].to_vec(), + ) + } +} + +pub struct JVMProbeNativeLib {} + +impl ProbeCopy for JVMProbeNativeLib { + #[cfg(all(target_os = "linux"))] + fn names() -> (Vec, Vec) { + ( + [ + settings::RASP_JAVA_NETTY_EPOLL_SO(), + ] + .to_vec(), + [].to_vec(), + ) + } + + #[cfg(all(target_os = "macos"))] + fn names() -> (Vec, Vec) { + ( + [ + settings::RASP_JAVA_NETTY_KQUEUQ_SO_MAC(), + settings::RASP_JAVA_NETTY_DNS_SO_MAC(), + ] + .to_vec(), + [].to_vec(), + ) + } +} + +pub fn parse_java_library_path(input: &str) -> Result, anyhow::Error> { + let xinput = input.replace("\\:", ":"); + let paths: Vec<&str> = xinput.split(":").collect(); + let mut result = Vec::with_capacity(paths.len()); + + for path in paths { + let path_buf = { + let path_str = path.to_string(); + PathBuf::from(path_str) + }; + if path_buf.exists() { + result.push(path_buf); + } else { + // Ignore non-existent paths + continue; + } + } + + Ok(result) +} + +fn copy_file_probe(from:String,to:String) -> AnyhowResult<()> { + let options = FileCopyOptions::new(); + return match file_copy(from.clone(), to.clone(), &options) { + Ok(_) => Ok(()), + Err(e) => { + warn!("can not copy: {}", e); + Err(anyhow!( + "copy failed: from {} to {}: {}", + from, + to, + e + )) + } + } +} + +fn get_last_filename(path: &str) -> Option { + Path::new(path) + .file_name() + .and_then(|name| name.to_str()) + .map(|name| name.to_string()) +} + +pub fn copy_probe_nativelib(pid:i32,dst_root:String) -> AnyhowResult<()> { + let java_library_path = jcmd(pid, " VM.system_properties").and_then(|output| { + let output_str = String::from_utf8_lossy(&output); + let lines: Vec<&str> = output_str.split("\n").collect(); + let java_library_path_line = lines.iter().find(|line| line.starts_with("java.library.path=")); + if let Some(line) = java_library_path_line { + let path = line.trim_start_matches("java.library.path="); + match parse_java_library_path(path) { + Ok(parsed_paths) => { + println!("Java library paths:{:?}",parsed_paths); + for from in JVMProbeNativeLib::names().0.iter() { + let src_path = from.clone(); + if let Some(soname) = get_last_filename(&src_path) { + let mut bIsExist = false; + println!("Last filename: {}", soname); + for path in parsed_paths.clone() { + let mut path_str = format!("{}{}",dst_root,path.display()); + let path_buf: PathBuf = path_str.into(); + println!(" {} exist", path_buf.display()); + if path_buf.join(&soname).exists() { + println!("{} exist",soname); + bIsExist = true; + break; + } + } + + if !bIsExist { + let path = parsed_paths[0].clone(); + + let dst_path = format!("{}{}/{}",dst_root,path.display(),soname); + println!("copy {} to {}",src_path,dst_path); + copy_file_probe(src_path,dst_path); + } + } + } + } + Err(e) => { + info!("parse java library path failed: {}", e); + } + } + + Ok(0) + } else { + Err(anyhow::anyhow!("java.library.path not found in output")) + } + }); + + Ok(()) +} + +pub fn java_attach(pid: i32) -> Result { + let java_attach = settings::RASP_JAVA_JATTACH_BIN(); + let agent = settings::RASP_JAVA_AGENT_BIN(); + let probe_param = format!("{}={};{};{};", agent, "attach", *RASP_JAVA_CHECKSUMSTR, settings::RASP_JAVA_PROBE_BIN()); + match run_async_process(Command::new(java_attach).args(&[ + pid.to_string().as_str(), + "load", + "instrument", + "false", + probe_param.as_str(), + ])) { + Ok((es, out, err)) => { + if out.len() != 0 { + info!("{}", &out); + } + if err.len() != 0 { + info!("{}", &err); + } + let es_code = match es.code() { + Some(ec) => ec, + None => { + return Err(anyhow!("get status code failed: {}", pid)); + } + }; + if es_code == 0 { + std::thread::sleep(Duration::from_millis(500)); + match check_result(pid, "attach") { + Ok(_) => { + return Ok(true); + } + Err(e) => { + return Err(anyhow!(e.to_string())); + } + } + } else { + let msg = format!( + "jvm attach exit code {} {} {} {}", + es_code, pid, &out, &err + ); + error!("{}", msg); + Err(anyhow!("{}", msg)) + } + } + Err(e) => { + Err(anyhow!(e.to_string())) + } + } +} + +pub fn jcmd(pid: i32, cmd: &'static str) -> Result> { + let java_attach = settings::RASP_JAVA_JATTACH_BIN(); + match run_async_process(Command::new(java_attach).args(&[ + pid.to_string().as_str(), + "jcmd", + cmd, + ])) { + Ok((_, out, err)) => { + // if out.len() != 0 { + // info!("{}", &out); + // } + if err.len() != 0 { + info!("pid: {}, {}", pid, &err); + } + return Ok(out.into()) + } + Err(e) => { + Err(anyhow!(e.to_string())) + } + } +} + +pub fn vm_version(pid: i32) -> Result { + return match jcmd(pid, "VM.version") { + Ok(stdout) => { + let response = String::from_utf8(stdout).unwrap_or(String::new()); + let re = Regex::new(r"JDK (\d+)\.").unwrap(); + let version = match re.captures(&response) { + Some(c) => c.get(1).map_or("", |m| m.as_str()), + None => return Err(anyhow!(String::from("can not find version"))), + }; + let version_number = match version.parse::() { + Ok(vm) => vm, + Err(e) => return Err(anyhow!(e.to_string())), + }; + Ok(version_number) + } + Err(e) => Err(anyhow!(e)), + }; +} + +pub fn prop(pid: i32) -> Result { + return match jcmd(pid, " VM.system_properties") { + Ok(stdout) => { + let response = String::from_utf8_lossy(&stdout); + let re = Regex::new(r"rasp\.probe").unwrap(); + if re.is_match(&response) { + info!("found rasp.probe"); + let re: Regex = Regex::new(r"smith\.rasp=(.*)").unwrap(); + + if let Some(captures) = re.captures(&response) { + if let Some(value_match) = captures.get(1) { + let check_str = value_match.as_str().trim().to_string(); + let unescaped_string = check_str.replace(r"\=", "="); + info!("found checkstr: {}", unescaped_string); + if unescaped_string != format!("{}-{}", RASP_VERSION, *RASP_JAVA_CHECKSUMSTR) { + return Ok(ProbeState::AttachedVersionNotMatch); + } + } + } + return Ok(ProbeState::Attached); + } + return Ok(ProbeState::NotAttach); + } + Err(e) => Err(anyhow!(e)), + }; +} + +pub fn check_result(pid: i32, need_status: &str) -> Result { + return match jcmd(pid, " VM.system_properties") { + Ok(stdout) => { + let response = String::from_utf8_lossy(&stdout); + let re: Regex = Regex::new(r"smith\.status=(.*)").unwrap(); + if let Some(captures) = re.captures(&response) { + if let Some(value_match) = captures.get(1) { + let check_status = value_match.as_str().trim().to_string(); + info!("found smith.status: {}", check_status); + if check_status != need_status { + return Err(anyhow!(check_status)); + } + } + } + Ok(true) + } + Err(e) => {Err(anyhow!(e))} + } +} + +pub fn java_detach(pid: i32) -> Result { + let java_detach = settings::RASP_JAVA_JATTACH_BIN(); + let agent = settings::RASP_JAVA_AGENT_BIN(); + let probe_param = format!("{}={};", agent, "detach"); + match run_async_process(Command::new(java_detach).args(&[ + pid.to_string().as_str(), + "load", + "instrument", + "false", + probe_param.as_str(), + ])) { + Ok((es, out, err)) => { + if out.len() != 0 { + info!("{}", &out); + } + if err.len() != 0 { + info!("{}", &err); + } + let es_code = match es.code() { + Some(ec) => ec, + None => { + return Err(anyhow!("get status code failed: {}", pid)); + } + }; + if es_code == 0 { + std::thread::sleep(Duration::from_millis(500)); + match check_result(pid, "detach") { + Ok(_) => { + return Ok(true); + } + Err(e) => { + return Err(anyhow!(e.to_string())); + } + } + } else { + let msg = format!( + "jvm detach exit code {} {} {} {}", + es_code, pid, &out, &err + ); + error!("{}", msg); + Err(anyhow!("{}", msg)) + } + } + Err(e) => { + return Err(anyhow!(e.to_string())); + } + } } \ No newline at end of file diff --git a/rasp/librasp/src/manager.rs b/rasp/librasp/src/manager.rs index a045701ff..43cbeadf9 100644 --- a/rasp/librasp/src/manager.rs +++ b/rasp/librasp/src/manager.rs @@ -1,971 +1,974 @@ -use std::collections::HashMap; -use std::ffi::OsStr; -use std::fs; -use std::path::Path; - -use anyhow::{anyhow, Result, Result as AnyhowResult}; -use crossbeam::channel::Sender; -use fs_extra::dir::{copy, create_all, CopyOptions}; -use fs_extra::file::{copy as file_copy, CopyOptions as FileCopyOptions}; -use libraspserver::proto::{PidMissingProbeConfig, ProbeConfigData}; -use log::*; - -use crate::cpython::{python_attach, CPythonProbe, CPythonProbeState}; -use crate::golang::{golang_attach, GolangProbe, GolangProbeState}; -use crate::jvm::{java_attach, java_detach, JVMProbe, JVMProbeState}; -use crate::nodejs::{nodejs_attach, NodeJSProbe}; -use crate::php::{php_attach, PHPProbeState}; -use crate::{ - comm::{Control, EbpfMode, ProcessMode, RASPComm, ThreadMode, check_need_mount}, - process::ProcessInfo, - runtime::{ProbeCopy, ProbeState, ProbeStateInspect, RuntimeInspect}, - settings, -}; - -pub struct RASPManager { - pub namespace_tracer: MntNamespaceTracer, - pub thread_comm: Option, - pub process_comm: Option, - pub ebpf_comm: Option, - pub runtime_dir: bool, -} - -impl RASPManager { - // comm - pub fn start_comm( - &mut self, - process_info: &ProcessInfo, - result_sender: Sender, - _server_log_level: String, - _server_ctrl: Control, - ) -> AnyhowResult<()> { - debug!("starting comm with probe, target pid: {}", process_info.pid); - let mnt_namespace = process_info.get_mnt_ns()?; - let nspid = if let Some(nspid) = ProcessInfo::read_nspid(process_info.pid)? { - nspid - } else { - process_info.pid - }; - if let Some(comm) = self.thread_comm.as_mut() { - comm.start_comm( - process_info.pid, - &mnt_namespace, - result_sender, - HashMap::new(), - )?; - } else if let Some(comm) = self.process_comm.as_mut() { - let mut patch_field = HashMap::new(); - let sid = process_info.sid; - let pid = process_info.pid; - let ppid = process_info.ppid; - let tgid = process_info.tgid; - let exe = process_info.exe_path.clone().unwrap_or("".to_string()); - let cmdline = process_info.cmdline.clone().unwrap_or("".to_string()); - patch_field.insert("sid", sid.to_string()); - patch_field.insert("pid", pid.to_string()); - patch_field.insert("nspid", nspid.to_string()); - patch_field.insert("ppid", ppid.to_string()); - patch_field.insert("tgid", tgid.to_string()); - patch_field.insert("argv", cmdline); - patch_field.insert("exe", exe); - patch_field.insert("ruid", process_info.ruid.to_string()); - patch_field.insert("rgid", process_info.rgid.to_string()); - patch_field.insert("euid", process_info.euid.to_string()); - patch_field.insert("egid", process_info.egid.to_string()); - patch_field.insert("suid", process_info.suid.to_string()); - patch_field.insert("sgid", process_info.sgid.to_string()); - patch_field.insert("fuid", process_info.fuid.to_string()); - patch_field.insert("fgid", process_info.fgid.to_string()); - debug!("update patch_field: {:?}", patch_field); - - // check reopen - if let Some(opened) = self.namespace_tracer.server_state(&mnt_namespace) { - if opened { - debug!("reusing stated server, mnt ns: {}", &mnt_namespace); - if let Some(runner) = comm.mnt_namesapce_server_map.get_mut(&mnt_namespace) { - runner.update_patch_field(patch_field); - } - return Ok(()); - } - } else { - comm.start_comm( - process_info.pid, - &mnt_namespace, - result_sender.clone(), - patch_field, - )?; - } - } else { - return Err(anyhow!("both thread && process comm mode not init")); - } - self.namespace_tracer - .add(mnt_namespace.clone(), process_info.pid); - self.namespace_tracer.server_state_on(mnt_namespace); - Ok(()) - } - - pub fn stop_comm(&mut self, process_info: &ProcessInfo) -> Result<()> { - let mnt_namespace = if let Some(ref ns) = process_info.namespace_info { - match ns.mnt.clone() { - Some(mnt_ns) => mnt_ns, - None => { - return Err(anyhow!("process mnt ns empty: {}", process_info.pid)); - } - } - } else { - return Err(anyhow!("fetch process ns failed: {}", process_info.pid)); - }; - let kill_check = self - .namespace_tracer - .server_state_off(&mnt_namespace, process_info.pid); - // kill server - if kill_check { - if let Some(comm) = self.thread_comm.as_mut() { - comm.stop_comm(process_info.pid, &mnt_namespace)?; - } - if let Some(comm) = self.process_comm.as_mut() { - comm.stop_comm(process_info.pid, &mnt_namespace)?; - } - } - Ok(()) - } - - pub fn patch_message_handle( - &self, - valid_messages: &mut Vec, - pid: i32, - ) -> AnyhowResult> { - for valid_m in valid_messages.iter_mut() { - if let Some(patches) = valid_m.data.patches.as_mut() { - let mut delete_index = Vec::new(); - for (index, patch) in patches.iter_mut().enumerate() { - if patch.path.is_none() { - delete_index.push(index); - continue; - } - if !self.runtime_dir { - warn!( - "due to missing runtime dir, patch ignored: {}", - patch.class_name - ); - delete_index.push(index); - continue; - } - let path_path_str = patch.path.clone().unwrap(); - let patch_path = Path::new(&path_path_str); - // check patch exist - if !patch_path.exists() { - delete_index.push(index); - continue; - } else { - let patch_file_name = patch_path - .file_name() - .unwrap_or(OsStr::new("")) - .to_string_lossy(); - if patch_file_name == "" { - delete_index.push(index); - continue; - } - let dest_path = format!("/proc/{}/root", pid); - match self.copy_file_from_to_dest(path_path_str.clone(), dest_path.clone()) - { - Ok(_) => { - patch.path = None; - patch.url = Some("file:///var/run/elkeid-agent/rasp/".to_string()); - } - Err(e) => { - error!("copy patch failed: {}", e); - delete_index.push(index); - continue; - } - } - } - } - for index in delete_index.iter() { - patches.remove(*index); - } - } - } - Ok(valid_messages.clone()) - } - - pub fn send_message_to_probe( - &mut self, - pid: i32, - mnt_namespace: &String, - message: &String, - ) -> AnyhowResult<()> { - // try to write probe to dir - let nspid = ProcessInfo::read_nspid(pid)?.ok_or(anyhow!("can not fetch nspid: {}", pid))?; - debug!("send messages to probe: {} {} {}", pid, nspid, &message); - // send through sock - let mut messages: Vec = - serde_json::from_str(message)?; - let mut valid_messages: Vec = Vec::new(); - if messages.len() <= 0 { - for message_type in [6, 7, 8, 9, 12, 13, 14] { - messages.push(PidMissingProbeConfig { - message_type, - data: ProbeConfigData::empty(message_type)?, - }) - } - } - for m in messages.iter() { - if let Some(uuid) = &m.data.uuid { - if uuid == "" { - valid_messages.push(PidMissingProbeConfig { - message_type: m.message_type, - data: ProbeConfigData::empty(m.message_type)?, - }); - } else { - let _ = match serde_json::to_string(&m) { - Ok(s) => s, - Err(e) => { - warn!("failed to convert json to string: {:?} {}", m, e); - continue; - } - }; - valid_messages.push(m.clone()); - } - } - else { - let _ = match serde_json::to_string(&m) { - Ok(s) => s, - Err(e) => { - warn!("failed to convert json to string: {:?} {}", m, e); - continue; - } - }; - valid_messages.push(m.clone()); - } - } - // handle patches - let valid_messages = self.patch_message_handle(&mut valid_messages, pid)?; - for valid_m in valid_messages.iter() { - let m_string = match serde_json::to_string(&valid_m) { - Ok(s) => s, - Err(e) => { - warn!("failed to convert json to string: {:?} {}", valid_m, e); - continue; - } - }; - debug!("sending message: {}", m_string); - if let Some(comm) = self.thread_comm.as_mut() { - comm.send_message_to_probe(pid, mnt_namespace, &m_string)?; - } else if let Some(comm) = self.process_comm.as_mut() { - comm.send_message_to_probe(pid, mnt_namespace, &m_string)?; - } else { - return Err(anyhow!("both thread && process comm mode not init")); - } - } - - serde_json::to_string(&valid_messages)?; - //self.write_message_to_config_file(pid, nspid, valid_messages_string)?; - - Ok(()) - } -} - -pub const PROCESS_BALACK: &'static [&'static str] = &[ - // why don't care - "/usr/sbin", - "/lib/systemd", - "/pause", - "/bin", - "/sbin", -]; - -pub enum BPFSelect { - FORCE, - FIRST, - SECOND, - DISABLE, -} - -impl RASPManager { - // Inspect - pub fn inspect(&mut self, process_info: &ProcessInfo) -> Result<()> { - let exe_path = if let Some(p) = &process_info.exe_path { - p.clone() - } else { - return Err(anyhow!( - "missing exe path during inspect: {}", - process_info.pid - )); - }; - info!("process exe: {}", exe_path); - for proces_black_name in PROCESS_BALACK.iter() { - if exe_path.starts_with(proces_black_name) { - info!("process hit black list: {}", &proces_black_name); - return Err(anyhow!("inspecting process hit black list")); - } - } - Ok(()) - } - pub fn runtime_inspect(&mut self, process_info: &mut ProcessInfo) -> Result { - let runtime = ProcessInfo::inspect_from_process_info(process_info)?; - // update runtime - process_info.runtime = runtime; - Ok(true) - } - // Attach - pub fn attach(&mut self, process_info: &ProcessInfo, bpf: BPFSelect) -> Result<()> { - if process_info.runtime.is_none() { - let msg = "attaching to unknow runtime process"; - error!("{}", msg); - return Err(anyhow!(msg)); - } - let environ = match process_info.environ.clone() { - Some(e) => e, - None => return Err(anyhow!("can not fetch envrion {}", process_info.pid)), - }; - let namespace = process_info.namespace_info.as_ref().unwrap(); - let mnt_namespace = namespace.mnt.as_ref().unwrap(); - let runtime_info = &process_info.runtime.clone().unwrap(); - let root_dir = format!("/proc/{}/root", process_info.pid); - let pid = process_info.pid; - ProcessInfo::read_nspid(pid)?.ok_or(anyhow!("can not read nspid: {}", pid))?; - // delete config - // self.delete_config_file(pid, nspid)?; - let attach_result = match runtime_info.name { - "JVM" => match JVMProbeState::inspect_process(process_info)? { - ProbeState::Attached => { - info!("JVM attached process {}", pid); - Ok(true) - } - ProbeState::NotAttach => { - if self.can_copy(mnt_namespace) { - for from in JVMProbe::names().0.iter() { - self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; - } - for from in JVMProbe::names().1.iter() { - self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; - } - } - java_attach(process_info.pid) - } - ProbeState::AttachedVersionNotMatch => { - let mut diff_ns:bool = false; - match check_need_mount(mnt_namespace) { - Ok(value) => { - diff_ns = value; - if diff_ns { - let to = format!("{}{}",root_dir.clone(), settings::RASP_JAVA_AGENT_BIN()); - self.copy_file_from_to_dest(settings::RASP_JAVA_JATTACH_BIN(), root_dir.clone()); - self.copy_file_from_to_dest(settings::RASP_JAVA_AGENT_BIN(), root_dir.clone()); - info!("copy from jattach/SmithAgent.jar to {}", to.clone()); - } - } - Err(e) => { - warn!( - "check_need_mount failed, {}", e - ); - } - - } - - match java_detach(pid) { - Ok(result) => { - if self.can_copy(mnt_namespace) { - for from in JVMProbe::names().0.iter() { - self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; - } - for from in JVMProbe::names().1.iter() { - self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; - } - } - java_attach(pid) - } - Err(e) => { - //process_info.tracing_state = ProbeState::Attached; - Err(anyhow!(e)) - } - } - - } - }, - "CPython" => match CPythonProbeState::inspect_process(process_info)? { - ProbeState::Attached => { - info!("CPython attached process"); - Ok(true) - } - ProbeState::NotAttach => { - if self.can_copy(mnt_namespace) { - for from in CPythonProbe::names().0.iter() { - self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; - } - for from in CPythonProbe::names().1.iter() { - self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; - } - } - python_attach(process_info.pid) - } - ProbeState::AttachedVersionNotMatch => { - let msg = format!("not support CPython update version now"); - error!("{}", msg); - Err(anyhow!(msg)) - } - }, - "Golang" => match GolangProbeState::inspect_process(process_info)? { - ProbeState::Attached => { - info!("Golang attached process"); - Ok(true) - } - ProbeState::NotAttach => { - let mut golang_attach = |pid: i32, bpf: bool| -> AnyhowResult { - if bpf { - if let Some(bpf_manager) = self.ebpf_comm.as_mut() { - bpf_manager.attach(pid) - } else { - Err(anyhow!( - "FORCE BPF attach failed, golang ebpf daemon not running" - )) - } - } else { - if self.can_copy(mnt_namespace) { - for from in GolangProbe::names().0.iter() { - self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; - } - for from in GolangProbe::names().1.iter() { - self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; - } - } - golang_attach(pid) - } - }; - match bpf { - BPFSelect::FORCE => golang_attach(pid, true), - BPFSelect::DISABLE => golang_attach(pid, false), - BPFSelect::FIRST => { - let bpf_result = golang_attach(pid, true); - match bpf_result { - Ok(true) => Ok(true), - Ok(false) => { - warn!("FIRST BPF attach failed, trying golang attach"); - golang_attach(pid, false) - } - Err(e) => { - warn!("FIRST BPF attach failed: {}, trying golang attach", e); - golang_attach(pid, false) - } - } - } - BPFSelect::SECOND => { - let golang_attach_result = golang_attach(pid, false); - match golang_attach_result { - Ok(true) => Ok(true), - Ok(false) => { - warn!("golang attach failed, trying BPF attach"); - golang_attach(pid, true) - } - Err(e) => { - warn!("golang attach faild: {}, trying BPF attach", e); - golang_attach(pid, true) - } - } - } - } - } - ProbeState::AttachedVersionNotMatch => { - let msg = format!("not support Golang update version now"); - error!("{}", msg); - Err(anyhow!(msg)) - } - }, - "NodeJS" => { - if self.can_copy(mnt_namespace) { - for from in NodeJSProbe::names().0.iter() { - self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; - } - for from in NodeJSProbe::names().1.iter() { - self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; - } - } - - let process_exe_file = process_info - .exe_path - .clone() - .ok_or(anyhow!("process exe path not found: {}", pid))?; - nodejs_attach(pid, &environ, &process_exe_file) - } - "PHP" => match PHPProbeState::inspect_process(&process_info)? { - ProbeState::Attached => { - info!("PHP attached process"); - Ok(true) - } - ProbeState::NotAttach => php_attach(process_info, runtime_info.version.clone()), - ProbeState::AttachedVersionNotMatch => { - let msg = format!("not support PHP update version now"); - error!("{}", msg); - Err(anyhow!(msg)) - } - }, - _ => { - let msg = format!("can not attach to runtime: `{}`", runtime_info.name); - error!("{}", msg); - return Err(anyhow!(msg)); - } - }; - match attach_result { - Ok(success) => { - if !success { - let msg = format!("attach failed: {:?}", process_info); - error!("{}", msg); - Err(anyhow!(msg)) - } else { - Ok(()) - } - } - Err(e) => Err(anyhow!(e)), - } - } - - pub fn detach(&mut self, process_info: &ProcessInfo) -> Result<()> { - if let Some(runtime) = process_info.runtime.clone() { - if runtime.name != "JVM" { - let msg = "attaching to not support runtime process"; - error!("{}, runtime: {}", msg, runtime); - return Err(anyhow!(msg)); - } - } else { - let msg = "attaching to unknow runtime process"; - error!("{}", msg); - return Err(anyhow!(msg)); - } - match java_detach(process_info.pid) { - Ok(success) => { - if !success { - let msg = format!("detach failed: {:?}", process_info); - error!("{}", msg); - Err(anyhow!(msg)) - } else { - Ok(()) - } - } - Err(e) => Err(anyhow!(e)), - } - } -} - -impl RASPManager { - pub fn init( - comm_mode: &str, - log_level: String, - ctrl: Control, - message_sender: Sender, - bind_path: String, - linking_to: Option, - using_mount: bool, - ebpf_mode: BPFSelect, - ) -> AnyhowResult { - Self::clean_prev_lib()?; - let runtime_dir = match Self::create_elkeid_rasp_dir( - &String::from("/var/run/elkeid-agent"), - &String::from("/rasp/com/security/patch"), - ) { - Ok(_) => true, - Err(e) => { - warn!("create runtime dir failed, due to: {}", e); - false - } - }; - let ebpf_manager = |ebpf_mode: BPFSelect, ctrl: Control| -> Option { - match ebpf_mode { - BPFSelect::DISABLE => None, - _ => match EbpfMode::new(ctrl) { - Ok(mut em) => { - match em.start_server() { - Ok(_) => Some(em), - Err(e) => { - error!("start golang eBPF daemon failed: {}", e); - None - } - } - }, - Err(e) => { - error!("start golang eBPF daemon not support this machine: {}", e); - None - } - }, - } - }; - match comm_mode { - "thread" => Ok(RASPManager { - thread_comm: Some(ThreadMode::new( - log_level, - ctrl.clone(), - message_sender.clone(), - bind_path, - linking_to, - using_mount, - )?), - namespace_tracer: MntNamespaceTracer::new(), - process_comm: None, - ebpf_comm: ebpf_manager(ebpf_mode, ctrl), - runtime_dir, - }), - - "server" => Ok(RASPManager { - process_comm: Some(ProcessMode::new(log_level, ctrl.clone())), - namespace_tracer: MntNamespaceTracer::new(), - thread_comm: None, - ebpf_comm: ebpf_manager(ebpf_mode, ctrl), - runtime_dir, - }), - _ => Err(anyhow!("{} is not a vaild comm mode", comm_mode)), - } - } - - fn create_elkeid_rasp_dir( - agent_runtime_path: &String, - rasp_runtime_path: &String, - ) -> AnyhowResult<()> { - info!("create rasp runtime path: {}", rasp_runtime_path); - // dose Agent create `agent_runtime_path`? - if !Path::new(agent_runtime_path).exists() { - return Err(anyhow!( - "can not found agent runtime path: {}", - agent_runtime_path - )); - } - let rasp_runtime_path_full = format!("{}{}", agent_runtime_path, rasp_runtime_path); - let path = Path::new(&rasp_runtime_path_full); - if path.exists() { - return Ok(()); - } - match fs_extra::dir::create_all(&rasp_runtime_path_full, false) { - Ok(_) => {} - Err(e) => { - warn!("create dir failed: {} {}", rasp_runtime_path_full, e); - } - }; - if !path.exists() { - return Err(anyhow!( - "can not create rasp runtime dir: {}", - rasp_runtime_path_full - )); - } - Ok(()) - } - - fn clean_prev_lib() -> AnyhowResult<()> { - info!("cleaning previous lib dir"); - for entry in read_dir("./")? { - let filename = entry.file_name().to_string_lossy().to_string(); - if filename.contains("lib-") && !filename.contains(settings::RASP_VERSION) { - info!("remove perv libs: {}", filename); - fs_extra::dir::remove(format!("./{}", filename))? - } - } - Ok(()) - } - - pub fn copy_to_dest(&self, dest_root: String) -> Result<()> { - let cwd_path = std::env::current_dir()?; - let cwd = cwd_path.to_str().unwrap(); - debug!("current dir: {}", cwd); - // check namespace before copy - match create_all(format!("{}{}", dest_root, cwd), false) { - Ok(_) => {} - Err(e) => { - warn!("create failed: {:?}", e); - } - }; - let mut options = CopyOptions::new(); - options.overwrite = true; - return match copy( - format!("{}/lib", cwd), - format!("{}/{}/", dest_root, cwd), - &options, - ) { - Ok(_) => Ok(()), - Err(e) => { - warn!("can not copy: {}", e); - Err(anyhow!("copy failed: {}", dest_root)) - } - }; - } - pub fn create_dir_if_not_exist(&self, dir: String, dest_root: String) -> AnyhowResult<()> { - let target = format!("{}{}", dest_root, dir); - if Path::new(&target).exists() { - return Ok(()); - } - create_all(format!("{}{}", dest_root, dir), true)?; - Ok(()) - } - pub fn copy_file_from_to_dest(&self, from: String, dest_root: String) -> AnyhowResult<()> { - let target = format!("{}/{}", dest_root, from); - if Path::new(&target).exists() { - return Ok(()); - } - let dir = Path::new(&from).parent().unwrap(); - self.create_dir_if_not_exist(dir.to_str().unwrap().to_string(), dest_root.clone())?; - let options = FileCopyOptions::new(); - debug!( - "copy file: {} {}", - from.clone(), - format!("{}/{}", dest_root, from) - ); - return match file_copy(from.clone(), format!("{}/{}", dest_root, from), &options) { - Ok(_) => Ok(()), - Err(e) => { - warn!("can not copy: {}", e); - Err(anyhow!( - "copy failed: from {} to {}: {}", - from, - format!("{}/{}", dest_root, from), - e - )) - } - }; - } - - pub fn remove_dir_from_to_dest(dest_root: String) -> AnyhowResult<()> { - if Path::new(&dest_root).exists() { - return match std::fs::remove_dir_all(dest_root.clone()) { - Ok(_) => { - info!("remove file: {}", dest_root); - Ok(()) - } - Err(e) => { - warn!("can not remove: {}", e); - Err(anyhow!( - "remove failed: dir {}, err: {}", - dest_root.clone(), e)) - } - } - } - return Ok(()); - } - - pub fn copy_dir_from_to_dest(&self, from: String, dest_root: String) -> AnyhowResult<()> { - let target = format!("{}{}", dest_root, from); - if Path::new(&target).exists() { - return Ok(()); - } - let dir = Path::new(&from).parent().unwrap(); - self.create_dir_if_not_exist(dir.to_str().unwrap().to_string(), dest_root.clone())?; - let mut options = CopyOptions::new(); - options.copy_inside = true; - debug!( - "copy dir: {} {}", - from.clone(), - format!("{}/{}", dest_root, from) - ); - return match copy(from.clone(), format!("{}/{}", dest_root, from), &options) { - Ok(_) => Ok(()), - Err(e) => { - warn!("can not copy: {}", e); - Err(anyhow!( - "copy failed: from {} to {}: {}", - from, - format!("{}/{}", dest_root, from), - e - )) - } - }; - } - pub fn copy_to_target_dir(&self, pid: i32, mnt_namespace: &String) -> Result<()> { - // check namespace first - if let Some(tracing) = self.namespace_tracer.server_state(&mnt_namespace) { - if tracing { - return Ok(()); - } - } - let root_dir = format!("/proc/{}/root", pid); - self.copy_to_dest(root_dir) - } - pub fn can_copy(&self, _mnt_namespace: &String) -> bool { - // !self.namespace_tracer.server_state(&mnt_namespace).is_some() - true - } - pub fn root_dir(pid: i32) -> String { - format!("/proc/{}/root", pid) - } -} - -pub struct MntNamespaceTracer { - /// {: ([, ...], )} - tracer: HashMap, bool)>, -} - -impl MntNamespaceTracer { - pub fn new() -> Self { - Self { - tracer: HashMap::, bool)>::new(), - } - } - pub fn add(&mut self, mnt_namespace: String, pid: i32) { - if let Some(value) = self.tracer.get_mut(&mnt_namespace) { - if value.0.contains(&pid) { - warn!( - "trying insert duplicate pid in mnt_namespace hashmap: {} {}", - mnt_namespace, pid - ); - } else { - value.0.push(pid) - } - } else { - let mut new_pid_vec = Vec::new(); - new_pid_vec.push(pid); - self.tracer.insert(mnt_namespace, (new_pid_vec, false)); - } - } - - pub fn delete_namespace(&mut self, mnt_namespace: String) { - self.tracer.remove(&mnt_namespace); - } - - pub fn delete_pid(&mut self, mnt_namespace: String, pid: i32) { - if let Some(value) = self.tracer.get_mut(&mnt_namespace) { - let index = value.0.iter().position(|x| *x == pid); - if let Some(i) = index { - value.0.remove(i); - } - } - } - - pub fn server_state(&self, mnt_namespace: &String) -> Option { - if let Some(value) = self.tracer.get(mnt_namespace) { - return Some(value.1); - } - None - } - pub fn server_state_on(&mut self, mnt_namespace: String) { - if let Some(value) = self.tracer.get_mut(&mnt_namespace) { - value.1 = true - } - } - /// return boolean value for kill server process or not - pub fn server_state_off(&mut self, mnt_namespace: &String, pid: i32) -> bool { - if let Some(value) = self.tracer.get_mut(mnt_namespace) { - if value.0.contains(&pid) { - let index = value.0.iter().position(|x| *x == pid); - if let Some(i) = index { - value.0.remove(i); - } - } - if value.0.len() == 0 { - self.tracer.remove(mnt_namespace); - return true; - } - } - return false; - } -} - -impl RASPManager { - /* - pub fn write_message_to_config_file( - &self, - pid: i32, - nspid: i32, - message: String, - ) -> AnyhowResult<()> { - let config_dir = "/var/run/elkeid_rasp"; - let config_path = format!("{}/{}.json", config_dir, nspid); - let config_path_bak = format!("{}.bak", config_path); - debug!("write message to {} {}", config_path_bak, message); - crate::async_command::run_async_process( - Command::new(crate::settings::RASP_NS_ENTER_BIN()).args([ - "-m", - "-t", - pid.to_string().as_str(), - "sh", - "-c", - "PATH=/bin:/usr/bin:/sbin", - format!( - "mkdir -p {} && echo '{}' > {} && mv {} {}", - config_dir, message, config_path_bak, config_path_bak, config_path - ) - .as_str(), - ]), - )?; - let ns_thread = thread::Builder::new().spawn(move || -> AnyhowResult<()> { - debug!("switch namespace"); - libraspserver::ns::switch_namespace(pid); - if !Path::new(&config_dir).exists() { - fs_extra::dir::create(config_dir, true)?; - } - fs_extra::file::write_all(&config_path_bak, message.as_str())?; - let mut option = fs_extra::file::CopyOptions::new(); - option.overwrite = true; - fs_extra::file::move_file(config_path_bak, config_path, &option)?; - Ok(()) - }).unwrap(); - ns_thread.join()?; - - Ok(()) - } - - pub fn delete_config_file(&self, pid: i32, nspid: i32) -> AnyhowResult<()> { - let config_path = format!("/var/run/elkeid_rasp/{}.json", nspid); - if Path::new(&config_path).exists() { - crate::async_command::run_async_process( - Command::new(crate::settings::RASP_NS_ENTER_BIN()).args([ - "-m", - "-t", - pid.to_string().as_str(), - "sh", - "-c", - format!("rm {}", config_path).as_str(), - ]), - )?; - } - Ok(()) - } - */ -} - -fn read_dir

(path: P) -> AnyhowResult> -where - P: AsRef, -{ - fs::read_dir(&path) - .map_err(|err| anyhow!("Failed to read file '{:?}': {}", path.as_ref(), err))? - .map(|entry| { - entry.map_err(|err| anyhow!("Failed to read file '{:?}': {}", path.as_ref(), err)) - }) - .collect() -} - -#[cfg(test)] -mod tests { - use super::*; - use libraspserver::proto::ProbeConfigPatch; - - #[test] - fn patch_message() { - let fake_patch = ProbeConfigPatch { - class_name: "CVE202144228".to_string(), - url: Some("file:///var/run/elkeid_rasp/".into()), - path: Some("/run/elkeid_rasp/com/security/patch/CVE202144228".into()), - sum_hash: None, - }; - let mut fake_patches = Vec::new(); - fake_patches.push(fake_patch); - let mut fake_configs = Vec::new(); - fake_configs.push(PidMissingProbeConfig { - message_type: 9, - data: ProbeConfigData { - uuid: "fake".to_string(), - blocks: None, - filters: None, - limits: None, - patches: Some(fake_patches), - }, - }); - let fake_manager = RASPManager { - namespace_tracer: MntNamespaceTracer::new(), - thread_comm: None, - process_comm: None, - runtime_dir: false, - }; - println!("{:?}", fake_configs); - let _ = fake_manager - .patch_message_handle(&mut fake_configs, 35432) - .unwrap(); - let result = 2 + 2; - assert_eq!(result, 4); - } -} +use std::collections::HashMap; +use std::ffi::OsStr; +use std::fs; +use std::path::Path; + +use anyhow::{anyhow, Result, Result as AnyhowResult}; +use crossbeam::channel::Sender; +use fs_extra::dir::{copy, create_all, CopyOptions}; +use fs_extra::file::{copy as file_copy, CopyOptions as FileCopyOptions}; +use libraspserver::proto::{PidMissingProbeConfig, ProbeConfigData}; +use log::*; + +use crate::cpython::{python_attach, CPythonProbe, CPythonProbeState}; +use crate::golang::{golang_attach, GolangProbe, GolangProbeState}; +use crate::jvm::{copy_probe_nativelib,java_attach, java_detach, JVMProbe, JVMProbeState}; +use crate::nodejs::{nodejs_attach, NodeJSProbe}; +use crate::php::{php_attach, PHPProbeState}; +use crate::{ + comm::{Control, EbpfMode, ProcessMode, RASPComm, ThreadMode, check_need_mount}, + process::ProcessInfo, + runtime::{ProbeCopy, ProbeState, ProbeStateInspect, RuntimeInspect}, + settings, +}; + +pub struct RASPManager { + pub namespace_tracer: MntNamespaceTracer, + pub thread_comm: Option, + pub process_comm: Option, + pub ebpf_comm: Option, + pub runtime_dir: bool, +} + +impl RASPManager { + // comm + pub fn start_comm( + &mut self, + process_info: &ProcessInfo, + result_sender: Sender, + _server_log_level: String, + _server_ctrl: Control, + ) -> AnyhowResult<()> { + debug!("starting comm with probe, target pid: {}", process_info.pid); + let mnt_namespace = process_info.get_mnt_ns()?; + let nspid = if let Some(nspid) = ProcessInfo::read_nspid(process_info.pid)? { + nspid + } else { + process_info.pid + }; + if let Some(comm) = self.thread_comm.as_mut() { + comm.start_comm( + process_info.pid, + &mnt_namespace, + result_sender, + HashMap::new(), + )?; + } else if let Some(comm) = self.process_comm.as_mut() { + let mut patch_field = HashMap::new(); + let sid = process_info.sid; + let pid = process_info.pid; + let ppid = process_info.ppid; + let tgid = process_info.tgid; + let exe = process_info.exe_path.clone().unwrap_or("".to_string()); + let cmdline = process_info.cmdline.clone().unwrap_or("".to_string()); + patch_field.insert("sid", sid.to_string()); + patch_field.insert("pid", pid.to_string()); + patch_field.insert("nspid", nspid.to_string()); + patch_field.insert("ppid", ppid.to_string()); + patch_field.insert("tgid", tgid.to_string()); + patch_field.insert("argv", cmdline); + patch_field.insert("exe", exe); + patch_field.insert("ruid", process_info.ruid.to_string()); + patch_field.insert("rgid", process_info.rgid.to_string()); + patch_field.insert("euid", process_info.euid.to_string()); + patch_field.insert("egid", process_info.egid.to_string()); + patch_field.insert("suid", process_info.suid.to_string()); + patch_field.insert("sgid", process_info.sgid.to_string()); + patch_field.insert("fuid", process_info.fuid.to_string()); + patch_field.insert("fgid", process_info.fgid.to_string()); + debug!("update patch_field: {:?}", patch_field); + + // check reopen + if let Some(opened) = self.namespace_tracer.server_state(&mnt_namespace) { + if opened { + debug!("reusing stated server, mnt ns: {}", &mnt_namespace); + if let Some(runner) = comm.mnt_namesapce_server_map.get_mut(&mnt_namespace) { + runner.update_patch_field(patch_field); + } + return Ok(()); + } + } else { + comm.start_comm( + process_info.pid, + &mnt_namespace, + result_sender.clone(), + patch_field, + )?; + } + } else { + return Err(anyhow!("both thread && process comm mode not init")); + } + self.namespace_tracer + .add(mnt_namespace.clone(), process_info.pid); + self.namespace_tracer.server_state_on(mnt_namespace); + Ok(()) + } + + pub fn stop_comm(&mut self, process_info: &ProcessInfo) -> Result<()> { + let mnt_namespace = if let Some(ref ns) = process_info.namespace_info { + match ns.mnt.clone() { + Some(mnt_ns) => mnt_ns, + None => { + return Err(anyhow!("process mnt ns empty: {}", process_info.pid)); + } + } + } else { + return Err(anyhow!("fetch process ns failed: {}", process_info.pid)); + }; + let kill_check = self + .namespace_tracer + .server_state_off(&mnt_namespace, process_info.pid); + // kill server + if kill_check { + if let Some(comm) = self.thread_comm.as_mut() { + comm.stop_comm(process_info.pid, &mnt_namespace)?; + } + if let Some(comm) = self.process_comm.as_mut() { + comm.stop_comm(process_info.pid, &mnt_namespace)?; + } + } + Ok(()) + } + + pub fn patch_message_handle( + &self, + valid_messages: &mut Vec, + pid: i32, + ) -> AnyhowResult> { + for valid_m in valid_messages.iter_mut() { + if let Some(patches) = valid_m.data.patches.as_mut() { + let mut delete_index = Vec::new(); + for (index, patch) in patches.iter_mut().enumerate() { + if patch.path.is_none() { + delete_index.push(index); + continue; + } + if !self.runtime_dir { + warn!( + "due to missing runtime dir, patch ignored: {}", + patch.class_name + ); + delete_index.push(index); + continue; + } + let path_path_str = patch.path.clone().unwrap(); + let patch_path = Path::new(&path_path_str); + // check patch exist + if !patch_path.exists() { + delete_index.push(index); + continue; + } else { + let patch_file_name = patch_path + .file_name() + .unwrap_or(OsStr::new("")) + .to_string_lossy(); + if patch_file_name == "" { + delete_index.push(index); + continue; + } + let dest_path = format!("/proc/{}/root", pid); + match self.copy_file_from_to_dest(path_path_str.clone(), dest_path.clone()) + { + Ok(_) => { + patch.path = None; + patch.url = Some("file:///var/run/elkeid-agent/rasp/".to_string()); + } + Err(e) => { + error!("copy patch failed: {}", e); + delete_index.push(index); + continue; + } + } + } + } + for index in delete_index.iter() { + patches.remove(*index); + } + } + } + Ok(valid_messages.clone()) + } + + pub fn send_message_to_probe( + &mut self, + pid: i32, + mnt_namespace: &String, + message: &String, + ) -> AnyhowResult<()> { + // try to write probe to dir + let nspid = ProcessInfo::read_nspid(pid)?.ok_or(anyhow!("can not fetch nspid: {}", pid))?; + debug!("send messages to probe: {} {} {}", pid, nspid, &message); + // send through sock + let mut messages: Vec = + serde_json::from_str(message)?; + let mut valid_messages: Vec = Vec::new(); + if messages.len() <= 0 { + for message_type in [6, 7, 8, 9, 12, 13, 14] { + messages.push(PidMissingProbeConfig { + message_type, + data: ProbeConfigData::empty(message_type)?, + }) + } + } + for m in messages.iter() { + if let Some(uuid) = &m.data.uuid { + if uuid == "" { + valid_messages.push(PidMissingProbeConfig { + message_type: m.message_type, + data: ProbeConfigData::empty(m.message_type)?, + }); + } else { + let _ = match serde_json::to_string(&m) { + Ok(s) => s, + Err(e) => { + warn!("failed to convert json to string: {:?} {}", m, e); + continue; + } + }; + valid_messages.push(m.clone()); + } + } + else { + let _ = match serde_json::to_string(&m) { + Ok(s) => s, + Err(e) => { + warn!("failed to convert json to string: {:?} {}", m, e); + continue; + } + }; + valid_messages.push(m.clone()); + } + } + // handle patches + let valid_messages = self.patch_message_handle(&mut valid_messages, pid)?; + for valid_m in valid_messages.iter() { + let m_string = match serde_json::to_string(&valid_m) { + Ok(s) => s, + Err(e) => { + warn!("failed to convert json to string: {:?} {}", valid_m, e); + continue; + } + }; + debug!("sending message: {}", m_string); + if let Some(comm) = self.thread_comm.as_mut() { + comm.send_message_to_probe(pid, mnt_namespace, &m_string)?; + } else if let Some(comm) = self.process_comm.as_mut() { + comm.send_message_to_probe(pid, mnt_namespace, &m_string)?; + } else { + return Err(anyhow!("both thread && process comm mode not init")); + } + } + + serde_json::to_string(&valid_messages)?; + //self.write_message_to_config_file(pid, nspid, valid_messages_string)?; + + Ok(()) + } +} + +pub const PROCESS_BALACK: &'static [&'static str] = &[ + // why don't care + "/usr/sbin", + "/lib/systemd", + "/pause", + "/bin", + "/sbin", +]; + +pub enum BPFSelect { + FORCE, + FIRST, + SECOND, + DISABLE, +} + +impl RASPManager { + // Inspect + pub fn inspect(&mut self, process_info: &ProcessInfo) -> Result<()> { + let exe_path = if let Some(p) = &process_info.exe_path { + p.clone() + } else { + return Err(anyhow!( + "missing exe path during inspect: {}", + process_info.pid + )); + }; + info!("process exe: {}", exe_path); + for proces_black_name in PROCESS_BALACK.iter() { + if exe_path.starts_with(proces_black_name) { + info!("process hit black list: {}", &proces_black_name); + return Err(anyhow!("inspecting process hit black list")); + } + } + Ok(()) + } + pub fn runtime_inspect(&mut self, process_info: &mut ProcessInfo) -> Result { + let runtime = ProcessInfo::inspect_from_process_info(process_info)?; + // update runtime + process_info.runtime = runtime; + Ok(true) + } + // Attach + pub fn attach(&mut self, process_info: &ProcessInfo, bpf: BPFSelect) -> Result<()> { + if process_info.runtime.is_none() { + let msg = "attaching to unknow runtime process"; + error!("{}", msg); + return Err(anyhow!(msg)); + } + let environ = match process_info.environ.clone() { + Some(e) => e, + None => return Err(anyhow!("can not fetch envrion {}", process_info.pid)), + }; + let namespace = process_info.namespace_info.as_ref().unwrap(); + let mnt_namespace = namespace.mnt.as_ref().unwrap(); + let runtime_info = &process_info.runtime.clone().unwrap(); + let root_dir = format!("/proc/{}/root", process_info.pid); + let pid = process_info.pid; + ProcessInfo::read_nspid(pid)?.ok_or(anyhow!("can not read nspid: {}", pid))?; + // delete config + // self.delete_config_file(pid, nspid)?; + let attach_result = match runtime_info.name { + "JVM" => match JVMProbeState::inspect_process(process_info)? { + ProbeState::Attached => { + info!("JVM attached process {}", pid); + Ok(true) + } + ProbeState::NotAttach => { + if self.can_copy(mnt_namespace) { + for from in JVMProbe::names().0.iter() { + self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; + } + for from in JVMProbe::names().1.iter() { + self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; + } + } + copy_probe_nativelib(process_info.pid,root_dir.clone())?; + + java_attach(process_info.pid) + } + ProbeState::AttachedVersionNotMatch => { + let mut diff_ns:bool = false; + match check_need_mount(mnt_namespace) { + Ok(value) => { + diff_ns = value; + if diff_ns { + let to = format!("{}{}",root_dir.clone(), settings::RASP_JAVA_AGENT_BIN()); + self.copy_file_from_to_dest(settings::RASP_JAVA_JATTACH_BIN(), root_dir.clone()); + self.copy_file_from_to_dest(settings::RASP_JAVA_AGENT_BIN(), root_dir.clone()); + info!("copy from jattach/SmithAgent.jar to {}", to.clone()); + } + } + Err(e) => { + warn!( + "check_need_mount failed, {}", e + ); + } + + } + + match java_detach(pid) { + Ok(result) => { + if self.can_copy(mnt_namespace) { + for from in JVMProbe::names().0.iter() { + self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; + } + for from in JVMProbe::names().1.iter() { + self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; + } + } + copy_probe_nativelib(process_info.pid,root_dir.clone())?; + java_attach(pid) + } + Err(e) => { + //process_info.tracing_state = ProbeState::Attached; + Err(anyhow!(e)) + } + } + + } + }, + "CPython" => match CPythonProbeState::inspect_process(process_info)? { + ProbeState::Attached => { + info!("CPython attached process"); + Ok(true) + } + ProbeState::NotAttach => { + if self.can_copy(mnt_namespace) { + for from in CPythonProbe::names().0.iter() { + self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; + } + for from in CPythonProbe::names().1.iter() { + self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; + } + } + python_attach(process_info.pid) + } + ProbeState::AttachedVersionNotMatch => { + let msg = format!("not support CPython update version now"); + error!("{}", msg); + Err(anyhow!(msg)) + } + }, + "Golang" => match GolangProbeState::inspect_process(process_info)? { + ProbeState::Attached => { + info!("Golang attached process"); + Ok(true) + } + ProbeState::NotAttach => { + let mut golang_attach = |pid: i32, bpf: bool| -> AnyhowResult { + if bpf { + if let Some(bpf_manager) = self.ebpf_comm.as_mut() { + bpf_manager.attach(pid) + } else { + Err(anyhow!( + "FORCE BPF attach failed, golang ebpf daemon not running" + )) + } + } else { + if self.can_copy(mnt_namespace) { + for from in GolangProbe::names().0.iter() { + self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; + } + for from in GolangProbe::names().1.iter() { + self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; + } + } + golang_attach(pid) + } + }; + match bpf { + BPFSelect::FORCE => golang_attach(pid, true), + BPFSelect::DISABLE => golang_attach(pid, false), + BPFSelect::FIRST => { + let bpf_result = golang_attach(pid, true); + match bpf_result { + Ok(true) => Ok(true), + Ok(false) => { + warn!("FIRST BPF attach failed, trying golang attach"); + golang_attach(pid, false) + } + Err(e) => { + warn!("FIRST BPF attach failed: {}, trying golang attach", e); + golang_attach(pid, false) + } + } + } + BPFSelect::SECOND => { + let golang_attach_result = golang_attach(pid, false); + match golang_attach_result { + Ok(true) => Ok(true), + Ok(false) => { + warn!("golang attach failed, trying BPF attach"); + golang_attach(pid, true) + } + Err(e) => { + warn!("golang attach faild: {}, trying BPF attach", e); + golang_attach(pid, true) + } + } + } + } + } + ProbeState::AttachedVersionNotMatch => { + let msg = format!("not support Golang update version now"); + error!("{}", msg); + Err(anyhow!(msg)) + } + }, + "NodeJS" => { + if self.can_copy(mnt_namespace) { + for from in NodeJSProbe::names().0.iter() { + self.copy_file_from_to_dest(from.clone(), root_dir.clone())?; + } + for from in NodeJSProbe::names().1.iter() { + self.copy_dir_from_to_dest(from.clone(), root_dir.clone())?; + } + } + + let process_exe_file = process_info + .exe_path + .clone() + .ok_or(anyhow!("process exe path not found: {}", pid))?; + nodejs_attach(pid, &environ, &process_exe_file) + } + "PHP" => match PHPProbeState::inspect_process(&process_info)? { + ProbeState::Attached => { + info!("PHP attached process"); + Ok(true) + } + ProbeState::NotAttach => php_attach(process_info, runtime_info.version.clone()), + ProbeState::AttachedVersionNotMatch => { + let msg = format!("not support PHP update version now"); + error!("{}", msg); + Err(anyhow!(msg)) + } + }, + _ => { + let msg = format!("can not attach to runtime: `{}`", runtime_info.name); + error!("{}", msg); + return Err(anyhow!(msg)); + } + }; + match attach_result { + Ok(success) => { + if !success { + let msg = format!("attach failed: {:?}", process_info); + error!("{}", msg); + Err(anyhow!(msg)) + } else { + Ok(()) + } + } + Err(e) => Err(anyhow!(e)), + } + } + + pub fn detach(&mut self, process_info: &ProcessInfo) -> Result<()> { + if let Some(runtime) = process_info.runtime.clone() { + if runtime.name != "JVM" { + let msg = "attaching to not support runtime process"; + error!("{}, runtime: {}", msg, runtime); + return Err(anyhow!(msg)); + } + } else { + let msg = "attaching to unknow runtime process"; + error!("{}", msg); + return Err(anyhow!(msg)); + } + match java_detach(process_info.pid) { + Ok(success) => { + if !success { + let msg = format!("detach failed: {:?}", process_info); + error!("{}", msg); + Err(anyhow!(msg)) + } else { + Ok(()) + } + } + Err(e) => Err(anyhow!(e)), + } + } +} + +impl RASPManager { + pub fn init( + comm_mode: &str, + log_level: String, + ctrl: Control, + message_sender: Sender, + bind_path: String, + linking_to: Option, + using_mount: bool, + ebpf_mode: BPFSelect, + ) -> AnyhowResult { + Self::clean_prev_lib()?; + let runtime_dir = match Self::create_elkeid_rasp_dir( + &String::from("/var/run/elkeid-agent"), + &String::from("/rasp/com/security/patch"), + ) { + Ok(_) => true, + Err(e) => { + warn!("create runtime dir failed, due to: {}", e); + false + } + }; + let ebpf_manager = |ebpf_mode: BPFSelect, ctrl: Control| -> Option { + match ebpf_mode { + BPFSelect::DISABLE => None, + _ => match EbpfMode::new(ctrl) { + Ok(mut em) => { + match em.start_server() { + Ok(_) => Some(em), + Err(e) => { + error!("start golang eBPF daemon failed: {}", e); + None + } + } + }, + Err(e) => { + error!("start golang eBPF daemon not support this machine: {}", e); + None + } + }, + } + }; + match comm_mode { + "thread" => Ok(RASPManager { + thread_comm: Some(ThreadMode::new( + log_level, + ctrl.clone(), + message_sender.clone(), + bind_path, + linking_to, + using_mount, + )?), + namespace_tracer: MntNamespaceTracer::new(), + process_comm: None, + ebpf_comm: ebpf_manager(ebpf_mode, ctrl), + runtime_dir, + }), + + "server" => Ok(RASPManager { + process_comm: Some(ProcessMode::new(log_level, ctrl.clone())), + namespace_tracer: MntNamespaceTracer::new(), + thread_comm: None, + ebpf_comm: ebpf_manager(ebpf_mode, ctrl), + runtime_dir, + }), + _ => Err(anyhow!("{} is not a vaild comm mode", comm_mode)), + } + } + + fn create_elkeid_rasp_dir( + agent_runtime_path: &String, + rasp_runtime_path: &String, + ) -> AnyhowResult<()> { + info!("create rasp runtime path: {}", rasp_runtime_path); + // dose Agent create `agent_runtime_path`? + if !Path::new(agent_runtime_path).exists() { + return Err(anyhow!( + "can not found agent runtime path: {}", + agent_runtime_path + )); + } + let rasp_runtime_path_full = format!("{}{}", agent_runtime_path, rasp_runtime_path); + let path = Path::new(&rasp_runtime_path_full); + if path.exists() { + return Ok(()); + } + match fs_extra::dir::create_all(&rasp_runtime_path_full, false) { + Ok(_) => {} + Err(e) => { + warn!("create dir failed: {} {}", rasp_runtime_path_full, e); + } + }; + if !path.exists() { + return Err(anyhow!( + "can not create rasp runtime dir: {}", + rasp_runtime_path_full + )); + } + Ok(()) + } + + fn clean_prev_lib() -> AnyhowResult<()> { + info!("cleaning previous lib dir"); + for entry in read_dir("./")? { + let filename = entry.file_name().to_string_lossy().to_string(); + if filename.contains("lib-") && !filename.contains(settings::RASP_VERSION) { + info!("remove perv libs: {}", filename); + fs_extra::dir::remove(format!("./{}", filename))? + } + } + Ok(()) + } + + pub fn copy_to_dest(&self, dest_root: String) -> Result<()> { + let cwd_path = std::env::current_dir()?; + let cwd = cwd_path.to_str().unwrap(); + debug!("current dir: {}", cwd); + // check namespace before copy + match create_all(format!("{}{}", dest_root, cwd), false) { + Ok(_) => {} + Err(e) => { + warn!("create failed: {:?}", e); + } + }; + let mut options = CopyOptions::new(); + options.overwrite = true; + return match copy( + format!("{}/lib", cwd), + format!("{}/{}/", dest_root, cwd), + &options, + ) { + Ok(_) => Ok(()), + Err(e) => { + warn!("can not copy: {}", e); + Err(anyhow!("copy failed: {}", dest_root)) + } + }; + } + pub fn create_dir_if_not_exist(&self, dir: String, dest_root: String) -> AnyhowResult<()> { + let target = format!("{}{}", dest_root, dir); + if Path::new(&target).exists() { + return Ok(()); + } + create_all(format!("{}{}", dest_root, dir), true)?; + Ok(()) + } + pub fn copy_file_from_to_dest(&self, from: String, dest_root: String) -> AnyhowResult<()> { + let target = format!("{}/{}", dest_root, from); + if Path::new(&target).exists() { + return Ok(()); + } + let dir = Path::new(&from).parent().unwrap(); + self.create_dir_if_not_exist(dir.to_str().unwrap().to_string(), dest_root.clone())?; + let options = FileCopyOptions::new(); + debug!( + "copy file: {} {}", + from.clone(), + format!("{}/{}", dest_root, from) + ); + return match file_copy(from.clone(), format!("{}/{}", dest_root, from), &options) { + Ok(_) => Ok(()), + Err(e) => { + warn!("can not copy: {}", e); + Err(anyhow!( + "copy failed: from {} to {}: {}", + from, + format!("{}/{}", dest_root, from), + e + )) + } + }; + } + + pub fn remove_dir_from_to_dest(dest_root: String) -> AnyhowResult<()> { + if Path::new(&dest_root).exists() { + return match std::fs::remove_dir_all(dest_root.clone()) { + Ok(_) => { + info!("remove file: {}", dest_root); + Ok(()) + } + Err(e) => { + warn!("can not remove: {}", e); + Err(anyhow!( + "remove failed: dir {}, err: {}", + dest_root.clone(), e)) + } + } + } + return Ok(()); + } + + pub fn copy_dir_from_to_dest(&self, from: String, dest_root: String) -> AnyhowResult<()> { + let target = format!("{}{}", dest_root, from); + if Path::new(&target).exists() { + return Ok(()); + } + let dir = Path::new(&from).parent().unwrap(); + self.create_dir_if_not_exist(dir.to_str().unwrap().to_string(), dest_root.clone())?; + let mut options = CopyOptions::new(); + options.copy_inside = true; + debug!( + "copy dir: {} {}", + from.clone(), + format!("{}/{}", dest_root, from) + ); + return match copy(from.clone(), format!("{}/{}", dest_root, from), &options) { + Ok(_) => Ok(()), + Err(e) => { + warn!("can not copy: {}", e); + Err(anyhow!( + "copy failed: from {} to {}: {}", + from, + format!("{}/{}", dest_root, from), + e + )) + } + }; + } + pub fn copy_to_target_dir(&self, pid: i32, mnt_namespace: &String) -> Result<()> { + // check namespace first + if let Some(tracing) = self.namespace_tracer.server_state(&mnt_namespace) { + if tracing { + return Ok(()); + } + } + let root_dir = format!("/proc/{}/root", pid); + self.copy_to_dest(root_dir) + } + pub fn can_copy(&self, _mnt_namespace: &String) -> bool { + // !self.namespace_tracer.server_state(&mnt_namespace).is_some() + true + } + pub fn root_dir(pid: i32) -> String { + format!("/proc/{}/root", pid) + } +} + +pub struct MntNamespaceTracer { + /// {: ([, ...], )} + tracer: HashMap, bool)>, +} + +impl MntNamespaceTracer { + pub fn new() -> Self { + Self { + tracer: HashMap::, bool)>::new(), + } + } + pub fn add(&mut self, mnt_namespace: String, pid: i32) { + if let Some(value) = self.tracer.get_mut(&mnt_namespace) { + if value.0.contains(&pid) { + warn!( + "trying insert duplicate pid in mnt_namespace hashmap: {} {}", + mnt_namespace, pid + ); + } else { + value.0.push(pid) + } + } else { + let mut new_pid_vec = Vec::new(); + new_pid_vec.push(pid); + self.tracer.insert(mnt_namespace, (new_pid_vec, false)); + } + } + + pub fn delete_namespace(&mut self, mnt_namespace: String) { + self.tracer.remove(&mnt_namespace); + } + + pub fn delete_pid(&mut self, mnt_namespace: String, pid: i32) { + if let Some(value) = self.tracer.get_mut(&mnt_namespace) { + let index = value.0.iter().position(|x| *x == pid); + if let Some(i) = index { + value.0.remove(i); + } + } + } + + pub fn server_state(&self, mnt_namespace: &String) -> Option { + if let Some(value) = self.tracer.get(mnt_namespace) { + return Some(value.1); + } + None + } + pub fn server_state_on(&mut self, mnt_namespace: String) { + if let Some(value) = self.tracer.get_mut(&mnt_namespace) { + value.1 = true + } + } + /// return boolean value for kill server process or not + pub fn server_state_off(&mut self, mnt_namespace: &String, pid: i32) -> bool { + if let Some(value) = self.tracer.get_mut(mnt_namespace) { + if value.0.contains(&pid) { + let index = value.0.iter().position(|x| *x == pid); + if let Some(i) = index { + value.0.remove(i); + } + } + if value.0.len() == 0 { + self.tracer.remove(mnt_namespace); + return true; + } + } + return false; + } +} + +impl RASPManager { + /* + pub fn write_message_to_config_file( + &self, + pid: i32, + nspid: i32, + message: String, + ) -> AnyhowResult<()> { + let config_dir = "/var/run/elkeid_rasp"; + let config_path = format!("{}/{}.json", config_dir, nspid); + let config_path_bak = format!("{}.bak", config_path); + debug!("write message to {} {}", config_path_bak, message); + crate::async_command::run_async_process( + Command::new(crate::settings::RASP_NS_ENTER_BIN()).args([ + "-m", + "-t", + pid.to_string().as_str(), + "sh", + "-c", + "PATH=/bin:/usr/bin:/sbin", + format!( + "mkdir -p {} && echo '{}' > {} && mv {} {}", + config_dir, message, config_path_bak, config_path_bak, config_path + ) + .as_str(), + ]), + )?; + let ns_thread = thread::Builder::new().spawn(move || -> AnyhowResult<()> { + debug!("switch namespace"); + libraspserver::ns::switch_namespace(pid); + if !Path::new(&config_dir).exists() { + fs_extra::dir::create(config_dir, true)?; + } + fs_extra::file::write_all(&config_path_bak, message.as_str())?; + let mut option = fs_extra::file::CopyOptions::new(); + option.overwrite = true; + fs_extra::file::move_file(config_path_bak, config_path, &option)?; + Ok(()) + }).unwrap(); + ns_thread.join()?; + + Ok(()) + } + + pub fn delete_config_file(&self, pid: i32, nspid: i32) -> AnyhowResult<()> { + let config_path = format!("/var/run/elkeid_rasp/{}.json", nspid); + if Path::new(&config_path).exists() { + crate::async_command::run_async_process( + Command::new(crate::settings::RASP_NS_ENTER_BIN()).args([ + "-m", + "-t", + pid.to_string().as_str(), + "sh", + "-c", + format!("rm {}", config_path).as_str(), + ]), + )?; + } + Ok(()) + } + */ +} + +fn read_dir

(path: P) -> AnyhowResult> +where + P: AsRef, +{ + fs::read_dir(&path) + .map_err(|err| anyhow!("Failed to read file '{:?}': {}", path.as_ref(), err))? + .map(|entry| { + entry.map_err(|err| anyhow!("Failed to read file '{:?}': {}", path.as_ref(), err)) + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use libraspserver::proto::ProbeConfigPatch; + + #[test] + fn patch_message() { + let fake_patch = ProbeConfigPatch { + class_name: "CVE202144228".to_string(), + url: Some("file:///var/run/elkeid_rasp/".into()), + path: Some("/run/elkeid_rasp/com/security/patch/CVE202144228".into()), + sum_hash: None, + }; + let mut fake_patches = Vec::new(); + fake_patches.push(fake_patch); + let mut fake_configs = Vec::new(); + fake_configs.push(PidMissingProbeConfig { + message_type: 9, + data: ProbeConfigData { + uuid: "fake".to_string(), + blocks: None, + filters: None, + limits: None, + patches: Some(fake_patches), + }, + }); + let fake_manager = RASPManager { + namespace_tracer: MntNamespaceTracer::new(), + thread_comm: None, + process_comm: None, + runtime_dir: false, + }; + println!("{:?}", fake_configs); + let _ = fake_manager + .patch_message_handle(&mut fake_configs, 35432) + .unwrap(); + let result = 2 + 2; + assert_eq!(result, 4); + } +} diff --git a/rasp/librasp/src/settings.rs b/rasp/librasp/src/settings.rs index 19215f842..229b04a12 100644 --- a/rasp/librasp/src/settings.rs +++ b/rasp/librasp/src/settings.rs @@ -1,136 +1,166 @@ -use std::env::current_dir; -// use anyhow::{Result as AnyhowResult, anyhow}; - -pub const RASP_VERSION: &'static str = "1.0.0.1"; - -pub fn RASP_BASE_DIR() -> String { - match current_dir() { - Ok(p) => match p.to_str() { - Some(p) => String::from(p), - None => { - panic!("can not fetch current dir from env"); - } - }, - Err(e) => panic!("can not fetch current dir from env: {}", e), - } -} - -pub fn RASP_NS_ENTER_BIN() -> String { - format!("{}{}", RASP_BASE_DIR(), "/nsenter") -} -pub fn RASP_MOUNT_SCRIPT_BIN() -> String { - format!("{}{}", RASP_BASE_DIR(), "/NSMount") -} - -pub fn RASP_LIB_DIR() -> String { - format!("{}{}", RASP_BASE_DIR(), format!("/lib-{}", RASP_VERSION)) -} - -pub fn RASP_SERVER_BIN() -> String { - format!("{}{}", RASP_LIB_DIR(), "/rasp_server") -} - -pub fn RASP_PANGOLIN() -> String { - format!("{}{}", RASP_LIB_DIR(), "/pangolin") -} - -// Golang - -pub fn RASP_GOLANG() -> String { - format!("{}{}", RASP_LIB_DIR(), "/golang/go_probe") -} -pub fn RASP_GOLANG_EBPF(version: &String) -> String { - format!("{}/{}{}", RASP_LIB_DIR(), "/golang/go_probe_ebpf", version) -} - -// Python -pub fn RASP_PYTHON_LOADER() -> String { - format!("{}{}", RASP_LIB_DIR(), "/python/python_loader") -} - -pub fn RASP_PYTHON_DIR() -> String { - format!("{}{}", RASP_LIB_DIR(), "/python/rasp") -} - -pub fn RASP_PYTHON_ENTRY() -> String { - format!("{}{}", RASP_LIB_DIR(), "/python/entry.py") -} - -// JAVA -pub fn RASP_JAVA_JATTACH_BIN() -> String { - format!("{}{}", RASP_LIB_DIR(), "/java/jattach") -} - -pub fn RASP_JAVA_PROBE_BIN() -> String { - format!("{}{}", RASP_LIB_DIR(), "/java/SmithProbe.jar") -} - -pub fn RASP_JAVA_CHECKSUM_PATH() -> String { - format!("{}{}", RASP_LIB_DIR(), "/java/checksum.data") -} - -pub fn RASP_JAVA_AGENT_BIN() -> String { - format!("{}{}", RASP_LIB_DIR(), "/java/SmithAgent.jar") -} - -pub fn RASP_JAVA_DIR() -> String { - format!("{}{}", RASP_LIB_DIR(), "/java") -} -// NodeJS - -pub fn RASP_NODEJS_DIR() -> String { - format!("{}{}", RASP_LIB_DIR(), "/node") -} - -pub fn RASP_NODEJS_INJECTOR() -> String { - format!("{}{}", RASP_NODEJS_DIR(), "/injector.js") -} - -pub fn RASP_NODEJS_ENTRY() -> String { - format!("{}{}", RASP_NODEJS_DIR(), "/smith") -} - -#[allow(non_snake_case)] -pub fn RASP_PHP_PROBE(major: &str, miner: &str, zts: bool) -> Option<(String, String)> { - if match major { - "5" => match miner { - "3" => true, - "4" => true, - "5" => true, - "6" => true, - _ => false, - }, - "7" => match miner { - "0" => true, - "2" => true, - "3" => true, - "4" => true, - _ => false, - }, - "8" => match miner { - "0" => true, - "1" => true, - _ => false, - }, - _ => false, - } { - if zts { - Some(( - format!( - "{}/php/libphp_probe-{}.{}-zts.so", - RASP_LIB_DIR(), - major, - miner - ), - format!("libphp_probe-{}.{}-zts.so", major, miner), - )) - } else { - Some(( - format!("{}/php/libphp_probe-{}.{}.so", RASP_LIB_DIR(), major, miner), - format!("libphp_probe-{}.{}.so", major, miner), - )) - } - } else { - None - } -} +use std::env::current_dir; +// use anyhow::{Result as AnyhowResult, anyhow}; + +pub const RASP_VERSION: &'static str = "1.0.0.1"; + +pub fn RASP_BASE_DIR() -> String { + match current_dir() { + Ok(p) => match p.to_str() { + Some(p) => String::from(p), + None => { + panic!("can not fetch current dir from env"); + } + }, + Err(e) => panic!("can not fetch current dir from env: {}", e), + } +} + +pub fn RASP_NS_ENTER_BIN() -> String { + format!("{}{}", RASP_BASE_DIR(), "/nsenter") +} +pub fn RASP_MOUNT_SCRIPT_BIN() -> String { + format!("{}{}", RASP_BASE_DIR(), "/NSMount") +} + +pub fn RASP_LIB_DIR() -> String { + format!("{}{}", RASP_BASE_DIR(), format!("/lib-{}", RASP_VERSION)) +} + +pub fn RASP_SERVER_BIN() -> String { + format!("{}{}", RASP_LIB_DIR(), "/rasp_server") +} + +pub fn RASP_PANGOLIN() -> String { + format!("{}{}", RASP_LIB_DIR(), "/pangolin") +} + +// Golang + +pub fn RASP_GOLANG() -> String { + format!("{}{}", RASP_LIB_DIR(), "/golang/go_probe") +} +pub fn RASP_GOLANG_EBPF(version: &String) -> String { + format!("{}/{}{}", RASP_LIB_DIR(), "/golang/go_probe_ebpf", version) +} + +// Python +pub fn RASP_PYTHON_LOADER() -> String { + format!("{}{}", RASP_LIB_DIR(), "/python/python_loader") +} + +pub fn RASP_PYTHON_DIR() -> String { + format!("{}{}", RASP_LIB_DIR(), "/python/rasp") +} + +pub fn RASP_PYTHON_ENTRY() -> String { + format!("{}{}", RASP_LIB_DIR(), "/python/entry.py") +} + +// JAVA +pub fn RASP_JAVA_JATTACH_BIN() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/jattach") +} + +pub fn RASP_JAVA_PROBE_BIN() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/SmithProbe.jar") +} + +#[cfg(all(target_os = "linux", target_arch = "x86_64"))] +pub fn RASP_JAVA_NETTY_EPOLL_SO() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/nativelib/librasp_netty_transport_native_epoll_x86_64.so") +} + +#[cfg(all(target_os = "linux", target_arch = "aarch64"))] +pub fn RASP_JAVA_NETTY_EPOLL_SO() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/nativelib/librasp_netty_transport_native_epoll_aarch_64.so") +} + +#[cfg(all(target_os = "macos", target_arch = "x86_64"))] +pub fn RASP_JAVA_NETTY_KQUEUQ_SO_MAC() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/nativelib/librasp_netty_transport_native_kqueue_x86_64.jnilib") +} + +#[cfg(all(target_os = "macos", target_arch = "x86_64"))] +pub fn RASP_JAVA_NETTY_DNS_SO_MAC() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/nativelib/librasp_netty_resolver_dns_native_macos_x86_64.jnilib") +} + +#[cfg(all(target_os = "macos", target_arch = "aarch64"))] +pub fn RASP_JAVA_NETTY_KQUEUQ_SO_MAC() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/nativelib/librasp_netty_transport_native_kqueue_aarch_64.jnilib") +} + +#[cfg(all(target_os = "macos", target_arch = "aarch64"))] +pub fn RASP_JAVA_NETTY_DNS_SO_MAC() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/nativelib/librasp_netty_resolver_dns_native_macos_aarch_64.jnilib") +} + +pub fn RASP_JAVA_CHECKSUM_PATH() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/checksum.data") +} + +pub fn RASP_JAVA_AGENT_BIN() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java/SmithAgent.jar") +} + +pub fn RASP_JAVA_DIR() -> String { + format!("{}{}", RASP_LIB_DIR(), "/java") +} + +// NodeJS +pub fn RASP_NODEJS_DIR() -> String { + format!("{}{}", RASP_LIB_DIR(), "/node") +} + +pub fn RASP_NODEJS_INJECTOR() -> String { + format!("{}{}", RASP_NODEJS_DIR(), "/injector.js") +} + +pub fn RASP_NODEJS_ENTRY() -> String { + format!("{}{}", RASP_NODEJS_DIR(), "/smith") +} + +#[allow(non_snake_case)] +pub fn RASP_PHP_PROBE(major: &str, miner: &str, zts: bool) -> Option<(String, String)> { + if match major { + "5" => match miner { + "3" => true, + "4" => true, + "5" => true, + "6" => true, + _ => false, + }, + "7" => match miner { + "0" => true, + "2" => true, + "3" => true, + "4" => true, + _ => false, + }, + "8" => match miner { + "0" => true, + "1" => true, + _ => false, + }, + _ => false, + } { + if zts { + Some(( + format!( + "{}/php/libphp_probe-{}.{}-zts.so", + RASP_LIB_DIR(), + major, + miner + ), + format!("libphp_probe-{}.{}-zts.so", major, miner), + )) + } else { + Some(( + format!("{}/php/libphp_probe-{}.{}.so", RASP_LIB_DIR(), major, miner), + format!("libphp_probe-{}.{}.so", major, miner), + )) + } + } else { + None + } +}