前段时间看到一篇不错的文章《看了这篇你就会手写RPC框架了》,于是便来了兴趣对着实现了一遍,后面觉得还有很多优化的地方便对其进行了改进。
主要改动点如下:
除了Java序列化协议,增加了protobuf和kryo序列化协议,配置即用。 增加多种负载均衡算法(随机、轮询、加权轮询、平滑加权轮询),配置即用。 由原来的每个请求建立一次连接,改为建立TCP长连接,并多次复用。 RPC,即 Remote Procedure Call(远程过程调用),调用远程计算机上的服务,就像调用本地服务一样。RPC可以很好的解耦系统,如WebService就是一种基于Http协议的RPC。
调用示意图
调用示意图
总的来说,就如下几个步骤:
客户端(ServerA)执行远程方法时就调用client stub传递类名、方法名和参数等信息。 client stub会将参数等信息序列化为二进制流的形式,然后通过Sockect发送给服务端(ServerB) 服务端收到数据包后,server stub 需要进行解析反序列化为类名、方法名和参数等信息。 server stub调用对应的本地方法,并把执行结果返回给客户端 所以一个RPC框架有如下角色:
远程方法的调用方,即客户端。一个服务既可以是消费者也可以是提供者。
远程服务的提供方,即服务端。一个服务既可以是消费者也可以是提供者。
保存服务提供者的服务地址等信息,一般由zookeeper、redis等实现。
监控接口的响应时间、统计请求数量等,及时发现系统问题并发出告警通知。
本RPC框架rpc-spring-boot-starter涉及技术栈如下:
消息编解码:protostuff、kryo、java
由于代码过多,这里只讲几处改动点。
1.编写LoadBalance的实现类
负载均衡算法实现类
2.自定义注解 @LoadBalanceAno
/** * 负载均衡注解 */ @Target (ElementType.TYPE)@Retention (RetentionPolicy.RUNTIME)@Documented public @interface LoadBalanceAno { String value () default "" ; }/** * 轮询算法 */ @LoadBalanceAno (RpcConstant.BALANCE_ROUND)public class FullRoundBalance implements LoadBalance { private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class ) ; private volatile int index; @Override public synchronized Service chooseOne (List services) { // 加锁防止多线程情况下,index超出services.size() if (index == services.size()) { index = 0 ; } return services.get(index++); } }
3.新建在resource目录下META-INF/servers文件夹并创建文件
enter description here
4.RpcConfig增加配置项loadBalance
/** * @author 2YSP * @date 2020/7/26 15:13 */ @ConfigurationProperties (prefix = "sp.rpc" )public class RpcConfig { /** * 服务注册中心地址 */ private String registerAddress = "127.0.0.1:2181" ; /** * 服务暴露端口 */ private Integer serverPort = 9999 ; /** * 服务协议 */ private String protocol = "java" ; /** * 负载均衡算法 */ private String loadBalance = "random" ; /** * 权重,默认为1 */ private Integer weight = 1 ; // 省略getter setter }
5.在自动配置类RpcAutoConfiguration根据配置选择对应的算法实现类
/** * 使用spi匹配符合配置的负载均衡算法 * * @param name * @return */ private LoadBalance getLoadBalance (String name) { ServiceLoader loader = ServiceLoader.load(LoadBalance.class ) ; Iterator iterator = loader.iterator(); while (iterator.hasNext()) { LoadBalance loadBalance = iterator.next(); LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class ) ; Assert.notNull(ano, "load balance name can not be empty!" ); if (name.equals(ano.value())) { return loadBalance; } }
throw new RpcException("invalid load balance config" ); } @Bean public ClientProxyFactory proxyFactory (@Autowired RpcConfig rpcConfig) { ClientProxyFactory clientProxyFactory = new ClientProxyFactory(); // 设置服务发现着 clientProxyFactory.setServerDiscovery(new ZookeeperServerDiscovery(rpcConfig.getRegisterAddress())); // 设置支持的协议 Map supportMessageProtocols = buildSupportMessageProtocols(); clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols); // 设置负载均衡算法 LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance()); clientProxyFactory.setLoadBalance(loadBalance); // 设置网络层实现 clientProxyFactory.setNetClient(new NettyNetClient()); return clientProxyFactory; }
使用Map来缓存数据
/** * 服务发现本地缓存 */ public class ServerDiscoveryCache { /** * key: serviceName */ private static final Map> SERVER_MAP = new ConcurrentHashMap<>(); /** * 客户端注入的远程服务service class */ public static final List SERVICE_CLASS_NAMES = new ArrayList<>(); public static void put (String serviceName, List serviceList) { SERVER_MAP.put(serviceName, serviceList); } /** * 去除指定的值 * @param serviceName * @param service */ public static void remove (String serviceName, Service service) { SERVER_MAP.computeIfPresent(serviceName, (key, value) -> value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList()) ); } public static void removeAll (String serviceName) { SERVER_MAP.remove(serviceName); } public static boolean isEmpty (String serviceName) { return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0 ; } public static List get (String serviceName) { return SERVER_MAP.get(serviceName); } }
ClientProxyFactory,先查本地缓存,缓存没有再查询zookeeper。
/** * 根据服务名获取可用的服务地址列表 * @param serviceName * @return */ private List getServiceList (String serviceName) { List services; synchronized (serviceName){ if (ServerDiscoveryCache.isEmpty(serviceName)) { services = serverDiscovery.findServiceList(serviceName); if (services == null || services.size() == 0 ) { throw new RpcException("No provider available!" ); } ServerDiscoveryCache.put(serviceName, services); } else { services = ServerDiscoveryCache.get(serviceName); } } return services; }
问题: 如果服务端因为宕机或网络问题下线了,缓存却还在就会导致客户端请求已经不可用的服务端,增加请求失败率。解决方案: 由于服务端注册的是临时节点,所以如果服务端下线节点会被移除。只要监听zookeeper的子节点,如果新增或删除子节点就直接清空本地缓存即可。
DefaultRpcProcessor
/** * Rpc处理者,支持服务启动暴露,自动注入Service * @author 2YSP * @date 2020/7/26 14:46 */ public class DefaultRpcProcessor implements ApplicationListener <ContextRefreshedEvent > { @Override public void onApplicationEvent (ContextRefreshedEvent event) { // Spring启动完毕过后会收到一个事件通知 if (Objects.isNull(event.getApplicationContext().getParent())){ ApplicationContext context = event.getApplicationContext(); // 开启服务 startServer(context); // 注入Service injectService(context); } } private void injectService (ApplicationContext context) { String[] names = context.getBeanDefinitionNames(); for (String name : names){ Class> clazz = context.getType(name); if (Objects.isNull(clazz)){ continue ; } Field[] declaredFields = clazz.getDeclaredFields(); for (Field field : declaredFields){ // 找出标记了InjectService注解的属性 InjectService injectService = field.getAnnotation(InjectService.class ) ; if (injectService == null ){ continue ; } Class> fieldClass = field.getType(); Object object = context.getBean(name); field.setAccessible(true ); try { field.set(object,clientProxyFactory.getProxy(fieldClass)); } catch (IllegalAccessException e) { e.printStackTrace(); } // 添加本地服务缓存 ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName()); } } // 注册子节点监听 if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){ ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery(); ZkClient zkClient = serverDiscovery.getZkClient(); ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{ String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service" ; zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl()); }); logger.info("subscribe service zk node successfully" ); } } private void startServer (ApplicationContext context) { ... } }
ZkChildListenerImpl
/** * 子节点事件监听处理类 */ public class ZkChildListenerImpl implements IZkChildListener { private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class ) ; /** * 监听子节点的删除和新增事件 * @param parentPath /rpc/serviceName/service * @param childList * @throws Exception */ @Override public void handleChildChange (String parentPath, List childList) throws Exception { logger.debug("Child change parentPath:[{}] -- childList:[{}]" , parentPath, childList); // 只要子节点有改动就清空缓存 String[] arr = parentPath.split("/" ); ServerDiscoveryCache.removeAll(arr[2 ]); } }
这部分的改动最多,先增加新的sendRequest接口。
添加接口
实现类NettyNetClient
/** * @author 2YSP * @date 2020/7/25 20:12 */ public class NettyNetClient implements NetClient { private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class ) ; private static ExecutorService threadPool = new ThreadPoolExecutor(4 , 10 , 200 , TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000 ), new ThreadFactoryBuilder() .setNameFormat("rpcClient-%d" ) .build()); private EventLoopGroup loopGroup = new NioEventLoopGroup(4 ); /** * 已连接的服务缓存 * key: 服务地址,格式:ip:port */ public static Map connectedServerNodes = new ConcurrentHashMap<>(); @Override public byte [] sendRequest(byte [] data, Service service) throws InterruptedException { .... return respData; } @Override public RpcResponse sendRequest (RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) { String address = service.getAddress(); synchronized (address) { if (connectedServerNodes.containsKey(address)) { SendHandlerV2 handler = connectedServerNodes.get(address); logger.info("使用现有的连接" ); return handler.sendRequest(rpcRequest); } String[] addrInfo = address.split(":" ); final String serverAddress = addrInfo[0 ]; final String serverPort = addrInfo[1 ]; final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address); threadPool.submit(() -> { // 配置客户端 Bootstrap b = new Bootstrap(); b.group(loopGroup).channel(NioSocketChannel.class ) .option (ChannelOption .TCP_NODELAY , true ) .handler (new ChannelInitializer <SocketChannel >() { @Override protected void initChannel (SocketChannel socketChannel) throws Exception { ChannelPipeline pipeline = socketChannel.pipeline(); pipeline .addLast(handler); } }); // 启用客户端连接 ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort)); channelFuture.addListener(new ChannelFutureListener() { @Override public void operationComplete (ChannelFuture channelFuture) throws Exception { connectedServerNodes.put(address, handler); } }); } ); logger.info("使用新的连接。。。" ); return handler.sendRequest(rpcRequest); } } }
每次请求都会调用sendRequest()方法,用线程池异步和服务端创建TCP长连接,连接成功后将SendHandlerV2缓存到ConcurrentHashMap中方便复用,后续请求的请求地址(ip+port)如果在connectedServerNodes中存在则使用connectedServerNodes中的handler处理不再重新建立连接。
SendHandlerV2
/** * @author 2YSP * @date 2020/8/19 20:06 */ public class SendHandlerV2 extends ChannelInboundHandlerAdapter { private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class ) ; /** * 等待通道建立最大时间 */ static final int CHANNEL_WAIT_TIME = 4 ; /** * 等待响应最大时间 */ static final int RESPONSE_WAIT_TIME = 8 ; private volatile Channel channel; private String remoteAddress; private static Map> requestMap = new ConcurrentHashMap<>(); private MessageProtocol messageProtocol; private CountDownLatch latch = new CountDownLatch(1 ); public SendHandlerV2 (MessageProtocol messageProtocol,String remoteAddress) { this .messageProtocol = messageProtocol; this .remoteAddress = remoteAddress; } @Override public void channelRegistered (ChannelHandlerContext ctx) throws Exception { this .channel = ctx.channel(); latch.countDown(); } @Override public void channelActive (ChannelHandlerContext ctx) throws Exception { logger.debug("Connect to server successfully:{}" , ctx); } @Override public void channelRead (ChannelHandlerContext ctx, Object msg) throws Exception { logger.debug("Client reads message:{}" , msg); ByteBuf byteBuf = (ByteBuf) msg; byte [] resp = new byte [byteBuf.readableBytes()]; byteBuf.readBytes(resp); // 手动回收 ReferenceCountUtil.release(byteBuf); RpcResponse response = messageProtocol.unmarshallingResponse(resp); RpcFuture future = requestMap.get(response.getRequestId()); future.setResponse(response); } @Override public void exceptionCaught (ChannelHandlerContext ctx, Throwable cause) throws Exception { cause.printStackTrace(); logger.error("Exception occurred:{}" , cause.getMessage()); ctx.close(); } @Override public void channelReadComplete (ChannelHandlerContext ctx) throws Exception { ctx.flush(); } @Override public void channelInactive (ChannelHandlerContext ctx) throws Exception { super .channelInactive(ctx); logger.error("channel inactive with remoteAddress:[{}]" ,remoteAddress); NettyNetClient.connectedServerNodes.remove(remoteAddress); } @Override public void userEventTriggered (ChannelHandlerContext ctx, Object evt) throws Exception { super .userEventTriggered(ctx, evt); } public RpcResponse sendRequest (RpcRequest request) { RpcResponse response; RpcFuture future = new RpcFuture<>(); requestMap.put(request.getRequestId(), future); try { byte [] data = messageProtocol.marshallingRequest(request);
ByteBuf reqBuf = Unpooled.buffer(data.length); reqBuf.writeBytes(data); if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){ channel.writeAndFlush(reqBuf); // 等待响应 response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS); }else { throw new RpcException("establish channel time out" ); } } catch (Exception e) { throw new RpcException(e.getMessage()); } finally { requestMap.remove(request.getRequestId()); } return response; } }
RpcFuture
package cn.sp.rpc.client.net;import java.util.concurrent.*;/** * @author 2YSP * @date 2020/8/19 22:31 */ public class RpcFuture <T > implements Future <T > { private T response; /** * 因为请求和响应是一一对应的,所以这里是1 */ private CountDownLatch countDownLatch = new CountDownLatch(1 ); /** * Future的请求时间,用于计算Future是否超时 */ private long beginTime = System.currentTimeMillis(); @Override public boolean cancel (boolean mayInterruptIfRunning) { return false ; } @Override public boolean isCancelled () { return false ; } @Override public boolean isDone () { if (response != null ) { return true ; } return false ; } /** * 获取响应,直到有结果才返回 * @return * @throws InterruptedException * @throws ExecutionException */ @Override public T get () throws InterruptedException, ExecutionException { countDownLatch.await(); return response; } @Override public T get (long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { if (countDownLatch.await(timeout,unit)){ return response; } return null ; } public void setResponse (T response) { this .response = response; countDownLatch.countDown(); } public long getBeginTime () { return beginTime; } }
此处逻辑,第一次执行 SendHandlerV2#sendRequest() 时channel需要等待通道建立好之后才能发送请求,所以用CountDownLatch来控制,等待通道建立。自定义Future+requestMap缓存来实现netty的请求和阻塞等待响应,RpcRequest对象在创建时会生成一个请求的唯一标识requestId,发送请求前先将RpcFuture缓存到requestMap中,key为requestId,读取到服务端的响应信息后(channelRead方法),将响应结果放入对应的RpcFuture中。SendHandlerV2#channelInactive() 方法中,如果连接的服务端异常断开连接了,则及时清理缓存中对应的serverNode。
测试环境:
(英特尔)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz 4核 1.本地启动zookeeper 2.本地启动一个消费者,两个服务端,轮询算法 3.使用ab进行压力测试,4个线程发送10000个请求
ab -c 4 -n 10000 http://localhost:8080/test/user?id=1
测试结果 :
测试结果
从图片可以看出,10000个请求只用了11s,比之前的130+秒耗时减少了10倍以上。
首先点击右下方在看 ,再长按下方二维码关注哦,并后台回复 二维码