package main import ( "flag" "fmt" "log" "os" ) var source = "" func (v *Value[T]) Step() (Expression, error) { return v, nil } func (v *Value[T]) Replace(Id, Expression) {} func (v *Quoted) Step() (Expression, error) { return v, nil } func (v *Quoted) Replace(Id, Expression) {} func (d *Declaration) Step() (Expression, error) { for _, v := range d.variables { if isValue(v.expr) { continue } n, err := v.expr.Step() if err != nil { return nil, err } v.expr = n return d, nil } for _, v := range d.variables { for _, b := range d.body { b.Replace(v.id.value, v.expr) } } if len(d.body) == 0 { return zero, nil } for i, b := range d.body { if isValue(b) { continue } n, err := b.Step() if err != nil { return nil, err } d.body[i] = n return d, nil } return d.body[len(d.body)-1], nil } func (d *Declaration) Replace(id Id, expr Expression) { for _, v := range d.variables { v.expr.Replace(id, expr) } for _, b := range d.body { b.Replace(id, expr) } } func (v *Symbol) Step() (Expression, error) { if v.expr == nil { return nil, fmt.Errorf("unbound symbol %s at %s", v.id, v.id.position) } return v.expr, nil } func (v *Symbol) Replace(id Id, expr Expression) { if v.id.value == id { v.expr = expr } } func (l *Lambda) Step() (Expression, error) { f := func(args []Expression) (Expression, error) { if len(args) != len(l.args) { return nil, fmt.Errorf("unexpected number of arguments") } for i, arg := range l.args { for _, e := range l.body { e.Replace(arg.id.value, args[i]) } } if len(l.body) == 0 { return zero, nil } for i, expr := range l.body { if isValue(expr) { continue } r, err := expr.Step() if err != nil { return nil, err } l.body[i] = r } return l.body[len(l.body)-1], nil } return &Value[Function]{f}, nil } func (l *Lambda) Replace(id Id, expr Expression) { for _, b := range l.body { b.Replace(id, expr) } } func (v *List) Step() (Expression, error) { for i, expr := range v.els { if !isValue(expr) { e, err := expr.Step() if err != nil { return nil, err } v.els[i] = e return v, nil } } if len(v.els) == 0 { return zero, nil } f, ok := v.els[0].(*Value[Function]) if !ok { return nil, fmt.Errorf("not a function") } r, err := f.value(v.els[1:]) if err != nil { return nil, err } return r, nil } func (v *List) Replace(id Id, expr Expression) { for _, e := range v.els { e.Replace(id, expr) } } func isValue(e Expression) bool { switch e.(type) { case *Value[int], *Value[float64], *Value[Function], *Quoted: return true default: return false } } func fullStep(e Expression) (Expression, error) { for !isValue(e) { n, err := e.Step() if err != nil { return nil, err } e = n } return e, nil } func stdlib(expr Expression) { functions := map[Id]Function{ "+": func(args []Expression) (Expression, error) { sum := 0 for _, arg := range args { if !isValue(arg) { panic("!") } switch x := arg.(type) { case *Value[int]: sum += x.value default: return nil, fmt.Errorf("invalid type") } } return &Value[int]{sum}, nil }, "-": func(args []Expression) (Expression, error) { v0, ok := args[0].(*Value[int]) if !ok { return nil, fmt.Errorf("integer value was expected") } v1, ok := args[1].(*Value[int]) if !ok { return nil, fmt.Errorf("integer value was expected") } return &Value[int]{v0.value - v1.value}, nil }, "*": func(args []Expression) (Expression, error) { v0, ok := args[0].(*Value[int]) if !ok { return nil, fmt.Errorf("integer value was expected") } v1, ok := args[1].(*Value[int]) if !ok { return nil, fmt.Errorf("integer value was expected") } return &Value[int]{v0.value * v1.value}, nil }, "/": func(args []Expression) (Expression, error) { v0, ok := args[0].(*Value[int]) if !ok { return nil, fmt.Errorf("integer value was expected") } v1, ok := args[1].(*Value[int]) if !ok { return nil, fmt.Errorf("integer value was expected") } return &Value[int]{v0.value / v1.value}, nil }, "/.": func(args []Expression) (Expression, error) { v0, ok := args[0].(*Value[float64]) if !ok { return nil, fmt.Errorf("float value was expected") } v1, ok := args[1].(*Value[float64]) if !ok { return nil, fmt.Errorf("float value was expected") } return &Value[float64]{float64(v0.value) / float64(v1.value)}, nil }, } for id, f := range functions { expr.Replace(id, &Value[Function]{f}) } } func main() { logger := log.New(os.Stderr, "error: ", 0) evalString := flag.String("eval", "", "evaluate an expression and return its value") flag.Parse() args := flag.Args() if len(args) > 0 { bytes, err := os.ReadFile(os.Args[1]) if err != nil { logger.Fatal(fmt.Errorf("file not found")) } source = string(bytes) } else if *evalString != "" { source = *evalString } else { logger.Fatal(fmt.Errorf("missing input file")) } tk, err := lex(source) if err != nil { logger.Fatal(err) } p, err := consume(tk) if err != nil { logger.Fatal(err) } stdlib(p) result, err := fullStep(p) if err != nil { logger.Fatal(err) } fmt.Println(result) }