当前位置 博文首页 > 快乐的小三菊的博客:springboot + shiro 整合 redis 解决频繁访

    快乐的小三菊的博客:springboot + shiro 整合 redis 解决频繁访

    作者:[db:作者] 时间:2021-07-13 16:08

    背景:

    ? ? ? ?关于频繁访问 redis ,一共分为两种情况,第一种是频繁的去 redis 中 读取 session ;另外一种是频繁的去更新 redis 中的 session ,针对这两种情况,分别写出相应的解决方案。

    第一种情况:频繁的去 redis 中读取 session

    ? ? ? ?针对于第一种情况有两种解决方式,使用本地缓存和从 request 中获取,下面分别介绍一下这两种解决方式:

    解决方式一:使用本地缓存

    ? ? ? ?在上一篇文章中我们用到了?RedisSessionDAO 这个类,在这个类中依赖了一个叫 SessionInMemory 的类,是 shiro-redis 作者为了解决一次请求频繁访问 redis 读取 session 的解决方案,基于本地 cache,如果是在一秒内的请求,都会从本地 cache 中获取 request。我们再看下这块的代码:

    public class RedisSessionDAO extends AbstractSessionDAO {
    
        private static Logger logger = LoggerFactory.getLogger(RedisSessionDAO.class);
    
        private static final String DEFAULT_SESSION_KEY_PREFIX = "shiro:session:";
        private String keyPrefix = DEFAULT_SESSION_KEY_PREFIX;
    
        private static final long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L;
        /**
         * doReadSession be called about 10 times when login.
         * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal.
         * The default value is 1000 milliseconds (1s).
         * Most of time, you don't need to change it.
         */
        private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT;
    
        /**
         * expire time in seconds
         */
        private static final int DEFAULT_EXPIRE = -2;
        private static final int NO_EXPIRE = -1;
    
        /**
         * Please make sure expire is longer than sesion.getTimeout()
         */
        private int expire = DEFAULT_EXPIRE;
    
        private static final int MILLISECONDS_IN_A_SECOND = 1000;
    
        private RedisManager redisManager;
        private static ThreadLocal sessionsInThread = new ThreadLocal();
    
        @Override
        public void update(Session session) throws UnknownSessionException {
            // 如果会话过期/停止 没必要再更新了
            try {
                if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
                    return;
                }
    
                if (session instanceof ShiroSession) {
                    // 如果没有主要字段(除lastAccessTime以外其他字段)发生改变
                    ShiroSession ss = (ShiroSession) session;
                    if (!ss.isChanged()) {
                        return;
                    }
                    // 如果没有返回 证明有调用 setAttribute往redis 放的时候永远设置为false
                    ss.setChanged(false);
                }
    
                this.saveSession(session);
            } catch (Exception e) {
                logger.warn("update Session is failed", e);
            }
        }
    
        /**
         * save session
         * @param session
         * @throws UnknownSessionException
         */
        private void saveSession(Session session) throws UnknownSessionException {
            if (session == null || session.getId() == null) {
                logger.error("session or session id is null");
                throw new UnknownSessionException("session or session id is null");
            }
            String key = getRedisSessionKey(session.getId());
            if (expire == DEFAULT_EXPIRE) {
                this.redisManager.set(key, session, (int) (session.getTimeout() / MILLISECONDS_IN_A_SECOND));
                return;
            }
            if (expire != NO_EXPIRE && expire * MILLISECONDS_IN_A_SECOND < session.getTimeout()) {
                logger.warn("Redis session expire time: "
                        + (expire * MILLISECONDS_IN_A_SECOND)
                        + " is less than Session timeout: "
                        + session.getTimeout()
                        + " . It may cause some problems.");
            }
            this.redisManager.set(key, session, expire);
        }
    
        @Override
        public void delete(Session session) {
            if (session == null || session.getId() == null) {
                logger.error("session or session id is null");
                return;
            }
            try {
                redisManager.del(getRedisSessionKey(session.getId()));
            } catch (Exception e) {
                logger.error("delete session error. session id= {}",session.getId());
            }
        }
    
        @Override
        public Collection<Session> getActiveSessions() {
            Set<Session> sessions = new HashSet<Session>();
            try {
                Set<String> keys = redisManager.scan(this.keyPrefix + "*");
                if (keys != null && keys.size() > 0) {
                    for (String key:keys) {
                        Session s = (Session) redisManager.get(key);
                        sessions.add(s);
                    }
                }
            } catch (Exception e) {
                logger.error("get active sessions error.");
            }
            return sessions;
        }
    
        public Long getActiveSessionsSize() {
            Long size = 0L;
            try {
                size = redisManager.scanSize(this.keyPrefix + "*");
            } catch (Exception e) {
                logger.error("get active sessions error.");
            }
            return size;
        }
    
        @Override
        protected Serializable doCreate(Session session) {
            if (session == null) {
                logger.error("session is null");
                throw new UnknownSessionException("session is null");
            }
            Serializable sessionId = this.generateSessionId(session);
            this.assignSessionId(session, sessionId);
            this.saveSession(session);
            return sessionId;
        }
    
        @Override
        protected Session doReadSession(Serializable sessionId) {
            if (sessionId == null) {
                logger.warn("session id is null");
                return null;
            }
            Session s = getSessionFromThreadLocal(sessionId);
    
            if (s != null) {
                return s;
            }
    
            logger.debug("read session from redis");
            try {
                s = (Session) redisManager.get(getRedisSessionKey(sessionId));
                setSessionToThreadLocal(sessionId, s);
            } catch (Exception e) {
                logger.error("read session error. settionId= {}",sessionId);
            }
            return s;
        }
        
        // 将 session 存入到 ThredLocal 中
        private void setSessionToThreadLocal(Serializable sessionId, Session s) {
            Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
            if (sessionMap == null) {
                sessionMap = new HashMap<Serializable, SessionInMemory>();
                sessionsInThread.set(sessionMap);
            }
            SessionInMemory sessionInMemory = new SessionInMemory();
            sessionInMemory.setCreateTime(new Date());
            sessionInMemory.setSession(s);
            sessionMap.put(sessionId, sessionInMemory);
        }
        // 获取 session 
        private Session getSessionFromThreadLocal(Serializable sessionId) {
            Session s = null;
    
            if (sessionsInThread.get() == null) {
                return null;
            }
    
            Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
            SessionInMemory sessionInMemory = sessionMap.get(sessionId);
            if (sessionInMemory == null) {
                return null;
            }
            Date now = new Date();
            long duration = now.getTime() - sessionInMemory.getCreateTime().getTime();
            // 判断请求的时间差,若时间差小于设定的时间,则从本地缓存中获取
            if (duration < sessionInMemoryTimeout) {
                s = sessionInMemory.getSession();
                logger.debug("read session from memory");
            } else {
                sessionMap.remove(sessionId);
            }
    
            return s;
        }
    
        private String getRedisSessionKey(Serializable sessionId) {
            return this.keyPrefix + sessionId;
        }
    
        public RedisManager getRedisManager() {
            return redisManager;
        }
    
        public void setRedisManager(RedisManager redisManager) {
            this.redisManager = redisManager;
        }
    
        public String getKeyPrefix() {
            return keyPrefix;
        }
    
        public void setKeyPrefix(String keyPrefix) {
            this.keyPrefix = keyPrefix;
        }
    
        public long getSessionInMemoryTimeout() {
            return sessionInMemoryTimeout;
        }
    
        public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) {
            this.sessionInMemoryTimeout = sessionInMemoryTimeout;
        }
    
        public int getExpire() {
            return expire;
        }
    
        public void setExpire(int expire) {
            this.expire = expire;
        }
    }

    ? ? ? ?再看下 SessionInMemory 的代码组成,只有两个成员变量,一个是当前的 session,另外一个是创建时间 createTime

    import org.apache.shiro.session.Session;
    
    import java.util.Date;
    
    /**
     * Use ThreadLocal as a temporary storage of Session, so that shiro wouldn't keep read redis several times while a request coming.
     */
    public class SessionInMemory {
        private Session session;
        private Date createTime;
    
        public Session getSession() {
            return session;
        }
    
        public void setSession(Session session) {
            this.session = session;
        }
    
        public Date getCreateTime() {
            return createTime;
        }
    
        public void setCreateTime(Date createTime) {
            this.createTime = createTime;
        }
    }

    解决方式二:从 request 中获取 session

    ? ? ? ?另外一个更好的解决方案是重写?DefaultWebSessionManager 类的 retrieveSession() 方法。在 Web 下使用 shiro 时,这个 sessionKey WebSessionKey 类型的,这个类有个我们很熟悉的属性:servletRequest?。我们可以直接把 session 对象怼进 request 里去!那么在单次请求周期内我们都可以从 request 中取 session 了,而且请求结束后 request 被销毁,作用域和生命周期的问题都不需要我们考虑了。 所以我们需要 Override 这个 retrieveSession() 方法,为此我们需要使用自定义的 SessionManager,如下:

    import org.apache.shiro.session.Session;
    import org.apache.shiro.session.UnknownSessionException;
    import org.apache.shiro.session.mgt.SessionKey;
    import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;
    import org.apache.shiro.web.session.mgt.WebSessionKey;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    
    import javax.servlet.ServletRequest;
    import java.io.Serializable;
    
    /**
     * @description: 解决单次请求需要多次访问redis
     */
    public class ShiroSessionManager extends DefaultWebSessionManager {
    
        private static Logger logger = LoggerFactory.getLogger(DefaultWebSessionManager.class);
        /**
         * 获取session
         * 优化单次请求需要多次访问redis的问题
         * @param sessionKey
         * @return
         * @throws UnknownSessionException
         */
        @Override
        protected Session retrieveSession(SessionKey sessionKey) throws UnknownSessionException {
            Serializable sessionId = getSessionId(sessionKey);
    
            ServletRequest request = null;
            if (sessionKey instanceof WebSessionKey) {
                request = ((WebSessionKey) sessionKey).getServletRequest();
            }
    
            if (request != null && null != sessionId) {
                Object sessionObj = request.getAttribute(sessionId.toString());
                if (sessionObj != null) {
                    logger.debug("read session from request");
                    return (Session) sessionObj;
                }
            }
    
            Session session = super.retrieveSession(sessionKey);
            if (request != null && null != sessionId) {
                request.setAttribute(sessionId.toString(), session);
            }
            return session;
        }
    }

    ? ? ? ?还需要记得在 ShiroConfig 中配置 SessionManager 为自定义的 ShiroSessionManager

    第二种情况:频繁的去更新 redis 中的 session

    ? ? ? ?session 数据发生变化时,就会更新 redis 中的 session,但在大多数的情况下发生变化的只是 session 中的?LastAccessTime (最后一次访问时间)字段。由于 redis 中的?session 失效是由其数据过期实现的,所以在 redis 中只更新 LastAccessTime 这个字段意义不大,反而增加了 redis 的压力。为了减少对 redis 的访问,降低网络压力,当只有这个字段发生变化时,不去更新 redis 中的 session 。若在 session 中发生了除?LastAccessTime 字段以外其他的字段发生改变。这个时候我们就可以增加一个标识位,只有标识为修改的时候才让?redis 做更新,否则直接返回。

    ? ? ? ?我们需要在?SimpleSession 上套一层,增加一个标识位?isChanged ,具体的代码如下所示:

    import org.apache.shiro.session.mgt.SimpleSession;
    
    import java.io.Serializable;
    import java.util.Date;
    import java.util.Map;
    
    /**
     * 由于SimpleSession lastAccessTime更改后也会调用SessionDao update方法,
     * 增加标识位,如果只是更新lastAccessTime SessionDao update方法直接返回
     */
    public class ShiroSession extends SimpleSession implements Serializable {
        // 除lastAccessTime以外其他字段发生改变时为true
        private boolean isChanged = false;
    
        public ShiroSession() {
            super();
            this.setChanged(true);
        }
    
        public ShiroSession(String host) {
            super(host);
            this.setChanged(true);
        }
    
    
        @Override
        public void setId(Serializable id) {
            super.setId(id);
            this.setChanged(true);
        }
    
        @Override
        public void setStopTimestamp(Date stopTimestamp) {
            super.setStopTimestamp(stopTimestamp);
            this.setChanged(true);
        }
    
        @Override
        public void setExpired(boolean expired) {
            super.setExpired(expired);
            this.setChanged(true);
        }
    
        @Override
        public void setTimeout(long timeout) {
            super.setTimeout(timeout);
            this.setChanged(true);
        }
    
        @Override
        public void setHost(String host) {
            super.setHost(host);
            this.setChanged(true);
        }
    
        @Override
        public void setAttributes(Map<Object, Object> attributes) {
            super.setAttributes(attributes);
            this.setChanged(true);
        }
    
        @Override
        public void setAttribute(Object key, Object value) {
            super.setAttribute(key, value);
            this.setChanged(true);
        }
    
        @Override
        public Object removeAttribute(Object key) {
            this.setChanged(true);
            return super.removeAttribute(key);
        }
    
        /**
         * 停止
         */
        @Override
        public void stop() {
            super.stop();
            this.setChanged(true);
        }
    
        /**
         * 设置过期
         */
        @Override
        protected void expire() {
            this.stop();
            this.setExpired(true);
        }
    
        public boolean isChanged() {
            return isChanged;
        }
    
        public void setChanged(boolean isChanged) {
            this.isChanged = isChanged;
        }
    
        @Override
        public boolean equals(Object obj) {
            return super.equals(obj);
        }
    
        @Override
        protected boolean onEquals(SimpleSession ss) {
            return super.onEquals(ss);
        }
    
        @Override
        public int hashCode() {
            return super.hashCode();
        }
    
        @Override
        public String toString() {
            return super.toString();
        }
    }

    ? ? ? ?编写类?ShiroSessionFactory 来实现?SessionFactory 的接口,实现?createSession() 方法,代码如下:

    import org.apache.shiro.session.Session;
    import org.apache.shiro.session.mgt.SessionContext;
    import org.apache.shiro.session.mgt.SessionFactory;
    import org.apache.shiro.web.session.mgt.DefaultWebSessionContext;
    import org.slf4j.Logger;
    import org.slf4j.LoggerFactory;
    import org.springframework.util.StringUtils;
    
    import com.cache.ShiroSession;
    
    import javax.servlet.http.HttpServletRequest;
    
    
    public class ShiroSessionFactory implements SessionFactory {
        private static final Logger logger = LoggerFactory.getLogger(ShiroSessionFactory.class);
    
        @Override
        public Session createSession(SessionContext initData) {
            ShiroSession session = new ShiroSession();
            HttpServletRequest request = (HttpServletRequest)initData.get(DefaultWebSessionContext.class.getName() + ".SERVLET_REQUEST");
            session.setHost(getIpAddress(request));
            return session;
        }
    
        public static String getIpAddress(HttpServletRequest request) {
            String localIP = "127.0.0.1";
            String ip = request.getHeader("x-forwarded-for");
            if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getHeader("WL-Proxy-Client-IP");
            }
            if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {
                ip = request.getRemoteAddr();
            }
            return ip;
        }
    }

    ? ? ? ? 将?ShiroSessionFactory 类配置到 ShiroConfig 中,记得赋值给 SessionManager ,代码如下:

        @Bean("sessionManager")
    	public SessionManager sessionManager() {
    		// ....
            ShiroSessionManager sessionManager =  new ShiroSessionManager();
    		sessionManager.setSessionFactory(sessionFactory());
            // ....
        }
    
        @Bean
    	public ShiroSessionFactory sessionFactory(){
    	    ShiroSessionFactory sessionFactory = new ShiroSessionFactory();
    	    return sessionFactory;
    	}

    ? ? ? ?最后在 RedisSessionDAOupdate 方法上判断若只是更改 session lastAccessTime 字段,则直接返回。代码如下:

        @Override
        public void update(Session session) throws UnknownSessionException {
            // 如果会话过期/停止 没必要再更新了
            try {
                if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
                    return;
                }
    
                if (session instanceof ShiroSession) {
                    // 如果没有主要字段(除lastAccessTime以外其他字段)发生改变
                    ShiroSession ss = (ShiroSession) session;
                    if (!ss.isChanged()) {
                        return;
                    }
                    // 如果没有返回 证明有调用 setAttribute往redis 放的时候永远设置为false
                    ss.setChanged(false);
                }
    
                this.saveSession(session);
            } catch (Exception e) {
                logger.warn("update Session is failed", e);
            }
        }

    ? ? ? ?这里注意:在操作 redis 更新 session 的时候,changed 属性一定是 false,如果只是更改 lastAccessTime 也不会直接返回,因为从 redis 拿出来的是 true 。所以,既然走到往 redis 里更新 session 这一步,那一定有 setAttributes()? 等方法被调用。所以往 redis 放的时候设置为 false。下次从 redis 获取 session false 时,则只更改 lastAccessTime ,那么 changed 属性就是 false,就不会操作 redis

    cs