diff --git a/lib/golang/lib.go b/lib/golang/lib.go index e6fd6f1..aa283fb 100644 --- a/lib/golang/lib.go +++ b/lib/golang/lib.go @@ -7,7 +7,9 @@ import ( "os" "os/exec" "path/filepath" + "reflect" "slices" + "sync" "time" "github.com/go-json-experiment/json" @@ -20,19 +22,26 @@ type Config struct { } type KittenIPC struct { - cmd *exec.Cmd - cfg Config + cmd *exec.Cmd + cfg Config + localApi any socketPath string listener net.Listener conn net.Conn errCh chan error + + nextId int64 + pendingCalls map[int64]chan callResult + mu sync.Mutex } -func New(cmd *exec.Cmd, api any, cfg Config) (*KittenIPC, error) { +func New(cmd *exec.Cmd, localApi any, cfg Config) (*KittenIPC, error) { k := KittenIPC{ - cmd: cmd, - cfg: cfg, + cmd: cmd, + cfg: cfg, + localApi: localApi, + pendingCalls: make(map[int64]chan callResult), } k.socketPath = filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d.sock", os.Getpid())) @@ -117,6 +126,11 @@ type Message struct { Error string `json:"error"` } +type callResult struct { + result []any + err error +} + func (k *KittenIPC) startRcvData() { scn := bufio.NewScanner(k.conn) for scn.Scan() { @@ -133,11 +147,131 @@ func (k *KittenIPC) startRcvData() { } func (k *KittenIPC) processMsg(msg Message) { - + switch msg.Type { + case MsgCall: + k.handleCall(msg) + case MsgResponse: + k.handleResponse(msg) + } } -func (k *KittenIPC) Call() { +func (k *KittenIPC) handleCall(msg Message) { + if k.localApi == nil { + k.sendResponse(msg.Id, nil, fmt.Errorf("remote side does not accept ipc calls")) + } + localApi := reflect.ValueOf(k.localApi) + + method := localApi.MethodByName(msg.Method) + if !method.IsValid() { + k.sendResponse(msg.Id, nil, fmt.Errorf("method not found: %s", msg.Method)) + return + } + + methodType := method.Type() + argsCount := methodType.NumIn() + + if len(msg.Params) != argsCount { + k.sendResponse(msg.Id, nil, fmt.Errorf("argument count mismatch: expected %d, got %d", argsCount, len(msg.Params))) + return + } + + var args []reflect.Value + for _, param := range msg.Params { + args = append(args, reflect.ValueOf(param)) + } + + results := method.Call(args) + resVals := results[0 : len(results)-1] + resErr := results[len(results)-1] + + var res []any + for _, resVal := range resVals { + res = append(res, resVal) + } + + k.sendResponse(msg.Id, res, resErr.Interface().(error)) +} + +func (k *KittenIPC) handleResponse(msg Message) { + + k.mu.Lock() + ch, ok := k.pendingCalls[msg.Id] + if ok { + delete(k.pendingCalls, msg.Id) + } + k.mu.Unlock() + + if !ok { + k.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id)) + return + } + + var err error + if msg.Error != "" { + err = fmt.Errorf("remote error: %s", msg.Error) + } + + ch <- callResult{result: msg.Result, err: err} + close(ch) +} + +func (k *KittenIPC) sendResponse(id int64, result []any, err error) { + msg := Message{ + Type: MsgResponse, + Id: id, + Result: result, + } + + if err != nil { + msg.Error = err.Error() + } + + if err := k.sendMsg(msg); err != nil { + k.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err)) + } +} + +func (k *KittenIPC) sendMsg(msg Message) error { + + data, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("marshal message: %w", err) + } + + data = append(data, '\n') + + if _, err := k.conn.Write(data); err != nil { + return fmt.Errorf("write message: %w", err) + } + + return nil +} + +func (k *KittenIPC) Call(method string, params ...any) ([]any, error) { + k.mu.Lock() + id := k.nextId + k.nextId++ + resChan := make(chan callResult, 1) + k.pendingCalls[id] = resChan + k.mu.Unlock() + + msg := Message{ + Type: MsgCall, + Id: id, + Method: method, + Params: params, + } + + if err := k.sendMsg(msg); err != nil { + k.mu.Lock() + delete(k.pendingCalls, id) + k.mu.Unlock() + return nil, fmt.Errorf("send call: %w", err) + } + + result := <-resChan + return result.result, result.err } func (k *KittenIPC) raiseErr(err error) {