#include "../../include/mptp.h"
#include "../../include/error.h"
#include "../../include/util.h"

#include <sys/stat.h>
#include <errno.h>
#include <error.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>

bool lm_mptp_sendfile(int sock, char *path, lm_mptp_transfer_callback_t callback, void *data){
  if (NULL == path){
    lm_error_set(LM_ERR_ArgNULL);
    return false;
  }

  lm_mptp_t packet;
  bool ret = false;

  FILE  *file = fopen(path, "r");
  size_t total = 0, current = 0;
  size_t read = 0;
  struct stat st;
  int size = -1;

  if(NULL == file){
    pdebug(__func__, "failed to open file: %s", path);
    lm_error_set(LM_ERR_SendOpenFail);
    goto end_1;
  }

  if(fstat(fileno(file), &st)<0){
    pdebug(__func__, "failed to stat file: %s", path);
    lm_error_set(LM_ERR_SendStatFail);
    goto end_1;
  }

  total = st.st_size;

  lm_mptp_init(&packet, false, MPTP_S2C_COOL, false);
  if((size = snprintf(packet.data, MPTP_DATA_MAX, "%lu", st.st_size)) <= 0){
    pdebug(__func__, "snprintf for stat size failed: %s", path);
    lm_error_set(LM_ERR_SendSnprintfFail);
    goto end_1;
  }

  packet.header.data_size = size;
  lm_mptp_server_send(sock, &packet);

  if(NULL != callback && !callback(path, current, total, data))
    goto end_1;

  // clear the packet
  lm_mptp_init(&packet, false, MPTP_S2C_COOL, false);

  while ((read = fread(packet.data, 1, MPTP_DATA_MAX, file)) > 0) {
    packet.header.data_size = read;
    pdebug(__func__, "sending the %lu/%lu of %s", current+read, total, path);

    if(!lm_mptp_server_send(sock, &packet)){
      pdebug(__func__, "failed to send packet for %s (left at %lu/%lu): %s", path, current, total, lm_strerror());
      goto end_2;
    }

    current += read;

    if(NULL != callback && !callback(path, current, st.st_size, data))
      goto end_1;

    lm_mptp_init(&packet, false, MPTP_S2C_COOL, false);
  }

  if(current != total){
    pdebug(__func__, "failed read the entire file (left at %lu/%lu): %s", current, total, path);
    lm_error_set(LM_ERR_SendReadFail);
    goto end_1;
  }

  pdebug(__func__, "completed sending %s, sending last packet", path);
  lm_mptp_init(&packet, false, MPTP_S2C_COOL, true);
  lm_mptp_server_send(sock, &packet);
  ret = true;
  goto end_2;

end_1:
  lm_mptp_init(&packet, false, MPTP_S2C_BRUH, true);
  lm_mptp_server_send(sock, &packet);
end_2:
  if(NULL != file)
    fclose(file);
  return ret;
}

bool lm_mptp_recvfile(int sock, char *path, lm_mptp_transfer_callback_t callback, void *data){
  if(NULL == path){
    lm_error_set(LM_ERR_ArgNULL);
    return false;
  }

  if(unlink(path) < 0 && errno != ENOENT){
    lm_error_set(LM_ERR_RecvDelFail);
    return false;
  }

  FILE *file = fopen(path, "a");
  size_t total = 0, current = 0;
  bool ret = false;
  lm_mptp_t packet;

  if(NULL == file){
    pdebug(__func__, "failed to open file: %s", path);
    lm_error_set(LM_ERR_RecvOpenFail);
    goto end;
  }

  if(!lm_mptp_client_recv(sock, &packet) || !lm_mptp_client_verify(&packet))
    goto end;

  if(MPTP_FLAGS_CODE(&packet) != MPTP_S2C_COOL){
    lm_error_set(LM_ERR_RecvBadCode);
    goto end;
  }

  char buffer[MPTP_DATA_MAX+1];
  lm_mptp_get_data(&packet, buffer);
  total = atol(buffer);

  if(total <= 0){
    lm_error_set(LM_ERR_RecvBadSize);
    goto end;
  }

  if(NULL != callback)
    if(!callback(path, current, total, data))
      goto end;

  while(lm_mptp_client_recv(sock, &packet)){
    if(!lm_mptp_client_verify(&packet)){
      pdebug(__func__, "failed to verify the packet for %s (%lu/%lu)", path, current, total);
      goto end;
    }

    if(MPTP_FLAGS_CODE(&packet) != MPTP_S2C_COOL){
      pdebug(__func__, "server responded with bad status code for %s (%lu/%lu)", path, current, total);
      lm_error_set(LM_ERR_RecvBadCode);
      goto end;
    }

    if(MPTP_IS_LAST(&packet)){
      pdebug(__func__, "received the last packet for %s (%lu/%lu)", path, current, total);
      break;
    }

    if(fwrite(packet.data, 1, packet.header.data_size, file)==0){
      pdebug(__func__, "failed to write received data for %s (%lu/%lu)", path, current, total);
      lm_error_set(LM_ERR_RecvWriteFail);
      goto end;
    }

    current += packet.header.data_size;

    if(NULL != callback)
      if(!callback(path, current, total, data))
        goto end;
  }

  if(current != total){
    if(MPTP_IS_LAST(&packet))
      pdebug(__func__, "failed to receive the entire file (left at %lu/%lu), got the last packet: %s", current, total, path);
    else
      pdebug(__func__, "failed to receive the entire file (left at %lu/%lu): %s (%s)", current, total, path, lm_strerror());
    lm_error_set(LM_ERR_RecvNotCompleted);
    goto end;
  }

  ret = true;
end:
  if(NULL != file)
    fclose(file);
  return ret;
}