//*************************************
//Parser.c
// by Spencer Putt
// January 2005
//*************************************

#include "defines.h"
#include <stdio.h>

typedef struct {
    short weight;
    int value;
} num_expr;

typedef struct {
    char *name;
    unsigned short value;
} label_struct;

typedef struct {
    char *name;
    char *define;
} define_struct;

int total_labels = 0,total_defines = 0;
const char operators[] = "*/+-|&^<>";
const char white_space[] = "\t ([])\n\r";
label_struct label_array[MAX_LABELS];
define_struct define_array[MAX_DEFINES];
static BOOL was_error = FALSE;
int (*post_error)(const char*,...) = &printf;

/* PROTOTYPES */
int conv_binary(char*);
int f_mult(int a1, int a2);
int f_div(int a1, int a2);
int f_add(int a1, int a2);
int f_sub(int a1, int a2);
int f_or(int a1, int a2);
int f_and(int a1, int a2);
int f_xor(int a1, int a2);
int f_shl(int a1, int a2);
int f_shr(int a1, int a2);

int parse_string(char*);
int eval_string(char*);
int parse_expr(char*);
int eval_level(num_expr*);
int search_labels(char*);
char* next_char(char*);
int insert_label(label_struct*);
BOOL resolve_macro(char *, int *, BOOL, FILE*);


/* TAKEN FROM MAIN.C */
extern BOOL pass_one;
extern FILE* output_file;
extern line_number;

void assemble_line(FILE*,char*);
void upcase(char*);
/* END OF STUFF TAKEN FROM MAIN */



int parse_string(char* str) {
    int temp;
    unsigned short value;

    if (!*str) {
        return 0;
    }
    str = next_char(str);
    was_error = FALSE;
    temp = eval_string(str);
    if (was_error) {
        return -1;
    }
    value = (unsigned short) ((((unsigned int) temp) & ((unsigned int) 0xFFFF0000)) >> 16);
    if (!(value == 0 || value == 0xFFFF)) {
        //printf("%d: overflow error detected",line_number);
        //return -1;
    }
    return (unsigned short) temp;
}
    
int eval_string(char* str) {
    int i,op_i,expr_length,expr_terms,total_terms;
    unsigned int temp;
    short parindicator = 0,ind = 0;
    char buffer[128],this_char,*temp_ptr;
    num_expr exprs[32];
    BOOL op_bool;
    
    //count up the terms, and make an array of that size.
    for (expr_terms = 1, i = 0; str[i]; i++) {
        for (op_i = 0; op_i < len_op; op_i++) {
            if (operators[op_i] == str[i]) break;
        }
        if (op_i < len_op) {
            expr_terms++;
            //truncate << or >> to < or > in source string
            if (op_i >= len_op-2) str[i++] = ' ';
        }
    }
    total_terms = expr_terms + expr_terms - 1;
    
    //allocate space for one last terminating term.
    //exprs = (num_expr*) malloc((total_terms+1) * sizeof(num_expr));
    
    // copy each term to its structure, along with a level weight
    // based on nested parenthesis
    expr_length = 0;
    for (temp = 0, i = 0; str[i]; i++) {
        this_char = str[i];
        if (this_char == '(' || this_char == '[') {
            parindicator++;
        } else if (this_char == ')' || this_char == ']') {
            parindicator--;
        } else {
            for (op_i = 0; op_i < len_op && operators[op_i]!=this_char; op_i++);
            if (op_i < 9) {
                //at this point we've hit an operator
                //first write the previous term to the expr buffer
                if (expr_length) {
                    buffer[expr_length] = 0;
                    exprs[temp++].value = parse_expr(buffer);
                    
                    //then handle the operator
                    exprs[temp].weight = parindicator; //this can differ from the term above
                    exprs[temp++].value = op_i;
                    expr_length = 0;
                } else {
                    //handle operator problems here
                    //the side effects of the extra operator are an over
                    //allocation of memory, and total terms will hold
                    //2 too many
                    if (this_char == '-') {
                        temp_ptr = next_char(str+i+1);
                        if (!temp_ptr) {
                            post_error("Unexpected operator at end of line.");
                            was_error = TRUE;
                            return -1;
                        }
                        if (*temp_ptr == '(' || *temp_ptr == '[') {
                            // skip the whitespace
                            i = temp_ptr - str - 1;
                            //puts("Detected '-' next to parenthesis: grouping ...");
                            ind = 1;
                            parindicator++;
                            for (op_i = i+2; str[op_i] && ind; op_i++) {
                                if (str[op_i]=='(' || str[op_i]=='[') {
                                    ind++;
                                } else if (str[op_i]==')' || str[op_i]==']') {
                                    ind--;
                                }
                            }
                            if (ind) {
                                // post an error
                                post_error("Parenthesis error in grouping.");
                                was_error = TRUE;
                                return -1;
                            }
                            // shift remaining chars in string to the right
                            for (ind = strlen(str) + 1; ind >= op_i; ind--) {
                                str[ind] = str[ind-1];
                            }
                            exprs[temp].weight = parindicator;
                            exprs[temp++].value = -1;
                            exprs[temp].weight = (short) parindicator;
                            exprs[temp++].value = 0; //multiply
                        } else {
                            // in this case, the '-' is paired with the next number
                            total_terms-=2;
                            exprs[temp].weight = (short) parindicator;
                            buffer[expr_length++] = this_char;
                        }
                    } else if (this_char == '+') {
                        total_terms-=2;
                    } else {
                        if (!i) {
                            post_error("Operator first error.");
                        } else {
                            post_error("Double operator exception.");
                        }
                        was_error = TRUE;
                        return -1;
                    }
                }
            } else {
                //search the buffer, omitting buffer white space (/t,' ',([,)])
                for (op_i = 0; op_i < len_white && this_char != white_space[op_i]; op_i++);
                if (op_i >= len_white) {
                    if (!expr_length) exprs[temp].weight = (short) parindicator;
                    buffer[expr_length++] = this_char;
                } 
            }
        }
    }
    //first write the previous term to the expr buffer
    if (!expr_length) {
        post_error("Operator at end of line.");
        was_error = TRUE;
        return -1;
    }
    buffer[expr_length] = 0;
    exprs[temp].value = parse_expr(buffer);

    //even if there's a parenthesis error, it's still salvagable
    if (parindicator) {
        post_error("Parenthesis error overall.");
    }

    //attach a terminator to end final parsing
    exprs[total_terms].weight = kTerminator;

    //now run the highest priority parenthesis level terms left to right
    for(;;) {
        //printf("TOTAL TERMS: %d  ******\n",total_terms);
        //run through every other term
        for (i = temp = parindicator = 0; i < total_terms; i+=2) {
            if (exprs[i].weight > temp) {
                parindicator = i;
                temp = exprs[i].weight;
            }
        }
        //if first term was the lowest level, bail
        if (!temp) {
            // evaluate the final level
            return eval_level(&exprs[parindicator]);
        }

        //decrease its weight and evaluate the level
        exprs[parindicator].value = eval_level(&exprs[parindicator]);
        exprs[parindicator].weight--;
        //abridge the terms seemlessly
        for (op_i = parindicator+1; exprs[op_i].weight == temp; op_i++);
        for (parindicator++;op_i<=total_terms;op_i++,parindicator++) {
            exprs[parindicator] = exprs[op_i];
        }
        total_terms = parindicator-1;
    }
}

//returns next non-white char
//null if next character doesn't exist
char* next_char(char* str) {
    int i;
    for (i = 0; str[i] && (str[i] == ' ' || str[i] == '\t' || str[i] == '\n'); i++);
    if (str[i]) return (char*) (str + i);
    return (char*) NULL;
}


//evaluates all terms on a level
int eval_level(num_expr* expr) {
    int i,acc,current_weight;
    unsigned short value;
    int (*op_table[])(int, int) = {&f_mult,&f_div,&f_add,&f_sub,&f_or,&f_and,&f_xor,&f_shl,&f_shr};
    
    current_weight = expr->weight;
    acc = expr->value;
    
    while ((expr+1)->weight == current_weight) {
        acc = op_table[(expr+1)->value](acc,(expr+2)->value);
        expr+=2;
    }
 
    return acc;
}

//"*/+-|&^<>";
int f_mult(int a1, int a2)  {return a1 * a2;}
int f_div(int a1, int a2)   {return a1 / a2;}
int f_add(int a1, int a2)   {return a1 + a2;}
int f_sub(int a1, int a2)   {return a1 - a2;}
int f_or(int a1, int a2)    {return a1 | a2;}
int f_and(int a1, int a2)   {return a1 & a2;}
int f_xor(int a1, int a2)   {return a1 ^ a2;}
int f_shl(int a1, int a2)   {return a1 << a2;}
int f_shr(int a1, int a2)   {return a1 >> a2;}

int parse_expr(char *str) {
    int temp,i,len;
    
    //system("PAUSE");
    len = strlen(str);
    if (str[0] == '-') {
        if (len > 1) {
            return temp = -1 * parse_expr(str+1);
        } else {
            //post a warning
            post_error("Rogue operator.");
            return -1;
        }
    } else if (str[0] == '~') {
        if (len > 1) {
            return ~ parse_expr(str+1);
        } else {
            post_error("Invalid ~ usage.");
            //post a warning here, too
            return -1;
        }
    } else {
        switch (str[0]) {
            case '%':
                //handle binary
                temp = conv_binary(str+1);
                break;
            case '\'':
                temp = str[1];
                break;
            default: {
                //handle distinction between labels and numbers here
                if (str[0] < 'A') {
                    switch (str[len-1]) {
                        case 'h':
                        case 'H': {
                            str[len-1] = 0;
                            sscanf(str,"%x",&temp);
                        } break;
                        case 'b':
                        case 'B': {
                            str[len-1] = 0;
                            temp = conv_binary(str);
                            //handle binary
                        } break;
                        default: {
                            //it's a standard number
                            sscanf(str,"%d",&temp);
                        }
                    }
                } else {
                    //puts("Starting to search for label.");
                    if ((i = search_labels(str)) != -1) {
                        temp = label_array[i].value;
                    } else {
                        temp = 0;
                    }
                }
            }
        }
        return temp;
    }
}

int conv_binary(char* str) {
    int acc = 0;
    while (*str) {
        acc<<=1;
        if (*str == '1') acc++;
        str++;
    }
    return acc;
}

//returns index of matching label
int do_search_labels(char *str) {
    int low_bound = 0, high_bound = total_labels - 1,temp;
    int i,str_len = strlen(str);
    
    //if case sensitive: don't convert
    //else:
    // convert input string to uppercase (to match label table)
    temp = 'a'-'A';
    for (i = 0; i < str_len; i++) {
        if (str[i]>='a' && str[i]<= 'z') str[i]-=temp;
    }

    i = high_bound/2;
    while (low_bound <= high_bound) {
        //if (total_labels < 4)
        //printf("Low bound: %d, high bound: %d, i: %d\n",low_bound,high_bound, i);
        temp = strncmp(str,label_array[i].name,256);
        // force it to be recognized 
        if (!temp) {
            //puts("Got an answer for our label.");
            return i;
        }
        if (temp > 0) {
            low_bound = i+1;
        } else {
            high_bound = i-1;
        }
        i = (high_bound + low_bound)/2;
    }
    //printf("Low bound: %d, high bound: %d, i: %d\n",low_bound,high_bound, i);
    if (high_bound >= 0) {
        if (i >=0 && i < total_labels) {
            if (strncmp(str,label_array[i].name,256)>0) {
                return i+1;
            }
        }
    }
    return i;
}

int search_labels(char *str) {
    int i;
    //puts("Search_labels: Code hit here.");
    //printf("Search_labels: Total Labels so far - %d\n",total_labels);

    if (!total_labels) {
        //printf("Label %s not found.\n",str);
        return -1;
    }
    i = do_search_labels(str);
    if (i == total_labels || strcmp(label_array[i].name,str)) {
        if (search_labels(strcat(str,"_"))==-1) {
            printf("Label %s not found.\n");
            return -1;
        }
    }
    //printf("label %s found at %d\n",str,i);
    return i;
}

//inserts label into label array
int insert_label(label_struct *label) {
    int i,op_i,temp,str_len = strlen(label->name);
    
    if (!total_labels) {
        // just to upcase it
        upcase(label->name);
        label_array[0].name = label->name;
        label_array[0].value = label->value;
        total_labels++;
        return 0;
    } else if (total_labels == 1) {
        upcase(label->name);
        if ((strncmp(label->name,label_array[0].name,256)>0)) {
            i = 1;
        } else {
            i = 0;
            label_array[1] = label_array[0];
        }
        label_array[i].name = label->name;
        label_array[i].value = label->value;
        total_labels++;
        return 0;
    }

    i = do_search_labels(label->name);
    
    //printf("placing %s at %d\n",label->name,i);
    if (i != total_labels) {
        if (!strcmp(label->name,label_array[i].name) && label->value != label_array[i].value) {
            //post_error("Label name %s has already been used.",label->name);
            label_array[i].value = label->value;
            return i;
        }
        for (op_i = total_labels; op_i > i; op_i--) {
            label_array[op_i].value = label_array[op_i-1].value;
            label_array[op_i].name  = label_array[op_i-1].name;
        }
    }     
    label_array[i].name  = label->name;
    label_array[i].value = label->value;
    total_labels++;
    return i;
}
