zxl
/
CTT
forked from Cal/CTT
1
0
Fork 0
CTT/Unity/Assets/Model/Module/Network/KService.cs

596 lines
20 KiB
C#
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Runtime.InteropServices;
namespace ET
{
public static class KcpProtocalType
{
public const byte SYN = 1;
public const byte ACK = 2;
public const byte FIN = 3;
public const byte MSG = 4;
public const byte RouterReconnect = 10;
public const byte RouterAck = 11;
public const byte RouterSYN = 12;
}
public enum ServiceType
{
Outer,
Inner,
}
public sealed class KService: AService
{
// KService创建的时间
public long StartTime;
// 当前时间 - KService创建的时间, 线程安全
public uint TimeNow
{
get
{
return (uint) (TimeHelper.ClientNow() - this.StartTime);
}
}
private Socket socket;
#region 回调方法
static KService()
{
//Kcp.KcpSetLog(KcpLog);
Kcp.KcpSetoutput(KcpOutput);
}
private static readonly byte[] logBuffer = new byte[1024];
#if ENABLE_IL2CPP
[AOT.MonoPInvokeCallback(typeof(KcpOutput))]
#endif
private static void KcpLog(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
{
try
{
Marshal.Copy(bytes, logBuffer, 0, len);
Log.Info(logBuffer.ToStr(0, len));
}
catch (Exception e)
{
Log.Error(e);
}
}
#if ENABLE_IL2CPP
[AOT.MonoPInvokeCallback(typeof(KcpOutput))]
#endif
private static int KcpOutput(IntPtr bytes, int len, IntPtr kcp, IntPtr user)
{
try
{
KChannel kChannel = KChannel.kChannels[(uint) user];
kChannel.Output(bytes, len);
}
catch (Exception e)
{
Log.Error(e);
return len;
}
return len;
}
#endregion
#region 主线程
public KService(ThreadSynchronizationContext threadSynchronizationContext, IPEndPoint ipEndPoint, ServiceType serviceType)
{
this.ServiceType = serviceType;
this.ThreadSynchronizationContext = threadSynchronizationContext;
this.StartTime = TimeHelper.ClientNow();
this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
this.socket.SendBufferSize = Kcp.OneM * 64;
this.socket.ReceiveBufferSize = Kcp.OneM * 64;
}
this.socket.Bind(ipEndPoint);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
const uint IOC_IN = 0x80000000;
const uint IOC_VENDOR = 0x18000000;
uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
}
}
public KService(ThreadSynchronizationContext threadSynchronizationContext, ServiceType serviceType)
{
this.ServiceType = serviceType;
this.ThreadSynchronizationContext = threadSynchronizationContext;
this.StartTime = TimeHelper.ClientNow();
this.socket = new Socket(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
// 作为客户端不需要修改发送跟接收缓冲区大小
this.socket.Bind(new IPEndPoint(IPAddress.Any, 0));
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
const uint IOC_IN = 0x80000000;
const uint IOC_VENDOR = 0x18000000;
uint SIO_UDP_CONNRESET = IOC_IN | IOC_VENDOR | 12;
this.socket.IOControl((int) SIO_UDP_CONNRESET, new[] { Convert.ToByte(false) }, null);
}
}
public void ChangeAddress(long id, IPEndPoint address)
{
#if NET_THREAD
this.ThreadSynchronizationContext.Post(() =>
{
#endif
KChannel kChannel = this.Get(id);
if (kChannel == null)
{
return;
}
Log.Info($"channel change address: {id} {address}");
kChannel.RemoteAddress = address;
#if NET_THREAD
}
);
#endif
}
#endregion
#region 网络线程
private readonly Dictionary<long, KChannel> idChannels = new Dictionary<long, KChannel>();
private readonly Dictionary<long, KChannel> localConnChannels = new Dictionary<long, KChannel>();
private readonly Dictionary<long, KChannel> waitConnectChannels = new Dictionary<long, KChannel>();
private readonly List<long> waitRemoveChannels = new List<long>();
private readonly byte[] cache = new byte[8192];
private EndPoint ipEndPoint = new IPEndPoint(IPAddress.Any, 0);
// 网络线程
private readonly Random random = new Random(Guid.NewGuid().GetHashCode());
// 下帧要更新的channel
private readonly HashSet<long> updateChannels = new HashSet<long>();
// 下次时间更新的channel
private readonly MultiMap<long, long> timeId = new MultiMap<long, long>();
private readonly List<long> timeOutTime = new List<long>();
// 记录最小时间不用每次都去MultiMap取第一个值
private long minTime;
public override bool IsDispose()
{
return this.socket == null;
}
public override void Dispose()
{
foreach (long channelId in this.idChannels.Keys.ToArray())
{
this.Remove(channelId);
}
this.socket.Close();
this.socket = null;
}
private IPEndPoint CloneAddress()
{
IPEndPoint ip = (IPEndPoint) this.ipEndPoint;
return new IPEndPoint(ip.Address, ip.Port);
}
private void Recv()
{
if (this.socket == null)
{
return;
}
while (socket != null && this.socket.Available > 0)
{
int messageLength = this.socket.ReceiveFrom(this.cache, ref this.ipEndPoint);
// 长度小于1不是正常的消息
if (messageLength < 1)
{
continue;
}
// accept
byte flag = this.cache[0];
// conn从100开始如果为123则是特殊包
uint remoteConn = 0;
uint localConn = 0;
try
{
KChannel kChannel = null;
switch (flag)
{
#if NOT_CLIENT
case KcpProtocalType.SYN: // accept
{
// 长度!=5不是SYN消息
if (messageLength < 9)
{
break;
}
string realAddress = null;
remoteConn = BitConverter.ToUInt32(this.cache, 1);
if (messageLength > 9)
{
realAddress = this.cache.ToStr(9, messageLength - 9);
}
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
this.waitConnectChannels.TryGetValue(remoteConn, out kChannel);
if (kChannel == null)
{
localConn = CreateRandomLocalConn(this.random);
// 已存在同样的localConn则不处理等待下次sync
if (this.localConnChannels.ContainsKey(localConn))
{
break;
}
long id = this.CreateAcceptChannelId(localConn);
if (this.idChannels.ContainsKey(id))
{
break;
}
kChannel = new KChannel(id, localConn, remoteConn, this.socket, this.CloneAddress(), this);
this.idChannels.Add(kChannel.Id, kChannel);
this.waitConnectChannels.Add(kChannel.RemoteConn, kChannel); // 连接上了或者超时后会删除
this.localConnChannels.Add(kChannel.LocalConn, kChannel);
kChannel.RealAddress = realAddress;
IPEndPoint realEndPoint = kChannel.RealAddress == null? kChannel.RemoteAddress : NetworkHelper.ToIPEndPoint(kChannel.RealAddress);
this.OnAccept(kChannel.Id, realEndPoint);
}
if (kChannel.RemoteConn != remoteConn)
{
break;
}
// 地址跟上次的不一致则跳过
if (kChannel.RealAddress != realAddress)
{
Log.Error($"kchannel syn address diff: {kChannel.Id} {kChannel.RealAddress} {realAddress}");
break;
}
try
{
byte[] buffer = this.cache;
buffer.WriteTo(0, KcpProtocalType.ACK);
buffer.WriteTo(1, kChannel.LocalConn);
buffer.WriteTo(5, kChannel.RemoteConn);
Log.Info($"kservice syn: {kChannel.Id} {remoteConn} {localConn}");
this.socket.SendTo(buffer, 0, 9, SocketFlags.None, kChannel.RemoteAddress);
}
catch (Exception e)
{
Log.Error(e);
kChannel.OnError(ErrorCode.ERR_SocketCantSend);
}
break;
}
#endif
case KcpProtocalType.ACK: // connect返回
// 长度!=9不是connect消息
if (messageLength != 9)
{
break;
}
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
kChannel = this.GetByLocalConn(localConn);
if (kChannel != null)
{
Log.Info($"kservice ack: {kChannel.Id} {remoteConn} {localConn}");
kChannel.RemoteConn = remoteConn;
kChannel.HandleConnnect();
}
break;
case KcpProtocalType.FIN: // 断开
// 长度!=13不是DisConnect消息
if (messageLength != 13)
{
break;
}
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
int error = BitConverter.ToInt32(this.cache, 9);
// 处理chanel
kChannel = this.GetByLocalConn(localConn);
if (kChannel == null)
{
break;
}
// 校验remoteConn防止第三方攻击
if (kChannel.RemoteConn != remoteConn)
{
break;
}
Log.Info($"kservice recv fin: {kChannel.Id} {localConn} {remoteConn} {error}");
kChannel.OnError(ErrorCode.ERR_PeerDisconnect);
break;
case KcpProtocalType.MSG: // 断开
// 长度<9不是Msg消息
if (messageLength < 9)
{
break;
}
// 处理chanel
remoteConn = BitConverter.ToUInt32(this.cache, 1);
localConn = BitConverter.ToUInt32(this.cache, 5);
kChannel = this.GetByLocalConn(localConn);
if (kChannel == null)
{
// 通知对方断开
this.Disconnect(localConn, remoteConn, ErrorCode.ERR_KcpNotFoundChannel, (IPEndPoint) this.ipEndPoint, 1);
break;
}
// 校验remoteConn防止第三方攻击
if (kChannel.RemoteConn != remoteConn)
{
break;
}
kChannel.HandleRecv(this.cache, 5, messageLength - 5);
break;
}
}
catch (Exception e)
{
Log.Error($"kservice error: {flag} {remoteConn} {localConn}\n{e}");
}
}
}
private KChannel Get(long id)
{
KChannel channel;
this.idChannels.TryGetValue(id, out channel);
return channel;
}
private KChannel GetByLocalConn(uint localConn)
{
KChannel channel;
this.localConnChannels.TryGetValue(localConn, out channel);
return channel;
}
protected override void Get(long id, IPEndPoint address)
{
if (this.idChannels.TryGetValue(id, out KChannel channel))
{
return;
}
try
{
// 低32bit是localConn
uint localConn = (uint)((ulong) id & uint.MaxValue);
channel = new KChannel(id, localConn, this.socket, address, this);
this.idChannels.Add(id, channel);
this.localConnChannels.Add(channel.LocalConn, channel);
}
catch (Exception e)
{
Log.Error($"kservice get error: {id}\n{e}");
}
}
public override void Remove(long id)
{
if (!this.idChannels.TryGetValue(id, out KChannel kChannel))
{
return;
}
Log.Info($"kservice remove channel: {id} {kChannel.LocalConn} {kChannel.RemoteConn}");
this.idChannels.Remove(id);
this.localConnChannels.Remove(kChannel.LocalConn);
if (this.waitConnectChannels.TryGetValue(kChannel.RemoteConn, out KChannel waitChannel))
{
if (waitChannel.LocalConn == kChannel.LocalConn)
{
this.waitConnectChannels.Remove(kChannel.RemoteConn);
}
}
kChannel.Dispose();
}
private void Disconnect(uint localConn, uint remoteConn, int error, IPEndPoint address, int times)
{
try
{
if (this.socket == null)
{
return;
}
byte[] buffer = this.cache;
buffer.WriteTo(0, KcpProtocalType.FIN);
buffer.WriteTo(1, localConn);
buffer.WriteTo(5, remoteConn);
buffer.WriteTo(9, (uint) error);
for (int i = 0; i < times; ++i)
{
this.socket.SendTo(buffer, 0, 13, SocketFlags.None, address);
}
}
catch (Exception e)
{
Log.Error($"Disconnect error {localConn} {remoteConn} {error} {address} {e}");
}
Log.Info($"channel send fin: {localConn} {remoteConn} {address} {error}");
}
protected override void Send(long channelId, long actorId, MemoryStream stream)
{
KChannel channel = this.Get(channelId);
if (channel == null)
{
return;
}
channel.Send(actorId, stream);
}
// 服务端需要看channel的update时间是否已到
public void AddToUpdateNextTime(long time, long id)
{
if (time == 0)
{
this.updateChannels.Add(id);
return;
}
if (time < this.minTime)
{
this.minTime = time;
}
this.timeId.Add(time, id);
}
public override void Update()
{
this.Recv();
this.TimerOut();
foreach (long id in updateChannels)
{
KChannel kChannel = this.Get(id);
if (kChannel == null)
{
continue;
}
if (kChannel.Id == 0)
{
continue;
}
kChannel.Update();
}
this.updateChannels.Clear();
this.RemoveConnectTimeoutChannels();
}
private void RemoveConnectTimeoutChannels()
{
this.waitRemoveChannels.Clear();
foreach (long channelId in this.waitConnectChannels.Keys)
{
this.waitConnectChannels.TryGetValue(channelId, out KChannel kChannel);
if (kChannel == null)
{
Log.Error($"RemoveConnectTimeoutChannels not found kchannel: {channelId}");
continue;
}
// 连接上了要马上删除
if (kChannel.IsConnected)
{
this.waitRemoveChannels.Add(channelId);
}
// 10秒连接超时
if (this.TimeNow > kChannel.CreateTime + 10 * 1000)
{
this.waitRemoveChannels.Add(channelId);
}
}
foreach (long channelId in this.waitRemoveChannels)
{
this.waitConnectChannels.Remove(channelId);
}
}
// 计算到期需要update的channel
private void TimerOut()
{
if (this.timeId.Count == 0)
{
return;
}
uint timeNow = this.TimeNow;
if (timeNow < this.minTime)
{
return;
}
this.timeOutTime.Clear();
foreach (KeyValuePair<long, List<long>> kv in this.timeId)
{
long k = kv.Key;
if (k > timeNow)
{
minTime = k;
break;
}
this.timeOutTime.Add(k);
}
foreach (long k in this.timeOutTime)
{
foreach (long v in this.timeId[k])
{
this.updateChannels.Add(v);
}
this.timeId.Remove(k);
}
}
#endregion
}
}