Skip to content

Allow for configuration of session id generation and format. #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spring-session/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ dependencies {
'org.mockito:mockito-core:1.9.5',
"org.springframework:spring-test:$springVersion",
'org.easytesting:fest-assert:1.4',
"org.springframework.security:spring-security-core:$springSecurityVersion"
"org.springframework.security:spring-security-core:$springSecurityVersion",
'com.google.guava:guava:18.0'

jacoco "org.jacoco:org.jacoco.agent:0.7.2.201409121644:runtime"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public final class MapSession implements ExpiringSession, Serializable {
*/
public static final int DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS = 1800;

private String id = UUID.randomUUID().toString();
private String id;
private Map<String, Object> sessionAttrs = new HashMap<String, Object>();
private long creationTime = System.currentTimeMillis();
private long lastAccessedTime = creationTime;
Expand All @@ -57,9 +57,15 @@ public final class MapSession implements ExpiringSession, Serializable {
private int maxInactiveInterval = DEFAULT_MAX_INACTIVE_INTERVAL_SECONDS;

/**
* Creates a new instance
* Creates a new instance.
*
* @param id the session id. Can not be null.
*/
public MapSession() {
public MapSession(String id) {
if(id == null) {
throw new IllegalArgumentException("id cannot be null");
}
this.id = id;
}

/**
Expand All @@ -81,6 +87,17 @@ public MapSession(ExpiringSession session) {
this.creationTime = session.getCreationTime();
this.maxInactiveInterval = session.getMaxInactiveIntervalInSeconds();
}

/**
* Creates a new instance creating the session id using a random UUID.
* This is here for compatibility with older implementations of {@link SessionRepository}.
*
* @deprecated - {@link SessionRepository} classes should now use the constructor that passes an id.
*/
@Deprecated
public MapSession() {
this.id = UUID.randomUUID().toString();
}

public void setLastAccessedTime(long lastAccessedTime) {
this.lastAccessedTime = lastAccessedTime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.session;

import org.springframework.session.events.SessionDestroyedEvent;
import org.springframework.session.id.UUIDSessionIdStrategy;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -37,6 +38,11 @@ public class MapSessionRepository implements SessionRepository<ExpiringSession>
* If non-null, this value is used to override {@link ExpiringSession#setMaxInactiveIntervalInSeconds(int)}.
*/
private Integer defaultMaxInactiveInterval;

/**
* The strategy for generating the session id. Defaults to using a UUID.
*/
private SessionIdStrategy sessionIdStrategy = new UUIDSessionIdStrategy();

private final Map<String,ExpiringSession> sessions;

Expand Down Expand Up @@ -66,6 +72,17 @@ public MapSessionRepository(Map<String,ExpiringSession> sessions) {
public void setDefaultMaxInactiveInterval(int defaultMaxInactiveInterval) {
this.defaultMaxInactiveInterval = Integer.valueOf(defaultMaxInactiveInterval);
}

/**
* Set the strategy for generating new session ids.
* @param sessionIdStrategy The session id strategy. Can not be null.
*/
public void setSessionIdStrategy(SessionIdStrategy sessionIdStrategy) {
if (sessionIdStrategy == null) {
throw new IllegalArgumentException("sessionIdStrategy can not be null");
}
this.sessionIdStrategy = sessionIdStrategy;
}

public void save(ExpiringSession session) {
sessions.put(session.getId(), new MapSession(session));
Expand All @@ -90,10 +107,11 @@ public void delete(String id) {
}

public ExpiringSession createSession() {
ExpiringSession result = new MapSession();
ExpiringSession result = new MapSession(sessionIdStrategy.createSessionId());
if(defaultMaxInactiveInterval != null) {
result.setMaxInactiveIntervalInSeconds(defaultMaxInactiveInterval);
}
return result;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package org.springframework.session;

/**
* An interface
*
* @author Art Gramlich
*/
public interface SessionIdStrategy {

/**
* Creates a new session id.
*
* @return the new session id
*/
String createSessionId();

}
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
import org.springframework.session.ExpiringSession;
import org.springframework.session.MapSession;
import org.springframework.session.Session;
import org.springframework.session.SessionIdStrategy;
import org.springframework.session.SessionRepository;
import org.springframework.session.events.SessionDestroyedEvent;
import org.springframework.session.id.UUIDSessionIdStrategy;
import org.springframework.session.web.http.SessionRepositoryFilter;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -174,6 +176,12 @@ public class RedisOperationsSessionRepository implements SessionRepository<Redis
* If non-null, this value is used to override the default value for {@link RedisSession#setMaxInactiveIntervalInSeconds(int)}.
*/
private Integer defaultMaxInactiveInterval;

/**
* The strategy for generating the session id. Defaults to using a UUID.
*/
private SessionIdStrategy sessionIdStrategy = new UUIDSessionIdStrategy();


/**
* Allows creating an instance and uses a default {@link RedisOperations} for both managing the session and the expirations.
Expand Down Expand Up @@ -207,6 +215,10 @@ public void setDefaultMaxInactiveInterval(int defaultMaxInactiveInterval) {
this.defaultMaxInactiveInterval = defaultMaxInactiveInterval;
}

public void setSessionIdStrategy(SessionIdStrategy sessionIdStrategy) {
this.sessionIdStrategy = sessionIdStrategy;
}

public void save(RedisSession session) {
session.saveDelta();
}
Expand Down Expand Up @@ -234,8 +246,7 @@ private RedisSession getSession(String id, boolean allowExpired) {
if(entries.isEmpty()) {
return null;
}
MapSession loaded = new MapSession();
loaded.setId(id);
MapSession loaded = new MapSession(id);
for(Map.Entry<Object,Object> entry : entries.entrySet()) {
String key = (String) entry.getKey();
if(CREATION_TIME_ATTR.equals(key)) {
Expand Down Expand Up @@ -271,7 +282,7 @@ public void delete(String sessionId) {
}

public RedisSession createSession() {
RedisSession redisSession = new RedisSession();
RedisSession redisSession = new RedisSession(sessionIdStrategy.createSessionId());
if(defaultMaxInactiveInterval != null) {
redisSession.setMaxInactiveIntervalInSeconds(defaultMaxInactiveInterval);
}
Expand Down Expand Up @@ -336,8 +347,8 @@ final class RedisSession implements ExpiringSession {
/**
* Creates a new instance ensuring to mark all of the new attributes to be persisted in the next save operation.
*/
RedisSession() {
this(new MapSession());
RedisSession(String id) {
this(new MapSession(id));
delta.put(CREATION_TIME_ATTR, getCreationTime());
delta.put(MAX_INACTIVE_ATTR, getMaxInactiveIntervalInSeconds());
delta.put(LAST_ACCESSED_ATTR, getLastAccessedTime());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package org.springframework.session.id;

/**
* {@link SessionIdEncoder} that encodes as a base32 extended hex with no padding.
*
* @author Art Gramlich
*/
public class Base32HexSessionIdEncoder implements SessionIdEncoder {

private static final char[] UPPER_DIGITS = "0123456789ABCDEFGHIJKLMNOPQRSTUV".toCharArray();
private static final char[] LOWER_DIGITS = "0123456789abcdefghijklmnopqrstuv".toCharArray();

private char[] digits;

/**
* Creates a new encoder with lower case digits.
*
*/
public Base32HexSessionIdEncoder() {
this(false);
}

/**
* Creates a new encoder.
*
* @param lowerCaseDigits - If true, lower case digits are used.
*/
public Base32HexSessionIdEncoder(boolean lowerCaseDigits) {
setLowerCaseDigits(lowerCaseDigits);
}

/**
* Indicates if upper or lower case digits should be used.
*
* @param lowerCaseDigits - If true, lower case digits are used.
*/
public void setLowerCaseDigits(boolean lowerCaseDigits) {
digits = lowerCaseDigits ? LOWER_DIGITS : UPPER_DIGITS;
}

/**
* Encode the session id in a base32hex format.
*
* @param bytes the bytes representing a session id.
* @return The session id as a string.
*/
public String encode(byte[] bytes) {
int numberOfBytes = bytes.length;
int numberOfTotalBits = (numberOfBytes * 8);
int numberOfTotalDigits = (numberOfTotalBits / 5) + (numberOfTotalBits % 5 == 0 ? 0 : 1);
StringBuilder id = new StringBuilder(numberOfTotalDigits);
long fiveByteGroup;
for (int i=0; i< numberOfBytes; i+=5) {
int bytesInGroup = Math.min(numberOfBytes - i, 5);
int digitsInGroup = ((bytesInGroup * 8) / 5) + (bytesInGroup == 5 ? 0 : 1);
fiveByteGroup = 0;
for (int j=0; j<5; j++) {
byte b = (j >= bytesInGroup ? (byte)0 : bytes[i+j]);
long bits = (b & 0xffL) << (8*(4-j));
fiveByteGroup = fiveByteGroup | bits;
}
for (int j=0; j<digitsInGroup; j++) {
int digit = (int)(0x1fL & (fiveByteGroup >>> (5*(7-j))));
id.append(digits[digit]);
}
}
return id.toString();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.springframework.session.id;

/**
* {@link SessionIdEncoder} that encodes as hexidecimal digits.
*
* @author Art Gramlich
*/
public class HexSessionIdEncoder implements SessionIdEncoder {

private static final char[] LOWER_DIGITS = {
'0', '1', '2', '3', '4', '5', '6', '7',
'8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};

private static final char[] UPPER_DIGITS = {
'0', '1', '2', '3', '4', '5', '6', '7',
'8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};

private char[] digits;


/**
* Creates a new encoder with lower case digits.
*
*/
public HexSessionIdEncoder() {
this(false);
}

/**
* Creates a new encoder.
*
* @param lowerCaseDigits - If true, lower case digits are used.
*/
public HexSessionIdEncoder(boolean lowerCaseDigits) {
setLowerCaseDigits(lowerCaseDigits);
}

/**
* Indicates if upper or lower case digits should be used.
*
* @param lowerCaseDigits - If true, lower case digits are used.
*/
public void setLowerCaseDigits(boolean lowerCaseDigits) {
digits = lowerCaseDigits ? LOWER_DIGITS : UPPER_DIGITS;
}

/**
* Encode the session id as hex digits.
*
* @param bytes the bytes representing a session id.
* @return The bytes as hexidecimal digits.
*/
public String encode(byte[] bytes) {
StringBuilder id = new StringBuilder(bytes.length * 2);
for (int i=0; i < bytes.length; i++) {
id.append(digits[(0xF0 & bytes[i]) >>> 4]);
id.append(digits[(0x0F & bytes[i])]);
}
return id.toString();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.springframework.session.id;

import java.security.SecureRandom;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;

import org.springframework.session.SessionIdStrategy;

public class SecureRandomSessionIdStrategy implements SessionIdStrategy {

private final Queue<SecureRandomGenerator> randomGenerators = new ConcurrentLinkedQueue<SecureRandomGenerator>();

private int maxIterations = 100000;
private int byteLength = 32;
private SessionIdEncoder encoder = new HexSessionIdEncoder(true);

public String createSessionId() {
SecureRandomGenerator generator = randomGenerators.poll();
if (generator == null) {
generator = new SecureRandomGenerator();
}
byte[] bytes = new byte[byteLength];
generator.generate(bytes);
randomGenerators.add(generator);
String id = encoder.encode(bytes);
return id;
}

public void setByteLength(int byteLength) {
this.byteLength = byteLength;
}

public void setEncoder(SessionIdEncoder encoder) {
this.encoder = encoder;
}

private class SecureRandomGenerator {
private SecureRandom secureRandom;
private int iteration;

private SecureRandomGenerator() {
secureRandom = new SecureRandom();
secureRandom.nextInt();
}

public void generate(byte[] bytes) {
secureRandom.nextBytes(bytes);
iteration++;
if (iteration == maxIterations) {
secureRandom.setSeed(secureRandom.nextLong());
iteration = 0;
}
}
}

}
Loading