new: add mptp packet verify functions

This commit is contained in:
ngn
2024-06-22 07:03:17 +03:00
parent 19c98b763d
commit c8e45097fe
12 changed files with 275 additions and 110 deletions

View File

@ -17,8 +17,10 @@ char *lm_strerror() {
{.code = LM_ERR_URLHostLarge, .desc = _("URL hostname is too large") },
{.code = LM_ERR_URLPathLarge, .desc = _("URL path is too large") },
{.code = LM_ERR_URLBadHost, .desc = _("URL does not have a valid hostname") },
{.code = LM_ERR_URLBadPort, .desc = _("URL does not have a valid port number") },
{.code = LM_ERR_URLBadPath, .desc = _("URL does not have a valid path") },
{.code = LM_ERR_URLBadPort, .desc = _("URL does not contain a hostname with a valid port number")},
{.code = LM_ERR_BadPort, .desc = _("hostname does not contain a valid port number")},
{.code = LM_ERR_BadHost, .desc = _("hostname is not valid")},
{.code = LM_ERR_URLPortUnknown, .desc = _("URL protocol port number is unknown") },
{.code = LM_ERR_URLEnd, .desc = _("URL is incomplete") },
{.code = LM_ERR_PoolNoSupport, .desc = _("pool does not support the specified protocol") },
@ -30,11 +32,15 @@ char *lm_strerror() {
{.code = LM_ERR_MPTPConnectFail, .desc = _("failed to connect to the MPTP host") },
{.code = LM_ERR_MPTPRecvFail, .desc = _("failed receive MPTP data from host") },
{.code = LM_ERR_MPTPSendFail, .desc = _("failed send MPTP data to host") },
{.code = LM_ERR_MPTPBadChunk, .desc = _("MPTP data chunk is too large") },
{.code = LM_ERR_MPTPBadChunk, .desc = _("MPTP data chunk size is invalid") },
{.code = LM_ERR_MPTPSetsockopt, .desc = _("failed to set MPTP socket options") },
{.code = LM_ERR_MPTPTimeout, .desc = _("MPTP connection timed out") },
{.code = LM_ERR_MPTPBindFail, .desc = _("failed to bind MPTP socket") },
{.code = LM_ERR_ArgNULL, .desc = _("required argument is a NULL pointer") },
{.code = LM_ERR_MPTPNotRequest, .desc = _("not a MPTP request") },
{.code = LM_ERR_MPTPNotResponse, .desc = _("not a MPTP response") },
{.code = LM_ERR_MPTPNotLast, .desc = _("MPTP request last flag is not set") },
{.code = LM_ERR_NoPort, .desc = _("host port not specified") },
};
for (int i = 0; i < sizeof(errors) / sizeof(lm_error_desc_t); i++) {

View File

@ -27,7 +27,7 @@ bool lm_mptp_packet_init(lm_mptp_t *packet, bool is_request, uint8_t code, bool
if (is_request)
packet->header.flags |= (MPTP_REQUEST << 7);
else
packet->header.flags |= (MPTP_REQUEST << 7);
packet->header.flags |= (MPTP_RESPONSE << 7);
packet->header.flags |= (code << 4);
@ -39,6 +39,30 @@ bool lm_mptp_packet_init(lm_mptp_t *packet, bool is_request, uint8_t code, bool
return true;
}
bool lm_mptp_verify(lm_mptp_t *packet){
if (NULL == packet) {
lm_error_set(LM_ERR_ArgNULL);
return false;
}
if (MPTP_FLAGS_VERSION(packet) != MPTP_VERSION_SUPPORTED) {
lm_error_set(LM_ERR_MPTPBadVersion);
return false;
}
if (packet->header.size > MPTP_CHUNK_MAX || packet->header.size < 0) {
lm_error_set(LM_ERR_MPTPBadChunk);
return false;
}
if (MPTP_FLAGS_CODE(packet) > MPTP_CODE_MAX || MPTP_FLAGS_CODE(packet) < 0) {
lm_error_set(LM_ERR_MPTPBadCode);
return false;
}
return true;
}
int lm_mptp_socket(char *addr, uint16_t port, struct sockaddr *saddr) {
if (NULL == addr || NULL == saddr) {
lm_error_set(LM_ERR_ArgNULL);
@ -129,6 +153,18 @@ int lm_mptp_client_connect(char *addr, uint16_t port) {
return sock;
}
bool lm_mptp_client_verify(lm_mptp_t *packet) {
if (!lm_mptp_verify(packet))
return false;
if (MPTP_IS_REQUEST(packet)) {
lm_error_set(LM_ERR_MPTPNotResponse);
return false;
}
return true;
}
bool lm_mptp_client_send(int sock, lm_mptp_t *packet) {
if (NULL == packet) {
lm_error_set(LM_ERR_ArgNULL);
@ -210,6 +246,23 @@ int lm_mptp_server_listen(char *addr, uint16_t port) {
return sock;
}
bool lm_mptp_server_verify(lm_mptp_t *packet) {
if (!lm_mptp_verify(packet))
return false;
if (!MPTP_IS_REQUEST(packet)) {
lm_error_set(LM_ERR_MPTPNotRequest);
return false;
}
if(!MPTP_IS_LAST(packet)){
lm_error_set(LM_ERR_MPTPNotLast);
return false;
}
return true;
}
bool lm_mptp_server_recv(int sock, lm_mptp_t *packet, struct sockaddr *addr) {
if (NULL == packet || NULL == addr) {
lm_error_set(LM_ERR_ArgNULL);

View File

@ -2,8 +2,9 @@
#include "../include/error.h"
#include "../include/mptp.h"
#include "../include/util.h"
#include <stdio.h>
#include <sys/socket.h>
#include <stdlib.h>
#include <stdio.h>
void lm_pool_free(lm_pool_t *pool) {
lm_url_free(&pool->url);
@ -12,23 +13,32 @@ void lm_pool_free(lm_pool_t *pool) {
lm_pool_t *lm_pool_new(char *name, char *url) {
lm_pool_t *pool = malloc(sizeof(lm_pool_t));
if (!lm_url_parse(&pool->url, url)) {
pool->available = true;
pool->name = name;
if (!lm_url_init(&pool->url, url)) {
free(pool);
return NULL;
}
if(pool->url.empty)
return pool;
if (!eq(pool->url.protocol, "mptp")) {
lm_error_set(LM_ERR_PoolNoSupport);
lm_pool_free(pool);
return NULL;
}
pool->available = true;
pool->name = name;
return pool;
}
void lm_pool_test(lm_pool_t *pool) {
if(pool->url.empty){
pool->available = false;
return;
}
lm_mptp_t packet;
lm_mptp_packet_init(&packet, true, MPTP_C2S_PING, true);
@ -48,7 +58,12 @@ void lm_pool_test(lm_pool_t *pool) {
goto end;
}
pool->available = MPTP_FLAGS_TYPE(&packet) == MPTP_S2C_PONG;
if (!lm_mptp_client_verify(&packet)){
pool->available = false;
goto end;
}
pool->available = MPTP_FLAGS_CODE(&packet) == MPTP_S2C_PONG;
end:
lm_mptp_close(sock);
return;
@ -60,7 +75,8 @@ bool lm_ctx_pool_add(lm_ctx_t *ctx, char *name, char *url) {
return false;
pdebug(ctx, __func__, "pool name is %s", pool->name);
pdebug(ctx, __func__, "pool URL is %s://%s:%d%s", pool->url.protocol, pool->url.host, pool->url.port, pool->url.path);
if(!pool->url.empty)
pdebug(ctx, __func__, "pool URL is %s://%s:%d%s", pool->url.protocol, pool->url.host, pool->url.port, pool->url.path);
if (NULL == ctx->pools) {
ctx->pools = pool;
@ -121,3 +137,38 @@ void lm_ctx_pool_test(lm_ctx_t *ctx) {
cur = cur->next;
}
}
bool lm_ctx_pool_serve(lm_ctx_t *ctx, char *addr) {
struct sockaddr saddr;
char *host = NULL;
lm_mptp_t packet;
uint16_t port;
int sock;
if(!parse_host(addr, host, &port))
return false;
if(port == 0){
lm_error_set(LM_ERR_NoPort);
return false;
}
if((sock = lm_mptp_server_listen(addr, port)) < 0)
return false;
while(lm_mptp_server_recv(sock, &packet, &saddr)){
if(!lm_mptp_verify(&packet)){
pdebug(ctx, __func__, "skipping invalid packet: %s", lm_strerror());
continue;
}
switch (MPTP_FLAGS_CODE(&packet)) {
case MPTP_C2S_PING:
lm_mptp_packet_init(&packet, false, MPTP_S2C_PONG, true);
lm_mptp_server_send(sock, &packet, &saddr);
break;
}
}
return true;
}

View File

@ -2,6 +2,7 @@
#include "../include/error.h"
#include "../include/util.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@ -14,7 +15,7 @@ typedef enum lm_state {
URL_PATH_3 = 3,
} lm_state_t;
uint16_t lm_url_default(char *protocol) {
uint16_t lm_url_default_port(char *protocol) {
if (eq(protocol, "ftp"))
return 21;
else if (eq(protocol, "ftps"))
@ -28,13 +29,20 @@ uint16_t lm_url_default(char *protocol) {
return 0;
}
bool lm_url_parse(lm_url_t *url, char *str) {
bool lm_url_init(lm_url_t *url, char *str) {
// clear out every variable
bzero(url->protocol, sizeof(url->protocol));
url->empty = true;
url->host = NULL;
url->path = NULL;
url->port = 0;
if(NULL == str)
return true;
// str is not NULL
url->empty = false;
// stores the string size
size_t strl = 0, index = 0, pos = 0;
@ -46,7 +54,7 @@ bool lm_url_parse(lm_url_t *url, char *str) {
}
lm_state_t state = URL_PROTOCOL_0;
char buffer[strl + 1], *save; // temporary buffer, strok_r save pointer
char buffer[strl + 1]; // temporary buffer, strok_r save pointer
bool ret = false; // return value
// clear out the temporary buffer
@ -182,30 +190,34 @@ bool lm_url_parse(lm_url_t *url, char *str) {
goto end;
}
if (strtok_r(url->host, ":", &save) == NULL) {
lm_error_set(LM_ERR_URLBadHost);
goto end;
}
char *portc = strtok_r(NULL, ":", &save);
if (NULL == portc) {
url->port = lm_url_default(url->protocol);
if (url->port == 0) {
if(parse_host(url->host, url->host, &url->port)){
if(url->port != 0){
ret = true;
goto end;
}
url->port = lm_url_default_port(url->protocol);
if(url->port == 0){
lm_error_set(LM_ERR_URLPortUnknown);
goto end;
}
ret = true;
goto end;
}
int port = atoi(portc);
if (port <= 0 || port > UINT16_MAX) {
switch (lm_error()) {
case LM_ERR_BadHost:
lm_error_set(LM_ERR_URLBadHost);
break;
case LM_ERR_BadPort:
lm_error_set(LM_ERR_URLBadPort);
goto end;
}
break;
url->port = port;
ret = true;
default:
break;
}
end:
if (!ret && NULL != url->host)

View File

@ -1,5 +1,7 @@
#include "../include/util.h"
#include "../include/types.h"
#include "../include/error.h"
#include <error.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
@ -43,3 +45,27 @@ bool contains(char *str, char s) {
return true;
return false;
}
bool parse_host(char *addr, char *host, uint16_t *port){
char *save = NULL, *portc = NULL;
int portint = 0;
if((host = strtok_r(addr, ":", &save)) == NULL){
lm_error_set(LM_ERR_BadHost);
return false;
}
if((portc = strtok_r(NULL, ":", &save)) == NULL){
*port = 0;
return true;
}
portint = atoi(portc);
if(portint <= 0 || portint > UINT16_MAX){
lm_error_set(LM_ERR_BadPort);
return false;
}
*port = portint;
return true;
}