#include "../../include/error.h" #include "../../include/mptp.h" #include "../../include/url.h" #include #include #include #include #include bool lm_mptp_init(lm_mptp_t *packet, bool is_request, uint8_t code, bool is_last) { bzero(&packet->header, sizeof(packet->header)); bzero(packet->data, MPTP_DATA_MAX); bzero(packet->host, MPTP_HOST_MAX); if (code > MPTP_CODE_MAX) { lm_error_set(LM_ERR_MPTPBadCode); return false; } packet->header.flags |= (MPTP_VERSION_SUPPORTED << 8); if (is_request) packet->header.flags |= (MPTP_REQUEST << 7); else packet->header.flags |= (MPTP_RESPONSE << 7); packet->header.flags |= (code << 3); if (is_last || is_request) packet->header.flags |= (1 << 2); else packet->header.flags |= (0 << 2); return true; } bool lm_mptp_verify(lm_mptp_t *packet) { if (NULL == packet) { lm_error_set(LM_ERR_ArgNULL); return false; } if (MPTP_FLAGS_VERSION(packet) != MPTP_VERSION_SUPPORTED) { lm_error_set(LM_ERR_MPTPBadVersion); return false; } if (packet->header.data_size > MPTP_DATA_MAX || packet->header.data_size < 0) { lm_error_set(LM_ERR_MPTPBadData); return false; } if (MPTP_FLAGS_CODE(packet) > MPTP_CODE_MAX || MPTP_FLAGS_CODE(packet) < 0) { lm_error_set(LM_ERR_MPTPBadCode); return false; } return true; } bool lm_mptp_socket_opts(int sock){ struct timeval timeout; int flags = 1; bzero(&timeout, sizeof(timeout)); timeout.tv_sec = MPTP_TIMEOUT; timeout.tv_usec = 0; if (setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)) < 0) { lm_error_set(LM_ERR_MPTPSetsockopt); lm_mptp_close(sock); return false; } if (setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, &timeout, sizeof(timeout)) < 0) { lm_error_set(LM_ERR_MPTPSetsockopt); lm_mptp_close(sock); return false; } if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &flags, sizeof(int)) < 0) { lm_error_set(LM_ERR_MPTPSetsockopt); lm_mptp_close(sock); return false; } if (setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, &flags, sizeof(int)) < 0) { lm_error_set(LM_ERR_MPTPSetsockopt); lm_mptp_close(sock); return false; } if (setsockopt(sock, SOL_TCP, TCP_NODELAY, &flags, sizeof(int)) < 0) { lm_error_set(LM_ERR_MPTPSetsockopt); lm_mptp_close(sock); return false; } return true; } int lm_mptp_socket(char *addr, uint16_t port, struct sockaddr *saddr) { if (NULL == addr || NULL == saddr) { lm_error_set(LM_ERR_ArgNULL); return -1; } struct addrinfo hints, *res, *cur; int sock = 0, status = 0, family = -1; bzero(&hints, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = AF_INET; if ((status = getaddrinfo(addr, NULL, &hints, &res)) < 0) { lm_error_set(LM_ERR_MPTPHostFail); return -1; } for (cur = res; cur != NULL; cur = cur->ai_next) { switch (cur->ai_family) { case AF_INET: family = cur->ai_family; struct sockaddr_in *ipv4 = (struct sockaddr_in *)cur->ai_addr; ipv4->sin_port = htons(port); memcpy(saddr, cur->ai_addr, sizeof(struct sockaddr)); break; case AF_INET6: family = cur->ai_family; struct sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)cur->ai_addr; ipv6->sin6_port = htons(port); memcpy(saddr, cur->ai_addr, sizeof(struct sockaddr)); break; } if (family != -1) break; } freeaddrinfo(res); if (family == -1) { lm_error_set(LM_ERR_MPTPHostFail); return -1; } if ((sock = socket(family, SOCK_STREAM, 0)) < 0) { lm_error_set(LM_ERR_MPTPSocketFail); return -1; } return sock; } void lm_mptp_close(int sock) { close(sock); } bool lm_mptp_set_host(lm_mptp_t *packet, char *host) { size_t size = strlen(host); if (size > MPTP_HOST_MAX || size < 0) { lm_error_set(LM_ERR_MPTPBadHost); return false; } // do NOT copy the NULL terminator packet->header.host_size = size; memcpy(packet->host, host, size); return true; } bool lm_mptp_get_host(lm_mptp_t *packet, char *host) { if (packet->header.host_size > MPTP_HOST_MAX || packet->header.host_size < 0) { host = NULL; lm_error_set(LM_ERR_BadHost); return false; } memcpy(host, packet->host, packet->header.host_size); host[packet->header.host_size] = 0; return true; } bool lm_mptp_set_data(lm_mptp_t *packet, char *data, size_t size) { if (size > MPTP_DATA_MAX || size < 0) { lm_error_set(LM_ERR_MPTPBadData); return false; } packet->header.data_size = size; mempcpy(packet->data, data, size); return true; } bool lm_mptp_get_data(lm_mptp_t *packet, char *data) { if (packet->header.data_size > MPTP_DATA_MAX || packet->header.data_size < 0) { data = NULL; lm_error_set(LM_ERR_BadHost); return false; } memcpy(data, packet->data, packet->header.data_size); data[packet->header.data_size] = 0; return true; } void lm_mptp_copy(lm_mptp_t *dst, lm_mptp_t *src) { memcpy(&dst->header, &src->header, sizeof(dst->header)); memcpy(&dst->host, &src->data, sizeof(src->host)); memcpy(&dst->data, &src->data, sizeof(src->data)); } bool lm_mptp_recv(int sock, lm_mptp_t *packet) { if (recv(sock, &packet->header, sizeof(packet->header), 0) <= 0) { if (ETIMEDOUT == errno || EAGAIN == errno) { lm_error_set(LM_ERR_MPTPTimeout); return false; } lm_error_set(LM_ERR_MPTPRecvFail, strerror(errno)); return false; } packet->header.flags = ntohs(packet->header.flags); // packet->header.host_size = ntohs(packet->header.host_size); // packet->header.data_size = ntohs(packet->header.data_size); if (packet->header.host_size <= MPTP_HOST_MAX && packet->header.host_size != 0){ if(recv(sock, packet->host, packet->header.host_size, 0) <= 0){ if (ETIMEDOUT == errno || EAGAIN == errno) { lm_error_set(LM_ERR_MPTPTimeout); return false; } lm_error_set(LM_ERR_MPTPRecvFail, strerror(errno)); return false; } } if (packet->header.data_size <= MPTP_DATA_MAX && packet->header.data_size != 0){ if(recv(sock, packet->data, packet->header.data_size, 0) <= 0){ if (ETIMEDOUT == errno || EAGAIN == errno) { lm_error_set(LM_ERR_MPTPTimeout); return false; } lm_error_set(LM_ERR_MPTPRecvFail, strerror(errno)); return false; } } return true; }