This commit is contained in:
Egor Aristov 2026-03-28 14:33:46 +03:00
parent 687364767a
commit b7f547839a
34 changed files with 1268 additions and 896 deletions

BIN
example/golang/simple Executable file

Binary file not shown.

View File

@ -7,3 +7,10 @@ require (
golang.org/x/sync v0.17.0
golang.org/x/text v0.30.0
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/stretchr/testify v1.11.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,6 +1,16 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e h1:Lf/gRkoycfOBPa42vU2bbgPurFong6zXeFtPoxholzU=
github.com/go-json-experiment/json v0.0.0-20251027170946-4849db3c2f7e/go.mod h1:uNVvRXArCGbZ508SxYYTC5v1JWoz2voff5pm25jU1Ok=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,14 +1,18 @@
package api
import "efprojects.com/kitten-ipc/types"
type ValType string
// todo check TInt size < 64
// todo check not float
const (
TNoType ValType = ""
TInt ValType = "int"
TString ValType = "string"
TBool ValType = "bool"
TBlob ValType = "blob"
)
type Val struct {
Name string
Type types.ValType
Children []Val
Name string
Type ValType
}
type Method struct {

View File

@ -0,0 +1,19 @@
package common
import (
"fmt"
"os"
)
func WriteFile(destFile string, content []byte) error {
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("open destination file: %w", err)
}
defer f.Close()
if _, err := f.Write(content); err != nil {
return fmt.Errorf("write file: %w", err)
}
return nil
}

View File

@ -4,18 +4,15 @@ import (
"bytes"
"fmt"
"go/format"
"os"
"strings"
"text/template"
_ "embed"
"efprojects.com/kitten-ipc/kitcom/internal/api"
"efprojects.com/kitten-ipc/types"
"efprojects.com/kitten-ipc/kitcom/internal/common"
)
// todo: check int overflow
// todo: check float is whole
//go:embed gogen.tmpl
var templateString string
@ -34,43 +31,41 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error {
Api: apis,
}
const defaultReceiver = "self"
tpl := template.New("gogen")
tpl = tpl.Funcs(map[string]any{
"receiver": func(name string) string {
return defaultReceiver
return strings.ToLower(name[:1])
},
"typedef": func(t types.ValType) (string, error) {
td, ok := map[types.ValType]string{
types.TInt: "int",
types.TString: "string",
types.TBool: "bool",
types.TBlob: "[]byte",
"typedef": func(t api.ValType) (string, error) {
td, ok := map[api.ValType]string{
api.TInt: "int",
api.TString: "string",
api.TBool: "bool",
api.TBlob: "[]byte",
}[t]
if !ok {
return "", fmt.Errorf("cannot generate type %v", t)
}
return td, nil
},
"convtype": func(valDef string, t types.ValType) (string, error) {
td, ok := map[types.ValType]string{
types.TInt: fmt.Sprintf("int(%s.(float64))", valDef),
types.TString: fmt.Sprintf("%s.(string)", valDef),
types.TBool: fmt.Sprintf("%s.(bool)", valDef),
types.TBlob: fmt.Sprintf("%s.([]byte)", valDef),
"convtype": func(valDef string, t api.ValType) (string, error) {
td, ok := map[api.ValType]string{
api.TInt: fmt.Sprintf("int(%s.(float64))", valDef),
api.TString: fmt.Sprintf("%s.(string)", valDef),
api.TBool: fmt.Sprintf("%s.(bool)", valDef),
api.TBlob: fmt.Sprintf("%s.([]byte)", valDef),
}[t]
if !ok {
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
}
return td, nil
},
"zerovalue": func(t types.ValType) (string, error) {
v, ok := map[types.ValType]string{
types.TInt: "0",
types.TString: `""`,
types.TBool: "false",
types.TBlob: "[]byte{}",
"zerovalue": func(t api.ValType) (string, error) {
v, ok := map[api.ValType]string{
api.TInt: "0",
api.TString: `""`,
api.TBool: "false",
api.TBlob: "[]byte{}",
}[t]
if !ok {
return "", fmt.Errorf("cannot generate zero value for type %v", t)
@ -86,27 +81,14 @@ func (g *GoApiGenerator) Generate(apis *api.Api, destFile string) error {
return fmt.Errorf("execute template: %w", err)
}
if err := g.writeDest(destFile, buf.Bytes()); err != nil {
formatted, err := format.Source(buf.Bytes())
if err != nil {
return fmt.Errorf("format source: %w", err)
}
if err := common.WriteFile(destFile, formatted); err != nil {
return fmt.Errorf("write file: %w", err)
}
return nil
}
func (g *GoApiGenerator) writeDest(destFile string, bytes []byte) error {
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("open destination file: %w", err)
}
defer f.Close()
formatted, err := format.Source(bytes)
if err != nil {
return fmt.Errorf("format source: %w", err)
}
if _, err := f.Write(formatted); err != nil {
return fmt.Errorf("write formatted source: %w", err)
}
return nil
}

View File

@ -28,7 +28,9 @@ func ({{ $e.Name | receiver }} *{{ $e.Name }}) {{ $mtd.Name }}(
if err != nil {
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }} failed: %w", err)
}
_ = results
if len(results) < {{ len $mtd.Ret }} {
return {{ range $mtd.Ret }}{{ .Type | zerovalue }}, {{ end }} fmt.Errorf("call to {{ $e.Name }}.{{ $mtd.Name }}: expected {{ len $mtd.Ret }} results, got %d", len(results))
}
{{ range $i, $ret := $mtd.Ret }}
{{ if eq $ret.Type "blob" }}
results[{{ $i }}], err = base64.StdEncoding.DecodeString(results[{{ $i }}].(string))

View File

@ -9,7 +9,6 @@ import (
"efprojects.com/kitten-ipc/kitcom/internal/api"
"efprojects.com/kitten-ipc/kitcom/internal/common"
"efprojects.com/kitten-ipc/types"
)
var decorComment = regexp.MustCompile(`^//\s?kittenipc:api$`)
@ -142,11 +141,11 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) {
case *ast.Ident:
switch paramType.Name {
case "int":
val.Type = types.TInt
val.Type = api.TInt
case "string":
val.Type = types.TString
val.Type = api.TString
case "bool":
val.Type = types.TBool
val.Type = api.TBool
case "error":
if returning {
return nil, nil
@ -161,7 +160,7 @@ func fieldToVal(param *ast.Field, returning bool) (*api.Val, error) {
case *ast.Ident:
switch elementType.Name {
case "byte":
val.Type = types.TBlob
val.Type = api.TBlob
default:
return nil, fmt.Errorf("parameter type %s is not supported yet", elementType.Name)
}

View File

@ -0,0 +1,43 @@
package golang
import (
"testing"
"efprojects.com/kitten-ipc/kitcom/internal/api"
"efprojects.com/kitten-ipc/kitcom/internal/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestGoParser(t *testing.T) {
parser := &GoApiParser{Parser: &common.Parser{}}
parser.AddFile("../../../example/golang/main.go")
result, err := parser.Parse()
require.NoError(t, err)
require.Len(t, result.Endpoints, 1)
ep := result.Endpoints[0]
assert.Equal(t, "GoIpcApi", ep.Name)
require.Len(t, ep.Methods, 2)
// Div method
div := ep.Methods[0]
assert.Equal(t, "Div", div.Name)
require.Len(t, div.Params, 2)
assert.Equal(t, api.TInt, div.Params[0].Type)
assert.Equal(t, "a", div.Params[0].Name)
assert.Equal(t, api.TInt, div.Params[1].Type)
assert.Equal(t, "b", div.Params[1].Name)
require.Len(t, div.Ret, 1)
assert.Equal(t, api.TInt, div.Ret[0].Type)
// XorData method
xor := ep.Methods[1]
assert.Equal(t, "XorData", xor.Name)
require.Len(t, xor.Params, 2)
assert.Equal(t, api.TBlob, xor.Params[0].Type)
assert.Equal(t, api.TBlob, xor.Params[1].Type)
require.Len(t, xor.Ret, 1)
assert.Equal(t, api.TBlob, xor.Ret[0].Type)
}

View File

@ -4,14 +4,13 @@ import (
"bytes"
"fmt"
"log"
"os"
"os/exec"
"text/template"
_ "embed"
"efprojects.com/kitten-ipc/kitcom/internal/api"
"efprojects.com/kitten-ipc/types"
"efprojects.com/kitten-ipc/kitcom/internal/common"
)
//go:embed tsgen.tmpl
@ -31,24 +30,24 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error
tpl := template.New("tsgen")
tpl = tpl.Funcs(map[string]any{
"typedef": func(t types.ValType) (string, error) {
td, ok := map[types.ValType]string{
types.TInt: "number",
types.TString: "string",
types.TBool: "boolean",
types.TBlob: "Buffer",
"typedef": func(t api.ValType) (string, error) {
td, ok := map[api.ValType]string{
api.TInt: "number",
api.TString: "string",
api.TBool: "boolean",
api.TBlob: "Buffer",
}[t]
if !ok {
return "", fmt.Errorf("cannot generate type %v", t)
}
return td, nil
},
"convtype": func(valDef string, t types.ValType) (string, error) {
td, ok := map[types.ValType]string{
types.TInt: fmt.Sprintf("%s as number", valDef),
types.TString: fmt.Sprintf("%s as string", valDef),
types.TBool: fmt.Sprintf("%s as boolean", valDef),
types.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef),
"convtype": func(valDef string, t api.ValType) (string, error) {
td, ok := map[api.ValType]string{
api.TInt: fmt.Sprintf("%s as number", valDef),
api.TString: fmt.Sprintf("%s as string", valDef),
api.TBool: fmt.Sprintf("%s as boolean", valDef),
api.TBlob: fmt.Sprintf("Buffer.from(%s, 'base64')", valDef),
}[t]
if !ok {
return "", fmt.Errorf("cannot convert type %v for val %s", t, valDef)
@ -64,24 +63,10 @@ func (g *TypescriptApiGenerator) Generate(apis *api.Api, destFile string) error
return fmt.Errorf("execute template: %w", err)
}
if err := g.writeDest(destFile, buf.Bytes()); err != nil {
if err := common.WriteFile(destFile, buf.Bytes()); err != nil {
return fmt.Errorf("write file: %w", err)
}
return nil
}
func (g *TypescriptApiGenerator) writeDest(destFile string, bytes []byte) error {
f, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return fmt.Errorf("open destination file: %w", err)
}
defer f.Close()
if _, err := f.Write(bytes); err != nil {
return fmt.Errorf("write formatted source: %w", err)
}
prettierCmd := exec.Command("npx", "prettier", destFile, "--write")
if out, err := prettierCmd.CombinedOutput(); err != nil {
log.Printf("Prettier returned error: %v", err)

View File

@ -12,7 +12,6 @@ import (
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/core"
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/parser"
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/tspath"
"efprojects.com/kitten-ipc/types"
)
type TypescriptApiParser struct {
@ -109,25 +108,25 @@ func (p *TypescriptApiParser) parseFile(sourceFilePath string) ([]api.Endpoint,
return endpoints, nil
}
func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (types.ValType, error) {
func (p *TypescriptApiParser) fieldToVal(typ *ast.TypeNode) (api.ValType, error) {
switch typ.Kind {
case ast.KindNumberKeyword:
return types.TInt, nil
return api.TInt, nil
case ast.KindStringKeyword:
return types.TString, nil
return api.TString, nil
case ast.KindBooleanKeyword:
return types.TBool, nil
return api.TBool, nil
case ast.KindTypeReference:
refNode := typ.AsTypeReferenceNode()
ident := refNode.TypeName.AsIdentifier()
switch ident.Text {
case "Buffer":
return types.TBlob, nil
return api.TBlob, nil
default:
return types.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text)
return api.TNoType, fmt.Errorf("reference type %s is not supported yet", ident.Text)
}
default:
return types.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind)
return api.TNoType, fmt.Errorf("type %s is not supported yet", typ.Kind)
}
}

View File

@ -0,0 +1,46 @@
package ts
import (
"path/filepath"
"testing"
"efprojects.com/kitten-ipc/kitcom/internal/api"
"efprojects.com/kitten-ipc/kitcom/internal/common"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestTsParser(t *testing.T) {
parser := &TypescriptApiParser{Parser: &common.Parser{}}
absPath, err := filepath.Abs("../../../example/ts/src/index.ts")
require.NoError(t, err)
parser.AddFile(absPath)
result, err := parser.Parse()
require.NoError(t, err)
require.Len(t, result.Endpoints, 1)
ep := result.Endpoints[0]
assert.Equal(t, "TsIpcApi", ep.Name)
require.Len(t, ep.Methods, 2)
// Div method
div := ep.Methods[0]
assert.Equal(t, "Div", div.Name)
require.Len(t, div.Params, 2)
assert.Equal(t, api.TInt, div.Params[0].Type)
assert.Equal(t, "a", div.Params[0].Name)
assert.Equal(t, api.TInt, div.Params[1].Type)
assert.Equal(t, "b", div.Params[1].Name)
require.Len(t, div.Ret, 1)
assert.Equal(t, api.TInt, div.Ret[0].Type)
// XorData method
xor := ep.Methods[1]
assert.Equal(t, "XorData", xor.Name)
require.Len(t, xor.Params, 2)
assert.Equal(t, api.TBlob, xor.Params[0].Type)
assert.Equal(t, api.TBlob, xor.Params[1].Type)
require.Len(t, xor.Ret, 1)
assert.Equal(t, api.TBlob, xor.Ret[0].Type)
}

60
lib/golang/child.go Normal file
View File

@ -0,0 +1,60 @@
package golang
import (
"context"
"fmt"
"net"
"os"
)
type ChildIPC struct {
*ipcCommon
}
func NewChild(localApis ...any) (*ChildIPC, error) {
c := ChildIPC{
ipcCommon: &ipcCommon{
localApis: mapTypeNames(localApis),
pendingCalls: make(map[int64]*pendingCall),
errCh: make(chan error, 1),
ctx: context.Background(),
},
}
socketPath := socketPathFromArgs()
if socketPath == "" {
return nil, fmt.Errorf("ipc socket path is missing")
}
c.socketPath = socketPath
return &c, nil
}
func (c *ChildIPC) Start() error {
conn, err := net.Dial("unix", c.socketPath)
if err != nil {
return fmt.Errorf("connect to parent socket: %w", err)
}
c.conn = conn
go c.readConn()
return nil
}
func (c *ChildIPC) Wait() error {
err := <-c.errCh
if err != nil {
return fmt.Errorf("ipc error: %w", err)
}
return nil
}
// socketPathFromArgs parses --ipc-socket from os.Args without calling flag.Parse(),
// which would interfere with the host application's flag handling.
func socketPathFromArgs() string {
for i, arg := range os.Args {
if arg == ipcSocketArg && i+1 < len(os.Args) {
return os.Args[i+1]
}
}
return ""
}

261
lib/golang/common.go Normal file
View File

@ -0,0 +1,261 @@
package golang
import (
"bufio"
"context"
"encoding/json"
"fmt"
"net"
"reflect"
"strings"
"sync"
"sync/atomic"
)
type IpcCommon interface {
Call(method string, params ...any) (Vals, error)
ConvType(needType, gotType reflect.Type, arg any) any
}
type callResult struct {
vals Vals
err error
}
type pendingCall struct {
resultChan chan callResult
}
type ipcCommon struct {
localApis map[string]any
socketPath string
conn net.Conn
errCh chan error
nextId int64
pendingCalls map[int64]*pendingCall
processingCalls atomic.Int64
stopRequested atomic.Bool
mu sync.Mutex
writeMu sync.Mutex
ctx context.Context
}
func (ipc *ipcCommon) readConn() {
scn := bufio.NewScanner(ipc.conn)
scn.Buffer(nil, maxMessageLength)
for scn.Scan() {
var msg Message
msgBytes := scn.Bytes()
if err := json.Unmarshal(msgBytes, &msg); err != nil {
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
break
}
ipc.processMsg(msg)
}
if err := scn.Err(); err != nil {
ipc.raiseErr(err)
}
}
func (ipc *ipcCommon) processMsg(msg Message) {
switch msg.Type {
case MsgCall:
go ipc.handleCall(msg)
case MsgResponse:
ipc.handleResponse(msg)
}
}
func (ipc *ipcCommon) sendMsg(msg Message) error {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal message: %w", err)
}
data = append(data, '\n')
ipc.writeMu.Lock()
_, writeErr := ipc.conn.Write(data)
ipc.writeMu.Unlock()
if writeErr != nil {
return fmt.Errorf("write message: %w", writeErr)
}
return nil
}
func (ipc *ipcCommon) handleCall(msg Message) {
if ipc.stopRequested.Load() {
return
}
ipc.processingCalls.Add(1)
defer ipc.processingCalls.Add(-1)
defer func() {
if err := recover(); err != nil {
ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err))
}
}()
method, err := ipc.findMethod(msg.Method)
if err != nil {
ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err))
return
}
argsCount := method.Type().NumIn()
if len(msg.Args) != argsCount {
ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args)))
return
}
var args []reflect.Value
for i, arg := range msg.Args {
paramType := method.Type().In(i)
argType := reflect.TypeOf(arg)
arg = ipc.ConvType(paramType, argType, arg)
args = append(args, reflect.ValueOf(arg))
}
allResultVals := method.Call(args)
retResultVals := allResultVals[0 : len(allResultVals)-1]
errResultVals := allResultVals[len(allResultVals)-1]
var results []any
for _, resVal := range retResultVals {
results = append(results, resVal.Interface())
}
var resErr error
if !errResultVals.IsNil() {
resErr = errResultVals.Interface().(error)
}
ipc.sendResponse(msg.Id, results, resErr)
}
func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {
parts := strings.Split(methodName, ".")
if len(parts) != 2 {
return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName)
}
endpointName, methodName := parts[0], parts[1]
localApi, ok := ipc.localApis[endpointName]
if !ok {
return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName)
}
method := reflect.ValueOf(localApi).MethodByName(methodName)
if !method.IsValid() {
return reflect.Value{}, fmt.Errorf("method not found: %s", methodName)
}
return method, nil
}
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
msg := Message{
Type: MsgResponse,
Id: id,
Result: result,
}
if err != nil {
msg.Error = err.Error()
}
if err := ipc.sendMsg(msg); err != nil {
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
}
}
func (ipc *ipcCommon) handleResponse(msg Message) {
ipc.mu.Lock()
call, ok := ipc.pendingCalls[msg.Id]
if ok {
delete(ipc.pendingCalls, msg.Id)
}
ipc.mu.Unlock()
if !ok {
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
return
}
var res callResult
if msg.Error == "" {
res = callResult{vals: msg.Result}
} else {
res = callResult{err: fmt.Errorf("remote error: %s", msg.Error)}
}
call.resultChan <- res
close(call.resultChan)
}
func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) {
if ipc.conn == nil {
return nil, fmt.Errorf("ipc is not connected to remote process socket")
}
if ipc.stopRequested.Load() {
return nil, fmt.Errorf("ipc is stopping")
}
ipc.mu.Lock()
id := ipc.nextId
ipc.nextId++
call := &pendingCall{
resultChan: make(chan callResult, 1),
}
ipc.pendingCalls[id] = call
ipc.mu.Unlock()
for i := range params {
params[i] = ipc.serialize(params[i])
}
msg := Message{
Type: MsgCall,
Id: id,
Method: method,
Args: params,
}
if err := ipc.sendMsg(msg); err != nil {
ipc.mu.Lock()
delete(ipc.pendingCalls, id)
ipc.mu.Unlock()
return nil, fmt.Errorf("send call: %w", err)
}
select {
case result := <-call.resultChan:
return result.vals, result.err
case <-ipc.ctx.Done():
ipc.mu.Lock()
delete(ipc.pendingCalls, id)
ipc.mu.Unlock()
return nil, ipc.ctx.Err()
}
}
func (ipc *ipcCommon) raiseErr(err error) {
select {
case ipc.errCh <- err:
default:
}
}
func (ipc *ipcCommon) closeConn() {
_ = ipc.conn.Close()
ipc.mu.Lock()
pending := ipc.pendingCalls
ipc.pendingCalls = make(map[int64]*pendingCall)
ipc.mu.Unlock()
for _, call := range pending {
call.resultChan <- callResult{err: fmt.Errorf("call cancelled due to ipc termination")}
close(call.resultChan)
}
}

43
lib/golang/common_test.go Normal file
View File

@ -0,0 +1,43 @@
package golang
import (
"testing"
"github.com/stretchr/testify/assert"
)
type testEndpoint struct{}
func (e *testEndpoint) Hello(name string) (string, error) {
return "hello " + name, nil
}
func TestFindMethod(t *testing.T) {
ipc := &ipcCommon{
localApis: mapTypeNames([]any{&testEndpoint{}}),
}
t.Run("valid method", func(t *testing.T) {
method, err := ipc.findMethod("testEndpoint.Hello")
assert.NoError(t, err)
assert.True(t, method.IsValid())
})
t.Run("invalid format - no dot", func(t *testing.T) {
_, err := ipc.findMethod("Hello")
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid method")
})
t.Run("unknown endpoint", func(t *testing.T) {
_, err := ipc.findMethod("Unknown.Hello")
assert.Error(t, err)
assert.Contains(t, err.Error(), "endpoint not found")
})
t.Run("unknown method", func(t *testing.T) {
_, err := ipc.findMethod("testEndpoint.Unknown")
assert.Error(t, err)
assert.Contains(t, err.Error(), "method not found")
})
}

View File

@ -2,4 +2,10 @@ module efprojects.com/kitten-ipc
go 1.25.1
require github.com/samber/mo v1.16.0
require github.com/stretchr/testify v1.11.1
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

View File

@ -1,2 +1,10 @@
github.com/samber/mo v1.16.0 h1:qpEPCI63ou6wXlsNDMLE0IIN8A+devbGX/K1xdgr4b4=
github.com/samber/mo v1.16.0/go.mod h1:DlgzJ4SYhOh41nP1L9kh9rDNERuf8IqWSAs+gj2Vxag=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,550 +0,0 @@
package golang
import (
"bufio"
"context"
"encoding/base64"
"encoding/json"
"errors"
"flag"
"fmt"
"math/rand"
"net"
"os"
"os/exec"
"path/filepath"
"reflect"
"slices"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"efprojects.com/kitten-ipc/types"
"github.com/samber/mo"
)
const ipcSocketArg = "--ipc-socket"
const maxMessageLength = 1 * 1024 * 1024 * 1024 // 1 gigabyte
type StdioMode int
type MsgType int
type Vals []any
const (
MsgCall MsgType = 1
MsgResponse MsgType = 2
)
type Message struct {
Type MsgType `json:"type"`
Id int64 `json:"id"`
Method string `json:"method"`
Args Vals `json:"args"`
Result Vals `json:"result"`
Error string `json:"error"`
}
type IpcCommon interface {
Call(method string, params ...any) (Vals, error)
ConvType(needType reflect.Type, gotType reflect.Type, arg any) any
}
type pendingCall struct {
resultChan chan mo.Result[Vals]
resultType reflect.Type
}
type ipcCommon struct {
localApis map[string]any
socketPath string
conn net.Conn
errCh chan error
nextId int64
pendingCalls map[int64]*pendingCall
processingCalls atomic.Int64
stopRequested atomic.Bool
mu sync.Mutex
writeMu sync.Mutex
ctx context.Context
}
func (ipc *ipcCommon) readConn() {
scn := bufio.NewScanner(ipc.conn)
scn.Buffer(nil, maxMessageLength)
for scn.Scan() {
var msg Message
msgBytes := scn.Bytes()
//log.Println(string(msgBytes))
if err := json.Unmarshal(msgBytes, &msg); err != nil {
ipc.raiseErr(fmt.Errorf("unmarshal message: %w", err))
break
}
ipc.processMsg(msg)
}
if err := scn.Err(); err != nil {
ipc.raiseErr(err)
}
}
func (ipc *ipcCommon) processMsg(msg Message) {
switch msg.Type {
case MsgCall:
go ipc.handleCall(msg)
case MsgResponse:
ipc.handleResponse(msg)
}
}
func (ipc *ipcCommon) sendMsg(msg Message) error {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshal message: %w", err)
}
data = append(data, '\n')
ipc.writeMu.Lock()
_, writeErr := ipc.conn.Write(data)
ipc.writeMu.Unlock()
if writeErr != nil {
return fmt.Errorf("write message: %w", writeErr)
}
return nil
}
func (ipc *ipcCommon) handleCall(msg Message) {
if ipc.stopRequested.Load() {
return
}
ipc.processingCalls.Add(1)
defer ipc.processingCalls.Add(-1)
defer func() {
if err := recover(); err != nil {
ipc.sendResponse(msg.Id, nil, fmt.Errorf("handle call panicked: %s", err))
}
}()
method, err := ipc.findMethod(msg.Method)
if err != nil {
ipc.sendResponse(msg.Id, nil, fmt.Errorf("find method: %w", err))
return
}
argsCount := method.Type().NumIn()
if len(msg.Args) != argsCount {
ipc.sendResponse(msg.Id, nil, fmt.Errorf("args count mismatch: expected %d, got %d", argsCount, len(msg.Args)))
return
}
var args []reflect.Value
for i, arg := range msg.Args {
paramType := method.Type().In(i)
argType := reflect.TypeOf(arg)
arg = ipc.ConvType(paramType, argType, arg)
args = append(args, reflect.ValueOf(arg))
}
allResultVals := method.Call(args)
retResultVals := allResultVals[0 : len(allResultVals)-1]
errResultVals := allResultVals[len(allResultVals)-1]
var results []any
for _, resVal := range retResultVals {
results = append(results, resVal.Interface())
}
var resErr error
if !errResultVals.IsNil() {
resErr = errResultVals.Interface().(error)
}
ipc.sendResponse(msg.Id, results, resErr)
}
func (ipc *ipcCommon) findMethod(methodName string) (reflect.Value, error) {
parts := strings.Split(methodName, ".")
if len(parts) != 2 {
return reflect.Value{}, fmt.Errorf("invalid method: %s", methodName)
}
endpointName, methodName := parts[0], parts[1]
localApi, ok := ipc.localApis[endpointName]
if !ok {
return reflect.Value{}, fmt.Errorf("endpoint not found: %s", endpointName)
}
method := reflect.ValueOf(localApi).MethodByName(methodName)
if !method.IsValid() {
return reflect.Value{}, fmt.Errorf("method not found: %s", methodName)
}
return method, nil
}
func (ipc *ipcCommon) sendResponse(id int64, result []any, err error) {
msg := Message{
Type: MsgResponse,
Id: id,
Result: result,
}
if err != nil {
msg.Error = err.Error()
}
if err := ipc.sendMsg(msg); err != nil {
ipc.raiseErr(fmt.Errorf("send response for id=%d: %w", id, err))
}
}
func (ipc *ipcCommon) handleResponse(msg Message) {
ipc.mu.Lock()
call, ok := ipc.pendingCalls[msg.Id]
if ok {
delete(ipc.pendingCalls, msg.Id)
}
ipc.mu.Unlock()
if !ok {
ipc.raiseErr(fmt.Errorf("received response for unknown call id: %d", msg.Id))
return
}
var res mo.Result[Vals]
if msg.Error == "" {
res = mo.Ok[Vals](msg.Result)
} else {
res = mo.Err[Vals](fmt.Errorf("remote error: %s", msg.Error))
}
call.resultChan <- res
close(call.resultChan)
}
func (ipc *ipcCommon) Call(method string, params ...any) (Vals, error) {
if ipc.conn == nil {
return nil, fmt.Errorf("ipc is not connected to remote process socket")
}
if ipc.stopRequested.Load() {
return nil, fmt.Errorf("ipc is stopping")
}
ipc.mu.Lock()
id := ipc.nextId
ipc.nextId++
call := &pendingCall{
resultChan: make(chan mo.Result[Vals], 1),
}
ipc.pendingCalls[id] = call
ipc.mu.Unlock()
for i := range params {
params[i] = ipc.serialize(params[i])
}
msg := Message{
Type: MsgCall,
Id: id,
Method: method,
Args: params,
}
if err := ipc.sendMsg(msg); err != nil {
ipc.mu.Lock()
delete(ipc.pendingCalls, id)
ipc.mu.Unlock()
return nil, fmt.Errorf("send call: %w", err)
}
select {
case result := <-call.resultChan:
return result.Get()
case <-ipc.ctx.Done():
ipc.mu.Lock()
delete(ipc.pendingCalls, id)
ipc.mu.Unlock()
return nil, ipc.ctx.Err()
}
}
func (ipc *ipcCommon) raiseErr(err error) {
select {
case ipc.errCh <- err:
default:
}
}
func (ipc *ipcCommon) closeConn() {
_ = ipc.conn.Close()
ipc.mu.Lock()
pending := ipc.pendingCalls
ipc.pendingCalls = make(map[int64]*pendingCall)
ipc.mu.Unlock()
for _, call := range pending {
call.resultChan <- mo.Err[Vals](fmt.Errorf("call cancelled due to ipc termination"))
close(call.resultChan)
}
}
func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any {
switch needType.Kind() {
case reflect.Int:
// JSON decodes any number to float64. If we need int, we should check and convert
if gotType.Kind() == reflect.Float64 {
floatArg := arg.(float64)
if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) {
arg = int(floatArg)
}
}
}
return arg
}
func (ipc *ipcCommon) serialize(arg any) any {
t := reflect.TypeOf(arg)
switch t.Kind() {
case reflect.Slice:
switch t.Elem().Name() {
case "uint8":
return map[string]any{
"t": types.TBlob,
"d": base64.StdEncoding.EncodeToString(arg.([]byte)),
}
}
}
return arg
}
type ParentIPC struct {
*ipcCommon
cmd *exec.Cmd
listener net.Listener
cmdDone chan struct{}
cmdErr error
}
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
return NewParentWithContext(context.Background(), cmd, localApis...)
}
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
p := ParentIPC{
ipcCommon: &ipcCommon{
localApis: mapTypeNames(localApis),
pendingCalls: make(map[int64]*pendingCall),
errCh: make(chan error, 1),
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
ctx: ctx,
},
cmd: cmd,
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if slices.Contains(cmd.Args, ipcSocketArg) {
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
}
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
p.errCh = make(chan error, 1)
p.cmdDone = make(chan struct{})
return &p, nil
}
func (p *ParentIPC) Start() error {
_ = os.Remove(p.socketPath)
listener, err := net.Listen("unix", p.socketPath)
if err != nil {
return fmt.Errorf("listen unix socket: %w", err)
}
p.listener = listener
defer p.listener.Close()
err = p.cmd.Start()
if err != nil {
return fmt.Errorf("cmd start: %w", err)
}
go func() {
p.cmdErr = p.cmd.Wait()
close(p.cmdDone)
}()
return p.acceptConn()
}
func (p *ParentIPC) acceptConn() error {
const acceptTimeout = time.Second * 10
res := make(chan mo.Result[net.Conn], 1)
go func() {
conn, err := p.listener.Accept()
if err != nil {
res <- mo.Err[net.Conn](err)
} else {
res <- mo.Ok[net.Conn](conn)
}
close(res)
}()
select {
case <-time.After(acceptTimeout):
_ = p.cmd.Process.Kill()
return fmt.Errorf("accept timeout")
case <-p.cmdDone:
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
case res := <-res:
if res.IsError() {
_ = p.cmd.Process.Kill()
return fmt.Errorf("accept: %w", res.Error())
}
p.conn = res.MustGet()
go p.readConn()
}
return nil
}
func (p *ParentIPC) Stop() error {
p.mu.Lock()
hasPending := len(p.pendingCalls) > 0
p.mu.Unlock()
if hasPending {
return fmt.Errorf("there are calls pending")
}
if p.processingCalls.Load() > 0 {
return fmt.Errorf("there are calls processing")
}
p.stopRequested.Store(true)
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
return fmt.Errorf("send SIGINT: %w", err)
}
return p.Wait()
}
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
const defaultTimeout = time.Duration(1<<63 - 1) // max duration in go
_timeout := variadicToOption(timeout).OrElse(defaultTimeout)
loop:
for {
select {
case err := <-p.errCh:
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
break loop
case <-p.cmdDone:
err := p.cmdErr
if err != nil {
var exitErr *exec.ExitError
if ok := errors.As(err, &exitErr); ok {
if !exitErr.Success() {
ws, ok := exitErr.Sys().(syscall.WaitStatus)
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
}
}
} else {
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
}
}
break loop
case <-time.After(_timeout):
p.stopRequested.Store(true)
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
}
}
}
p.closeConn()
_ = os.Remove(p.socketPath)
return retErr
}
type ChildIPC struct {
*ipcCommon
}
func NewChild(localApis ...any) (*ChildIPC, error) {
c := ChildIPC{
ipcCommon: &ipcCommon{
localApis: mapTypeNames(localApis),
pendingCalls: make(map[int64]*pendingCall),
errCh: make(chan error, 1),
ctx: context.Background(), // todo: NewChildWithContext
},
}
socketPath := flag.String("ipc-socket", "", "Path to IPC socket")
flag.Parse()
if *socketPath == "" {
return nil, fmt.Errorf("ipc socket path is missing")
}
c.socketPath = *socketPath
return &c, nil
}
func (c *ChildIPC) Start() error {
conn, err := net.Dial("unix", c.socketPath)
if err != nil {
return fmt.Errorf("connect to parent socket: %w", err)
}
c.conn = conn
go c.readConn()
return nil
}
func (c *ChildIPC) Wait() error {
err := <-c.errCh
if err != nil {
return fmt.Errorf("ipc error: %w", err)
}
return nil
}
func mergeErr(errs ...error) (ret error) {
for _, err := range errs {
if err != nil {
if ret == nil {
ret = err
} else {
ret = fmt.Errorf("%w; %w", ret, err)
}
}
}
return
}
func mapTypeNames(types []any) map[string]any {
result := make(map[string]any)
for _, t := range types {
if reflect.TypeOf(t).Kind() != reflect.Pointer {
panic(fmt.Sprintf("LocalAPI argument must be pointer"))
}
typeName := reflect.TypeOf(t).Elem().Name()
result[typeName] = t
}
return result
}
func variadicToOption[T any](variadic []T) mo.Option[T] {
if len(variadic) >= 2 {
panic("variadic param count must be 0 or 1")
}
if len(variadic) == 0 {
return mo.None[T]()
}
return mo.Some(variadic[0])
}

165
lib/golang/parent.go Normal file
View File

@ -0,0 +1,165 @@
package golang
import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"os"
"os/exec"
"path/filepath"
"slices"
"syscall"
"time"
)
type ParentIPC struct {
*ipcCommon
cmd *exec.Cmd
listener net.Listener
cmdDone chan struct{}
cmdErr error
}
func NewParent(cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
return NewParentWithContext(context.Background(), cmd, localApis...)
}
func NewParentWithContext(ctx context.Context, cmd *exec.Cmd, localApis ...any) (*ParentIPC, error) {
p := ParentIPC{
ipcCommon: &ipcCommon{
localApis: mapTypeNames(localApis),
pendingCalls: make(map[int64]*pendingCall),
errCh: make(chan error, 1),
socketPath: filepath.Join(os.TempDir(), fmt.Sprintf("kitten-ipc-%d-%d.sock", os.Getpid(), rand.Int63())),
ctx: ctx,
},
cmd: cmd,
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if slices.Contains(cmd.Args, ipcSocketArg) {
return nil, fmt.Errorf("you should not use `%s` argument in your command", ipcSocketArg)
}
cmd.Args = append(cmd.Args, ipcSocketArg, p.socketPath)
p.errCh = make(chan error, 1)
p.cmdDone = make(chan struct{})
return &p, nil
}
func (p *ParentIPC) Start() error {
_ = os.Remove(p.socketPath)
listener, err := net.Listen("unix", p.socketPath)
if err != nil {
return fmt.Errorf("listen unix socket: %w", err)
}
p.listener = listener
defer p.listener.Close()
err = p.cmd.Start()
if err != nil {
return fmt.Errorf("cmd start: %w", err)
}
go func() {
p.cmdErr = p.cmd.Wait()
close(p.cmdDone)
}()
return p.acceptConn()
}
type connResult struct {
conn net.Conn
err error
}
func (p *ParentIPC) acceptConn() error {
res := make(chan connResult, 1)
go func() {
conn, err := p.listener.Accept()
res <- connResult{conn: conn, err: err}
close(res)
}()
select {
case <-time.After(time.Duration(defaultAcceptTimeout) * time.Second):
_ = p.cmd.Process.Kill()
return fmt.Errorf("accept timeout")
case <-p.cmdDone:
return fmt.Errorf("cmd exited before accepting connection: %w", p.cmdErr)
case r := <-res:
if r.err != nil {
_ = p.cmd.Process.Kill()
return fmt.Errorf("accept: %w", r.err)
}
p.conn = r.conn
go p.readConn()
}
return nil
}
func (p *ParentIPC) Stop() error {
p.mu.Lock()
hasPending := len(p.pendingCalls) > 0
p.mu.Unlock()
if hasPending {
return fmt.Errorf("there are calls pending")
}
if p.processingCalls.Load() > 0 {
return fmt.Errorf("there are calls processing")
}
p.stopRequested.Store(true)
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
return fmt.Errorf("send SIGINT: %w", err)
}
return p.Wait()
}
func (p *ParentIPC) Wait(timeout ...time.Duration) (retErr error) {
const maxDuration = time.Duration(1<<63 - 1)
_timeout := maxDuration
if len(timeout) > 0 {
_timeout = timeout[0]
}
loop:
for {
select {
case err := <-p.errCh:
retErr = mergeErr(retErr, fmt.Errorf("ipc internal error: %w", err))
break loop
case <-p.cmdDone:
err := p.cmdErr
if err != nil {
var exitErr *exec.ExitError
if ok := errors.As(err, &exitErr); ok {
if !exitErr.Success() {
ws, ok := exitErr.Sys().(syscall.WaitStatus)
if !(ok && ws.Signaled() && ws.Signal() == syscall.SIGINT && p.stopRequested.Load()) {
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
}
}
} else {
retErr = mergeErr(retErr, fmt.Errorf("cmd wait: %w", err))
}
}
break loop
case <-time.After(_timeout):
p.stopRequested.Store(true)
if err := p.cmd.Process.Signal(syscall.SIGINT); err != nil {
retErr = mergeErr(retErr, fmt.Errorf("send SIGINT: %w", err))
}
}
}
p.closeConn()
_ = os.Remove(p.socketPath)
return retErr
}

23
lib/golang/protocol.go Normal file
View File

@ -0,0 +1,23 @@
package golang
const ipcSocketArg = "--ipc-socket"
const maxMessageLength = 1 << 30 // 1 GB
const defaultAcceptTimeout = 10 // seconds
type MsgType int
type Vals []any
const (
MsgCall MsgType = 1
MsgResponse MsgType = 2
)
type Message struct {
Type MsgType `json:"type"`
Id int64 `json:"id"`
Method string `json:"method"`
Args Vals `json:"args"`
Result Vals `json:"result"`
Error string `json:"error"`
}

35
lib/golang/serialize.go Normal file
View File

@ -0,0 +1,35 @@
package golang
import (
"encoding/base64"
"reflect"
)
func (ipc *ipcCommon) serialize(arg any) any {
t := reflect.TypeOf(arg)
switch t.Kind() {
case reflect.Slice:
switch t.Elem().Name() {
case "uint8":
return map[string]any{
"t": "blob",
"d": base64.StdEncoding.EncodeToString(arg.([]byte)),
}
}
}
return arg
}
func (ipc *ipcCommon) ConvType(needType reflect.Type, gotType reflect.Type, arg any) any {
switch needType.Kind() {
case reflect.Int:
// JSON decodes any number to float64. If we need int, we should check and convert
if gotType.Kind() == reflect.Float64 {
floatArg := arg.(float64)
if float64(int64(floatArg)) == floatArg && !needType.OverflowInt(int64(floatArg)) {
arg = int(floatArg)
}
}
}
return arg
}

View File

@ -0,0 +1,61 @@
package golang
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestSerialize(t *testing.T) {
ipc := &ipcCommon{}
t.Run("primitives pass through", func(t *testing.T) {
assert.Equal(t, 42, ipc.serialize(42))
assert.Equal(t, "hello", ipc.serialize("hello"))
assert.Equal(t, true, ipc.serialize(true))
assert.Equal(t, 3.14, ipc.serialize(3.14))
})
t.Run("byte slice serializes to blob", func(t *testing.T) {
data := []byte{0x01, 0x02, 0x03}
result := ipc.serialize(data)
m, ok := result.(map[string]any)
assert.True(t, ok)
assert.Equal(t, "blob", m["t"])
assert.Equal(t, "AQID", m["d"]) // base64 of {1,2,3}
})
t.Run("empty byte slice serializes to blob", func(t *testing.T) {
data := []byte{}
result := ipc.serialize(data)
m, ok := result.(map[string]any)
assert.True(t, ok)
assert.Equal(t, "blob", m["t"])
assert.Equal(t, "", m["d"])
})
}
func TestConvType(t *testing.T) {
ipc := &ipcCommon{}
t.Run("float64 to int", func(t *testing.T) {
result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42))
assert.Equal(t, 42, result)
})
t.Run("float64 with fractional part stays float", func(t *testing.T) {
result := ipc.ConvType(reflect.TypeOf(0), reflect.TypeOf(0.0), float64(42.5))
assert.Equal(t, float64(42.5), result)
})
t.Run("string passes through", func(t *testing.T) {
result := ipc.ConvType(reflect.TypeOf(""), reflect.TypeOf(""), "hello")
assert.Equal(t, "hello", result)
})
t.Run("bool passes through", func(t *testing.T) {
result := ipc.ConvType(reflect.TypeOf(true), reflect.TypeOf(true), true)
assert.Equal(t, true, result)
})
}

View File

@ -1,12 +0,0 @@
package types
type ValType string
const (
TNoType ValType = "" // zero value constant for ValType (please don't use just "" as zero value!)
TInt ValType = "int"
TString ValType = "string"
TBool ValType = "bool"
TBlob ValType = "blob"
TArray ValType = "array"
)

31
lib/golang/util.go Normal file
View File

@ -0,0 +1,31 @@
package golang
import (
"fmt"
"reflect"
)
func mergeErr(errs ...error) (ret error) {
for _, err := range errs {
if err != nil {
if ret == nil {
ret = err
} else {
ret = fmt.Errorf("%w; %w", ret, err)
}
}
}
return
}
func mapTypeNames(types []any) map[string]any {
result := make(map[string]any)
for _, t := range types {
if reflect.TypeOf(t).Kind() != reflect.Pointer {
panic(fmt.Sprintf("LocalAPI argument must be pointer"))
}
typeName := reflect.TypeOf(t).Elem().Name()
result[typeName] = t
}
return result
}

46
lib/golang/util_test.go Normal file
View File

@ -0,0 +1,46 @@
package golang
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestMergeErr(t *testing.T) {
t.Run("all nil returns nil", func(t *testing.T) {
assert.NoError(t, mergeErr(nil, nil, nil))
})
t.Run("single error returns it", func(t *testing.T) {
err := fmt.Errorf("one")
assert.EqualError(t, mergeErr(nil, err, nil), "one")
})
t.Run("multiple errors merged", func(t *testing.T) {
err1 := fmt.Errorf("one")
err2 := fmt.Errorf("two")
result := mergeErr(err1, err2)
assert.ErrorContains(t, result, "one")
assert.ErrorContains(t, result, "two")
})
}
func TestMapTypeNames(t *testing.T) {
type Foo struct{}
type Bar struct{}
t.Run("maps pointer types by name", func(t *testing.T) {
foo := &Foo{}
bar := &Bar{}
result := mapTypeNames([]any{foo, bar})
assert.Equal(t, foo, result["Foo"])
assert.Equal(t, bar, result["Bar"])
})
t.Run("panics on non-pointer", func(t *testing.T) {
assert.Panics(t, func() {
mapTypeNames([]any{Foo{}})
})
})
}

View File

@ -0,0 +1,27 @@
import {test} from 'vitest';
import {AsyncQueue} from './asyncqueue.js';
test('put then collect returns items', async ({expect}) => {
const q = new AsyncQueue<number>();
q.put(1);
q.put(2);
const items = await q.collect();
expect(items).toEqual([1, 2]);
});
test('collect then put resolves on first put', async ({expect}) => {
const q = new AsyncQueue<number>();
const collectPromise = q.collect();
q.put(42);
const items = await collectPromise;
expect(items).toEqual([42]);
});
test('multiple puts before collect', async ({expect}) => {
const q = new AsyncQueue<string>();
q.put('a');
q.put('b');
q.put('c');
const items = await q.collect();
expect(items).toEqual(['a', 'b', 'c']);
});

44
lib/ts/src/child.ts Normal file
View File

@ -0,0 +1,44 @@
import * as net from 'node:net';
import {IPCCommon} from './common.js';
import {socketPathFromArgs} from './util.js';
export class ChildIPC extends IPCCommon {
constructor(...localApis: object[]) {
super(localApis, socketPathFromArgs());
}
async start(): Promise<void> {
return new Promise((resolve, reject) => {
this.conn = net.createConnection(this.socketPath, () => {
this.readConn();
resolve();
});
this.conn.on('error', reject);
});
}
async wait(): Promise<void> {
const closePromise = new Promise<void>((resolve) => {
this.onClose = () => {
if (this.processingCalls === 0) {
this.conn?.destroy();
resolve();
}
};
if (this.stopRequested && this.processingCalls === 0) {
this.conn?.destroy();
resolve();
}
});
const errorPromise = this.errorQueue.collect().then((errors) => {
if (errors.length === 1) {
throw errors[0];
} else if (errors.length > 1) {
throw new Error(errors.map(e => e.toString()).join(', '));
}
});
await Promise.race([closePromise, errorPromise]);
}
}

View File

@ -1,46 +1,10 @@
import * as net from 'node:net';
import * as readline from 'node:readline';
import {type ChildProcess, spawn} from 'node:child_process';
import * as os from 'node:os';
import * as path from 'node:path';
import * as fs from 'node:fs';
import * as util from 'node:util';
import * as crypto from 'node:crypto';
import {AsyncQueue} from './asyncqueue.js';
import type {CallMessage, CallResult, Message, ResponseMessage, Vals} from './protocol.js';
import {MsgType} from './protocol.js';
const IPC_SOCKET_ARG = 'ipc-socket';
type JSONSerializable = string | number | boolean;
enum MsgType {
Call = 1,
Response = 2,
}
type Vals = any[];
interface CallMessage {
type: MsgType.Call,
id: number,
method: string;
args: Vals;
}
interface ResponseMessage {
type: MsgType.Response,
id: number,
result?: Vals;
error?: string;
}
type Message = CallMessage | ResponseMessage;
interface CallResult {
result: Vals;
error: Error | null;
}
abstract class IPCCommon {
export abstract class IPCCommon {
protected localApis: Record<string, any>;
protected socketPath: string;
protected conn: net.Socket | null = null;
@ -75,6 +39,7 @@ abstract class IPCCommon {
});
this.conn.on('close', (hadError: boolean) => {
this.rejectPendingCalls(new Error('connection closed'));
if (hadError) {
this.raiseErr(new Error('connection closed due to error'));
}
@ -194,7 +159,6 @@ abstract class IPCCommon {
}
public serialize(arg: any): any {
// noinspection FallThroughInSwitchStatementJS
switch (typeof arg) {
case 'string':
case 'boolean':
@ -212,7 +176,6 @@ abstract class IPCCommon {
}
public deserialize(arg: any): any {
// noinspection FallThroughInSwitchStatementJS
switch (typeof arg) {
case 'string':
case 'boolean':
@ -248,201 +211,15 @@ abstract class IPCCommon {
if (this.onClose) this.onClose();
}
protected rejectPendingCalls(err: Error): void {
const pending = this.pendingCalls;
this.pendingCalls = {};
for (const callback of Object.values(pending)) {
callback({result: [], error: err});
}
}
protected raiseErr(err: Error): void {
this.errorQueue.put(err);
}
}
export class ParentIPC extends IPCCommon {
private readonly cmdPath: string;
private readonly cmdArgs: string[];
private cmd: ChildProcess | null = null;
private readonly listener: net.Server;
private cmdExitResult: { code: number | null, signal: string | null } | null = null;
private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = [];
constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) {
const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`);
super(localApis, socketPath);
this.cmdPath = cmdPath;
if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) {
throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`);
}
this.cmdArgs = cmdArgs;
this.listener = net.createServer();
}
async start(): Promise<void> {
try {
fs.unlinkSync(this.socketPath);
} catch {
}
await new Promise<void>((resolve, reject) => {
this.listener.listen(this.socketPath, () => {
resolve();
});
this.listener.on('error', reject);
});
const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath];
this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'});
this.cmd.on('error', (err) => {
this.raiseErr(err);
});
this.cmd.on('close', (code, signal) => {
const result = { code, signal };
this.cmdExitResult = result;
for (const cb of this.cmdExitCallbacks) cb(result);
this.cmdExitCallbacks = [];
});
await this.acceptConn();
}
private async acceptConn(): Promise<void> {
const acceptTimeout = 10000;
const acceptPromise = new Promise<net.Socket>((resolve, reject) => {
this.listener.once('connection', (conn) => {
resolve(conn);
});
this.listener.once('error', reject);
});
const exitPromise = new Promise<net.Socket>((_, reject) => {
if (this.cmdExitResult) {
reject(new Error(`command exited before connection established`));
} else {
this.cmdExitCallbacks.push(() => {
reject(new Error(`command exited before connection established`));
});
}
});
try {
this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), acceptTimeout);
this.readConn();
} catch (e) {
if (this.cmd) this.cmd.kill();
throw e;
}
}
async wait(): Promise<void> {
if (!this.cmd) {
throw new Error('Command is not started yet');
}
const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => {
if (this.cmdExitResult) {
resolve(this.cmdExitResult);
} else {
this.cmdExitCallbacks.push(resolve);
}
});
try {
await Promise.race([
exitPromise.then(({ code, signal }) => {
if (signal || code) {
if (signal) throw new Error(`Process exited with signal ${ signal }`);
else throw new Error(`Process exited with code ${ code }`);
} else if (!this.ready) {
throw new Error('command exited before connection established');
}
}),
this.errorQueue.collect().then((errors) => {
if (errors.length === 1) {
throw errors[0];
} else if (errors.length > 1) {
throw new Error(errors.map(e => e.toString()).join(', '));
}
}),
]);
} finally {
try { fs.unlinkSync(this.socketPath); } catch {}
}
}
}
export class ChildIPC extends IPCCommon {
constructor(...localApis: object[]) {
super(localApis, socketPathFromArgs());
}
async start(): Promise<void> {
return new Promise((resolve, reject) => {
this.conn = net.createConnection(this.socketPath, () => {
this.readConn();
resolve();
});
this.conn.on('error', reject);
});
}
async wait(): Promise<void> {
const closePromise = new Promise<void>((resolve) => {
this.onClose = () => {
if (this.processingCalls === 0) {
this.conn?.destroy();
resolve();
}
};
if (this.stopRequested && this.processingCalls === 0) {
this.conn?.destroy();
resolve();
}
});
const errorPromise = this.errorQueue.collect().then((errors) => {
if (errors.length === 1) {
throw errors[0];
} else if (errors.length > 1) {
throw new Error(errors.map(e => e.toString()).join(', '));
}
});
await Promise.race([closePromise, errorPromise]);
}
}
function socketPathFromArgs(): string {
const {values} = util.parseArgs({
options: {
[IPC_SOCKET_ARG]: {
type: 'string',
}
}
});
if (!values[IPC_SOCKET_ARG]) {
throw new Error('ipc socket path is missing');
}
return values[IPC_SOCKET_ARG];
}
function sleep(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms));
}
// throws on timeout
function timeout<T>(prom: Promise<T>, ms: number): Promise<T> {
return Promise.race(
[
prom,
new Promise((res, reject) => {
setTimeout(() => {reject(new Error('timed out'))}, ms)
})]
) as Promise<T>;
}

View File

@ -1 +1,2 @@
export {ParentIPC, ChildIPC} from './lib.js';
export {ParentIPC} from './parent.js';
export {ChildIPC} from './child.js';

View File

@ -1,5 +1,5 @@
import {test} from 'vitest';
import {ParentIPC} from './lib.js';
import {ParentIPC} from './parent.js';
test('test connection timeout', async ({expect}) => {
const parentIpc = new ParentIPC('../testdata/sleep15.sh', []);

126
lib/ts/src/parent.ts Normal file
View File

@ -0,0 +1,126 @@
import * as net from 'node:net';
import * as os from 'node:os';
import * as path from 'node:path';
import * as fs from 'node:fs';
import * as crypto from 'node:crypto';
import {type ChildProcess, spawn} from 'node:child_process';
import {IPCCommon} from './common.js';
import {timeout} from './util.js';
const IPC_SOCKET_ARG = 'ipc-socket';
const ACCEPT_TIMEOUT_MS = 10000;
export class ParentIPC extends IPCCommon {
private readonly cmdPath: string;
private readonly cmdArgs: string[];
private cmd: ChildProcess | null = null;
private readonly listener: net.Server;
private cmdExitResult: { code: number | null, signal: string | null } | null = null;
private cmdExitCallbacks: ((result: { code: number | null, signal: string | null }) => void)[] = [];
constructor(cmdPath: string, cmdArgs: string[], ...localApis: object[]) {
const socketPath = path.join(os.tmpdir(), `kitten-ipc-${ process.pid }-${ crypto.randomInt(2**48 - 1) }.sock`);
super(localApis, socketPath);
this.cmdPath = cmdPath;
if (cmdArgs.includes(`--${ IPC_SOCKET_ARG }`)) {
throw new Error(`you should not use '--${ IPC_SOCKET_ARG }' argument in your command`);
}
this.cmdArgs = cmdArgs;
this.listener = net.createServer();
}
async start(): Promise<void> {
try {
fs.unlinkSync(this.socketPath);
} catch {
}
await new Promise<void>((resolve, reject) => {
this.listener.listen(this.socketPath, () => {
resolve();
});
this.listener.on('error', reject);
});
const cmdArgs = [...this.cmdArgs, `--${ IPC_SOCKET_ARG }`, this.socketPath];
this.cmd = spawn(this.cmdPath, cmdArgs, {stdio: 'inherit'});
this.cmd.on('error', (err) => {
this.raiseErr(err);
});
this.cmd.on('close', (code, signal) => {
const result = { code, signal };
this.cmdExitResult = result;
for (const cb of this.cmdExitCallbacks) cb(result);
this.cmdExitCallbacks = [];
});
await this.acceptConn();
}
private async acceptConn(): Promise<void> {
const acceptPromise = new Promise<net.Socket>((resolve, reject) => {
this.listener.once('connection', (conn) => {
resolve(conn);
});
this.listener.once('error', reject);
});
const exitPromise = new Promise<net.Socket>((_, reject) => {
if (this.cmdExitResult) {
reject(new Error(`command exited before connection established`));
} else {
this.cmdExitCallbacks.push(() => {
reject(new Error(`command exited before connection established`));
});
}
});
try {
this.conn = await timeout(Promise.race([acceptPromise, exitPromise]), ACCEPT_TIMEOUT_MS);
this.readConn();
} catch (e) {
if (this.cmd) this.cmd.kill();
throw e;
}
}
async wait(): Promise<void> {
if (!this.cmd) {
throw new Error('Command is not started yet');
}
const exitPromise = new Promise<{ code: number | null, signal: string | null }>((resolve) => {
if (this.cmdExitResult) {
resolve(this.cmdExitResult);
} else {
this.cmdExitCallbacks.push(resolve);
}
});
try {
await Promise.race([
exitPromise.then(({ code, signal }) => {
if (signal || code) {
if (signal) throw new Error(`Process exited with signal ${ signal }`);
else throw new Error(`Process exited with code ${ code }`);
} else if (!this.ready) {
throw new Error('command exited before connection established');
}
}),
this.errorQueue.collect().then((errors) => {
if (errors.length === 1) {
throw errors[0];
} else if (errors.length > 1) {
throw new Error(errors.map(e => e.toString()).join(', '));
}
}),
]);
} finally {
try { fs.unlinkSync(this.socketPath); } catch {}
}
}
}

27
lib/ts/src/protocol.ts Normal file
View File

@ -0,0 +1,27 @@
export enum MsgType {
Call = 1,
Response = 2,
}
export type Vals = any[];
export interface CallMessage {
type: MsgType.Call,
id: number,
method: string;
args: Vals;
}
export interface ResponseMessage {
type: MsgType.Response,
id: number,
result?: Vals;
error?: string;
}
export type Message = CallMessage | ResponseMessage;
export interface CallResult {
result: Vals;
error: Error | null;
}

View File

@ -0,0 +1,68 @@
import {test} from 'vitest';
import {ChildIPC} from './child.js';
// Access serialize/deserialize through a ChildIPC instance's public methods
// We create a minimal wrapper to test them
class TestableIPC {
private ipc: any;
constructor() {
// Access the prototype methods directly
this.ipc = Object.create(ChildIPC.prototype);
}
serialize(arg: any): any {
return this.ipc.serialize(arg);
}
deserialize(arg: any): any {
return this.ipc.deserialize(arg);
}
}
test('serialize primitives', ({expect}) => {
const t = new TestableIPC();
expect(t.serialize(42)).toBe(42);
expect(t.serialize('hello')).toBe('hello');
expect(t.serialize(true)).toBe(true);
expect(t.serialize(3.14)).toBe(3.14);
});
test('serialize buffer to base64', ({expect}) => {
const t = new TestableIPC();
const buf = Buffer.from([1, 2, 3]);
expect(t.serialize(buf)).toBe('AQID');
});
test('deserialize primitives', ({expect}) => {
const t = new TestableIPC();
expect(t.deserialize(42)).toBe(42);
expect(t.deserialize('hello')).toBe('hello');
expect(t.deserialize(true)).toBe(true);
});
test('deserialize blob to buffer', ({expect}) => {
const t = new TestableIPC();
const result = t.deserialize({t: 'blob', d: 'AQID'});
expect(Buffer.isBuffer(result)).toBe(true);
expect(result).toEqual(Buffer.from([1, 2, 3]));
});
test('serialize then deserialize blob round-trip', ({expect}) => {
const t = new TestableIPC();
const original = Buffer.from([0xDE, 0xAD, 0xBE, 0xEF]);
// Go serializes as {t: 'blob', d: base64}, TS serializes as just base64 string
// The TS serialize returns base64 string directly for Buffer
const serialized = t.serialize(original);
expect(typeof serialized).toBe('string');
});
test('deserialize unknown object throws', ({expect}) => {
const t = new TestableIPC();
expect(() => t.deserialize({foo: 'bar'})).toThrow('cannot deserialize');
});
test('serialize unsupported object throws', ({expect}) => {
const t = new TestableIPC();
expect(() => t.serialize({foo: 'bar'})).toThrow('cannot serialize');
});

29
lib/ts/src/util.ts Normal file
View File

@ -0,0 +1,29 @@
import * as util from 'node:util';
const IPC_SOCKET_ARG = 'ipc-socket';
export function socketPathFromArgs(): string {
const {values} = util.parseArgs({
options: {
[IPC_SOCKET_ARG]: {
type: 'string',
}
}
});
if (!values[IPC_SOCKET_ARG]) {
throw new Error('ipc socket path is missing');
}
return values[IPC_SOCKET_ARG];
}
export function timeout<T>(prom: Promise<T>, ms: number): Promise<T> {
return Promise.race(
[
prom,
new Promise((res, reject) => {
setTimeout(() => {reject(new Error('timed out'))}, ms)
})]
) as Promise<T>;
}