package main import ( "flag" "fmt" "log" "os" "reflect" ) var source = "" func isValue(e Expression) bool { _, ok := e.(Value) return ok } func fullStep(e Expression) (Expression, error) { for !isValue(e) { n, err := e.Step() if err != nil { return nil, err } e = n } return e, nil } func builtinFunctions(expr Expression) { functions := map[Id]BuiltInFunction{ // "quote": { // t: &FunctionType{ // ret: &PrimInteger, // params: , // }, // f: func(e []Expression) (Expression, error) { // if len(e) == 1 { // return &Quoted{e[0]}, nil // } // return &Quoted{&List{e}}, nil // }, // }, "+": { t: &FunctionType{ ret: &PrimInteger, params: []Type{&PrimInteger, &PrimInteger}, }, f: func(args []Expression) (Expression, error) { var sum IntValue = 0 for _, arg := range args { if !isValue(arg) { panic("not a value") } switch x := arg.(type) { case *IntValue: sum += *x default: panic("type error") } } return &sum, nil }, }, // "-": func(args []Expression) (Expression, error) { // v0, ok := args[0].(*IntValue) // if !ok { // return nil, fmt.Errorf("integer value was expected") // } // v1, ok := args[1].(*IntValue) // if !ok { // return nil, fmt.Errorf("integer value was expected") // } // var result IntValue = (*v0 - *v1) // return &result, nil // }, // "=": func(args []Expression) (Expression, error) { // if reflect.TypeOf(args[0]) != reflect.TypeOf(args[1]) { // return nil, fmt.Errorf("unmatched types") // } // v0 := reflect.ValueOf(args[0]) // v1 := reflect.ValueOf(args[1]) // var result BoolValue = BoolValue(v0.Equal(v1)) // return &result, nil // }, } for id, f := range functions { expr.Substitute(id, &f) } } func main() { logger := log.New(os.Stderr, "error: ", 0) evalString := flag.String("eval", "", "evaluate an expression and return its value") flag.Parse() args := flag.Args() if len(args) > 0 { bytes, err := os.ReadFile(os.Args[1]) if err != nil { logger.Fatal(fmt.Errorf("file not found")) } source = string(bytes) } else if *evalString != "" { source = *evalString } else { logger.Fatal(fmt.Errorf("missing input file")) } tk, err := lex(source) if err != nil { logger.Fatal(err) } p, err := consume(tk) if err != nil { logger.Fatal(err) } builtinFunctions(p) _, err = p.TypeCheck(make(TypeEnvironment)) if err != nil { logger.Fatal(err) } result, err := fullStep(p) if err != nil { logger.Fatal(err) } fmt.Println(reflect.TypeOf(result), result) }