forked from apache/arrow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor.cc
More file actions
122 lines (103 loc) · 4.03 KB
/
tensor.cc
File metadata and controls
122 lines (103 loc) · 4.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "arrow/tensor.h"
#include <cstdint>
#include <functional>
#include <memory>
#include <numeric>
#include <string>
#include <vector>
#include "arrow/compare.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/util/logging.h"
namespace arrow {
static void ComputeRowMajorStrides(const FixedWidthType& type,
const std::vector<int64_t>& shape,
std::vector<int64_t>* strides) {
int64_t remaining = type.bit_width() / 8;
for (int64_t dimsize : shape) {
remaining *= dimsize;
}
if (remaining == 0) {
strides->assign(shape.size(), type.bit_width() / 8);
return;
}
for (int64_t dimsize : shape) {
remaining /= dimsize;
strides->push_back(remaining);
}
}
static void ComputeColumnMajorStrides(const FixedWidthType& type,
const std::vector<int64_t>& shape,
std::vector<int64_t>* strides) {
int64_t total = type.bit_width() / 8;
for (int64_t dimsize : shape) {
if (dimsize == 0) {
strides->assign(shape.size(), type.bit_width() / 8);
return;
}
}
for (int64_t dimsize : shape) {
strides->push_back(total);
total *= dimsize;
}
}
/// Constructor with strides and dimension names
Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides,
const std::vector<std::string>& dim_names)
: type_(type), data_(data), shape_(shape), strides_(strides), dim_names_(dim_names) {
DCHECK(is_tensor_supported(type->id()));
if (shape.size() > 0 && strides.size() == 0) {
ComputeRowMajorStrides(static_cast<const FixedWidthType&>(*type_), shape, &strides_);
}
}
Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
const std::vector<int64_t>& shape, const std::vector<int64_t>& strides)
: Tensor(type, data, shape, strides, {}) {}
Tensor::Tensor(const std::shared_ptr<DataType>& type, const std::shared_ptr<Buffer>& data,
const std::vector<int64_t>& shape)
: Tensor(type, data, shape, {}, {}) {}
const std::string& Tensor::dim_name(int i) const {
static const std::string kEmpty = "";
if (dim_names_.size() == 0) {
return kEmpty;
} else {
DCHECK_LT(i, static_cast<int>(dim_names_.size()));
return dim_names_[i];
}
}
int64_t Tensor::size() const {
return std::accumulate(shape_.begin(), shape_.end(), 1LL, std::multiplies<int64_t>());
}
bool Tensor::is_contiguous() const { return is_row_major() || is_column_major(); }
bool Tensor::is_row_major() const {
std::vector<int64_t> c_strides;
const auto& fw_type = static_cast<const FixedWidthType&>(*type_);
ComputeRowMajorStrides(fw_type, shape_, &c_strides);
return strides_ == c_strides;
}
bool Tensor::is_column_major() const {
std::vector<int64_t> f_strides;
const auto& fw_type = static_cast<const FixedWidthType&>(*type_);
ComputeColumnMajorStrides(fw_type, shape_, &f_strides);
return strides_ == f_strides;
}
Type::type Tensor::type_id() const { return type_->id(); }
bool Tensor::Equals(const Tensor& other) const { return TensorEquals(*this, other); }
} // namespace arrow