TCP/UDP 网络协议教程 / 12-UDP 编程实战
12 - UDP 编程实战
12.1 UDP 基础收发
UDP 服务器
"""UDP 回显服务器"""
import socket
def udp_echo_server(host='0.0.0.0', port=9999):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind((host, port))
print(f"UDP 服务器启动: {host}:{port}")
while True:
data, addr = sock.recvfrom(4096)
print(f"收到来自 {addr}: {data}")
sock.sendto(data, addr) # 回显
if __name__ == '__main__':
udp_echo_server()
UDP 客户端
"""UDP 客户端"""
import socket
def udp_client():
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(3)
for i in range(5):
message = f"Hello UDP {i}".encode()
sock.sendto(message, ('127.0.0.1', 9999))
try:
data, addr = sock.recvfrom(4096)
print(f"收到: {data.decode()}")
except socket.timeout:
print("接收超时")
sock.close()
udp_client()
12.2 UDP 广播
"""局域网广播发现"""
import socket
import json
import time
def broadcast_server(port=37020):
"""响应广播请求的服务"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('0.0.0.0', port))
while True:
data, addr = sock.recvfrom(4096)
if data == b'DISCOVER':
info = json.dumps({
'name': 'MyServer',
'ip': socket.gethostbyname(socket.gethostname()),
'port': 8080
})
sock.sendto(info.encode(), addr)
def broadcast_discover(port=37020):
"""发送广播发现请求"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
sock.settimeout(2)
sock.sendto(b'DISCOVER', ('255.255.255.255', port))
servers = []
try:
while True:
data, addr = sock.recvfrom(4096)
info = json.loads(data.decode())
servers.append(info)
print(f"发现: {info}")
except socket.timeout:
pass
sock.close()
return servers
# servers = broadcast_discover()
12.3 UDP 多播
"""UDP 多播发送和接收"""
import socket
import struct
import threading
MCAST_GROUP = '224.1.1.1'
MCAST_PORT = 5007
def multicast_sender():
"""多播发送"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
for i in range(10):
msg = f"Multicast #{i}".encode()
sock.sendto(msg, (MCAST_GROUP, MCAST_PORT))
print(f"发送: {msg.decode()}")
time.sleep(1)
sock.close()
def multicast_receiver():
"""多播接收"""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', MCAST_PORT))
# 加入多播组
mreq = struct.pack('4s4s',
socket.inet_aton(MCAST_GROUP),
socket.inet_aton('0.0.0.0'))
sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
print(f"监听多播 {MCAST_GROUP}:{MCAST_PORT}")
while True:
data, addr = sock.recvfrom(4096)
print(f"收到来自 {addr}: {data.decode()}")
12.4 应用层可靠性 UDP
简单可靠 UDP 协议
"""应用层可靠传输"""
import socket
import struct
import time
import threading
# 协议格式: [序列号 4B][确认号 4B][标志 1B][数据长度 2B][数据]
HEADER_FMT = '!IIBH'
HEADER_SIZE = struct.calcsize(HEADER_FMT)
FLAG_DATA = 0x01
FLAG_ACK = 0x02
FLAG_FIN = 0x04
class ReliableUDP:
def __init__(self, sock=None):
self.sock = sock or socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.seq_num = 0
self.unacked = {} # {seq: (data, send_time, dest)}
self.rto = 1.0
self.lock = threading.Lock()
def send_reliable(self, data: bytes, dest: tuple):
"""可靠发送"""
header = struct.pack(HEADER_FMT, self.seq_num, 0, FLAG_DATA, len(data))
packet = header + data
with self.lock:
self.unacked[self.seq_num] = (packet, time.time(), dest)
self.sock.sendto(packet, dest)
sent_seq = self.seq_num
self.seq_num += 1
return sent_seq
def recv_reliable(self) -> tuple:
"""可靠接收"""
while True:
data, addr = self.sock.recvfrom(65535)
if len(data) < HEADER_SIZE:
continue
seq, ack, flags, data_len = struct.unpack(HEADER_FMT, data[:HEADER_SIZE])
payload = data[HEADER_SIZE:HEADER_SIZE + data_len]
if flags & FLAG_DATA:
# 发送 ACK
ack_pkt = struct.pack(HEADER_FMT, 0, seq, FLAG_ACK, 0)
self.sock.sendto(ack_pkt, addr)
return payload, addr
elif flags & FLAG_ACK:
with self.lock:
self.unacked.pop(ack, None)
def check_retransmissions(self):
"""检查重传"""
now = time.time()
with self.lock:
for seq in list(self.unacked.keys()):
pkt, send_time, dest = self.unacked[seq]
if now - send_time > self.rto:
self.sock.sendto(pkt, dest)
self.unacked[seq] = (pkt, now, dest)
Go-Back-N 实现
"""Go-Back-N 滑动窗口"""
class GBNSender:
def __init__(self, sock, window_size=4):
self.sock = sock
self.window_size = window_size
self.base = 0
self.next_seq = 0
self.packets = {} # seq -> (data, dest)
self.timers = {}
self.rto = 1.0
def send(self, data: bytes, dest: tuple) -> bool:
if self.next_seq >= self.base + self.window_size:
return False # 窗口已满
seq = self.next_seq
header = struct.pack(HEADER_FMT, seq, 0, FLAG_DATA, len(data))
packet = header + data
self.sock.sendto(packet, dest)
self.packets[seq] = (data, dest)
self.timers[seq] = time.time()
self.next_seq += 1
return True
def on_ack(self, ack_num: int):
if ack_num >= self.base:
for seq in range(self.base, ack_num + 1):
self.packets.pop(seq, None)
self.timers.pop(seq, None)
self.base = ack_num + 1
def retransmit_all(self):
"""GBN: 重传窗口中所有段"""
now = time.time()
if self.base < self.next_seq:
oldest = min(self.timers.keys(), default=None)
if oldest is not None and now - self.timers[oldest] > self.rto:
for seq in range(self.base, self.next_seq):
if seq in self.packets:
data, dest = self.packets[seq]
header = struct.pack(HEADER_FMT, seq, 0, FLAG_DATA, len(data))
self.sock.sendto(header + data, dest)
self.timers[seq] = now
12.5 UDP 打洞 (NAT Traversal)
"""
UDP 打洞原理:
两个 NAT 后的主机通过第三方服务器交换地址信息
然后直接通信
A ──→ NAT_A ──→ Server ←── NAT_B ←── B
交换地址
A ──→ NAT_A ──→ NAT_B ←── B
直接通信
"""
import socket
import json
# 中继服务器
def relay_server(port=9999):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('0.0.0.0', port))
clients = {} # name -> (ip, port)
while True:
data, addr = sock.recvfrom(4096)
msg = json.loads(data.decode())
if msg['type'] == 'register':
clients[msg['name']] = addr
print(f"注册: {msg['name']} = {addr}")
elif msg['type'] == 'query':
target = clients.get(msg['target'])
if target:
# 发送双方的外部地址
sock.sendto(json.dumps({
'type': 'peer',
'ip': target[0],
'port': target[1]
}).encode(), addr)
# 客户端
def hole_punch_client(name, target, server_addr):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 注册
sock.sendto(json.dumps({
'type': 'register',
'name': name
}).encode(), server_addr)
# 查询对方地址
sock.sendto(json.dumps({
'type': 'query',
'target': target
}).encode(), server_addr)
data, _ = sock.recvfrom(4096)
peer = json.loads(data.decode())
# 打洞:向对方发送数据
peer_addr = (peer['ip'], peer['port'])
sock.sendto(b'hole punch', peer_addr)
return sock, peer_addr
12.6 UDP 性能优化
"""
UDP 性能优化技巧:
1. 增大缓冲区
2. 控制包大小(避免分片)
3. 批量发送
4. 应用层限速
"""
import socket
def optimize_udp_socket(sock):
"""优化 UDP Socket"""
# 增大发送缓冲区
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4 * 1024 * 1024)
# 增大接收缓冲区
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4 * 1024 * 1024)
# 安全的 UDP 包大小
SAFE_UDP_SIZE = 1472 # 以太网安全上限
def send_safe(sock, data: bytes, dest: tuple):
"""分片安全发送"""
if len(data) <= SAFE_UDP_SIZE:
sock.sendto(data, dest)
else:
# 分块发送
offset = 0
seq = 0
while offset < len(data):
chunk = data[offset:offset + SAFE_UDP_SIZE]
# 添加自定义头部:[seq 4B][offset 4B][total 4B][data]
header = struct.pack('!III', seq, offset, len(data))
sock.sendto(header + chunk, dest)
offset += SAFE_UDP_SIZE
seq += 1
12.7 注意事项
⚠️ 没有流控:发送太快可能淹没接收方,应用层需要自己限速
⚠️ 没有拥塞控制:大量 UDP 流量可能挤占网络带宽
⚠️ 包大小限制:超过 MTU 会分片,增大丢包概率
⚠️ NAT 超时:UDP NAT 映射超时快,需要定期保活
⚠️ ICMP 端口不可达:connect 后发送到未监听端口会收到错误
12.8 扩展阅读
- KCP - 可靠 UDP 协议
- QUIC - 基于 UDP 的可靠传输
- RFC 8085 - UDP Usage Guidelines
下一章:13 - QUIC 协议 - QUIC 原理、0-RTT、连接迁移