[PATCH 1/3] riscv: optimized memcpy

Nick Kossifidis mick at ics.forth.gr
Tue Jan 30 04:11:44 PST 2024


On 1/28/24 13:10, Jisheng Zhang wrote:
> +
> +void *__memcpy(void *dest, const void *src, size_t count)
> +{
> +	union const_types s = { .as_u8 = src };
> +	union types d = { .as_u8 = dest };
> +	int distance = 0;
> +
> +	if (count < MIN_THRESHOLD)
> +		goto copy_remainder;
> +
> +	if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS)) {
> +		/* Copy a byte at time until destination is aligned. */
> +		for (; d.as_uptr & WORD_MASK; count--)
> +			*d.as_u8++ = *s.as_u8++;
> +
> +		distance = s.as_uptr & WORD_MASK;
> +	}
> +
> +	if (distance) {
> +		unsigned long last, next;
> +
> +		/*
> +		 * s is distance bytes ahead of d, and d just reached
> +		 * the alignment boundary. Move s backward to word align it
> +		 * and shift data to compensate for distance, in order to do
> +		 * word-by-word copy.
> +		 */
> +		s.as_u8 -= distance;
> +
> +		next = s.as_ulong[0];
> +		for (; count >= BYTES_LONG; count -= BYTES_LONG) {
> +			last = next;
> +			next = s.as_ulong[1];
> +
> +			d.as_ulong[0] = last >> (distance * 8) |
> +					next << ((BYTES_LONG - distance) * 8);
> +
> +			d.as_ulong++;
> +			s.as_ulong++;
> +		}
> +
> +		/* Restore s with the original offset. */
> +		s.as_u8 += distance;
> +	} else {
> +		/*
> +		 * If the source and dest lower bits are the same, do a simple
> +		 * aligned copy.
> +		 */
> +		size_t aligned_count = count & ~(BYTES_LONG * 8 - 1);
> +
> +		__memcpy_aligned(d.as_ulong, s.as_ulong, aligned_count);
> +		d.as_u8 += aligned_count;
> +		s.as_u8 += aligned_count;
> +		count &= BYTES_LONG * 8 - 1;
> +	}
> +
> +copy_remainder:
> +	while (count--)
> +		*d.as_u8++ = *s.as_u8++;
> +
> +	return dest;
> +}
> +EXPORT_SYMBOL(__memcpy);
> +

We could also implement memcmp this way, e.g.:

int
memcmp(const void *s1, const void *s2, size_t len)
{
	union const_data a = { .as_bytes = s1 };
	union const_data b = { .as_bytes = s2 };
	unsigned long a_val = 0;
	unsigned long b_val = 0;
	size_t remaining = len;
	size_t a_offt = 0;

	/* Nothing to do */
	if (!s1 || !s2 || s1 == s2 || !len)
		return 0;

	if (len < 2 * WORD_SIZE)
		goto trailing_fw;

	for(; b.as_uptr & WORD_MASK; remaining--) {
		a_val = *a.as_bytes++;
		b_val = *b.as_bytes++;
		if (a_val != b_val)
			goto done;
	}

	a_offt = a.as_uptr & WORD_MASK;
	if (!a_offt) {
		for (; remaining >= WORD_SIZE; remaining -= WORD_SIZE) {
			a_val = *a.as_ulong++;
			b_val = *b.as_ulong++;
			if (a_val != b_val)
				break;

		}
	} else {
		unsigned long a_cur, a_next;
		a.as_bytes -= a_offt;
		a_next = *a.as_ulong;
		for (; remaining >= WORD_SIZE; remaining -= WORD_SIZE, b.as_ulong++) {
			a_cur = a_next;
			a_next = *++a.as_ulong;
			a_val = a_cur >> (a_offt * 8) |
				a_next << ((WORD_SIZE - a_offt) * 8);
			b_val = *b.as_ulong;
			if (a_val != b_val) {
				a.as_bytes += a_offt;
				break;
			}
		}
		a.as_bytes += a_offt;
	}

  trailing_fw:
	while (remaining-- > 0) {
		a_val = *a.as_bytes++;
		b_val = *b.as_bytes++;
		if (a_val != b_val)
			break;
	}

  done:
	if (!remaining)
		return 0;

	return (int) (a_val - b_val);
}

Regards,
Nick



More information about the linux-riscv mailing list