Type check

This commit is contained in:
silva guimaraes 2025-06-01 15:05:45 -03:00
parent 8f5b16878f
commit 0cd6040a1f
7 changed files with 625 additions and 256 deletions

2
.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
testdata
tmp

334
expression.go Normal file
View file

@ -0,0 +1,334 @@
package main
import (
"fmt"
"maps"
)
type Expression interface {
Step() (Expression, error)
Substitute(Id, Expression)
TypeCheck(TypeEnvironment) (Type, error)
}
type Declaration struct {
variables []*Symbol
body []Expression
}
func (d *Declaration) Step() (Expression, error) {
for _, v := range d.variables {
if isValue(v.expr) {
continue
}
n, err := v.expr.Step()
if err != nil {
return nil, err
}
v.expr = n
return d, nil
}
for _, v := range d.variables {
for _, b := range d.body {
b.Substitute(v.id.value, v.expr)
}
}
if len(d.body) == 0 {
return &zero, nil
}
for i, b := range d.body {
if isValue(b) {
continue
}
n, err := b.Step()
if err != nil {
return nil, err
}
d.body[i] = n
return d, nil
}
return d.body[len(d.body)-1], nil
}
func (d *Declaration) Substitute(id Id, expr Expression) {
for _, v := range d.variables {
v.expr.Substitute(id, expr)
}
for _, b := range d.body {
b.Substitute(id, expr)
}
}
func (d *Declaration) TypeCheck(env TypeEnvironment) (Type, error) {
env = maps.Clone(env)
for _, v := range d.variables {
typ, err := v.expr.TypeCheck(env)
if err != nil {
return nil, err
}
env[v.id.value] = typ
}
var body []Type = make([]Type, len(d.body))
for i, e := range d.body {
typ, err := e.TypeCheck(env)
if err != nil {
return nil, err
}
body[i] = typ
}
if len(body) > 0 {
return body[len(body)-1], nil
} else {
return &PrimVoid, nil
}
}
type Conditional struct {
predicate Expression
ifTrue Expression
ifFalse Expression
}
func (c *Conditional) Step() (Expression, error) {
n, err := c.predicate.Step()
if err != nil {
return nil, err
}
c.predicate = n
if !isValue(n) {
return c, nil
}
b, ok := n.(*BoolValue)
if !ok {
return nil, fmt.Errorf("boolean value was expected")
}
if *b {
n, err = c.ifTrue.Step()
if err != nil {
return nil, err
}
c.ifTrue = n
if isValue(n) {
return n, nil
} else {
return c, nil
}
} else {
n, err = c.ifFalse.Step()
if err != nil {
return nil, err
}
c.ifTrue = n
if isValue(n) {
return n, nil
} else {
return c, nil
}
}
}
func (c *Conditional) Substitute(id Id, expr Expression) {
c.predicate.Substitute(id, expr)
c.ifTrue.Substitute(id, expr)
c.ifFalse.Substitute(id, expr)
}
func (c *Conditional) TypeCheck(env TypeEnvironment) (Type, error) {
typ, err := c.predicate.TypeCheck(env)
if err != nil {
return nil, err
}
if !typ.Equal(&PrimBoolean) { // TODO: truthiness
return nil, fmt.Errorf("expected boolean expression as predicate")
}
typ1, err := c.ifTrue.TypeCheck(env)
if err != nil {
return nil, err
}
typ2, err := c.ifFalse.TypeCheck(env)
if err != nil {
return nil, err
}
if !typ1.Equal(typ2) {
return nil, fmt.Errorf("type mismatch between conditionals. can't coerce %v and %v", typ1, typ2)
}
return typ1, nil
}
type BuiltInFunction struct {
t *FunctionType
f func([]Expression) (Expression, error)
}
func (b *BuiltInFunction) HasValue() {}
func (b *BuiltInFunction) Substitute(Id, Expression) {}
func (b *BuiltInFunction) TypeCheck(TypeEnvironment) (Type, error) { return b.t, nil }
func (b *BuiltInFunction) Step() (Expression, error) { panic("not implemented") }
func (b BuiltInFunction) String() string { return "(lambda (..) (..))" }
// func (f *FunctionValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimBoolean, nil }
type Lambda struct {
t *FunctionType
params []*Symbol
body []Expression
}
func (l *Lambda) Step() (Expression, error) {
for _, arg := range l.params {
if arg.expr == nil {
panic("!")
}
for _, e := range l.body {
e.Substitute(arg.id.value, arg.expr)
}
}
if len(l.body) == 0 {
return &zero, nil
}
for i, expr := range l.body {
if isValue(expr) {
continue
}
r, err := expr.Step()
if err != nil {
return nil, err
}
l.body[i] = r
return l, nil
}
return l.body[len(l.body)-1], nil
}
func (l *Lambda) Substitute(id Id, expr Expression) {
for _, b := range l.body {
b.Substitute(id, expr)
}
}
func (l *Lambda) TypeCheck(env TypeEnvironment) (Type, error) {
env = maps.Clone(env)
for _, p := range l.params {
typ, err := p.expr.TypeCheck(env)
if err != nil {
return nil, err
}
env[p.id.value] = typ
}
var body []Type = make([]Type, len(l.body))
for i, e := range l.body {
typ, err := e.TypeCheck(env)
if err != nil {
return nil, err
}
body[i] = typ
}
if len(body) > 0 {
return body[len(body)-1], nil
} else {
return &PrimVoid, nil
}
}
type Symbol struct {
id *Identifier
expr Expression
}
func (v *Symbol) Step() (Expression, error) {
if v.expr == nil {
return nil, fmt.Errorf("unbound symbol %s at %s", v.id, v.id.position)
}
return v.expr, nil
}
func (v *Symbol) Substitute(id Id, expr Expression) {
if v.id.value == id {
v.expr = expr
}
}
func (s *Symbol) TypeCheck(env TypeEnvironment) (Type, error) {
if typ, ok := env[s.id.value]; ok {
return typ, nil
}
if s.expr == nil {
return nil, fmt.Errorf("undeclared symbol %s at %s", s.id, s.id.position)
}
return s.expr.TypeCheck(env)
}
type List struct {
els []Expression
}
func (v *List) Step() (Expression, error) {
for i, expr := range v.els {
if !isValue(expr) {
e, err := expr.Step()
if err != nil {
return nil, err
}
v.els[i] = e
return v, nil
}
}
if len(v.els) == 0 {
return &zero, nil
}
args := v.els[1:]
switch x := v.els[0].(type) {
case *Lambda:
for i, p := range x.params {
p.expr = args[i]
}
return x, nil
case *BuiltInFunction:
return x.f(args)
default:
return nil, fmt.Errorf("not a function")
}
}
func (v *List) Substitute(id Id, expr Expression) {
for _, e := range v.els {
e.Substitute(id, expr)
}
}
func (v *List) TypeCheck(env TypeEnvironment) (Type, error) {
var et []Type
for _, e := range v.els {
typ, err := e.TypeCheck(env)
if err != nil {
return nil, err
}
et = append(et, typ)
}
if len(et) == 0 {
return &PrimAny, nil
}
tf, ok := et[0].(*FunctionType)
if !ok {
return nil, fmt.Errorf("expected function call. got %v", et[0])
}
args := et[1:]
// if !isVariadic(tf.params[len(tf.params)-1]) {
// }
if len(args) != len(tf.params) {
return nil, fmt.Errorf("parameter and arguments number mismatch.")
}
for i := range args {
if !tf.params[i].Equal(args[i]) {
return nil, fmt.Errorf("parameter type mismatch. expected %v got %v",
tf.params[i], args[i])
}
}
return tf.ret, nil
// for _, typ := range et {
// if !typ.Equal(et[0]) {
// return nil, fmt.Errorf("array type mismatch")
// }
// }
// return &ListType{et[0]}, nil
}

30
lex.go
View file

@ -3,6 +3,7 @@ package main
import (
"fmt"
"regexp"
"slices"
"strconv"
"strings"
)
@ -37,9 +38,9 @@ type Id string
type Token token[string]
type Int token[int]
type IntToken token[int]
type Float token[float64]
type FloatToken token[float64]
type Identifier token[Id]
@ -52,13 +53,16 @@ var (
SingleQuote = &Token{value: "'"}
Let = &Token{value: "let"}
LambdaTok = &Token{value: "lambda"}
If = &Token{value: "if"}
TrueTok = &Token{value: "t"}
NilTok = &Token{value: "nil"}
)
func (t Token) String() string {
return fmt.Sprint(t.value)
}
func (t Int) String() string {
func (t IntToken) String() string {
return fmt.Sprint(t.value)
}
@ -75,6 +79,9 @@ var knownTokens = []*Token{
SingleQuote,
Let,
LambdaTok,
If,
TrueTok,
NilTok,
}
func (t *Token) Equals(a Lexeme) bool {
@ -86,8 +93,8 @@ func (t *Token) Equals(a Lexeme) bool {
}
}
func (t *Int) Equals(a Lexeme) bool { return false }
func (t *Float) Equals(a Lexeme) bool { return false }
func (t *IntToken) Equals(a Lexeme) bool { return false }
func (t *FloatToken) Equals(a Lexeme) bool { return false }
func (t *Identifier) Equals(a Lexeme) bool {
switch x := a.(type) {
@ -110,6 +117,15 @@ func lex(source string) (*Input, error) {
outer:
for len(source) > 0 {
p := position{pos: pos, line: line, column: column, filepath: "stdin"}
if source[0] == ';' {
idx := slices.Index([]byte(source), '\n')
if idx < 0 {
break
} else {
source = source[idx:]
continue
}
}
for _, try := range knownTokens {
if !strings.HasPrefix(source, try.value) {
continue
@ -137,7 +153,7 @@ outer:
a := source[index[0]:index[1]]
conv, err := strconv.ParseFloat(a, 64)
if err == nil {
var i = &Float{position: p, value: conv}
var i = &FloatToken{position: p, value: conv}
tokens = append(tokens, i)
l := len(a)
source = source[l:]
@ -151,7 +167,7 @@ outer:
a := source[index[0]:index[1]]
conv, err := strconv.Atoi(a)
if err == nil {
var i = &Int{position: p, value: conv}
var i = &IntToken{position: p, value: conv}
tokens = append(tokens, i)
l := len(a)
source = source[l:]

259
main.go
View file

@ -5,144 +5,14 @@ import (
"fmt"
"log"
"os"
"reflect"
)
var source = ""
func (v *Value[T]) Step() (Expression, error) { return v, nil }
func (v *Value[T]) Replace(Id, Expression) {}
func (v *Quoted) Step() (Expression, error) { return v, nil }
func (v *Quoted) Replace(Id, Expression) {}
func (d *Declaration) Step() (Expression, error) {
for _, v := range d.variables {
if isValue(v.expr) {
continue
}
n, err := v.expr.Step()
if err != nil {
return nil, err
}
v.expr = n
return d, nil
}
for _, v := range d.variables {
for _, b := range d.body {
b.Replace(v.id.value, v.expr)
}
}
if len(d.body) == 0 {
return zero, nil
}
for i, b := range d.body {
if isValue(b) {
continue
}
n, err := b.Step()
if err != nil {
return nil, err
}
d.body[i] = n
return d, nil
}
return d.body[len(d.body)-1], nil
}
func (d *Declaration) Replace(id Id, expr Expression) {
for _, v := range d.variables {
v.expr.Replace(id, expr)
}
for _, b := range d.body {
b.Replace(id, expr)
}
}
func (v *Symbol) Step() (Expression, error) {
if v.expr == nil {
return nil, fmt.Errorf("unbound symbol %s at %s", v.id, v.id.position)
}
return v.expr, nil
}
func (v *Symbol) Replace(id Id, expr Expression) {
if v.id.value == id {
v.expr = expr
}
}
func (l *Lambda) Step() (Expression, error) {
f := func(args []Expression) (Expression, error) {
if len(args) != len(l.args) {
return nil, fmt.Errorf("unexpected number of arguments")
}
for i, arg := range l.args {
for _, e := range l.body {
e.Replace(arg.id.value, args[i])
}
}
if len(l.body) == 0 {
return zero, nil
}
for i, expr := range l.body {
if isValue(expr) {
continue
}
r, err := expr.Step()
if err != nil {
return nil, err
}
l.body[i] = r
}
return l.body[len(l.body)-1], nil
}
return &Value[Function]{f}, nil
}
func (l *Lambda) Replace(id Id, expr Expression) {
for _, b := range l.body {
b.Replace(id, expr)
}
}
func (v *List) Step() (Expression, error) {
for i, expr := range v.els {
if !isValue(expr) {
e, err := expr.Step()
if err != nil {
return nil, err
}
v.els[i] = e
return v, nil
}
}
if len(v.els) == 0 {
return zero, nil
}
f, ok := v.els[0].(*Value[Function])
if !ok {
return nil, fmt.Errorf("not a function")
}
r, err := f.value(v.els[1:])
if err != nil {
return nil, err
}
return r, nil
}
func (v *List) Replace(id Id, expr Expression) {
for _, e := range v.els {
e.Replace(id, expr)
}
}
func isValue(e Expression) bool {
switch e.(type) {
case *Value[int], *Value[float64], *Value[Function], *Quoted:
return true
default:
return false
}
_, ok := e.(Value)
return ok
}
func fullStep(e Expression) (Expression, error) {
@ -156,70 +26,65 @@ func fullStep(e Expression) (Expression, error) {
return e, nil
}
func stdlib(expr Expression) {
functions := map[Id]Function{
"+": func(args []Expression) (Expression, error) {
sum := 0
for _, arg := range args {
if !isValue(arg) {
panic("!")
func builtinFunctions(expr Expression) {
functions := map[Id]BuiltInFunction{
// "quote": {
// t: &FunctionType{
// ret: &PrimInteger,
// params: ,
// },
// f: func(e []Expression) (Expression, error) {
// if len(e) == 1 {
// return &Quoted{e[0]}, nil
// }
// return &Quoted{&List{e}}, nil
// },
// },
"+": {
t: &FunctionType{
ret: &PrimInteger,
params: []Type{&PrimInteger, &PrimInteger},
},
f: func(args []Expression) (Expression, error) {
var sum IntValue = 0
for _, arg := range args {
if !isValue(arg) {
panic("not a value")
}
switch x := arg.(type) {
case *IntValue:
sum += *x
default:
panic("type error")
}
}
switch x := arg.(type) {
case *Value[int]:
sum += x.value
default:
return nil, fmt.Errorf("invalid type")
}
}
return &Value[int]{sum}, nil
},
"-": func(args []Expression) (Expression, error) {
v0, ok := args[0].(*Value[int])
if !ok {
return nil, fmt.Errorf("integer value was expected")
}
v1, ok := args[1].(*Value[int])
if !ok {
return nil, fmt.Errorf("integer value was expected")
}
return &Value[int]{v0.value - v1.value}, nil
},
"*": func(args []Expression) (Expression, error) {
v0, ok := args[0].(*Value[int])
if !ok {
return nil, fmt.Errorf("integer value was expected")
}
v1, ok := args[1].(*Value[int])
if !ok {
return nil, fmt.Errorf("integer value was expected")
}
return &Value[int]{v0.value * v1.value}, nil
},
"/": func(args []Expression) (Expression, error) {
v0, ok := args[0].(*Value[int])
if !ok {
return nil, fmt.Errorf("integer value was expected")
}
v1, ok := args[1].(*Value[int])
if !ok {
return nil, fmt.Errorf("integer value was expected")
}
return &Value[int]{v0.value / v1.value}, nil
},
"/.": func(args []Expression) (Expression, error) {
v0, ok := args[0].(*Value[float64])
if !ok {
return nil, fmt.Errorf("float value was expected")
}
v1, ok := args[1].(*Value[float64])
if !ok {
return nil, fmt.Errorf("float value was expected")
}
return &Value[float64]{float64(v0.value) / float64(v1.value)}, nil
return &sum, nil
},
},
// "-": func(args []Expression) (Expression, error) {
// v0, ok := args[0].(*IntValue)
// if !ok {
// return nil, fmt.Errorf("integer value was expected")
// }
// v1, ok := args[1].(*IntValue)
// if !ok {
// return nil, fmt.Errorf("integer value was expected")
// }
// var result IntValue = (*v0 - *v1)
// return &result, nil
// },
// "=": func(args []Expression) (Expression, error) {
// if reflect.TypeOf(args[0]) != reflect.TypeOf(args[1]) {
// return nil, fmt.Errorf("unmatched types")
// }
// v0 := reflect.ValueOf(args[0])
// v1 := reflect.ValueOf(args[1])
// var result BoolValue = BoolValue(v0.Equal(v1))
// return &result, nil
// },
}
for id, f := range functions {
expr.Replace(id, &Value[Function]{f})
expr.Substitute(id, &f)
}
}
@ -247,10 +112,14 @@ func main() {
if err != nil {
logger.Fatal(err)
}
stdlib(p)
builtinFunctions(p)
_, err = p.TypeCheck(make(TypeEnvironment))
if err != nil {
logger.Fatal(err)
}
result, err := fullStep(p)
if err != nil {
logger.Fatal(err)
}
fmt.Println(result)
fmt.Println(reflect.TypeOf(result), result)
}

114
parse.go
View file

@ -8,56 +8,12 @@ import (
var (
noMatch = errors.New("no match")
zero = &Value[int]{0}
)
type Function func([]Expression) (Expression, error)
type Expression interface {
Step() (Expression, error)
Replace(Id, Expression)
}
type List struct {
els []Expression
}
type Value[T interface{ int | float64 | Function }] struct {
value T
// rng Range
}
type Quoted struct {
expr Expression
}
type Declaration struct {
variables []*Symbol
body []Expression
}
type Lambda struct {
args []*Symbol
body []Expression
}
type Symbol struct {
id *Identifier
expr Expression
}
// type Progn struct {
// *List
// }
func (v Value[T]) String() string {
return fmt.Sprint(v.value)
}
func (v Quoted) String() string {
return fmt.Sprintf("'%s", v.expr)
}
func (v List) String() string {
var s []string
for _, e := range v.els {
@ -68,7 +24,7 @@ func (v List) String() string {
func (v Lambda) String() string {
var args List
for _, arg := range v.args {
for _, arg := range v.params {
args.els = append(args.els, arg)
}
var s []string
@ -78,10 +34,6 @@ func (v Lambda) String() string {
return fmt.Sprintf("(lambda %v %v)", &args, strings.Join(s, " "))
}
func (v Function) String() string {
return "(lambda (..) (..))"
}
func (t Symbol) String() string {
return fmt.Sprintf("%v", t.id.value)
}
@ -109,6 +61,17 @@ func parseSymbol(in *Input) (Expression, error) {
return nil, noMatch
}
switch x := l.(type) {
case *Token:
switch {
case x.Equals(TrueTok):
v := BoolValue(true)
return &v, nil
case x.Equals(NilTok):
v := BoolValue(false)
return &v, nil
default:
return nil, noMatch
}
case *Identifier:
return &Symbol{id: x}, nil
default:
@ -122,10 +85,12 @@ func parseValue(in *Input) (Expression, error) {
return nil, noMatch
}
switch x := l.(type) {
case *Int:
return &Value[int]{x.value}, nil
case *Float:
return &Value[float64]{x.value}, nil
case *IntToken:
r := IntValue(x.value)
return &r, nil
case *FloatToken:
r := FloatValue(x.value)
return &r, nil
default:
return nil, noMatch
}
@ -172,6 +137,10 @@ func parseLet(in *Input) (Expression, error) {
return nil, fmt.Errorf("unmatched parenthesis starting at %s", openingParens.Position())
}
decl := in.Take(idx)
_, ok = in.Pop()
if !ok {
return nil, noMatch
}
tok, ok := decl.Pop()
if !ok {
return nil, noMatch
@ -254,7 +223,43 @@ func parseLambda(in *Input) (Expression, error) {
if !ok {
return nil, noMatch
}
return &Lambda{args: args, body: body}, nil
return &Lambda{params: args, body: body}, nil
}
func parseConditional(in *Input) (Expression, error) {
openingParens, ok := in.Pop()
if !ok {
return nil, noMatch
}
if !openingParens.Equals(Lparen) {
return nil, noMatch
}
idx, ok := in.Find(Lparen, Rparen)
if !ok {
return nil, fmt.Errorf("unmatched parenthesis starting at %s", openingParens.Position())
}
decl := in.Take(idx)
_, ok = in.Pop()
tok, ok := decl.Pop()
if !ok {
return nil, noMatch
}
if !tok.Equals(If) {
return nil, noMatch
}
predicate, err := consume(decl)
if err != nil {
return nil, err
}
ifTrue, err := consume(decl)
if err != nil {
return nil, err
}
ifFalse, err := consume(decl)
if err != nil {
return nil, err
}
return &Conditional{predicate: predicate, ifTrue: ifTrue, ifFalse: ifFalse}, nil
}
type ParserFunc func(*Input) (Expression, error)
@ -263,6 +268,7 @@ func consume(in *Input) (n Expression, err error) {
var parseFunctions = []ParserFunc{
parseQuotedObject,
parseLambda,
parseConditional,
parseLet,
parseList,
parseSymbol,

94
type.go Normal file
View file

@ -0,0 +1,94 @@
package main
import "fmt"
type Type interface {
Equal(Type) bool
}
type PrimitiveType string
type ListType struct{ t Type }
type QuotedType struct{ t Type }
type FunctionType struct {
ret Type
params []Type
}
// type VariadicListType ListType
type TypeEnvironment map[Id]Type
var (
PrimVoid PrimitiveType = "void"
PrimAny PrimitiveType = "any"
PrimBoolean PrimitiveType = "boolean"
PrimInteger PrimitiveType = "integer"
PrimFloat PrimitiveType = "float"
)
func (a *PrimitiveType) Equal(b Type) bool {
switch b := b.(type) {
case *PrimitiveType:
return a == b
default:
return false
}
}
func (p PrimitiveType) String() string {
return string(p)
}
func (a *ListType) Equal(b Type) bool {
switch b := b.(type) {
case *ListType:
return a.t.Equal(b.t)
default:
return false
}
}
func (l ListType) String() string {
return fmt.Sprintf("[list %v]", l.t)
}
func (a *QuotedType) Equal(b Type) bool {
switch b := b.(type) {
case *QuotedType:
return a.t.Equal(b.t)
default:
return false
}
}
func (q QuotedType) String() string {
return fmt.Sprintf("[quote %v]", q.t)
}
func (fa *FunctionType) Equal(b Type) bool {
fb, ok := b.(*FunctionType)
if !ok {
return false
}
if len(fa.params) != len(fb.params) {
return false
}
for i := range fa.params {
if !fa.params[i].Equal(fb.params[i]) {
return false
}
}
return fa.ret.Equal(fb.ret)
}
// func (fa *VariadicListType) Equal(b Type) bool {
// return false
// }
//
// func isVariadic(t Type) bool {
// _, ok := t.(*VariadicListType)
// return ok
// }

48
value.go Normal file
View file

@ -0,0 +1,48 @@
package main
import "fmt"
var (
zero IntValue = 0
True BoolValue = true
Frue BoolValue = false
)
type Value interface {
Expression
HasValue()
}
type IntValue int
func (i *IntValue) HasValue() {}
func (i *IntValue) Step() (Expression, error) { return i, nil }
func (i *IntValue) Substitute(Id, Expression) {}
func (i *IntValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimInteger, nil }
func (i *IntValue) String() string { return fmt.Sprint(*i) }
type FloatValue float32
func (f *FloatValue) HasValue() {}
func (f *FloatValue) Step() (Expression, error) { return f, nil }
func (f *FloatValue) Substitute(Id, Expression) {}
func (f *FloatValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimFloat, nil }
func (f *FloatValue) String() string { return fmt.Sprint(*f) }
type BoolValue bool
func (b *BoolValue) HasValue() {}
func (b *BoolValue) Step() (Expression, error) { return b, nil }
func (b *BoolValue) Substitute(Id, Expression) {}
func (b *BoolValue) TypeCheck(TypeEnvironment) (Type, error) { return &PrimBoolean, nil }
func (b BoolValue) String() string { return fmt.Sprint(bool(b)) }
type Quoted struct{ expr Expression }
func (q *Quoted) HasValue() {}
func (q *Quoted) Step() (Expression, error) { return q, nil }
func (q *Quoted) Substitute(Id, Expression) {}
func (q Quoted) String() string { return fmt.Sprintf("'%v", q.expr) }
func (q *Quoted) TypeCheck(env TypeEnvironment) (Type, error) {
return &QuotedType{&PrimVoid}, nil
}