diff --git a/socket.c b/socket.c index 3551228..4d5c578 100644 --- a/socket.c +++ b/socket.c @@ -737,6 +737,17 @@ init_message_nonaddress(SCK_Message *message) /* ================================================== */ +static int +match_cmsg(struct cmsghdr *cmsg, int level, int type, size_t length) +{ + if (cmsg->cmsg_type == type && cmsg->cmsg_level == level && + (length == 0 || cmsg->cmsg_len == CMSG_LEN(length))) + return 1; + return 0; +} + +/* ================================================== */ + static int process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, SCK_Message *message) @@ -795,7 +806,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, for (cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { #ifdef HAVE_IN_PKTINFO - if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { + if (match_cmsg(cmsg, IPPROTO_IP, IP_PKTINFO, sizeof (struct in_pktinfo))) { struct in_pktinfo ipi; if (message->addr_type != SCK_ADDR_IP) @@ -807,7 +818,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, message->if_index = ipi.ipi_ifindex; } #elif defined(IP_RECVDSTADDR) - if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_RECVDSTADDR) { + if (match_cmsg(cmsg, IPPROTO_IP, IP_RECVDSTADDR, sizeof (struct in_addr))) { struct in_addr addr; if (message->addr_type != SCK_ADDR_IP) @@ -820,7 +831,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, #endif #ifdef HAVE_IN6_PKTINFO - if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { + if (match_cmsg(cmsg, IPPROTO_IPV6, IPV6_PKTINFO, sizeof (struct in6_pktinfo))) { struct in6_pktinfo ipi; if (message->addr_type != SCK_ADDR_IP) @@ -835,7 +846,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, #endif #ifdef SCM_TIMESTAMP - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMP) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMP, sizeof (struct timeval))) { struct timeval tv; memcpy(&tv, CMSG_DATA(cmsg), sizeof (tv)); @@ -844,14 +855,15 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, #endif #ifdef SCM_TIMESTAMPNS - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPNS) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPNS, sizeof (message->timestamp.kernel))) { memcpy(&message->timestamp.kernel, CMSG_DATA(cmsg), sizeof (message->timestamp.kernel)); } #endif #ifdef HAVE_LINUX_TIMESTAMPING #ifdef HAVE_LINUX_TIMESTAMPING_OPT_PKTINFO - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPING_PKTINFO) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPING_PKTINFO, + sizeof (struct scm_ts_pktinfo))) { struct scm_ts_pktinfo ts_pktinfo; memcpy(&ts_pktinfo, CMSG_DATA(cmsg), sizeof (ts_pktinfo)); @@ -860,7 +872,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, } #endif - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPING) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPING, sizeof (struct scm_timestamping))) { struct scm_timestamping ts3; memcpy(&ts3, CMSG_DATA(cmsg), sizeof (ts3)); @@ -868,8 +880,9 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, message->timestamp.hw = ts3.ts[2]; } - if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) || - (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) { + if ((match_cmsg(cmsg, SOL_IP, IP_RECVERR, 0) || + match_cmsg(cmsg, SOL_IPV6, IPV6_RECVERR, 0)) && + cmsg->cmsg_len >= CMSG_LEN(sizeof (struct sock_extended_err))) { struct sock_extended_err err; memcpy(&err, CMSG_DATA(cmsg), sizeof (err)); @@ -882,7 +895,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, } #endif - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_RIGHTS, 0)) { if (!(flags & SCK_FLAG_MSG_DESCRIPTOR) || cmsg->cmsg_len != CMSG_LEN(sizeof (int))) { int i, fd;