package api import ( "bufio" "context" "encoding/binary" "fmt" "io" "runtime/debug" "strconv" "sync" "time" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/bundled" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/core" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/lsp/lsproto" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/project" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/project/logging" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/vfs" "efprojects.com/kitten-ipc/kitcom/internal/tsgo/vfs/osvfs" "github.com/go-json-experiment/json" ) //go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageType -output=stringer_generated.go //go:generate go tool mvdan.cc/gofumpt -w stringer_generated.go type MessageType uint8 const ( MessageTypeUnknown MessageType = iota MessageTypeRequest MessageTypeCallResponse MessageTypeCallError MessageTypeResponse MessageTypeError MessageTypeCall ) func (m MessageType) IsValid() bool { return m >= MessageTypeRequest && m <= MessageTypeCall } type MessagePackType uint8 const ( MessagePackTypeFixedArray3 MessagePackType = 0x93 MessagePackTypeBin8 MessagePackType = 0xC4 MessagePackTypeBin16 MessagePackType = 0xC5 MessagePackTypeBin32 MessagePackType = 0xC6 MessagePackTypeU8 MessagePackType = 0xCC ) type Callback int const ( CallbackDirectoryExists Callback = 1 << iota CallbackFileExists CallbackGetAccessibleEntries CallbackReadFile CallbackRealpath ) type ServerOptions struct { In io.Reader Out io.Writer Err io.Writer Cwd string DefaultLibraryPath string } var _ vfs.FS = (*Server)(nil) type Server struct { r *bufio.Reader w *bufio.Writer stderr io.Writer cwd string newLine string fs vfs.FS defaultLibraryPath string callbackMu sync.Mutex enabledCallbacks Callback logger logging.Logger api *API requestId int } func NewServer(options *ServerOptions) *Server { if options.Cwd == "" { panic("Cwd is required") } server := &Server{ r: bufio.NewReader(options.In), w: bufio.NewWriter(options.Out), stderr: options.Err, cwd: options.Cwd, fs: bundled.WrapFS(osvfs.FS()), defaultLibraryPath: options.DefaultLibraryPath, } logger := logging.NewLogger(options.Err) server.logger = logger server.api = NewAPI(&APIInit{ Logger: logger, FS: server, SessionOptions: &project.SessionOptions{ CurrentDirectory: options.Cwd, DefaultLibraryPath: options.DefaultLibraryPath, PositionEncoding: lsproto.PositionEncodingKindUTF8, LoggingEnabled: true, }, }) return server } // DefaultLibraryPath implements APIHost. func (s *Server) DefaultLibraryPath() string { return s.defaultLibraryPath } // FS implements APIHost. func (s *Server) FS() vfs.FS { return s } // GetCurrentDirectory implements APIHost. func (s *Server) GetCurrentDirectory() string { return s.cwd } func (s *Server) Run() error { for { messageType, method, payload, err := s.readRequest("") if err != nil { return err } switch messageType { case MessageTypeRequest: defer func() { if r := recover(); r != nil { stack := debug.Stack() err = fmt.Errorf("panic handling request: %v\n%s", r, string(stack)) if fatalErr := s.sendError(method, err); fatalErr != nil { panic("fatal error sending panic response") } } }() result, err := s.handleRequest(method, payload) if err != nil { if err := s.sendError(method, err); err != nil { return err } } else { if err := s.sendResponse(method, result); err != nil { return err } } default: return fmt.Errorf("%w: expected request, received: %s", ErrInvalidRequest, messageType.String()) } } } func (s *Server) readRequest(expectedMethod string) (messageType MessageType, method string, payload []byte, err error) { t, err := s.r.ReadByte() if err != nil { return messageType, method, payload, err } if MessagePackType(t) != MessagePackTypeFixedArray3 { return messageType, method, payload, fmt.Errorf("%w: expected message to be encoded as fixed 3-element array (0x93), received: 0x%2x", ErrInvalidRequest, t) } t, err = s.r.ReadByte() if err != nil { return messageType, method, payload, err } if MessagePackType(t) != MessagePackTypeU8 { return messageType, method, payload, fmt.Errorf("%w: expected first element of message tuple to be encoded as unsigned 8-bit int (0xcc), received: 0x%2x", ErrInvalidRequest, t) } rawMessageType, err := s.r.ReadByte() if err != nil { return messageType, method, payload, err } messageType = MessageType(rawMessageType) if !messageType.IsValid() { return messageType, method, payload, fmt.Errorf("%w: unknown message type: %d", ErrInvalidRequest, messageType) } rawMethod, err := s.readBin() if err != nil { return messageType, method, payload, err } method = string(rawMethod) if expectedMethod != "" && method != expectedMethod { return messageType, method, payload, fmt.Errorf("%w: expected method %q, received %q", ErrInvalidRequest, expectedMethod, method) } payload, err = s.readBin() return messageType, method, payload, err } func (s *Server) readBin() ([]byte, error) { // https://github.com/msgpack/msgpack/blob/master/spec.md#bin-format-family t, err := s.r.ReadByte() if err != nil { return nil, err } var size uint switch MessagePackType(t) { case MessagePackTypeBin8: var size8 uint8 if err = binary.Read(s.r, binary.BigEndian, &size8); err != nil { return nil, err } size = uint(size8) case MessagePackTypeBin16: var size16 uint16 if err = binary.Read(s.r, binary.BigEndian, &size16); err != nil { return nil, err } size = uint(size16) case MessagePackTypeBin32: var size32 uint32 if err = binary.Read(s.r, binary.BigEndian, &size32); err != nil { return nil, err } size = uint(size32) default: return nil, fmt.Errorf("%w: expected binary data length (0xc4-0xc6), received: 0x%2x", ErrInvalidRequest, t) } payload := make([]byte, size) bytesRead, err := io.ReadFull(s.r, payload) if err != nil { return nil, err } if bytesRead != int(size) { return nil, fmt.Errorf("%w: expected %d bytes, read %d", ErrInvalidRequest, size, bytesRead) } return payload, nil } func (s *Server) enableCallback(callback string) error { switch callback { case "directoryExists": s.enabledCallbacks |= CallbackDirectoryExists case "fileExists": s.enabledCallbacks |= CallbackFileExists case "getAccessibleEntries": s.enabledCallbacks |= CallbackGetAccessibleEntries case "readFile": s.enabledCallbacks |= CallbackReadFile case "realpath": s.enabledCallbacks |= CallbackRealpath default: return fmt.Errorf("unknown callback: %s", callback) } return nil } func (s *Server) handleRequest(method string, payload []byte) ([]byte, error) { s.requestId++ switch method { case "configure": return nil, s.handleConfigure(payload) case "echo": return payload, nil default: return s.api.HandleRequest(core.WithRequestID(context.Background(), strconv.Itoa(s.requestId)), method, payload) } } func (s *Server) handleConfigure(payload []byte) error { var params *ConfigureParams if err := json.Unmarshal(payload, ¶ms); err != nil { return fmt.Errorf("%w: %w", ErrInvalidRequest, err) } for _, callback := range params.Callbacks { if err := s.enableCallback(callback); err != nil { return err } } // !!! if params.LogFile != "" { // s.logger.SetFile(params.LogFile) } else { // s.logger.SetFile("") } return nil } func (s *Server) sendResponse(method string, result []byte) error { return s.writeMessage(MessageTypeResponse, method, result) } func (s *Server) sendError(method string, err error) error { return s.writeMessage(MessageTypeError, method, []byte(err.Error())) } func (s *Server) writeMessage(messageType MessageType, method string, payload []byte) error { if err := s.w.WriteByte(byte(MessagePackTypeFixedArray3)); err != nil { return err } if err := s.w.WriteByte(byte(MessagePackTypeU8)); err != nil { return err } if err := s.w.WriteByte(byte(messageType)); err != nil { return err } if err := s.writeBin([]byte(method)); err != nil { return err } if err := s.writeBin(payload); err != nil { return err } return s.w.Flush() } func (s *Server) writeBin(payload []byte) error { length := len(payload) if length < 256 { if err := s.w.WriteByte(byte(MessagePackTypeBin8)); err != nil { return err } if err := s.w.WriteByte(byte(length)); err != nil { return err } } else if length < 1<<16 { if err := s.w.WriteByte(byte(MessagePackTypeBin16)); err != nil { return err } if err := binary.Write(s.w, binary.BigEndian, uint16(length)); err != nil { return err } } else { if err := s.w.WriteByte(byte(MessagePackTypeBin32)); err != nil { return err } if err := binary.Write(s.w, binary.BigEndian, uint32(length)); err != nil { return err } } _, err := s.w.Write(payload) return err } func (s *Server) call(method string, payload any) ([]byte, error) { s.callbackMu.Lock() defer s.callbackMu.Unlock() jsonPayload, err := json.Marshal(payload) if err != nil { return nil, err } if err = s.writeMessage(MessageTypeCall, method, jsonPayload); err != nil { return nil, err } messageType, _, responsePayload, err := s.readRequest(method) if err != nil { return nil, err } if messageType != MessageTypeCallResponse && messageType != MessageTypeCallError { return nil, fmt.Errorf("%w: expected call-response or call-error, received: %s", ErrInvalidRequest, messageType.String()) } if messageType == MessageTypeCallError { return nil, fmt.Errorf("%w: %s", ErrClientError, responsePayload) } return responsePayload, nil } // DirectoryExists implements vfs.FS. func (s *Server) DirectoryExists(path string) bool { if s.enabledCallbacks&CallbackDirectoryExists != 0 { result, err := s.call("directoryExists", path) if err != nil { panic(err) } if len(result) > 0 { return string(result) == "true" } } return s.fs.DirectoryExists(path) } // FileExists implements vfs.FS. func (s *Server) FileExists(path string) bool { if s.enabledCallbacks&CallbackFileExists != 0 { result, err := s.call("fileExists", path) if err != nil { panic(err) } if len(result) > 0 { return string(result) == "true" } } return s.fs.FileExists(path) } // GetAccessibleEntries implements vfs.FS. func (s *Server) GetAccessibleEntries(path string) vfs.Entries { if s.enabledCallbacks&CallbackGetAccessibleEntries != 0 { result, err := s.call("getAccessibleEntries", path) if err != nil { panic(err) } if len(result) > 0 { var rawEntries *struct { Files []string `json:"files"` Directories []string `json:"directories"` } if err := json.Unmarshal(result, &rawEntries); err != nil { panic(err) } if rawEntries != nil { return vfs.Entries{ Files: rawEntries.Files, Directories: rawEntries.Directories, } } } } return s.fs.GetAccessibleEntries(path) } // ReadFile implements vfs.FS. func (s *Server) ReadFile(path string) (contents string, ok bool) { if s.enabledCallbacks&CallbackReadFile != 0 { data, err := s.call("readFile", path) if err != nil { panic(err) } if string(data) == "null" { return "", false } if len(data) > 0 { var result string if err := json.Unmarshal(data, &result); err != nil { panic(err) } return result, true } } return s.fs.ReadFile(path) } // Realpath implements vfs.FS. func (s *Server) Realpath(path string) string { if s.enabledCallbacks&CallbackRealpath != 0 { data, err := s.call("realpath", path) if err != nil { panic(err) } if len(data) > 0 { var result string if err := json.Unmarshal(data, &result); err != nil { panic(err) } return result } } return s.fs.Realpath(path) } // UseCaseSensitiveFileNames implements vfs.FS. func (s *Server) UseCaseSensitiveFileNames() bool { return s.fs.UseCaseSensitiveFileNames() } // WriteFile implements vfs.FS. func (s *Server) WriteFile(path string, data string, writeByteOrderMark bool) error { return s.fs.WriteFile(path, data, writeByteOrderMark) } // WalkDir implements vfs.FS. func (s *Server) WalkDir(root string, walkFn vfs.WalkDirFunc) error { panic("unimplemented") } // Stat implements vfs.FS. func (s *Server) Stat(path string) vfs.FileInfo { panic("unimplemented") } // Remove implements vfs.FS. func (s *Server) Remove(path string) error { panic("unimplemented") } // Chtimes implements vfs.FS. func (s *Server) Chtimes(path string, aTime time.Time, mTime time.Time) error { panic("unimplemented") }