diff --git a/src/main/java/org/owasp/webgoat/plugins/Plugin.java b/src/main/java/org/owasp/webgoat/plugins/Plugin.java index b0f58df7e..4752b2ce1 100644 --- a/src/main/java/org/owasp/webgoat/plugins/Plugin.java +++ b/src/main/java/org/owasp/webgoat/plugins/Plugin.java @@ -50,27 +50,24 @@ public class Plugin { } public void loadClasses(Map classes) { + ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); + PluginClassLoader pluginClassLoader = new PluginClassLoader(contextClassLoader); for (Map.Entry clazz : classes.entrySet()) { - loadClass(clazz.getKey(), clazz.getValue()); + loadClass(pluginClassLoader, clazz.getKey(), clazz.getValue()); } if (lesson == null) { throw new PluginLoadingFailure(String - .format("Lesson class not found, following classes were detected in the plugin: %s", - StringUtils.collectionToCommaDelimitedString(classes.keySet()))); + .format("Lesson class not found, following classes were detected in the plugin: %s", + StringUtils.collectionToCommaDelimitedString(classes.keySet()))); } } - private void loadClass(String name, byte[] classFile) { - ClassLoader contextClassLoader = Thread.currentThread().getContextClassLoader(); - PluginClassLoader pluginClassLoader = new PluginClassLoader(contextClassLoader, name, classFile); - try { - String realClassName = name.replaceFirst("/", "").replaceAll("/", ".").replaceAll(".class", ""); - Class clazz = pluginClassLoader.loadClass(realClassName); - if (AbstractLesson.class.isAssignableFrom(clazz)) { - this.lesson = clazz; - } - } catch (ClassNotFoundException e) { - logger.error("Unable to load class {}", name, e); + private void loadClass(PluginClassLoader pluginClassLoader, String name, byte[] classFile) { + String realClassName = name.replaceFirst("/", "").replaceAll("/", ".").replaceAll(".class", ""); + + Class clazz = pluginClassLoader.loadClass(realClassName, classFile); + if (AbstractLesson.class.isAssignableFrom(clazz)) { + this.lesson = clazz; } } @@ -97,7 +94,7 @@ public class Plugin { Files.copy(file, bos); Path propertiesPath = createPropertiesDirectory(); ResourceBundleClassLoader.setPropertiesPath(propertiesPath); - if ( reload ) { + if (reload) { Files.write(propertiesPath.resolve(file.getFileName()), bos.toByteArray(), CREATE, APPEND); } else { Files.write(propertiesPath.resolve(file.getFileName()), bos.toByteArray(), CREATE, TRUNCATE_EXISTING); @@ -117,8 +114,14 @@ public class Plugin { public void rewritePaths(Path pluginTarget) { try { - PluginFileUtils.replaceInFiles(this.lesson.getSimpleName() + "_files", pluginTarget.getFileName().toString() + "/plugin/" + this.lesson.getSimpleName() + "/lessonSolutions/en/" + this.lesson.getSimpleName() + "_files", solutionLanguageFiles.values()); - PluginFileUtils.replaceInFiles(this.lesson.getSimpleName() + "_files", pluginTarget.getFileName().toString() + "/plugin/" + this.lesson.getSimpleName() + "/lessonPlans/en/" + this.lesson.getSimpleName() + "_files", lessonPlansLanguageFiles.values()); + PluginFileUtils.replaceInFiles(this.lesson.getSimpleName() + "_files", + pluginTarget.getFileName().toString() + "/plugin/" + this.lesson + .getSimpleName() + "/lessonSolutions/en/" + this.lesson.getSimpleName() + "_files", + solutionLanguageFiles.values()); + PluginFileUtils.replaceInFiles(this.lesson.getSimpleName() + "_files", + pluginTarget.getFileName().toString() + "/plugin/" + this.lesson + .getSimpleName() + "/lessonPlans/en/" + this.lesson.getSimpleName() + "_files", + lessonPlansLanguageFiles.values()); } catch (IOException e) { throw new PluginLoadingFailure("Unable to rewrite the paths in the solutions", e); } diff --git a/src/main/java/org/owasp/webgoat/plugins/PluginClassLoader.java b/src/main/java/org/owasp/webgoat/plugins/PluginClassLoader.java index 6af81a6d3..b5796c0f0 100644 --- a/src/main/java/org/owasp/webgoat/plugins/PluginClassLoader.java +++ b/src/main/java/org/owasp/webgoat/plugins/PluginClassLoader.java @@ -1,22 +1,42 @@ package org.owasp.webgoat.plugins; +import com.google.common.base.Optional; +import com.google.common.base.Predicate; +import com.google.common.collect.FluentIterable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.ArrayList; +import java.util.List; + public class PluginClassLoader extends ClassLoader { + private final List> classes = new ArrayList<>(); private final Logger logger = LoggerFactory.getLogger(Plugin.class); - private final byte[] classFile; - public PluginClassLoader(ClassLoader parent, String nameOfClass, byte[] classFile) { - super(parent); - logger.debug("Creating class loader for {}", nameOfClass); - this.classFile = classFile; + public Class loadClass(String nameOfClass, byte[] classFile) { + Class clazz = defineClass(nameOfClass, classFile, 0, classFile.length); + classes.add(clazz); + return clazz; } - public Class findClass(String name) { + public PluginClassLoader(ClassLoader contextClassLoader) { + super(contextClassLoader); + } + + public Class findClass(final String name) throws ClassNotFoundException { logger.debug("Finding class " + name); - return defineClass(name, classFile, 0, classFile.length); + Optional> foundClass = FluentIterable.from(classes) + .firstMatch(new Predicate>() { + @Override + public boolean apply(Class clazz) { + return clazz.getName().equals(name); + } + }); + if (foundClass.isPresent()) { + return foundClass.get(); + } + throw new ClassNotFoundException("Class " + name + " not found"); } }