强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

TCP/UDP 网络协议教程 / 12-UDP 编程实战

12 - UDP 编程实战

12.1 UDP 基础收发

UDP 服务器

"""UDP 回显服务器"""
import socket

def udp_echo_server(host='0.0.0.0', port=9999):
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind((host, port))
    
    print(f"UDP 服务器启动: {host}:{port}")
    
    while True:
        data, addr = sock.recvfrom(4096)
        print(f"收到来自 {addr}: {data}")
        sock.sendto(data, addr)  # 回显

if __name__ == '__main__':
    udp_echo_server()

UDP 客户端

"""UDP 客户端"""
import socket

def udp_client():
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(3)
    
    for i in range(5):
        message = f"Hello UDP {i}".encode()
        sock.sendto(message, ('127.0.0.1', 9999))
        
        try:
            data, addr = sock.recvfrom(4096)
            print(f"收到: {data.decode()}")
        except socket.timeout:
            print("接收超时")
    
    sock.close()

udp_client()

12.2 UDP 广播

"""局域网广播发现"""
import socket
import json
import time

def broadcast_server(port=37020):
    """响应广播请求的服务"""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(('0.0.0.0', port))
    
    while True:
        data, addr = sock.recvfrom(4096)
        if data == b'DISCOVER':
            info = json.dumps({
                'name': 'MyServer',
                'ip': socket.gethostbyname(socket.gethostname()),
                'port': 8080
            })
            sock.sendto(info.encode(), addr)

def broadcast_discover(port=37020):
    """发送广播发现请求"""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
    sock.settimeout(2)
    
    sock.sendto(b'DISCOVER', ('255.255.255.255', port))
    
    servers = []
    try:
        while True:
            data, addr = sock.recvfrom(4096)
            info = json.loads(data.decode())
            servers.append(info)
            print(f"发现: {info}")
    except socket.timeout:
        pass
    
    sock.close()
    return servers

# servers = broadcast_discover()

12.3 UDP 多播

"""UDP 多播发送和接收"""
import socket
import struct
import threading

MCAST_GROUP = '224.1.1.1'
MCAST_PORT = 5007

def multicast_sender():
    """多播发送"""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, 2)
    
    for i in range(10):
        msg = f"Multicast #{i}".encode()
        sock.sendto(msg, (MCAST_GROUP, MCAST_PORT))
        print(f"发送: {msg.decode()}")
        time.sleep(1)
    
    sock.close()

def multicast_receiver():
    """多播接收"""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(('', MCAST_PORT))
    
    # 加入多播组
    mreq = struct.pack('4s4s',
                       socket.inet_aton(MCAST_GROUP),
                       socket.inet_aton('0.0.0.0'))
    sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, mreq)
    
    print(f"监听多播 {MCAST_GROUP}:{MCAST_PORT}")
    
    while True:
        data, addr = sock.recvfrom(4096)
        print(f"收到来自 {addr}: {data.decode()}")

12.4 应用层可靠性 UDP

简单可靠 UDP 协议

"""应用层可靠传输"""
import socket
import struct
import time
import threading

# 协议格式: [序列号 4B][确认号 4B][标志 1B][数据长度 2B][数据]
HEADER_FMT = '!IIBH'
HEADER_SIZE = struct.calcsize(HEADER_FMT)

FLAG_DATA = 0x01
FLAG_ACK = 0x02
FLAG_FIN = 0x04

class ReliableUDP:
    def __init__(self, sock=None):
        self.sock = sock or socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.seq_num = 0
        self.unacked = {}  # {seq: (data, send_time, dest)}
        self.rto = 1.0
        self.lock = threading.Lock()
    
    def send_reliable(self, data: bytes, dest: tuple):
        """可靠发送"""
        header = struct.pack(HEADER_FMT, self.seq_num, 0, FLAG_DATA, len(data))
        packet = header + data
        
        with self.lock:
            self.unacked[self.seq_num] = (packet, time.time(), dest)
            self.sock.sendto(packet, dest)
            sent_seq = self.seq_num
            self.seq_num += 1
        
        return sent_seq
    
    def recv_reliable(self) -> tuple:
        """可靠接收"""
        while True:
            data, addr = self.sock.recvfrom(65535)
            if len(data) < HEADER_SIZE:
                continue
            
            seq, ack, flags, data_len = struct.unpack(HEADER_FMT, data[:HEADER_SIZE])
            payload = data[HEADER_SIZE:HEADER_SIZE + data_len]
            
            if flags & FLAG_DATA:
                # 发送 ACK
                ack_pkt = struct.pack(HEADER_FMT, 0, seq, FLAG_ACK, 0)
                self.sock.sendto(ack_pkt, addr)
                return payload, addr
            elif flags & FLAG_ACK:
                with self.lock:
                    self.unacked.pop(ack, None)
    
    def check_retransmissions(self):
        """检查重传"""
        now = time.time()
        with self.lock:
            for seq in list(self.unacked.keys()):
                pkt, send_time, dest = self.unacked[seq]
                if now - send_time > self.rto:
                    self.sock.sendto(pkt, dest)
                    self.unacked[seq] = (pkt, now, dest)

Go-Back-N 实现

"""Go-Back-N 滑动窗口"""
class GBNSender:
    def __init__(self, sock, window_size=4):
        self.sock = sock
        self.window_size = window_size
        self.base = 0
        self.next_seq = 0
        self.packets = {}  # seq -> (data, dest)
        self.timers = {}
        self.rto = 1.0
    
    def send(self, data: bytes, dest: tuple) -> bool:
        if self.next_seq >= self.base + self.window_size:
            return False  # 窗口已满
        
        seq = self.next_seq
        header = struct.pack(HEADER_FMT, seq, 0, FLAG_DATA, len(data))
        packet = header + data
        self.sock.sendto(packet, dest)
        
        self.packets[seq] = (data, dest)
        self.timers[seq] = time.time()
        self.next_seq += 1
        return True
    
    def on_ack(self, ack_num: int):
        if ack_num >= self.base:
            for seq in range(self.base, ack_num + 1):
                self.packets.pop(seq, None)
                self.timers.pop(seq, None)
            self.base = ack_num + 1
    
    def retransmit_all(self):
        """GBN: 重传窗口中所有段"""
        now = time.time()
        if self.base < self.next_seq:
            oldest = min(self.timers.keys(), default=None)
            if oldest is not None and now - self.timers[oldest] > self.rto:
                for seq in range(self.base, self.next_seq):
                    if seq in self.packets:
                        data, dest = self.packets[seq]
                        header = struct.pack(HEADER_FMT, seq, 0, FLAG_DATA, len(data))
                        self.sock.sendto(header + data, dest)
                        self.timers[seq] = now

12.5 UDP 打洞 (NAT Traversal)

"""
UDP 打洞原理:
两个 NAT 后的主机通过第三方服务器交换地址信息
然后直接通信

    A ──→ NAT_A ──→ Server ←── NAT_B ←── B
                    交换地址
    A ──→ NAT_A ──→ NAT_B ←── B
              直接通信
"""
import socket
import json

# 中继服务器
def relay_server(port=9999):
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.bind(('0.0.0.0', port))
    
    clients = {}  # name -> (ip, port)
    
    while True:
        data, addr = sock.recvfrom(4096)
        msg = json.loads(data.decode())
        
        if msg['type'] == 'register':
            clients[msg['name']] = addr
            print(f"注册: {msg['name']} = {addr}")
        
        elif msg['type'] == 'query':
            target = clients.get(msg['target'])
            if target:
                # 发送双方的外部地址
                sock.sendto(json.dumps({
                    'type': 'peer',
                    'ip': target[0],
                    'port': target[1]
                }).encode(), addr)

# 客户端
def hole_punch_client(name, target, server_addr):
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    
    # 注册
    sock.sendto(json.dumps({
        'type': 'register',
        'name': name
    }).encode(), server_addr)
    
    # 查询对方地址
    sock.sendto(json.dumps({
        'type': 'query',
        'target': target
    }).encode(), server_addr)
    
    data, _ = sock.recvfrom(4096)
    peer = json.loads(data.decode())
    
    # 打洞:向对方发送数据
    peer_addr = (peer['ip'], peer['port'])
    sock.sendto(b'hole punch', peer_addr)
    
    return sock, peer_addr

12.6 UDP 性能优化

"""
UDP 性能优化技巧:
1. 增大缓冲区
2. 控制包大小(避免分片)
3. 批量发送
4. 应用层限速
"""
import socket

def optimize_udp_socket(sock):
    """优化 UDP Socket"""
    # 增大发送缓冲区
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4 * 1024 * 1024)
    
    # 增大接收缓冲区
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4 * 1024 * 1024)

# 安全的 UDP 包大小
SAFE_UDP_SIZE = 1472  # 以太网安全上限

def send_safe(sock, data: bytes, dest: tuple):
    """分片安全发送"""
    if len(data) <= SAFE_UDP_SIZE:
        sock.sendto(data, dest)
    else:
        # 分块发送
        offset = 0
        seq = 0
        while offset < len(data):
            chunk = data[offset:offset + SAFE_UDP_SIZE]
            # 添加自定义头部:[seq 4B][offset 4B][total 4B][data]
            header = struct.pack('!III', seq, offset, len(data))
            sock.sendto(header + chunk, dest)
            offset += SAFE_UDP_SIZE
            seq += 1

12.7 注意事项

⚠️ 没有流控:发送太快可能淹没接收方,应用层需要自己限速

⚠️ 没有拥塞控制:大量 UDP 流量可能挤占网络带宽

⚠️ 包大小限制:超过 MTU 会分片,增大丢包概率

⚠️ NAT 超时:UDP NAT 映射超时快,需要定期保活

⚠️ ICMP 端口不可达:connect 后发送到未监听端口会收到错误

12.8 扩展阅读


下一章13 - QUIC 协议 - QUIC 原理、0-RTT、连接迁移