#include <stddef.h>
#include <stdio.h>
#include <string.h>
#include <stdbool.h>

#define BITS_PER_LONG 64
#define unlikely(x) x
#define UL(x) (x##UL)
#define GENMASK(h, l) \
	(((~UL(0)) - (UL(1) << (l)) + 1) & \
	 (~UL(0) >> (BITS_PER_LONG - 1 - (h))))

#define BIT_WORD(nr)		((nr) / BITS_PER_LONG)
#define BITMAP_FIRST_WORD_MASK(start) (~0UL << ((start) & (BITS_PER_LONG - 1)))
#define BITMAP_LAST_WORD_MASK(nbits) (~0UL >> (-(nbits) & (BITS_PER_LONG - 1)))


void my_bitmap_write(unsigned long *map, unsigned long value,
		     unsigned long start, unsigned long nbits)
{
	unsigned long w, end;

	if (unlikely(nbits == 0))
		return;

	value &= GENMASK(nbits - 1, 0);

	map += BIT_WORD(start);
	start %= BITS_PER_LONG;
	end = start + nbits - 1;

	w = *map & (end < BITS_PER_LONG ? ~GENMASK(end, start) : BITMAP_LAST_WORD_MASK(start));
	*map = w | (value << start);

	if (end < BITS_PER_LONG)
		return;

	w = *++map & BITMAP_LAST_WORD_MASK(end + 1 - BITS_PER_LONG);
	*map = w | (value >> (BITS_PER_LONG - start));
}

void bitmap_write(unsigned long *map, unsigned long value,
		  unsigned long start, unsigned long nbits)
{
	size_t index = BIT_WORD(start);
	unsigned long offset = start % BITS_PER_LONG;
	unsigned long space = BITS_PER_LONG - offset;

	if (unlikely(!nbits))
		return;
	value &= GENMASK(nbits - 1, 0);
	if (space >= nbits) {
		map[index] &= ~(GENMASK(nbits - 1, 0) << offset);
		map[index] |= value << offset;
		return;
	}
	map[index] &= ~BITMAP_FIRST_WORD_MASK(start);
	map[index] |= value << offset;
	map[index + 1] &= ~BITMAP_LAST_WORD_MASK(start + nbits);
	map[index + 1] |= (value >> space);
}

void bitmap_write_new(unsigned long *map, unsigned long value,
		      unsigned long start, unsigned long nbits)
{
	unsigned long offset;
	unsigned long space;
	size_t index;
	bool fit;

	if (unlikely(!nbits))
		return;

	value &= GENMASK(nbits - 1, 0);
	offset = start % BITS_PER_LONG;
	space = BITS_PER_LONG - offset;
	index =	BIT_WORD(start);
	fit = space >= nbits;

	map[index] &= (fit ? (~(GENMASK(nbits - 1, 0) << offset)) : ~BITMAP_FIRST_WORD_MASK(start));
	map[index] |= value << offset;
	if (fit)
		return;

	map[index + 1] &= BITMAP_FIRST_WORD_MASK(start + nbits);
	map[index + 1] |= (value >> space);
}

void bitmap_write_new_shift(unsigned long *map, unsigned long value,
		      unsigned long start, unsigned long nbits)
{
	unsigned long offset;
	unsigned long space;
	size_t index;
	bool fit;

	if (unlikely(!nbits))
		return;

	value &= GENMASK(nbits - 1, 0);
	offset = start % BITS_PER_LONG;
	space = BITS_PER_LONG - offset;
	index =	BIT_WORD(start);
	fit = space >= nbits;

	map[index] &= (fit ? ~(GENMASK(nbits - 1 + offset, offset)) : ~BITMAP_FIRST_WORD_MASK(start));
	map[index] |= value << offset;
	if (fit)
		return;

	map[index + 1] &= BITMAP_FIRST_WORD_MASK(start + nbits);
	map[index + 1] |= (value >> space);
}

#define MAPSIZE 3

void print_map(unsigned long *map, const char *c)
{
	int i;

	printf("%s: ", c);
	for (i = 0; i < MAPSIZE; i++) {
		printf("%lx ", map[i]);
	}
	printf("\n");
}

#define COMPARE(fn1, fn2)		\
do {					\
	unsigned long one[MAPSIZE], two[MAPSIZE], three[MAPSIZE];\
	int res;						\
								\
	memset(one, 0, sizeof(unsigned long)*MAPSIZE);		\
	memset(two, 0, sizeof(unsigned long)*MAPSIZE);		\
	fn1(one, value, start, nbits);				\
	fn2(two, value, start, nbits);				\
	res = memcmp(one, two, sizeof(unsigned long)*MAPSIZE);	\
	if (res) {						\
		printf(#fn1 " vs. " #fn2 ": [%lu, %lu]: %d\n", 	\
			start, nbits, res);			\
		print_map(one, #fn1);				\
		print_map(two, #fn2);				\
	}							\
} while (0)

void test(unsigned long value, unsigned long start, unsigned long nbits)
{
	COMPARE(bitmap_write, bitmap_write_new);
	COMPARE(bitmap_write, my_bitmap_write);
	COMPARE(bitmap_write, bitmap_write_new_shift);
}

int main()
{
	unsigned long value = 0xfafafafafafafafaUL;
	unsigned long start, nbits;

	for (start = 0; start <= (MAPSIZE-1)*BITS_PER_LONG; start++) {
		for (nbits = 0; nbits <= BITS_PER_LONG; nbits++) {
			test(value, start, nbits);
		}
	}
	return 0;
}
