Shiro 配置 EhCache 缓存

1. 添加 pom.xml 依赖

<dependency>
    <groupId>net.sf.ehcache</groupId>
    <artifactId>ehcache</artifactId>
    <version>2.9.0</version>
</dependency>

2. 创建 EHCache 配置

存储在 resources/spring-shiro-ehcache.xml

<ehcache updateCheck="false" name="shiroCache">
    <defaultCache
            maxElementsInMemory="10000"
            eternal="false"
            timeToIdleSeconds="120"
            timeToLiveSeconds="120"
            overflowToDisk="false"
            diskPersistent="false"
            diskExpiryThreadIntervalSeconds="120"
            />
</ehcache>

3. 为 Shiro 配置 EHCache 缓存

resources/spring-shiro.xml 里生成 EhCacheManager 的 Bean,把 EhCacheManager 的 Bean 注入到 SecurityManager

<bean id="securityManager" class="org.apache.shiro.web.mgt.DefaultWebSecurityManager">
    <property name="realm" ref="realm"/>
    <!-- 需要使用cache的话加上这句 -->
    <property name="cacheManager" ref="shiroEhcacheManager" />
</bean>

<!-- 需要使用cache的话加上这句 -->
<bean id="shiroEhcacheManager" class="org.apache.shiro.cache.ehcache.EhCacheManager">
    <property name="cacheManagerConfigFile" value="classpath:spring-shiro-ehcache.xml" />
</bean>

4. 测试

filter.ShiroRealm.doGetAuthenticationInfo() 里打上断点

可以看到添加 EHCache 后 doGetAuthenticationInfo() 在登录成功后不会调用了,因为验证登录信息首先从 Cache 里查找,如果没有找到才去调用 doGetAuthenticationInfo() 进行登录验证。

EHCache 的说明

自动地可用的默认的 Filter 实例是被 DefaultFilter 枚举定义的,枚举的名称字段是可供配置的名称。它们是:

Filter Name Class
anon org.apache.shiro.web.filter.authc.AnonymousFilter
authc org.apache.shiro.web.filter.authc.FormAuthenticationFilter
authcBasic org.apache.shiro.web.filter.authc.BasicHttpAuthenticationFilter
logout org.apache.shiro.web.filter.authc.LogoutFilter
noSessionCreation org.apache.shiro.web.filter.session.NoSessionCreationFilter
perms org.apache.shiro.web.filter.authz.PermissionAuthorizationFilter
roles org.apache.shiro.web.filter.authz.RolesAuthorizationFilter
port org.apache.shiro.web.filter.authz.PortFilter
rest org.apache.shiro.web.filter.authz.HttpMethodPermissionFilter
ssl org.apache.shiro.web.filter.authz.SslFilter
user org.apache.shiro.web.filter.authz.UserFilter

Shiro 使用 Redis 存储 Session

Redis 是一个高速的分布式缓存

虽然配置 EhCache 提升了效率,但是,Session 仍然存储在 Server 的内存里(Shiro 默认使用 MemorySessionDAO 把 Session 存储在 ConcurrentMap 里),当有大量的用户登录后 Server 的内存就会急剧增加,而且由于 Server 之间内存里的 Session 不能共享,所以没法实现集群。为了解决这两个问题,我们本地仍然使用 EhCache 缓存 Session,但是 Session 存储在 Redis 里。

流程说明:
  1. Servlet 容器在用户浏览器首次访问后会产生 Session,并将 Session 的 ID 保存到 Cookie 中(浏览器不同 ID 不一定相同),同时 Shiro 会将该 Session 缓存到 Redis 中
  2. 用户登录认证成功后 Shiro 会修改 Session 属性,添加用户认证成功标识,并同步修改 Redis 中 Session
  3. 用户发起请求后,Shiro 会先判断本地 EhCache 缓存中是否存在该 Session,如果有,直接从本地EhCache 缓存中读取,如果没有再从 Redis 中读取 Session,并在此时判断 Session 是否认证通过,如果认证通过将该 Session 缓存到本地 EhCache 中
  4. 如果 Session 发生改变,或被删除(用户退出登录),先对 Redis 中 Session 做相应修改(修改或删除);再通过 Redis 消息通道发布缓存失效消息,通知其它节点 EhCache 失效
注意

1. 实现 SessionDAO

为了把 Session 存储到 Redis,只需要实现 Shiro 提供的 SessionDAO 就可以了。RedisSessionDAO 继承 CachingSessionDAO 来实现 SessionDAO,主要是实现 Session 的 CRUD,SerializationUtils 用来序列化和反序列化,RedisManager 用来访问 Redis。

RedisSessionDAO

package shiro;

import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.ValidatingSession;
import org.apache.shiro.session.mgt.eis.CachingSessionDAO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import util.SerializationUtils;

import java.io.Serializable;
import java.util.*;

public class RedisSessionDAO extends CachingSessionDAO {
    private static Logger logger = LoggerFactory.getLogger(RedisSessionDAO.class);

    // 登录成功的信息存储在 session 的这个 attribute 里.
    private static final String AUTHENTICATED_SESSION_KEY =
            "org.apache.shiro.subject.support.DefaultSubjectContext_AUTHENTICATED_SESSION_KEY";

    private String keyPrefix = "shiro_redis_session:";
    private String deleteChannel = "shiro_redis_session:delete";
    private int timeToLiveSeconds = 1800; // Expiration of Jedis's key, unit: second

    private RedisManager redisManager;

    /**
     * DefaultSessionManager 创建完 session 后会调用该方法。
     * 把 session 保持到 Redis。
     * 返回 Session ID;主要此处返回的 ID.equals(session.getId())
     */
    @Override
    protected Serializable doCreate(Session session) {
        logger.debug("=> Create session with ID [{}]", session.getId());

        // 创建一个Id并设置给Session
        Serializable sessionId = this.generateSessionId(session);
        assignSessionId(session, sessionId);

        // session 由 Redis 缓存失效决定
        String key = SerializationUtils.sessionKey(keyPrefix, session);
        String value = SerializationUtils.sessionToString(session);
        redisManager.setex(key, value, timeToLiveSeconds);

        return sessionId;
    }

    /**
     * 决定从本地 Cache 还是从 Redis 读取 Session.
     * @param sessionId
     * @return
     * @throws UnknownSessionException
     */
    @Override
    public Session readSession(Serializable sessionId) throws UnknownSessionException {
        Session s = getCachedSession(sessionId);

        // 1. 如果本地缓存没有,则从 Redis 读取。
        // 2. ServerA 登录了,ServerB 没有登录但缓存里有此 session,所以从 Redis 读取而不是直接用缓存里的
        if (s == null || (
                s.getAttribute(AUTHENTICATED_SESSION_KEY) != null
                && !(Boolean) s.getAttribute(AUTHENTICATED_SESSION_KEY)
        )) {
            s = doReadSession(sessionId);
            if (s == null) {
                throw new UnknownSessionException("There is no session with id [" + sessionId + "]");
            }
            return s;
        }

        return s;
    }

    /**
     * 从 Redis 上读取 session,并缓存到本地 Cache.
     * @param sessionId
     * @return
     */
    @Override
    protected Session doReadSession(Serializable sessionId) {
        logger.debug("=> Read session with ID [{}]", sessionId);

        String value = redisManager.get(SerializationUtils.sessionKey(keyPrefix, sessionId));

        // 例如 Redis 调用 flushdb 情况了所有的数据,读到的 session 就是空的
        if (value != null) {
            Session session = SerializationUtils.sessionFromString(value);
            super.cache(session, session.getId());

            return session;
        }

        return null;
    }

    /**
     * 更新 session 到 Redis.
     * @param session
     */
    @Override
    protected void doUpdate(Session session) {
        // 如果会话过期/停止,没必要再更新了
        if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
            logger.debug("=> Invalid session.");
            return;
        }

        logger.debug("=> Update session with ID [{}]", session.getId());

        String key = SerializationUtils.sessionKey(keyPrefix, session);
        String value = SerializationUtils.sessionToString(session);
        redisManager.setex(key, value, timeToLiveSeconds);
    }

    /**
     * 从 Redis 删除 session,并且发布消息通知其它 Server 上的 Cache 删除 session.
     * @param session
     */
    @Override
    protected void doDelete(Session session) {
        logger.debug("=> Delete session with ID [{}]", session.getId());

        redisManager.del(SerializationUtils.sessionKey(keyPrefix, session));
        // 发布消息通知其它 Server 上的 cache 删除 session.
        redisManager.publish(deleteChannel, SerializationUtils.sessionIdToString(session));

        // 放在其它类里用一个 daemon 线程执行,删除 cache 中的 session
        // jedis.subscribe(new JedisPubSub() {
        //     @Override
        //     public void onMessage(String channel, String message) {
        //         // 1. deserialize message to sessionId
        //         // 2. Session session = getCachedSession(sessionId);
        //         // 3. uncache(session);
        //     }
        // }, deleteChannel);
    }

    /**
     * 取得所有有效的 session.
     * @return
     */
    @Override
    public Collection<Session> getActiveSessions() {
        logger.debug("=> Get active sessions");
        Set<String> keys = redisManager.keys(keyPrefix + "*");
        Collection<String> values = redisManager.mget(keys.toArray(new String[0]));
        List<Session> sessions = new LinkedList<Session>();

        for (String value : values) {
            sessions.add(SerializationUtils.sessionFromString(value));
        }

        return sessions;
    }

    public String getKeyPrefix() {
        return keyPrefix;
    }

    public void setKeyPrefix(String keyPrefix) {
        this.keyPrefix = keyPrefix;
    }

    public String getDeleteChannel() {
        return deleteChannel;
    }

    public void setDeleteChannel(String deleteChannel) {
        this.deleteChannel = deleteChannel;
    }

    public RedisManager getRedisManager() {
        return redisManager;
    }

    public void setRedisManager(RedisManager redisManager) {
        this.redisManager = redisManager;
    }

    public int getTimeToLiveSeconds() {
        return timeToLiveSeconds;
    }

    public void setTimeToLiveSeconds(int timeToLiveSeconds) {
        this.timeToLiveSeconds = timeToLiveSeconds;
    }
}

RedisManager

package shiro;

import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;

import java.util.Collection;
import java.util.Collections;
import java.util.Set;

public class RedisManager {
    private String host = "127.0.0.1";
    private int port = 6379;
    private int timeout = 0; // Timeout for Jedis try to connect to redis server
    private String password = "";

    private JedisPool jedisPool   = null;

    public RedisManager(){
        init();
    }

    /**
     * Initializing jedis pool to connect to Jedis.
     */
    public void init() {
        if(password != null && !"".equals(password)) {
            jedisPool = new JedisPool(new JedisPoolConfig(), host, port, timeout, password);
        } else if (timeout != 0) {
            jedisPool = new JedisPool(new JedisPoolConfig(), host, port, timeout);
        } else {
            jedisPool = new JedisPool(new JedisPoolConfig(), host, port);
        }
    }

    public Jedis getJedis() {
        return jedisPool.getResource();
    }

    /**
     * Get value from Redis
     * @param key
     * @return
     */
    public String get(String key){
        Jedis jedis = jedisPool.getResource();

        try {
            return jedis.get(key);
        } finally {
            jedis.close();
        }
    }

    /**
     * Set value into Redis with default time to live in seconds.
     * @param key
     * @param value
     */
    public void set(String key, String value){
        Jedis jedis = jedisPool.getResource();

        try {
            jedis.set(key, value);
        } finally {
            jedis.close();
        }
    }

    /**
     * Set value into Redis with specified time to live in seconds.
     * @param key
     * @param value
     * @param timeToLiveSeconds
     */
    public void setex(String key, String value, int timeToLiveSeconds){
        Jedis jedis = jedisPool.getResource();

        try {
            jedis.setex(key, timeToLiveSeconds, value);
        } finally {
            jedis.close();
        }
    }

    /**
     * Delete key and its value from Jedis.
     * @param key
     */
    public void del(String key){
        Jedis jedis = jedisPool.getResource();

        try {
            jedis.del(key);
        } finally {
            jedis.close();
        }
    }

    /**
     * Get keys matches the given pattern.
     * @param pattern
     * @return
     */
    public Set<String> keys(String pattern){
        Jedis jedis = jedisPool.getResource();
        try {
            return jedis.keys(pattern);
        } finally {
            jedis.close();
        }
    }

    /**
     * Get multiple values for the given keys.
     * @param keys
     * @return
     */
    public Collection<String> mget(String... keys) {
        if (keys == null && keys.length == 0) {
            Collections.emptySet();
        }

        Jedis jedis = jedisPool.getResource();
        try {
            return jedis.mget(keys);
        } finally {
            jedis.close();
        }
    }

    /**
     * Publish message to channel using subscribe and publish protocol.
     * @param channel
     * @param value
     */
    public void publish(String channel, String value) {
        Jedis jedis = jedisPool.getResource();
        try {
            jedis.publish(channel, value);
        } finally {
            jedis.close();
        }
    }

    public String getHost() {
        return host;
    }

    public void setHost(String host) {
        this.host = host;
    }

    public int getPort() {
        return port;
    }

    public void setPort(int port) {
        this.port = port;
    }

    public int getTimeout() {
        return timeout;
    }

    public void setTimeout(int timeout) {
        this.timeout = timeout;
    }

    public String getPassword() {
        return password;
    }

    public void setPassword(String password) {
        this.password = password;
    }
}

SerializationUtils

package util;

import org.apache.shiro.session.Session;
import org.apache.shiro.session.mgt.SimpleSession;

import java.io.Serializable;

public class SerializationUtils {
    /**
     * 使用 sessionId 创建字符串的 key,用来在 Redis 里作为存储 Session 的 key.
     * @param prefix
     * @param sessionId
     * @return
     */
    public static String sessionKey(String prefix, Serializable sessionId) {
        return prefix + sessionId;
    }

    /**
     * 使用 session 创建字符串的 key,用来在 Redis 里作为存储 Session 的 key.
     * @param prefix
     * @param session
     * @return
     */
    public static String sessionKey(String prefix, Session session) {
        return prefix + session.getId();
    }

    /**
     * 把 sessionId 序列化为 string,因为 Redis 的 key 和 value 必须同时为 string 或者 byte[].
     * @param session
     * @return
     */
    public static String sessionIdToString(Session session) {
        byte[] content = org.apache.commons.lang3.SerializationUtils.serialize(session.getId());
        return org.apache.shiro.codec.Base64.encodeToString(content);
    }

    /**
     * 反序列化得到 sessionId.
     * @param value
     * @return
     */
    public static Serializable sessionIdFromString(String value) {
        byte[] content = org.apache.shiro.codec.Base64.decode(value);
        return org.apache.commons.lang3.SerializationUtils.deserialize(content);
    }

    /**
     * 把 session 序列化为 string,因为 Redis 的 key 和 value 必须同时为 string 或者 byte[].
     * @param value
     * @return
     */
    public static Session sessionToString(String value) {
        byte[] content = org.apache.shiro.codec.Base64.decode(value);
        return org.apache.commons.lang3.SerializationUtils.deserialize(content);
    }

    /**
     * 反序列化得到 session.
     * @param session
     * @return
     */
    public static String sessionFromString(Session session) {
        byte[] content = org.apache.commons.lang3.SerializationUtils.serialize((SimpleSession) session);
        return org.apache.shiro.codec.Base64.encodeToString(content);
    }
}

2. 配置 Session Manager

配置 shiro-spring.xml,使用我们实现的 ShiroSessionDAO 存储和读取 Session,其需要一个 RedisManager 来访问 Redis。

    <!-- 安全管理器 -->
    <bean id="securityManager" class="org.apache.shiro.web.mgt.DefaultWebSecurityManager">
        <property name="realm" ref="userRealm"/>

        <!-- 需要使用cache的话加上这句 -->
        <property name="cacheManager" ref="ehCacheManager" />
        <property name="sessionManager" ref="sessionManager" />
    </bean>

    <!-- 需要使用cache的话加上这句 -->
    <bean id="ehCacheManager" class="org.apache.shiro.cache.ehcache.EhCacheManager">
        <property name="cacheManagerConfigFile" value="classpath:ehcache-shiro.xml" />
    </bean>

    <!--保持 Session 到 Redis-->
    <bean id="redisManager" class="shiro.RedisManager"/>

    <bean id="redisSessionDAO" class="shiro.RedisSessionDAO">
        <property name="redisManager" ref="redisManager"/>
        <property name="timeToLiveSeconds" value="180"/>
    </bean>

    <bean id="sessionManager" class="org.apache.shiro.web.session.mgt.DefaultWebSessionManager">
        <property name="sessionDAO" ref="redisSessionDAO" />
    </bean>