Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
func tfGPUPackage(ver string, cuda string) (name string, cpuVersion string, err error) {
for _, compat := range TFCompatibilityMatrix {
if compat.TF == ver && version.Equal(compat.CUDA, cuda) {
return splitPythonPackage(compat.TFGPUPackage)
return splitPinnedPythonRequirement(compat.TFGPUPackage)
}
}
// We've already warned user if they're doing something stupid in validateAndCompleteCUDA(), so fail silently
Expand Down
105 changes: 57 additions & 48 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package config

import (
"bufio"
"fmt"
"os"
"path"
"regexp"
"strings"

"gopkg.in/yaml.v2"
Expand All @@ -13,19 +17,20 @@ import (
// TODO(andreas): support conda packages
// TODO(andreas): support dockerfiles
// TODO(andreas): custom cpu/gpu installs
// TODO(andreas): validate python_requirements
// TODO(andreas): suggest valid torchvision versions (e.g. if the user wants to use 0.8.0, suggest 0.8.1)

type Build struct {
GPU bool `json:"gpu,omitempty" yaml:"gpu"`
PythonVersion string `json:"python_version,omitempty" yaml:"python_version"`
PythonRequirements string `json:"python_requirements,omitempty" yaml:"python_requirements"`
PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages"`
PythonPackages []string `json:"python_packages,omitempty" yaml:"python_packages"` // Deprecated, but included for backwards compatibility
Run []string `json:"run,omitempty" yaml:"run"`
SystemPackages []string `json:"system_packages,omitempty" yaml:"system_packages"`
PreInstall []string `json:"pre_install,omitempty" yaml:"pre_install"` // Deprecated, but included for backwards compatibility
CUDA string `json:"cuda,omitempty" yaml:"cuda"`
CuDNN string `json:"cudnn,omitempty" yaml:"cudnn"`

pythonRequirementsContent []string
}

type Example struct {
Expand Down Expand Up @@ -92,11 +97,9 @@ func (c *Config) cudaFromTF() (tfVersion string, tfCUDA string, tfCuDNN string,
}

func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
for _, pkg := range c.Build.PythonPackages {
pkgName, version, err := splitPythonPackage(pkg)
for _, pkg := range c.Build.pythonRequirementsContent {
pkgName, version, err := splitPinnedPythonRequirement(pkg)
if err != nil {
// this should be caught by validation earlier
console.Warnf("Python package %s is malformed.", pkg)
return "", false
}
if pkgName == name {
Expand All @@ -106,7 +109,7 @@ func (c *Config) pythonPackageVersion(name string) (version string, ok bool) {
return "", false
}

func (c *Config) ValidateAndCompleteConfig() error {
func (c *Config) ValidateAndComplete(projectDir string) error {
// TODO(andreas): return all errors at once, rather than
// whack-a-mole one at a time with errs := []error{}, etc.

Expand All @@ -124,47 +127,72 @@ func (c *Config) ValidateAndCompleteConfig() error {
}
}

if err := c.validatePythonPackagesHaveVersions(); err != nil {
return err
if len(c.Build.PythonPackages) > 0 && c.Build.PythonRequirements != "" {
return fmt.Errorf("Only one of python_packages or python_requirements can be set in your cog.yaml, not both")
}

if c.Build.GPU {
if err := c.validateAndCompleteCUDA(); err != nil {
// Load python_requirements into memory to simplify reading it multiple times
if c.Build.PythonRequirements != "" {
fh, err := os.Open(path.Join(projectDir, c.Build.PythonRequirements))
if err != nil {
return err
}
// Use scanner to handle CRLF endings
scanner := bufio.NewScanner(fh)
for scanner.Scan() {
c.Build.pythonRequirementsContent = append(c.Build.pythonRequirementsContent, scanner.Text())
}
}

if len(c.Build.PythonPackages) > 0 && c.Build.PythonRequirements != "" {
return fmt.Errorf("Only one of python_packages or python_requirements can be set in your cog.yaml, not both")
// Backwards compatibility
if len(c.Build.PythonPackages) > 0 {
c.Build.pythonRequirementsContent = c.Build.PythonPackages
}

if c.Build.GPU {
if err := c.validateAndCompleteCUDA(); err != nil {
return err
}
}

return nil
}

func (c *Config) PythonPackagesForArch(goos string, goarch string) (packages []string, indexURLs []string, err error) {
packages = []string{}
// PythonRequirementsForArch returns a requirements.txt file with all the GPU packages resolved for given OS and architecture.
func (c *Config) PythonRequirementsForArch(goos string, goarch string) (string, error) {
packages := []string{}
indexURLSet := map[string]bool{}
for _, pkg := range c.Build.PythonPackages {
for _, pkg := range c.Build.pythonRequirementsContent {
archPkg, indexURL, err := c.pythonPackageForArch(pkg, goos, goarch)
if err != nil {
return nil, nil, err
return "", err
}
packages = append(packages, archPkg)
if indexURL != "" {
indexURLSet[indexURL] = true
}
}
indexURLs = []string{}

// Create final requirements.txt output
// Put index URLs first
lines := []string{}
for indexURL := range indexURLSet {
indexURLs = append(indexURLs, indexURL)
lines = append(lines, "--find-links "+indexURL)
Comment thread
nickstenning marked this conversation as resolved.
}
return packages, indexURLs, nil

// Then, everything else
lines = append(lines, packages...)

return strings.Join(lines, "\n"), nil
}

// pythonPackageForArch takes a package==version line and
// returns a package==version and index URL resolved to the correct GPU package for the given OS and architecture
func (c *Config) pythonPackageForArch(pkg string, goos string, goarch string) (actualPackage string, indexURL string, err error) {
name, version, err := splitPythonPackage(pkg)
name, version, err := splitPinnedPythonRequirement(pkg)
if err != nil {
return "", "", err
// It's not pinned, so just return the line verbatim
return pkg, "", nil
}
if name == "tensorflow" {
if c.Build.GPU {
Expand Down Expand Up @@ -292,34 +320,15 @@ Compatible cuDNN version is: %s`,
return nil
}

func (c *Config) validatePythonPackagesHaveVersions() error {
packagesWithoutVersions := []string{}
for _, pkg := range c.Build.PythonPackages {
_, _, err := splitPythonPackage(pkg)
if err != nil {
packagesWithoutVersions = append(packagesWithoutVersions, pkg)
}
}
if len(packagesWithoutVersions) > 0 {
return fmt.Errorf(`All Python packages must have pinned versions, e.g. mypkg==1.0.0.
The following packages are missing pinned versions: %s`, strings.Join(packagesWithoutVersions, ","))
}
return nil
}
// splitPythonPackage returns the name and version from a requirements.txt line in the form name==version
func splitPinnedPythonRequirement(requirement string) (name string, version string, err error) {
pinnedPackageRe := regexp.MustCompile(`^([a-zA-Z0-9\-_]+)==([\d\.]+)$`)

func splitPythonPackage(pkg string) (name string, version string, err error) {
if strings.HasPrefix(pkg, "git+") {
return pkg, "", nil
}

if !strings.Contains(pkg, "==") {
return "", "", fmt.Errorf("Package %s is not in the format 'name==version'", pkg)
}
parts := strings.Split(pkg, "==")
if len(parts) != 2 {
return "", "", fmt.Errorf("Package %s is not in the format 'name==version'", pkg)
match := pinnedPackageRe.FindStringSubmatch(requirement)
if match == nil {
return "", "", fmt.Errorf("Package %s is not in the format 'name==version'", requirement)
}
return parts[0], parts[1], nil
return match[1], match[2], nil
}

func sliceContains(slice []string, s string) bool {
Expand Down
Loading