498 lines
13 KiB
Go
498 lines
13 KiB
Go
package api
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"runtime/debug"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/bundled"
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/core"
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/lsp/lsproto"
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/project"
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/project/logging"
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/vfs"
|
|
"efprojects.com/kitten-ipc/kitcom/internal/tsgo/vfs/osvfs"
|
|
"github.com/go-json-experiment/json"
|
|
)
|
|
|
|
//go:generate go tool golang.org/x/tools/cmd/stringer -type=MessageType -output=stringer_generated.go
|
|
//go:generate go tool mvdan.cc/gofumpt -w stringer_generated.go
|
|
|
|
type MessageType uint8
|
|
|
|
const (
|
|
MessageTypeUnknown MessageType = iota
|
|
MessageTypeRequest
|
|
MessageTypeCallResponse
|
|
MessageTypeCallError
|
|
MessageTypeResponse
|
|
MessageTypeError
|
|
MessageTypeCall
|
|
)
|
|
|
|
func (m MessageType) IsValid() bool {
|
|
return m >= MessageTypeRequest && m <= MessageTypeCall
|
|
}
|
|
|
|
type MessagePackType uint8
|
|
|
|
const (
|
|
MessagePackTypeFixedArray3 MessagePackType = 0x93
|
|
MessagePackTypeBin8 MessagePackType = 0xC4
|
|
MessagePackTypeBin16 MessagePackType = 0xC5
|
|
MessagePackTypeBin32 MessagePackType = 0xC6
|
|
MessagePackTypeU8 MessagePackType = 0xCC
|
|
)
|
|
|
|
type Callback int
|
|
|
|
const (
|
|
CallbackDirectoryExists Callback = 1 << iota
|
|
CallbackFileExists
|
|
CallbackGetAccessibleEntries
|
|
CallbackReadFile
|
|
CallbackRealpath
|
|
)
|
|
|
|
type ServerOptions struct {
|
|
In io.Reader
|
|
Out io.Writer
|
|
Err io.Writer
|
|
Cwd string
|
|
DefaultLibraryPath string
|
|
}
|
|
|
|
var _ vfs.FS = (*Server)(nil)
|
|
|
|
type Server struct {
|
|
r *bufio.Reader
|
|
w *bufio.Writer
|
|
stderr io.Writer
|
|
|
|
cwd string
|
|
newLine string
|
|
fs vfs.FS
|
|
defaultLibraryPath string
|
|
|
|
callbackMu sync.Mutex
|
|
enabledCallbacks Callback
|
|
logger logging.Logger
|
|
api *API
|
|
|
|
requestId int
|
|
}
|
|
|
|
func NewServer(options *ServerOptions) *Server {
|
|
if options.Cwd == "" {
|
|
panic("Cwd is required")
|
|
}
|
|
|
|
server := &Server{
|
|
r: bufio.NewReader(options.In),
|
|
w: bufio.NewWriter(options.Out),
|
|
stderr: options.Err,
|
|
cwd: options.Cwd,
|
|
fs: bundled.WrapFS(osvfs.FS()),
|
|
defaultLibraryPath: options.DefaultLibraryPath,
|
|
}
|
|
logger := logging.NewLogger(options.Err)
|
|
server.logger = logger
|
|
server.api = NewAPI(&APIInit{
|
|
Logger: logger,
|
|
FS: server,
|
|
SessionOptions: &project.SessionOptions{
|
|
CurrentDirectory: options.Cwd,
|
|
DefaultLibraryPath: options.DefaultLibraryPath,
|
|
PositionEncoding: lsproto.PositionEncodingKindUTF8,
|
|
LoggingEnabled: true,
|
|
},
|
|
})
|
|
return server
|
|
}
|
|
|
|
// DefaultLibraryPath implements APIHost.
|
|
func (s *Server) DefaultLibraryPath() string {
|
|
return s.defaultLibraryPath
|
|
}
|
|
|
|
// FS implements APIHost.
|
|
func (s *Server) FS() vfs.FS {
|
|
return s
|
|
}
|
|
|
|
// GetCurrentDirectory implements APIHost.
|
|
func (s *Server) GetCurrentDirectory() string {
|
|
return s.cwd
|
|
}
|
|
|
|
func (s *Server) Run() error {
|
|
for {
|
|
messageType, method, payload, err := s.readRequest("")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch messageType {
|
|
case MessageTypeRequest:
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
stack := debug.Stack()
|
|
err = fmt.Errorf("panic handling request: %v\n%s", r, string(stack))
|
|
if fatalErr := s.sendError(method, err); fatalErr != nil {
|
|
panic("fatal error sending panic response")
|
|
}
|
|
}
|
|
}()
|
|
|
|
result, err := s.handleRequest(method, payload)
|
|
|
|
if err != nil {
|
|
if err := s.sendError(method, err); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := s.sendResponse(method, result); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
default:
|
|
return fmt.Errorf("%w: expected request, received: %s", ErrInvalidRequest, messageType.String())
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) readRequest(expectedMethod string) (messageType MessageType, method string, payload []byte, err error) {
|
|
t, err := s.r.ReadByte()
|
|
if err != nil {
|
|
return messageType, method, payload, err
|
|
}
|
|
if MessagePackType(t) != MessagePackTypeFixedArray3 {
|
|
return messageType, method, payload, fmt.Errorf("%w: expected message to be encoded as fixed 3-element array (0x93), received: 0x%2x", ErrInvalidRequest, t)
|
|
}
|
|
t, err = s.r.ReadByte()
|
|
if err != nil {
|
|
return messageType, method, payload, err
|
|
}
|
|
if MessagePackType(t) != MessagePackTypeU8 {
|
|
return messageType, method, payload, fmt.Errorf("%w: expected first element of message tuple to be encoded as unsigned 8-bit int (0xcc), received: 0x%2x", ErrInvalidRequest, t)
|
|
}
|
|
rawMessageType, err := s.r.ReadByte()
|
|
if err != nil {
|
|
return messageType, method, payload, err
|
|
}
|
|
messageType = MessageType(rawMessageType)
|
|
if !messageType.IsValid() {
|
|
return messageType, method, payload, fmt.Errorf("%w: unknown message type: %d", ErrInvalidRequest, messageType)
|
|
}
|
|
rawMethod, err := s.readBin()
|
|
if err != nil {
|
|
return messageType, method, payload, err
|
|
}
|
|
method = string(rawMethod)
|
|
if expectedMethod != "" && method != expectedMethod {
|
|
return messageType, method, payload, fmt.Errorf("%w: expected method %q, received %q", ErrInvalidRequest, expectedMethod, method)
|
|
}
|
|
payload, err = s.readBin()
|
|
return messageType, method, payload, err
|
|
}
|
|
|
|
func (s *Server) readBin() ([]byte, error) {
|
|
// https://github.com/msgpack/msgpack/blob/master/spec.md#bin-format-family
|
|
t, err := s.r.ReadByte()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var size uint
|
|
switch MessagePackType(t) {
|
|
case MessagePackTypeBin8:
|
|
var size8 uint8
|
|
if err = binary.Read(s.r, binary.BigEndian, &size8); err != nil {
|
|
return nil, err
|
|
}
|
|
size = uint(size8)
|
|
case MessagePackTypeBin16:
|
|
var size16 uint16
|
|
if err = binary.Read(s.r, binary.BigEndian, &size16); err != nil {
|
|
return nil, err
|
|
}
|
|
size = uint(size16)
|
|
case MessagePackTypeBin32:
|
|
var size32 uint32
|
|
if err = binary.Read(s.r, binary.BigEndian, &size32); err != nil {
|
|
return nil, err
|
|
}
|
|
size = uint(size32)
|
|
default:
|
|
return nil, fmt.Errorf("%w: expected binary data length (0xc4-0xc6), received: 0x%2x", ErrInvalidRequest, t)
|
|
}
|
|
payload := make([]byte, size)
|
|
bytesRead, err := io.ReadFull(s.r, payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if bytesRead != int(size) {
|
|
return nil, fmt.Errorf("%w: expected %d bytes, read %d", ErrInvalidRequest, size, bytesRead)
|
|
}
|
|
return payload, nil
|
|
}
|
|
|
|
func (s *Server) enableCallback(callback string) error {
|
|
switch callback {
|
|
case "directoryExists":
|
|
s.enabledCallbacks |= CallbackDirectoryExists
|
|
case "fileExists":
|
|
s.enabledCallbacks |= CallbackFileExists
|
|
case "getAccessibleEntries":
|
|
s.enabledCallbacks |= CallbackGetAccessibleEntries
|
|
case "readFile":
|
|
s.enabledCallbacks |= CallbackReadFile
|
|
case "realpath":
|
|
s.enabledCallbacks |= CallbackRealpath
|
|
default:
|
|
return fmt.Errorf("unknown callback: %s", callback)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) handleRequest(method string, payload []byte) ([]byte, error) {
|
|
s.requestId++
|
|
switch method {
|
|
case "configure":
|
|
return nil, s.handleConfigure(payload)
|
|
case "echo":
|
|
return payload, nil
|
|
default:
|
|
return s.api.HandleRequest(core.WithRequestID(context.Background(), strconv.Itoa(s.requestId)), method, payload)
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleConfigure(payload []byte) error {
|
|
var params *ConfigureParams
|
|
if err := json.Unmarshal(payload, ¶ms); err != nil {
|
|
return fmt.Errorf("%w: %w", ErrInvalidRequest, err)
|
|
}
|
|
for _, callback := range params.Callbacks {
|
|
if err := s.enableCallback(callback); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// !!!
|
|
if params.LogFile != "" {
|
|
// s.logger.SetFile(params.LogFile)
|
|
} else {
|
|
// s.logger.SetFile("")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) sendResponse(method string, result []byte) error {
|
|
return s.writeMessage(MessageTypeResponse, method, result)
|
|
}
|
|
|
|
func (s *Server) sendError(method string, err error) error {
|
|
return s.writeMessage(MessageTypeError, method, []byte(err.Error()))
|
|
}
|
|
|
|
func (s *Server) writeMessage(messageType MessageType, method string, payload []byte) error {
|
|
if err := s.w.WriteByte(byte(MessagePackTypeFixedArray3)); err != nil {
|
|
return err
|
|
}
|
|
if err := s.w.WriteByte(byte(MessagePackTypeU8)); err != nil {
|
|
return err
|
|
}
|
|
if err := s.w.WriteByte(byte(messageType)); err != nil {
|
|
return err
|
|
}
|
|
if err := s.writeBin([]byte(method)); err != nil {
|
|
return err
|
|
}
|
|
if err := s.writeBin(payload); err != nil {
|
|
return err
|
|
}
|
|
return s.w.Flush()
|
|
}
|
|
|
|
func (s *Server) writeBin(payload []byte) error {
|
|
length := len(payload)
|
|
if length < 256 {
|
|
if err := s.w.WriteByte(byte(MessagePackTypeBin8)); err != nil {
|
|
return err
|
|
}
|
|
if err := s.w.WriteByte(byte(length)); err != nil {
|
|
return err
|
|
}
|
|
} else if length < 1<<16 {
|
|
if err := s.w.WriteByte(byte(MessagePackTypeBin16)); err != nil {
|
|
return err
|
|
}
|
|
if err := binary.Write(s.w, binary.BigEndian, uint16(length)); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
if err := s.w.WriteByte(byte(MessagePackTypeBin32)); err != nil {
|
|
return err
|
|
}
|
|
if err := binary.Write(s.w, binary.BigEndian, uint32(length)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
_, err := s.w.Write(payload)
|
|
return err
|
|
}
|
|
|
|
func (s *Server) call(method string, payload any) ([]byte, error) {
|
|
s.callbackMu.Lock()
|
|
defer s.callbackMu.Unlock()
|
|
jsonPayload, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err = s.writeMessage(MessageTypeCall, method, jsonPayload); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
messageType, _, responsePayload, err := s.readRequest(method)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if messageType != MessageTypeCallResponse && messageType != MessageTypeCallError {
|
|
return nil, fmt.Errorf("%w: expected call-response or call-error, received: %s", ErrInvalidRequest, messageType.String())
|
|
}
|
|
|
|
if messageType == MessageTypeCallError {
|
|
return nil, fmt.Errorf("%w: %s", ErrClientError, responsePayload)
|
|
}
|
|
|
|
return responsePayload, nil
|
|
}
|
|
|
|
// DirectoryExists implements vfs.FS.
|
|
func (s *Server) DirectoryExists(path string) bool {
|
|
if s.enabledCallbacks&CallbackDirectoryExists != 0 {
|
|
result, err := s.call("directoryExists", path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if len(result) > 0 {
|
|
return string(result) == "true"
|
|
}
|
|
}
|
|
return s.fs.DirectoryExists(path)
|
|
}
|
|
|
|
// FileExists implements vfs.FS.
|
|
func (s *Server) FileExists(path string) bool {
|
|
if s.enabledCallbacks&CallbackFileExists != 0 {
|
|
result, err := s.call("fileExists", path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if len(result) > 0 {
|
|
return string(result) == "true"
|
|
}
|
|
}
|
|
return s.fs.FileExists(path)
|
|
}
|
|
|
|
// GetAccessibleEntries implements vfs.FS.
|
|
func (s *Server) GetAccessibleEntries(path string) vfs.Entries {
|
|
if s.enabledCallbacks&CallbackGetAccessibleEntries != 0 {
|
|
result, err := s.call("getAccessibleEntries", path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if len(result) > 0 {
|
|
var rawEntries *struct {
|
|
Files []string `json:"files"`
|
|
Directories []string `json:"directories"`
|
|
}
|
|
if err := json.Unmarshal(result, &rawEntries); err != nil {
|
|
panic(err)
|
|
}
|
|
if rawEntries != nil {
|
|
return vfs.Entries{
|
|
Files: rawEntries.Files,
|
|
Directories: rawEntries.Directories,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return s.fs.GetAccessibleEntries(path)
|
|
}
|
|
|
|
// ReadFile implements vfs.FS.
|
|
func (s *Server) ReadFile(path string) (contents string, ok bool) {
|
|
if s.enabledCallbacks&CallbackReadFile != 0 {
|
|
data, err := s.call("readFile", path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if string(data) == "null" {
|
|
return "", false
|
|
}
|
|
if len(data) > 0 {
|
|
var result string
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
panic(err)
|
|
}
|
|
return result, true
|
|
}
|
|
}
|
|
return s.fs.ReadFile(path)
|
|
}
|
|
|
|
// Realpath implements vfs.FS.
|
|
func (s *Server) Realpath(path string) string {
|
|
if s.enabledCallbacks&CallbackRealpath != 0 {
|
|
data, err := s.call("realpath", path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if len(data) > 0 {
|
|
var result string
|
|
if err := json.Unmarshal(data, &result); err != nil {
|
|
panic(err)
|
|
}
|
|
return result
|
|
}
|
|
}
|
|
return s.fs.Realpath(path)
|
|
}
|
|
|
|
// UseCaseSensitiveFileNames implements vfs.FS.
|
|
func (s *Server) UseCaseSensitiveFileNames() bool {
|
|
return s.fs.UseCaseSensitiveFileNames()
|
|
}
|
|
|
|
// WriteFile implements vfs.FS.
|
|
func (s *Server) WriteFile(path string, data string, writeByteOrderMark bool) error {
|
|
return s.fs.WriteFile(path, data, writeByteOrderMark)
|
|
}
|
|
|
|
// WalkDir implements vfs.FS.
|
|
func (s *Server) WalkDir(root string, walkFn vfs.WalkDirFunc) error {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// Stat implements vfs.FS.
|
|
func (s *Server) Stat(path string) vfs.FileInfo {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// Remove implements vfs.FS.
|
|
func (s *Server) Remove(path string) error {
|
|
panic("unimplemented")
|
|
}
|
|
|
|
// Chtimes implements vfs.FS.
|
|
func (s *Server) Chtimes(path string, aTime time.Time, mTime time.Time) error {
|
|
panic("unimplemented")
|
|
}
|