当前位置 博文首页 > 快乐的小三菊的博客:springboot + shiro 整合 redis 解决频繁访
? ? ? ?关于频繁访问 redis ,一共分为两种情况,第一种是频繁的去 redis 中 读取 session ;另外一种是频繁的去更新 redis 中的 session ,针对这两种情况,分别写出相应的解决方案。
? ? ? ?针对于第一种情况有两种解决方式,使用本地缓存和从 request 中获取,下面分别介绍一下这两种解决方式:
解决方式一:使用本地缓存
? ? ? ?在上一篇文章中我们用到了?RedisSessionDAO 这个类,在这个类中依赖了一个叫 SessionInMemory 的类,是 shiro-redis 作者为了解决一次请求频繁访问 redis 读取 session 的解决方案,基于本地 cache,如果是在一秒内的请求,都会从本地 cache 中获取 request。我们再看下这块的代码:
public class RedisSessionDAO extends AbstractSessionDAO {
private static Logger logger = LoggerFactory.getLogger(RedisSessionDAO.class);
private static final String DEFAULT_SESSION_KEY_PREFIX = "shiro:session:";
private String keyPrefix = DEFAULT_SESSION_KEY_PREFIX;
private static final long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L;
/**
* doReadSession be called about 10 times when login.
* Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal.
* The default value is 1000 milliseconds (1s).
* Most of time, you don't need to change it.
*/
private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT;
/**
* expire time in seconds
*/
private static final int DEFAULT_EXPIRE = -2;
private static final int NO_EXPIRE = -1;
/**
* Please make sure expire is longer than sesion.getTimeout()
*/
private int expire = DEFAULT_EXPIRE;
private static final int MILLISECONDS_IN_A_SECOND = 1000;
private RedisManager redisManager;
private static ThreadLocal sessionsInThread = new ThreadLocal();
@Override
public void update(Session session) throws UnknownSessionException {
// 如果会话过期/停止 没必要再更新了
try {
if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
return;
}
if (session instanceof ShiroSession) {
// 如果没有主要字段(除lastAccessTime以外其他字段)发生改变
ShiroSession ss = (ShiroSession) session;
if (!ss.isChanged()) {
return;
}
// 如果没有返回 证明有调用 setAttribute往redis 放的时候永远设置为false
ss.setChanged(false);
}
this.saveSession(session);
} catch (Exception e) {
logger.warn("update Session is failed", e);
}
}
/**
* save session
* @param session
* @throws UnknownSessionException
*/
private void saveSession(Session session) throws UnknownSessionException {
if (session == null || session.getId() == null) {
logger.error("session or session id is null");
throw new UnknownSessionException("session or session id is null");
}
String key = getRedisSessionKey(session.getId());
if (expire == DEFAULT_EXPIRE) {
this.redisManager.set(key, session, (int) (session.getTimeout() / MILLISECONDS_IN_A_SECOND));
return;
}
if (expire != NO_EXPIRE && expire * MILLISECONDS_IN_A_SECOND < session.getTimeout()) {
logger.warn("Redis session expire time: "
+ (expire * MILLISECONDS_IN_A_SECOND)
+ " is less than Session timeout: "
+ session.getTimeout()
+ " . It may cause some problems.");
}
this.redisManager.set(key, session, expire);
}
@Override
public void delete(Session session) {
if (session == null || session.getId() == null) {
logger.error("session or session id is null");
return;
}
try {
redisManager.del(getRedisSessionKey(session.getId()));
} catch (Exception e) {
logger.error("delete session error. session id= {}",session.getId());
}
}
@Override
public Collection<Session> getActiveSessions() {
Set<Session> sessions = new HashSet<Session>();
try {
Set<String> keys = redisManager.scan(this.keyPrefix + "*");
if (keys != null && keys.size() > 0) {
for (String key:keys) {
Session s = (Session) redisManager.get(key);
sessions.add(s);
}
}
} catch (Exception e) {
logger.error("get active sessions error.");
}
return sessions;
}
public Long getActiveSessionsSize() {
Long size = 0L;
try {
size = redisManager.scanSize(this.keyPrefix + "*");
} catch (Exception e) {
logger.error("get active sessions error.");
}
return size;
}
@Override
protected Serializable doCreate(Session session) {
if (session == null) {
logger.error("session is null");
throw new UnknownSessionException("session is null");
}
Serializable sessionId = this.generateSessionId(session);
this.assignSessionId(session, sessionId);
this.saveSession(session);
return sessionId;
}
@Override
protected Session doReadSession(Serializable sessionId) {
if (sessionId == null) {
logger.warn("session id is null");
return null;
}
Session s = getSessionFromThreadLocal(sessionId);
if (s != null) {
return s;
}
logger.debug("read session from redis");
try {
s = (Session) redisManager.get(getRedisSessionKey(sessionId));
setSessionToThreadLocal(sessionId, s);
} catch (Exception e) {
logger.error("read session error. settionId= {}",sessionId);
}
return s;
}
// 将 session 存入到 ThredLocal 中
private void setSessionToThreadLocal(Serializable sessionId, Session s) {
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
if (sessionMap == null) {
sessionMap = new HashMap<Serializable, SessionInMemory>();
sessionsInThread.set(sessionMap);
}
SessionInMemory sessionInMemory = new SessionInMemory();
sessionInMemory.setCreateTime(new Date());
sessionInMemory.setSession(s);
sessionMap.put(sessionId, sessionInMemory);
}
// 获取 session
private Session getSessionFromThreadLocal(Serializable sessionId) {
Session s = null;
if (sessionsInThread.get() == null) {
return null;
}
Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
SessionInMemory sessionInMemory = sessionMap.get(sessionId);
if (sessionInMemory == null) {
return null;
}
Date now = new Date();
long duration = now.getTime() - sessionInMemory.getCreateTime().getTime();
// 判断请求的时间差,若时间差小于设定的时间,则从本地缓存中获取
if (duration < sessionInMemoryTimeout) {
s = sessionInMemory.getSession();
logger.debug("read session from memory");
} else {
sessionMap.remove(sessionId);
}
return s;
}
private String getRedisSessionKey(Serializable sessionId) {
return this.keyPrefix + sessionId;
}
public RedisManager getRedisManager() {
return redisManager;
}
public void setRedisManager(RedisManager redisManager) {
this.redisManager = redisManager;
}
public String getKeyPrefix() {
return keyPrefix;
}
public void setKeyPrefix(String keyPrefix) {
this.keyPrefix = keyPrefix;
}
public long getSessionInMemoryTimeout() {
return sessionInMemoryTimeout;
}
public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) {
this.sessionInMemoryTimeout = sessionInMemoryTimeout;
}
public int getExpire() {
return expire;
}
public void setExpire(int expire) {
this.expire = expire;
}
}
? ? ? ?再看下 SessionInMemory 的代码组成,只有两个成员变量,一个是当前的 session,另外一个是创建时间 createTime 。
import org.apache.shiro.session.Session;
import java.util.Date;
/**
* Use ThreadLocal as a temporary storage of Session, so that shiro wouldn't keep read redis several times while a request coming.
*/
public class SessionInMemory {
private Session session;
private Date createTime;
public Session getSession() {
return session;
}
public void setSession(Session session) {
this.session = session;
}
public Date getCreateTime() {
return createTime;
}
public void setCreateTime(Date createTime) {
this.createTime = createTime;
}
}
解决方式二:从 request 中获取 session
? ? ? ?另外一个更好的解决方案是重写?DefaultWebSessionManager 类的 retrieveSession() 方法。在 Web 下使用 shiro 时,这个 sessionKey 是 WebSessionKey 类型的,这个类有个我们很熟悉的属性:servletRequest?。我们可以直接把 session 对象怼进 request 里去!那么在单次请求周期内我们都可以从 request 中取 session 了,而且请求结束后 request 被销毁,作用域和生命周期的问题都不需要我们考虑了。 所以我们需要 Override 这个 retrieveSession() 方法,为此我们需要使用自定义的 SessionManager,如下:
import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.SessionKey;
import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;
import org.apache.shiro.web.session.mgt.WebSessionKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.ServletRequest;
import java.io.Serializable;
/**
* @description: 解决单次请求需要多次访问redis
*/
public class ShiroSessionManager extends DefaultWebSessionManager {
private static Logger logger = LoggerFactory.getLogger(DefaultWebSessionManager.class);
/**
* 获取session
* 优化单次请求需要多次访问redis的问题
* @param sessionKey
* @return
* @throws UnknownSessionException
*/
@Override
protected Session retrieveSession(SessionKey sessionKey) throws UnknownSessionException {
Serializable sessionId = getSessionId(sessionKey);
ServletRequest request = null;
if (sessionKey instanceof WebSessionKey) {
request = ((WebSessionKey) sessionKey).getServletRequest();
}
if (request != null && null != sessionId) {
Object sessionObj = request.getAttribute(sessionId.toString());
if (sessionObj != null) {
logger.debug("read session from request");
return (Session) sessionObj;
}
}
Session session = super.retrieveSession(sessionKey);
if (request != null && null != sessionId) {
request.setAttribute(sessionId.toString(), session);
}
return session;
}
}
? ? ? ?还需要记得在 ShiroConfig 中配置 SessionManager 为自定义的 ShiroSessionManager 。
? ? ? ?当 session 数据发生变化时,就会更新 redis 中的 session,但在大多数的情况下发生变化的只是 session 中的?LastAccessTime (最后一次访问时间)字段。由于 redis 中的?session 失效是由其数据过期实现的,所以在 redis 中只更新 LastAccessTime 这个字段意义不大,反而增加了 redis 的压力。为了减少对 redis 的访问,降低网络压力,当只有这个字段发生变化时,不去更新 redis 中的 session 。若在 session 中发生了除?LastAccessTime 字段以外其他的字段发生改变。这个时候我们就可以增加一个标识位,只有标识为修改的时候才让?redis 做更新,否则直接返回。
? ? ? ?我们需要在?SimpleSession 上套一层,增加一个标识位?isChanged ,具体的代码如下所示:
import org.apache.shiro.session.mgt.SimpleSession;
import java.io.Serializable;
import java.util.Date;
import java.util.Map;
/**
* 由于SimpleSession lastAccessTime更改后也会调用SessionDao update方法,
* 增加标识位,如果只是更新lastAccessTime SessionDao update方法直接返回
*/
public class ShiroSession extends SimpleSession implements Serializable {
// 除lastAccessTime以外其他字段发生改变时为true
private boolean isChanged = false;
public ShiroSession() {
super();
this.setChanged(true);
}
public ShiroSession(String host) {
super(host);
this.setChanged(true);
}
@Override
public void setId(Serializable id) {
super.setId(id);
this.setChanged(true);
}
@Override
public void setStopTimestamp(Date stopTimestamp) {
super.setStopTimestamp(stopTimestamp);
this.setChanged(true);
}
@Override
public void setExpired(boolean expired) {
super.setExpired(expired);
this.setChanged(true);
}
@Override
public void setTimeout(long timeout) {
super.setTimeout(timeout);
this.setChanged(true);
}
@Override
public void setHost(String host) {
super.setHost(host);
this.setChanged(true);
}
@Override
public void setAttributes(Map<Object, Object> attributes) {
super.setAttributes(attributes);
this.setChanged(true);
}
@Override
public void setAttribute(Object key, Object value) {
super.setAttribute(key, value);
this.setChanged(true);
}
@Override
public Object removeAttribute(Object key) {
this.setChanged(true);
return super.removeAttribute(key);
}
/**
* 停止
*/
@Override
public void stop() {
super.stop();
this.setChanged(true);
}
/**
* 设置过期
*/
@Override
protected void expire() {
this.stop();
this.setExpired(true);
}
public boolean isChanged() {
return isChanged;
}
public void setChanged(boolean isChanged) {
this.isChanged = isChanged;
}
@Override
public boolean equals(Object obj) {
return super.equals(obj);
}
@Override
protected boolean onEquals(SimpleSession ss) {
return super.onEquals(ss);
}
@Override
public int hashCode() {
return super.hashCode();
}
@Override
public String toString() {
return super.toString();
}
}
? ? ? ?编写类?ShiroSessionFactory 来实现?SessionFactory 的接口,实现?createSession() 方法,代码如下:
import org.apache.shiro.session.Session;
import org.apache.shiro.session.mgt.SessionContext;
import org.apache.shiro.session.mgt.SessionFactory;
import org.apache.shiro.web.session.mgt.DefaultWebSessionContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.StringUtils;
import com.cache.ShiroSession;
import javax.servlet.http.HttpServletRequest;
public class ShiroSessionFactory implements SessionFactory {
private static final Logger logger = LoggerFactory.getLogger(ShiroSessionFactory.class);
@Override
public Session createSession(SessionContext initData) {
ShiroSession session = new ShiroSession();
HttpServletRequest request = (HttpServletRequest)initData.get(DefaultWebSessionContext.class.getName() + ".SERVLET_REQUEST");
session.setHost(getIpAddress(request));
return session;
}
public static String getIpAddress(HttpServletRequest request) {
String localIP = "127.0.0.1";
String ip = request.getHeader("x-forwarded-for");
if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return ip;
}
}
? ? ? ? 将?ShiroSessionFactory 类配置到 ShiroConfig 中,记得赋值给 SessionManager ,代码如下:
@Bean("sessionManager")
public SessionManager sessionManager() {
// ....
ShiroSessionManager sessionManager = new ShiroSessionManager();
sessionManager.setSessionFactory(sessionFactory());
// ....
}
@Bean
public ShiroSessionFactory sessionFactory(){
ShiroSessionFactory sessionFactory = new ShiroSessionFactory();
return sessionFactory;
}
? ? ? ?最后在 RedisSessionDAO 的 update 方法上判断若只是更改 session 的 lastAccessTime 字段,则直接返回。代码如下:
@Override
public void update(Session session) throws UnknownSessionException {
// 如果会话过期/停止 没必要再更新了
try {
if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
return;
}
if (session instanceof ShiroSession) {
// 如果没有主要字段(除lastAccessTime以外其他字段)发生改变
ShiroSession ss = (ShiroSession) session;
if (!ss.isChanged()) {
return;
}
// 如果没有返回 证明有调用 setAttribute往redis 放的时候永远设置为false
ss.setChanged(false);
}
this.saveSession(session);
} catch (Exception e) {
logger.warn("update Session is failed", e);
}
}
? ? ? ?这里注意:在操作 redis 更新 session 的时候,changed 属性一定是 false,如果只是更改 lastAccessTime 也不会直接返回,因为从 redis 拿出来的是 true 。所以,既然走到往 redis 里更新 session 这一步,那一定有 setAttributes()? 等方法被调用。所以往 redis 放的时候设置为 false。下次从 redis 获取 session 是 false 时,则只更改 lastAccessTime ,那么 changed 属性就是 false,就不会操作 redis。
cs