H1387XD
H1387XD

Reputation: 22

Why is my code inside an if block being called, no matter the condition?

I created a programming language but can't fix yacc executing code inside an if block:

def t_error(t):
    print(f"unexpected token: {t.value}")
    sys.exit()

lexer = lex.lex()
variables = {}

def p_program(p):
    '''
    program : statements
    '''
    p[0] = p[1]

def p_statements(p):
    '''
    statements : statement
               | statement statements
    '''
    if len(p) == 2:
        p[0] = p[1]
    else:
        p[0] = p[1]

def p_expression_literals(p):
    '''
    expression : ID
              | INT
              | FLOAT
              | BOOL
              | STRING
    '''
    if p[1] in variables and isinstance(p[1], str):
        p[0] = variables[p[1]]
    else:
        p[0] = p[1]

def p_expression_cmp(p):
    '''
    expression : expression CMP expression
    '''
    if p[2] == "==":
        p[0] = p[1] == p[3]

def p_statement_assign(p):
    '''
    statement : ID EQ expression
    '''
    variables[p[1]] = p[3]

def p_expression_binop(p):
    '''
    expression : expression ADD expression
               | expression SUB expression
               | expression MUL expression
               | expression DIV expression
    '''
    if p[2] == "+":
        p[0] = p[1] + p[3]
    elif p[2] == "-":
        p[0] = p[1] - p[3]
    elif p[2] == "*":
        p[0] = p[1] * p[3]
    elif p[2] == "/":
        p[0] = p[1] / p[3]

def p_expression_group(p):
    '''
    expression : LP expression RP
               | LP statement RP
    '''
    p[0] = p[2]

def p_scope_block(p):
    '''
    block : LCB statements RCB
    '''
    p[0] = p[2]
def c_print(arg):
    arg = str(arg)
    arg = arg.lstrip('"')
    arg = arg.rstrip('"')
    print(arg)
def p_print(p):
    '''
    statement : PRINT LP expression RP
    '''
    c_print(p[3])

def p_if_statement(p):
    '''
    statement : IF expression block
              | IF expression block ELSE block
    '''
    if p[2]:
        p[0] = p[3]
    elif len(p) == 6:  # for 'else' part
        p[0] = p[5]

def p_error(p):
    print("Syntax error at '%s'" % p.value)


parser = yacc.yacc()
parser.parse(code)

When I do an if statement, the block is being called no matter the condition.

Upvotes: -1

Views: 43

Answers (1)

Bart Kiers
Bart Kiers

Reputation: 170158

As mentioned in the comments: don't evaluate your code immediately. Instead, build a parse/syntax tree and evaluate only the nodes you want to. Here's a quick demo of how that could be done:

import ply.lex as lex
import ply.yacc as yacc
from nodes import *

reserved = {
    'if': 'IF',
    'else': 'ELSE',
    'print': 'PRINT',
    'prompt': 'PROMPT'
}

tokens = (['LPAREN', 'RPAREN', 'LBRACE', 'RBRACE', 'PLUS', 'EQ', 'NUMBER', 'STRING', 'ASSIGN',  'IDENTIFIER'] +
          list(reserved.values()))

t_LPAREN = r'\('
t_RPAREN = r'\)'
t_LBRACE = r'\{'
t_RBRACE = r'\}'
t_PLUS = r'\+'
t_EQ = r'=='
t_ASSIGN = r'='

def t_NUMBER(t):
    r'\d+'
    t.value = int(t.value)
    return t

def t_STRING(t):
    r'"[^"]*"'
    t.value = str(t.value)[1:-1]
    return t

def t_IDENTIFIER(t):
    r'[a-zA-Z_][a-zA-Z0-9_]*'
    t.type = reserved.get(t.value, 'IDENTIFIER')
    return t

def t_SPACE(t):
    r'\s+'
    pass

def t_error(t):
    print(f"Illegal character '{t.value[0]}'")
    t.lexer.skip(1)

lexer = lex.lex()

def p_program(p):
    '''program : statement
               | program statement'''
    if len(p) == 2:
        p[0] = ProgramNode([p[1]])
    else:
        p[0] = ProgramNode(p[1].statements + [p[2]])

def p_block(p):
    '''block : LBRACE program RBRACE'''
    p[0] = p[2]

def p_statement(p):
    '''statement : if_statement
                 | print_statement
                 | assignment'''
    p[0] = p[1]

def p_assignment(p):
    '''assignment : IDENTIFIER ASSIGN expression'''
    p[0] = AssignNode(p[1], p[3])

def p_prompt(p):
    '''prompt : PROMPT LPAREN STRING RPAREN'''
    p[0] = PromptStmt(p[3])

def p_if_statement(p):
    '''if_statement : IF LPAREN expression RPAREN block ELSE block'''
    p[0] = IfStmtNode(p[3], p[5], p[7])

def p_print_statement(p):
    '''print_statement : PRINT LPAREN expression RPAREN'''
    p[0] = PrintNode(p[3])

def p_expression(p):
    '''expression : cmp_expression'''
    p[0] = p[1]

def p_cmp_expression(p):
    '''cmp_expression : add_expression EQ add_expression
                      | add_expression'''
    if len(p) == 4:
        p[0] = ExprNode(p[2], p[1], p[3])
    else:
        p[0] = p[1]

def p_add_expression(p):
    '''add_expression : atom PLUS atom
                      | atom'''
    if len(p) == 4:
        p[0] = ExprNode(p[2], p[1], p[3])
    else:
        p[0] = p[1]

def p_atom_expression(p):
    '''atom : LPAREN expression RPAREN
            | IDENTIFIER
            | NUMBER
            | STRING
            | prompt'''
    if len(p) == 4:
        p[0] = p[2]
    elif p.slice[1].type == 'IDENTIFIER':
        p[0] = VarNode(p[1])
    elif p.slice[1].type == 'NUMBER' or p.slice[1].type == 'STRING':
        p[0] = AtomNode(p[1])
    else:
        p[0] = p[1]

def p_error(p):
    print(f"Syntax error at '{p.value}'" if p else "Syntax error at EOF")

parser = yacc.yacc()

if __name__ == "__main__":
    test_input = '''
        answer = prompt("What is 40 + 2? ")
        
        if (40 + 2 == answer) {
          print("That is correct!")
        }
        else {
          print("Not correct :(")
        }
    '''
    result = parser.parse(test_input)
    result.eval({})

The nodes.py file could look like this:

import re
from abc import ABC, abstractmethod

class AstNode(ABC):
    @abstractmethod
    def eval(self, variables):
        pass


class ProgramNode(AstNode):
    def __init__(self, statements):
        self.statements = statements

    def eval(self, variables):
        for statement in self.statements:
            statement.eval(variables)


class AssignNode(AstNode):
    def __init__(self, variable, value):
        self.variable = variable
        self.value = value

    def eval(self, variables):
        variables[self.variable] = self.value.eval(variables)


class PromptStmt(AstNode):
    def __init__(self, prompt):
        self.prompt = prompt

    def eval(self, variables):
        answer = input(self.prompt)
        return int(answer) if re.match(r'^\d+$', answer) else answer


class IfStmtNode(AstNode):
    def __init__(self, if_condition, if_body, else_body):
        self.if_condition = if_condition
        self.if_body = if_body
        self.else_body = else_body

    def eval(self, variables):
        if self.if_condition.eval(variables):
            self.if_body.eval(variables)
        else:
            self.else_body.eval(variables)


class PrintNode(AstNode):
    def __init__(self, expression):
        self.expression = expression

    def eval(self, variables):
        print(self.expression.eval(variables))


class ExprNode(AstNode):
    def __init__(self, op, lhs, rhs):
        self.op = op
        self.lhs = lhs
        self.rhs = rhs

    def eval(self, variables):
        if self.op == '==':
            return self.lhs.eval(variables) == self.rhs.eval(variables)
        elif self.op == '+':
            return self.lhs.eval(variables) + self.rhs.eval(variables)
        else:
            raise NotImplementedError(f'Not yet implemented: {self.op}')


class VarNode(AstNode):
    def __init__(self, name):
        self.name = name

    def eval(self, variables):
        return variables[self.name]


class AtomNode(AstNode):
    def __init__(self, atom):
        self.atom = atom

    def eval(self, variables):
        return self.atom

    def __str__(self):
        return str(self.atom)

Upvotes: 0

Related Questions