tensorflow - ActiveState ActiveGo 1.8
...

Package tensorflow

import "github.com/tensorflow/tensorflow/tensorflow/go"
Overview
Index
Examples
Subdirectories

Overview ▾

Package tensorflow is a Go binding to TensorFlow.

The API is subject to change and may break at any time.

TensorFlow (www.tensorflow.org) is an open source software library for numerical computation using data flow graphs. This package provides functionality to build and execute such graphs and depends on TensorFlow being available. For installation instructions see https://www.tensorflow.org/code/tensorflow/go/README.md

Example

Code:

package tensorflow_test

import (
    "archive/zip"
    "bufio"
    "flag"
    "fmt"
    "io"
    "io/ioutil"
    "log"
    "net/http"
    "os"
    "path/filepath"

    tf "github.com/tensorflow/tensorflow/tensorflow/go"
    "github.com/tensorflow/tensorflow/tensorflow/go/op"
)

func Example() {
    // An example for using the TensorFlow Go API for image recognition
    // using a pre-trained inception model (http://arxiv.org/abs/1512.00567).
    //
    // Sample usage: <program> -dir=/tmp/modeldir -image=/path/to/some/jpeg
    //
    // The pre-trained model takes input in the form of a 4-dimensional
    // tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ],
    // where:
    // - BATCH_SIZE allows for inference of multiple images in one pass through the graph
    // - IMAGE_HEIGHT is the height of the images on which the model was trained
    // - IMAGE_WIDTH is the width of the images on which the model was trained
    // - 3 is the (R, G, B) values of the pixel colors represented as a float.
    //
    // And produces as output a vector with shape [ NUM_LABELS ].
    // output[i] is the probability that the input image was recognized as
    // having the i-th label.
    //
    // A separate file contains a list of string labels corresponding to the
    // integer indices of the output.
    //
    // This example:
    // - Loads the serialized representation of the pre-trained model into a Graph
    // - Creates a Session to execute operations on the Graph
    // - Converts an image file to a Tensor to provide as input to a Session run
    // - Executes the Session and prints out the label with the highest probability
    //
    // To convert an image file to a Tensor suitable for input to the Inception model,
    // this example:
    // - Constructs another TensorFlow graph to normalize the image into a
    //   form suitable for the model (for example, resizing the image)
    // - Creates an executes a Session to obtain a Tensor in this normalized form.
    modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary")
    imagefile := flag.String("image", "", "Path of a JPEG-image to extract labels for")
    flag.Parse()
    if *modeldir == "" || *imagefile == "" {
        flag.Usage()
        return
    }
    // Load the serialized GraphDef from a file.
    modelfile, labelsfile, err := modelFiles(*modeldir)
    if err != nil {
        log.Fatal(err)
    }
    model, err := ioutil.ReadFile(modelfile)
    if err != nil {
        log.Fatal(err)
    }

    // Construct an in-memory graph from the serialized form.
    graph := tf.NewGraph()
    if err := graph.Import(model, ""); err != nil {
        log.Fatal(err)
    }

    // Create a session for inference over graph.
    session, err := tf.NewSession(graph, nil)
    if err != nil {
        log.Fatal(err)
    }
    defer session.Close()

    // Run inference on *imageFile.
    // For multiple images, session.Run() can be called in a loop (and
    // concurrently). Alternatively, images can be batched since the model
    // accepts batches of image data as input.
    tensor, err := makeTensorFromImage(*imagefile)
    if err != nil {
        log.Fatal(err)
    }
    output, err := session.Run(
        map[tf.Output]*tf.Tensor{
            graph.Operation("input").Output(0): tensor,
        },
        []tf.Output{
            graph.Operation("output").Output(0),
        },
        nil)
    if err != nil {
        log.Fatal(err)
    }
    // output[0].Value() is a vector containing probabilities of
    // labels for each image in the "batch". The batch size was 1.
    // Find the most probably label index.
    probabilities := output[0].Value().([][]float32)[0]
    printBestLabel(probabilities, labelsfile)
}

func printBestLabel(probabilities []float32, labelsFile string) {
    bestIdx := 0
    for i, p := range probabilities {
        if p > probabilities[bestIdx] {
            bestIdx = i
        }
    }
    // Found the best match. Read the string from labelsFile, which
    // contains one line per label.
    file, err := os.Open(labelsFile)
    if err != nil {
        log.Fatal(err)
    }
    defer file.Close()
    scanner := bufio.NewScanner(file)
    var labels []string
    for scanner.Scan() {
        labels = append(labels, scanner.Text())
    }
    if err := scanner.Err(); err != nil {
        log.Printf("ERROR: failed to read %s: %v", labelsFile, err)
    }
    fmt.Printf("BEST MATCH: (%2.0f%% likely) %s\n", probabilities[bestIdx]*100.0, labels[bestIdx])
}

// Convert the image in filename to a Tensor suitable as input to the Inception model.
func makeTensorFromImage(filename string) (*tf.Tensor, error) {
    bytes, err := ioutil.ReadFile(filename)
    if err != nil {
        return nil, err
    }
    // DecodeJpeg uses a scalar String-valued tensor as input.
    tensor, err := tf.NewTensor(string(bytes))
    if err != nil {
        return nil, err
    }
    // Construct a graph to normalize the image
    graph, input, output, err := constructGraphToNormalizeImage()
    if err != nil {
        return nil, err
    }
    // Execute that graph to normalize this one image
    session, err := tf.NewSession(graph, nil)
    if err != nil {
        return nil, err
    }
    defer session.Close()
    normalized, err := session.Run(
        map[tf.Output]*tf.Tensor{input: tensor},
        []tf.Output{output},
        nil)
    if err != nil {
        return nil, err
    }
    return normalized[0], nil
}

// The inception model takes as input the image described by a Tensor in a very
// specific normalized format (a particular image size, shape of the input tensor,
// normalized pixel values etc.).
//
// This function constructs a graph of TensorFlow operations which takes as
// input a JPEG-encoded string and returns a tensor suitable as input to the
// inception model.
func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output, err error) {
    // Some constants specific to the pre-trained model at:
    // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
    //
    // - The model was trained after with images scaled to 224x224 pixels.
    // - The colors, represented as R, G, B in 1-byte each were converted to
    //   float using (value - Mean)/Scale.
    const (
        H, W  = 224, 224
        Mean  = float32(117)
        Scale = float32(1)
    )
    // - input is a String-Tensor, where the string the JPEG-encoded image.
    // - The inception model takes a 4D tensor of shape
    //   [BatchSize, Height, Width, Colors=3], where each pixel is
    //   represented as a triplet of floats
    // - Apply normalization on each pixel and use ExpandDims to make
    //   this single image be a "batch" of size 1 for ResizeBilinear.
    s := op.NewScope()
    input = op.Placeholder(s, tf.String)
    output = op.Div(s,
        op.Sub(s,
            op.ResizeBilinear(s,
                op.ExpandDims(s,
                    op.Cast(s,
                        op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)), tf.Float),
                    op.Const(s.SubScope("make_batch"), int32(0))),
                op.Const(s.SubScope("size"), []int32{H, W})),
            op.Const(s.SubScope("mean"), Mean)),
        op.Const(s.SubScope("scale"), Scale))
    graph, err = s.Finalize()
    return graph, input, output, err
}

func modelFiles(dir string) (modelfile, labelsfile string, err error) {
    const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"
    var (
        model   = filepath.Join(dir, "tensorflow_inception_graph.pb")
        labels  = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt")
        zipfile = filepath.Join(dir, "inception5h.zip")
    )
    if filesExist(model, labels) == nil {
        return model, labels, nil
    }
    log.Println("Did not find model in", dir, "downloading from", URL)
    if err := os.MkdirAll(dir, 0755); err != nil {
        return "", "", err
    }
    if err := download(URL, zipfile); err != nil {
        return "", "", fmt.Errorf("failed to download %v - %v", URL, err)
    }
    if err := unzip(dir, zipfile); err != nil {
        return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err)
    }
    os.Remove(zipfile)
    return model, labels, filesExist(model, labels)
}

func filesExist(files ...string) error {
    for _, f := range files {
        if _, err := os.Stat(f); err != nil {
            return fmt.Errorf("unable to stat %s: %v", f, err)
        }
    }
    return nil
}

func download(URL, filename string) error {
    resp, err := http.Get(URL)
    if err != nil {
        return err
    }
    defer resp.Body.Close()
    file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
    if err != nil {
        return err
    }
    defer file.Close()
    _, err = io.Copy(file, resp.Body)
    return err
}

func unzip(dir, zipfile string) error {
    r, err := zip.OpenReader(zipfile)
    if err != nil {
        return err
    }
    defer r.Close()
    for _, f := range r.File {
        src, err := f.Open()
        if err != nil {
            return err
        }
        log.Println("Extracting", f.Name)
        dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644)
        if err != nil {
            return err
        }
        if _, err := io.Copy(dst, src); err != nil {
            return err
        }
        dst.Close()
    }
    return nil
}

Index ▾

func Version() string
type DataType
type Graph
    func NewGraph() *Graph
    func (g *Graph) AddOperation(args OpSpec) (*Operation, error)
    func (g *Graph) Import(def []byte, prefix string) error
    func (g *Graph) Operation(name string) *Operation
    func (g *Graph) WriteTo(w io.Writer) (int64, error)
type Input
type OpSpec
type Operation
    func (op *Operation) Name() string
    func (op *Operation) NumOutputs() int
    func (op *Operation) Output(i int) Output
    func (op *Operation) OutputListSize(output string) (int, error)
    func (op *Operation) Type() string
type Output
    func (p Output) DataType() DataType
    func (p Output) Shape() Shape
type OutputList
type PartialRun
    func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error)
type SavedModel
    func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error)
type Session
    func NewSession(graph *Graph, options *SessionOptions) (*Session, error)
    func (s *Session) Close() error
    func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error)
    func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error)
type SessionOptions
type Shape
    func MakeShape(shape ...int64) Shape
    func ScalarShape() Shape
    func (s Shape) IsFullySpecified() bool
    func (s Shape) NumDimensions() int
    func (s Shape) Size(dim int) int64
    func (s Shape) String() string
    func (s Shape) ToSlice() ([]int64, error)
type Tensor
    func NewTensor(value interface{}) (*Tensor, error)
    func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error)
    func (t *Tensor) DataType() DataType
    func (t *Tensor) Shape() []int64
    func (t *Tensor) Value() interface{}
    func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error)

Examples

Package
PartialRun

Package files

doc.go graph.go lib.go operation.go saved_model.go session.go shape.go status.go tensor.go version.go

func Version

func Version() string

Version returns a string describing the version of the underlying TensorFlow runtime.

type DataType

DataType holds the type for a scalar value. E.g., one slot in a tensor.

type DataType C.TF_DataType

Types of scalar values in the TensorFlow type system.

const (
    Float      DataType = C.TF_FLOAT
    Double     DataType = C.TF_DOUBLE
    Int32      DataType = C.TF_INT32
    Uint8      DataType = C.TF_UINT8
    Int16      DataType = C.TF_INT16
    Int8       DataType = C.TF_INT8
    String     DataType = C.TF_STRING
    Complex64  DataType = C.TF_COMPLEX64
    Complex    DataType = C.TF_COMPLEX
    Int64      DataType = C.TF_INT64
    Bool       DataType = C.TF_BOOL
    Qint8      DataType = C.TF_QINT8
    Quint8     DataType = C.TF_QUINT8
    Qint32     DataType = C.TF_QINT32
    Bfloat16   DataType = C.TF_BFLOAT16
    Qint16     DataType = C.TF_QINT16
    Quint16    DataType = C.TF_QUINT16
    Uint16     DataType = C.TF_UINT16
    Complex128 DataType = C.TF_COMPLEX128
    Half       DataType = C.TF_HALF
)

type Graph

Graph represents a computation graph. Graphs may be shared between sessions.

type Graph struct {
    // contains filtered or unexported fields
}

func NewGraph

func NewGraph() *Graph

NewGraph returns a new Graph.

func (*Graph) AddOperation

func (g *Graph) AddOperation(args OpSpec) (*Operation, error)

AddOperation adds an operation to g.

func (*Graph) Import

func (g *Graph) Import(def []byte, prefix string) error

Import imports the nodes and edges from a serialized representation of another Graph into g.

Names of imported nodes will be prefixed with prefix.

func (*Graph) Operation

func (g *Graph) Operation(name string) *Operation

Operation returns the Operation named name in the Graph, or nil if no such operation is present.

func (*Graph) WriteTo

func (g *Graph) WriteTo(w io.Writer) (int64, error)

WriteTo writes out a serialized representation of g to w.

Implements the io.WriterTo interface.

type Input

Input is the interface for specifying inputs to an operation being added to a Graph.

Operations can have multiple inputs, each of which could be either a tensor produced by another operation (an Output object), or a list of tensors produced by other operations (an OutputList). Thus, this interface is implemented by both Output and OutputList.

See OpSpec.Input for more information.

type Input interface {
    // contains filtered or unexported methods
}

type OpSpec

OpSpec is the specification of an Operation to be added to a Graph (using Graph.AddOperation).

type OpSpec struct {
    // Type of the operation (e.g., "Add", "MatMul").
    Type string

    // Name by which the added operation will be referred to in the Graph.
    // If omitted, defaults to Type.
    Name string

    // Inputs to this operation, which in turn must be outputs
    // of other operations already added to the Graph.
    //
    // An operation may have multiple inputs with individual inputs being
    // either a single tensor produced by another operation or a list of
    // tensors produced by multiple operations. For example, the "Concat"
    // operation takes two inputs: (1) the dimension along which to
    // concatenate and (2) a list of tensors to concatenate. Thus, for
    // Concat, len(Input) must be 2, with the first element being an Output
    // and the second being an OutputList.
    Input []Input

    // Map from attribute name to its value that will be attached to this
    // operation.
    Attrs map[string]interface{}
}

type Operation

Operation that has been added to the graph.

type Operation struct {
    // contains filtered or unexported fields
}

func (*Operation) Name

func (op *Operation) Name() string

Name returns the name of the operation.

func (*Operation) NumOutputs

func (op *Operation) NumOutputs() int

NumOutputs returns the number of outputs of op.

func (*Operation) Output

func (op *Operation) Output(i int) Output

Output returns the i-th output of op.

func (*Operation) OutputListSize

func (op *Operation) OutputListSize(output string) (int, error)

OutputListSize returns the size of the list of Outputs that is produced by a named output of op.

An Operation has multiple named outputs, each of which produces either a single tensor or a list of tensors. This method returns the size of the list of tensors for a specific output of the operation, identified by its name.

func (*Operation) Type

func (op *Operation) Type() string

Type returns the name of the operator used by this operation.

type Output

Output represents one of the outputs of an operation in the graph. Has a DataType (and eventually a Shape). May be passed as an input argument to a function for adding operations to a graph, or to a Session's Run() method to fetch that output as a tensor.

type Output struct {
    // Op is the Operation that produces this Output.
    Op *Operation

    // Index specifies the index of the output within the Operation.
    Index int
}

func (Output) DataType

func (p Output) DataType() DataType

DataType returns the type of elements in the tensor produced by p.

func (Output) Shape

func (p Output) Shape() Shape

Shape returns the (possibly incomplete) shape of the tensor produced p.

type OutputList

OutputList represents a list of Outputs that can be provided as input to another operation.

type OutputList []Output

type PartialRun

PartialRun enables incremental evaluation of graphs.

PartialRun allows the caller to pause the evaluation of a graph, run arbitrary code that depends on the intermediate computation of the graph, and then resume graph execution. The results of the arbitrary code can be fed into the graph when resuming execution. In contrast, Session.Run executes the graph to compute the requested fetches using the provided feeds and discards all intermediate state (e.g., value of intermediate tensors) when it returns.

For example, consider a graph for unsupervised training of a neural network model. PartialRun can be used to pause execution after the forward pass of the network, let the caller actuate the output (e.g., play a game, actuate a robot etc.), determine the error/loss and then feed this calculated loss when resuming the backward pass of the graph.

type PartialRun struct {
    // contains filtered or unexported fields
}

Example

Code:

var (
    // Create a graph: a + 2 + 3 + b.
    //
    // Skipping error handling for brevity of this example.
    // The 'op' package can be used to make graph construction code
    // with error handling more succinct.
    g        = NewGraph()
    a, _     = Placeholder(g, "a", Int32)
    b, _     = Placeholder(g, "b", Int32)
    two, _   = Const(g, "Two", int32(2))
    three, _ = Const(g, "Three", int32(3))

    plus2, _ = Add(g, "plus2", a, two)       // a + 2
    plus3, _ = Add(g, "plus3", plus2, three) // (a + 2) + 3
    plusB, _ = Add(g, "plusB", plus3, b)     // ((a + 2) + 3) + b

)
sess, err := NewSession(g, nil)
if err != nil {
    panic(err)
}
defer sess.Close()

// All the feeds, fetches and targets for subsequent PartialRun.Run
// calls must be provided at setup.
pr, err := sess.NewPartialRun(
    []Output{a, b},
    []Output{plus2, plusB},
    []*Operation{plus3.Op},
)
if err != nil {
    panic(err)
}

// Feed 'a=1', fetch 'plus2', and compute (but do not fetch) 'plus3'.
// Imagine this to be the forward pass of unsupervised neural network
// training of a robot.
val, _ := NewTensor(int32(1))
fetches, err := pr.Run(
    map[Output]*Tensor{a: val},
    []Output{plus2},
    nil)
if err != nil {
    panic(err)
}
v1 := fetches[0].Value().(int32)

// Now, feed 'b=4', fetch 'plusB=a+2+3+b'
// Imagine this to be the result of actuating the robot to determine
// the error produced by the current state of the neural network.
val, _ = NewTensor(int32(4))
fetches, err = pr.Run(
    map[Output]*Tensor{b: val},
    []Output{plusB},
    nil)
if err != nil {
    panic(err)
}
v2 := fetches[0].Value().(int32)

fmt.Println(v1, v2)

Output:

3 10

func (*PartialRun) Run

func (pr *PartialRun) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error)

Run resumes execution of the graph to compute the requested fetches and targets with the provided feeds.

type SavedModel

SavedModel represents the contents of loaded SavedModel. TODO(jhseu): Add and document metagraphdef when we pregenerate protobufs.

type SavedModel struct {
    Session *Session
    Graph   *Graph
}

func LoadSavedModel

func LoadSavedModel(exportDir string, tags []string, options *SessionOptions) (*SavedModel, error)

LoadSavedModel creates a new SavedModel from a model previously exported to a directory on disk.

Exported models contain a set of graphs and, optionally, variable values. Tags in the model identify a single graph. LoadSavedModel initializes a session with the identified graph and with variables initialized to from the checkpoints on disk.

The tensorflow package currently does not have the ability to export a model to a directory from Go. This function thus currently targets loading models exported in other languages, such as using tf.saved_model.builder in Python. See: https://www.tensorflow.org/code/tensorflow/python/saved_model/

type Session

Session drives a TensorFlow graph computation.

When a Session is created with a given target, a new Session object is bound to the universe of resources specified by that target. Those resources are available to this session to perform computation described in the GraphDef. After creating the session with a graph, the caller uses the Run() API to perform the computation and potentially fetch outputs as Tensors. A Session allows concurrent calls to Run().

type Session struct {
    // contains filtered or unexported fields
}

func NewSession

func NewSession(graph *Graph, options *SessionOptions) (*Session, error)

NewSession creates a new execution session with the associated graph. options may be nil to use the default options.

func (*Session) Close

func (s *Session) Close() error

Close a session. This contacts any other processes associated with this session, if applicable. Blocks until all previous calls to Run have returned.

func (*Session) NewPartialRun

func (s *Session) NewPartialRun(feeds, fetches []Output, targets []*Operation) (*PartialRun, error)

NewPartialRun sets up the graph for incremental evaluation.

All values of feeds, fetches and targets that may be provided to Run calls on the returned PartialRun need to be provided to NewPartialRun.

See documentation for the PartialRun type.

func (*Session) Run

func (s *Session) Run(feeds map[Output]*Tensor, fetches []Output, targets []*Operation) ([]*Tensor, error)

Run the graph with the associated session starting with the supplied feeds to compute the value of the requested fetches. Runs, but does not return Tensors for operations specified in targets.

On success, returns the fetched Tensors in the same order as supplied in the fetches argument. If fetches is set to nil, the returned Tensor fetches is empty.

type SessionOptions

SessionOptions contains configuration information for a session.

type SessionOptions struct {
    // Target indicates the TensorFlow runtime to connect to.
    //
    // If 'target' is empty or unspecified, the local TensorFlow runtime
    // implementation will be used.  Otherwise, the TensorFlow engine
    // defined by 'target' will be used to perform all computations.
    //
    // "target" can be either a single entry or a comma separated list
    // of entries. Each entry is a resolvable address of one of the
    // following formats:
    //   local
    //   ip:port
    //   host:port
    //   ... other system-specific formats to identify tasks and jobs ...
    //
    // NOTE: at the moment 'local' maps to an in-process service-based
    // runtime.
    //
    // Upon creation, a single session affines itself to one of the
    // remote processes, with possible load balancing choices when the
    // "target" resolves to a list of possible processes.
    //
    // If the session disconnects from the remote process during its
    // lifetime, session calls may fail immediately.
    Target string

    // Config is a binary-serialized representation of the
    // tensorflow.ConfigProto protocol message
    // (https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto).
    Config []byte
}

type Shape

Shape represents the (possibly partially known) shape of a tensor that will be produced by an operation.

The zero-value of a Shape represents a shape with an unknown number of dimensions.

type Shape struct {
    // contains filtered or unexported fields
}

func MakeShape

func MakeShape(shape ...int64) Shape

MakeShape returns a Shape with the provided size of each dimension.

A value of -1 implies that the size of the corresponding dimension is not known.

func ScalarShape

func ScalarShape() Shape

ScalarShape returns a Shape representing a scalar.

func (Shape) IsFullySpecified

func (s Shape) IsFullySpecified() bool

IsFullySpecified returns true iff the size of all the dimensions of s are known.

func (Shape) NumDimensions

func (s Shape) NumDimensions() int

NumDimensions returns the number of dimensions represented by s, or -1 if unknown.

func (Shape) Size

func (s Shape) Size(dim int) int64

Size returns the size of the dim-th dimension of the shape, or -1 if it is unknown.

REQUIRES: 0 <= dim < s.NumDimensions()

func (Shape) String

func (s Shape) String() string

func (Shape) ToSlice

func (s Shape) ToSlice() ([]int64, error)

ToSlice returns the (possibly partially known) shape represented by s as a slice, or an error if the number of dimensions is not known.

type Tensor

Tensor holds a multi-dimensional array of elements of a single data type.

type Tensor struct {
    // contains filtered or unexported fields
}

func NewTensor

func NewTensor(value interface{}) (*Tensor, error)

NewTensor converts from a Go value to a Tensor. Valid values are scalars, slices, and arrays. Every element of a slice must have the same length so that the resulting Tensor has a valid shape.

func ReadTensor

func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error)

ReadTensor constructs a Tensor with the provided type and shape from the serialized tensor contents in r.

See also WriteContentsTo.

func (*Tensor) DataType

func (t *Tensor) DataType() DataType

DataType returns the scalar datatype of the Tensor.

func (*Tensor) Shape

func (t *Tensor) Shape() []int64

Shape returns the shape of the Tensor.

func (*Tensor) Value

func (t *Tensor) Value() interface{}

Value converts the Tensor to a Go value. For now, not all Tensor types are supported, and this function may panic if it encounters an unsupported DataType.

The type of the output depends on the Tensor type and dimensions. For example: Tensor(int64, 0): int64 Tensor(float64, 3): [][][]float64

func (*Tensor) WriteContentsTo

func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error)

WriteContentsTo writes the serialized contents of t to w.

Returns the number of bytes written. See ReadTensor for reconstructing a Tensor from the serialized form.

WARNING: WriteContentsTo is not comprehensive and will fail if t.DataType() is non-numeric (e.g., String). See https://github.com/tensorflow/tensorflow/issues/6003.

Subdirectories

Name Synopsis
..
genop Command genop generates a Go source file with functions for TensorFlow ops.
op Package op defines functions for adding TensorFlow operations to a Graph.