卷积码的Viterbi译码

lzusa 发布于 2022-11-29 2 次阅读


这里使用的是hard-hard方式,使用一般的(7,5)卷积码

#define  _CRT_SECURE_NO_WARNINGS
#include<stdio.h>
#include<stdlib.h>
#include<time.h>
#include<math.h>

#define message_length 1000 //the length of message
#define codeword_length 2000 //the length of codeword
float code_rate = (float)message_length / (float)codeword_length;

// channel coefficient
#define pi 3.1415926
#define INF 0x7ffffff
double N0, sgm;

//int state_table[10][10];//state table, the size should be defined yourself
int state_num;//the number of the state of encoder structure

int message[message_length], codeword[codeword_length];//message and codeword
int re_codeword[codeword_length];//the received codeword
int de_message[message_length];//the decoding message

double tx_symbol[codeword_length][2];//the transmitted symbols
double rx_symbol[codeword_length][2];//the received symbols

struct state_table
{
    int in;
    int now;
    int next;
    int out[2];
};
state_table table[8];

int transition[message_length][4];

void statetable();
void encoder();
void modulation();
void demodulation();
void channel();
void decoder();

int main()
{
    int i;
    float SNR, start, finish;
    long int bit_error, seq, seq_num;
    double BER;
    double progress;

    //generate state table
    statetable();

    //random seed
    srand((int)time(0));

    //input the SNR and frame number
    printf("\nEnter start SNR: ");
    scanf("%f", &start);
    printf("\nEnter finish SNR: ");
    scanf("%f", &finish);
    printf("\nPlease input the number of message: ");
    scanf("%d", &seq_num);

    for (SNR = start; SNR <= finish; SNR++)
    {
        //channel noise
        N0 = (1.0 / code_rate) / pow(10.0, (float)(SNR) / 10.0);
        sgm = sqrt(N0 / 2);

        bit_error = 0;

        for (seq = 1; seq<=seq_num; seq++)
        {
            //generate binary message randomly
            /****************
            Pay attention that message is appended by 0 whose number is equal to the state of encoder structure.
            ****************/
            for (i = 0; i<message_length - state_num; i++)
            {
                message[i] = rand() % 2;
            }
            for (i = message_length - state_num; i<message_length; i++)
            {
                message[i] = 0;
            }
            //convolutional encoder

            encoder();
            //BPSK modulation
            modulation();

            //AWGN channel
            channel();

            //BPSK demodulation, it's needed in hard-decision Viterbi decoder
            demodulation();


            //convolutional decoder
            decoder();

            //calculate the number of bit error
            for (i = 0; i<message_length; i++)
            {
                if (message[i] != de_message[i])
                    bit_error++;
            }

            progress = (double)(seq * 100) / (double)seq_num;

            //calculate the BER
            BER = (double)bit_error / (double)(message_length*seq);

            //print the intermediate result
            printf("Progress=%2.1f, SNR=%2.1f, Bit Errors=%2.1d, BER=%E\r", progress, SNR, bit_error, BER);
        }

        //calculate the BER
        BER = (double)bit_error / (double)(message_length*seq_num);

        //print the final result
        printf("Progress=%2.1f, SNR=%2.1f, Bit Errors=%2.1d, BER=%E\n", progress, SNR, bit_error, BER);
    }
    system("pause");
}
void statetable()
{
    //定义寄存器状态00为0,01为1,10为2,11为3
    state_num = 2;
    for (int i = 0; i <= 7; i++)
    {
        table[i].in = i % 2;
        table[i].now = int(i / 2);
    } 
    table[0].next = table[2].next = 0;
    table[1].next = table[3].next = 2;
    table[4].next = table[6].next = 1;
    table[5].next = table[7].next = 3;

    table[0].out[0] = 0; table[0].out[1] = 0;
    table[1].out[0] = 1; table[1].out[1] = 1;
    table[2].out[0] = 1; table[2].out[1] = 1;
    table[3].out[0] = 0; table[3].out[1] = 0;
    table[4].out[0] = 1; table[4].out[1] = 0;
    table[5].out[0] = 0; table[5].out[1] = 1;
    table[6].out[0] = 0; table[6].out[1] = 1;
    table[7].out[0] = 1; table[7].out[1] = 0;
}

void encoder()
{
    //convolution encoder, the input is message[] and the output is codeword[]
    int now = 0;//用于标志寄存器状态
    for (int i = 0; i < message_length; i++)
    {
        for (int j = 0; j <= 7; j++)
            if (table[j].in == message[i] && now == table[j].now)
            {
                codeword[2 * i] = table[j].out[0];
                codeword[2 * i + 1] = table[j].out[1];
                now = table[j].next;
                break;
            }
    }
}

void modulation()
{
    //BPSK modulation
    int i;

    //0 is mapped to (1,0) and 1 is mapped tp (-1,0)
    for (i = 0; i<codeword_length; i++)
    {
        tx_symbol[i][0] = -1 * (2 * codeword[i] - 1);
        tx_symbol[i][1]=0;
    }
}
void channel()
{
    //AWGN channel
    int i, j;
    double u, r, g;

    for (i = 0; i<codeword_length; i++)
    {
        for (j = 0; j<2; j++)
        {
            u=(float)rand()/(float)RAND_MAX;
            if(u==1.0)
                u=0.999999;
            r=sgm*sqrt(2.0*log(1.0/(1.0-u)));

            u=(float)rand()/(float)RAND_MAX;
            if(u==1.0)
                u=0.999999;
            g=(float)r*cos(2*pi*u);

            rx_symbol[i][j]=tx_symbol[i][j]+g;
        }
    }
}
void demodulation()
{
    int i;
    double d1, d2;
    for (i = 0; i<codeword_length; i++)
    {
        d1 = (rx_symbol[i][0] - 1)*(rx_symbol[i][0] - 1) + rx_symbol[i][1] * rx_symbol[i][1];
        d2 = (rx_symbol[i][0] + 1)*(rx_symbol[i][0] + 1) + rx_symbol[i][1] * rx_symbol[i][1];
        if (d1<d2)
            re_codeword[i] = 0;
        else
            re_codeword[i] = 1;
    }
}
void decoder()
{
    int metrics[4], metrics_next[4]; // 用于记录当前和下一个状态的最小汉明距离
    metrics[0] = 0; //对于起始状态,只有第一个状态不为INF
    metrics[1] = metrics[2] = metrics[3] = INF;
    for (int i = 0; i < message_length - state_num; i++)
    {
        metrics_next[0] = metrics_next[1] = metrics_next[2] = metrics_next[3] = INF; //下一个状态的最小汉明距离先置为最大
        for (int j = 0; j < 4; j++)
        {
            for (int k = 0; k <= 7; k++) //当前状态为j,从第k条边转移到下一个状态table[k].next
            {
                if (table[k].now == j)
                {
                    int d = 0;
                    if (re_codeword[i * 2] != table[k].out[0])
                        d++;
                    if (re_codeword[i * 2 + 1] != table[k].out[1])
                        d++;
                    if (metrics_next[table[k].next] > metrics[j] + d)
                    {
                        metrics_next[table[k].next] = metrics[j] + d;
                        transition[i][table[k].next] = k; //用于记录路径   
                    }
                }
            }
        }
        for (int j = 0; j <= 3; j++)
            metrics[j] = metrics_next[j];
    }
    for (int i = message_length - state_num; i < message_length; i++)
    {
        metrics_next[0] = metrics_next[1] = metrics_next[2] = metrics_next[3] = INF; //下一个状态的最小汉明距离先置为最大
        for (int j = 0; j < 4; j++)
        {
            for (int k = 0; k <= 7; k++) //当前状态为j,从第k条边转移到下一个状态table[k].next
            {
                if (table[k].now == j && table[k].in == 0) //回到初始状态时只走0边
                {
                    int d = 0;
                    if (re_codeword[i * 2] != table[k].out[0])
                        d++;
                    if (re_codeword[i * 2 + 1] != table[k].out[1])
                        d++;
                    if (metrics_next[table[k].next] > metrics[j] + d)
                    {
                        metrics_next[table[k].next] = metrics[j] + d;
                        transition[i][table[k].next] = k; //用于记录路径   
                    }
                }
            }
        }
        for (int j = 0; j <= 3; j++)
            metrics[j] = metrics_next[j];
    }
    int state_now = 0; //回溯到哪个状态,从0状态开始
    for (int i = message_length - 1; i >= 0; i--)
    {
        de_message[i] = transition[i][state_now] % 2;
        state_now = int(transition[i][state_now] / 2);
    }
}
看烟花已落,你我仍是陌路人
最后更新于 2022-11-29