Fixed saving lesson tracker with reloadable classloader

This commit is contained in:
Nanne Baars 2017-03-22 15:51:57 +01:00
parent 259fd19c1b
commit 53d30e2274
2 changed files with 14 additions and 9 deletions

View File

@ -42,6 +42,7 @@ import java.util.List;
@Getter @Getter
public class Assignment implements Serializable { public class Assignment implements Serializable {
private static final long serialVersionUID = 5410058267505412928L;
@NonNull @NonNull
private final String name; private final String name;
@NonNull @NonNull

View File

@ -2,16 +2,14 @@
package org.owasp.webgoat.session; package org.owasp.webgoat.session;
import com.google.common.collect.Maps; import com.google.common.collect.Maps;
import com.google.common.io.ByteStreams;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.owasp.webgoat.lessons.AbstractLesson; import org.owasp.webgoat.lessons.AbstractLesson;
import org.owasp.webgoat.lessons.Assignment; import org.owasp.webgoat.lessons.Assignment;
import org.springframework.core.serializer.DefaultDeserializer; import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import java.io.File; import java.io.*;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors; import java.util.stream.Collectors;
@ -100,23 +98,29 @@ public class UserTracker {
public Map<String, LessonTracker> load() { public Map<String, LessonTracker> load() {
File file = new File(webgoatHome, user + ".progress"); File file = new File(webgoatHome, user + ".progress");
Map<String, LessonTracker> storage = Maps.newHashMap();
if (file.exists() && file.isFile()) { if (file.exists() && file.isFile()) {
try { try {
DefaultDeserializer deserializer = new DefaultDeserializer(Thread.currentThread().getContextClassLoader()); DefaultDeserializer deserializer = new DefaultDeserializer(Thread.currentThread().getContextClassLoader());
return (Map<String, LessonTracker>) deserializer.deserialize(new FileInputStream(file)); try (FileInputStream fis = new FileInputStream(file)) {
byte[] b = ByteStreams.toByteArray(fis);
storage = (Map<String, LessonTracker>) deserializer.deserialize(new ByteArrayInputStream(b));
}
} catch (Exception e) { } catch (Exception e) {
log.error("Unable to read the progress file, creating a new one..."); log.error("Unable to read the progress file, creating a new one...");
} }
} }
return Maps.newHashMap(); return storage;
} }
@SneakyThrows @SneakyThrows
private void save(Map<String, LessonTracker> storage) { private void save(Map<String, LessonTracker> storage) {
File file = new File(webgoatHome, user + ".progress"); File file = new File(webgoatHome, user + ".progress");
DefaultSerializer serializer = new DefaultSerializer();
serializer.serialize(storage, new FileOutputStream(file)); try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(file))) {
objectOutputStream.writeObject(storage);
objectOutputStream.flush();
}
} }