sauce-finder/main.go
2025-08-23 19:48:04 -03:00

310 lines
6 KiB
Go

package main
import (
"bytes"
"cmp"
"context"
"errors"
"fmt"
"image/color"
// "sort"
// "image/png"
"io"
"log"
"net/http"
"os"
"path"
"sauce/shared"
"slices"
"time"
"gocv.io/x/gocv"
"gocv.io/x/gocv/contrib"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var (
port = ":9393"
hashes []contrib.ImgHashBase
matcher = gocv.NewBFMatcher()
phash = contrib.PHash{}
avghash = contrib.AverageHash{}
// index = loadIndex()
db *gorm.DB
)
type candidate struct {
page shared.Page
phash, avghash float64
matches []gocv.DMatch
averageDistance float64
publication shared.Publication
}
func loadImageFromDisk(path string) (shared.Page, error) {
img := gocv.IMRead(path, gocv.IMReadColor)
if img.Empty() {
log.Panic("cannot read image", path)
}
return shared.LoadImage(path, img)
}
// func (e *shared.Page) hashImage() {
// phash.Compute(e.image, &e.phash)
// if e.phash.Empty() {
// panic("empty")
// }
// avghash.Compute(e.image, &e.avghash)
// if e.phash.Empty() {
// panic("empty")
// }
// }
func newCandidate(e shared.Page, p shared.Publication) candidate {
return candidate{
page: e,
publication: p,
// phash: phash.Compare(e.phash, b.phash) / 64,
// avghash: avghash.Compare(e.avghash, b.avghash) / 64,
}
}
func (c *candidate) tryMatch(search shared.Page) {
c.matches = matcher.Match(search.Descriptors, c.page.Descriptors)
slices.SortFunc(c.matches, func(a, b gocv.DMatch) int {
return cmp.Compare(a.Distance, b.Distance)
})
var average float64
for _, m := range c.matches {
average += m.Distance
}
c.averageDistance = average / float64(len(c.matches))
}
// todo: paralelizar
func loadIndex() []shared.Publication {
now := time.Now()
log.Println("loading index...")
const indexFolder = "index"
indexDir, err := os.ReadDir(indexFolder)
if err != nil {
panic(err)
}
var index []shared.Publication
for _, i := range indexDir {
if !i.Type().IsDir() {
continue
}
var pages []shared.Page
cachePath := path.Join(indexFolder, i.Name(), "cache")
pagesPath := path.Join(indexFolder, i.Name(), "pages")
_, err := os.Stat(cachePath) // validade cache
if errors.Is(err, os.ErrNotExist) {
pagesFolder, err := os.ReadDir(pagesPath)
if err != nil {
log.Println(err)
continue
}
err = os.Mkdir(cachePath, os.ModePerm)
if err != nil {
panic(err)
}
for _, p := range pagesFolder {
e, err := loadImageFromDisk(path.Join(pagesPath, p.Name()))
if err != nil {
log.Println(err)
continue
}
e.SaveORBtoDisk(path.Join(cachePath, p.Name()))
pages = append(pages, e)
// img, err := e.Descriptors.ToImage()
// if err != nil {
// panic(err)
// }
//
// cache, err := os.Create(path.Join(cachePath, p.Name()))
// if err != nil {
// panic(err)
// }
// err = png.Encode(cache, img)
}
} else if err != nil {
panic(err)
} else {
cacheDir, err := os.ReadDir(cachePath)
if err != nil {
panic(err)
}
for _, c := range cacheDir {
des := gocv.IMRead(path.Join(cachePath, c.Name()), gocv.IMReadAnyColor)
pages = append(pages, shared.Page{
Descriptors: des,
Path: path.Join(pagesPath, c.Name()),
Name: c.Name(),
})
}
}
index = append(index, shared.Publication{
Title: i.Name(),
Pages: pages,
})
}
log.Println("index loaded in", time.Since(now))
return index
}
func drawMatches(a, b shared.Page, matches []gocv.DMatch, path string) {
output := gocv.NewMat()
gocv.DrawMatches(
a.Image, a.Keypoints,
b.Image, b.Keypoints,
matches[:20],
&output,
color.RGBA{R: 255}, color.RGBA{R: 255}, nil,
gocv.NotDrawSinglePoints,
)
gocv.IMWrite(path, output)
// fmt.Println()
// img2 := gocv.NewMat()
// gocv.DrawKeyPoints(search.image, kp, &img2, color.RGBA{R: 255}, 0)
// gocv.IMWrite("matches.png", img3)
}
func handleSearch(w http.ResponseWriter, req *http.Request) {
fileReader, _, err := req.FormFile("search")
if err != nil {
panic(err)
}
file, err := io.ReadAll(fileReader)
if err != nil {
panic(err)
}
search, err := shared.LoadImageFromBytes(file)
var candidates []candidate
rows, err := db.Debug().Model(&shared.Page{}).Preload("publications").Rows()
if err != nil {
panic(err)
}
defer rows.Close()
for rows.Next() {
var page shared.Page
db.ScanRows(rows, &page)
page.Descriptors, err = gocv.NewMatFromBytes(500, 32, gocv.MatTypeCV8U, page.DescriptorBlob)
if err != nil {
panic(err)
}
c := newCandidate(page, shared.Publication{})
c.tryMatch(search)
candidates = append(candidates, c)
// }
}
slices.SortFunc(candidates, func(a, b candidate) int {
return cmp.Compare(a.averageDistance, b.averageDistance)
})
var pages []shared.Page
for _, c := range candidates[:8] {
var pub shared.Publication
err = db.Where("id = ?", c.page.UserID).Find(&pub).Error
if err != nil {
panic(err)
}
fmt.Println("pub:", pub)
c.page.Publication = pub
pages = append(pages, c.page)
}
layout(results(search.B64, pages)).Render(context.Background(), w)
}
func main() {
var err error
db, err = gorm.Open(sqlite.Open("test.db"), &gorm.Config{})
if err != nil {
panic(err)
}
home := bytes.Buffer{}
layout(form()).Render(context.Background(), &home)
router := http.NewServeMux()
router.HandleFunc("GET /", func(w http.ResponseWriter, req *http.Request) {
w.Write(home.Bytes())
})
router.HandleFunc("POST /search", handleSearch)
router.HandleFunc("GET /src", func(w http.ResponseWriter, r *http.Request) {
url := r.FormValue("src")
if url == "" {
panic(url)
}
resp1, err := http.Get(url)
if err != nil {
panic(err)
}
if resp1.StatusCode != 200 {
panic(resp1.Status)
}
defer resp1.Body.Close()
_, err = io.Copy(w, resp1.Body)
if err != nil {
panic(err)
}
})
router.Handle("GET /index/", http.StripPrefix("/index/", http.FileServer(http.Dir("index"))))
server := http.Server{
Addr: port,
Handler: Logging(router),
}
fmt.Println("http://localhost" + port)
log.Fatal(server.ListenAndServe())
}