/***************************************************************************
 *                                                                         *
 *   Copyright (C) 2025 by David C. Rankin                                 *
 *   swdev@3111skyline.com                                                 *
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License.        *
 *                                                                         *
 *   This program is distributed in the hope that it will be useful,       *
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
 *   GNU General Public License for more details.                          *
 *                                                                         *
 *   You should have received a copy of the GNU General Public License     *
 *   Version 2.0 along with this program; if not, write to the             *
 *   Free Software Foundation, Inc.,                                       *
 *   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
 *                                                                         *
 *   Online at:                                                            *
 *   http://www.gnu.org/licenses/old-licenses/gpl-2.0.en.html              *
 *                                                                         *
 ***************************************************************************/
/*
  compile:

$ gcc -Wall -Wextra -pedantic -Wshadow -std=c23 -O3 -o cidrnew cidrnew.c

  examples:

$ ./cidrnew 170.84.0.0 - 170.84.227.255

 cidr : 170.84.0.0/17
      : 170.84.128.0/18
      : 170.84.192.0/19
      : 170.84.224.0/22

or single subnet range example:

$ ./cidrnew 170.84.0.0 - 170.84.127.255

 cidr : 170.84.0.0/17

or split subnet examples:

$ cidr 76.96.0.0 - 76.159.255.255

 cidr : 76.96.0.0/11
      : 76.96.128.0/11

$ cidr 189.105.128.0 - 189.107.255.255

 cidr : 189.105.128.0/17
      : 189.106.0.0/15

or differing network ID and split subnet:

$ ./cidrnew 172.16.37.0 - 172.16.40.255

 cidr : 172.16.37.0/24
      : 172.16.38.0/23
      : 172.16.40.0/24

*/

#include <stdio.h>
#include <string.h>
#include <inttypes.h>
#include <arpa/inet.h>
#include <limits.h>


#define IPv4BYTES   4
#define CIDRARSZ   16

#if defined(__LP64__) || defined(_LP64)
# define BITS_PER_LONG 64
#else
# define BITS_PER_LONG 32
#endif


/** @struct ip_t
 *
 *  @brief single 32-bit value/octets struct providing representation of
 *  an IP address as a single unsigned value or in dotted-quad array form.
 */
typedef struct {
  union {
    uint8_t iparr[IPv4BYTES];
    uint32_t ip;
  };
} ip_t;

/** @struct cidr_t
 *
 *  @brief this struct combines two struct ip_t with an 8-bit unsigned value
 *  to allow a CIDR block to be described with start, end and routing prefix.
 */
typedef struct {
  ip_t start, end;
  uint8_t rpfx;
} cidr_t;


/** @brief function computes the MSB (most significant bit) of the word
 *  parameter used compute host bits from number of ip addresses in range.
 *  @param word value to return the MSB for.
 *  @return the most significant bit for word is returned.
 */
static __always_inline unsigned long msb (unsigned long word)
{
    if (!word) return 0;

    int num = BITS_PER_LONG - 1;

#if BITS_PER_LONG == 64
    if (!(word & (~0ul << 32))) {
        num -= 32;
        word <<= 32;
    }
#endif
    if (!(word & (~0ul << (BITS_PER_LONG-16)))) {
        num -= 16;
        word <<= 16;
    }
    if (!(word & (~0ul << (BITS_PER_LONG-8)))) {
        num -= 8;
        word <<= 8;
    }
    if (!(word & (~0ul << (BITS_PER_LONG-4)))) {
        num -= 4;
        word <<= 4;
    }
    if (!(word & (~0ul << (BITS_PER_LONG-2)))) {
        num -= 2;
        word <<= 2;
    }
    if (!(word & (~0ul << (BITS_PER_LONG-1))))
        num -= 1;

    return num;
}


/**
 *  @brief test whether v is a power of 2.
 *  @param v value to test.
 *  @return returns 1 if v is a power of 2, 0 otherwise.
 */
static inline int ispow2ui (uint32_t v)
{
  return v && !(v & (v - 1));
}


/** @brief returns power of 2 less than or equal to v.
 *  @param v value to find power of 2 for.
 *  @return returns power of 2 less than or equal to v.
 */
uint32_t pwr2_floor (uint32_t v)
{
  int nbits = 0;

  if (v == 0 || ispow2ui (v)) {
    return v;
  }

  while (1u << nbits < v) {
    nbits += 1;
  }

  return 1u << (nbits - 1);
}


/**
 *  @brief fills 's' with a nul-terminated string containing the formatted
 *  binary representation of 'v' zero padded to 'sz' characters with
 *  separator char 'sep' placed every 'seppos' characters from right to left.
 *  @param s storage for the nul-terminated binary representation.
 *  @param strsz the size of the storage provided by s.
 *  @param v value to compute binary representation of.
 *  @param sz the number of characters in the binary representation, not
 *  including the separator characters (e.g. 8 gives 01100001 for 97).
 *  @param seppos separator positoin every seppos chars from right to left.
 *  @param sep the separator character to use (e.g. '-' or '.'  etc..).
 *  @return returns pointer to string contianing formatted binary
 *  representation of of v, NULL if strsz is 0.
 *
 *  @note binfmtstr (s, 16, 138, 8, 4, '-') yields "1000-1010"in s.
 */
char *binfmtstr (char *s, const uint32_t strsz, const uint64_t v,
                  const uint8_t sz, const uint8_t seppos, const char sep)
{
  /* compute required length */
  const uint32_t len = seppos ? (sz + (sz + seppos - 1) / seppos - 1) : sz;
  char *p = s + len;      /* advance p to nul-terminating char pos */

  *p = 0;                 /* nul-terminate s at p */

  /* if string-size 0 or length < requested representation of bits */
  if (strsz == 0 || len < sz) {
    return NULL;
  }

  /* loop sz times */
  for (uint8_t i = 0; i < sz; i++) {
    p--;                  /* decrement pointer address */
    /* if not 1st && seppos non-zero && i corresponds to seppos */
    if (i > 0 && seppos > 0 && i % seppos == 0) {
      *p-- = sep;         /* prepend sep char, decrement pointer */
    }
    /* prepend bit representation */
    *p = (v >> i & 1) ? '1' : '0';
  }

  return p;   /* return pointer to beginning of representation in s */
}


/** @brief convert ipv4 octects in ip to character string in buf.
 *  @param buf storage to hold converted string 16 bytes minimum to
 *  hold the nul-terminate string representation of the IP address.
 *  @param ip the struct ip_t to convert octects to string from.
 *  @return pointer to buf as convenience to allows use by assigning the
 *  return value.
 */
const char *str_ip_octets (char *buf, ip_t *ip)
{
  sprintf (buf, "%hhu.%hhu.%hhu.%hhu",
          ip->iparr[0], ip->iparr[1],
          ip->iparr[2], ip->iparr[3]);

  return buf;
}


/**
 *  @brief print network ID, start blk, end blk and end IP in
 *  dotteg-quad and binary
 */
void prn_nwidblk (ip_t *nwid, ip_t *sblk, ip_t *eblk, ip_t *e)
{
  char  bufnwid[16], bufbs[16], bufbe[16], bufe[16],
        binnwid[64], binbs[64], binbe[64], bine[64];

  printf ("\nnwid : %-15s  %s\nblks : %-15s  %s\nblke : %-15s  %s\n"
          "end  : %-15s  %s\n",
          str_ip_octets (bufnwid, nwid),
          binfmtstr (binnwid, 64, htonl (nwid->ip), 32, 8, '.'),
          str_ip_octets (bufbs, sblk),
          binfmtstr (binbs, 64, htonl (sblk->ip), 32, 8, '.'),
          str_ip_octets (bufbe, eblk),
          binfmtstr (binbe, 64, htonl (eblk->ip), 32, 8, ','),
          str_ip_octets (bufe, e),
          binfmtstr (bine, 64, htonl (e->ip), 32, 8, '.'));
}


/**
 *  @brief print single ip_t in dotted-quad and binary
 */
void prn_ip_w_bin (ip_t *ip, const char *label)
{
  char  buf[16] = "",
        bin[48] = "";

  printf ("\n%.4s : %-15s  %s\n", label, str_ip_octets (buf, ip),
          binfmtstr (bin, sizeof bin, htonl (ip->ip), 32, 8, '.'));
}


/** @brief get_no_of_addr_in_range() computes the number of IP addresses
 *  in the range provided by the start and end IP parameters.
 *  @param s start IP address in range.
 *  @param e end IP address in range.
 *  @return the total number of addresses in range are returned which
 *  includes the start and end IP, otherwise 0 is returned on error.
 */
uint32_t get_no_of_addr_in_range (ip_t *s, ip_t *e)
{
  char bufs[16] = "";
  int64_t rng = htonl (e->ip) - htonl (s->ip);

  if (rng == 0) {
    fprintf (stderr, "warning: duplicate IP (%s) - range of 1\n",
            str_ip_octets (bufs, s));
  }
  else if ((int32_t)rng < 0) {
    ip_t tmp = *s;
    *s = *e;
    *e = tmp;
    rng = -rng;
  }

  rng += 1;         /* add 1 for inclusive range */

  if (rng > __UINT32_MAX__) {
    fputs ("error: computed range exceeds 32-bit.\n", stderr);
    return 0;
  }

  return (uint32_t)rng;
}


/**
 *  @brief computes the network ID and subnet mask given a start IP and
 *  the number of IP addresses in the range.
 *  @param nwid pointer to storage to hold the network ID.
 *  @param snmask pointer to storage to hold the subnet mask.
 *  @param s starting IP address.
 *  @param naddr number of address in range.
 *  @return both nwid and snmask values are made available to the caller
 *  through the pointer parameters, a pointer to the nwid is returned for
 *  convenience.
 */
ip_t *ipv4_nwid_snmask (ip_t *nwid, ip_t *snmask, ip_t *s, uint32_t naddr)
{
  uint8_t  hbits = msb (naddr);   /* MSB provides host bits */

  /* zero host bits for subnet mask, convert to network byte order */
  snmask->ip = (0xffffffff >> hbits) << hbits;
  snmask->ip = htonl (snmask->ip);

  /* compute network ID from start IP and subnet mask */
  nwid->ip = (s->ip) & snmask->ip;

  return nwid;
}


/**
 *  @brief void function providing output if a CIDR validation check fails,
 *  the error messages notes a CIDR that is invalid for the range provided
 *  and provides the start and ending IP values defining the range.
 */
void err_ipv4_validate_cidr (ip_t *nwid, cidr_t *cidr)
{
  char buf[16] = "";
  printf ("error: CIDR invalid for range\n  nwid  : %s\n",
          str_ip_octets (buf, nwid));
  printf ("  start : %s\n", str_ip_octets (buf, &cidr->start));
  printf ("  end   : %s\n", str_ip_octets (buf, &cidr->end));
}


/**
 *  @brief function to check that computed CIDR contians all and only IP
 *  addresses within the described subnet.
 *  @param nwid storage to hold the computed Network ID.
 *  @param cidr computed CIDR to be tested.
 *  @return returns 0 on success, -1 otherwise.
 */
int ipv4_validate_cidr (ip_t *nwid, cidr_t *cidr)
{
  uint32_t  blkaddr = 1u << (32 - cidr->rpfx),
            naddr = 0;
  ip_t      snmask = { .ip = 0 };

  ipv4_nwid_snmask (nwid, &snmask, &cidr->start, blkaddr);

  if ((naddr = get_no_of_addr_in_range (nwid, &cidr->end)) > blkaddr) {
    return -1;
  }

  return 0;
}


/**
 *  @brief function to compute CIDR given start and end IP addresses, the
 *  array of cidr_t is populated with each CIDR needed to describe range up
 *  to a maximum of nelem array elements.
 *  @param cidr array of cidr_t to hold computed CIDR for range.
 *  @param nelem maximum number of elements in cidr.
 *  @param s start IP address in range.
 *  @param e end IP address in range.
 *  @return returns the number of elements added to cidr array, 0 on error.
 */
uint8_t ipv4_to_cidr (cidr_t *cidr, uint8_t nelem, ip_t *s, ip_t *e)
{
  uint32_t  naddr = 0;
  uint8_t   n = 0;
  ip_t      nwid = { .ip = 0 };

  /* get number of IP addresses in range */
  if ((naddr = get_no_of_addr_in_range (s, e)) == 0) {
    return 0;
  }

  while (n < nelem && ispow2ui (naddr) == 0) {
    uint8_t   hbits_block = msb (naddr);
    uint32_t  pwr2floor = 1u << hbits_block;

    cidr[n].start   = *s;
    cidr[n].end.ip  = ntohl (htonl (s->ip) + pwr2floor - 1);
    cidr[n].rpfx    = 32 - hbits_block;

    /* TEMP test */
    if (ipv4_validate_cidr (&nwid, &cidr[n]) < 0) {
      err_ipv4_validate_cidr (&nwid, &cidr[n]);
    }

    s->ip = ntohl (htonl (cidr[n].end.ip) + 1);

    n += 1;

    /* get number of IP addresses in next range */
    if ((naddr = get_no_of_addr_in_range (s, e)) == 0) {
      return n;
    }
  }

  cidr[n].start = *s;
  cidr[n].end   = *e;
  cidr[n].rpfx  = 32 - msb (naddr);

  /* TEMP test */
  if (ipv4_validate_cidr (&nwid, &cidr[n]) < 0) {
    err_ipv4_validate_cidr (&nwid, &cidr[n]);
  }

  n += 1;

  return n;
}


/**
 *  @brief wrapper function to ipv4_to_cidr() that separates address range
 *  into subnets capable of being described by a valid CIDR.
 *  @param cidr array of cidr_t to hold computed CIDR for range.
 *  @param nelem maximum number of elements in cidr.
 *  @param s start IP address in range.
 *  @param e end IP address in range.
 *  @return returns the number of elements added to cidr array, 0 on error.
 */
uint8_t ipv4_to_cidr_block (cidr_t *cidr, uint8_t nelem, cidr_t *src)
{
  uint8_t   n = 0;
  uint32_t  naddr     = 0,
            nblkaddr  = 0;
  ip_t      nwid      = { .ip = 0 },
            snmask    = { .ip = 0 },
            sblk      = src->start,
            eblk      = src->end;

  do {
    eblk = src->end;    /* initialize end of block to end IP */

    /* get total number of addresses in range of start block to end */
    if ((naddr = get_no_of_addr_in_range (&sblk, &eblk)) == 0) {
      return 0;
    }

    /* get network ID and subnet mask to compute addresses in subnet */
    ipv4_nwid_snmask (&nwid, &snmask, &sblk, naddr);

    /* compute addresses in subnet */
    nblkaddr = htonl(sblk.ip | ~snmask.ip) - htonl(sblk.ip) + 1;

    /* number of addresses in subnet must be a power of two, if not
     * find the power of two less than computed number of addresses and
     * reduce by number of addresses between the network ID and start ip.
     */
    if (ispow2ui (nblkaddr) == 0) {
      uint32_t pwr2addr = pwr2_floor (nblkaddr);
      nblkaddr = pwr2addr - (htonl (sblk.ip) - htonl (nwid.ip));
    }

    /* if network ID not start IP, split range into subnet for
     * first CIDR block. Add addresses in subnet to start IP
     */
    if (nwid.ip != sblk.ip || nblkaddr != naddr) {
      /* split range into first block */
      eblk.ip = ntohl (htonl (sblk.ip) + nblkaddr - 1);

#ifdef DEBUG
      /* tmp IP with binary representation */
      prn_nwidblk (&nwid, &sblk, &eblk, &src->end);
#endif

      /* fill cidrarr with blocks describing current subnet, increment n */
      n += ipv4_to_cidr (cidr, nelem, &sblk, &eblk);

      /* set start and end address for next block in range */
      sblk.ip = ntohl (htonl (eblk.ip) + 1);
      eblk.ip = ntohl (htonl (sblk.ip) + naddr - (nblkaddr + 1));
    }
  } while (eblk.ip != src->end.ip);

#ifdef DEBUG
  /* tmp IP with binary representation */
  prn_nwidblk (&nwid, &sblk, &eblk, &src->end);
#endif

  /* append CIDR blocks to cidrarr for remainder of range */
  n += ipv4_to_cidr (&cidr[n], nelem - n, &sblk, &src->end);

  return n;
}


/**
 *  @brief parse command line CIDR or IP string input into cidr_t,
 *  in each case, the IP is parsed into the start member ip_t and
 *  routing prefix is only set if 's' contains '/' character followed
 *  by an unsigned value.
 *  @param cidr pointer to storage to hold IP in .start.ip and/or CIDR.
 *  @param s character string holding text to parse into IP or CIDR.
 *  @return if valid CIDR found, the return from sscanf of 5 is returned,
 *  for a single IP, the scanf return of 4 is provided, any other value
 *  represents a failure to parse s into value IP or CIDR.
 */
int parse_ip_octets (cidr_t *cidr, char *s)
{
  cidr_t tmp = { .start.ip = 0 };
  int n = 0;

  /* test if '/' before routing prefix present, parse full CIDR */
  if (strchr (s, '/') && (n = sscanf (s, "%hhu.%hhu.%hhu.%hhu/%hhu",
                                      &tmp.start.iparr[0], &tmp.start.iparr[1],
                                      &tmp.start.iparr[2], &tmp.start. iparr[3],
                                      &tmp.rpfx)) == 5) {
    *cidr = tmp;
    return n;
  }

  /* if no routing prefix, parse as IP */
  if ((n = sscanf (s, "%hhu.%hhu.%hhu.%hhu",
                  &tmp.start.iparr[0], &tmp.start.iparr[1],
                  &tmp.start.iparr[2], &tmp.start.iparr[3])) == 4) {
    *cidr = tmp;
  }

  return n;
}


/**
 *  @brief iterate over command-line arguments calling parse_ip_octets()
 *  to parse arguments into IP addresses or valid CIDR depending on the
 *  command line provided.
 *  @param cidr storage for provided CIDR in .start and .rpfx members or
 *  if a pair of IPs describing range is provide they are parsed into
 *  the .start and .end members, respectively.
 *  @param argc argument count from main().
 *  @param argv argument vector from main().
 *  @return if CIDR parsed, 5 is returned, if pair of IPs parsed, 2 is
 *  returned, -1 is retunred on error.
 */
int parse_cmdline_addrs (cidr_t *cidr, int argc, char **argv)
{
  int cnt = 0;              /* counter for IPs found */

  /* iterate over commend line parsing IP addresses */
  for (int i = 1; i < argc && cnt < 2; i++) {
    cidr_t tmp = { .start.ip = 0 };
    int parsed = parse_ip_octets (&tmp, argv[i]);
    /* parse IP from argv[i] */
    if (parsed == 4) {
      if (cnt == 0) {       /* if first IP assign to start struct */
        cidr->start = tmp.start;
        cnt += 1;           /* only increment after good IP found */
      }
      else {                /* if second IP assign to end struct */
        cidr->end = tmp.start;
        cnt += 1;           /* same */
      }
    }
    /* TODO allow parsing multiple CIDR into range to recreate original
     * range made up of multiple CIDR. will need to pass cidrarr, n, and
     * max parameters. consider creating cidrarr_t type containing
     * cidrarr, nelem and max for self-contained type. can rework all
     * code to use instead of separate cidrarr.
     */
    else if (parsed == 5) {
      *cidr = tmp;
      return parsed;
    }
  }

  if (cnt < 2) {  /* validate cnt, output error on failure */
    fprintf (stderr, "error: invalid IP only (%d) found\n", cnt);
    return -1;
  }

  return cnt;     /* return count for validation in caller */
}


int main (int argc, char **argv) {

  char bufs[16] = "",
       bufe[16] = "";

  int n       = 0,    /* number of CIDR statements describing range */
      nparsed = 0;    /* return from command line parse of IPs or CIDR */
  /* struct to hold start/end IP addresses or CIDR with octets zero
   * initiaized.
   */
  cidr_t cidr = { .start.ip = 0 };
  /* array of cidr_t to hold multiple CIDR necessary to describe range */
  cidr_t  cidrarr[CIDRARSZ] = {{ .start.ip = 0 }};

  if (argc < 2) { /* validate at least one argument given */
    char *p = strrchr (argv[0], '/');
    printf ("usage : ./%s  [startIP endIP] [address/CIDR]\n",
            p ? p + 1 : argv[0]);
    return 1;
  }

  /* parse command line */
  nparsed = parse_cmdline_addrs (&cidr, argc, argv);

  /* handle CIDR input to address range  or address range to CIDR */
  if (nparsed == 5) {
    cidr.end.ip = ntohl (htonl(cidr.start.ip) + (1u << (32 - cidr.rpfx)) - 1);
    printf ("\naddrs : %s - %s\n",
            str_ip_octets (bufs, &cidr.start),
            str_ip_octets (bufe, &cidr.end));
  }
  else if (nparsed == 2) {

    /* fill cid_t array with CIDR statements assigning number required to n */
    n = ipv4_to_cidr_block (cidrarr, CIDRARSZ, &cidr);

    /* iterate over each CIDR outputting result */
    for (uint8_t i = 0; i < n; i++) {
      /* convert start and end addresses to string */
      str_ip_octets (bufs, &cidrarr[i].start);
#ifdef DEBUG
      str_ip_octets (bufe, &cidrarr[i].end);
      /* output start, end, no  of addresses in range and routing prefix */
      printf ("\nstart : %s\nend   : %s\nrange : %u\nrpfx  : %u\ncidr  : %s/%u\n",
              bufs, bufe, 1u << (32 - cidrarr[i].rpfx), cidrarr[i].rpfx,
              bufs, cidrarr[i].rpfx);
#else
      if (i > 0) {
        printf ("      : %s/%u\n", bufs, cidrarr[i].rpfx);
      }
      else {
        printf ("\n cidr : %s/%u\n", bufs, cidrarr[i].rpfx);
      }
#endif
    }
  }
  else {
    fputs ("error: invalid arguments\n"
           "usage: ./prog  [startIP endIP] [address/CIDR]\n", stderr);
    return 1;
  }

  putchar ('\n');
}
