petspats/pyha

View on GitHub
pyha/cores/fft/bitreversal_fftshift_avgpool/bitreversal_fftshift_avgpool.py

Summary

Maintainability
A
3 hrs
Test Coverage
import numpy as np
import pytest

from pyha import Hardware, Sfix, resize, scalb, simulate, sims_close
from pyha.common.ram import RAM
from pyha.cores import DownCounter
from pyha.common.datavalid import DataValid, NumpyToDataValid
from pyha.cores.util import toggle_bit_reverse


def build_lut(fft_size, freq_axis_decimation):
    """ This LUT fixes the bit-reversal and performs fftshift. It defines th RAM write addresses."""
    orig_inp = np.array(list(range(fft_size)))
    shift = np.fft.fftshift(orig_inp)
    rev = toggle_bit_reverse(shift)
    lut = rev // freq_axis_decimation
    return lut


class BitreversalFFTshiftAVGPool(Hardware):
    """
    Bitreversal, FFTShift and AveragePooling
    ----------------------------------------

    Fixes bitreversal, performs fftshift and applies average pooling, implemented with 2 BRAM blocks.
    Internal accumulator may be saturated.

    Args:
        fft_size:
        avg_freq_axis: Pooling in frequnecy domain, decimates the data rate and has major impact on resource usage.
            Large decimations use LESS memory.
            Example, if input is 1024 point fft and avg_freq_axis is 2, then output is 512 points.
        avg_time_axis: Pooling in time domain, decimates the data rate.

    TODO: this core should be unsigned...
    """
    def __init__(self, fft_size, avg_freq_axis, avg_time_axis):
        """

        Args:
            fft_size:
            avg_freq_axis:
            avg_time_axis:
        """
        self._pyha_simulation_input_callback = NumpyToDataValid(dtype=Sfix(0.0, 0, -35, overflow_style='saturate'))

        assert not (avg_freq_axis == 1 and avg_time_axis == 1)
        self.AVG_FREQ_AXIS = avg_freq_axis
        self.AVG_TIME_AXIS = avg_time_axis
        self.ACCUMULATION_BITS = int(np.log2(avg_freq_axis * avg_time_axis))
        self.FFT_SIZE = fft_size
        self.LUT = build_lut(fft_size, avg_freq_axis)

        self.time_axis_counter = self.AVG_TIME_AXIS
        self.state = True
        self.ram = [RAM([Sfix(0.0)] * (fft_size // avg_freq_axis)),
                    RAM([Sfix(0.0)] * (fft_size // avg_freq_axis))]
        self.out_valid = False
        self.control = 0

        self.output = DataValid(Sfix())
        self.final_counter = DownCounter(self.FFT_SIZE / self.AVG_FREQ_AXIS + 1)
        self.start_counter = DownCounter(fft_size + 1)

    def work_ram(self, data, write_ram, read_ram):
        # READ-MODIFY-WRITE
        write_index = self.LUT[self.control]
        write_index_future = self.LUT[(self.control + 1) % self.FFT_SIZE]
        read = self.ram[write_ram].delayed_read(write_index_future)
        new_value = resize(read + data, size_res=data, overflow_style='saturate')
        self.ram[write_ram].delayed_write(write_index, new_value)

        # output stage
        self.out_valid = False
        if self.control < self.FFT_SIZE / self.AVG_FREQ_AXIS and self.time_axis_counter == self.AVG_TIME_AXIS:
            _ = self.ram[read_ram].delayed_read(self.control)
            self.out_valid = True

            # clear memory
            self.ram[read_ram].delayed_write(self.control, Sfix(0.0, size_res=data))

    def main(self, input):
        """
        Args:
            input (DataValid): 36 bits, type not restricted

        Returns:
            DataValid: Output type shifted right by the bit-growth.

        """
        if not input.valid:
            return DataValid(self.output.data, valid=False)

        self.control = (self.control + 1) % self.FFT_SIZE
        if self.state:
            self.work_ram(input.data, 0, 1)
            read = self.ram[1].get_readregister()
        else:
            self.work_ram(input.data, 1, 0)
            read = self.ram[0].get_readregister()

        if self.control >= self.FFT_SIZE - 1:
            next_counter = self.time_axis_counter - 1
            if next_counter == 0:
                next_counter = self.AVG_TIME_AXIS
                self.state = not self.state

            self.time_axis_counter = next_counter

        self.output.data = scalb(read, -self.ACCUMULATION_BITS)
        self.start_counter.tick()
        self.output.valid = self.start_counter.is_over() and self.out_valid
        return self.output

    def model(self, inp):
        shaped = np.reshape(inp, (-1, self.FFT_SIZE))
        # apply bitreversal
        unrev = toggle_bit_reverse(shaped)

        # fftshift
        unshift = np.fft.fftshift(unrev, axes=1)

        # average in freq axis
        avg_y = np.split(unshift.T, len(unshift.T) // self.AVG_FREQ_AXIS)
        avg_y = np.mean(avg_y, axis=1)

        # average in time axis
        avg_x = np.split(avg_y.T, len(avg_y.T) // self.AVG_TIME_AXIS)
        avg_x = np.mean(avg_x, axis=1)
        return avg_x.flatten()


@pytest.mark.parametrize("avg_freq_axis", [2, 8, 16])
@pytest.mark.parametrize("avg_time_axis", [1, 4, 8])
@pytest.mark.parametrize("fft_size", [256, 128])
@pytest.mark.parametrize("input_power", [0.1, 0.001])
def test_all(fft_size, avg_freq_axis, avg_time_axis, input_power):
    np.random.seed(0)
    avg_time_axis = 1
    packets = avg_time_axis + 1
    orig_inp = np.random.uniform(-1, 1, packets * fft_size) * input_power

    orig_inp_quant = np.vectorize(lambda x: float(Sfix(x, 0, -35)))(orig_inp)

    dut = BitreversalFFTshiftAVGPool(fft_size, avg_freq_axis, avg_time_axis)
    sim_out = simulate(dut, orig_inp_quant, pipeline_flush='auto')
    assert sims_close(sim_out, rtol=1e-30, atol=1e-30)


@pytest.mark.parametrize("avg_freq_axis", [2])
@pytest.mark.parametrize("avg_time_axis", [1])
@pytest.mark.parametrize("fft_size", [128])
@pytest.mark.parametrize("input_power", [0.001])
def test_nonstandard_input_size(fft_size, avg_freq_axis, avg_time_axis, input_power):
    np.random.seed(0)
    avg_time_axis = 1
    packets = avg_time_axis + 1

    dtype = Sfix(0, -4, -40, round_style='round')

    dut = BitreversalFFTshiftAVGPool(fft_size, avg_freq_axis, avg_time_axis)
    dut._pyha_simulation_input_callback = NumpyToDataValid(dtype)

    inp = np.random.uniform(-1, 1, packets * fft_size) * input_power
    inp = [float(dtype(x)) for x in inp]
    sim_out = simulate(dut, inp, pipeline_flush='auto', simulations=['MODEL', 'HARDWARE', 'RTL'])
    assert sims_close(sim_out, rtol=1e-30, atol=1e-30)