/* \file dasm.cpp
 *
 */
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <iostream>
#include <string>
using namespace std;
#include "strtoken.h"
#include "asm.h"

List code, symbols;

// defined in mips.cpp
int is_branch_or_jump(string str);


void parse(FILE *in);
void process_branches();
void pseudo(Node *np);
void show_symbols(FILE *out);
void map_symbols();
void process_symbols();

int main(int argc, char *argv[])
{
	FILE *in, *out;
	string filename;
	errno_t nerr;
	if (argc<2) {
		cout << "usage: dasm source [ outfile ]\n";
		return EXIT_FAILURE;
	}
	filename = argv[1];
	nerr = fopen_s(&in,filename.c_str(),"r");
	if (nerr) {
		cout << "file: " << filename << " not found\n";
		return EXIT_FAILURE;
	}
	if (argc<3) out = stdout;
	else {
		filename = argv[2];
		nerr = fopen_s(&out,filename.c_str(),"wt");
		if (nerr) {
			cout << "file: " << filename << " creation error\n";
			return EXIT_FAILURE;
		}
	}
	parse(in);
	fclose(in);

	
	Node *np;
	// process pseudo instructions
	for (np=code.first(); np; np=np->next) pseudo(np);


	map_symbols();
	process_symbols();
	process_branches();

	// output binary
	int n, count;
	bool binary=false;
	count = 0;
	for (np=code.first(); np; np=np->next) {
		Instruction *ip = (Instruction *) np;
		if (ip->opname.empty()) continue;
		if (ip->is_directive()) {
			if (count>0) {
				fprintf(out,"\n");
				count=0;
			}
			if (ip->opname.compare(".text")==0) fprintf(out,"inst\n");
			else if (ip->opname.compare(".data")==0) fprintf(out,"data\n");
			else if (ip->opname.compare(".end")==0) fprintf(out,"end\n");
			continue;
		}
		n = ip->code();
		ip->value = n;
		if (binary) fwrite(&n,4,1,out);
		else {
			fprintf(out,"%08x ",n);
			if (count%8==7) fprintf(out,"\n");
		}
		count++;
	}
	if (!binary) fprintf(out,"\n");

	// show results
	code.show();
	if (symbols.count()>0) {
		printf("\n\nSymbol Table\n\n");
		symbols.show();
		printf("\n");
	}

	return EXIT_SUCCESS;
}

int is_blank(char *buf)
{
	while (*buf) {
		if (*buf++ != ' ') return 0;
	}
	return 1;
}

int decode_register(char *token, int &field)
{
	while (*token==' ') token++;
	if (*token!='$' && *token!='r') return -1;
	sscanf_s(++token,"%d",&field);
	return 1;
}

int decode_number(string token, int &value)
{
	int n;
	n = token.find("0x");
	if (n>=0) n = sscanf_s(token.c_str(),"%x",&value);
	else n = sscanf_s(token.c_str(),"%d",&value);
	return n;
}

bool insert_symbol(string str)
{
	// see if symbol is in table
	Node *np;
	Symbol *sp;
	for (np=symbols.first(); np; np=np->next) {
		sp = (Symbol *) np;
		if (sp->compare(str)==0) return false; // symbol present
	}
	// add new symbol to table
	sp = new Symbol(str);
	symbols.append(sp);
	return true; 
}

void map_symbols()
{
	Node *np;
	int count=0;
	for (np=code.first(); np; np=np->next) {
		Instruction *ip = (Instruction *) np;
		int type = ip->type();
		if (type==0) continue; // is directive
		ip->addr = count;
		Symbol *sp = ip->label;
		if (sp) sp->addr = count;
		if (type>0) count++;
	}
}

int find_symbol(string str)
{
	Node *np;
	for (np=symbols.first(); np; np=np->next) {
		Symbol *sp = (Symbol *) np;
		if (sp->compare(str)==0) return sp->addr;
	}
	return -1;
}

void show_symbols(FILE *out)
{
	Node *np;
	Symbol *sp;
	if (symbols.count()==0) return;
	fprintf(out,"\nSymbol Table\n\n");
	for (np=symbols.first(); np;  np=np->next) {
		sp = (Symbol *) np;
		fprintf(out,"%4x %s\n",sp->addr,sp->name.c_str());
	}
}


void parse(FILE *in)
{
	char buf[512];
	char *inp, *next, *tok;
	int field, count, n;
	string token;
	bool done = false;
	/*
	 * Go through input file, line by line
	 */
	n = 0;
	while (!done) {
		fgets(buf,511,in);
		if (feof(in)) break;
		for (inp=buf; *inp; inp++) {
			if (*inp=='\t') *inp = ' ';
			if (*inp=='\n') *inp = 0;
			if (*inp=='(') *inp = ',';
		}
		strqtrm(buf,'#');
		//printf("%s\n",buf);
		if (is_blank(buf)) continue;
		next = buf;
		/*
		 * parse current line
		 */
		count = 0;
		Instruction* ins = new Instruction;
		code.append(ins);
		ins->src = buf;
		while (next) {
			inp = next;
			// check for label
			next = strqtrm(inp,':');
			if (next) {
				strtrm(inp,' ');
				token = inp;
				if (insert_symbol(token)) {
					ins->label = (Symbol *) symbols.last();
				}
				else fprintf(stderr,"duplicate label: %s\n",inp);
				//fprintf(out,"label %s (%d)\n",inp,label.length());
			}
			else {
				token = "";
				next = inp;
			}
			tok = strtoken(next,' ');
			if (tok) {
				//printf("operation %s\n",tok);
				ins->opname = tok;
			}
			while (next) {
				inp = next;
				next = strqtrm(inp,',');
				Operand *op  = new Operand;
				ins->operands.append(op);
				n = decode_register(inp,field);
				if (n>0) {
					//printf("register r%d\n",field);
					op->type = 1;
					op->value = field;
				}
				else {
					token = inp;
					n = decode_number(token,field);
					if (n>0) {
						//printf("number %d\n",field);
						op->type = 2;
						op->value = field;
					}
					else if (*inp) {
						tok = strtoken(inp,' ');
						if (tok) {
							op->type = 3;
							op->ref = tok;
							//printf("reference %s (%d)\n",tok,strlen(tok));
						}
						else op->type = 9;
					}
					else op->type = 9;
				}
				if (op->type==9) ins->operands.remove(op);
			}
		}
	}
	Node *np;
	for (np=code.first(); np; np=np->next) {
		Instruction *ip = (Instruction *) np;
		ip->set_argtype();
		if (ip->argtype==121) ip->swap121();
	}
}

void process_symbols()
{
	Node *np;
	for (np=code.first(); np; np=np->next) {
		Instruction *ip = (Instruction *) np;
		Operand *op = (Operand *) ip->operands.last();
		if (!op) continue;
		if (op->type==3) {
			int targ = find_symbol(op->ref);
			if (targ>=0) {
				op->value = targ;
			}
			else {
				ip->show();
				printf("*** symbol not defined\n");
			}
		}
	}
}
				

void process_branches()
{
	int n = 0;
	Node *np;
	for (np=code.first(); np; np=np->next) {
		Instruction *ip = (Instruction *) np;
		n = is_branch_or_jump(ip->opname);
		if (!n) continue;
		Operand * op = (Operand *) ip->operands.last();
		if (!op) continue;
		if (op->type==3) {
			int targ = find_symbol(op->ref);
			if (targ>=0) {
				op->value = n>1? targ: targ - ip->addr;
			}
			else {
				ip->show();
				printf("*** branch target not defined\n");
			}
		}
	}
}

Instruction *new_Instruction(char *str, Operand *op1, Operand *op2, Operand *op3)
{
	Instruction *ip = new Instruction(str);
	if (op1) ip->operands.append(op1);
	if (op2) ip->operands.append(op2);
	if (op3) ip->operands.append(op3);
	ip->set_argtype();
	return ip;
}


void pseudo(Node *np)
{
	Instruction *ip = (Instruction *) np;
	Instruction *ipnew;
	Operand *op;
	if (ip->opname.compare("li")==0 || ip->opname.compare("la")==0) {
		op = (Operand *) ip->operands.last();
		op = new Operand(1,0);
		ip->operands.insert(op,1);
		ip->opname = "ori";
	}
	else if (ip->opname.compare("b")==0) {
		ip->opname = "bgez";
		op = new Operand(1,0);
		ip->operands.prepend(op);
	}
	else if (ip->opname.compare("move")==0) {
		ip->opname = "or";
	}
	else if (ip->opname.compare("ret")==0) {
		ip->opname = "jr";
		op = (Operand *) ip->operands.first();
		if (op) {
			printf("--- return %d %d\n",op->type, op->value);
			op->show();
		}
		op = new Operand(1,31);
		ip->operands.prepend(op);
	}
	else if (ip->opname.compare("blt")==0) {
		/*
		 *[004000e4] 0062082a  slt $1, $3, $2           ; 26: blt $3,$2,end 
		 *[004000e8] 14200002  bne $1, $0, 8 [end-0x004000e8]
		 */
		ip->opname = "slt";
		op = (Operand *) ip->operands.last();
		ip->operands.extract(op);
		ip->operands.prepend(new Operand(1,1));
		ipnew = new_Instruction("bne",new Operand(1,1),new Operand(1,0), op);
		code.after(ipnew,ip);
	}
	else if (ip->opname.compare("bge")==0) {
		/*
		 *[004000b4] 0043082a  slt $1, $2, $3           ; 20: bge $2,$3,end 
		 *[004000b8] 1020000e  beq $1, $0, 56 [end-0x004000b8] 
		 */
		ip->opname = "slt";
		op = (Operand *) ip->operands.last();
		ip->operands.extract(op);
		ip->operands.prepend(new Operand(1,1));
		ipnew = new_Instruction("beq",new Operand(1,1),new Operand(1,0), op);
		code.after(ipnew,ip);
	}
	else if (ip->opname.compare("not")==0) {
		ip->opname = "nor";
		op = new Operand(1,0);
		ip->operands.append(op);
	}
	else if (ip->opname.compare("neg")==0) {
		ip->opname = "sub";
		op = new Operand(1,0);
		ip->operands.insert(op,1);
	}
	else if (ip->opname.compare("seq")==0) {
		/*
		[00400048] 10620003  beq $3, $2, 12           ; 11: seq $8,$2,$3 
		[0040004c] 34080000  ori $8, $0, 0            
		[00400050] 10000002  beq $0, $0, 8            
		[00400054] 34080001  ori $8, $0, 1
		*/
		ip->opname = "beq";
		op = (Operand *) ip->operands.first();
		int rd = op->value;
		ip->operands.remove(op);
		op = new Operand(2,3);
		ip->operands.append(op);
		// add instructions (in reverse order)
		ipnew = new_Instruction("ori",new Operand(1,rd),new Operand(1,0),new Operand(2,1));;
		code.after(ipnew,ip);
		ipnew  = new_Instruction("beq",new Operand(1,0),new Operand(1,0),new Operand(2,2));
		code.after(ipnew,ip);
		ipnew = new_Instruction("ori",new Operand(1,rd),new Operand(1,0),new Operand(2,0));
		code.after(ipnew,ip);	
	}
	else if (ip->opname.compare("sgtu")==0) {
		/* 
		[00400078] 0062502b  sltu $10, $3, $2         ; 14: sgtu $10,$2,$3
		*/
		ip->opname = "sltu";
		op = (Operand *) ip->operands[1];
		ip->operands.extract(op);
		ip->operands.append(op);
	}
	else return;
	ip->set_argtype();
}
