/*  This file (parser.c) is part of Spencer's Assembler.

	Spencer's Assembler is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA

	Copyright 2006, Spencer Putt*/
#include <stdio.h>
#include "defines.h"
#include "main.h"

#define kTotal_commands     32
#define len_op      (sizeof(operators)-1)
#define len_white   (sizeof(white_space)-1)

#define eParenthesis    6 
#define eOperator       7
#define eTerm           8
#define eLabel          9
#define eMacro          10
#define eCharacter      11

int total_labels = 0,total_defines = 0;
const char operators[] = "+-*<>|&/%^";
const char white_space[] = "\t ()";
label_struct label_array[MAX_LABELS];
define_struct define_array[MAX_DEFINES];
static BOOL was_error = FALSE;
parse_error parse_errors[32];
int current_parse_error = 0;
int parse_ec = 0;

//returns next non-white char
//null if next character doesn't exist
char* next_char(char* str) {
    if (!str) return NULL;
    while(*str == ' ' || *str == '\t') str++;
    if (*str) return str;
    return (char*) NULL;
}


void post_parse_error(char *error, char* error_text) {
    if (parse_errors[current_parse_error].error = malloc( strlen(error) + 1 )) {
        strcpy(parse_errors[current_parse_error].error, error);
    }
    if (parse_errors[current_parse_error].error_text = malloc( strlen(error_text) + 1 )) {
        strcpy(parse_errors[current_parse_error].error_text, error_text);
    }
    parse_errors[current_parse_error].line = line_number;
    current_parse_error++;
}

BOOL flush_parse_errors(void) {
    //printf("Current parse error: %d\n",current_parse_error);
    if (!current_parse_error) return FALSE;
    
    while (current_parse_error) {
        free(parse_errors[--current_parse_error].error);
        free(parse_errors[current_parse_error].error_text);
    }
    return TRUE;
}

BOOL show_parse_errors(void) {
    int temp_error = current_parse_error;
    int line_backup = line_number;
    int i;
    if (!current_parse_error) return FALSE;
    for (i = 0; i < current_parse_error; i++) {
        line_number = parse_errors[i].line;
        post_error(parse_errors[i].error, parse_errors[i].error_text);
    }
    line_number = line_backup;
    return flush_parse_errors();
}

void handle_parse_ec(int error_code, char *str) {
    char error_base[256];
    static const char error_strings[][32] = {"Parenthesis","Operator","Term","Label","Macro","Character"};
    //printf("Posting error :%d\n", error_code);
    if (!pass_one && error_code >= 6) {
        strcat( strcpy(error_base, error_strings[error_code - 6]), " error in %s" );
        post_parse_error(error_base, str);
    }
}

int parse_string(char* str) {
    int temp;
    unsigned short value;

    if (!*str) return 0;
    was_error = FALSE;
    temp = eval_string(str);
    if (was_error) return -1;
    
    return temp;
} 

int separate_terms(char *str, num_expr *exprs) {
    int i,expr_i,op_i,expr_length,result;
    int parindicator = 0,expr_level;
    char buffer[128],this_char;
    
    for (expr_length = 0, expr_i = 0, i = 0; str[i]; i++) {
        this_char = str[i];
        //putchar(str[i]);
        //putchar('\n');
        if (this_char == '(') parindicator++;
        else if (this_char == ')') {
            parindicator--;
            if (parindicator < 0) return eParenthesis;
        } else {
            /* Check to see if it's an operator */
            for (op_i = 0; op_i < len_op && operators[op_i] != this_char; op_i++);
            if (op_i == len_op) {
                /* It's not an operator, write it the buffer */
                if (!expr_length) {
                    //printf("Setting level to %d\n",parindicator);
                    exprs[expr_i].weight = expr_level = parindicator;
                }
                /* Record the initial weight after first volley of parenthesis */
                if (parindicator == expr_level) {
                    if (this_char != ' ' && this_char != '\t') buffer[expr_length++] = this_char;
                } else if (parindicator > expr_level) {
                    /* There was an internal parenthesis, it's probably a macro */
                    buffer[expr_length++] = '(';
                    buffer[expr_length++] = this_char;
                    while (parindicator > expr_level && str[i]) {
                        i++;
                        if (str[i] == '(') parindicator++;
                        else if (str[i]==')') parindicator--;
                        buffer[expr_length++] = str[i];
                    }
                    if (!str[i]) return eParenthesis;
                /* Term was interrupted by ) parenthesis */
                } else if (this_char != ' ' && this_char != '\t') return eParenthesis;
            } else {
                if (expr_length) {
                    /* Parse the extracted term */
                    
                    buffer[expr_length] = 0;
                    //printf("Term to parse: %s\n",buffer);
                    if (result = parse_expr(buffer, expr_length, &exprs[expr_i++].value)) return result;
                    expr_length = 0;
                
                    /* We have hit an operator.  Handle it, then go on*/
                    exprs[expr_i].weight = parindicator;
                    exprs[expr_i++].value = op_i;
                } else {
                    /* There was either no expression, or there was a double operator */
                    switch(this_char) {
                        case '+': {
                            char *nextptr = next_char(str+i+1);
                            if (nextptr && *nextptr == '+' && !expr_length) {
                                buffer[expr_length++] = '+';
                                exprs[expr_i].weight = expr_level = parindicator;
                                i++;
                            }
                            break;
                        }
                        case '%':
                            expr_level = parindicator;
                            buffer[expr_length++] = this_char;
                            break;
                        case '<': break;
                        case '>': break;
                        case '-': {
                            int next_i;
                            /* Get the next non-white char, index */
                            for (next_i = i+1; str[next_i] && 
                                (str[next_i]==' ' || str[next_i]=='\t'); next_i++);
                            if (str[next_i] != '(') {
                                /* Otherwise, this is an errant operator */
                                buffer[expr_length++] = '-';
                                exprs[expr_i].weight = expr_level = parindicator;
                                if (!str[next_i]) return eOperator;
                                else if (str[next_i] == '-') {
                                    buffer[expr_length++] = '-';
                                    i++;
                                }
                            } else {
                                /* This case is a - followed by a '(', e.g. -(10+12)
                                Parse the expr in the parenthesis, and negate it */
                                int inpind = 0,result;
                                do {
                                    if (str[next_i]=='(') inpind++;
                                    else if (str[next_i]==')') inpind--;
                                    buffer[expr_length++] = str[next_i++];
                                } while (inpind && str[next_i]);
                                buffer[expr_length] = 0;
                                
                                exprs[expr_i].weight = parindicator;
                                result = parse_string(buffer);
                                if (result != -1) exprs[expr_i++].value = -result;
                                else return 0;
                                
                                if (!(str = next_char(&str[next_i]))) goto separation_done;
                                
                                for (op_i = 0; op_i < len_op && operators[op_i]!= str[0]; op_i++);
                                if (op_i < len_op) {
                                    /* It IS an operator next */
                                    exprs[expr_i].weight = parindicator;
                                    exprs[expr_i++].value = op_i;
                                }
                                /* Term error if there isn't an operator */
                                else return eTerm;
                                i = expr_length = 0;
                            } break;
                        }
                        default: return eOperator;
                    }
                }
            }
        }
    }
    //first write the previous term to the expr buffer
    if (!expr_length) return eOperator;
    buffer[expr_length] = 0;
    //printf("Inserting term: %d: %s\n",exprs[expr_i].weight,buffer);
    if (result = parse_expr(buffer, expr_length, &exprs[expr_i].value)) return result;
    expr_i++;
separation_done:
    exprs[expr_i].weight = kTerminator;
    exprs[expr_i+1].weight = kTerminator-1;
    exprs[expr_i+2].weight = kTerminator-2;
    exprs[expr_i+3].weight = kTerminator-3;
    return 0;
}
    
    
int eval_string(char* str) {
    num_expr exprs[64];
    int error,highest,highest_index,i;
    
    error = separate_terms(str, exprs);
    if (error) {
        handle_parse_ec(error, str);
        return -1;
    }
    /*
    puts("starting eval");
    
    for (i = 0; exprs[i].weight > kTerminator;  i++) {
        if (i&1) {
            printf("%d Op: %c\n",exprs[i].weight,operators[exprs[i].value]);
        } else {
            printf("%d     %d\n",exprs[i].weight,exprs[i].value);
        }
    }
    printf("%d Op: %c\n",exprs[i].weight,'x');
    i++;
    printf("%d Op: %c\n",exprs[i].weight,'x');//*/
    return eval_level(exprs, -1);
}


int f_mult  (int a1, int a2)   {return a1 * a2;}
int f_div   (int a1, int a2)   {return a1 / a2;}
int f_add   (int a1, int a2)   {return a1 + a2;}
int f_sub   (int a1, int a2)   {return a1 - a2;}
int f_or    (int a1, int a2)   {return a1 | a2;}
int f_and   (int a1, int a2)   {return a1 & a2;}
int f_xor   (int a1, int a2)   {return a1 ^ a2;}
int f_shl   (int a1, int a2)   {return a1 << a2;}
int f_shr   (int a1, int a2)   {return a1 >> a2;}
int f_mod   (int a1, int a2)   {return a1 % a2;}

typedef int (*arithp)(int,int);

int eval_level(num_expr* exprs, int base) {
    static unsigned int expr_i = 0,nested = 0;
    int acc;
    static const arithp op_table[len_op] = {&f_add,&f_sub,&f_mult,&f_shl,&f_shr,&f_or,&f_and,&f_div,&f_mod,&f_xor};
    
    nested++;
    
    acc = exprs[expr_i].value;
    while (1) {
        if (expr_i > 28) {
            handle_parse_ec(eTerm, "parsed string");
            return acc;
        }
        //printf("%02d %2c %02d\n",exprs[expr_i].value,operators[exprs[expr_i+1].value],exprs[expr_i+2].value);
        //printf("% 2d % 2d % 2d\n",exprs[expr_i].weight,exprs[expr_i+1].weight,exprs[expr_i+2].weight);
        //printf("With acc: %d in base %d\n",acc,base);
        expr_i+=2;
        if (exprs[expr_i].weight == exprs[expr_i-1].weight) {
            if (exprs[expr_i-1].weight > base) acc = op_table[exprs[expr_i-1].value](acc,exprs[expr_i].value);
            else goto eval_base;
        } else if (exprs[expr_i].weight > exprs[expr_i-1].weight) {
            arithp this_op = op_table[exprs[expr_i-1].value];
            acc = this_op(acc, eval_level(exprs, exprs[expr_i-1].weight) );
        } else {
            //puts("It was less...");
            if (exprs[expr_i].weight > base) {
                
                acc = op_table[exprs[expr_i-1].value](acc, exprs[expr_i].value);
            } else {
            eval_base:
                //puts("Hit base!");
                if (--nested) expr_i-=2;
                else if (base==-1) expr_i = 0;
                return acc;
            }
        }
    }
}

int conv_hex(char* str) {
    //static const char hex_table[16] = "0123456789ABCDEF";
    int acc = 0;
    char thisChar;

    while (thisChar = *(str++)) {
        acc<<=4;
        if (thisChar > '9') {
            if (thisChar > 'A'+9) acc += thisChar - ('a'-10);
            else acc+= thisChar - ('A'-10);
        }       
        else acc+= thisChar - '0';
    }
    return acc;
}

int conv_dec(char* str) {
    int acc = 0,i;
    
    while (*str) {
        acc*=10;
        acc+=*str-'0';
        str++;
    }
    return acc;
}

/* convert a string of straight characters '0' and '1' to a value */
int conv_binary(char* str) {
    int acc = 0;
    while (*str) {
        acc<<=1;
        if (*str == '1') acc++;
        str++;
    }
    return acc;
}

int parse_expr(char *str, int len, int *result) { 
    switch (str[0]) {
        case '-':
            //puts("Reusable parsing");
            if (len > 1) {
                if (str[1]=='_' && str[2]==0) {
                    *result = (unsigned int) reusables[total_reusables-1];
                } else if (str[1]=='-' && str[2]=='_' && str[3]==0) {
                    *result = (unsigned int) reusables[total_reusables-2];
                } else {
                    parse_expr(str+1,len,result);
                    *result = -(*result);
                }
            } else return eOperator;
            break;
        case '~':
            if (len > 1) {
                parse_expr(str+1, len, result);
                *result = ~*result;
            } else return eOperator;
            break;
        case '$':
            if (len > 1) *result = conv_hex(str+1);
            else *result = program_counter;
            break;
        case '%':
            *result = conv_binary(str+1);
            break;
        case '\'':
            /* Suggests a case like this: 'a' */
            if (str[1] != '\\') {
                if (str[1] != '\'') {
                    *result = str[1];
                    if (!next_char(str+2)) return eCharacter;
                } else *result = ' ';
            }
            /* A control character, '\n','\0','\"' */
            else {
                switch (toupper(str[2])) {
                    case 'N':   *result = '\n'; break;
                    case '\\':  *result = '\\'; break;
                    case '0':   *result = 0;    break;
                    case '"':   *result = '"';  break;
                    case 'R':   *result = '\r'; break;
                    case 'T':   *result = '\t'; break;
                    case '\'':  *result = '\''; break;
                    case '#':   *result = rand()&0xFF; break;
                    default:    *result = str[2];
                }
            }
            break;
        case '"':
            *result = str[1];
            break;
        case '_':
        case '+':
            //printf("Reusable parsing\n");
            if (str[1]==0) {
                //printf("Resturning reusable\n");
                *result = (unsigned int) reusables[total_reusables];
                break;
            } else if (str[1]=='_' && str[2]==0) {
                *result =(unsigned int) reusables[total_reusables+1];
                break;
            }
        default:
            //handle distinction between labels and numbers here
            if (str[0] < 'A' && str[0]!='_') {
                switch (str[len-1]) {
                    case 'h':
                    case 'H':
                        str[len-1] = 0;
                        *result = conv_hex(str);
                        break;
                    case 'b':
                    case 'B':
                        str[len-1] = 0;
                        *result = conv_binary(str);
                        //handle binary
                        break;
                    default:
                        //it's a standard number
                        *result = conv_dec(str);
                }
            } else {
                int i, error;
                if ((i = search_labels(str)) != -1) {
                    *result = label_array[i].value;
                } else {
                    //printf("Failed to find label %s.\n",str);
                    //system("PAUSE");
                    if (error = resolve_macro(NULL, str,result, FALSE)) return error;
                    else {
                       // puts("Found the macro.");
                    }
                }
            }
    }
    return 0;
}


/* Searches a label or define array for label or define named str */
BOOL do_search_names(int *result, char *str, label_struct* name_array, int total_names) {
    int low_bound = 0, high_bound = total_names - 1,temp;
    int i,str_len = strlen(str);
    
    /* Set up current "window" for alphabetical search */
    i = high_bound/2;
    while (low_bound <= high_bound) {
        temp = strncmp(str,name_array[i].name,256);
        if (!temp) {
            /* Compare matched.  Return true and set the value. */
            *result = i;
            return TRUE;
        } else if (temp > 0) low_bound = i + 1;
        else high_bound = i - 1;
        i = (high_bound + low_bound)/2;
    }
    /* No match.  Return accurate position for later insertion */
    if (high_bound >= 0 && i >= 0 && i < total_names) {
        if (strncmp(str,name_array[i].name,256) > 0) *result = i + 1;
    } else *result = i;
    return FALSE;
}

int search_labels(char *str) {
    int i;
    //printf("Searching for %s out of %d\n",str,total_labels);
    upcase(str);
    if (do_search_names(&i,str,label_array,total_labels)) return i;
    else return -1;
}

//inserts label into label array
int insert_label(label_struct *label) {
    int i,op_i;
    char already_used_error[] = "Label name %s has already been used.";
    
    //printf("Inserting Label %s %04x\n", label->name, label->value);
    upcase(label->name);
    
    /* Check for reusable label */
    /*
    if (label->name[0]=='_' && label->name[1]==0) {
        reusables[total_reusables] = (unsigned short) program_counter & 0xFFFF;
        return;
    }*/

    if (do_search_names(&i,label->name,label_array,total_labels)) {
        if (label->value != label_array[i].value) {
            post_error(already_used_error,label->name);
            free(label->name);
            label_array[i].value = label->value;
            return i;
        }
    } else {
        for (op_i = total_labels; op_i > i; op_i--) {
            label_array[op_i] = label_array[op_i-1];
        }
        label_array[i] = *label;
        total_labels++;
        return i;
    }     
}

// what is AT src cannot be mutilated.
int resolve_macro(FILE* outFile, char *src, int *result, BOOL command) {
    char arg_buf[256], macro_buf[strlen(src) + 1], name_buf[256];
    int mac_index, i;
    
    for (i = 0; src[i] && src[i] != '('; i++) macro_buf[i] = src[i];
    macro_buf[i] = 0;

    mac_index = search_defines( pack_string(macro_buf) );

    if (mac_index != -1) {
        char buffer[strlen(define_array[mac_index].define) + 1];
        if (src[i] != '\0' && *next_char(src + i + 1) != ')') {
            /* If there are arguments, turn them into defines */
            char *arg_ptr, *input_ptr;
            
            arg_ptr = src + i + 1;
            input_ptr = define_array[mac_index].name + strlen(macro_buf) + 1;
            while (input_ptr    = get_level_word( next_char(input_ptr), name_buf, ','),
                   arg_ptr      = get_level_word( next_char(arg_ptr), arg_buf, ',')) {

                insert_define(name_buf, arg_buf);
                //printf("Def %s is now: %s\n",name_buf,define_array[search_defines(name_buf)].define);
            }
        }
        if (command) {
            buffer[0] = ' ';
            assemble_line(outFile, strcpy(buffer + 1, define_array[mac_index].define) - 1);
        } else {
            line_expansion(define_array[mac_index].define, LE_CREATE);

            get_line(NULL, buffer);
            while (*next_char(buffer) == '#') {
                /* This will handle the appropriate action for a directive. */
                do_assemble_line(NULL, buffer);
                get_line(NULL, buffer);
            }
            
            *result = parse_string(buffer);
            
            line_expansion(NULL, LE_DESTROY);
        }
    } else {
        //printf("Couldn't find macro %s\n",macro_buf);
        //if (!pass_one) getchar();
        return eLabel;
    }
    return 0;
}

int search_defines(char* str) {
    int i;
    
    //printf("Searching %d defines for '%s'...\n",total_defines,str);
    
    if (!total_defines) return -1;

    pack_string(str);
    if (do_search_names(&i, str, (label_struct*) define_array, total_defines)) {
        //printf("Found %s at %d\n",str,i);
        return i;
    }
    else return -1;    
}

int insert_define(char* name, char* define) {
    int di;
    
    pack_string(name);
    upcase(define);
    //printf("Inserting %s %s\n",name,define);
    if (!(next_char(define) && !strncmp(next_char(define),"EVAL(",5))) {
    /* search to see if it is currently defined, and upcase */
        if (!do_search_names(&di, name, (label_struct*) define_array, total_defines)) {
            int op_i;
            /* New define, allocate space for the entry */
            //printf("Inserting: '%s' '%s' at %d\n",name,define,di);
            for (op_i = total_defines; op_i > di; op_i--) define_array[op_i] = define_array[op_i-1];

            define_array[di].name = (char*) malloc(strlen(name) + 1);
            strcpy(define_array[di].name, name);
            define_array[di].define = (char*) malloc(strlen(define) + 1);
            strcpy(define_array[di].define, define);
            total_defines++;
        } else {
            #define max_replace_terms 64
            const char separators[] = "+-*<>|&/%^\\, ";
            char terms[max_replace_terms][64];
            int i = 0, expr_i = 0, expr_length = 0;
            char result[strlen(define) + (strlen(name) * max_replace_terms)];

            do {
                int sep_i;

                for (sep_i = 0; sep_i < (sizeof(separators)-1) && separators[sep_i] != define[i]; sep_i++);
                if (sep_i < (sizeof(separators)-1)) {
                    terms[expr_length++][expr_i] = 0;
                    terms[expr_length][expr_i = 0] = separators[sep_i];
                    terms[expr_length++][1] = 0;
                } else {
                    if (!(terms[expr_length][expr_i++] = define[i])) expr_length++;
                }
            } while (define[i++]);
            
            result[0] = 0;
            for (i = 0; i < expr_length; i++) {
                if (!strcmp(name, terms[i])) strcat(result, define_array[di].define);
                else strcat(result, terms[i]);
            }
            
            define_array[di].define = (char*) realloc(define_array[di].define, strlen(result) + 1);
            strcpy(define_array[di].define, result);        
        }
        return di;
     } else {
        define = next_char(define) + 5;
        char hex_table[16] = "0123457689ABCDEF";
        char eval_buf[256], *ptr, result_buf[16];
        unsigned int result, i, eval_i;
        
        ptr = get_level_word(define, eval_buf, ')');
        result = parse_string(eval_buf);
        
        i = 0;
        while (result) {
            result_buf[i++] = hex_table[ result & 0xF ];
            result >>= 4;
        }
        eval_buf[0] = '$';
        eval_i = 1;
        while (i--) eval_buf[eval_i++] = result_buf[i];
        if (ptr) strcpy(eval_buf + eval_i, ptr);

        return insert_define(name, eval_buf);
    }
}
        
