From f365a88fc0b3532885125a81939b68923497676b Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Wed, 25 Sep 2024 16:34:36 +0800 Subject: [PATCH 1/8] feat: add spec v2 --- specs-go/v2/architecture.go | 51 ++++++++++++++++++++++++++++++++ specs-go/v2/config.go | 40 +++++++++++++++++++++++++ specs-go/v2/engine.go | 32 ++++++++++++++++++++ specs-go/v2/manifest.go | 36 ++++++++++++++++++++++ specs-go/v2/mediatype.go | 59 +++++++++++++++++++++++++++++++++++++ specs-go/v2/processor.go | 42 ++++++++++++++++++++++++++ specs-go/v2/weights.go | 29 ++++++++++++++++++ 7 files changed, 289 insertions(+) create mode 100644 specs-go/v2/architecture.go create mode 100644 specs-go/v2/config.go create mode 100644 specs-go/v2/engine.go create mode 100644 specs-go/v2/manifest.go create mode 100644 specs-go/v2/mediatype.go create mode 100644 specs-go/v2/processor.go create mode 100644 specs-go/v2/weights.go diff --git a/specs-go/v2/architecture.go b/specs-go/v2/architecture.go new file mode 100644 index 0000000..f7069b9 --- /dev/null +++ b/specs-go/v2/architecture.go @@ -0,0 +1,51 @@ +package v2 + +// TransformerForCausalLLM represents the configuration of a transformer model for causal language modeling. +// It defines the architecture and hyperparameters of the model. +// +// Supported features: +// - Attention mechanisms: Multi-Head Attention (MHA) and Grouped Query Attention (GQA) +// - Activation functions: GELU, ReLU +// - Position embeddings: Rotary Position Embedding (RoPE) +// - Normalization: RMSNorm (Root Mean Square Layer Normalization) +// +// This structure is designed to be flexible and accommodate various transformer architectures +// used in state-of-the-art language models. +type TransformerForCausalLLM struct { + // Version of the transformer architecture config + Version string `json:"version"` + + // Vocabulary size of the model + VocabularySize int `json:"vocabulary_size"` + + // The hidden size of the model, e.g. 768, 1024, 2048, etc. + HiddenSize int `json:"hidden_size"` + + // The number of transformer layers of the model. + NumHiddenLayers int `json:"num_hidden_layers"` + + // The number of attention heads, e.g. 12, 16, 32, etc. + NumAttentionHeads int `json:"num_attention_heads"` + + // The number of key value heads, e.g. 1, 2, 4, etc. + // Only used by GQA attention mechanism. + NumKeyValueHeads int `json:"num_key_value_heads"` + + // The activation function used by the pointwise feed-forward layers, e.g. 'gelu', 'relu', 'tanh', etc. + Activation string `json:"activation"` + + // The intermediate size in the feed-forward layers. The non-linearity is applied in this intermediate size. + IntermediateSize int `json:"intermediate_size"` + + // The rms_norm parameter + NormEpsilon float64 `json:"norm_epsilon"` + + // The position embedding type, for example 'rope', 'sinusoidal', 'alibi', etc. + PositionEmbedding string `json:"position_embedding"` + + // The base in signifying the rotary embedding period. + RotaryEmbeddingBase int `json:"rotary_embedding_base,omitempty"` + + // Fraction of hidden size to apply rotary embeddings to. Must be in [0,1]. + RotaryEmbeddingFraction float64 `json:"rotary_embedding_fraction,omitempty"` +} diff --git a/specs-go/v2/config.go b/specs-go/v2/config.go new file mode 100644 index 0000000..e88db6d --- /dev/null +++ b/specs-go/v2/config.go @@ -0,0 +1,40 @@ +package v2 + +import ( + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +// Config represents the JSON structure that encapsulates essential metadata and configuration details of a machine learning model. +type Config struct { + // Name specifies the unique identifier or title of the model. + Name string `json:"name"` + + // Family indicates the broader category or lineage of the model, such as 'GPT', 'LLAMA', or 'QWEN'. + // This helps in grouping related models or identifying their general capabilities. + Family string `json:"family"` + + // Architecture defines the fundamental structure or design of the model, + // such as 'transformer', 'CNN' (Convolutional Neural Network), 'RNN' (Recurrent Neural Network), etc. + // This information is crucial for understanding the model's underlying principles and potential applications. + Architecture string `json:"architecture"` + + // Description provides detailed information about the model's purpose, capabilities, and usage. + // It is represented as an array of Descriptors, allowing for rich, structured content. + Description []oci.Descriptor `json:"description,omitempty"` + + // License contains the legal and usage terms associated with the model. + // It includes policies and grants that govern how the model can be used, distributed, or modified. + // Represented as an array of Descriptors to accommodate multiple or complex licensing terms. + License []oci.Descriptor `json:"license,omitempty"` + + // Extensions allows for the inclusion of additional, model-specific configuration details. + // Each extension is represented by a Descriptor, enabling flexible and extensible metadata. + // This field accommodates unique requirements or features of different model types, such as: + // - Generation configuration: Parameters for text generation in language models + // - Quantization configuration: Details about model weight quantization + // - Transformer configuration: Specific architectural details for transformer models + // - Domain-specific settings: Configurations relevant to particular application domains + // The use of Descriptors ensures that each extension can be properly identified and processed, + // allowing for seamless integration of diverse model configurations within a unified structure. + Extensions []oci.Descriptor `json:"extensions,omitempty"` +} diff --git a/specs-go/v2/engine.go b/specs-go/v2/engine.go new file mode 100644 index 0000000..5fb1e19 --- /dev/null +++ b/specs-go/v2/engine.go @@ -0,0 +1,32 @@ +package v2 + +import oci "github.com/opencontainers/image-spec/specs-go/v1" + +// Engine provides the structure for the `application/vnd.cnai.models.engine.v0+json` mediatype when marshalled to JSON. +// It encapsulates the details necessary to describe and configure the execution environment for a model. +type Engine struct { + // Name specifies the engine or framework used, such as 'transformers', 'tensorrt', or 'vllm'. + // This field is crucial for identifying the runtime environment required for the model. + Name string `json:"name,omitempty"` + + // Version indicates the specific version of the engine or framework. + // Examples include '4.44.0', '8.10', '1.0', etc. This ensures compatibility and reproducibility. + Version string `json:"version,omitempty"` + + // Dependencies lists the additional packages or libraries required by the engine. + // This optional field is used to specify and install necessary components for the engine's operation. + Dependencies []string `json:"dependencies,omitempty"` + + // Environment defines key-value pairs for environment variables. + // These variables are used to configure the runtime environment for the engine executor. + Environment map[string]string `json:"environment,omitempty"` + + // EntryPoint specifies the command or script to initiate the engine. + // This optional field provides the starting point for executing the model within the engine. + EntryPoint string `json:"entrypoint,omitempty"` + + // Extensions allows for additional, engine-specific configuration details. + // Each extension is represented by a Descriptor, enabling flexible and extensible metadata + // to accommodate unique requirements or features of different engine types. + Extensions []oci.Descriptor `json:"extensions,omitempty"` +} diff --git a/specs-go/v2/manifest.go b/specs-go/v2/manifest.go new file mode 100644 index 0000000..c59ab97 --- /dev/null +++ b/specs-go/v2/manifest.go @@ -0,0 +1,36 @@ +package v2 + +import ( + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +// Manifest represents the structure for the `application/vnd.cncf.cnai.models.manifest.v2+json` mediatype when marshalled to JSON. +// It encapsulates all the essential components and metadata for a machine learning model. +type Manifest struct { + // Version specifies the version of the manifest schema. + Version string `json:"version"` + + // MediaType indicates the specific type of this document's data structure. + // It should be set to `application/vnd.cnai.models.manifest.v2+json` or an applicable IANA media type. + MediaType string `json:"mediaType"` + + // Config references the configuration object for the model. + // This JSON blob contains essential setup information used by the runtime. + Config Config `json:"config"` + + // Processor references the pre-processor object(s) by digest. + // It's used for any data preparation or transformation required before model inference. + Processor []oci.Descriptor `json:"processor"` + + // Weights references the model's weight object by digest. + // These are typically binary blobs containing the trained parameters of the model. + Weights Weights `json:"weights"` + + // Engine is an optional field that references the engine object by digest. + // The engine structure contains information for setting up the runtime environment. + Engine Engine `json:"engine,omitempty"` + + // Annotations is an optional map for storing arbitrary metadata related to the model manifest. + // This can include information like creation date, author, or custom tags. + Annotations map[string]string `json:"annotations,omitempty"` +} diff --git a/specs-go/v2/mediatype.go b/specs-go/v2/mediatype.go new file mode 100644 index 0000000..da57c25 --- /dev/null +++ b/specs-go/v2/mediatype.go @@ -0,0 +1,59 @@ +package v2 + +// manifest +const ( + // MediaTypeModelManifest specifies the media type for a models manifest. + MediaTypeModelManifest = "application/vnd.cnai.model.manifest.v2+json" +) + +// configs +const ( + // MediaTypeModelConfig specifies the media type for model configuration. + MediaTypeModelConfig = "application/vnd.cnai.model.config.v2+json" + + // MediaTypeModelLicense specifies the media type for model license. + MediaTypeModelLicense = "application/vnd.cnai.model.license.v2+json" + + // MediaTypeModelDescription specifies the media type for model description. + MediaTypeModelDescription = "application/vnd.cnai.model.description.v2+json" + + // MediaTypeModelExtension specifies the media type for model configuration extension. + MediaTypeModelExtension = "application/vnd.cnai.model.extension.v2+json" +) + +// processors +const ( + // MediaTypeModelProcessorText specifies the media type for text processors. + // This includes tokenizers like sentencepiece, used for processing textual input. + MediaTypeModelProcessorText = "application/vnd.cnai.model.processor.text.v2.tar" + + // MediaTypeModelProcessorAudio specifies the media type for audio processors. + // These are used for processing audio input, such as speech-to-text models. + MediaTypeModelProcessorAudio = "application/vnd.cnai.model.processor.audio.v2.tar" + + // MediaTypeModelProcessorImage specifies the media type for image processors. + // These are used for processing image input, such as in computer vision models. + MediaTypeModelProcessorImage = "application/vnd.cnai.model.processor.image.v2.tar" + + // MediaTypeModelProcessorMultiModal specifies the media type for multi-modal processors. + // These are used for models that can process multiple types of input (e.g., text and images). + MediaTypeModelProcessorMultiModal = "application/vnd.cnai.model.processor.multimodal.v2.tar" +) + +// weights +const ( + // MediaTypeModelWeights specifies the media type for model weights. + MediaTypeModelWeights = "application/vnd.cnai.model.weights.v2.tar" +) + +// engine +const ( + // MediaTypeModelEngine specifies the media type for model engine. + MediaTypeModelEngine = "application/vnd.cnai.model.engine.v2.tar" +) + +// transformer architecture +const ( + // MediaTypeModelArchitectureTransformer specifies the media type for model architecture. + MediaTypeModelArchitectureTransformer = "application/vnd.cnai.model.architecture.transformer.v2.tar" +) diff --git a/specs-go/v2/processor.go b/specs-go/v2/processor.go new file mode 100644 index 0000000..9e9e11a --- /dev/null +++ b/specs-go/v2/processor.go @@ -0,0 +1,42 @@ +package v2 + +import ( + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +// TextProcessor represents the structure for the `application/vnd.cnai.models.tokenizer.v0+json` mediatype when marshalled to JSON. +// It encapsulates the essential components of a text tokenizer used in natural language processing models. +type TextProcessor struct { + // TokenizerConfig is a Descriptor referencing the configuration file(s) for the tokenizer. + // This can be a single file or multiple files containing essential information such as: + // - Vocabulary: The set of tokens used by the tokenizer + // - Settings: Parameters that control tokenization behavior + // - Special tokens: Tokens with specific meanings or functions (e.g., [PAD], [CLS], [SEP]) + // Modern tokenizers often consolidate all configuration into a single file for simplicity, + // while some may still use separate files for different components. + TokenizerConfig oci.Descriptor `json:"tokenizer_config,omitempty"` + + // Algorithm is the tokenization algorithm used by the tokenizer, such as BPE, WordPiece, Unigram, etc. + Algorithm string `json:"algorithm,omitempty"` + + // Library is the library used by the tokenizer, such as sentencepiece, tiktoken, huggingface tokenizers, etc. + Library string `json:"library,omitempty"` +} + +// AudioProcessor represents the structure for the `application/vnd.cnai.models.processor.audio.v2+json` mediatype when marshalled to JSON. +// It encapsulates the essential components of an audio processor used in audio processing models. +type AudioProcessor struct { + // TODO: to be defined +} + +// ImageProcessor represents the structure for the `application/vnd.cnai.models.processor.image.v2+json` mediatype when marshalled to JSON. +// It encapsulates the essential components of an image processor used in image processing models. +type ImageProcessor struct { + // TODO: to be defined +} + +// MultiModalProcessor represents the structure for the `application/vnd.cnai.models.processor.multimodal.v2+json` mediatype when marshalled to JSON. +// It encapsulates the essential components of a multi-modal processor used in multi-modal processing models. +type MultiModalProcessor struct { + // TODO: to be defined +} diff --git a/specs-go/v2/weights.go b/specs-go/v2/weights.go new file mode 100644 index 0000000..eccc2fc --- /dev/null +++ b/specs-go/v2/weights.go @@ -0,0 +1,29 @@ +package v2 + +import ( + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +// Weights represents the structure for the `application/vnd.cnai.models.weights.v0+json` mediatype when marshalled to JSON. +// It encapsulates the essential information about a model's weights, including their storage format, numerical precision, and file references. +type Weights struct { + // File is an array of Descriptors referencing the inline files or directories containing the model weights. + // These Descriptors provide details such as the file size, digest, and media type of each weight file or directory. + File []oci.Descriptor `json:"file,omitempty"` + + // Format specifies the storage format of the weights. This field can include values such as: + // - 'safetensors': A fast and safe format for storing tensors + // - 'gguf': GPT-Generated Unified Format, used by some language models + // - 'onnx': Open Neural Network Exchange format + // - 'pytorch': PyTorch's native serialization format + // The format information is crucial for correctly loading and interpreting the weight data. + Format string `json:"format,omitempty"` + + // Precision indicates the numerical precision of the weights. This field can include values such as: + // - 'bf16': Brain Floating Point (bfloat16) + // - 'fp16': Half-precision floating-point + // - 'fp32': Single-precision floating-point + // - 'int8': 8-bit integer quantization + // The precision information is essential for memory management and computational efficiency. + Precision string `json:"precision,omitempty"` +} From 83b6fd577920f654ff19f5c772a49150d689d12c Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 11:22:11 +0800 Subject: [PATCH 2/8] docs: add docs for spec v2 --- docs/v2/modelfile.md | 82 ++++++++++++++++++++++++++++++++++++++++++++ docs/v2/tool.md | 51 +++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 docs/v2/modelfile.md create mode 100644 docs/v2/tool.md diff --git a/docs/v2/modelfile.md b/docs/v2/modelfile.md new file mode 100644 index 0000000..c14a2c1 --- /dev/null +++ b/docs/v2/modelfile.md @@ -0,0 +1,82 @@ + +### Modelfile +A Modelfile is a text file containing all commands, in order, needed to build a given model image. It automates the process of building model images. + +#### Modelfile Instructions +| **Instruction** | **Description** | +| --- | --- | +| CREATE | Create a new model image | +| FROM | Specify the base model image to use | +| NAME | Specify model name | +| FAMILY | Specify model family | +| ARCHITECTURE | Specify model architecture | +| LICENSE | Specify the legal license under which the model is used | +| CONFIG | Specify model configuration file | +| WEIGHTS | Specify model weights file | +| FORMAT | Specify model weights format | +| TOKENIZER | Specify tokenizer configuration | + +#### Modelfile Example +```plain +CREATE registry.cnai.com/sys/gemma-2b:latest + +# Model Information +NAME gemma-2b +FAMILY gemma +ARCHITECTURE transformer +FORMAT safetensors + +# Model License +LICENSE examples/huggingface/gemma-2b/LICENSE + +# Model Configuration +CONFIG examples/huggingface/gemma-2b/config.json +CONFIG examples/huggingface/gemma-2b/generation_config.json + +# Model Tokenizer +TOKENIZER examples/huggingface/gemma-2b/tokenizer.json + +# Model Weights +WEIGHTS examples/huggingface/gemma-2b/model.safetensors.index.json +WEIGHTS examples/huggingface/gemma-2b/model-00001-of-00002.safetensors +WEIGHTS examples/huggingface/gemma-2b/model-00002-of-00002.safetensors + +``` + +### Management tool +We propose a model management tool, which is a command-line tool for building, managing, and running AI models. + +#### build +We can use Modelfile to build model images. + +```plain +mdctl build -f ./Modelfile +``` + +#### list +We can list all the model images that have been pushed. + +```plain +mdctl list +``` + +#### push +We can push the built model image to a model repository. + +```plain +mdctl push +``` + +#### pull +We can pull the model image from the model repository to local storage. + +```plain +mdctl pull +``` + +#### unpack +We can pull the model image to local storage and then use mdctl to run the model. + +```plain +mdctl unpack +``` diff --git a/docs/v2/tool.md b/docs/v2/tool.md new file mode 100644 index 0000000..7c32598 --- /dev/null +++ b/docs/v2/tool.md @@ -0,0 +1,51 @@ +# mdctl - Model Control Tool + +`mdctl` is a command-line tool for building, managing, and running AI models. + +## Installation + +To install `mdctl`, clone the repository and build the binary: + +``` +git clone https://github.com/CloudNativeAI/mdctl.git +cd mdctl +go build +``` + +## Usage + +To build a model, use the `build` command: + +``` +./mdctl build -f Modelfile +``` + +To list all models, use the `list` command: + +``` +./mdctl list +``` + +To push a model, use the `push` command. Before pushing, you need to set the model registry credentials: + +``` +export MODEL_REGISTRY_USER= +export MODEL_REGISTRY_PASSWORD= +export MODEL_REGISTRY_URL= +``` + +``` +./mdctl push +``` + +To pull a model, use the `pull` command: + +``` +./mdctl pull +``` + +To run a model, use the `unpack` command: + +``` +./mdctl unpack -n +``` From f2fe71097f75ee0b9b7ec3cb9e0c0d5e11ccfba9 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 11:25:33 +0800 Subject: [PATCH 3/8] chore: add lint for tools --- .github/workflows/tools-lint.yml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .github/workflows/tools-lint.yml diff --git a/.github/workflows/tools-lint.yml b/.github/workflows/tools-lint.yml new file mode 100644 index 0000000..ce23006 --- /dev/null +++ b/.github/workflows/tools-lint.yml @@ -0,0 +1,31 @@ +name: Lint + +on: + push: + branches: [main, release-*] + pull_request: + branches: [main, release-*] + +permissions: + contents: read + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - name: Checkout code + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 + + - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 + with: + go-version-file: tools/mdctl/go.mod + cache: false + + - name: Golangci lint + uses: golangci/golangci-lint-action@aaa42aa0628b4ae2578232a66b541047968fac86 + with: + version: v1.54 + args: --verbose + working-directory: tools/mdctl From b739be61fc9072d6635ddd9a1d13310dd9ab1c60 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 11:26:28 +0800 Subject: [PATCH 4/8] feat: add tool for demo --- tools/mdctl/.gitignore | 5 + tools/mdctl/cmd/build.go | 168 ++++++++++++ tools/mdctl/cmd/cmd.go | 145 ++++++++++ tools/mdctl/cmd/list.go | 36 +++ tools/mdctl/cmd/pull.go | 115 ++++++++ tools/mdctl/cmd/push.go | 89 ++++++ tools/mdctl/cmd/run.go | 152 +++++++++++ .../examples/huggingface/gemma-2b/Modelfile | 22 ++ .../examples/huggingface/gemma-2b/README.md | 1 + .../examples/huggingface/gemma-2b/run.py | 10 + tools/mdctl/format/bytes.go | 47 ++++ tools/mdctl/format/format.go | 25 ++ tools/mdctl/format/parse.go | 135 ++++++++++ tools/mdctl/format/time.go | 68 +++++ tools/mdctl/go.mod | 25 ++ tools/mdctl/go.sum | 26 ++ tools/mdctl/main.go | 12 + tools/mdctl/models/archiver.go | 253 ++++++++++++++++++ tools/mdctl/models/descriptor.go | 154 +++++++++++ tools/mdctl/models/image.go | 45 ++++ tools/mdctl/models/layers.go | 165 ++++++++++++ tools/mdctl/models/manifests.go | 60 +++++ tools/mdctl/models/modelpath.go | 187 +++++++++++++ tools/mdctl/progress/bar.go | 215 +++++++++++++++ tools/mdctl/progress/progress.go | 113 ++++++++ tools/mdctl/progress/spinner.go | 73 +++++ tools/mdctl/registry/client.go | 193 +++++++++++++ tools/mdctl/version/version.go | 3 + 28 files changed, 2542 insertions(+) create mode 100644 tools/mdctl/.gitignore create mode 100644 tools/mdctl/cmd/build.go create mode 100644 tools/mdctl/cmd/cmd.go create mode 100644 tools/mdctl/cmd/list.go create mode 100644 tools/mdctl/cmd/pull.go create mode 100644 tools/mdctl/cmd/push.go create mode 100644 tools/mdctl/cmd/run.go create mode 100644 tools/mdctl/examples/huggingface/gemma-2b/Modelfile create mode 100644 tools/mdctl/examples/huggingface/gemma-2b/README.md create mode 100755 tools/mdctl/examples/huggingface/gemma-2b/run.py create mode 100644 tools/mdctl/format/bytes.go create mode 100644 tools/mdctl/format/format.go create mode 100644 tools/mdctl/format/parse.go create mode 100644 tools/mdctl/format/time.go create mode 100644 tools/mdctl/go.mod create mode 100644 tools/mdctl/go.sum create mode 100644 tools/mdctl/main.go create mode 100644 tools/mdctl/models/archiver.go create mode 100644 tools/mdctl/models/descriptor.go create mode 100644 tools/mdctl/models/image.go create mode 100644 tools/mdctl/models/layers.go create mode 100644 tools/mdctl/models/manifests.go create mode 100644 tools/mdctl/models/modelpath.go create mode 100644 tools/mdctl/progress/bar.go create mode 100644 tools/mdctl/progress/progress.go create mode 100644 tools/mdctl/progress/spinner.go create mode 100644 tools/mdctl/registry/client.go create mode 100644 tools/mdctl/version/version.go diff --git a/tools/mdctl/.gitignore b/tools/mdctl/.gitignore new file mode 100644 index 0000000..a921e72 --- /dev/null +++ b/tools/mdctl/.gitignore @@ -0,0 +1,5 @@ +mdctl +vendor +*.safetensors +*.model +gemma-2b:* diff --git a/tools/mdctl/cmd/build.go b/tools/mdctl/cmd/build.go new file mode 100644 index 0000000..5300c33 --- /dev/null +++ b/tools/mdctl/cmd/build.go @@ -0,0 +1,168 @@ +package cmd + +import ( + "fmt" + + v2 "github.com/CloudNativeAI/model-spec/specs-go/v2" + "github.com/CloudNativeAI/model-spec/tools/mdctl/format" + "github.com/CloudNativeAI/model-spec/tools/mdctl/models" + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +func BuildModel(commands []format.Command) error { + manifest := v2.Manifest{MediaType: v2.MediaTypeModelManifest} + config := v2.Config{} + weights := v2.Weights{} + engine := v2.Engine{} + + if len(commands) == 0 { + return fmt.Errorf("modelfile has no command") + } + + var modelName string + if commands[0].Name == format.CREATE { + modelName = commands[0].Args + fmt.Println("Create", modelName) + } else if commands[0].Name == format.FROM { + modelName = commands[0].Args + fmt.Println("From ", modelName) + if err := PullModel(modelName); err != nil { + return fmt.Errorf("failed to pull base model") + } + _, err := FetchManifest(modelName, &manifest, &config) + if err != nil { + return fmt.Errorf("failed to get remote manifest") + } + } else { + return fmt.Errorf("first command should be %s or %s", format.CREATE, format.FROM) + } + for _, c := range commands { + switch c.Name { + case format.CREATE, format.FROM: + config.Name = c.Args + + case format.NAME: + config.Name = c.Args + + case format.DESCRIPTION: + layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelDescription, "Description") + if err != nil { + return fmt.Errorf("failed to build description layer: %w", err) + } + config.Description = append(config.Description, *layer) + fmt.Printf("Add description [%s]\n", c.Args) + + case format.LICENSE: + layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelLicense, "License") + if err != nil { + return fmt.Errorf("failed to build license layer: %w", err) + } + config.License = append(config.License, *layer) + fmt.Printf("Add license [%s]\n", c.Args) + + case format.ARCHITECTURE: + config.Architecture = c.Args + + case format.FAMILY: + config.Family = c.Args + + case format.CONFIG: + layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelConfig, "") + if err != nil { + return fmt.Errorf("failed to build config layer: %w", err) + } + config.Extensions = append(config.Extensions, *layer) + fmt.Printf("Add config [%s]\n", c.Args) + + case format.PARAM_SIZE: + engine.Name = c.Args + + case format.FORMAT: + weights.Format = c.Args + + case format.WEIGHTS: + layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelWeights, "") + if err != nil { + return fmt.Errorf("failed to build weights layer: %w", err) + } + weights.File = append(weights.File, *layer) + fmt.Printf("Add weights [%s]\n", c.Args) + + case format.TOKENIZER: + layer, err := models.BuildDescriptor(models.TAR, c.Args, v2.MediaTypeModelProcessorText, "") + if err != nil { + return fmt.Errorf("failed to build tokenizer layer: %w", err) + } + manifest.Processor = append(manifest.Processor, *layer) + fmt.Printf("Add tokenizer [%s]\n", c.Args) + + default: + fmt.Printf("WARN: [%s] - [%s] not handled\n", c.Name, c.Args) + } + } + + manifest.Config = config + manifest.Weights = weights + manifest.Engine = engine + + // Commit layers + _, err := Commit(&manifest) + if err != nil { + return fmt.Errorf("failed to commit layers: %w", err) + } + + // Commit manifest layer + if err := models.WriteManifest(modelName, &manifest); err != nil { + return fmt.Errorf("failed to write manifest: %w", err) + } + + fmt.Println("Build succeed") + return nil +} + +func Commit(m *v2.Manifest) (bool, error) { + layerGroups := []struct { + name string + layers []oci.Descriptor + }{ + {"Description", m.Config.Description}, + {"License", m.Config.License}, + {"Extensions", m.Config.Extensions}, + {"Weights", m.Weights.File}, + {"Tokenizer", m.Processor}, + } + + var committed bool + for _, group := range layerGroups { + if len(group.layers) == 0 { + continue // Skip empty layer groups + } + groupCommitted, err := commitLayers(group.name, group.layers) + if err != nil { + return false, fmt.Errorf("failed to commit %s layers: %w", group.name, err) + } + committed = committed || groupCommitted + } + return committed, nil +} + +func commitLayers(groupName string, layers []oci.Descriptor) (bool, error) { + var groupCommitted bool + for _, layer := range layers { + layerCommitted, err := commitSingleLayer(groupName, layer) + if err != nil { + return false, err + } + groupCommitted = groupCommitted || layerCommitted + } + return groupCommitted, nil +} + +func commitSingleLayer(groupName string, layer oci.Descriptor) (bool, error) { + committed, err := models.Commit(layer) + if err != nil { + return false, fmt.Errorf("failed to commit %s layer: %w", groupName, err) + } + + return committed, nil +} diff --git a/tools/mdctl/cmd/cmd.go b/tools/mdctl/cmd/cmd.go new file mode 100644 index 0000000..de970e6 --- /dev/null +++ b/tools/mdctl/cmd/cmd.go @@ -0,0 +1,145 @@ +package cmd + +import ( + "bytes" + "fmt" + "log" + "os" + "path/filepath" + + "github.com/CloudNativeAI/model-spec/tools/mdctl/format" + "github.com/CloudNativeAI/model-spec/tools/mdctl/progress" + "github.com/spf13/cobra" +) + +func BuildHandler(cmd *cobra.Command, args []string) error { + filename, _ := cmd.Flags().GetString("file") + filename, err := filepath.Abs(filename) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + + p := progress.NewProgress(os.Stderr) + // defer p.Stop() + // bars := make(map[string]*progress.Bar) + + modelFile, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("failed to read modelfile: %w", err) + } + + commands, err := format.Parse(bytes.NewReader(modelFile)) + if err != nil { + return fmt.Errorf("failed to parse modelfile: %w", err) + } + + // status := "building" + // spinner := progress.NewSpinner(status) + // p.Add(status, spinner) + + if err := BuildModel(commands); err != nil { + return fmt.Errorf("failed to build model: %w", err) + } + p.StopAndClear() + + return nil +} + +func RunHandler(cmd *cobra.Command, _ []string) error { + name, _ := cmd.Flags().GetString("name") + fmt.Println("Unpack Model: ", name) + if err := RunModel(name); err != nil { + return fmt.Errorf("failed to unpack model: %w", err) + } + fmt.Println("Unpack succeed") + return nil +} + +func PushHandler(cmd *cobra.Command, _ []string) error { + name, _ := cmd.Flags().GetString("name") + fmt.Println("Push Model:", name) + if err := PushModel(name); err != nil { + return fmt.Errorf("failed to push model: %w", err) + } + return nil +} + +func PullHandler(cmd *cobra.Command, _ []string) error { + name, _ := cmd.Flags().GetString("name") + fmt.Println("Pull Model:", name) + if err := PullModel(name); err != nil { + return fmt.Errorf("failed to pull model: %w", err) + } + return nil +} + +func ListHandler(cmd *cobra.Command, args []string) error { + return ListModel() +} + +func NewCLI() *cobra.Command { + log.SetFlags(log.LstdFlags | log.Lshortfile) + cobra.EnableCommandSorting = false + + rootCmd := &cobra.Command{ + Use: "mdctl", + Short: "Model management tool", + SilenceUsage: true, + SilenceErrors: true, + CompletionOptions: cobra.CompletionOptions{ + DisableDefaultCmd: true, + }, + Run: func(cmd *cobra.Command, args []string) { + cmd.Print(cmd.UsageString()) + }, + } + + buildCmd := &cobra.Command{ + Use: "build", + Short: "build models from a Modelfile", + Args: cobra.ExactArgs(0), + RunE: BuildHandler, + } + buildCmd.Flags().StringP("file", "f", "Modelfile", "Path to the Modelfile") + + runCmd := &cobra.Command{ + Use: "unpack", + Short: "run a model", + Args: cobra.ExactArgs(0), + RunE: RunHandler, + } + runCmd.Flags().StringP("name", "n", "", "URL of the model") + + pushCmd := &cobra.Command{ + Use: "push", + Short: "push a model", + Args: cobra.ExactArgs(0), + RunE: PushHandler, + } + pushCmd.Flags().StringP("name", "n", "", "URL of the model") + + pullCmd := &cobra.Command{ + Use: "pull", + Short: "pull a model", + Args: cobra.ExactArgs(0), + RunE: PullHandler, + } + pullCmd.Flags().StringP("name", "n", "", "URL of the model") + + listCmd := &cobra.Command{ + Use: "list", + Short: "list models", + Args: cobra.ExactArgs(0), + RunE: ListHandler, + } + + rootCmd.AddCommand( + buildCmd, + listCmd, + runCmd, + pushCmd, + pullCmd, + ) + + return rootCmd +} diff --git a/tools/mdctl/cmd/list.go b/tools/mdctl/cmd/list.go new file mode 100644 index 0000000..d54a11e --- /dev/null +++ b/tools/mdctl/cmd/list.go @@ -0,0 +1,36 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/CloudNativeAI/model-spec/tools/mdctl/models" +) + +func ListModel() error { + dir, err := models.GetManifestRoot() + if err != nil { + return err + } + err = filepath.Walk(dir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("failed to walk model: %w", err) + } + if !info.IsDir() { + term := dir + string(os.PathSeparator) + name := strings.TrimPrefix(path, term) + lastSeparatorIndex := strings.LastIndex(name, string(os.PathSeparator)) + if lastSeparatorIndex != -1 { + name = name[:lastSeparatorIndex] + ":" + name[lastSeparatorIndex+1:] + } + fmt.Println(name) + } + return nil + }) + if err != nil { + return fmt.Errorf("failed to list model: %w", err) + } + return nil +} diff --git a/tools/mdctl/cmd/pull.go b/tools/mdctl/cmd/pull.go new file mode 100644 index 0000000..8b34449 --- /dev/null +++ b/tools/mdctl/cmd/pull.go @@ -0,0 +1,115 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + spec "github.com/CloudNativeAI/model-spec/specs-go/v2" + "github.com/CloudNativeAI/model-spec/tools/mdctl/models" + "github.com/CloudNativeAI/model-spec/tools/mdctl/registry" + v1 "github.com/opencontainers/image-spec/specs-go/v1" +) + +func FetchManifest(name string, manifest *spec.Manifest, config *spec.Config) (*v1.Manifest, error) { + mp := models.ParseModelPath(name) + repo, err := registry.NewRepo(mp.Namespace, mp.Name) + if err != nil { + return nil, fmt.Errorf("failed to new repo: %w", err) + } + ctx := context.Background() + imageManifest, err := registry.PullImageManifest(repo, ctx, mp.Tag) + if err != nil { + return nil, fmt.Errorf("failed to pull image manifest: %w", err) + } + + // Fetch layers + for _, layer := range imageManifest.Layers { + switch layer.MediaType { + case spec.MediaTypeModelManifest, spec.MediaTypeModelConfig: + default: + continue + } + + // create temp file + // TODO: use []byte + tempRoot, err := models.GetBlobsPath("") + if err != nil { + return nil, fmt.Errorf("failed to get blobs path: %w", err) + } + delimiter := ":" + pattern := strings.Join([]string{"sha256", "*-temp"}, delimiter) + temp, err := os.CreateTemp(tempRoot, pattern) + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + defer temp.Close() + + //fmt.Println("Pull layer: ", layer.Digest, layer.Size) + err = registry.PullLayer(repo, ctx, layer.Digest.String(), layer.Size, temp.Name()) + if err != nil { + return nil, fmt.Errorf("failed to pull layer: %w", err) + } + content, err := os.ReadFile(temp.Name()) + if err != nil { + return nil, fmt.Errorf("failed to read temp file: %w", err) + } + + switch layer.MediaType { + case spec.MediaTypeModelManifest: + if err := json.Unmarshal(content, &manifest); err != nil { + return nil, fmt.Errorf("failed to unmarshal manifest: %w", err) + } + case spec.MediaTypeModelConfig: + if err := json.Unmarshal(content, &config); err != nil { + return nil, fmt.Errorf("failed to unmarshal config: %w", err) + } + } + } + return imageManifest, nil +} + +func PullModel(name string) error { + mp := models.ParseModelPath(name) + + repo, err := registry.NewRepo(mp.Namespace, mp.Name) + if err != nil { + return fmt.Errorf("failed to new repo: %w", err) + } + ctx := context.Background() + + image_manifest, err := registry.PullImageManifest(repo, ctx, mp.Tag) + if err != nil { + return fmt.Errorf("failed to pull image manifest: %w", err) + } + + for _, layer := range image_manifest.Layers { + fmt.Println("Pull layer:", layer.Digest, layer.Size) + digest := layer.Digest.String() + + var targetPath string + if layer.MediaType == spec.MediaTypeModelManifest { + targetPath, err = mp.GetManifestPath() + targetDir := filepath.Dir(targetPath) + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("failed to mkdir: %w", err) + } + } else { + targetPath, err = models.GetBlobsPath(digest) + } + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + + err = registry.PullLayer(repo, ctx, digest, layer.Size, targetPath) + if err != nil { + return fmt.Errorf("failed to pull layer: %w", err) + } + } + fmt.Println("Pull succeed") + + return nil +} diff --git a/tools/mdctl/cmd/push.go b/tools/mdctl/cmd/push.go new file mode 100644 index 0000000..cc580e3 --- /dev/null +++ b/tools/mdctl/cmd/push.go @@ -0,0 +1,89 @@ +package cmd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os" + + modelspec "github.com/CloudNativeAI/model-spec/specs-go/v2" + "github.com/CloudNativeAI/model-spec/tools/mdctl/models" + "github.com/CloudNativeAI/model-spec/tools/mdctl/registry" + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +func PushModel(name string) error { + mp := models.ParseModelPath(name) + + manifestPath, err := mp.GetManifestPath() + if err != nil { + return fmt.Errorf("failed to get manifest path: %w", err) + } + + manifestFile, err := os.Open(manifestPath) + if err != nil { + return fmt.Errorf("failed to open manifest file: %w", err) + } + defer manifestFile.Close() + + manifest := modelspec.Manifest{} + if err := json.NewDecoder(manifestFile).Decode(&manifest); err != nil { + return fmt.Errorf("failed to decode manifest: %w", err) + } + + repo, err := registry.NewRepo(mp.Namespace, mp.Name) + if err != nil { + return fmt.Errorf("failed to new repo: %w", err) + } + ctx := context.Background() + + var layers []oci.Descriptor + + layerGroups := []struct { + name string + layers []oci.Descriptor + }{ + {"Description", manifest.Config.Description}, + {"License", manifest.Config.License}, + {"Extensions", manifest.Config.Extensions}, + {"Weights", manifest.Weights.File}, + {"Tokenizer", manifest.Processor}, + } + + for _, group := range layerGroups { + if len(group.layers) == 0 { + continue // Skip empty layer groups + } + layers = append(layers, group.layers...) + + for _, layer := range group.layers { + fmt.Println("Push layer:", layer.Digest, layer.Size) + _, err := registry.PushLayer(repo, ctx, &layer) + if err != nil { + return fmt.Errorf("failed to push layer: %w", err) + } + } + } + + manifestDesc, err := registry.PushModelManifest(repo, ctx, manifestPath) + if err != nil { + return fmt.Errorf("failed to push model manifest: %w", err) + } + + // push empty layer + err = repo.Push(ctx, oci.DescriptorEmptyJSON, bytes.NewReader(oci.DescriptorEmptyJSON.Data)) + if err != nil { + return fmt.Errorf("failed to push empty layer: %w", err) + } + + // assemble descriptors and model manifest to a image manifest + layers = append(layers, *manifestDesc) + err = registry.PushModel(repo, mp.Tag, ctx, layers) + if err != nil { + return fmt.Errorf("failed to push oci image manifest: %w", err) + } + + fmt.Println("Push succeed") + return nil +} diff --git a/tools/mdctl/cmd/run.go b/tools/mdctl/cmd/run.go new file mode 100644 index 0000000..383c234 --- /dev/null +++ b/tools/mdctl/cmd/run.go @@ -0,0 +1,152 @@ +package cmd + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/CloudNativeAI/model-spec/tools/mdctl/models" +) + +const ( + DOT_GITS_DIR = ".gits" + DOT_VOLUMES_DIR = ".volumes" + MODEL_DIR = "model" + DATASET_DIR = "dataset" + SOURCE_DIR = "source" + TASK_DIR = "task" + ENTRYPOINT = "run.py" + SETUP = "setup.sh" + CONFIG = "config.json" + INFO = "info.json" + LICENSE = "LICENSE" +) + +func RunModel(name string) error { + mp := models.ParseModelPath(name) + manifest, _, err := models.GetManifest(mp) + if err != nil { + return fmt.Errorf("failed to get manifest: %w", err) + } + + root, err := filepath.Abs(mp.Name + ":" + mp.Tag) + if err != nil { + return fmt.Errorf("failed to get absolute path: %w", err) + } + + if err := os.MkdirAll(root, 0o755); err != nil { + return fmt.Errorf("failed to mkdir: %w", err) + } + + for _, layer := range manifest.Weights.File { + filename, err := models.GetBlobsPath(layer.Digest.String()) + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + if err := models.UnTar(filename, root); err != nil { + return fmt.Errorf("failed to untar: %w", err) + } + } + + for _, layer := range manifest.Processor { + filename, err := models.GetBlobsPath(layer.Digest.String()) + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + if err := models.UnTar(filename, root); err != nil { + return fmt.Errorf("failed to untar: %w", err) + } + } + + for _, layer := range manifest.Config.Description { + filename, err := models.GetBlobsPath(layer.Digest.String()) + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + if err := models.UnTar(filename, root); err != nil { + return fmt.Errorf("failed to untar: %w", err) + } + } + + for _, layer := range manifest.Config.License { + filename, err := models.GetBlobsPath(layer.Digest.String()) + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + if err := models.UnTar(filename, root); err != nil { + return fmt.Errorf("failed to untar: %w", err) + } + } + + for _, layer := range manifest.Config.Extensions { + filename, err := models.GetBlobsPath(layer.Digest.String()) + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + if err := models.UnTar(filename, root); err != nil { + return fmt.Errorf("failed to untar: %w", err) + } + } + + err = os.Chdir(root) + if err != nil { + return fmt.Errorf("failed to change workdir: %w", err) + } + + // stdout, stderr, err = executeScript(entrypoint, []string{}) + // if err != nil { + // fmt.Printf("Error: %v\n", err) + // if stderr != "" { + // fmt.Printf("Stderr: %v\n", stderr) + // } + // return err + // } + // fmt.Printf("Stdout: %v\n", stdout) + + return nil +} + +func executeBinary(binaryPath string, args []string) (stdout string, stderr string, err error) { + cmd := exec.Command(binaryPath, args...) + + var outBuf, errBuf bytes.Buffer + cmd.Stdout = &outBuf + cmd.Stderr = &errBuf + + err = cmd.Run() + if err != nil { + return "", "", fmt.Errorf("failed to execute binary: %w", err) + } + + stdout = outBuf.String() + stderr = errBuf.String() + + return stdout, stderr, nil +} + +func executeScript(scriptPath string, args []string) (stdout string, stderr string, err error) { + var cmd *exec.Cmd + if bytes.HasSuffix([]byte(scriptPath), []byte(".sh")) { + cmd = exec.Command("bash", append([]string{scriptPath}, args...)...) + } else if bytes.HasSuffix([]byte(scriptPath), []byte(".py")) { + cmd = exec.Command("python3", append([]string{scriptPath}, args...)...) + } else { + return "", "", fmt.Errorf("unsupported script type: %s", scriptPath) + } + + var outBuf, errBuf bytes.Buffer + cmd.Stdout = &outBuf + cmd.Stderr = &errBuf + + err = cmd.Run() + if err != nil { + return "", "", fmt.Errorf("failed to execute script: %w", err) + } + + stdout = outBuf.String() + stderr = errBuf.String() + + return stdout, stderr, nil +} diff --git a/tools/mdctl/examples/huggingface/gemma-2b/Modelfile b/tools/mdctl/examples/huggingface/gemma-2b/Modelfile new file mode 100644 index 0000000..0c203d1 --- /dev/null +++ b/tools/mdctl/examples/huggingface/gemma-2b/Modelfile @@ -0,0 +1,22 @@ +CREATE registry.cnai.com/models/gemma-2b:latest + +# Model Information +NAME gemma-2b +FAMILY gemma +ARCHITECTURE transformer +FORMAT safetensors + +# Model License +LICENSE examples/huggingface/gemma-2b/LICENSE + +# Model Configuration +CONFIG examples/huggingface/gemma-2b/config.json +CONFIG examples/huggingface/gemma-2b/generation_config.json + +# Model Tokenizer +TOKENIZER examples/huggingface/gemma-2b/tokenizer.json + +# Model Weights +WEIGHTS examples/huggingface/gemma-2b/model.safetensors.index.json +WEIGHTS examples/huggingface/gemma-2b/model-00001-of-00002.safetensors +WEIGHTS examples/huggingface/gemma-2b/model-00002-of-00002.safetensors diff --git a/tools/mdctl/examples/huggingface/gemma-2b/README.md b/tools/mdctl/examples/huggingface/gemma-2b/README.md new file mode 100644 index 0000000..11a27f5 --- /dev/null +++ b/tools/mdctl/examples/huggingface/gemma-2b/README.md @@ -0,0 +1 @@ +gemma-2b \ No newline at end of file diff --git a/tools/mdctl/examples/huggingface/gemma-2b/run.py b/tools/mdctl/examples/huggingface/gemma-2b/run.py new file mode 100755 index 0000000..4788a6d --- /dev/null +++ b/tools/mdctl/examples/huggingface/gemma-2b/run.py @@ -0,0 +1,10 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM + +tokenizer = AutoTokenizer.from_pretrained("gemma-2b:latest") +model = AutoModelForCausalLM.from_pretrained("gemma-2b:latest") + +input_text = "Who are you?" +input_ids = tokenizer(input_text, return_tensors="pt") + +outputs = model.generate(**input_ids, max_length=64) +print(tokenizer.decode(outputs[0])) diff --git a/tools/mdctl/format/bytes.go b/tools/mdctl/format/bytes.go new file mode 100644 index 0000000..01be17c --- /dev/null +++ b/tools/mdctl/format/bytes.go @@ -0,0 +1,47 @@ +package format + +import ( + "fmt" + "math" +) + +const ( + Byte = 1 + KiloByte = Byte * 1000 + MegaByte = KiloByte * 1000 + GigaByte = MegaByte * 1000 + TeraByte = GigaByte * 1000 +) + +func HumanBytes(b int64) string { + var value float64 + var unit string + + switch { + case b >= TeraByte: + value = float64(b) / TeraByte + unit = "TB" + case b >= GigaByte: + value = float64(b) / GigaByte + unit = "GB" + case b >= MegaByte: + value = float64(b) / MegaByte + unit = "MB" + case b >= KiloByte: + value = float64(b) / KiloByte + unit = "KB" + default: + return fmt.Sprintf("%d B", b) + } + + switch { + case value >= 100: + return fmt.Sprintf("%d %s", int(value), unit) + case value >= 10: + return fmt.Sprintf("%d %s", int(value), unit) + case value != math.Trunc(value): + return fmt.Sprintf("%.1f %s", value, unit) + default: + return fmt.Sprintf("%d %s", int(value), unit) + } +} diff --git a/tools/mdctl/format/format.go b/tools/mdctl/format/format.go new file mode 100644 index 0000000..8fd2def --- /dev/null +++ b/tools/mdctl/format/format.go @@ -0,0 +1,25 @@ +package format + +import ( + "fmt" + "math" +) + +const ( + Thousand = 1000 + Million = Thousand * 1000 + Billion = Million * 1000 +) + +func HumanNumber(b uint64) string { + switch { + case b > Billion: + return fmt.Sprintf("%.0fB", math.Round(float64(b)/Billion)) + case b > Million: + return fmt.Sprintf("%.0fM", math.Round(float64(b)/Million)) + case b > Thousand: + return fmt.Sprintf("%.0fK", math.Round(float64(b)/Thousand)) + default: + return fmt.Sprintf("%d", b) + } +} diff --git a/tools/mdctl/format/parse.go b/tools/mdctl/format/parse.go new file mode 100644 index 0000000..222daae --- /dev/null +++ b/tools/mdctl/format/parse.go @@ -0,0 +1,135 @@ +package format + +import ( + "bufio" + "bytes" + "fmt" + "io" + "log" + "strings" +) + +const ( + CREATE = "create" + FROM = "from" + NAME = "name" + FAMILY = "family" + ARCHITECTURE = "architecture" + LICENSE = "license" + DESCRIPTION = "description" + PARAM_SIZE = "param_size" + WEIGHTS = "weights" + TOKENIZER = "tokenizer" + PRECISION = "precision" + FORMAT = "format" + QUANTIZATION = "quantization" + CONFIG = "config" +) + +type Command struct { + Name string + Args string +} + +func (c *Command) Reset() { + c.Name = "" + c.Args = "" +} + +func Parse(reader io.Reader) ([]Command, error) { + var commands []Command + var command Command + var modelCommand Command + scanner := bufio.NewScanner(reader) + scanner.Buffer(make([]byte, 0, bufio.MaxScanTokenSize), bufio.MaxScanTokenSize) + scanner.Split(scanModelfile) + for scanner.Scan() { + line := scanner.Bytes() + + fields := bytes.SplitN(line, []byte(" "), 2) + if len(fields) == 0 || len(fields[0]) == 0 { + continue + } + + switch string(bytes.ToUpper(fields[0])) { + case strings.ToUpper(CREATE), strings.ToUpper(FROM): + command.Name = string(bytes.ToLower(fields[0])) + command.Args = string(bytes.TrimSpace(fields[1])) + modelCommand = command + case strings.ToUpper(NAME), + strings.ToUpper(FAMILY), + strings.ToUpper(ARCHITECTURE), + strings.ToUpper(LICENSE), + strings.ToUpper(DESCRIPTION), + strings.ToUpper(FORMAT), + strings.ToUpper(PRECISION), + strings.ToUpper(QUANTIZATION), + strings.ToUpper(PARAM_SIZE), + strings.ToUpper(WEIGHTS), + strings.ToUpper(CONFIG), + strings.ToUpper(TOKENIZER): + command.Name = string(bytes.ToLower(fields[0])) + command.Args = string(bytes.TrimSpace(fields[1])) + + default: + if !bytes.HasPrefix(fields[0], []byte("#")) { + log.Printf("WARNING: unknown command: %s", fields[0]) + } + continue + } + + commands = append(commands, command) + command.Reset() + } + + if modelCommand.Args == "" { + return nil, fmt.Errorf("no FROM or CREATE line was specified") + } + + return commands, scanner.Err() +} + +func scanModelfile(data []byte, atEOF bool) (advance int, token []byte, err error) { + advance, token, err = scan([]byte(`"""`), []byte(`"""`), data, atEOF) + if err != nil { + return 0, nil, fmt.Errorf("failed to scan modelfile: %w", err) + } + + if advance > 0 && token != nil { + return advance, token, nil + } + + advance, token, err = scan([]byte(`"`), []byte(`"`), data, atEOF) + if err != nil { + return 0, nil, fmt.Errorf("failed to scan modelfile: %w", err) + } + + if advance > 0 && token != nil { + return advance, token, nil + } + + return bufio.ScanLines(data, atEOF) +} + +func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token []byte, err error) { + newline := bytes.IndexByte(data, '\n') + + if start := bytes.Index(data, openBytes); start >= 0 && start < newline { + end := bytes.Index(data[start+len(openBytes):], closeBytes) + if end < 0 { + if atEOF { + return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes) + } else { + return 0, nil, nil + } + } + + n := start + len(openBytes) + end + len(closeBytes) + + newData := data[:start] + newData = append(newData, data[start+len(openBytes):n-len(closeBytes)]...) + return n, newData, nil + } + + return 0, nil, nil +} diff --git a/tools/mdctl/format/time.go b/tools/mdctl/format/time.go new file mode 100644 index 0000000..6637c06 --- /dev/null +++ b/tools/mdctl/format/time.go @@ -0,0 +1,68 @@ +package format + +import ( + "fmt" + "math" + "strings" + "time" +) + +// humanDuration returns a human-readable approximation of a +// duration (eg. "About a minute", "4 hours ago", etc.). +func humanDuration(d time.Duration) string { + seconds := int(d.Seconds()) + + switch { + case seconds < 1: + return "Less than a second" + case seconds == 1: + return "1 second" + case seconds < 60: + return fmt.Sprintf("%d seconds", seconds) + } + + minutes := int(d.Minutes()) + switch { + case minutes == 1: + return "About a minute" + case minutes < 60: + return fmt.Sprintf("%d minutes", minutes) + } + + hours := int(math.Round(d.Hours())) + switch { + case hours == 1: + return "About an hour" + case hours < 48: + return fmt.Sprintf("%d hours", hours) + case hours < 24*7*2: + return fmt.Sprintf("%d days", hours/24) + case hours < 24*30*2: + return fmt.Sprintf("%d weeks", hours/24/7) + case hours < 24*365*2: + return fmt.Sprintf("%d months", hours/24/30) + } + + return fmt.Sprintf("%d years", int(d.Hours())/24/365) +} + +func HumanTime(t time.Time, zeroValue string) string { + return humanTime(t, zeroValue) +} + +func HumanTimeLower(t time.Time, zeroValue string) string { + return strings.ToLower(humanTime(t, zeroValue)) +} + +func humanTime(t time.Time, zeroValue string) string { + if t.IsZero() { + return zeroValue + } + + delta := time.Since(t) + if delta < 0 { + return humanDuration(-delta) + " from now" + } + + return humanDuration(delta) + " ago" +} diff --git a/tools/mdctl/go.mod b/tools/mdctl/go.mod new file mode 100644 index 0000000..5ab4d68 --- /dev/null +++ b/tools/mdctl/go.mod @@ -0,0 +1,25 @@ +module github.com/CloudNativeAI/model-spec/tools/mdctl + +go 1.22.4 + +require ( + github.com/klauspost/compress v1.17.7 + github.com/opencontainers/go-digest v1.0.0 + github.com/opencontainers/image-spec v1.1.0 + github.com/spf13/cobra v1.7.0 + oras.land/oras-go/v2 v2.4.0 +) + +require ( + github.com/CloudNativeAI/model-spec/specs-go v0.0.0-20240925072522-ca68e666bb02 // indirect + golang.org/x/sync v0.6.0 // indirect +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/term v0.15.0 +) + +replace github.com/CloudNativeAI/model-spec/specs-go/ => ../../specs-go/ diff --git a/tools/mdctl/go.sum b/tools/mdctl/go.sum new file mode 100644 index 0000000..9ad0dc4 --- /dev/null +++ b/tools/mdctl/go.sum @@ -0,0 +1,26 @@ +github.com/CloudNativeAI/model-spec/specs-go v0.0.0-20240925072522-ca68e666bb02 h1:hldWO7cYXMsfCFjlQ2VcGd9PfQYt79sPN2mSjtHVrdc= +github.com/CloudNativeAI/model-spec/specs-go v0.0.0-20240925072522-ca68e666bb02/go.mod h1:aqXPi8WPdmWT8sUAQYi7gStLYBhiud0dIT75PskIYpE= +github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= +github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= +golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +oras.land/oras-go/v2 v2.4.0 h1:i+Wt5oCaMHu99guBD0yuBjdLvX7Lz8ukPbwXdR7uBMs= +oras.land/oras-go/v2 v2.4.0/go.mod h1:osvtg0/ClRq1KkydMAEu/IxFieyjItcsQ4ut4PPF+f8= diff --git a/tools/mdctl/main.go b/tools/mdctl/main.go new file mode 100644 index 0000000..d5c33fe --- /dev/null +++ b/tools/mdctl/main.go @@ -0,0 +1,12 @@ +package main + +import ( + "context" + + "github.com/CloudNativeAI/model-spec/tools/mdctl/cmd" + "github.com/spf13/cobra" +) + +func main() { + cobra.CheckErr(cmd.NewCLI().ExecuteContext(context.Background())) +} diff --git a/tools/mdctl/models/archiver.go b/tools/mdctl/models/archiver.go new file mode 100644 index 0000000..1420a99 --- /dev/null +++ b/tools/mdctl/models/archiver.go @@ -0,0 +1,253 @@ +package models + +import ( + "archive/tar" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/klauspost/compress/zstd" +) + +func Tar(src, dst, newName string) error { + fi, err := os.Stat(src) + if err != nil { + return fmt.Errorf("failed to stat src: %w", err) + } + + if fi.IsDir() { + return tarDirectory(src, dst) + } else { + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("failed to create tar file: %w", err) + } + defer out.Close() + return tarFile(src, newName, out) + } +} + +func tarFile(src, newName string, writers ...io.Writer) error { + mw := io.MultiWriter(writers...) + tw := tar.NewWriter(mw) + defer tw.Close() + + file, err := os.Open(src) + if err != nil { + return fmt.Errorf("failed to open src: %w", err) + } + defer file.Close() + + stat, err := file.Stat() + if err != nil { + return fmt.Errorf("failed to stat src: %w", err) + } + + // todo: check the link arg? + header, err := tar.FileInfoHeader(stat, stat.Name()) + if err != nil { + return fmt.Errorf("failed to get file info header: %w", err) + } + if newName != "" { + header.Name = newName + } + + if err := tw.WriteHeader(header); err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + + if _, err := io.Copy(tw, file); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} + +func tarDirectory(srcPath, tarPath string) error { + tarFile, err := os.Create(tarPath) + if err != nil { + return fmt.Errorf("failed to create tar file: %w", err) + } + defer tarFile.Close() + + tw := tar.NewWriter(tarFile) + defer tw.Close() + + var tarWalkFn filepath.WalkFunc = func(file string, fi os.FileInfo, err error) error { + if err != nil { + return fmt.Errorf("failed to walk file: %w", err) + } + relPath, err := filepath.Rel(srcPath, file) + if err != nil { + return fmt.Errorf("failed to get relative path: %w", err) + } + if relPath == "." { + // Skip the root directory entry + return nil + } + header, err := tar.FileInfoHeader(fi, relPath) + if err != nil { + return fmt.Errorf("failed to get file info header: %w", err) + } + header.Name = relPath + if err := tw.WriteHeader(header); err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + if fi.IsDir() { + return nil + } + data, err := os.Open(file) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer data.Close() + if _, err := io.Copy(tw, data); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + return nil + } + + srcInfo, err := os.Stat(srcPath) + if err != nil { + return fmt.Errorf("failed to stat src: %w", err) + } + if srcInfo.IsDir() { + // Walk the source path and tar each file and directory + if err := filepath.Walk(srcPath, tarWalkFn); err != nil { + return fmt.Errorf("failed to walk file: %w", err) + } + } + + return nil +} + +func compressFile(srcPath, dstPath string) error { + srcFile, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("failed to open source file: %w", err) + } + defer srcFile.Close() + + dstFile, err := os.Create(dstPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer dstFile.Close() + + encoder, err := zstd.NewWriter(dstFile) + if err != nil { + return fmt.Errorf("failed to create zstd encoder: %w", err) + } + defer encoder.Close() + + if _, err := io.Copy(encoder, srcFile); err != nil { + return fmt.Errorf("failed to compress file: %w", err) + } + + return nil +} + +func decompressFile(srcPath, dstPath string) error { + srcFile, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("failed to open source file: %w", err) + } + defer srcFile.Close() + + dstFile, err := os.Create(dstPath) + if err != nil { + return fmt.Errorf("failed to create destination file: %w", err) + } + defer dstFile.Close() + + decoder, err := zstd.NewReader(srcFile) + if err != nil { + return fmt.Errorf("failed to create zstd decoder: %w", err) + } + defer decoder.Close() + + if _, err := io.Copy(dstFile, decoder); err != nil { + return fmt.Errorf("failed to decompress file: %w", err) + } + + return nil +} + +func untar(tarPath, dstPath string) error { + fileName := filepath.Base(tarPath) + fmt.Printf("Unpack layer: %s\n", fileName) + + tarFile, err := os.Open(tarPath) + if err != nil { + return fmt.Errorf("failed to open tar file: %w", err) + } + defer tarFile.Close() + + tr := tar.NewReader(tarFile) + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar file: %w", err) + } + + target := filepath.Join(dstPath, header.Name) + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + case tar.TypeReg: + file, err := os.OpenFile(target, os.O_CREATE|os.O_RDWR, os.FileMode(header.Mode)) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + if _, err := io.Copy(file, tr); err != nil { + _ = file.Close() + return fmt.Errorf("failed to write file: %w", err) + } + _ = file.Close() + } + } + + return nil +} + +//func Tar(srcDir, destFilePath string) error { +// return tarit(srcDir, destFilePath) +//} + +func UnTar(srcFilePath, destDir string) error { + return untar(srcFilePath, destDir) +} + +func Compress(src, destFilePath string) error { + tempTarPath := destFilePath + ".tar" + if err := Tar(src, tempTarPath, ""); err != nil { + return fmt.Errorf("failed to tar: %w", err) + } + defer os.Remove(tempTarPath) + if err := compressFile(tempTarPath, destFilePath); err != nil { + return fmt.Errorf("failed to compress file: %w", err) + } + + return nil +} + +func Decompress(srcFilePath, destDir string) error { + tempTarPath := srcFilePath + ".tar" + if err := decompressFile(srcFilePath, tempTarPath); err != nil { + return fmt.Errorf("failed to decompress file: %w", err) + } + defer os.Remove(tempTarPath) + + if err := untar(tempTarPath, destDir); err != nil { + return fmt.Errorf("failed to untar: %w", err) + } + + return nil +} diff --git a/tools/mdctl/models/descriptor.go b/tools/mdctl/models/descriptor.go new file mode 100644 index 0000000..22772b2 --- /dev/null +++ b/tools/mdctl/models/descriptor.go @@ -0,0 +1,154 @@ +package models + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/opencontainers/go-digest" + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +const ( + TAR = 1 + ZSTD = 2 +) + +func FastDescriptor(srcPath string, media string) (*oci.Descriptor, error) { + var err error + srcPath, err = filepath.Abs(srcPath) + if err != nil { + return nil, fmt.Errorf("get real path failed: %w", err) + } + + bin, err := os.Open(srcPath) + if err != nil { + return nil, fmt.Errorf("failed to open source file: %w", err) + } + defer bin.Close() + + sha256sum := sha256.New() + n, err := io.Copy(sha256sum, bin) + if err != nil { + return nil, fmt.Errorf("failed to read for sha256: %w", err) + } + + return &oci.Descriptor{ + MediaType: media, + Digest: digest.Digest(fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))), + Size: n, + Annotations: map[string]string{ + "temp_name": srcPath, + }, + }, nil +} + +func BuildDescriptor(method int, src, media, newName string) (*oci.Descriptor, error) { + var err error + src, err = filepath.Abs(src) + if err != nil { + return nil, fmt.Errorf("get real path failed: %w", err) + } + + blobs, err := GetBlobsPath("") + if err != nil { + return nil, fmt.Errorf("failed to get blobs path: %w", err) + } + + delimiter := ":" + pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter) + temp, err := os.CreateTemp(blobs, pattern) + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + defer temp.Close() + + switch method { + case TAR: + err = Tar(src, temp.Name(), newName) + case ZSTD: + err = Compress(src, temp.Name()) + default: + err = Tar(src, temp.Name(), newName) + } + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + + bin, err := os.Open(temp.Name()) + if err != nil { + return nil, fmt.Errorf("failed to open source file: %w", err) + } + defer bin.Close() + + sha256sum := sha256.New() + n, err := io.Copy(sha256sum, bin) + if err != nil { + return nil, fmt.Errorf("failed to read for sha256: %w", err) + } + + return &oci.Descriptor{ + MediaType: media, + Digest: digest.Digest(fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))), + Size: n, + Annotations: map[string]string{ + "temp_name": temp.Name(), + }, + }, nil +} + +func NewDescriptor(r io.Reader, mediatype string) (*oci.Descriptor, error) { + blobs, err := GetBlobsPath("") + if err != nil { + return nil, fmt.Errorf("failed to get blobs path: %w", err) + } + + delimiter := ":" + pattern := strings.Join([]string{"sha256", "*-partial"}, delimiter) + temp, err := os.CreateTemp(blobs, pattern) + if err != nil { + return nil, fmt.Errorf("failed to create temp file: %w", err) + } + defer temp.Close() + + sha256sum := sha256.New() + n, err := io.Copy(io.MultiWriter(temp, sha256sum), r) + if err != nil { + return nil, fmt.Errorf("failed to read for sha256: %w", err) + } + + return &oci.Descriptor{ + MediaType: mediatype, + //Digest: fmt.Sprintf("sha256:%x", sha256sum.Sum(nil)), + Digest: digest.Digest(fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))), + Size: n, + Annotations: map[string]string{ + "temp_name": temp.Name(), + }, + }, nil +} + +func Commit(l oci.Descriptor) (bool, error) { + tempFileName := l.Annotations["temp_name"] + if tempFileName == "" { + return false, fmt.Errorf("temp file name is empty") + } + + // always remove temp + defer os.Remove(tempFileName) + defer delete(l.Annotations, "temp_name") + + blob, err := GetBlobsPath(l.Digest.String()) + if err != nil { + return false, fmt.Errorf("failed to get blobs path: %w", err) + } + + if _, err := os.Stat(blob); err != nil { + return true, os.Rename(tempFileName, blob) + } + + return false, nil +} diff --git a/tools/mdctl/models/image.go b/tools/mdctl/models/image.go new file mode 100644 index 0000000..f9ea751 --- /dev/null +++ b/tools/mdctl/models/image.go @@ -0,0 +1,45 @@ +package models + +import ( + "crypto/sha256" + "fmt" + "io" + "log" + "os" +) + +type RootFS struct { + Type string `json:"type"` + DiffIDs []string `json:"diff_ids"` +} + +// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer +func GetSHA256Digest(r io.Reader) (string, int64) { + h := sha256.New() + n, err := io.Copy(h, r) + if err != nil { + log.Fatal(err) + } + + return fmt.Sprintf("sha256:%x", h.Sum(nil)), n +} + +func verifyBlob(digest string) error { + fp, err := GetBlobsPath(digest) + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + + f, err := os.Open(fp) + if err != nil { + return fmt.Errorf("failed to open blob: %w", err) + } + defer f.Close() + + fileDigest, _ := GetSHA256Digest(f) + if digest != fileDigest { + return fmt.Errorf("digest mismatch: want %s, got %s", digest, fileDigest) + } + + return nil +} diff --git a/tools/mdctl/models/layers.go b/tools/mdctl/models/layers.go new file mode 100644 index 0000000..41fd980 --- /dev/null +++ b/tools/mdctl/models/layers.go @@ -0,0 +1,165 @@ +package models + +import ( + "errors" + "fmt" + "os" + + oci "github.com/opencontainers/image-spec/specs-go/v1" +) + +type Descriptors struct { + Items []oci.Descriptor +} + +func (ls *Descriptors) Add(layer *oci.Descriptor) { + if layer.Size > 0 { + ls.Items = append(ls.Items, *layer) + } +} + +func (ls *Descriptors) Replace(layer *oci.Descriptor) { + if layer.Size > 0 { + var newItems []oci.Descriptor + for _, item := range ls.Items { + if item.MediaType != layer.MediaType { + newItems = append(newItems, item) + } + } + ls.Items = append(newItems, *layer) + } +} + +func (ls *Descriptors) Delete(layer *oci.Descriptor) { + var newItems []oci.Descriptor + for _, item := range ls.Items { + if item.MediaType != layer.MediaType { + newItems = append(newItems, item) + } + } + ls.Items = newItems +} + +func (ls *Descriptors) AddFile(input string, media string) (*oci.Descriptor, error) { + var err error + fileInfo, err := os.Stat(input) + if err != nil { + return nil, fmt.Errorf("failed to stat input: %w", err) + } + if fileInfo.Mode().IsRegular() { + bin, err := os.Open(input) + if err != nil { + return nil, fmt.Errorf("failed to open input: %w", err) + } + defer bin.Close() + layer, err := NewDescriptor(bin, media) + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + ls.Add(layer) + return layer, nil + } + return nil, errors.New("not a regular file type") +} + +func (ls *Descriptors) AddCompress(srcDir string, media string) (*oci.Descriptor, error) { + layer, err := BuildDescriptor(ZSTD, srcDir, media, "") + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + //fmt.Printf("add layer: %v\n", layer) + ls.Add(layer) + return layer, nil +} + +func (ls *Descriptors) ReplaceCompress(srcDir string, media string) (*oci.Descriptor, error) { + layer, err := BuildDescriptor(ZSTD, srcDir, media, "") + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + //fmt.Printf("add layer: %v\n", layer) + ls.Replace(layer) + return layer, nil +} + +func (ls *Descriptors) AddTar(srcDir string, media string) (*oci.Descriptor, error) { + layer, err := BuildDescriptor(TAR, srcDir, media, "") + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + ls.Add(layer) + return layer, nil +} + +func (ls *Descriptors) ReplaceTar(srcDir string, media string) (*oci.Descriptor, error) { + layer, err := BuildDescriptor(TAR, srcDir, media, "") + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + ls.Replace(layer) + return layer, nil +} + +func (ls *Descriptors) AddTarWithNewName(src, media, newName string) (*oci.Descriptor, error) { + layer, err := BuildDescriptor(TAR, src, media, newName) + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + ls.Add(layer) + return layer, nil +} + +func (ls *Descriptors) ReplaceTarWithNewName(src, media, newName string) (*oci.Descriptor, error) { + layer, err := BuildDescriptor(TAR, src, media, newName) + if err != nil { + return nil, fmt.Errorf("failed to build descriptor: %w", err) + } + ls.Replace(layer) + return layer, nil +} + +// +//func (ls *Descriptors) CreateDescriptor(input string, mediatype string) (*spec.BuildDescriptor, error) { +// var layer *spec.BuildDescriptor +// var err error +// fileInfo, err := os.Stat(input) +// if err != nil { // Check error type +// bin := strings.NewReader(input) +// layer, err = NewDescriptor(bin, mediatype) +// if err != nil { +// return nil, err +// } +// ls.Add(layer) +// return layer, nil +// } +// if fileInfo.Mode().IsRegular() { +// bin, err := os.Open(input) +// if err != nil { +// return nil, err +// } +// layer, err := NewDescriptor(bin, mediatype) +// if err != nil { +// bin.Close() +// return nil, err +// } +// ls.Add(layer) +// bin.Close() +// } +// return layer, nil +//} + +func (ls *Descriptors) Commit() error { + // Commit every layer + for _, layer := range ls.Items { + committed, err := Commit(layer) + if err != nil { + return fmt.Errorf("failed to commit layer: %w", err) + } + status := "writing layer" + if !committed { + status = "layer already exists" + } + fmt.Printf("%s %s\n", status, layer.Digest) + } + return nil +} diff --git a/tools/mdctl/models/manifests.go b/tools/mdctl/models/manifests.go new file mode 100644 index 0000000..4b5b5d1 --- /dev/null +++ b/tools/mdctl/models/manifests.go @@ -0,0 +1,60 @@ +package models + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "os" + "path/filepath" + + modelspec "github.com/CloudNativeAI/model-spec/specs-go/v2" +) + +func WriteManifest(name string, manifest *modelspec.Manifest) error { + var b bytes.Buffer + if err := json.NewEncoder(&b).Encode(manifest); err != nil { + return fmt.Errorf("failed to encode manifest: %w", err) + } + + modelPath := ParseModelPath(name) + // modelPath := ParseModelPath("") + manifestPath, err := modelPath.GetManifestPath() + if err != nil { + return fmt.Errorf("failed to get manifest path: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil { + return fmt.Errorf("failed to mkdir all: %w", err) + } + + return os.WriteFile(manifestPath, b.Bytes(), 0644) +} + +func GetManifest(mp ModelPath) (*modelspec.Manifest, string, error) { + fp, err := mp.GetManifestPath() + if err != nil { + return nil, "", fmt.Errorf("failed to get manifest path: %w", err) + } + + if _, err = os.Stat(fp); err != nil { + return nil, "", fmt.Errorf("failed to stat manifest: %w", err) + } + + var manifest *modelspec.Manifest + + bts, err := os.ReadFile(fp) + if err != nil { + return nil, "", fmt.Errorf("failed to read manifest: %w", err) + } + + shaSum := sha256.Sum256(bts) + shaStr := hex.EncodeToString(shaSum[:]) + + if err := json.Unmarshal(bts, &manifest); err != nil { + return nil, "", fmt.Errorf("failed to unmarshal manifest: %w", err) + } + + return manifest, shaStr, nil +} diff --git a/tools/mdctl/models/modelpath.go b/tools/mdctl/models/modelpath.go new file mode 100644 index 0000000..a2711f9 --- /dev/null +++ b/tools/mdctl/models/modelpath.go @@ -0,0 +1,187 @@ +package models + +import ( + "errors" + "fmt" + "net/url" + "os" + "path/filepath" + "strings" +) + +const ( + DefaultProtocolScheme = "https" + DefaultRegistry = "registry.cnai.com" + DefaultNamespace = "sys" + DefaultTag = "latest" +) + +var ( + ErrInvalidImageFormat = errors.New("invalid models format") + ErrInvalidProtocol = errors.New("invalid protocol scheme") + ErrInsecureProtocol = errors.New("insecure protocol http") +) + +var errModelPathInvalid = errors.New("invalid models path") + +func realpath(mfDir, from string) string { + abspath, err := filepath.Abs(from) + if err != nil { + return from + } + + home, err := os.UserHomeDir() + if err != nil { + return abspath + } + + if from == "~" { + return home + } else if strings.HasPrefix(from, "~/") { + return filepath.Join(home, from[2:]) + } + + if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { + // this is a file relative to the Modelfile + return filepath.Join(mfDir, from) + } + + return abspath +} + +func ParseModelPath(name string) ModelPath { + mp := ModelPath{ + ProtocolScheme: DefaultProtocolScheme, + Registry: DefaultRegistry, + Namespace: DefaultNamespace, + Name: "", + Tag: DefaultTag, + } + + before, after, found := strings.Cut(name, "://") + if found { + mp.ProtocolScheme = before + name = after + } + + parts := strings.Split(name, string(os.PathSeparator)) + switch len(parts) { + case 3: + mp.Registry = parts[0] + mp.Namespace = parts[1] + mp.Name = parts[2] + case 2: + mp.Namespace = parts[0] + mp.Name = parts[1] + case 1: + mp.Name = parts[0] + } + + if repo, tag, found := strings.Cut(mp.Name, ":"); found { + mp.Name = repo + mp.Tag = tag + } + + return mp +} + +// ModelsDir returns the path to the models directory. +func ModelsDir() (string, error) { + if models, exists := os.LookupEnv("MODELS_DIR"); exists { + return models, nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home dir: %w", err) + } + return filepath.Join(home, ".models"), nil +} + +func GetManifestRoot() (string, error) { + dir, err := ModelsDir() + if err != nil { + return "", fmt.Errorf("failed to get models dir: %w", err) + } + + path := filepath.Join(dir, "manifests") + if err := os.MkdirAll(path, 0o755); err != nil { + return "", fmt.Errorf("failed to mkdir all: %w", err) + } + + return path, nil +} + +func GetBlobsPath(digest string) (string, error) { + dir, err := ModelsDir() + if err != nil { + return "", fmt.Errorf("failed to get models dir: %w", err) + } + + path := filepath.Join(dir, "blobs", digest) + dirPath := filepath.Dir(path) + if digest == "" { + dirPath = path + } + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return "", fmt.Errorf("failed to mkdir all: %w", err) + } + + return path, nil +} + +type ModelPath struct { + ProtocolScheme string + Registry string + Namespace string + Name string + Tag string +} + +func (mp ModelPath) Validate() error { + if mp.Name == "" { + return fmt.Errorf("%w: models repository Name is required", errModelPathInvalid) + } + + if strings.Contains(mp.Tag, ":") { + return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid) + } + + return nil +} + +func (mp ModelPath) GetNamespaceRepository() string { + return fmt.Sprintf("%s/%s", mp.Namespace, mp.Name) +} + +func (mp ModelPath) GetFullTagname() string { + return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Name, mp.Tag) +} + +func (mp ModelPath) GetShortTagname() string { + if mp.Registry == DefaultRegistry { + if mp.Namespace == DefaultNamespace { + return fmt.Sprintf("%s:%s", mp.Name, mp.Tag) + } + return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Name, mp.Tag) + } + return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Name, mp.Tag) +} + +// GetManifestRoot returns the path to the manifest file for the given models path, +// it is up to the caller to create the directory if it does not exist. +func (mp ModelPath) GetManifestPath() (string, error) { + dir, err := ModelsDir() + if err != nil { + return "", fmt.Errorf("failed to get models dir: %w", err) + } + + return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Name, mp.Tag), nil +} + +func (mp ModelPath) BaseURL() *url.URL { + return &url.URL{ + Scheme: mp.ProtocolScheme, + Host: mp.Registry, + } +} diff --git a/tools/mdctl/progress/bar.go b/tools/mdctl/progress/bar.go new file mode 100644 index 0000000..e9c63d9 --- /dev/null +++ b/tools/mdctl/progress/bar.go @@ -0,0 +1,215 @@ +package progress + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/CloudNativeAI/model-spec/tools/mdctl/format" + "golang.org/x/term" +) + +type Bar struct { + message string + messageWidth int + + maxValue int64 + initialValue int64 + currentValue int64 + + started time.Time + stopped time.Time + + maxBuckets int + buckets []bucket +} + +type bucket struct { + updated time.Time + value int64 +} + +func NewBar(message string, maxValue, initialValue int64) *Bar { + b := Bar{ + message: message, + messageWidth: -1, + maxValue: maxValue, + initialValue: initialValue, + currentValue: initialValue, + started: time.Now(), + maxBuckets: 10, + } + + if initialValue >= maxValue { + b.stopped = time.Now() + } + + return &b +} + +// formatDuration limits the rendering of a time.Duration to 2 units +func formatDuration(d time.Duration) string { + switch { + case d >= 100*time.Hour: + return "99h+" + case d >= time.Hour: + return fmt.Sprintf("%dh%dm", int(d.Hours()), int(d.Minutes())%60) + default: + return d.Round(time.Second).String() + } +} + +func (b *Bar) String() string { + termWidth, _, err := term.GetSize(int(os.Stderr.Fd())) + if err != nil { + termWidth = 80 + } + + var pre strings.Builder + if len(b.message) > 0 { + message := strings.TrimSpace(b.message) + if b.messageWidth > 0 && len(message) > b.messageWidth { + message = message[:b.messageWidth] + } + + fmt.Fprintf(&pre, "%s", message) + if padding := b.messageWidth - pre.Len(); padding > 0 { + pre.WriteString(repeat(" ", padding)) + } + + pre.WriteString(" ") + } + + fmt.Fprintf(&pre, "%3.0f%%", b.percent()) + + var suf strings.Builder + // max 13 characters: "999 MB/999 MB" + if b.stopped.IsZero() { + curValue := format.HumanBytes(b.currentValue) + suf.WriteString(repeat(" ", 6-len(curValue))) + suf.WriteString(curValue) + suf.WriteString("/") + + maxValue := format.HumanBytes(b.maxValue) + suf.WriteString(repeat(" ", 6-len(maxValue))) + suf.WriteString(maxValue) + } else { + maxValue := format.HumanBytes(b.maxValue) + suf.WriteString(repeat(" ", 6-len(maxValue))) + suf.WriteString(maxValue) + suf.WriteString(repeat(" ", 7)) + } + + rate := b.rate() + // max 10 characters: " 999 MB/s" + if b.stopped.IsZero() && rate > 0 { + suf.WriteString(" ") + humanRate := format.HumanBytes(int64(rate)) + suf.WriteString(repeat(" ", 6-len(humanRate))) + suf.WriteString(humanRate) + suf.WriteString("/s") + } else { + suf.WriteString(repeat(" ", 10)) + } + + // max 8 characters: " 59m59s" + if b.stopped.IsZero() && rate > 0 { + suf.WriteString(" ") + var remaining time.Duration + if rate > 0 { + remaining = time.Duration(int64(float64(b.maxValue-b.currentValue)/rate)) * time.Second + } + + humanRemaining := formatDuration(remaining) + suf.WriteString(repeat(" ", 6-len(humanRemaining))) + suf.WriteString(humanRemaining) + } else { + suf.WriteString(repeat(" ", 8)) + } + + var mid strings.Builder + // add 5 extra spaces: 2 boundary characters and 1 space at each end + f := termWidth - pre.Len() - suf.Len() - 5 + n := int(float64(f) * b.percent() / 100) + + mid.WriteString(" ▕") + + if n > 0 { + mid.WriteString(repeat("█", n)) + } + + if f-n > 0 { + mid.WriteString(repeat(" ", f-n)) + } + + mid.WriteString("▏ ") + + return pre.String() + mid.String() + suf.String() +} + +func (b *Bar) Set(value int64) { + if value >= b.maxValue { + value = b.maxValue + } + + b.currentValue = value + if b.currentValue >= b.maxValue { + b.stopped = time.Now() + } + + // throttle bucket updates to 1 per second + if len(b.buckets) == 0 || time.Since(b.buckets[len(b.buckets)-1].updated) > time.Second { + b.buckets = append(b.buckets, bucket{ + updated: time.Now(), + value: value, + }) + + if len(b.buckets) > b.maxBuckets { + b.buckets = b.buckets[1:] + } + } +} + +func (b *Bar) percent() float64 { + if b.maxValue > 0 { + return float64(b.currentValue) / float64(b.maxValue) * 100 + } + + return 0 +} + +func (b *Bar) rate() float64 { + var numerator, denominator float64 + + if !b.stopped.IsZero() { + numerator = float64(b.currentValue - b.initialValue) + denominator = b.stopped.Sub(b.started).Round(time.Second).Seconds() + } else { + switch len(b.buckets) { + case 0: + // noop + case 1: + numerator = float64(b.buckets[0].value - b.initialValue) + denominator = b.buckets[0].updated.Sub(b.started).Round(time.Second).Seconds() + default: + first, last := b.buckets[0], b.buckets[len(b.buckets)-1] + numerator = float64(last.value - first.value) + denominator = last.updated.Sub(first.updated).Round(time.Second).Seconds() + } + } + + if denominator != 0 { + return numerator / denominator + } + + return 0 +} + +func repeat(s string, n int) string { + if n > 0 { + return strings.Repeat(s, n) + } + + return "" +} diff --git a/tools/mdctl/progress/progress.go b/tools/mdctl/progress/progress.go new file mode 100644 index 0000000..78917e9 --- /dev/null +++ b/tools/mdctl/progress/progress.go @@ -0,0 +1,113 @@ +package progress + +import ( + "fmt" + "io" + "sync" + "time" +) + +type State interface { + String() string +} + +type Progress struct { + mu sync.Mutex + w io.Writer + + pos int + + ticker *time.Ticker + states []State +} + +func NewProgress(w io.Writer) *Progress { + p := &Progress{w: w} + go p.start() + return p +} + +func (p *Progress) stop() bool { + for _, state := range p.states { + if spinner, ok := state.(*Spinner); ok { + spinner.Stop() + } + } + + if p.ticker != nil { + p.ticker.Stop() + p.ticker = nil + p.render() + return true + } + + return false +} + +func (p *Progress) Stop() bool { + stopped := p.stop() + if stopped { + fmt.Fprint(p.w, "\n") + } + return stopped +} + +func (p *Progress) StopAndClear() bool { + fmt.Fprint(p.w, "\033[?25l") + defer fmt.Fprint(p.w, "\033[?25h") + + stopped := p.stop() + if stopped { + // clear all progress lines + for i := 0; i < p.pos; i++ { + if i > 0 { + fmt.Fprint(p.w, "\033[A") + } + fmt.Fprint(p.w, "\033[2K\033[1G") + } + } + + return stopped +} + +func (p *Progress) Add(key string, state State) { + p.mu.Lock() + defer p.mu.Unlock() + + p.states = append(p.states, state) +} + +func (p *Progress) render() error { + p.mu.Lock() + defer p.mu.Unlock() + + fmt.Fprint(p.w, "\033[?25l") + defer fmt.Fprint(p.w, "\033[?25h") + + // clear already rendered progress lines + for i := 0; i < p.pos; i++ { + if i > 0 { + fmt.Fprint(p.w, "\033[A") + } + fmt.Fprint(p.w, "\033[2K\033[1G") + } + + // render progress lines + for i, state := range p.states { + fmt.Fprint(p.w, state.String()) + if i < len(p.states)-1 { + fmt.Fprint(p.w, "\n") + } + } + + p.pos = len(p.states) + + return nil +} + +func (p *Progress) start() { + p.ticker = time.NewTicker(100 * time.Millisecond) + for range p.ticker.C { + p.render() + } +} diff --git a/tools/mdctl/progress/spinner.go b/tools/mdctl/progress/spinner.go new file mode 100644 index 0000000..02f3f9f --- /dev/null +++ b/tools/mdctl/progress/spinner.go @@ -0,0 +1,73 @@ +package progress + +import ( + "fmt" + "strings" + "time" +) + +type Spinner struct { + message string + messageWidth int + + parts []string + + value int + + ticker *time.Ticker + started time.Time + stopped time.Time +} + +func NewSpinner(message string) *Spinner { + s := &Spinner{ + message: message, + parts: []string{ + "⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏", + }, + started: time.Now(), + } + go s.start() + return s +} + +func (s *Spinner) String() string { + var sb strings.Builder + if len(s.message) > 0 { + message := strings.TrimSpace(s.message) + if s.messageWidth > 0 && len(message) > s.messageWidth { + message = message[:s.messageWidth] + } + + fmt.Fprintf(&sb, "%s", message) + if padding := s.messageWidth - sb.Len(); padding > 0 { + sb.WriteString(strings.Repeat(" ", padding)) + } + + sb.WriteString(" ") + } + + if s.stopped.IsZero() { + spinner := s.parts[s.value] + sb.WriteString(spinner) + sb.WriteString(" ") + } + + return sb.String() +} + +func (s *Spinner) start() { + s.ticker = time.NewTicker(100 * time.Millisecond) + for range s.ticker.C { + s.value = (s.value + 1) % len(s.parts) + if !s.stopped.IsZero() { + return + } + } +} + +func (s *Spinner) Stop() { + if s.stopped.IsZero() { + s.stopped = time.Now() + } +} diff --git a/tools/mdctl/registry/client.go b/tools/mdctl/registry/client.go new file mode 100644 index 0000000..bd3e687 --- /dev/null +++ b/tools/mdctl/registry/client.go @@ -0,0 +1,193 @@ +package registry + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "strings" + + modelspec "github.com/CloudNativeAI/model-spec/specs-go/v2" + "github.com/CloudNativeAI/model-spec/tools/mdctl/models" + "github.com/opencontainers/image-spec/specs-go" + "oras.land/oras-go/v2/content" + + oci "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/registry/remote" + "oras.land/oras-go/v2/registry/remote/auth" + "oras.land/oras-go/v2/registry/remote/retry" +) + +const ( + ArtifactType = "application/vnd.test.artifact" +) + +func NewRepo(ns, name string) (*remote.Repository, error) { + user, exists := os.LookupEnv("MODEL_REGISTRY_USER") + if !exists { + return nil, fmt.Errorf("username not found") + } + password, exists := os.LookupEnv("MODEL_REGISTRY_PASSWORD") + if !exists { + return nil, fmt.Errorf("password not found") + } + reg, exists := os.LookupEnv("MODEL_REGISTRY_URL") + if !exists { + return nil, fmt.Errorf("registry url not found") + } + + repo, err := remote.NewRepository(reg + "/" + ns + "/" + name) + if err != nil { + return nil, fmt.Errorf("failed to new repository: %w", err) + } + repo.PlainHTTP = true + + repo.Client = &auth.Client{ + Client: retry.DefaultClient, + Cache: auth.NewCache(), + Credential: auth.StaticCredential(reg, auth.Credential{ + Username: user, + Password: password, + }), + } + + return repo, nil +} + +func NewOciManifest(layers []oci.Descriptor) ([]byte, error) { + content := oci.Manifest{ + MediaType: oci.MediaTypeImageManifest, + ArtifactType: ArtifactType, + Config: oci.DescriptorEmptyJSON, + Layers: layers, + Versioned: specs.Versioned{SchemaVersion: 2}, + } + return json.Marshal(content) +} + +func PushLayer(repo *remote.Repository, ctx context.Context, descriptor *oci.Descriptor) (bool, error) { + layerPath, err := models.GetBlobsPath(descriptor.Digest.String()) + if err != nil { + return false, fmt.Errorf("failed to get blobs path: %w", err) + } + + layerFile, err := os.Open(layerPath) + if err != nil { + return false, fmt.Errorf("failed to open layer file: %w", err) + } + defer layerFile.Close() + + exist, err := repo.Exists(ctx, *descriptor) + if err != nil { + return false, fmt.Errorf("failed to check if layer exists: %w", err) + } + if exist { + return true, nil + } + + return false, repo.Push(ctx, *descriptor, layerFile) +} + +func PushModelManifest(repo *remote.Repository, ctx context.Context, modelManifestPath string) (*oci.Descriptor, error) { + modelManifestFile, err := os.Open(modelManifestPath) + if err != nil { + return nil, fmt.Errorf("failed to open model manifest file: %w", err) + } + defer modelManifestFile.Close() + + descriptor, err := models.FastDescriptor(modelManifestPath, modelspec.MediaTypeModelManifest) + if err != nil { + return nil, fmt.Errorf("failed to fast descriptor: %w", err) + } + + fmt.Println("Push manifest:", descriptor.Digest, descriptor.Size) + err = repo.Push(ctx, *descriptor, modelManifestFile) + if err != nil { + return nil, fmt.Errorf("failed to push model manifest: %w", err) + } + + return descriptor, nil +} + +func PushModel(repo *remote.Repository, tag string, ctx context.Context, layers []oci.Descriptor) error { + manifestBlob, err := NewOciManifest(layers) + if err != nil { + return fmt.Errorf("failed to new oci manifest: %w", err) + } + manifestDesc := content.NewDescriptorFromBytes(oci.MediaTypeImageManifest, manifestBlob) + + err = repo.PushReference(ctx, manifestDesc, bytes.NewReader(manifestBlob), tag) + if err != nil { + return fmt.Errorf("failed to push model: %w", err) + } + + return nil +} + +func PullImageManifest(repo *remote.Repository, ctx context.Context, tag string) (*oci.Manifest, error) { + descriptor, err := repo.Resolve(ctx, tag) + if err != nil { + return nil, fmt.Errorf("failed to resolve image manifest: %w", err) + } + + pulledBlob, err := content.FetchAll(ctx, repo, descriptor) + if err != nil { + return nil, fmt.Errorf("failed to fetch all: %w", err) + } + + manifest := oci.Manifest{} + if err := json.NewDecoder(bytes.NewReader(pulledBlob)).Decode(&manifest); err != nil { + return nil, fmt.Errorf("failed to decode manifest: %w", err) + } + + return &manifest, nil +} + +func PullLayer(repo *remote.Repository, ctx context.Context, digest string, size int64, targetPath string) error { + if fi, err := os.Stat(targetPath); err == nil && fi.Mode().IsRegular() && fi.Size() == size { + return nil + } + + descriptor, err := repo.Blobs().Resolve(ctx, digest) + if err != nil { + return fmt.Errorf("failed to resolve blob: %w", err) + } + + rc, err := repo.Fetch(ctx, descriptor) + if err != nil { + return fmt.Errorf("failed to fetch blob: %w", err) + } + defer rc.Close() + + pulledBlob, err := content.ReadAll(rc, descriptor) + if err != nil { + return fmt.Errorf("failed to read all: %w", err) + } + + blobs, err := models.GetBlobsPath("") + if err != nil { + return fmt.Errorf("failed to get blobs path: %w", err) + } + + delimiter := ":" + pattern := strings.Join([]string{"sha256", "*-downloading"}, delimiter) + temp, err := os.CreateTemp(blobs, pattern) + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + defer temp.Close() + + _, err = io.Copy(temp, bytes.NewReader(pulledBlob)) + if err != nil { + return fmt.Errorf("failed to copy blob to temp file: %w", err) + } + + err = os.Rename(temp.Name(), targetPath) + if err != nil { + return fmt.Errorf("failed to rename temp file: %w", err) + } + + return nil +} diff --git a/tools/mdctl/version/version.go b/tools/mdctl/version/version.go new file mode 100644 index 0000000..820e2f7 --- /dev/null +++ b/tools/mdctl/version/version.go @@ -0,0 +1,3 @@ +package version + +var Version string = "0.0.0" From c19996fead3350466023ffc4f2714ea34b2f498d Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 11:54:51 +0800 Subject: [PATCH 5/8] docs: improve readme for examples --- .../mdctl/examples/huggingface/gemma-2b/README.md | 14 +++++++++++++- tools/mdctl/examples/huggingface/gemma-2b/run.py | 10 ---------- 2 files changed, 13 insertions(+), 11 deletions(-) delete mode 100755 tools/mdctl/examples/huggingface/gemma-2b/run.py diff --git a/tools/mdctl/examples/huggingface/gemma-2b/README.md b/tools/mdctl/examples/huggingface/gemma-2b/README.md index 11a27f5..eda6742 100644 --- a/tools/mdctl/examples/huggingface/gemma-2b/README.md +++ b/tools/mdctl/examples/huggingface/gemma-2b/README.md @@ -1 +1,13 @@ -gemma-2b \ No newline at end of file +# How to use mdctl examples +## Download model +Download model from huggingface: +``` +git lfs install +git clone https://huggingface.co/gemma-ai/gemma-2b +``` + +## Build model image +Put the modelfile to the model directory and build model image: +``` +mdctl build -f Modelfile +``` diff --git a/tools/mdctl/examples/huggingface/gemma-2b/run.py b/tools/mdctl/examples/huggingface/gemma-2b/run.py deleted file mode 100755 index 4788a6d..0000000 --- a/tools/mdctl/examples/huggingface/gemma-2b/run.py +++ /dev/null @@ -1,10 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForCausalLM - -tokenizer = AutoTokenizer.from_pretrained("gemma-2b:latest") -model = AutoModelForCausalLM.from_pretrained("gemma-2b:latest") - -input_text = "Who are you?" -input_ids = tokenizer(input_text, return_tensors="pt") - -outputs = model.generate(**input_ids, max_length=64) -print(tokenizer.decode(outputs[0])) From ee6431c4261190ac3e62a8be812556107343812c Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 11:56:08 +0800 Subject: [PATCH 6/8] build: improve go mod --- .github/workflows/tools-lint.yml | 31 ------------------------------- go.mod | 13 +++++++++++++ tools/mdctl/go.sum => go.sum | 24 +++++++++++------------- tools/mdctl/go.mod | 25 ------------------------- 4 files changed, 24 insertions(+), 69 deletions(-) delete mode 100644 .github/workflows/tools-lint.yml rename tools/mdctl/go.sum => go.sum (52%) delete mode 100644 tools/mdctl/go.mod diff --git a/.github/workflows/tools-lint.yml b/.github/workflows/tools-lint.yml deleted file mode 100644 index ce23006..0000000 --- a/.github/workflows/tools-lint.yml +++ /dev/null @@ -1,31 +0,0 @@ -name: Lint - -on: - push: - branches: [main, release-*] - pull_request: - branches: [main, release-*] - -permissions: - contents: read - -jobs: - lint: - name: Lint - runs-on: ubuntu-latest - timeout-minutes: 30 - steps: - - name: Checkout code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 - - - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 - with: - go-version-file: tools/mdctl/go.mod - cache: false - - - name: Golangci lint - uses: golangci/golangci-lint-action@aaa42aa0628b4ae2578232a66b541047968fac86 - with: - version: v1.54 - args: --verbose - working-directory: tools/mdctl diff --git a/go.mod b/go.mod index 8e6befa..9bd5e0e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,16 @@ module github.com/CloudNativeAI/model-spec go 1.22.4 + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/klauspost/compress v1.17.10 // indirect + github.com/opencontainers/go-digest v1.0.0 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect + github.com/spf13/cobra v1.8.1 // indirect + github.com/spf13/pflag v1.0.5 // indirect + golang.org/x/sync v0.6.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/term v0.24.0 // indirect + oras.land/oras-go/v2 v2.5.0 // indirect +) diff --git a/tools/mdctl/go.sum b/go.sum similarity index 52% rename from tools/mdctl/go.sum rename to go.sum index 9ad0dc4..6bf2313 100644 --- a/tools/mdctl/go.sum +++ b/go.sum @@ -1,26 +1,24 @@ -github.com/CloudNativeAI/model-spec/specs-go v0.0.0-20240925072522-ca68e666bb02 h1:hldWO7cYXMsfCFjlQ2VcGd9PfQYt79sPN2mSjtHVrdc= -github.com/CloudNativeAI/model-spec/specs-go v0.0.0-20240925072522-ca68e666bb02/go.mod h1:aqXPi8WPdmWT8sUAQYi7gStLYBhiud0dIT75PskIYpE= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= -github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.10 h1:oXAz+Vh0PMUvJczoi+flxpnBEPxoER1IaAnU/NMPtT0= +github.com/klauspost/compress v1.17.10/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.24.0 h1:Mh5cbb+Zk2hqqXNO7S1iTjEphVL+jb8ZWaqh/g+JWkM= +golang.org/x/term v0.24.0/go.mod h1:lOBK/LVxemqiMij05LGJ0tzNr8xlmwBRJ81PX6wVLH8= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -oras.land/oras-go/v2 v2.4.0 h1:i+Wt5oCaMHu99guBD0yuBjdLvX7Lz8ukPbwXdR7uBMs= -oras.land/oras-go/v2 v2.4.0/go.mod h1:osvtg0/ClRq1KkydMAEu/IxFieyjItcsQ4ut4PPF+f8= +oras.land/oras-go/v2 v2.5.0 h1:o8Me9kLY74Vp5uw07QXPiitjsw7qNXi8Twd+19Zf02c= +oras.land/oras-go/v2 v2.5.0/go.mod h1:z4eisnLP530vwIOUOJeBIj0aGI0L1C3d53atvCBqZHg= diff --git a/tools/mdctl/go.mod b/tools/mdctl/go.mod deleted file mode 100644 index 5ab4d68..0000000 --- a/tools/mdctl/go.mod +++ /dev/null @@ -1,25 +0,0 @@ -module github.com/CloudNativeAI/model-spec/tools/mdctl - -go 1.22.4 - -require ( - github.com/klauspost/compress v1.17.7 - github.com/opencontainers/go-digest v1.0.0 - github.com/opencontainers/image-spec v1.1.0 - github.com/spf13/cobra v1.7.0 - oras.land/oras-go/v2 v2.4.0 -) - -require ( - github.com/CloudNativeAI/model-spec/specs-go v0.0.0-20240925072522-ca68e666bb02 // indirect - golang.org/x/sync v0.6.0 // indirect -) - -require ( - github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/sys v0.15.0 // indirect - golang.org/x/term v0.15.0 -) - -replace github.com/CloudNativeAI/model-spec/specs-go/ => ../../specs-go/ From f7899cf1a97794767e22e49d3f7a9009a321277f Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 16:15:31 +0800 Subject: [PATCH 7/8] chore: fix lint error --- tools/mdctl/cmd/build.go | 2 +- tools/mdctl/cmd/pull.go | 10 ++-- tools/mdctl/cmd/push.go | 6 +-- tools/mdctl/cmd/run.go | 88 +++++++++++++++----------------- tools/mdctl/format/parse.go | 7 ++- tools/mdctl/models/archiver.go | 14 ++--- tools/mdctl/models/image.go | 39 +++++++------- tools/mdctl/models/modelpath.go | 58 ++++++++++----------- tools/mdctl/progress/progress.go | 4 +- tools/mdctl/registry/client.go | 13 +++-- 10 files changed, 117 insertions(+), 124 deletions(-) diff --git a/tools/mdctl/cmd/build.go b/tools/mdctl/cmd/build.go index 5300c33..297feef 100644 --- a/tools/mdctl/cmd/build.go +++ b/tools/mdctl/cmd/build.go @@ -74,7 +74,7 @@ func BuildModel(commands []format.Command) error { config.Extensions = append(config.Extensions, *layer) fmt.Printf("Add config [%s]\n", c.Args) - case format.PARAM_SIZE: + case format.PARAMSIZE: engine.Name = c.Args case format.FORMAT: diff --git a/tools/mdctl/cmd/pull.go b/tools/mdctl/cmd/pull.go index 8b34449..526419e 100644 --- a/tools/mdctl/cmd/pull.go +++ b/tools/mdctl/cmd/pull.go @@ -21,7 +21,7 @@ func FetchManifest(name string, manifest *spec.Manifest, config *spec.Config) (* return nil, fmt.Errorf("failed to new repo: %w", err) } ctx := context.Background() - imageManifest, err := registry.PullImageManifest(repo, ctx, mp.Tag) + imageManifest, err := registry.PullImageManifest(ctx, repo, mp.Tag) if err != nil { return nil, fmt.Errorf("failed to pull image manifest: %w", err) } @@ -49,7 +49,7 @@ func FetchManifest(name string, manifest *spec.Manifest, config *spec.Config) (* defer temp.Close() //fmt.Println("Pull layer: ", layer.Digest, layer.Size) - err = registry.PullLayer(repo, ctx, layer.Digest.String(), layer.Size, temp.Name()) + err = registry.PullLayer(ctx, repo, layer.Digest.String(), layer.Size, temp.Name()) if err != nil { return nil, fmt.Errorf("failed to pull layer: %w", err) } @@ -81,12 +81,12 @@ func PullModel(name string) error { } ctx := context.Background() - image_manifest, err := registry.PullImageManifest(repo, ctx, mp.Tag) + imageManifest, err := registry.PullImageManifest(ctx, repo, mp.Tag) if err != nil { return fmt.Errorf("failed to pull image manifest: %w", err) } - for _, layer := range image_manifest.Layers { + for _, layer := range imageManifest.Layers { fmt.Println("Pull layer:", layer.Digest, layer.Size) digest := layer.Digest.String() @@ -104,7 +104,7 @@ func PullModel(name string) error { return fmt.Errorf("failed to get blobs path: %w", err) } - err = registry.PullLayer(repo, ctx, digest, layer.Size, targetPath) + err = registry.PullLayer(ctx, repo, digest, layer.Size, targetPath) if err != nil { return fmt.Errorf("failed to pull layer: %w", err) } diff --git a/tools/mdctl/cmd/push.go b/tools/mdctl/cmd/push.go index cc580e3..ad8a909 100644 --- a/tools/mdctl/cmd/push.go +++ b/tools/mdctl/cmd/push.go @@ -59,14 +59,14 @@ func PushModel(name string) error { for _, layer := range group.layers { fmt.Println("Push layer:", layer.Digest, layer.Size) - _, err := registry.PushLayer(repo, ctx, &layer) + _, err := registry.PushLayer(ctx, repo, &layer) if err != nil { return fmt.Errorf("failed to push layer: %w", err) } } } - manifestDesc, err := registry.PushModelManifest(repo, ctx, manifestPath) + manifestDesc, err := registry.PushModelManifest(ctx, repo, manifestPath) if err != nil { return fmt.Errorf("failed to push model manifest: %w", err) } @@ -79,7 +79,7 @@ func PushModel(name string) error { // assemble descriptors and model manifest to a image manifest layers = append(layers, *manifestDesc) - err = registry.PushModel(repo, mp.Tag, ctx, layers) + err = registry.PushModel(ctx, repo, mp.Tag, layers) if err != nil { return fmt.Errorf("failed to push oci image manifest: %w", err) } diff --git a/tools/mdctl/cmd/run.go b/tools/mdctl/cmd/run.go index 383c234..9621dcd 100644 --- a/tools/mdctl/cmd/run.go +++ b/tools/mdctl/cmd/run.go @@ -1,27 +1,23 @@ package cmd import ( - "bytes" "fmt" "os" - "os/exec" "path/filepath" "github.com/CloudNativeAI/model-spec/tools/mdctl/models" ) const ( - DOT_GITS_DIR = ".gits" - DOT_VOLUMES_DIR = ".volumes" - MODEL_DIR = "model" - DATASET_DIR = "dataset" - SOURCE_DIR = "source" - TASK_DIR = "task" - ENTRYPOINT = "run.py" - SETUP = "setup.sh" - CONFIG = "config.json" - INFO = "info.json" - LICENSE = "LICENSE" + MODELDIR = "model" + DATASETDIR = "dataset" + SOURCEDIR = "source" + TASKDIR = "task" + ENTRYPOINT = "run.py" + SETUP = "setup.sh" + CONFIG = "config.json" + INFO = "info.json" + LICENSE = "LICENSE" ) func RunModel(name string) error { @@ -108,45 +104,45 @@ func RunModel(name string) error { return nil } -func executeBinary(binaryPath string, args []string) (stdout string, stderr string, err error) { - cmd := exec.Command(binaryPath, args...) +// func executeBinary(binaryPath string, args []string) (stdout string, stderr string, err error) { +// cmd := exec.Command(binaryPath, args...) - var outBuf, errBuf bytes.Buffer - cmd.Stdout = &outBuf - cmd.Stderr = &errBuf +// var outBuf, errBuf bytes.Buffer +// cmd.Stdout = &outBuf +// cmd.Stderr = &errBuf - err = cmd.Run() - if err != nil { - return "", "", fmt.Errorf("failed to execute binary: %w", err) - } +// err = cmd.Run() +// if err != nil { +// return "", "", fmt.Errorf("failed to execute binary: %w", err) +// } - stdout = outBuf.String() - stderr = errBuf.String() +// stdout = outBuf.String() +// stderr = errBuf.String() - return stdout, stderr, nil -} +// return stdout, stderr, nil +// } -func executeScript(scriptPath string, args []string) (stdout string, stderr string, err error) { - var cmd *exec.Cmd - if bytes.HasSuffix([]byte(scriptPath), []byte(".sh")) { - cmd = exec.Command("bash", append([]string{scriptPath}, args...)...) - } else if bytes.HasSuffix([]byte(scriptPath), []byte(".py")) { - cmd = exec.Command("python3", append([]string{scriptPath}, args...)...) - } else { - return "", "", fmt.Errorf("unsupported script type: %s", scriptPath) - } +// func executeScript(scriptPath string, args []string) (stdout string, stderr string, err error) { +// var cmd *exec.Cmd +// if bytes.HasSuffix([]byte(scriptPath), []byte(".sh")) { +// cmd = exec.Command("bash", append([]string{scriptPath}, args...)...) +// } else if bytes.HasSuffix([]byte(scriptPath), []byte(".py")) { +// cmd = exec.Command("python3", append([]string{scriptPath}, args...)...) +// } else { +// return "", "", fmt.Errorf("unsupported script type: %s", scriptPath) +// } - var outBuf, errBuf bytes.Buffer - cmd.Stdout = &outBuf - cmd.Stderr = &errBuf +// var outBuf, errBuf bytes.Buffer +// cmd.Stdout = &outBuf +// cmd.Stderr = &errBuf - err = cmd.Run() - if err != nil { - return "", "", fmt.Errorf("failed to execute script: %w", err) - } +// err = cmd.Run() +// if err != nil { +// return "", "", fmt.Errorf("failed to execute script: %w", err) +// } - stdout = outBuf.String() - stderr = errBuf.String() +// stdout = outBuf.String() +// stderr = errBuf.String() - return stdout, stderr, nil -} +// return stdout, stderr, nil +// } diff --git a/tools/mdctl/format/parse.go b/tools/mdctl/format/parse.go index 222daae..58c96c6 100644 --- a/tools/mdctl/format/parse.go +++ b/tools/mdctl/format/parse.go @@ -17,7 +17,7 @@ const ( ARCHITECTURE = "architecture" LICENSE = "license" DESCRIPTION = "description" - PARAM_SIZE = "param_size" + PARAMSIZE = "param_size" WEIGHTS = "weights" TOKENIZER = "tokenizer" PRECISION = "precision" @@ -64,7 +64,7 @@ func Parse(reader io.Reader) ([]Command, error) { strings.ToUpper(FORMAT), strings.ToUpper(PRECISION), strings.ToUpper(QUANTIZATION), - strings.ToUpper(PARAM_SIZE), + strings.ToUpper(PARAMSIZE), strings.ToUpper(WEIGHTS), strings.ToUpper(CONFIG), strings.ToUpper(TOKENIZER): @@ -119,9 +119,8 @@ func scan(openBytes, closeBytes, data []byte, atEOF bool) (advance int, token [] if end < 0 { if atEOF { return 0, nil, fmt.Errorf("unterminated %s: expecting %s", openBytes, closeBytes) - } else { - return 0, nil, nil } + return 0, nil, nil } n := start + len(openBytes) + end + len(closeBytes) diff --git a/tools/mdctl/models/archiver.go b/tools/mdctl/models/archiver.go index 1420a99..2f3487e 100644 --- a/tools/mdctl/models/archiver.go +++ b/tools/mdctl/models/archiver.go @@ -18,14 +18,14 @@ func Tar(src, dst, newName string) error { if fi.IsDir() { return tarDirectory(src, dst) - } else { - out, err := os.Create(dst) - if err != nil { - return fmt.Errorf("failed to create tar file: %w", err) - } - defer out.Close() - return tarFile(src, newName, out) } + + out, err := os.Create(dst) + if err != nil { + return fmt.Errorf("failed to create tar file: %w", err) + } + defer out.Close() + return tarFile(src, newName, out) } func tarFile(src, newName string, writers ...io.Writer) error { diff --git a/tools/mdctl/models/image.go b/tools/mdctl/models/image.go index f9ea751..c6e9157 100644 --- a/tools/mdctl/models/image.go +++ b/tools/mdctl/models/image.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log" - "os" ) type RootFS struct { @@ -24,22 +23,22 @@ func GetSHA256Digest(r io.Reader) (string, int64) { return fmt.Sprintf("sha256:%x", h.Sum(nil)), n } -func verifyBlob(digest string) error { - fp, err := GetBlobsPath(digest) - if err != nil { - return fmt.Errorf("failed to get blobs path: %w", err) - } - - f, err := os.Open(fp) - if err != nil { - return fmt.Errorf("failed to open blob: %w", err) - } - defer f.Close() - - fileDigest, _ := GetSHA256Digest(f) - if digest != fileDigest { - return fmt.Errorf("digest mismatch: want %s, got %s", digest, fileDigest) - } - - return nil -} +// func verifyBlob(digest string) error { +// fp, err := GetBlobsPath(digest) +// if err != nil { +// return fmt.Errorf("failed to get blobs path: %w", err) +// } + +// f, err := os.Open(fp) +// if err != nil { +// return fmt.Errorf("failed to open blob: %w", err) +// } +// defer f.Close() + +// fileDigest, _ := GetSHA256Digest(f) +// if digest != fileDigest { +// return fmt.Errorf("digest mismatch: want %s, got %s", digest, fileDigest) +// } + +// return nil +// } diff --git a/tools/mdctl/models/modelpath.go b/tools/mdctl/models/modelpath.go index a2711f9..d65f856 100644 --- a/tools/mdctl/models/modelpath.go +++ b/tools/mdctl/models/modelpath.go @@ -24,30 +24,30 @@ var ( var errModelPathInvalid = errors.New("invalid models path") -func realpath(mfDir, from string) string { - abspath, err := filepath.Abs(from) - if err != nil { - return from - } - - home, err := os.UserHomeDir() - if err != nil { - return abspath - } - - if from == "~" { - return home - } else if strings.HasPrefix(from, "~/") { - return filepath.Join(home, from[2:]) - } - - if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { - // this is a file relative to the Modelfile - return filepath.Join(mfDir, from) - } - - return abspath -} +// func realpath(mfDir, from string) string { +// abspath, err := filepath.Abs(from) +// if err != nil { +// return from +// } + +// home, err := os.UserHomeDir() +// if err != nil { +// return abspath +// } + +// if from == "~" { +// return home +// } else if strings.HasPrefix(from, "~/") { +// return filepath.Join(home, from[2:]) +// } + +// if _, err := os.Stat(filepath.Join(mfDir, from)); err == nil { +// // this is a file relative to the Modelfile +// return filepath.Join(mfDir, from) +// } + +// return abspath +// } func ParseModelPath(name string) ModelPath { mp := ModelPath{ @@ -85,8 +85,8 @@ func ParseModelPath(name string) ModelPath { return mp } -// ModelsDir returns the path to the models directory. -func ModelsDir() (string, error) { +// ModelDir returns the path to the models directory. +func ModelDir() (string, error) { if models, exists := os.LookupEnv("MODELS_DIR"); exists { return models, nil } @@ -98,7 +98,7 @@ func ModelsDir() (string, error) { } func GetManifestRoot() (string, error) { - dir, err := ModelsDir() + dir, err := ModelDir() if err != nil { return "", fmt.Errorf("failed to get models dir: %w", err) } @@ -112,7 +112,7 @@ func GetManifestRoot() (string, error) { } func GetBlobsPath(digest string) (string, error) { - dir, err := ModelsDir() + dir, err := ModelDir() if err != nil { return "", fmt.Errorf("failed to get models dir: %w", err) } @@ -171,7 +171,7 @@ func (mp ModelPath) GetShortTagname() string { // GetManifestRoot returns the path to the manifest file for the given models path, // it is up to the caller to create the directory if it does not exist. func (mp ModelPath) GetManifestPath() (string, error) { - dir, err := ModelsDir() + dir, err := ModelDir() if err != nil { return "", fmt.Errorf("failed to get models dir: %w", err) } diff --git a/tools/mdctl/progress/progress.go b/tools/mdctl/progress/progress.go index 78917e9..b2572b7 100644 --- a/tools/mdctl/progress/progress.go +++ b/tools/mdctl/progress/progress.go @@ -37,7 +37,7 @@ func (p *Progress) stop() bool { if p.ticker != nil { p.ticker.Stop() p.ticker = nil - p.render() + _ = p.render() return true } @@ -108,6 +108,6 @@ func (p *Progress) render() error { func (p *Progress) start() { p.ticker = time.NewTicker(100 * time.Millisecond) for range p.ticker.C { - p.render() + _ = p.render() } } diff --git a/tools/mdctl/registry/client.go b/tools/mdctl/registry/client.go index bd3e687..c760e41 100644 --- a/tools/mdctl/registry/client.go +++ b/tools/mdctl/registry/client.go @@ -12,9 +12,8 @@ import ( modelspec "github.com/CloudNativeAI/model-spec/specs-go/v2" "github.com/CloudNativeAI/model-spec/tools/mdctl/models" "github.com/opencontainers/image-spec/specs-go" - "oras.land/oras-go/v2/content" - oci "github.com/opencontainers/image-spec/specs-go/v1" + "oras.land/oras-go/v2/content" "oras.land/oras-go/v2/registry/remote" "oras.land/oras-go/v2/registry/remote/auth" "oras.land/oras-go/v2/registry/remote/retry" @@ -67,7 +66,7 @@ func NewOciManifest(layers []oci.Descriptor) ([]byte, error) { return json.Marshal(content) } -func PushLayer(repo *remote.Repository, ctx context.Context, descriptor *oci.Descriptor) (bool, error) { +func PushLayer(ctx context.Context, repo *remote.Repository, descriptor *oci.Descriptor) (bool, error) { layerPath, err := models.GetBlobsPath(descriptor.Digest.String()) if err != nil { return false, fmt.Errorf("failed to get blobs path: %w", err) @@ -90,7 +89,7 @@ func PushLayer(repo *remote.Repository, ctx context.Context, descriptor *oci.Des return false, repo.Push(ctx, *descriptor, layerFile) } -func PushModelManifest(repo *remote.Repository, ctx context.Context, modelManifestPath string) (*oci.Descriptor, error) { +func PushModelManifest(ctx context.Context, repo *remote.Repository, modelManifestPath string) (*oci.Descriptor, error) { modelManifestFile, err := os.Open(modelManifestPath) if err != nil { return nil, fmt.Errorf("failed to open model manifest file: %w", err) @@ -111,7 +110,7 @@ func PushModelManifest(repo *remote.Repository, ctx context.Context, modelManife return descriptor, nil } -func PushModel(repo *remote.Repository, tag string, ctx context.Context, layers []oci.Descriptor) error { +func PushModel(ctx context.Context, repo *remote.Repository, tag string, layers []oci.Descriptor) error { manifestBlob, err := NewOciManifest(layers) if err != nil { return fmt.Errorf("failed to new oci manifest: %w", err) @@ -126,7 +125,7 @@ func PushModel(repo *remote.Repository, tag string, ctx context.Context, layers return nil } -func PullImageManifest(repo *remote.Repository, ctx context.Context, tag string) (*oci.Manifest, error) { +func PullImageManifest(ctx context.Context, repo *remote.Repository, tag string) (*oci.Manifest, error) { descriptor, err := repo.Resolve(ctx, tag) if err != nil { return nil, fmt.Errorf("failed to resolve image manifest: %w", err) @@ -145,7 +144,7 @@ func PullImageManifest(repo *remote.Repository, ctx context.Context, tag string) return &manifest, nil } -func PullLayer(repo *remote.Repository, ctx context.Context, digest string, size int64, targetPath string) error { +func PullLayer(ctx context.Context, repo *remote.Repository, digest string, size int64, targetPath string) error { if fi, err := os.Stat(targetPath); err == nil && fi.Mode().IsRegular() && fi.Size() == size { return nil } From aea8ab31014540009968901294b923f2f6988e48 Mon Sep 17 00:00:00 2001 From: Zhao Chen Date: Thu, 26 Sep 2024 17:44:43 +0800 Subject: [PATCH 8/8] docs: fix lint error --- docs/v2/modelfile.md | 31 +++++++++++++------ docs/v2/tool.md | 14 ++++----- .../examples/huggingface/gemma-2b/README.md | 9 ++++-- 3 files changed, 36 insertions(+), 18 deletions(-) diff --git a/docs/v2/modelfile.md b/docs/v2/modelfile.md index c14a2c1..f461890 100644 --- a/docs/v2/modelfile.md +++ b/docs/v2/modelfile.md @@ -1,8 +1,9 @@ +# Introduction to Modelfile -### Modelfile A Modelfile is a text file containing all commands, in order, needed to build a given model image. It automates the process of building model images. -#### Modelfile Instructions +## Modelfile Instructions + | **Instruction** | **Description** | | --- | --- | | CREATE | Create a new model image | @@ -16,65 +17,77 @@ A Modelfile is a text file containing all commands, in order, needed to build a | FORMAT | Specify model weights format | | TOKENIZER | Specify tokenizer configuration | -#### Modelfile Example +## Modelfile Example + ```plain CREATE registry.cnai.com/sys/gemma-2b:latest # Model Information + NAME gemma-2b FAMILY gemma ARCHITECTURE transformer FORMAT safetensors # Model License + LICENSE examples/huggingface/gemma-2b/LICENSE # Model Configuration + CONFIG examples/huggingface/gemma-2b/config.json CONFIG examples/huggingface/gemma-2b/generation_config.json # Model Tokenizer + TOKENIZER examples/huggingface/gemma-2b/tokenizer.json # Model Weights + WEIGHTS examples/huggingface/gemma-2b/model.safetensors.index.json WEIGHTS examples/huggingface/gemma-2b/model-00001-of-00002.safetensors WEIGHTS examples/huggingface/gemma-2b/model-00002-of-00002.safetensors ``` -### Management tool +## Management tool + We propose a model management tool, which is a command-line tool for building, managing, and running AI models. -#### build +### build + We can use Modelfile to build model images. ```plain mdctl build -f ./Modelfile ``` -#### list +### list + We can list all the model images that have been pushed. ```plain mdctl list ``` -#### push +### push + We can push the built model image to a model repository. ```plain mdctl push ``` -#### pull +### pull + We can pull the model image from the model repository to local storage. ```plain mdctl pull ``` -#### unpack +### unpack + We can pull the model image to local storage and then use mdctl to run the model. ```plain diff --git a/docs/v2/tool.md b/docs/v2/tool.md index 7c32598..4aead27 100644 --- a/docs/v2/tool.md +++ b/docs/v2/tool.md @@ -6,7 +6,7 @@ To install `mdctl`, clone the repository and build the binary: -``` +```plain git clone https://github.com/CloudNativeAI/mdctl.git cd mdctl go build @@ -16,36 +16,36 @@ go build To build a model, use the `build` command: -``` +```plain ./mdctl build -f Modelfile ``` To list all models, use the `list` command: -``` +```plain ./mdctl list ``` To push a model, use the `push` command. Before pushing, you need to set the model registry credentials: -``` +```plain export MODEL_REGISTRY_USER= export MODEL_REGISTRY_PASSWORD= export MODEL_REGISTRY_URL= ``` -``` +```plain ./mdctl push ``` To pull a model, use the `pull` command: -``` +```plain ./mdctl pull ``` To run a model, use the `unpack` command: -``` +```plain ./mdctl unpack -n ``` diff --git a/tools/mdctl/examples/huggingface/gemma-2b/README.md b/tools/mdctl/examples/huggingface/gemma-2b/README.md index eda6742..422e95f 100644 --- a/tools/mdctl/examples/huggingface/gemma-2b/README.md +++ b/tools/mdctl/examples/huggingface/gemma-2b/README.md @@ -1,13 +1,18 @@ # How to use mdctl examples + ## Download model + Download model from huggingface: -``` + +```plain git lfs install git clone https://huggingface.co/gemma-ai/gemma-2b ``` ## Build model image + Put the modelfile to the model directory and build model image: -``` + +```plain mdctl build -f Modelfile ```